Merge branch 'local' into develop

# Conflicts:
#	app/core/config.py
This commit is contained in:
zchen
2024-06-23 16:12:08 +08:00
20 changed files with 455 additions and 406 deletions

View File

@@ -2,9 +2,10 @@ import json
import logging
from fastapi import APIRouter, HTTPException
from app.core.config import DEBUG
from app.schemas.attribute_retrieve import *
from app.schemas.response_template import ResponseModel
from app.service.attribute.config import const
from app.service.attribute.config import const, local_debug_const
from app.service.attribute.service_att_recognition import AttributeRecognition
from app.service.attribute.service_category_recognition import CategoryRecognition
@@ -17,13 +18,16 @@ logger = logging.getLogger()
def attribute_recognition(request_item: list[AttributeRecognitionModel]):
try:
logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}")
service = AttributeRecognition(const=const, request_data=request_item)
if DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else:
service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result()
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
return ResponseModel(data={"list": data})
# 类别识别

View File

@@ -4,9 +4,10 @@ import time
from fastapi import APIRouter, HTTPException
from app.schemas.design import DesignModel
from app.schemas.design import DesignModel, DesignProgressModel
from app.schemas.response_template import ResponseModel
from app.service.design.service import generate
from app.service.design.utils.redis_utils import Redis
router = APIRouter()
logger = logging.getLogger()
@@ -22,3 +23,19 @@ def design(request_data: DesignModel):
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
@router.post('/get_progress')
def get_progress(request_data: DesignProgressModel):
try:
logger.info(f"get_progress request item is : @@@@@@:{request_data}")
process_id = request_data.process_id
r = Redis()
data = r.read(key=process_id)
if data is None:
raise ValueError("The progress must be numbers ")
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {data}")
except Exception as e:
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -1,6 +1,6 @@
import logging
from fastapi import APIRouter
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
from fastapi import FastAPI, HTTPException
from app.schemas.response_template import ResponseModel
@@ -15,6 +15,8 @@ def test(id: int):
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
"local_oss_server": OSS
}
logger.info(data)
if id == 1:

View File

@@ -19,15 +19,16 @@ class Settings(BaseSettings):
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
OSS = "minio"
DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
# FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
# FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
# RABBITMQ_ENV = "" # 生产环境
# RABBITMQ_ENV = "-dev" # 开发环境
@@ -47,7 +48,7 @@ S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_HOST = "10.1.1.150"
REDIS_PORT = "6379"
REDIS_DB = "2"
@@ -60,9 +61,9 @@ RABBITMQ_PARAMS = {
}
# milvus 配置
MILVUS_DB_HOST = "10.1.1.240"
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_PORT = "19530"
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
MILVUS_TABLE_SEG = "seg_cache"
@@ -123,8 +124,8 @@ GPI_MODEL_NAME = 'diffusion_ensemble_all'
GPI_MODEL_URL = '10.1.1.240:10041'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
GRI_MODEL_NAME = 'stable_diffusion_1_5'
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
GRI_MODEL_NAME = 'diffusion_relight_ensemble'
GRI_MODEL_URL = '10.1.1.240:10051'
# SEG service config

View File

@@ -48,3 +48,7 @@ from pydantic import BaseModel
class DesignModel(BaseModel):
objects: list[dict]
process_id: str
class DesignProgressModel(BaseModel):
process_id: str

View File

@@ -1,13 +1,13 @@
top_description_list = ['service/attribute/config/descriptor/top/length.csv',
'service/attribute/config/descriptor/top/type.csv',
'service/attribute/config/descriptor/top/sleeve_length.csv',
'service/attribute/config/descriptor/top/sleeve_shape.csv',
'service/attribute/config/descriptor/top/sleeve_shoulder.csv',
'service/attribute/config/descriptor/top/neckline.csv',
'service/attribute/config/descriptor/top/design.csv',
'service/attribute/config/descriptor/top/opening_type.csv',
'service/attribute/config/descriptor/top/silhouette.csv',
'service/attribute/config/descriptor/top/collar.csv']
top_description_list = ['app/service/attribute/config/descriptor/top/length.csv',
'app/service/attribute/config/descriptor/top/type.csv',
'app/service/attribute/config/descriptor/top/sleeve_length.csv',
'app/service/attribute/config/descriptor/top/sleeve_shape.csv',
'app/service/attribute/config/descriptor/top/sleeve_shoulder.csv',
'app/service/attribute/config/descriptor/top/neckline.csv',
'app/service/attribute/config/descriptor/top/design.csv',
'app/service/attribute/config/descriptor/top/opening_type.csv',
'app/service/attribute/config/descriptor/top/silhouette.csv',
'app/service/attribute/config/descriptor/top/collar.csv']
top_model_list = ['attr_retrieve_T_length',
'attr_retrieve_T_type',
@@ -22,11 +22,11 @@ top_model_list = ['attr_retrieve_T_length',
]
bottom_description_list = [
'service/attribute/config/descriptor/bottom/subtype.csv',
'service/attribute/config/descriptor/bottom/length.csv',
'service/attribute/config/descriptor/bottom/silhouette.csv',
'service/attribute/config/descriptor/bottom/opening_type.csv',
'service/attribute/config/descriptor/bottom/design.csv']
'app/service/attribute/config/descriptor/bottom/subtype.csv',
'app/service/attribute/config/descriptor/bottom/length.csv',
'app/service/attribute/config/descriptor/bottom/silhouette.csv',
'app/service/attribute/config/descriptor/bottom/opening_type.csv',
'app/service/attribute/config/descriptor/bottom/design.csv']
bottom_model_list = [
'attr_retrieve_B_subtype',
@@ -35,14 +35,14 @@ bottom_model_list = [
'attr_recong_B_optype',
'attr_retrieve_B_design']
outwear_description_list = ['service/attribute/config/descriptor/outwear/length.csv',
'service/attribute/config/descriptor/outwear/sleeve_length.csv',
'service/attribute/config/descriptor/outwear/sleeve_shape.csv',
'service/attribute/config/descriptor/outwear/sleeve_shoulder.csv',
'service/attribute/config/descriptor/outwear/collar.csv',
'service/attribute/config/descriptor/outwear/design.csv',
'service/attribute/config/descriptor/outwear/opening_type.csv',
'service/attribute/config/descriptor/outwear/silhouette.csv', ]
outwear_description_list = ['app/service/attribute/config/descriptor/outwear/length.csv',
'app/service/attribute/config/descriptor/outwear/sleeve_length.csv',
'app/service/attribute/config/descriptor/outwear/sleeve_shape.csv',
'app/service/attribute/config/descriptor/outwear/sleeve_shoulder.csv',
'app/service/attribute/config/descriptor/outwear/collar.csv',
'app/service/attribute/config/descriptor/outwear/design.csv',
'app/service/attribute/config/descriptor/outwear/opening_type.csv',
'app/service/attribute/config/descriptor/outwear/silhouette.csv', ]
outwear_model_list = ['attr_recong_O_length',
'attr_retrieve_O_sleeve_length',
@@ -53,15 +53,15 @@ outwear_model_list = ['attr_recong_O_length',
'attr_recong_O_optype',
'attr_retrieve_O_silhouette']
dress_description_list = [ # 'service/attribute/config/descriptor/dress/D_length.csv',
'service/attribute/config/descriptor/dress/sleeve_length.csv',
'service/attribute/config/descriptor/dress/sleeve_shape.csv',
# 'service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv',
'service/attribute/config/descriptor/dress/neckline.csv',
'service/attribute/config/descriptor/dress/collar.csv',
'service/attribute/config/descriptor/dress/design.csv',
'service/attribute/config/descriptor/dress/silhouette.csv',
'service/attribute/config/descriptor/dress/type.csv']
dress_description_list = [ # 'app/service/attribute/config/descriptor/dress/D_length.csv',
'app/service/attribute/config/descriptor/dress/sleeve_length.csv',
'app/service/attribute/config/descriptor/dress/sleeve_shape.csv',
# 'app/service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv',
'app/service/attribute/config/descriptor/dress/neckline.csv',
'app/service/attribute/config/descriptor/dress/collar.csv',
'app/service/attribute/config/descriptor/dress/design.csv',
'app/service/attribute/config/descriptor/dress/silhouette.csv',
'app/service/attribute/config/descriptor/dress/type.csv']
dress_model_list = [ # 'attr_recong_D_length',
'attr_retrieve_D_sleeve_length',

View File

@@ -11,12 +11,12 @@ from minio import Minio
import tritonclient.http as httpclient
from app.core.config import *
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.oss_client import oss_get_image
class AttributeRecognition:
def __init__(self, const, request_data):
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
logging.info("实例化完成")
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
@@ -97,9 +97,10 @@ class AttributeRecognition:
return res
def get_image(self, url):
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

View File

@@ -18,12 +18,13 @@ import torch
from app.core.config import *
from app.schemas.attribute_retrieve import CategoryRecognitionModel
from app.service.utils.oss_client import oss_get_image
class CategoryRecognition:
def __init__(self, request_data):
self.attr_type = pd.read_csv(CATEGORY_PATH)
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.request_data = []
self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL)
for sketch in request_data:
@@ -51,9 +52,10 @@ class CategoryRecognition:
def get_image(self, url):
# Get data of an object.
# Read data from response.
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

View File

@@ -1,5 +1,6 @@
import logging
import time
import numpy as np
from pymilvus import MilvusClient
@@ -14,17 +15,17 @@ class KeypointDetection(object):
path here: abstract path
"""
def __init__(self):
self.client = MilvusClient(
uri="http://10.1.1.240:19530",
token="root:Milvus",
db_name=MILVUS_ALIAS
)
# def __init__(self):
# self.client = MilvusClient(
# uri="http://10.1.1.240:19530",
# token="root:Milvus",
# db_name=MILVUS_ALIAS
# )
def __del__(self):
# start_time = time.time()
self.client.close()
# print(f"client close time : {time.time() - start_time}")
# def __del__(self):
# start_time = time.time()
# self.client.close()
# print(f"client close time : {time.time() - start_time}")
# @ RunTime
def __call__(self, result):
@@ -55,7 +56,7 @@ class KeypointDetection(object):
@staticmethod
# @ RunTime
def save_keypoint_cache(keypoint_id, cache, site, KEYPOINT_RESULT_TABLE_FIELD_SET=None):
def save_keypoint_cache(keypoint_id, cache, site):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
@@ -69,24 +70,16 @@ class KeypointDetection(object):
"keypoint_vector": result.tolist()
}
]
client = MilvusClient(
uri="http://10.1.1.240:19530",
token="root:Milvus",
db_name=MILVUS_ALIAS
)
try:
start_time = time.time()
res = client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data,
)
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# start_time = time.time()
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
finally:
client.close()
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
@@ -102,12 +95,9 @@ class KeypointDetection(object):
"keypoint_vector": result.tolist()
}
]
client = MilvusClient(
uri="http://10.1.1.240:19530",
token="root:Milvus",
db_name=MILVUS_ALIAS
)
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
@@ -125,8 +115,9 @@ class KeypointDetection(object):
# @ RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = result['image_id']
res = self.client.query(
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",

View File

@@ -1,6 +1,5 @@
import io
import logging
import time
import cv2
import numpy as np
@@ -8,6 +7,7 @@ from PIL import Image
from minio import Minio
from app.core.config import *
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
@@ -70,11 +70,7 @@ class LoadImageFromFile(object):
class LoadBodyImageFromFile(object):
def __init__(self, body_path):
self.body_path = body_path
self.minioClient = Minio(
f"{MINIO_URL}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
@@ -82,33 +78,33 @@ class LoadBodyImageFromFile(object):
def __call__(self, result):
result["image_url"] = result['body_path'] = self.body_path
result["name"] = "mannequin"
if not result['image_url'].lower().endswith(".png"):
logging.info(1)
bucket = self.body_path.split("/", 1)[0]
object_name = self.body_path.split("/", 1)[1]
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
image = self.minioClient.get_object(bucket, object_name)
image = Image.open(io.BytesIO(image.data))
image = image.convert("RGBA")
data = image.getdata()
#
new_data = []
for item in data:
if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
new_data.append((255, 255, 255, 0))
else:
new_data.append(item)
image.putdata(new_data)
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
self.body_path = image_path
result["image_url"] = result['body_path'] = self.body_path
response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
# if not result['image_url'].lower().endswith(".png"):
# bucket = self.body_path.split("/", 1)[0]
# object_name = self.body_path.split("/", 1)[1]
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
# image = self.minioClient.get_object(bucket, object_name)
# image = Image.open(io.BytesIO(image.data))
# image = image.convert("RGBA")
# data = image.getdata()
# #
# new_data = []
# for item in data:
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
# new_data.append((255, 255, 255, 0))
# else:
# new_data.append(item)
# image.putdata(new_data)
# image_data = io.BytesIO()
# image.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# self.body_path = image_path
# result["image_url"] = result['body_path'] = self.body_path
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
# put_image_time = time.time()
result['body_image'] = Image.open(io.BytesIO(response.read()))
# result['body_image'] = Image.open(io.BytesIO(response.read()))
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
return result

View File

@@ -1,19 +1,16 @@
import random
from io import BytesIO
# import boto3
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import *
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
minio_client = Minio(
MINIO_URL,
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
# minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
@@ -56,17 +53,18 @@ class Painting(object):
@staticmethod
def get_gradient(bucket_name, object_name):
image_data = minio_client.get_object(bucket_name, object_name)
# image_data = minio_client.get_object(bucket_name, object_name)
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
# 从数据流中读取图像
image_bytes = image_data.read()
# image_bytes = image_data.read()
# 将图像数据转换为numpy数组
image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
# image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
# 使用OpenCV解码图像数组
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
return image
@staticmethod
@@ -494,16 +492,20 @@ class PrintPainting(object):
if not 'IfSingle' in print_dict.keys():
print_dict['IfSingle'] = False
data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1])
# data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body']
# data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1])
# data_bytes = BytesIO(data.read())
# image = Image.open(data_bytes)
# image_mode = image.mode
data_bytes = BytesIO(data.read())
image = Image.open(data_bytes)
image_mode = image.mode
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image_mode == "RGBA":
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
if image.shape[2] == 4:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image_pil = Image.fromarray(image_rgb)
new_background = Image.new('RGB', image_pil.size, (255, 255, 255))
new_background.paste(image_pil, mask=image.split()[3])
image = new_background
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
@@ -577,21 +579,30 @@ class PrintPainting(object):
@staticmethod
def read_image(image_url):
data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
# data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body']
data_bytes = BytesIO(data.read())
image = Image.open(data_bytes)
image_mode = image.mode
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image_mode == "RGBA":
# new_background = Image.new('RGB', image.size, (255, 255, 255))
# new_background.paste(image, mask=image.split()[3])
# image = new_background
return image, image_mode
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
if image.shape[2] == 4:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image = Image.fromarray(image_rgb)
image_mode = "RGBA"
else:
image_mode = "RGB"
return image, image_mode
# data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
# # data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body']
#
# data_bytes = BytesIO(data.read())
# image = Image.open(data_bytes)
# image_mode = image.mode
# # 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
# if image_mode == "RGBA":
# # new_background = Image.new('RGB', image.size, (255, 255, 255))
# # new_background.paste(image, mask=image.split()[3])
# # image = new_background
# return image, image_mode
# image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
# return image, "RGB"
# @staticmethod
# def read_image(image_url):
# response = requests.get(image_url)

View File

@@ -41,7 +41,7 @@ class Split(object):
else:
back_mask = result['back_mask']
rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask'])
rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), re4sult['final_image'], result['mask'])
result_front_image = np.zeros_like(rgba_image)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]

View File

@@ -13,15 +13,12 @@ import io
from app.core.config import *
from app.service.design.utils.design_ensemble import get_keypoint_result
from app.service.utils.oss_client import oss_get_image, oss_upload_image
class DesignPreprocessing:
def __init__(self):
self.minio_client = Minio(
MINIO_URL,
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
# def __init__(self):
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# @ RunTime
def pipeline(self, image_list):
@@ -51,8 +48,9 @@ class DesignPreprocessing:
def read_image(self, image_list):
for obj in image_list:
file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
# file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # 如果是四通道 mask
@@ -125,7 +123,10 @@ class DesignPreprocessing:
try:
# 覆盖到minio
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
bucket_name = item['image_url'].split("/", 1)[0]
object_name = item['image_url'].split("/", 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
except ResponseError as err:
print(f"Error: {err}")
@@ -165,36 +166,76 @@ class DesignPreprocessing:
# @ RunTime
def composing_image(self, image_list):
for image in image_list:
if image['site'] == 'down':
image_width = image['obj'].shape[1]
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
scale = 0.4
if waist_width / scale >= image['obj'].shape[1]:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
if IF_DEBUG_SHOW:
cv2.imshow("composing_image", ret)
cv2.waitKey(0)
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
else:
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
''' 比例相同 整合上下装代码'''
image_width = image['obj'].shape[1]
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
scale = 0.4
if waist_width / scale >= image_width:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
if IF_DEBUG_SHOW:
cv2.imshow("composing_image", ret)
cv2.waitKey(0)
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
image['show_image_url'] = f"{bucket_name}/{object_name}"
else:
scale = 0.4
image_width = image['obj'].shape[1]
waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1]
if waist_width / scale >= image_width:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
if IF_DEBUG_SHOW:
cv2.imshow("composing_image", ret)
cv2.waitKey(0)
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
else:
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
image['show_image_url'] = f"{bucket_name}/{object_name}"
# if image['site'] == 'down':
# image_width = image['obj'].shape[1]
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
# scale = 0.4
# if waist_width / scale >= image_width:
# add_width = int((waist_width / scale - image_width) / 2)
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
# if IF_DEBUG_SHOW:
# cv2.imshow("composing_image", ret)
# cv2.waitKey(0)
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_width = image['obj'].shape[1]
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
# scale = 0.4
# if waist_width / scale >= image_width:
# add_width = int((waist_width / scale - image_width) / 2)
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
# if IF_DEBUG_SHOW:
# cv2.imshow("composing_image", ret)
# cv2.waitKey(0)
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
return image_list
@staticmethod

View File

@@ -10,21 +10,17 @@
import json
import logging
import time
from io import BytesIO
import cv2
import minio
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
@@ -36,7 +32,7 @@ class GenerateImage:
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
if request_data.mode == "img2img":
@@ -63,10 +59,13 @@ class GenerateImage:
# Read data from response.
# read image use cv2
try:
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url, data_type="cv2")
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image = cv2.resize(image_rbg, (1024, 1024))
except minio.error.S3Error:
@@ -189,7 +188,8 @@ if __name__ == '__main__':
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
image_url="",
mode='txt2img',
category="test"
category="test",
gender="male"
)
server = GenerateImage(rd)
print(server.get_result())
print(server.get_result())

View File

@@ -18,10 +18,10 @@ import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.schemas.generate_image import GenerateProductImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
@@ -33,69 +33,29 @@ class GenerateProductImage:
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
self.batch_size = 1
self.prompt = request_data.prompt
# TODO aida design 结果图背景改为白色
self.image, self.image_size = self.get_image(request_data.image_url)
# TODO image 填充并resize成512*768
self.image, self.image_size = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_bytes = io.BytesIO(response.read())
# 转换为PIL图像对象
image = Image.open(image_bytes)
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
# self.gen_product_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
@@ -105,13 +65,6 @@ class GenerateProductImage:
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GPI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
@@ -129,11 +82,10 @@ class GenerateProductImage:
input_image.set_data_from_numpy(image_obj)
inputs = [input_text, input_image]
ctx = self.infer(inputs)
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
# logger.info(gen_product_data)
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
@@ -141,7 +93,6 @@ class GenerateProductImage:
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, gen_product_data)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
@@ -153,7 +104,6 @@ class GenerateProductImage:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
@@ -165,11 +115,37 @@ def infer_cancel(tasks_id):
return data
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# resize 图片内尺寸 并贴到768-512的纯白图像上
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image, image_size
if __name__ == '__main__':
rd = GenerateImageModel(
rd = GenerateProductImageModel(
tasks_id="123-89",
prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png",
prompt="",
# prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
)
server = GenerateProductImage(rd)
print(server.get_result())

View File

@@ -22,6 +22,7 @@ from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
@@ -31,71 +32,34 @@ class GenerateRelightImage:
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "relight_image"
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "12345"
# TODO aida design 结果图背景改为白色
# self.image, self.image_size = self.get_image(request_data.image_url)
self.image = request_data.image_url
# TODO image 填充并resize成512*768
self.seed = "1"
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.direction = "Right Light"
self.image_url = request_data.image_url
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_bytes = io.BytesIO(response.read())
# 转换为PIL图像对象
image = Image.open(image_bytes)
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
# self.gen_product_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
@@ -105,62 +69,40 @@ class GenerateRelightImage:
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GRI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
direction = "Right Light"
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere'
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput(
"prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)
)
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * self.batch_size
seeds = [self.seed] * self.batch_size
nagetive_prompts = [self.negative_prompt] * self.batch_size
directions = [self.direction] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
negative_prompts = [negative_prompt] * self.batch_size
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput(
"negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)
)
input_text_neg.set_data_from_numpy(text_obj_neg)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput(
"seed", seed.shape, np_to_triton_dtype(seed.dtype)
)
input_seed.set_data_from_numpy(seed)
input_images = [self.image] * self.batch_size
text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1))
input_input_images = grpcclient.InferInput(
"input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype)
)
input_input_images.set_data_from_numpy(text_obj_images)
directions = [direction] * self.batch_size
text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1))
input_directions = grpcclient.InferInput(
"direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype)
)
input_directions.set_data_from_numpy(text_obj_directions)
output_img = grpcclient.InferRequestedOutput("generated_image")
request_start = time.time()
inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions]
ctx = self.infer(inputs)
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
# logger.info(gen_product_data)
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
@@ -168,7 +110,6 @@ class GenerateRelightImage:
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, gen_product_data)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
@@ -179,9 +120,8 @@ class GenerateRelightImage:
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GRI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
@@ -195,8 +135,9 @@ def infer_cancel(tasks_id):
if __name__ == '__main__':
rd = GenerateRelightImageModel(
tasks_id="123-89",
prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
image_url="/workspace/i3.png",
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
prompt="Colorful black",
image_url='aida-users/89/product_image/123-89.png'
)
server = GenerateRelightImage(rd)
print(server.get_result())

View File

@@ -31,8 +31,6 @@ class GenerateSingleLogoImage:
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
@@ -51,23 +49,15 @@ class GenerateSingleLogoImage:
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GSL_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def callback(self, result, error):
if error:
self.gen_single_logo_data['status'] = "FAILURE"
self.gen_single_logo_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
else:
image = result.as_numpy("generated_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_single_logo_data['status'] = "SUCCESS"
self.gen_single_logo_data['message'] = "success"
self.gen_single_logo_data['image_url'] = str(image_url)
@@ -81,25 +71,19 @@ class GenerateSingleLogoImage:
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
# negative_prompts
text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1))
# print('text obj neg: ', text_obj_neg)
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
input_text_neg.set_data_from_numpy(text_obj_neg)
# seed
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype))
input_seed.set_data_from_numpy(seed)
inputs = [input_text, input_text_neg, input_seed]
ctx = self.infer(inputs)
ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
@@ -107,7 +91,6 @@ class GenerateSingleLogoImage:
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
raise Exception(str(e))
@@ -115,7 +98,6 @@ class GenerateSingleLogoImage:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")

View File

@@ -16,8 +16,11 @@ from PIL import Image
from minio import Minio
from app.core.config import *
from app.service.utils.oss_client import oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
@@ -34,36 +37,34 @@ minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET
# except Exception as e:
# print(f'上传到 S3 失败: {e}')
def upload_SDXL_image(image, user_id, category, object_name):
def upload_SDXL_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
minio_req = minio_client.put_object(
GI_MINIO_BUCKET,
f'{user_id}/{category}/{object_name}',
io.BytesIO(image_bytes),
len(image_bytes),
content_type='image/jpeg'
)
image_url = f"aida-users/{minio_req.object_name}"
# minio_req = minio_client.put_object(
# GI_MINIO_BUCKET,
# f'{user_id}/{category}/{file_name}',
# io.BytesIO(image_bytes),
# len(image_bytes),
# content_type='image/jpeg'
# )
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_png_sd(image, user_id, category, object_name):
def upload_png_sd(image, user_id, category, file_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
minio_req = minio_client.put_object(
GI_MINIO_BUCKET,
f'{user_id}/{category}/{object_name}',
io.BytesIO(img_byte_array),
len(img_byte_array),
content_type='image/jpeg'
)
image_url = f"aida-users/{minio_req.object_name}"
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -1,17 +1,15 @@
import io
import json
import logging
import time
import minio.error
import redis
import json
import cv2
import minio.error
import numpy as np
import redis
import torch
import tritonclient.grpc as grpcclient
from minio import Minio
from app.core.config import *
from app.schemas.super_resolution import SuperResolutionModel
from app.service.utils.decorator import RunTime
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
@@ -24,7 +22,7 @@ class SuperResolution:
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.sr_image_url = data.sr_image_url
self.sr_xn = data.sr_xn
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
self.redis_client.expire(self.tasks_id, 600)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
@@ -33,16 +31,25 @@ class SuperResolution:
# @RunTime
def read_image(self):
try:
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2")
except minio.error.S3Error as e:
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
return img
# try:
# image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
# except minio.error.S3Error as e:
# sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
# self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
# logger.info(f" [x] Sent {sr_data}")
# raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
# img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
# return img
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
@@ -101,8 +108,10 @@ class SuperResolution:
def upload_img_sr(self, image):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
image_url = f"aida-users/{res.object_name}"
# res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg'
oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"{SR_MINIO_BUCKET}/{object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -0,0 +1,70 @@
import io
import logging
from io import BytesIO
import boto3
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import *
logger = logging.getLogger()
# 获取图片
def oss_get_image(bucket, object_name, data_type):
# cv2 默认全通道读取
image_object = None
try:
if OSS == "minio":
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
else:
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body']
if data_type == "cv2":
image_bytes = image_data.read()
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
else:
data_bytes = BytesIO(image_data.read())
image_object = Image.open(data_bytes)
except Exception as e:
logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}")
return image_object
def oss_upload_image(bucket, object_name, image_bytes):
req = None
try:
if OSS == "minio":
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
else:
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
req = oss_client.put_object(Bucket=AIDA_CLOTHING, Key=object_name, Body=image_bytes, ContentType='image/png')
except Exception as e:
logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}")
return req
if __name__ == '__main__':
# url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png"
# url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg"
# url = "aida-sys-image/images/female/outwear/0628000054.jpg"
# url = "aida-users/89/product_image/string-89.png"
url = "aida-users/89/single_logo/123-89.png"
# url = 'aida-users/89/relight_image/123-89.png'
# url = 'aida-users/89/relight_image/123-89.png'
# url = 'aida-users/89/relight_image/123-89.png'
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
read_type = "cv2"
if read_type == "cv2":
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
cv2.imshow("", img)
cv2.waitKey(0)
else:
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
img.show()