Merge remote-tracking branch 'refs/remotes/origin/develop'

# Conflicts:
#	app/core/config.py
#	app/service/chat_robot/script/service/CallQWen.py
This commit is contained in:
zchen
2025-09-02 20:08:54 +08:00
74 changed files with 5502 additions and 310 deletions

View File

@@ -9,9 +9,9 @@ import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL, CATEGORY_PATH
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import local_debug_const
from app.service.attribute.config import local_debug_const, const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
@@ -25,18 +25,18 @@ class BrandDna:
self.sketch_bucket = "test"
self.image_url = request_item.image_url
self.is_brand_dna = request_item.is_brand_dna
# self.attr_type = pd.read_csv(CATEGORY_PATH)
self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.attr_type = pd.read_csv(CATEGORY_PATH)
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
# self.const = const
self.const = local_debug_const
self.const = const
# self.const = local_debug_const
# 获取结果
def get_result(self):
mask, image = self.get_seg_mask()
cv2.imshow("", image)
cv2.waitKey(0)
# cv2.imshow("", image)
# cv2.waitKey(0)
height, width, channels = image.shape
result_dict = []
@@ -50,8 +50,8 @@ class BrandDna:
outwear_img[mask == value] = image[mask == value]
outwear_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", outwear_img)
cv2.waitKey(0)
# cv2.imshow("", outwear_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(outwear_img)
@@ -89,8 +89,8 @@ class BrandDna:
tops_img[mask == value] = image[mask == value]
tops_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", tops_img)
cv2.waitKey(0)
# cv2.imshow("", tops_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(tops_img)
@@ -129,8 +129,8 @@ class BrandDna:
bottoms_img[mask == value] = image[mask == value]
bottoms_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", bottoms_img)
cv2.waitKey(0)
# cv2.imshow("", bottoms_img)
# cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(bottoms_img)
@@ -327,7 +327,7 @@ if __name__ == '__main__':
# result_url = service.get_result()
# print(result_url)
request_item = BrandDnaModel(
image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png",
image_url="aida-results/result_00006a48-e315-11ee-b7c8-b48351119060.png",
is_brand_dna=True
)
service = BrandDna(request_item)

View File

@@ -0,0 +1,104 @@
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
# from langchain_openai import ChatOpenAI
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
class GenerateBrandInfo:
def __init__(self, request_data):
# minio client init
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# user info init
self.user_id = request_data.user_id
self.category = "brand_logo"
# generate logo init
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.batch_size = 1
self.mode = 'txt2img'
# llm generate brand info init
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
ResponseSchema(name="brand_slogan", description="Brand slogan."),
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
]
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
self.format_instructions = self.output_parser.get_format_instructions()
self.prompt = PromptTemplate(
template="你是一个时装品牌的设计师。根据用户输入提取出brand namebrand slogan,brand logo 描述。如果没有以上内容需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt这个prompt用于生成模型prompt需要完全表达用户的想法并使用英文使用简洁明了的单词不要过长。.\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": self.format_instructions}
)
self._input = self.prompt.format_prompt(question=request_data.prompt)
self.result_data = {}
def get_result(self):
self.llm_generate_brand_info()
self.generate_brand_logo()
return self.result_data
def llm_generate_brand_info(self):
output = self.model(self._input.to_messages())
brand_data = self.output_parser.parse(output.content)
self.result_data = brand_data
self.generate_logo_prompt = brand_data['brand_logo_prompt']
def generate_brand_logo(self):
prompts = [self.generate_logo_prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
logo_url = self.upload_logo_image(image_result, generate_uuid())
self.result_data['brand_logo'] = logo_url
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}'
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
if __name__ == '__main__':
request_data = GenerateBrandModel(
user_id="89",
prompt="华为"
)
service = GenerateBrandInfo(request_data)
print(service.get_result())

View File

@@ -0,0 +1,32 @@
from dotenv import load_dotenv
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
# 加载.env文件的环境变量
load_dotenv()
# 创建一个大语言模型model指定了大语言模型的种类
model = ChatOpenAI(model="qwen2.5-14b-instruct")
# 想要接收的响应模式
response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
ResponseSchema(name="brand_slogan", description="Brand slogan."),
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
template="你是一个时装品牌的设计师。根据用户输入提取出brand namebrand slogan,brand logo 描述。如果没有以上内容需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt这个prompt用于生成模型.\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": format_instructions}
)
_input = prompt.format_prompt(question="brand name: cat home")
output = model(_input.to_messages())
brand_data = output_parser.parse(output.content)
def generate_logo(bucket_name, object_name, prompt):
pass

View File

@@ -90,7 +90,6 @@ def chat(post_data):
user_id = post_data.user_id
session_id = post_data.session_id
input_message = post_data.message
gender = post_data.gender
# final_outputs = agent_executor(
# {"input": input_message, "gender": gender},
@@ -98,7 +97,7 @@ def chat(post_data):
# session_key=f"buffer:{user_id}:{session_id}",
# )
final_outputs = CallQWen.call_with_messages(input_message, gender)
final_outputs = CallQWen.call_with_messages(input_message)
# api_response = {
# 'user_id': user_id,
# 'session_id': session_id,

View File

@@ -34,6 +34,39 @@ You may encounter the following types of questions:
Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
"""
FASHION_CHAT_BOT_PREFIX_TEMP = """
You are a fashion design assistant with the following capabilities:
1. Direct conversation: Answer general questions (e.g., greetings, opinions).
2. Tool usage:
- `get_image_from_vector_db`: Retrieve clothing items (requires gender parameter).
- `internet_search`: Fetch real-time fashion trends.
- `tutorial_tool`: Provide styling guides.
Key Rules:
1. Tool Selection:
- Use `get_image_from_vector_db` for clothing queries (e.g., "show men's jackets").
- Use `internet_search` for time-sensitive queries (e.g., "2024 Paris Fashion Week trends").
- Use `tutorial_tool` for educational requests (e.g., "how to layer outfits").
2. Gender Handling (for `get_image_from_vector_db` only):
- Step 1: Check the **current user input** for gender keywords (e.g., "women/men/she/he"). If found, extract and pass as `gender`.
- Step 2: If no gender in current input, scan the **chat history** for the most recent gender reference.
- Step 3: If undetermined, default to `"unisex"`.
3. Output Format:
- Direct replies: Keep responses under 20 words.
- Tool calls:
- Always include required parameters (e.g., `gender` for `get_image_from_vector_db`).
- Auto-fill `gender` using the above rules if unspecified.
Examples:
1. User: "Find red dresses for women"
→ `get_image_from_vector_db(gender="female", query="dress")`
2. User: "show men's jackets"
→ `get_image_from_vector_db(gender="male", query="outwear")`
3. User: "Show casual outfits"
→ `get_image_from_vector_db(gender="unisex", query="casual outfits")`"""
TOOL_SELECT_SUFFIX = """
Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly.
For database-related questions, use SQL tools to identify relevant tables and query their schemas.

View File

@@ -9,7 +9,7 @@ from app.core.config import *
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
GET_LANGUAGE_PREFIX
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
from app.service.search_image_with_text.service import query
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
@@ -212,14 +212,15 @@ def get_assistant_response(messages):
return response
def call_with_messages(message, gender):
def call_with_messages(message):
global tool_info
user_input = message
print('\n')
messages = [
{
"content": FASHION_CHAT_BOT_PREFIX, # 系统message
# "content": FASHION_CHAT_BOT_PREFIX, # 系统message
"content": FASHION_CHAT_BOT_PREFIX_TEMP, # 修改后的系统message
"role": "system"
},
{
@@ -255,7 +256,7 @@ def call_with_messages(message, gender):
tool_info = {"name": "search_from_internet", "role": "tool"}
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
message = [
{'role': 'assistant', 'content': content['query']}
{'role': 'assistant', 'content': content['query'] if "query" in content.keys() else user_input}
]
tool_info['content'] = search_from_internet(message)
flag = False
@@ -282,6 +283,8 @@ def call_with_messages(message, gender):
result_content = tool_info['content']
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
# todo 从历史对话中获取性别目前无法获得性别时默认使用female
gender = content['gender'] if "gender" in content.keys() and content['gender'] != 'unisex' else 'female'
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
'content': get_image_from_vector_db(gender, content['parameters']['content'] if "parameters" in content.keys() else content['content'])}
flag = False

View File

@@ -0,0 +1,161 @@
import io
import time
from pprint import pprint
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.clothing_seg import ClothingSegModel
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class ClothingSeg:
def __init__(self, request_data):
self.image_data = request_data.image_data
self.user_id = request_data.user_id
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
@RunTime
def get_result(self):
self.read_image()
self.clothing_seg()
self.upload_image()
for data in self.image_data:
del data["image"]
del data["clothing"]
return self.image_data
@RunTime
def upload_image(self):
for data in self.image_data:
data["clothing_url"] = []
for clothing in data["clothing"]:
object_name = f"{self.user_id}/clothing_seg/{generate_uuid()}.png"
image_data = io.BytesIO()
clothing.save(image_data, format="PNG")
image_data.seek(0)
image_bytes = image_data.read()
oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
data["clothing_url"].append(f"aida-users/{object_name}")
@RunTime
def read_image(self):
for data in self.image_data:
url = data["image_url"]
image = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
data["image"] = image
@RunTime
def clothing_seg(self):
for data in self.image_data:
image_type = data["image_type"]
image = data["image"]
clothing_result = []
if image_type == "sketch":
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
seg_mask = get_seg_result(1, image[:, :, :3])
else:
seg_mask = get_seg_result(1, image[:, :, :3])
temp = seg_mask != 0.0
mask = (255 * (temp + 0).astype(np.uint8))
x_min, y_min, x_max, y_max = get_bounding_box(mask)
cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
h, w = cropped_image.shape[:2]
mask_pil = Image.fromarray(cropped_mask).convert("L")
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
clothing_result.append(transparent_image)
else:
input_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
input0_data = [input_image.astype(np.float32)] * 1
input0_data = np.array(input0_data, dtype=np.float32)
inputs = [
grpcclient.InferInput(
"INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype)
),
]
inputs[0].set_data_from_numpy(input0_data)
outputs = [
# grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]
response = self.triton_client.infer("seg_clothing", inputs, request_id=str(1), outputs=outputs)
# output0_data = response.as_numpy("OUTPUT0")
# cv2.imwrite("output02.png", output0_data * 100)
output1_data = response.as_numpy("OUTPUT1")
for alpha in output1_data:
alpha = cv2.resize(alpha, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_CUBIC)
x_min, y_min, x_max, y_max = get_bounding_box(alpha)
cropped_mask = alpha[y_min:y_max + 1, x_min:x_max + 1]
cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
h, w = cropped_image.shape[:2]
mask_pil = Image.fromarray(cropped_mask).convert("L")
image_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
transparent_image = Image.new("RGBA", (w, h), (0, 0, 0, 0))
transparent_image.paste(image_pil, (0, 0), mask=mask_pil)
clothing_result.append(transparent_image)
data["clothing"] = clothing_result
@RunTime
def get_bounding_box(mask):
"""
从仅包含 0 和 1 的掩码图像中获取边界框。
:param mask: 输入的掩码图像,二维 numpy 数组,元素为 0 或 1
:return: 边界框坐标 (x_min, y_min, x_max, y_max)
"""
# 找到所有值不为 0 的像素的坐标
rows, cols = np.where(mask != 0)
if len(rows) == 0 or len(cols) == 0:
# 如果没有找到不为 0 的像素,返回全 0 的边界框
return 0, 0, 0, 0
# 计算边界框的坐标
x_min = np.min(cols)
y_min = np.min(rows)
x_max = np.max(cols)
y_max = np.max(rows)
return x_min, y_min, x_max, y_max
if __name__ == "__main__":
test_data = ClothingSegModel(
user_id=89,
image_data=[
# {
# "image_url": "test/clothing_seg/dress.jpg",
# "image_type": "sketch"
# },
# {
# "image_url": "test/clothing_seg/skirt_559.jpg",
# "image_type": "sketch"
# },
{
"image_url": "aida-collection-element/87/Sketchboard/ab40e035-547a-48c5-9f97-1db7bf56ad77.jpg",
"image_type": "sketch"
}
]
)
start_time = time.time()
server = ClothingSeg(test_data)
pprint(server.get_result())
print(time.time() - start_time)

View File

@@ -5,9 +5,9 @@ from celery import Celery
from minio import Minio
from app.core.config import *
from app.service.design_batch.item import BodyItem, TopItem, BottomItem
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, AccessoriesItem
from app.service.design_batch.utils.MQ import publish_status
from app.service.design_batch.utils.organize import organize_body, organize_clothing
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_accessories
from app.service.design_batch.utils.save_json import oss_upload_json
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
@@ -19,6 +19,8 @@ logging.getLogger('pika').setLevel(logging.WARNING)
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
print("start")
def process_item(item, basic):
# 处理project中单个item
@@ -28,9 +30,14 @@ def process_item(item, basic):
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
item_data = top_server.process()
else:
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
elif item['type'].lower() in ['accessories']:
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
else:
raise NotImplementedError(f"Item type {item['type']} not implemented")
return item_data
@@ -40,6 +47,10 @@ def process_layer(item, layers):
body_layer = organize_body(item)
layers.append(body_layer)
return item['body_image'].size
elif item['name'] == 'accessories':
front_layer, back_layer = organize_accessories(item)
layers.append(front_layer)
layers.append(back_layer)
else:
front_layer, back_layer = organize_clothing(item)
layers.append(front_layer)
@@ -48,6 +59,9 @@ def process_layer(item, layers):
@celery_app.task
def batch_design(objects_data, tasks_id, json_name):
print(objects_data)
print(tasks_id)
print(json_name)
object_response = []
threads = []
active_threads = 0
@@ -71,7 +85,7 @@ def batch_design(objects_data, tasks_id, json_name):
for lay in layers:
items_response['layers'].append({
'image_category': lay['name'],
'image_category': "body" if lay['name'] == 'mannequin' else lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
@@ -121,6 +135,7 @@ def batch_design(objects_data, tasks_id, json_name):
for t in threads:
t.join()
logger.debug(object_response)
print(object_response)
oss_upload_json(minio_client, object_response, json_name)
publish_status(tasks_id, "ok", json_name)
return object_response

View File

@@ -1,4 +1,4 @@
from app.service.design_batch.pipeline import *
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
class BaseItem:
@@ -9,6 +9,27 @@ class BaseItem:
self.result.update(basic)
class AccessoriesItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Accessories_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
ContourDetection(),
# Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.Accessories_pipeline:
self.result = item(self.result)
return self.result
class TopItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
@@ -16,6 +37,7 @@ class TopItem(BaseItem):
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
@@ -35,7 +57,8 @@ class BottomItem(BaseItem):
LoadImage(minio_client),
KeyPoint(),
ContourDetection(),
# Segmentation(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),

View File

@@ -1,3 +1,4 @@
from .back_perspective import BackPerspective
from .color import Color
from .contour_detection import ContourDetection
from .keypoint import KeyPoint
@@ -13,6 +14,7 @@ __all__ = [
'KeyPoint',
'ContourDetection',
'Segmentation',
'BackPerspective',
'Color',
'PrintPainting',
'Scaling',

View File

@@ -0,0 +1,79 @@
import cv2
import numpy as np
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.new_oss_client import oss_upload_image
class BackPerspective:
def __init__(self, minio_client):
self.minio_client = minio_client
def __call__(self, result):
# 如果sketch为系统图 查看是否有对应的 背后视角图
if result['path'].split('/')[0] == 'aida-sys-image':
file_path = result['path'].replace("images", 'images_back', 1)
if self.is_file_exists(bucket_name='aida-sys-image', file_name=file_path[file_path.find('/') + 1:]):
result['back_perspective_url'] = file_path
return result
else:
seg_result = get_seg_result("1", result['image'])[0]
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
seg_result = result['seg_result']
else:
seg_result = get_seg_result("1", result['image'])[0]
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
back_sketch = result['image'].copy()
back_sketch[m > 100] = 255
# 上传背后视角图
_, img_encoded = cv2.imencode(".jpg", back_sketch)
resp = oss_upload_image(self.minio_client, bucket='test', object_name=result['path'], image_bytes=img_encoded.tobytes())
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
return result
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
mask = mask.astype(np.uint8) * 255
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 创建一个彩色副本用于绘制轮廓
mask_color = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
def thicken_contour_inward(contour, thick):
# 创建一个空白的黑色图像与原始掩码大小相同
blank = np.zeros_like(mask)
# 在空白图像上绘制白色的轮廓
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
# 找到轮廓的中心(可以用重心等方法近似)
M = cv2.moments(contour)
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
# 进行距离变换,离中心越近的值越小
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留
result = np.zeros_like(mask)
for i in range(dist_transform.shape[0]):
for j in range(dist_transform.shape[1]):
if dist_transform[i, j] < thick:
result[i, j] = 255
return result
for contour in contours:
thickened_contour = thicken_contour_inward(contour, thickness)
mask_color[thickened_contour > 0] = color
_, binary_result = cv2.threshold(mask_color, 127, 255, cv2.THRESH_BINARY)
# 转换为掩码形式
mask_result = cv2.cvtColor(binary_result, cv2.COLOR_BGR2GRAY)
return mask_result
def is_file_exists(self, bucket_name, file_name):
try:
self.minio_client.stat_object(bucket_name, file_name)
return True
except Exception:
return False

View File

@@ -14,14 +14,39 @@ class Color:
def __call__(self, result):
dim_image_h, dim_image_w = result['image'].shape[0:2]
# 渐变色
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
# 无色
elif "color" not in result.keys() or result['color'] == "":
result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
result['alpha'] = 100 / 255.0
return result
# 正常颜色
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
if "partial_color" in result.keys() and result['partial_color'] != "":
bucket_name = result['partial_color'].split('/')[0]
object_name = result['partial_color'][result['partial_color'].find('/') + 1:]
partial_color = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
h, w = partial_color.shape[0:2]
resize_pattern = cv2.resize(resize_pattern, (w, h), interpolation=cv2.INTER_AREA)
# 分离出 png 图的 alpha 通道
alpha_channel = partial_color[:, :, 3]
# 提取 png 图的 RGB 通道
png_rgb = partial_color[:, :, :3]
# 创建一个与 cv 图大小相同的掩码,用于指示哪些像素需要替换
mask = alpha_channel > 0
# 将掩码扩展为 3 通道,以便与 cv 图进行逐元素操作
mask_3ch = np.stack([mask] * 3, axis=-1)
# 根据掩码将 png 图的颜色覆盖到 cv 图上
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)

View File

@@ -4,7 +4,8 @@ import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.design_batch.utils.design_ensemble import get_keypoint_result
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
from app.service.utils.decorator import ClassCallRunTime, RunTime
logger = logging.getLogger(__name__)
@@ -16,14 +17,15 @@ class KeyPoint:
def get_name(cls):
return cls.name
@ClassCallRunTime
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# keypoint_cache = self.keypoint_cache(result, site)
keypoint_cache = False
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
@@ -87,7 +89,7 @@ class KeyPoint:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
@RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)

View File

@@ -1,6 +1,9 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image
@@ -71,6 +74,8 @@ class LoadImage:
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
elif name == 'accessories':
keypoint = "accessories"
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")

View File

@@ -15,8 +15,25 @@ class PrintPainting:
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
partial_path = result['print']['partial'] if 'partial' in result['print'] else None
result['single_image'] = None
result['print_image'] = None
# TODO 给result['pattern_image'] resize 到resize_scale的大小
# TODO 给result['mask'] resize 到resize_scale的大小
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
pass
else:
height, width = result['pattern_image'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
result['pattern_image'] = cv2.resize(result['pattern_image'], (new_width, new_height))
result['final_image'] = cv2.resize(result['final_image'], (new_width, new_height))
result['mask'] = cv2.resize(result['mask'], (new_width, new_height))
result['gray'] = cv2.resize(result['gray'], (new_width, new_height))
print(1)
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
@@ -39,7 +56,7 @@ class PrintPainting:
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
mask = image.split()[3]
resized_source = image.resize(new_size)
@@ -62,9 +79,12 @@ class PrintPainting:
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
@@ -143,7 +163,7 @@ class PrintPainting:
for i in range(len(element_print['element_path_list'])):
image, image_mode = self.read_image(element_print['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
new_size = (int(result['final_image'].shape[1] * element_print['element_scale_list'][i][0]), int(result['final_image'].shape[0] * element_print['element_scale_list'][i][1]))
mask = image.split()[3]
resized_source = image.resize(new_size)
@@ -165,9 +185,11 @@ class PrintPainting:
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
@@ -241,6 +263,45 @@ class PrintPainting:
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if partial_path:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
image, image_mode = self.read_image(partial_path)
if image_mode == "RGBA":
new_size = (result['pattern_image'].shape[1], result['pattern_image'].shape[0])
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
# rotated_resized_source = resized_source.rotate(-partial_print['print_angle_list'][i])
# rotated_resized_source_mask = resized_source_mask.rotate(-partial_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(resized_source, (0, 0), resized_source)
source_image_pil_mask.paste(resized_source_mask, (0, 0), resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
@@ -360,10 +421,10 @@ class PrintPainting:
return print_image
def get_print(self, print_dict):
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
print_dict['scale'] = print_dict['print_scale_list'][0][0]
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
@@ -386,8 +447,9 @@ class PrintPainting:
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
x_offset = print_w - int(location[0][1] % print_w)
y_offset = print_w - int(location[0][0] % print_h)
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
@@ -409,7 +471,7 @@ class PrintPainting:
return high, low
@staticmethod
def img_rotate(image, angel, scale):
def img_rotate(image, angel):
"""顺时针旋转图像任意角度
Args:
@@ -424,7 +486,7 @@ class PrintPainting:
center = (w // 2, h // 2)
# if type(angel) is not int:
# angel = 0
M = cv2.getRotationMatrix2D(center, -angel, scale)
M = cv2.getRotationMatrix2D(center, -angel, 1)
# 调整旋转后的图像长宽
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
@@ -433,7 +495,7 @@ class PrintPainting:
# 旋转图像
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
return rotated_img, ((rotated_img.shape[1] - image.shape[1]) // 2, (rotated_img.shape[0] - image.shape[0]) // 2)
# return rotated_img, (0, 0)
@staticmethod
@@ -442,8 +504,11 @@ class PrintPainting:
angle: 旋转的角度
crop: 是否需要进行裁剪,布尔向量
"""
if not isinstance(crop, bool):
raise ValueError("The 'crop' parameter must be a boolean.")
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
w, h = img.shape[:2]
h, w = img.shape[:2]
# 旋转角度的周期是360°
angle %= 360
# 计算仿射变换矩阵
@@ -455,7 +520,7 @@ class PrintPainting:
if crop:
# 裁剪角度的等效周期是180°
angle_crop = angle % 180
if angle > 90:
if angle_crop > 90:
angle_crop = 180 - angle_crop
# 转化角度为弧度
theta = angle_crop * np.pi / 180

View File

@@ -46,4 +46,16 @@ class Scaling:
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
elif result['keypoint'] == 'accessories':
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
distance_clo = result['img_shape'][1]
distance_bdy = 320 / 2
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
else:
result['scale'] = 1
return result

View File

@@ -5,7 +5,8 @@ import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.design_batch.utils.design_ensemble import get_seg_result
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
logger = logging.getLogger()
@@ -15,6 +16,7 @@ class Segmentation:
def __init__(self, minio_client):
self.minio_client = minio_client
@ClassCallRunTime
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
@@ -31,24 +33,37 @@ class Segmentation:
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
result['seg_result'] = seg_result
if not _:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
seg_result = get_seg_result(result["image_id"], result['image'])
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
# 处理前片后片
temp_front = seg_result == 1.0
temp_front = seg_result == 1
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
temp_back = seg_result == 2
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"seg_cache/{image_id}.npy"
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
@@ -57,7 +72,7 @@ class Segmentation:
@staticmethod
def load_seg_result(image_id):
file_path = f"seg_cache/{image_id}.npy"
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)

View File

@@ -7,10 +7,11 @@ from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.design_batch.utils.conversion_image import rgb_to_rgba
from app.service.design_batch.utils.upload_image import upload_png_mask
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
from app.service.design_fast.utils.upload_image import upload_png_mask
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
class Split(object):
@@ -20,51 +21,95 @@ class Split(object):
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
front_mask = result['front_mask']
back_mask = result['back_mask']
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'):
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
front_mask = result['front_mask']
back_mask = result['back_mask']
else:
height, width = result['front_mask'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
front_mask = cv2.resize(result['front_mask'], (new_width, new_height))
back_mask = cv2.resize(result['back_mask'], (new_width, new_height))
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
rgba_image = cv2.resize(rgba_image, new_size)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
# 预处理用户自选区mask
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
blue_mask = b > r
# 创建红色和绿色掩码
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
else:
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
# result_back_image = np.zeros_like(rgba_image)
# back_mask = cv2.resize(back_mask, new_size)
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
#
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# else:
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# result['back_image'] = None
# result["back_image_url"] = None
# # result["back_mask_url"] = None
# # result['back_mask_image'] = None
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
# result["back_mask_url"] = None
# result['back_mask_image'] = None
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))

View File

@@ -2,16 +2,16 @@ import json
import pika
from app.core.config import RABBITMQ_PARAMS
from app.core.config import RABBITMQ_PARAMS, BATCH_DESIGN_RABBITMQ_QUEUES
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue='DesignBatch', durable=True)
channel.queue_declare(queue=BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key='DesignBatch',
routing_key=BATCH_DESIGN_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,

View File

@@ -33,8 +33,8 @@ def organize_clothing(layer):
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
pattern_image=layer['pattern_image']
pattern_image=layer['pattern_image'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
@@ -50,6 +50,46 @@ def organize_clothing(layer):
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
return front_layer, back_layer
def organize_accessories(layer):
# 起始坐标
start_point = (0, 0)
# 前片数据
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
name=f'{layer["name"].lower()}_front',
image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
pattern_image=layer['pattern_image'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
name=f'{layer["name"].lower()}_back',
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
return front_layer, back_layer

View File

@@ -200,6 +200,11 @@ def design_generate_v2(request_data):
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
# 发送结果给java端
url = JAVA_STREAM_API_URL
# xu_pei_test_url = "https://cd21b9110505.ngrok-free.app/api/third/party/receiveDesignResults"
logger.info(f"java 回调 -> {url}")
# logger.info(f"xupei java 回调 -> {xu_pei_test_url}")
headers = {
'Accept': "*/*",
'Accept-Encoding': "gzip, deflate, br",
@@ -213,6 +218,11 @@ def design_generate_v2(request_data):
# 打印结果
logger.info(response.text)
# response = post_request(xu_pei_test_url, json_data=items_response, headers=headers)
# if response:
# 打印结果
# logger.info(f"xupei test response : {response.text}")
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object))
threads.append(t)

View File

@@ -57,7 +57,7 @@ class BottomItem(BaseItem):
LoadImage(minio_client),
KeyPoint(),
ContourDetection(),
# Segmentation(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),

View File

@@ -29,6 +29,24 @@ class Color:
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
if "partial_color" in result.keys() and result['partial_color'] != "":
bucket_name = result['partial_color'].split('/')[0]
object_name = result['partial_color'][result['partial_color'].find('/') + 1:]
partial_color = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2")
h, w = partial_color.shape[0:2]
resize_pattern = cv2.resize(resize_pattern, (w, h), interpolation=cv2.INTER_AREA)
# 分离出 png 图的 alpha 通道
alpha_channel = partial_color[:, :, 3]
# 提取 png 图的 RGB 通道
png_rgb = partial_color[:, :, :3]
# 创建一个与 cv 图大小相同的掩码,用于指示哪些像素需要替换
mask = alpha_channel > 0
# 将掩码扩展为 3 通道,以便与 cv 图进行逐元素操作
mask_3ch = np.stack([mask] * 3, axis=-1)
# 根据掩码将 png 图的颜色覆盖到 cv 图上
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)

View File

@@ -15,6 +15,7 @@ class PrintPainting:
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
partial_path = result['print']['partial'] if 'partial' in result['print'] else None
result['single_image'] = None
result['print_image'] = None
# TODO 给result['pattern_image'] resize 到resize_scale的大小
@@ -32,7 +33,6 @@ class PrintPainting:
result['mask'] = cv2.resize(result['mask'], (new_width, new_height))
result['gray'] = cv2.resize(result['gray'], (new_width, new_height))
print(1)
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
@@ -54,90 +54,89 @@ class PrintPainting:
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
if image_mode == "RGB":
image_rgba = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
image = Image.fromarray(image_rgba)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
# else:
# mask = self.get_mask_inv(image)
# mask = np.expand_dims(mask, axis=2)
# mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
# mask = cv2.bitwise_not(mask)
#
# mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# # 旋转后的坐标需要重新算
# rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
# rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
# # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
# x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
#
# image_x = print_background.shape[1] # 底图宽
# image_y = print_background.shape[0] # 底图高
# print_x = rotate_image.shape[1] #印花宽
# print_y = rotate_image.shape[0] #印花高
#
# # 有bug
# # if x + print_x > image_x:
# # rotate_image = rotate_image[:, :x + print_x - image_x]
# # rotate_mask = rotate_mask[:, :x + print_x - image_x]
# # #
# # if y + print_y > image_y:
# # rotate_image = rotate_image[:y + print_y - image_y]
# # rotate_mask = rotate_mask[:y + print_y - image_y]
#
# # 不能是并行
# # 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# # 先挪 再判断 最后裁剪
#
# # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
# if x <= 0: # 如果X轴偏移量小于0说明印花需要被裁剪至合适大小 或当X轴偏移量大于印花宽度时裁剪后的印花宽度为0
# rotate_image = rotate_image[:, abs(x):]
# rotate_mask = rotate_mask[:, abs(x):]
# start_x = x = 0
# else:
# start_x = x
#
# if y <= 0: # 如果X轴偏移量大于0说明印花需要被裁剪至合适大小 或当Y轴偏移量大于印花宽度时裁剪后的印花宽度为0
# rotate_image = rotate_image[abs(y):, :]
# rotate_mask = rotate_mask[abs(y):, :]
# start_y = y = 0
# else:
# start_y = y
#
# # ------------------
# # 如果print-size大于image-size 则需要裁剪print
#
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :image_x - x]
# rotate_mask = rotate_mask[:, :image_x - x]
#
# if y + print_y > image_y:
# rotate_image = rotate_image[:image_y - y, :]
# rotate_mask = rotate_mask[:image_y - y, :]
#
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
#
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
# mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
# print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
@@ -262,6 +261,45 @@ class PrintPainting:
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if partial_path:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
image, image_mode = self.read_image(partial_path)
if image_mode == "RGBA":
new_size = (result['pattern_image'].shape[1], result['pattern_image'].shape[0])
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
# rotated_resized_source = resized_source.rotate(-partial_print['print_angle_list'][i])
# rotated_resized_source_mask = resized_source_mask.rotate(-partial_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(resized_source, (0, 0), resized_source)
source_image_pil_mask.paste(resized_source_mask, (0, 0), resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
@@ -414,7 +452,6 @@ class PrintPainting:
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:

View File

@@ -65,35 +65,51 @@ class Split(object):
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
# result_back_image = np.zeros_like(rgba_image)
# back_mask = cv2.resize(back_mask, new_size)
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
#
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# else:
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# result['back_image'] = None
# result["back_image_url"] = None
# # result["back_mask_url"] = None
# # result['back_mask_image'] = None
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
# result["back_mask_url"] = None
# result['back_mask_image'] = None
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))

View File

@@ -0,0 +1,24 @@
from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product, publish_status as product_publish_status
from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight, publish_status as relight_publish_status
from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform, publish_status as pose_transform_publish_status
async def start_product_batch_generate(data):
generate_clothes_task = batch_generate_product.delay(data.dict())
print(generate_clothes_task)
product_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "")
return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state}
async def start_relight_batch_generate(data):
generate_clothes_task = batch_generate_relight.delay(data.dict())
print(generate_clothes_task)
relight_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "")
return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state}
async def start_pose_transform_batch_generate(data):
generate_clothes_task = batch_generate_pose_transform.delay(data.dict())
print(generate_clothes_task)
pose_transform_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}

View File

@@ -0,0 +1,242 @@
# 旧版product
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import BatchGenerateProductImageModel, ProductItemModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
celery_app = Celery('product_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_product'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logger = logging.getLogger()
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
category = "product_image"
@celery_app.task
def batch_generate_product(batch_request_data):
batch_size = len(batch_request_data['batch_data_list'])
logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}")
batch_tasks_id = batch_request_data['batch_tasks_id']
user_id = batch_request_data['user_id']
result_data_list = []
for i, data in enumerate(batch_request_data['batch_data_list']):
tasks_id = data['tasks_id']
image = pre_processing_image(data['image_url'])
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
images = [image.astype(np.uint8)] * 1
prompts = [data['prompt']] * 1
if data['product_type'] == "single":
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_image_strength.set_data_from_numpy(image_strength_obj)
inputs = [input_text, input_image, input_image_strength]
try:
if data['product_type'] == "single":
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
image = result.as_numpy("generated_cnet_image")
else:
result = grpc_client.infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
except Exception as e:
if 'mask_list' in str(e):
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
e_image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
e_input_image_strength = grpcclient.InferInput("image_strength", e_image_strength_obj.shape, np_to_triton_dtype(e_image_strength_obj.dtype))
e_input_text.set_data_from_numpy(e_text_obj)
e_input_image.set_data_from_numpy(e_image_obj)
e_input_image_strength.set_data_from_numpy(e_image_strength_obj)
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=[e_input_text, e_input_image, e_input_image_strength], priority=100)
image = result.as_numpy("generated_cnet_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
else:
image_result = str(e)
logger.error(image_result)
if isinstance(image_result, Image.Image):
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
data['product_img'] = image_url
result_data_list.append(data)
else:
image_url = image_result
data['product_img'] = image_url
result_data_list.append(data)
# 发送每条结果
if DEBUG:
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
else:
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
# 任务完成,发送所有数据结果
if DEBUG:
print(result_data_list)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
else:
publish_status(batch_tasks_id, f"OK", result_data_list)
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def post_processing_image(image, left, top):
resized_image = image.resize((int(image.width * (768 / image.height)), 768))
# 计算裁剪的坐标
left = (resized_image.width - 512) // 2
upper = 0
right = left + 512
lower = 768
# 进行裁剪
cropped_image = resized_image.crop((left, upper, right, lower))
return cropped_image
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_GPI_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_GPI_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
# rd = BatchGenerateProductImageModel(
# tasks_id="123-15-51-89",
# image_strength=0.7,
# prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
# image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# product_type="overall",
# batch_size=20
# )
# batch_generate_product(rd.dict())
# rd = {
# "user_id": "89",
# "batch_data_list": [
# {
# "tasks_id": "A-123-15-51-89",
# "image_strength": 0.7,
# "prompt": " The best quality, ma123sterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# "product_type": "overall",
# },
# {
# "tasks_id": "B-123-15-51-89",
# "image_strength": 0.7,
# "prompt": " The best quality, masterpiece, real image.Outwear123,high quality clothing details,8K realistic,HDR",
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
# "product_type": "overall",
# }
# ]
# }
rd = BatchGenerateProductImageModel(
batch_tasks_id="abcd",
user_id="89",
batch_data_list=[
ProductItemModel(
tasks_id="123-5464",
image_strength=0.7,
product_type="overall",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
prompt="123"
),
ProductItemModel(
tasks_id="123-5464123",
image_strength=0.7,
product_type="overall",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
prompt="123"
)
]
)
batch_generate_product(rd.dict())

View File

@@ -0,0 +1,250 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import BatchGenerateRelightImageModel, RelightItemModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
celery_app = Celery('relight_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_relight'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
category = "relight_image"
@celery_app.task
def batch_generate_relight(batch_request_data):
batch_size = len(batch_request_data['batch_data_list'])
logger.info(f"batch_generate_relight batch_request_data: {json.dumps(batch_request_data, indent=4)}")
batch_tasks_id = batch_request_data['batch_tasks_id']
user_id = batch_request_data['user_id']
result_data_list = []
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
seed = "1"
for i, data in enumerate(batch_request_data['batch_data_list']):
direction = data['direction']
prompt = data['prompt']
product_type = data['product_type']
image_url = data['image_url']
image = pre_processing_image(image_url)
tasks_id = data['tasks_id']
prompts = [prompt] * 1
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * 1
seeds = [seed] * 1
nagetive_prompts = [negative_prompt] * 1
directions = [direction] * 1
if product_type == 'single':
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
try:
if data['product_type'] == "single":
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
image = result.as_numpy("generated_relight_image")
else:
result = grpc_client.infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, priority=100)
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
except Exception as e:
print(e)
if 'mask_list' in str(e):
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
e_na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
e_seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
e_direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
e_input_natext = grpcclient.InferInput("negative_prompt", e_na_text_obj.shape, np_to_triton_dtype(e_na_text_obj.dtype))
e_input_seed = grpcclient.InferInput("seed", e_seed_obj.shape, np_to_triton_dtype(e_seed_obj.dtype))
e_input_direction = grpcclient.InferInput("direction", e_direction_obj.shape, np_to_triton_dtype(e_direction_obj.dtype))
e_input_text.set_data_from_numpy(e_text_obj)
e_input_image.set_data_from_numpy(e_image_obj)
e_input_natext.set_data_from_numpy(e_na_text_obj)
e_input_seed.set_data_from_numpy(e_seed_obj)
e_input_direction.set_data_from_numpy(e_direction_obj)
e_inputs = [e_input_text, e_input_natext, e_input_image, e_input_seed, e_input_direction]
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=e_inputs, priority=100)
image = result.as_numpy("generated_relight_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
else:
image_result = str(e)
logger.error(e)
if isinstance(image_result, Image.Image):
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
data['relight_img'] = image_url
result_data_list.append(data)
else:
image_url = image_result
data['relight_img'] = image_url
result_data_list.append(data)
# 发送每条结果
if DEBUG:
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
else:
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | result_data{data}")
# 任务完成,发送所有数据结果
if DEBUG:
print(result_data_list)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
else:
publish_status(batch_tasks_id, f"OK", result_data_list)
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id{batch_tasks_id} | progressOK | result_data_list{result_data_list}")
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_GRI_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_GRI_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
rd = BatchGenerateRelightImageModel(
batch_tasks_id="abcd",
user_id="89",
batch_data_list=[
RelightItemModel(
tasks_id="123-5464",
product_type="overall",
image_url="test/703190759.png",
prompt="Colorful black",
direction="Right Light",
),
RelightItemModel(
tasks_id="123-5464123",
product_type="overall",
image_url="test/703190759.png",
direction="Right Light",
prompt="Colorful black",
)
]
)
batch_generate_relight(rd.dict())
# X = {
# "batch_tasks_id": "abcd",
# "user_id": "89",
# "batch_data_list": [
# {
# "tasks_id": "123-5464",
# "product_type": "overall",
# "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
# "prompt": "Colorful black",
# "direction": "Right Light",
# },
# {
# "tasks_id": "123-5464",
# "product_type": "overall",
# "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
# "prompt": "Colorful black",
# "direction": "Right Light",
# }
#
# ]
# }

View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
from io import BytesIO
import imageio
import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image
from celery import Celery
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.pose_transform import BatchPoseTransformModel
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video
from app.service.utils.new_oss_client import oss_upload_image
from app.service.utils.oss_client import oss_get_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
logger = logging.getLogger()
celery_app = Celery('post_transform_tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.task_default_queue = 'queue_post_transform'
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
category = "pose_transform"
def upload_first_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
result_image = result_image.convert("RGB")
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
@celery_app.task
def batch_generate_pose_transform(batch_request_data):
logger.info(f"batch_generate_pose_transform batch_request_data: {json.dumps(batch_request_data, indent=4)}")
batch_size = batch_request_data['batch_size']
image_url = batch_request_data['image_url']
image = pre_processing_image(image_url)
pose_num = batch_request_data['pose_id']
tasks_id = batch_request_data['tasks_id']
user_id = tasks_id.rsplit('-', 1)[1]
pose_num = [pose_num] * 1
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype))
input_pose_num.set_data_from_numpy(pose_num_obj)
image_files = [image.astype(np.uint8)] * 1
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
result_url_list = []
for i in range(batch_size):
try:
result = grpc_client.infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], client_timeout=60000, priority=100)
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
# 第一帧图像
first_image = Image.fromarray(result_data[0])
first_image_url = upload_first_image(first_image, user_id=user_id, category=f"{category}_first_img", file_name=f"{tasks_id}_batch_{i}.png")
# 上传GIF
gif_buffer = BytesIO()
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
gif_buffer.seek(0)
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=user_id, category=f"{category}_gif", file_name=f"{tasks_id}_batch_{i}.gif")
# 上传video
video_url = upload_video(frames=result_data, user_id=user_id, category=f"{category}_video", file_name=f"{tasks_id}_batch_{i}.mp4")
data = {
"gif_url": gif_url,
"video_url": video_url,
"first_image_url": first_image_url,
}
except Exception as e:
print(e)
data = {}
result_url_list.append(data)
if DEBUG is False:
if i + 1 < batch_size:
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
logger.info(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}")
# print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progress{i + 1}/{batch_size} | image_url{image_url}")
else:
publish_status(tasks_id, f"OK", result_url_list)
logger.info(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progressOK | image_url{image_url}")
# print(f" [x]Queue : {BATCH_PS_RABBITMQ_QUEUES} | tasks_id{tasks_id} | progressOK | image_url{image_url}")
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_PS_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_PS_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
if __name__ == '__main__':
rd = BatchPoseTransformModel(
tasks_id="123-89",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
pose_id="1",
batch_size=10
)
batch_generate_pose_transform(rd.dict())

View File

@@ -0,0 +1,28 @@
# import logging
#
# from celery import Celery
#
# from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product
# from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight
# from app.service.generate_batch_image.service_batch_pose_transform import batch_generate_pose_transform
#
# logger = logging.getLogger()
# celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
# celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
# celery_app.conf.worker_hijack_root_logger = False
# logging.getLogger('pika').setLevel(logging.WARNING)
#
#
# @celery_app.task
# def batch_pose_transform_tasks(batch_request_data):
# batch_generate_pose_transform(batch_request_data)
#
#
# @celery_app.task
# def batch_generate_relight_tasks(batch_request_data):
# batch_generate_relight(batch_request_data)
#
#
# @celery_app.task
# def batch_generate_product_tasks(batch_request_data):
# batch_generate_product(batch_request_data)

View File

@@ -0,0 +1,36 @@
from app.schemas.generate_image import BatchGenerateRelightImageModel, BatchGenerateProductImageModel
from app.service.generate_batch_image.service_batch_generate_product_image import batch_generate_product
from app.service.generate_batch_image.service_batch_generate_relight_image import batch_generate_relight
if __name__ == '__main__':
rd = BatchGenerateProductImageModel(
tasks_id="test1-89",
image_strength=0.7,
prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
product_type="single",
batch_size=2
)
x = batch_generate_product.delay(rd.dict())
print(x)
"""relight"""
# rd = BatchGenerateRelightImageModel(
# tasks_id="123-89",
# # prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
# prompt="Colorful black",
# image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
# direction="Right Light",
# product_type="single",
# batch_size=2
# )
# batch_generate_relight.delay(rd.dict())
"""pose transform"""
# rd = BatchPoseTransformModel(
# tasks_id="123-89",
# image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
# pose_id="1",
# batch_size=10
# )
# batch_pose_transform_tasks.delay(rd.dict())

View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import logging
import time
import uuid
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import tritonclient.http as httpclient
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class AgentToolGenerateImage:
def __init__(self, version):
if version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
else:
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
def get_result(self, prompt, size, version, category, gender):
image_url_list = []
image_result_list = []
clothing_category_list = []
try:
prompts = [prompt] * 1
modes = ["txt2img"] * 1
images = [self.image.astype(np.float16)] * 1
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
for i in range(size):
if version == "fast":
response = self.grpc_client.infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, priority=0)
else:
response = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs, priority=0)
image = response.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
_, img_byte_array = cv2.imencode('.jpg', image_result)
req = oss_upload_image(oss_client=minio_client, bucket='test', object_name=f'{uuid.uuid1()}-{i}.jpg', image_bytes=img_byte_array)
image_url_list.append(f"{req.bucket_name}/{req.object_name}")
image_result_list.append(image_result)
if category == "sketch":
clothing_category_list = self.get_clothing_category(image_result_list, gender)
return image_url_list, clothing_category_list
except Exception as e:
logger.error(e)
return image_url_list, clothing_category_list
finally:
self.grpc_client.close()
self.triton_client.close()
def preprocess(self, img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(
img,
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
def get_category(self, image):
inputs = [httpclient.InferInput("input__0", image.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
return colattr[indexs[0]]
def get_clothing_category(self, images, gender):
category_list = []
for image in images:
sketch = self.preprocess(image)
if gender.lower() == "female":
category_list.append(self.get_category(sketch))
elif gender.lower() == "male":
category = self.get_category(sketch)
if category == 'Trousers' or category == 'Skirt':
category_list.append('Bottoms')
elif category == 'Blouse' or category == 'Dress':
category_list.append('Tops')
else:
category_list.append('Outwear')
else:
category_list.append(self.get_category(sketch))
return category_list
attr_type = pd.read_csv(CATEGORY_PATH)
if __name__ == '__main__':
request_data = {
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
"category": "sketch",
"version": "high",
"size": 2,
"gender": "Female",
}
server = AgentToolGenerateImage(request_data['version'])
image_url_list, clothing_category_list = server.get_result(
prompt=request_data['prompt'],
size=request_data['size'],
version=request_data['version'],
category=request_data['category'],
gender=request_data['gender']
)
print(image_url_list)
print(clothing_category_list)

View File

@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
@@ -29,12 +30,6 @@ logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.version = request_data.version
if request_data.version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
@@ -153,15 +148,14 @@ class GenerateImage:
inputs = [input_text, input_image, input_mode]
if self.version == "fast":
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
else:
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
@@ -169,7 +163,6 @@ class GenerateImage:
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
@@ -178,10 +171,8 @@ class GenerateImage:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
if not DEBUG:
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):
@@ -195,7 +186,7 @@ def infer_cancel(tasks_id):
if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background',
prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view",
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
mode='txt2img',
category="test",

View File

@@ -17,6 +17,7 @@ import tritonclient.grpc as grpcclient
from app.core.config import *
from app.schemas.generate_image import GenerateMultiViewModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
@@ -25,14 +26,7 @@ logger = logging.getLogger()
class GenerateMultiView:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.image = self.get_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
@@ -52,16 +46,11 @@ class GenerateMultiView:
if error:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
# pil图像转成numpy数组
images = result.as_numpy("generated_image")
# for id, img in enumerate(images):
# cv2.imwrite(f"{id}.png", img)
# image_url = ""
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
@@ -103,10 +92,8 @@ class GenerateMultiView:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
if not DEBUG:
publish_status(str_generate_data, GMV_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):

View File

@@ -212,6 +212,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateProductImageModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
@@ -220,12 +221,6 @@ logger = logging.getLogger()
class GenerateProductImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
@@ -295,9 +290,9 @@ class GenerateProductImage:
inputs = [input_text, input_image, input_image_strength]
if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
time_out = 600
while time_out > 0:
@@ -318,9 +313,8 @@ class GenerateProductImage:
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
if not DEBUG:
publish_status(str_gen_product_data, GPI_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):

View File

@@ -20,6 +20,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
@@ -28,10 +29,6 @@ logger = logging.getLogger()
class GenerateRelightImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "relight_image"
@@ -42,7 +39,7 @@ class GenerateRelightImage:
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.direction = request_data.direction
self.image_url = request_data.image_url
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
self.image = pre_processing_image(self.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
@@ -114,9 +111,9 @@ class GenerateRelightImage:
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
if self.product_type == 'single':
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback, priority=1)
else:
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback, priority=1)
time_out = 600
while time_out > 0:
@@ -137,10 +134,49 @@ class GenerateRelightImage:
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GRI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
if not DEBUG:
publish_status(str_gen_product_data, GRI_RABBITMQ_QUEUES)
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
@@ -157,7 +193,7 @@ if __name__ == '__main__':
prompt="Colorful black",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
direction="Right Light",
product_type="single"
product_type="overall"
)
server = GenerateRelightImage(rd)
print(server.get_result())

View File

@@ -21,6 +21,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
import tritonclient.grpc as grpcclient
from app.schemas.generate_image import GenerateSingleLogoImageModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
logger = logging.getLogger()
@@ -28,10 +29,6 @@ logger = logging.getLogger()
class GenerateSingleLogoImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.batch_size = 1
@@ -96,9 +93,8 @@ class GenerateSingleLogoImage:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
if not DEBUG:
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
def infer_cancel(tasks_id):

View File

@@ -0,0 +1,185 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_pose_transform.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
from io import BytesIO
import imageio
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from PIL import Image
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.pose_transform import PoseTransformModel
from app.service.generate_image.utils.mq import publish_status
from app.service.generate_image.utils.pose_transform_upload import upload_gif, upload_video, upload_first_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class PoseTransformService:
def __init__(self, request_data):
self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "pose_transform"
self.image_url = request_data.image_url
self.pose_num = request_data.pose_id
self.image = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '',
'video_url': '', 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error):
if error:
self.pose_transform_data['status'] = "FAILURE"
self.pose_transform_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
else:
result_data = np.squeeze(result.as_numpy("generated_image_list").astype(np.uint8))[:, :, :, ::-1]
# 第一帧图像
first_image = Image.fromarray(result_data[0])
first_image_url = upload_first_image(first_image, user_id=self.user_id,
category=f"{self.category}_first_img",
file_name=f"{self.tasks_id}.png")
# 上传GIF
gif_buffer = BytesIO()
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
gif_buffer.seek(0)
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif",
file_name=f"{self.tasks_id}.gif")
# 上传video
video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video",
file_name=f"{self.tasks_id}.mp4")
self.pose_transform_data['status'] = "SUCCESS"
self.pose_transform_data['message'] = "success"
self.pose_transform_data['gif_url'] = str(gif_url)
self.pose_transform_data['video_url'] = str(video_url)
self.pose_transform_data['image_url'] = str(first_image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
pose_num = [self.pose_num] * 1
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape,
np_to_triton_dtype(pose_num_obj.dtype))
input_pose_num.set_data_from_numpy(pose_num_obj)
image_files = [self.image.astype(np.uint8)] * 1
image_files_obj = np.array(image_files, dtype=np.uint8).reshape((-1, 768, 512, 3))
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
input_image_files.set_data_from_numpy(image_files_obj)
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files],
callback=self.callback, client_timeout=60000)
time_out = 60000
while time_out > 0:
pose_transform_data, _ = self.read_tasks_status()
if pose_transform_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif pose_transform_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
pose_transform_data, _ = self.read_tasks_status()
return pose_transform_data
except Exception as e:
self.pose_transform_data['status'] = "FAILURE"
self.pose_transform_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
raise Exception(str(e))
finally:
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
if not DEBUG:
publish_status(json.dumps(str_pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
pose_transform_data = json.dumps(data)
redis_client.set(tasks_id, pose_transform_data)
return data
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:],
data_type="PIL")
# 目标图片的尺寸
target_width = 512
target_height = 768
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = target_width / original_width
height_ratio = target_height / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (target_width - new_width) // 2
y_offset = (target_height - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
result_image = result_image.convert("RGB")
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
if __name__ == '__main__':
rd = PoseTransformModel(
tasks_id="123-89",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
pose_id="1"
)
server = PoseTransformService(rd)
print(server.get_result())

View File

@@ -0,0 +1,23 @@
import json
import pika
import logging
from app.core.config import RABBITMQ_PARAMS
logger = logging.getLogger(__name__)
def publish_status(message, queue_name):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=queue_name, durable=True)
channel.basic_publish(exchange='',
routing_key=queue_name,
body=message,
properties=pika.BasicProperties(
delivery_mode=2,
))
connection.close()
logger.info(f" [x] Queue : {queue_name} | Sent message : {json.dumps(json.loads(message), indent=4)}")

View File

@@ -0,0 +1,75 @@
import io
import logging
import os.path
import numpy as np
# import boto3
from minio import Minio
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
# minio 配置
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def upload_first_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(oss_client=minio_client, bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_gif(gif_buffer, user_id, category, file_name):
try:
object_name = f'{user_id}/{category}/{file_name}'
req = minio_client.put_object(
"aida-users",
object_name,
gif_buffer,
length=gif_buffer.getbuffer().nbytes,
content_type="image/gif"
)
return f"aida-users/{object_name}"
except Exception as e:
logging.warning(f"upload_gif runtime exception : {e}")
def upload_video(frames, user_id, category, file_name):
try:
save_path = ndarray_to_video(frames, file_name)
object_name = f'{user_id}/{category}/{file_name}'
minio_client.fput_object(
"aida-users",
object_name,
save_path,
content_type="video/mp4" # 指定MIME类型确保可在线播放[9](@ref)
)
return f"aida-users/{object_name}"
except Exception as e:
logging.warning(f"upload_video runtime exception : {e}")
def ndarray_to_video(images, output_path, frame_size=(512, 768), fps=9):
save_path = os.path.join(POSE_TRANSFORM_VIDEO_PATH, output_path)
clip = ImageSequenceClip([frame for frame in images], fps=fps)
clip.write_videofile(save_path, codec='libx264')
return save_path
if __name__ == '__main__':
images = np.random.randint(0, 256, size=(4, 768, 512, 3), dtype=np.uint8)
print(upload_video(images, user_id=89, category='pose_transform_video', file_name="1123123.mp4"))

View File

@@ -0,0 +1,114 @@
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
from app.schemas.mannequin_edit import MannequinModel
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class MannequinEditService():
def __init__(self, request_data):
self.resize_pixel = request_data.resize_pixel
self.top = request_data.top
self.bottom = request_data.bottom
self.image = oss_get_image(oss_client=minio_client, bucket=request_data.mannequins.split('/')[0], object_name=request_data.mannequins[request_data.mannequins.find('/') + 1:], data_type="cv2")
self.mannequin_name = request_data.mannequin_name
self.bucket_name = request_data.bucket_name
if self.image.shape[2] == 4:
self.bgr = self.image[:, :, :3]
self.alpha = self.image[:, :, 3]
self.bgr = cv2.bitwise_and(self.bgr, self.bgr, mask=cv2.normalize(self.alpha, None, 0, 1, cv2.NORM_MINMAX))
self.h, self.w, _ = self.bgr.shape
else:
self.bgr = self.image
self.h, self.w, _ = self.bgr.shape
self.alpha = None
def __call__(self, *args, **kwargs):
new_mannequin = self.resize_leg(self.top, self.bottom)
_, encoded_image = cv2.imencode('.png', new_mannequin)
image_bytes = encoded_image.tobytes()
req = oss_upload_image(oss_client=minio_client, bucket=self.bucket_name, object_name=f"{self.mannequin_name}.png", image_bytes=image_bytes)
return req.bucket_name + "/" + req.object_name
def post_processing(self, image):
# 原始图片的尺寸
original_width, original_height = image.size
# 计算宽度和高度的缩放比例
width_ratio = self.w / original_width
height_ratio = self.h / original_height
# 选择较小的缩放比例,确保图片能完整放入目标图片中
scale_ratio = min(width_ratio, height_ratio)
# 计算调整后的尺寸
new_width = int(original_width * scale_ratio)
new_height = int(original_height * scale_ratio)
# 调整图片大小
resized_image = image.resize((new_width, new_height))
# 创建一个 512x768 的透明图片
result_image = Image.new("RGBA", (self.w, self.h), (255, 255, 255, 0))
# 计算需要粘贴的位置,使图片居中
x_offset = (self.w - new_width) // 2
y_offset = (self.h - new_height) // 2
# 将调整大小后的图片粘贴到透明图片上
if resized_image.mode == "RGBA":
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
else:
result_image.paste(resized_image, (x_offset, y_offset))
image = np.array(result_image)
return image
def resize_leg(self, top, bottom):
# 上部
top_part = self.bgr[:top, :]
top_part_alpha = self.alpha[:top, :]
# 需要resize 部分
part_resize = self.bgr[top:bottom, :]
part_resize_alpha = self.alpha[top:bottom, :]
# 下部
part_bottom = self.bgr[bottom:, :]
part_bottom_alpha = self.alpha[bottom:, :]
new_height = int((bottom - top) + self.resize_pixel)
resized_thigh = cv2.resize(part_resize, (self.w, new_height), interpolation=cv2.INTER_LINEAR)
resized_thigh_alpha = cv2.resize(part_resize_alpha, (self.w, new_height), interpolation=cv2.INTER_LINEAR)
# 组合
new_bgr = np.vstack((top_part, resized_thigh, part_bottom))
new_bgr_alpha = np.vstack((top_part_alpha, resized_thigh_alpha, part_bottom_alpha))
if self.alpha is not None:
# 拼接 alpha 通道
# 合并 BGR 通道和 alpha 通道
new_image = np.dstack((new_bgr, new_bgr_alpha))
else:
new_image = new_bgr
new_image = self.post_processing(Image.fromarray(new_image))
return new_image
if __name__ == '__main__':
request_data = MannequinModel(
mannequins="aida-sys-image/models/male/dc36ce58-46c3-4b6f-8787-5ca7d6fc26e6.png",
resize_pixel=-100,
bucket_name="test",
mannequin_name="mannequin_name",
top=270,
bottom=432
)
service = MannequinEditService(request_data)
print(service())

View File

@@ -0,0 +1,68 @@
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
style = ['NEW_CHINESE', 'COUNTRY_STYLE', 'FUTURISM', 'MINIMALISM', 'LOLITA', 'Y2K', 'BUSINESS', 'MERLAD',
'OUTDOOR_FUNCTIONAL', 'ROCK', 'DOPAMINE', 'GOTHIC', 'POST_APOCALYPTIC', 'ROMANTIC', 'WABI_SABI']
position = ['Overall', 'Tops', 'Bottoms', 'Outwear', 'Blouse', 'Dress', 'Trousers', 'Skirt']
gender = ['Female', 'Male']
age_group = ['Adult', 'Child']
process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
class ProjectInfoExtraction:
def __init__(self, request_data):
# llm generate brand info init
if len(request_data.image_list) or len(request_data.file_list):
self.model = ChatTongyi(model="qwen-vl-plus", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
else:
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.response_schemas = [
ResponseSchema(name="project_name", description="项目的名称."),
ResponseSchema(name="process", description="项目的类型,单品还是系列."),
ResponseSchema(name="ageGroup", description="项目设计服装的受众对象."),
ResponseSchema(name="gender", description="项目设计服装的受众性别."),
ResponseSchema(name="position", description="项目单品设计的部位."),
ResponseSchema(name="style", description="项目的设计风格.")
]
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
self.format_instructions = self.output_parser.get_format_instructions()
self.prompt = PromptTemplate(
template="你是一个时装品牌的设计师助理。根据用户输入提取出"
"[project_name] : 项目的名称,"
f"[process] : 项目的类型,从{process}选择."
f"[ageGroup] : 服装的受众,从{age_group}选择."
f"[gender] : 服装的适用性别,从{gender}选择"
f"[position] : single_design的部位如果[process]是SINGLE_DESIGN,从{position}中选择,如果[process]是SERIES_DESIGN这项为空"
f"[style] : 设计的风格,从{style}中选择"
".\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": self.format_instructions}
)
self._input = self.prompt.format_prompt(question=request_data.prompt)
self.result_data = {}
def get_result(self):
self.llm_extraction_project_info()
return self.result_data
def llm_extraction_project_info(self):
output = self.model(self._input.to_messages())
project_info = self.output_parser.parse(output.content)
self.result_data = project_info
if __name__ == '__main__':
request_data = ProjectInfoExtractionModel(
prompt="性别为儿童",
image_list=[
'https://www.minio-api.aida.com.hk/test/019aaeed-3227-11f0-a194-0826ae3ad6b3.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250613%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250613T020236Z&X-Amz-Expires=604800&X-Amz-SignedHeaders=host&X-Amz-Signature=a513b706c24134071a489c34f0fa2c0f510e871b8589dc0c08a0f26ea28ee2ff'
],
file_list=[]
)
service = ProjectInfoExtraction(request_data)
print(service.get_result())

View File

@@ -9,6 +9,7 @@ from retry import retry
from app.core.config import QWEN_API_KEY
from app.service.chat_robot.script.service.CallQWen import get_language
from app.service.prompt_generation.util import minio_util
logger = logging.getLogger(__name__)
@@ -143,6 +144,38 @@ def get_translation_from_llama3(text):
# response = requests.post(url, data=json.dumps(payload), headers=headers)
def get_prompt_from_image(image_path, text):
start_time = time.time()
# url = "http://localhost:11434/api/generate"
url = "http://10.1.1.243:11434/api/generate"
image_base64 = minio_util.minio_url_to_base64(image_path.img)
# image_base64 = minio_url_to_base64(image_path)
# 创建请求的负载 translator是自定义的翻译模型
payload = {
"model": "llama3.2-vision",
"images": [image_base64],
"prompt": f"{text}",
"stream": False
}
# 将负载转换为 JSON 格式
headers = {'Content-Type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
# 处理响应
if response.status_code == 200:
# print("Response from server:")
# print(response.json())
resp = json.loads(response.content).get("response")
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} \n, response is {resp}")
# print("input : {}, sketch re-generate result : {}".format(text, resp))
return resp
else:
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} , response is {response.content}")
print(f"Request failed with status code {response.status_code}")
print(response.text)
def main():
"""Main function"""
text = get_translation_from_llama3("[火焰]")

View File

@@ -0,0 +1,21 @@
import base64
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def minio_url_to_base64(minio_url: str) -> str:
bucket_name, object_name = minio_url.split("/", 1)
try:
response = minio_client.get_object(bucket_name, object_name)
image_data = response.read()
return base64.b64encode(image_data).decode('utf-8')
except Exception as e:
raise RuntimeError(f"Failed to get object: {e}")
finally:
if 'response' in locals():
response.close()

View File

@@ -0,0 +1,539 @@
import pymysql
import numpy as np
from apscheduler.schedulers.blocking import BlockingScheduler
import os
import logging
from collections import defaultdict
import torch
from torchvision import models, transforms
from minio import Minio
from PIL import Image
import io
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix
import matplotlib.font_manager as fm
from scipy import sparse
import pandas as pd
from datetime import datetime, timedelta
import json
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
# 自动选择可用字体
try:
# 尝试加载常见中文字体
font_path = fm.findfont(fm.FontProperties(family=['Microsoft YaHei', 'SimHei', 'WenQuanYi Zen Hei']))
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
except:
# 回退到英文字体
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
# 检查系统中可用的字体并选择支持中文的字体
font_path = fm.findfont(fm.FontProperties(family='Microsoft YaHei')) # 或其他支持中文的字体
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 设置为 Microsoft YaHei
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 配置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
filename='scheduler.log'
)
# MinIO 配置信息
minio_client = Minio(
"www.minio.aida.com.hk:12024", # MinIO Endpoint
access_key="admin", # Access Key
secret_key="Aidlab123123!", # Secret Key
secure=True # 使用https
)
# 预加载系统sketch特征向量
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
# 行为权重和衰减系数
BEHAVIOR_CONFIG = {
'portfolioClick': {'weight': 1, 'decay': 0.3},
'portfolioLike': {'weight': 2, 'decay': 0.2},
'secondCreation': {'weight': 3, 'decay': 0.1},
'sketchLike': {'weight': 4, 'decay': 0} # 不衰减
}
# 保存sketch_to_iid到文件
def save_sketch_to_iid():
"""保存sketch到iid的映射"""
sketch_to_iid = {sketch_path: iid for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)}
np.save('sketch_to_iid.npy', sketch_to_iid)
print("sketch_to_iid 已保存")
# 从文件加载sketch_to_iid
def load_sketch_to_iid():
"""加载保存的sketch到iid的映射"""
if os.path.exists('sketch_to_iid.npy'):
sketch_to_iid = np.load('sketch_to_iid.npy', allow_pickle=True).item()
print("sketch_to_iid 已加载")
return sketch_to_iid
else:
# 如果文件不存在,则生成并保存
print("sketch_to_iid 文件不存在,正在生成并保存...")
save_sketch_to_iid()
return np.load('sketch_to_iid.npy', allow_pickle=True).item()
# 使用load_sketch_to_iid来获取映射
sketch_to_iid = load_sketch_to_iid()
# 在代码中其他地方使用sketch_to_iid
# print(f"Total sketches: {len(sketch_to_iid)}")
# 定义图像预处理与ResNet训练时的预处理一致
transform = transforms.Compose([
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
])
# 加载预训练的 ResNet 模型 (ResNet50)
resnet_model = models.resnet50(pretrained=True)
modules = list(resnet_model.children())[:-1] # 移除最后的全连接层
resnet_model = torch.nn.Sequential(*modules)
resnet_model.eval() # 设置为评估模式
# 从 MinIO 获取图片并进行预处理
def get_sketch_image_from_minio(sketch_path):
"""
从 MinIO 获取 sketch 图像并预处理
"""
# 分割路径,获取桶名和文件路径
path_parts = sketch_path.split('/', 1) # 根据第一个斜杠分割,得到桶名和路径
bucket_name = path_parts[0] # 桶名
file_name = path_parts[1] # 文件路径(从第二部分开始)
try:
# 获取文件
obj = minio_client.get_object(bucket_name, file_name)
img_data = obj.read() # 读取图像数据
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
img = transform(img) # 对图像进行预处理
return img.unsqueeze(0) # 扩展维度以适应批量处理
except Exception as e:
print(f"Error fetching image for {sketch_path}: {e}")
return None
def extract_feature_vector_from_resnet(sketch_path):
"""
提取 sketch 图像的特征向量
"""
# 从 MinIO 获取图像并预处理
img_tensor = get_sketch_image_from_minio(sketch_path)
if img_tensor is None:
return np.zeros(2048) # 如果获取失败,返回零向量
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
return feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
def update_user_matrices():
"""每天更新用户交互次数矩阵和特征向量矩阵"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
# 修改后的查询语句移除category过滤
cursor.execute("""
SELECT account_id, path, COUNT(*) as like_count
FROM user_preference_log_test
GROUP BY account_id, path
""")
user_data = cursor.fetchall()
logging.info(f"成功读取{len(user_data)}条用户偏好记录")
# 计算矩阵
interaction_matrix, raw_counts_sparse, user_index_interaction_matrix, sketch_index_interaction_matrix, iid_to_category_interaction_matrix = calculate_interaction_matrix(user_data)
# visualize_sparse_matrix(raw_counts_sparse,'交互次数矩阵', 'interaction_frequency_matrix.png')
# visualize_sparse_matrix(interaction_matrix, '交互次数得分矩阵', 'interaction_score_matrix.png')
# plot_interaction_count_matrix(raw_counts_sparse)
# feature_matrix, iid_to_category_feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix = calculate_feature_matrix(user_data)
feature_matrix, user_index_feature_matrix, sketch_index_feature_matrix, iid_to_category_feature_matrix = calculate_feature_matrix(user_data)
# visualize_sparse_matrix(feature_matrix, '系统sketch与用户category平均特征向量关联度矩阵', 'correlation_matrix.png')
# 存储矩阵
np.save(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", feature_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", iid_to_category_interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", user_index_interaction_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}iid_to_category_feature_matrix.npy", iid_to_category_feature_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", user_index_feature_matrix)
#
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy", sketch_index_interaction_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", sketch_index_feature_matrix)
# logging.info("矩阵更新完成")
except Exception as e:
logging.error(f"定时任务执行失败: {str(e)}", exc_info=True)
finally:
if conn:
conn.close()
def plot_interaction_count_matrix(interaction_count_matrix):
"""绘制交互次数矩阵的分布图(热图),不隐藏零值"""
try:
if not isinstance(interaction_count_matrix, csr_matrix):
sparse_matrix = csr_matrix(interaction_count_matrix)
else:
sparse_matrix = interaction_count_matrix
# 转换为密集矩阵
try:
dense_matrix = sparse_matrix.toarray()
except MemoryError:
logging.error("内存不足,无法转换为密集矩阵")
return
# 自动检测可用中文字体
try:
font_path = fm.findfont(fm.FontProperties(family='sans-serif', style='normal'))
plt.rcParams['font.sans-serif'] = fm.FontProperties(fname=font_path).get_name()
except:
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] # 回退到英文字体
plt.rcParams['axes.unicode_minus'] = False
# 处理大矩阵的显示,限制显示范围
if dense_matrix.shape[0] > 1000 or dense_matrix.shape[1] > 1000:
dense_matrix = dense_matrix[:1000, :1000] # 只绘制前1000行列
plt.figure(figsize=(15, 10))
# 使用 `cmap` 来设置颜色,零值可以使用特定颜色,调整 `vmin` 和 `vmax` 让热图更具对比
sns.heatmap(
dense_matrix,
cmap="Blues", # 可以选择不同的颜色映射,"Blues" 或 "YlGnBu"
annot=False, # 关闭标注
cbar_kws={"label": "Interaction Count"}, # 添加颜色条标签
linewidths=0.5,
vmin=0, # 设置最小值,确保零值明显
vmax=np.max(dense_matrix) # 设置最大值,保持颜色映射的合理性
)
plt.title('User-Sketch Interaction Matrix (With Zero Entries)')
plt.xlabel('Sketch Index')
plt.ylabel('User Index')
plt.savefig('interaction_heatmap_with_zeros.png', dpi=150, bbox_inches='tight')
plt.close()
logging.info("热图已保存为 interaction_heatmap_with_zeros.png")
except Exception as e:
logging.error(f"绘图失败: {str(e)}", exc_info=True)
def visualize_sparse_matrix(matrix, title='Non-zero Interactions (Scatter Plot)', filename="scatter_figure_interaction.png"):
if not sparse.issparse(matrix):
# 转换为稀疏矩阵
matrix = sparse.csr_matrix(matrix)
# 获取非零元素的坐标和值
rows, cols = matrix.nonzero()
values = matrix.data
# 绘制散点图
plt.figure(figsize=(24, 20))
plt.scatter(cols, rows, c=values, cmap='coolwarm', alpha=0.7, s=1)
plt.colorbar(label='Interaction Count')
plt.title(title)
plt.xlabel('Item Index')
plt.ylabel('Item Index')
plt.savefig(filename)
def calculate_interaction_matrix(user_data):
"""基于新表结构的交互次数矩阵计算仅系统sketch"""
# 获取所有用户ID
all_users = set()
for account_id, path, like_count in user_data:
category = get_category_from_path(path)
if category not in TABLE_CATEGORIES.keys():
continue
all_users.add(account_id)
# 获取所有系统sketch的iid
system_sketch_iids = {sketch_to_iid[path] for path in SYSTEM_FEATURES.keys() if path in sketch_to_iid}
# 创建映射关系
user_index = {uid: idx for idx, uid in enumerate(sorted(all_users))}
sketch_index = {iid: idx for idx, iid in enumerate(sorted(system_sketch_iids))}
# 初始化双矩阵:归一化矩阵(密集)和原始计数矩阵(稀疏)
interaction_matrix = np.zeros((len(all_users), len(sketch_index))) # 归一化矩阵
data, rows, cols = [], [], [] # 用于构建稀疏矩阵的COO格式数据
# 预计算用户最大交互次数
user_max_likes = defaultdict(int)
for account_id, path, like_count in user_data:
if sketch_to_iid.get(path) in system_sketch_iids:
user_max_likes[account_id] = max(user_max_likes[account_id], like_count)
# 填充矩阵
for account_id, path, like_count in user_data:
sketch_iid = sketch_to_iid.get(path)
if sketch_iid not in system_sketch_iids:
continue
user_idx = user_index[account_id]
sketch_idx = sketch_index[sketch_iid]
# 填充稀疏矩阵数据
data.append(like_count)
rows.append(user_idx)
cols.append(sketch_idx)
# 归一化计算
max_like = user_max_likes.get(account_id, 1)
interaction_matrix[user_idx, sketch_idx] = np.log1p(1 + like_count) / np.log1p(1 + max_like)
# 构建稀疏矩阵CSR格式适合快速行操作
interaction_count_matrix = csr_matrix((data, (rows, cols)), shape=(len(all_users), len(sketch_index)))
return interaction_matrix, interaction_count_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
def calculate_feature_matrix(user_data):
"""基于新表结构的特征矩阵计算,返回用户与系统草图的相似度矩阵(加权平均)"""
# 用户特征数据存储结构:{(account_id, category): {sketch_iid: [(feature_vector, weight)]}}
user_feature_weights = defaultdict(lambda: defaultdict(list))
# 初始化所有用户和系统草图集合
all_users = set()
all_system_sketches = set(SYSTEM_FEATURES.keys())
# ==== 第一遍遍历:收集特征向量和权重 ====
for account_id, path, like_count in user_data:
category = get_category_from_path(path)
if category not in TABLE_CATEGORIES.keys():
continue
sketch_iid = sketch_to_iid.get(path)
if not sketch_iid:
continue
# 记录用户
all_users.add(account_id)
# 提取特征并记录权重like_count
if path in SYSTEM_FEATURES: # 系统草图
feature = SYSTEM_FEATURES[path]
weight = like_count # 使用like_count作为权重
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
else: # 用户草图
feature = extract_feature_vector_from_resnet(path)
weight = like_count
user_feature_weights[(account_id, category)][sketch_iid].append((feature, weight))
# ==== 第二遍遍历收集所有系统草图iid ====
system_sketch_iids = set()
for sketch_path in SYSTEM_FEATURES:
if iid := sketch_to_iid.get(sketch_path):
system_sketch_iids.add(iid)
# ==== 创建索引映射 ====
user_list = sorted(all_users)
sketch_list = sorted(system_sketch_iids)
user_index = {uid: idx for idx, uid in enumerate(user_list)}
sketch_index = {iid: idx for idx, iid in enumerate(sketch_list)}
# ==== 初始化特征矩阵 ====
feature_matrix = np.zeros((len(user_list), len(sketch_list)))
# ==== 预计算加权平均特征向量 ====
user_avg_features = {}
for (account_id, category), sketches in user_feature_weights.items():
try:
# 展平所有特征向量和权重
all_features_weights = [(vec, weight) for vec_list in sketches.values() for vec, weight in vec_list]
if len(all_features_weights) == 0:
continue
# 计算总权重
total_weight = sum(weight for _, weight in all_features_weights)
if total_weight <= 0: # 防止除零错误
total_weight = 1.0
# 加权平均计算
weighted_sum = np.zeros_like(all_features_weights[0][0]) # 获取特征向量维度
for vec, weight in all_features_weights:
weighted_sum += vec * weight
avg_vec = weighted_sum / total_weight
user_avg_features[(account_id, category)] = avg_vec
except Exception as e:
logging.warning(f"用户({account_id},{category})加权特征计算失败: {str(e)}")
continue
# ==== 计算相似度并填充矩阵 ====
for sketch_path, sys_vector in SYSTEM_FEATURES.items():
sketch_iid = sketch_to_iid.get(sketch_path)
system_sketch_category = get_category_from_path(sketch_path)
if not sketch_iid or sketch_iid not in sketch_index:
continue
sketch_col = sketch_index[sketch_iid]
# 遍历所有用户
for account_id in all_users:
user_row = user_index.get(account_id)
if user_row is None:
continue
# 获取用户加权平均特征向量
try:
# 直接通过复合键获取用户特征向量
user_vec = user_avg_features[(account_id, system_sketch_category)]
except KeyError:
# 该用户在此类别下无特征数据
continue
# 计算余弦相似度
cos_sim = cosine_similarity(user_vec, sys_vector)
feature_matrix[user_row, sketch_col] = cos_sim
return feature_matrix, user_index, sketch_index, {iid: get_category_from_path(path) for path, iid in sketch_to_iid.items()}
def get_category_from_path(path):
"""从path字段解析类别"""
try:
parts = path.split('/')
if len(parts) >= 2:
return f"{parts[2]}_{parts[3]}"
return "unknown"
except:
return "unknown"
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
def fetch_user_behavior_data(days=30):
"""从MySQL获取用户行为数据整合旧查询和新需求"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
# 计算日期范围
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
# 整合查询(获取完整行为数据)
query = f"""
SELECT
account_id,
behavior_type,
gender,
category,
url,
create_time
FROM user_behavior
WHERE create_time BETWEEN '{start_date}' AND '{end_date}'
"""
df = pd.read_sql(query, conn)
logging.info(f"成功读取{len(df)}条用户行为记录")
return df
except Exception as e:
logging.error(f"数据库查询失败: {str(e)}")
return pd.DataFrame()
finally:
if conn:
conn.close()
def calculate_heat(row, current_date):
"""计算单个行为的热度值(每次行为独立计算,不考虑聚合次数)"""
# 计算时间差(天)
days_passed = (current_date - row['create_time']).days
# 获取行为配置默认权重为0
config = BEHAVIOR_CONFIG.get(row['behavior_type'], {'weight': 0, 'decay': 0})
# 计算热度值 = 权重 * e^(-衰减系数 * 天数)
return config['weight'] * np.exp(-config['decay'] * days_passed)
def load_heat_matrix_as_array(file_path):
"""
直接加载为二维numpy数组
返回: (data_array, row_labels, col_labels)
"""
with open(file_path) as f:
saved = json.load(f)
return (
np.array(saved['data']), # 二维矩阵
saved['row_labels'], # 行标签列表
saved['col_labels'] # 列标签列表
)
def update_heat_matrices():
"""每日计算并存储热度矩阵gender_category × path"""
current_date = datetime.now()
# 获取数据
df = fetch_user_behavior_data(30)
if df.empty:
logging.warning("无有效数据,跳过今日计算")
return None
# 计算热度值
df['heat'] = df.apply(calculate_heat, axis=1, current_date=current_date)
df['gender_category'] = df['gender'] + '_' + df['category']
# 构建热度向量
heat_vectors = {}
grouped = df.groupby(['gender_category', 'url'])['heat'].sum()
for (gender_category, url), heat in grouped.items():
heat_vectors.setdefault(gender_category, {})[url] = heat
# 存储结果
save_path = 'heat_vectors_data'
os.makedirs(save_path, exist_ok=True)
date_str = current_date.strftime('%Y%m%d')
# vectors_file = f"{save_path}/heat_vectors_{date_str}.json"
vectors_file = f"{save_path}/heat_vectors.json"
with open(vectors_file, 'w', encoding='utf-8') as f:
json.dump({
'update_time': current_date.strftime('%Y-%m-%d %H:%M:%S'),
'data': heat_vectors
}, f, ensure_ascii=False, indent=2)
logging.info(f"成功存储热度向量,共{len(heat_vectors)}个分组,日期: {date_str}")
return heat_vectors
if __name__ == "__main__":
try:
# update_user_matrices()
# update_heat_matrices()
scheduler = BlockingScheduler()
scheduler.add_job(update_user_matrices, 'cron', hour=12, timezone='Asia/Shanghai')
logging.info("定时任务已启动每天12:00执行")
scheduler.start()
except KeyboardInterrupt:
logging.info("定时任务已停止")
except Exception as e:
logging.error(f"调度器启动失败: {str(e)}", exc_info=True)

View File

@@ -0,0 +1,240 @@
# 预加载资源
import logging
import time
from collections import defaultdict
import os
import json
import numpy as np
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
logger = logging.getLogger()
import pymysql
from concurrent.futures import ThreadPoolExecutor
HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
matrix_data = {
"interaction_matrix": None,
"feature_matrix": None,
"user_index_interaction": None,
"sketch_index_interaction": None,
"user_index_feature": None,
"sketch_index_feature": None,
"iid_to_sketch": None,
"category_to_iids": None,
"cached_scores": {},
"cached_valid_idxs": {},
"category_sketch_idxs_inter": None,
"category_sketch_idxs_feature": None,
"user_inter_full": dict(),
"user_feat_full": dict(),
"brand_feature_matrix": None,
"brand_index_map": None,
"heat_data": {},
}
def load_resources():
"""加载所有矩阵和映射关系,并触发预缓存"""
try:
start_time = time.time()
# 清空缓存
matrix_data["cached_scores"].clear()
matrix_data["cached_valid_idxs"].clear()
# 加载数据
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
allow_pickle=True).item()
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
if os.path.exists(brand_feature_path):
matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
else:
logger.warning("brand_feature_matrix 文件不存在,使用空数组")
matrix_data["brand_feature_matrix"] = np.array([])
# brand_index_map
brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
if os.path.exists(brand_index_path):
matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
else:
logger.warning("brand_index_map 文件不存在,使用空字典")
matrix_data["brand_index_map"] = {}
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
matrix_data["category_to_iids"] = defaultdict(list)
for iid, cat in category_to_iid_map.items():
matrix_data["category_to_iids"][cat].append(iid)
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}")
# 触发预缓存
precache_user_category()
if os.path.exists(HEAT_VECTOR_FILE):
with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
heat_json = json.load(f)
matrix_data["heat_data"] = heat_json.get("data", {})
logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
else:
matrix_data["heat_data"] = {}
except Exception as e:
logger.error(f"资源加载失败: {str(e)}")
raise RuntimeError("初始化失败")
def precache_user_category():
"""优化后的用户分类预缓存(添加耗时统计)"""
if not all([
matrix_data["interaction_matrix"] is not None,
matrix_data["feature_matrix"] is not None,
matrix_data["user_index_interaction"] is not None
]):
logger.warning("资源未加载完成,跳过预缓存")
return
start_time = time.perf_counter()
time_stats = {
"get_all_user_categories": 0,
"process_user_category": 0,
"thread_execution": 0,
"cache_update": 0,
"total": 0,
}
# 统计用户类别获取时间
t1 = time.perf_counter()
user_categories = get_all_user_categories()
time_stats["get_all_user_categories"] = time.perf_counter() - t1
precached_count = 0
def process_user_category(user_id, categories):
"""单用户类别缓存计算(统计耗时)"""
local_cache = {}
local_valid_idxs = {}
t_start = time.perf_counter()
for category in categories:
cache_key = (user_id, category)
if cache_key in matrix_data["cached_scores"]:
continue
try:
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
# 统计获取类别 IID 耗时
t_iid = time.perf_counter()
category_iids = matrix_data["category_to_iids"].get(category, [])
valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
time_stats["process_user_category"] += time.perf_counter() - t_iid
# 统计矩阵计算耗时
t_matrix = time.perf_counter()
processed_inter = np.zeros(len(valid_sketch_idxs_inter))
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
processed_feat = np.zeros(len(valid_sketch_idxs_feature))
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores * 0.3
time_stats["process_user_category"] += time.perf_counter() - t_matrix
if len(processed_inter) == len(processed_feat):
local_cache[cache_key] = (processed_inter, processed_feat)
local_valid_idxs[cache_key] = valid_sketch_idxs_inter
except Exception as e:
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
return local_cache, local_valid_idxs
# 统计线程执行时间
t2 = time.perf_counter()
with ThreadPoolExecutor(max_workers=8) as executor:
futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
for future in futures:
try:
t_cache = time.perf_counter()
cache_part, valid_idxs_part = future.result()
matrix_data["cached_scores"].update(cache_part)
matrix_data["cached_valid_idxs"].update(valid_idxs_part)
time_stats["cache_update"] += time.perf_counter() - t_cache
precached_count += len(cache_part)
except Exception as e:
logger.error(f"线程执行错误: {str(e)}")
time_stats["thread_execution"] = time.perf_counter() - t2
time_stats["total"] = time.perf_counter() - start_time
# 输出统计信息
logger.info(f"""
预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
- 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
- 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
- 线程任务执行: {time_stats["thread_execution"]:.2f}s
- 更新缓存数据: {time_stats["cache_update"]:.2f}s
- 总耗时: {time_stats["total"]:.2f}s
""")
def get_all_user_categories():
"""获取所有用户及其对应的分类"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
query = """
SELECT DISTINCT account_id, path
FROM user_preference_log_prediction
"""
cursor.execute(query)
results = cursor.fetchall()
user_categories = defaultdict(set)
for account_id, path in results:
category = get_category_from_path(path)
user_categories[account_id].add(category)
return dict(user_categories)
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}")
return {}
finally:
if conn:
conn.close()
def get_category_from_path(path: str) -> str:
"""从路径解析类别"""
try:
parts = path.split('/')
if len(parts) >= 4:
return f"{parts[2]}_{parts[3]}"
return "unknown"
except:
return "unknown"

View File

@@ -6,7 +6,7 @@ from chromadb.config import Settings
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
from tqdm import tqdm
from app.core.config import OLLAMA_URL
from app.core.config import OLLAMA_URL, CHROMADB_PATH
# 读取 csv 文件
# csv_file_path = r'D:/Files/csv/output/output.csv'
@@ -15,7 +15,7 @@ from app.core.config import OLLAMA_URL
# df = pd.read_csv(csv_file_path, encoding='Windows-1252')
# 创建 Chroma 客户端
client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector_db"))
client = chromadb.Client(Settings(is_persistent=True, persist_directory=CHROMADB_PATH))
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="./service/search_image_with_text/vector_db"))
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db"))
# 创建集合

View File

@@ -0,0 +1,99 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))