feat design 预处理接口迁移
This commit is contained in:
29
app/api/api_design_pre_processing.py
Normal file
29
app/api/api_design_pre_processing.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.schemas.pre_processing import DesignPreProcessingModel
|
||||||
|
from app.service.design_pre_processing.service import DesignPreprocessing
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/design_pre_processing")
|
||||||
|
def design_pre_processing(request_data: DesignPreProcessingModel):
|
||||||
|
try:
|
||||||
|
logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}")
|
||||||
|
code = 200
|
||||||
|
message = "access"
|
||||||
|
start_time = time.time()
|
||||||
|
server = DesignPreprocessing()
|
||||||
|
data = server.pipeline(image_list=request_data.sketches)
|
||||||
|
logger.info(f"design_pre_processing Run time is @@@@@@:{time.time() - start_time}")
|
||||||
|
except Exception as e:
|
||||||
|
code = 400
|
||||||
|
message = str(e)
|
||||||
|
data = str(e)
|
||||||
|
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||||
|
logger.info({"code": code, "message": message, "data": data})
|
||||||
|
return {"code": code, "message": message, "data": data}
|
||||||
@@ -7,6 +7,7 @@ from app.api import api_attribute_retrieve
|
|||||||
from app.api import api_design
|
from app.api import api_design
|
||||||
from app.api import api_chat_robot
|
from app.api import api_chat_robot
|
||||||
from app.api import api_prompt_generation
|
from app.api import api_prompt_generation
|
||||||
|
from app.api import api_design_pre_processing
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -18,3 +19,4 @@ router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"]
|
|||||||
router.include_router(api_design.router, tags=['design'], prefix="/api")
|
router.include_router(api_design.router, tags=['design'], prefix="/api")
|
||||||
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
|
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
|
||||||
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
|
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
|
||||||
|
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||||
|
|||||||
@@ -118,6 +118,9 @@ AIDA_CLOTHING = "aida-clothing"
|
|||||||
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
|
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
|
||||||
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
||||||
|
|
||||||
|
# DESIGN 预处理
|
||||||
|
IF_DEBUG_SHOW = False
|
||||||
|
|
||||||
# 优先级
|
# 优先级
|
||||||
PRIORITY_DICT = {
|
PRIORITY_DICT = {
|
||||||
'earring_front': 99,
|
'earring_front': 99,
|
||||||
|
|||||||
5
app/schemas/pre_processing.py
Normal file
5
app/schemas/pre_processing.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class DesignPreProcessingModel(BaseModel):
|
||||||
|
sketches: list[dict]
|
||||||
@@ -34,8 +34,8 @@ class KeypointDetection(object):
|
|||||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||||
|
|
||||||
# 取消向量查询 直接过模型推理
|
|
||||||
keypoint_cache = self.keypoint_cache(result, site)
|
keypoint_cache = self.keypoint_cache(result, site)
|
||||||
|
# 取消向量查询 直接过模型推理
|
||||||
# keypoint_cache = False
|
# keypoint_cache = False
|
||||||
|
|
||||||
if keypoint_cache is False:
|
if keypoint_cache is False:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def get_keypoint_result(image, site):
|
|||||||
keypoint_result = None
|
keypoint_result = None
|
||||||
try:
|
try:
|
||||||
image, scale_factor = keypoint_preprocess(image)
|
image, scale_factor = keypoint_preprocess(image)
|
||||||
client = httpclient.InferenceServerClient(url=KEYPOINT_MODEL_URL)
|
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||||
transformed_img = image.astype(np.float32)
|
transformed_img = image.astype(np.float32)
|
||||||
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
||||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||||
|
|||||||
320
app/service/design_pre_processing/service.py
Normal file
320
app/service/design_pre_processing/service.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from minio import Minio
|
||||||
|
from pymilvus import connections, Collection
|
||||||
|
from urllib3.exceptions import ResponseError
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
import io
|
||||||
|
|
||||||
|
from app.core.config import *
|
||||||
|
from app.service.design.utils.design_ensemble import get_keypoint_result
|
||||||
|
|
||||||
|
|
||||||
|
class DesignPreprocessing:
|
||||||
|
def __init__(self):
|
||||||
|
self.minio_client = Minio(
|
||||||
|
MINIO_URL,
|
||||||
|
access_key=MINIO_ACCESS,
|
||||||
|
secret_key=MINIO_SECRET,
|
||||||
|
secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
# @ RunTime
|
||||||
|
def pipeline(self, image_list):
|
||||||
|
sketches_list = self.read_image(image_list)
|
||||||
|
logging.info("read image success")
|
||||||
|
|
||||||
|
bounding_box_sketches_list = self.bounding_box(sketches_list)
|
||||||
|
logging.info("bounding box image success")
|
||||||
|
|
||||||
|
super_resolution_list = self.super_resolution(bounding_box_sketches_list)
|
||||||
|
logging.info("super_resolution_list image success")
|
||||||
|
|
||||||
|
infer_sketches_list = self.infer_image(super_resolution_list)
|
||||||
|
logging.info("infer image success")
|
||||||
|
|
||||||
|
result = self.composing_image(infer_sketches_list)
|
||||||
|
logging.info("Replenish white edge image success")
|
||||||
|
|
||||||
|
for d in result:
|
||||||
|
if 'image_obj' in d:
|
||||||
|
del d['image_obj']
|
||||||
|
if 'obj' in d:
|
||||||
|
del d['obj']
|
||||||
|
if 'keypoint_result' in d:
|
||||||
|
del d['keypoint_result']
|
||||||
|
return result
|
||||||
|
|
||||||
|
def read_image(self, image_list):
|
||||||
|
for obj in image_list:
|
||||||
|
file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
|
||||||
|
image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
|
||||||
|
if len(image.shape) == 2:
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||||
|
elif image.shape[2] == 4: # 如果是四通道 mask
|
||||||
|
image = image[:, :, :3]
|
||||||
|
obj["image_obj"] = image
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
# @ RunTime
|
||||||
|
def bounding_box(self, image_list):
|
||||||
|
for item in image_list:
|
||||||
|
image = item['image_obj']
|
||||||
|
# 使用Canny边缘检测来检测物体的轮廓
|
||||||
|
edges = cv2.Canny(image, 50, 150)
|
||||||
|
# 查找轮廓
|
||||||
|
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
# 初始化包围所有外接矩形的大矩形的坐标
|
||||||
|
x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1
|
||||||
|
# 遍历所有外接矩形,更新大矩形的坐标
|
||||||
|
for contour in contours:
|
||||||
|
x, y, w, h = cv2.boundingRect(contour)
|
||||||
|
x_min = min(x_min, x)
|
||||||
|
y_min = min(y_min, y)
|
||||||
|
x_max = max(x_max, x + w)
|
||||||
|
y_max = max(y_max, y + h)
|
||||||
|
|
||||||
|
if IF_DEBUG_SHOW:
|
||||||
|
image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
|
||||||
|
cv2.imshow("bounding_box image", image_with_big_rect)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
|
||||||
|
# 根据大矩形的坐标来裁剪原始图像
|
||||||
|
if len(contours) > 0:
|
||||||
|
cropped_image = image[y_min:y_max, x_min:x_max]
|
||||||
|
item['obj'] = cropped_image # 新shape图像
|
||||||
|
# 取消直接覆盖,新增size判断
|
||||||
|
# try:
|
||||||
|
# # 覆盖到minio
|
||||||
|
# image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes()
|
||||||
|
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
|
||||||
|
# print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||||
|
# except ResponseError as err:
|
||||||
|
# print(f"Error: {err}")
|
||||||
|
else:
|
||||||
|
item['obj'] = image
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
def super_resolution(self, image_list):
|
||||||
|
for item in image_list:
|
||||||
|
# 判断 两边是否同时都小于512 因为此处做四倍超分
|
||||||
|
if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512:
|
||||||
|
# 如果任意一边小于256则超分
|
||||||
|
if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256:
|
||||||
|
# 超分
|
||||||
|
img = item['obj'].astype(np.float32) / 255.
|
||||||
|
sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1))
|
||||||
|
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
||||||
|
inputs = [
|
||||||
|
grpcclient.InferInput("input", sample.shape, datatype="FP32")
|
||||||
|
]
|
||||||
|
inputs[0].set_data_from_numpy(sample)
|
||||||
|
triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
|
||||||
|
result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
|
||||||
|
result_image = result.as_numpy(f'output')[0]
|
||||||
|
sr_output = torch.tensor(result_image)
|
||||||
|
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
if output.ndim == 3:
|
||||||
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||||
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
|
item['obj'] = output
|
||||||
|
try:
|
||||||
|
# 覆盖到minio
|
||||||
|
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
|
||||||
|
self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
|
||||||
|
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||||
|
except ResponseError as err:
|
||||||
|
print(f"Error: {err}")
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
# @ RunTime
|
||||||
|
def infer_image(self, image_list):
|
||||||
|
for sketch in image_list:
|
||||||
|
# 小写
|
||||||
|
image_category = sketch['image_category'].lower()
|
||||||
|
# 判断上下装
|
||||||
|
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||||
|
# 推理得到keypoint
|
||||||
|
sketch['keypoint_result'] = self.keypoint_cache(sketch)
|
||||||
|
|
||||||
|
if IF_DEBUG_SHOW:
|
||||||
|
debug_show_image = sketch['obj'].copy()
|
||||||
|
points_list = []
|
||||||
|
point_size = 1
|
||||||
|
point_color = (0, 0, 255) # BGR
|
||||||
|
thickness = 4 # 可以为 0 、4、8
|
||||||
|
for i in sketch['keypoint_result'].values():
|
||||||
|
points_list.append((int(i[1]), int(i[0])))
|
||||||
|
for point in points_list:
|
||||||
|
cv2.circle(debug_show_image, point, point_size, point_color, thickness)
|
||||||
|
cv2.imshow("", debug_show_image)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
# # 关键点在上部则推理seg
|
||||||
|
# if sketch["site"] == "up":
|
||||||
|
# # 判断seg缓存是否存在,是否与当前图片shape一致
|
||||||
|
# seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape)
|
||||||
|
# if seg_result is False:
|
||||||
|
# # 推理seg + 保存
|
||||||
|
# seg_result = get_seg_result(sketch['image_id'], sketch['obj'])
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
# @ RunTime
|
||||||
|
def composing_image(self, image_list):
|
||||||
|
for image in image_list:
|
||||||
|
if image['site'] == 'down':
|
||||||
|
image_width = image['obj'].shape[1]
|
||||||
|
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
|
||||||
|
scale = 0.4
|
||||||
|
if waist_width / scale >= image['obj'].shape[1]:
|
||||||
|
add_width = int((waist_width / scale - image_width) / 2)
|
||||||
|
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
||||||
|
if IF_DEBUG_SHOW:
|
||||||
|
cv2.imshow("composing_image", ret)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
||||||
|
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
|
else:
|
||||||
|
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
||||||
|
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
|
else:
|
||||||
|
scale = 0.4
|
||||||
|
image_width = image['obj'].shape[1]
|
||||||
|
waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1]
|
||||||
|
if waist_width / scale >= image_width:
|
||||||
|
add_width = int((waist_width / scale - image_width) / 2)
|
||||||
|
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
||||||
|
if IF_DEBUG_SHOW:
|
||||||
|
cv2.imshow("composing_image", ret)
|
||||||
|
cv2.waitKey(0)
|
||||||
|
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
||||||
|
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
|
else:
|
||||||
|
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
||||||
|
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def select_seg_result(image_id, image_obj):
|
||||||
|
try:
|
||||||
|
# 如果shape不匹配 返回false
|
||||||
|
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
|
||||||
|
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def search_seg_result(image_id, ori_shape):
|
||||||
|
try:
|
||||||
|
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||||
|
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
|
||||||
|
# collection.load()
|
||||||
|
# start_time = time.time()
|
||||||
|
# res = collection.query(
|
||||||
|
# expr=f"seg_id == {image_id}",
|
||||||
|
# offset=0,
|
||||||
|
# limit=10,
|
||||||
|
# output_fields=["seg_cache"],
|
||||||
|
# )
|
||||||
|
# logging.info(f"search seg cache time : {time.time() - start_time}")
|
||||||
|
|
||||||
|
# if len(res):
|
||||||
|
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
|
||||||
|
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
|
||||||
|
# array_2d_exact = array_2d_exact.squeeze().numpy()
|
||||||
|
# return array_2d_exact
|
||||||
|
# else:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def keypoint_cache(self, sketch):
|
||||||
|
try:
|
||||||
|
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||||
|
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||||
|
# collection.load()
|
||||||
|
start_time = time.time()
|
||||||
|
# res = collection.query(
|
||||||
|
# expr=f"keypoint_id == {sketch['image_id']}",
|
||||||
|
# offset=0,
|
||||||
|
# limit=1,
|
||||||
|
# output_fields=["keypoint_cache", "keypoint_site"],
|
||||||
|
# )
|
||||||
|
res = []
|
||||||
|
logging.info(f"search keypoint time : {time.time() - start_time}")
|
||||||
|
if len(res) == 0:
|
||||||
|
# 没有结果 直接推理拿结果 并保存
|
||||||
|
keypoint_infer_result = self.infer_keypoint_result(sketch)
|
||||||
|
return self.save_keypoint_cache(sketch, keypoint_infer_result)
|
||||||
|
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']:
|
||||||
|
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||||
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||||
|
elif res[0]["keypoint_site"] != sketch['site']:
|
||||||
|
# 需要的类型和查询到的不一致,则更新类型为all
|
||||||
|
keypoint_infer_result = self.infer_keypoint_result(sketch)
|
||||||
|
return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector'])
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"search keypoint cache milvus error {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# @ RunTime
|
||||||
|
def infer_keypoint_result(self, sketch):
|
||||||
|
keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果
|
||||||
|
return keypoint_infer_result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# @ RunTime
|
||||||
|
def save_keypoint_cache(sketch, keypoint_infer_result):
|
||||||
|
if sketch['site'] == "down":
|
||||||
|
zeros = np.zeros(20, dtype=int)
|
||||||
|
result = np.concatenate([zeros, keypoint_infer_result.flatten()])
|
||||||
|
else:
|
||||||
|
zeros = np.zeros(4, dtype=int)
|
||||||
|
result = np.concatenate([keypoint_infer_result.flatten(), zeros])
|
||||||
|
data = [
|
||||||
|
[int(sketch['image_id'])],
|
||||||
|
[sketch['site']],
|
||||||
|
[result.tolist()]
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||||
|
start_time = time.time()
|
||||||
|
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||||
|
# mr = collection.insert(data)
|
||||||
|
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||||
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"save keypoint cache milvus error : {e}")
|
||||||
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_keypoint_cache(sketch, infer_result, search_result):
|
||||||
|
if sketch['site'] == "up":
|
||||||
|
# 需要的是up 即推理出来的是up 那么查询的就是down
|
||||||
|
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||||
|
else:
|
||||||
|
# 需要的是down 即推理出来的是down 那么查询的就是up
|
||||||
|
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||||
|
data = [
|
||||||
|
[int(sketch['image_id'])],
|
||||||
|
["all"],
|
||||||
|
[result.tolist()]
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||||
|
start_time = time.time()
|
||||||
|
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||||
|
# mr = collection.upsert(data)
|
||||||
|
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||||
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"save keypoint cache milvus error : {e}")
|
||||||
|
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||||
Reference in New Issue
Block a user