feat product image 新增product type 参数 ,解决single item 无法检测头部的问题

fix
This commit is contained in:
zhouchengrong
2024-07-04 14:14:57 +08:00
parent 24142a01cc
commit eede159507
13 changed files with 163 additions and 101 deletions

View File

@@ -34,13 +34,14 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
]
"""
try:
logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}")
for item in request_item:
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
if DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else:
service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result()
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -65,10 +66,11 @@ def category_recognition(request_item: list[CategoryRecognitionModel]):
]
"""
try:
logger.info(f"category_recognition request item is : @@@@@@:{request_item}")
for item in request_item:
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
service = CategoryRecognition(request_data=request_item)
data = service.get_result()
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -30,9 +30,9 @@ def chat_robot(request_data: ChatRobotModel):
}
"""
try:
logger.info(f"chat_robot request item is : @@@@@@:{request_data}")
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = chat(post_data=request_data)
logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -66,7 +66,24 @@ def design(request_data: DesignModel):
],
"path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png",
"print": {
"IfSingle": false,
"single": {
"location": [
[
200.0,
200.0
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png"
],
"print_scale_list": [
1.0
]
},
"overall": {
"location": [
[
512.0,
@@ -83,6 +100,24 @@ def design(request_data: DesignModel):
1.0
]
},
"element": {
"element_angle_list": [
0.0
],
"element_path_list": [
"aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png"
],
"element_scale_list": [
0.2731036750637755
],
"location": [
[
228.63694825464364,
406.4843844199667
]
]
}
},
"priority": 10,
"resize_scale": [
1.0,
@@ -102,9 +137,9 @@ def design(request_data: DesignModel):
}
"""
try:
logger.info(f"design request item is : @@@@@@:{request_data.dict()}")
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = generate(request_data=request_data)
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"design response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -124,13 +159,13 @@ def get_progress(request_data: DesignProgressModel):
}
"""
try:
logger.info(f"get_progress request item is : @@@@@@:{request_data.dict()}")
logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}")
process_id = request_data.process_id
r = Redis()
data = r.read(key=process_id)
if data is None:
raise ValueError(f"No progress ID: {process_id}")
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}")
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}")
except Exception as e:
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -150,10 +185,10 @@ def model_process(request_data: ModelProgressModel):
}
"""
try:
logger.info(f"model_process request item is : @@@@@@:{request_data.dict()}")
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = model_transpose(image_path=request_data.model_path)
logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"model_process response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"model_process Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -30,10 +30,10 @@ def design_pre_processing(request_data: DesignPreProcessingModel):
}
"""
try:
logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}")
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data)}")
server = DesignPreprocessing()
data = server.pipeline(image_list=request_data.sketches)
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"design response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -38,7 +38,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun
}
"""
try:
logger.info(f"generate_image request item is : @@@@@@:{request_item}")
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -52,7 +52,7 @@ def generate_image(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_image_infer_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"generate_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -78,7 +78,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_
}
"""
try:
logger.info(f"generate_single_logo request item is : @@@@@@:{request_item}")
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateSingleLogoImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -92,7 +92,7 @@ def generate_single_logo_image(tasks_id: str):
try:
logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}")
data = generate_single_logo_cancel(tasks_id)
logger.info(f"generate_single_logo_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"generate_single_logo_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -110,17 +110,20 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **image_strength**: 生成强度,越低越接近原图
- **product_type**: 输入single item 还是 overall item
示例参数:
{
"tasks_id": "123-89",
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
"image_strength": 0.8
"image_strength": 0.8,
"product_type": "overall"
}
"""
try:
logger.info(f"generate_product_image request item is : @@@@@@:{request_item}")
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateProductImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -134,7 +137,7 @@ def generate_product_image(tasks_id: str):
try:
logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_product_image_cancel(tasks_id)
logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -162,7 +165,7 @@ def generate_relight_image(request_item: GenerateRelightImageModel, background_t
}
"""
try:
logger.info(f"generate_relight_image request item is : @@@@@@:{request_item}")
logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateRelightImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -176,7 +179,7 @@ def generate_relight_image(tasks_id: str):
try:
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_relight_image_cancel(tasks_id)
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -27,7 +27,7 @@ def prompt_generation(request_data: PromptGenerationImageModel):
try:
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
data = translate_to_en(request_data.text)
logger.info(f"prompt_generation response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"prompt_generation response @@@@@@:{data}")
except Exception as e:
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -27,7 +27,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg
}
"""
try:
logger.info(f"super_resolution request item is : @@@@@@:{request_item}")
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = SuperResolution(request_item)
background_tasks.add_task(service.sr_result)
except Exception as e:
@@ -41,7 +41,7 @@ def super_resolution(tasks_id: str):
try:
logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}")
data = infer_cancel(tasks_id)
logger.info(f"sr_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"sr_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"sr_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -1,8 +1,10 @@
import json
import logging
from fastapi import APIRouter
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
from fastapi import FastAPI, HTTPException
from fastapi import APIRouter
from fastapi import HTTPException
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
from app.schemas.response_template import ResponseModel
logger = logging.getLogger()
@@ -18,7 +20,7 @@ def test(id: int):
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
"local_oss_server": OSS
}
logger.info(data)
logger.info(json.dumps(data))
if id == 1:
raise HTTPException(status_code=404, detail="Item not found")

View File

@@ -118,9 +118,11 @@ GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}")
# Generate Single Logo service config
# Generate Product service config
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
GPI_MODEL_NAME = 'diffusion_ensemble_all'
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.240:10041'
# Generate Single Logo service config

View File

@@ -21,6 +21,7 @@ class GenerateProductImageModel(BaseModel):
prompt: str
image_url: str
image_strength: float
product_type: str
class GenerateRelightImageModel(BaseModel):

View File

@@ -1,22 +1,23 @@
import json
import logging
from loguru import logger
from langchain.agents import Tool
from langchain.utilities import SerpAPIWrapper
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.callbacks import FileCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.callbacks import FileCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.utilities import SerpAPIWrapper
from loguru import logger
from app.core.config import *
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
from app.core.config import *
import os
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
@@ -110,5 +111,5 @@ def chat(post_data):
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs['response_type']
}
logging.info(api_response)
logging.info(json.dumps(api_response))
return api_response

View File

@@ -39,6 +39,7 @@ class GenerateProductImage:
self.category = "product_image"
self.image_strength = request_data.image_strength
self.batch_size = 1
self.product_type = request_data.product_type
self.prompt = request_data.prompt
self.image, self.image_size = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
@@ -54,6 +55,9 @@ class GenerateProductImage:
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
if self.product_type == "single":
image = result.as_numpy("generated_cnet_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
@@ -73,10 +77,17 @@ class GenerateProductImage:
self.image = cv2.resize(self.image, (512, 768))
images = [self.image.astype(np.uint8)] * self.batch_size
if self.product_type == "single":
text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
else:
text_obj = np.array(prompts, dtype="object").reshape(1)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
# 假设 prompts、images 和 self.image_strength 已经定义
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
@@ -86,7 +97,11 @@ class GenerateProductImage:
inputs = [input_text, input_image, input_image_strength]
input_image_strength.set_data_from_numpy(image_strength_obj)
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback)
if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
@@ -151,6 +166,7 @@ if __name__ == '__main__':
image_strength=0.9,
# prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
product_type="single"
)
server = GenerateProductImage(rd)
print(server.get_result())

View File

@@ -1,51 +1,51 @@
from app.core.config import LOGS_PATH
LOGGER_CONFIG_DICT = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"simple": {"format": "%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s"}
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'simple': {'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "simple",
"stream": "ext://sys.stdout",
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'INFO',
'formatter': 'simple',
'stream': 'ext://sys.stdout',
},
"info_file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "INFO",
"formatter": "simple",
"filename": f"{LOGS_PATH}info.log",
"maxBytes": 10485760,
"backupCount": 50,
"encoding": "utf8",
'info_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'INFO',
'formatter': 'simple',
'filename': f'{LOGS_PATH}info.log',
'maxBytes': 10485760,
'backupCount': 50,
'encoding': 'utf8',
},
"error_file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "ERROR",
"formatter": "simple",
"filename": f"{LOGS_PATH}error.log",
"maxBytes": 10485760,
"backupCount": 20,
"encoding": "utf8",
'error_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'ERROR',
'formatter': 'simple',
'filename': f'{LOGS_PATH}error.log',
'maxBytes': 10485760,
'backupCount': 20,
'encoding': 'utf8',
},
"debug_file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "simple",
"filename": f"{LOGS_PATH}debug.log",
"maxBytes": 10485760,
"backupCount": 50,
"encoding": "utf8",
'debug_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'DEBUG',
'formatter': 'simple',
'filename': f'{LOGS_PATH}debug.log',
'maxBytes': 10485760,
'backupCount': 50,
'encoding': 'utf8',
},
},
"loggers": {
"my_module": {"level": "INFO", "handlers": ["console"], "propagate": "no"}
'loggers': {
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
},
"root": {
"level": "INFO",
"handlers": ["error_file_handler", "info_file_handler", "debug_file_handler", "console"],
'root': {
'level': 'INFO',
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
},
}