diff --git a/app/api/api_image2sketch.py b/app/api/api_image2sketch.py index 5d15daa..cf8df13 100644 --- a/app/api/api_image2sketch.py +++ b/app/api/api_image2sketch.py @@ -17,14 +17,18 @@ def image2sketch(request_item: Image2SketchModel): """ 创建一个具有以下参数的请求体: - **image_url**: 提取图片url + - **style_image_url**: 被模仿sketch图片url + - **default_style**: 默认风格 粗1,、中2、细3 - **sketch_bucket**: sketch保存的bucket - **sketch_name**: sketch保存的object name 示例参数: { - "image_url": "test/real_Top_971fe3085a69f31f3e66c225eabb0eea.jpg_Img.jpg", - "sketch_bucket": "test", - "sketch_name": "12341556-89.jpg" + "image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg", + "style_image_url": "test/image2sketch/style_3.png", + "default_style": "1", + "sketch_bucket": "test", + "sketch_name": "image2sketch/test.png" } """ try: diff --git a/app/core/config.py b/app/core/config.py index 35c12b7..2e4d7bd 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -20,7 +20,7 @@ class Settings(BaseSettings): OSS = "minio" -DEBUG = False +DEBUG = True if DEBUG: LOGS_PATH = "logs/" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" diff --git a/app/schemas/image2sketch.py b/app/schemas/image2sketch.py index a124739..b4650b9 100644 --- a/app/schemas/image2sketch.py +++ b/app/schemas/image2sketch.py @@ -3,5 +3,7 @@ from pydantic import BaseModel class Image2SketchModel(BaseModel): image_url: str + style_image_url: str + default_style: str sketch_bucket: str sketch_name: str diff --git a/app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg b/app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg new file mode 100644 index 0000000..3a66b7f Binary files /dev/null and b/app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg differ diff --git a/app/service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg b/app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg similarity index 100% rename from app/service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg rename to app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg diff --git a/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png b/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png new file mode 100644 index 0000000..8d8bcf4 Binary files /dev/null and b/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png differ diff --git a/app/service/image2sketch/infer.py b/app/service/image2sketch/infer.py index 266b37c..8ec241f 100644 --- a/app/service/image2sketch/infer.py +++ b/app/service/image2sketch/infer.py @@ -54,7 +54,7 @@ def load_img(filepath): if __name__ == '__main__': img_A = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testA/real_Dress_732caedc416a0cbfedd0e6528040eac7.jpg_Img.jpg" - img_B = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testC/styleA.png" + img_B = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testC/style_3.png" from opt import Config opt = Config() # get test options @@ -73,7 +73,7 @@ if __name__ == '__main__': model.eval() data = {} print(os.getcwd()) - B = reference, _, _ = load_img(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testC\styleA.png") + B = reference, _, _ = load_img(r"/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png") style_img = transform(reference) data['B'] = style_img data['B'] = data['B'].unsqueeze(0).to(device) diff --git a/app/service/image2sketch/opt.py b/app/service/image2sketch/opt.py index 8f33b9c..eb453fb 100644 --- a/app/service/image2sketch/opt.py +++ b/app/service/image2sketch/opt.py @@ -46,8 +46,12 @@ class Config: self.num_test = 1000 self.morm = 'batch' if DEBUG: - self.style_image = "service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg" + self.style_image1 = "service/image2sketch/datasets/ref_unpair/testC/style_1.jpg" + self.style_image2 = "service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg" + self.style_image3 = "service/image2sketch/datasets/ref_unpair/testC/style_3.png" self.checkpoints_dir = 'service/image2sketch/checkpoints/' else: self.checkpoints_dir = 'app/service/image2sketch/checkpoints/' - self.style_image = "app/service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg" + self.style_image1 = "app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg" + self.style_image2 = "app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg" + self.style_image3 = "app/service/image2sketch/datasets/ref_unpair/testC/style_3.png" diff --git a/app/service/image2sketch/server.py b/app/service/image2sketch/server.py index ebd363e..3094eea 100644 --- a/app/service/image2sketch/server.py +++ b/app/service/image2sketch/server.py @@ -33,6 +33,7 @@ def tensor2im(input_image, imtype=np.uint8): class Image2SketchServer: def __init__(self, request_data): self.image_url = request_data.image_url + self.style_image_url = request_data.style_image_url self.sketch_bucket = request_data.sketch_bucket self.sketch_name = request_data.sketch_name self.opt = Config() @@ -47,7 +48,15 @@ class Image2SketchServer: self.model.setup(self.opt) transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] transform = transforms.Compose(transform_list) - style_img = Image.open(self.opt.style_image).convert('L') + if request_data.default_style == "1": + style_img = Image.open(self.opt.style_image1).convert('L') + elif request_data.default_style == "2": + style_img = Image.open(self.opt.style_image2).convert('L') + elif request_data.default_style == "3": + style_img = Image.open(self.opt.style_image3).convert('L') + else: + style_img = oss_get_image(bucket=self.style_image_url.split('/')[0], object_name=self.style_image_url[self.style_image_url.find('/') + 1:], data_type="PIL") + style_img = style_img.convert('L') style_img = transform(style_img) self.data['B'] = style_img self.data['B'] = self.data['B'].unsqueeze(0).to(device)