#!/usr/bin/env python # -*- coding: UTF-8 -*- """ @Project :trinity_client @File :service_att_recognition.py @Author :周成融 @Date :2023/7/26 12:01:05 @detail : """ import json import logging import time import cv2 import minio import numpy as np import redis import tritonclient.grpc as grpcclient from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust from app.service.generate_image.utils.mq import publish_status from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.utils.oss_client import oss_get_image logger = logging.getLogger() class GenerateImage: def __init__(self, request_data): self.version = request_data.version if request_data.version == "fast": self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL) else: self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) if request_data.mode == "img2img": # cv2 读图片是BGR PIL读图片是RGB self.image = self.get_image(request_data.image_url) else: self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) self.prompt = request_data.prompt self.tasks_id = request_data.tasks_id self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.mode = request_data.mode self.batch_size = 1 self.category = request_data.category if self.category == "sketch": self.prompt = f"{self.category},{self.prompt}" self.index = 0 self.gender = request_data.gender self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): # Get data of an object. # Read data from response. # read image use cv2 try: # response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) # image_file = BytesIO(response.data) # image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) # image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) # image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="cv2") image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) return image def callback(self, result, error): if error: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(error) # self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: # pil图像转成numpy数组 image = result.as_numpy("generated_image") image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) is_smudge = True if self.category == "sketch": if self.version == "fast": # 色阶调整 cutoff = 1 levels_img = autoLevels(image_result, cutoff) # 亮度调整 luminance = luminance_adjust(0.3, levels_img) # 去背景 remove_bg_image = remove_background(luminance) # 人脸检测 # if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: # is_smudge = False # else: # 污点/ is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) # 类型识别 category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) self.generate_data['category'] = str(category) image_result = not_smudge_image else: category, scores, not_smudge_image = generate_category_recognition(image=image_result, gender=self.gender) self.generate_data['category'] = str(category) image_result = not_smudge_image if is_smudge: # 无污点 # image_result = adjust_contrast(image_result) image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") # logger.info(f"upload image SUCCESS : {image_url}") self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" self.generate_data['image_url'] = str(image_url) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: # 有污点 保存图片到本地 测试用 self.generate_data['status'] = "SUCCESS" self.generate_data['message'] = "success" self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) # logger.info(f"stain_detection result : {self.generate_data}") def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) return json.loads(status_data), status_data def get_result(self): try: prompts = [self.prompt] * self.batch_size modes = [self.mode] * self.batch_size images = [self.image.astype(np.float16)] * self.batch_size text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) input_text.set_data_from_numpy(text_obj) input_image.set_data_from_numpy(image_obj) input_mode.set_data_from_numpy(mode_obj) inputs = [input_text, input_image, input_mode] if self.version == "fast": ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1) else: ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1) time_out = 600 generate_data = None while time_out > 0: generate_data, _ = self.read_tasks_status() if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break elif generate_data['status'] == "SUCCESS": break time_out -= 1 time.sleep(0.1) return generate_data except Exception as e: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) raise Exception(str(e)) finally: dict_generate_data, str_generate_data = self.read_tasks_status() if not DEBUG: publish_status(str_generate_data, GI_RABBITMQ_QUEUES) def infer_cancel(tasks_id): redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} generate_data = json.dumps(data) redis_client.set(tasks_id, generate_data) return data if __name__ == '__main__': rd = GenerateImageModel( tasks_id="123-89", prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view", image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", mode='txt2img', category="test", gender="male", version="high" ) server = GenerateImage(rd) print(server.get_result())