feat(新功能):
fix(修复bug): relight 图片尺寸自适应 docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
@@ -4,7 +4,7 @@ import logging
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES
|
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, BATCH_PS_RABBITMQ_QUEUES
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
@@ -16,10 +16,17 @@ def test(id: int):
|
|||||||
data = {
|
data = {
|
||||||
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
||||||
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
||||||
|
"pose transform PS_RABBITMQ_QUEUES": PS_RABBITMQ_QUEUES,
|
||||||
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
|
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
|
||||||
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||||
"to product image GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
"to product image GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||||
"relight GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
"relight GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||||
|
|
||||||
|
# batch
|
||||||
|
"batch product BATCH_GPI_RABBITMQ_QUEUES": BATCH_GPI_RABBITMQ_QUEUES,
|
||||||
|
"batch relight BATCH_GRI_RABBITMQ_QUEUES": BATCH_GRI_RABBITMQ_QUEUES,
|
||||||
|
"batch pose transform BATCH_PS_RABBITMQ_QUEUES": BATCH_PS_RABBITMQ_QUEUES,
|
||||||
|
|
||||||
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
||||||
"local_oss_server": OSS
|
"local_oss_server": OSS
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class GenerateRelightImage:
|
|||||||
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
||||||
self.direction = request_data.direction
|
self.direction = request_data.direction
|
||||||
self.image_url = request_data.image_url
|
self.image_url = request_data.image_url
|
||||||
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
self.image = pre_processing_image(self.image_url)
|
||||||
self.tasks_id = request_data.tasks_id
|
self.tasks_id = request_data.tasks_id
|
||||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||||
@@ -137,6 +137,46 @@ class GenerateRelightImage:
|
|||||||
if not DEBUG:
|
if not DEBUG:
|
||||||
publish_status(str_gen_product_data, GRI_RABBITMQ_QUEUES)
|
publish_status(str_gen_product_data, GRI_RABBITMQ_QUEUES)
|
||||||
|
|
||||||
|
def pre_processing_image(image_url):
|
||||||
|
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||||
|
# 目标图片的尺寸
|
||||||
|
target_width = 512
|
||||||
|
target_height = 768
|
||||||
|
|
||||||
|
# 原始图片的尺寸
|
||||||
|
original_width, original_height = image.size
|
||||||
|
|
||||||
|
# 计算宽度和高度的缩放比例
|
||||||
|
width_ratio = target_width / original_width
|
||||||
|
height_ratio = target_height / original_height
|
||||||
|
|
||||||
|
# 选择较小的缩放比例,确保图片能完整放入目标图片中
|
||||||
|
scale_ratio = min(width_ratio, height_ratio)
|
||||||
|
|
||||||
|
# 计算调整后的尺寸
|
||||||
|
new_width = int(original_width * scale_ratio)
|
||||||
|
new_height = int(original_height * scale_ratio)
|
||||||
|
|
||||||
|
# 调整图片大小
|
||||||
|
resized_image = image.resize((new_width, new_height))
|
||||||
|
|
||||||
|
# 创建一个 512x768 的透明图片
|
||||||
|
result_image = Image.new("RGBA", (target_width, target_height), (255, 255, 255, 0))
|
||||||
|
|
||||||
|
# 计算需要粘贴的位置,使图片居中
|
||||||
|
x_offset = (target_width - new_width) // 2
|
||||||
|
y_offset = (target_height - new_height) // 2
|
||||||
|
|
||||||
|
# 将调整大小后的图片粘贴到透明图片上
|
||||||
|
if resized_image.mode == "RGBA":
|
||||||
|
result_image.paste(resized_image, (x_offset, y_offset), mask=resized_image.split()[3])
|
||||||
|
else:
|
||||||
|
result_image.paste(resized_image, (x_offset, y_offset))
|
||||||
|
|
||||||
|
image = np.array(result_image)
|
||||||
|
|
||||||
|
# image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||||
|
return image
|
||||||
|
|
||||||
def infer_cancel(tasks_id):
|
def infer_cancel(tasks_id):
|
||||||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
@@ -153,7 +193,7 @@ if __name__ == '__main__':
|
|||||||
prompt="Colorful black",
|
prompt="Colorful black",
|
||||||
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
||||||
direction="Right Light",
|
direction="Right Light",
|
||||||
product_type="single"
|
product_type="overall"
|
||||||
)
|
)
|
||||||
server = GenerateRelightImage(rd)
|
server = GenerateRelightImage(rd)
|
||||||
print(server.get_result())
|
print(server.get_result())
|
||||||
|
|||||||
Reference in New Issue
Block a user