Merge branch 'develop'

This commit is contained in:
zhouchengrong
2024-05-17 18:22:35 +08:00
11 changed files with 25049 additions and 322 deletions

View File

@@ -10,6 +10,7 @@ logger = logging.getLogger()
@router.post("/generate_image")
def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks):
try:
logger.info(f"request data ### : {request_item}")
service = GenerateImage(request_item)
background_tasks.add_task(service.get_result)
code = 200

View File

@@ -23,9 +23,11 @@ DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
RABBITMQ_ENV = "" # 生产环境
# RABBITMQ_ENV = "-dev" # 开发环境
@@ -52,6 +54,13 @@ RABBITMQ_PARAMS = {
"virtual_host": "/"
}
# milvus 配置
MILVUS_DB_HOST = "10.1.1.240"
MILVUS_ALIAS = "default"
MILVUS_PORT = "19530"
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
MILVUS_TABLE_SEG = "seg_cache"
# attribute service config
ATT_TRITON_URL = "10.1.1.240:8020"
@@ -62,15 +71,50 @@ SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ENV}")
# GenerateImage service config
GI_MODEL_NAME = 'stable_diffusion_xl_lcm'
GI_MODEL_URL = '10.1.1.150:8001'
GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10041'
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SEG service config
SEG_MODEL_URL = '10.1.1.240:10000'
SEGMENTATION = {
"new_model_name": "seg_knet",
"name": "seg_ocrnet_hr18",
"input": "seg_input__0",
"output": "seg_output__0",
}
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:9000'
AIDA_CLOTHING = "aida-clothing"
# 优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
'hairstyle_front': 97,
'outwear_front': 20,
'tops_front': 19,
'dress_front': 18,
'blouse_front': 17,
'skirt_front': 16,
'trousers_front': 15,
'bottoms_front': 14,
'shoes_right': 1,
'shoes_left': 1,
'body': 0,
'bottoms_back': -14,
'trousers_back': -15,
'skirt_back': -16,
'blouse_back': -17,
'dress_back': -18,
'tops_back': -19,
'outwear_back': -20,
'hairstyle_back': -97,
'bag_back': -98,
'earring_back': -99,
}

View File

@@ -7,3 +7,4 @@ class GenerateImageModel(BaseModel):
image_url: str
mode: str
category: str
gender: str

View File

@@ -1,7 +1,7 @@
labelName,join_attr,taskName,taskId
top,attr_top,category,1
pants,attr_pants,category,1
skirt,attr_skirt,category,1
dress,attr_dress,category,1
outwear,attr_outwear,category,1
jumpsuit,attr_jumpsuit,category,1
Blouse,attr_top,category,1
Trousers,attr_pants,category,1
Skirt,attr_skirt,category,1
Dress,attr_dress,category,1
Outwear,attr_outwear,category,1
Jumpsuit,attr_jumpsuit,category,1
1 labelName join_attr taskName taskId
2 top Blouse attr_top category 1
3 pants Trousers attr_pants category 1
4 skirt Skirt attr_skirt category 1
5 dress Dress attr_dress category 1
6 outwear Outwear attr_outwear category 1
7 jumpsuit Jumpsuit attr_jumpsuit category 1

View File

@@ -22,19 +22,25 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd
logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
if request_data.mode == "img2img":
# cv2 读图片是BGR PIL读图片是RGB
self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else:
@@ -47,35 +53,69 @@ class GenerateImage:
self.batch_size = 1
self.category = request_data.category
self.index = 0
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''}
self.gender = request_data.gender
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
# Get data of an object.
# Read data from response.
# read image use cv2
try:
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image = cv2.resize(image_rbg, (1024, 1024))
except minio.error.S3Error:
image_cv2 = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image_cv2
image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image
def callback(self, result, error):
if error:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
self.generate_data['data'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
image_result = result.as_numpy("generated_image")[0]
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# pil图像转成numpy数组
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
is_smudge = True
if self.category == "sketch":
# 色阶调整
cutoff = 1
levels_img = autoLevels(image_result, cutoff)
# 亮度调整
luminance = luminance_adjust(0.3, levels_img)
# 去背景
remove_bg_image = remove_background(luminance)
# 人脸检测
if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0:
is_smudge = False
else:
# 污点/
is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id)
# 类型识别
category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender)
self.generate_data['category'] = str(category)
image_result = not_smudge_image
if is_smudge: # 无污点
# image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else: # 有污点 保存图片到本地 测试用
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['image_url'] = str(GI_SYS_IMAGE_URL)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# logger.info(f"stain_detection result : {self.generate_data}")
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
@@ -108,10 +148,11 @@ class GenerateImage:
inputs = [input_text, input_image, input_mode]
ctx = self.infer(inputs)
time_out = 60
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
@@ -119,16 +160,18 @@ class GenerateImage:
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = "failure"
self.generate_data['data'] = str(e)
self.generate_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")

View File

@@ -9,224 +9,169 @@
"""
import json
import logging
import minio
import numpy as np
import random
import redis
import tritonclient
import tritonclient.grpc as grpc_client
from minio import Minio
import cv2
from PIL import Image
import time
from io import BytesIO
import cv2
import minio
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.remove_background import remove_background
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class GenerateImage:
def __init__(self, request_data):
self.tasks_id = request_data.tasks_id
self.model = request_data.model
self.request_count = request_data.request_count
self.prompt = request_data.prompt
self.image = request_data.image
self.mode = request_data.mode
self.batch_size = request_data.batch_size
self.image_url = request_data.image_url
self.user_id = request_data.user_id
self.content = request_data.content
self.category = request_data.category
self.model_name = f"{self.category}{GI_MODEL_NAME}"
self.mode = request_data.mode
self.version = request_data.version
self.triton_client = grpc_client.InferenceServerClient(url="1")
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.samples = 4 # no.of images to generate
self.steps = 24
self.guidance_scale = 7
self.seed = random.randint(0, 2000000000)
self.batch_size = 1
self.generate_data = json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''})
self.redis_client.set(self.tasks_id, self.generate_data)
def get_result(self):
pass
@staticmethod
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
@staticmethod
def preprocess_image(image, category):
height, width, _ = image.shape
if category == "print" or category == "moodboard":
square_size = min(height, width)
start_x = (width - square_size) // 2
start_y = (height - square_size) // 2
cropped = image[start_y: start_y + square_size, start_x: start_x + square_size]
resized_image = cv2.resize(cropped, (512, 512))
elif category == "sketch":
# below is the way that get "bigger" square image.
max_dimension = max(height, width)
square_image = np.ones((max_dimension, max_dimension, 3), dtype=np.uint8) * 255
start_h = (max_dimension - height) // 2
start_w = (max_dimension - width) // 2
square_image[start_h:start_h + height, start_w:start_w + width] = image
resized_image = cv2.resize(square_image, (512, 512))
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
if request_data.mode == "img2img":
self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else:
raise ValueError(f"wrong category {category}, only in moodboard, print and sketch!")
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
return resized_image
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.mode = request_data.mode
self.batch_size = 1
self.category = request_data.category
self.index = 0
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'data': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self):
def get_image(self, image_url):
# Get data of an object.
# Read data from response.
try:
response = self.minio_client.get_object(self.image_url.split('/')[0], self.image_url[self.image_url.find('/') + 1:])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = self.preprocess_image(img, self.category)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = cv2.resize(image_cv2, (1024, 1024))
except minio.error.S3Error:
img = np.random.randn(512, 512, 3)
return img
image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
return image
def callback(self, result, error):
if error:
generate_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, generate_data)
self.generate_data['status'] = "FAILURE"
self.generate_data['message'] = str(error)
self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else:
images = result.as_numpy("IMAGES")
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
# for i in range(len(pil_images)):
# pil = pil_images[i]
# pil.save(f'./temp_i2_{i}.png')
# self.image_grid(pil_images, rows, cols)
url_list = []
for i, image in enumerate(pil_images):
if self.category == "sketch":
image = remove_background(np.asarray(image))
image_url = upload_png_sd(image, user_id=self.user_id, category=f"{self.category}", object_name=f"{generate_uuid()}_{i}.png", )
url_list.append(image_url)
generate_data = json.dumps({'status': 'SUCCESS', 'message': 'success', 'data': f'{url_list}'})
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=generate_data)
logger.info(f" [x] Sent {generate_data}")
self.redis_client.set(self.tasks_id, generate_data)
image_result = result.as_numpy("generated_image")[0]
is_smudge = True
if self.category == "sketch":
# 去背景
remove_bg_image = remove_background(np.asarray(image_result))
# 污点检测
is_smudge, not_smudge_image = stain_detection(remove_bg_image)
image_result = not_smudge_image
if is_smudge: # 无污点
image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
else: # 有污点
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
self.generate_data['data'] = str(GI_SYS_IMAGE_URL)
self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
# logger.info(f"stain_detection result : {self.generate_data}")
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.infer(
model_name=GI_MODEL_NAME,
inputs=inputs,
# callback=self.callback
)
# @RunTime
def get_result(self):
self.triton_client.get_model_metadata(model_name=self.model_name, model_version=self.version)
self.triton_client.get_model_config(model_name=self.model_name, model_version=self.version)
try:
prompts = [self.prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
image = self.get_image()
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
# Input placeholder
prompt_in = tritonclient.grpc.InferInput(name="PROMPT", shape=(self.batch_size,), datatype="BYTES")
samples_in = tritonclient.grpc.InferInput("SAMPLES", (self.batch_size,), "INT32")
steps_in = tritonclient.grpc.InferInput("STEPS", (self.batch_size,), "INT32")
guidance_scale_in = tritonclient.grpc.InferInput("GUIDANCE_SCALE", (self.batch_size,), "FP32")
seed_in = tritonclient.grpc.InferInput("SEED", (self.batch_size,), "INT64")
input_images_in = tritonclient.grpc.InferInput("INPUT_IMAGES", image.shape, "FP16")
images = tritonclient.grpc.InferRequestedOutput(name="IMAGES",
# binary_data=False
)
mode_in = tritonclient.grpc.InferInput("MODE", (self.batch_size,), "INT32")
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16")
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype))
# Setting inputs
prompt_in.set_data_from_numpy(np.asarray([self.content] * self.batch_size, dtype=object))
samples_in.set_data_from_numpy(np.asarray([self.samples], dtype=np.int32))
steps_in.set_data_from_numpy(np.asarray([self.steps], dtype=np.int32))
guidance_scale_in.set_data_from_numpy(np.asarray([self.guidance_scale], dtype=np.float32))
seed_in.set_data_from_numpy(np.asarray([self.seed], dtype=np.int64))
input_images_in.set_data_from_numpy(image.astype(np.float16))
mode_in.set_data_from_numpy(np.asarray([self.mode], dtype=np.int32))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
# inference
# @RunTime
def infer():
return self.triton_client.async_infer(
model_name=self.model_name,
model_version=self.version,
inputs=[prompt_in, samples_in, steps_in, guidance_scale_in, seed_in, input_images_in, mode_in],
outputs=[images],
callback=self.callback
)
ctx = infer()
time_out = 60
while time_out > 0:
generate_data = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=json.dumps(generate_data))
logger.info(f" [x] Sent {generate_data}")
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(1)
return self.read_tasks_status()
inputs = [input_text, input_image, input_mode]
ctx = self.infer(inputs)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
# self.generate_data['status'] = "FAILURE"
# self.generate_data['message'] = "failure"
# self.generate_data['data'] = str(e)
# self.redis_client.set(self.tasks_id, json.dumps(self.generate_data))
raise Exception(str(e))
# finally:
# dict_generate_data, str_generate_data = self.read_tasks_status()
# if DEBUG is False:
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps({'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'})
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
# request_data = {
# "user_id": 78,
# "image_url": "123_123.png",
# "category": "print",
# "mode": 1,
# "str": "a simple print",
# "version": "1"
# }
rd = GenerateImageModel(
mode=1,
content='a blouse',
gender='',
user_id=89,
image_url='test/微信图片_20231206133428.jpg',
category='sketch',
version='1',
tasks_id='123456'
tasks_id="123-89",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
image_url="",
mode='txt2img',
category="test"
)
server = GenerateImage(rd)
server.get_result()
# print(infer_cancel(123456))
print(server.get_result())

View File

@@ -0,0 +1,30 @@
import cv2
def adjust_contrast(image, alpha=1.5, beta=-60):
"""
调整图像的对比度和亮度。
参数:
image_path (numpy): 图像的路径。
alpha (float): 控制对比度的系数。alpha > 1 增加对比度alpha < 1 减少对比度。
beta (int): 用于调整亮度的值,可以是正或负。
返回:
adjusted_image (ndarray): 调整对比度后的图像。
"""
adjusted_image = cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
return adjusted_image
# 使用示例
if __name__ == "__main__":
image = cv2.imread('output_6.png') # 替换为你的图片路径
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
alpha = 1.5 # 对比度系数大于1增加对比度
beta = -60 # 亮度调整这里设置为0不改变亮度
# 调整图像对比度
result_image = adjust_contrast(image, alpha, beta)
# 可以选择保存调整后的图像
cv2.imwrite('adjusted_image.jpg', result_image) # 保存调整后的图片

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,395 @@
import logging
import time
import mmcv
import numpy as np
import torch
import tritonclient.http as httpclient
import torch.nn.functional as F
from app.core.config import *
import cv2
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd, upload_face_png_sd
logger = logging.getLogger()
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale = ori_shape
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
def get_mask(image_obj):
pre_mask = None
if len(image_obj.shape) == 2:
image_obj = cv2.cvtColor(image_obj, cv2.COLOR_GRAY2RGB)
if image_obj.shape[2] == 4: # 如果是四通道 mask
pre_mask = image_obj[:, :, 3]
image_obj = image_obj[:, :, :3]
Contour = get_contours(image_obj)
Mask = np.zeros(image_obj.shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(image_obj.shape[:2], np.uint8) * 255
if pre_mask is None:
mask = Mask
else:
mask = cv2.bitwise_and(Mask, pre_mask)
return image_obj, mask
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour
# def seg_infer_image(image_obj):
# image, ori_shape = seg_preprocess(image_obj)
# client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
# transformed_img = image.astype(np.float32)
# # 输入集
# inputs = [
# httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
# ]
# inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# # 输出集
# outputs = [
# httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
# ]
# results = client.infer(model_name=SEGMENTATION['name'], inputs=inputs, outputs=outputs)
# # 推理
# # 取结果
# inference_output1 = torch.from_numpy(results.as_numpy(SEGMENTATION['output']))
# seg_result = seg_postprocess(inference_output1, ori_shape)
# return seg_result
def seg_infer_image(image_obj):
image, ori_shape = seg_preprocess(image_obj)
client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
start_time = time.time()
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
print(f"KNet infer time is :{time.time() - start_time}")
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
# def seg_postprocess(output, ori_shape):
# seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
# seg_logit = F.softmax(seg_logit, dim=1)
# seg_pred = seg_logit.argmax(dim=1)
# seg_pred = seg_pred.cpu().numpy()
# return seg_pred
# KNet
def seg_postprocess(output, ori_shape):
# seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
# seg_logit = F.softmax(seg_logit, dim=1)
# seg_pred = seg_logit.argmax(dim=1)
# seg_pred = output.cpu().numpy()
return output[0]
def remove_background(image):
image_obj, mask = get_mask(image)
seg_result = seg_infer_image(image_obj)
temp_front = seg_result == 1
front_mask = (mask * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2
back_mask = (mask * (temp_back + 0).astype(np.uint8))
if len(front_mask.shape) > 2:
front_mask = front_mask[0]
else:
front_mask = front_mask
if len(back_mask.shape) > 2:
back_mask = back_mask[0]
else:
back_mask = back_mask
result_mask = front_mask + back_mask
white_background = np.ones_like(image_obj) * 255
remove_bg_image = np.where(result_mask[:, :, None].astype(bool), image_obj, white_background)
# cv2.imwrite("source_image", image)
# cv2.imwrite("remove_bg_image", remove_bg_image)
return remove_bg_image
def bounding_box(image):
edges = cv2.Canny(image, 50, 150)
# 查找轮廓
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 初始化包围所有外接矩形的大矩形的坐标
x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1
# 遍历所有外接矩形,更新大矩形的坐标
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
x_min = min(x_min, x)
y_min = min(y_min, y)
x_max = max(x_max, x + w)
y_max = max(y_max, y + h)
# 根据大矩形的坐标来裁剪原始图像
result_image = image[y_min:y_max, x_min:x_max]
# cv2.imshow("result_image", result_image)
# cv2.waitKey(0)
return result_image
def stain_detection(image, user_id, category, tasks_id, spot_size=100):
height, width, _ = image.shape
corners = [
image[0:spot_size, 0:spot_size], # top left
image[0:spot_size, width - spot_size:width], # top right
# image[height - spot_size:height, 0:spot_size], # bottom left
# image[height - spot_size:height, width - spot_size:width] # bottom right
]
for index, corner in enumerate(corners):
num_white_pixels = (corner == [255, 255, 255]).all(axis=2).sum()
if num_white_pixels != spot_size * spot_size:
logger.info(f"{index + 1}发现了污点")
return False, None
# 中心区域检测
# 将图像转换为灰度图像
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# 获取图像的中心点坐标
center_x, center_y = image.shape[1] // 2, image.shape[0] // 2
# 定义中心区域的大小
patch_size = 100
half_patch = patch_size // 2
# 提取中心区域
center_patch = gray[center_y - half_patch:center_y + half_patch, center_x - half_patch:center_x + half_patch]
# 设置阈值来检测纯白区域
_, thresh = cv2.threshold(center_patch, 254, 255, cv2.THRESH_BINARY)
# 寻找轮廓
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 过滤非连续的纯白区域
filtered_contours = [cnt for cnt in contours if cv2.contourArea(cnt) >= 300] # 根据面积进行过滤这里假设面积大于30的为连续区域
# 如果有连续的纯白区域存在
if filtered_contours:
# 将纯白区域替换为灰色
if DEBUG:
for cnt in filtered_contours:
x, y, w, h = cv2.boundingRect(cnt)
# 在原始图像上进行替换
image[y + center_y - half_patch:y + center_y - half_patch + h, x + center_x - half_patch:x + center_x - half_patch + w][thresh[y:y + h, x:x + w] == 255] = (128, 128, 128)
# 显示图像
cv2.imshow('Marked Image', image)
cv2.waitKey(0)
logger.info("中心区域存在连续的纯白区域")
is_pure_white = True
else:
logger.info("中心区域不存在连续的纯白区域")
is_pure_white = False
if is_pure_white:
return False, None
if DEBUG:
for corner_coords in [
(0, 0),
# (0, width - spot_size),
(height - spot_size, 0),
# (height - spot_size, width - spot_size)
# 中心点
]:
cv2.rectangle(image, corner_coords, (corner_coords[0] + spot_size, corner_coords[1] + spot_size), (0, 0, 255), 2)
cv2.rectangle(image, (center_x - spot_size // 2, center_y - spot_size // 2), (center_x + spot_size // 2, center_y + spot_size // 2), (0, 255, 0), 2) # 在原始图像上绘制矩形框
dst = image.copy()
for corner_coords in [
(0, 0),
# (0, width - spot_size),
(height - spot_size, 0),
# (height - spot_size, width - spot_size)
# 中心点
]:
cv2.rectangle(dst, corner_coords, (corner_coords[0] + spot_size, corner_coords[1] + spot_size), (0, 0, 255), 2)
cv2.rectangle(dst, (center_x - spot_size // 2, center_y - spot_size // 2), (center_x + spot_size // 2, center_y + spot_size // 2), (0, 255, 0), 2) # 在原始图像上绘制矩形框
image_url = upload_stain_png_sd(dst, user_id=user_id, category=f"{category}", object_name=f"{tasks_id}.png")
return True, image
def generate_category_recognition(image, gender):
def preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
preprocessed_img = preprocess(image)
triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL)
inputs = [
httpclient.InferInput("input__0", preprocessed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(preprocessed_img, binary_data=True)
results = triton_client.infer(model_name="attr_retrieve_category", inputs=inputs)
inference_output = torch.from_numpy(results.as_numpy(f'output__0'))
scores = inference_output.detach().numpy()
import pandas as pd
attr_type = pd.read_csv(CATEGORY_PATH)
colattr = list(attr_type['labelName'])
task = attr_type['taskName'][0]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]
category = colattr[indexs[0]]
if gender == "Male":
if category == 'Trousers' or category == 'Skirt':
category = 'Bottoms'
elif category == 'Blouse' or category == 'Dress':
category = 'Tops'
else:
category = 'Outwear'
return category, scores, image
def autoLevels(img, cutoff=0.1):
channels = img.shape[2] # h,w,ch
table = np.zeros((1, 256, 3), np.uint8)
for ch in range(channels):
# cutoff=0.1, 计算 0.1%, 99.9% 分位的灰度值
low = np.percentile(img[:, :, ch], q=cutoff) # ch 通道, cutoff=0.1, 0.1 分位的灰度值
high = np.percentile(img[:, :, ch], q=100 - cutoff) # 99.9 分位的灰度值, [0, high] 占比99.9%
# 输入动态线性拉伸
Sin = min(max(low, 0), high - 2) # Sin, 黑场阈值, 0<=Sin<Hin
Hin = min(high, 255) # Hin, 白场阈值, Sin<Hin<=255
difIn = Hin - Sin
V1 = np.array([(min(max(255 * (i - Sin) / difIn, 0), 255)) for i in range(256)])
# 灰场伽马调节
gradMed = np.median(img[:, :, ch]) # 拉伸前的中值
Mt = V1[int(gradMed)] / 128. # 拉伸后的映射值
V2 = 255 * np.power(V1 / 255, 1 / Mt) # 伽马调节
# 输出线性拉伸
Sout, Hout = 5, 250 # Sout 输出黑场阈值, Hout 输出白场阈值
difOut = Hout - Sout
table[0, :, ch] = np.array([(min(max(Sout + difOut * V2[i] / 255, 0), 255)) for i in range(256)])
return cv2.LUT(img, table)
def luminance_adjust(alpha, img):
if alpha > 0:
img_out = img * (1 - alpha) + alpha * 255.0
else:
img_out = img * (1 + alpha)
return np.array(img_out, dtype='uint8')
# 14.14 Photoshop 自动色阶调整算法
def face_detect_pic(image, user_id, category, tasks_id):
# 1、转灰度图
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# cv2.imshow("gray", gray)
# 2、训练一组人脸
face_detector = cv2.CascadeClassifier(FACE_CLASSIFIER)
# 3、检测人脸用灰度图检测返回人脸矩形坐标(4个角)
faces_rect = face_detector.detectMultiScale(gray, 1.05, 3)
if DEBUG:
dst = image.copy()
for x, y, w, h in faces_rect:
cv2.rectangle(dst, (x, y), (x + w, y + h), (0, 0, 255), 3) # 画出矩形框
# cv2.imshow("", dst)
# cv2.waitKey(0)
# TODO 暂时保留
dst = image.copy()
for x, y, w, h in faces_rect:
cv2.rectangle(dst, (x, y), (x + w, y + h), (0, 0, 255), 3) # 画出矩形框
image_url = upload_face_png_sd(dst, user_id=user_id, category=f"{category}", object_name=f"{tasks_id}.png")
return len(faces_rect)
if __name__ == '__main__':
# Photoshop 自动色阶调整算法
img = cv2.imread("2.png", flags=1) # 读取彩色图像
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 转换为灰度图像
print("cutoff={}, minG={}, maxG={}".format(0.0, gray.min(), gray.min()))
# 色阶手动调整
# equManual = levelsAdjust(img, 63, 205, 0.8, 10, 245) # 手动调节
# 色阶自动调整
cutoff = 1.0 # 截断比例, 建议范围 [0.0,1.0]
# cv2.imwrite("source.png", img)
equAuto = autoLevels(img, cutoff)
# cv2.imwrite("levels.png", equAuto)
luminance = luminance_adjust(0.3, equAuto)
# cv2.imwrite("luminance.png", luminance)
#
# # 将图像转换为灰度
# gray = cv2.cvtColor(luminance, cv2.COLOR_BGR2GRAY)
#
# # 使用Canny边缘检测算法检测图像的边缘
# edges = cv2.Canny(gray, 150, 200)
#
# # 对边缘进行膨胀操作,增强轮廓
# kernel = np.ones((1, 1), np.uint8)
# dilated_edges = cv2.dilate(edges, kernel, iterations=1)
#
# # 创建一个与原始图像相同大小的空白图像
# # result = np.zeros_like(luminance)
#
# # 将增强后的轮廓叠加到原始图像上
# luminance[dilated_edges != 0] = (255, 255, 255)
remove_bg_img = remove_background(luminance)
# cv2.imwrite("remove_bg_img.png", remove_bg_img)
print(1)
cv2.imshow("source", img)
cv2.imshow("levels", equAuto)
cv2.imshow("luminance", luminance)
# cv2.imshow("dilated_edges", luminance)
cv2.imshow("remove_bg_img", remove_bg_img)
cv2.waitKey(0)
image = cv2.imread("1.png")
remove_background(image)

View File

@@ -1,112 +0,0 @@
import cv2
import mmcv
import numpy as np
import torch
from PIL import Image
import tritonclient.http as httpclient
import torch.nn.functional as F
from app.core.config import *
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale = (224, 224)
scale_factor = []
img, x, y = mmcv.imresize(img, img_scale, return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
def get_mask(image_obj):
pre_mask = None
if len(image_obj.shape) == 2:
image_obj = cv2.cvtColor(image_obj, cv2.COLOR_GRAY2RGB)
if image_obj.shape[2] == 4: # 如果是四通道 mask
pre_mask = image_obj[:, :, 3]
image_obj = image_obj[:, :, :3]
Contour = get_contours(image_obj)
Mask = np.zeros(image_obj.shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(image_obj.shape[:2], np.uint8) * 255
if pre_mask is None:
mask = Mask
else:
mask = cv2.bitwise_and(Mask, pre_mask)
return image_obj, mask
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour
def seg_infer_image(image_obj):
image, ori_shape = seg_preprocess(image_obj)
client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = torch.from_numpy(results.as_numpy(SEGMENTATION['output']))
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
def seg_postprocess(output, ori_shape):
seg_logit = F.interpolate(output, size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_logit = F.softmax(seg_logit, dim=1)
seg_pred = seg_logit.argmax(dim=1)
seg_pred = seg_pred.cpu().numpy()
return seg_pred
def remove_background(image):
image_obj, mask = get_mask(image)
seg_result = seg_infer_image(image_obj)
temp_front = seg_result == 1
front_mask = (mask * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2
back_mask = (mask * (temp_back + 0).astype(np.uint8))
if len(front_mask.shape) > 2:
front_mask = front_mask[0]
else:
front_mask = front_mask
if len(back_mask.shape) > 2:
back_mask = back_mask[0]
else:
back_mask = back_mask
result_mask = front_mask + back_mask
white_background = np.ones_like(image_obj) * 255
result_image = np.where(result_mask[:, :, None].astype(bool), image_obj, white_background)
return Image.fromarray(result_image)

View File

@@ -10,6 +10,7 @@
import io
import logging
import cv2
from PIL import Image
from minio import Minio
@@ -20,18 +21,47 @@ minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET
def upload_png_sd(image, user_id, category, object_name):
try:
image_file = io.BytesIO()
image = Image.fromarray(image)
image.save(image_file, format='JPEG')
image_file.seek(0)
_, img_byte_array = cv2.imencode('.jpg', image)
minio_req = minio_client.put_object(
GI_MINIO_BUCKET,
f'{user_id}/{category}/{object_name}',
image_file,
len(image_file.getvalue()),
io.BytesIO(img_byte_array),
len(img_byte_array),
content_type='image/jpeg'
)
image_url = f"aida-users/{minio_req.object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_stain_png_sd(image, user_id, category, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
minio_req = minio_client.put_object(
"test",
f'generate_result/stain/{user_id}_{category}_{object_name}',
io.BytesIO(img_byte_array),
len(img_byte_array),
content_type='image/jpeg'
)
image_url = f"test/{minio_req.object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_face_png_sd(image, user_id, category, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
minio_req = minio_client.put_object(
"test",
f'generate_result/face/{user_id}_{category}_{object_name}',
io.BytesIO(img_byte_array),
len(img_byte_array),
content_type='image/jpeg'
)
image_url = f"test/{minio_req.object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")