diff --git a/app/core/config.py b/app/core/config.py index 97e014d..dd56258 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -100,9 +100,12 @@ SR_MINIO_BUCKET = "aida-users" SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ENV}") # GenerateImage service config -GI_MODEL_NAME = 'stable_diffusion_xl' FAST_GI_MODEL_URL = '10.1.1.243:10011' -GI_MODEL_URL = '10.1.1.240:10041' +FAST_GI_MODEL_NAME = 'stable_diffusion_xl' + +GI_MODEL_URL = '10.1.1.240:10061' +GI_MODEL_NAME = 'flux' + GI_MINIO_BUCKET = "aida-users" GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" diff --git a/app/service/generate_image/service_generate_image.py b/app/service/generate_image/service_generate_image.py index 7d2937b..8cf7cf9 100644 --- a/app/service/generate_image/service_generate_image.py +++ b/app/service/generate_image/service_generate_image.py @@ -138,8 +138,8 @@ class GenerateImage: image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") - input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) + input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) @@ -185,12 +185,12 @@ def infer_cancel(tasks_id): if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", - prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', + prompt='a fabric print, flower, yellow, 4k, hud', image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', category="test", gender="male", - version="fast" + version="high" ) server = GenerateImage(rd) print(server.get_result())