diff --git a/app/api/api_attribute_retrieve.py b/app/api/api_attribute_retrieve.py index 5c15efe..2a5ad4d 100644 --- a/app/api/api_attribute_retrieve.py +++ b/app/api/api_attribute_retrieve.py @@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]): """ try: for item in request_item: - logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") + logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") if DEBUG: service = AttributeRecognition(const=local_debug_const, request_data=request_item) else: service = AttributeRecognition(const=const, request_data=request_item) data = service.get_result() - logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") + logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}") except Exception as e: logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 53790a3..a37bec3 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,9 +3,10 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel +from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel @@ -61,6 +62,44 @@ def generate_image(tasks_id: str): return ResponseModel(data=data['data']) +'''multi view''' + + +@router.post("/generate_multi_view") +def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **image_url**: 前视角图的输入,minio或S3 url 地址 + + 示例参数: + { + "tasks_id": "123-89", + "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg" + } + """ + try: + logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = GenerateMultiView(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_multi_view_cancel/{tasks_id}") +def generate_multi_view(tasks_id: str): + try: + logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") + data = generate_multi_view_cancel(tasks_id) + logger.info(f"generate_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) + + '''single logo''' diff --git a/app/core/config.py b/app/core/config.py index 1d11a0b..7456912 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -109,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl' GI_MODEL_URL = '10.1.1.240:10061' GI_MODEL_NAME = 'flux' +GMV_MODEL_URL = '10.1.1.243:10081' +GMV_MODEL_NAME = 'multi_view' + +GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}") + + GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 11e295f..7181418 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -1,6 +1,11 @@ from pydantic import BaseModel +class GenerateMultiViewModel(BaseModel): + tasks_id: str + image_url: str + + class GenerateImageModel(BaseModel): tasks_id: str prompt: str diff --git a/app/service/design/items/pipelines/segmentation.py b/app/service/design/items/pipelines/segmentation.py index 19eb1fd..7ed43e5 100644 --- a/app/service/design/items/pipelines/segmentation.py +++ b/app/service/design/items/pipelines/segmentation.py @@ -53,7 +53,7 @@ class Segmentation(object): file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") @@ -64,7 +64,7 @@ class Segmentation(object): seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_batch/pipeline/segmentation.py b/app/service/design_batch/pipeline/segmentation.py index cba3446..aa05c0d 100644 --- a/app/service/design_batch/pipeline/segmentation.py +++ b/app/service/design_batch/pipeline/segmentation.py @@ -51,19 +51,19 @@ class Segmentation: file_path = f"seg_cache/{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") @staticmethod def load_seg_result(image_id): file_path = f"seg_cache/{image_id}.npy" - logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_fast/design_generate.py b/app/service/design_fast/design_generate.py index 4fc0726..2f7fa93 100644 --- a/app/service/design_fast/design_generate.py +++ b/app/service/design_fast/design_generate.py @@ -207,7 +207,7 @@ def design_generate_v2(request_data): 'Connection': "keep-alive", 'Content-Type': "application/json" } - logger.info(items_response) + # logger.info(items_response) response = post_request(url, json_data=items_response, headers=headers) if response: # 打印结果 diff --git a/app/service/design_fast/pipeline/segmentation.py b/app/service/design_fast/pipeline/segmentation.py index ebf02b4..0c9c51e 100644 --- a/app/service/design_fast/pipeline/segmentation.py +++ b/app/service/design_fast/pipeline/segmentation.py @@ -36,11 +36,11 @@ class Segmentation: # preview 过模型 不缓存 if "preview_submit" in result.keys() and result['preview_submit'] == "preview": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) # submit 过模型 缓存 elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) self.save_seg_result(seg_result, result['image_id']) # null 正常流程 加载本地缓存 无缓存则过模型 else: @@ -49,14 +49,14 @@ class Segmentation: # 判断缓存和实际图片size是否相同 if not _ or result["image"].shape[:2] != seg_result.shape: # 推理获得seg 结果 - seg_result = get_seg_result(result["image_id"], result['image'])[0] + seg_result = get_seg_result(result["image_id"], result['image']) self.save_seg_result(seg_result, result['image_id']) result['seg_result'] = seg_result # 处理前片后片 - temp_front = seg_result == 1.0 + temp_front = seg_result == 1 result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) - temp_back = seg_result == 2.0 + temp_back = seg_result == 2 result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) result['mask'] = result['front_mask'] + result['back_mask'] return result @@ -66,19 +66,19 @@ class Segmentation: file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logger.info(f"保存成功 :{os.path.abspath(file_path)}") + logger.debug(f"保存成功 :{os.path.abspath(file_path)}") except Exception as e: logger.error(f"保存失败: {e}") @staticmethod def load_seg_result(image_id): file_path = f"{SEG_CACHE_PATH}{image_id}.npy" - logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") + # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") try: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logger.warning("文件不存在") + # logger.warning("文件不存在") return False, None except Exception as e: logger.error(f"加载失败: {e}") diff --git a/app/service/design_fast/utils/design_ensemble.py b/app/service/design_fast/utils/design_ensemble.py index 267ea00..bfc50c6 100644 --- a/app/service/design_fast/utils/design_ensemble.py +++ b/app/service/design_fast/utils/design_ensemble.py @@ -13,7 +13,6 @@ import cv2 import mmcv import numpy as np import torch -import torch.nn.functional as F import tritonclient.http as httpclient from app.core.config import * @@ -85,7 +84,10 @@ def seg_preprocess(img_path): if ori_shape != (img_scale_w, img_scale_h): # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 img = cv2.resize(img, (img_scale_h, img_scale_w)) - # img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) + + # 扩充25的白边 + img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255]) + img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) return preprocessed_img, ori_shape @@ -114,9 +116,9 @@ def get_seg_result(image_id, image): # no cache def seg_postprocess(image_id, output, ori_shape): - seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) - seg_pred = seg_logit.cpu().numpy() - return seg_pred[0] + seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50)) + seg_logit = seg_logit[25: - 25, 25: - 25] + return seg_logit def key_point_show(image_path, key_point_result=None): diff --git a/app/service/design_pre_processing/service.py b/app/service/design_pre_processing/service.py index 16ca870..636360c 100644 --- a/app/service/design_pre_processing/service.py +++ b/app/service/design_pre_processing/service.py @@ -266,7 +266,7 @@ class DesignPreprocessing: seg_result = np.load(file_path) return True, seg_result except FileNotFoundError: - logging.info("文件不存在") + # logging.info("文件不存在") return False, None except Exception as e: logging.warning(f"加载失败: {e}") @@ -277,7 +277,7 @@ class DesignPreprocessing: file_path = f"{SEG_CACHE_PATH}{image_id}.npy" try: np.save(file_path, seg_result) - logging.info(f"保存成功,{os.path.abspath(file_path)}") + logging.debug(f"保存成功,{os.path.abspath(file_path)}") except Exception as e: logging.warning(f"保存失败: {e}") diff --git a/app/service/generate_image/service_generate_multi_view.py b/app/service/generate_image/service_generate_multi_view.py new file mode 100644 index 0000000..c930ab2 --- /dev/null +++ b/app/service/generate_image/service_generate_multi_view.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time + +import numpy as np +import redis +import tritonclient.grpc as grpcclient + +from app.core.config import * +from app.schemas.generate_image import GenerateMultiViewModel +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() + + +class GenerateMultiView: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL) + + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.image = self.get_image(request_data.image_url) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + try: + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + return image + except Exception as e: + logger.error(e) + + def callback(self, result, error): + if error: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(error) + # self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: + # pil图像转成numpy数组 + images = result.as_numpy("generated_image") + # for id, img in enumerate(images): + # cv2.imwrite(f"{id}.png", img) + # image_url = "" + image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + images = [np.array(self.image).astype(np.uint8)] * 1 + + image_obj = np.array(images, dtype=np.uint8) + + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + + input_image.set_data_from_numpy(image_obj) + + inputs = [input_image] + ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback) + + time_out = 600 + generate_data = None + while time_out > 0: + generate_data, _ = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + return generate_data + except Exception as e: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data) + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateMultiViewModel( + tasks_id="123-89", + image_url="aida-sys-image/images/female/outwear/0628000123.jpg", + ) + server = GenerateMultiView(rd) + print(server.get_result()) diff --git a/app/service/utils/decorator.py b/app/service/utils/decorator.py index 3e86182..c0164ab 100644 --- a/app/service/utils/decorator.py +++ b/app/service/utils/decorator.py @@ -7,9 +7,9 @@ def RunTime(func): t1 = time.time() res = func(*args, **kwargs) t2 = time.time() - # if t2 - t1 > 0.05: - # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") - logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + if t2 - t1 > 0.05: + logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") + # logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s") return res return wrapper @@ -22,7 +22,8 @@ def ClassCallRunTime(func): end_time = time.time() execution_time = end_time - start_time class_name = args[0].__class__.__name__ # 获取类名 - print(f"class name: {class_name} , run time is : {execution_time} s") + if execution_time > 0.05: + logging.info(f"class name: {class_name} , run time is : {execution_time} s") return result return wrapper diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 4b3cbb1..7939333 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/89/test/123-89.png" + url = "aida-users/89/123-89.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2"