From 85942167f3f65db4caede62e4945a4d5b60cfecb Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 30 Sep 2024 10:57:12 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=20image2sketch=20=E5=8F=98=E6=9B=B4?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- app/api/api_image2sketch.py | 13 +- app/schemas/image2sketch.py | 1 - .../image2sketch_2/download_checkpoints.py | 45 ++++++ app/service/image2sketch_2/server.py | 142 ++++++++++++++++++ app/service/utils/new_oss_client.py | 3 + 6 files changed, 200 insertions(+), 7 deletions(-) create mode 100644 app/service/image2sketch_2/download_checkpoints.py create mode 100644 app/service/image2sketch_2/server.py diff --git a/.gitignore b/.gitignore index 8fd7817..3f9e525 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,5 @@ app/logs/* *.log /qodana.yaml .pth -.pytorch \ No newline at end of file +.pytorch +*.png \ No newline at end of file diff --git a/app/api/api_image2sketch.py b/app/api/api_image2sketch.py index cf8df13..24acf46 100644 --- a/app/api/api_image2sketch.py +++ b/app/api/api_image2sketch.py @@ -6,7 +6,7 @@ from fastapi import APIRouter, HTTPException from app.schemas.image2sketch import Image2SketchModel from app.schemas.response_template import ResponseModel -from app.service.image2sketch.server import Image2SketchServer +from app.service.image2sketch_2.server import processing_pipeline router = APIRouter() logger = logging.getLogger() @@ -25,8 +25,7 @@ def image2sketch(request_item: Image2SketchModel): 示例参数: { "image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg", - "style_image_url": "test/image2sketch/style_3.png", - "default_style": "1", + "default_style": 0, "sketch_bucket": "test", "sketch_name": "image2sketch/test.png" } @@ -34,8 +33,12 @@ def image2sketch(request_item: Image2SketchModel): try: start_time = time.time() logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}") - service = Image2SketchServer(request_item) - sketch_url = service.get_result() + sketch_url = processing_pipeline( + image_url=request_item.image_url, + thickness=request_item.default_style, + sketch_bucket=request_item.sketch_bucket, + sketch_name=request_item.sketch_name + ) logger.info(f"run time is : {time.time() - start_time}") except Exception as e: logger.warning(f"image2sketch Run Exception @@@@@@:{e}") diff --git a/app/schemas/image2sketch.py b/app/schemas/image2sketch.py index b4650b9..dbbbbb5 100644 --- a/app/schemas/image2sketch.py +++ b/app/schemas/image2sketch.py @@ -3,7 +3,6 @@ from pydantic import BaseModel class Image2SketchModel(BaseModel): image_url: str - style_image_url: str default_style: str sketch_bucket: str sketch_name: str diff --git a/app/service/image2sketch_2/download_checkpoints.py b/app/service/image2sketch_2/download_checkpoints.py new file mode 100644 index 0000000..9048c34 --- /dev/null +++ b/app/service/image2sketch_2/download_checkpoints.py @@ -0,0 +1,45 @@ +import os + +from minio import Minio +from minio.error import S3Error + +MINIO_URL = "www.minio.aida.com.hk:12024" +MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB' +MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR' +MINIO_SECURE = True +# 配置MinIO客户端 +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +# 下载函数 +def download_folder(bucket_name, folder_name, local_dir): + try: + # 确保本地目录存在 + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + # 遍历MinIO中的文件 + objects = minio_client.list_objects(bucket_name, prefix=folder_name, recursive=True) + for obj in objects: + # 构造本地文件路径 + local_file_path = os.path.join(local_dir, obj.object_name[len(folder_name):]) + local_file_dir = os.path.dirname(local_file_path) + + # 确保本地目录存在 + if not os.path.exists(local_file_dir): + os.makedirs(local_file_dir) + + # 下载文件 + minio_client.fget_object(bucket_name, obj.object_name, local_file_path) + print(f"Downloaded {obj.object_name} to {local_file_path}") + + except S3Error as e: + print(f"Error occurred: {e}") + + +# 使用示例 +bucket_name = "test" # 替换成你的bucket名称 +folder_name = "checkpoints/lineart/" # 权重文件夹的路径 +local_dir = "app/service/image2sketch_2" # 替换成你希望保存到的本地目录 + +download_folder(bucket_name, folder_name, local_dir) diff --git a/app/service/image2sketch_2/server.py b/app/service/image2sketch_2/server.py new file mode 100644 index 0000000..93c9574 --- /dev/null +++ b/app/service/image2sketch_2/server.py @@ -0,0 +1,142 @@ +import cv2 +import numpy +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms as transforms +from PIL import Image + +from app.service.utils.oss_client import oss_get_image, oss_upload_image + +norm_layer = nn.InstanceNorm2d + +weights = [(0.7, 0.3), (0.5, 0.5), (0.3, 0.7), (0.1, 0.9), (0, 1)] +kernel = np.ones((3, 3), np.uint8) + + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) + + +class Generator(nn.Module): + def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): + super(Generator, self).__init__() + + # Initial convolution block + model0 = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True)] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features * 2 + for _ in range(2): + model1 += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True)] + in_features = out_features + out_features = in_features * 2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features // 2 + for _ in range(2): + model3 += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True)] + in_features = out_features + out_features = in_features // 2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [nn.ReflectionPad2d(3), + nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +model1 = Generator(3, 1, 3) +model1.load_state_dict(torch.load('service/image2sketch_2/model.pth', map_location=torch.device('cpu'))) +model1.eval() + + +def predict(input_img, width): + transform = transforms.Compose([transforms.Resize(width, Image.BICUBIC), transforms.ToTensor()]) + input_img = transform(input_img) + input_img = torch.unsqueeze(input_img, 0) + + with torch.no_grad(): + drawing = model1(input_img)[0].detach() + + drawing = transforms.ToPILImage()(drawing) + + # 转ndarray + drawing = numpy.array(drawing) + return drawing + + +def get_image(image_url): + image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") + image = image.convert('RGB') + width = image.size[0] + height = image.size[1] + return image, width, height + + +def processing_pipeline(image_url, thickness, sketch_bucket, sketch_name): + thickness = int(thickness) + # 提取sketch + image, width, height = get_image(image_url) + sketch_image = predict(image, width) + + # 设定线条粗细 + if thickness != 0: + dilated = cv2.erode(sketch_image, kernel, iterations=1) + # 将原图与膨胀后的图像进行混合,使用不同的权重 + sketch_image = cv2.addWeighted(sketch_image, weights[thickness][0], dilated, weights[thickness][1], 0) + + # 上传minio + image_bytes = cv2.imencode(".jpg", sketch_image)[1].tobytes() + req = oss_upload_image(bucket=sketch_bucket, object_name=sketch_name, image_bytes=image_bytes) + return f"{req.bucket_name}/{req.object_name}" + + +if __name__ == '__main__': + result_url = processing_pipeline("aida-users/89/relight_image/d5f0d967-f8e8-424d-98f9-a8ad8313deec-0-89.png", 1, "test", "test123.jpg") + print(result_url) diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py index 28015e9..95a0fbf 100644 --- a/app/service/utils/new_oss_client.py +++ b/app/service/utils/new_oss_client.py @@ -9,6 +9,7 @@ from PIL import Image from minio import Minio from app.core.config import * +from app.service.utils.decorator import RunTime minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) @@ -39,6 +40,7 @@ http_client = urllib3.PoolManager( # 获取图片 +@RunTime def oss_get_image(oss_client, bucket, object_name, data_type): # cv2 默认全通道读取 image_object = None @@ -58,6 +60,7 @@ def oss_get_image(oss_client, bucket, object_name, data_type): return image_object +@RunTime def oss_upload_image(oss_client, bucket, object_name, image_bytes): req = None try: