diff --git a/app/service/generate_image/service.py b/app/service/generate_image/service.py index 45346bf..6f8d092 100644 --- a/app/service/generate_image/service.py +++ b/app/service/generate_image/service.py @@ -22,19 +22,25 @@ from tritonclient.utils import np_to_triton_dtype from app.core.config import * from app.schemas.generate_image import GenerateImageModel -from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.adjust_contrast import adjust_contrast +from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic +from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd + logger = logging.getLogger() class GenerateImage: def __init__(self, request_data): + if DEBUG is False: + self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + self.channel = self.connection.channel() + # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) + # self.channel = self.connection.channel() self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) - self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) - self.channel = self.connection.channel() if request_data.mode == "img2img": + # cv2 读图片是BGR PIL读图片是RGB self.image = self.get_image(request_data.image_url) self.prompt = request_data.prompt else: @@ -47,36 +53,69 @@ class GenerateImage: self.batch_size = 1 self.category = request_data.category self.index = 0 - self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''} + self.gender = request_data.gender + self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''} self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) self.redis_client.expire(self.tasks_id, 600) def get_image(self, image_url): # Get data of an object. # Read data from response. + # read image use cv2 try: response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:]) image_file = BytesIO(response.data) image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8) image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB) + image = cv2.resize(image_rbg, (1024, 1024)) except minio.error.S3Error: - image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - return image_cv2 + image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + return image def callback(self, result, error): if error: self.generate_data['status'] = "FAILURE" self.generate_data['message'] = str(error) - self.generate_data['data'] = str(error) + # self.generate_data['data'] = str(error) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) else: - image_result = result.as_numpy("generated_image")[0] - image_result = adjust_contrast(image_result) - image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") - self.generate_data['status'] = "SUCCESS" - self.generate_data['message'] = "success" - self.generate_data['data'] = str(image_url) - self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + # pil图像转成numpy数组 + image = result.as_numpy("generated_image") + image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) + is_smudge = True + if self.category == "sketch": + # 色阶调整 + cutoff = 1 + levels_img = autoLevels(image_result, cutoff) + # 亮度调整 + luminance = luminance_adjust(0.3, levels_img) + # 去背景 + remove_bg_image = remove_background(luminance) + # 人脸检测 + # if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: + # is_smudge = False + # else: + # 污点/ + is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) + # 类型识别 + category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) + self.generate_data['category'] = str(category) + image_result = not_smudge_image + if is_smudge: # 无污点 + # image_result = adjust_contrast(image_result) + image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") + # logger.info(f"upload image SUCCESS : {image_url}") + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(image_url) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + else: # 有污点 保存图片到本地 测试用 + self.generate_data['status'] = "SUCCESS" + self.generate_data['message'] = "success" + self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL) + self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) + # logger.info(f"stain_detection result : {self.generate_data}") def read_tasks_status(self): status_data = self.redis_client.get(self.tasks_id) @@ -109,10 +148,11 @@ class GenerateImage: inputs = [input_text, input_image, input_mode] ctx = self.infer(inputs) - time_out = 60 + time_out = 600 generate_data = None while time_out > 0: generate_data, _ = self.read_tasks_status() + # logger.info(generate_data) if generate_data['status'] in ["REVOKED", "FAILURE"]: ctx.cancel() break @@ -120,16 +160,18 @@ class GenerateImage: break time_out -= 1 time.sleep(0.1) + # logger.info(time_out, generate_data) return generate_data except Exception as e: self.generate_data['status'] = "FAILURE" - self.generate_data['message'] = "failure" - self.generate_data['data'] = str(e) + self.generate_data['message'] = str(e) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data)) raise Exception(str(e)) finally: dict_generate_data, str_generate_data = self.read_tasks_status() - self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + if DEBUG is False: + self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) + # self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data) logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")