Merge branch 'refs/heads/local' into develop
This commit is contained in:
@@ -20,6 +20,7 @@ class GenerateProductImageModel(BaseModel):
|
|||||||
tasks_id: str
|
tasks_id: str
|
||||||
prompt: str
|
prompt: str
|
||||||
image_url: str
|
image_url: str
|
||||||
|
image_strength: float
|
||||||
|
|
||||||
|
|
||||||
class GenerateRelightImageModel(BaseModel):
|
class GenerateRelightImageModel(BaseModel):
|
||||||
|
|||||||
@@ -7,17 +7,17 @@
|
|||||||
@Date :2023/7/26 12:01:05
|
@Date :2023/7/26 12:01:05
|
||||||
@detail :
|
@detail :
|
||||||
"""
|
"""
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import redis
|
import redis
|
||||||
import tritonclient.grpc as grpcclient
|
import tritonclient.grpc as grpcclient
|
||||||
import numpy as np
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from minio import Minio
|
|
||||||
from tritonclient.utils import np_to_triton_dtype
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.generate_image import GenerateProductImageModel
|
from app.schemas.generate_image import GenerateProductImageModel
|
||||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||||
@@ -37,6 +37,7 @@ class GenerateProductImage:
|
|||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
self.category = "product_image"
|
self.category = "product_image"
|
||||||
|
self.image_strength = request_data.image_strength
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
self.prompt = request_data.prompt
|
self.prompt = request_data.prompt
|
||||||
self.image, self.image_size = pre_processing_image(request_data.image_url)
|
self.image, self.image_size = pre_processing_image(request_data.image_url)
|
||||||
@@ -74,13 +75,16 @@ class GenerateProductImage:
|
|||||||
|
|
||||||
text_obj = np.array(prompts, dtype="object").reshape(1)
|
text_obj = np.array(prompts, dtype="object").reshape(1)
|
||||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||||
|
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
|
||||||
|
|
||||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||||
|
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||||
|
|
||||||
input_text.set_data_from_numpy(text_obj)
|
input_text.set_data_from_numpy(text_obj)
|
||||||
input_image.set_data_from_numpy(image_obj)
|
input_image.set_data_from_numpy(image_obj)
|
||||||
inputs = [input_text, input_image]
|
inputs = [input_text, input_image, input_image_strength]
|
||||||
|
input_image_strength.set_data_from_numpy(image_strength_obj)
|
||||||
|
|
||||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback)
|
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||||
time_out = 600
|
time_out = 600
|
||||||
@@ -144,6 +148,7 @@ if __name__ == '__main__':
|
|||||||
rd = GenerateProductImageModel(
|
rd = GenerateProductImageModel(
|
||||||
tasks_id="123-89",
|
tasks_id="123-89",
|
||||||
prompt="",
|
prompt="",
|
||||||
|
image_strength=0.9,
|
||||||
# prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
# prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||||
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user