From fabe64785ef244054435346a4c75cb66696ec595 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 19 Sep 2024 15:10:50 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=20design=20=20=E6=8F=90=E9=80=9F?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_design.py | 3 +- app/service/design/utils/upload_image.py | 20 +--- app/service/design_test/batch_design.py | 8 +- app/service/design_test/pipeline/color.py | 2 +- app/service/design_test/pipeline/loading.py | 2 +- .../design_test/pipeline/print_painting.py | 2 +- .../design_test/pipeline/segmentation.py | 2 +- app/service/design_test/pipeline/split.py | 4 +- app/service/design_test/utils/upload_image.py | 24 +---- app/service/utils/new_oss_client.py | 91 ++++++++++++++++++ app/service/utils/oss_client.py | 27 ++++-- requirements.txt | Bin 1842 -> 1828 bytes 12 files changed, 129 insertions(+), 56 deletions(-) create mode 100644 app/service/utils/new_oss_client.py diff --git a/app/api/api_design.py b/app/api/api_design.py index 4db9fc2..ba8f04d 100644 --- a/app/api/api_design.py +++ b/app/api/api_design.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, HTTPException, UploadFile, File, Form from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel from app.schemas.response_template import ResponseModel from app.service.design.model_process_service import model_transpose +from app.service.design.service import generate from app.service.design.service_design_batch_generate import start_design_batch_generate from app.service.design.utils.redis_utils import Redis from app.service.design_test.batch_design import design_generate @@ -183,7 +184,7 @@ def design(request_data: DesignModel): # logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") # data = generate(request_data=request_data) # logger.info(f"design response @@@@@@:{json.dumps(data)}") - + # logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}") data = design_generate(request_data=request_data) logger.info(f"design response @@@@@@:{json.dumps(data)}") diff --git a/app/service/design/utils/upload_image.py b/app/service/design/utils/upload_image.py index 8d20061..388f8b8 100644 --- a/app/service/design/utils/upload_image.py +++ b/app/service/design/utils/upload_image.py @@ -17,7 +17,7 @@ from app.service.utils.oss_client import oss_upload_image # @RunTime -def upload_png_mask(minio_client, front_image, object_name, mask=None): +def upload_png_mask(front_image, object_name, mask=None): try: mask_url = None if mask is not None: @@ -25,29 +25,15 @@ def upload_png_mask(minio_client, front_image, object_name, mask=None): # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - # image_bytes = io.BytesIO() - # image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - # image_bytes.seek(0) - # mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" - # oss upload #################### - req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" image_data = io.BytesIO() front_image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - # image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) + req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" return front_image, image_url, mask_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") - -# @RunTime -# def upload_png_mask(front_image, object_name, mask=None): -# mask_url = None -# if mask is not None: -# mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" -# image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" -# return front_image, image_url, mask_url diff --git a/app/service/design_test/batch_design.py b/app/service/design_test/batch_design.py index 5a07429..27846cb 100644 --- a/app/service/design_test/batch_design.py +++ b/app/service/design_test/batch_design.py @@ -13,7 +13,8 @@ from minio import Minio from app.core.config import PRIORITY_DICT from app.service.design.utils.redis_utils import Redis from app.service.design_test.item import BodyItem, TopItem, BottomItem -from app.service.utils.oss_client import oss_upload_image +from app.service.utils.decorator import RunTime +from app.service.utils.new_oss_client import oss_upload_image id_lock = threading.Lock() @@ -298,10 +299,11 @@ def synthesis(data, size, basic_info): logging.warning(f"synthesis runtime exception : {e}") +@RunTime def design_generate(request_data): objects_data = request_data.dict()['objects'] process_id = request_data.dict()['process_id'] - object_response = [] + object_response = {} threads = [] active_threads = 0 lock = threading.Lock() @@ -362,7 +364,7 @@ def design_generate(request_data): update_progress(process_id, total) with lock: - object_response.append(items_response) + object_response[step] = items_response active_threads -= 1 for step, object in enumerate(objects_data): diff --git a/app/service/design_test/pipeline/color.py b/app/service/design_test/pipeline/color.py index d065aba..546c671 100644 --- a/app/service/design_test/pipeline/color.py +++ b/app/service/design_test/pipeline/color.py @@ -3,7 +3,7 @@ import logging import cv2 import numpy as np -from app.service.utils.oss_client import oss_get_image +from app.service.utils.new_oss_client import oss_get_image logger = logging.getLogger() diff --git a/app/service/design_test/pipeline/loading.py b/app/service/design_test/pipeline/loading.py index 7175881..0ce0dfa 100644 --- a/app/service/design_test/pipeline/loading.py +++ b/app/service/design_test/pipeline/loading.py @@ -5,7 +5,7 @@ import cv2 import numpy as np from PIL import Image -from app.service.utils.oss_client import oss_get_image +from app.service.utils.new_oss_client import oss_get_image logger = logging.getLogger() diff --git a/app/service/design_test/pipeline/print_painting.py b/app/service/design_test/pipeline/print_painting.py index 4a85399..6fe40d8 100644 --- a/app/service/design_test/pipeline/print_painting.py +++ b/app/service/design_test/pipeline/print_painting.py @@ -4,7 +4,7 @@ import cv2 import numpy as np from PIL import Image -from app.service.utils.oss_client import oss_get_image +from app.service.utils.new_oss_client import oss_get_image class PrintPainting: diff --git a/app/service/design_test/pipeline/segmentation.py b/app/service/design_test/pipeline/segmentation.py index 156742f..5c248b2 100644 --- a/app/service/design_test/pipeline/segmentation.py +++ b/app/service/design_test/pipeline/segmentation.py @@ -6,7 +6,7 @@ import numpy as np from app.core.config import SEG_CACHE_PATH from app.service.design.utils.design_ensemble import get_seg_result -from app.service.utils.oss_client import oss_get_image +from app.service.utils.new_oss_client import oss_get_image logger = logging.getLogger() diff --git a/app/service/design_test/pipeline/split.py b/app/service/design_test/pipeline/split.py index 1fa4215..50e167d 100644 --- a/app/service/design_test/pipeline/split.py +++ b/app/service/design_test/pipeline/split.py @@ -8,9 +8,9 @@ from cv2 import cvtColor, COLOR_BGR2RGBA from app.core.config import AIDA_CLOTHING from app.service.design.utils.conversion_image import rgb_to_rgba -from app.service.design.utils.upload_image import upload_png_mask +from app.service.design_test.utils.upload_image import upload_png_mask from app.service.utils.generate_uuid import generate_uuid -from app.service.utils.oss_client import oss_upload_image +from app.service.utils.new_oss_client import oss_upload_image class Split(object): diff --git a/app/service/design_test/utils/upload_image.py b/app/service/design_test/utils/upload_image.py index 9039ce7..2c79f9f 100644 --- a/app/service/design_test/utils/upload_image.py +++ b/app/service/design_test/utils/upload_image.py @@ -13,12 +13,11 @@ import logging import cv2 from app.core.config import * -from app.service.utils.decorator import RunTime -from app.service.utils.oss_client import oss_upload_image +from app.service.utils.new_oss_client import oss_upload_image # @RunTime -def upload_png_mask(front_image, object_name, mask=None): +def upload_png_mask(minio_client, front_image, object_name, mask=None): try: mask_url = None if mask is not None: @@ -26,30 +25,15 @@ def upload_png_mask(front_image, object_name, mask=None): # 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明 rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA) rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0] - # image_bytes = io.BytesIO() - # image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes()) - # image_bytes.seek(0) - # mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}" - # oss upload #################### - req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) + req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1]) mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" image_data = io.BytesIO() front_image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() - # image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}" - req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) + req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes) image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" return front_image, image_url, mask_url except Exception as e: logging.warning(f"upload_png_mask runtime exception : {e}") - - -# @RunTime -# def upload_png_mask(front_image, object_name, mask=None): -# mask_url = None -# if mask is not None: -# mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png" -# image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png" -# return front_image, image_url, mask_url diff --git a/app/service/utils/new_oss_client.py b/app/service/utils/new_oss_client.py new file mode 100644 index 0000000..28015e9 --- /dev/null +++ b/app/service/utils/new_oss_client.py @@ -0,0 +1,91 @@ +import io +import logging +from io import BytesIO + +import cv2 +import numpy as np +import urllib3 +from PIL import Image +from minio import Minio + +from app.core.config import * + +minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + +# 自定义 Retry 类 +class CustomRetry(urllib3.Retry): + def increment(self, method=None, url=None, response=None, error=None, **kwargs): + # 调用父类的 increment 方法 + new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs) + # 打印重试信息 + logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}") + return new_retry + + +logger = logging.getLogger() +timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒 +http_client = urllib3.PoolManager( + num_pools=10, # 设置连接池大小 + maxsize=10, + timeout=timeout, + cert_reqs='CERT_REQUIRED', # 需要证书验证 + retries=CustomRetry( + total=5, + backoff_factor=0.2, + status_forcelist=[500, 502, 503, 504], + ), +) + + +# 获取图片 +def oss_get_image(oss_client, bucket, object_name, data_type): + # cv2 默认全通道读取 + image_object = None + try: + image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) + if data_type == "cv2": + image_bytes = image_data.read() + image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 + image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED) + if image_object.dtype == np.uint16: + image_object = (image_object / 256).astype('uint8') + else: + data_bytes = BytesIO(image_data.read()) + image_object = Image.open(data_bytes) + except Exception as e: + logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}") + return image_object + + +def oss_upload_image(oss_client, bucket, object_name, image_bytes): + req = None + try: + req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') + except Exception as e: + logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") + return req + + +if __name__ == '__main__': + # url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png" + # url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg" + # url = "aida-sys-image/images/female/outwear/0628000054.jpg" + # url = "aida-users/89/product_image/string-89.png" + # url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png" + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + # url = 'aida-users/89/relight_image/123-89.png' + # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" + # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" + # url = "aida-users/89/single_logo/123-89.png" + url = "aida-users/31/sketchboard/female/dress/6edcbf92-7da9-4809-a0a8-a4b4f06dec1e0628000041.jpg" + # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" + read_type = "cv2" + if read_type == "cv2": + img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + cv2.imshow("", img) + cv2.waitKey(0) + else: + img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + img.show() diff --git a/app/service/utils/oss_client.py b/app/service/utils/oss_client.py index 28015e9..65ce3a2 100644 --- a/app/service/utils/oss_client.py +++ b/app/service/utils/oss_client.py @@ -2,6 +2,7 @@ import io import logging from io import BytesIO +import boto3 import cv2 import numpy as np import urllib3 @@ -10,8 +11,6 @@ from minio import Minio from app.core.config import * -minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) - # 自定义 Retry 类 class CustomRetry(urllib3.Retry): @@ -39,11 +38,16 @@ http_client = urllib3.PoolManager( # 获取图片 -def oss_get_image(oss_client, bucket, object_name, data_type): +def oss_get_image(bucket, object_name, data_type): # cv2 默认全通道读取 image_object = None try: - image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE, http_client=http_client) + image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name) + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body'] if data_type == "cv2": image_bytes = image_data.read() image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型 @@ -58,10 +62,15 @@ def oss_get_image(oss_client, bucket, object_name, data_type): return image_object -def oss_upload_image(oss_client, bucket, object_name, image_bytes): +def oss_upload_image(bucket, object_name, image_bytes): req = None try: - req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') + if OSS == "minio": + oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png') + else: + oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME) + req = oss_client.put_object(Bucket=bucket, Key=object_name, Body=io.BytesIO(image_bytes), ContentType='image/png') except Exception as e: logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}") return req @@ -79,13 +88,13 @@ if __name__ == '__main__': # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/89/single_logo/123-89.png" - url = "aida-users/31/sketchboard/female/dress/6edcbf92-7da9-4809-a0a8-a4b4f06dec1e0628000041.jpg" + url = "aida-clothing/mask/mask_f354afb5-6423-11ef-8b08-0826ae3ad6b3.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" read_type = "cv2" if read_type == "cv2": - img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) cv2.imshow("", img) cv2.waitKey(0) else: - img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) + img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) img.show() diff --git a/requirements.txt b/requirements.txt index 21dbcb5b32cf478aad263cb15d648e0c87fc4c31..6c9e38f1ded86de71e2126d5c357903ea0d08a05 100644 GIT binary patch delta 12 TcmdnQw}fxQEY{5{Si2YjAo~QZ delta 22 dcmZ3&w~24VELP5BhE#?eh9ZW_%~M(X839pz2FCyZ