Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
2024-12-02 18:22:14 +08:00
19 changed files with 545 additions and 90 deletions

View File

@@ -2,13 +2,13 @@ import json
import logging import logging
import os import os
from fastapi import APIRouter, HTTPException, UploadFile, File, Form from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
from app.service.design.model_process_service import model_transpose from app.service.design.model_process_service import model_transpose
from app.service.design_batch.service import start_design_batch_generate from app.service.design_batch.service import start_design_batch_generate
from app.service.design_fast.design_generate import design_generate from app.service.design_fast.design_generate import design_generate, design_generate_v2
from app.service.design_fast.utils.redis_utils import Redis from app.service.design_fast.utils.redis_utils import Redis
router = APIRouter() router = APIRouter()
@@ -16,7 +16,7 @@ logger = logging.getLogger()
@router.post("/design") @router.post("/design")
def design(request_data: DesignModel): def design(request_data: DesignModel, background_tasks: BackgroundTasks):
""" """
创建一个具有以下参数的请求体: 创建一个具有以下参数的请求体:
示例参数: 示例参数:
@@ -67,7 +67,6 @@ def design(request_data: DesignModel):
0 0
], ],
"path": "aida-sys-image/images/female/trousers/0825000630.jpg", "path": "aida-sys-image/images/female/trousers/0825000630.jpg",
"seg_mask_url": "test/result.png",
"print": { "print": {
"element": { "element": {
"element_angle_list": [], "element_angle_list": [],
@@ -104,7 +103,6 @@ def design(request_data: DesignModel):
0 0
], ],
"path": "aida-sys-image/images/female/blouse/0902003811.jpg", "path": "aida-sys-image/images/female/blouse/0902003811.jpg",
"seg_mask_url": "test/result.png",
"print": { "print": {
"element": { "element": {
"element_angle_list": [], "element_angle_list": [],
@@ -141,7 +139,6 @@ def design(request_data: DesignModel):
0 0
], ],
"path": "aida-sys-image/images/female/outwear/0825000410.jpg", "path": "aida-sys-image/images/female/outwear/0825000410.jpg",
"seg_mask_url": "test/result.png",
"print": { "print": {
"element": { "element": {
"element_angle_list": [], "element_angle_list": [],
@@ -167,6 +164,10 @@ def design(request_data: DesignModel):
1.0, 1.0,
1.0 1.0
], ],
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
"type": "Outwear" "type": "Outwear"
}, },
{ {
@@ -195,6 +196,182 @@ def design(request_data: DesignModel):
return ResponseModel(data=data) return ResponseModel(data=data)
@router.post("/design_v2")
async def design_v2(request_data: DesignModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
示例参数:
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
200,
241
],
"hand_point_right": [
223,
297
],
"waistband_left": [
112,
241
],
"hand_point_left": [
92,
305
],
"shoulder_left": [
99,
116
],
"shoulder_right": [
215,
116
]
},
"layer_order": true,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 270372,
"color": "30 28 28",
"image_id": 69780,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
},
{
"businessId": 270373,
"color": "30 28 28",
"image_id": 98243,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 11,
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"businessId": 270374,
"color": "172 68 68",
"image_id": 98244,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 12,
"resize_scale": [
1.0,
1.0
],
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
"type": "Outwear"
},
{
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
"image_id": 96090,
"type": "Body"
}
]
}
],
"process_id": "83"
}
"""
try:
# 异步
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
background_tasks.add_task(design_generate_v2, request_data)
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.post('/get_progress') @router.post('/get_progress')
def get_progress(request_data: DesignProgressModel): def get_progress(request_data: DesignProgressModel):
""" """

View File

@@ -26,6 +26,7 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun
- **mode**: 生成模式img2img或者txt2img - **mode**: 生成模式img2img或者txt2img
- **category**: 生成图片的类别sketch print 等等 - **category**: 生成图片的类别sketch print 等等
- **gender**: 生成sketch专用服装类别 - **gender**: 生成sketch专用服装类别
- **version**: 使用模型版本 fast 或者 high
示例参数: 示例参数:
{ {
@@ -34,7 +35,8 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
"mode": "img2img", "mode": "img2img",
"category": "sketch", "category": "sketch",
"gender": "male" "gender": "male",
"version": "fast"
} }
""" """
try: try:

View File

@@ -93,9 +93,6 @@ OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-4-0613", "gpt-4-0613",
"gpt-4-32k-0613", } "gpt-4-32k-0613", }
# attribute service config
ATT_TRITON_URL = "10.1.1.240:10000"
# SR service config # SR service config
SR_MODEL_NAME = "super_resolution" SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031" SR_TRITON_URL = "10.1.1.240:10031"
@@ -103,8 +100,12 @@ SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ENV}") SR_RABBITMQ_QUEUES = os.getenv("SR_RABBITMQ_QUEUES", f"SuperResolution{RABBITMQ_ENV}")
# GenerateImage service config # GenerateImage service config
GI_MODEL_NAME = 'stable_diffusion_xl' FAST_GI_MODEL_URL = '10.1.1.243:10011'
GI_MODEL_URL = '10.1.1.240:10041' FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GI_MINIO_BUCKET = "aida-users" GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}") GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg" GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
@@ -113,17 +114,15 @@ GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}") SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}")
# Generate Single Logo service config # Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.240:10041' GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MINIO_BUCKET = "aida-users" GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent' GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}") GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}")
# Generate Product service config # Generate Product service config
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}") GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all' GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet' GPI_MODEL_URL = '10.1.1.243:10051'
GPI_MODEL_URL = '10.1.1.240:10041'
# Generate Single Logo service config # Generate Single Logo service config
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}") GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
@@ -132,14 +131,14 @@ GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051' GRI_MODEL_URL = '10.1.1.240:10051'
# SEG service config # SEG service config
SEG_MODEL_URL = '10.1.1.240:10000'
SEGMENTATION = { SEGMENTATION = {
"new_model_name": "seg_knet", "new_model_name": "seg_knet",
"name": "seg_ocrnet_hr18", "name": "seg_ocrnet_hr18",
"input": "seg_input__0", "input": "seg_input__0",
"output": "seg_output__0", "output": "seg_output__0",
} }
# ollama config
OLLAMA_URL = "http://10.1.1.243:11434/api/embeddings"
# DESIGN config # DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:10000' DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing" AIDA_CLOTHING = "aida-clothing"

View File

@@ -8,6 +8,7 @@ class GenerateImageModel(BaseModel):
mode: str mode: str
category: str category: str
gender: str gender: str
version: str
class GenerateSingleLogoImageModel(BaseModel): class GenerateSingleLogoImageModel(BaseModel):

View File

@@ -28,7 +28,7 @@ class AttributeRecognition:
} }
) )
self.const = const self.const = const
self.triton_client = httpclient.InferenceServerClient(url=f"{ATT_TRITON_URL}") self.triton_client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
def get_result(self): def get_result(self):
for sketch in self.request_data: for sketch in self.request_data:

View File

@@ -26,7 +26,7 @@ class CategoryRecognition:
self.attr_type = pd.read_csv(CATEGORY_PATH) self.attr_type = pd.read_csv(CATEGORY_PATH)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.request_data = [] self.request_data = []
self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
for sketch in request_data: for sketch in request_data:
self.request_data.append( self.request_data.append(
{ {

View File

@@ -2,11 +2,12 @@ import logging
import threading import threading
import time import time
import requests
from minio import Minio from minio import Minio
from app.core.config import * from app.core.config import *
from app.service.design_fast.item import BodyItem, TopItem, BottomItem from app.service.design_fast.item import BodyItem, TopItem, BottomItem, AccessoriesItem
from app.service.design_fast.utils.organize import organize_body, organize_clothing from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_accessories
from app.service.design_fast.utils.progress import final_progress, update_progress from app.service.design_fast.utils.progress import final_progress, update_progress
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
from app.service.utils.decorator import RunTime from app.service.utils.decorator import RunTime
@@ -26,9 +27,14 @@ def process_item(item, basic):
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']: elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
top_server = TopItem(data=item, basic=basic, minio_client=minio_client) top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
item_data = top_server.process() item_data = top_server.process()
else: elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client) bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process() item_data = bottom_server.process()
elif item['type'].lower() in ['accessories']:
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
else:
raise NotImplementedError(f"Item type {item['type']} not implemented")
return item_data return item_data
@@ -38,6 +44,10 @@ def process_layer(item, layers):
body_layer = organize_body(item) body_layer = organize_body(item)
layers.append(body_layer) layers.append(body_layer)
return item['body_image'].size return item['body_image'].size
elif item['name'] == 'accessories':
front_layer, back_layer = organize_accessories(item)
layers.append(front_layer)
layers.append(back_layer)
else: else:
front_layer, back_layer = organize_clothing(item) front_layer, back_layer = organize_clothing(item)
layers.append(front_layer) layers.append(front_layer)
@@ -57,7 +67,7 @@ def design_generate(request_data):
def process_object(step, object): def process_object(step, object):
nonlocal active_threads nonlocal active_threads
basic = object['basic'] basic = object['basic']
items_response = {'layers': []} items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""}
if basic['single_overall'] == "overall": if basic['single_overall'] == "overall":
item_results = [] item_results = []
for item in object['items']: for item in object['items']:
@@ -126,6 +136,117 @@ def design_generate(request_data):
return object_response return object_response
@RunTime
def design_generate_v2(request_data):
objects_data = request_data.dict()['objects']
threads = []
def process_object(step, object):
basic = object['basic']
items_response = {
'layers': [],
'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "",
'requestId': object['requestId'] if 'requestId' in object.keys() else ""
}
if basic['single_overall'] == "overall":
item_results = []
for item in object['items']:
item_results.append(process_item(item, basic))
layers = []
body_size = None
for item in item_results:
body_size = process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
layers, new_size = update_base_size_priority(layers, body_size)
for lay in layers:
items_response['layers'].append({
'image_category': "body" if lay['name'] == 'mannequin' else lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
})
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
else:
item_result = process_item(object['items'][0], basic)
items_response['layers'].append({
'image_category': f"{item_result['name']}_front",
'image_size': item_result['back_image'].size if item_result['back_image'] else None,
'position': None,
'priority': 0,
'image_url': item_result['front_image_url'],
'mask_url': item_result['mask_url'],
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
})
items_response['layers'].append({
'image_category': f"{item_result['name']}_back",
'image_size': item_result['front_image'].size if item_result['front_image'] else None,
'position': None,
'priority': 0,
'image_url': item_result['back_image_url'],
'mask_url': item_result['mask_url'],
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
})
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
# 发送结果给java端
url = "https://3998-117-143-125-51.ngrok-free.app/api/third/party/receiveDesignResults"
headers = {
'Accept': "*/*",
'Accept-Encoding': "gzip, deflate, br",
'User-Agent': "PostmanRuntime-ApipostRuntime/1.1.0",
'Connection': "keep-alive",
'Content-Type': "application/json"
}
response = post_request(url, json_data=items_response, headers=headers)
if response:
# 打印结果
logger.info(response.text)
logger.info(items_response)
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object))
threads.append(t)
t.start()
def post_request(url, data=None, json_data=None, headers=None, auth=None, timeout=5):
"""
发送POST请求的封装函数
:param url: 接口的URL地址
:param data: 要发送的数据(字典形式,用于表单数据等,会自动编码)
:param json_data: 要发送的JSON数据字典形式会自动转换为JSON字符串
:param headers: 请求头字典
:param auth: 认证信息(如 ('username', 'password') 形式用于基本认证)
:param timeout: 超时时间,单位为秒
:return: 返回接口的响应对象
"""
try:
response = requests.post(
url,
data=data,
json=json_data,
headers=headers,
auth=auth,
timeout=timeout
)
response.raise_for_status() # 如果请求失败,抛出异常
return response
except requests.RequestException as e:
print(f"POST请求出错: {e}")
return None
if __name__ == '__main__': if __name__ == '__main__':
object_data = { object_data = {
"objects": [ "objects": [

View File

@@ -1,4 +1,4 @@
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection, BackPerspective from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
class BaseItem: class BaseItem:
@@ -9,6 +9,27 @@ class BaseItem:
self.result.update(basic) self.result.update(basic)
class AccessoriesItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Accessories_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
ContourDetection(),
# Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.Accessories_pipeline:
self.result = item(self.result)
return self.result
class TopItem(BaseItem): class TopItem(BaseItem):
def __init__(self, data, basic, minio_client): def __init__(self, data, basic, minio_client):
super().__init__(data, basic) super().__init__(data, basic)

View File

@@ -74,6 +74,8 @@ class LoadImage:
keypoint = 'head_point' keypoint = 'head_point'
elif name == 'earring': elif name == 'earring':
keypoint = 'ear_point' keypoint = 'ear_point'
elif name == 'accessories':
keypoint = "accessories"
else: else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.") f"bag, shoes, hairstyle, earring.")

View File

@@ -18,7 +18,7 @@ class Scaling:
- -
int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1 int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1
) )
if distance_clo == 0: if distance_clo == 0:
result['scale'] = 1 result['scale'] = 1
else: else:
@@ -46,4 +46,16 @@ class Scaling:
result['scale'] = result['scale_bag'] result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point': elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings'] result['scale'] = result['scale_earrings']
elif result['keypoint'] == 'accessories':
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
distance_clo = result['img_shape'][1]
distance_bdy = 320 / 2
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
else:
result['scale'] = 1
return result return result

View File

@@ -8,9 +8,10 @@ from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING from app.core.config import AIDA_CLOTHING
from app.service.design_fast.utils.conversion_image import rgb_to_rgba from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
from app.service.design_fast.utils.upload_image import upload_png_mask from app.service.design_fast.utils.upload_image import upload_png_mask
from app.service.utils.generate_uuid import generate_uuid from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
class Split(object): class Split(object):
@@ -20,7 +21,7 @@ class Split(object):
def __call__(self, result): def __call__(self, result):
try: try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'): if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms','accessories'):
front_mask = result['front_mask'] front_mask = result['front_mask']
back_mask = result['back_mask'] back_mask = result['back_mask']
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask) rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
@@ -30,6 +31,24 @@ class Split(object):
front_mask = cv2.resize(front_mask, new_size) front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0] result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA)) result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
# 预处理用户自选区mask
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
blue_mask = b > r
# 创建红色和绿色掩码
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
else:
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None) result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape height, width = front_mask.shape

View File

@@ -55,6 +55,45 @@ def organize_clothing(layer):
return front_layer, back_layer return front_layer, back_layer
def organize_accessories(layer):
# 起始坐标
start_point = (0, 0)
# 前片数据
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
name=f'{layer["name"].lower()}_front',
image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
pattern_image=layer['pattern_image'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
name=f'{layer["name"].lower()}_back',
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_image_url=layer['pattern_image_url'],
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
)
return front_layer, back_layer
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale): def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
""" """
Align left Align left

View File

@@ -79,9 +79,11 @@ def synthesis(data, size, basic_info):
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY) _, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask) top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask) bottom_outer_mask = np.array(binary_body_mask)
accessories_outer_mask = np.array(binary_body_mask)
top = True top = True
bottom = True bottom = True
accessories = True
i = len(data) i = len(data)
while i: while i:
i -= 1 i -= 1
@@ -109,10 +111,23 @@ def synthesis(data, size, basic_info):
background = np.zeros_like(top_outer_mask) background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end] background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask bottom_outer_mask = background + bottom_outer_mask
elif accessories and data[i]['name'] in ['accessories_front']:
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
accessories_outer_mask = background + accessories_outer_mask
pass
elif bottom is False and top is False: elif bottom is False and top is False:
break break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask) all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
all_mask = cv2.bitwise_or(all_mask, accessories_outer_mask)
for layer in data: for layer in data:
if layer['image'] is not None: if layer['image'] is not None:

View File

@@ -0,0 +1,26 @@
from PIL import Image
def sketch_to_transparent(image, mask, transparency):
# 打开原始图片
image = image.convert("RGBA")
# 打开mask图片假设mask图片是灰度图白色区域为要处理的区域黑色区域为保留的区域
mask = Image.fromarray(mask)
# 根据透明度调整因子将透明度转换为0-255之间的值
alpha_value = int((1 - transparency) * 255.0)
# 获取图片的像素数据
image_pixels = image.load()
mask_pixels = mask.load()
width, height = image.size
for y in range(height):
for x in range(width):
# 如果mask区域对应的像素为白色值大于128这里假设白色为要处理的区域可根据实际情况调整
if mask_pixels[x, y] > 128:
r, g, b, a = image_pixels[x, y]
image_pixels[x, y] = (r, g, b, alpha_value)
return image

View File

@@ -35,7 +35,12 @@ class GenerateImage:
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS)) # self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel() # self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) # self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) self.version = request_data.version
if request_data.version == "fast":
self.grpc_client = grpcclient.InferenceServerClient(url=FAST_GI_MODEL_URL)
else:
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True) self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
if request_data.mode == "img2img": if request_data.mode == "img2img":
# cv2 读图片是BGR PIL读图片是RGB # cv2 读图片是BGR PIL读图片是RGB
@@ -87,23 +92,28 @@ class GenerateImage:
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
is_smudge = True is_smudge = True
if self.category == "sketch": if self.category == "sketch":
# 色阶调整 if self.version == "fast":
cutoff = 1 # 色阶调整
levels_img = autoLevels(image_result, cutoff) cutoff = 1
# 亮度调整 levels_img = autoLevels(image_result, cutoff)
luminance = luminance_adjust(0.3, levels_img) # 亮度调整
# 去背景 luminance = luminance_adjust(0.3, levels_img)
remove_bg_image = remove_background(luminance) # 去背景
# 人脸检测 remove_bg_image = remove_background(luminance)
# if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0: # 人脸检测
# is_smudge = False # if face_detect_pic(remove_bg_image, self.user_id, self.category, self.tasks_id) > 0:
# else: # is_smudge = False
# 污点/ # else:
is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id) # 污点/
# 类型识别 is_smudge, not_smudge_image = stain_detection(remove_bg_image, self.user_id, self.category, self.tasks_id)
category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender) # 类型识别
self.generate_data['category'] = str(category) category, scores, not_smudge_image = generate_category_recognition(image=remove_bg_image, gender=self.gender)
image_result = not_smudge_image self.generate_data['category'] = str(category)
image_result = not_smudge_image
else:
category, scores, not_smudge_image = generate_category_recognition(image=image_result, gender=self.gender)
self.generate_data['category'] = str(category)
image_result = not_smudge_image
if is_smudge: # 无污点 if is_smudge: # 无污点
# image_result = adjust_contrast(image_result) # image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
@@ -134,15 +144,19 @@ class GenerateImage:
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "FP16") input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(text_obj.dtype)) input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj) input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj) input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj) input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode] inputs = [input_text, input_image, input_mode]
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback) if self.version == "fast":
ctx = self.grpc_client.async_infer(model_name=FAST_GI_MODEL_NAME, inputs=inputs, callback=self.callback)
else:
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600 time_out = 600
generate_data = None generate_data = None
while time_out > 0: while time_out > 0:
@@ -181,11 +195,12 @@ def infer_cancel(tasks_id):
if __name__ == '__main__': if __name__ == '__main__':
rd = GenerateImageModel( rd = GenerateImageModel(
tasks_id="123-89", tasks_id="123-89",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic', prompt='a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background',
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg", image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
mode='txt2img', mode='txt2img',
category="test", category="test",
gender="male" gender="male",
version="high"
) )
server = GenerateImage(rd) server = GenerateImage(rd)
print(server.get_result()) print(server.get_result())

View File

@@ -15,7 +15,7 @@ import cv2
import numpy as np import numpy as np
import redis import redis
import tritonclient.grpc as grpcclient import tritonclient.grpc as grpcclient
from PIL import Image, ImageOps from PIL import Image
from tritonclient.utils import np_to_triton_dtype from tritonclient.utils import np_to_triton_dtype
from app.core.config import * from app.core.config import *
@@ -41,7 +41,7 @@ class GenerateProductImage:
self.batch_size = 1 self.batch_size = 1
self.product_type = request_data.product_type self.product_type = request_data.product_type
self.prompt = request_data.prompt self.prompt = request_data.prompt
self.image, self.image_size = pre_processing_image(request_data.image_url) self.image, self.image_size, self.left, self.top = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:] self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''} self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
@@ -55,12 +55,10 @@ class GenerateProductImage:
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data)) self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else: else:
# pil图像转成numpy数组 # pil图像转成numpy数组
if self.product_type == "single": image = result.as_numpy("generated_inpaint_image")
image = result.as_numpy("generated_cnet_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size) image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png") cropped_image = post_processing_image(image_result, self.left, self.top)
image_url = upload_SDXL_image(cropped_image, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS" self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success" self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url) self.gen_product_data['image_url'] = str(image_url)
@@ -74,16 +72,16 @@ class GenerateProductImage:
try: try:
prompts = [self.prompt] * self.batch_size prompts = [self.prompt] * self.batch_size
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB) self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
self.image = cv2.resize(self.image, (512, 768)) self.image = cv2.resize(self.image, (1024, 1024))
images = [self.image.astype(np.uint8)] * self.batch_size images = [self.image.astype(np.uint8)] * self.batch_size
if self.product_type == "single": if self.product_type == "single":
text_obj = np.array(prompts, dtype="object").reshape(-1, 1) text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3)) image_obj = np.array(images, dtype=np.uint8).reshape((-1, 1024, 1024, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
else: else:
text_obj = np.array(prompts, dtype="object").reshape(1) text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3)) image_obj = np.array(images, dtype=np.uint8).reshape((1024, 1024, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1)) image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
# 假设 prompts、images 和 self.image_strength 已经定义 # 假设 prompts、images 和 self.image_strength 已经定义
@@ -94,11 +92,12 @@ class GenerateProductImage:
input_text.set_data_from_numpy(text_obj) input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj) input_image.set_data_from_numpy(image_obj)
inputs = [input_text, input_image, input_image_strength]
input_image_strength.set_data_from_numpy(image_strength_obj) input_image_strength.set_data_from_numpy(image_strength_obj)
inputs = [input_text, input_image, input_image_strength]
if self.product_type == "single": if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name="stable_diffusion_xl_cnet_inpaint", inputs=inputs, callback=self.callback)
else: else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback) ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
@@ -136,22 +135,13 @@ def infer_cancel(tasks_id):
def pre_processing_image(image_url): def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL") image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# resize 原图至1024*1024
image = image.resize((int(1024 / image.height * image.width), 1024))
# 原始图片的尺寸 # 原始图片的尺寸
width, height = image.size width, height = image.size
# 计算长宽比为 3:2 的新尺寸 new_height, new_width = 1024, 1024
desired_ratio = 2 / 3
current_ratio = width / height
if current_ratio > desired_ratio:
# 原始图片更宽,需要在上下添加 padding
new_width = width
new_height = int(width / desired_ratio)
else:
# 原始图片更高或者长宽比已经为 3:2
new_height = height
new_width = int(height * desired_ratio)
# 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景 # 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0)) pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
@@ -160,9 +150,9 @@ def pre_processing_image(image_url):
top = (new_height - height) // 2 top = (new_height - height) // 2
pad_image.paste(image, (left, top)) pad_image.paste(image, (left, top))
# 将画布 resize 成宽度 500长度 750 # 将画布 resize 成宽度 1024长度 1024
resized_image = pad_image.resize((500, 750)) resized_image = pad_image.resize((1024, 1024))
image_size = (512, 768) image_size = (1024, 1024)
if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info): if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
# 创建白色背景 # 创建白色背景
@@ -171,16 +161,29 @@ def pre_processing_image(image_url):
background.paste(resized_image, mask=resized_image.split()[3]) background.paste(resized_image, mask=resized_image.split()[3])
image = np.array(background) image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image, image_size return image, image_size, left, top
def post_processing_image(image, left, top):
resized_image = image.resize((int(image.width * (768 / image.height)), 768))
# 计算裁剪的坐标
left = (resized_image.width - 512) // 2
upper = 0
right = left + 512
lower = 768
# 进行裁剪
cropped_image = resized_image.crop((left, upper, right, lower))
return cropped_image
if __name__ == '__main__': if __name__ == '__main__':
rd = GenerateProductImageModel( rd = GenerateProductImageModel(
tasks_id="123-89", tasks_id="123-89",
# prompt="", # prompt="",
image_strength=0.9, image_strength=0.7,
prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting", prompt="The best quality, masterpiece,outwear, 8K realistic, HUD",
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png", image_url="aida-results/result_53381ada-ac64-11ef-ae9d-0242ac150002.png",
product_type="overall" product_type="overall"
) )
server = GenerateProductImage(rd) server = GenerateProductImage(rd)

View File

@@ -81,7 +81,7 @@ def get_contours(image):
def seg_infer_image(image_obj): def seg_infer_image(image_obj):
image, ori_shape = seg_preprocess(image_obj) image, ori_shape = seg_preprocess(image_obj)
client = httpclient.InferenceServerClient(url=f"{SEG_MODEL_URL}") client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32) transformed_img = image.astype(np.float32)
# 输入集 # 输入集
inputs = [ inputs = [
@@ -250,7 +250,7 @@ def generate_category_recognition(image, gender):
return preprocessed_img return preprocessed_img
preprocessed_img = preprocess(image) preprocessed_img = preprocess(image)
triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL) triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
inputs = [ inputs = [
httpclient.InferInput("input__0", preprocessed_img.shape, datatype="FP32") httpclient.InferInput("input__0", preprocessed_img.shape, datatype="FP32")

View File

@@ -6,6 +6,8 @@ from chromadb.config import Settings
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
from tqdm import tqdm from tqdm import tqdm
from app.core.config import OLLAMA_URL
# 读取 csv 文件 # 读取 csv 文件
# csv_file_path = r'D:/Files/csv/output/output.csv' # csv_file_path = r'D:/Files/csv/output/output.csv'
# image_path = r'D:/images-clean' # image_path = r'D:/images-clean'
@@ -18,7 +20,7 @@ client = chromadb.Client(Settings(is_persistent=True, persist_directory="/vector
# client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db")) # client = chromadb.Client(Settings(is_persistent=True, persist_directory="D:/workspace/AiDLab/vector_db"))
# 创建集合 # 创建集合
# embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large") # embedding_fn = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="mxbai-embed-large")
embedding_fn = OllamaEmbeddingFunction(url="http://10.1.1.240:11434/api/embeddings", model_name="mxbai-embed-large") embedding_fn = OllamaEmbeddingFunction(url=OLLAMA_URL, model_name="mxbai-embed-large")
# def create_collection(): # def create_collection():

View File

@@ -82,9 +82,10 @@ if __name__ == '__main__':
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png" # url = "aida-users/89/single_logo/123-89.png"
url = "aida-results/result_e961eed6-9278-11ef-a957-0826ae3ad6b3.png" url = "aida-users/89/test/123-89.png"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "cv2" read_type = "2"
if read_type == "cv2": if read_type == "cv2":
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type) img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
cv2.imshow("", img) cv2.imshow("", img)