diff --git a/app/api/api_image2sketch.py b/app/api/api_image2sketch.py index 5bc6191..5d15daa 100644 --- a/app/api/api_image2sketch.py +++ b/app/api/api_image2sketch.py @@ -16,9 +16,9 @@ logger = logging.getLogger() def image2sketch(request_item: Image2SketchModel): """ 创建一个具有以下参数的请求体: - - **sr_image_url**: 超分图片的minio或s3 url地址 - - **sr_xn**: 超分的倍数,只接受2或4 - - **sr_tasks_id**: 任务id 用于取消超分任务和获取超分结果 + - **image_url**: 提取图片url + - **sketch_bucket**: sketch保存的bucket + - **sketch_name**: sketch保存的object name 示例参数: { diff --git a/app/service/image2sketch/server.py b/app/service/image2sketch/server.py index 82f1843..ebd363e 100644 --- a/app/service/image2sketch/server.py +++ b/app/service/image2sketch/server.py @@ -52,7 +52,8 @@ class Image2SketchServer: self.data['B'] = style_img self.data['B'] = self.data['B'].unsqueeze(0).to(device) A, self.width, self.height = self.get_image(self.image_url) - + self.data['A'] = transform(A) + self.data['A'] = self.data['A'].unsqueeze(0).to(device) def get_result(self): self.model.set_input(self.data) diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index 370cd7c..20794b9 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -5,12 +5,23 @@ from io import BytesIO import boto3 import cv2 import numpy as np +import urllib3 from PIL import Image from minio import Minio from app.core.config import * logger = logging.getLogger() +timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒 +http_client = urllib3.PoolManager( + timeout=timeout, + cert_reqs='CERT_REQUIRED', # 需要证书验证 + retries=urllib3.Retry( + total=5, + backoff_factor=0.2, + status_forcelist=[500, 502, 503, 504], + ), +) # 获取图片 @@ -19,7 +30,7 @@ def oss_get_image(bucket, object_name, data_type): image_object = None try: if OSS == "minio": - oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE, http_client=http_client) image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) else: oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) @@ -64,8 +75,8 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - # url = "aida-users/89/product_image/string-89.png" - url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" + url = "aida-users/89/product_image/string-89.png" + # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "cv2" if read_type == "cv2": img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)