feat(新功能): sketch 多视角图生成功能接口
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
@@ -3,9 +3,10 @@ import logging
|
|||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||||
|
|
||||||
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel
|
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
||||||
|
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
|
||||||
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
||||||
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
|
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
|
||||||
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
|
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
|
||||||
@@ -61,6 +62,44 @@ def generate_image(tasks_id: str):
|
|||||||
return ResponseModel(data=data['data'])
|
return ResponseModel(data=data['data'])
|
||||||
|
|
||||||
|
|
||||||
|
'''multi view'''
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate_multi_view")
|
||||||
|
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||||
|
- **image_url**: 前视角图的输入,minio或S3 url 地址
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"tasks_id": "123-89",
|
||||||
|
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||||
|
service = GenerateMultiView(request_item)
|
||||||
|
background_tasks.add_task(service.get_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/generate_multi_view_cancel/{tasks_id}")
|
||||||
|
def generate_multi_view(tasks_id: str):
|
||||||
|
try:
|
||||||
|
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
|
||||||
|
data = generate_multi_view_cancel(tasks_id)
|
||||||
|
logger.info(f"generate_cancel response @@@@@@:{data}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=data['data'])
|
||||||
|
|
||||||
|
|
||||||
'''single logo'''
|
'''single logo'''
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -109,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
|||||||
GI_MODEL_URL = '10.1.1.240:10061'
|
GI_MODEL_URL = '10.1.1.240:10061'
|
||||||
GI_MODEL_NAME = 'flux'
|
GI_MODEL_NAME = 'flux'
|
||||||
|
|
||||||
|
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||||
|
GMV_MODEL_NAME = 'multi_view'
|
||||||
|
|
||||||
|
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
|
||||||
|
|
||||||
|
|
||||||
GI_MINIO_BUCKET = "aida-users"
|
GI_MINIO_BUCKET = "aida-users"
|
||||||
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
||||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateMultiViewModel(BaseModel):
|
||||||
|
tasks_id: str
|
||||||
|
image_url: str
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageModel(BaseModel):
|
class GenerateImageModel(BaseModel):
|
||||||
tasks_id: str
|
tasks_id: str
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|||||||
126
app/service/generate_image/service_generate_multi_view.py
Normal file
126
app/service/generate_image/service_generate_multi_view.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
"""
|
||||||
|
@Project :trinity_client
|
||||||
|
@File :service_att_recognition.py
|
||||||
|
@Author :周成融
|
||||||
|
@Date :2023/7/26 12:01:05
|
||||||
|
@detail :
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import redis
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
|
||||||
|
from app.core.config import *
|
||||||
|
from app.schemas.generate_image import GenerateMultiViewModel
|
||||||
|
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||||
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateMultiView:
|
||||||
|
def __init__(self, request_data):
|
||||||
|
if DEBUG is False:
|
||||||
|
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
|
self.channel = self.connection.channel()
|
||||||
|
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
|
# self.channel = self.connection.channel()
|
||||||
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
|
||||||
|
|
||||||
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
self.image = self.get_image(request_data.image_url)
|
||||||
|
self.tasks_id = request_data.tasks_id
|
||||||
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
|
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
|
self.redis_client.expire(self.tasks_id, 600)
|
||||||
|
|
||||||
|
def get_image(self, image_url):
|
||||||
|
try:
|
||||||
|
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
def callback(self, result, error):
|
||||||
|
if error:
|
||||||
|
self.generate_data['status'] = "FAILURE"
|
||||||
|
self.generate_data['message'] = str(error)
|
||||||
|
# self.generate_data['data'] = str(error)
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
|
else:
|
||||||
|
# pil图像转成numpy数组
|
||||||
|
images = result.as_numpy("generated_image")
|
||||||
|
# for id, img in enumerate(images):
|
||||||
|
# cv2.imwrite(f"{id}.png", img)
|
||||||
|
# image_url = ""
|
||||||
|
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
|
||||||
|
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||||
|
self.generate_data['status'] = "SUCCESS"
|
||||||
|
self.generate_data['message'] = "success"
|
||||||
|
self.generate_data['image_url'] = str(image_url)
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
|
|
||||||
|
def read_tasks_status(self):
|
||||||
|
status_data = self.redis_client.get(self.tasks_id)
|
||||||
|
return json.loads(status_data), status_data
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
try:
|
||||||
|
images = [np.array(self.image).astype(np.uint8)] * 1
|
||||||
|
|
||||||
|
image_obj = np.array(images, dtype=np.uint8)
|
||||||
|
|
||||||
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||||
|
|
||||||
|
input_image.set_data_from_numpy(image_obj)
|
||||||
|
|
||||||
|
inputs = [input_image]
|
||||||
|
ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||||
|
|
||||||
|
time_out = 600
|
||||||
|
generate_data = None
|
||||||
|
while time_out > 0:
|
||||||
|
generate_data, _ = self.read_tasks_status()
|
||||||
|
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||||
|
ctx.cancel()
|
||||||
|
break
|
||||||
|
elif generate_data['status'] == "SUCCESS":
|
||||||
|
break
|
||||||
|
time_out -= 1
|
||||||
|
time.sleep(0.1)
|
||||||
|
return generate_data
|
||||||
|
except Exception as e:
|
||||||
|
self.generate_data['status'] = "FAILURE"
|
||||||
|
self.generate_data['message'] = str(e)
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
|
raise Exception(str(e))
|
||||||
|
finally:
|
||||||
|
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||||
|
if DEBUG is False:
|
||||||
|
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
|
||||||
|
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||||
|
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||||
|
|
||||||
|
|
||||||
|
def infer_cancel(tasks_id):
|
||||||
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||||
|
generate_data = json.dumps(data)
|
||||||
|
redis_client.set(tasks_id, generate_data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
rd = GenerateMultiViewModel(
|
||||||
|
tasks_id="123-89",
|
||||||
|
image_url="aida-sys-image/images/female/outwear/0628000123.jpg",
|
||||||
|
)
|
||||||
|
server = GenerateMultiView(rd)
|
||||||
|
print(server.get_result())
|
||||||
@@ -82,7 +82,7 @@ if __name__ == '__main__':
|
|||||||
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||||
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
||||||
# url = "aida-users/89/single_logo/123-89.png"
|
# url = "aida-users/89/single_logo/123-89.png"
|
||||||
url = "aida-users/89/test/123-89.png"
|
url = "aida-users/89/123-89.png"
|
||||||
|
|
||||||
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
||||||
read_type = "2"
|
read_type = "2"
|
||||||
|
|||||||
Reference in New Issue
Block a user