Files
AiDA_Python/app/service/generate_image/service_generate_image.py

198 lines
9.1 KiB
Python
Raw Normal View History

2024-04-15 18:07:25 +08:00
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
2024-04-15 18:07:25 +08:00
@Author 周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
2024-04-16 16:36:17 +08:00
import cv2
import minio
import numpy as np
2024-04-15 18:07:25 +08:00
import redis
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype
2024-04-15 18:07:25 +08:00
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
from app.service.generate_image.utils.mq import publish_status
2024-06-21 17:13:39 +08:00
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
2024-04-15 18:07:25 +08:00
logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
2024-12-01 15:35:09 +08:00
self.version = request_data.version
if request_data.version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
else:
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
2024-04-15 18:07:25 +08:00
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
2024-04-16 16:36:17 +08:00
if request_data.mode == "img2img":
2024-04-25 17:36:35 +08:00
# cv2 读图片是BGR PIL读图片是RGB
2024-04-16 16:36:17 +08:00
self.image = self.get_image(request_data.image_url)
else:
2024-04-16 16:36:17 +08:00
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.mode = request_data.mode
2024-04-15 18:07:25 +08:00
self.batch_size = 1
self.category = request_data.category
if self.category == "sketch":
self.prompt = f"{self.category},{self.prompt}"
self.index = 0
self.gender = request_data.gender
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''}
2024-04-17 17:37:51 +08:00
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
2024-04-15 18:07:25 +08:00
2024-04-16 16:36:17 +08:00
def get_image(self, image_url):
# Get data of an object.
# Read data from response.
2024-04-25 17:36:35 +08:00
# read image use cv2
2024-04-16 16:36:17 +08:00
try:
2024-06-20 16:23:02 +08:00
# response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="cv2")
2024-04-25 17:36:35 +08:00
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image = cv2.resize(image_rbg, (1024, 1024))
2024-04-16 16:36:17 +08:00
except minio.error.S3Error:
image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image
2024-04-16 16:36:17 +08:00
2024-04-15 18:07:25 +08:00
def callback(self, result, error):
if error:
2024-04-17 17:37:51 +08:00
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
# self.generate_data['data'] = str(error)
2024-04-17 17:37:51 +08:00
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
2024-04-15 18:07:25 +08:00
else:
2024-04-25 17:36:35 +08:00
# pil图像转成numpy数组
image = result.as_numpy("generated_image")
2024-04-25 17:36:35 +08:00
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
is_smudge = True
if self.category == "sketch":
if self.version == "fast":
# 色阶调整
cutoff = 1
levels_img = autoLevels(image_result, cutoff)
# 亮度调整
luminance = luminance_adjust(0.3, levels_img)
# 去背景
remove_bg_image = remove_background(luminance)
# 人脸检测
# if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0:
# is_smudge = False
# else:
# 污点/
is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id)
# 类型识别
category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender)
self.generate_data['category'] = str(category)
image_result = not_smudge_image
else:
category, scores, not_smudge_image = generate_category_recognition(image=image_result, gender=self.gender)
self.generate_data['category'] = str(category)
image_result = not_smudge_image
if is_smudge: # 无污点
# image_result = adjust_contrast(image_result)
2024-06-23 16:30:18 +08:00
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
2024-05-13 10:44:20 +08:00
else: # 有污点 保存图片到本地 测试用
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# logger.info(f"stain_detection result : {self.generate_data}")
2024-04-15 18:07:25 +08:00
def read_tasks_status(self):
2024-04-17 17:37:51 +08:00
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
2024-04-15 18:07:25 +08:00
def get_result(self):
2024-04-17 17:37:51 +08:00
try:
prompts = [self.prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
2024-12-01 15:30:32 +08:00
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
2024-04-17 17:37:51 +08:00
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
2024-12-01 15:35:09 +08:00
if self.version == "fast":
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
2024-12-01 15:36:24 +08:00
else:
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback, priority=1)
2024-12-01 15:36:24 +08:00
time_out = 600
2024-04-17 17:37:51 +08:00
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(e)
2024-04-17 17:37:51 +08:00
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if not DEBUG:
publish_status(str_generate_data, GI_RABBITMQ_QUEUES)
2024-04-15 18:07:25 +08:00
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
2024-04-17 17:37:51 +08:00
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
2024-04-15 18:07:25 +08:00
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
2024-04-15 18:33:20 +08:00
rd = GenerateImageModel(
tasks_id="123-89",
prompt="Women's clothing ,dress,technical drawing style, clean line art, no shading, no texture, flat sketch, no human body, no face, centered composition, pure white background, single garmentsingle garment only, front flat view",
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
mode='txt2img',
2024-06-20 16:23:02 +08:00
category="test",
gender="male",
2024-12-01 15:30:32 +08:00
version="high"
2024-04-15 18:07:25 +08:00
)
2024-04-15 18:33:20 +08:00
server = GenerateImage(rd)
print(server.get_result())