diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 4f85002..29f34d6 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -20,6 +20,7 @@ class GenerateProductImageModel(BaseModel): tasks_id: str prompt: str image_url: str + image_strength: float class GenerateRelightImageModel(BaseModel): diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index dcdf09f..6ee1bc6 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -7,17 +7,17 @@ @Date :2023/7/26 12:01:05 @detail : """ -import io import json import logging import time + import cv2 +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np from PIL import Image, ImageOps -from minio import Minio from tritonclient.utils import np_to_triton_dtype + from app.core.config import * from app.schemas.generate_image import GenerateProductImageModel from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image @@ -37,6 +37,7 @@ class GenerateProductImage: self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.category = "product_image" + self.image_strength = request_data.image_strength self.batch_size = 1 self.prompt = request_data.prompt self.image, self.image_size = pre_processing_image(request_data.image_url) @@ -74,13 +75,16 @@ class GenerateProductImage: text_obj = np.array(prompts, dtype="object").reshape(1) image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8") + input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) - inputs = [input_text, input_image] + inputs = [input_text, input_image, input_image_strength] + input_image_strength.set_data_from_numpy(image_strength_obj) ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback) time_out = 600 @@ -144,6 +148,7 @@ if __name__ == '__main__': rd = GenerateProductImageModel( tasks_id="123-89", prompt="", + image_strength=0.9, # prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", )