feat
fix design 恢复seg模型推理,并保存本地
This commit is contained in:
@@ -5,11 +5,12 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tritonclient.grpc as grpcclient
|
||||
from pymilvus import MilvusClient
|
||||
from urllib3.exceptions import ResponseError
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.pre_processing import DesignPreProcessingModel
|
||||
from app.service.design.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.design.utils.design_ensemble import get_keypoint_result, get_seg_result
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
|
||||
@@ -124,9 +125,9 @@ class DesignPreprocessing:
|
||||
bucket_name = item['image_url'].split("/", 1)[0]
|
||||
object_name = item['image_url'].split("/", 1)[1]
|
||||
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||
logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||
except ResponseError as err:
|
||||
print(f"Error: {err}")
|
||||
logging.warning(f"Error: {err}")
|
||||
return image_list
|
||||
|
||||
# @ RunTime
|
||||
@@ -138,6 +139,12 @@ class DesignPreprocessing:
|
||||
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# 推理得到keypoint
|
||||
sketch['keypoint_result'] = self.keypoint_cache(sketch)
|
||||
if sketch['site'] == 'up':
|
||||
_, seg_cache = self.load_seg_result(sketch['image_id'])
|
||||
if not _:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(sketch["image_id"], sketch['image_obj'])[0]
|
||||
self.save_seg_result(seg_result, sketch['image_id'])
|
||||
|
||||
if IF_DEBUG_SHOW:
|
||||
debug_show_image = sketch['obj'].copy()
|
||||
@@ -236,58 +243,37 @@ class DesignPreprocessing:
|
||||
return image_list
|
||||
|
||||
@staticmethod
|
||||
def select_seg_result(image_id, image_obj):
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
# 如果shape不匹配 返回false
|
||||
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
|
||||
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
|
||||
return result
|
||||
else:
|
||||
return False
|
||||
except FileNotFoundError as e:
|
||||
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
||||
return False
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
logging.info("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logging.warning(f"加载失败: {e}")
|
||||
return False, None
|
||||
|
||||
@staticmethod
|
||||
def search_seg_result(image_id, ori_shape):
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
|
||||
# collection.load()
|
||||
# start_time = time.time()
|
||||
# res = collection.query(
|
||||
# expr=f"seg_id == {image_id}",
|
||||
# offset=0,
|
||||
# limit=10,
|
||||
# output_fields=["seg_cache"],
|
||||
# )
|
||||
# logging.info(f"search seg cache time : {time.time() - start_time}")
|
||||
|
||||
# if len(res):
|
||||
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
|
||||
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
|
||||
# array_2d_exact = array_2d_exact.squeeze().numpy()
|
||||
# return array_2d_exact
|
||||
# else:
|
||||
return False
|
||||
np.save(file_path, seg_result)
|
||||
logging.info(f"保存成功,{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
||||
return False
|
||||
logging.warning(f"保存失败: {e}")
|
||||
|
||||
def keypoint_cache(self, sketch):
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# collection.load()
|
||||
start_time = time.time()
|
||||
# res = collection.query(
|
||||
# expr=f"keypoint_id == {sketch['image_id']}",
|
||||
# offset=0,
|
||||
# limit=1,
|
||||
# output_fields=["keypoint_cache", "keypoint_site"],
|
||||
# )
|
||||
res = []
|
||||
logging.info(f"search keypoint time : {time.time() - start_time}")
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
keypoint_id = sketch['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
# ids=[keypoint_id],
|
||||
filter=f"keypoint_id == {keypoint_id}",
|
||||
output_fields=['keypoint_vector', 'keypoint_site']
|
||||
)
|
||||
if len(res) == 0:
|
||||
# 没有结果 直接推理拿结果 并保存
|
||||
keypoint_infer_result = self.infer_keypoint_result(sketch)
|
||||
|
||||
Reference in New Issue
Block a user