feat design batch

fix
This commit is contained in:
zhouchengrong
2024-09-12 10:05:38 +08:00
parent cfa2cd1987
commit 44bb380944
18 changed files with 1889 additions and 0 deletions

View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File conversion_image.py
@Author :周成融
@Date 2023/8/21 10:40:29
@detail
"""
import numpy as np
# def rgb_to_rgba(rgb_size, rgb_image, mask):
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
# # 创建四通道的结果图像
# rgba_image = np.dstack((rgb_image, alpha_channel))
# alpha_channel = np.where(mask > 0, 255, 0)
# # 更新RGBA图像的透明度通道
# rgba_image[:, :, 3] = alpha_channel
# return rgba_image
def rgb_to_rgba(rgb_image, mask):
# 创建全透明的alpha通道
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
# 合并RGB图像和alpha通道
rgba_image = np.dstack((rgb_image, alpha_channel))
return rgba_image
if __name__ == '__main__':
image = open("")

View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File design_ensemble.py
@Author :周成融
@Date 2023/8/16 19:36:21
@detail :发起请求 获取推理结果
"""
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import *
"""
keypoint
预处理 推理 后处理
"""
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
# @ RunTime
# 推理
def get_keypoint_result(image, site):
keypoint_result = None
try:
image, scale_factor = keypoint_preprocess(image)
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
inference_output = torch.from_numpy(results.as_numpy(f'output'))
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
except Exception as e:
logging.warning(f"get_keypoint_result : {e}")
return keypoint_result
def keypoint_postprocess(output, scale_factor):
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
segment_result = max_coords.numpy()
scale_factor = [1 / x for x in scale_factor[::-1]]
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
"""
seg
预处理 推理 后处理
"""
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# @ RunTime
def get_seg_result(image_id, image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
def key_point_show(image_path, key_point_result=None):
img = cv2.imread(image_path)
points_list = key_point_result
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for point in points_list:
cv2.circle(img, point[::-1], point_size, point_color, thickness)
cv2.imshow("0", img)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
# a = get_seg_result(1, image)
print(a)

View File

@@ -0,0 +1,99 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -0,0 +1,181 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File synthesis_item.py
@Author :周成融
@Date 2023/8/26 14:13:04
@detail
"""
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
def positioning(all_mask_shape, mask_shape, offset):
all_start = 0
all_end = 0
mask_start = 0
mask_end = 0
if offset == 0:
all_start = 0
all_end = min(all_mask_shape, mask_shape)
mask_start = 0
mask_end = min(all_mask_shape, mask_shape)
elif offset > 0:
all_start = min(offset, all_mask_shape)
all_end = min(offset + mask_shape, all_mask_shape)
mask_start = 0
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
elif offset < 0:
if abs(offset) > mask_shape:
all_start = 0
all_end = 0
else:
all_start = 0
if mask_shape - abs(offset) > all_mask_shape:
all_end = min(mask_shape - abs(offset), all_mask_shape)
else:
all_end = mask_shape - abs(offset)
if abs(offset) > mask_shape:
mask_start = mask_shape
mask_end = mask_shape
else:
mask_start = abs(offset)
if mask_shape - abs(offset) >= all_mask_shape:
mask_end = all_mask_shape + abs(offset)
else:
mask_end = mask_shape
return all_start, all_end, mask_start, mask_end
# @RunTime
def synthesis(data, size, basic_info):
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
body_mask = None
for d in data:
if d['name'] == 'body' or d['name'] == 'mannequin':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
result_image = base_image
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
result_image = front_image
if back_image:
result_image.paste(back_image, (0, 0), back_image)
# with io.BytesIO() as output:
# result_image.save(output, format='PNG')
# data = output.getvalue()
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# oss upload
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"

View File

@@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File upload_image.py
@Author :周成融
@Date 2023/8/28 13:49:20
@detail
"""
import io
import logging
import cv2
from app.core.config import *
from app.service.utils.decorator import RunTime
from app.service.utils.oss_client import oss_upload_image
# @RunTime
def upload_png_mask(front_image, object_name, mask=None):
try:
mask_url = None
if mask is not None:
mask_inverted = cv2.bitwise_not(mask)
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
# image_bytes = io.BytesIO()
# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
# image_bytes.seek(0)
# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}"
# oss upload ####################
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
image_data = io.BytesIO()
front_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
return front_image, image_url, mask_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
# @RunTime
# def upload_png_mask(front_image, object_name, mask=None):
# mask_url = None
# if mask is not None:
# mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
# image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
# return front_image, image_url, mask_url