Merge remote-tracking branch 'origin/develop' into develop

# Conflicts:
#	app/service/chat_robot/script/main.py
This commit is contained in:
2024-07-08 19:00:08 +08:00
19 changed files with 371 additions and 282 deletions

View File

@@ -34,13 +34,14 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
]
"""
try:
logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}")
for item in request_item:
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
if DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else:
service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result()
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -65,10 +66,11 @@ def category_recognition(request_item: list[CategoryRecognitionModel]):
]
"""
try:
logger.info(f"category_recognition request item is : @@@@@@:{request_item}")
for item in request_item:
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
service = CategoryRecognition(request_data=request_item)
data = service.get_result()
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -30,9 +30,9 @@ def chat_robot(request_data: ChatRobotModel):
}
"""
try:
logger.info(f"chat_robot request item is : @@@@@@:{request_data}")
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = chat(post_data=request_data)
logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -66,22 +66,57 @@ def design(request_data: DesignModel):
],
"path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png",
"print": {
"IfSingle": false,
"location": [
[
512.0,
512.0
"single": {
"location": [
[
200.0,
200.0
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png"
],
"print_scale_list": [
1.0
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png"
],
"print_scale_list": [
1.0
]
},
"overall": {
"location": [
[
512.0,
512.0
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png"
],
"print_scale_list": [
1.0
]
},
"element": {
"element_angle_list": [
0.0
],
"element_path_list": [
"aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png"
],
"element_scale_list": [
0.2731036750637755
],
"location": [
[
228.63694825464364,
406.4843844199667
]
]
}
},
"priority": 10,
"resize_scale": [
@@ -102,9 +137,9 @@ def design(request_data: DesignModel):
}
"""
try:
logger.info(f"design request item is : @@@@@@:{request_data.dict()}")
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = generate(request_data=request_data)
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"design response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -124,13 +159,13 @@ def get_progress(request_data: DesignProgressModel):
}
"""
try:
logger.info(f"get_progress request item is : @@@@@@:{request_data.dict()}")
logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}")
process_id = request_data.process_id
r = Redis()
data = r.read(key=process_id)
if data is None:
raise ValueError(f"No progress ID: {process_id}")
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}")
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}")
except Exception as e:
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -150,10 +185,10 @@ def model_process(request_data: ModelProgressModel):
}
"""
try:
logger.info(f"model_process request item is : @@@@@@:{request_data.dict()}")
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = model_transpose(image_path=request_data.model_path)
logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"model_process response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"model_process Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -30,10 +30,10 @@ def design_pre_processing(request_data: DesignPreProcessingModel):
}
"""
try:
logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}")
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}")
server = DesignPreprocessing()
data = server.pipeline(image_list=request_data.sketches)
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"design response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel
from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
@@ -38,7 +38,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun
}
"""
try:
logger.info(f"generate_image request item is : @@@@@@:{request_item}")
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -52,7 +52,7 @@ def generate_image(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_image_infer_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"generate_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -78,7 +78,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_
}
"""
try:
logger.info(f"generate_single_logo request item is : @@@@@@:{request_item}")
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateSingleLogoImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -92,7 +92,7 @@ def generate_single_logo_image(tasks_id: str):
try:
logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}")
data = generate_single_logo_cancel(tasks_id)
logger.info(f"generate_single_logo_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"generate_single_logo_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -109,16 +109,21 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **image_strength**: 生成强度,越低越接近原图
- **product_type**: 输入single item 还是 overall item
示例参数:
{
"tasks_id": "123-89",
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png"
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
"image_strength": 0.8,
"product_type": "overall"
}
"""
try:
logger.info(f"generate_product_image request item is : @@@@@@:{request_item}")
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateProductImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -132,7 +137,7 @@ def generate_product_image(tasks_id: str):
try:
logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_product_image_cancel(tasks_id)
logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -143,22 +148,27 @@ def generate_product_image(tasks_id: str):
@router.post("/generate_relight_image")
def generate_relight_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks):
def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
- **product_type**: 输入single item 还是 overall item
示例参数:
{
"tasks_id": "123-89",
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png"
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"direction": "Right Light",
"product_type": "overall"
}
"""
try:
logger.info(f"generate_relight_image request item is : @@@@@@:{request_item}")
logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateRelightImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -172,7 +182,7 @@ def generate_relight_image(tasks_id: str):
try:
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_relight_image_cancel(tasks_id)
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -27,7 +27,7 @@ def prompt_generation(request_data: PromptGenerationImageModel):
try:
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
data = translate_to_en(request_data.text)
logger.info(f"prompt_generation response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"prompt_generation response @@@@@@:{data}")
except Exception as e:
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -27,7 +27,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg
}
"""
try:
logger.info(f"super_resolution request item is : @@@@@@:{request_item}")
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = SuperResolution(request_item)
background_tasks.add_task(service.sr_result)
except Exception as e:
@@ -41,7 +41,7 @@ def super_resolution(tasks_id: str):
try:
logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}")
data = infer_cancel(tasks_id)
logger.info(f"sr_cancel response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"sr_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"sr_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -1,8 +1,10 @@
import json
import logging
from fastapi import APIRouter
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
from fastapi import FastAPI, HTTPException
from fastapi import APIRouter
from fastapi import HTTPException
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
from app.schemas.response_template import ResponseModel
logger = logging.getLogger()
@@ -18,7 +20,7 @@ def test(id: int):
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
"local_oss_server": OSS
}
logger.info(data)
logger.info(json.dumps(data))
if id == 1:
raise HTTPException(status_code=404, detail="Item not found")

View File

@@ -118,14 +118,17 @@ GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}")
# Generate Single Logo service config
# Generate Product service config
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
GPI_MODEL_NAME = 'diffusion_ensemble_all'
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.240:10041'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
GRI_MODEL_NAME = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# SEG service config

View File

@@ -20,9 +20,13 @@ class GenerateProductImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
image_strength: float
product_type: str
class GenerateRelightImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
direction: str
product_type: str

View File

@@ -1,26 +1,24 @@
import json
import logging
from langchain_community.chat_models import ChatTongyi
from loguru import logger
from langchain.agents import Tool
from langchain_community.utilities import SerpAPIWrapper
from langchain.callbacks import FileCallbackHandler
from langchain.utilities import SerpAPIWrapper
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
# from langchain_community.chat_models import ChatOpenAI
# from langchain_community.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.callbacks import FileCallbackHandler
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.service import CallQWen
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
from app.core.config import *
import os
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
# log callbacks
@@ -138,5 +136,5 @@ def chat(post_data):
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs["response_type"]
}
logging.info(api_response)
logging.info(json.dumps(api_response))
return api_response

View File

@@ -36,7 +36,9 @@ class Clothing(object):
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else ""
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url']
)
layer.insert(front_layer)
@@ -51,7 +53,8 @@ class Clothing(object):
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else ""
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url']
)
layer.insert(back_layer)

View File

@@ -88,98 +88,113 @@ class PrintPainting(object):
# @ RunTime
def __call__(self, result):
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
result['single_image'] = None
result['print_image'] = None
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True)
if "location" not in result['print'].keys():
result['print']["location"] = [[0, 0]]
elif result['print']["location"] == [] or result['print']["location"] is None:
result['print']["location"] = [[0, 0]]
if result['print']['IfSingle']:
if len(result['print']['print_path_list']) > 0:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
# print_background = np.full((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), 255, dtype=np.uint8)
for i in range(len(result['print']['print_path_list'])):
image, image_mode = self.read_image(result['print']['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * result['print']['print_scale_list'][i]), int(image.height * result['print']['print_scale_list'][i]))
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
if single_print['print_path_list']:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
rotated_resized_source = resized_source.rotate(-result['print']['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-result['print']['print_angle_list'][i])
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask)
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(result['print']['location'][i][0] - rotated_new_size[0]), int(result['print']['location'][i][1] - rotated_new_size[1])
start_x = x
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# ------------------
# 如果print-size大于image-size 则需要裁剪print
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
@@ -197,54 +212,27 @@ class PrintPainting(object):
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
else:
painting_dict = {}
painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2]
# no print
if len(result['print_dict']['print_path_list']) == 0 or not self.print_flag:
result['print_image'] = result['pattern_image']
# print
else:
if "print_angle_list" in result['print'].keys() and result['print']['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, result, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, result, print_trigger=True)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['final_image'] = result['print_image']
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if "element" in result.keys():
if element_print['element_path_list']:
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(result['element']['element_path_list'])):
image, image_mode = self.read_image(result['element']['element_path_list'][i])
for i in range(len(element_print['element_path_list'])):
image, image_mode = self.read_image(element_print['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * result['element']['element_scale_list'][i]), int(image.height * result['element']['element_scale_list'][i]))
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-result['element']['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-result['element']['element_angle_list'][i])
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source_mask)
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
@@ -255,10 +243,10 @@ class PrintPainting(object):
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i])
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(result['element']['location'][i][0] - rotated_new_size[0]), int(result['element']['location'][i][1] - rotated_new_size[1])
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
@@ -352,18 +340,18 @@ class PrintPainting(object):
return print_background
def painting_collection(self, painting_dict, result, print_trigger=False):
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
if print_trigger:
print_ = self.get_print(result['print_dict'])
painting_dict['Trigger'] = not print_['IfSingle']
painting_dict['location'] = print_['location'] if 'location' in print_.keys() else None
print_ = self.get_print(print_dict)
painting_dict['Trigger'] = not is_single
painting_dict['location'] = print_['location']
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not print_['IfSingle']:
if not is_single:
self.random_seed = random.randint(0, 1000)
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in result['print'].keys() and result['print']['print_angle_list'][0] != 0:
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
else:
@@ -458,19 +446,11 @@ class PrintPainting(object):
@staticmethod
def get_print(print_dict):
if not 'print_scale_list' in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
if not 'IfSingle' in print_dict.keys():
print_dict['IfSingle'] = False
# data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1])
# data_bytes = BytesIO(data.read())
# image = Image.open(data_bytes)
# image_mode = image.mode
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL")
@@ -480,13 +460,6 @@ class PrintPainting(object):
new_background.paste(image, mask=image.split()[3])
image = new_background
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
# file = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]).data
# print_dict['image'] = cv2.imdecode(np.fromstring(file, np.uint8), 1)
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
# return image
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):

View File

@@ -71,6 +71,11 @@ class Split(object):
result["back_image_url"] = None
result["back_mask_url"] = None
result['back_mask_image'] = None
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba((result['pattern_image'].shape[0], result['pattern_image'].shape[1]), result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
_, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -1,10 +1,10 @@
import concurrent.futures
from app.core.config import PRIORITY_DICT
from app.service.design.core.layer import Layer
from app.service.design.items import build_item
from app.service.design.utils.redis_utils import Redis
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
import concurrent.futures
from app.service.utils.decorator import RunTime
@@ -96,6 +96,7 @@ def process_object(cfg, process_id, total):
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
# 'image': lay['image'],
# 'mask_image': lay['mask_image'],

View File

@@ -40,16 +40,17 @@ class GenerateImage:
if request_data.mode == "img2img":
# cv2 读图片是BGR PIL读图片是RGB
self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else:
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.mode = request_data.mode
self.batch_size = 1
self.category = request_data.category
if self.category == "sketch":
self.prompt = f"{self.category},{self.prompt}"
self.index = 0
self.gender = request_data.gender
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''}

View File

@@ -7,17 +7,17 @@
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
import time
import cv2
import numpy as np
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateProductImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
@@ -37,7 +37,9 @@ class GenerateProductImage:
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
self.image_strength = request_data.image_strength
self.batch_size = 1
self.product_type = request_data.product_type
self.prompt = request_data.prompt
self.image, self.image_size = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
@@ -53,7 +55,10 @@ class GenerateProductImage:
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
if self.product_type == "single":
image = result.as_numpy("generated_cnet_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
@@ -72,17 +77,31 @@ class GenerateProductImage:
self.image = cv2.resize(self.image, (512, 768))
images = [self.image.astype(np.uint8)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape(1)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
if self.product_type == "single":
text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
else:
text_obj = np.array(prompts, dtype="object").reshape(1)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
# 假设 prompts、images 和 self.image_strength 已经定义
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
inputs = [input_text, input_image]
inputs = [input_text, input_image, input_image_strength]
input_image_strength.set_data_from_numpy(image_strength_obj)
if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
@@ -117,24 +136,39 @@ def infer_cancel(tasks_id):
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 原始图片的尺寸
width, height = image.size
# resize 图片内尺寸 并贴到768-512的纯白图像上
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 计算长宽比为 3:2 的新尺寸
desired_ratio = 2 / 3
current_ratio = width / height
if current_ratio > desired_ratio:
# 原始图片更宽,需要在上下添加 padding
new_width = width
new_height = int(width / desired_ratio)
else:
# 原始图片更高或者长宽比已经为 3:2
new_height = height
new_width = int(height * desired_ratio)
# 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
# 将原始图片粘贴到新的画布中心
left = (new_width - width) // 2
top = (new_height - height) // 2
pad_image.paste(image, (left, top))
# 将画布 resize 成宽度 500长度 750
resized_image = pad_image.resize((500, 750))
image_size = (512, 768)
if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
background = Image.new("RGB", image_size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
background.paste(resized_image, mask=resized_image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image, image_size
@@ -143,9 +177,11 @@ def pre_processing_image(image_url):
if __name__ == '__main__':
rd = GenerateProductImageModel(
tasks_id="123-89",
prompt="",
# prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
# prompt="",
image_strength=0.9,
prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
product_type="overall"
)
server = GenerateProductImage(rd)
print(server.get_result())

View File

@@ -7,16 +7,15 @@
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
import time
import cv2
import numpy as np
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from PIL import Image
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
@@ -39,8 +38,9 @@ class GenerateRelightImage:
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "1"
self.product_type = request_data.product_type
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.direction = "Right Light"
self.direction = request_data.direction
self.image_url = request_data.image_url
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
self.tasks_id = request_data.tasks_id
@@ -56,7 +56,11 @@ class GenerateRelightImage:
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
if self.product_type == 'single':
image = result.as_numpy("generated_relight_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
@@ -79,11 +83,18 @@ class GenerateRelightImage:
nagetive_prompts = [self.negative_prompt] * self.batch_size
directions = [self.direction] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
if self.product_type == 'single':
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
@@ -98,8 +109,11 @@ class GenerateRelightImage:
input_direction.set_data_from_numpy(direction_obj)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
if self.product_type == 'single':
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
else:
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
@@ -137,7 +151,9 @@ if __name__ == '__main__':
tasks_id="123-89",
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
prompt="Colorful black",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png'
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
direction="Right Light",
product_type="single"
)
server = GenerateRelightImage(rd)
print(server.get_result())