Merge branch 'refs/heads/develop'
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@@ -136,4 +136,9 @@ app/logs/*
|
||||
/qodana.yaml
|
||||
.pth
|
||||
.pytorch
|
||||
*.png
|
||||
*.png
|
||||
*.pth
|
||||
*.db
|
||||
*.npy
|
||||
*.pytorch
|
||||
*.jpg
|
||||
@@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
|
||||
"""
|
||||
try:
|
||||
for item in request_item:
|
||||
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||
logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||
if DEBUG:
|
||||
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
|
||||
else:
|
||||
service = AttributeRecognition(const=const, request_data=request_item)
|
||||
data = service.get_result()
|
||||
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
|
||||
logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
34
app/api/api_brand_dna.py
Normal file
34
app/api/api_brand_dna.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.brand_dna import BrandDnaModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.brand_dna.service import BrandDna
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/seg_product")
|
||||
def image2sketch(request_item: BrandDnaModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **image_url**: 提取图片url
|
||||
- **is_brand_dna**: 是否提取属性
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
|
||||
"is_brand_dna": False
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = BrandDna(request_item)
|
||||
result_url = service.get_result()
|
||||
except Exception as e:
|
||||
logger.warning(f"brand dna Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=result_url)
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
|
||||
|
||||
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel
|
||||
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.design.model_process_service import model_transpose
|
||||
from app.service.design_batch.service import start_design_batch_generate
|
||||
@@ -18,6 +18,14 @@ logger = logging.getLogger()
|
||||
@router.post("/design")
|
||||
def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
objects.items.transparent:
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
mask_url 为空"" -> 单件衣服透明
|
||||
mask_url 非空"mask_url" -> 区域透明
|
||||
|
||||
创建一个具有以下参数的请求体:
|
||||
示例参数:
|
||||
{
|
||||
@@ -197,7 +205,7 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||
|
||||
|
||||
@router.post("/design_v2")
|
||||
async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||
async def design_v2(request_data: DesignStreamModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
示例参数:
|
||||
@@ -445,7 +453,7 @@ async def design(file: UploadFile = File(...),
|
||||
|
||||
async def save_request_file(contents, file_name):
|
||||
# 创建保存文件的目录(如果不存在)
|
||||
save_dir = os.path.join(os.getcwd(), "design_batch", "request_data")
|
||||
save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
# 处理文件
|
||||
|
||||
@@ -3,9 +3,10 @@ import logging
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel
|
||||
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
||||
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
|
||||
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
||||
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
|
||||
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
|
||||
@@ -61,6 +62,44 @@ def generate_image(tasks_id: str):
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''multi view'''
|
||||
|
||||
|
||||
@router.post("/generate_multi_view")
|
||||
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 前视角图的输入,minio或S3 url 地址
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateMultiView(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_multi_view_cancel/{tasks_id}")
|
||||
def generate_multi_view(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_multi_view_cancel(tasks_id)
|
||||
logger.info(f"generate_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''single logo'''
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def prompt_generation(request_data: PromptGenerationImageModel):
|
||||
"""
|
||||
try:
|
||||
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
|
||||
data = get_translation_from_llama3("[" + request_data.text + "]")
|
||||
data = get_translation_from_llama3(request_data.text)
|
||||
logger.info(f"prompt_generation response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import api_attribute_retrieve, api_query_image
|
||||
from app.api import api_brand_dna
|
||||
from app.api import api_brighten
|
||||
from app.api import api_chat_robot
|
||||
from app.api import api_design
|
||||
@@ -23,4 +24,5 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
|
||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL
|
||||
from app.schemas.response_template import ResponseModel
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -18,6 +18,7 @@ def test(id: int):
|
||||
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
||||
"local_oss_server": OSS
|
||||
}
|
||||
logger.info(json.dumps(data))
|
||||
|
||||
@@ -34,6 +34,9 @@ else:
|
||||
RABBITMQ_ENV = "-dev" # 开发环境
|
||||
# RABBITMQ_ENV = "-local" # 本地测试环境
|
||||
|
||||
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# minio 配置
|
||||
@@ -106,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
||||
GI_MODEL_URL = '10.1.1.240:10061'
|
||||
GI_MODEL_NAME = 'flux'
|
||||
|
||||
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||
GMV_MODEL_NAME = 'multi_view'
|
||||
|
||||
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
|
||||
|
||||
|
||||
GI_MINIO_BUCKET = "aida-users"
|
||||
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||
|
||||
6
app/schemas/brand_dna.py
Normal file
6
app/schemas/brand_dna.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BrandDnaModel(BaseModel):
|
||||
image_url: str
|
||||
is_brand_dna: bool
|
||||
@@ -6,6 +6,12 @@ class DesignModel(BaseModel):
|
||||
process_id: str
|
||||
|
||||
|
||||
class DesignStreamModel(BaseModel):
|
||||
objects: list[dict]
|
||||
process_id: str
|
||||
requestId: str
|
||||
|
||||
|
||||
class DesignProgressModel(BaseModel):
|
||||
process_id: str
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GenerateMultiViewModel(BaseModel):
|
||||
tasks_id: str
|
||||
image_url: str
|
||||
|
||||
|
||||
class GenerateImageModel(BaseModel):
|
||||
tasks_id: str
|
||||
prompt: str
|
||||
|
||||
335
app/service/brand_dna/service.py
Normal file
335
app/service/brand_dna/service.py
Normal file
@@ -0,0 +1,335 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL
|
||||
from app.schemas.brand_dna import BrandDnaModel
|
||||
from app.service.attribute.config import local_debug_const
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
class BrandDna:
|
||||
def __init__(self, request_item):
|
||||
self.sketch_bucket = "test"
|
||||
self.image_url = request_item.image_url
|
||||
self.is_brand_dna = request_item.is_brand_dna
|
||||
# self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
|
||||
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
|
||||
# self.const = const
|
||||
self.const = local_debug_const
|
||||
|
||||
# 获取结果
|
||||
def get_result(self):
|
||||
mask, image = self.get_seg_mask()
|
||||
cv2.imshow("", image)
|
||||
cv2.waitKey(0)
|
||||
|
||||
height, width, channels = image.shape
|
||||
result_dict = []
|
||||
white_img = np.ones((height, width, channels), dtype=image.dtype) * 255
|
||||
mask_image = np.zeros((height, width, 3))
|
||||
|
||||
for value in np.unique(mask):
|
||||
if value == 1:
|
||||
outwear_img = white_img.copy()
|
||||
outwear_mask_img = mask_image.copy()
|
||||
outwear_img[mask == value] = image[mask == value]
|
||||
outwear_mask_img[mask == value] = [0, 0, 255]
|
||||
|
||||
cv2.imshow("", outwear_img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
# 预处理之后的input img
|
||||
preprocess_img = self.category_preprocess(outwear_img)
|
||||
# 类别检测
|
||||
category = self.recognition_category(preprocess_img)
|
||||
if category == 'Trousers' or category == 'Skirt':
|
||||
male_category = 'Bottoms'
|
||||
elif category == 'Blouse' or category == 'Dress':
|
||||
male_category = 'Tops'
|
||||
else:
|
||||
male_category = 'Outwear'
|
||||
|
||||
attribute = {}
|
||||
mask_url = ""
|
||||
img_url = ""
|
||||
# 属性检测
|
||||
if self.is_brand_dna:
|
||||
attribute = self.get_recognition_attribute_result(category, preprocess_img)
|
||||
else:
|
||||
img_url = self.put_image(outwear_img, f"img/{generate_uuid()}")
|
||||
mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}")
|
||||
|
||||
result_dict.append(
|
||||
{
|
||||
'category_female': category,
|
||||
'category_male': male_category,
|
||||
'mask_url': mask_url,
|
||||
'img_url': img_url,
|
||||
'attribute': attribute
|
||||
}
|
||||
)
|
||||
if value == 2:
|
||||
tops_img = white_img.copy()
|
||||
tops_mask_img = mask_image.copy()
|
||||
tops_img[mask == value] = image[mask == value]
|
||||
tops_mask_img[mask == value] = [0, 0, 255]
|
||||
|
||||
cv2.imshow("", tops_img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
# 预处理之后的input img
|
||||
preprocess_img = self.category_preprocess(tops_img)
|
||||
# 类别检测
|
||||
category = self.recognition_category(preprocess_img)
|
||||
if category == 'Trousers' or category == 'Skirt':
|
||||
male_category = 'Bottoms'
|
||||
elif category == 'Blouse' or category == 'Dress':
|
||||
male_category = 'Tops'
|
||||
else:
|
||||
male_category = 'Outwear'
|
||||
|
||||
# 属性检测
|
||||
attribute = {}
|
||||
img_url = ""
|
||||
mask_url = ""
|
||||
# 属性检测
|
||||
if self.is_brand_dna:
|
||||
attribute = self.get_recognition_attribute_result(category, preprocess_img)
|
||||
else:
|
||||
mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}")
|
||||
img_url = self.put_image(tops_img, f"img/{generate_uuid()}")
|
||||
|
||||
result_dict.append(
|
||||
{
|
||||
'category_female': category,
|
||||
'category_male': male_category,
|
||||
'mask_url': mask_url,
|
||||
'img_url': img_url,
|
||||
'attribute': attribute
|
||||
}
|
||||
)
|
||||
if value == 3:
|
||||
bottoms_img = white_img.copy()
|
||||
bottoms_mask_img = mask_image.copy()
|
||||
bottoms_img[mask == value] = image[mask == value]
|
||||
bottoms_mask_img[mask == value] = [0, 0, 255]
|
||||
|
||||
cv2.imshow("", bottoms_img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
# 预处理之后的input img
|
||||
preprocess_img = self.category_preprocess(bottoms_img)
|
||||
# 类别检测
|
||||
category = self.recognition_category(preprocess_img)
|
||||
if category == 'Trousers' or category == 'Skirt':
|
||||
male_category = 'Bottoms'
|
||||
elif category == 'Blouse' or category == 'Dress':
|
||||
male_category = 'Tops'
|
||||
else:
|
||||
male_category = 'Outwear'
|
||||
|
||||
attribute = {}
|
||||
img_url = ""
|
||||
mask_url = ""
|
||||
# 属性检测
|
||||
if self.is_brand_dna:
|
||||
attribute = self.get_recognition_attribute_result(category, preprocess_img)
|
||||
else:
|
||||
img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}")
|
||||
mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}")
|
||||
|
||||
result_dict.append(
|
||||
{
|
||||
'category_female': category,
|
||||
'category_male': male_category,
|
||||
'mask_url': mask_url,
|
||||
'img_url': img_url,
|
||||
'attribute': attribute
|
||||
}
|
||||
)
|
||||
return result_dict
|
||||
|
||||
# 获取product mask
|
||||
def get_seg_mask(self):
|
||||
input_image = self.get_image()
|
||||
input_img, ori_shape = self.seg_product_preprocess(input_image)
|
||||
transformed_img = input_img.astype(np.float32)
|
||||
|
||||
inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)]
|
||||
results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs)
|
||||
inference_output1 = results.as_numpy("seg_output__0")
|
||||
mask = self.product_postprocess(inference_output1, ori_shape)[0]
|
||||
return mask, input_image
|
||||
|
||||
# 获取图片
|
||||
def get_image(self):
|
||||
image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
||||
# 将其转换为彩色图像
|
||||
if len(image.shape) == 3 and image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
||||
elif len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
return image
|
||||
# return cv2.imread(self.image_url)
|
||||
|
||||
# 图片上传
|
||||
def put_image(self, image, object_name):
|
||||
try:
|
||||
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||
oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes)
|
||||
return f"{self.sketch_bucket}/{object_name}.jpg"
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
|
||||
# 服装分割预处理
|
||||
@staticmethod
|
||||
def seg_product_preprocess(image):
|
||||
img = mmcv.imread(image)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale_w, img_scale_h = ori_shape
|
||||
if ori_shape[0] > 1024:
|
||||
img_scale_w = 1024
|
||||
if ori_shape[1] > 1024:
|
||||
img_scale_h = 1024
|
||||
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||||
if ori_shape != (img_scale_w, img_scale_h):
|
||||
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||||
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, ori_shape
|
||||
|
||||
# 类别检测后处理
|
||||
@staticmethod
|
||||
def product_postprocess(output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
# 类别检测模型预处理
|
||||
@staticmethod
|
||||
def category_preprocess(img):
|
||||
img = mmcv.imread(img)
|
||||
# ori_shape = img.shape[:2]
|
||||
img_scale = (224, 224)
|
||||
img = cv2.resize(img, img_scale)
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img
|
||||
|
||||
# 类别检测
|
||||
def recognition_category(self, image):
|
||||
inputs = [
|
||||
httpclient.InferInput("input__0", image.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(image, binary_data=True)
|
||||
results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
|
||||
|
||||
scores = inference_output.detach().numpy()
|
||||
|
||||
colattr = list(self.attr_type['labelName'])
|
||||
maxsc = np.max(scores[0][:5])
|
||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||
|
||||
return colattr[indexs[0]]
|
||||
|
||||
# 属性检测
|
||||
def recognition_attribute(self, model_name, description, image):
|
||||
attr_type = pd.read_csv(description)
|
||||
inputs = [
|
||||
httpclient.InferInput("input__0", image.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(image, binary_data=True)
|
||||
results = self.att_client.infer(model_name=model_name, inputs=inputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
|
||||
scores = inference_output.detach().numpy()
|
||||
colattr = list(attr_type['labelName'])
|
||||
task = description.split('/')[-1][:-4]
|
||||
maxsc = np.max(scores[0][:5])
|
||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||
attr = {
|
||||
task: []
|
||||
}
|
||||
for i in range(len(indexs)):
|
||||
atr = colattr[indexs[i]]
|
||||
attr[task].append(atr)
|
||||
return attr
|
||||
|
||||
# 获取属性检测结果
|
||||
def get_recognition_attribute_result(self, category, input_img):
|
||||
if category == "Blouse":
|
||||
attr_dict = {}
|
||||
for i in range(len(self.const.top_description_list)):
|
||||
attr_description = self.const.top_description_list[i]
|
||||
attr_model_path = self.const.top_model_list[i]
|
||||
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
|
||||
attr_dict = self.merge(attr_dict, present_dict)
|
||||
|
||||
elif category == 'Trousers' or category == "Skirt":
|
||||
attr_dict = {}
|
||||
for i in range(len(self.const.bottom_description_list)):
|
||||
attr_description = self.const.bottom_description_list[i]
|
||||
attr_model_path = self.const.bottom_model_list[i]
|
||||
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
|
||||
attr_dict = self.merge(attr_dict, present_dict)
|
||||
|
||||
elif category == 'Dress':
|
||||
attr_dict = {}
|
||||
for i in range(len(self.const.dress_description_list)):
|
||||
attr_description = self.const.dress_description_list[i]
|
||||
attr_model_path = self.const.dress_model_list[i]
|
||||
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
|
||||
attr_dict = self.merge(attr_dict, present_dict)
|
||||
|
||||
elif category == 'Outwear':
|
||||
attr_dict = {}
|
||||
for i in range(len(self.const.outwear_description_list)):
|
||||
attr_description = self.const.outwear_description_list[i]
|
||||
attr_model_path = self.const.outwear_model_list[i]
|
||||
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
|
||||
attr_dict = self.merge(attr_dict, present_dict)
|
||||
else:
|
||||
attr_dict = {}
|
||||
return attr_dict
|
||||
|
||||
@staticmethod
|
||||
def merge(dict1, dict2):
|
||||
res = {**dict1, **dict2}
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# for path in os.listdir('./test_img'):
|
||||
# img_path = os.path.join(r'./test_img', path)
|
||||
# request_item = BrandDnaModel(
|
||||
# image_url=img_path,
|
||||
# is_brand_dna=True
|
||||
# )
|
||||
# service = BrandDna(request_item)
|
||||
# result_url = service.get_result()
|
||||
# print(result_url)
|
||||
request_item = BrandDnaModel(
|
||||
image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png",
|
||||
is_brand_dna=True
|
||||
)
|
||||
service = BrandDna(request_item)
|
||||
result_url = service.get_result()
|
||||
print(result_url)
|
||||
@@ -69,3 +69,5 @@ TOOLS_FUNCTIONS_SUFFIX = (
|
||||
)
|
||||
|
||||
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."
|
||||
|
||||
GET_LANGUAGE_PREFIX = "Please identify the language. Only output the language name"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from dashscope import Generation
|
||||
from retry import retry
|
||||
@@ -7,7 +8,8 @@ from urllib3.exceptions import NewConnectionError
|
||||
from app.core.config import *
|
||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||
GET_LANGUAGE_PREFIX
|
||||
from app.service.search_image_with_text.service import query
|
||||
|
||||
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
@@ -159,10 +161,11 @@ def search_from_internet(message):
|
||||
model='qwen-turbo',
|
||||
api_key=QWEN_API_KEY,
|
||||
messages=message,
|
||||
tools=tools,
|
||||
prompt='The output must be in English.Keep the final result under 200 words.'
|
||||
# tools=tools,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
enable_search='True'
|
||||
# result_format='message', # 将输出设置为message形式
|
||||
# enable_search='True'
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -198,14 +201,9 @@ def get_response(messages):
|
||||
|
||||
|
||||
def call_with_messages(message, gender):
|
||||
global tool_info
|
||||
user_input = message
|
||||
print('\n')
|
||||
# messages = [
|
||||
# {
|
||||
# "content": input('请输入:'), # 提问示例:"现在几点了?" "一个小时后几点" "北京天气如何?"
|
||||
# "role": "user"
|
||||
# }
|
||||
# ]
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -223,14 +221,10 @@ def call_with_messages(message, gender):
|
||||
}
|
||||
]
|
||||
|
||||
# 模型的第一轮调用
|
||||
# first_response = get_response(messages)
|
||||
# assistant_output = first_response.output.choices[0].message
|
||||
# print(f"\n大模型第一轮输出信息:{first_response}\n")
|
||||
# messages.append(assistant_output)
|
||||
flag = True
|
||||
count = 1
|
||||
result_content = "我是一个时尚AI助手,请问有什么可以帮您"
|
||||
# result_content = "我是一个时尚AI助手,请问有什么可以帮您"
|
||||
result_content = "I am a fashion AI assistant, how can I help you?"
|
||||
response_type = "chat"
|
||||
|
||||
while flag and count <= 3:
|
||||
@@ -244,44 +238,53 @@ def call_with_messages(message, gender):
|
||||
print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容
|
||||
result_content = assistant_output.content
|
||||
break
|
||||
# 如果模型选择的工具是search_from_internet
|
||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'search_from_internet':
|
||||
# tool_info = {"name": "search_from_internet", "role": "tool"}
|
||||
# user_input = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['user_input']
|
||||
# tool_info['content'] = search_from_internet(user_input)
|
||||
# 如果模型选择的工具是get_database_table
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
||||
tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
||||
# 如果模型选择的工具是get_table_info
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
|
||||
tool_info = {"name": "get_table_info", "role": "tool"}
|
||||
table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
|
||||
tool_info['content'] = get_table_info(table_names)
|
||||
# 如果模型选择的工具是query_database
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
|
||||
tool_info = {"name": "query_database", "role": "tool"}
|
||||
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
|
||||
tool_info['content'] = query_database(sql_string)
|
||||
# 如果模型选择的工具是internet_search
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'internet_search':
|
||||
tool_info = {"name": "search_from_internet", "role": "tool"}
|
||||
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
|
||||
message = [
|
||||
{'role': 'assistant', 'content': content['query']}
|
||||
]
|
||||
tool_info['content'] = search_from_internet(message)
|
||||
flag = False
|
||||
result_content = tool_info['content']
|
||||
response_type = "image"
|
||||
result_content = tool_info['content'].output.text
|
||||
# 如果模型选择的工具是get_database_table
|
||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
||||
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
||||
# # 如果模型选择的工具是get_table_info
|
||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
|
||||
# tool_info = {"name": "get_table_info", "role": "tool"}
|
||||
# table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
|
||||
# tool_info['content'] = get_table_info(table_names)
|
||||
# # 如果模型选择的工具是query_database
|
||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
|
||||
# tool_info = {"name": "query_database", "role": "tool"}
|
||||
# sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
|
||||
# tool_info['content'] = query_database(sql_string)
|
||||
# flag = False
|
||||
# result_content = tool_info['content']
|
||||
# response_type = "image"
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
|
||||
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
|
||||
flag = False
|
||||
result_content = tool_info['content']
|
||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
|
||||
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
|
||||
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
|
||||
'content': get_image_from_vector_db(gender, user_input)}
|
||||
'content': get_image_from_vector_db(gender, content['parameters']['content'])}
|
||||
flag = False
|
||||
result_content = tool_info['content']
|
||||
response_type = "image"
|
||||
else:
|
||||
tool_info = {"name": assistant_output.tool_calls[0]['function']['name'], 'content': 'null'}
|
||||
logging.info(assistant_output.tool_calls[0]['function']['name'] + "(unknown tools)")
|
||||
flag = False
|
||||
|
||||
print(f"工具输出信息:{tool_info['content']}\n")
|
||||
messages.append(tool_info)
|
||||
count += 1
|
||||
|
||||
final_output = {"output": result_content}
|
||||
final_output["response_type"] = response_type
|
||||
final_output = {"output": result_content, "response_type": response_type}
|
||||
QWenCallbackHandler.on_chain_end(qwen, final_output)
|
||||
|
||||
# 模型的第二轮调用,对工具的输出进行总结
|
||||
@@ -298,5 +301,23 @@ def tutorial_tool():
|
||||
return TUTORIAL_TOOL_RETURN
|
||||
|
||||
|
||||
def get_language(message: str) -> str:
|
||||
messages = [
|
||||
{
|
||||
"content": message, # 用户message
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": GET_LANGUAGE_PREFIX, # ai message
|
||||
"role": "assistant"
|
||||
}
|
||||
]
|
||||
|
||||
first_response = get_response(messages)
|
||||
assistant_output = first_response.output.choices[0].message.content
|
||||
logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
|
||||
return assistant_output
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
call_with_messages()
|
||||
get_language("")
|
||||
|
||||
@@ -53,7 +53,7 @@ class Segmentation(object):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存失败: {e}")
|
||||
|
||||
@@ -64,7 +64,7 @@ class Segmentation(object):
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logger.warning("文件不存在")
|
||||
# logger.warning("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"加载失败: {e}")
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.service.design_batch.utils.save_json import oss_upload_json
|
||||
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
||||
|
||||
id_lock = threading.Lock()
|
||||
celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://')
|
||||
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
@@ -120,7 +120,7 @@ def batch_design(objects_data, tasks_id, json_name):
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
logger.debug(object_response)
|
||||
oss_upload_json(minio_client, object_response, json_name)
|
||||
publish_status(tasks_id, "ok", json_name)
|
||||
return object_response
|
||||
|
||||
@@ -51,19 +51,19 @@ class Segmentation:
|
||||
file_path = f"seg_cache/{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"seg_cache/{image_id}.npy"
|
||||
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logger.warning("文件不存在")
|
||||
# logger.warning("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"加载失败: {e}")
|
||||
|
||||
@@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status
|
||||
|
||||
|
||||
async def start_design_batch_generate(data, file):
|
||||
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id)
|
||||
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.tasks_id, data.file_name)
|
||||
print(generate_clothes_task)
|
||||
publish_status(data.tasks_id, "0/100", "")
|
||||
return {"task_id": data.tasks_id}
|
||||
|
||||
@@ -157,6 +157,6 @@ if __name__ == '__main__':
|
||||
],
|
||||
"process_id": "83"
|
||||
}
|
||||
task_id = 1
|
||||
task_id = 10086
|
||||
json_name = "test.json"
|
||||
batch_design.delay(data['objects'], task_id, json_name)
|
||||
|
||||
@@ -2,9 +2,11 @@ import json
|
||||
|
||||
import pika
|
||||
|
||||
from app.core.config import RABBITMQ_PARAMS
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213'))
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue='DesignBatch', durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def oss_upload_json(oss_client, json_data, object_name):
|
||||
try:
|
||||
with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file:
|
||||
save_dir = os.path.join(os.getcwd(), "service/design_batch", "response_data")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
# 处理文件
|
||||
file_path = os.path.join(save_dir, object_name)
|
||||
with open(file_path, 'w') as file:
|
||||
json.dump(json_data, file, indent=4)
|
||||
oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}")
|
||||
json_bytes = json.dumps(json_data).encode('utf-8')
|
||||
oss_client.put_object("test", object_name, io.BytesIO(json_bytes), length=len(json_bytes), content_type="application/json")
|
||||
except Exception as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
@@ -139,6 +139,7 @@ def design_generate(request_data):
|
||||
@RunTime
|
||||
def design_generate_v2(request_data):
|
||||
objects_data = request_data.dict()['objects']
|
||||
request_id = request_data.requestId
|
||||
threads = []
|
||||
|
||||
def process_object(step, object):
|
||||
@@ -146,7 +147,7 @@ def design_generate_v2(request_data):
|
||||
items_response = {
|
||||
'layers': [],
|
||||
'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "",
|
||||
'requestId': object['requestId'] if 'requestId' in object.keys() else ""
|
||||
'requestId': request_id
|
||||
}
|
||||
if basic['single_overall'] == "overall":
|
||||
item_results = []
|
||||
@@ -197,9 +198,8 @@ def design_generate_v2(request_data):
|
||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
||||
|
||||
# 发送结果给java端
|
||||
url = "https://3998-117-143-125-51.ngrok-free.app/api/third/party/receiveDesignResults"
|
||||
url = JAVA_STREAM_API_URL
|
||||
headers = {
|
||||
'Accept': "*/*",
|
||||
'Accept-Encoding': "gzip, deflate, br",
|
||||
@@ -207,11 +207,11 @@ def design_generate_v2(request_data):
|
||||
'Connection': "keep-alive",
|
||||
'Content-Type': "application/json"
|
||||
}
|
||||
# logger.info(items_response)
|
||||
response = post_request(url, json_data=items_response, headers=headers)
|
||||
if response:
|
||||
# 打印结果
|
||||
logger.info(response.text)
|
||||
logger.info(items_response)
|
||||
|
||||
for step, object in enumerate(objects_data):
|
||||
t = threading.Thread(target=process_object, args=(step, object))
|
||||
|
||||
@@ -66,19 +66,19 @@ class Segmentation:
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logger.warning("文件不存在")
|
||||
# logger.warning("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"加载失败: {e}")
|
||||
|
||||
@@ -25,6 +25,7 @@ from app.core.config import *
|
||||
|
||||
def keypoint_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
|
||||
img_scale = (256, 256)
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.resize(img, img_scale)
|
||||
@@ -62,7 +63,11 @@ def keypoint_postprocess(output, scale_factor):
|
||||
scale_matrix = np.diag(scale_factor)
|
||||
nan = np.isinf(scale_matrix)
|
||||
scale_matrix[nan] = 0
|
||||
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||||
# 应用缩放因子
|
||||
scaled_result = np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||||
# 补偿边框偏移
|
||||
compensated_result = scaled_result - 25
|
||||
return compensated_result
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -266,7 +266,7 @@ class DesignPreprocessing:
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logging.info("文件不存在")
|
||||
# logging.info("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logging.warning(f"加载失败: {e}")
|
||||
@@ -277,7 +277,7 @@ class DesignPreprocessing:
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logging.info(f"保存成功,{os.path.abspath(file_path)}")
|
||||
logging.debug(f"保存成功,{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logging.warning(f"保存失败: {e}")
|
||||
|
||||
|
||||
126
app/service/generate_image/service_generate_multi_view.py
Normal file
126
app/service/generate_image/service_generate_multi_view.py
Normal file
@@ -0,0 +1,126 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import redis
|
||||
import tritonclient.grpc as grpcclient
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateMultiViewModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class GenerateMultiView:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
|
||||
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.image = self.get_image(request_data.image_url)
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
|
||||
def get_image(self, image_url):
|
||||
try:
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
return image
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def callback(self, result, error):
|
||||
if error:
|
||||
self.generate_data['status'] = "FAILURE"
|
||||
self.generate_data['message'] = str(error)
|
||||
# self.generate_data['data'] = str(error)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||
else:
|
||||
# pil图像转成numpy数组
|
||||
images = result.as_numpy("generated_image")
|
||||
# for id, img in enumerate(images):
|
||||
# cv2.imwrite(f"{id}.png", img)
|
||||
# image_url = ""
|
||||
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
|
||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||
self.generate_data['status'] = "SUCCESS"
|
||||
self.generate_data['message'] = "success"
|
||||
self.generate_data['image_url'] = str(image_url)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||
|
||||
def read_tasks_status(self):
|
||||
status_data = self.redis_client.get(self.tasks_id)
|
||||
return json.loads(status_data), status_data
|
||||
|
||||
def get_result(self):
|
||||
try:
|
||||
images = [np.array(self.image).astype(np.uint8)] * 1
|
||||
|
||||
image_obj = np.array(images, dtype=np.uint8)
|
||||
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
|
||||
inputs = [input_image]
|
||||
ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||
|
||||
time_out = 600
|
||||
generate_data = None
|
||||
while time_out > 0:
|
||||
generate_data, _ = self.read_tasks_status()
|
||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
break
|
||||
elif generate_data['status'] == "SUCCESS":
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(0.1)
|
||||
return generate_data
|
||||
except Exception as e:
|
||||
self.generate_data['status'] = "FAILURE"
|
||||
self.generate_data['message'] = str(e)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||
generate_data = json.dumps(data)
|
||||
redis_client.set(tasks_id, generate_data)
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = GenerateMultiViewModel(
|
||||
tasks_id="123-89",
|
||||
image_url="aida-sys-image/images/female/outwear/0628000123.jpg",
|
||||
)
|
||||
server = GenerateMultiView(rd)
|
||||
print(server.get_result())
|
||||
@@ -1,4 +1,196 @@
|
||||
#!/usr/bin/env python
|
||||
# #!/usr/bin/env python
|
||||
# # -*- coding: UTF-8 -*-
|
||||
# """
|
||||
# @Project :trinity_client
|
||||
# @File :service_att_recognition.py
|
||||
# @Author :周成融
|
||||
# @Date :2023/7/26 12:01:05
|
||||
# @detail :
|
||||
# """
|
||||
# import json
|
||||
# import logging
|
||||
# import time
|
||||
#
|
||||
# import cv2
|
||||
# import numpy as np
|
||||
# import redis
|
||||
# import tritonclient.grpc as grpcclient
|
||||
# from PIL import Image
|
||||
# from tritonclient.utils import np_to_triton_dtype
|
||||
#
|
||||
# from app.core.config import *
|
||||
# from app.schemas.generate_image import GenerateProductImageModel
|
||||
# from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
# from app.service.utils.oss_client import oss_get_image
|
||||
#
|
||||
# logger = logging.getLogger()
|
||||
#
|
||||
#
|
||||
# class GenerateProductImage:
|
||||
# def __init__(self, request_data):
|
||||
# if DEBUG is False:
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# # self.channel = self.connection.channel()
|
||||
# # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
# self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||
# self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
# self.category = "product_image"
|
||||
# self.image_strength = request_data.image_strength
|
||||
# self.batch_size = 1
|
||||
# self.product_type = request_data.product_type
|
||||
# self.prompt = request_data.prompt
|
||||
# self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url)
|
||||
# self.tasks_id = request_data.tasks_id
|
||||
# self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
# self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
# self.redis_client.expire(self.tasks_id, 600)
|
||||
#
|
||||
# def callback(self, result, error):
|
||||
# if error:
|
||||
# self.gen_product_data['status'] = "FAILURE"
|
||||
# self.gen_product_data['message'] = str(error)
|
||||
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
# else:
|
||||
# # pil图像转成numpy数组
|
||||
# image = result.as_numpy("generated_inpaint_image")
|
||||
# image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
||||
# cropped_image = post_processing_image(image_result, self.left, self.top)
|
||||
# image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||
# self.gen_product_data['status'] = "SUCCESS"
|
||||
# self.gen_product_data['message'] = "success"
|
||||
# self.gen_product_data['image_url'] = str(image_url)
|
||||
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
#
|
||||
# def read_tasks_status(self):
|
||||
# status_data = self.redis_client.get(self.tasks_id)
|
||||
# return json.loads(status_data), status_data
|
||||
#
|
||||
# def get_result(self):
|
||||
# try:
|
||||
# prompts = [self.prompt] * self.batch_size
|
||||
# self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
||||
# self.image = cv2.resize(self.image, (1024, 1024))
|
||||
# images = [self.image.astype(np.uint8)] * self.batch_size
|
||||
#
|
||||
# if self.product_type == "single":
|
||||
# text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
|
||||
# image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3))
|
||||
# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
|
||||
# else:
|
||||
# text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
# image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3))
|
||||
# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
|
||||
#
|
||||
# # 假设 prompts、images 和 self.image_strength 已经定义
|
||||
#
|
||||
# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
# input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
# input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||
#
|
||||
# input_text.set_data_from_numpy(text_obj)
|
||||
# input_image.set_data_from_numpy(image_obj)
|
||||
# input_image_strength.set_data_from_numpy(image_strength_obj)
|
||||
#
|
||||
# inputs = [input_text, input_image, input_image_strength]
|
||||
#
|
||||
# if self.product_type == "single":
|
||||
# ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback)
|
||||
# else:
|
||||
# ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
|
||||
#
|
||||
# time_out = 600
|
||||
# while time_out > 0:
|
||||
# gen_product_data, _ = self.read_tasks_status()
|
||||
# if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
# ctx.cancel()
|
||||
# break
|
||||
# elif gen_product_data['status'] == "SUCCESS":
|
||||
# break
|
||||
# time_out -= 1
|
||||
# time.sleep(0.1)
|
||||
# gen_product_data, _ = self.read_tasks_status()
|
||||
# return gen_product_data
|
||||
# except Exception as e:
|
||||
# self.gen_product_data['status'] = "FAILURE"
|
||||
# self.gen_product_data['message'] = str(e)
|
||||
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
# raise Exception(str(e))
|
||||
# finally:
|
||||
# dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||
# if DEBUG is False:
|
||||
# self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||
# logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||
#
|
||||
#
|
||||
# def infer_cancel(tasks_id):
|
||||
# redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
# data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||
# gen_product_data = json.dumps(data)
|
||||
# redis_client.set(tasks_id, gen_product_data)
|
||||
# return data
|
||||
#
|
||||
#
|
||||
# def pre_processing_image(image_url):
|
||||
# image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# # resize 原图至1024*1024
|
||||
# image = image.resize((int(1024 / image.height * image.width), 1024))
|
||||
#
|
||||
# # 原始图片的尺寸
|
||||
# width, height = image.size
|
||||
#
|
||||
# new_height, new_width = 1024, 1024
|
||||
# # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
|
||||
# pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
|
||||
#
|
||||
# # 将原始图片粘贴到新的画布中心
|
||||
# left = (new_width - width) // 2
|
||||
# top = (new_height - height) // 2
|
||||
# pad_image.paste(image, (left, top))
|
||||
#
|
||||
# # 将画布 resize 成宽度 1024,长度 1024
|
||||
# resized_image = pad_image.resize((1024, 1024))
|
||||
# image_size = (1024, 1024)
|
||||
#
|
||||
# if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
|
||||
# # 创建白色背景
|
||||
# background = Image.new("RGB", image_size, (255, 255, 255))
|
||||
# # 将图片粘贴到白色背景上
|
||||
# background.paste(resized_image, mask=resized_image.split()[3])
|
||||
# image = np.array(background)
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
# return image, image_size, left, top
|
||||
#
|
||||
#
|
||||
# def post_processing_image(image, left, top):
|
||||
# resized_image = image.resize((int(image.width * (768 / image.height)), 768))
|
||||
# # 计算裁剪的坐标
|
||||
# left = (resized_image.width - 512) // 2
|
||||
# upper = 0
|
||||
# right = left + 512
|
||||
# lower = 768
|
||||
#
|
||||
# # 进行裁剪
|
||||
# cropped_image = resized_image.crop((left, upper, right, lower))
|
||||
# return cropped_image
|
||||
#
|
||||
#
|
||||
# if __name__ == '__main__':
|
||||
# rd = GenerateProductImageModel(
|
||||
# tasks_id="123-89",
|
||||
# # prompt="",
|
||||
# image_strength=0.7,
|
||||
# prompt="The best quality, masterpiece,outwear, 8K realistic, HUD",
|
||||
# image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png",
|
||||
# product_type="overall"
|
||||
# )
|
||||
# server = GenerateProductImage(rd)
|
||||
# print(server.get_result())
|
||||
|
||||
# 旧版product
|
||||
# !/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@@ -34,14 +226,14 @@ class GenerateProductImage:
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url="10.1.1.243:18001")
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "product_image"
|
||||
self.image_strength = request_data.image_strength
|
||||
self.batch_size = 1
|
||||
self.product_type = request_data.product_type
|
||||
self.prompt = request_data.prompt
|
||||
self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url)
|
||||
self.image = pre_processing_image(request_data.image_url)
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||
@@ -55,10 +247,13 @@ class GenerateProductImage:
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
else:
|
||||
# pil图像转成numpy数组
|
||||
image = result.as_numpy("generated_inpaint_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
||||
cropped_image = post_processing_image(image_result, self.left, self.top)
|
||||
image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||
if self.product_type == "single":
|
||||
image = result.as_numpy("generated_cnet_image")
|
||||
else:
|
||||
image = result.as_numpy("generated_inpaint_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
# cropped_image = post_processing_image(image_result, self.left, self.top)
|
||||
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||
self.gen_product_data['status'] = "SUCCESS"
|
||||
self.gen_product_data['message'] = "success"
|
||||
self.gen_product_data['image_url'] = str(image_url)
|
||||
@@ -71,17 +266,18 @@ class GenerateProductImage:
|
||||
def get_result(self):
|
||||
try:
|
||||
prompts = [self.prompt] * self.batch_size
|
||||
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
||||
self.image = cv2.resize(self.image, (1024, 1024))
|
||||
|
||||
self.image = cv2.cvtColor(self.image, cv2.COLOR_RGBA2RGB)
|
||||
# self.image = cv2.resize(self.image, (1024, 1024))
|
||||
images = [self.image.astype(np.uint8)] * self.batch_size
|
||||
|
||||
if self.product_type == "single":
|
||||
text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3))
|
||||
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
|
||||
|
||||
# 假设 prompts、images 和 self.image_strength 已经定义
|
||||
@@ -97,9 +293,9 @@ class GenerateProductImage:
|
||||
inputs = [input_text, input_image, input_image_strength]
|
||||
|
||||
if self.product_type == "single":
|
||||
ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback)
|
||||
ctx = self.grpc_client.async_infer(model_name="stable_diffusion_1_5_cnet", inputs=inputs, callback=self.callback)
|
||||
else:
|
||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
|
||||
ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback)
|
||||
|
||||
time_out = 600
|
||||
while time_out > 0:
|
||||
@@ -135,33 +331,26 @@ def infer_cancel(tasks_id):
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# resize 原图至1024*1024
|
||||
image = image.resize((int(1024 / image.height * image.width), 1024))
|
||||
|
||||
# 原始图片的尺寸
|
||||
# 调整图片高度为768像素,保持宽高比
|
||||
width, height = image.size
|
||||
new_height = 768
|
||||
new_width = int(width * (new_height / height))
|
||||
resized_image = image.resize((new_width, new_height))
|
||||
|
||||
new_height, new_width = 1024, 1024
|
||||
# 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
|
||||
pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
|
||||
# 创建一个512x768的透明图片
|
||||
result_image = Image.new("RGBA", (512, 768), (255, 255, 255, 255))
|
||||
|
||||
# 将原始图片粘贴到新的画布中心
|
||||
left = (new_width - width) // 2
|
||||
top = (new_height - height) // 2
|
||||
pad_image.paste(image, (left, top))
|
||||
# 计算需要粘贴的位置,使图片居中
|
||||
x_offset = (512 - new_width) // 2
|
||||
y_offset = 0
|
||||
|
||||
# 将画布 resize 成宽度 1024,长度 1024
|
||||
resized_image = pad_image.resize((1024, 1024))
|
||||
image_size = (1024, 1024)
|
||||
# 将调整大小后的图片粘贴到透明图片上
|
||||
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||
|
||||
if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
|
||||
# 创建白色背景
|
||||
background = Image.new("RGB", image_size, (255, 255, 255))
|
||||
# 将图片粘贴到白色背景上
|
||||
background.paste(resized_image, mask=resized_image.split()[3])
|
||||
image = np.array(background)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
return image, image_size, left, top
|
||||
image = np.array(result_image)
|
||||
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
return image
|
||||
|
||||
|
||||
def post_processing_image(image, left, top):
|
||||
@@ -182,9 +371,9 @@ if __name__ == '__main__':
|
||||
tasks_id="123-89",
|
||||
# prompt="",
|
||||
image_strength=0.7,
|
||||
prompt="The best quality, masterpiece,outwear, 8K realistic, HUD",
|
||||
image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png",
|
||||
product_type="overall"
|
||||
prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR",
|
||||
image_url="aida-results/result_40c7924e-e220-11ef-8ea2-0242ac150003.png",
|
||||
product_type="single"
|
||||
)
|
||||
server = GenerateProductImage(rd)
|
||||
print(server.get_result())
|
||||
|
||||
@@ -8,6 +8,7 @@ from requests import RequestException
|
||||
from retry import retry
|
||||
|
||||
from app.core.config import QWEN_API_KEY
|
||||
from app.service.chat_robot.script.service.CallQWen import get_language
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -93,7 +94,13 @@ def get_translation_from_llama3(text):
|
||||
|
||||
# prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
|
||||
|
||||
# 创建请求的负载
|
||||
# 先获取用户输入文本的语言
|
||||
language = get_language(text)
|
||||
|
||||
if 'English' in language:
|
||||
return text
|
||||
|
||||
# 创建请求的负载 translator是自定义的翻译模型
|
||||
payload = {
|
||||
"model": "translator",
|
||||
"prompt": f"[{text}]",
|
||||
@@ -117,6 +124,26 @@ def get_translation_from_llama3(text):
|
||||
print(response.text)
|
||||
|
||||
|
||||
# 在llama3中创建一个翻译模型
|
||||
# def create_model_with_llama(text):
|
||||
# url = "http://localhost:11434/api/create"
|
||||
# # url = "http://10.1.1.240:1143/api/generate"
|
||||
#
|
||||
# # prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
|
||||
#
|
||||
# # 创建翻译器的配置文件
|
||||
# payload = {
|
||||
# "model": "translator",
|
||||
# "modelfile": "FROM llama3\nSYSTEM Translate everything within the brackets [] into English."
|
||||
# "Never translate or modify any English input."
|
||||
# "The input must be fully translated into coherent English sentences."
|
||||
# }
|
||||
#
|
||||
# # 将负载转换为 JSON 格式
|
||||
# headers = {'Content-Type': 'application/json'}
|
||||
# response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
text = get_translation_from_llama3("[火焰]")
|
||||
|
||||
@@ -7,9 +7,9 @@ def RunTime(func):
|
||||
t1 = time.time()
|
||||
res = func(*args, **kwargs)
|
||||
t2 = time.time()
|
||||
# if t2 - t1 > 0.05:
|
||||
# logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||
logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||
if t2 - t1 > 0.05:
|
||||
logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||
# logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
@@ -22,7 +22,8 @@ def ClassCallRunTime(func):
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
class_name = args[0].__class__.__name__ # 获取类名
|
||||
print(f"class name: {class_name} , run time is : {execution_time} s")
|
||||
if execution_time > 0.05:
|
||||
logging.info(f"class name: {class_name} , run time is : {execution_time} s")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -82,7 +82,7 @@ if __name__ == '__main__':
|
||||
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
||||
# url = "aida-users/89/single_logo/123-89.png"
|
||||
url = "aida-users/89/test/123-89.png"
|
||||
url = "aida-users/89/123-89.png"
|
||||
|
||||
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
||||
read_type = "2"
|
||||
|
||||
Reference in New Issue
Block a user