From 638447f31305a4d660d4727c71ad0752522e53d4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Fri, 28 Jun 2024 16:58:52 +0800 Subject: [PATCH] =?UTF-8?q?feat=20generate=20to=20product=20image=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=20image=5Fstrength=E5=8F=82=E6=95=B0=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/schemas/generate_image.py | 1 + .../service_generate_product_image.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 4f85002..29f34d6 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -20,6 +20,7 @@ class GenerateProductImageModel(BaseModel): tasks_id: str prompt: str image_url: str + image_strength: float class GenerateRelightImageModel(BaseModel): diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index dcdf09f..6ee1bc6 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -7,17 +7,17 @@ @Date :2023/7/26 12:01:05 @detail : """ -import io import json import logging import time + import cv2 +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np from PIL import Image, ImageOps -from minio import Minio from tritonclient.utils import np_to_triton_dtype + from app.core.config import * from app.schemas.generate_image import GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image @@ -37,6 +37,7 @@ class GenerateProductImage: self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" + self.image_strength = request_data.image_strength self.batch_size = 1 self.prompt = request_data.prompt self.image, self.image_size = pre_processing_image(request_data.image_url) @@ -74,13 +75,16 @@ class GenerateProductImage: text_obj = np.array(prompts, dtype="object").reshape(1) image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) - inputs = [input_text, input_image] + inputs = [input_text, input_image, input_image_strength] + input_image_strength.set_data_from_numpy(image_strength_obj) ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 @@ -144,6 +148,7 @@ if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", prompt="", + image_strength=0.9, # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", )