diff --git a/app/core/config.py b/app/core/config.py index 48af4a1..ff6f0a1 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -117,9 +117,7 @@ GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f # Generate Product service config GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") -GPI_MODEL_NAME_OVERALL = 'stable_diffusion_xl_cnet' -GPI_MODEL_NAME_SINGLE = 'stable_diffusion_xl_cnet' - +GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all' GPI_MODEL_URL = '10.1.1.243:10051' # Generate Single Logo service config diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py index da4bb4b..16b814b 100644 --- a/app/service/generate_image/service_generate_product_image.py +++ b/app/service/generate_image/service_generate_product_image.py @@ -55,10 +55,7 @@ class GenerateProductImage: self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) else: # pil图像转成numpy数组 - if self.product_type == "single": - image = result.as_numpy("generated_cnet_image") - else: - image = result.as_numpy("generated_cnet_image") + image = result.as_numpy("generated_inpaint_image") image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) cropped_image = post_processing_image(image_result, self.left, self.top) image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") @@ -78,14 +75,9 @@ class GenerateProductImage: self.image = cv2.resize(self.image, (1024, 1024)) images = [self.image.astype(np.uint8)] * self.batch_size - if self.product_type == "single": - text_obj = np.array(prompts, dtype="object").reshape(-1, 1) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) - else: - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3)) - image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((-1, 1)) + text_obj = np.array(prompts, dtype="object").reshape((1)) + image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3)) + image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) # 假设 prompts、images 和 self.image_strength 已经定义 @@ -95,13 +87,11 @@ class GenerateProductImage: input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) - inputs = [input_text, input_image, input_image_strength] input_image_strength.set_data_from_numpy(image_strength_obj) - if self.product_type == "single": - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) - else: - ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) + inputs = [input_text, input_image, input_image_strength] + + ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) time_out = 600 while time_out > 0: diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 0ead375..5067d15 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -82,7 +82,7 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url ="aida-results/result_68756122-ac6b-11ef-8bf8-0826ae3ad6b3.png" + url ="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "2"