fix  design 恢复seg模型推理,并保存本地
This commit is contained in:
zhouchengrong
2024-07-24 15:20:18 +08:00
parent dbd75a0d84
commit b7e9131cf7

View File

@@ -5,11 +5,12 @@ import cv2
import numpy as np
import torch
import tritonclient.grpc as grpcclient
from pymilvus import MilvusClient
from urllib3.exceptions import ResponseError
from app.core.config import *
from app.schemas.pre_processing import DesignPreProcessingModel
from app.service.design.utils.design_ensemble import get_keypoint_result
from app.service.design.utils.design_ensemble import get_keypoint_result, get_seg_result
from app.service.utils.oss_client import oss_get_image, oss_upload_image
@@ -124,9 +125,9 @@ class DesignPreprocessing:
bucket_name = item['image_url'].split("/", 1)[0]
object_name = item['image_url'].split("/", 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
except ResponseError as err:
print(f"Error: {err}")
logging.warning(f"Error: {err}")
return image_list
# @ RunTime
@@ -138,6 +139,12 @@ class DesignPreprocessing:
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 推理得到keypoint
sketch['keypoint_result'] = self.keypoint_cache(sketch)
if sketch['site'] == 'up':
_, seg_cache = self.load_seg_result(sketch['image_id'])
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(sketch["image_id"], sketch['image_obj'])[0]
self.save_seg_result(seg_result, sketch['image_id'])
if IF_DEBUG_SHOW:
debug_show_image = sketch['obj'].copy()
@@ -236,58 +243,37 @@ class DesignPreprocessing:
return image_list
@staticmethod
def select_seg_result(image_id, image_obj):
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
# 如果shape不匹配 返回false
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
return result
else:
return False
except FileNotFoundError as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
logging.info("文件不存在")
return False, None
except Exception as e:
logging.warning(f"加载失败: {e}")
return False, None
@staticmethod
def search_seg_result(image_id, ori_shape):
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
# collection.load()
# start_time = time.time()
# res = collection.query(
# expr=f"seg_id == {image_id}",
# offset=0,
# limit=10,
# output_fields=["seg_cache"],
# )
# logging.info(f"search seg cache time {time.time() - start_time}")
# if len(res):
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
# array_2d_exact = array_2d_exact.squeeze().numpy()
# return array_2d_exact
# else:
return False
np.save(file_path, seg_result)
logging.info(f"保存成功,{os.path.abspath(file_path)}")
except Exception as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
logging.warning(f"保存失败: {e}")
def keypoint_cache(self, sketch):
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# collection.load()
start_time = time.time()
# res = collection.query(
# expr=f"keypoint_id == {sketch['image_id']}",
# offset=0,
# limit=1,
# output_fields=["keypoint_cache", "keypoint_site"],
# )
res = []
logging.info(f"search keypoint time : {time.time() - start_time}")
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = sketch['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch)