diff --git a/app/api/api_agent_generate_image.py b/app/api/api_agent_generate_image.py new file mode 100644 index 0000000..e4001ff --- /dev/null +++ b/app/api/api_agent_generate_image.py @@ -0,0 +1,19 @@ +import io +import logging + +from fastapi import APIRouter, HTTPException +from starlette.responses import StreamingResponse + +from app.schemas.response_template import ResponseModel +from app.service.generate_image.agent_generate import GenerateImage + +router = APIRouter() +logger = logging.getLogger() + + +@router.get("/agent_generate_image") +def generate_image(prompt: str): + server = GenerateImage() + byte_stream = server.get_result(prompt) + # 返回流式响应 + return StreamingResponse(byte_stream, media_type="image/png") diff --git a/app/api/api_mannequins_edit.py b/app/api/api_mannequins_edit.py new file mode 100644 index 0000000..5cfaf3d --- /dev/null +++ b/app/api/api_mannequins_edit.py @@ -0,0 +1,40 @@ +import json +import logging + +from fastapi import APIRouter, HTTPException + +from app.schemas.mannequin_edit import MannequinModel +from app.schemas.response_template import ResponseModel +from app.service.mannequins_edit.service import MannequinEditService + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/mannequins_edit") +def mannequins_edit(request_data: MannequinModel): + """ + 模特腿长调整 + 创建一个具有以下参数的请求体: + - **mannequins**: mannequins url等信息 + - **scale**: 大腿小腿比例 + - **bucket_name**: bucket name + - **mannequin_name**: 模特名称 + + 示例参数: + - **{ + "mannequins": "aida-sys-image/models/male/dc36ce58-46c3-4b6f-8787-5ca7d6fc26e6.png", + "scale": [0.75, 0.75], + "bucket_name": "test", + "mannequin_name": "mannequin_name" + }** + """ + try: + logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict())}") + service = MannequinEditService(request_data) + data = service() + logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data)}") + except Exception as e: + logger.warning(f"mannequins_edit Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/api/api_pose_transform.py b/app/api/api_pose_transform.py new file mode 100644 index 0000000..fe5fc5a --- /dev/null +++ b/app/api/api_pose_transform.py @@ -0,0 +1,49 @@ +import json +import logging + +from fastapi import APIRouter, BackgroundTasks, HTTPException + +from app.schemas.pose_transform import PoseTransformModel +from app.schemas.response_template import ResponseModel +from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel + +router = APIRouter() +logger = logging.getLogger() + + +@router.post("/pose_transform") +def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks): + """ + 创建一个具有以下参数的请求体: + - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 + - **image_url**: 被生成图片的S3或minio url地址 + - **pose_id**: 1 + + + 示例参数: + { + "tasks_id": "123-89", + "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", + "pose_id": "1" + } + """ + try: + logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict())}") + service = PoseTransformService(request_item) + background_tasks.add_task(service.get_result) + except Exception as e: + logger.warning(f"pose_transform Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel() + + +@router.get("/pose_transform_cancel/{tasks_id}") +def pose_transform_cancel(tasks_id: str): + try: + logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}") + data = pose_transform_infer_cancel(tasks_id) + logger.info(f"pose_transform_cancel response @@@@@@:{data}") + except Exception as e: + logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data['data']) diff --git a/app/api/api_route.py b/app/api/api_route.py index 3890316..9858ba6 100644 --- a/app/api/api_route.py +++ b/app/api/api_route.py @@ -1,5 +1,6 @@ from fastapi import APIRouter +from app.api import api_agent_generate_image, api_recommendation from app.api import api_attribute_retrieve, api_query_image from app.api import api_brand_dna from app.api import api_brighten @@ -8,9 +9,10 @@ from app.api import api_design from app.api import api_design_pre_processing from app.api import api_generate_image from app.api import api_image2sketch +from app.api import api_mannequins_edit +from app.api import api_pose_transform from app.api import api_prompt_generation from app.api import api_super_resolution -from app.api import api_recommendation from app.api import api_test router = APIRouter() @@ -28,3 +30,6 @@ router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api") router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api") router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api") +router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api") +router.include_router(api_agent_generate_image.router, tags=['api_agent_generate_image'], prefix="/api") +router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api") diff --git a/app/core/config.py b/app/core/config.py index df4702b..5a1e2a3 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -138,7 +138,7 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' -GPI_MODEL_URL = '10.1.1.243:15551' +GPI_MODEL_URL = '10.1.1.243:10051' # Generate Single Logo service config GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") diff --git a/app/schemas/mannequin_edit.py b/app/schemas/mannequin_edit.py new file mode 100644 index 0000000..c5514d6 --- /dev/null +++ b/app/schemas/mannequin_edit.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class MannequinModel(BaseModel): + mannequins: str + scale: list[float, float] + bucket_name: str + mannequin_name: str diff --git a/app/schemas/pose_transform.py b/app/schemas/pose_transform.py new file mode 100644 index 0000000..045d8b9 --- /dev/null +++ b/app/schemas/pose_transform.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class PoseTransformModel(BaseModel): + image_url: str + tasks_id: str + pose_id: str diff --git a/app/service/design_fast/pipeline/color.py b/app/service/design_fast/pipeline/color.py index 3033bb5..d6c84e4 100644 --- a/app/service/design_fast/pipeline/color.py +++ b/app/service/design_fast/pipeline/color.py @@ -29,6 +29,24 @@ class Color: else: pattern = self.get_pattern(result['color']) resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) + + if "partial_color" in result.keys() and result['partial_color'] != "": + bucket_name = result['partial_color'].split('/')[0] + object_name = result['partial_color'][result['partial_color'].find('/') + 1:] + partial_color = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="cv2") + h, w = partial_color.shape[0:2] + resize_pattern = cv2.resize(resize_pattern, (w, h), interpolation=cv2.INTER_AREA) + # 分离出 png 图的 alpha 通道 + alpha_channel = partial_color[:, :, 3] + # 提取 png 图的 RGB 通道 + png_rgb = partial_color[:, :, :3] + # 创建一个与 cv 图大小相同的掩码,用于指示哪些像素需要替换 + mask = alpha_channel > 0 + # 将掩码扩展为 3 通道,以便与 cv 图进行逐元素操作 + mask_3ch = np.stack([mask] * 3, axis=-1) + # 根据掩码将 png 图的颜色覆盖到 cv 图上 + resize_pattern[mask_3ch] = png_rgb[mask_3ch] + resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA) closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2) get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255) diff --git a/app/service/generate_image/agent_generate.py b/app/service/generate_image/agent_generate.py new file mode 100644 index 0000000..58ac869 --- /dev/null +++ b/app/service/generate_image/agent_generate.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import io +import logging +from datetime import timedelta + +import cv2 +import numpy as np +import tritonclient.grpc as grpcclient +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.service.utils.oss_client import oss_upload_image + +logger = logging.getLogger() + + +class GenerateImage: + def __init__(self): + self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + self.batch_size = 1 + self.mode = 'txt2img' + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + def get_result(self, prompt): + prompts = [prompt] * self.batch_size + modes = [self.mode] * self.batch_size + images = [self.image.astype(np.float16)] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) + input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_mode.set_data_from_numpy(mode_obj) + + inputs = [input_text, input_image, input_mode] + result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs) + image = result.as_numpy("generated_image") + image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) + _, img_byte_array = cv2.imencode('.jpg', image_result) + byte_stream = io.BytesIO(img_byte_array) + byte_stream.seek(0) + + # object_name = f'test.jpg' + # req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array) + # url = self.minio_client.get_presigned_url( + # "GET", + # "test", + # object_name, + # expires=timedelta(hours=2), + # ) + return byte_stream + + +if __name__ == '__main__': + server = GenerateImage() + print(server.get_result("rabbit")) diff --git a/app/service/generate_image/service_pose_transform.py b/app/service/generate_image/service_pose_transform.py new file mode 100644 index 0000000..f2948b3 --- /dev/null +++ b/app/service/generate_image/service_pose_transform.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_pose_transform.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging + +import cv2 +import numpy as np +import redis +import tritonclient.grpc as grpcclient +from PIL import Image + +from app.core.config import * +from app.schemas.pose_transform import PoseTransformModel +from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image +from app.service.utils.oss_client import oss_get_image + +logger = logging.getLogger() + + +class PoseTransformService: + def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.category = "pose_transform" + self.batch_size = 1 + self.seed = "1" + self.image_url = request_data.image_url + self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.redis_client.expire(self.tasks_id, 600) + + def callback(self, result, error): + if error: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + else: + image = result.as_numpy("generated_inpaint_image") + image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))) + image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") + self.gen_product_data['status'] = "SUCCESS" + self.gen_product_data['message'] = "success" + self.gen_product_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (512, 768)) + images = [image.astype(np.uint8)] * self.batch_size + + image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + + input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + + input_image.set_data_from_numpy(image_obj) + + inputs = [input_image] + # ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + + # time_out = 600 + # while time_out > 0: + # gen_product_data, _ = self.read_tasks_status() + # if gen_product_data['status'] in ["REVOKED", "FAILURE", "NO_FACE"]: + # ctx.cancel() + # break + # elif gen_product_data['status'] == "SUCCESS": + # break + # time_out -= 1 + # time.sleep(0.1) + gen_product_data, _ = self.read_tasks_status() + return gen_product_data + except Exception as e: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + raise Exception(str(e)) + finally: + dict_gen_product_data, str_gen_product_data = self.read_tasks_status() + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + gen_product_data = json.dumps(data) + redis_client.set(tasks_id, gen_product_data) + return data + + +if __name__ == '__main__': + rd = PoseTransformModel( + tasks_id="123-89", + image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png', + pose_id="1" + ) + server = PoseTransformService(rd) + print(server.get_result()) diff --git a/app/service/mannequins_edit/service.py b/app/service/mannequins_edit/service.py new file mode 100644 index 0000000..bbb6cc5 --- /dev/null +++ b/app/service/mannequins_edit/service.py @@ -0,0 +1,101 @@ +import cv2 +import mediapipe as mp +import numpy as np +from minio import Minio + +from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE +from app.schemas.mannequin_edit import MannequinModel +from app.service.utils.new_oss_client import oss_get_image, oss_upload_image + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +class MannequinEditService(): + def __init__(self, request_data): + self.scale = request_data.scale + self.image = oss_get_image(oss_client=minio_client, bucket=request_data.mannequins.split('/')[0], object_name=request_data.mannequins[request_data.mannequins.find('/') + 1:], data_type="cv2") + self.mannequin_name = request_data.mannequin_name + self.bucket_name = request_data.bucket_name + if self.image.shape[2] == 4: + self.bgr = self.image[:, :, :3] + self.alpha = self.image[:, :, 3] + self.bgr = cv2.bitwise_and(self.bgr, self.bgr, mask=cv2.normalize(self.alpha, None, 0, 1, cv2.NORM_MINMAX)) + self.h, self.w, _ = self.bgr.shape + else: + self.bgr = self.image + self.h, self.w, _ = self.bgr.shape + self.alpha = None + + def __call__(self, *args, **kwargs): + leg_top, leg_bottom = self.attitude_detection() + if leg_top and leg_bottom: + new_mannequin = self.resize_leg(leg_top, leg_bottom) + _, encoded_image = cv2.imencode('.png', new_mannequin) + image_bytes = encoded_image.tobytes() + req = oss_upload_image(oss_client=minio_client, bucket=self.bucket_name, object_name=f"{self.mannequin_name}.png", image_bytes=image_bytes) + return req.bucket_name + "/" + req.object_name + else: + return "No leg detected" + + def attitude_detection(self): + mp_pose = mp.solutions.pose + pose = mp_pose.Pose() + + # 将 BGR 图像转换为 RGB 格式 + image_rgb = cv2.cvtColor(self.bgr, cv2.COLOR_BGR2RGB) + leg_top, leg_bottom = None, None + # 进行姿态检测 + results = pose.process(image_rgb) + if results.pose_landmarks: + # 获取腿部关键点 + landmarks = results.pose_landmarks.landmark + + # 找到腿部上边界和下边界 + leg_top = int(landmarks[mp_pose.PoseLandmark.LEFT_HIP].y * self.h) + leg_bottom = int(max(landmarks[mp_pose.PoseLandmark.LEFT_ANKLE].y, + landmarks[mp_pose.PoseLandmark.RIGHT_ANKLE].y) * self.h) + + return leg_top, leg_bottom + + def resize_leg(self, leg_top, leg_bottom): + # 上半身 + top_part_bgr = self.bgr[:leg_top, :] + top_part_bgr_alpha = self.alpha[:leg_top, :] + + # 小腿 + part_thigh = self.bgr[leg_top:leg_bottom, :] + part_thigh_alpha = self.alpha[leg_top:leg_bottom, :] + + # 大腿 + part_calf = self.bgr[leg_bottom:, :] + part_calf_alpha = self.alpha[leg_bottom:, :] + + new_thigh_height = int((leg_bottom - leg_top) * self.scale[0]) + new_calf_height = int((self.h - leg_bottom) * self.scale[1]) + + resized_thigh = cv2.resize(part_thigh, (self.w, new_thigh_height), interpolation=cv2.INTER_LINEAR) + resized_thigh_alpha = cv2.resize(part_thigh_alpha, (self.w, new_thigh_height), interpolation=cv2.INTER_LINEAR) + resized_calf = cv2.resize(part_calf, (self.w, new_calf_height), interpolation=cv2.INTER_LINEAR) + resized_calf_alpha = cv2.resize(part_calf_alpha, (self.w, new_calf_height), interpolation=cv2.INTER_LINEAR) + + new_bgr = np.vstack((top_part_bgr, resized_thigh, resized_calf)) + new_bgr_alpha = np.vstack((top_part_bgr_alpha, resized_thigh_alpha, resized_calf_alpha)) + + if self.alpha is not None: + # 拼接 alpha 通道 + # 合并 BGR 通道和 alpha 通道 + new_image = np.dstack((new_bgr, new_bgr_alpha)) + else: + new_image = new_bgr + return new_image + + +if __name__ == '__main__': + request_data = MannequinModel( + mannequins="aida-sys-image/models/male/dc36ce58-46c3-4b6f-8787-5ca7d6fc26e6.png", + scale=[0.75, 0.75], + bucket_name="test", + mannequin_name="mannequin_name" + ) + service = MannequinEditService(request_data) + print(service()) diff --git a/requirements.txt b/requirements.txt index 7350714..909fb21 100644 Binary files a/requirements.txt and b/requirements.txt differ