feat
fix design 恢复seg模型推理,并保存本地
This commit is contained in:
@@ -5,11 +5,12 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tritonclient.grpc as grpcclient
|
import tritonclient.grpc as grpcclient
|
||||||
|
from pymilvus import MilvusClient
|
||||||
from urllib3.exceptions import ResponseError
|
from urllib3.exceptions import ResponseError
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.schemas.pre_processing import DesignPreProcessingModel
|
from app.schemas.pre_processing import DesignPreProcessingModel
|
||||||
from app.service.design.utils.design_ensemble import get_keypoint_result
|
from app.service.design.utils.design_ensemble import get_keypoint_result, get_seg_result
|
||||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||||
|
|
||||||
|
|
||||||
@@ -124,9 +125,9 @@ class DesignPreprocessing:
|
|||||||
bucket_name = item['image_url'].split("/", 1)[0]
|
bucket_name = item['image_url'].split("/", 1)[0]
|
||||||
object_name = item['image_url'].split("/", 1)[1]
|
object_name = item['image_url'].split("/", 1)[1]
|
||||||
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||||
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
logging.info(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||||
except ResponseError as err:
|
except ResponseError as err:
|
||||||
print(f"Error: {err}")
|
logging.warning(f"Error: {err}")
|
||||||
return image_list
|
return image_list
|
||||||
|
|
||||||
# @ RunTime
|
# @ RunTime
|
||||||
@@ -138,6 +139,12 @@ class DesignPreprocessing:
|
|||||||
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||||
# 推理得到keypoint
|
# 推理得到keypoint
|
||||||
sketch['keypoint_result'] = self.keypoint_cache(sketch)
|
sketch['keypoint_result'] = self.keypoint_cache(sketch)
|
||||||
|
if sketch['site'] == 'up':
|
||||||
|
_, seg_cache = self.load_seg_result(sketch['image_id'])
|
||||||
|
if not _:
|
||||||
|
# 推理获得seg 结果
|
||||||
|
seg_result = get_seg_result(sketch["image_id"], sketch['image_obj'])[0]
|
||||||
|
self.save_seg_result(seg_result, sketch['image_id'])
|
||||||
|
|
||||||
if IF_DEBUG_SHOW:
|
if IF_DEBUG_SHOW:
|
||||||
debug_show_image = sketch['obj'].copy()
|
debug_show_image = sketch['obj'].copy()
|
||||||
@@ -236,58 +243,37 @@ class DesignPreprocessing:
|
|||||||
return image_list
|
return image_list
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def select_seg_result(image_id, image_obj):
|
def load_seg_result(image_id):
|
||||||
|
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||||
try:
|
try:
|
||||||
# 如果shape不匹配 返回false
|
seg_result = np.load(file_path)
|
||||||
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
|
return True, seg_result
|
||||||
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
|
except FileNotFoundError:
|
||||||
return result
|
logging.info("文件不存在")
|
||||||
else:
|
return False, None
|
||||||
return False
|
except Exception as e:
|
||||||
except FileNotFoundError as e:
|
logging.warning(f"加载失败: {e}")
|
||||||
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
return False, None
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search_seg_result(image_id, ori_shape):
|
def save_seg_result(seg_result, image_id):
|
||||||
|
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||||
try:
|
try:
|
||||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
np.save(file_path, seg_result)
|
||||||
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
|
logging.info(f"保存成功,{os.path.abspath(file_path)}")
|
||||||
# collection.load()
|
|
||||||
# start_time = time.time()
|
|
||||||
# res = collection.query(
|
|
||||||
# expr=f"seg_id == {image_id}",
|
|
||||||
# offset=0,
|
|
||||||
# limit=10,
|
|
||||||
# output_fields=["seg_cache"],
|
|
||||||
# )
|
|
||||||
# logging.info(f"search seg cache time : {time.time() - start_time}")
|
|
||||||
|
|
||||||
# if len(res):
|
|
||||||
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
|
|
||||||
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
|
|
||||||
# array_2d_exact = array_2d_exact.squeeze().numpy()
|
|
||||||
# return array_2d_exact
|
|
||||||
# else:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
logging.warning(f"保存失败: {e}")
|
||||||
return False
|
|
||||||
|
|
||||||
def keypoint_cache(self, sketch):
|
def keypoint_cache(self, sketch):
|
||||||
try:
|
try:
|
||||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
keypoint_id = sketch['image_id']
|
||||||
# collection.load()
|
res = client.query(
|
||||||
start_time = time.time()
|
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||||
# res = collection.query(
|
# ids=[keypoint_id],
|
||||||
# expr=f"keypoint_id == {sketch['image_id']}",
|
filter=f"keypoint_id == {keypoint_id}",
|
||||||
# offset=0,
|
output_fields=['keypoint_vector', 'keypoint_site']
|
||||||
# limit=1,
|
)
|
||||||
# output_fields=["keypoint_cache", "keypoint_site"],
|
|
||||||
# )
|
|
||||||
res = []
|
|
||||||
logging.info(f"search keypoint time : {time.time() - start_time}")
|
|
||||||
if len(res) == 0:
|
if len(res) == 0:
|
||||||
# 没有结果 直接推理拿结果 并保存
|
# 没有结果 直接推理拿结果 并保存
|
||||||
keypoint_infer_result = self.infer_keypoint_result(sketch)
|
keypoint_infer_result = self.infer_keypoint_result(sketch)
|
||||||
|
|||||||
Reference in New Issue
Block a user