Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
2025-01-14 09:41:46 +08:00
13 changed files with 208 additions and 29 deletions

View File

@@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
""" """
try: try:
for item in request_item: for item in request_item:
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}") logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
if DEBUG: if DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item) service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else: else:
service = AttributeRecognition(const=const, request_data=request_item) service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result() data = service.get_result()
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}") logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e: except Exception as e:
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}") logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))

View File

@@ -3,9 +3,10 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
@@ -61,6 +62,44 @@ def generate_image(tasks_id: str):
return ResponseModel(data=data['data']) return ResponseModel(data=data['data'])
'''multi view'''
@router.post("/generate_multi_view")
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 前视角图的输入minio或S3 url 地址
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
}
"""
try:
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateMultiView(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_multi_view_cancel/{tasks_id}")
def generate_multi_view(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_multi_view_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''single logo''' '''single logo'''

View File

@@ -109,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061' GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux' GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
GI_MINIO_BUCKET = "aida-users" GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"

View File

@@ -1,6 +1,11 @@
from pydantic import BaseModel from pydantic import BaseModel
class GenerateMultiViewModel(BaseModel):
tasks_id: str
image_url: str
class GenerateImageModel(BaseModel): class GenerateImageModel(BaseModel):
tasks_id: str tasks_id: str
prompt: str prompt: str

View File

@@ -53,7 +53,7 @@ class Segmentation(object):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy" file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try: try:
np.save(file_path, seg_result) np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}") logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e: except Exception as e:
logger.error(f"保存失败: {e}") logger.error(f"保存失败: {e}")
@@ -64,7 +64,7 @@ class Segmentation(object):
seg_result = np.load(file_path) seg_result = np.load(file_path)
return True, seg_result return True, seg_result
except FileNotFoundError: except FileNotFoundError:
logger.warning("文件不存在") # logger.warning("文件不存在")
return False, None return False, None
except Exception as e: except Exception as e:
logger.error(f"加载失败: {e}") logger.error(f"加载失败: {e}")

View File

@@ -51,19 +51,19 @@ class Segmentation:
file_path = f"seg_cache/{image_id}.npy" file_path = f"seg_cache/{image_id}.npy"
try: try:
np.save(file_path, seg_result) np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}") logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e: except Exception as e:
logger.error(f"保存失败: {e}") logger.error(f"保存失败: {e}")
@staticmethod @staticmethod
def load_seg_result(image_id): def load_seg_result(image_id):
file_path = f"seg_cache/{image_id}.npy" file_path = f"seg_cache/{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try: try:
seg_result = np.load(file_path) seg_result = np.load(file_path)
return True, seg_result return True, seg_result
except FileNotFoundError: except FileNotFoundError:
logger.warning("文件不存在") # logger.warning("文件不存在")
return False, None return False, None
except Exception as e: except Exception as e:
logger.error(f"加载失败: {e}") logger.error(f"加载失败: {e}")

View File

@@ -207,7 +207,7 @@ def design_generate_v2(request_data):
'Connection': "keep-alive", 'Connection': "keep-alive",
'Content-Type': "application/json" 'Content-Type': "application/json"
} }
logger.info(items_response) # logger.info(items_response)
response = post_request(url, json_data=items_response, headers=headers) response = post_request(url, json_data=items_response, headers=headers)
if response: if response:
# 打印结果 # 打印结果

View File

@@ -36,11 +36,11 @@ class Segmentation:
# preview 过模型 不缓存 # preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview": if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果 # 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0] seg_result = get_seg_result(result["image_id"], result['image'])
# submit 过模型 缓存 # submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果 # 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0] seg_result = get_seg_result(result["image_id"], result['image'])
self.save_seg_result(seg_result, result['image_id']) self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型 # null 正常流程 加载本地缓存 无缓存则过模型
else: else:
@@ -49,14 +49,14 @@ class Segmentation:
# 判断缓存和实际图片size是否相同 # 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape: if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果 # 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0] seg_result = get_seg_result(result["image_id"], result['image'])
self.save_seg_result(seg_result, result['image_id']) self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result result['seg_result'] = seg_result
# 处理前片后片 # 处理前片后片
temp_front = seg_result == 1.0 temp_front = seg_result == 1
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8)) result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0 temp_back = seg_result == 2
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8)) result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask'] result['mask'] = result['front_mask'] + result['back_mask']
return result return result
@@ -66,19 +66,19 @@ class Segmentation:
file_path = f"{SEG_CACHE_PATH}{image_id}.npy" file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try: try:
np.save(file_path, seg_result) np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}") logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e: except Exception as e:
logger.error(f"保存失败: {e}") logger.error(f"保存失败: {e}")
@staticmethod @staticmethod
def load_seg_result(image_id): def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy" file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy") # logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try: try:
seg_result = np.load(file_path) seg_result = np.load(file_path)
return True, seg_result return True, seg_result
except FileNotFoundError: except FileNotFoundError:
logger.warning("文件不存在") # logger.warning("文件不存在")
return False, None return False, None
except Exception as e: except Exception as e:
logger.error(f"加载失败: {e}") logger.error(f"加载失败: {e}")

View File

@@ -13,7 +13,6 @@ import cv2
import mmcv import mmcv
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
import tritonclient.http as httpclient import tritonclient.http as httpclient
from app.core.config import * from app.core.config import *
@@ -85,7 +84,10 @@ def seg_preprocess(img_path):
if ori_shape != (img_scale_w, img_scale_h): if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了 # mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w)) img = cv2.resize(img, (img_scale_h, img_scale_w))
# img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
# 扩充25的白边
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0) preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape return preprocessed_img, ori_shape
@@ -114,9 +116,9 @@ def get_seg_result(image_id, image):
# no cache # no cache
def seg_postprocess(image_id, output, ori_shape): def seg_postprocess(image_id, output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False) seg_logit = cv2.resize(output[0][0].astype(np.uint8), (ori_shape[1] + 50, ori_shape[0] + 50))
seg_pred = seg_logit.cpu().numpy() seg_logit = seg_logit[25: - 25, 25: - 25]
return seg_pred[0] return seg_logit
def key_point_show(image_path, key_point_result=None): def key_point_show(image_path, key_point_result=None):

View File

@@ -266,7 +266,7 @@ class DesignPreprocessing:
seg_result = np.load(file_path) seg_result = np.load(file_path)
return True, seg_result return True, seg_result
except FileNotFoundError: except FileNotFoundError:
logging.info("文件不存在") # logging.info("文件不存在")
return False, None return False, None
except Exception as e: except Exception as e:
logging.warning(f"加载失败: {e}") logging.warning(f"加载失败: {e}")
@@ -277,7 +277,7 @@ class DesignPreprocessing:
file_path = f"{SEG_CACHE_PATH}{image_id}.npy" file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try: try:
np.save(file_path, seg_result) np.save(file_path, seg_result)
logging.info(f"保存成功,{os.path.abspath(file_path)}") logging.debug(f"保存成功,{os.path.abspath(file_path)}")
except Exception as e: except Exception as e:
logging.warning(f"保存失败: {e}") logging.warning(f"保存失败: {e}")

View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from app.core.config import *
from app.schemas.generate_image import GenerateMultiViewModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class GenerateMultiView:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.image = self.get_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
try:
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
return image
except Exception as e:
logger.error(e)
def callback(self, result, error):
if error:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
# pil图像转成numpy数组
images = result.as_numpy("generated_image")
# for id, img in enumerate(images):
# cv2.imwrite(f"{id}.png", img)
# image_url = ""
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
images = [np.array(self.image).astype(np.uint8)] * 1
image_obj = np.array(images, dtype=np.uint8)
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image.set_data_from_numpy(image_obj)
inputs = [input_image]
ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateMultiViewModel(
tasks_id="123-89",
image_url="aida-sys-image/images/female/outwear/0628000123.jpg",
)
server = GenerateMultiView(rd)
print(server.get_result())

View File

@@ -7,9 +7,9 @@ def RunTime(func):
t1 = time.time() t1 = time.time()
res = func(*args, **kwargs) res = func(*args, **kwargs)
t2 = time.time() t2 = time.time()
# if t2 - t1 > 0.05: if t2 - t1 > 0.05:
# logging.info(f"function【{func.__name__}】,runtime【{str(t2 - t1)}】s")
logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s") logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s")
# logging.info(f"function【{func.__name__}】,runtime【{str(t2 - t1)}】s")
return res return res
return wrapper return wrapper
@@ -22,7 +22,8 @@ def ClassCallRunTime(func):
end_time = time.time() end_time = time.time()
execution_time = end_time - start_time execution_time = end_time - start_time
class_name = args[0].__class__.__name__ # 获取类名 class_name = args[0].__class__.__name__ # 获取类名
print(f"class name: {class_name} , run time is : {execution_time} s") if execution_time > 0.05:
logging.info(f"class name: {class_name} , run time is : {execution_time} s")
return result return result
return wrapper return wrapper

View File

@@ -82,7 +82,7 @@ if __name__ == '__main__':
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png" # url = "aida-users/89/single_logo/123-89.png"
url = "aida-users/89/test/123-89.png" url = "aida-users/89/123-89.png"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "2" read_type = "2"