45 lines
2.2 KiB
Python
45 lines
2.2 KiB
Python
import logging
|
|
import time
|
|
|
|
from fastapi import APIRouter
|
|
from app.schemas.outfit_matcher import OutfitMatcher
|
|
from app.service.outfit_matcher.dataset import FashionDataset
|
|
from app.service.outfit_matcher.outfit_evaluator import OutfitMaterTypeAware
|
|
from app.service.utils.decorator import RunTime
|
|
|
|
logger = logging.getLogger()
|
|
router = APIRouter()
|
|
|
|
|
|
@RunTime
|
|
@router.post("outfit_matcher")
|
|
def outfit_matcher(request_item: OutfitMatcher):
|
|
request_item = dict(request_item)
|
|
for i in range(len(request_item['query'])):
|
|
request_item['query'][i] = dict(request_item['query'][i])
|
|
for i in range(len(request_item['database'])):
|
|
request_item['database'][i] = dict(request_item['database'][i])
|
|
|
|
try:
|
|
fashion_dataset = FashionDataset(request_item['database'])
|
|
service = OutfitMaterTypeAware()
|
|
result = []
|
|
start_time = time.time()
|
|
for item in request_item['query']:
|
|
outfits = fashion_dataset.generate_outfit(item, request_item["topk"], request_item["max_outfits"])
|
|
scores = service.get_result(outfits)
|
|
if request_item['is_best']:
|
|
best_outfits, best_scores = service.visualize(outfits, scores, request_item["topk"], best=True,
|
|
# output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_best_{param['topk']}.png")
|
|
)
|
|
result.append({"best_outfits": best_outfits, "best_scores": best_scores})
|
|
else:
|
|
bad_outfits, bad_scores = service.visualize(outfits, scores, request_item["topk"], best=False,
|
|
# output_path=os.path.join(r"E:\workspace\outfit_matcher\2024 SS Outfit", f"{item['item_name']}_worst_{param['topk']}.png")
|
|
)
|
|
result.append({"bad_outfits": bad_outfits, "bad_scores": bad_scores})
|
|
logger.info(f"run time is : {time.time() - start_time}")
|
|
return {"message": "ok", "data": result}
|
|
except Exception as e:
|
|
return {"message": f"{e}", "data": e}
|