feat 新增对比度

This commit is contained in:
zchen
2024-04-23 20:45:34 +08:00
parent abff47b264
commit 04397c3b6e
3 changed files with 156 additions and 179 deletions

View File

@@ -22,6 +22,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
@@ -84,6 +85,7 @@ class GenerateImage:
is_smudge, not_smudge_image = stain_detection(remove_bg_image)
image_result = not_smudge_image
if is_smudge: # 无污点
image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"

View File

@@ -9,224 +9,169 @@
"""
import json
import logging
import minio
import numpy as np
import random
import redis
import tritonclient
import tritonclient.grpc as grpc_client
from minio import Minio
import cv2
from PIL import Image
import time
from io import BytesIO
import cv2
import minio
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.image_processing import remove_background
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
self.tasks_id = request_data.tasks_id
self.model = request_data.model
self.request_count = request_data.request_count
self.prompt = request_data.prompt
self.image = request_data.image
self.mode = request_data.mode
self.batch_size = request_data.batch_size
self.image_url = request_data.image_url
self.user_id = request_data.user_id
self.content = request_data.content
self.category = request_data.category
self.model_name = f"{self.category}{GI_MODEL_NAME}"
self.mode = request_data.mode
self.version = request_data.version
self.triton_client = grpc_client.InferenceServerClient(url="1")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.samples = 4 # no.of images to generate
self.steps = 24
self.guidance_scale = 7
self.seed = random.randint(0, 2000000000)
self.batch_size = 1
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})
self.redis_client.set(self.tasks_id, self.generate_data)
def get_result(self):
pass
@staticmethod
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
@staticmethod
def preprocess_image(image, category):
height, width, _ = image.shape
if category == "print" or category == "moodboard":
square_size = min(height, width)
start_x = (width - square_size) // 2
start_y = (height - square_size) // 2
cropped = image[start_y: start_y + square_size, start_x: start_x + square_size]
resized_image = cv2.resize(cropped, (512, 512))
elif category == "sketch":
# below is the way that get "bigger" square image.
max_dimension = max(height, width)
square_image = np.ones((max_dimension, max_dimension, 3), dtype=np.uint8) * 255
start_h = (max_dimension - height) // 2
start_w = (max_dimension - width) // 2
square_image[start_h:start_h + height, start_w:start_w + width] = image
resized_image = cv2.resize(square_image, (512, 512))
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
if request_data.mode == "img2img":
self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else:
raise ValueError(f"wrong category {category}, only in moodboard, print and sketch!")
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
return resized_image
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.mode = request_data.mode
self.batch_size = 1
self.category = request_data.category
self.index = 0
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self):
def get_image(self, image_url):
# Get data of an object.
# Read data from response.
try:
response = self.minio_client.get_object(self.image_url.split('/')[0], self.image_url[self.image_url.find('/') + 1:])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = self.preprocess_image(img, self.category)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = cv2.resize(image_cv2, (1024, 1024))
except minio.error.S3Error:
img = np.random.randn(512, 512, 3)
return img
image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image
def callback(self, result, error):
if error:
generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, generate_data)
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
images = result.as_numpy("IMAGES")
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
# for i in range(len(pil_images)):
# pil = pil_images[i]
# pil.save(f'./temp_i2_{i}.png')
# self.image_grid(pil_images, rows, cols)
url_list = []
for i, image in enumerate(pil_images):
if self.category == "sketch":
image = remove_background(np.asarray(image))
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", )
url_list.append(image_url)
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'})
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data)
logger.info(f" [x] Sent {generate_data}")
self.redis_client.set(self.tasks_id, generate_data)
image_result = result.as_numpy("generated_image")[0]
is_smudge = True
if self.category == "sketch":
# 去背景
remove_bg_image = remove_background(np.asarray(image_result))
# 污点检测
is_smudge, not_smudge_image = stain_detection(remove_bg_image)
image_result = not_smudge_image
if is_smudge: # 无污点
image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else: # 有污点
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = str(GI_SYS_IMAGE_URL)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# logger.info(f"stain_detection result : {self.generate_data}")
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.infer(
model_name=GI_MODEL_NAME,
inputs=inputs,
# callback=self.callback
)
# @RunTime
def get_result(self):
self.triton_client.get_model_metadata(model_name=self.model_name, model_version=self.version)
self.triton_client.get_model_config(model_name=self.model_name, model_version=self.version)
try:
prompts = [self.prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
image = self.get_image()
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
# Input placeholder
prompt_in = tritonclient.grpc.InferInput(name="PROMPT", shape=(self.batch_size,), datatype="BYTES")
samples_in = tritonclient.grpc.InferInput("SAMPLES", (self.batch_size,), "INT32")
steps_in = tritonclient.grpc.InferInput("STEPS", (self.batch_size,), "INT32")
guidance_scale_in = tritonclient.grpc.InferInput("GUIDANCE_SCALE", (self.batch_size,), "FP32")
seed_in = tritonclient.grpc.InferInput("SEED", (self.batch_size,), "INT64")
input_images_in = tritonclient.grpc.InferInput("INPUT_IMAGES", image.shape, "FP16")
images = tritonclient.grpc.InferRequestedOutput(name="IMAGES",
# binary_data=False
)
mode_in = tritonclient.grpc.InferInput("MODE", (self.batch_size,), "INT32")
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16")
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype))
# Setting inputs
prompt_in.set_data_from_numpy(np.asarray([self.content] * self.batch_size, dtype=object))
samples_in.set_data_from_numpy(np.asarray([self.samples], dtype=np.int32))
steps_in.set_data_from_numpy(np.asarray([self.steps], dtype=np.int32))
guidance_scale_in.set_data_from_numpy(np.asarray([self.guidance_scale], dtype=np.float32))
seed_in.set_data_from_numpy(np.asarray([self.seed], dtype=np.int64))
input_images_in.set_data_from_numpy(image.astype(np.float16))
mode_in.set_data_from_numpy(np.asarray([self.mode], dtype=np.int32))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
# inference
# @RunTime
def infer():
return self.triton_client.async_infer(
model_name=self.model_name,
model_version=self.version,
inputs=[prompt_in, samples_in, steps_in, guidance_scale_in, seed_in, input_images_in, mode_in],
outputs=[images],
callback=self.callback
)
ctx = infer()
time_out = 60
while time_out > 0:
generate_data = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data))
logger.info(f" [x] Sent {generate_data}")
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
inputs = [input_text, input_image, input_mode]
ctx = self.infer(inputs)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
# self.generate_data['status'] = "FAILURE"
# self.generate_data['message'] = "failure"
# self.generate_data['data'] = str(e)
# self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
# finally:
# dict_generate_data, str_generate_data = self.read_tasks_status()
# if DEBUG is False:
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
# request_data = {
# "user_id": 78,
# "image_url": "123_123.png",
# "category": "print",
# "mode": 1,
# "str": "a simple print",
# "version": "1"
# }
rd = GenerateImageModel(
mode=1,
content='a blouse',
gender='',
user_id=89,
image_url='test/微信图片_20231206133428.jpg',
category='sketch',
version='1',
tasks_id='123456'
tasks_id="123-89",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
image_url="",
mode='txt2img',
category="test"
)
server = GenerateImage(rd)
server.get_result()
# print(infer_cancel(123456))
print(server.get_result())

View File

@@ -0,0 +1,30 @@
import cv2
def adjust_contrast(image, alpha=1.5, beta=-60):
"""
调整图像的对比度和亮度。
参数:
image_path (numpy): 图像的路径。
alpha (float): 控制对比度的系数。alpha > 1 增加对比度alpha < 1 减少对比度。
beta (int): 用于调整亮度的值,可以是正或负。
返回:
adjusted_image (ndarray): 调整对比度后的图像。
"""
adjusted_image = cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
return adjusted_image
# 使用示例
if __name__ == "__main__":
image = cv2.imread('output_6.png') # 替换为你的图片路径
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
alpha = 1.5 # 对比度系数大于1增加对比度
beta = -60 # 亮度调整这里设置为0不改变亮度
# 调整图像对比度
result_image = adjust_contrast(image, alpha, beta)
# 可以选择保存调整后的图像
cv2.imwrite('adjusted_image.jpg', result_image) # 保存调整后的图片