From 7ff3a72d8c1231e6444c9db34187166a60a0dd7a Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 15 Aug 2024 10:24:19 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=20=20sketch=20=E6=8F=90=E5=8F=96?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_image2sketch.py | 3 +++ app/service/image2sketch/opt.py | 3 ++- app/service/image2sketch/server.py | 3 --- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/app/api/api_image2sketch.py b/app/api/api_image2sketch.py index 414361d..5bc6191 100644 --- a/app/api/api_image2sketch.py +++ b/app/api/api_image2sketch.py @@ -1,5 +1,6 @@ import json import logging +import time from fastapi import APIRouter, HTTPException @@ -27,9 +28,11 @@ def image2sketch(request_item: Image2SketchModel): } """ try: + start_time = time.time() logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}") service = Image2SketchServer(request_item) sketch_url = service.get_result() + logger.info(f"run time is : {time.time() - start_time}") except Exception as e: logger.warning(f"image2sketch Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/app/service/image2sketch/opt.py b/app/service/image2sketch/opt.py index 03cf7a3..8f33b9c 100644 --- a/app/service/image2sketch/opt.py +++ b/app/service/image2sketch/opt.py @@ -7,7 +7,6 @@ class Config: self.dataroot = "app/service/image2sketch/datasets/ref_unpair" self.name = 'semi_unpair' self.gpu_ids = [0] - self.checkpoints_dir = 'app/service/image2sketch/checkpoints/' # 模型参数 self.model = 'unpaired' self.input_nc = 3 @@ -48,5 +47,7 @@ class Config: self.morm = 'batch' if DEBUG: self.style_image = "service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg" + self.checkpoints_dir = 'service/image2sketch/checkpoints/' else: + self.checkpoints_dir = 'app/service/image2sketch/checkpoints/' self.style_image = "app/service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg" diff --git a/app/service/image2sketch/server.py b/app/service/image2sketch/server.py index 7a15b55..ebd363e 100644 --- a/app/service/image2sketch/server.py +++ b/app/service/image2sketch/server.py @@ -55,9 +55,6 @@ class Image2SketchServer: self.data['A'] = transform(A) self.data['A'] = self.data['A'].unsqueeze(0).to(device) - def __del__(self): - torch.cuda.empty_cache() - def get_result(self): self.model.set_input(self.data) self.model.test() # run inference