From 5092a8c7bc38c089fa6e327ecb372ad4f0e951e4 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 30 May 2024 15:02:35 +0800 Subject: [PATCH] =?UTF-8?q?feat=20generate=20slogan=20|=20to=20product=20i?= =?UTF-8?q?mage=20|=20slogan=20=E6=8E=A5=E5=8F=A3=E9=83=A8=E7=BD=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service_generate_product_image.py | 181 ++++++++++++++++++ .../service_generate_single_logo.py | 72 +++++++ 2 files changed, 253 insertions(+) create mode 100644 app/service/generate_image/service_generate_product_image.py create mode 100644 app/service/generate_image/service_generate_single_logo.py diff --git a/app/service/generate_image/service_generate_product_image.py b/app/service/generate_image/service_generate_product_image.py new file mode 100644 index 0000000..ea875bd --- /dev/null +++ b/app/service/generate_image/service_generate_product_image.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import time +from io import BytesIO + +import cv2 +import minio +import redis +import tritonclient.grpc as grpcclient +import numpy as np +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import * +from app.schemas.generate_image import GenerateImageModel +from app.service.generate_image.utils.adjust_contrast import adjust_contrast +from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic +from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd + +logger = logging.getLogger() + + +class GenerateProductImage: + def __init__(self, request_data): + # if DEBUG is False: + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + self.redis_client.expire(self.tasks_id, 600) + + def get_image(self, image_url): + # Get data of an object. + # Read data from response. + # read image use cv2 + try: + response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) + image_file = BytesIO(response.data) + image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) + image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + image = cv2.resize(image_rbg, (1024, 1024)) + except minio.error.S3Error: + image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + return image + + def callback(self, result, error): + if error: + self.generate_data['status'] = "FAILURE" + self.generate_data['message'] = str(error) + # self.generate_data['data'] = str(error) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: + # pil图像转成numpy数组 + image = result.as_numpy("generated_image") + image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) + is_smudge = True + if self.category == "sketch": + # 色阶调整 + cutoff = 1 + levels_img = autoLevels(image_result, cutoff) + # 亮度调整 + luminance = luminance_adjust(0.3, levels_img) + # 去背景 + remove_bg_image = remove_background(luminance) + # 人脸检测 + if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: + is_smudge = False + else: + # 污点/ + is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) + # 类型识别 + category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) + self.generate_data['category'] = str(category) + image_result = not_smudge_image + if is_smudge: # 无污点 + # image_result = adjust_contrast(image_result) + image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: # 有污点 保存图片到本地 测试用 + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + # logger.info(f"stain_detection result : {self.generate_data}") + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def infer(self, inputs): + return self.grpc_client.async_infer( + model_name=GI_MODEL_NAME, + inputs=inputs, + callback=self.callback + ) + + def get_result(self): + try: + # prompts = [self.prompt] * self.batch_size + # modes = [self.mode] * self.batch_size + # images = [self.image.astype(np.float16)] * self.batch_size + # + # text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + # mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) + # image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) + # + # input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + # input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") + # input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) + # + # input_text.set_data_from_numpy(text_obj) + # input_image.set_data_from_numpy(image_obj) + # input_mode.set_data_from_numpy(mode_obj) + # + # inputs = [input_text, input_image, input_mode] + # ctx = self.infer(inputs) + # time_out = 600 + # generate_data = None + # while time_out > 0: + # generate_data, _ = self.read_tasks_status() + # # logger.info(generate_data) + # if generate_data['status'] in ["REVOKED", "FAILURE"]: + # ctx.cancel() + # break + # elif generate_data['status'] == "SUCCESS": + # break + # time_out -= 1 + # time.sleep(0.1) + # # logger.info(time_out, generate_data) + generate_data, _ = self.read_tasks_status() + return generate_data + except Exception as e: + self.gen_product_data['status'] = "FAILURE" + self.gen_product_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) + raise Exception(str(e)) + finally: + dict_gen_product_data, str_gen_product_data = self.read_tasks_status() + # if DEBUG is False: + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data) + logger.info(f" [x] Sent {json.dumps(dict_gen_product_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateImageModel( + tasks_id="123-89", + prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', + image_url="", + ) + server = GenerateImage(rd) + print(server.get_result()) diff --git a/app/service/generate_image/service_generate_single_logo.py b/app/service/generate_image/service_generate_single_logo.py new file mode 100644 index 0000000..0bb38a0 --- /dev/null +++ b/app/service/generate_image/service_generate_single_logo.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :trinity_client +@File :service_att_recognition.py +@Author :周成融 +@Date :2023/7/26 12:01:05 +@detail : +""" +import json +import logging +import redis +from minio import Minio +from app.core.config import * +from app.schemas.generate_image import GenerateSingleLogoImageModel + +logger = logging.getLogger() + + +class GenerateSingleLogoImage: + def __init__(self, request_data): + # if DEBUG is False: + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + # self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + self.tasks_id = request_data.tasks_id + self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] + self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + self.redis_client.expire(self.tasks_id, 600) + + def read_tasks_status(self): + status_data = self.redis_client.get(self.tasks_id) + return json.loads(status_data), status_data + + def get_result(self): + try: + generate_data, _ = self.read_tasks_status() + return generate_data + except Exception as e: + self.gen_single_logo_data['status'] = "FAILURE" + self.gen_single_logo_data['message'] = str(e) + self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data)) + raise Exception(str(e)) + finally: + dict_generate_data, str_generate_data = self.read_tasks_status() + # if DEBUG is False: + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + self.channel.basic_publish(exchange='', routing_key=GEN_SINGLE_LOGO_RABBITMQ_QUEUES, body=str_generate_data) + logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}") + + +def infer_cancel(tasks_id): + redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) + data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} + generate_data = json.dumps(data) + redis_client.set(tasks_id, generate_data) + return data + + +if __name__ == '__main__': + rd = GenerateSingleLogoImageModel( + tasks_id="123-8", + prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', + image_url="", + ) + server = GenerateSingleLogoImage(rd) + print(server.get_result())