diff --git a/.gitignore b/.gitignore index 3f9e525..5a29a1f 100644 --- a/.gitignore +++ b/.gitignore @@ -136,4 +136,9 @@ app/logs/* /qodana.yaml .pth .pytorch -*.png \ No newline at end of file +*.png +*.pth +*.db +*.npy +*.pytorch +*.jpg \ No newline at end of file diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 5c15efe..2a5ad4d 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): """ try: for item in request_item: - logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") + logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") if DEBUG: service = AttributeRecognition(const=local_debug_const, request_data=request_item) else: service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() - logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") + logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_brand_dna.py b/app/api/api_brand_dna.py new file mode 100644 index 0000000..6b19416 --- /dev/null +++ b/app/api/api_brand_dna.py @@ -0,0 +1,34 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.brand_dna import BrandDnaModel +from app.schemas.response_template import ResponseModel +from app.service.brand_dna.service import BrandDna + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/seg_product") +def image2sketch(request_item: BrandDnaModel): + """ + 创建一个具有以下参数的请求体: + - **image_url**: 提取图片url + - **is_brand_dna**: 是否提取属性 + + 示例参数: + { + "image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg", + "is_brand_dna": False + } + """ + try: + logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = BrandDna(request_item) + result_url = service.get_result() + except Exception as e: + logger.warning(f"brand dna Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=result_url) diff --git a/app/api/api_design.py b/app/api/api_design.py index 665d544..1c77ed8 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -4,7 +4,7 @@ import os from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks -from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel +from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel from app.schemas.response_template import ResponseModel from app.service.design.model_process_service import model_transpose from app.service.design_batch.service import start_design_batch_generate @@ -18,6 +18,14 @@ logger = logging.getLogger() @router.post("/design") def design(request_data: DesignModel, background_tasks: BackgroundTasks): """ + objects.items.transparent: + "transparent":{ + "mask_url":"test/transparent_test/transparent_mask.png", + "scale":0.1 + }, + mask_url 为空"" -> 单件衣服透明 + mask_url 非空"mask_url" -> 区域透明 + 创建一个具有以下参数的请求体: 示例参数: { @@ -197,7 +205,7 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks): @router.post("/design_v2") -async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks): +async def design_v2(request_data: DesignStreamModel, background_tasks: BackgroundTasks): """ 创建一个具有以下参数的请求体: 示例参数: @@ -445,7 +453,7 @@ async def design(file: UploadFile = File(...), async def save_request_file(contents, file_name): # 创建保存文件的目录(如果不存在) - save_dir = os.path.join(os.getcwd(), "design_batch", "request_data") + save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data") if not os.path.exists(save_dir): os.makedirs(save_dir) # 处理文件 diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 53790a3..a37bec3 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,9 +3,10 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel +from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel @@ -61,6 +62,44 @@ def generate_image(tasks_id: str): return ResponseModel(data=data['data']) +'''multi view''' + + +@router.post("/generate_multi_view") +def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **image_url**: 前视角图的输入,minio或S3 url 地址 + + 示例参数: + { + "tasks_id": "123-89", + "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg" + } + """ + try: + logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = GenerateMultiView(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_multi_view_cancel/{tasks_id}") +def generate_multi_view(tasks_id: str): + try: + logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") + data = generate_multi_view_cancel(tasks_id) + logger.info(f"generate_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) + + '''single logo''' diff --git a/app/api/api_prompt_generation.py b/app/api/api_prompt_generation.py index 11733e8..b731e33 100644 --- a/app/api/api_prompt_generation.py +++ b/app/api/api_prompt_generation.py @@ -26,7 +26,7 @@ def prompt_generation(request_data: PromptGenerationImageModel): """ try: logger.info(f"prompt_generation request item is : @@@@@@:{request_data}") - data = get_translation_from_llama3("[" + request_data.text + "]") + data = get_translation_from_llama3(request_data.text) logger.info(f"prompt_generation response @@@@@@:{data}") except Exception as e: logger.warning(f"prompt_generation Run Exception @@@@@@:{e}") diff --git a/app/api/api_route.py b/app/api/api_route.py index 0da3a66..973a940 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from app.api import api_attribute_retrieve, api_query_image +from app.api import api_brand_dna from app.api import api_brighten from app.api import api_chat_robot from app.api import api_design @@ -23,4 +24,5 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'], router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api") router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") -router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") \ No newline at end of file +router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") +router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") diff --git a/app/api/api_test.py b/app/api/api_test.py index 1271f95..a273b11 100644 --- a/app/api/api_test.py +++ b/app/api/api_test.py @@ -4,7 +4,7 @@ import logging from fastapi import APIRouter from fastapi import HTTPException -from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS +from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL from app.schemas.response_template import ResponseModel logger = logging.getLogger() @@ -18,6 +18,7 @@ def test(id: int): "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES, "GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES, "GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES, + "JAVA_STREAM_API_URL": JAVA_STREAM_API_URL, "local_oss_server": OSS } logger.info(json.dumps(data)) diff --git a/app/core/config.py b/app/core/config.py index 7629429..7456912 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -34,6 +34,9 @@ else: RABBITMQ_ENV = "-dev" # 开发环境 # RABBITMQ_ENV = "-local" # 本地测试环境 +JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults") + + settings = Settings() # minio 配置 @@ -106,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl' GI_MODEL_URL = '10.1.1.240:10061' GI_MODEL_NAME = 'flux' +GMV_MODEL_URL = '10.1.1.243:10081' +GMV_MODEL_NAME = 'multi_view' + +GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}") + + GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" diff --git a/app/schemas/brand_dna.py b/app/schemas/brand_dna.py new file mode 100644 index 0000000..c5ae2ab --- /dev/null +++ b/app/schemas/brand_dna.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class BrandDnaModel(BaseModel): + image_url: str + is_brand_dna: bool diff --git a/app/schemas/design.py b/app/schemas/design.py index 7ebd8e6..98d0a29 100644 --- a/app/schemas/design.py +++ b/app/schemas/design.py @@ -6,6 +6,12 @@ class DesignModel(BaseModel): process_id: str +class DesignStreamModel(BaseModel): + objects: list[dict] + process_id: str + requestId: str + + class DesignProgressModel(BaseModel): process_id: str diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 11e295f..7181418 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -1,6 +1,11 @@ from pydantic import BaseModel +class GenerateMultiViewModel(BaseModel): + tasks_id: str + image_url: str + + class GenerateImageModel(BaseModel): tasks_id: str prompt: str diff --git a/app/service/brand_dna/service.py b/app/service/brand_dna/service.py new file mode 100644 index 0000000..012e682 --- /dev/null +++ b/app/service/brand_dna/service.py @@ -0,0 +1,335 @@ +import logging + +import cv2 +import mmcv +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient +from minio import Minio + +from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL +from app.schemas.brand_dna import BrandDnaModel +from app.service.attribute.config import local_debug_const +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.new_oss_client import oss_upload_image, oss_get_image + +logger = logging.getLogger() + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +class BrandDna: + def __init__(self, request_item): + self.sketch_bucket = "test" + self.image_url = request_item.image_url + self.is_brand_dna = request_item.is_brand_dna + # self.attr_type = pd.read_csv(CATEGORY_PATH) + self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv") + self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL) + self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000') + # self.const = const + self.const = local_debug_const + + # 获取结果 + def get_result(self): + mask, image = self.get_seg_mask() + cv2.imshow("", image) + cv2.waitKey(0) + + height, width, channels = image.shape + result_dict = [] + white_img = np.ones((height, width, channels), dtype=image.dtype) * 255 + mask_image = np.zeros((height, width, 3)) + + for value in np.unique(mask): + if value == 1: + outwear_img = white_img.copy() + outwear_mask_img = mask_image.copy() + outwear_img[mask == value] = image[mask == value] + outwear_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", outwear_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(outwear_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + attribute = {} + mask_url = "" + img_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + img_url = self.put_image(outwear_img, f"img/{generate_uuid()}") + mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + if value == 2: + tops_img = white_img.copy() + tops_mask_img = mask_image.copy() + tops_img[mask == value] = image[mask == value] + tops_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", tops_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(tops_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + # 属性检测 + attribute = {} + img_url = "" + mask_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}") + img_url = self.put_image(tops_img, f"img/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + if value == 3: + bottoms_img = white_img.copy() + bottoms_mask_img = mask_image.copy() + bottoms_img[mask == value] = image[mask == value] + bottoms_mask_img[mask == value] = [0, 0, 255] + + cv2.imshow("", bottoms_img) + cv2.waitKey(0) + + # 预处理之后的input img + preprocess_img = self.category_preprocess(bottoms_img) + # 类别检测 + category = self.recognition_category(preprocess_img) + if category == 'Trousers' or category == 'Skirt': + male_category = 'Bottoms' + elif category == 'Blouse' or category == 'Dress': + male_category = 'Tops' + else: + male_category = 'Outwear' + + attribute = {} + img_url = "" + mask_url = "" + # 属性检测 + if self.is_brand_dna: + attribute = self.get_recognition_attribute_result(category, preprocess_img) + else: + img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}") + mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}") + + result_dict.append( + { + 'category_female': category, + 'category_male': male_category, + 'mask_url': mask_url, + 'img_url': img_url, + 'attribute': attribute + } + ) + return result_dict + + # 获取product mask + def get_seg_mask(self): + input_image = self.get_image() + input_img, ori_shape = self.seg_product_preprocess(input_image) + transformed_img = input_img.astype(np.float32) + + inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")] + inputs[0].set_data_from_numpy(transformed_img, binary_data=True) + outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)] + results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs) + inference_output1 = results.as_numpy("seg_output__0") + mask = self.product_postprocess(inference_output1, ori_shape)[0] + return mask, input_image + + # 获取图片 + def get_image(self): + image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") + # 将其转换为彩色图像 + if len(image.shape) == 3 and image.shape[2] == 4: + image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) + elif len(image.shape) == 2: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + return image + # return cv2.imread(self.image_url) + + # 图片上传 + def put_image(self, image, object_name): + try: + image_bytes = cv2.imencode('.jpg', image)[1].tobytes() + oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes) + return f"{self.sketch_bucket}/{object_name}.jpg" + except Exception as e: + logger.warning(e) + + # 服装分割预处理 + @staticmethod + def seg_product_preprocess(image): + img = mmcv.imread(image) + ori_shape = img.shape[:2] + img_scale_w, img_scale_h = ori_shape + if ori_shape[0] > 1024: + img_scale_w = 1024 + if ori_shape[1] > 1024: + img_scale_h = 1024 + # 如果图片size任意一边 大于 1024, 则会resize 成1024 + if ori_shape != (img_scale_w, img_scale_h): + # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 + img = cv2.resize(img, (img_scale_h, img_scale_w)) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img, ori_shape + + # 类别检测后处理 + @staticmethod + def product_postprocess(output, ori_shape): + seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) + seg_pred = seg_logit.cpu().numpy() + return seg_pred[0] + + # 类别检测模型预处理 + @staticmethod + def category_preprocess(img): + img = mmcv.imread(img) + # ori_shape = img.shape[:2] + img_scale = (224, 224) + img = cv2.resize(img, img_scale) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) + return preprocessed_img + + # 类别检测 + def recognition_category(self, image): + inputs = [ + httpclient.InferInput("input__0", image.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(image, binary_data=True) + results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs) + inference_output = torch.from_numpy(results.as_numpy(f'output__0')) + + scores = inference_output.detach().numpy() + + colattr = list(self.attr_type['labelName']) + maxsc = np.max(scores[0][:5]) + indexs = np.argwhere(scores == maxsc)[:, 1] + + return colattr[indexs[0]] + + # 属性检测 + def recognition_attribute(self, model_name, description, image): + attr_type = pd.read_csv(description) + inputs = [ + httpclient.InferInput("input__0", image.shape, datatype="FP32") + ] + inputs[0].set_data_from_numpy(image, binary_data=True) + results = self.att_client.infer(model_name=model_name, inputs=inputs) + inference_output = torch.from_numpy(results.as_numpy(f"output__0")) + scores = inference_output.detach().numpy() + colattr = list(attr_type['labelName']) + task = description.split('/')[-1][:-4] + maxsc = np.max(scores[0][:5]) + indexs = np.argwhere(scores == maxsc)[:, 1] + attr = { + task: [] + } + for i in range(len(indexs)): + atr = colattr[indexs[i]] + attr[task].append(atr) + return attr + + # 获取属性检测结果 + def get_recognition_attribute_result(self, category, input_img): + if category == "Blouse": + attr_dict = {} + for i in range(len(self.const.top_description_list)): + attr_description = self.const.top_description_list[i] + attr_model_path = self.const.top_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Trousers' or category == "Skirt": + attr_dict = {} + for i in range(len(self.const.bottom_description_list)): + attr_description = self.const.bottom_description_list[i] + attr_model_path = self.const.bottom_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Dress': + attr_dict = {} + for i in range(len(self.const.dress_description_list)): + attr_description = self.const.dress_description_list[i] + attr_model_path = self.const.dress_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + + elif category == 'Outwear': + attr_dict = {} + for i in range(len(self.const.outwear_description_list)): + attr_description = self.const.outwear_description_list[i] + attr_model_path = self.const.outwear_model_list[i] + present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img) + attr_dict = self.merge(attr_dict, present_dict) + else: + attr_dict = {} + return attr_dict + + @staticmethod + def merge(dict1, dict2): + res = {**dict1, **dict2} + return res + + +if __name__ == '__main__': + # for path in os.listdir('./test_img'): + # img_path = os.path.join(r'./test_img', path) + # request_item = BrandDnaModel( + # image_url=img_path, + # is_brand_dna=True + # ) + # service = BrandDna(request_item) + # result_url = service.get_result() + # print(result_url) + request_item = BrandDnaModel( + image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png", + is_brand_dna=True + ) + service = BrandDna(request_item) + result_url = service.get_result() + print(result_url) diff --git a/app/service/chat_robot/script/prompt.py b/app/service/chat_robot/script/prompt.py index ad6ac9e..121e57d 100644 --- a/app/service/chat_robot/script/prompt.py +++ b/app/service/chat_robot/script/prompt.py @@ -69,3 +69,5 @@ TOOLS_FUNCTIONS_SUFFIX = ( ) TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now." + +GET_LANGUAGE_PREFIX = "Please identify the language. Only output the language name" diff --git a/app/service/chat_robot/script/service/CallQWen.py b/app/service/chat_robot/script/service/CallQWen.py index 33dcd04..cb7669b 100644 --- a/app/service/chat_robot/script/service/CallQWen.py +++ b/app/service/chat_robot/script/service/CallQWen.py @@ -1,4 +1,5 @@ import json +import logging from dashscope import Generation from retry import retry @@ -7,7 +8,8 @@ from urllib3.exceptions import NewConnectionError from app.core.config import * from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler from app.service.chat_robot.script.database import CustomDatabase -from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN +from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \ + GET_LANGUAGE_PREFIX from app.service.search_image_with_text.service import query get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database." @@ -159,10 +161,11 @@ def search_from_internet(message): model='qwen-turbo', api_key=QWEN_API_KEY, messages=message, - tools=tools, + prompt='The output must be in English.Keep the final result under 200 words.' + # tools=tools, # seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234 - result_format='message', # 将输出设置为message形式 - enable_search='True' + # result_format='message', # 将输出设置为message形式 + # enable_search='True' ) return response @@ -198,14 +201,9 @@ def get_response(messages): def call_with_messages(message, gender): + global tool_info user_input = message print('\n') - # messages = [ - # { - # "content": input('请输入:'), # 提问示例:"现在几点了?" "一个小时后几点" "北京天气如何?" - # "role": "user" - # } - # ] messages = [ { @@ -223,14 +221,10 @@ def call_with_messages(message, gender): } ] - # 模型的第一轮调用 - # first_response = get_response(messages) - # assistant_output = first_response.output.choices[0].message - # print(f"\n大模型第一轮输出信息:{first_response}\n") - # messages.append(assistant_output) flag = True count = 1 - result_content = "我是一个时尚AI助手,请问有什么可以帮您" + # result_content = "我是一个时尚AI助手,请问有什么可以帮您" + result_content = "I am a fashion AI assistant, how can I help you?" response_type = "chat" while flag and count <= 3: @@ -244,44 +238,53 @@ def call_with_messages(message, gender): print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容 result_content = assistant_output.content break - # 如果模型选择的工具是search_from_internet - # elif assistant_output.tool_calls[0]['function']['name'] == 'search_from_internet': - # tool_info = {"name": "search_from_internet", "role": "tool"} - # user_input = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['user_input'] - # tool_info['content'] = search_from_internet(user_input) - # 如果模型选择的工具是get_database_table - elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table': - tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()} - # 如果模型选择的工具是get_table_info - elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info': - tool_info = {"name": "get_table_info", "role": "tool"} - table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names'] - tool_info['content'] = get_table_info(table_names) - # 如果模型选择的工具是query_database - elif assistant_output.tool_calls[0]['function']['name'] == 'query_database': - tool_info = {"name": "query_database", "role": "tool"} - sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string'] - tool_info['content'] = query_database(sql_string) + # 如果模型选择的工具是internet_search + elif assistant_output.tool_calls[0]['function']['name'] == 'internet_search': + tool_info = {"name": "search_from_internet", "role": "tool"} + content = json.loads(assistant_output.tool_calls[0]['function']['arguments']) + message = [ + {'role': 'assistant', 'content': content['query']} + ] + tool_info['content'] = search_from_internet(message) flag = False - result_content = tool_info['content'] - response_type = "image" + result_content = tool_info['content'].output.text + # 如果模型选择的工具是get_database_table + # elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table': + # tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()} + # # 如果模型选择的工具是get_table_info + # elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info': + # tool_info = {"name": "get_table_info", "role": "tool"} + # table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names'] + # tool_info['content'] = get_table_info(table_names) + # # 如果模型选择的工具是query_database + # elif assistant_output.tool_calls[0]['function']['name'] == 'query_database': + # tool_info = {"name": "query_database", "role": "tool"} + # sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string'] + # tool_info['content'] = query_database(sql_string) + # flag = False + # result_content = tool_info['content'] + # response_type = "image" elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool': tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()} flag = False result_content = tool_info['content'] elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db': + content = json.loads(assistant_output.tool_calls[0]['function']['arguments']) tool_info = {"name": "get_image_from_vector_db", "role": "tool", - 'content': get_image_from_vector_db(gender, user_input)} + 'content': get_image_from_vector_db(gender, content['parameters']['content'])} flag = False result_content = tool_info['content'] response_type = "image" + else: + tool_info = {"name": assistant_output.tool_calls[0]['function']['name'], 'content': 'null'} + logging.info(assistant_output.tool_calls[0]['function']['name'] + "(unknown tools)") + flag = False print(f"工具输出信息:{tool_info['content']}\n") messages.append(tool_info) count += 1 - final_output = {"output": result_content} - final_output["response_type"] = response_type + final_output = {"output": result_content, "response_type": response_type} QWenCallbackHandler.on_chain_end(qwen, final_output) # 模型的第二轮调用,对工具的输出进行总结 @@ -298,5 +301,23 @@ def tutorial_tool(): return TUTORIAL_TOOL_RETURN +def get_language(message: str) -> str: + messages = [ + { + "content": message, # 用户message + "role": "user" + }, + { + "content": GET_LANGUAGE_PREFIX, # ai message + "role": "assistant" + } + ] + + first_response = get_response(messages) + assistant_output = first_response.output.choices[0].message.content + logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}") + return assistant_output + + if __name__ == '__main__': - call_with_messages() + get_language("") diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index 19eb1fd..7ed43e5 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -53,7 +53,7 @@ class Segmentation(object): file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") @@ -64,7 +64,7 @@ class Segmentation(object): seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_batch/design_batch_celery.py b/app/service/design_batch/design_batch_celery.py index 3f12862..06ccc5e 100644 --- a/app/service/design_batch/design_batch_celery.py +++ b/app/service/design_batch/design_batch_celery.py @@ -12,7 +12,7 @@ from app.service.design_batch.utils.save_json import oss_upload_json from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single id_lock = threading.Lock() -celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://') +celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True) celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s' celery_app.conf.worker_hijack_root_logger = False logging.getLogger('pika').setLevel(logging.WARNING) @@ -120,7 +120,7 @@ def batch_design(objects_data, tasks_id, json_name): for t in threads: t.join() - + logger.debug(object_response) oss_upload_json(minio_client, object_response, json_name) publish_status(tasks_id, "ok", json_name) return object_response diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index cba3446..aa05c0d 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -51,19 +51,19 @@ class Segmentation: file_path = f"seg_cache/{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") @staticmethod def load_seg_result(image_id): file_path = f"seg_cache/{image_id}.npy" - logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_batch/service.py b/app/service/design_batch/service.py index ca6908e..e9fb814 100644 --- a/app/service/design_batch/service.py +++ b/app/service/design_batch/service.py @@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status async def start_design_batch_generate(data, file): - generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id) + generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.tasks_id, data.file_name) print(generate_clothes_task) publish_status(data.tasks_id, "0/100", "") return {"task_id": data.tasks_id} diff --git a/app/service/design_batch/test.py b/app/service/design_batch/test.py index 6b94bc6..2e74cc9 100644 --- a/app/service/design_batch/test.py +++ b/app/service/design_batch/test.py @@ -157,6 +157,6 @@ if __name__ == '__main__': ], "process_id": "83" } - task_id = 1 + task_id = 10086 json_name = "test.json" batch_design.delay(data['objects'], task_id, json_name) diff --git a/app/service/design_batch/utils/MQ.py b/app/service/design_batch/utils/MQ.py index 50e98c2..1b64bf3 100644 --- a/app/service/design_batch/utils/MQ.py +++ b/app/service/design_batch/utils/MQ.py @@ -2,9 +2,11 @@ import json import pika +from app.core.config import RABBITMQ_PARAMS + def publish_status(task_id, progress, result): - connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213')) + connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) channel = connection.channel() channel.queue_declare(queue='DesignBatch', durable=True) message = {'task_id': task_id, 'progress': progress, "result": result} diff --git a/app/service/design_batch/utils/save_json.py b/app/service/design_batch/utils/save_json.py index 9acd916..915df69 100644 --- a/app/service/design_batch/utils/save_json.py +++ b/app/service/design_batch/utils/save_json.py @@ -1,13 +1,21 @@ +import io import json import logging +import os logger = logging.getLogger() def oss_upload_json(oss_client, json_data, object_name): try: - with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file: + save_dir = os.path.join(os.getcwd(), "service/design_batch", "response_data") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # 处理文件 + file_path = os.path.join(save_dir, object_name) + with open(file_path, 'w') as file: json.dump(json_data, file, indent=4) - oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}") + json_bytes = json.dumps(json_data).encode('utf-8') + oss_client.put_object("test", object_name, io.BytesIO(json_bytes), length=len(json_bytes), content_type="application/json") except Exception as e: logger.warning(str(e)) diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index f4012cf..2f7fa93 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -139,6 +139,7 @@ def design_generate(request_data): @RunTime def design_generate_v2(request_data): objects_data = request_data.dict()['objects'] + request_id = request_data.requestId threads = [] def process_object(step, object): @@ -146,7 +147,7 @@ def design_generate_v2(request_data): items_response = { 'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "", - 'requestId': object['requestId'] if 'requestId' in object.keys() else "" + 'requestId': request_id } if basic['single_overall'] == "overall": item_results = [] @@ -197,9 +198,8 @@ def design_generate_v2(request_data): 'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None, }) items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image']) - # 发送结果给java端 - url = "https://3998-117-143-125-51.ngrok-free.app/api/third/party/receiveDesignResults" + url = JAVA_STREAM_API_URL headers = { 'Accept': "*/*", 'Accept-Encoding': "gzip, deflate, br", @@ -207,11 +207,11 @@ def design_generate_v2(request_data): 'Connection': "keep-alive", 'Content-Type': "application/json" } + # logger.info(items_response) response = post_request(url, json_data=items_response, headers=headers) if response: # 打印结果 logger.info(response.text) - logger.info(items_response) for step, object in enumerate(objects_data): t = threading.Thread(target=process_object, args=(step, object)) diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index 9cc53a3..0c9c51e 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -66,19 +66,19 @@ class Segmentation: file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") @staticmethod def load_seg_result(image_id): file_path = f"{SEG_CACHE_PATH}{image_id}.npy" - logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index bfc50c6..8eef4f2 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -25,6 +25,7 @@ from app.core.config import * def keypoint_preprocess(img_path): img = mmcv.imread(img_path) + img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) img_scale = (256, 256) h, w = img.shape[:2] img = cv2.resize(img, img_scale) @@ -62,7 +63,11 @@ def keypoint_postprocess(output, scale_factor): scale_matrix = np.diag(scale_factor) nan = np.isinf(scale_matrix) scale_matrix[nan] = 0 - return np.ceil(np.dot(segment_result, scale_matrix) * 4) + # 应用缩放因子 + scaled_result = np.ceil(np.dot(segment_result, scale_matrix) * 4) + # 补偿边框偏移 + compensated_result = scaled_result - 25 + return compensated_result """ diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index 16ca870..636360c 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -266,7 +266,7 @@ class DesignPreprocessing: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logging.info("文件不存在") + # logging.info("文件不存在") return False, None except Exception as e: logging.warning(f"加载失败: {e}") @@ -277,7 +277,7 @@ class DesignPreprocessing: file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logging.info(f"保存成功,{os.path.abspath(file_path)}") + logging.debug(f"保存成功,{os.path.abspath(file_path)}") except Exception as e: logging.warning(f"保存失败: {e}") diff --git a/app/service/generate_image/service_generate_multi_view.py b/app/service/generate_image/service_generate_multi_view.py new file mode 100644 index 0000000..c930ab2 --- /dev/null +++ b/app/service/generate_image/service_generate_multi_view.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time + +import numpy as np +import redis +import tritonclient.grpc as grpcclient + +from app.core.config import * +from app.schemas.generate_image import GenerateMultiViewModel +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() + + +class GenerateMultiView: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL) + + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.image = self.get_image(request_data.image_url) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + try: + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + return image + except Exception as e: + logger.error(e) + + def callback(self, result, error): + if error: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(error) + # self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: + # pil图像转成numpy数组 + images = result.as_numpy("generated_image") + # for id, img in enumerate(images): + # cv2.imwrite(f"{id}.png", img) + # image_url = "" + image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + images = [np.array(self.image).astype(np.uint8)] * 1 + + image_obj = np.array(images, dtype=np.uint8) + + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + + input_image.set_data_from_numpy(image_obj) + + inputs = [input_image] + ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback) + + time_out = 600 + generate_data = None + while time_out > 0: + generate_data, _ = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + return generate_data + except Exception as e: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data) + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateMultiViewModel( + tasks_id="123-89", + image_url="aida-sys-image/images/female/outwear/0628000123.jpg", + ) + server = GenerateMultiView(rd) + print(server.get_result()) diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index 1c20a13..287a983 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -1,4 +1,196 @@ -#!/usr/bin/env python +# #!/usr/bin/env python +# # -*- coding: UTF-8 -*- +# """ +# @Project :trinity_client +# @File :service_att_recognition.py +# @Author :周成融 +# @Date :2023/7/26 12:01:05 +# @detail : +# """ +# import json +# import logging +# import time +# +# import cv2 +# import numpy as np +# import redis +# import tritonclient.grpc as grpcclient +# from PIL import Image +# from tritonclient.utils import np_to_triton_dtype +# +# from app.core.config import * +# from app.schemas.generate_image import GenerateProductImageModel +# from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +# from app.service.utils.oss_client import oss_get_image +# +# logger = logging.getLogger() +# +# +# class GenerateProductImage: +# def __init__(self, request_data): +# if DEBUG is False: +# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) +# self.channel = self.connection.channel() +# # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) +# # self.channel = self.connection.channel() +# # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) +# self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) +# self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) +# self.category = "product_image" +# self.image_strength = request_data.image_strength +# self.batch_size = 1 +# self.product_type = request_data.product_type +# self.prompt = request_data.prompt +# self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) +# self.tasks_id = request_data.tasks_id +# self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] +# self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# self.redis_client.expire(self.tasks_id, 600) +# +# def callback(self, result, error): +# if error: +# self.gen_product_data['status'] = "FAILURE" +# self.gen_product_data['message'] = str(error) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# else: +# # pil图像转成numpy数组 +# image = result.as_numpy("generated_inpaint_image") +# image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) +# cropped_image = post_processing_image(image_result, self.left, self.top) +# image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") +# self.gen_product_data['status'] = "SUCCESS" +# self.gen_product_data['message'] = "success" +# self.gen_product_data['image_url'] = str(image_url) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# +# def read_tasks_status(self): +# status_data = self.redis_client.get(self.tasks_id) +# return json.loads(status_data), status_data +# +# def get_result(self): +# try: +# prompts = [self.prompt] * self.batch_size +# self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) +# self.image = cv2.resize(self.image, (1024, 1024)) +# images = [self.image.astype(np.uint8)] * self.batch_size +# +# if self.product_type == "single": +# text_obj = np.array(prompts, dtype="object").reshape(-1, 1) +# image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) +# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) +# else: +# text_obj = np.array(prompts, dtype="object").reshape((1)) +# image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) +# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) +# +# # 假设 prompts、images 和 self.image_strength 已经定义 +# +# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) +# input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") +# input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) +# +# input_text.set_data_from_numpy(text_obj) +# input_image.set_data_from_numpy(image_obj) +# input_image_strength.set_data_from_numpy(image_strength_obj) +# +# inputs = [input_text, input_image, input_image_strength] +# +# if self.product_type == "single": +# ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) +# else: +# ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) +# +# time_out = 600 +# while time_out > 0: +# gen_product_data, _ = self.read_tasks_status() +# if gen_product_data['status'] in ["REVOKED", "FAILURE"]: +# ctx.cancel() +# break +# elif gen_product_data['status'] == "SUCCESS": +# break +# time_out -= 1 +# time.sleep(0.1) +# gen_product_data, _ = self.read_tasks_status() +# return gen_product_data +# except Exception as e: +# self.gen_product_data['status'] = "FAILURE" +# self.gen_product_data['message'] = str(e) +# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) +# raise Exception(str(e)) +# finally: +# dict_gen_product_data, str_gen_product_data = self.read_tasks_status() +# if DEBUG is False: +# self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data) +# logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") +# +# +# def infer_cancel(tasks_id): +# redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) +# data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} +# gen_product_data = json.dumps(data) +# redis_client.set(tasks_id, gen_product_data) +# return data +# +# +# def pre_processing_image(image_url): +# image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") +# # resize 原图至1024*1024 +# image = image.resize((int(1024 / image.height * image.width), 1024)) +# +# # 原始图片的尺寸 +# width, height = image.size +# +# new_height, new_width = 1024, 1024 +# # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 +# pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) +# +# # 将原始图片粘贴到新的画布中心 +# left = (new_width - width) // 2 +# top = (new_height - height) // 2 +# pad_image.paste(image, (left, top)) +# +# # 将画布 resize 成宽度 1024,长度 1024 +# resized_image = pad_image.resize((1024, 1024)) +# image_size = (1024, 1024) +# +# if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): +# # 创建白色背景 +# background = Image.new("RGB", image_size, (255, 255, 255)) +# # 将图片粘贴到白色背景上 +# background.paste(resized_image, mask=resized_image.split()[3]) +# image = np.array(background) +# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +# return image, image_size, left, top +# +# +# def post_processing_image(image, left, top): +# resized_image = image.resize((int(image.width * (768 / image.height)), 768)) +# # 计算裁剪的坐标 +# left = (resized_image.width - 512) // 2 +# upper = 0 +# right = left + 512 +# lower = 768 +# +# # 进行裁剪 +# cropped_image = resized_image.crop((left, upper, right, lower)) +# return cropped_image +# +# +# if __name__ == '__main__': +# rd = GenerateProductImageModel( +# tasks_id="123-89", +# # prompt="", +# image_strength=0.7, +# prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", +# image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", +# product_type="overall" +# ) +# server = GenerateProductImage(rd) +# print(server.get_result()) + +# 旧版product +# !/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @@ -34,14 +226,14 @@ class GenerateProductImage: # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.channel = self.connection.channel() # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) + self.grpc_client = grpcclient.InferenceServerClient(url="10.1.1.243:18001") self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" self.image_strength = request_data.image_strength self.batch_size = 1 self.product_type = request_data.product_type self.prompt = request_data.prompt - self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url) + self.image = pre_processing_image(request_data.image_url) self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} @@ -55,10 +247,13 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - image = result.as_numpy("generated_inpaint_image") - image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) - cropped_image = post_processing_image(image_result, self.left, self.top) - image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + if self.product_type == "single": + image = result.as_numpy("generated_cnet_image") + else: + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + # cropped_image = post_processing_image(image_result, self.left, self.top) + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['message'] = "success" self.gen_product_data['image_url'] = str(image_url) @@ -71,17 +266,18 @@ class GenerateProductImage: def get_result(self): try: prompts = [self.prompt] * self.batch_size - self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) - self.image = cv2.resize(self.image, (1024, 1024)) + + self.image = cv2.cvtColor(self.image, cv2.COLOR_RGBA2RGB) + # self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size if self.product_type == "single": - text_obj = np.array(prompts, dtype="object").reshape(-1, 1) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1)) else: text_obj = np.array(prompts, dtype="object").reshape((1)) - image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) # 假设 prompts、images 和 self.image_strength 已经定义 @@ -97,9 +293,9 @@ class GenerateProductImage: inputs = [input_text, input_image, input_image_strength] if self.product_type == "single": - ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name="stable_diffusion_1_5_cnet", inputs=inputs, callback=self.callback) else: - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: @@ -135,33 +331,26 @@ def infer_cancel(tasks_id): def pre_processing_image(image_url): image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") - # resize 原图至1024*1024 - image = image.resize((int(1024 / image.height * image.width), 1024)) - - # 原始图片的尺寸 + # 调整图片高度为768像素,保持宽高比 width, height = image.size + new_height = 768 + new_width = int(width * (new_height / height)) + resized_image = image.resize((new_width, new_height)) - new_height, new_width = 1024, 1024 - # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 - pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) + # 创建一个512x768的透明图片 + result_image = Image.new("RGBA", (512, 768), (255, 255, 255, 255)) - # 将原始图片粘贴到新的画布中心 - left = (new_width - width) // 2 - top = (new_height - height) // 2 - pad_image.paste(image, (left, top)) + # 计算需要粘贴的位置,使图片居中 + x_offset = (512 - new_width) // 2 + y_offset = 0 - # 将画布 resize 成宽度 1024,长度 1024 - resized_image = pad_image.resize((1024, 1024)) - image_size = (1024, 1024) + # 将调整大小后的图片粘贴到透明图片上 + result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3]) - if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): - # 创建白色背景 - background = Image.new("RGB", image_size, (255, 255, 255)) - # 将图片粘贴到白色背景上 - background.paste(resized_image, mask=resized_image.split()[3]) - image = np.array(background) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image, image_size, left, top + image = np.array(result_image) + + # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) + return image def post_processing_image(image, left, top): @@ -182,9 +371,9 @@ if __name__ == '__main__': tasks_id="123-89", # prompt="", image_strength=0.7, - prompt="The best quality, masterpiece,outwear, 8K realistic, HUD", - image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png", - product_type="overall" + prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR", + image_url="aida-results/result_40c7924e-e220-11ef-8ea2-0242ac150003.png", + product_type="single" ) server = GenerateProductImage(rd) print(server.get_result()) diff --git a/app/service/prompt_generation/chatgpt_for_translation.py b/app/service/prompt_generation/chatgpt_for_translation.py index e541781..e668500 100644 --- a/app/service/prompt_generation/chatgpt_for_translation.py +++ b/app/service/prompt_generation/chatgpt_for_translation.py @@ -8,6 +8,7 @@ from requests import RequestException from retry import retry from app.core.config import QWEN_API_KEY +from app.service.chat_robot.script.service.CallQWen import get_language logger = logging.getLogger(__name__) @@ -93,7 +94,13 @@ def get_translation_from_llama3(text): # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" - # 创建请求的负载 + # 先获取用户输入文本的语言 + language = get_language(text) + + if 'English' in language: + return text + + # 创建请求的负载 translator是自定义的翻译模型 payload = { "model": "translator", "prompt": f"[{text}]", @@ -117,6 +124,26 @@ def get_translation_from_llama3(text): print(response.text) +# 在llama3中创建一个翻译模型 +# def create_model_with_llama(text): +# url = "http://localhost:11434/api/create" +# # url = "http://10.1.1.240:1143/api/generate" +# +# # prompt = f"System: {prefix_for_llama}\nUser:[{text}]" +# +# # 创建翻译器的配置文件 +# payload = { +# "model": "translator", +# "modelfile": "FROM llama3\nSYSTEM Translate everything within the brackets [] into English." +# "Never translate or modify any English input." +# "The input must be fully translated into coherent English sentences." +# } +# +# # 将负载转换为 JSON 格式 +# headers = {'Content-Type': 'application/json'} +# response = requests.post(url, data=json.dumps(payload), headers=headers) + + def main(): """Main function""" text = get_translation_from_llama3("[火焰]") diff --git a/app/service/utils/decorator.py b/app/service/utils/decorator.py index 3e86182..c0164ab 100644 --- a/app/service/utils/decorator.py +++ b/app/service/utils/decorator.py @@ -7,9 +7,9 @@ def RunTime(func): t1 = time.time() res = func(*args, **kwargs) t2 = time.time() - # if t2 - t1 > 0.05: - # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") - logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + if t2 - t1 > 0.05: + logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") return res return wrapper @@ -22,7 +22,8 @@ def ClassCallRunTime(func): end_time = time.time() execution_time = end_time - start_time class_name = args[0].__class__.__name__ # 获取类名 - print(f"class name: {class_name} , run time is : {execution_time} s") + if execution_time > 0.05: + logging.info(f"class name: {class_name} , run time is : {execution_time} s") return result return wrapper diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 4b3cbb1..7939333 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/89/test/123-89.png" + url = "aida-users/89/123-89.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2"