Files
AiDA_Python/app/service/design_pre_processing/service.py
zhouchengrong 2df1518a99 feat
fix  minio and s3
2024-06-21 17:13:39 +08:00

362 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import time
import cv2
import numpy as np
import torch
from minio import Minio
from pymilvus import connections, Collection
from urllib3.exceptions import ResponseError
import torch.nn.functional as F
import tritonclient.grpc as grpcclient
import io
from app.core.config import *
from app.service.design.utils.design_ensemble import get_keypoint_result
from app.service.utils.oss_client import oss_get_image, oss_upload_image
class DesignPreprocessing:
# def __init__(self):
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# @ RunTime
def pipeline(self, image_list):
sketches_list = self.read_image(image_list)
logging.info("read image success")
bounding_box_sketches_list = self.bounding_box(sketches_list)
logging.info("bounding box image success")
super_resolution_list = self.super_resolution(bounding_box_sketches_list)
logging.info("super_resolution_list image success")
infer_sketches_list = self.infer_image(super_resolution_list)
logging.info("infer image success")
result = self.composing_image(infer_sketches_list)
logging.info("Replenish white edge image success")
for d in result:
if 'image_obj' in d:
del d['image_obj']
if 'obj' in d:
del d['obj']
if 'keypoint_result' in d:
del d['keypoint_result']
return result
def read_image(self, image_list):
for obj in image_list:
# file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # 如果是四通道 mask
image = image[:, :, :3]
obj["image_obj"] = image
return image_list
# @ RunTime
def bounding_box(self, image_list):
for item in image_list:
image = item['image_obj']
# 使用Canny边缘检测来检测物体的轮廓
edges = cv2.Canny(image, 50, 150)
# 查找轮廓
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 初始化包围所有外接矩形的大矩形的坐标
x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1
# 遍历所有外接矩形,更新大矩形的坐标
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
x_min = min(x_min, x)
y_min = min(y_min, y)
x_max = max(x_max, x + w)
y_max = max(y_max, y + h)
if IF_DEBUG_SHOW:
image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
cv2.imshow("bounding_box image", image_with_big_rect)
cv2.waitKey(0)
# 根据大矩形的坐标来裁剪原始图像
if len(contours) > 0:
cropped_image = image[y_min:y_max, x_min:x_max]
item['obj'] = cropped_image # 新shape图像
# 取消直接覆盖新增size判断
# try:
# # 覆盖到minio
# image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes()
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
# print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
# except ResponseError as err:
# print(f"Error: {err}")
else:
item['obj'] = image
return image_list
def super_resolution(self, image_list):
for item in image_list:
# 判断 两边是否同时都小于512 因为此处做四倍超分
if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512:
# 如果任意一边小于256则超分
if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256:
# 超分
img = item['obj'].astype(np.float32) / 255.
sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample)
triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
result_image = result.as_numpy(f'output')[0]
sr_output = torch.tensor(result_image)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
item['obj'] = output
try:
# 覆盖到minio
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
bucket_name = item['image_url'].split("/", 1)[0]
object_name = item['image_url'].split("/", 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
except ResponseError as err:
print(f"Error: {err}")
return image_list
# @ RunTime
def infer_image(self, image_list):
for sketch in image_list:
# 小写
image_category = sketch['image_category'].lower()
# 判断上下装
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 推理得到keypoint
sketch['keypoint_result'] = self.keypoint_cache(sketch)
if IF_DEBUG_SHOW:
debug_show_image = sketch['obj'].copy()
points_list = []
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for i in sketch['keypoint_result'].values():
points_list.append((int(i[1]), int(i[0])))
for point in points_list:
cv2.circle(debug_show_image, point, point_size, point_color, thickness)
cv2.imshow("", debug_show_image)
cv2.waitKey(0)
# # 关键点在上部则推理seg
# if sketch["site"] == "up":
# # 判断seg缓存是否存在,是否与当前图片shape一致
# seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape)
# if seg_result is False:
# # 推理seg + 保存
# seg_result = get_seg_result(sketch['image_id'], sketch['obj'])
return image_list
# @ RunTime
def composing_image(self, image_list):
for image in image_list:
''' 比例相同 整合上下装代码'''
image_width = image['obj'].shape[1]
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
scale = 0.4
if waist_width / scale >= image_width:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
if IF_DEBUG_SHOW:
cv2.imshow("composing_image", ret)
cv2.waitKey(0)
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
image['show_image_url'] = f"{bucket_name}/{object_name}"
else:
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
image['show_image_url'] = f"{bucket_name}/{object_name}"
# if image['site'] == 'down':
# image_width = image['obj'].shape[1]
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
# scale = 0.4
# if waist_width / scale >= image_width:
# add_width = int((waist_width / scale - image_width) / 2)
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
# if IF_DEBUG_SHOW:
# cv2.imshow("composing_image", ret)
# cv2.waitKey(0)
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_width = image['obj'].shape[1]
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
# scale = 0.4
# if waist_width / scale >= image_width:
# add_width = int((waist_width / scale - image_width) / 2)
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
# if IF_DEBUG_SHOW:
# cv2.imshow("composing_image", ret)
# cv2.waitKey(0)
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
return image_list
@staticmethod
def select_seg_result(image_id, image_obj):
try:
# 如果shape不匹配 返回false
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
return result
else:
return False
except FileNotFoundError as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
@staticmethod
def search_seg_result(image_id, ori_shape):
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
# collection.load()
# start_time = time.time()
# res = collection.query(
# expr=f"seg_id == {image_id}",
# offset=0,
# limit=10,
# output_fields=["seg_cache"],
# )
# logging.info(f"search seg cache time {time.time() - start_time}")
# if len(res):
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
# array_2d_exact = array_2d_exact.squeeze().numpy()
# return array_2d_exact
# else:
return False
except Exception as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
def keypoint_cache(self, sketch):
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# collection.load()
start_time = time.time()
# res = collection.query(
# expr=f"keypoint_id == {sketch['image_id']}",
# offset=0,
# limit=1,
# output_fields=["keypoint_cache", "keypoint_site"],
# )
res = []
logging.info(f"search keypoint time : {time.time() - start_time}")
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch)
return self.save_keypoint_cache(sketch, keypoint_infer_result)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != sketch['site']:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result = self.infer_keypoint_result(sketch)
return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector'])
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False
# @ RunTime
def infer_keypoint_result(self, sketch):
keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果
return keypoint_infer_result
@staticmethod
# @ RunTime
def save_keypoint_cache(sketch, keypoint_infer_result):
if sketch['site'] == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, keypoint_infer_result.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([keypoint_infer_result.flatten(), zeros])
data = [
[int(sketch['image_id'])],
[sketch['site']],
[result.tolist()]
]
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.insert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(sketch, infer_result, search_result):
if sketch['site'] == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
[int(sketch['image_id'])],
["all"],
[result.tolist()]
]
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))