Merge remote-tracking branch 'origin/develop' into develop
This commit is contained in:
39
app/api/api_extraction_project_info.py
Normal file
39
app/api/api_extraction_project_info.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.project_info_extraction.service import ProjectInfoExtraction
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/extraction_project_info")
|
||||
def extraction_project_info(request_data: ProjectInfoExtractionModel):
|
||||
"""
|
||||
通过prompt 提取project_name,role ,gender ,style。
|
||||
创建一个具有以下参数的请求体:
|
||||
- **prompt**:
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"prompt": "海边派对主题的系列设计",
|
||||
"image_list": [
|
||||
"https://www.minio-api.aida.com.hk/test/test123.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=vXKFLSJkYeEq2DrSZvkB%2F20250519%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250519T050808Z&X-Amz-Expires=7200&X-Amz-SignedHeaders=host&X-Amz-Signature=296ff07cc4692d0a26ddffac582064f036494af343389fe60193dc2c5dc883ff"
|
||||
],
|
||||
"file_list": [
|
||||
""
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"extraction_project_info request item is : @@@@@@:{request_data}")
|
||||
service = ProjectInfoExtraction(request_data)
|
||||
data = service.get_result()
|
||||
logger.info(f"extraction_project_info response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"extraction_project_info Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
@@ -5,16 +5,17 @@ from app.api import api_attribute_retrieve, api_query_image
|
||||
from app.api import api_brand_dna
|
||||
from app.api import api_brighten
|
||||
from app.api import api_chat_robot
|
||||
from app.api import api_clothing_seg
|
||||
from app.api import api_design
|
||||
from app.api import api_design_pre_processing
|
||||
from app.api import api_extraction_project_info
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_image2sketch
|
||||
from app.api import api_mannequins_edit
|
||||
from app.api import api_pose_transform
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_clothing_seg
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_recommendation
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_test
|
||||
|
||||
router = APIRouter()
|
||||
@@ -36,3 +37,4 @@ router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'],
|
||||
router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api")
|
||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
|
||||
@@ -43,7 +43,7 @@ JAVA_STREAM_API_URL = os.getenv("JAVA_STREAM_API_URL", "https://api.aida.com.hk/
|
||||
settings = Settings()
|
||||
|
||||
# minio 配置
|
||||
MINIO_URL = "www.minio.aida.com.hk:12024"
|
||||
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -43,19 +45,33 @@ class GenerateRelightImageModel(BaseModel):
|
||||
"""
|
||||
|
||||
|
||||
class BatchGenerateProductImageModel(BaseModel):
|
||||
# product任务子项
|
||||
class ProductItemModel(BaseModel):
|
||||
tasks_id: str
|
||||
image_strength: float
|
||||
prompt: str
|
||||
image_url: str
|
||||
image_strength: float
|
||||
product_type: str
|
||||
batch_size: int
|
||||
|
||||
|
||||
class BatchGenerateRelightImageModel(BaseModel):
|
||||
# product批处理 集合
|
||||
class BatchGenerateProductImageModel(BaseModel):
|
||||
batch_tasks_id: str
|
||||
user_id: str
|
||||
batch_data_list: List[ProductItemModel]
|
||||
|
||||
|
||||
# relight任务子项
|
||||
class RelightItemModel(BaseModel):
|
||||
tasks_id: str
|
||||
prompt: str
|
||||
image_url: str
|
||||
direction: str
|
||||
product_type: str
|
||||
batch_size: int
|
||||
|
||||
|
||||
# relight批处理集合
|
||||
class BatchGenerateRelightImageModel(BaseModel):
|
||||
batch_tasks_id: str
|
||||
user_id: str
|
||||
batch_data_list: List[RelightItemModel]
|
||||
|
||||
7
app/schemas/project_info_extraction.py
Normal file
7
app/schemas/project_info_extraction.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ProjectInfoExtractionModel(BaseModel):
|
||||
prompt: str
|
||||
image_list: list
|
||||
file_list: list
|
||||
@@ -6,15 +6,15 @@ from app.service.generate_batch_image.service_batch_pose_transform import batch_
|
||||
async def start_product_batch_generate(data):
|
||||
generate_clothes_task = batch_generate_product.delay(data.dict())
|
||||
print(generate_clothes_task)
|
||||
product_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
|
||||
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}
|
||||
product_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "")
|
||||
return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state}
|
||||
|
||||
|
||||
async def start_relight_batch_generate(data):
|
||||
generate_clothes_task = batch_generate_relight.delay(data.dict())
|
||||
print(generate_clothes_task)
|
||||
relight_publish_status(data.tasks_id, f"0/{data.batch_size}", "")
|
||||
return {"task_id": data.tasks_id, "state": generate_clothes_task.state}
|
||||
relight_publish_status(data.batch_tasks_id, f"0/{len(data.batch_data_list)}", "")
|
||||
return {"task_id": data.batch_tasks_id, "state": generate_clothes_task.state}
|
||||
|
||||
|
||||
async def start_pose_transform_batch_generate(data):
|
||||
|
||||
@@ -19,7 +19,7 @@ from celery import Celery
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import BatchGenerateProductImageModel
|
||||
from app.schemas.generate_image import BatchGenerateProductImageModel, ProductItemModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -35,38 +35,38 @@ category = "product_image"
|
||||
|
||||
@celery_app.task
|
||||
def batch_generate_product(batch_request_data):
|
||||
batch_size = len(batch_request_data['batch_data_list'])
|
||||
logger.info(f"batch_generate_product batch_request_data:{json.dumps(batch_request_data, indent=4)}")
|
||||
tasks_id = batch_request_data['tasks_id']
|
||||
user_id = tasks_id.rsplit('-', 1)[1]
|
||||
batch_size = batch_request_data['batch_size']
|
||||
image = pre_processing_image(batch_request_data['image_url'])
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
images = [image.astype(np.uint8)] * 1
|
||||
batch_tasks_id = batch_request_data['batch_tasks_id']
|
||||
user_id = batch_request_data['user_id']
|
||||
result_data_list = []
|
||||
|
||||
prompts = [batch_request_data['prompt']] * 1
|
||||
for i, data in enumerate(batch_request_data['batch_data_list']):
|
||||
tasks_id = data['tasks_id']
|
||||
image = pre_processing_image(data['image_url'])
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
||||
images = [image.astype(np.uint8)] * 1
|
||||
prompts = [data['prompt']] * 1
|
||||
if data['product_type'] == "single":
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||
|
||||
if batch_request_data['product_type'] == "single":
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_image_strength.set_data_from_numpy(image_strength_obj)
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_image_strength.set_data_from_numpy(image_strength_obj)
|
||||
inputs = [input_text, input_image, input_image_strength]
|
||||
|
||||
inputs = [input_text, input_image, input_image_strength]
|
||||
|
||||
image_url_list = []
|
||||
for i in range(batch_size):
|
||||
try:
|
||||
if batch_request_data['product_type'] == "single":
|
||||
if data['product_type'] == "single":
|
||||
result = grpc_client.infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_cnet_image")
|
||||
else:
|
||||
@@ -77,7 +77,7 @@ def batch_generate_product(batch_request_data):
|
||||
if 'mask_list' in str(e):
|
||||
e_text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
e_image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
e_image_strength_obj = np.array(batch_request_data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||
e_image_strength_obj = np.array(data['image_strength'], dtype=np.float32).reshape((-1, 1))
|
||||
|
||||
e_input_text = grpcclient.InferInput("prompt", e_text_obj.shape, np_to_triton_dtype(e_text_obj.dtype))
|
||||
e_input_image = grpcclient.InferInput("input_image", e_image_obj.shape, "UINT8")
|
||||
@@ -96,18 +96,29 @@ def batch_generate_product(batch_request_data):
|
||||
|
||||
if isinstance(image_result, Image.Image):
|
||||
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
|
||||
image_url_list.append(image_url)
|
||||
data['product_img'] = image_url
|
||||
result_data_list.append(data)
|
||||
else:
|
||||
image_url = image_result
|
||||
if DEBUG is False:
|
||||
if i + 1 < batch_size:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
||||
# print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
||||
else:
|
||||
publish_status(tasks_id, f"OK", image_url_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}")
|
||||
# print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}")
|
||||
data['product_img'] = image_url
|
||||
result_data_list.append(data)
|
||||
|
||||
# 发送每条结果
|
||||
if DEBUG:
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
else:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
|
||||
# 任务完成,发送所有数据结果
|
||||
if DEBUG:
|
||||
print(result_data_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
print(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
else:
|
||||
publish_status(batch_tasks_id, f"OK", result_data_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GPI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
@@ -180,12 +191,52 @@ def publish_status(task_id, progress, result):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# rd = BatchGenerateProductImageModel(
|
||||
# tasks_id="123-15-51-89",
|
||||
# image_strength=0.7,
|
||||
# prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||
# image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
# product_type="overall",
|
||||
# batch_size=20
|
||||
# )
|
||||
# batch_generate_product(rd.dict())
|
||||
# rd = {
|
||||
# "user_id": "89",
|
||||
# "batch_data_list": [
|
||||
# {
|
||||
# "tasks_id": "A-123-15-51-89",
|
||||
# "image_strength": 0.7,
|
||||
# "prompt": " The best quality, ma123sterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
# "product_type": "overall",
|
||||
# },
|
||||
# {
|
||||
# "tasks_id": "B-123-15-51-89",
|
||||
# "image_strength": 0.7,
|
||||
# "prompt": " The best quality, masterpiece, real image.Outwear123,high quality clothing details,8K realistic,HDR",
|
||||
# "image_url": "aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
# "product_type": "overall",
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
rd = BatchGenerateProductImageModel(
|
||||
tasks_id="123-15-51-89",
|
||||
image_strength=0.7,
|
||||
prompt=" The best quality, masterpiece, real image.Outwear,high quality clothing details,8K realistic,HDR",
|
||||
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
product_type="overall",
|
||||
batch_size=20
|
||||
batch_tasks_id="abcd",
|
||||
user_id="89",
|
||||
batch_data_list=[
|
||||
ProductItemModel(
|
||||
tasks_id="123-5464",
|
||||
image_strength=0.7,
|
||||
product_type="overall",
|
||||
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
prompt="123"
|
||||
),
|
||||
ProductItemModel(
|
||||
tasks_id="123-5464123",
|
||||
image_strength=0.7,
|
||||
product_type="overall",
|
||||
image_url="aida-results/result_40b1a2fe-e220-11ef-9bfa-0242ac150003.png",
|
||||
prompt="123"
|
||||
)
|
||||
]
|
||||
)
|
||||
batch_generate_product(rd.dict())
|
||||
|
||||
@@ -18,7 +18,7 @@ from celery import Celery
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import BatchGenerateRelightImageModel
|
||||
from app.schemas.generate_image import BatchGenerateRelightImageModel, RelightItemModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
@@ -34,55 +34,58 @@ category = "relight_image"
|
||||
|
||||
@celery_app.task
|
||||
def batch_generate_relight(batch_request_data):
|
||||
batch_size = len(batch_request_data['batch_data_list'])
|
||||
logger.info(f"batch_generate_relight batch_request_data: {json.dumps(batch_request_data, indent=4)}")
|
||||
batch_tasks_id = batch_request_data['batch_tasks_id']
|
||||
user_id = batch_request_data['user_id']
|
||||
result_data_list = []
|
||||
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
||||
direction = batch_request_data['direction']
|
||||
seed = "1"
|
||||
prompt = batch_request_data['prompt']
|
||||
product_type = batch_request_data['product_type']
|
||||
image_url = batch_request_data['image_url']
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2")
|
||||
tasks_id = batch_request_data['tasks_id']
|
||||
user_id = tasks_id.rsplit('-', 1)[1]
|
||||
batch_size = batch_request_data['batch_size']
|
||||
|
||||
prompts = [prompt] * 1
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (512, 768))
|
||||
images = [image.astype(np.uint8)] * 1
|
||||
seeds = [seed] * 1
|
||||
nagetive_prompts = [negative_prompt] * 1
|
||||
directions = [direction] * 1
|
||||
for i, data in enumerate(batch_request_data['batch_data_list']):
|
||||
direction = data['direction']
|
||||
|
||||
if product_type == 'single':
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
|
||||
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
|
||||
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
|
||||
prompt = data['prompt']
|
||||
product_type = data['product_type']
|
||||
image_url = data['image_url']
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url.split('/', 1)[1], data_type="cv2")
|
||||
tasks_id = data['tasks_id']
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_natext.set_data_from_numpy(na_text_obj)
|
||||
input_seed.set_data_from_numpy(seed_obj)
|
||||
input_direction.set_data_from_numpy(direction_obj)
|
||||
prompts = [prompt] * 1
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (512, 768))
|
||||
images = [image.astype(np.uint8)] * 1
|
||||
seeds = [seed] * 1
|
||||
nagetive_prompts = [negative_prompt] * 1
|
||||
directions = [direction] * 1
|
||||
|
||||
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
|
||||
image_url_list = []
|
||||
for i in range(batch_size):
|
||||
if product_type == 'single':
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
|
||||
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
|
||||
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_natext.set_data_from_numpy(na_text_obj)
|
||||
input_seed.set_data_from_numpy(seed_obj)
|
||||
input_direction.set_data_from_numpy(direction_obj)
|
||||
|
||||
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
|
||||
try:
|
||||
if batch_request_data['product_type'] == "single":
|
||||
if data['product_type'] == "single":
|
||||
result = grpc_client.infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, priority=100)
|
||||
image = result.as_numpy("generated_relight_image")
|
||||
else:
|
||||
@@ -121,18 +124,29 @@ def batch_generate_relight(batch_request_data):
|
||||
logger.error(e)
|
||||
if isinstance(image_result, Image.Image):
|
||||
image_url = upload_SDXL_image(image_result, user_id=user_id, category=f"{category}", file_name=f"{tasks_id}-batch-{i}.png")
|
||||
image_url_list.append(image_url)
|
||||
data['relight_img'] = image_url
|
||||
|
||||
result_data_list.append(data)
|
||||
else:
|
||||
image_url = image_result
|
||||
if DEBUG is False:
|
||||
if i + 1 < batch_size:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", image_url)
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
||||
# print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | image_url:{image_url}")
|
||||
else:
|
||||
publish_status(tasks_id, f"OK", image_url_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}")
|
||||
# print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:OK | image_url:{image_url}")
|
||||
data['relight_img'] = image_url
|
||||
result_data_list.append(data)
|
||||
|
||||
# 发送每条结果
|
||||
if DEBUG:
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
else:
|
||||
publish_status(tasks_id, f"{i + 1}/{batch_size}", data)
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | tasks_id:{tasks_id} | progress:{i + 1}/{batch_size} | result_data:{data}")
|
||||
# 任务完成,发送所有数据结果
|
||||
if DEBUG:
|
||||
print(result_data_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
print(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
else:
|
||||
publish_status(batch_tasks_id, f"OK", result_data_list)
|
||||
logger.info(f" [x]Queue : {BATCH_GRI_RABBITMQ_QUEUES} | batch_tasks_id:{batch_tasks_id} | progress:OK | result_data_list:{result_data_list}")
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
@@ -151,12 +165,44 @@ def publish_status(task_id, progress, result):
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = BatchGenerateRelightImageModel(
|
||||
tasks_id="123-89",
|
||||
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
prompt="Colorful black",
|
||||
image_url='aida-users/89/clothing_seg/283c5c82-1a92-11f0-b72a-0242ac150002.png',
|
||||
direction="Right Light",
|
||||
product_type="overall",
|
||||
batch_size=10
|
||||
batch_tasks_id="abcd",
|
||||
user_id="89",
|
||||
batch_data_list=[
|
||||
RelightItemModel(
|
||||
tasks_id="123-5464",
|
||||
product_type="overall",
|
||||
image_url="aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
|
||||
prompt="Colorful black",
|
||||
direction="Right Light",
|
||||
),
|
||||
RelightItemModel(
|
||||
tasks_id="123-5464123",
|
||||
product_type="overall",
|
||||
image_url="aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
|
||||
direction="Right Light",
|
||||
prompt="Colorful black",
|
||||
)
|
||||
]
|
||||
)
|
||||
batch_generate_relight(rd.dict())
|
||||
# X = {
|
||||
# "batch_tasks_id": "abcd",
|
||||
# "user_id": "89",
|
||||
# "batch_data_list": [
|
||||
# {
|
||||
# "tasks_id": "123-5464",
|
||||
# "product_type": "overall",
|
||||
# "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
|
||||
# "prompt": "Colorful black",
|
||||
# "direction": "Right Light",
|
||||
# },
|
||||
# {
|
||||
# "tasks_id": "123-5464",
|
||||
# "product_type": "overall",
|
||||
# "image_url": "aida-users/89/product_image/02894523-19b5-46eb-a9c6-2f512f5fec84-0-89.png",
|
||||
# "prompt": "Colorful black",
|
||||
# "direction": "Right Light",
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# }
|
||||
|
||||
@@ -29,9 +29,6 @@ logger = logging.getLogger()
|
||||
|
||||
class PoseTransformService:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=PT_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "pose_transform"
|
||||
@@ -40,7 +37,8 @@ class PoseTransformService:
|
||||
self.image = pre_processing_image(request_data.image_url)
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '',
|
||||
'video_url': '', 'image_url': ''}
|
||||
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
@@ -55,16 +53,20 @@ class PoseTransformService:
|
||||
|
||||
# 第一帧图像
|
||||
first_image = Image.fromarray(result_data[0])
|
||||
first_image_url = upload_first_image(first_image, user_id=self.user_id, category=f"{self.category}_first_img", file_name=f"{self.tasks_id}.png")
|
||||
first_image_url = upload_first_image(first_image, user_id=self.user_id,
|
||||
category=f"{self.category}_first_img",
|
||||
file_name=f"{self.tasks_id}.png")
|
||||
|
||||
# 上传GIF
|
||||
gif_buffer = BytesIO()
|
||||
imageio.mimsave(gif_buffer, result_data, format='GIF', fps=5)
|
||||
gif_buffer.seek(0)
|
||||
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif", file_name=f"{self.tasks_id}.gif")
|
||||
gif_url = upload_gif(gif_buffer=gif_buffer, user_id=self.user_id, category=f"{self.category}_gif",
|
||||
file_name=f"{self.tasks_id}.gif")
|
||||
|
||||
# 上传video
|
||||
video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video", file_name=f"{self.tasks_id}.mp4")
|
||||
video_url = upload_video(frames=result_data, user_id=self.user_id, category=f"{self.category}_video",
|
||||
file_name=f"{self.tasks_id}.mp4")
|
||||
|
||||
self.pose_transform_data['status'] = "SUCCESS"
|
||||
self.pose_transform_data['message'] = "success"
|
||||
@@ -82,7 +84,8 @@ class PoseTransformService:
|
||||
try:
|
||||
pose_num = [self.pose_num] * 1
|
||||
pose_num_obj = np.array(pose_num, dtype="object").reshape((-1, 1))
|
||||
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape, np_to_triton_dtype(pose_num_obj.dtype))
|
||||
input_pose_num = grpcclient.InferInput("pose_num", pose_num_obj.shape,
|
||||
np_to_triton_dtype(pose_num_obj.dtype))
|
||||
input_pose_num.set_data_from_numpy(pose_num_obj)
|
||||
|
||||
image_files = [self.image.astype(np.uint8)] * 1
|
||||
@@ -90,7 +93,8 @@ class PoseTransformService:
|
||||
input_image_files = grpcclient.InferInput("image_file", image_files_obj.shape, "UINT8")
|
||||
input_image_files.set_data_from_numpy(image_files_obj)
|
||||
|
||||
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files], callback=self.callback, client_timeout=60000)
|
||||
ctx = self.grpc_client.async_infer(model_name="animatex_1", inputs=[input_pose_num, input_image_files],
|
||||
callback=self.callback, client_timeout=60000)
|
||||
time_out = 60000
|
||||
while time_out > 0:
|
||||
pose_transform_data, _ = self.read_tasks_status()
|
||||
@@ -100,7 +104,7 @@ class PoseTransformService:
|
||||
elif pose_transform_data['status'] == "SUCCESS":
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(0.1)
|
||||
time.sleep(1)
|
||||
pose_transform_data, _ = self.read_tasks_status()
|
||||
return pose_transform_data
|
||||
except Exception as e:
|
||||
@@ -111,9 +115,22 @@ class PoseTransformService:
|
||||
finally:
|
||||
dict_pose_transform_data, str_pose_transform_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=PS_RABBITMQ_QUEUES, body=str_pose_transform_data)
|
||||
self.connection.close()
|
||||
logger.info(f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
|
||||
publish_status(str_pose_transform_data)
|
||||
logger.info(
|
||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_pose_transform_data, indent=4)}")
|
||||
|
||||
|
||||
def publish_status(message):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=PS_RABBITMQ_QUEUES, durable=True)
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=PS_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
))
|
||||
connection.close()
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
@@ -125,7 +142,8 @@ def infer_cancel(tasks_id):
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:],
|
||||
data_type="PIL")
|
||||
# 目标图片的尺寸
|
||||
target_width = 512
|
||||
target_height = 768
|
||||
|
||||
61
app/service/project_info_extraction/service.py
Normal file
61
app/service/project_info_extraction/service.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from app.schemas.project_info_extraction import ProjectInfoExtractionModel
|
||||
|
||||
style = ['NEW_CHINESE', 'COUNTRY_STYLE', 'FUTURISM', 'MINIMALISM', 'LOLITA', 'Y2K', 'BUSINESS', 'MERLAD',
|
||||
'OUTDOOR_FUNCTIONAL', 'ROCK', 'DOPAMINE', 'GOTHIC', 'POST_APOCALYPTIC', 'ROMANTIC', 'WABI_SABI']
|
||||
position = ['Overall', 'Tops', 'Bottoms', 'Outwear', 'Blouse', 'Dress', 'Trousers', 'Skirt']
|
||||
gender = ['Female', 'Male']
|
||||
age_group = ['Adult', 'Child']
|
||||
process = ['SERIES_DESIGN', 'SINGLE_DESIGN']
|
||||
|
||||
|
||||
class ProjectInfoExtraction:
|
||||
def __init__(self, request_data):
|
||||
# llm generate brand info init
|
||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
|
||||
self.response_schemas = [
|
||||
ResponseSchema(name="project_name", description="项目的名称."),
|
||||
ResponseSchema(name="process", description="项目的类型,单品还是系列."),
|
||||
ResponseSchema(name="ageGroup", description="项目设计服装的受众对象."),
|
||||
ResponseSchema(name="gender", description="项目设计服装的受众性别."),
|
||||
ResponseSchema(name="position", description="项目单品设计的部位."),
|
||||
ResponseSchema(name="style", description="项目的设计风格.")
|
||||
]
|
||||
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
|
||||
self.format_instructions = self.output_parser.get_format_instructions()
|
||||
self.prompt = PromptTemplate(
|
||||
template="你是一个时装品牌的设计师助理。根据用户输入提取出"
|
||||
"[project_name] : 项目的名称,"
|
||||
f"[process] : 项目的类型,从{process}选择."
|
||||
f"[ageGroup] : 服装的受众,从{age_group}选择."
|
||||
f"[gender] : 服装的适用性别,从{gender}选择"
|
||||
f"[position] : single_design的部位,如果[process]是SINGLE_DESIGN,从{position}中选择,如果[process]是SERIES_DESIGN,这项为空"
|
||||
f"[style] : 设计的风格,从{style}中选择"
|
||||
".\n{format_instructions}\n{question}",
|
||||
input_variables=["question"],
|
||||
partial_variables={"format_instructions": self.format_instructions}
|
||||
)
|
||||
self._input = self.prompt.format_prompt(question=request_data.prompt)
|
||||
|
||||
self.result_data = {}
|
||||
|
||||
def get_result(self):
|
||||
self.llm_extraction_project_info()
|
||||
return self.result_data
|
||||
|
||||
def llm_extraction_project_info(self):
|
||||
output = self.model(self._input.to_messages())
|
||||
project_info = self.output_parser.parse(output.content)
|
||||
self.result_data = project_info
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = ProjectInfoExtractionModel(
|
||||
prompt="海边派对主题的衬衫设计"
|
||||
)
|
||||
service = ProjectInfoExtraction(request_data)
|
||||
print(service.get_result())
|
||||
Reference in New Issue
Block a user