Merge remote-tracking branch 'refs/remotes/origin/develop'

This commit is contained in:
zhouchengrong
2024-07-09 09:38:36 +08:00
77 changed files with 6095 additions and 205 deletions

View File

@@ -1,9 +1,12 @@
import json
import logging
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from app.core.config import DEBUG
from app.schemas.attribute_retrieve import *
from app.service.attribute.config import const
from app.schemas.response_template import ResponseModel
from app.service.attribute.config import const, local_debug_const
from app.service.attribute.service_att_recognition import AttributeRecognition
from app.service.attribute.service_category_recognition import CategoryRecognition
@@ -12,34 +15,63 @@ logger = logging.getLogger()
# 属性识别
@router.post("/attribute_recognition")
@router.post("/attribute_recognition", response_model=ResponseModel)
def attribute_recognition(request_item: list[AttributeRecognitionModel]):
"""
获取sketch的属性collar sleeve_length 等等
创建一个具有以下参数的请求体:
- **category**: sketch的类别 Dress
- **colony**: 服装类别,男装或女装
- **sketch_img_url**: 被提取属性的S3或minio url地址
示例参数:
[
{
"category": "Dress",
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
}
]
"""
try:
service = AttributeRecognition(const=const, request_data=request_item)
for item in request_item:
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
if DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else:
service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result()
code = 200
message = "access"
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e:
code = 400
message = e
data = e
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
return {"code": code, "message": message, "data": data}
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data={"list": data})
# 类别识别
@router.post("/category_recognition")
def category_recognition(request_item: list[CategoryRecognitionModel]):
"""
获取sketch的类别dress blouse 等等
创建一个具有以下参数的请求体:
- **colony**: 服装类别male或Female
- **sketch_img_url**: 被提取sketch类别的S3或minio url地址
示例参数:
[
{
"colony": "Female",
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
}
]
"""
try:
for item in request_item:
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
service = CategoryRecognition(request_data=request_item)
data = service.get_result()
code = 200
message = "access"
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}")
except Exception as e:
code = 400
message = e
data = e
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
return {"code": code, "message": message, "data": data}
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

39
app/api/api_chat_robot.py Normal file
View File

@@ -0,0 +1,39 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.chat_robot import ChatRobotModel
from app.schemas.response_template import ResponseModel
from app.service.chat_robot.script.main import chat
router = APIRouter()
logger = logging.getLogger()
@router.post("/chat_robot")
def chat_robot(request_data: ChatRobotModel):
"""
对话机器人
创建一个具有以下参数的请求体:
- **gender**: 服装类别
- **message**: 消息
- **session_id**: 会话id
- **user_id**: 用户id
示例参数:
{
"gender": "male",
"message": "你好",
"session_id": "string-89",
"user_id": 89
}
"""
try:
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = chat(post_data=request_data)
logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

195
app/api/api_design.py Normal file
View File

@@ -0,0 +1,195 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel
from app.schemas.response_template import ResponseModel
from app.service.design.model_process_service import model_transpose
from app.service.design.service import generate
from app.service.design.utils.redis_utils import Redis
router = APIRouter()
logger = logging.getLogger()
@router.post("/design")
def design(request_data: DesignModel):
"""
创建一个具有以下参数的请求体:
示例参数:
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
203,
249
],
"hand_point_right": [
229,
343
],
"waistband_left": [
119,
248
],
"hand_point_left": [
97,
343
],
"shoulder_left": [
108,
107
],
"shoulder_right": [
212,
107
]
},
"layer_order": true,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 255303,
"color": "139 148 156",
"image_id": 95159,
"offset": [
0,
0
],
"path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png",
"print": {
"single": {
"location": [
[
200.0,
200.0
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png"
],
"print_scale_list": [
1.0
]
},
"overall": {
"location": [
[
512.0,
512.0
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png"
],
"print_scale_list": [
1.0
]
},
"element": {
"element_angle_list": [
0.0
],
"element_path_list": [
"aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png"
],
"element_scale_list": [
0.2731036750637755
],
"location": [
[
228.63694825464364,
406.4843844199667
]
]
}
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"type": "Dress"
},
{
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"type": "Body"
}
]
}
],
"process_id": "89"
}
"""
try:
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = generate(request_data=request_data)
logger.info(f"design response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
@router.post('/get_progress')
def get_progress(request_data: DesignProgressModel):
"""
获取design 进度
创建一个具有以下参数的请求体:
- **process_id**: 进度id
示例参数:
{
"process_id": "6878547032381675"
}
"""
try:
logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}")
process_id = request_data.process_id
r = Redis()
data = r.read(key=process_id)
if data is None:
raise ValueError(f"No progress ID: {process_id}")
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}")
except Exception as e:
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
@router.post('/model_process')
def model_process(request_data: ModelProgressModel):
"""
获取模特图片预处理
创建一个具有以下参数的请求体:
- **model_path**: 模特图片的minio或s3 url地址
示例参数:
{
"model_path": "aida-users/10/models/female/9c788f5b-b8c7-424c-b149-025aeb0bda51model.jpg"
}
"""
try:
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}")
data = model_transpose(image_path=request_data.model_path)
logger.info(f"model_process response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"model_process Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -0,0 +1,40 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.pre_processing import DesignPreProcessingModel
from app.schemas.response_template import ResponseModel
from app.service.design_pre_processing.service import DesignPreprocessing
router = APIRouter()
logger = logging.getLogger()
@router.post("/design_pre_processing")
def design_pre_processing(request_data: DesignPreProcessingModel):
"""
design 预处理 获取sketch的基本信息
创建一个具有以下参数的请求体:
- **sketches**: sketch url等信息
示例参数:
{
"sketches": [
{
"image_category": "dress",
"image_id": "107903",
"image_url": "aida-sys-image/images/female/dress/0628000000.jpg"
}
]
}
"""
try:
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}")
server = DesignPreprocessing()
data = server.pipeline(image_list=request_data.sketches)
logger.info(f"design response @@@@@@:{json.dumps(data)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -1,28 +1,189 @@
import json
import logging
from fastapi import APIRouter, BackgroundTasks
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.service import GenerateImage, infer_cancel
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel
from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
router = APIRouter()
logger = logging.getLogger()
'''generate image'''
@router.post("/generate_image")
def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 图生图的输入minio或S3 url 地址
- **mode**: 生成模式img2img或者txt2img
- **category**: 生成图片的类别sketch print 等等
- **gender**: 生成sketch专用服装类别
示例参数:
{
"tasks_id": "123-89",
"prompt": "skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic",
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
"mode": "img2img",
"category": "sketch",
"gender": "male"
}
"""
try:
logger.info(f"request data ### : {request_item}")
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateImage(request_item)
background_tasks.add_task(service.get_result)
code = 200
message = "access"
except Exception as e:
code = 400
message = e
logger.warning(e)
return {"code": code, "message": message}
logger.warning(f"generate_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_cancel/{tasks_id}>")
def generate_image(tasks_id):
result = infer_cancel(tasks_id)
return {"code": 200, "message": result['message'], "data": result['data']}
@router.get("/generate_cancel/{tasks_id}")
def generate_image(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_image_infer_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''single logo'''
@router.post("/generate_single_logo")
def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **prompt**: 想要生成图片的描述词
- **seed**: 固定的prompt和固定的seed 每次的生成结果都是一样的
示例参数:
{
"tasks_id": "123-89",
"prompt": "an apple",
"seed": "2"
}
"""
try:
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateSingleLogoImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_single_logo Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_single_logo_cancel/{tasks_id}")
def generate_single_logo_image(tasks_id: str):
try:
logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}")
data = generate_single_logo_cancel(tasks_id)
logger.info(f"generate_single_logo_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''product image'''
@router.post("/generate_product_image")
def generate_product_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **image_strength**: 生成强度,越低越接近原图
- **product_type**: 输入single item 还是 overall item
示例参数:
{
"tasks_id": "123-89",
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
"image_strength": 0.8,
"product_type": "overall"
}
"""
try:
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateProductImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_product_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_product_image_cancel_cancel/{tasks_id}")
def generate_product_image(tasks_id: str):
try:
logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_product_image_cancel(tasks_id)
logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''relight image'''
@router.post("/generate_relight_image")
def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
- **product_type**: 输入single item 还是 overall item
示例参数:
{
"tasks_id": "123-89",
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"direction": "Right Light",
"product_type": "overall"
}
"""
try:
logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateRelightImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
def generate_relight_image(tasks_id: str):
try:
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_relight_image_cancel(tasks_id)
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])

View File

@@ -0,0 +1,34 @@
import json
import logging
import time
from fastapi import APIRouter, HTTPException
from app.schemas.prompt_generation import PromptGenerationImageModel
from app.schemas.response_template import ResponseModel
from app.service.prompt_generation.chatgpt_for_translation import translate_to_en
router = APIRouter()
logger = logging.getLogger()
@router.post("/translateToEN")
def prompt_generation(request_data: PromptGenerationImageModel):
"""
翻译prompt接口
创建一个具有以下参数的请求体:
- **text**: 待翻译语句
示例参数:
{
"text": "你好"
}
"""
try:
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
data = translate_to_en(request_data.text)
logger.info(f"prompt_generation response @@@@@@:{data}")
except Exception as e:
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -4,6 +4,11 @@ from app.api import api_test
from app.api import api_super_resolution
from app.api import api_generate_image
from app.api import api_attribute_retrieve
from app.api import api_design
from app.api import api_chat_robot
from app.api import api_prompt_generation
from app.api import api_design_pre_processing
router = APIRouter()
@@ -11,3 +16,7 @@ router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")
router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api")
router.include_router(api_design.router, tags=['design'], prefix="/api")
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")

View File

@@ -1,8 +1,9 @@
import json
import logging
from fastapi import APIRouter, BackgroundTasks
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.response_template import ResponseModel
from app.schemas.super_resolution import SuperResolutionModel
from app.service.super_resolution.service import SuperResolution, infer_cancel
@@ -12,19 +13,36 @@ logger = logging.getLogger()
@router.post("/super_resolution")
def super_resolution(request_item: SuperResolutionModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **sr_image_url**: 超分图片的minio或s3 url地址
- **sr_xn**: 超分的倍数只接受2或4
- **sr_tasks_id**: 任务id 用于取消超分任务和获取超分结果
示例参数:
{
"sr_image_url": "aida-sys-image/images/female/blouse/0628000098.jpg",
"sr_xn": 2,
"sr_tasks_id": "12341556-89"
}
"""
try:
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = SuperResolution(request_item)
background_tasks.add_task(service.sr_result)
code = 200
message = "access"
except Exception as e:
code = 400
message = e
logger.warning(e)
return {"code": code, "message": message}
logger.warning(f"super_resolution Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/sr_cancel/{tasks_id}>")
def super_resolution(tasks_id):
result = infer_cancel(tasks_id)
return {"code": 200, "message": result['message'], "data": result['data']}
def super_resolution(tasks_id: str):
try:
logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}")
data = infer_cancel(tasks_id)
logger.info(f"sr_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"sr_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])

View File

@@ -1,13 +1,27 @@
import json
import logging
from fastapi import APIRouter
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES
from fastapi import HTTPException
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
from app.schemas.response_template import ResponseModel
logger = logging.getLogger()
router = APIRouter()
@router.get("")
def test():
logger.info(SR_RABBITMQ_QUEUES)
logger.info("test")
return {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES}
@router.get("{id}")
def test(id: int):
data = {
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
"local_oss_server": OSS
}
logger.info(json.dumps(data))
if id == 1:
raise HTTPException(status_code=404, detail="Item not found")
return ResponseModel(data=data)

View File

@@ -19,15 +19,16 @@ class Settings(BaseSettings):
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
OSS = "minio"
DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
# FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
# FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
RABBITMQ_ENV = "" # 生产环境
# RABBITMQ_ENV = "-dev" # 开发环境
@@ -41,6 +42,11 @@ MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# S3 配置
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
S3_AWS_SECRET_ACCESS_KEY = "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"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
@@ -55,12 +61,38 @@ RABBITMQ_PARAMS = {
}
# milvus 配置
MILVUS_DB_HOST = "10.1.1.240"
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_PORT = "19530"
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con_python' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# attribute service config
ATT_TRITON_URL = "10.1.1.240:10000"
@@ -77,6 +109,28 @@ GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}")
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.240:10041'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}")
# Generate Product service config
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.240:10041'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# SEG service config
SEG_MODEL_URL = '10.1.1.240:10000'
SEGMENTATION = {
@@ -87,9 +141,13 @@ SEGMENTATION = {
}
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:9000'
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# DESIGN 预处理
IF_DEBUG_SHOW = False
# 优先级
PRIORITY_DICT = {
@@ -117,4 +175,3 @@ PRIORITY_DICT = {
'bag_back': -98,
'earring_back': -99,
}

View File

@@ -1,14 +1,18 @@
import logging.config
from http.client import HTTPException
from fastapi.responses import JSONResponse
from fastapi import FastAPI, HTTPException, Request
import uvicorn
from fastapi import FastAPI
from app.api.api_route import router
from app.core.config import settings
from app.schemas.response_template import ResponseModel
from logging_env import LOGGER_CONFIG_DICT
logging.config.dictConfig(LOGGER_CONFIG_DICT)
logging.getLogger("pika").setLevel(logging.WARNING)
from starlette.middleware.cors import CORSMiddleware
@@ -35,5 +39,15 @@ def get_application() -> FastAPI:
app = get_application()
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict()
)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class ChatRobotModel(BaseModel):
gender: str
message: str
session_id: str
user_id: int

58
app/schemas/design.py Normal file
View File

@@ -0,0 +1,58 @@
from pydantic import BaseModel
# class BodyPointModel(BaseModel):
# waistband_right: list[int]
# hand_point_right: list[int]
# waistband_left: list[int]
# hand_point_left: list[int]
# shoulder_left: list[int]
# shoulder_right: list[int]
#
#
# class BasicModel(BaseModel):
# body_point: BodyPointModel
# layer_order: bool
# scale_bag: float
# scale_earrings: float
# self_template: bool
# single_overall: str
# switch_category: str
# body_path: str
#
#
# class PrintModel(BaseModel):
# if_single: bool
# print_path_list: list[str]
#
#
# class ItemModel(BaseModel):
# color: str
# image_id: str
# offset: list[int]
# path: str
# print: PrintModel
# resize_scale: float
# type: str
#
#
# class CollocationModel(BaseModel):
# basic: BasicModel
# item: list[ItemModel]
#
#
# class DesignModel(BaseModel):
# object: list[CollocationModel]
# process_id: str
class DesignModel(BaseModel):
objects: list[dict]
process_id: str
class DesignProgressModel(BaseModel):
process_id: str
class ModelProgressModel(BaseModel):
model_path: str

View File

@@ -8,3 +8,25 @@ class GenerateImageModel(BaseModel):
mode: str
category: str
gender: str
class GenerateSingleLogoImageModel(BaseModel):
tasks_id: str
prompt: str
seed: str
class GenerateProductImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
image_strength: float
product_type: str
class GenerateRelightImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
direction: str
product_type: str

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class DesignPreProcessingModel(BaseModel):
sketches: list[dict]

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class PromptGenerationImageModel(BaseModel):
text: str

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
from typing import Any, Optional
class ResponseModel(BaseModel):
code: int = 200
msg: str = "OK!"
data: Optional[Any] = None

View File

@@ -1,13 +1,13 @@
top_description_list = ['service/attribute/config/descriptor/top/length.csv',
'service/attribute/config/descriptor/top/type.csv',
'service/attribute/config/descriptor/top/sleeve_length.csv',
'service/attribute/config/descriptor/top/sleeve_shape.csv',
'service/attribute/config/descriptor/top/sleeve_shoulder.csv',
'service/attribute/config/descriptor/top/neckline.csv',
'service/attribute/config/descriptor/top/design.csv',
'service/attribute/config/descriptor/top/opening_type.csv',
'service/attribute/config/descriptor/top/silhouette.csv',
'service/attribute/config/descriptor/top/collar.csv']
top_description_list = ['app/service/attribute/config/descriptor/top/length.csv',
'app/service/attribute/config/descriptor/top/type.csv',
'app/service/attribute/config/descriptor/top/sleeve_length.csv',
'app/service/attribute/config/descriptor/top/sleeve_shape.csv',
'app/service/attribute/config/descriptor/top/sleeve_shoulder.csv',
'app/service/attribute/config/descriptor/top/neckline.csv',
'app/service/attribute/config/descriptor/top/design.csv',
'app/service/attribute/config/descriptor/top/opening_type.csv',
'app/service/attribute/config/descriptor/top/silhouette.csv',
'app/service/attribute/config/descriptor/top/collar.csv']
top_model_list = ['attr_retrieve_T_length',
'attr_retrieve_T_type',
@@ -22,11 +22,11 @@ top_model_list = ['attr_retrieve_T_length',
]
bottom_description_list = [
'service/attribute/config/descriptor/bottom/subtype.csv',
'service/attribute/config/descriptor/bottom/length.csv',
'service/attribute/config/descriptor/bottom/silhouette.csv',
'service/attribute/config/descriptor/bottom/opening_type.csv',
'service/attribute/config/descriptor/bottom/design.csv']
'app/service/attribute/config/descriptor/bottom/subtype.csv',
'app/service/attribute/config/descriptor/bottom/length.csv',
'app/service/attribute/config/descriptor/bottom/silhouette.csv',
'app/service/attribute/config/descriptor/bottom/opening_type.csv',
'app/service/attribute/config/descriptor/bottom/design.csv']
bottom_model_list = [
'attr_retrieve_B_subtype',
@@ -35,14 +35,14 @@ bottom_model_list = [
'attr_recong_B_optype',
'attr_retrieve_B_design']
outwear_description_list = ['service/attribute/config/descriptor/outwear/length.csv',
'service/attribute/config/descriptor/outwear/sleeve_length.csv',
'service/attribute/config/descriptor/outwear/sleeve_shape.csv',
'service/attribute/config/descriptor/outwear/sleeve_shoulder.csv',
'service/attribute/config/descriptor/outwear/collar.csv',
'service/attribute/config/descriptor/outwear/design.csv',
'service/attribute/config/descriptor/outwear/opening_type.csv',
'service/attribute/config/descriptor/outwear/silhouette.csv', ]
outwear_description_list = ['app/service/attribute/config/descriptor/outwear/length.csv',
'app/service/attribute/config/descriptor/outwear/sleeve_length.csv',
'app/service/attribute/config/descriptor/outwear/sleeve_shape.csv',
'app/service/attribute/config/descriptor/outwear/sleeve_shoulder.csv',
'app/service/attribute/config/descriptor/outwear/collar.csv',
'app/service/attribute/config/descriptor/outwear/design.csv',
'app/service/attribute/config/descriptor/outwear/opening_type.csv',
'app/service/attribute/config/descriptor/outwear/silhouette.csv', ]
outwear_model_list = ['attr_recong_O_length',
'attr_retrieve_O_sleeve_length',
@@ -53,15 +53,15 @@ outwear_model_list = ['attr_recong_O_length',
'attr_recong_O_optype',
'attr_retrieve_O_silhouette']
dress_description_list = [ # 'service/attribute/config/descriptor/dress/D_length.csv',
'service/attribute/config/descriptor/dress/sleeve_length.csv',
'service/attribute/config/descriptor/dress/sleeve_shape.csv',
# 'service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv',
'service/attribute/config/descriptor/dress/neckline.csv',
'service/attribute/config/descriptor/dress/collar.csv',
'service/attribute/config/descriptor/dress/design.csv',
'service/attribute/config/descriptor/dress/silhouette.csv',
'service/attribute/config/descriptor/dress/type.csv']
dress_description_list = [ # 'app/service/attribute/config/descriptor/dress/D_length.csv',
'app/service/attribute/config/descriptor/dress/sleeve_length.csv',
'app/service/attribute/config/descriptor/dress/sleeve_shape.csv',
# 'app/service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv',
'app/service/attribute/config/descriptor/dress/neckline.csv',
'app/service/attribute/config/descriptor/dress/collar.csv',
'app/service/attribute/config/descriptor/dress/design.csv',
'app/service/attribute/config/descriptor/dress/silhouette.csv',
'app/service/attribute/config/descriptor/dress/type.csv']
dress_model_list = [ # 'attr_recong_D_length',
'attr_retrieve_D_sleeve_length',

View File

@@ -11,12 +11,12 @@ from minio import Minio
import tritonclient.http as httpclient
from app.core.config import *
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.oss_client import oss_get_image
class AttributeRecognition:
def __init__(self, const, request_data):
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
logging.info("实例化完成")
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
@@ -97,9 +97,10 @@ class AttributeRecognition:
return res
def get_image(self, url):
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

View File

@@ -18,12 +18,13 @@ import torch
from app.core.config import *
from app.schemas.attribute_retrieve import CategoryRecognitionModel
from app.service.utils.oss_client import oss_get_image
class CategoryRecognition:
def __init__(self, request_data):
self.attr_type = pd.read_csv(CATEGORY_PATH)
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.request_data = []
self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL)
for sketch in request_data:
@@ -51,9 +52,10 @@ class CategoryRecognition:
def get_image(self, url):
# Get data of an object.
# Read data from response.
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

View File

@@ -0,0 +1,7 @@
from .agent_executor import CustomAgentExecutor
from .conversational_functions_agent import ConversationalFunctionsAgent
__all__ = [
"CustomAgentExecutor",
"ConversationalFunctionsAgent"
]

View File

@@ -0,0 +1,132 @@
import inspect
import json
import logging
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain.agents import AgentExecutor
from langchain.callbacks.manager import Callbacks, CallbackManager
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, RunInfo
from langchain_core.agents import AgentAction, AgentFinish
class CustomAgentExecutor(AgentExecutor):
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
session_key: str = "",
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs, session_key)
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
logging.exception(e)
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs, session_key
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
session_key: str = ""
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None and outputs['need_record']:
self.memory.save_context(inputs, outputs, session_key)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs, session_key)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str]
) -> Optional[AgentFinish]:
"""Check if the tool is a returning tool."""
agent_action, observation = next_step_output
name_to_tool_map = {tool.name: tool for tool in self.tools}
return_value_key = "output"
if len(self.agent.return_values) > 0:
return_value_key = self.agent.return_values[0]
try:
observation_list = json.loads(observation)
if agent_action.tool == "sql_db_query" and isinstance(observation_list,
list) and observation_list.__len__() != 0:
return AgentFinish(
{return_value_key: observation},
"",
)
except:
pass
# Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map:
if name_to_tool_map[agent_action.tool].return_direct:
return AgentFinish(
{return_value_key: observation},
"",
)
return None

View File

@@ -0,0 +1,198 @@
import json
import re
from json import JSONDecodeError
from typing import List, Tuple, Any, Union
from dataclasses import dataclass
from langchain.callbacks.manager import Callbacks
from langchain.agents import (
OpenAIFunctionsAgent,
)
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
OutputParserException
)
from langchain.schema.messages import (
AIMessage,
FunctionMessage
)
from langchain.tools import BaseTool, StructuredTool
# from langchain.tools.convert_to_openai import FunctionDescription
from langchain.utils.openai_functions import FunctionDescription
@dataclass
class _FunctionsAgentAction(AgentAction):
"""Add message_log to AgentAction class for the _FunctionAgentAction
"""
message_log: List[BaseMessage]
def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
):
"""Override init to support instantiation by position for backward compat."""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agents action to a message.
This code is used to reconstruct the original AI message from the agents action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tools invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agents action and observation into a function message.
Args:
agent_action: the tools invocation request from the agents
observation: the result of the tools invocation
Returns:
FunctionMessage that corresponds to the original tools invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tools into the OpenAI function API."""
parameters = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": tool.param_description if hasattr(tool, 'param_description') else "",
},
},
"required": ["query"],
}
return {
"name": tool.name,
"description": tool.description,
"parameters": parameters,
}
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message but got {type(message)}")
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_call = message.additional_kwargs["function_call"]
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tools input: {function_call} because"
f"the `arguments` is not valid JSON."
)
if "query" in _tool_input:
tool_input = _tool_input["query"]
else:
tool_input = _tool_input
return _FunctionsAgentAction(
tool=function_name,
tool_input=tool_input,
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n",
message_log=[message]
)
# pattern = r'\((.*?)\)'
# matches = re.findall(pattern, message.content)
# result = []
#
# for match in matches:
# result.append(match)
#
# if result:
# output = result
# else:
# output = message.content
return AgentFinish(return_values={"output": message.content}, log=message.content)
class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
@property
def functions(self) -> List[dict]:
return [dict(_format_tool_to_openai_function(t)) for t in self.tools]
def plan(self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Decide how agents should move after receiving an input. The difference between
OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser
into it.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
**kwargs: Including user's input string
Returns:
Action specifying what tools to use.
"""
agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages: List[BaseMessage] = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision

View File

@@ -0,0 +1,6 @@
from .openai_token_record_callback import OpenAITokenRecordCallbackHandler
__all__ = [
'OpenAITokenRecordCallbackHandler'
]

View File

@@ -0,0 +1,46 @@
"""Callback Handler that add on_chain_end function to record Token usage."""
from typing import Any, Dict
from langchain.callbacks import OpenAICallbackHandler
from langchain.schema import LLMResult
from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
need_record: bool = True
response_type: str = "string"
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None
self.successful_requests += 1
if "token_usage" not in response.llm_output:
return None
if "function_call" in response.generations[0][0].message.additional_kwargs:
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
self.need_record = False
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
self.response_type = "image"
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
model_name = standardize_model_name(response.llm_output.get("model_name", ""))
if model_name in MODEL_COST_PER_1K_TOKENS:
completion_cost = get_openai_token_cost_for_model(
model_name, completion_tokens, is_completion=True
)
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
"""Write token usage to redis."""
outputs["total_tokens"] = self.total_tokens
outputs["total_cost"] = self.total_cost
outputs["prompt_tokens"] = self.prompt_tokens
outputs["completion_tokens"] = self.completion_tokens
outputs["need_record"] = self.need_record
outputs["response_type"] = self.response_type

View File

@@ -0,0 +1,79 @@
from typing import Optional, List
import json
from sqlalchemy import text
# from langchain import SQLDatabase
from langchain.utilities import SQLDatabase
class CustomDatabase(SQLDatabase):
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
# def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
connection = self._engine.connect()
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
# raise ValueError(f"table_names {missing_tables} not found in database")
return f"Table {','.join(missing_tables)} can not be found in the database"
all_table_names = table_names
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
]
tables = []
for table in meta_tables:
table_name = table.name
column_names = table.columns.keys()
table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n"
for column_name in column_names:
if column_name not in ["ID", "img_name"]:
query = text(f"SELECT DISTINCT {column_name} FROM {table_name}")
result = connection.execute(query)
enum_values: List[str] = [row[0] for row in result.fetchall()]
column_info = f"{column_name}: {', '.join(enum_values)}\n"
table_info += column_info
# table_info = f"Table: {table_name}\n"
#
# if self._sample_rows_in_table_info:
# table_info += f"{self._get_sample_rows(table)}\n"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def run(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.begin() as connection:
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
f"ALTER SESSION SET search_path='{self._schema}'"
)
elif self.dialect == "bigquery":
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
else:
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.execute(text(command))
if cursor.rowcount:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
# Convert columns values to string to avoid issues with sqlalchmey
# trunacating text
if isinstance(result, list):
return json.dumps([r[0] for r in result])
return json.dumps([result[0]])
return ""

View File

@@ -0,0 +1,115 @@
import json
import logging
from langchain.agents import Tool
from langchain.callbacks import FileCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.utilities import SerpAPIWrapper
from loguru import logger
from app.core.config import *
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
# log callbacks
logfile = "logs/chat_debug.log"
logger.add(logfile, colorize=True, enqueue=True)
log_handler = FileCallbackHandler(logfile)
# Initiate our LLM 'gpt-3.5-turbo'
llm = ChatOpenAI(temperature=0.1,
openai_api_key=OPENAI_API_KEY,
# callbacks=[OpenAICallbackHandler()]
)
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
tools = [
Tool(
name="internet_search",
description="Can be used to perform Internet searches",
func=search.run
),
QuerySQLDataBaseTool(db=db, return_direct=False),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
Tool(
name="tutorial_tool",
description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string",
func=CustomTutorialTool(),
return_direct=True
)
]
messages = [
SystemMessage(content=FASHION_CHAT_BOT_PREFIX),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template(
"{input} "
"Question from a {gender}."
),
AIMessage(content=TOOLS_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages)
agent = ConversationalFunctionsAgent(
llm=llm,
tools=tools,
prompt=prompt
)
memory = UserConversationBufferWindowMemory.from_redis(
return_messages=True, k=2, input_key='input', output_key='output'
)
agent_executor = CustomAgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
def chat(post_data):
user_id = post_data.user_id
session_id = post_data.session_id
input_message = post_data.message
gender = post_data.gender
final_outputs = agent_executor(
{"input": input_message, "gender": gender},
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
session_key=f"buffer:{user_id}:{session_id}",
)
api_response = {
'user_id': user_id,
'session_id': session_id,
# 'message_id': message_id,
# 'create_time': created_time,
'input': final_outputs['input'],
# 'conversion': messages,
'output': final_outputs['output'],
# 'gpt_response_time': gpt_response_time,
'total_tokens': final_outputs['total_tokens'],
'total_cost': final_outputs['total_cost'],
'prompt_tokens': final_outputs['prompt_tokens'],
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs['response_type']
}
logging.info(json.dumps(api_response))
return api_response

View File

@@ -0,0 +1,3 @@
from .user_buffer_window import UserConversationBufferWindowMemory
__all__ = ['UserConversationBufferWindowMemory']

View File

@@ -0,0 +1,93 @@
import logging
from typing import Any, Dict, List, Tuple
import json
import redis
from redis import Redis
from langchain.memory import RedisChatMessageHistory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
from langchain.schema.messages import _message_to_dict, messages_from_dict
from langchain.memory.utils import get_prompt_input_key
from app.core.config import *
class UserConversationBufferWindowMemory(BaseChatMemory):
"""Buffer for storing conversation memory."""
redis_client: Redis
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
k: int = 5
@classmethod
def from_redis(
cls,
host: str = REDIS_HOST,
port: int = REDIS_PORT,
db: int = 3,
**kwargs
):
redis_client = Redis(host=host, port=port, db=db)
try:
response = redis_client.ping()
if response:
print("Connect to redis server successfully.")
logging.info("Connect to redis server successfully.")
else:
print("Fail to connect to redis server")
logging.info("Fail to connect to redis server")
except redis.RedisError as e:
logging.info(f"Error occurs when connecting to redis server: {str(e)}")
return cls(redis_client=redis_client, **kwargs)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]:
"""Return history buffer."""
_items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else []
items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
buffer = messages_from_dict(items)
if not self.return_messages:
buffer = get_buffer_string(
buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: buffer}
def _get_input_output(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> Tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
def add_message(self, key: str, message: BaseMessage) -> None:
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.add_message(key, HumanMessage(content=input_str))
self.add_message(key, AIMessage(content=output_str))
# def clear(self, key) -> None:
# """Clear memory contents."""
# self.redis_client.delete(key)

View File

@@ -0,0 +1,52 @@
FASHION_CHAT_BOT_PREFIX = """
You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
Remember your answer should be very precise and the final output answer should not exceed 20 words.
You may encounter the following types of questions:
1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
2) If the query related to current events, you should use internet_search to seek help from the internet.
3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant.
Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
"""
TOOL_SELECT_SUFFIX = """
Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly.
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
The use of online resources should be limited to inquiry pertaining to current subjects.
"""
SQL_FUNCTIONS_SUFFIX = """
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
"""
INTERNET_SEARCH_SUFFIX = """
If the question should be answered using internet search tools, I should seek help from the internet.
"""
ANSWER_FORMAT_SUFFIX = """
My final answer are limited to 20 words and be as much precise as possible.
"""
TOOLS_FUNCTIONS_SUFFIX = (
"If the input involves clothing queries,"
"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
"All SQL statements must use 'ORDER BY RAND()', for example:"
"Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
"If the input does not involve clothing queries, "
"I should engage in conversation as an assistant or search from internet with internet_search tool."
"If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
"Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
)
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."

View File

@@ -0,0 +1,10 @@
from .sql_tools import (
QuerySQLDataBaseTool,
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool
)
__all__ = [
"QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool"
]

View File

@@ -0,0 +1,183 @@
# flake8: noqa
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
# from langchain.sql_database import SQLDatabase
from langchain.utilities import SQLDatabase
from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
class BaseSQLDatabaseTool(BaseModel):
"""Base tools for interacting with a SQL database."""
db: SQLDatabase = Field(exclude=True)
param_description: str = ""
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""
name = "sql_db_query"
# description = """
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
# You should always use order by rand() to randomly select data.
# If the query is not correct, an error message will be returned.
# If an error is returned, rewrite the query, check the query, and try again.
# Always limit your query to at most 4 results.
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
# You MUST double check your query before executing it.
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
# """
description: str = (
"The input of this tool is a detailed and correct SQL select query statement, "
"and the output is the result of the database, and it can only return up to 4 results."
"If the query is not correct, an error message will be returned."
"If an error is returned, rewrite the query, check the query, and try again."
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
"use sql_db_schema to query the correct table fields."
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
"LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'"
)
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
result = self.db.run_no_throw(query)
return result
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("QuerySqlDbTool does not support async")
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting metadata about a SQL database."""
name = "sql_db_schema"
# description = """
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
# Find the most relevant tables with the query, and output the schema of these tables.
# """
description: str = (
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
"Example Input: 'female_outwear, male_top'"
)
def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.db.get_table_info_no_throw(table_names.split(", "))
async def _arun(
self,
table_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("SchemaSqlDbTool does not support async")
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""
name = "sql_db_list_tables"
description = "Input is an empty string, output is a comma separated list of tables in the database."
def _run(
self,
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for a specific table."""
return ", ".join(self.db.get_usable_table_names())
async def _arun(
self,
tool_input: str = "",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async")
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "sql_db_query_checker"
description = (
"Use this tools to double check if your query is correct before executing it."
"Always use this tools before executing a query with sql_db_query!"
)
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER,
input_variables=["query", "dialect"]
),
)
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
)
return values
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)

View File

@@ -0,0 +1,19 @@
from typing import Any
from langchain.tools.base import BaseTool
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
# 处理系统引导教程相关的输入
class CustomTutorialTool(BaseTool):
name = "tutorial_tool"
description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string")
def _run(self, tool_input, **kwargs: Any) -> str:
return TUTORIAL_TOOL_RETURN
async def _arun(self, tool_input, **kwargs: Any) -> str:
raise NotImplementedError("CustomTutorialTool does not support async")

View File

@@ -0,0 +1 @@
from .logger import Logger

View File

@@ -0,0 +1,26 @@
import logging
from logging import handlers
class Logger(object):
level_relations = {
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'crit': logging.CRITICAL
}
def __init__(self, filename, level='info', when='D', backCount=3,
fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
self.logger = logging.getLogger(filename)
format_str = logging.Formatter(fmt) # set log format
self.logger.setLevel(self.level_relations.get(level)) # set log level
sh = logging.StreamHandler() # output to terminal
sh.setFormatter(format_str) # set format for terminal log
th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount,
encoding='utf-8') # log into file
th.setFormatter(format_str) # set format for file log
self.logger.addHandler(sh) # output to terminal
self.logger.addHandler(th) # output to file

View File

View File

@@ -0,0 +1,116 @@
import logging
import numpy as np
import cv2
from matplotlib import pyplot as plt
from PIL import Image
def show(img, win_name="temp"):
cv2.imshow(win_name, img)
cv2.waitKey(0)
def crop(img):
mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2)
img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340]
return img_roi
class Layer(object):
def __init__(self):
self._layer = []
@property
def layer(self):
return self._layer
def insert(self, layer_instance):
if layer_instance['name'] == 'body':
self._body = layer_instance
self._layer.append(layer_instance)
def sort(self, priority):
self._layer.sort(key=lambda x: priority[x['name']])
# def merge(self, cfg):
# """
# opencv shape order (height, width, channel)
# image coordinate system:
# |------------->x (width)
# |
# |
# |
# y (height)
# Returns:
#
#
# """
# base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0))
# for layer in self._layer:
# y, x = layer['position']
# base_image.paste(layer['image'], (x, y), layer['image'])
# # base_image.show()
#
# for x in self._layer:
# if np.all(x['mask'] == 0):
# continue
# # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask)
# roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x)
# temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2)
# tmp1 = (roi * (temp_bg / 255)).astype(np.uint8)
# temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2)
# tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8)
#
# roi[:] = cv2.add(tmp1, tmp2)
# # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA),
# # win_name=x.get('name'))
# # crop image and get the central part
# if cfg.get('basic')['self_template'] == False:
# dst_roi = crop(dst)
# else:
# dst_roi = dst
# return dst_roi, signal
#
# @staticmethod
# def get_roi(dst, image):
# signal = False
# dst_y, dst_x = dst.shape[:2]
# roi_height, roi_width = image['mask'].shape
# roi_y0, roi_x0 = image['position']
#
# if roi_y0 < 0:
# roi_yin = 0
# mask_yin = -roi_y0
# signal = True
# else:
# roi_yin = roi_y0
# mask_yin = 0
# if roi_y0 + roi_height > dst_y:
# roi_yout = dst_y
# mask_yout = dst_y - roi_y0
# signal = True
# else:
# roi_yout = roi_height + roi_y0
# mask_yout = roi_height
# # x part
# if roi_x0 < 0:
# roi_xin = 0
# mask_xin = -roi_x0
# signal = True
# else:
# roi_xin = roi_x0
# mask_xin = 0
# if roi_x0 + roi_width > dst_x:
# roi_xout = dst_x
# mask_xout = dst_x - roi_x0
# signal = True
# else:
# roi_xout = roi_width + roi_x0
# mask_xout = roi_width
#
# roi = dst[roi_yin: roi_yout, roi_xin: roi_xout]
# roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout]
# roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout]
# return roi, roi_mask, roi_image, signal

View File

@@ -0,0 +1,45 @@
class Priority(object):
"""Item layer priority levels.
"""
def __init__(self, item_list):
self._priority = dict(
earring_front=99,
bag_front=98,
hairstyle_front=97,
outwear_front=20,
bottoms_front=19,
dress_front=18,
blouse_front=17,
skirt_front=16,
trousers_front=15,
tops_front=14,
shoes_right=1,
shoes_left=1,
body=0,
tops_back=-14,
trousers_back=-15,
skirt_back=-16,
blouse_back=-17,
dress_back=-18,
bottoms_back=-19,
outwear_back=-20,
hairstyle_back=-97,
bag_back=-98,
earring_back=-99,
)
self.clothing_start_num = 10
if not isinstance(item_list, list):
raise ValueError('item_list must be a list!')
for cate in item_list:
cate = cate.lower()
if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
raise ValueError(f'Item type error. Cannot recognize {cate}')
for i, cate in enumerate(item_list):
cate = cate.lower()
self._priority[f'{cate}_front'] = self.clothing_start_num - i
self._priority[f'{cate}_back'] = -(self.clothing_start_num - i)
@property
def priority(self):
return self._priority

View File

@@ -0,0 +1,771 @@
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 493827,
"color": "127 61 21",
"elementId": 493827,
"icon": "none",
"image_id": 110201,
"offset": [
1,
1
],
"path": "aida-users/31/sketch/62302527-2910-4740-808d-2cb8221daa34-3-31.png",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Dress"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "27 25 23",
"icon": "none",
"image_id": 110202,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/skirt/0916000602.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Skirt"
},
{
"businessId": 493825,
"color": "229 214 200",
"elementId": 493825,
"icon": "none",
"image_id": 107101,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"businessId": 493824,
"color": "76 124 124",
"elementId": 493824,
"icon": "none",
"image_id": 104522,
"offset": [
1,
1
],
"path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "229 214 200",
"icon": "none",
"image_id": 110203,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0825001576.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"color": "76 124 124",
"icon": "none",
"image_id": 96071,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/skirt/903000097.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Skirt"
},
{
"color": "209 125 29",
"icon": "none",
"image_id": 93798,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/outwear/outwear_p4_561.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 493824,
"color": "209 125 29",
"elementId": 493824,
"icon": "none",
"image_id": 104522,
"offset": [
1,
1
],
"path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"color": "118 123 115",
"icon": "none",
"image_id": 110204,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0902000457.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"color": "118 123 115",
"icon": "none",
"image_id": 79259,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/826000094.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "127 61 21",
"icon": "none",
"image_id": 96038,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/dress/0902003549.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Dress"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 493822,
"color": "127 61 21",
"elementId": 493822,
"icon": "none",
"image_id": 62309,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/trousers/c37c2ea6-8955-4b40-8339-c737e672ca3d.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
},
{
"businessId": 493825,
"color": "118 123 115",
"elementId": 493825,
"icon": "none",
"image_id": 107101,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 493826,
"color": "127 61 21",
"elementId": 493826,
"icon": "none",
"image_id": 107105,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/Skirt/58710352-6301-450d-b69a-fb2922b5429a.png",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Skirt"
},
{
"color": "118 123 115",
"icon": "none",
"image_id": 79114,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/903000169.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"color": "229 214 200",
"icon": "none",
"image_id": 90573,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/outwear/0628000541.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Outwear"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
},
{
"basic": {
"body_point_test": {
"waistband_right": [
336,
264
],
"hand_point_right": [
350,
303
],
"waistband_left": [
245,
274
],
"hand_point_left": [
219,
315
],
"shoulder_left": [
227,
155
],
"shoulder_right": [
338,
149
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "229 214 200",
"icon": "none",
"image_id": 110205,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/0916000217.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
},
{
"businessId": 493825,
"color": "209 125 29",
"elementId": 493825,
"icon": "none",
"image_id": 107101,
"offset": [
1,
1
],
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
"image_id": 82966,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
}
],
"process_id": "6878547032381675"
}

View File

@@ -0,0 +1,16 @@
from .builder import ITEMS, build_item
from .clothing import Clothing # 4.0 sec
from .body import Body
from .top import Top, Blouse, Outwear, Dress
from .bottom import Bottom, Trousers, Skirt
from .shoes import Shoes
from .bag import Bag
from .accessories import Hairstyle, Earring
__all__ = [
'ITEMS', 'build_item',
'Clothing', 'Body',
'Top', 'Blouse', 'Outwear', 'Dress',
'Bottom', 'Trousers', 'Skirt',
'Shoes', 'Bag', 'Hairstyle', 'Earring'
]

View File

@@ -0,0 +1,59 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Hairstyle(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Hairstyle, self).__init__(**kwargs)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
"""
align up
Args:
keypoint_type: string, "head_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
side_indicator = f'{keypoint_type}_up'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
# logging.info(clothes_point[side_indicator])
start_point = (
int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)),
int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale))
)
return start_point
@ITEMS.register_module()
class Earring(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Earring, self).__init__(**kwargs)

View File

@@ -0,0 +1,45 @@
import random
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Bag(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Bag, self).__init__(**kwargs)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
"""
align left
Args:
keypoint_type: string, "hand_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (y', x')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
location = random.choice(seq=['left', 'right'])
if location == 'left':
side_indicator = f'{keypoint_type}_left'
else:
side_indicator = f'{keypoint_type}_right'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale),
body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale))
return start_point

View File

@@ -0,0 +1,36 @@
import cv2
from .builder import ITEMS
from .pipelines import Compose
@ITEMS.register_module()
class Body(object):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']),
# dict(type='ImageShow', key=['body_image', "body_mask"])
]
self.pipeline = Compose(pipeline)
self.result = dict()
def process(self):
self.pipeline(self.result)
pass
def organize(self, layer):
body_layer = dict(priority=0,
name=type(self).__name__.lower(),
image=self.result['body_image'],
image_url=self.result['image_url'],
mask_image=None,
mask_url=None,
sacle=1,
# mask=self.result['body_mask'],
position=(0, 0))
layer.insert(body_layer)
@staticmethod
def show(img):
cv2.imshow('', img)
cv2.waitKey(0)

View File

@@ -0,0 +1,38 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Bottom(Clothing):
def __init__(self, pipeline, **kwargs):
if pipeline is None:
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']),
]
kwargs.update(pipeline=pipeline)
super(Bottom, self).__init__(**kwargs)
@ITEMS.register_module()
class Trousers(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Trousers, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Skirt(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Skirt, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Bottoms(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Bottoms, self).__init__(pipeline, **kwargs)

View File

@@ -0,0 +1,9 @@
from mmcv.utils import Registry, build_from_cfg
ITEMS = Registry('item')
PIPELINES = Registry('pipeline')
def build_item(cfg, default_args=None):
item = build_from_cfg(cfg, ITEMS, default_args)
return item

View File

@@ -0,0 +1,99 @@
import cv2
from app.core.config import PRIORITY_DICT
from .builder import ITEMS
from .pipelines import Compose
@ITEMS.register_module()
class Clothing(object):
def __init__(self, pipeline, **kwargs):
self.pipeline = Compose(pipeline)
self.result = dict(name=type(self).__name__.lower(), **kwargs)
def process(self):
self.pipeline(self.result)
def apply_scale(self, img):
scale = self.result['scale']
height, width = img.shape[0: 2]
if len(img.shape) > 2:
height, width = img.shape[0: 2]
scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
return scaled_img
def organize(self, layer):
start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"])
front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None),
name=f'{type(self).__name__.lower()}_front',
image=self.result["front_image"],
# mask_image=self.result['front_mask_image'],
image_url=self.result['front_image_url'],
mask_url=self.result['front_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url']
)
layer.insert(front_layer)
back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None),
name=f'{type(self).__name__.lower()}_back',
image=self.result["back_image"],
# mask_image=self.result['back_mask_image'],
image_url=self.result['back_image_url'],
mask_url=self.result['back_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url']
)
layer.insert(back_layer)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
"""
Align left
Args:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale + offset
y' = x_body - x1 * scale + offset
"""
side_indicator = f'{keypoint_type}_left'
# if keypoint_type == "ear_point":
# start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
# body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
# else:
# start_point = (
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
# )
# milvus_DB_keypoint_cache:
start_point = (
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
)
# start_point = (
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
# )
return start_point

View File

@@ -0,0 +1,19 @@
from .compose import Compose
from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow
from .keypoints import KeypointDetection
from .segmentation import Segmentation
from .painting import Painting, PrintPainting
from .scale import Scaling
from .contour_detection import ContourDetection
from .split import Split
__all__ = [
'Compose',
'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow',
'KeypointDetection',
'Segmentation',
'Painting', 'PrintPainting',
'Scaling',
'ContourDetection',
'split',
]

View File

@@ -0,0 +1,36 @@
import collections
from mmcv.utils import build_from_cfg
from ..builder import PIPELINES
@PIPELINES.register_module()
class Compose(object):
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, data):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data

View File

@@ -0,0 +1,58 @@
import cv2
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class ContourDetection(object):
def __init__(self):
# logging.info("ContourDetection run ")
pass
# @ RunTime
def __call__(self, result):
# shoe diff
if result['name'] == 'shoes':
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
for i in range(2):
Max_contour = Contour[i]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
else:
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
# TODO 修复部分图片出现透明的情况 下版本上线
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
# Mask = cv2.bitwise_not(Mask)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
return result
@staticmethod
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour

View File

@@ -0,0 +1,139 @@
import logging
import time
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from ..builder import PIPELINES
from ...utils.design_ensemble import get_keypoint_result
@PIPELINES.register_module()
class KeypointDetection(object):
"""
path here: abstract path
"""
# def __init__(self):
# self.client = MilvusClient(
# uri="http://10.1.1.240:19530",
# token="root:Milvus",
# db_name=MILVUS_ALIAS
# )
# def __del__(self):
# start_time = time.time()
# self.client.close()
# print(f"client close time : {time.time() - start_time}")
# @ RunTime
def __call__(self, result):
# logging.info("KeypointDetection run ")
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
result['clothes_keypoint'] = keypoint_cache
return result
@staticmethod
def infer_keypoint_result(result):
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
start_time = time.time()
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
# logging.info(f"infer keypoint time : {time.time() - start_time}")
return keypoint_infer_result, site
@staticmethod
# @ RunTime
def save_keypoint_cache(keypoint_id, cache, site):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([cache.flatten(), zeros])
# 取消向量保存 直接拿结果
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": site,
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# start_time = time.time()
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
if site == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False

View File

@@ -0,0 +1,130 @@
import cv2
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile(object):
def __init__(self, path, color=None, print_dict=None):
self.path = path
self.color = color
self.print_dict = print_dict
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def __call__(self, result):
result['image'], result['pre_mask'] = self.read_image(self.path)
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['keypoint'] = self.get_keypoint(result['name'])
result['path'] = self.path
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
result['color'] = self.color if self.color is not None else None
result['print_dict'] = self.print_dict
return result
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
keypoint = 'waistband'
elif name == 'bag':
keypoint = 'hand_point'
elif name == 'shoes':
keypoint = 'toe'
elif name == 'hairstyle':
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")
return keypoint
@staticmethod
def read_image(image_path):
image_mask = None
# file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
return image, image_mask
@PIPELINES.register_module()
class LoadBodyImageFromFile(object):
def __init__(self, body_path):
self.body_path = body_path
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
# @ RunTime
def __call__(self, result):
result["image_url"] = result['body_path'] = self.body_path
result["name"] = "mannequin"
# if not result['image_url'].lower().endswith(".png"):
# bucket = self.body_path.split("/", 1)[0]
# object_name = self.body_path.split("/", 1)[1]
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
# image = self.minioClient.get_object(bucket, object_name)
# image = Image.open(io.BytesIO(image.data))
# image = image.convert("RGBA")
# data = image.getdata()
# #
# new_data = []
# for item in data:
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
# new_data.append((255, 255, 255, 0))
# else:
# new_data.append(item)
# image.putdata(new_data)
# image_data = io.BytesIO()
# image.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# self.body_path = image_path
# result["image_url"] = result['body_path'] = self.body_path
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
# put_image_time = time.time()
# result['body_image'] = Image.open(io.BytesIO(response.read()))
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
return result
@PIPELINES.register_module()
class ImageShow(object):
def __init__(self, key):
self.key = key
# @ RunTime
def __call__(self, result):
import matplotlib.pyplot as plt
if isinstance(self.key, list):
for key in self.key:
plt.imshow(result[key])
plt.title(key)
plt.show()
elif isinstance(self.key, str):
img = self._resize_img(result[self.key])
cv2.imshow(self.key, img)
cv2.waitKey(0)
else:
raise TypeError(f'key should be string but got type {type(self.key)}.')
return result
@staticmethod
def _resize_img(img):
shape = img.shape
if shape[0] > 400 or shape[1] > 400:
ratio = min(400 / shape[0], 400 / shape[1])
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
return img

View File

@@ -0,0 +1,611 @@
import random
import cv2
import numpy as np
from PIL import Image
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
@PIPELINES.register_module()
class Painting(object):
def __init__(self, painting_flag=True):
self.painting_flag = painting_flag
# @ RunTime
def __call__(self, result):
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
dim_image_h, dim_image_w = result['image'].shape[0:2]
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
result['alpha'] = 100 / 255.0
else:
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
get_image_fir = result['image'] * (closed_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
return result
@staticmethod
def get_gradient(bucket_name, object_name):
# image_data = minio_client.get_object(bucket_name, object_name)
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
# 从数据流中读取图像
# image_bytes = image_data.read()
# 将图像数据转换为numpy数组
# image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
# 使用OpenCV解码图像数组
# image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
return image
@staticmethod
def crop_image(image, image_size_h, image_size_w):
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
return pattern
@PIPELINES.register_module()
class PrintPainting(object):
def __init__(self, print_flag=True):
self.print_flag = print_flag
# @ RunTime
def __call__(self, result):
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
result['single_image'] = None
result['print_image'] = None
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
if single_print['print_path_list']:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if element_print['element_path_list']:
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])):
image, image_mode = self.read_image(element_print['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
print(1)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
img2gray = cv2.cvtColor(print_background, cv2.COLOR_BGR2GRAY)
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask_)
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_)
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_inv)
print_background = img1_bg + img2_fg
return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
if print_trigger:
print_ = self.get_print(print_dict)
painting_dict['Trigger'] = not is_single
painting_dict['location'] = print_['location']
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not is_single:
self.random_seed = random.randint(0, 1000)
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
tile = None
if not trigger:
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
else:
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
if len(pattern.shape) == 2:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
if len(pattern.shape) == 3:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
return tile
def get_mask_inv(self, print_):
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
mask_inv = cv2.inRange(print_tile, lower, upper)
return mask_inv
else:
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
return mask_inv
@staticmethod
def printpaint(result, painting_dict, print_=False):
if print_ and painting_dict['Trigger']:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else:
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:
assert f'there must be parameter of location if choose IfSingle'
for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
print_image = cv2.add(img_bg, img_fg)
return print_image
@staticmethod
def get_print(print_dict):
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
random.seed(self.random_seed)
# logging.info(f'overall print location : {location}')
# x_offset = random.randint(0, image.shape[0] - image_size_h)
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
x_offset = print_w - int(location[0][1] % print_w)
y_offset = print_w - int(location[0][0] % print_h)
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_low_high_lab(Lab_value, L=False):
if L:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
else:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
return high, low
@staticmethod
def img_rotate(image, angel, scale):
"""顺时针旋转图像任意角度
Args:
image (np.array): [原始图像]
angel (float): [逆时针旋转的角度]
Returns:
[array]: [旋转后的图像]
"""
h, w = image.shape[:2]
center = (w // 2, h // 2)
# if type(angel) is not int:
# angel = 0
M = cv2.getRotationMatrix2D(center, -angel, scale)
# 调整旋转后的图像长宽
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
M[0, 2] += (rotated_w - w) // 2
M[1, 2] += (rotated_h - h) // 2
# 旋转图像
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
# return rotated_img, (0, 0)
@staticmethod
def rotate_crop_image(img, angle, crop):
"""
angle: 旋转的角度
crop: 是否需要进行裁剪,布尔向量
"""
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
w, h = img.shape[:2]
# 旋转角度的周期是360°
angle %= 360
# 计算仿射变换矩阵
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
# 得到旋转后的图像
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
# 如果需要去除黑边
if crop:
# 裁剪角度的等效周期是180°
angle_crop = angle % 180
if angle > 90:
angle_crop = 180 - angle_crop
# 转化角度为弧度
theta = angle_crop * np.pi / 180
# 计算高宽比
hw_ratio = float(h) / float(w)
# 计算裁剪边长系数的分子项
tan_theta = np.tan(theta)
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
# 计算分母中和高宽比相关的项
r = hw_ratio if h > w else 1 / hw_ratio
# 计算分母项
denominator = r * tan_theta + 1
# 最终的边长系数
crop_mult = numerator / denominator
# 得到裁剪区域
w_crop = int(crop_mult * w)
h_crop = int(crop_mult * h)
x0 = int((w - w_crop) / 2)
y0 = int((h - h_crop) / 2)
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
return img_rotated
@staticmethod
def read_image(image_url):
image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
if image.shape[2] == 4:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image = Image.fromarray(image_rgb)
image_mode = "RGBA"
else:
image_mode = "RGB"
return image, image_mode
@staticmethod
def resize_and_crop(img, target_width, target_height):
# 获取原始图像的尺寸
original_height, original_width = img.shape[:2]
# 计算目标尺寸的宽高比
target_ratio = target_width / target_height
# 计算原始图像的宽高比
original_ratio = original_width / original_height
# 调整尺寸
if original_ratio > target_ratio:
# 原始图像更宽按高度resize然后裁剪宽度
new_height = target_height
new_width = int(original_width * (target_height / original_height))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪宽度
start_x = (new_width - target_width) // 2
cropped_img = resized_img[:, start_x:start_x + target_width]
else:
# 原始图像更高按宽度resize然后裁剪高度
new_width = target_width
new_height = int(original_height * (target_width / original_width))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪高度
start_y = (new_height - target_height) // 2
cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img

View File

@@ -0,0 +1,56 @@
import math
import cv2
from ..builder import PIPELINES
@PIPELINES.register_module()
class Scaling(object):
def __init__(self):
pass
# @ RunTime
def __call__(self, result):
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
# milvus_db_keypoint_cache
distance_clo = math.sqrt(
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
+
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2)
distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
# distance_clo = math.sqrt(
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2
# +
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2)
#
# distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'toe':
distance_bdy = math.sqrt(
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
+
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
)
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
Edge = cv2.Canny(Blur, 10, 200)
Edge = cv2.dilate(Edge, None)
Edge = cv2.erode(Edge, None)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
Max_contour = Contours[0]
x, y, w, h = cv2.boundingRect(Max_contour)
width = w
distance_clo = width
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'hand_point':
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
return result

View File

@@ -0,0 +1,14 @@
from ..builder import PIPELINES
from ...utils.design_ensemble import get_seg_result
@PIPELINES.register_module()
class Segmentation(object):
def __init__(self, device='cpu', show=False, debug=None):
self.show = show
self.device = device
self.debug = debug
def __call__(self, result):
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
return result

View File

@@ -0,0 +1,81 @@
import logging
import cv2
import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.service.utils.generate_uuid import generate_uuid
from ..builder import PIPELINES
from ...utils.conversion_image import rgb_to_rgba
from ...utils.upload_image import upload_png_mask
@PIPELINES.register_module()
class Split(object):
"""
Split image into front and back layer according to the segmentation result
"""
# KNet
def __call__(self, result):
try:
if 'mask' not in result.keys():
raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.')
if 'seg_result' not in result.keys(): # 没过seg模型
result['front_mask'] = result['mask'].copy()
result['back_mask'] = np.zeros_like(result['mask'])
else:
temp_front = result['seg_result'] == 1
result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8))
temp_back = result['seg_result'] == 2
result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8))
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
if len(result['front_mask'].shape) > 2:
front_mask = result['front_mask'][0]
else:
front_mask = result['front_mask']
if len(result['back_mask'].shape) > 2:
back_mask = result['back_mask'][0]
else:
back_mask = result['back_mask']
rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask'])
result_front_image = np.zeros_like(rgba_image)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"][1]))
result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS)
# result['front_mask_image'] = cv2.resize(front_mask, front_new_size)
# result['front_image'] = result_front_image_pil
front_mask = cv2.resize(front_mask, front_new_size)
result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask)
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"][1]))
result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS)
# result['back_mask_image'] = cv2.resize(back_mask, back_new_size)
# result['back_image'] = result_back_image_pil
back_mask = cv2.resize(back_mask, back_new_size)
result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask)
else:
result['back_image'] = None
result["back_image_url"] = None
result["back_mask_url"] = None
result['back_mask_image'] = None
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba((result['pattern_image'].shape[0], result['pattern_image'].shape[1]), result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
_, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -0,0 +1,121 @@
import cv2
import numpy as np
from PIL import Image
from .builder import ITEMS
from .clothing import Clothing
from ..utils.conversion_image import rgb_to_rgba
from ..utils.upload_image import upload_png_mask
from ...utils.generate_uuid import generate_uuid
@ITEMS.register_module()
class Shoes(Clothing):
# TODO location of shoes has little mismatch
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Shoes, self).__init__(**kwargs)
def organize(self, layer):
left_shoe_mask, right_shoe_mask = self.cut()
left_layer = dict(name=f'{type(self).__name__.lower()}_left',
image=self.result['shoes_left'],
image_url=self.result['left_image_url'],
mask_url=self.result['left_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=self.calculate_start_point(self.result['keypoint'],
self.result['scale'],
self.result['clothes_keypoint'],
self.result['body_point'],
'left'))
layer.insert(left_layer)
right_layer = dict(name=f'{type(self).__name__.lower()}_right',
image=self.result['shoes_right'],
image_url=self.result['right_image_url'],
mask_url=self.result['right_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=self.calculate_start_point(self.result['keypoint'],
self.result['scale'],
self.result['clothes_keypoint'],
self.result['body_point'],
'right'))
layer.insert(right_layer)
def cut(self):
"""
Cut shoes mask into two pieces
Returns:
"""
contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = sorted(contour, key=cv2.contourArea, reverse=True)
bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]]
(contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False))
epsilon_left = 0.001 * cv2.arcLength(contours[0], True)
approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True)
mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8)
cv2.drawContours(mask_left, [approx_left], -1, 255, -1)
item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0)
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left)
result_image = np.zeros_like(rgba_image)
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
result_left_image_pil = Image.fromarray(result_image, 'RGBA')
result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS)
self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}")
epsilon_right = 0.001 * cv2.arcLength(contours[1], True)
approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True)
mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8)
cv2.drawContours(mask_right, [approx_right], -1, 255, -1)
item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0)
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right)
result_image = np.zeros_like(rgba_image)
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
result_right_image_pil = Image.fromarray(result_image, 'RGBA')
result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS)
self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}")
return item_mask_left, item_mask_right
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location):
"""
left shoes align left
right shoes align right
Args:
keypoint_type: string, "toe"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
location: string, indicates whether the start point belongs to right or left shoe
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
if location not in ['left', 'right']:
raise KeyError(f'location value must be left or right but got {location}')
side_indicator = f'{keypoint_type}_{location}'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
return start_point

View File

@@ -0,0 +1,46 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Top(Clothing):
def __init__(self, pipeline, **kwargs):
if pipeline is None:
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
dict(type='Scaling'),
dict(type='Split'),
]
kwargs.update(pipeline=pipeline)
super(Top, self).__init__(**kwargs)
@ITEMS.register_module()
class Blouse(Top):
def __init__(self, pipeline=None, **kwargs):
super(Blouse, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Outwear(Top):
def __init__(self, pipeline=None, **kwargs):
super(Outwear, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Dress(Top):
def __init__(self, pipeline=None, **kwargs):
super(Dress, self).__init__(pipeline, **kwargs)
# Men's clothing
@ITEMS.register_module()
class Tops(Top):
def __init__(self, pipeline=None, **kwargs):
super(Tops, self).__init__(pipeline, **kwargs)

View File

@@ -0,0 +1,28 @@
import io
from app.service.utils.oss_client import oss_get_image, oss_upload_image
def model_transpose(image_path):
bucket = image_path.split("/", 1)[0]
object_name = image_path.split("/", 1)[1]
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
image = oss_get_image(bucket=bucket, object_name=object_name, data_type="PIL")
image = image.convert("RGBA")
data = image.getdata()
#
new_data = []
for item in data:
if item[0] >= 256 and item[1] >= 256 and item[2] >= 256:
new_data.append((255, 255, 255, 0))
else:
new_data.append(item)
image.putdata(new_data)
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
oss_upload_image(bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
image_path = f"{bucket}/{new_object_name}"
return image_path

View File

@@ -0,0 +1,134 @@
import concurrent.futures
from app.core.config import PRIORITY_DICT
from app.service.design.core.layer import Layer
from app.service.design.items import build_item
from app.service.design.utils.redis_utils import Redis
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
from app.service.utils.decorator import RunTime
def process_item(item, layers):
# logging.info("process running.........")
item.process()
item.organize(layers)
if item.result['name'] == "mannequin":
return item.result['body_image'].size
def update_progress(process_id, total):
r = Redis()
progress = r.read(key=process_id)
if progress and total != 1:
if int(progress) <= 100:
r.write(key=process_id, value=int(progress) + int(100 / total))
else:
r.write(key=process_id, value=100)
return progress
elif total == 1:
r.write(key=process_id, value=100)
return progress
else:
r.write(key=process_id, value=int(100 / total))
return progress
def final_progress(process_id):
r = Redis()
progress = r.read(key=process_id)
r.write(key=process_id, value=100)
return progress
@RunTime
def generate(request_data):
return_response = {}
request_data = request_data.dict()
assert "process_id" in request_data.keys(), "Need process_id parameters"
objects = request_data['objects']
# insert_keypoint_cache(objects)
process_id = request_data['process_id']
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交每个对象的处理任务
futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)}
# 获取处理结果
for future in concurrent.futures.as_completed(futures):
obj = futures[future]
result = future.result()
return_response[obj] = result
final_progress(process_id)
return return_response
def process_object(cfg, process_id, total):
basic_info = cfg.get('basic')
items_response = {
'layers': []
}
if cfg.get('basic')['single_overall'] == 'overall':
basic_info['debug'] = False
items = [build_item(x, default_args=basic_info) for x in cfg.get('items')]
layers = Layer()
body_size = None
futures = []
for item in items:
futures = [process_item(item, layers)]
for future in futures:
if future is not None:
body_size = future
# 是否自定义排序
if basic_info.get('layer_order', False):
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
else:
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
# 合成
items_response['synthesis_url'] = synthesis(layers, body_size)
for lay in layers:
items_response['layers'].append({
'image_category': lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
# 'image': lay['image'],
# 'mask_image': lay['mask_image'],
})
elif cfg.get('basic')['single_overall'] == 'single':
assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters "
basic_info['debug'] = False
for item in cfg.get('items'):
if item['type'] == cfg.get('basic')['switch_category']:
item = build_item(item, default_args=cfg.get('basic'))
item.process()
items_response['layers'].append({
'image_category': f"{item.result['name']}_front",
'image_size': item.result['back_image'].size if item.result['back_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['front_image_url'],
'mask_url': item.result['front_mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else ""
})
items_response['layers'].append({
'image_category': f"{item.result['name']}_back",
'image_size': item.result['front_image'].size if item.result['front_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['back_image_url'],
'mask_url': item.result['back_mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else ""
})
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
break
update_progress(process_id, total)
return items_response

View File

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File conversion_image.py
@Author :周成融
@Date 2023/8/21 10:40:29
@detail
"""
import numpy as np
def rgb_to_rgba(rgb_size, rgb_image, mask):
alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
# 创建四通道的结果图像
rgba_image = np.dstack((rgb_image, alpha_channel))
alpha_channel = np.where(mask > 0, 255, 0)
# 更新RGBA图像的透明度通道
rgba_image[:, :, 3] = alpha_channel
return rgba_image
if __name__ == '__main__':
image = open("")

View File

@@ -0,0 +1,140 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File design_ensemble.py
@Author :周成融
@Date 2023/8/16 19:36:21
@detail :发起请求 获取推理结果
"""
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import *
"""
keypoint
预处理 推理 后处理
"""
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
img, w_scale, h_scale = mmcv.imresize(img, img_scale, return_scale=True)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
# @ RunTime
# 推理
def get_keypoint_result(image, site):
keypoint_result = None
try:
image, scale_factor = keypoint_preprocess(image)
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
inference_output = torch.from_numpy(results.as_numpy(f'output'))
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
except Exception as e:
logging.warning(f"get_keypoint_result : {e}")
return keypoint_result
def keypoint_postprocess(output, scale_factor):
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
segment_result = max_coords.numpy()
scale_factor = [1 / x for x in scale_factor[::-1]]
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
"""
seg
预处理 推理 后处理
"""
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
scale_factor = []
img, x, y = mmcv.imresize(img, (img_scale_w, img_scale_h), return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# @ RunTime
def get_seg_result(image_id, image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
def key_point_show(image_path, key_point_result=None):
img = cv2.imread(image_path)
points_list = key_point_result
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for point in points_list:
cv2.circle(img, point[::-1], point_size, point_color, thickness)
cv2.imshow("0", img)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread("./14162b58-f259-4833-98cb-89b9b496b251.jfif")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("./14162b58-f259-4833-98cb-89b9b496b251.jfif", new_list)
# a = get_seg_result(1, image)
print(a)

View File

@@ -0,0 +1,99 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -0,0 +1,167 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File synthesis_item.py
@Author :周成融
@Date 2023/8/26 14:13:04
@detail
"""
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
def positioning(all_mask_shape, mask_shape, offset):
all_start = 0
all_end = 0
mask_start = 0
mask_end = 0
if offset == 0:
all_start = 0
all_end = min(all_mask_shape, mask_shape)
mask_start = 0
mask_end = min(all_mask_shape, mask_shape)
elif offset > 0:
all_start = min(offset, all_mask_shape)
all_end = min(offset + mask_shape, all_mask_shape)
mask_start = 0
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
elif offset < 0:
if abs(offset) > mask_shape:
all_start = 0
all_end = 0
else:
all_start = 0
if mask_shape - abs(offset) > all_mask_shape:
all_end = min(mask_shape - abs(offset), all_mask_shape)
else:
all_end = mask_shape - abs(offset)
if abs(offset) > mask_shape:
mask_start = mask_shape
mask_end = mask_shape
else:
mask_start = abs(offset)
if mask_shape - abs(offset) >= all_mask_shape:
mask_end = all_mask_shape + abs(offset)
else:
mask_end = mask_shape
return all_start, all_end, mask_start, mask_end
# @RunTime
def synthesis(data, size):
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
top_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
bottom_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
top = True
bottom = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
top_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front"]:
bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
bottom_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(cropped_image, (0, 0), cropped_image)
else:
base_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
result_image = base_image
with io.BytesIO() as output:
result_image.save(output, format='PNG')
data = output.getvalue()
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
result_image = front_image
if back_image:
result_image.paste(back_image, (0, 0), back_image)
# with io.BytesIO() as output:
# result_image.save(output, format='PNG')
# data = output.getvalue()
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# oss upload
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"

View File

@@ -0,0 +1,45 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File upload_image.py
@Author :周成融
@Date 2023/8/28 13:49:20
@detail
"""
import io
import logging
import cv2
from app.core.config import *
from app.service.utils.oss_client import oss_upload_image
# @RunTime
def upload_png_mask(front_image, object_name, mask=None):
try:
mask_url = None
if mask is not None:
mask_inverted = cv2.bitwise_not(mask)
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
# image_bytes = io.BytesIO()
# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
# image_bytes.seek(0)
# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}"
# oss upload ####################
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
image_data = io.BytesIO()
front_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
return front_image, image_url, mask_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -0,0 +1,374 @@
import logging
import time
import cv2
import numpy as np
import torch
import tritonclient.grpc as grpcclient
from urllib3.exceptions import ResponseError
from app.core.config import *
from app.schemas.pre_processing import DesignPreProcessingModel
from app.service.design.utils.design_ensemble import get_keypoint_result
from app.service.utils.oss_client import oss_get_image, oss_upload_image
class DesignPreprocessing:
# def __init__(self):
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# @ RunTime
def pipeline(self, image_list):
sketches_list = self.read_image(image_list)
logging.info("read image success")
bounding_box_sketches_list = self.bounding_box(sketches_list)
logging.info("bounding box image success")
super_resolution_list = self.super_resolution(bounding_box_sketches_list)
logging.info("super_resolution_list image success")
infer_sketches_list = self.infer_image(super_resolution_list)
logging.info("infer image success")
result = self.composing_image(infer_sketches_list)
logging.info("Replenish white edge image success")
for d in result:
if 'image_obj' in d:
del d['image_obj']
if 'obj' in d:
del d['obj']
if 'keypoint_result' in d:
del d['keypoint_result']
return result
def read_image(self, image_list):
for obj in image_list:
# file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # 如果是四通道 mask
image = image[:, :, :3]
obj["image_obj"] = image
return image_list
# @ RunTime
def bounding_box(self, image_list):
for item in image_list:
image = item['image_obj']
# 使用Canny边缘检测来检测物体的轮廓
edges = cv2.Canny(image, 50, 150)
# 查找轮廓
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 初始化包围所有外接矩形的大矩形的坐标
x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1
# 遍历所有外接矩形,更新大矩形的坐标
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
x_min = min(x_min, x)
y_min = min(y_min, y)
x_max = max(x_max, x + w)
y_max = max(y_max, y + h)
if IF_DEBUG_SHOW:
image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
cv2.imshow("bounding_box image", image_with_big_rect)
cv2.waitKey(0)
# 根据大矩形的坐标来裁剪原始图像
if len(contours) > 0:
cropped_image = image[y_min:y_max, x_min:x_max]
item['obj'] = cropped_image # 新shape图像
# 取消直接覆盖新增size判断
# try:
# # 覆盖到minio
# image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes()
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
# print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
# except ResponseError as err:
# print(f"Error: {err}")
else:
item['obj'] = image
return image_list
def super_resolution(self, image_list):
for item in image_list:
# 判断 两边是否同时都小于512 因为此处做四倍超分
if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512:
# 如果任意一边小于256则超分
if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256:
# 超分
img = item['obj'].astype(np.float32) / 255.
sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample)
triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
result_image = result.as_numpy(f'output')[0]
sr_output = torch.tensor(result_image)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
item['obj'] = output
try:
# 覆盖到minio
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
bucket_name = item['image_url'].split("/", 1)[0]
object_name = item['image_url'].split("/", 1)[1]
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
except ResponseError as err:
print(f"Error: {err}")
return image_list
# @ RunTime
def infer_image(self, image_list):
for sketch in image_list:
# 小写
image_category = sketch['image_category'].lower()
# 判断上下装
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 推理得到keypoint
sketch['keypoint_result'] = self.keypoint_cache(sketch)
if IF_DEBUG_SHOW:
debug_show_image = sketch['obj'].copy()
points_list = []
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for i in sketch['keypoint_result'].values():
points_list.append((int(i[1]), int(i[0])))
for point in points_list:
cv2.circle(debug_show_image, point, point_size, point_color, thickness)
cv2.imshow("", debug_show_image)
cv2.waitKey(0)
# # 关键点在上部则推理seg
# if sketch["site"] == "up":
# # 判断seg缓存是否存在,是否与当前图片shape一致
# seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape)
# if seg_result is False:
# # 推理seg + 保存
# seg_result = get_seg_result(sketch['image_id'], sketch['obj'])
return image_list
# @ RunTime
def composing_image(self, image_list):
for image in image_list:
''' 比例相同 整合上下装代码'''
image_width = image['obj'].shape[1]
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
scale = 0.4
if waist_width / scale >= image_width:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
if IF_DEBUG_SHOW:
cv2.imshow("composing_image", ret)
cv2.waitKey(0)
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.')
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
image['show_image_url'] = f"{bucket_name}/{object_name}"
else:
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.')
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
image['show_image_url'] = f"{bucket_name}/{object_name}"
# if image['site'] == 'down':
# image_width = image['obj'].shape[1]
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
# scale = 0.4
# if waist_width / scale >= image_width:
# add_width = int((waist_width / scale - image_width) / 2)
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
# if IF_DEBUG_SHOW:
# cv2.imshow("composing_image", ret)
# cv2.waitKey(0)
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_width = image['obj'].shape[1]
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
# scale = 0.4
# if waist_width / scale >= image_width:
# add_width = int((waist_width / scale - image_width) / 2)
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
# if IF_DEBUG_SHOW:
# cv2.imshow("composing_image", ret)
# cv2.waitKey(0)
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
# else:
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
# bucket_name = image['image_url'].split('/', 1)[0]
# object_name = image['image_url'].split('/', 1)[1]
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
# image['show_image_url'] = f"{bucket_name}/{object_name}"
return image_list
@staticmethod
def select_seg_result(image_id, image_obj):
try:
# 如果shape不匹配 返回false
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
return result
else:
return False
except FileNotFoundError as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
@staticmethod
def search_seg_result(image_id, ori_shape):
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
# collection.load()
# start_time = time.time()
# res = collection.query(
# expr=f"seg_id == {image_id}",
# offset=0,
# limit=10,
# output_fields=["seg_cache"],
# )
# logging.info(f"search seg cache time {time.time() - start_time}")
# if len(res):
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
# array_2d_exact = array_2d_exact.squeeze().numpy()
# return array_2d_exact
# else:
return False
except Exception as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
def keypoint_cache(self, sketch):
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# collection.load()
start_time = time.time()
# res = collection.query(
# expr=f"keypoint_id == {sketch['image_id']}",
# offset=0,
# limit=1,
# output_fields=["keypoint_cache", "keypoint_site"],
# )
res = []
logging.info(f"search keypoint time : {time.time() - start_time}")
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch)
return self.save_keypoint_cache(sketch, keypoint_infer_result)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != sketch['site']:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result = self.infer_keypoint_result(sketch)
return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector'])
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False
# @ RunTime
def infer_keypoint_result(self, sketch):
keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果
return keypoint_infer_result
@staticmethod
# @ RunTime
def save_keypoint_cache(sketch, keypoint_infer_result):
if sketch['site'] == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, keypoint_infer_result.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([keypoint_infer_result.flatten(), zeros])
data = [
[int(sketch['image_id'])],
[sketch['site']],
[result.tolist()]
]
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.insert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(sketch, infer_result, search_result):
if sketch['site'] == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
[int(sketch['image_id'])],
["all"],
[result.tolist()]
]
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
if __name__ == '__main__':
data = {
"sketches": [
{
"image_category": "dress",
"image_id": "107903",
"image_url": "aida-sys-image/images/female/dress/0628000000.jpg"
}
]
}
request_data = DesignPreProcessingModel(sketches=data["sketches"])
server = DesignPreprocessing()
data = server.pipeline(image_list=request_data.sketches)
print(data)

View File

@@ -10,21 +10,19 @@
import json
import logging
import time
from io import BytesIO
import cv2
import minio
import numpy as np
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
@@ -36,22 +34,23 @@ class GenerateImage:
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
if request_data.mode == "img2img":
# cv2 读图片是BGR PIL读图片是RGB
self.image = self.get_image(request_data.image_url)
self.prompt = request_data.prompt
else:
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.prompt = request_data.prompt
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.mode = request_data.mode
self.batch_size = 1
self.category = request_data.category
if self.category == "sketch":
self.prompt = f"{self.category},{self.prompt}"
self.index = 0
self.gender = request_data.gender
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''}
@@ -63,10 +62,13 @@ class GenerateImage:
# Read data from response.
# read image use cv2
try:
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_file = BytesIO(response.data)
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="cv2")
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image = cv2.resize(image_rbg, (1024, 1024))
except minio.error.S3Error:
@@ -104,7 +106,7 @@ class GenerateImage:
image_result = not_smudge_image
if is_smudge: # 无污点
# image_result = adjust_contrast(image_result)
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.generate_data['status'] = "SUCCESS"
self.generate_data['message'] = "success"
@@ -121,13 +123,6 @@ class GenerateImage:
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
@@ -147,7 +142,7 @@ class GenerateImage:
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
ctx = self.infer(inputs)
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
generate_data = None
while time_out > 0:
@@ -187,9 +182,10 @@ if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
image_url="",
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
mode='txt2img',
category="test"
category="test",
gender="male"
)
server = GenerateImage(rd)
print(server.get_result())

View File

@@ -0,0 +1,187 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import cv2
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from PIL import Image, ImageOps
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateProductImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class GenerateProductImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
self.image_strength = request_data.image_strength
self.batch_size = 1
self.product_type = request_data.product_type
self.prompt = request_data.prompt
self.image, self.image_size = pre_processing_image(request_data.image_url)
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
if self.product_type == "single":
image = result.as_numpy("generated_cnet_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
self.image = cv2.resize(self.image, (512, 768))
images = [self.image.astype(np.uint8)] * self.batch_size
if self.product_type == "single":
text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
else:
text_obj = np.array(prompts, dtype="object").reshape(1)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
# 假设 prompts、images 和 self.image_strength 已经定义
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
inputs = [input_text, input_image, input_image_strength]
input_image_strength.set_data_from_numpy(image_strength_obj)
if self.product_type == "single":
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
else:
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
def pre_processing_image(image_url):
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
# 原始图片的尺寸
width, height = image.size
# 计算长宽比为 3:2 的新尺寸
desired_ratio = 2 / 3
current_ratio = width / height
if current_ratio > desired_ratio:
# 原始图片更宽,需要在上下添加 padding
new_width = width
new_height = int(width / desired_ratio)
else:
# 原始图片更高或者长宽比已经为 3:2
new_height = height
new_width = int(height * desired_ratio)
# 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
# 将原始图片粘贴到新的画布中心
left = (new_width - width) // 2
top = (new_height - height) // 2
pad_image.paste(image, (left, top))
# 将画布 resize 成宽度 500长度 750
resized_image = pad_image.resize((500, 750))
image_size = (512, 768)
if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
# 创建白色背景
background = Image.new("RGB", image_size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(resized_image, mask=resized_image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image, image_size
if __name__ == '__main__':
rd = GenerateProductImageModel(
tasks_id="123-89",
# prompt="",
image_strength=0.9,
prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
product_type="overall"
)
server = GenerateProductImage(rd)
print(server.get_result())

View File

@@ -0,0 +1,159 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import cv2
import numpy as np
import redis
import tritonclient.grpc as grpcclient
from PIL import Image
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
from app.service.utils.oss_client import oss_get_image
logger = logging.getLogger()
class GenerateRelightImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "relight_image"
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "1"
self.product_type = request_data.product_type
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.direction = request_data.direction
self.image_url = request_data.image_url
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
if self.product_type == 'single':
image = result.as_numpy("generated_relight_image")
else:
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (512, 768))
images = [image.astype(np.uint8)] * self.batch_size
seeds = [self.seed] * self.batch_size
nagetive_prompts = [self.negative_prompt] * self.batch_size
directions = [self.direction] * self.batch_size
if self.product_type == 'single':
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
else:
text_obj = np.array(prompts, dtype="object").reshape((1))
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
seed_obj = np.array(seeds, dtype="object").reshape((1))
direction_obj = np.array(directions, dtype="object").reshape((1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_natext.set_data_from_numpy(na_text_obj)
input_seed.set_data_from_numpy(seed_obj)
input_direction.set_data_from_numpy(direction_obj)
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
if self.product_type == 'single':
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
else:
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GRI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
if __name__ == '__main__':
rd = GenerateRelightImageModel(
tasks_id="123-89",
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
prompt="Colorful black",
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
direction="Right Light",
product_type="single"
)
server = GenerateRelightImage(rd)
print(server.get_result())

View File

@@ -0,0 +1,119 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import cv2
import numpy as np
import redis
from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
import tritonclient.grpc as grpcclient
from app.schemas.generate_image import GenerateSingleLogoImageModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
logger = logging.getLogger()
class GenerateSingleLogoImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.batch_size = 1
self.category = "single_logo"
self.negative_prompts = "bad, ugly"
self.seed = request_data.seed
self.tasks_id = request_data.tasks_id
self.prompt = request_data.prompt
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
self.redis_client.expire(self.tasks_id, 600)
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def callback(self, result, error):
if error:
self.gen_single_logo_data['status'] = "FAILURE"
self.gen_single_logo_data['message'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
else:
image = result.as_numpy("generated_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
self.gen_single_logo_data['status'] = "SUCCESS"
self.gen_single_logo_data['message'] = "success"
self.gen_single_logo_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
def get_result(self):
try:
# prompt
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
input_text_neg.set_data_from_numpy(text_obj_neg)
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype))
input_seed.set_data_from_numpy(seed)
inputs = [input_text, input_text_neg, input_seed]
ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
return generate_data
except Exception as e:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateSingleLogoImageModel(
tasks_id="123-89",
prompt='an apple',
seed="2",
)
server = GenerateSingleLogoImage(rd)
print(server.get_result())

View File

@@ -381,7 +381,7 @@ if __name__ == '__main__':
remove_bg_img = remove_background(luminance)
# cv2.imwrite("remove_bg_img.png", remove_bg_img)
print(1)
# print(1)
cv2.imshow("source", img)
cv2.imshow("levels", equAuto)
cv2.imshow("luminance", luminance)

View File

@@ -10,26 +10,61 @@
import io
import logging
# import boto3
import cv2
from PIL import Image
from minio import Minio
from app.core.config import *
from app.service.utils.oss_client import oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
def upload_png_sd(image, user_id, category, object_name):
# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
# def upload_single_logo(image, user_id, category, object_name):
# with io.BytesIO() as output:
# image.save(output, format='PNG')
# data = output.getvalue()
# # 创建一个 S3 客户端
# try:
# key = f'{user_id}/{category}/{object_name}'
# image_url = f"{AIDA_CLOTHING}/{key}"
# s3.put_object(Bucket=GSL_MINIO_BUCKET, Key=key, Body=data, ContentType='image/png')
# return image_url
# except Exception as e:
# print(f'上传到 S3 失败: {e}')
def upload_SDXL_image(image, user_id, category, file_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# minio_req = minio_client.put_object(
# GI_MINIO_BUCKET,
# f'{user_id}/{category}/{file_name}',
# io.BytesIO(image_bytes),
# len(image_bytes),
# content_type='image/jpeg'
# )
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_png_sd(image, user_id, category, file_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
minio_req = minio_client.put_object(
GI_MINIO_BUCKET,
f'{user_id}/{category}/{object_name}',
io.BytesIO(img_byte_array),
len(img_byte_array),
content_type='image/jpeg'
)
image_url = f"aida-users/{minio_req.object_name}"
object_name = f'{user_id}/{category}/{file_name}'
req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -0,0 +1,77 @@
import logging
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from app.core.config import OPENAI_MODEL, OPENAI_API_KEY
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
llm = ChatOpenAI(model_name=OPENAI_MODEL,
openai_api_key=OPENAI_API_KEY,
temperature=0)
def translate_to_en(text):
template = (
"""You are a translation expert, proficient in various languages.
And can translate various languages into English.
Please translate to grammatically correct English regardless of the input language.
If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1",
output the input text exactly as it is without any modifications or additions.
If there are grammatical errors, correct them and then output the sentence."""
)
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
# 待翻译文本由 Human 角色输入
human_template = "User input : {text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template)
# 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate
chat_prompt_template = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template)
result = translate_chain.invoke(text)
logging.info("translate result : " + result.get('text'))
# print("translate result : " + result.get('text'))
return result.get('text')
# template = (
# """
# Input sentence:
# {translate}
# 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images,
# ensuring all key nouns or adjectives related to the image are retained.
# 2. Simplify complex sentence structures and clarify ambiguous expressions.
# 3. Only Output the adjusted English sentence.
#
# Output :
# """
# )
# # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words."
# # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding.
#
# prompt_template = PromptTemplate(input_variables=["translate"], template=template)
# prompt_chain = LLMChain(llm=llm, prompt=prompt_template)
#
# from langchain.chains import SimpleSequentialChain
# overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True)
#
# response = overall_chain.run(text)
# return response
def main():
"""Main function"""
text = translate_to_en("fire")
print(text)
if __name__ == "__main__":
main()

View File

@@ -1,17 +1,17 @@
import io
import json
import logging
import time
import minio.error
import redis
import json
import cv2
import minio.error
import numpy as np
import redis
import torch
import tritonclient.grpc as grpcclient
from minio import Minio
from app.core.config import *
from app.schemas.super_resolution import SuperResolutionModel
from app.service.utils.decorator import RunTime
from app.service.utils.oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
@@ -24,7 +24,7 @@ class SuperResolution:
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.sr_image_url = data.sr_image_url
self.sr_xn = data.sr_xn
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
self.redis_client.expire(self.tasks_id, 600)
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
@@ -33,36 +33,37 @@ class SuperResolution:
# @RunTime
def read_image(self):
try:
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2")
img = img.astype(np.float32) / 255. # 解码
except minio.error.S3Error as e:
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
logger.info(f" [x] Sent {sr_data}")
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
return img
# try:
# image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
# except minio.error.S3Error as e:
# sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
# self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
# logger.info(f" [x] Sent {sr_data}")
# raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
# img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
# return img
def read_tasks_status(self):
status_data = json.loads(self.redis_client.get(self.tasks_id))
logging.info(f"{self.tasks_id} ===> {status_data}")
return status_data
# @RunTime
def infer(self, inputs):
return self.triton_client.async_infer(
model_name=SR_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
# @RunTime
def sr_result(self):
sample = self.read_image()
if self.sr_xn == 2:
new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn)
sample = cv2.resize(sample, new_shape)
print(new_shape)
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
@@ -72,13 +73,16 @@ class SuperResolution:
# , binary_data=True
)
ctx = self.infer(inputs)
ctx = self.triton_client.async_infer(
model_name=SR_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
time_out = 60
while time_out > 0:
generate_data = self.read_tasks_status()
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
# noinspection PyTypeChecker
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data))
logger.info(f" [x] Sent {generate_data}")
break
@@ -88,28 +92,19 @@ class SuperResolution:
time.sleep(1)
return self.read_tasks_status()
# results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
# sr_output = torch.from_numpy(results.as_numpy(f"output"))
# output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
# if output.ndim == 3:
# output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
# output = (output * 255.0).round().astype(np.uint8)
# output_url = self.upload_img_sr(output, generate_uuid())
# return output_url
def upload_img_sr(self, image):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
image_url = f"aida-users/{res.object_name}"
# res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg'
oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
image_url = f"{SR_MINIO_BUCKET}/{object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
def callback(self, result, error):
if error:
print(error)
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
self.redis_client.set(self.tasks_id, sr_info_data)
else:
@@ -135,6 +130,6 @@ def infer_cancel(tasks_id):
if __name__ == '__main__':
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123")
request_data = SuperResolutionModel(sr_image_url="aida-users/83/print/b77bf4ca-6ca2-44a1-9040-505f359a974c-3-83.png", sr_xn=2, sr_tasks_id="12341556")
service = SuperResolution(request_data)
result_url = service.sr_result()

View File

@@ -0,0 +1,74 @@
import io
import logging
from io import BytesIO
import boto3
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import *
logger = logging.getLogger()
# 获取图片
def oss_get_image(bucket, object_name, data_type):
# cv2 默认全通道读取
image_object = None
try:
if OSS == "minio":
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
else:
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body']
if data_type == "cv2":
image_bytes = image_data.read()
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
else:
data_bytes = BytesIO(image_data.read())
image_object = Image.open(data_bytes)
except Exception as e:
logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}")
return image_object
def oss_upload_image(bucket, object_name, image_bytes):
req = None
try:
if OSS == "minio":
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
else:
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
req = oss_client.put_object(Bucket=bucket, Key=object_name, Body=io.BytesIO(image_bytes), ContentType='image/png')
except Exception as e:
logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}")
return req
if __name__ == '__main__':
# url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png"
# url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg"
# url = "aida-sys-image/images/female/outwear/0628000054.jpg"
# url = "aida-users/89/product_image/string-89.png"
# url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png"
# url = 'aida-users/89/relight_image/123-89.png'
# url = 'aida-users/89/relight_image/123-89.png'
# url = 'aida-users/89/relight_image/123-89.png'
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png"
# url = "aida-users/89/product_image/string-89.png"
url = "aida-results/result_c6520ce7-33a1-11ef-a8d3-b0dcefbff887.png"
read_type = "PIL"
if read_type == "cv2":
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
cv2.imshow("", img)
cv2.waitKey(0)
else:
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
img.show()

View File

@@ -1,51 +1,51 @@
from app.core.config import LOGS_PATH
LOGGER_CONFIG_DICT = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"simple": {"format": "%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s"}
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'simple': {'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "DEBUG",
"formatter": "simple",
"stream": "ext://sys.stdout",
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'INFO',
'formatter': 'simple',
'stream': 'ext://sys.stdout',
},
"info_file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "INFO",
"formatter": "simple",
"filename": f"{LOGS_PATH}info.log",
"maxBytes": 10485760,
"backupCount": 50,
"encoding": "utf8",
'info_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'INFO',
'formatter': 'simple',
'filename': f'{LOGS_PATH}info.log',
'maxBytes': 10485760,
'backupCount': 50,
'encoding': 'utf8',
},
"error_file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "ERROR",
"formatter": "simple",
"filename": f"{LOGS_PATH}error.log",
"maxBytes": 10485760,
"backupCount": 20,
"encoding": "utf8",
'error_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'ERROR',
'formatter': 'simple',
'filename': f'{LOGS_PATH}error.log',
'maxBytes': 10485760,
'backupCount': 20,
'encoding': 'utf8',
},
"debug_file_handler": {
"class": "logging.handlers.RotatingFileHandler",
"level": "DEBUG",
"formatter": "simple",
"filename": f"{LOGS_PATH}debug.log",
"maxBytes": 10485760,
"backupCount": 50,
"encoding": "utf8",
'debug_file_handler': {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'DEBUG',
'formatter': 'simple',
'filename': f'{LOGS_PATH}debug.log',
'maxBytes': 10485760,
'backupCount': 50,
'encoding': 'utf8',
},
},
"loggers": {
"my_module": {"level": "INFO", "handlers": ["console"], "propagate": "no"}
'loggers': {
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
},
"root": {
"level": "INFO",
"handlers": ["error_file_handler", "info_file_handler", "debug_file_handler", "console"],
'root': {
'level': 'INFO',
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
},
}

Binary file not shown.