fix  design 恢复seg模型推理,并保存本地
This commit is contained in:
zhouchengrong
2024-07-24 15:20:18 +08:00
parent dbd75a0d84
commit b7e9131cf7

View File

@@ -5,11 +5,12 @@ import cv2
import numpy as np import numpy as np
import torch import torch
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
from pymilvus import MilvusClient
from urllib3.exceptions import ResponseError from urllib3.exceptions import ResponseError
from app.core.config import * from app.core.config import *
from app.schemas.pre_processing import DesignPreProcessingModel from app.schemas.pre_processing import DesignPreProcessingModel
from app.service.design.utils.design_ensemble import get_keypoint_result from app.service.design.utils.design_ensemble import get_keypoint_result, get_seg_result
from app.service.utils.oss_client import oss_get_image, oss_upload_image from app.service.utils.oss_client import oss_get_image, oss_upload_image
@@ -124,9 +125,9 @@ class DesignPreprocessing:
bucket_name = item['image_url'].split("/", 1)[0] bucket_name = item['image_url'].split("/", 1)[0]
object_name = item['image_url'].split("/", 1)[1] object_name = item['image_url'].split("/", 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes) oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.") logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
except ResponseError as err: except ResponseError as err:
print(f"Error: {err}") logging.warning(f"Error: {err}")
return image_list return image_list
# @ RunTime # @ RunTime
@@ -138,6 +139,12 @@ class DesignPreprocessing:
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down' sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 推理得到keypoint # 推理得到keypoint
sketch['keypoint_result'] = self.keypoint_cache(sketch) sketch['keypoint_result'] = self.keypoint_cache(sketch)
if sketch['site'] == 'up':
_, seg_cache = self.load_seg_result(sketch['image_id'])
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(sketch["image_id"], sketch['image_obj'])[0]
self.save_seg_result(seg_result, sketch['image_id'])
if IF_DEBUG_SHOW: if IF_DEBUG_SHOW:
debug_show_image = sketch['obj'].copy() debug_show_image = sketch['obj'].copy()
@@ -236,58 +243,37 @@ class DesignPreprocessing:
return image_list return image_list
@staticmethod @staticmethod
def select_seg_result(image_id, image_obj): def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try: try:
# 如果shape不匹配 返回false seg_result = np.load(file_path)
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64) return True, seg_result
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]: except FileNotFoundError:
return result logging.info("文件不存在")
else: return False, None
return False except Exception as e:
except FileNotFoundError as e: logging.warning(f"加载失败: {e}")
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") return False, None
return False
@staticmethod @staticmethod
def search_seg_result(image_id, ori_shape): def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try: try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) np.save(file_path, seg_result)
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection. logging.info(f"保存成功,{os.path.abspath(file_path)}")
# collection.load()
# start_time = time.time()
# res = collection.query(
# expr=f"seg_id == {image_id}",
# offset=0,
# limit=10,
# output_fields=["seg_cache"],
# )
# logging.info(f"search seg cache time {time.time() - start_time}")
# if len(res):
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
# array_2d_exact = array_2d_exact.squeeze().numpy()
# return array_2d_exact
# else:
return False
except Exception as e: except Exception as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}") logging.warning(f"保存失败: {e}")
return False
def keypoint_cache(self, sketch): def keypoint_cache(self, sketch):
try: try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT) client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection. keypoint_id = sketch['image_id']
# collection.load() res = client.query(
start_time = time.time() collection_name=MILVUS_TABLE_KEYPOINT,
# res = collection.query( # ids=[keypoint_id],
# expr=f"keypoint_id == {sketch['image_id']}", filter=f"keypoint_id == {keypoint_id}",
# offset=0, output_fields=['keypoint_vector', 'keypoint_site']
# limit=1, )
# output_fields=["keypoint_cache", "keypoint_site"],
# )
res = []
logging.info(f"search keypoint time : {time.time() - start_time}")
if len(res) == 0: if len(res) == 0:
# 没有结果 直接推理拿结果 并保存 # 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch) keypoint_infer_result = self.infer_keypoint_result(sketch)