Merge branch 'refs/heads/develop'

This commit is contained in:
zhouchengrong
2025-02-04 09:37:55 +08:00
31 changed files with 949 additions and 118 deletions

7
.gitignore vendored
View File

@@ -136,4 +136,9 @@ app/logs/*
/qodana.yaml /qodana.yaml
.pth .pth
.pytorch .pytorch
*.png *.png
*.pth
*.db
*.npy
*.pytorch
*.jpg

View File

@@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
""" """
try: try:
for item in request_item: for item in request_item:
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
if DEBUG: if DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item) service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else: else:
service = AttributeRecognition(const=const, request_data=request_item) service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result() data = service.get_result()
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e: except Exception as e:
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))

34
app/api/api_brand_dna.py Normal file
View File

@@ -0,0 +1,34 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.brand_dna import BrandDnaModel
from app.schemas.response_template import ResponseModel
from app.service.brand_dna.service import BrandDna
router = APIRouter()
logger = logging.getLogger()
@router.post("/seg_product")
def image2sketch(request_item: BrandDnaModel):
"""
创建一个具有以下参数的请求体:
- **image_url**: 提取图片url
- **is_brand_dna**: 是否提取属性
示例参数:
{
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
"is_brand_dna": False
}
"""
try:
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = BrandDna(request_item)
result_url = service.get_result()
except Exception as e:
logger.warning(f"brand dna Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=result_url)

View File

@@ -4,7 +4,7 @@ import os
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
from app.service.design.model_process_service import model_transpose from app.service.design.model_process_service import model_transpose
from app.service.design_batch.service import start_design_batch_generate from app.service.design_batch.service import start_design_batch_generate
@@ -18,6 +18,14 @@ logger = logging.getLogger()
@router.post("/design") @router.post("/design")
def design(request_data: DesignModel, background_tasks: BackgroundTasks): def design(request_data: DesignModel, background_tasks: BackgroundTasks):
""" """
objects.items.transparent:
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
mask_url 为空"" -> 单件衣服透明
mask_url 非空"mask_url" -> 区域透明
创建一个具有以下参数的请求体: 创建一个具有以下参数的请求体:
示例参数: 示例参数:
{ {
@@ -197,7 +205,7 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
@router.post("/design_v2") @router.post("/design_v2")
async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks): async def design_v2(request_data: DesignStreamModel, background_tasks: BackgroundTasks):
""" """
创建一个具有以下参数的请求体: 创建一个具有以下参数的请求体:
示例参数: 示例参数:
@@ -445,7 +453,7 @@ async def design(file: UploadFile = File(...),
async def save_request_file(contents, file_name): async def save_request_file(contents, file_name):
# 创建保存文件的目录(如果不存在) # 创建保存文件的目录(如果不存在)
save_dir = os.path.join(os.getcwd(), "design_batch", "request_data") save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
# 处理文件 # 处理文件

View File

@@ -3,9 +3,10 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
@@ -61,6 +62,44 @@ def generate_image(tasks_id: str):
return ResponseModel(data=data['data']) return ResponseModel(data=data['data'])
'''multi view'''
@router.post("/generate_multi_view")
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 前视角图的输入minio或S3 url 地址
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
}
"""
try:
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateMultiView(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_multi_view_cancel/{tasks_id}")
def generate_multi_view(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_multi_view_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''single logo''' '''single logo'''

View File

@@ -26,7 +26,7 @@ def prompt_generation(request_data: PromptGenerationImageModel):
""" """
try: try:
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
data = get_translation_from_llama3("[" + request_data.text + "]") data = get_translation_from_llama3(request_data.text)
logger.info(f"prompt_generation response @@@@@@:{data}") logger.info(f"prompt_generation response @@@@@@:{data}")
except Exception as e: except Exception as e:
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api import api_attribute_retrieve, api_query_image from app.api import api_attribute_retrieve, api_query_image
from app.api import api_brand_dna
from app.api import api_brighten from app.api import api_brighten
from app.api import api_chat_robot from app.api import api_chat_robot
from app.api import api_design from app.api import api_design
@@ -23,4 +24,5 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")

View File

@@ -4,7 +4,7 @@ import logging
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import HTTPException from fastapi import HTTPException
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
logger = logging.getLogger() logger = logging.getLogger()
@@ -18,6 +18,7 @@ def test(id: int):
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
"local_oss_server": OSS "local_oss_server": OSS
} }
logger.info(json.dumps(data)) logger.info(json.dumps(data))

View File

@@ -34,6 +34,9 @@ else:
RABBITMQ_ENV = "-dev" # 开发环境 RABBITMQ_ENV = "-dev" # 开发环境
# RABBITMQ_ENV = "-local" # 本地测试环境 # RABBITMQ_ENV = "-local" # 本地测试环境
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
settings = Settings() settings = Settings()
# minio 配置 # minio 配置
@@ -106,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061' GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux' GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
GI_MINIO_BUCKET = "aida-users" GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"

6
app/schemas/brand_dna.py Normal file
View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class BrandDnaModel(BaseModel):
image_url: str
is_brand_dna: bool

View File

@@ -6,6 +6,12 @@ class DesignModel(BaseModel):
process_id: str process_id: str
class DesignStreamModel(BaseModel):
objects: list[dict]
process_id: str
requestId: str
class DesignProgressModel(BaseModel): class DesignProgressModel(BaseModel):
process_id: str process_id: str

View File

@@ -1,6 +1,11 @@
from pydantic import BaseModel from pydantic import BaseModel
class GenerateMultiViewModel(BaseModel):
tasks_id: str
image_url: str
class GenerateImageModel(BaseModel): class GenerateImageModel(BaseModel):
tasks_id: str tasks_id: str
prompt: str prompt: str

View File

@@ -0,0 +1,335 @@
import logging
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import local_debug_const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class BrandDna:
def __init__(self, request_item):
self.sketch_bucket = "test"
self.image_url = request_item.image_url
self.is_brand_dna = request_item.is_brand_dna
# self.attr_type = pd.read_csv(CATEGORY_PATH)
self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
# self.const = const
self.const = local_debug_const
# 获取结果
def get_result(self):
mask, image = self.get_seg_mask()
cv2.imshow("", image)
cv2.waitKey(0)
height, width, channels = image.shape
result_dict = []
white_img = np.ones((height, width, channels), dtype=image.dtype) * 255
mask_image = np.zeros((height, width, 3))
for value in np.unique(mask):
if value == 1:
outwear_img = white_img.copy()
outwear_mask_img = mask_image.copy()
outwear_img[mask == value] = image[mask == value]
outwear_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", outwear_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(outwear_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
attribute = {}
mask_url = ""
img_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
img_url = self.put_image(outwear_img, f"img/{generate_uuid()}")
mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
if value == 2:
tops_img = white_img.copy()
tops_mask_img = mask_image.copy()
tops_img[mask == value] = image[mask == value]
tops_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", tops_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(tops_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
# 属性检测
attribute = {}
img_url = ""
mask_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}")
img_url = self.put_image(tops_img, f"img/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
if value == 3:
bottoms_img = white_img.copy()
bottoms_mask_img = mask_image.copy()
bottoms_img[mask == value] = image[mask == value]
bottoms_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", bottoms_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(bottoms_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
attribute = {}
img_url = ""
mask_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}")
mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
return result_dict
# 获取product mask
def get_seg_mask(self):
input_image = self.get_image()
input_img, ori_shape = self.seg_product_preprocess(input_image)
transformed_img = input_img.astype(np.float32)
inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)]
results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs)
inference_output1 = results.as_numpy("seg_output__0")
mask = self.product_postprocess(inference_output1, ori_shape)[0]
return mask, input_image
# 获取图片
def get_image(self):
image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
# 将其转换为彩色图像
if len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
return image
# return cv2.imread(self.image_url)
# 图片上传
def put_image(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes)
return f"{self.sketch_bucket}/{object_name}.jpg"
except Exception as e:
logger.warning(e)
# 服装分割预处理
@staticmethod
def seg_product_preprocess(image):
img = mmcv.imread(image)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# 类别检测后处理
@staticmethod
def product_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
# 类别检测模型预处理
@staticmethod
def category_preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
# 类别检测
def recognition_category(self, image):
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
colattr = list(self.attr_type['labelName'])
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
return colattr[indexs[0]]
# 属性检测
def recognition_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.att_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
# 获取属性检测结果
def get_recognition_attribute_result(self, category, input_img):
if category == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Trousers' or category == "Skirt":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
return attr_dict
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
if __name__ == '__main__':
# for path in os.listdir('./test_img'):
# img_path = os.path.join(r'./test_img', path)
# request_item = BrandDnaModel(
# image_url=img_path,
# is_brand_dna=True
# )
# service = BrandDna(request_item)
# result_url = service.get_result()
# print(result_url)
request_item = BrandDnaModel(
image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png",
is_brand_dna=True
)
service = BrandDna(request_item)
result_url = service.get_result()
print(result_url)

View File

@@ -69,3 +69,5 @@ TOOLS_FUNCTIONS_SUFFIX = (
) )
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now." TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."
GET_LANGUAGE_PREFIX = "Please identify the language. Only output the language name"

View File

@@ -1,4 +1,5 @@
import json import json
import logging
from dashscope import Generation from dashscope import Generation
from retry import retry from retry import retry
@@ -7,7 +8,8 @@ from urllib3.exceptions import NewConnectionError
from app.core.config import * from app.core.config import *
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
GET_LANGUAGE_PREFIX
from app.service.search_image_with_text.service import query from app.service.search_image_with_text.service import query
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
@@ -159,10 +161,11 @@ def search_from_internet(message):
model='qwen-turbo', model='qwen-turbo',
api_key=QWEN_API_KEY, api_key=QWEN_API_KEY,
messages=message, messages=message,
tools=tools, prompt='The output must be in English.Keep the final result under 200 words.'
# tools=tools,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234 # seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式 # result_format='message', # 将输出设置为message形式
enable_search='True' # enable_search='True'
) )
return response return response
@@ -198,14 +201,9 @@ def get_response(messages):
def call_with_messages(message, gender): def call_with_messages(message, gender):
global tool_info
user_input = message user_input = message
print('\n') print('\n')
# messages = [
# {
# "content": input('请输入:'), # 提问示例:"现在几点了?" "一个小时后几点" "北京天气如何?"
# "role": "user"
# }
# ]
messages = [ messages = [
{ {
@@ -223,14 +221,10 @@ def call_with_messages(message, gender):
} }
] ]
# 模型的第一轮调用
# first_response = get_response(messages)
# assistant_output = first_response.output.choices[0].message
# print(f"\n大模型第一轮输出信息{first_response}\n")
# messages.append(assistant_output)
flag = True flag = True
count = 1 count = 1
result_content = "我是一个时尚AI助手请问有什么可以帮您" # result_content = "我是一个时尚AI助手请问有什么可以帮您"
result_content = "I am a fashion AI assistant, how can I help you?"
response_type = "chat" response_type = "chat"
while flag and count <= 3: while flag and count <= 3:
@@ -244,44 +238,53 @@ def call_with_messages(message, gender):
print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容 print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容
result_content = assistant_output.content result_content = assistant_output.content
break break
# 如果模型选择的工具是search_from_internet # 如果模型选择的工具是internet_search
# elif assistant_output.tool_calls[0]['function']['name'] == 'search_from_internet': elif assistant_output.tool_calls[0]['function']['name'] == 'internet_search':
# tool_info = {"name": "search_from_internet", "role": "tool"} tool_info = {"name": "search_from_internet", "role": "tool"}
# user_input = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['user_input'] content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
# tool_info['content'] = search_from_internet(user_input) message = [
# 如果模型选择的工具是get_database_table {'role': 'assistant', 'content': content['query']}
elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table': ]
tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()} tool_info['content'] = search_from_internet(message)
# 如果模型选择的工具是get_table_info
elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
tool_info = {"name": "get_table_info", "role": "tool"}
table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
tool_info['content'] = get_table_info(table_names)
# 如果模型选择的工具是query_database
elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
tool_info = {"name": "query_database", "role": "tool"}
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
tool_info['content'] = query_database(sql_string)
flag = False flag = False
result_content = tool_info['content'] result_content = tool_info['content'].output.text
response_type = "image" # 如果模型选择的工具是get_database_table
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
# # 如果模型选择的工具是get_table_info
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
# tool_info = {"name": "get_table_info", "role": "tool"}
# table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
# tool_info['content'] = get_table_info(table_names)
# # 如果模型选择的工具是query_database
# elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
# tool_info = {"name": "query_database", "role": "tool"}
# sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
# tool_info['content'] = query_database(sql_string)
# flag = False
# result_content = tool_info['content']
# response_type = "image"
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool': elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()} tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
flag = False flag = False
result_content = tool_info['content'] result_content = tool_info['content']
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db': elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
tool_info = {"name": "get_image_from_vector_db", "role": "tool", tool_info = {"name": "get_image_from_vector_db", "role": "tool",
'content': get_image_from_vector_db(gender, user_input)} 'content': get_image_from_vector_db(gender, content['parameters']['content'])}
flag = False flag = False
result_content = tool_info['content'] result_content = tool_info['content']
response_type = "image" response_type = "image"
else:
tool_info = {"name": assistant_output.tool_calls[0]['function']['name'], 'content': 'null'}
logging.info(assistant_output.tool_calls[0]['function']['name'] + "(unknown tools)")
flag = False
print(f"工具输出信息:{tool_info['content']}\n") print(f"工具输出信息:{tool_info['content']}\n")
messages.append(tool_info) messages.append(tool_info)
count += 1 count += 1
final_output = {"output": result_content} final_output = {"output": result_content, "response_type": response_type}
final_output["response_type"] = response_type
QWenCallbackHandler.on_chain_end(qwen, final_output) QWenCallbackHandler.on_chain_end(qwen, final_output)
# 模型的第二轮调用,对工具的输出进行总结 # 模型的第二轮调用,对工具的输出进行总结
@@ -298,5 +301,23 @@ def tutorial_tool():
return TUTORIAL_TOOL_RETURN return TUTORIAL_TOOL_RETURN
def get_language(message: str) -> str:
messages = [
{
"content": message, # 用户message
"role": "user"
},
{
"content": GET_LANGUAGE_PREFIX, # ai message
"role": "assistant"
}
]
first_response = get_response(messages)
assistant_output = first_response.output.choices[0].message.content
logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
return assistant_output
if __name__ == '__main__': if __name__ == '__main__':
call_with_messages() get_language("")

View File

@@ -53,7 +53,7 @@ class Segmentation(object):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy" file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try: try:
np.save(file_path, seg_result) np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}") logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e: except Exception as e:
logger.error(f"保存失败: {e}") logger.error(f"保存失败: {e}")
@@ -64,7 +64,7 @@ class Segmentation(object):
seg_result = np.load(file_path) seg_result = np.load(file_path)
return True, seg_result return True, seg_result
except FileNotFoundError: except FileNotFoundError:
logger.warning("文件不存在") # logger.warning("文件不存在")
return False, None return False, None
except Exception as e: except Exception as e:
logger.error(f"加载失败: {e}") logger.error(f"加载失败: {e}")

View File

@@ -12,7 +12,7 @@ from app.service.design_batch.utils.save_json import oss_upload_json
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
id_lock = threading.Lock() id_lock = threading.Lock()
celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://') celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING) logging.getLogger('pika').setLevel(logging.WARNING)
@@ -120,7 +120,7 @@ def batch_design(objects_data, tasks_id, json_name):
for t in threads: for t in threads:
t.join() t.join()
logger.debug(object_response)
oss_upload_json(minio_client, object_response, json_name) oss_upload_json(minio_client, object_response, json_name)
publish_status(tasks_id, "ok", json_name) publish_status(tasks_id, "ok", json_name)
return object_response return object_response

View File

@@ -51,19 +51,19 @@ class Segmentation:
file_path = f"seg_cache/{image_id}.npy" file_path = f"seg_cache/{image_id}.npy"
try: try:
np.save(file_path, seg_result) np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}") logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e: except Exception as e:
logger.error(f"保存失败: {e}") logger.error(f"保存失败: {e}")
@staticmethod @staticmethod
def load_seg_result(image_id): def load_seg_result(image_id):
file_path = f"seg_cache/{image_id}.npy" file_path = f"seg_cache/{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try: try:
seg_result = np.load(file_path) seg_result = np.load(file_path)
return True, seg_result return True, seg_result
except FileNotFoundError: except FileNotFoundError:
logger.warning("文件不存在") # logger.warning("文件不存在")
return False, None return False, None
except Exception as e: except Exception as e:
logger.error(f"加载失败: {e}") logger.error(f"加载失败: {e}")

View File

@@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status
async def start_design_batch_generate(data, file): async def start_design_batch_generate(data, file):
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id) generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.tasks_id, data.file_name)
print(generate_clothes_task) print(generate_clothes_task)
publish_status(data.tasks_id, "0/100", "") publish_status(data.tasks_id, "0/100", "")
return {"task_id": data.tasks_id} return {"task_id": data.tasks_id}

View File

@@ -157,6 +157,6 @@ if __name__ == '__main__':
], ],
"process_id": "83" "process_id": "83"
} }
task_id = 1 task_id = 10086
json_name = "test.json" json_name = "test.json"
batch_design.delay(data['objects'], task_id, json_name) batch_design.delay(data['objects'], task_id, json_name)

View File

@@ -2,9 +2,11 @@ import json
import pika import pika
from app.core.config import RABBITMQ_PARAMS
def publish_status(task_id, progress, result): def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213')) connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel() channel = connection.channel()
channel.queue_declare(queue='DesignBatch', durable=True) channel.queue_declare(queue='DesignBatch', durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result} message = {'task_id': task_id, 'progress': progress, "result": result}

View File

@@ -1,13 +1,21 @@
import io
import json import json
import logging import logging
import os
logger = logging.getLogger() logger = logging.getLogger()
def oss_upload_json(oss_client, json_data, object_name): def oss_upload_json(oss_client, json_data, object_name):
try: try:
with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file: save_dir = os.path.join(os.getcwd(), "service/design_batch", "response_data")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 处理文件
file_path = os.path.join(save_dir, object_name)
with open(file_path, 'w') as file:
json.dump(json_data, file, indent=4) json.dump(json_data, file, indent=4)
oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}") json_bytes = json.dumps(json_data).encode('utf-8')
oss_client.put_object("test", object_name, io.BytesIO(json_bytes), length=len(json_bytes), content_type="application/json")
except Exception as e: except Exception as e:
logger.warning(str(e)) logger.warning(str(e))

View File

@@ -139,6 +139,7 @@ def design_generate(request_data):
@RunTime @RunTime
def design_generate_v2(request_data): def design_generate_v2(request_data):
objects_data = request_data.dict()['objects'] objects_data = request_data.dict()['objects']
request_id = request_data.requestId
threads = [] threads = []
def process_object(step, object): def process_object(step, object):
@@ -146,7 +147,7 @@ def design_generate_v2(request_data):
items_response = { items_response = {
'layers': [], 'layers': [],
'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "", 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "",
'requestId': object['requestId'] if 'requestId' in object.keys() else "" 'requestId': request_id
} }
if basic['single_overall'] == "overall": if basic['single_overall'] == "overall":
item_results = [] item_results = []
@@ -197,9 +198,8 @@ def design_generate_v2(request_data):
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
}) })
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image']) items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
# 发送结果给java端 # 发送结果给java端
url = "https://3998-117-143-125-51.ngrok-free.app/api/third/party/receiveDesignResults" url = JAVA_STREAM_API_URL
headers = { headers = {
'Accept': "*/*", 'Accept': "*/*",
'Accept-Encoding': "gzip, deflate, br", 'Accept-Encoding': "gzip, deflate, br",
@@ -207,11 +207,11 @@ def design_generate_v2(request_data):
'Connection': "keep-alive", 'Connection': "keep-alive",
'Content-Type': "application/json" 'Content-Type': "application/json"
} }
# logger.info(items_response)
response = post_request(url, json_data=items_response, headers=headers) response = post_request(url, json_data=items_response, headers=headers)
if response: if response:
# 打印结果 # 打印结果
logger.info(response.text) logger.info(response.text)
logger.info(items_response)
for step, object in enumerate(objects_data): for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object)) t = threading.Thread(target=process_object, args=(step, object))

View File

@@ -66,19 +66,19 @@ class Segmentation:
file_path = f"{SEG_CACHE_PATH}{image_id}.npy" file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try: try:
np.save(file_path, seg_result) np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}") logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e: except Exception as e:
logger.error(f"保存失败: {e}") logger.error(f"保存失败: {e}")
@staticmethod @staticmethod
def load_seg_result(image_id): def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy" file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try: try:
seg_result = np.load(file_path) seg_result = np.load(file_path)
return True, seg_result return True, seg_result
except FileNotFoundError: except FileNotFoundError:
logger.warning("文件不存在") # logger.warning("文件不存在")
return False, None return False, None
except Exception as e: except Exception as e:
logger.error(f"加载失败: {e}") logger.error(f"加载失败: {e}")

View File

@@ -25,6 +25,7 @@ from app.core.config import *
def keypoint_preprocess(img_path): def keypoint_preprocess(img_path):
img = mmcv.imread(img_path) img = mmcv.imread(img_path)
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
img_scale = (256, 256) img_scale = (256, 256)
h, w = img.shape[:2] h, w = img.shape[:2]
img = cv2.resize(img, img_scale) img = cv2.resize(img, img_scale)
@@ -62,7 +63,11 @@ def keypoint_postprocess(output, scale_factor):
scale_matrix = np.diag(scale_factor) scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix) nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0 scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4) # 应用缩放因子
scaled_result = np.ceil(np.dot(segment_result, scale_matrix) * 4)
# 补偿边框偏移
compensated_result = scaled_result - 25
return compensated_result
""" """

View File

@@ -266,7 +266,7 @@ class DesignPreprocessing:
seg_result = np.load(file_path) seg_result = np.load(file_path)
return True, seg_result return True, seg_result
except FileNotFoundError: except FileNotFoundError:
logging.info("文件不存在") # logging.info("文件不存在")
return False, None return False, None
except Exception as e: except Exception as e:
logging.warning(f"加载失败: {e}") logging.warning(f"加载失败: {e}")
@@ -277,7 +277,7 @@ class DesignPreprocessing:
file_path = f"{SEG_CACHE_PATH}{image_id}.npy" file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try: try:
np.save(file_path, seg_result) np.save(file_path, seg_result)
logging.info(f"保存成功,{os.path.abspath(file_path)}") logging.debug(f"保存成功,{os.path.abspath(file_path)}")
except Exception as e: except Exception as e:
logging.warning(f"保存失败: {e}") logging.warning(f"保存失败: {e}")

View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from app.core.config import *
from app.schemas.generate_image import GenerateMultiViewModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class GenerateMultiView:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.image = self.get_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
try:
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
return image
except Exception as e:
logger.error(e)
def callback(self, result, error):
if error:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
# pil图像转成numpy数组
images = result.as_numpy("generated_image")
# for id, img in enumerate(images):
# cv2.imwrite(f"{id}.png", img)
# image_url = ""
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
images = [np.array(self.image).astype(np.uint8)] * 1
image_obj = np.array(images, dtype=np.uint8)
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image.set_data_from_numpy(image_obj)
inputs = [input_image]
ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateMultiViewModel(
tasks_id="123-89",
image_url="aida-sys-image/images/female/outwear/0628000123.jpg",
)
server = GenerateMultiView(rd)
print(server.get_result())

View File

@@ -1,4 +1,196 @@
#!/usr/bin/env python # #!/usr/bin/env python
# # -*- coding: UTF-8 -*-
# """
# @Project trinity_client
# @File service_att_recognition.py
# @Author :周成融
# @Date 2023/7/26 12:01:05
# @detail
# """
# import json
# import logging
# import time
#
# import cv2
# import numpy as np
# import redis
# import tritonclient.grpc as grpcclient
# from PIL import Image
# from tritonclient.utils import np_to_triton_dtype
#
# from app.core.config import *
# from app.schemas.generate_image import GenerateProductImageModel
# from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
# from app.service.utils.oss_client import oss_get_image
#
# logger = logging.getLogger()
#
#
# class GenerateProductImage:
# def __init__(self, request_data):
# if DEBUG is False:
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# # self.channel = self.connection.channel()
# # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
# self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
# self.category = "product_image"
# self.image_strength = request_data.image_strength
# self.batch_size = 1
# self.product_type = request_data.product_type
# self.prompt = request_data.prompt
# self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url)
# self.tasks_id = request_data.tasks_id
# self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
# self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
# self.redis_client.expire(self.tasks_id, 600)
#
# def callback(self, result, error):
# if error:
# self.gen_product_data['status'] = "FAILURE"
# self.gen_product_data['message'] = str(error)
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
# else:
# # pil图像转成numpy数组
# image = result.as_numpy("generated_inpaint_image")
# image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
# cropped_image = post_processing_image(image_result, self.left, self.top)
# image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
# self.gen_product_data['status'] = "SUCCESS"
# self.gen_product_data['message'] = "success"
# self.gen_product_data['image_url'] = str(image_url)
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
#
# def read_tasks_status(self):
# status_data = self.redis_client.get(self.tasks_id)
# return json.loads(status_data), status_data
#
# def get_result(self):
# try:
# prompts = [self.prompt] * self.batch_size
# self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
# self.image = cv2.resize(self.image, (1024, 1024))
# images = [self.image.astype(np.uint8)] * self.batch_size
#
# if self.product_type == "single":
# text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
# image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3))
# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
# else:
# text_obj = np.array(prompts, dtype="object").reshape((1))
# image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3))
# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
#
# # 假设 prompts、images 和 self.image_strength 已经定义
#
# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
# input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
# input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
#
# input_text.set_data_from_numpy(text_obj)
# input_image.set_data_from_numpy(image_obj)
# input_image_strength.set_data_from_numpy(image_strength_obj)
#
# inputs = [input_text, input_image, input_image_strength]
#
# if self.product_type == "single":
# ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback)
# else:
# ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
#
# time_out = 600
# while time_out > 0:
# gen_product_data, _ = self.read_tasks_status()
# if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
# ctx.cancel()
# break
# elif gen_product_data['status'] == "SUCCESS":
# break
# time_out -= 1
# time.sleep(0.1)
# gen_product_data, _ = self.read_tasks_status()
# return gen_product_data
# except Exception as e:
# self.gen_product_data['status'] = "FAILURE"
# self.gen_product_data['message'] = str(e)
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
# raise Exception(str(e))
# finally:
# dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
# if DEBUG is False:
# self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
#
#
# def infer_cancel(tasks_id):
# redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
# data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
# gen_product_data = json.dumps(data)
# redis_client.set(tasks_id, gen_product_data)
# return data
#
#
# def pre_processing_image(image_url):
# image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# # resize 原图至1024*1024
# image = image.resize((int(1024 / image.height * image.width), 1024))
#
# # 原始图片的尺寸
# width, height = image.size
#
# new_height, new_width = 1024, 1024
# # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
# pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
#
# # 将原始图片粘贴到新的画布中心
# left = (new_width - width) // 2
# top = (new_height - height) // 2
# pad_image.paste(image, (left, top))
#
# # 将画布 resize 成宽度 1024长度 1024
# resized_image = pad_image.resize((1024, 1024))
# image_size = (1024, 1024)
#
# if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
# # 创建白色背景
# background = Image.new("RGB", image_size, (255, 255, 255))
# # 将图片粘贴到白色背景上
# background.paste(resized_image, mask=resized_image.split()[3])
# image = np.array(background)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# return image, image_size, left, top
#
#
# def post_processing_image(image, left, top):
# resized_image = image.resize((int(image.width * (768 / image.height)), 768))
# # 计算裁剪的坐标
# left = (resized_image.width - 512) // 2
# upper = 0
# right = left + 512
# lower = 768
#
# # 进行裁剪
# cropped_image = resized_image.crop((left, upper, right, lower))
# return cropped_image
#
#
# if __name__ == '__main__':
# rd = GenerateProductImageModel(
# tasks_id="123-89",
# # prompt="",
# image_strength=0.7,
# prompt="The best quality, masterpiece,outwear, 8K realistic, HUD",
# image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png",
# product_type="overall"
# )
# server = GenerateProductImage(rd)
# print(server.get_result())
# 旧版product
# !/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
""" """
@Project trinity_client @Project trinity_client
@@ -34,14 +226,14 @@ class GenerateProductImage:
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel() # self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.grpc_client = grpcclient.InferenceServerClient(url="10.1.1.243:18001")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image" self.category = "product_image"
self.image_strength = request_data.image_strength self.image_strength = request_data.image_strength
self.batch_size = 1 self.batch_size = 1
self.product_type = request_data.product_type self.product_type = request_data.product_type
self.prompt = request_data.prompt self.prompt = request_data.prompt
self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) self.image = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
@@ -55,10 +247,13 @@ class GenerateProductImage:
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else: else:
# pil图像转成numpy数组 # pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image") if self.product_type == "single":
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) image = result.as_numpy("generated_cnet_image")
cropped_image = post_processing_image(image_result, self.left, self.top) else:
image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
# cropped_image = post_processing_image(image_result, self.left, self.top)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success" self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url) self.gen_product_data['image_url'] = str(image_url)
@@ -71,17 +266,18 @@ class GenerateProductImage:
def get_result(self): def get_result(self):
try: try:
prompts = [self.prompt] * self.batch_size prompts = [self.prompt] * self.batch_size
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
self.image = cv2.resize(self.image, (1024, 1024)) self.image = cv2.cvtColor(self.image, cv2.COLOR_RGBA2RGB)
# self.image = cv2.resize(self.image, (1024, 1024))
images = [self.image.astype(np.uint8)] * self.batch_size images = [self.image.astype(np.uint8)] * self.batch_size
if self.product_type == "single": if self.product_type == "single":
text_obj = np.array(prompts, dtype="object").reshape(-1, 1) text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1))
else: else:
text_obj = np.array(prompts, dtype="object").reshape((1)) text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
# 假设 prompts、images 和 self.image_strength 已经定义 # 假设 prompts、images 和 self.image_strength 已经定义
@@ -97,9 +293,9 @@ class GenerateProductImage:
inputs = [input_text, input_image, input_image_strength] inputs = [input_text, input_image, input_image_strength]
if self.product_type == "single": if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name="stable_diffusion_1_5_cnet", inputs=inputs, callback=self.callback)
else: else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback)
time_out = 600 time_out = 600
while time_out > 0: while time_out > 0:
@@ -135,33 +331,26 @@ def infer_cancel(tasks_id):
def pre_processing_image(image_url): def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# resize 原图至1024*1024 # 调整图片高度为768像素保持宽高比
image = image.resize((int(1024 / image.height * image.width), 1024))
# 原始图片的尺寸
width, height = image.size width, height = image.size
new_height = 768
new_width = int(width * (new_height / height))
resized_image = image.resize((new_width, new_height))
new_height, new_width = 1024, 1024 # 创建一个512x768的透明图片
# 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 result_image = Image.new("RGBA", (512, 768), (255, 255, 255, 255))
pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
# 将原始图片粘贴到新的画布中心 # 计算需要粘贴的位置,使图片居中
left = (new_width - width) // 2 x_offset = (512 - new_width) // 2
top = (new_height - height) // 2 y_offset = 0
pad_image.paste(image, (left, top))
# 将画布 resize 成宽度 1024长度 1024 # 将调整大小后的图片粘贴到透明图片上
resized_image = pad_image.resize((1024, 1024)) result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
image_size = (1024, 1024)
if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): image = np.array(result_image)
# 创建白色背景
background = Image.new("RGB", image_size, (255, 255, 255)) # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
# 将图片粘贴到白色背景上 return image
background.paste(resized_image, mask=resized_image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image, image_size, left, top
def post_processing_image(image, left, top): def post_processing_image(image, left, top):
@@ -182,9 +371,9 @@ if __name__ == '__main__':
tasks_id="123-89", tasks_id="123-89",
# prompt="", # prompt="",
image_strength=0.7, image_strength=0.7,
prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR",
image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", image_url="aida-results/result_40c7924e-e220-11ef-8ea2-0242ac150003.png",
product_type="overall" product_type="single"
) )
server = GenerateProductImage(rd) server = GenerateProductImage(rd)
print(server.get_result()) print(server.get_result())

View File

@@ -8,6 +8,7 @@ from requests import RequestException
from retry import retry from retry import retry
from app.core.config import QWEN_API_KEY from app.core.config import QWEN_API_KEY
from app.service.chat_robot.script.service.CallQWen import get_language
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -93,7 +94,13 @@ def get_translation_from_llama3(text):
# prompt = f"System: {prefix_for_llama}\nUser:[{text}]" # prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
# 创建请求的负载 # 先获取用户输入文本的语言
language = get_language(text)
if 'English' in language:
return text
# 创建请求的负载 translator是自定义的翻译模型
payload = { payload = {
"model": "translator", "model": "translator",
"prompt": f"[{text}]", "prompt": f"[{text}]",
@@ -117,6 +124,26 @@ def get_translation_from_llama3(text):
print(response.text) print(response.text)
# 在llama3中创建一个翻译模型
# def create_model_with_llama(text):
# url = "http://localhost:11434/api/create"
# # url = "http://10.1.1.240:1143/api/generate"
#
# # prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
#
# # 创建翻译器的配置文件
# payload = {
# "model": "translator",
# "modelfile": "FROM llama3\nSYSTEM Translate everything within the brackets [] into English."
# "Never translate or modify any English input."
# "The input must be fully translated into coherent English sentences."
# }
#
# # 将负载转换为 JSON 格式
# headers = {'Content-Type': 'application/json'}
# response = requests.post(url, data=json.dumps(payload), headers=headers)
def main(): def main():
"""Main function""" """Main function"""
text = get_translation_from_llama3("[火焰]") text = get_translation_from_llama3("[火焰]")

View File

@@ -7,9 +7,9 @@ def RunTime(func):
t1 = time.time() t1 = time.time()
res = func(*args, **kwargs) res = func(*args, **kwargs)
t2 = time.time() t2 = time.time()
# if t2 - t1 > 0.05: if t2 - t1 > 0.05:
# logging.info(f"function【{func.__name__}】,runtime{str(t2 - t1)}】s") logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s")
logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s") # logging.info(f"function【{func.__name__}】,runtime{str(t2 - t1)}】s")
return res return res
return wrapper return wrapper
@@ -22,7 +22,8 @@ def ClassCallRunTime(func):
end_time = time.time() end_time = time.time()
execution_time = end_time - start_time execution_time = end_time - start_time
class_name = args[0].__class__.__name__ # 获取类名 class_name = args[0].__class__.__name__ # 获取类名
print(f"class name: {class_name} , run time is : {execution_time} s") if execution_time > 0.05:
logging.info(f"class name: {class_name} , run time is : {execution_time} s")
return result return result
return wrapper return wrapper

View File

@@ -82,7 +82,7 @@ if __name__ == '__main__':
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png" # url = "aida-users/89/single_logo/123-89.png"
url = "aida-users/89/test/123-89.png" url = "aida-users/89/123-89.png"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "2" read_type = "2"