Files
AiDA_Python/app/service/design_pre_processing/service.py

373 lines
20 KiB
Python
Raw Normal View History

2024-05-30 09:48:13 +08:00
import logging
import os
2024-05-30 09:48:13 +08:00
import time
import cv2
import numpy as np
import torch
import tritonclient.grpc as grpcclient
from minio import Minio
2026-01-23 17:34:51 +08:00
# from pymilvus import MilvusClient
from urllib3.exceptions import ResponseError
2024-05-30 09:48:13 +08:00
from app.core.config import settings, SR_MODEL_NAME, SR_TRITON_URL, MILVUS_TABLE_KEYPOINT, KEYPOINT_RESULT_TABLE_FIELD_SET
from app.schemas.pre_processing import DesignPreProcessingModel
2024-09-26 06:09:05 +00:00
from app.service.design_fast.utils.design_ensemble import get_seg_result, get_keypoint_result
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
2024-05-30 09:48:13 +08:00
2024-08-20 11:15:24 +08:00
logger = logging.getLogger()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
2024-08-20 11:15:24 +08:00
2024-05-30 09:48:13 +08:00
class DesignPreprocessing:
2024-06-21 17:13:39 +08:00
# def __init__(self):
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
2024-05-30 09:48:13 +08:00
# @ RunTime
def pipeline(self, image_list):
sketches_list = self.read_image(image_list)
2024-08-20 11:15:24 +08:00
# logging.info("read image success")
2024-05-30 09:48:13 +08:00
bounding_box_sketches_list = self.bounding_box(sketches_list)
2024-08-20 11:15:24 +08:00
# logging.info("bounding box image success")
2024-05-30 09:48:13 +08:00
# super_resolution_list = self.super_resolution(bounding_box_sketches_list)
2024-08-20 11:15:24 +08:00
# logging.info("super_resolution_list image success")
2024-05-30 09:48:13 +08:00
infer_sketches_list = self.infer_image(bounding_box_sketches_list)
2024-08-20 11:15:24 +08:00
# logging.info("infer image success")
2024-05-30 09:48:13 +08:00
result = self.composing_image(infer_sketches_list)
2024-08-20 11:15:24 +08:00
# logging.info("Replenish white edge image success")
2024-05-30 09:48:13 +08:00
for d in result:
if 'image_obj' in d:
del d['image_obj']
if 'obj' in d:
del d['obj']
if 'keypoint_result' in d:
del d['keypoint_result']
return result
@staticmethod
def read_image(image_list):
2024-05-30 09:48:13 +08:00
for obj in image_list:
2024-06-21 17:13:39 +08:00
# file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
image = oss_get_image(oss_client=minio_client, bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2")
2024-05-30 09:48:13 +08:00
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # 如果是四通道 mask
# 分离RGB和Alpha通道
bgr = image[:, :, :3]
alpha = image[:, :, 3]
# 创建白色背景(也可改为其他颜色,如(255,255,255)就是白色)
background_color = (255, 255, 255)
background = np.full_like(bgr, background_color)
# 将Alpha通道转换为掩码0=透明255=不透明)
alpha_mask = alpha / 255.0 # 归一化到0-1
alpha_mask = np.expand_dims(alpha_mask, axis=-1) # 扩展维度,方便广播计算
# 混合背景和原图:透明区域显示背景色,不透明区域显示原图
image = (bgr * alpha_mask + background * (1 - alpha_mask)).astype(np.uint8)
# 此时image已经是3通道RGB无需再执行image = image[:, :, :3]
2024-05-30 09:48:13 +08:00
obj["image_obj"] = image
return image_list
# @ RunTime
@staticmethod
def bounding_box(image_list):
2024-05-30 09:48:13 +08:00
for item in image_list:
image = item['image_obj']
height, width = image.shape[:2]
2024-05-30 09:48:13 +08:00
# 使用Canny边缘检测来检测物体的轮廓
edges = cv2.Canny(image, 50, 150)
# 查找轮廓
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 初始化包围所有外接矩形的大矩形的坐标
x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1
# 遍历所有外接矩形,更新大矩形的坐标
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
x_min = min(x_min, x)
y_min = min(y_min, y)
x_max = max(x_max, x + w)
y_max = max(y_max, y + h)
# 根据大矩形的坐标来裁剪原始图像
if len(contours) > 0:
cropped_image = image[y_min:y_max, x_min:x_max]
item['obj'] = cropped_image # 新shape图像
else:
item['obj'] = image
padding_top = max(20 - y_min, 0)
padding_bottom = max(20 - (height - y_max), 0)
padding_left = max(20 - x_min, 0)
padding_right = max(20 - (width - x_max), 0)
# 添加padding
padded_image = cv2.copyMakeBorder(
image,
padding_top,
padding_bottom,
padding_left,
padding_right,
cv2.BORDER_CONSTANT,
2024-08-20 11:15:24 +08:00
value=(255, 255, 255)
)
item['obj'] = padded_image
2024-05-30 09:48:13 +08:00
return image_list
@staticmethod
def super_resolution(image_list):
2024-05-30 09:48:13 +08:00
for item in image_list:
# 判断 两边是否同时都小于512 因为此处做四倍超分
if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512:
# 如果任意一边小于256则超分
if item['obj'].shape[0] <= 200 or item['obj'].shape[1] <= 200:
2024-05-30 09:48:13 +08:00
# 超分
img = item['obj'].astype(np.float32) / 255.
sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample)
triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
result_image = result.as_numpy(f'output')[0]
sr_output = torch.tensor(result_image)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
item['obj'] = output
try:
# 覆盖到minio
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
2024-06-21 17:13:39 +08:00
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
bucket_name = item['image_url'].split("/", 1)[0]
object_name = item['image_url'].split("/", 1)[1]
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
2024-05-30 09:48:13 +08:00
except ResponseError as err:
logging.warning(f"Error: {err}")
2024-05-30 09:48:13 +08:00
return image_list
# @ RunTime
def infer_image(self, image_list):
for sketch in image_list:
# 小写
image_category = sketch['image_category'].lower()
# 判断上下装
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 推理得到keypoint
sketch['keypoint_result'] = self.keypoint_cache(sketch)
if sketch['site'] == 'up':
_, seg_cache = self.load_seg_result(sketch['image_id'])
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(sketch['obj'])[0]
self.save_seg_result(seg_result, sketch['image_id'])
2024-08-20 11:15:24 +08:00
logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_result.shape}")
else:
logger.info(f"{sketch['image_id']} image size is :{sketch['obj'].shape} , seg cache size is :{seg_cache.shape}")
2024-05-30 09:48:13 +08:00
return image_list
# @ RunTime
@staticmethod
def composing_image(image_list):
2024-05-30 09:48:13 +08:00
for image in image_list:
2024-06-21 17:13:39 +08:00
''' 比例相同 整合上下装代码'''
image_width = image['obj'].shape[1]
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
scale = 0.4
if waist_width / scale >= image_width:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(255, 255, 255))
img_rgba = cv2.cvtColor(ret, cv2.COLOR_RGB2RGBA)
image_bytes = cv2.imencode(".png", img_rgba)[1].tobytes()
2024-06-21 17:13:39 +08:00
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.')
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
2024-06-21 17:13:39 +08:00
image['show_image_url'] = f"{bucket_name}/{object_name}"
2024-05-30 09:48:13 +08:00
else:
2024-06-21 17:13:39 +08:00
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.')
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
2024-06-21 17:13:39 +08:00
image['show_image_url'] = f"{bucket_name}/{object_name}"
# if image['site'] == 'down':
# image_width = image['obj'].shape[1]
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
# scale = 0.4
# if waist_width / scale >= image_width:
# add_width = int((waist_width / scale - image_width) / 2)
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
# if IF_DEBUG_SHOW:
# cv2.imshow("composing_image", ret)
# cv2.waitKey(0)
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_width = image['obj'].shape[1]
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
# scale = 0.4
# if waist_width / scale >= image_width:
# add_width = int((waist_width / scale - image_width) / 2)
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
# if IF_DEBUG_SHOW:
# cv2.imshow("composing_image", ret)
# cv2.waitKey(0)
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
2024-05-30 09:48:13 +08:00
return image_list
@staticmethod
def load_seg_result(image_id):
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
2024-05-30 09:48:13 +08:00
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
# logging.info("文件不存在")
return False, None
except Exception as e:
logging.warning(f"加载失败: {e}")
return False, None
2024-05-30 09:48:13 +08:00
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
2024-05-30 09:48:13 +08:00
try:
np.save(file_path, seg_result)
logging.debug(f"保存成功,{os.path.abspath(file_path)}")
2024-05-30 09:48:13 +08:00
except Exception as e:
logging.warning(f"保存失败: {e}")
2024-05-30 09:48:13 +08:00
def keypoint_cache(self, sketch):
try:
2026-01-23 17:34:51 +08:00
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = sketch['image_id']
2026-01-23 17:34:51 +08:00
# res = client.query(
# collection_name=MILVUS_TABLE_KEYPOINT,
# # ids=[keypoint_id],
# filter=f"keypoint_id == {keypoint_id}",
# output_fields=['keypoint_vector', 'keypoint_site']
# )
res = []
2024-05-30 09:48:13 +08:00
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch)
return self.save_keypoint_cache(sketch, keypoint_infer_result)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != sketch['site']:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result = self.infer_keypoint_result(sketch)
return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector'])
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False
# @ RunTime
@staticmethod
def infer_keypoint_result(sketch):
2024-05-30 09:48:13 +08:00
keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果
return keypoint_infer_result
@staticmethod
# @ RunTime
def save_keypoint_cache(sketch, keypoint_infer_result):
if sketch['site'] == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, keypoint_infer_result.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([keypoint_infer_result.flatten(), zeros])
# [
# [int(sketch['image_id'])],
# [sketch['site']],
# [result.tolist()]
# ]
2024-05-30 09:48:13 +08:00
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
time.time()
2024-05-30 09:48:13 +08:00
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.insert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(sketch, infer_result, search_result):
if sketch['site'] == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
# [
# [int(sketch['image_id'])],
# ["all"],
# [result.tolist()]
# ]
2024-05-30 09:48:13 +08:00
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# start_time = time.time()
2024-05-30 09:48:13 +08:00
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
if __name__ == '__main__':
data = {
"sketches": [
{
"image_category": "blouse",
"image_id": "123123123",
"image_url": "test/0628000198.jpg"
}
]
}
request_data = DesignPreProcessingModel(sketches=data["sketches"])
server = DesignPreprocessing()
data = server.pipeline(image_list=request_data.sketches)
print(data)