feat image2sketch 新增风格上传 自定义风格

fix
This commit is contained in:
zhouchengrong
2024-09-20 17:03:48 +08:00
parent 1385fde9ce
commit 2e07bc2de9
9 changed files with 28 additions and 9 deletions

View File

@@ -17,14 +17,18 @@ def image2sketch(request_item: Image2SketchModel):
""" """
创建一个具有以下参数的请求体: 创建一个具有以下参数的请求体:
- **image_url**: 提取图片url - **image_url**: 提取图片url
- **style_image_url**: 被模仿sketch图片url
- **default_style**: 默认风格 粗1、中2、细3
- **sketch_bucket**: sketch保存的bucket - **sketch_bucket**: sketch保存的bucket
- **sketch_name**: sketch保存的object name - **sketch_name**: sketch保存的object name
示例参数: 示例参数:
{ {
"image_url": "test/real_Top_971fe3085a69f31f3e66c225eabb0eea.jpg_Img.jpg", "image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
"style_image_url": "test/image2sketch/style_3.png",
"default_style": "1",
"sketch_bucket": "test", "sketch_bucket": "test",
"sketch_name": "12341556-89.jpg" "sketch_name": "image2sketch/test.png"
} }
""" """
try: try:

View File

@@ -20,7 +20,7 @@ class Settings(BaseSettings):
OSS = "minio" OSS = "minio"
DEBUG = False DEBUG = True
if DEBUG: if DEBUG:
LOGS_PATH = "logs/" LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv" CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"

View File

@@ -3,5 +3,7 @@ from pydantic import BaseModel
class Image2SketchModel(BaseModel): class Image2SketchModel(BaseModel):
image_url: str image_url: str
style_image_url: str
default_style: str
sketch_bucket: str sketch_bucket: str
sketch_name: str sketch_name: str

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

View File

Before

Width:  |  Height:  |  Size: 376 KiB

After

Width:  |  Height:  |  Size: 376 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

View File

@@ -54,7 +54,7 @@ def load_img(filepath):
if __name__ == '__main__': if __name__ == '__main__':
img_A = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testA/real_Dress_732caedc416a0cbfedd0e6528040eac7.jpg_Img.jpg" img_A = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testA/real_Dress_732caedc416a0cbfedd0e6528040eac7.jpg_Img.jpg"
img_B = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testC/styleA.png" img_B = "/workspace/Semi_ref2sketch_code/datasets/ref_unpair/testC/style_3.png"
from opt import Config from opt import Config
opt = Config() # get test options opt = Config() # get test options
@@ -73,7 +73,7 @@ if __name__ == '__main__':
model.eval() model.eval()
data = {} data = {}
print(os.getcwd()) print(os.getcwd())
B = reference, _, _ = load_img(r"E:\workspace\trinity_client_aida\app\service\image2sketch\datasets\ref_unpair\testC\styleA.png") B = reference, _, _ = load_img(r"/app/service/image2sketch/datasets/ref_unpair/testC/style_3.png")
style_img = transform(reference) style_img = transform(reference)
data['B'] = style_img data['B'] = style_img
data['B'] = data['B'].unsqueeze(0).to(device) data['B'] = data['B'].unsqueeze(0).to(device)

View File

@@ -46,8 +46,12 @@ class Config:
self.num_test = 1000 self.num_test = 1000
self.morm = 'batch' self.morm = 'batch'
if DEBUG: if DEBUG:
self.style_image = "service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg" self.style_image1 = "service/image2sketch/datasets/ref_unpair/testC/style_1.jpg"
self.style_image2 = "service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg"
self.style_image3 = "service/image2sketch/datasets/ref_unpair/testC/style_3.png"
self.checkpoints_dir = 'service/image2sketch/checkpoints/' self.checkpoints_dir = 'service/image2sketch/checkpoints/'
else: else:
self.checkpoints_dir = 'app/service/image2sketch/checkpoints/' self.checkpoints_dir = 'app/service/image2sketch/checkpoints/'
self.style_image = "app/service/image2sketch/datasets/ref_unpair/testC/20180422151845_stEe4.jpeg" self.style_image1 = "app/service/image2sketch/datasets/ref_unpair/testC/style_1.jpg"
self.style_image2 = "app/service/image2sketch/datasets/ref_unpair/testC/style_2.jpeg"
self.style_image3 = "app/service/image2sketch/datasets/ref_unpair/testC/style_3.png"

View File

@@ -33,6 +33,7 @@ def tensor2im(input_image, imtype=np.uint8):
class Image2SketchServer: class Image2SketchServer:
def __init__(self, request_data): def __init__(self, request_data):
self.image_url = request_data.image_url self.image_url = request_data.image_url
self.style_image_url = request_data.style_image_url
self.sketch_bucket = request_data.sketch_bucket self.sketch_bucket = request_data.sketch_bucket
self.sketch_name = request_data.sketch_name self.sketch_name = request_data.sketch_name
self.opt = Config() self.opt = Config()
@@ -47,7 +48,15 @@ class Image2SketchServer:
self.model.setup(self.opt) self.model.setup(self.opt)
transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] transform_list = [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
transform = transforms.Compose(transform_list) transform = transforms.Compose(transform_list)
style_img = Image.open(self.opt.style_image).convert('L') if request_data.default_style == "1":
style_img = Image.open(self.opt.style_image1).convert('L')
elif request_data.default_style == "2":
style_img = Image.open(self.opt.style_image2).convert('L')
elif request_data.default_style == "3":
style_img = Image.open(self.opt.style_image3).convert('L')
else:
style_img = oss_get_image(bucket=self.style_image_url.split('/')[0], object_name=self.style_image_url[self.style_image_url.find('/') + 1:], data_type="PIL")
style_img = style_img.convert('L')
style_img = transform(style_img) style_img = transform(style_img)
self.data['B'] = style_img self.data['B'] = style_img
self.data['B'] = self.data['B'].unsqueeze(0).to(device) self.data['B'] = self.data['B'].unsqueeze(0).to(device)