diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index 53790a3..a37bec3 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -3,9 +3,10 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException -from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel +from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel from app.schemas.response_template import ResponseModel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel +from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel @@ -61,6 +62,44 @@ def generate_image(tasks_id: str): return ResponseModel(data=data['data']) +'''multi view''' + + +@router.post("/generate_multi_view") +def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **image_url**: 前视角图的输入,minio或S3 url 地址 + + 示例参数: + { + "tasks_id": "123-89", + "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg" + } + """ + try: + logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = GenerateMultiView(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/generate_multi_view_cancel/{tasks_id}") +def generate_multi_view(tasks_id: str): + try: + logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}") + data = generate_multi_view_cancel(tasks_id) + logger.info(f"generate_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"generate_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) + + '''single logo''' diff --git a/app/core/config.py b/app/core/config.py index 1d11a0b..7456912 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -109,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl' GI_MODEL_URL = '10.1.1.240:10061' GI_MODEL_NAME = 'flux' +GMV_MODEL_URL = '10.1.1.243:10081' +GMV_MODEL_NAME = 'multi_view' + +GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}") + + GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 11e295f..7181418 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -1,6 +1,11 @@ from pydantic import BaseModel +class GenerateMultiViewModel(BaseModel): + tasks_id: str + image_url: str + + class GenerateImageModel(BaseModel): tasks_id: str prompt: str diff --git a/app/service/generate_image/service_generate_multi_view.py b/app/service/generate_image/service_generate_multi_view.py new file mode 100644 index 0000000..c930ab2 --- /dev/null +++ b/app/service/generate_image/service_generate_multi_view.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time + +import numpy as np +import redis +import tritonclient.grpc as grpcclient + +from app.core.config import * +from app.schemas.generate_image import GenerateMultiViewModel +from app.service.generate_image.utils.upload_sd_image import upload_png_sd +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() + + +class GenerateMultiView: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL) + + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.image = self.get_image(request_data.image_url) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + try: + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + return image + except Exception as e: + logger.error(e) + + def callback(self, result, error): + if error: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(error) + # self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: + # pil图像转成numpy数组 + images = result.as_numpy("generated_image") + # for id, img in enumerate(images): + # cv2.imwrite(f"{id}.png", img) + # image_url = "" + image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + images = [np.array(self.image).astype(np.uint8)] * 1 + + image_obj = np.array(images, dtype=np.uint8) + + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + + input_image.set_data_from_numpy(image_obj) + + inputs = [input_image] + ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback) + + time_out = 600 + generate_data = None + while time_out > 0: + generate_data, _ = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + return generate_data + except Exception as e: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data) + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateMultiViewModel( + tasks_id="123-89", + image_url="aida-sys-image/images/female/outwear/0628000123.jpg", + ) + server = GenerateMultiView(rd) + print(server.get_result()) diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 4b3cbb1..7939333 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/89/test/123-89.png" + url = "aida-users/89/123-89.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2"