From 5ed53a1e7c452602be8f8a633e0b3bb39183fb8a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Wed, 17 Apr 2024 17:37:51 +0800 Subject: [PATCH] =?UTF-8?q?feat=20generate=20image=20=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E8=A1=A5=E5=85=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/service/generate_image/service.py | 94 +++++++++++++++------------ 1 file changed, 52 insertions(+), 42 deletions(-) diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index 46c96a9..d0bda9a 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -47,6 +47,9 @@ class GenerateImage: self.batch_size = 1 self.category = request_data.category self.index = 0 + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): # Get data of an object. @@ -60,24 +63,23 @@ class GenerateImage: image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) return image_cv2 - def __call__(self, *args, **kwargs): - self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}) - self.redis_client.set(self.tasks_id, self.generate_data) - self.redis_client.expire(self.tasks_id, 600) - def callback(self, result, error): if error: - generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) - self.redis_client.set(self.tasks_id, generate_data) + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(error) + self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: image_result = result.as_numpy("generated_image")[0] image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{image_url}'}) - self.redis_client.set(self.tasks_id, generate_data) + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['data'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) def read_tasks_status(self): - status_data = json.loads(self.redis_client.get(self.tasks_id)) - return status_data + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data def infer(self, inputs): return self.grpc_client.async_infer( @@ -87,45 +89,53 @@ class GenerateImage: ) def get_result(self): - prompts = [self.prompt] * self.batch_size - modes = [self.mode] * self.batch_size - images = [self.image.astype(np.float16)] * self.batch_size + try: + prompts = [self.prompt] * self.batch_size + modes = [self.mode] * self.batch_size + images = [self.image.astype(np.float16)] * self.batch_size - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) - image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) - input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") - input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") + input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_text.set_data_from_numpy(text_obj) - input_image.set_data_from_numpy(image_obj) - input_mode.set_data_from_numpy(mode_obj) + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_mode.set_data_from_numpy(mode_obj) - inputs = [input_text, input_image, input_mode] - ctx = self.infer(inputs) - time_out = 60 - while time_out > 0: - generate_data = self.read_tasks_status() - if generate_data['status'] in ["REVOKED", "FAILURE"]: - ctx.cancel() - self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) - logger.info(f" [x] Sent {json.dumps(generate_data, indent=4)}") - break - elif generate_data['status'] == "SUCCESS": - self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) - logger.info(f" [x] Sent {json.dumps(generate_data, indent=4)}") - break - time_out -= 1 - time.sleep(0.1) - return self.read_tasks_status() + inputs = [input_text, input_image, input_mode] + ctx = self.infer(inputs) + time_out = 60 + generate_data = None + while time_out > 0: + generate_data, _ = self.read_tasks_status() + if generate_data['status'] in ["REVOKED", "FAILURE"]: + ctx.cancel() + break + elif generate_data['status'] == "SUCCESS": + break + time_out -= 1 + time.sleep(0.1) + return generate_data + except Exception as e: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = "failure" + self.generate_data['data'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) - data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} - generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) redis_client.set(tasks_id, generate_data) return data