Merge branch 'refs/heads/develop'
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@@ -136,4 +136,9 @@ app/logs/*
|
|||||||
/qodana.yaml
|
/qodana.yaml
|
||||||
.pth
|
.pth
|
||||||
.pytorch
|
.pytorch
|
||||||
*.png
|
*.png
|
||||||
|
*.pth
|
||||||
|
*.db
|
||||||
|
*.npy
|
||||||
|
*.pytorch
|
||||||
|
*.jpg
|
||||||
@@ -35,13 +35,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
for item in request_item:
|
for item in request_item:
|
||||||
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
|
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
|
||||||
else:
|
else:
|
||||||
service = AttributeRecognition(const=const, request_data=request_item)
|
service = AttributeRecognition(const=const, request_data=request_item)
|
||||||
data = service.get_result()
|
data = service.get_result()
|
||||||
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
|
logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
|
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|||||||
34
app/api/api_brand_dna.py
Normal file
34
app/api/api_brand_dna.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
|
from app.schemas.brand_dna import BrandDnaModel
|
||||||
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.brand_dna.service import BrandDna
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/seg_product")
|
||||||
|
def image2sketch(request_item: BrandDnaModel):
|
||||||
|
"""
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **image_url**: 提取图片url
|
||||||
|
- **is_brand_dna**: 是否提取属性
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
|
||||||
|
"is_brand_dna": False
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||||
|
service = BrandDna(request_item)
|
||||||
|
result_url = service.get_result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"brand dna Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=result_url)
|
||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
|
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
|
||||||
|
|
||||||
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel
|
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.design.model_process_service import model_transpose
|
from app.service.design.model_process_service import model_transpose
|
||||||
from app.service.design_batch.service import start_design_batch_generate
|
from app.service.design_batch.service import start_design_batch_generate
|
||||||
@@ -18,6 +18,14 @@ logger = logging.getLogger()
|
|||||||
@router.post("/design")
|
@router.post("/design")
|
||||||
def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||||
"""
|
"""
|
||||||
|
objects.items.transparent:
|
||||||
|
"transparent":{
|
||||||
|
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||||
|
"scale":0.1
|
||||||
|
},
|
||||||
|
mask_url 为空"" -> 单件衣服透明
|
||||||
|
mask_url 非空"mask_url" -> 区域透明
|
||||||
|
|
||||||
创建一个具有以下参数的请求体:
|
创建一个具有以下参数的请求体:
|
||||||
示例参数:
|
示例参数:
|
||||||
{
|
{
|
||||||
@@ -197,7 +205,7 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/design_v2")
|
@router.post("/design_v2")
|
||||||
async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks):
|
async def design_v2(request_data: DesignStreamModel, background_tasks: BackgroundTasks):
|
||||||
"""
|
"""
|
||||||
创建一个具有以下参数的请求体:
|
创建一个具有以下参数的请求体:
|
||||||
示例参数:
|
示例参数:
|
||||||
@@ -445,7 +453,7 @@ async def design(file: UploadFile = File(...),
|
|||||||
|
|
||||||
async def save_request_file(contents, file_name):
|
async def save_request_file(contents, file_name):
|
||||||
# 创建保存文件的目录(如果不存在)
|
# 创建保存文件的目录(如果不存在)
|
||||||
save_dir = os.path.join(os.getcwd(), "design_batch", "request_data")
|
save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
os.makedirs(save_dir)
|
os.makedirs(save_dir)
|
||||||
# 处理文件
|
# 处理文件
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ import logging
|
|||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||||
|
|
||||||
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel
|
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
||||||
|
from app.service.generate_image.service_generate_multi_view import GenerateMultiView, infer_cancel as generate_multi_view_cancel
|
||||||
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
||||||
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
|
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
|
||||||
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
|
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
|
||||||
@@ -61,6 +62,44 @@ def generate_image(tasks_id: str):
|
|||||||
return ResponseModel(data=data['data'])
|
return ResponseModel(data=data['data'])
|
||||||
|
|
||||||
|
|
||||||
|
'''multi view'''
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate_multi_view")
|
||||||
|
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||||
|
- **image_url**: 前视角图的输入,minio或S3 url 地址
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"tasks_id": "123-89",
|
||||||
|
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||||
|
service = GenerateMultiView(request_item)
|
||||||
|
background_tasks.add_task(service.get_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/generate_multi_view_cancel/{tasks_id}")
|
||||||
|
def generate_multi_view(tasks_id: str):
|
||||||
|
try:
|
||||||
|
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
|
||||||
|
data = generate_multi_view_cancel(tasks_id)
|
||||||
|
logger.info(f"generate_cancel response @@@@@@:{data}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=data['data'])
|
||||||
|
|
||||||
|
|
||||||
'''single logo'''
|
'''single logo'''
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def prompt_generation(request_data: PromptGenerationImageModel):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
|
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
|
||||||
data = get_translation_from_llama3("[" + request_data.text + "]")
|
data = get_translation_from_llama3(request_data.text)
|
||||||
logger.info(f"prompt_generation response @@@@@@:{data}")
|
logger.info(f"prompt_generation response @@@@@@:{data}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
|
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api import api_attribute_retrieve, api_query_image
|
from app.api import api_attribute_retrieve, api_query_image
|
||||||
|
from app.api import api_brand_dna
|
||||||
from app.api import api_brighten
|
from app.api import api_brighten
|
||||||
from app.api import api_chat_robot
|
from app.api import api_chat_robot
|
||||||
from app.api import api_design
|
from app.api import api_design
|
||||||
@@ -23,4 +24,5 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
|
|||||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||||
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
||||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||||
|
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import logging
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
|
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -18,6 +18,7 @@ def test(id: int):
|
|||||||
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||||
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||||
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||||
|
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
||||||
"local_oss_server": OSS
|
"local_oss_server": OSS
|
||||||
}
|
}
|
||||||
logger.info(json.dumps(data))
|
logger.info(json.dumps(data))
|
||||||
|
|||||||
@@ -34,6 +34,9 @@ else:
|
|||||||
RABBITMQ_ENV = "-dev" # 开发环境
|
RABBITMQ_ENV = "-dev" # 开发环境
|
||||||
# RABBITMQ_ENV = "-local" # 本地测试环境
|
# RABBITMQ_ENV = "-local" # 本地测试环境
|
||||||
|
|
||||||
|
JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/api/third/party/receiveDesignResults")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
# minio 配置
|
# minio 配置
|
||||||
@@ -106,6 +109,12 @@ FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
|||||||
GI_MODEL_URL = '10.1.1.240:10061'
|
GI_MODEL_URL = '10.1.1.240:10061'
|
||||||
GI_MODEL_NAME = 'flux'
|
GI_MODEL_NAME = 'flux'
|
||||||
|
|
||||||
|
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||||
|
GMV_MODEL_NAME = 'multi_view'
|
||||||
|
|
||||||
|
GMV_RABBITMQ_QUEUES = os.getenv("GMV_RABBITMQ_QUEUES", f"GenerateMultiView{RABBITMQ_ENV}")
|
||||||
|
|
||||||
|
|
||||||
GI_MINIO_BUCKET = "aida-users"
|
GI_MINIO_BUCKET = "aida-users"
|
||||||
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
||||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||||
|
|||||||
6
app/schemas/brand_dna.py
Normal file
6
app/schemas/brand_dna.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BrandDnaModel(BaseModel):
|
||||||
|
image_url: str
|
||||||
|
is_brand_dna: bool
|
||||||
@@ -6,6 +6,12 @@ class DesignModel(BaseModel):
|
|||||||
process_id: str
|
process_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class DesignStreamModel(BaseModel):
|
||||||
|
objects: list[dict]
|
||||||
|
process_id: str
|
||||||
|
requestId: str
|
||||||
|
|
||||||
|
|
||||||
class DesignProgressModel(BaseModel):
|
class DesignProgressModel(BaseModel):
|
||||||
process_id: str
|
process_id: str
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateMultiViewModel(BaseModel):
|
||||||
|
tasks_id: str
|
||||||
|
image_url: str
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageModel(BaseModel):
|
class GenerateImageModel(BaseModel):
|
||||||
tasks_id: str
|
tasks_id: str
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|||||||
335
app/service/brand_dna/service.py
Normal file
335
app/service/brand_dna/service.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tritonclient.http as httpclient
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL
|
||||||
|
from app.schemas.brand_dna import BrandDnaModel
|
||||||
|
from app.service.attribute.config import local_debug_const
|
||||||
|
from app.service.utils.generate_uuid import generate_uuid
|
||||||
|
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
class BrandDna:
|
||||||
|
def __init__(self, request_item):
|
||||||
|
self.sketch_bucket = "test"
|
||||||
|
self.image_url = request_item.image_url
|
||||||
|
self.is_brand_dna = request_item.is_brand_dna
|
||||||
|
# self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||||
|
self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
|
||||||
|
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||||
|
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
|
||||||
|
# self.const = const
|
||||||
|
self.const = local_debug_const
|
||||||
|
|
||||||
|
# 获取结果
|
||||||
|
def get_result(self):
|
||||||
|
mask, image = self.get_seg_mask()
|
||||||
|
cv2.imshow("", image)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
|
||||||
|
height, width, channels = image.shape
|
||||||
|
result_dict = []
|
||||||
|
white_img = np.ones((height, width, channels), dtype=image.dtype) * 255
|
||||||
|
mask_image = np.zeros((height, width, 3))
|
||||||
|
|
||||||
|
for value in np.unique(mask):
|
||||||
|
if value == 1:
|
||||||
|
outwear_img = white_img.copy()
|
||||||
|
outwear_mask_img = mask_image.copy()
|
||||||
|
outwear_img[mask == value] = image[mask == value]
|
||||||
|
outwear_mask_img[mask == value] = [0, 0, 255]
|
||||||
|
|
||||||
|
cv2.imshow("", outwear_img)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
|
||||||
|
# 预处理之后的input img
|
||||||
|
preprocess_img = self.category_preprocess(outwear_img)
|
||||||
|
# 类别检测
|
||||||
|
category = self.recognition_category(preprocess_img)
|
||||||
|
if category == 'Trousers' or category == 'Skirt':
|
||||||
|
male_category = 'Bottoms'
|
||||||
|
elif category == 'Blouse' or category == 'Dress':
|
||||||
|
male_category = 'Tops'
|
||||||
|
else:
|
||||||
|
male_category = 'Outwear'
|
||||||
|
|
||||||
|
attribute = {}
|
||||||
|
mask_url = ""
|
||||||
|
img_url = ""
|
||||||
|
# 属性检测
|
||||||
|
if self.is_brand_dna:
|
||||||
|
attribute = self.get_recognition_attribute_result(category, preprocess_img)
|
||||||
|
else:
|
||||||
|
img_url = self.put_image(outwear_img, f"img/{generate_uuid()}")
|
||||||
|
mask_url = self.put_image(outwear_mask_img, f"mask/{generate_uuid()}")
|
||||||
|
|
||||||
|
result_dict.append(
|
||||||
|
{
|
||||||
|
'category_female': category,
|
||||||
|
'category_male': male_category,
|
||||||
|
'mask_url': mask_url,
|
||||||
|
'img_url': img_url,
|
||||||
|
'attribute': attribute
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if value == 2:
|
||||||
|
tops_img = white_img.copy()
|
||||||
|
tops_mask_img = mask_image.copy()
|
||||||
|
tops_img[mask == value] = image[mask == value]
|
||||||
|
tops_mask_img[mask == value] = [0, 0, 255]
|
||||||
|
|
||||||
|
cv2.imshow("", tops_img)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
|
||||||
|
# 预处理之后的input img
|
||||||
|
preprocess_img = self.category_preprocess(tops_img)
|
||||||
|
# 类别检测
|
||||||
|
category = self.recognition_category(preprocess_img)
|
||||||
|
if category == 'Trousers' or category == 'Skirt':
|
||||||
|
male_category = 'Bottoms'
|
||||||
|
elif category == 'Blouse' or category == 'Dress':
|
||||||
|
male_category = 'Tops'
|
||||||
|
else:
|
||||||
|
male_category = 'Outwear'
|
||||||
|
|
||||||
|
# 属性检测
|
||||||
|
attribute = {}
|
||||||
|
img_url = ""
|
||||||
|
mask_url = ""
|
||||||
|
# 属性检测
|
||||||
|
if self.is_brand_dna:
|
||||||
|
attribute = self.get_recognition_attribute_result(category, preprocess_img)
|
||||||
|
else:
|
||||||
|
mask_url = self.put_image(tops_mask_img, f"mask/{generate_uuid()}")
|
||||||
|
img_url = self.put_image(tops_img, f"img/{generate_uuid()}")
|
||||||
|
|
||||||
|
result_dict.append(
|
||||||
|
{
|
||||||
|
'category_female': category,
|
||||||
|
'category_male': male_category,
|
||||||
|
'mask_url': mask_url,
|
||||||
|
'img_url': img_url,
|
||||||
|
'attribute': attribute
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if value == 3:
|
||||||
|
bottoms_img = white_img.copy()
|
||||||
|
bottoms_mask_img = mask_image.copy()
|
||||||
|
bottoms_img[mask == value] = image[mask == value]
|
||||||
|
bottoms_mask_img[mask == value] = [0, 0, 255]
|
||||||
|
|
||||||
|
cv2.imshow("", bottoms_img)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
|
||||||
|
# 预处理之后的input img
|
||||||
|
preprocess_img = self.category_preprocess(bottoms_img)
|
||||||
|
# 类别检测
|
||||||
|
category = self.recognition_category(preprocess_img)
|
||||||
|
if category == 'Trousers' or category == 'Skirt':
|
||||||
|
male_category = 'Bottoms'
|
||||||
|
elif category == 'Blouse' or category == 'Dress':
|
||||||
|
male_category = 'Tops'
|
||||||
|
else:
|
||||||
|
male_category = 'Outwear'
|
||||||
|
|
||||||
|
attribute = {}
|
||||||
|
img_url = ""
|
||||||
|
mask_url = ""
|
||||||
|
# 属性检测
|
||||||
|
if self.is_brand_dna:
|
||||||
|
attribute = self.get_recognition_attribute_result(category, preprocess_img)
|
||||||
|
else:
|
||||||
|
img_url = self.put_image(bottoms_img, f"img/{generate_uuid()}")
|
||||||
|
mask_url = self.put_image(bottoms_mask_img, f"mask/{generate_uuid()}")
|
||||||
|
|
||||||
|
result_dict.append(
|
||||||
|
{
|
||||||
|
'category_female': category,
|
||||||
|
'category_male': male_category,
|
||||||
|
'mask_url': mask_url,
|
||||||
|
'img_url': img_url,
|
||||||
|
'attribute': attribute
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
# 获取product mask
|
||||||
|
def get_seg_mask(self):
|
||||||
|
input_image = self.get_image()
|
||||||
|
input_img, ori_shape = self.seg_product_preprocess(input_image)
|
||||||
|
transformed_img = input_img.astype(np.float32)
|
||||||
|
|
||||||
|
inputs = [httpclient.InferInput(f"seg_input__0", transformed_img.shape, datatype="FP32")]
|
||||||
|
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||||
|
outputs = [httpclient.InferRequestedOutput(f"seg_output__0", binary_data=True)]
|
||||||
|
results = self.seg_client.infer(model_name=f"seg_product", inputs=inputs, outputs=outputs)
|
||||||
|
inference_output1 = results.as_numpy("seg_output__0")
|
||||||
|
mask = self.product_postprocess(inference_output1, ori_shape)[0]
|
||||||
|
return mask, input_image
|
||||||
|
|
||||||
|
# 获取图片
|
||||||
|
def get_image(self):
|
||||||
|
image = oss_get_image(oss_client=minio_client, bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
||||||
|
# 将其转换为彩色图像
|
||||||
|
if len(image.shape) == 3 and image.shape[2] == 4:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
||||||
|
elif len(image.shape) == 2:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||||
|
return image
|
||||||
|
# return cv2.imread(self.image_url)
|
||||||
|
|
||||||
|
# 图片上传
|
||||||
|
def put_image(self, image, object_name):
|
||||||
|
try:
|
||||||
|
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||||
|
oss_upload_image(oss_client=minio_client, bucket=self.sketch_bucket, object_name=f"{object_name}.jpg", image_bytes=image_bytes)
|
||||||
|
return f"{self.sketch_bucket}/{object_name}.jpg"
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(e)
|
||||||
|
|
||||||
|
# 服装分割预处理
|
||||||
|
@staticmethod
|
||||||
|
def seg_product_preprocess(image):
|
||||||
|
img = mmcv.imread(image)
|
||||||
|
ori_shape = img.shape[:2]
|
||||||
|
img_scale_w, img_scale_h = ori_shape
|
||||||
|
if ori_shape[0] > 1024:
|
||||||
|
img_scale_w = 1024
|
||||||
|
if ori_shape[1] > 1024:
|
||||||
|
img_scale_h = 1024
|
||||||
|
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||||||
|
if ori_shape != (img_scale_w, img_scale_h):
|
||||||
|
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||||||
|
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||||||
|
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||||
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||||
|
return preprocessed_img, ori_shape
|
||||||
|
|
||||||
|
# 类别检测后处理
|
||||||
|
@staticmethod
|
||||||
|
def product_postprocess(output, ori_shape):
|
||||||
|
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||||
|
seg_pred = seg_logit.cpu().numpy()
|
||||||
|
return seg_pred[0]
|
||||||
|
|
||||||
|
# 类别检测模型预处理
|
||||||
|
@staticmethod
|
||||||
|
def category_preprocess(img):
|
||||||
|
img = mmcv.imread(img)
|
||||||
|
# ori_shape = img.shape[:2]
|
||||||
|
img_scale = (224, 224)
|
||||||
|
img = cv2.resize(img, img_scale)
|
||||||
|
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||||
|
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||||
|
return preprocessed_img
|
||||||
|
|
||||||
|
# 类别检测
|
||||||
|
def recognition_category(self, image):
|
||||||
|
inputs = [
|
||||||
|
httpclient.InferInput("input__0", image.shape, datatype="FP32")
|
||||||
|
]
|
||||||
|
inputs[0].set_data_from_numpy(image, binary_data=True)
|
||||||
|
results = self.att_client.infer(model_name="attr_retrieve_category", inputs=inputs)
|
||||||
|
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
|
||||||
|
|
||||||
|
scores = inference_output.detach().numpy()
|
||||||
|
|
||||||
|
colattr = list(self.attr_type['labelName'])
|
||||||
|
maxsc = np.max(scores[0][:5])
|
||||||
|
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||||
|
|
||||||
|
return colattr[indexs[0]]
|
||||||
|
|
||||||
|
# 属性检测
|
||||||
|
def recognition_attribute(self, model_name, description, image):
|
||||||
|
attr_type = pd.read_csv(description)
|
||||||
|
inputs = [
|
||||||
|
httpclient.InferInput("input__0", image.shape, datatype="FP32")
|
||||||
|
]
|
||||||
|
inputs[0].set_data_from_numpy(image, binary_data=True)
|
||||||
|
results = self.att_client.infer(model_name=model_name, inputs=inputs)
|
||||||
|
inference_output = torch.from_numpy(results.as_numpy(f"output__0"))
|
||||||
|
scores = inference_output.detach().numpy()
|
||||||
|
colattr = list(attr_type['labelName'])
|
||||||
|
task = description.split('/')[-1][:-4]
|
||||||
|
maxsc = np.max(scores[0][:5])
|
||||||
|
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||||
|
attr = {
|
||||||
|
task: []
|
||||||
|
}
|
||||||
|
for i in range(len(indexs)):
|
||||||
|
atr = colattr[indexs[i]]
|
||||||
|
attr[task].append(atr)
|
||||||
|
return attr
|
||||||
|
|
||||||
|
# 获取属性检测结果
|
||||||
|
def get_recognition_attribute_result(self, category, input_img):
|
||||||
|
if category == "Blouse":
|
||||||
|
attr_dict = {}
|
||||||
|
for i in range(len(self.const.top_description_list)):
|
||||||
|
attr_description = self.const.top_description_list[i]
|
||||||
|
attr_model_path = self.const.top_model_list[i]
|
||||||
|
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
|
||||||
|
attr_dict = self.merge(attr_dict, present_dict)
|
||||||
|
|
||||||
|
elif category == 'Trousers' or category == "Skirt":
|
||||||
|
attr_dict = {}
|
||||||
|
for i in range(len(self.const.bottom_description_list)):
|
||||||
|
attr_description = self.const.bottom_description_list[i]
|
||||||
|
attr_model_path = self.const.bottom_model_list[i]
|
||||||
|
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
|
||||||
|
attr_dict = self.merge(attr_dict, present_dict)
|
||||||
|
|
||||||
|
elif category == 'Dress':
|
||||||
|
attr_dict = {}
|
||||||
|
for i in range(len(self.const.dress_description_list)):
|
||||||
|
attr_description = self.const.dress_description_list[i]
|
||||||
|
attr_model_path = self.const.dress_model_list[i]
|
||||||
|
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
|
||||||
|
attr_dict = self.merge(attr_dict, present_dict)
|
||||||
|
|
||||||
|
elif category == 'Outwear':
|
||||||
|
attr_dict = {}
|
||||||
|
for i in range(len(self.const.outwear_description_list)):
|
||||||
|
attr_description = self.const.outwear_description_list[i]
|
||||||
|
attr_model_path = self.const.outwear_model_list[i]
|
||||||
|
present_dict = self.recognition_attribute(attr_model_path, attr_description, input_img)
|
||||||
|
attr_dict = self.merge(attr_dict, present_dict)
|
||||||
|
else:
|
||||||
|
attr_dict = {}
|
||||||
|
return attr_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge(dict1, dict2):
|
||||||
|
res = {**dict1, **dict2}
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# for path in os.listdir('./test_img'):
|
||||||
|
# img_path = os.path.join(r'./test_img', path)
|
||||||
|
# request_item = BrandDnaModel(
|
||||||
|
# image_url=img_path,
|
||||||
|
# is_brand_dna=True
|
||||||
|
# )
|
||||||
|
# service = BrandDna(request_item)
|
||||||
|
# result_url = service.get_result()
|
||||||
|
# print(result_url)
|
||||||
|
request_item = BrandDnaModel(
|
||||||
|
image_url="aida-users/60/product_image/07cb5d5d-5022-44cc-b0d3-cc986cfebad1-2-60.png",
|
||||||
|
is_brand_dna=True
|
||||||
|
)
|
||||||
|
service = BrandDna(request_item)
|
||||||
|
result_url = service.get_result()
|
||||||
|
print(result_url)
|
||||||
@@ -69,3 +69,5 @@ TOOLS_FUNCTIONS_SUFFIX = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."
|
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."
|
||||||
|
|
||||||
|
GET_LANGUAGE_PREFIX = "Please identify the language. Only output the language name"
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
from dashscope import Generation
|
from dashscope import Generation
|
||||||
from retry import retry
|
from retry import retry
|
||||||
@@ -7,7 +8,8 @@ from urllib3.exceptions import NewConnectionError
|
|||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||||
from app.service.chat_robot.script.database import CustomDatabase
|
from app.service.chat_robot.script.database import CustomDatabase
|
||||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN
|
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||||
|
GET_LANGUAGE_PREFIX
|
||||||
from app.service.search_image_with_text.service import query
|
from app.service.search_image_with_text.service import query
|
||||||
|
|
||||||
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
get_database_table_description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||||
@@ -159,10 +161,11 @@ def search_from_internet(message):
|
|||||||
model='qwen-turbo',
|
model='qwen-turbo',
|
||||||
api_key=QWEN_API_KEY,
|
api_key=QWEN_API_KEY,
|
||||||
messages=message,
|
messages=message,
|
||||||
tools=tools,
|
prompt='The output must be in English.Keep the final result under 200 words.'
|
||||||
|
# tools=tools,
|
||||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||||
result_format='message', # 将输出设置为message形式
|
# result_format='message', # 将输出设置为message形式
|
||||||
enable_search='True'
|
# enable_search='True'
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -198,14 +201,9 @@ def get_response(messages):
|
|||||||
|
|
||||||
|
|
||||||
def call_with_messages(message, gender):
|
def call_with_messages(message, gender):
|
||||||
|
global tool_info
|
||||||
user_input = message
|
user_input = message
|
||||||
print('\n')
|
print('\n')
|
||||||
# messages = [
|
|
||||||
# {
|
|
||||||
# "content": input('请输入:'), # 提问示例:"现在几点了?" "一个小时后几点" "北京天气如何?"
|
|
||||||
# "role": "user"
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -223,14 +221,10 @@ def call_with_messages(message, gender):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
# 模型的第一轮调用
|
|
||||||
# first_response = get_response(messages)
|
|
||||||
# assistant_output = first_response.output.choices[0].message
|
|
||||||
# print(f"\n大模型第一轮输出信息:{first_response}\n")
|
|
||||||
# messages.append(assistant_output)
|
|
||||||
flag = True
|
flag = True
|
||||||
count = 1
|
count = 1
|
||||||
result_content = "我是一个时尚AI助手,请问有什么可以帮您"
|
# result_content = "我是一个时尚AI助手,请问有什么可以帮您"
|
||||||
|
result_content = "I am a fashion AI assistant, how can I help you?"
|
||||||
response_type = "chat"
|
response_type = "chat"
|
||||||
|
|
||||||
while flag and count <= 3:
|
while flag and count <= 3:
|
||||||
@@ -244,44 +238,53 @@ def call_with_messages(message, gender):
|
|||||||
print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容
|
print(f"最终答案:{assistant_output.content}") # 此处直接返回模型的回复,您可以根据您的业务,选择当无需调用工具时最终回复的内容
|
||||||
result_content = assistant_output.content
|
result_content = assistant_output.content
|
||||||
break
|
break
|
||||||
# 如果模型选择的工具是search_from_internet
|
# 如果模型选择的工具是internet_search
|
||||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'search_from_internet':
|
elif assistant_output.tool_calls[0]['function']['name'] == 'internet_search':
|
||||||
# tool_info = {"name": "search_from_internet", "role": "tool"}
|
tool_info = {"name": "search_from_internet", "role": "tool"}
|
||||||
# user_input = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['user_input']
|
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
|
||||||
# tool_info['content'] = search_from_internet(user_input)
|
message = [
|
||||||
# 如果模型选择的工具是get_database_table
|
{'role': 'assistant', 'content': content['query']}
|
||||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
]
|
||||||
tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
tool_info['content'] = search_from_internet(message)
|
||||||
# 如果模型选择的工具是get_table_info
|
|
||||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
|
|
||||||
tool_info = {"name": "get_table_info", "role": "tool"}
|
|
||||||
table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
|
|
||||||
tool_info['content'] = get_table_info(table_names)
|
|
||||||
# 如果模型选择的工具是query_database
|
|
||||||
elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
|
|
||||||
tool_info = {"name": "query_database", "role": "tool"}
|
|
||||||
sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
|
|
||||||
tool_info['content'] = query_database(sql_string)
|
|
||||||
flag = False
|
flag = False
|
||||||
result_content = tool_info['content']
|
result_content = tool_info['content'].output.text
|
||||||
response_type = "image"
|
# 如果模型选择的工具是get_database_table
|
||||||
|
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
||||||
|
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
||||||
|
# # 如果模型选择的工具是get_table_info
|
||||||
|
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_table_info':
|
||||||
|
# tool_info = {"name": "get_table_info", "role": "tool"}
|
||||||
|
# table_names = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['table_names']
|
||||||
|
# tool_info['content'] = get_table_info(table_names)
|
||||||
|
# # 如果模型选择的工具是query_database
|
||||||
|
# elif assistant_output.tool_calls[0]['function']['name'] == 'query_database':
|
||||||
|
# tool_info = {"name": "query_database", "role": "tool"}
|
||||||
|
# sql_string = json.loads(assistant_output.tool_calls[0]['function']['arguments'])['sql_string']
|
||||||
|
# tool_info['content'] = query_database(sql_string)
|
||||||
|
# flag = False
|
||||||
|
# result_content = tool_info['content']
|
||||||
|
# response_type = "image"
|
||||||
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
|
elif assistant_output.tool_calls[0]['function']['name'] == 'tutorial_tool':
|
||||||
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
|
tool_info = {"name": "tutorial_tool", "role": "tool", 'content': tutorial_tool()}
|
||||||
flag = False
|
flag = False
|
||||||
result_content = tool_info['content']
|
result_content = tool_info['content']
|
||||||
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
|
elif assistant_output.tool_calls[0]['function']['name'] == 'get_image_from_vector_db':
|
||||||
|
content = json.loads(assistant_output.tool_calls[0]['function']['arguments'])
|
||||||
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
|
tool_info = {"name": "get_image_from_vector_db", "role": "tool",
|
||||||
'content': get_image_from_vector_db(gender, user_input)}
|
'content': get_image_from_vector_db(gender, content['parameters']['content'])}
|
||||||
flag = False
|
flag = False
|
||||||
result_content = tool_info['content']
|
result_content = tool_info['content']
|
||||||
response_type = "image"
|
response_type = "image"
|
||||||
|
else:
|
||||||
|
tool_info = {"name": assistant_output.tool_calls[0]['function']['name'], 'content': 'null'}
|
||||||
|
logging.info(assistant_output.tool_calls[0]['function']['name'] + "(unknown tools)")
|
||||||
|
flag = False
|
||||||
|
|
||||||
print(f"工具输出信息:{tool_info['content']}\n")
|
print(f"工具输出信息:{tool_info['content']}\n")
|
||||||
messages.append(tool_info)
|
messages.append(tool_info)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
final_output = {"output": result_content}
|
final_output = {"output": result_content, "response_type": response_type}
|
||||||
final_output["response_type"] = response_type
|
|
||||||
QWenCallbackHandler.on_chain_end(qwen, final_output)
|
QWenCallbackHandler.on_chain_end(qwen, final_output)
|
||||||
|
|
||||||
# 模型的第二轮调用,对工具的输出进行总结
|
# 模型的第二轮调用,对工具的输出进行总结
|
||||||
@@ -298,5 +301,23 @@ def tutorial_tool():
|
|||||||
return TUTORIAL_TOOL_RETURN
|
return TUTORIAL_TOOL_RETURN
|
||||||
|
|
||||||
|
|
||||||
|
def get_language(message: str) -> str:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"content": message, # 用户message
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": GET_LANGUAGE_PREFIX, # ai message
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
first_response = get_response(messages)
|
||||||
|
assistant_output = first_response.output.choices[0].message.content
|
||||||
|
logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
|
||||||
|
return assistant_output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
call_with_messages()
|
get_language("")
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class Segmentation(object):
|
|||||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||||
try:
|
try:
|
||||||
np.save(file_path, seg_result)
|
np.save(file_path, seg_result)
|
||||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存失败: {e}")
|
logger.error(f"保存失败: {e}")
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ class Segmentation(object):
|
|||||||
seg_result = np.load(file_path)
|
seg_result = np.load(file_path)
|
||||||
return True, seg_result
|
return True, seg_result
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning("文件不存在")
|
# logger.warning("文件不存在")
|
||||||
return False, None
|
return False, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载失败: {e}")
|
logger.error(f"加载失败: {e}")
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.service.design_batch.utils.save_json import oss_upload_json
|
|||||||
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
||||||
|
|
||||||
id_lock = threading.Lock()
|
id_lock = threading.Lock()
|
||||||
celery_app = Celery('tasks', broker='amqp://guest:guest@10.1.2.213:5672//', backend='rpc://')
|
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||||
celery_app.conf.worker_hijack_root_logger = False
|
celery_app.conf.worker_hijack_root_logger = False
|
||||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||||
@@ -120,7 +120,7 @@ def batch_design(objects_data, tasks_id, json_name):
|
|||||||
|
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.join()
|
t.join()
|
||||||
|
logger.debug(object_response)
|
||||||
oss_upload_json(minio_client, object_response, json_name)
|
oss_upload_json(minio_client, object_response, json_name)
|
||||||
publish_status(tasks_id, "ok", json_name)
|
publish_status(tasks_id, "ok", json_name)
|
||||||
return object_response
|
return object_response
|
||||||
|
|||||||
@@ -51,19 +51,19 @@ class Segmentation:
|
|||||||
file_path = f"seg_cache/{image_id}.npy"
|
file_path = f"seg_cache/{image_id}.npy"
|
||||||
try:
|
try:
|
||||||
np.save(file_path, seg_result)
|
np.save(file_path, seg_result)
|
||||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存失败: {e}")
|
logger.error(f"保存失败: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_seg_result(image_id):
|
def load_seg_result(image_id):
|
||||||
file_path = f"seg_cache/{image_id}.npy"
|
file_path = f"seg_cache/{image_id}.npy"
|
||||||
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||||
try:
|
try:
|
||||||
seg_result = np.load(file_path)
|
seg_result = np.load(file_path)
|
||||||
return True, seg_result
|
return True, seg_result
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning("文件不存在")
|
# logger.warning("文件不存在")
|
||||||
return False, None
|
return False, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载失败: {e}")
|
logger.error(f"加载失败: {e}")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from app.service.design_batch.utils.MQ import publish_status
|
|||||||
|
|
||||||
|
|
||||||
async def start_design_batch_generate(data, file):
|
async def start_design_batch_generate(data, file):
|
||||||
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.total, data.tasks_id)
|
generate_clothes_task = batch_design.delay(json.loads(file.decode())['objects'], data.tasks_id, data.file_name)
|
||||||
print(generate_clothes_task)
|
print(generate_clothes_task)
|
||||||
publish_status(data.tasks_id, "0/100", "")
|
publish_status(data.tasks_id, "0/100", "")
|
||||||
return {"task_id": data.tasks_id}
|
return {"task_id": data.tasks_id}
|
||||||
|
|||||||
@@ -157,6 +157,6 @@ if __name__ == '__main__':
|
|||||||
],
|
],
|
||||||
"process_id": "83"
|
"process_id": "83"
|
||||||
}
|
}
|
||||||
task_id = 1
|
task_id = 10086
|
||||||
json_name = "test.json"
|
json_name = "test.json"
|
||||||
batch_design.delay(data['objects'], task_id, json_name)
|
batch_design.delay(data['objects'], task_id, json_name)
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ import json
|
|||||||
|
|
||||||
import pika
|
import pika
|
||||||
|
|
||||||
|
from app.core.config import RABBITMQ_PARAMS
|
||||||
|
|
||||||
|
|
||||||
def publish_status(task_id, progress, result):
|
def publish_status(task_id, progress, result):
|
||||||
connection = pika.BlockingConnection(pika.ConnectionParameters('10.1.2.213'))
|
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
channel = connection.channel()
|
channel = connection.channel()
|
||||||
channel.queue_declare(queue='DesignBatch', durable=True)
|
channel.queue_declare(queue='DesignBatch', durable=True)
|
||||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||||
|
|||||||
@@ -1,13 +1,21 @@
|
|||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def oss_upload_json(oss_client, json_data, object_name):
|
def oss_upload_json(oss_client, json_data, object_name):
|
||||||
try:
|
try:
|
||||||
with open(f"app/service/design_batch/response_json/{object_name}", 'w') as file:
|
save_dir = os.path.join(os.getcwd(), "service/design_batch", "response_data")
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
# 处理文件
|
||||||
|
file_path = os.path.join(save_dir, object_name)
|
||||||
|
with open(file_path, 'w') as file:
|
||||||
json.dump(json_data, file, indent=4)
|
json.dump(json_data, file, indent=4)
|
||||||
oss_client.fput_object("test", object_name, f"app/service/design_batch/response_json/{object_name}")
|
json_bytes = json.dumps(json_data).encode('utf-8')
|
||||||
|
oss_client.put_object("test", object_name, io.BytesIO(json_bytes), length=len(json_bytes), content_type="application/json")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ def design_generate(request_data):
|
|||||||
@RunTime
|
@RunTime
|
||||||
def design_generate_v2(request_data):
|
def design_generate_v2(request_data):
|
||||||
objects_data = request_data.dict()['objects']
|
objects_data = request_data.dict()['objects']
|
||||||
|
request_id = request_data.requestId
|
||||||
threads = []
|
threads = []
|
||||||
|
|
||||||
def process_object(step, object):
|
def process_object(step, object):
|
||||||
@@ -146,7 +147,7 @@ def design_generate_v2(request_data):
|
|||||||
items_response = {
|
items_response = {
|
||||||
'layers': [],
|
'layers': [],
|
||||||
'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "",
|
'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "",
|
||||||
'requestId': object['requestId'] if 'requestId' in object.keys() else ""
|
'requestId': request_id
|
||||||
}
|
}
|
||||||
if basic['single_overall'] == "overall":
|
if basic['single_overall'] == "overall":
|
||||||
item_results = []
|
item_results = []
|
||||||
@@ -197,9 +198,8 @@ def design_generate_v2(request_data):
|
|||||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
||||||
})
|
})
|
||||||
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
||||||
|
|
||||||
# 发送结果给java端
|
# 发送结果给java端
|
||||||
url = "https://3998-117-143-125-51.ngrok-free.app/api/third/party/receiveDesignResults"
|
url = JAVA_STREAM_API_URL
|
||||||
headers = {
|
headers = {
|
||||||
'Accept': "*/*",
|
'Accept': "*/*",
|
||||||
'Accept-Encoding': "gzip, deflate, br",
|
'Accept-Encoding': "gzip, deflate, br",
|
||||||
@@ -207,11 +207,11 @@ def design_generate_v2(request_data):
|
|||||||
'Connection': "keep-alive",
|
'Connection': "keep-alive",
|
||||||
'Content-Type': "application/json"
|
'Content-Type': "application/json"
|
||||||
}
|
}
|
||||||
|
# logger.info(items_response)
|
||||||
response = post_request(url, json_data=items_response, headers=headers)
|
response = post_request(url, json_data=items_response, headers=headers)
|
||||||
if response:
|
if response:
|
||||||
# 打印结果
|
# 打印结果
|
||||||
logger.info(response.text)
|
logger.info(response.text)
|
||||||
logger.info(items_response)
|
|
||||||
|
|
||||||
for step, object in enumerate(objects_data):
|
for step, object in enumerate(objects_data):
|
||||||
t = threading.Thread(target=process_object, args=(step, object))
|
t = threading.Thread(target=process_object, args=(step, object))
|
||||||
|
|||||||
@@ -66,19 +66,19 @@ class Segmentation:
|
|||||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||||
try:
|
try:
|
||||||
np.save(file_path, seg_result)
|
np.save(file_path, seg_result)
|
||||||
logger.info(f"保存成功 :{os.path.abspath(file_path)}")
|
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存失败: {e}")
|
logger.error(f"保存失败: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_seg_result(image_id):
|
def load_seg_result(image_id):
|
||||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||||
logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||||
try:
|
try:
|
||||||
seg_result = np.load(file_path)
|
seg_result = np.load(file_path)
|
||||||
return True, seg_result
|
return True, seg_result
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logger.warning("文件不存在")
|
# logger.warning("文件不存在")
|
||||||
return False, None
|
return False, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载失败: {e}")
|
logger.error(f"加载失败: {e}")
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from app.core.config import *
|
|||||||
|
|
||||||
def keypoint_preprocess(img_path):
|
def keypoint_preprocess(img_path):
|
||||||
img = mmcv.imread(img_path)
|
img = mmcv.imread(img_path)
|
||||||
|
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
|
||||||
img_scale = (256, 256)
|
img_scale = (256, 256)
|
||||||
h, w = img.shape[:2]
|
h, w = img.shape[:2]
|
||||||
img = cv2.resize(img, img_scale)
|
img = cv2.resize(img, img_scale)
|
||||||
@@ -62,7 +63,11 @@ def keypoint_postprocess(output, scale_factor):
|
|||||||
scale_matrix = np.diag(scale_factor)
|
scale_matrix = np.diag(scale_factor)
|
||||||
nan = np.isinf(scale_matrix)
|
nan = np.isinf(scale_matrix)
|
||||||
scale_matrix[nan] = 0
|
scale_matrix[nan] = 0
|
||||||
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
# 应用缩放因子
|
||||||
|
scaled_result = np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||||||
|
# 补偿边框偏移
|
||||||
|
compensated_result = scaled_result - 25
|
||||||
|
return compensated_result
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ class DesignPreprocessing:
|
|||||||
seg_result = np.load(file_path)
|
seg_result = np.load(file_path)
|
||||||
return True, seg_result
|
return True, seg_result
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logging.info("文件不存在")
|
# logging.info("文件不存在")
|
||||||
return False, None
|
return False, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"加载失败: {e}")
|
logging.warning(f"加载失败: {e}")
|
||||||
@@ -277,7 +277,7 @@ class DesignPreprocessing:
|
|||||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||||
try:
|
try:
|
||||||
np.save(file_path, seg_result)
|
np.save(file_path, seg_result)
|
||||||
logging.info(f"保存成功,{os.path.abspath(file_path)}")
|
logging.debug(f"保存成功,{os.path.abspath(file_path)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"保存失败: {e}")
|
logging.warning(f"保存失败: {e}")
|
||||||
|
|
||||||
|
|||||||
126
app/service/generate_image/service_generate_multi_view.py
Normal file
126
app/service/generate_image/service_generate_multi_view.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
"""
|
||||||
|
@Project :trinity_client
|
||||||
|
@File :service_att_recognition.py
|
||||||
|
@Author :周成融
|
||||||
|
@Date :2023/7/26 12:01:05
|
||||||
|
@detail :
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import redis
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
|
||||||
|
from app.core.config import *
|
||||||
|
from app.schemas.generate_image import GenerateMultiViewModel
|
||||||
|
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||||
|
from app.service.utils.oss_client import oss_get_image
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateMultiView:
|
||||||
|
def __init__(self, request_data):
|
||||||
|
if DEBUG is False:
|
||||||
|
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
|
self.channel = self.connection.channel()
|
||||||
|
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
|
# self.channel = self.connection.channel()
|
||||||
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
self.grpc_client = grpcclient.InferenceServerClient(url=GMV_MODEL_URL)
|
||||||
|
|
||||||
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
self.image = self.get_image(request_data.image_url)
|
||||||
|
self.tasks_id = request_data.tasks_id
|
||||||
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
|
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
|
self.redis_client.expire(self.tasks_id, 600)
|
||||||
|
|
||||||
|
def get_image(self, image_url):
|
||||||
|
try:
|
||||||
|
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
def callback(self, result, error):
|
||||||
|
if error:
|
||||||
|
self.generate_data['status'] = "FAILURE"
|
||||||
|
self.generate_data['message'] = str(error)
|
||||||
|
# self.generate_data['data'] = str(error)
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
|
else:
|
||||||
|
# pil图像转成numpy数组
|
||||||
|
images = result.as_numpy("generated_image")
|
||||||
|
# for id, img in enumerate(images):
|
||||||
|
# cv2.imwrite(f"{id}.png", img)
|
||||||
|
# image_url = ""
|
||||||
|
image_url = upload_png_sd(images[6], user_id=self.user_id, category="multi_view", file_name=f"{self.tasks_id}.png")
|
||||||
|
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||||
|
self.generate_data['status'] = "SUCCESS"
|
||||||
|
self.generate_data['message'] = "success"
|
||||||
|
self.generate_data['image_url'] = str(image_url)
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
|
|
||||||
|
def read_tasks_status(self):
|
||||||
|
status_data = self.redis_client.get(self.tasks_id)
|
||||||
|
return json.loads(status_data), status_data
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
try:
|
||||||
|
images = [np.array(self.image).astype(np.uint8)] * 1
|
||||||
|
|
||||||
|
image_obj = np.array(images, dtype=np.uint8)
|
||||||
|
|
||||||
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||||
|
|
||||||
|
input_image.set_data_from_numpy(image_obj)
|
||||||
|
|
||||||
|
inputs = [input_image]
|
||||||
|
ctx = self.grpc_client.async_infer(model_name=GMV_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||||
|
|
||||||
|
time_out = 600
|
||||||
|
generate_data = None
|
||||||
|
while time_out > 0:
|
||||||
|
generate_data, _ = self.read_tasks_status()
|
||||||
|
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||||
|
ctx.cancel()
|
||||||
|
break
|
||||||
|
elif generate_data['status'] == "SUCCESS":
|
||||||
|
break
|
||||||
|
time_out -= 1
|
||||||
|
time.sleep(0.1)
|
||||||
|
return generate_data
|
||||||
|
except Exception as e:
|
||||||
|
self.generate_data['status'] = "FAILURE"
|
||||||
|
self.generate_data['message'] = str(e)
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
|
||||||
|
raise Exception(str(e))
|
||||||
|
finally:
|
||||||
|
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||||
|
if DEBUG is False:
|
||||||
|
self.channel.basic_publish(exchange='', routing_key=GMV_RABBITMQ_QUEUES, body=str_generate_data)
|
||||||
|
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||||
|
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||||
|
|
||||||
|
|
||||||
|
def infer_cancel(tasks_id):
|
||||||
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||||
|
generate_data = json.dumps(data)
|
||||||
|
redis_client.set(tasks_id, generate_data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
rd = GenerateMultiViewModel(
|
||||||
|
tasks_id="123-89",
|
||||||
|
image_url="aida-sys-image/images/female/outwear/0628000123.jpg",
|
||||||
|
)
|
||||||
|
server = GenerateMultiView(rd)
|
||||||
|
print(server.get_result())
|
||||||
@@ -1,4 +1,196 @@
|
|||||||
#!/usr/bin/env python
|
# #!/usr/bin/env python
|
||||||
|
# # -*- coding: UTF-8 -*-
|
||||||
|
# """
|
||||||
|
# @Project :trinity_client
|
||||||
|
# @File :service_att_recognition.py
|
||||||
|
# @Author :周成融
|
||||||
|
# @Date :2023/7/26 12:01:05
|
||||||
|
# @detail :
|
||||||
|
# """
|
||||||
|
# import json
|
||||||
|
# import logging
|
||||||
|
# import time
|
||||||
|
#
|
||||||
|
# import cv2
|
||||||
|
# import numpy as np
|
||||||
|
# import redis
|
||||||
|
# import tritonclient.grpc as grpcclient
|
||||||
|
# from PIL import Image
|
||||||
|
# from tritonclient.utils import np_to_triton_dtype
|
||||||
|
#
|
||||||
|
# from app.core.config import *
|
||||||
|
# from app.schemas.generate_image import GenerateProductImageModel
|
||||||
|
# from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||||
|
# from app.service.utils.oss_client import oss_get_image
|
||||||
|
#
|
||||||
|
# logger = logging.getLogger()
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class GenerateProductImage:
|
||||||
|
# def __init__(self, request_data):
|
||||||
|
# if DEBUG is False:
|
||||||
|
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
|
# self.channel = self.connection.channel()
|
||||||
|
# # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
|
# # self.channel = self.connection.channel()
|
||||||
|
# # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
# self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||||
|
# self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
# self.category = "product_image"
|
||||||
|
# self.image_strength = request_data.image_strength
|
||||||
|
# self.batch_size = 1
|
||||||
|
# self.product_type = request_data.product_type
|
||||||
|
# self.prompt = request_data.prompt
|
||||||
|
# self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url)
|
||||||
|
# self.tasks_id = request_data.tasks_id
|
||||||
|
# self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
|
# self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||||
|
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
|
# self.redis_client.expire(self.tasks_id, 600)
|
||||||
|
#
|
||||||
|
# def callback(self, result, error):
|
||||||
|
# if error:
|
||||||
|
# self.gen_product_data['status'] = "FAILURE"
|
||||||
|
# self.gen_product_data['message'] = str(error)
|
||||||
|
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
|
# else:
|
||||||
|
# # pil图像转成numpy数组
|
||||||
|
# image = result.as_numpy("generated_inpaint_image")
|
||||||
|
# image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
||||||
|
# cropped_image = post_processing_image(image_result, self.left, self.top)
|
||||||
|
# image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||||
|
# self.gen_product_data['status'] = "SUCCESS"
|
||||||
|
# self.gen_product_data['message'] = "success"
|
||||||
|
# self.gen_product_data['image_url'] = str(image_url)
|
||||||
|
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
|
#
|
||||||
|
# def read_tasks_status(self):
|
||||||
|
# status_data = self.redis_client.get(self.tasks_id)
|
||||||
|
# return json.loads(status_data), status_data
|
||||||
|
#
|
||||||
|
# def get_result(self):
|
||||||
|
# try:
|
||||||
|
# prompts = [self.prompt] * self.batch_size
|
||||||
|
# self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
||||||
|
# self.image = cv2.resize(self.image, (1024, 1024))
|
||||||
|
# images = [self.image.astype(np.uint8)] * self.batch_size
|
||||||
|
#
|
||||||
|
# if self.product_type == "single":
|
||||||
|
# text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
|
||||||
|
# image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3))
|
||||||
|
# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
|
||||||
|
# else:
|
||||||
|
# text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||||
|
# image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3))
|
||||||
|
# image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
|
||||||
|
#
|
||||||
|
# # 假设 prompts、images 和 self.image_strength 已经定义
|
||||||
|
#
|
||||||
|
# input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
|
# input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||||
|
# input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||||
|
#
|
||||||
|
# input_text.set_data_from_numpy(text_obj)
|
||||||
|
# input_image.set_data_from_numpy(image_obj)
|
||||||
|
# input_image_strength.set_data_from_numpy(image_strength_obj)
|
||||||
|
#
|
||||||
|
# inputs = [input_text, input_image, input_image_strength]
|
||||||
|
#
|
||||||
|
# if self.product_type == "single":
|
||||||
|
# ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback)
|
||||||
|
# else:
|
||||||
|
# ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
|
||||||
|
#
|
||||||
|
# time_out = 600
|
||||||
|
# while time_out > 0:
|
||||||
|
# gen_product_data, _ = self.read_tasks_status()
|
||||||
|
# if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
|
||||||
|
# ctx.cancel()
|
||||||
|
# break
|
||||||
|
# elif gen_product_data['status'] == "SUCCESS":
|
||||||
|
# break
|
||||||
|
# time_out -= 1
|
||||||
|
# time.sleep(0.1)
|
||||||
|
# gen_product_data, _ = self.read_tasks_status()
|
||||||
|
# return gen_product_data
|
||||||
|
# except Exception as e:
|
||||||
|
# self.gen_product_data['status'] = "FAILURE"
|
||||||
|
# self.gen_product_data['message'] = str(e)
|
||||||
|
# self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
|
# raise Exception(str(e))
|
||||||
|
# finally:
|
||||||
|
# dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||||
|
# if DEBUG is False:
|
||||||
|
# self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||||
|
# logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def infer_cancel(tasks_id):
|
||||||
|
# redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
# data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||||
|
# gen_product_data = json.dumps(data)
|
||||||
|
# redis_client.set(tasks_id, gen_product_data)
|
||||||
|
# return data
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def pre_processing_image(image_url):
|
||||||
|
# image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||||
|
# # resize 原图至1024*1024
|
||||||
|
# image = image.resize((int(1024 / image.height * image.width), 1024))
|
||||||
|
#
|
||||||
|
# # 原始图片的尺寸
|
||||||
|
# width, height = image.size
|
||||||
|
#
|
||||||
|
# new_height, new_width = 1024, 1024
|
||||||
|
# # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
|
||||||
|
# pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
|
||||||
|
#
|
||||||
|
# # 将原始图片粘贴到新的画布中心
|
||||||
|
# left = (new_width - width) // 2
|
||||||
|
# top = (new_height - height) // 2
|
||||||
|
# pad_image.paste(image, (left, top))
|
||||||
|
#
|
||||||
|
# # 将画布 resize 成宽度 1024,长度 1024
|
||||||
|
# resized_image = pad_image.resize((1024, 1024))
|
||||||
|
# image_size = (1024, 1024)
|
||||||
|
#
|
||||||
|
# if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
|
||||||
|
# # 创建白色背景
|
||||||
|
# background = Image.new("RGB", image_size, (255, 255, 255))
|
||||||
|
# # 将图片粘贴到白色背景上
|
||||||
|
# background.paste(resized_image, mask=resized_image.split()[3])
|
||||||
|
# image = np.array(background)
|
||||||
|
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
# return image, image_size, left, top
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def post_processing_image(image, left, top):
|
||||||
|
# resized_image = image.resize((int(image.width * (768 / image.height)), 768))
|
||||||
|
# # 计算裁剪的坐标
|
||||||
|
# left = (resized_image.width - 512) // 2
|
||||||
|
# upper = 0
|
||||||
|
# right = left + 512
|
||||||
|
# lower = 768
|
||||||
|
#
|
||||||
|
# # 进行裁剪
|
||||||
|
# cropped_image = resized_image.crop((left, upper, right, lower))
|
||||||
|
# return cropped_image
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# rd = GenerateProductImageModel(
|
||||||
|
# tasks_id="123-89",
|
||||||
|
# # prompt="",
|
||||||
|
# image_strength=0.7,
|
||||||
|
# prompt="The best quality, masterpiece,outwear, 8K realistic, HUD",
|
||||||
|
# image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png",
|
||||||
|
# product_type="overall"
|
||||||
|
# )
|
||||||
|
# server = GenerateProductImage(rd)
|
||||||
|
# print(server.get_result())
|
||||||
|
|
||||||
|
# 旧版product
|
||||||
|
# !/usr/bin/env python
|
||||||
# -*- coding: UTF-8 -*-
|
# -*- coding: UTF-8 -*-
|
||||||
"""
|
"""
|
||||||
@Project :trinity_client
|
@Project :trinity_client
|
||||||
@@ -34,14 +226,14 @@ class GenerateProductImage:
|
|||||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||||
# self.channel = self.connection.channel()
|
# self.channel = self.connection.channel()
|
||||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url="10.1.1.243:18001")
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.category = "product_image"
|
self.category = "product_image"
|
||||||
self.image_strength = request_data.image_strength
|
self.image_strength = request_data.image_strength
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.product_type = request_data.product_type
|
self.product_type = request_data.product_type
|
||||||
self.prompt = request_data.prompt
|
self.prompt = request_data.prompt
|
||||||
self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url)
|
self.image = pre_processing_image(request_data.image_url)
|
||||||
self.tasks_id = request_data.tasks_id
|
self.tasks_id = request_data.tasks_id
|
||||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||||
@@ -55,10 +247,13 @@ class GenerateProductImage:
|
|||||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||||
else:
|
else:
|
||||||
# pil图像转成numpy数组
|
# pil图像转成numpy数组
|
||||||
image = result.as_numpy("generated_inpaint_image")
|
if self.product_type == "single":
|
||||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
image = result.as_numpy("generated_cnet_image")
|
||||||
cropped_image = post_processing_image(image_result, self.left, self.top)
|
else:
|
||||||
image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
image = result.as_numpy("generated_inpaint_image")
|
||||||
|
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||||
|
# cropped_image = post_processing_image(image_result, self.left, self.top)
|
||||||
|
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||||
self.gen_product_data['status'] = "SUCCESS"
|
self.gen_product_data['status'] = "SUCCESS"
|
||||||
self.gen_product_data['message'] = "success"
|
self.gen_product_data['message'] = "success"
|
||||||
self.gen_product_data['image_url'] = str(image_url)
|
self.gen_product_data['image_url'] = str(image_url)
|
||||||
@@ -71,17 +266,18 @@ class GenerateProductImage:
|
|||||||
def get_result(self):
|
def get_result(self):
|
||||||
try:
|
try:
|
||||||
prompts = [self.prompt] * self.batch_size
|
prompts = [self.prompt] * self.batch_size
|
||||||
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
|
||||||
self.image = cv2.resize(self.image, (1024, 1024))
|
self.image = cv2.cvtColor(self.image, cv2.COLOR_RGBA2RGB)
|
||||||
|
# self.image = cv2.resize(self.image, (1024, 1024))
|
||||||
images = [self.image.astype(np.uint8)] * self.batch_size
|
images = [self.image.astype(np.uint8)] * self.batch_size
|
||||||
|
|
||||||
if self.product_type == "single":
|
if self.product_type == "single":
|
||||||
text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
|
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3))
|
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||||
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
|
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1))
|
||||||
else:
|
else:
|
||||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||||
image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3))
|
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||||
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
|
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
|
||||||
|
|
||||||
# 假设 prompts、images 和 self.image_strength 已经定义
|
# 假设 prompts、images 和 self.image_strength 已经定义
|
||||||
@@ -97,9 +293,9 @@ class GenerateProductImage:
|
|||||||
inputs = [input_text, input_image, input_image_strength]
|
inputs = [input_text, input_image, input_image_strength]
|
||||||
|
|
||||||
if self.product_type == "single":
|
if self.product_type == "single":
|
||||||
ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback)
|
ctx = self.grpc_client.async_infer(model_name="stable_diffusion_1_5_cnet", inputs=inputs, callback=self.callback)
|
||||||
else:
|
else:
|
||||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
|
ctx = self.grpc_client.async_infer(model_name="diffusion_ensemble_all", inputs=inputs, callback=self.callback)
|
||||||
|
|
||||||
time_out = 600
|
time_out = 600
|
||||||
while time_out > 0:
|
while time_out > 0:
|
||||||
@@ -135,33 +331,26 @@ def infer_cancel(tasks_id):
|
|||||||
|
|
||||||
def pre_processing_image(image_url):
|
def pre_processing_image(image_url):
|
||||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||||
# resize 原图至1024*1024
|
# 调整图片高度为768像素,保持宽高比
|
||||||
image = image.resize((int(1024 / image.height * image.width), 1024))
|
|
||||||
|
|
||||||
# 原始图片的尺寸
|
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
|
new_height = 768
|
||||||
|
new_width = int(width * (new_height / height))
|
||||||
|
resized_image = image.resize((new_width, new_height))
|
||||||
|
|
||||||
new_height, new_width = 1024, 1024
|
# 创建一个512x768的透明图片
|
||||||
# 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
|
result_image = Image.new("RGBA", (512, 768), (255, 255, 255, 255))
|
||||||
pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
|
|
||||||
|
|
||||||
# 将原始图片粘贴到新的画布中心
|
# 计算需要粘贴的位置,使图片居中
|
||||||
left = (new_width - width) // 2
|
x_offset = (512 - new_width) // 2
|
||||||
top = (new_height - height) // 2
|
y_offset = 0
|
||||||
pad_image.paste(image, (left, top))
|
|
||||||
|
|
||||||
# 将画布 resize 成宽度 1024,长度 1024
|
# 将调整大小后的图片粘贴到透明图片上
|
||||||
resized_image = pad_image.resize((1024, 1024))
|
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||||
image_size = (1024, 1024)
|
|
||||||
|
|
||||||
if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
|
image = np.array(result_image)
|
||||||
# 创建白色背景
|
|
||||||
background = Image.new("RGB", image_size, (255, 255, 255))
|
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||||
# 将图片粘贴到白色背景上
|
return image
|
||||||
background.paste(resized_image, mask=resized_image.split()[3])
|
|
||||||
image = np.array(background)
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
return image, image_size, left, top
|
|
||||||
|
|
||||||
|
|
||||||
def post_processing_image(image, left, top):
|
def post_processing_image(image, left, top):
|
||||||
@@ -182,9 +371,9 @@ if __name__ == '__main__':
|
|||||||
tasks_id="123-89",
|
tasks_id="123-89",
|
||||||
# prompt="",
|
# prompt="",
|
||||||
image_strength=0.7,
|
image_strength=0.7,
|
||||||
prompt="The best quality, masterpiece,outwear, 8K realistic, HUD",
|
prompt="The best quality, masterpiece, real image.,high quality clothing details,8K realistic,HDR",
|
||||||
image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png",
|
image_url="aida-results/result_40c7924e-e220-11ef-8ea2-0242ac150003.png",
|
||||||
product_type="overall"
|
product_type="single"
|
||||||
)
|
)
|
||||||
server = GenerateProductImage(rd)
|
server = GenerateProductImage(rd)
|
||||||
print(server.get_result())
|
print(server.get_result())
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from requests import RequestException
|
|||||||
from retry import retry
|
from retry import retry
|
||||||
|
|
||||||
from app.core.config import QWEN_API_KEY
|
from app.core.config import QWEN_API_KEY
|
||||||
|
from app.service.chat_robot.script.service.CallQWen import get_language
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -93,7 +94,13 @@ def get_translation_from_llama3(text):
|
|||||||
|
|
||||||
# prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
|
# prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
|
||||||
|
|
||||||
# 创建请求的负载
|
# 先获取用户输入文本的语言
|
||||||
|
language = get_language(text)
|
||||||
|
|
||||||
|
if 'English' in language:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# 创建请求的负载 translator是自定义的翻译模型
|
||||||
payload = {
|
payload = {
|
||||||
"model": "translator",
|
"model": "translator",
|
||||||
"prompt": f"[{text}]",
|
"prompt": f"[{text}]",
|
||||||
@@ -117,6 +124,26 @@ def get_translation_from_llama3(text):
|
|||||||
print(response.text)
|
print(response.text)
|
||||||
|
|
||||||
|
|
||||||
|
# 在llama3中创建一个翻译模型
|
||||||
|
# def create_model_with_llama(text):
|
||||||
|
# url = "http://localhost:11434/api/create"
|
||||||
|
# # url = "http://10.1.1.240:1143/api/generate"
|
||||||
|
#
|
||||||
|
# # prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
|
||||||
|
#
|
||||||
|
# # 创建翻译器的配置文件
|
||||||
|
# payload = {
|
||||||
|
# "model": "translator",
|
||||||
|
# "modelfile": "FROM llama3\nSYSTEM Translate everything within the brackets [] into English."
|
||||||
|
# "Never translate or modify any English input."
|
||||||
|
# "The input must be fully translated into coherent English sentences."
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# # 将负载转换为 JSON 格式
|
||||||
|
# headers = {'Content-Type': 'application/json'}
|
||||||
|
# response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main function"""
|
"""Main function"""
|
||||||
text = get_translation_from_llama3("[火焰]")
|
text = get_translation_from_llama3("[火焰]")
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ def RunTime(func):
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
res = func(*args, **kwargs)
|
res = func(*args, **kwargs)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
# if t2 - t1 > 0.05:
|
if t2 - t1 > 0.05:
|
||||||
# logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||||
logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
# logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -22,7 +22,8 @@ def ClassCallRunTime(func):
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
execution_time = end_time - start_time
|
execution_time = end_time - start_time
|
||||||
class_name = args[0].__class__.__name__ # 获取类名
|
class_name = args[0].__class__.__name__ # 获取类名
|
||||||
print(f"class name: {class_name} , run time is : {execution_time} s")
|
if execution_time > 0.05:
|
||||||
|
logging.info(f"class name: {class_name} , run time is : {execution_time} s")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ if __name__ == '__main__':
|
|||||||
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||||
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
||||||
# url = "aida-users/89/single_logo/123-89.png"
|
# url = "aida-users/89/single_logo/123-89.png"
|
||||||
url = "aida-users/89/test/123-89.png"
|
url = "aida-users/89/123-89.png"
|
||||||
|
|
||||||
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
||||||
read_type = "2"
|
read_type = "2"
|
||||||
|
|||||||
Reference in New Issue
Block a user