Merge branch 'refs/heads/develop'

This commit is contained in:
zhouchengrong
2025-02-04 09:37:55 +08:00
31 changed files with 949 additions and 118 deletions

7
.gitignore vendored
View File

@@ -136,4 +136,9 @@ app/logs/*
/qodana.yaml
.pth
.pytorch
*.png
*.png
*.pth
*.db
*.npy
*.pytorch
*.jpg

View File

@@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
"""
try:
for item in request_item:
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
if DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else:
service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result()
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

34
app/api/api_brand_dna.py Normal file
View File

@@ -0,0 +1,34 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.brand_dna import BrandDnaModel
from app.schemas.response_template import ResponseModel
from app.service.brand_dna.service import BrandDna
router = APIRouter()
logger = logging.getLogger()
@router.post("/seg_product")
def image2sketch(request_item: BrandDnaModel):
"""
创建一个具有以下参数的请求体:
- **image_url**: 提取图片url
- **is_brand_dna**: 是否提取属性
示例参数:
{
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
"is_brand_dna": False
}
"""
try:
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = BrandDna(request_item)
result_url = service.get_result()
except Exception as e:
logger.warning(f"brand dna Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=result_url)

View File

@@ -4,7 +4,7 @@ import os
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel
from app.schemas.response_template import ResponseModel
from app.service.design.model_process_service import model_transpose
from app.service.design_batch.service import start_design_batch_generate
@@ -18,6 +18,14 @@ logger = logging.getLogger()
@router.post("/design")
def design(request_data: DesignModel, background_tasks: BackgroundTasks):
"""
objects.items.transparent:
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
mask_url 为空"" -> 单件衣服透明
mask_url 非空"mask_url" -> 区域透明
创建一个具有以下参数的请求体:
示例参数:
{
@@ -197,7 +205,7 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
@router.post("/design_v2")
async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks):
async def design_v2(request_data: DesignStreamModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
示例参数:
@@ -445,7 +453,7 @@ async def design(file: UploadFile = File(...),
async def save_request_file(contents, file_name):
# 创建保存文件的目录(如果不存在)
save_dir = os.path.join(os.getcwd(), "design_batch", "request_data")
save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 处理文件

View File

@@ -3,9 +3,10 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel
from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
@@ -61,6 +62,44 @@ def generate_image(tasks_id: str):
return ResponseModel(data=data['data'])
'''multi view'''
@router.post("/generate_multi_view")
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 前视角图的输入minio或S3 url 地址
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
}
"""
try:
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateMultiView(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_multi_view_cancel/{tasks_id}")
def generate_multi_view(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_multi_view_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''single logo'''

View File

@@ -26,7 +26,7 @@ def prompt_generation(request_data: PromptGenerationImageModel):
"""
try:
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
data = get_translation_from_llama3("[" + request_data.text + "]")
data = get_translation_from_llama3(request_data.text)
logger.info(f"prompt_generation response @@@@@@:{data}")
except Exception as e:
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter
from app.api import api_attribute_retrieve, api_query_image
from app.api import api_brand_dna
from app.api import api_brighten
from app.api import api_chat_robot
from app.api import api_design
@@ -23,4 +24,5 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")

View File

@@ -4,7 +4,7 @@ import logging
from fastapi import APIRouter
from fastapi import HTTPException
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL
from app.schemas.response_template import ResponseModel
logger = logging.getLogger()
@@ -18,6 +18,7 @@ def test(id: int):
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
"local_oss_server": OSS
}
logger.info(json.dumps(data))

View File

@@ -34,6 +34,9 @@ else:
RABBITMQ_ENV = "-dev" # 开发环境
# RABBITMQ_ENV = "-local" # 本地测试环境
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
settings = Settings()
# minio 配置
@@ -106,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"

6
app/schemas/brand_dna.py Normal file
View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class BrandDnaModel(BaseModel):
image_url: str
is_brand_dna: bool

View File

@@ -6,6 +6,12 @@ class DesignModel(BaseModel):
process_id: str
class DesignStreamModel(BaseModel):
objects: list[dict]
process_id: str
requestId: str
class DesignProgressModel(BaseModel):
process_id: str

View File

@@ -1,6 +1,11 @@
from pydantic import BaseModel
class GenerateMultiViewModel(BaseModel):
tasks_id: str
image_url: str
class GenerateImageModel(BaseModel):
tasks_id: str
prompt: str

View File

@@ -0,0 +1,335 @@
import logging
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import local_debug_const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
class BrandDna:
def __init__(self, request_item):
self.sketch_bucket = "test"
self.image_url = request_item.image_url
self.is_brand_dna = request_item.is_brand_dna
# self.attr_type = pd.read_csv(CATEGORY_PATH)
self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
# self.const = const
self.const = local_debug_const
# 获取结果
def get_result(self):
mask, image = self.get_seg_mask()
cv2.imshow("", image)
cv2.waitKey(0)
height, width, channels = image.shape
result_dict = []
white_img = np.ones((height, width, channels), dtype=image.dtype) * 255
mask_image = np.zeros((height, width, 3))
for value in np.unique(mask):
if value == 1:
outwear_img = white_img.copy()
outwear_mask_img = mask_image.copy()
outwear_img[mask == value] = image[mask == value]
outwear_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", outwear_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(outwear_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
attribute = {}
mask_url = ""
img_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
img_url = self.put_image(outwear_img, f"img/{generate_uuid()}")
mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
if value == 2:
tops_img = white_img.copy()
tops_mask_img = mask_image.copy()
tops_img[mask == value] = image[mask == value]
tops_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", tops_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(tops_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
# 属性检测
attribute = {}
img_url = ""
mask_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}")
img_url = self.put_image(tops_img, f"img/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
if value == 3:
bottoms_img = white_img.copy()
bottoms_mask_img = mask_image.copy()
bottoms_img[mask == value] = image[mask == value]
bottoms_mask_img[mask == value] = [0, 0, 255]
cv2.imshow("", bottoms_img)
cv2.waitKey(0)
# 预处理之后的input img
preprocess_img = self.category_preprocess(bottoms_img)
# 类别检测
category = self.recognition_category(preprocess_img)
if category == 'Trousers' or category == 'Skirt':
male_category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
male_category = 'Tops'
else:
male_category = 'Outwear'
attribute = {}
img_url = ""
mask_url = ""
# 属性检测
if self.is_brand_dna:
attribute = self.get_recognition_attribute_result(category, preprocess_img)
else:
img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}")
mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}")
result_dict.append(
{
'category_female': category,
'category_male': male_category,
'mask_url': mask_url,
'img_url': img_url,
'attribute': attribute
}
)
return result_dict
# 获取product mask
def get_seg_mask(self):
input_image = self.get_image()
input_img, ori_shape = self.seg_product_preprocess(input_image)
transformed_img = input_img.astype(np.float32)
inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)]
results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs)
inference_output1 = results.as_numpy("seg_output__0")
mask = self.product_postprocess(inference_output1, ori_shape)[0]
return mask, input_image
# 获取图片
def get_image(self):
image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
# 将其转换为彩色图像
if len(image.shape) == 3 and image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
elif len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
return image
# return cv2.imread(self.image_url)
# 图片上传
def put_image(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes)
return f"{self.sketch_bucket}/{object_name}.jpg"
except Exception as e:
logger.warning(e)
# 服装分割预处理
@staticmethod
def seg_product_preprocess(image):
img = mmcv.imread(image)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# 类别检测后处理
@staticmethod
def product_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
# 类别检测模型预处理
@staticmethod
def category_preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
# 类别检测
def recognition_category(self, image):
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
colattr = list(self.attr_type['labelName'])
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
return colattr[indexs[0]]
# 属性检测
def recognition_attribute(self, model_name, description, image):
attr_type = pd.read_csv(description)
inputs = [
httpclient.InferInput("input__0", image.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(image, binary_data=True)
results = self.att_client.infer(model_name=model_name, inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
scores = inference_output.detach().numpy()
colattr = list(attr_type['labelName'])
task = description.split('/')[-1][:-4]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
attr = {
task: []
}
for i in range(len(indexs)):
atr = colattr[indexs[i]]
attr[task].append(atr)
return attr
# 获取属性检测结果
def get_recognition_attribute_result(self, category, input_img):
if category == "Blouse":
attr_dict = {}
for i in range(len(self.const.top_description_list)):
attr_description = self.const.top_description_list[i]
attr_model_path = self.const.top_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Trousers' or category == "Skirt":
attr_dict = {}
for i in range(len(self.const.bottom_description_list)):
attr_description = self.const.bottom_description_list[i]
attr_model_path = self.const.bottom_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Dress':
attr_dict = {}
for i in range(len(self.const.dress_description_list)):
attr_description = self.const.dress_description_list[i]
attr_model_path = self.const.dress_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
elif category == 'Outwear':
attr_dict = {}
for i in range(len(self.const.outwear_description_list)):
attr_description = self.const.outwear_description_list[i]
attr_model_path = self.const.outwear_model_list[i]
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
attr_dict = self.merge(attr_dict, present_dict)
else:
attr_dict = {}
return attr_dict
@staticmethod
def merge(dict1, dict2):
res = {**dict1, **dict2}
return res
if __name__ == '__main__':
# for path in os.listdir('./test_img'):
# img_path = os.path.join(r'./test_img', path)
# request_item = BrandDnaModel(
# image_url=img_path,
# is_brand_dna=True
# )
# service = BrandDna(request_item)
# result_url = service.get_result()
# print(result_url)
request_item = BrandDnaModel(
image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png",
is_brand_dna=True
)
service = BrandDna(request_item)
result_url = service.get_result()
print(result_url)

View File

@@ -69,3 +69,5 @@ TOOLS_FUNCTIONS_SUFFIX = (
)
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."
GET_LANGUAGE_PREFIX = "Please identify the language. Only output the language name"

View File

@@ -1,4 +1,5 @@
import json
import logging
from dashscope import Generation
from retry import retry
@@ -7,7 +8,8 @@ from urllib3.exceptions import NewConnectionError
from app.core.config import *
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
GET_LANGUAGE_PREFIX
from app.service.search_image_with_text.service import query
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
@@ -159,10 +161,11 @@ def search_from_internet(message):
model='qwen-turbo',
api_key=QWEN_API_KEY,
messages=message,
tools=tools,
prompt='The output must be in English.Keep the final result under 200 words.'
# tools=tools,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
enable_search='True'
# result_format='message', # 将输出设置为message形式
# enable_search='True'
)
return response
@@ -198,14 +201,9 @@ def get_response(messages):
def call_with_messages(message, gender):
global tool_info
user_input = message
print('\n')
# messages = [
# {
# "content": input('请输入:'), # 提问示例:"现在几点了?" "一个小时后几点" "北京天气如何?"
# "role": "user"
# }
# ]
messages = [
{
@@ -223,14 +221,10 @@ def call_with_messages(message, gender):
}
]
# 模型的第一轮调用
# first_response = get_response(messages)
# assistant_output = first_response.output.choices[0].message
# print(f"\n大模型第一轮输出信息{first_response}\n")
# messages.append(assistant_output)
flag = True
count = 1
result_content = "我是一个时尚AI助手请问有什么可以帮您"
# result_content = "我是一个时尚AI助手请问有什么可以帮您"
result_content = "I am a fashion AI assistant, how can I help you?"
response_type = "chat"
while flag and count <= 3:
@@ -244,44 +238,53 @@ def call_with_messages(message, gender):
print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容
result_content = assistant_output.content
break
# 如果模型选择的工具是search_from_internet
# elif assistant_output.tool_calls[0]['function']['name'] == 'search_from_internet':
# tool_info = {"name": "search_from_internet", "role": "tool"}
# user_input = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['user_input']
# tool_info['content'] = search_from_internet(user_input)
# 如果模型选择的工具是get_database_table
elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
# 如果模型选择的工具是get_table_info
elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
tool_info = {"name": "get_table_info", "role": "tool"}
table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
tool_info['content'] = get_table_info(table_names)
# 如果模型选择的工具是query_database
elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
tool_info = {"name": "query_database", "role": "tool"}
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
tool_info['content'] = query_database(sql_string)
# 如果模型选择的工具是internet_search
elif assistant_output.tool_calls[0]['function']['name'] == 'internet_search':
tool_info = {"name": "search_from_internet", "role": "tool"}
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
message = [
{'role': 'assistant', 'content': content['query']}
]
tool_info['content'] = search_from_internet(message)
flag = False
result_content = tool_info['content']
response_type = "image"
result_content = tool_info['content'].output.text
# 如果模型选择的工具是get_database_table
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
# # 如果模型选择的工具是get_table_info
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
# tool_info = {"name": "get_table_info", "role": "tool"}
# table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
# tool_info['content'] = get_table_info(table_names)
# # 如果模型选择的工具是query_database
# elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
# tool_info = {"name": "query_database", "role": "tool"}
# sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
# tool_info['content'] = query_database(sql_string)
# flag = False
# result_content = tool_info['content']
# response_type = "image"
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
flag = False
result_content = tool_info['content']
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
'content': get_image_from_vector_db(gender, user_input)}
'content': get_image_from_vector_db(gender, content['parameters']['content'])}
flag = False
result_content = tool_info['content']
response_type = "image"
else:
tool_info = {"name": assistant_output.tool_calls[0]['function']['name'], 'content': 'null'}
logging.info(assistant_output.tool_calls[0]['function']['name'] + "(unknown tools)")
flag = False
print(f"工具输出信息:{tool_info['content']}\n")
messages.append(tool_info)
count += 1
final_output = {"output": result_content}
final_output["response_type"] = response_type
final_output = {"output": result_content, "response_type": response_type}
QWenCallbackHandler.on_chain_end(qwen, final_output)
# 模型的第二轮调用,对工具的输出进行总结
@@ -298,5 +301,23 @@ def tutorial_tool():
return TUTORIAL_TOOL_RETURN
def get_language(message: str) -> str:
messages = [
{
"content": message, # 用户message
"role": "user"
},
{
"content": GET_LANGUAGE_PREFIX, # ai message
"role": "assistant"
}
]
first_response = get_response(messages)
assistant_output = first_response.output.choices[0].message.content
logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
return assistant_output
if __name__ == '__main__':
call_with_messages()
get_language("")

View File

@@ -53,7 +53,7 @@ class Segmentation(object):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}")
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@@ -64,7 +64,7 @@ class Segmentation(object):
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logger.warning("文件不存在")
# logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")

View File

@@ -12,7 +12,7 @@ from app.service.design_batch.utils.save_json import oss_upload_json
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
id_lock = threading.Lock()
celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://')
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
@@ -120,7 +120,7 @@ def batch_design(objects_data, tasks_id, json_name):
for t in threads:
t.join()
logger.debug(object_response)
oss_upload_json(minio_client, object_response, json_name)
publish_status(tasks_id, "ok", json_name)
return object_response

View File

@@ -51,19 +51,19 @@ class Segmentation:
file_path = f"seg_cache/{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}")
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"seg_cache/{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logger.warning("文件不存在")
# logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")

View File

@@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status
async def start_design_batch_generate(data, file):
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id)
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.tasks_id, data.file_name)
print(generate_clothes_task)
publish_status(data.tasks_id, "0/100", "")
return {"task_id": data.tasks_id}

View File

@@ -157,6 +157,6 @@ if __name__ == '__main__':
],
"process_id": "83"
}
task_id = 1
task_id = 10086
json_name = "test.json"
batch_design.delay(data['objects'], task_id, json_name)

View File

@@ -2,9 +2,11 @@ import json
import pika
from app.core.config import RABBITMQ_PARAMS
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213'))
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue='DesignBatch', durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}

View File

@@ -1,13 +1,21 @@
import io
import json
import logging
import os
logger = logging.getLogger()
def oss_upload_json(oss_client, json_data, object_name):
try:
with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file:
save_dir = os.path.join(os.getcwd(), "service/design_batch", "response_data")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 处理文件
file_path = os.path.join(save_dir, object_name)
with open(file_path, 'w') as file:
json.dump(json_data, file, indent=4)
oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}")
json_bytes = json.dumps(json_data).encode('utf-8')
oss_client.put_object("test", object_name, io.BytesIO(json_bytes), length=len(json_bytes), content_type="application/json")
except Exception as e:
logger.warning(str(e))

View File

@@ -139,6 +139,7 @@ def design_generate(request_data):
@RunTime
def design_generate_v2(request_data):
objects_data = request_data.dict()['objects']
request_id = request_data.requestId
threads = []
def process_object(step, object):
@@ -146,7 +147,7 @@ def design_generate_v2(request_data):
items_response = {
'layers': [],
'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "",
'requestId': object['requestId'] if 'requestId' in object.keys() else ""
'requestId': request_id
}
if basic['single_overall'] == "overall":
item_results = []
@@ -197,9 +198,8 @@ def design_generate_v2(request_data):
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
})
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
# 发送结果给java端
url = "https://3998-117-143-125-51.ngrok-free.app/api/third/party/receiveDesignResults"
url = JAVA_STREAM_API_URL
headers = {
'Accept': "*/*",
'Accept-Encoding': "gzip, deflate, br",
@@ -207,11 +207,11 @@ def design_generate_v2(request_data):
'Connection': "keep-alive",
'Content-Type': "application/json"
}
# logger.info(items_response)
response = post_request(url, json_data=items_response, headers=headers)
if response:
# 打印结果
logger.info(response.text)
logger.info(items_response)
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object))

View File

@@ -66,19 +66,19 @@ class Segmentation:
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.info(f"保存成功 {os.path.abspath(file_path)}")
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logger.warning("文件不存在")
# logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")

View File

@@ -25,6 +25,7 @@ from app.core.config import *
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
@@ -62,7 +63,11 @@ def keypoint_postprocess(output, scale_factor):
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
# 应用缩放因子
scaled_result = np.ceil(np.dot(segment_result, scale_matrix) * 4)
# 补偿边框偏移
compensated_result = scaled_result - 25
return compensated_result
"""

View File

@@ -266,7 +266,7 @@ class DesignPreprocessing:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logging.info("文件不存在")
# logging.info("文件不存在")
return False, None
except Exception as e:
logging.warning(f"加载失败: {e}")
@@ -277,7 +277,7 @@ class DesignPreprocessing:
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logging.info(f"保存成功,{os.path.abspath(file_path)}")
logging.debug(f"保存成功,{os.path.abspath(file_path)}")
except Exception as e:
logging.warning(f"保存失败: {e}")

View File

@@ -0,0 +1,126 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from app.core.config import *
from app.schemas.generate_image import GenerateMultiViewModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class GenerateMultiView:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.image = self.get_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
try:
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
return image
except Exception as e:
logger.error(e)
def callback(self, result, error):
if error:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
# pil图像转成numpy数组
images = result.as_numpy("generated_image")
# for id, img in enumerate(images):
# cv2.imwrite(f"{id}.png", img)
# image_url = ""
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
images = [np.array(self.image).astype(np.uint8)] * 1
image_obj = np.array(images, dtype=np.uint8)
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image.set_data_from_numpy(image_obj)
inputs = [input_image]
ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateMultiViewModel(
tasks_id="123-89",
image_url="aida-sys-image/images/female/outwear/0628000123.jpg",
)
server = GenerateMultiView(rd)
print(server.get_result())

View File

@@ -1,4 +1,196 @@
#!/usr/bin/env python
# #!/usr/bin/env python
# # -*- coding: UTF-8 -*-
# """
# @Project trinity_client
# @File service_att_recognition.py
# @Author :周成融
# @Date 2023/7/26 12:01:05
# @detail
# """
# import json
# import logging
# import time
#
# import cv2
# import numpy as np
# import redis
# import tritonclient.grpc as grpcclient
# from PIL import Image
# from tritonclient.utils import np_to_triton_dtype
#
# from app.core.config import *
# from app.schemas.generate_image import GenerateProductImageModel
# from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
# from app.service.utils.oss_client import oss_get_image
#
# logger = logging.getLogger()
#
#
# class GenerateProductImage:
# def __init__(self, request_data):
# if DEBUG is False:
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# # self.channel = self.connection.channel()
# # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
# self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
# self.category = "product_image"
# self.image_strength = request_data.image_strength
# self.batch_size = 1
# self.product_type = request_data.product_type
# self.prompt = request_data.prompt
# self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url)
# self.tasks_id = request_data.tasks_id
# self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
# self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
# self.redis_client.expire(self.tasks_id, 600)
#
# def callback(self, result, error):
# if error:
# self.gen_product_data['status'] = "FAILURE"
# self.gen_product_data['message'] = str(error)
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
# else:
# # pil图像转成numpy数组
# image = result.as_numpy("generated_inpaint_image")
# image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
# cropped_image = post_processing_image(image_result, self.left, self.top)
# image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
# self.gen_product_data['status'] = "SUCCESS"
# self.gen_product_data['message'] = "success"
# self.gen_product_data['image_url'] = str(image_url)
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
#
# def read_tasks_status(self):
# status_data = self.redis_client.get(self.tasks_id)
# return json.loads(status_data), status_data
#
# def get_result(self):
# try:
# prompts = [self.prompt] * self.batch_size
# self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
# self.image = cv2.resize(self.image, (1024, 1024))
# images = [self.image.astype(np.uint8)] * self.batch_size
#
# if self.product_type == "single":
# text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
# image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3))
# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
# else:
# text_obj = np.array(prompts, dtype="object").reshape((1))
# image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3))
# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
#
# # 假设 prompts、images 和 self.image_strength 已经定义
#
# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
# input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
# input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
#
# input_text.set_data_from_numpy(text_obj)
# input_image.set_data_from_numpy(image_obj)
# input_image_strength.set_data_from_numpy(image_strength_obj)
#
# inputs = [input_text, input_image, input_image_strength]
#
# if self.product_type == "single":
# ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback)
# else:
# ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
#
# time_out = 600
# while time_out > 0:
# gen_product_data, _ = self.read_tasks_status()
# if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
# ctx.cancel()
# break
# elif gen_product_data['status'] == "SUCCESS":
# break
# time_out -= 1
# time.sleep(0.1)
# gen_product_data, _ = self.read_tasks_status()
# return gen_product_data
# except Exception as e:
# self.gen_product_data['status'] = "FAILURE"
# self.gen_product_data['message'] = str(e)
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
# raise Exception(str(e))
# finally:
# dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
# if DEBUG is False:
# self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
#
#
# def infer_cancel(tasks_id):
# redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
# data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
# gen_product_data = json.dumps(data)
# redis_client.set(tasks_id, gen_product_data)
# return data
#
#
# def pre_processing_image(image_url):
# image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# # resize 原图至1024*1024
# image = image.resize((int(1024 / image.height * image.width), 1024))
#
# # 原始图片的尺寸
# width, height = image.size
#
# new_height, new_width = 1024, 1024
# # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
# pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
#
# # 将原始图片粘贴到新的画布中心
# left = (new_width - width) // 2
# top = (new_height - height) // 2
# pad_image.paste(image, (left, top))
#
# # 将画布 resize 成宽度 1024长度 1024
# resized_image = pad_image.resize((1024, 1024))
# image_size = (1024, 1024)
#
# if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
# # 创建白色背景
# background = Image.new("RGB", image_size, (255, 255, 255))
# # 将图片粘贴到白色背景上
# background.paste(resized_image, mask=resized_image.split()[3])
# image = np.array(background)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# return image, image_size, left, top
#
#
# def post_processing_image(image, left, top):
# resized_image = image.resize((int(image.width * (768 / image.height)), 768))
# # 计算裁剪的坐标
# left = (resized_image.width - 512) // 2
# upper = 0
# right = left + 512
# lower = 768
#
# # 进行裁剪
# cropped_image = resized_image.crop((left, upper, right, lower))
# return cropped_image
#
#
# if __name__ == '__main__':
# rd = GenerateProductImageModel(
# tasks_id="123-89",
# # prompt="",
# image_strength=0.7,
# prompt="The best quality, masterpiece,outwear, 8K realistic, HUD",
# image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png",
# product_type="overall"
# )
# server = GenerateProductImage(rd)
# print(server.get_result())
# 旧版product
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@@ -34,14 +226,14 @@ class GenerateProductImage:
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.grpc_client = grpcclient.InferenceServerClient(url="10.1.1.243:18001")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
self.image_strength = request_data.image_strength
self.batch_size = 1
self.product_type = request_data.product_type
self.prompt = request_data.prompt
self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url)
self.image = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
@@ -55,10 +247,13 @@ class GenerateProductImage:
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
cropped_image = post_processing_image(image_result, self.left, self.top)
image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
if self.product_type == "single":
image = result.as_numpy("generated_cnet_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
# cropped_image = post_processing_image(image_result, self.left, self.top)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
@@ -71,17 +266,18 @@ class GenerateProductImage:
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
self.image = cv2.resize(self.image, (1024, 1024))
self.image = cv2.cvtColor(self.image, cv2.COLOR_RGBA2RGB)
# self.image = cv2.resize(self.image, (1024, 1024))
images = [self.image.astype(np.uint8)] * self.batch_size
if self.product_type == "single":
text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
# 假设 prompts、images 和 self.image_strength 已经定义
@@ -97,9 +293,9 @@ class GenerateProductImage:
inputs = [input_text, input_image, input_image_strength]
if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name="stable_diffusion_1_5_cnet", inputs=inputs, callback=self.callback)
else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
@@ -135,33 +331,26 @@ def infer_cancel(tasks_id):
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# resize 原图至1024*1024
image = image.resize((int(1024 / image.height * image.width), 1024))
# 原始图片的尺寸
# 调整图片高度为768像素保持宽高比
width, height = image.size
new_height = 768
new_width = int(width * (new_height / height))
resized_image = image.resize((new_width, new_height))
new_height, new_width = 1024, 1024
# 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
# 创建一个512x768的透明图片
result_image = Image.new("RGBA", (512, 768), (255, 255, 255, 255))
# 将原始图片粘贴到新的画布中心
left = (new_width - width) // 2
top = (new_height - height) // 2
pad_image.paste(image, (left, top))
# 计算需要粘贴的位置,使图片居中
x_offset = (512 - new_width) // 2
y_offset = 0
# 将画布 resize 成宽度 1024长度 1024
resized_image = pad_image.resize((1024, 1024))
image_size = (1024, 1024)
# 将调整大小后的图片粘贴到透明图片上
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
# 创建白色背景
background = Image.new("RGB", image_size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(resized_image, mask=resized_image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image, image_size, left, top
image = np.array(result_image)
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
return image
def post_processing_image(image, left, top):
@@ -182,9 +371,9 @@ if __name__ == '__main__':
tasks_id="123-89",
# prompt="",
image_strength=0.7,
prompt="The best quality, masterpiece,outwear, 8K realistic, HUD",
image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png",
product_type="overall"
prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR",
image_url="aida-results/result_40c7924e-e220-11ef-8ea2-0242ac150003.png",
product_type="single"
)
server = GenerateProductImage(rd)
print(server.get_result())

View File

@@ -8,6 +8,7 @@ from requests import RequestException
from retry import retry
from app.core.config import QWEN_API_KEY
from app.service.chat_robot.script.service.CallQWen import get_language
logger = logging.getLogger(__name__)
@@ -93,7 +94,13 @@ def get_translation_from_llama3(text):
# prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
# 创建请求的负载
# 先获取用户输入文本的语言
language = get_language(text)
if 'English' in language:
return text
# 创建请求的负载 translator是自定义的翻译模型
payload = {
"model": "translator",
"prompt": f"[{text}]",
@@ -117,6 +124,26 @@ def get_translation_from_llama3(text):
print(response.text)
# 在llama3中创建一个翻译模型
# def create_model_with_llama(text):
# url = "http://localhost:11434/api/create"
# # url = "http://10.1.1.240:1143/api/generate"
#
# # prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
#
# # 创建翻译器的配置文件
# payload = {
# "model": "translator",
# "modelfile": "FROM llama3\nSYSTEM Translate everything within the brackets [] into English."
# "Never translate or modify any English input."
# "The input must be fully translated into coherent English sentences."
# }
#
# # 将负载转换为 JSON 格式
# headers = {'Content-Type': 'application/json'}
# response = requests.post(url, data=json.dumps(payload), headers=headers)
def main():
"""Main function"""
text = get_translation_from_llama3("[火焰]")

View File

@@ -7,9 +7,9 @@ def RunTime(func):
t1 = time.time()
res = func(*args, **kwargs)
t2 = time.time()
# if t2 - t1 > 0.05:
# logging.info(f"function【{func.__name__}】,runtime{str(t2 - t1)}】s")
logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s")
if t2 - t1 > 0.05:
logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s")
# logging.info(f"function【{func.__name__}】,runtime{str(t2 - t1)}】s")
return res
return wrapper
@@ -22,7 +22,8 @@ def ClassCallRunTime(func):
end_time = time.time()
execution_time = end_time - start_time
class_name = args[0].__class__.__name__ # 获取类名
print(f"class name: {class_name} , run time is : {execution_time} s")
if execution_time > 0.05:
logging.info(f"class name: {class_name} , run time is : {execution_time} s")
return result
return wrapper

View File

@@ -82,7 +82,7 @@ if __name__ == '__main__':
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png"
url = "aida-users/89/test/123-89.png"
url = "aida-users/89/123-89.png"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "2"