feat 新增对比度

This commit is contained in:
zchen
2024-04-23 20:45:34 +08:00
parent abff47b264
commit 04397c3b6e
3 changed files with 156 additions and 179 deletions

View File

@@ -22,6 +22,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import GenerateImageModel from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection from app.service.generate_image.utils.image_processing import remove_background, stain_detection
from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.upload_sd_image import upload_png_sd
@@ -84,6 +85,7 @@ class GenerateImage:
is_smudge, not_smudge_image = stain_detection(remove_bg_image) is_smudge, not_smudge_image = stain_detection(remove_bg_image)
image_result = not_smudge_image image_result = not_smudge_image
if is_smudge: # 无污点 if is_smudge: # 无污点
image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png") image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}") # logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS" self.generate_data['status'] = "SUCCESS"

View File

@@ -9,224 +9,169 @@
""" """
import json import json
import logging import logging
import minio
import numpy as np
import random
import redis
import tritonclient
import tritonclient.grpc as grpc_client
from minio import Minio
import cv2
from PIL import Image
import time import time
from io import BytesIO
import cv2
import minio
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
from app.schemas.generate_image import GenerateImageModel from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.image_processing import remove_background from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection
from app.service.generate_image.utils.upload_sd_image import upload_png_sd from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger() logger = logging.getLogger()
class GenerateImage: class GenerateImage:
def __init__(self, request_data): def __init__(self, request_data):
self.tasks_id = request_data.tasks_id if DEBUG is False:
self.model = request_data.model self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.request_count = request_data.request_count self.channel = self.connection.channel()
self.prompt = request_data.prompt # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.image = request_data.image # self.channel = self.connection.channel()
self.mode = request_data.mode
self.batch_size = request_data.batch_size
self.image_url = request_data.image_url
self.user_id = request_data.user_id
self.content = request_data.content
self.category = request_data.category
self.model_name = f"{self.category}{GI_MODEL_NAME}"
self.mode = request_data.mode
self.version = request_data.version
self.triton_client = grpc_client.InferenceServerClient(url="1")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.samples = 4 # no.of images to generate self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.steps = 24 self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.guidance_scale = 7 if request_data.mode == "img2img":
self.seed = random.randint(0, 2000000000) self.image = self.get_image(request_data.image_url)
self.batch_size = 1 self.prompt = request_data.prompt
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})
self.redis_client.set(self.tasks_id, self.generate_data)
def get_result(self):
pass
@staticmethod
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
@staticmethod
def preprocess_image(image, category):
height, width, _ = image.shape
if category == "print" or category == "moodboard":
square_size = min(height, width)
start_x = (width - square_size) // 2
start_y = (height - square_size) // 2
cropped = image[start_y: start_y + square_size, start_x: start_x + square_size]
resized_image = cv2.resize(cropped, (512, 512))
elif category == "sketch":
# below is the way that get "bigger" square image.
max_dimension = max(height, width)
square_image = np.ones((max_dimension, max_dimension, 3), dtype=np.uint8) * 255
start_h = (max_dimension - height) // 2
start_w = (max_dimension - width) // 2
square_image[start_h:start_h + height, start_w:start_w + width] = image
resized_image = cv2.resize(square_image, (512, 512))
else: else:
raise ValueError(f"wrong category {category}, only in moodboard, print and sketch!") self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
return resized_image self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.mode = request_data.mode
self.batch_size = 1
self.category = request_data.category
self.index = 0
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self): def get_image(self, image_url):
# Get data of an object. # Get data of an object.
# Read data from response. # Read data from response.
try: try:
response = self.minio_client.get_object(self.image_url.split('/')[0], self.image_url[self.image_url.find('/') + 1:]) response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型 image_file = BytesIO(response.data)
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码 image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
img = self.preprocess_image(img, self.category) image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) image = cv2.resize(image_cv2, (1024, 1024))
except minio.error.S3Error: except minio.error.S3Error:
img = np.random.randn(512, 512, 3) image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return img return image
def callback(self, result, error): def callback(self, result, error):
if error: if error:
generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"}) self.generate_data['status'] = "FAILURE"
self.redis_client.set(self.tasks_id, generate_data) self.generate_data['message'] = str(error)
self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else: else:
images = result.as_numpy("IMAGES") image_result = result.as_numpy("generated_image")[0]
if images.ndim == 3: is_smudge = True
images = images[None, ...] if self.category == "sketch":
images = (images * 255).round().astype("uint8") # 去背景
pil_images = [Image.fromarray(image) for image in images] remove_bg_image = remove_background(np.asarray(image_result))
# 污点检测
# for i in range(len(pil_images)): is_smudge, not_smudge_image = stain_detection(remove_bg_image)
# pil = pil_images[i] image_result = not_smudge_image
# pil.save(f'./temp_i2_{i}.png') if is_smudge: # 无污点
# self.image_grid(pil_images, rows, cols) image_result = adjust_contrast(image_result)
url_list = [] image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
for i, image in enumerate(pil_images): # logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
if self.category == "sketch": self.generate_data['message'] = "success"
image = remove_background(np.asarray(image)) self.generate_data['data'] = str(image_url)
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", ) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
url_list.append(image_url) else: # 有污点
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'}) self.generate_data['status'] = "SUCCESS"
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data) self.generate_data['message'] = "success"
logger.info(f" [x] Sent {generate_data}") self.generate_data['data'] = str(GI_SYS_IMAGE_URL)
self.redis_client.set(self.tasks_id, generate_data) self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# logger.info(f"stain_detection result : {self.generate_data}")
def read_tasks_status(self): def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id)) status_data = self.redis_client.get(self.tasks_id)
logging.info(f"{self.tasks_id} ===> {status_data}") return json.loads(status_data), status_data
return status_data
def infer(self, inputs):
return self.grpc_client.infer(
model_name=GI_MODEL_NAME,
inputs=inputs,
# callback=self.callback
)
# @RunTime
def get_result(self): def get_result(self):
self.triton_client.get_model_metadata(model_name=self.model_name, model_version=self.version) try:
self.triton_client.get_model_config(model_name=self.model_name, model_version=self.version) prompts = [self.prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
image = self.get_image() text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
# Input placeholder input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
prompt_in = tritonclient.grpc.InferInput(name="PROMPT", shape=(self.batch_size,), datatype="BYTES") input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16")
samples_in = tritonclient.grpc.InferInput("SAMPLES", (self.batch_size,), "INT32") input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype))
steps_in = tritonclient.grpc.InferInput("STEPS", (self.batch_size,), "INT32")
guidance_scale_in = tritonclient.grpc.InferInput("GUIDANCE_SCALE", (self.batch_size,), "FP32")
seed_in = tritonclient.grpc.InferInput("SEED", (self.batch_size,), "INT64")
input_images_in = tritonclient.grpc.InferInput("INPUT_IMAGES", image.shape, "FP16")
images = tritonclient.grpc.InferRequestedOutput(name="IMAGES",
# binary_data=False
)
mode_in = tritonclient.grpc.InferInput("MODE", (self.batch_size,), "INT32")
# Setting inputs input_text.set_data_from_numpy(text_obj)
prompt_in.set_data_from_numpy(np.asarray([self.content] * self.batch_size, dtype=object)) input_image.set_data_from_numpy(image_obj)
samples_in.set_data_from_numpy(np.asarray([self.samples], dtype=np.int32)) input_mode.set_data_from_numpy(mode_obj)
steps_in.set_data_from_numpy(np.asarray([self.steps], dtype=np.int32))
guidance_scale_in.set_data_from_numpy(np.asarray([self.guidance_scale], dtype=np.float32))
seed_in.set_data_from_numpy(np.asarray([self.seed], dtype=np.int64))
input_images_in.set_data_from_numpy(image.astype(np.float16))
mode_in.set_data_from_numpy(np.asarray([self.mode], dtype=np.int32))
# inference inputs = [input_text, input_image, input_mode]
# @RunTime ctx = self.infer(inputs)
def infer(): time_out = 600
return self.triton_client.async_infer( generate_data = None
model_name=self.model_name, while time_out > 0:
model_version=self.version, generate_data, _ = self.read_tasks_status()
inputs=[prompt_in, samples_in, steps_in, guidance_scale_in, seed_in, input_images_in, mode_in], # logger.info(generate_data)
outputs=[images], if generate_data['status'] in ["REVOKED", "FAILURE"]:
callback=self.callback ctx.cancel()
) break
elif generate_data['status'] == "SUCCESS":
ctx = infer() break
time_out = 60 time_out -= 1
while time_out > 0: time.sleep(0.1)
generate_data = self.read_tasks_status() # logger.info(time_out, generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]: return generate_data
ctx.cancel() except Exception as e:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data)) # self.generate_data['status'] = "FAILURE"
logger.info(f" [x] Sent {generate_data}") # self.generate_data['message'] = "failure"
break # self.generate_data['data'] = str(e)
elif generate_data['status'] == "SUCCESS": # self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
break raise Exception(str(e))
time_out -= 1 # finally:
time.sleep(1) # dict_generate_data, str_generate_data = self.read_tasks_status()
return self.read_tasks_status() # if DEBUG is False:
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id): def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'} data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}) generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data) redis_client.set(tasks_id, generate_data)
return data return data
if __name__ == '__main__': if __name__ == '__main__':
# request_data = {
# "user_id": 78,
# "image_url": "123_123.png",
# "category": "print",
# "mode": 1,
# "str": "a simple print",
# "version": "1"
# }
rd = GenerateImageModel( rd = GenerateImageModel(
mode=1, tasks_id="123-89",
content='a blouse', prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
gender='', image_url="",
user_id=89, mode='txt2img',
image_url='test/微信图片_20231206133428.jpg', category="test"
category='sketch',
version='1',
tasks_id='123456'
) )
server = GenerateImage(rd) server = GenerateImage(rd)
server.get_result() print(server.get_result())
# print(infer_cancel(123456))

View File

@@ -0,0 +1,30 @@
import cv2
def adjust_contrast(image, alpha=1.5, beta=-60):
"""
调整图像的对比度和亮度。
参数:
image_path (numpy): 图像的路径。
alpha (float): 控制对比度的系数。alpha > 1 增加对比度alpha < 1 减少对比度。
beta (int): 用于调整亮度的值,可以是正或负。
返回:
adjusted_image (ndarray): 调整对比度后的图像。
"""
adjusted_image = cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
return adjusted_image
# 使用示例
if __name__ == "__main__":
image = cv2.imread('output_6.png') # 替换为你的图片路径
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
alpha = 1.5 # 对比度系数大于1增加对比度
beta = -60 # 亮度调整这里设置为0不改变亮度
# 调整图像对比度
result_image = adjust_contrast(image, alpha, beta)
# 可以选择保存调整后的图像
cv2.imwrite('adjusted_image.jpg', result_image) # 保存调整后的图片