Merge remote-tracking branch 'refs/remotes/origin/develop'
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from fastapi import APIRouter
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.core.config import DEBUG
|
||||
from app.schemas.attribute_retrieve import *
|
||||
from app.service.attribute.config import const
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.attribute.config import const, local_debug_const
|
||||
from app.service.attribute.service_att_recognition import AttributeRecognition
|
||||
from app.service.attribute.service_category_recognition import CategoryRecognition
|
||||
|
||||
@@ -12,34 +15,63 @@ logger = logging.getLogger()
|
||||
|
||||
|
||||
# 属性识别
|
||||
@router.post("/attribute_recognition")
|
||||
@router.post("/attribute_recognition", response_model=ResponseModel)
|
||||
def attribute_recognition(request_item: list[AttributeRecognitionModel]):
|
||||
"""
|
||||
获取sketch的属性,collar sleeve_length 等等
|
||||
创建一个具有以下参数的请求体:
|
||||
- **category**: sketch的类别 ,Dress
|
||||
- **colony**: 服装类别,男装或女装
|
||||
- **sketch_img_url**: 被提取属性的S3或minio url地址
|
||||
|
||||
示例参数:
|
||||
[
|
||||
{
|
||||
"category": "Dress",
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
|
||||
}
|
||||
]
|
||||
"""
|
||||
try:
|
||||
service = AttributeRecognition(const=const, request_data=request_item)
|
||||
for item in request_item:
|
||||
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||
if DEBUG:
|
||||
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
|
||||
else:
|
||||
service = AttributeRecognition(const=const, request_data=request_item)
|
||||
data = service.get_result()
|
||||
code = 200
|
||||
message = "access"
|
||||
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
code = 400
|
||||
message = e
|
||||
data = e
|
||||
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
|
||||
return {"code": code, "message": message, "data": data}
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data={"list": data})
|
||||
|
||||
|
||||
# 类别识别
|
||||
@router.post("/category_recognition")
|
||||
def category_recognition(request_item: list[CategoryRecognitionModel]):
|
||||
"""
|
||||
获取sketch的类别,dress blouse 等等
|
||||
创建一个具有以下参数的请求体:
|
||||
- **colony**: 服装类别,male或Female
|
||||
- **sketch_img_url**: 被提取sketch类别的S3或minio url地址
|
||||
|
||||
示例参数:
|
||||
[
|
||||
{
|
||||
"colony": "Female",
|
||||
"sketch_img_url": "aida-users/89/sketchboard/female/Dress/ae976103-d7ec-4eed-b5d1-3e5f04d8be26.jpg"
|
||||
}
|
||||
]
|
||||
"""
|
||||
try:
|
||||
for item in request_item:
|
||||
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||
service = CategoryRecognition(request_data=request_item)
|
||||
data = service.get_result()
|
||||
code = 200
|
||||
message = "access"
|
||||
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
code = 400
|
||||
message = e
|
||||
data = e
|
||||
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
|
||||
return {"code": code, "message": message, "data": data}
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
39
app/api/api_chat_robot.py
Normal file
39
app/api/api_chat_robot.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.chat_robot import ChatRobotModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.chat_robot.script.main import chat
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/chat_robot")
|
||||
def chat_robot(request_data: ChatRobotModel):
|
||||
"""
|
||||
对话机器人
|
||||
创建一个具有以下参数的请求体:
|
||||
- **gender**: 服装类别
|
||||
- **message**: 消息
|
||||
- **session_id**: 会话id
|
||||
- **user_id**: 用户id
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"gender": "male",
|
||||
"message": "你好",
|
||||
"session_id": "string-89",
|
||||
"user_id": 89
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
data = chat(post_data=request_data)
|
||||
logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
195
app/api/api_design.py
Normal file
195
app/api/api_design.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.design.model_process_service import model_transpose
|
||||
from app.service.design.service import generate
|
||||
from app.service.design.utils.redis_utils import Redis
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/design")
|
||||
def design(request_data: DesignModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
示例参数:
|
||||
{
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
203,
|
||||
249
|
||||
],
|
||||
"hand_point_right": [
|
||||
229,
|
||||
343
|
||||
],
|
||||
"waistband_left": [
|
||||
119,
|
||||
248
|
||||
],
|
||||
"hand_point_left": [
|
||||
97,
|
||||
343
|
||||
],
|
||||
"shoulder_left": [
|
||||
108,
|
||||
107
|
||||
],
|
||||
"shoulder_right": [
|
||||
212,
|
||||
107
|
||||
]
|
||||
},
|
||||
"layer_order": true,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 255303,
|
||||
"color": "139 148 156",
|
||||
"image_id": 95159,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-users/89/sketch/c89d75f3-581f-4edd-9f8e-b08e84a2cbe7-3-89.png",
|
||||
"print": {
|
||||
"single": {
|
||||
"location": [
|
||||
[
|
||||
200.0,
|
||||
200.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [
|
||||
"aida-users/89/slogan_image/ce0b2423-9e5a-466f-9611-c254940a7819-1-89.png"
|
||||
],
|
||||
"print_scale_list": [
|
||||
1.0
|
||||
]
|
||||
},
|
||||
"overall": {
|
||||
"location": [
|
||||
[
|
||||
512.0,
|
||||
512.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [
|
||||
"aida-users/89/print/468643b4-bc2d-41b2-9a16-79766606a2db-3-89.png"
|
||||
],
|
||||
"print_scale_list": [
|
||||
1.0
|
||||
]
|
||||
},
|
||||
"element": {
|
||||
"element_angle_list": [
|
||||
0.0
|
||||
],
|
||||
"element_path_list": [
|
||||
"aida-users/88/designelements/Embroidery/a4d9605a-675e-4606-93e0-77ca6baaf55f.png"
|
||||
],
|
||||
"element_scale_list": [
|
||||
0.2731036750637755
|
||||
],
|
||||
"location": [
|
||||
[
|
||||
228.63694825464364,
|
||||
406.4843844199667
|
||||
]
|
||||
]
|
||||
}
|
||||
},
|
||||
"priority": 10,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Dress"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||
"image_id": 67277,
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"process_id": "89"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
data = generate(request_data=request_data)
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
@router.post('/get_progress')
|
||||
def get_progress(request_data: DesignProgressModel):
|
||||
"""
|
||||
获取design 进度
|
||||
创建一个具有以下参数的请求体:
|
||||
- **process_id**: 进度id
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"process_id": "6878547032381675"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
process_id = request_data.process_id
|
||||
r = Redis()
|
||||
data = r.read(key=process_id)
|
||||
if data is None:
|
||||
raise ValueError(f"No progress ID: {process_id}")
|
||||
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
@router.post('/model_process')
|
||||
def model_process(request_data: ModelProgressModel):
|
||||
"""
|
||||
获取模特图片预处理
|
||||
创建一个具有以下参数的请求体:
|
||||
- **model_path**: 模特图片的minio或s3 url地址
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"model_path": "aida-users/10/models/female/9c788f5b-b8c7-424c-b149-025aeb0bda51model.jpg"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
|
||||
data = model_transpose(image_path=request_data.model_path)
|
||||
logger.info(f"model_process response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"model_process Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
40
app/api/api_design_pre_processing.py
Normal file
40
app/api/api_design_pre_processing.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.pre_processing import DesignPreProcessingModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.design_pre_processing.service import DesignPreprocessing
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/design_pre_processing")
|
||||
def design_pre_processing(request_data: DesignPreProcessingModel):
|
||||
"""
|
||||
design 预处理 获取sketch的基本信息
|
||||
创建一个具有以下参数的请求体:
|
||||
- **sketches**: sketch url等信息
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"sketches": [
|
||||
{
|
||||
"image_category": "dress",
|
||||
"image_id": "107903",
|
||||
"image_url": "aida-sys-image/images/female/dress/0628000000.jpg"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
server = DesignPreprocessing()
|
||||
data = server.pipeline(image_list=request_data.sketches)
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
@@ -1,28 +1,189 @@
|
||||
import json
|
||||
import logging
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from app.schemas.generate_image import GenerateImageModel
|
||||
from app.service.generate_image.service import GenerateImage, infer_cancel
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
|
||||
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
|
||||
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
|
||||
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
'''generate image'''
|
||||
|
||||
|
||||
@router.post("/generate_image")
|
||||
def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 图生图的输入,minio或S3 url 地址
|
||||
- **mode**: 生成模式,img2img或者txt2img
|
||||
- **category**: 生成图片的类别,sketch print 等等
|
||||
- **gender**: 生成sketch专用,服装类别
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic",
|
||||
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
|
||||
"mode": "img2img",
|
||||
"category": "sketch",
|
||||
"gender": "male"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"request data ### : {request_item}")
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
code = 200
|
||||
message = "access"
|
||||
except Exception as e:
|
||||
code = 400
|
||||
message = e
|
||||
logger.warning(e)
|
||||
return {"code": code, "message": message}
|
||||
logger.warning(f"generate_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_cancel/{tasks_id}>")
|
||||
def generate_image(tasks_id):
|
||||
result = infer_cancel(tasks_id)
|
||||
return {"code": 200, "message": result['message'], "data": result['data']}
|
||||
@router.get("/generate_cancel/{tasks_id}")
|
||||
def generate_image(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_image_infer_cancel(tasks_id)
|
||||
logger.info(f"generate_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''single logo'''
|
||||
|
||||
|
||||
@router.post("/generate_single_logo")
|
||||
def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **seed**: 固定的prompt和固定的seed 每次的生成结果都是一样的
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "an apple",
|
||||
"seed": "2"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateSingleLogoImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_single_logo Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_single_logo_cancel/{tasks_id}")
|
||||
def generate_single_logo_image(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_single_logo_cancel(tasks_id)
|
||||
logger.info(f"generate_single_logo_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''product image'''
|
||||
|
||||
|
||||
@router.post("/generate_product_image")
|
||||
def generate_product_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **image_strength**: 生成强度,越低越接近原图
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||
"image_strength": 0.8,
|
||||
"product_type": "overall"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateProductImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_product_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_product_image_cancel_cancel/{tasks_id}")
|
||||
def generate_product_image(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_product_image_cancel(tasks_id)
|
||||
logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''relight image'''
|
||||
|
||||
|
||||
@router.post("/generate_relight_image")
|
||||
def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"direction": "Right Light",
|
||||
"product_type": "overall"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateRelightImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
|
||||
def generate_relight_image(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_relight_image_cancel(tasks_id)
|
||||
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
34
app/api/api_prompt_generation.py
Normal file
34
app/api/api_prompt_generation.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.prompt_generation import PromptGenerationImageModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.prompt_generation.chatgpt_for_translation import translate_to_en
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/translateToEN")
|
||||
def prompt_generation(request_data: PromptGenerationImageModel):
|
||||
"""
|
||||
翻译prompt接口
|
||||
创建一个具有以下参数的请求体:
|
||||
- **text**: 待翻译语句
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"text": "你好"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
|
||||
data = translate_to_en(request_data.text)
|
||||
logger.info(f"prompt_generation response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
@@ -4,6 +4,11 @@ from app.api import api_test
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_attribute_retrieve
|
||||
from app.api import api_design
|
||||
from app.api import api_chat_robot
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_design_pre_processing
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -11,3 +16,7 @@ router.include_router(api_test.router, tags=["test"], prefix="/test")
|
||||
router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
|
||||
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")
|
||||
router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api")
|
||||
router.include_router(api_design.router, tags=['design'], prefix="/api")
|
||||
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
|
||||
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
|
||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.schemas.super_resolution import SuperResolutionModel
|
||||
from app.service.super_resolution.service import SuperResolution, infer_cancel
|
||||
|
||||
@@ -12,19 +13,36 @@ logger = logging.getLogger()
|
||||
|
||||
@router.post("/super_resolution")
|
||||
def super_resolution(request_item: SuperResolutionModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **sr_image_url**: 超分图片的minio或s3 url地址
|
||||
- **sr_xn**: 超分的倍数,只接受2或4
|
||||
- **sr_tasks_id**: 任务id 用于取消超分任务和获取超分结果
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"sr_image_url": "aida-sys-image/images/female/blouse/0628000098.jpg",
|
||||
"sr_xn": 2,
|
||||
"sr_tasks_id": "12341556-89"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = SuperResolution(request_item)
|
||||
background_tasks.add_task(service.sr_result)
|
||||
code = 200
|
||||
message = "access"
|
||||
except Exception as e:
|
||||
code = 400
|
||||
message = e
|
||||
logger.warning(e)
|
||||
return {"code": code, "message": message}
|
||||
logger.warning(f"super_resolution Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/sr_cancel/{tasks_id}>")
|
||||
def super_resolution(tasks_id):
|
||||
result = infer_cancel(tasks_id)
|
||||
return {"code": 200, "message": result['message'], "data": result['data']}
|
||||
def super_resolution(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = infer_cancel(tasks_id)
|
||||
logger.info(f"sr_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"sr_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
@@ -1,13 +1,27 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS
|
||||
from app.schemas.response_template import ResponseModel
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
def test():
|
||||
logger.info(SR_RABBITMQ_QUEUES)
|
||||
logger.info("test")
|
||||
return {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES}
|
||||
@router.get("{id}")
|
||||
def test(id: int):
|
||||
data = {
|
||||
"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES,
|
||||
"GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||
"GPI_RABBITMQ_QUEUES": GPI_RABBITMQ_QUEUES,
|
||||
"GRI_RABBITMQ_QUEUES": GRI_RABBITMQ_QUEUES,
|
||||
"local_oss_server": OSS
|
||||
}
|
||||
logger.info(json.dumps(data))
|
||||
if id == 1:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
return ResponseModel(data=data)
|
||||
|
||||
@@ -19,15 +19,16 @@ class Settings(BaseSettings):
|
||||
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
|
||||
|
||||
OSS = "minio"
|
||||
DEBUG = False
|
||||
if DEBUG:
|
||||
LOGS_PATH = "logs/"
|
||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||
FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
|
||||
# FACE_CLASSIFIER = "service/generate_image/utils/haarcascade_frontalface_alt.xml"
|
||||
else:
|
||||
LOGS_PATH = "app/logs/"
|
||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||
FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
|
||||
# FACE_CLASSIFIER = 'app/service/generate_image/utils/haarcascade_frontalface_alt.xml'
|
||||
|
||||
RABBITMQ_ENV = "" # 生产环境
|
||||
# RABBITMQ_ENV = "-dev" # 开发环境
|
||||
@@ -41,6 +42,11 @@ MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
|
||||
# S3 配置
|
||||
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
|
||||
S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8"
|
||||
S3_REGION_NAME = "ap-east-1"
|
||||
|
||||
# redis 配置
|
||||
REDIS_HOST = "10.1.1.240"
|
||||
REDIS_PORT = "6379"
|
||||
@@ -55,12 +61,38 @@ RABBITMQ_PARAMS = {
|
||||
}
|
||||
|
||||
# milvus 配置
|
||||
MILVUS_DB_HOST = "10.1.1.240"
|
||||
MILVUS_URL = "http://10.1.1.240:19530"
|
||||
MILVUS_TOKEN = "root:Milvus"
|
||||
MILVUS_ALIAS = "default"
|
||||
MILVUS_PORT = "19530"
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
|
||||
MILVUS_TABLE_SEG = "seg_cache"
|
||||
|
||||
# Mysql 配置
|
||||
DB_HOST = '18.167.251.121' # 数据库主机地址
|
||||
# DB_PORT = int( 33006)
|
||||
DB_PORT = 33008 # 数据库端口
|
||||
DB_USERNAME = 'aida_con_python' # 数据库用户名
|
||||
DB_PASSWORD = '123456' # 数据库密码
|
||||
DB_NAME = 'aida' # 数据库库名
|
||||
|
||||
# openai
|
||||
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
|
||||
OPENAI_STREAM = True
|
||||
BUFFER_THRESHOLD = 6 # must be even number
|
||||
SINGLE_TOKEN_THRESHOLD = 200
|
||||
TOKEN_THRESHOLD = 600
|
||||
OPENAI_TEMPERATURE = 0
|
||||
|
||||
# OPENAI_API_KEY = "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"
|
||||
OPENAI_API_KEY = "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"
|
||||
OPENAI_MODEL = "gpt-3.5-turbo-0613"
|
||||
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613", }
|
||||
|
||||
# attribute service config
|
||||
ATT_TRITON_URL = "10.1.1.240:10000"
|
||||
|
||||
@@ -77,6 +109,28 @@ GI_MINIO_BUCKET = "aida-users"
|
||||
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
|
||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||
|
||||
# SLOGAN service config
|
||||
SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}")
|
||||
|
||||
# Generate Single Logo service config
|
||||
GSL_MODEL_URL = '10.1.1.240:10041'
|
||||
GSL_MINIO_BUCKET = "aida-users"
|
||||
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
|
||||
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}")
|
||||
|
||||
# Generate Product service config
|
||||
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
|
||||
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
|
||||
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
|
||||
|
||||
GPI_MODEL_URL = '10.1.1.240:10041'
|
||||
|
||||
# Generate Single Logo service config
|
||||
GRI_RABBITMQ_QUEUES = os.getenv("GEN_RELIGHT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
|
||||
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||
|
||||
# SEG service config
|
||||
SEG_MODEL_URL = '10.1.1.240:10000'
|
||||
SEGMENTATION = {
|
||||
@@ -87,9 +141,13 @@ SEGMENTATION = {
|
||||
}
|
||||
|
||||
# DESIGN config
|
||||
DESIGN_MODEL_URL = '10.1.1.240:9000'
|
||||
|
||||
DESIGN_MODEL_URL = '10.1.1.240:10000'
|
||||
AIDA_CLOTHING = "aida-clothing"
|
||||
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
|
||||
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
||||
|
||||
# DESIGN 预处理
|
||||
IF_DEBUG_SHOW = False
|
||||
|
||||
# 优先级
|
||||
PRIORITY_DICT = {
|
||||
@@ -117,4 +175,3 @@ PRIORITY_DICT = {
|
||||
'bag_back': -98,
|
||||
'earring_back': -99,
|
||||
}
|
||||
|
||||
|
||||
16
app/main.py
16
app/main.py
@@ -1,14 +1,18 @@
|
||||
import logging.config
|
||||
from http.client import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.api.api_route import router
|
||||
from app.core.config import settings
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from logging_env import LOGGER_CONFIG_DICT
|
||||
|
||||
|
||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||
logging.getLogger("pika").setLevel(logging.WARNING)
|
||||
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
@@ -35,5 +39,15 @@ def get_application() -> FastAPI:
|
||||
|
||||
|
||||
app = get_application()
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
8
app/schemas/chat_robot.py
Normal file
8
app/schemas/chat_robot.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatRobotModel(BaseModel):
|
||||
gender: str
|
||||
message: str
|
||||
session_id: str
|
||||
user_id: int
|
||||
58
app/schemas/design.py
Normal file
58
app/schemas/design.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# class BodyPointModel(BaseModel):
|
||||
# waistband_right: list[int]
|
||||
# hand_point_right: list[int]
|
||||
# waistband_left: list[int]
|
||||
# hand_point_left: list[int]
|
||||
# shoulder_left: list[int]
|
||||
# shoulder_right: list[int]
|
||||
#
|
||||
#
|
||||
# class BasicModel(BaseModel):
|
||||
# body_point: BodyPointModel
|
||||
# layer_order: bool
|
||||
# scale_bag: float
|
||||
# scale_earrings: float
|
||||
# self_template: bool
|
||||
# single_overall: str
|
||||
# switch_category: str
|
||||
# body_path: str
|
||||
#
|
||||
#
|
||||
# class PrintModel(BaseModel):
|
||||
# if_single: bool
|
||||
# print_path_list: list[str]
|
||||
#
|
||||
#
|
||||
# class ItemModel(BaseModel):
|
||||
# color: str
|
||||
# image_id: str
|
||||
# offset: list[int]
|
||||
# path: str
|
||||
# print: PrintModel
|
||||
# resize_scale: float
|
||||
# type: str
|
||||
#
|
||||
#
|
||||
# class CollocationModel(BaseModel):
|
||||
# basic: BasicModel
|
||||
# item: list[ItemModel]
|
||||
#
|
||||
#
|
||||
# class DesignModel(BaseModel):
|
||||
# object: list[CollocationModel]
|
||||
# process_id: str
|
||||
|
||||
class DesignModel(BaseModel):
|
||||
objects: list[dict]
|
||||
process_id: str
|
||||
|
||||
|
||||
class DesignProgressModel(BaseModel):
|
||||
process_id: str
|
||||
|
||||
|
||||
class ModelProgressModel(BaseModel):
|
||||
model_path: str
|
||||
@@ -8,3 +8,25 @@ class GenerateImageModel(BaseModel):
|
||||
mode: str
|
||||
category: str
|
||||
gender: str
|
||||
|
||||
|
||||
class GenerateSingleLogoImageModel(BaseModel):
|
||||
tasks_id: str
|
||||
prompt: str
|
||||
seed: str
|
||||
|
||||
|
||||
class GenerateProductImageModel(BaseModel):
|
||||
tasks_id: str
|
||||
prompt: str
|
||||
image_url: str
|
||||
image_strength: float
|
||||
product_type: str
|
||||
|
||||
|
||||
class GenerateRelightImageModel(BaseModel):
|
||||
tasks_id: str
|
||||
prompt: str
|
||||
image_url: str
|
||||
direction: str
|
||||
product_type: str
|
||||
|
||||
5
app/schemas/pre_processing.py
Normal file
5
app/schemas/pre_processing.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DesignPreProcessingModel(BaseModel):
|
||||
sketches: list[dict]
|
||||
5
app/schemas/prompt_generation.py
Normal file
5
app/schemas/prompt_generation.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptGenerationImageModel(BaseModel):
|
||||
text: str
|
||||
8
app/schemas/response_template.py
Normal file
8
app/schemas/response_template.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
code: int = 200
|
||||
msg: str = "OK!"
|
||||
data: Optional[Any] = None
|
||||
@@ -1,13 +1,13 @@
|
||||
top_description_list = ['service/attribute/config/descriptor/top/length.csv',
|
||||
'service/attribute/config/descriptor/top/type.csv',
|
||||
'service/attribute/config/descriptor/top/sleeve_length.csv',
|
||||
'service/attribute/config/descriptor/top/sleeve_shape.csv',
|
||||
'service/attribute/config/descriptor/top/sleeve_shoulder.csv',
|
||||
'service/attribute/config/descriptor/top/neckline.csv',
|
||||
'service/attribute/config/descriptor/top/design.csv',
|
||||
'service/attribute/config/descriptor/top/opening_type.csv',
|
||||
'service/attribute/config/descriptor/top/silhouette.csv',
|
||||
'service/attribute/config/descriptor/top/collar.csv']
|
||||
top_description_list = ['app/service/attribute/config/descriptor/top/length.csv',
|
||||
'app/service/attribute/config/descriptor/top/type.csv',
|
||||
'app/service/attribute/config/descriptor/top/sleeve_length.csv',
|
||||
'app/service/attribute/config/descriptor/top/sleeve_shape.csv',
|
||||
'app/service/attribute/config/descriptor/top/sleeve_shoulder.csv',
|
||||
'app/service/attribute/config/descriptor/top/neckline.csv',
|
||||
'app/service/attribute/config/descriptor/top/design.csv',
|
||||
'app/service/attribute/config/descriptor/top/opening_type.csv',
|
||||
'app/service/attribute/config/descriptor/top/silhouette.csv',
|
||||
'app/service/attribute/config/descriptor/top/collar.csv']
|
||||
|
||||
top_model_list = ['attr_retrieve_T_length',
|
||||
'attr_retrieve_T_type',
|
||||
@@ -22,11 +22,11 @@ top_model_list = ['attr_retrieve_T_length',
|
||||
]
|
||||
|
||||
bottom_description_list = [
|
||||
'service/attribute/config/descriptor/bottom/subtype.csv',
|
||||
'service/attribute/config/descriptor/bottom/length.csv',
|
||||
'service/attribute/config/descriptor/bottom/silhouette.csv',
|
||||
'service/attribute/config/descriptor/bottom/opening_type.csv',
|
||||
'service/attribute/config/descriptor/bottom/design.csv']
|
||||
'app/service/attribute/config/descriptor/bottom/subtype.csv',
|
||||
'app/service/attribute/config/descriptor/bottom/length.csv',
|
||||
'app/service/attribute/config/descriptor/bottom/silhouette.csv',
|
||||
'app/service/attribute/config/descriptor/bottom/opening_type.csv',
|
||||
'app/service/attribute/config/descriptor/bottom/design.csv']
|
||||
|
||||
bottom_model_list = [
|
||||
'attr_retrieve_B_subtype',
|
||||
@@ -35,14 +35,14 @@ bottom_model_list = [
|
||||
'attr_recong_B_optype',
|
||||
'attr_retrieve_B_design']
|
||||
|
||||
outwear_description_list = ['service/attribute/config/descriptor/outwear/length.csv',
|
||||
'service/attribute/config/descriptor/outwear/sleeve_length.csv',
|
||||
'service/attribute/config/descriptor/outwear/sleeve_shape.csv',
|
||||
'service/attribute/config/descriptor/outwear/sleeve_shoulder.csv',
|
||||
'service/attribute/config/descriptor/outwear/collar.csv',
|
||||
'service/attribute/config/descriptor/outwear/design.csv',
|
||||
'service/attribute/config/descriptor/outwear/opening_type.csv',
|
||||
'service/attribute/config/descriptor/outwear/silhouette.csv', ]
|
||||
outwear_description_list = ['app/service/attribute/config/descriptor/outwear/length.csv',
|
||||
'app/service/attribute/config/descriptor/outwear/sleeve_length.csv',
|
||||
'app/service/attribute/config/descriptor/outwear/sleeve_shape.csv',
|
||||
'app/service/attribute/config/descriptor/outwear/sleeve_shoulder.csv',
|
||||
'app/service/attribute/config/descriptor/outwear/collar.csv',
|
||||
'app/service/attribute/config/descriptor/outwear/design.csv',
|
||||
'app/service/attribute/config/descriptor/outwear/opening_type.csv',
|
||||
'app/service/attribute/config/descriptor/outwear/silhouette.csv', ]
|
||||
|
||||
outwear_model_list = ['attr_recong_O_length',
|
||||
'attr_retrieve_O_sleeve_length',
|
||||
@@ -53,15 +53,15 @@ outwear_model_list = ['attr_recong_O_length',
|
||||
'attr_recong_O_optype',
|
||||
'attr_retrieve_O_silhouette']
|
||||
|
||||
dress_description_list = [ # 'service/attribute/config/descriptor/dress/D_length.csv',
|
||||
'service/attribute/config/descriptor/dress/sleeve_length.csv',
|
||||
'service/attribute/config/descriptor/dress/sleeve_shape.csv',
|
||||
# 'service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv',
|
||||
'service/attribute/config/descriptor/dress/neckline.csv',
|
||||
'service/attribute/config/descriptor/dress/collar.csv',
|
||||
'service/attribute/config/descriptor/dress/design.csv',
|
||||
'service/attribute/config/descriptor/dress/silhouette.csv',
|
||||
'service/attribute/config/descriptor/dress/type.csv']
|
||||
dress_description_list = [ # 'app/service/attribute/config/descriptor/dress/D_length.csv',
|
||||
'app/service/attribute/config/descriptor/dress/sleeve_length.csv',
|
||||
'app/service/attribute/config/descriptor/dress/sleeve_shape.csv',
|
||||
# 'app/service/attribute/config/descriptor/dress/D_sleeve_shoulder.csv',
|
||||
'app/service/attribute/config/descriptor/dress/neckline.csv',
|
||||
'app/service/attribute/config/descriptor/dress/collar.csv',
|
||||
'app/service/attribute/config/descriptor/dress/design.csv',
|
||||
'app/service/attribute/config/descriptor/dress/silhouette.csv',
|
||||
'app/service/attribute/config/descriptor/dress/type.csv']
|
||||
|
||||
dress_model_list = [ # 'attr_recong_D_length',
|
||||
'attr_retrieve_D_sleeve_length',
|
||||
|
||||
@@ -11,12 +11,12 @@ from minio import Minio
|
||||
import tritonclient.http as httpclient
|
||||
from app.core.config import *
|
||||
from app.schemas.attribute_retrieve import AttributeRecognitionModel
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
|
||||
class AttributeRecognition:
|
||||
def __init__(self, const, request_data):
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
logging.info("实例化完成")
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.request_data = []
|
||||
for i, sketch in enumerate(request_data):
|
||||
self.request_data.append(
|
||||
@@ -97,9 +97,10 @@ class AttributeRecognition:
|
||||
return res
|
||||
|
||||
def get_image(self, url):
|
||||
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
||||
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
|
||||
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
@@ -18,12 +18,13 @@ import torch
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.attribute_retrieve import CategoryRecognitionModel
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
|
||||
class CategoryRecognition:
|
||||
def __init__(self, request_data):
|
||||
self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.request_data = []
|
||||
self.triton_client = httpclient.InferenceServerClient(url=ATT_TRITON_URL)
|
||||
for sketch in request_data:
|
||||
@@ -51,9 +52,10 @@ class CategoryRecognition:
|
||||
def get_image(self, url):
|
||||
# Get data of an object.
|
||||
# Read data from response.
|
||||
response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
||||
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
||||
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
7
app/service/chat_robot/script/agents/__init__.py
Normal file
7
app/service/chat_robot/script/agents/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .agent_executor import CustomAgentExecutor
|
||||
from .conversational_functions_agent import ConversationalFunctionsAgent
|
||||
|
||||
__all__ = [
|
||||
"CustomAgentExecutor",
|
||||
"ConversationalFunctionsAgent"
|
||||
]
|
||||
132
app/service/chat_robot/script/agents/agent_executor.py
Normal file
132
app/service/chat_robot/script/agents/agent_executor.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import RUN_KEY, RunInfo
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
session_key: str = "",
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs, session_key)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
outputs = (
|
||||
self._call(inputs, run_manager=run_manager)
|
||||
if new_arg_supported
|
||||
else self._call(inputs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
logging.exception(e)
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs, session_key
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
session_key: str = ""
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None and outputs['need_record']:
|
||||
self.memory.save_context(inputs, outputs, session_key)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(self.memory.memory_variables)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs, session_key)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def _get_tool_return(
|
||||
self, next_step_output: Tuple[AgentAction, str]
|
||||
) -> Optional[AgentFinish]:
|
||||
"""Check if the tool is a returning tool."""
|
||||
agent_action, observation = next_step_output
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
return_value_key = "output"
|
||||
|
||||
if len(self.agent.return_values) > 0:
|
||||
return_value_key = self.agent.return_values[0]
|
||||
|
||||
try:
|
||||
observation_list = json.loads(observation)
|
||||
if agent_action.tool == "sql_db_query" and isinstance(observation_list,
|
||||
list) and observation_list.__len__() != 0:
|
||||
return AgentFinish(
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
if name_to_tool_map[agent_action.tool].return_direct:
|
||||
return AgentFinish(
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
return None
|
||||
@@ -0,0 +1,198 @@
|
||||
import json
|
||||
import re
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Tuple, Any, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.agents import (
|
||||
OpenAIFunctionsAgent,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
OutputParserException
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage
|
||||
)
|
||||
from langchain.tools import BaseTool, StructuredTool
|
||||
# from langchain.tools.convert_to_openai import FunctionDescription
|
||||
from langchain.utils.openai_functions import FunctionDescription
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FunctionsAgentAction(AgentAction):
|
||||
"""Add message_log to AgentAction class for the _FunctionAgentAction
|
||||
"""
|
||||
message_log: List[BaseMessage]
|
||||
|
||||
def __init__(
|
||||
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
|
||||
):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
|
||||
|
||||
|
||||
def _convert_agent_action_to_messages(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert an agents action to a message.
|
||||
|
||||
This code is used to reconstruct the original AI message from the agents action.
|
||||
|
||||
Args:
|
||||
agent_action: Agent action to convert.
|
||||
|
||||
Returns:
|
||||
AIMessage that corresponds to the original tools invocation.
|
||||
"""
|
||||
if isinstance(agent_action, _FunctionsAgentAction):
|
||||
return agent_action.message_log + [
|
||||
_create_function_message(agent_action, observation)
|
||||
]
|
||||
else:
|
||||
return [AIMessage(content=agent_action.log)]
|
||||
|
||||
|
||||
def _create_function_message(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> FunctionMessage:
|
||||
"""Convert agents action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tools invocation request from the agents
|
||||
observation: the result of the tools invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tools invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _format_intermediate_steps(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Format intermediate steps.
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
"""
|
||||
messages = []
|
||||
|
||||
for intermediate_step in intermediate_steps:
|
||||
agent_action, observation = intermediate_step
|
||||
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tools into the OpenAI function API."""
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": tool.param_description if hasattr(tool, 'param_description') else "",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
|
||||
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
|
||||
if not isinstance(message, AIMessage):
|
||||
raise TypeError(f"Expected an AI message but got {type(message)}")
|
||||
|
||||
function_call = message.additional_kwargs.get("function_call", {})
|
||||
|
||||
if function_call:
|
||||
function_call = message.additional_kwargs["function_call"]
|
||||
function_name = function_call["name"]
|
||||
try:
|
||||
_tool_input = json.loads(function_call["arguments"])
|
||||
except JSONDecodeError:
|
||||
raise OutputParserException(
|
||||
f"Could not parse tools input: {function_call} because"
|
||||
f"the `arguments` is not valid JSON."
|
||||
)
|
||||
|
||||
if "query" in _tool_input:
|
||||
tool_input = _tool_input["query"]
|
||||
else:
|
||||
tool_input = _tool_input
|
||||
|
||||
return _FunctionsAgentAction(
|
||||
tool=function_name,
|
||||
tool_input=tool_input,
|
||||
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n",
|
||||
message_log=[message]
|
||||
)
|
||||
|
||||
# pattern = r'\((.*?)\)'
|
||||
# matches = re.findall(pattern, message.content)
|
||||
# result = []
|
||||
#
|
||||
# for match in matches:
|
||||
# result.append(match)
|
||||
#
|
||||
# if result:
|
||||
# output = result
|
||||
# else:
|
||||
# output = message.content
|
||||
|
||||
return AgentFinish(return_values={"output": message.content}, log=message.content)
|
||||
|
||||
|
||||
class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
|
||||
@property
|
||||
def functions(self) -> List[dict]:
|
||||
return [dict(_format_tool_to_openai_function(t)) for t in self.tools]
|
||||
|
||||
def plan(self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Decide how agents should move after receiving an input. The difference between
|
||||
OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser
|
||||
into it.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
**kwargs: Including user's input string
|
||||
|
||||
Returns:
|
||||
Action specifying what tools to use.
|
||||
"""
|
||||
agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages: List[BaseMessage] = prompt.to_messages()
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
)
|
||||
agent_decision = _parse_ai_message(predicted_message)
|
||||
return agent_decision
|
||||
6
app/service/chat_robot/script/callbacks/__init__.py
Normal file
6
app/service/chat_robot/script/callbacks/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .openai_token_record_callback import OpenAITokenRecordCallbackHandler
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OpenAITokenRecordCallbackHandler'
|
||||
]
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Callback Handler that add on_chain_end function to record Token usage."""
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain.callbacks import OpenAICallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
|
||||
|
||||
|
||||
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
need_record: bool = True
|
||||
response_type: str = "string"
|
||||
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is None:
|
||||
return None
|
||||
self.successful_requests += 1
|
||||
if "token_usage" not in response.llm_output:
|
||||
return None
|
||||
if "function_call" in response.generations[0][0].message.additional_kwargs:
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
|
||||
self.need_record = False
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
|
||||
self.response_type = "image"
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
model_name = standardize_model_name(response.llm_output.get("model_name", ""))
|
||||
if model_name in MODEL_COST_PER_1K_TOKENS:
|
||||
completion_cost = get_openai_token_cost_for_model(
|
||||
model_name, completion_tokens, is_completion=True
|
||||
)
|
||||
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
|
||||
self.total_cost += prompt_cost + completion_cost
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
|
||||
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
|
||||
"""Write token usage to redis."""
|
||||
outputs["total_tokens"] = self.total_tokens
|
||||
outputs["total_cost"] = self.total_cost
|
||||
outputs["prompt_tokens"] = self.prompt_tokens
|
||||
outputs["completion_tokens"] = self.completion_tokens
|
||||
outputs["need_record"] = self.need_record
|
||||
outputs["response_type"] = self.response_type
|
||||
79
app/service/chat_robot/script/database.py
Normal file
79
app/service/chat_robot/script/database.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Optional, List
|
||||
import json
|
||||
|
||||
from sqlalchemy import text
|
||||
# from langchain import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
|
||||
|
||||
class CustomDatabase(SQLDatabase):
|
||||
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||
# def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
connection = self._engine.connect()
|
||||
all_table_names = self.get_usable_table_names()
|
||||
if table_names is not None:
|
||||
missing_tables = set(table_names).difference(all_table_names)
|
||||
if missing_tables:
|
||||
# raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
return f"Table {','.join(missing_tables)} can not be found in the database"
|
||||
all_table_names = table_names
|
||||
meta_tables = [
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
]
|
||||
|
||||
tables = []
|
||||
for table in meta_tables:
|
||||
table_name = table.name
|
||||
column_names = table.columns.keys()
|
||||
table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n"
|
||||
for column_name in column_names:
|
||||
if column_name not in ["ID", "img_name"]:
|
||||
query = text(f"SELECT DISTINCT {column_name} FROM {table_name}")
|
||||
result = connection.execute(query)
|
||||
enum_values: List[str] = [row[0] for row in result.fetchall()]
|
||||
column_info = f"{column_name}: {', '.join(enum_values)}\n"
|
||||
table_info += column_info
|
||||
|
||||
# table_info = f"Table: {table_name}\n"
|
||||
#
|
||||
# if self._sample_rows_in_table_info:
|
||||
# table_info += f"{self._get_sample_rows(table)}\n"
|
||||
tables.append(table_info)
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
|
||||
"""
|
||||
with self._engine.begin() as connection:
|
||||
if self._schema is not None:
|
||||
if self.dialect == "snowflake":
|
||||
connection.exec_driver_sql(
|
||||
f"ALTER SESSION SET search_path='{self._schema}'"
|
||||
)
|
||||
elif self.dialect == "bigquery":
|
||||
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
|
||||
else:
|
||||
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
|
||||
cursor = connection.execute(text(command))
|
||||
if cursor.rowcount:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone() # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
|
||||
# Convert columns values to string to avoid issues with sqlalchmey
|
||||
# trunacating text
|
||||
if isinstance(result, list):
|
||||
return json.dumps([r[0] for r in result])
|
||||
|
||||
return json.dumps([result[0]])
|
||||
return ""
|
||||
115
app/service/chat_robot/script/main.py
Normal file
115
app/service/chat_robot/script/main.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.agents import Tool
|
||||
from langchain.callbacks import FileCallbackHandler
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import SystemMessage, AIMessage
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
|
||||
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
|
||||
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
|
||||
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
|
||||
|
||||
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
|
||||
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
||||
# log callbacks
|
||||
logfile = "logs/chat_debug.log"
|
||||
logger.add(logfile, colorize=True, enqueue=True)
|
||||
log_handler = FileCallbackHandler(logfile)
|
||||
|
||||
# Initiate our LLM 'gpt-3.5-turbo'
|
||||
llm = ChatOpenAI(temperature=0.1,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
# callbacks=[OpenAICallbackHandler()]
|
||||
)
|
||||
|
||||
search = SerpAPIWrapper()
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
tools = [
|
||||
Tool(
|
||||
name="internet_search",
|
||||
description="Can be used to perform Internet searches",
|
||||
func=search.run
|
||||
),
|
||||
QuerySQLDataBaseTool(db=db, return_direct=False),
|
||||
InfoSQLDatabaseTool(db=db),
|
||||
ListSQLDatabaseTool(db=db),
|
||||
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
|
||||
Tool(
|
||||
name="tutorial_tool",
|
||||
description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
"Input is an empty string",
|
||||
func=CustomTutorialTool(),
|
||||
return_direct=True
|
||||
)
|
||||
]
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=FASHION_CHAT_BOT_PREFIX),
|
||||
MessagesPlaceholder(variable_name="history"),
|
||||
HumanMessagePromptTemplate.from_template(
|
||||
"{input} "
|
||||
"Question from a {gender}."
|
||||
),
|
||||
AIMessage(content=TOOLS_FUNCTIONS_SUFFIX),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
]
|
||||
|
||||
prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages)
|
||||
agent = ConversationalFunctionsAgent(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
memory = UserConversationBufferWindowMemory.from_redis(
|
||||
return_messages=True, k=2, input_key='input', output_key='output'
|
||||
)
|
||||
agent_executor = CustomAgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
verbose=True,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
|
||||
def chat(post_data):
|
||||
user_id = post_data.user_id
|
||||
session_id = post_data.session_id
|
||||
input_message = post_data.message
|
||||
gender = post_data.gender
|
||||
|
||||
final_outputs = agent_executor(
|
||||
{"input": input_message, "gender": gender},
|
||||
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
|
||||
session_key=f"buffer:{user_id}:{session_id}",
|
||||
)
|
||||
api_response = {
|
||||
'user_id': user_id,
|
||||
'session_id': session_id,
|
||||
# 'message_id': message_id,
|
||||
# 'create_time': created_time,
|
||||
'input': final_outputs['input'],
|
||||
# 'conversion': messages,
|
||||
'output': final_outputs['output'],
|
||||
# 'gpt_response_time': gpt_response_time,
|
||||
'total_tokens': final_outputs['total_tokens'],
|
||||
'total_cost': final_outputs['total_cost'],
|
||||
'prompt_tokens': final_outputs['prompt_tokens'],
|
||||
'completion_tokens': final_outputs['completion_tokens'],
|
||||
'response_type': final_outputs['response_type']
|
||||
}
|
||||
logging.info(json.dumps(api_response))
|
||||
return api_response
|
||||
3
app/service/chat_robot/script/memory/__init__.py
Normal file
3
app/service/chat_robot/script/memory/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .user_buffer_window import UserConversationBufferWindowMemory
|
||||
|
||||
__all__ = ['UserConversationBufferWindowMemory']
|
||||
93
app/service/chat_robot/script/memory/user_buffer_window.py
Normal file
93
app/service/chat_robot/script/memory/user_buffer_window.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import json
|
||||
|
||||
import redis
|
||||
from redis import Redis
|
||||
from langchain.memory import RedisChatMessageHistory
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
|
||||
from langchain.schema.messages import _message_to_dict, messages_from_dict
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
|
||||
class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""Buffer for storing conversation memory."""
|
||||
|
||||
redis_client: Redis
|
||||
human_prefix: str = "Human"
|
||||
ai_prefix: str = "AI"
|
||||
memory_key: str = "history" #: :meta private:
|
||||
k: int = 5
|
||||
|
||||
@classmethod
|
||||
def from_redis(
|
||||
cls,
|
||||
host: str = REDIS_HOST,
|
||||
port: int = REDIS_PORT,
|
||||
db: int = 3,
|
||||
**kwargs
|
||||
):
|
||||
redis_client = Redis(host=host, port=port, db=db)
|
||||
try:
|
||||
response = redis_client.ping()
|
||||
if response:
|
||||
print("Connect to redis server successfully.")
|
||||
logging.info("Connect to redis server successfully.")
|
||||
else:
|
||||
print("Fail to connect to redis server")
|
||||
logging.info("Fail to connect to redis server")
|
||||
except redis.RedisError as e:
|
||||
logging.info(f"Error occurs when connecting to redis server: {str(e)}")
|
||||
return cls(redis_client=redis_client, **kwargs)
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""Will always return list of memory variables.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
_items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else []
|
||||
items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
|
||||
buffer = messages_from_dict(items)
|
||||
if not self.return_messages:
|
||||
buffer = get_buffer_string(
|
||||
buffer,
|
||||
human_prefix=self.human_prefix,
|
||||
ai_prefix=self.ai_prefix,
|
||||
)
|
||||
return {self.memory_key: buffer}
|
||||
|
||||
def _get_input_output(
|
||||
self, inputs: Dict[str, Any], outputs: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
if self.input_key is None:
|
||||
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
|
||||
else:
|
||||
prompt_input_key = self.input_key
|
||||
if self.output_key is None:
|
||||
if len(outputs) != 1:
|
||||
raise ValueError(f"One output key expected, got {outputs.keys()}")
|
||||
output_key = list(outputs.keys())[0]
|
||||
else:
|
||||
output_key = self.output_key
|
||||
return inputs[prompt_input_key], outputs[output_key]
|
||||
|
||||
def add_message(self, key: str, message: BaseMessage) -> None:
|
||||
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.add_message(key, HumanMessage(content=input_str))
|
||||
self.add_message(key, AIMessage(content=output_str))
|
||||
|
||||
# def clear(self, key) -> None:
|
||||
# """Clear memory contents."""
|
||||
# self.redis_client.delete(key)
|
||||
52
app/service/chat_robot/script/prompt.py
Normal file
52
app/service/chat_robot/script/prompt.py
Normal file
@@ -0,0 +1,52 @@
|
||||
FASHION_CHAT_BOT_PREFIX = """
|
||||
You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
|
||||
The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
|
||||
Remember your answer should be very precise and the final output answer should not exceed 20 words.
|
||||
|
||||
You may encounter the following types of questions:
|
||||
1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database.
|
||||
Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
|
||||
Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
|
||||
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||||
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
|
||||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
If the question does not seem related to the database, just return "I don't know" as the answer.
|
||||
|
||||
2) If the query related to current events, you should use internet_search to seek help from the internet.
|
||||
|
||||
3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant.
|
||||
|
||||
Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
|
||||
"""
|
||||
|
||||
TOOL_SELECT_SUFFIX = """
|
||||
Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly.
|
||||
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
|
||||
The use of online resources should be limited to inquiry pertaining to current subjects.
|
||||
"""
|
||||
|
||||
SQL_FUNCTIONS_SUFFIX = """
|
||||
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
|
||||
"""
|
||||
|
||||
INTERNET_SEARCH_SUFFIX = """
|
||||
If the question should be answered using internet search tools, I should seek help from the internet.
|
||||
"""
|
||||
|
||||
ANSWER_FORMAT_SUFFIX = """
|
||||
My final answer are limited to 20 words and be as much precise as possible.
|
||||
"""
|
||||
|
||||
TOOLS_FUNCTIONS_SUFFIX = (
|
||||
"If the input involves clothing queries,"
|
||||
"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
|
||||
"All SQL statements must use 'ORDER BY RAND()', for example:"
|
||||
"Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
|
||||
"If the input does not involve clothing queries, "
|
||||
"I should engage in conversation as an assistant or search from internet with internet_search tool."
|
||||
"If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
|
||||
"Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
|
||||
)
|
||||
|
||||
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."
|
||||
10
app/service/chat_robot/script/tools/__init__.py
Normal file
10
app/service/chat_robot/script/tools/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .sql_tools import (
|
||||
QuerySQLDataBaseTool,
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QuerySQLCheckerTool
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool"
|
||||
]
|
||||
183
app/service/chat_robot/script/tools/sql_tools.py
Normal file
183
app/service/chat_robot/script/tools/sql_tools.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# flake8: noqa
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
from langchain.utilities import SQLDatabase
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
"""Base tools for interacting with a SQL database."""
|
||||
|
||||
db: SQLDatabase = Field(exclude=True)
|
||||
param_description: str = ""
|
||||
|
||||
# Override BaseTool.Config to appease mypy
|
||||
# See https://github.com/pydantic/pydantic/issues/4173
|
||||
class Config(BaseTool.Config):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for querying a SQL database."""
|
||||
|
||||
name = "sql_db_query"
|
||||
# description = """
|
||||
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
|
||||
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
|
||||
# You should always use ‘order by rand()’ to randomly select data.
|
||||
# If the query is not correct, an error message will be returned.
|
||||
# If an error is returned, rewrite the query, check the query, and try again.
|
||||
# Always limit your query to at most 4 results.
|
||||
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
||||
# You MUST double check your query before executing it.
|
||||
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
||||
# """
|
||||
|
||||
description: str = (
|
||||
"The input of this tool is a detailed and correct SQL select query statement, "
|
||||
"and the output is the result of the database, and it can only return up to 4 results."
|
||||
"If the query is not correct, an error message will be returned."
|
||||
"If an error is returned, rewrite the query, check the query, and try again."
|
||||
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
|
||||
"use sql_db_schema to query the correct table fields."
|
||||
|
||||
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
|
||||
"LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||||
"order by rand() LIMIT 2'"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
result = self.db.run_no_throw(query)
|
||||
return result
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("QuerySqlDbTool does not support async")
|
||||
|
||||
|
||||
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting metadata about a SQL database."""
|
||||
|
||||
name = "sql_db_schema"
|
||||
# description = """
|
||||
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
|
||||
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
|
||||
# Find the most relevant tables with the query, and output the schema of these tables.
|
||||
# """
|
||||
|
||||
description: str = (
|
||||
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
|
||||
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
|
||||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||||
|
||||
"Example Input: 'female_outwear, male_top'"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
table_names: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for tables in a comma-separated list."""
|
||||
return self.db.get_table_info_no_throw(table_names.split(", "))
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
table_name: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("SchemaSqlDbTool does not support async")
|
||||
|
||||
|
||||
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Tool for getting tables names."""
|
||||
|
||||
name = "sql_db_list_tables"
|
||||
description = "Input is an empty string, output is a comma separated list of tables in the database."
|
||||
|
||||
def _run(
|
||||
self,
|
||||
tool_input: str = "",
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for a specific table."""
|
||||
return ", ".join(self.db.get_usable_table_names())
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
tool_input: str = "",
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
raise NotImplementedError("ListTablesSqlDbTool does not support async")
|
||||
|
||||
|
||||
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"""Use an LLM to check if a query is correct.
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
|
||||
template: str = QUERY_CHECKER
|
||||
llm: BaseLanguageModel
|
||||
llm_chain: LLMChain = Field(init=False)
|
||||
name = "sql_db_query_checker"
|
||||
description = (
|
||||
"Use this tools to double check if your query is correct before executing it."
|
||||
"Always use this tools before executing a query with sql_db_query!"
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
values["llm_chain"] = LLMChain(
|
||||
llm=values.get("llm"),
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER,
|
||||
input_variables=["query", "dialect"]
|
||||
),
|
||||
)
|
||||
|
||||
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
||||
raise ValueError(
|
||||
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the LLM to check the query."""
|
||||
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)
|
||||
19
app/service/chat_robot/script/tools/tutorial_tool.py
Normal file
19
app/service/chat_robot/script/tools/tutorial_tool.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
|
||||
|
||||
|
||||
# 处理系统引导教程相关的输入
|
||||
class CustomTutorialTool(BaseTool):
|
||||
name = "tutorial_tool"
|
||||
|
||||
description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
"Input is an empty string")
|
||||
|
||||
def _run(self, tool_input, **kwargs: Any) -> str:
|
||||
return TUTORIAL_TOOL_RETURN
|
||||
|
||||
async def _arun(self, tool_input, **kwargs: Any) -> str:
|
||||
raise NotImplementedError("CustomTutorialTool does not support async")
|
||||
1
app/service/chat_robot/script/utils/__init__.py
Normal file
1
app/service/chat_robot/script/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .logger import Logger
|
||||
26
app/service/chat_robot/script/utils/logger.py
Normal file
26
app/service/chat_robot/script/utils/logger.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from logging import handlers
|
||||
|
||||
|
||||
class Logger(object):
|
||||
level_relations = {
|
||||
'debug': logging.DEBUG,
|
||||
'info': logging.INFO,
|
||||
'warning': logging.WARNING,
|
||||
'error': logging.ERROR,
|
||||
'crit': logging.CRITICAL
|
||||
}
|
||||
|
||||
def __init__(self, filename, level='info', when='D', backCount=3,
|
||||
fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
|
||||
self.logger = logging.getLogger(filename)
|
||||
format_str = logging.Formatter(fmt) # set log format
|
||||
self.logger.setLevel(self.level_relations.get(level)) # set log level
|
||||
sh = logging.StreamHandler() # output to terminal
|
||||
sh.setFormatter(format_str) # set format for terminal log
|
||||
th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount,
|
||||
encoding='utf-8') # log into file
|
||||
|
||||
th.setFormatter(format_str) # set format for file log
|
||||
self.logger.addHandler(sh) # output to terminal
|
||||
self.logger.addHandler(th) # output to file
|
||||
0
app/service/design/core/__init__.py
Normal file
0
app/service/design/core/__init__.py
Normal file
116
app/service/design/core/layer.py
Normal file
116
app/service/design/core/layer.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def show(img, win_name="temp"):
|
||||
cv2.imshow(win_name, img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
def crop(img):
|
||||
mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2)
|
||||
img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340]
|
||||
return img_roi
|
||||
|
||||
|
||||
class Layer(object):
|
||||
def __init__(self):
|
||||
self._layer = []
|
||||
|
||||
@property
|
||||
def layer(self):
|
||||
return self._layer
|
||||
|
||||
def insert(self, layer_instance):
|
||||
if layer_instance['name'] == 'body':
|
||||
self._body = layer_instance
|
||||
self._layer.append(layer_instance)
|
||||
|
||||
def sort(self, priority):
|
||||
self._layer.sort(key=lambda x: priority[x['name']])
|
||||
|
||||
# def merge(self, cfg):
|
||||
# """
|
||||
# opencv shape order (height, width, channel)
|
||||
# image coordinate system:
|
||||
# |------------->x (width)
|
||||
# |
|
||||
# |
|
||||
# |
|
||||
# y (height)
|
||||
# Returns:
|
||||
#
|
||||
#
|
||||
# """
|
||||
# base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0))
|
||||
# for layer in self._layer:
|
||||
# y, x = layer['position']
|
||||
# base_image.paste(layer['image'], (x, y), layer['image'])
|
||||
# # base_image.show()
|
||||
#
|
||||
# for x in self._layer:
|
||||
# if np.all(x['mask'] == 0):
|
||||
# continue
|
||||
# # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask)
|
||||
# roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x)
|
||||
# temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2)
|
||||
# tmp1 = (roi * (temp_bg / 255)).astype(np.uint8)
|
||||
# temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2)
|
||||
# tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8)
|
||||
#
|
||||
# roi[:] = cv2.add(tmp1, tmp2)
|
||||
# # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA),
|
||||
# # win_name=x.get('name'))
|
||||
# # crop image and get the central part
|
||||
# if cfg.get('basic')['self_template'] == False:
|
||||
# dst_roi = crop(dst)
|
||||
# else:
|
||||
# dst_roi = dst
|
||||
# return dst_roi, signal
|
||||
#
|
||||
# @staticmethod
|
||||
# def get_roi(dst, image):
|
||||
# signal = False
|
||||
# dst_y, dst_x = dst.shape[:2]
|
||||
# roi_height, roi_width = image['mask'].shape
|
||||
# roi_y0, roi_x0 = image['position']
|
||||
#
|
||||
# if roi_y0 < 0:
|
||||
# roi_yin = 0
|
||||
# mask_yin = -roi_y0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_yin = roi_y0
|
||||
# mask_yin = 0
|
||||
# if roi_y0 + roi_height > dst_y:
|
||||
# roi_yout = dst_y
|
||||
# mask_yout = dst_y - roi_y0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_yout = roi_height + roi_y0
|
||||
# mask_yout = roi_height
|
||||
# # x part
|
||||
# if roi_x0 < 0:
|
||||
# roi_xin = 0
|
||||
# mask_xin = -roi_x0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_xin = roi_x0
|
||||
# mask_xin = 0
|
||||
# if roi_x0 + roi_width > dst_x:
|
||||
# roi_xout = dst_x
|
||||
# mask_xout = dst_x - roi_x0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_xout = roi_width + roi_x0
|
||||
# mask_xout = roi_width
|
||||
#
|
||||
# roi = dst[roi_yin: roi_yout, roi_xin: roi_xout]
|
||||
# roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout]
|
||||
# roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout]
|
||||
# return roi, roi_mask, roi_image, signal
|
||||
45
app/service/design/core/priority.py
Normal file
45
app/service/design/core/priority.py
Normal file
@@ -0,0 +1,45 @@
|
||||
class Priority(object):
|
||||
"""Item layer priority levels.
|
||||
"""
|
||||
|
||||
def __init__(self, item_list):
|
||||
self._priority = dict(
|
||||
earring_front=99,
|
||||
bag_front=98,
|
||||
hairstyle_front=97,
|
||||
outwear_front=20,
|
||||
bottoms_front=19,
|
||||
dress_front=18,
|
||||
blouse_front=17,
|
||||
skirt_front=16,
|
||||
trousers_front=15,
|
||||
tops_front=14,
|
||||
shoes_right=1,
|
||||
shoes_left=1,
|
||||
body=0,
|
||||
tops_back=-14,
|
||||
trousers_back=-15,
|
||||
skirt_back=-16,
|
||||
blouse_back=-17,
|
||||
dress_back=-18,
|
||||
bottoms_back=-19,
|
||||
outwear_back=-20,
|
||||
hairstyle_back=-97,
|
||||
bag_back=-98,
|
||||
earring_back=-99,
|
||||
)
|
||||
self.clothing_start_num = 10
|
||||
if not isinstance(item_list, list):
|
||||
raise ValueError('item_list must be a list!')
|
||||
for cate in item_list:
|
||||
cate = cate.lower()
|
||||
if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
raise ValueError(f'Item type error. Cannot recognize {cate}')
|
||||
for i, cate in enumerate(item_list):
|
||||
cate = cate.lower()
|
||||
self._priority[f'{cate}_front'] = self.clothing_start_num - i
|
||||
self._priority[f'{cate}_back'] = -(self.clothing_start_num - i)
|
||||
|
||||
@property
|
||||
def priority(self):
|
||||
return self._priority
|
||||
771
app/service/design/fastapi_request.json
Normal file
771
app/service/design/fastapi_request.json
Normal file
@@ -0,0 +1,771 @@
|
||||
{
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 493827,
|
||||
"color": "127 61 21",
|
||||
"elementId": 493827,
|
||||
"icon": "none",
|
||||
"image_id": 110201,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketch/62302527-2910-4740-808d-2cb8221daa34-3-31.png",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Dress"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "27 25 23",
|
||||
"icon": "none",
|
||||
"image_id": 110202,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/skirt/0916000602.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Skirt"
|
||||
},
|
||||
{
|
||||
"businessId": 493825,
|
||||
"color": "229 214 200",
|
||||
"elementId": 493825,
|
||||
"icon": "none",
|
||||
"image_id": 107101,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 493824,
|
||||
"color": "76 124 124",
|
||||
"elementId": 493824,
|
||||
"icon": "none",
|
||||
"image_id": 104522,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "229 214 200",
|
||||
"icon": "none",
|
||||
"image_id": 110203,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0825001576.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"color": "76 124 124",
|
||||
"icon": "none",
|
||||
"image_id": 96071,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/skirt/903000097.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Skirt"
|
||||
},
|
||||
{
|
||||
"color": "209 125 29",
|
||||
"icon": "none",
|
||||
"image_id": 93798,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/outwear_p4_561.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 493824,
|
||||
"color": "209 125 29",
|
||||
"elementId": 493824,
|
||||
"icon": "none",
|
||||
"image_id": 104522,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketch/3e82214a-0191-11ef-96d2-b48351119060_1.png",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"color": "118 123 115",
|
||||
"icon": "none",
|
||||
"image_id": 110204,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0902000457.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"color": "118 123 115",
|
||||
"icon": "none",
|
||||
"image_id": 79259,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/826000094.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "127 61 21",
|
||||
"icon": "none",
|
||||
"image_id": 96038,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/dress/0902003549.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Dress"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 493822,
|
||||
"color": "127 61 21",
|
||||
"elementId": 493822,
|
||||
"icon": "none",
|
||||
"image_id": 62309,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/trousers/c37c2ea6-8955-4b40-8339-c737e672ca3d.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"businessId": 493825,
|
||||
"color": "118 123 115",
|
||||
"elementId": 493825,
|
||||
"icon": "none",
|
||||
"image_id": 107101,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 493826,
|
||||
"color": "127 61 21",
|
||||
"elementId": 493826,
|
||||
"icon": "none",
|
||||
"image_id": 107105,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/Skirt/58710352-6301-450d-b69a-fb2922b5429a.png",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Skirt"
|
||||
},
|
||||
{
|
||||
"color": "118 123 115",
|
||||
"icon": "none",
|
||||
"image_id": 79114,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/903000169.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"color": "229 214 200",
|
||||
"icon": "none",
|
||||
"image_id": 90573,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/0628000541.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
336,
|
||||
264
|
||||
],
|
||||
"hand_point_right": [
|
||||
350,
|
||||
303
|
||||
],
|
||||
"waistband_left": [
|
||||
245,
|
||||
274
|
||||
],
|
||||
"hand_point_left": [
|
||||
219,
|
||||
315
|
||||
],
|
||||
"shoulder_left": [
|
||||
227,
|
||||
155
|
||||
],
|
||||
"shoulder_right": [
|
||||
338,
|
||||
149
|
||||
]
|
||||
},
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"color": "229 214 200",
|
||||
"icon": "none",
|
||||
"image_id": 110205,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/0916000217.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"businessId": 493825,
|
||||
"color": "209 125 29",
|
||||
"elementId": 493825,
|
||||
"icon": "none",
|
||||
"image_id": 107101,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-users/31/sketchboard/female/Blouse/de8f5656-d7ae-4642-bc90-f7f9d85da09b.jpg",
|
||||
"print": {
|
||||
"IfSingle": false,
|
||||
"print_path_list": []
|
||||
},
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-users/31/models/female/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png",
|
||||
"image_id": 82966,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"process_id": "6878547032381675"
|
||||
}
|
||||
16
app/service/design/items/__init__.py
Normal file
16
app/service/design/items/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from .builder import ITEMS, build_item
|
||||
from .clothing import Clothing # 4.0 sec
|
||||
from .body import Body
|
||||
from .top import Top, Blouse, Outwear, Dress
|
||||
from .bottom import Bottom, Trousers, Skirt
|
||||
from .shoes import Shoes
|
||||
from .bag import Bag
|
||||
from .accessories import Hairstyle, Earring
|
||||
|
||||
__all__ = [
|
||||
'ITEMS', 'build_item',
|
||||
'Clothing', 'Body',
|
||||
'Top', 'Blouse', 'Outwear', 'Dress',
|
||||
'Bottom', 'Trousers', 'Skirt',
|
||||
'Shoes', 'Bag', 'Hairstyle', 'Earring'
|
||||
]
|
||||
59
app/service/design/items/accessories.py
Normal file
59
app/service/design/items/accessories.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Hairstyle(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Hairstyle, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
|
||||
"""
|
||||
align up
|
||||
Args:
|
||||
keypoint_type: string, "head_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
side_indicator = f'{keypoint_type}_up'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
# logging.info(clothes_point[side_indicator])
|
||||
|
||||
start_point = (
|
||||
int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)),
|
||||
int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale))
|
||||
)
|
||||
return start_point
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Earring(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Earring, self).__init__(**kwargs)
|
||||
45
app/service/design/items/bag.py
Normal file
45
app/service/design/items/bag.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import random
|
||||
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bag(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Bag, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
|
||||
"""
|
||||
align left
|
||||
Args:
|
||||
keypoint_type: string, "hand_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (y', x')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
location = random.choice(seq=['left', 'right'])
|
||||
if location == 'left':
|
||||
side_indicator = f'{keypoint_type}_left'
|
||||
else:
|
||||
side_indicator = f'{keypoint_type}_right'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale),
|
||||
body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale))
|
||||
return start_point
|
||||
36
app/service/design/items/body.py
Normal file
36
app/service/design/items/body.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import cv2
|
||||
|
||||
from .builder import ITEMS
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Body(object):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']),
|
||||
# dict(type='ImageShow', key=['body_image', "body_mask"])
|
||||
]
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.result = dict()
|
||||
|
||||
def process(self):
|
||||
self.pipeline(self.result)
|
||||
pass
|
||||
|
||||
def organize(self, layer):
|
||||
body_layer = dict(priority=0,
|
||||
name=type(self).__name__.lower(),
|
||||
image=self.result['body_image'],
|
||||
image_url=self.result['image_url'],
|
||||
mask_image=None,
|
||||
mask_url=None,
|
||||
sacle=1,
|
||||
# mask=self.result['body_mask'],
|
||||
position=(0, 0))
|
||||
layer.insert(body_layer)
|
||||
|
||||
@staticmethod
|
||||
def show(img):
|
||||
cv2.imshow('', img)
|
||||
cv2.waitKey(0)
|
||||
38
app/service/design/items/bottom.py
Normal file
38
app/service/design/items/bottom.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bottom(Clothing):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
if pipeline is None:
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Bottom, self).__init__(**kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Trousers(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Trousers, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Skirt(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Skirt, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bottoms(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Bottoms, self).__init__(pipeline, **kwargs)
|
||||
9
app/service/design/items/builder.py
Normal file
9
app/service/design/items/builder.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
ITEMS = Registry('item')
|
||||
PIPELINES = Registry('pipeline')
|
||||
|
||||
|
||||
def build_item(cfg, default_args=None):
|
||||
item = build_from_cfg(cfg, ITEMS, default_args)
|
||||
return item
|
||||
99
app/service/design/items/clothing.py
Normal file
99
app/service/design/items/clothing.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import cv2
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
from .builder import ITEMS
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Clothing(object):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.result = dict(name=type(self).__name__.lower(), **kwargs)
|
||||
|
||||
def process(self):
|
||||
self.pipeline(self.result)
|
||||
|
||||
def apply_scale(self, img):
|
||||
scale = self.result['scale']
|
||||
height, width = img.shape[0: 2]
|
||||
if len(img.shape) > 2:
|
||||
height, width = img.shape[0: 2]
|
||||
scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
|
||||
return scaled_img
|
||||
|
||||
def organize(self, layer):
|
||||
start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"])
|
||||
|
||||
front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None),
|
||||
name=f'{type(self).__name__.lower()}_front',
|
||||
image=self.result["front_image"],
|
||||
# mask_image=self.result['front_mask_image'],
|
||||
image_url=self.result['front_image_url'],
|
||||
mask_url=self.result['front_mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=self.result["resize_scale"],
|
||||
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
|
||||
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
|
||||
pattern_image_url=self.result['pattern_image_url']
|
||||
|
||||
)
|
||||
layer.insert(front_layer)
|
||||
|
||||
back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None),
|
||||
name=f'{type(self).__name__.lower()}_back',
|
||||
image=self.result["back_image"],
|
||||
# mask_image=self.result['back_mask_image'],
|
||||
image_url=self.result['back_image_url'],
|
||||
mask_url=self.result['back_mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=self.result["resize_scale"],
|
||||
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
|
||||
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
|
||||
pattern_image_url=self.result['pattern_image_url']
|
||||
)
|
||||
layer.insert(back_layer)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
|
||||
"""
|
||||
Align left
|
||||
Args:
|
||||
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale + offset
|
||||
y' = x_body - x1 * scale + offset
|
||||
|
||||
"""
|
||||
|
||||
side_indicator = f'{keypoint_type}_left'
|
||||
|
||||
# if keypoint_type == "ear_point":
|
||||
# start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
|
||||
# body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
|
||||
# else:
|
||||
# start_point = (
|
||||
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
|
||||
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
|
||||
# )
|
||||
|
||||
# milvus_DB_keypoint_cache:
|
||||
start_point = (
|
||||
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
|
||||
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
|
||||
)
|
||||
# start_point = (
|
||||
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
|
||||
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
|
||||
# )
|
||||
|
||||
return start_point
|
||||
19
app/service/design/items/pipelines/__init__.py
Normal file
19
app/service/design/items/pipelines/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .compose import Compose
|
||||
from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow
|
||||
from .keypoints import KeypointDetection
|
||||
from .segmentation import Segmentation
|
||||
from .painting import Painting, PrintPainting
|
||||
from .scale import Scaling
|
||||
from .contour_detection import ContourDetection
|
||||
from .split import Split
|
||||
|
||||
__all__ = [
|
||||
'Compose',
|
||||
'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow',
|
||||
'KeypointDetection',
|
||||
'Segmentation',
|
||||
'Painting', 'PrintPainting',
|
||||
'Scaling',
|
||||
'ContourDetection',
|
||||
'split',
|
||||
]
|
||||
36
app/service/design/items/pipelines/compose.py
Normal file
36
app/service/design/items/pipelines/compose.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import collections
|
||||
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Compose(object):
|
||||
def __init__(self, transforms):
|
||||
assert isinstance(transforms, collections.abc.Sequence)
|
||||
self.transforms = []
|
||||
for transform in transforms:
|
||||
if isinstance(transform, dict):
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
self.transforms.append(transform)
|
||||
elif callable(transform):
|
||||
self.transforms.append(transform)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict')
|
||||
|
||||
def __call__(self, data):
|
||||
"""Call function to apply transforms sequentially.
|
||||
|
||||
Args:
|
||||
data (dict): A result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Transformed data.
|
||||
"""
|
||||
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
58
app/service/design/items/pipelines/contour_detection.py
Normal file
58
app/service/design/items/pipelines/contour_detection.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ContourDetection(object):
|
||||
def __init__(self):
|
||||
# logging.info("ContourDetection run ")
|
||||
pass
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
# shoe diff
|
||||
if result['name'] == 'shoes':
|
||||
Contour = self.get_contours(result['image'])
|
||||
Mask = np.zeros(result['image'].shape[:2], np.uint8)
|
||||
for i in range(2):
|
||||
Max_contour = Contour[i]
|
||||
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
|
||||
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
|
||||
cv2.drawContours(Mask, [Approx], -1, 255, -1)
|
||||
if result['pre_mask'] is None:
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
else:
|
||||
Contour = self.get_contours(result['image'])
|
||||
Mask = np.zeros(result['image'].shape[:2], np.uint8)
|
||||
if len(Contour):
|
||||
Max_contour = Contour[0]
|
||||
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
|
||||
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
|
||||
cv2.drawContours(Mask, [Approx], -1, 255, -1)
|
||||
else:
|
||||
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
|
||||
# TODO 修复部分图片出现透明的情况 下版本上线
|
||||
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
|
||||
# Mask = cv2.bitwise_not(Mask)
|
||||
if result['pre_mask'] is None:
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_contours(image):
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
Edge = cv2.Canny(gray, 10, 150)
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
|
||||
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
return Contour
|
||||
139
app/service/design/items/pipelines/keypoints.py
Normal file
139
app/service/design/items/pipelines/keypoints.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.design_ensemble import get_keypoint_result
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class KeypointDetection(object):
|
||||
"""
|
||||
path here: abstract path
|
||||
"""
|
||||
|
||||
# def __init__(self):
|
||||
# self.client = MilvusClient(
|
||||
# uri="http://10.1.1.240:19530",
|
||||
# token="root:Milvus",
|
||||
# db_name=MILVUS_ALIAS
|
||||
# )
|
||||
|
||||
# def __del__(self):
|
||||
# start_time = time.time()
|
||||
# self.client.close()
|
||||
# print(f"client close time : {time.time() - start_time}")
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
# logging.info("KeypointDetection run ")
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
|
||||
keypoint_cache = self.keypoint_cache(result, site)
|
||||
# 取消向量查询 直接过模型推理
|
||||
# keypoint_cache = False
|
||||
|
||||
if keypoint_cache is False:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
else:
|
||||
result['clothes_keypoint'] = keypoint_cache
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def infer_keypoint_result(result):
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
start_time = time.time()
|
||||
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
|
||||
# logging.info(f"infer keypoint time : {time.time() - start_time}")
|
||||
return keypoint_infer_result, site
|
||||
|
||||
@staticmethod
|
||||
# @ RunTime
|
||||
def save_keypoint_cache(keypoint_id, cache, site):
|
||||
if site == "down":
|
||||
zeros = np.zeros(20, dtype=int)
|
||||
result = np.concatenate([zeros, cache.flatten()])
|
||||
else:
|
||||
zeros = np.zeros(4, dtype=int)
|
||||
result = np.concatenate([cache.flatten(), zeros])
|
||||
# 取消向量保存 直接拿结果
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": site,
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
# start_time = time.time()
|
||||
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
client.close()
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logging.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
@staticmethod
|
||||
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
||||
if site == "up":
|
||||
# 需要的是up 即推理出来的是up 那么查询的就是down
|
||||
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||
else:
|
||||
# 需要的是down 即推理出来的是down 那么查询的就是up
|
||||
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": "all",
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
start_time = time.time()
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# mr = collection.upsert(data)
|
||||
client.upsert(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
data=data
|
||||
)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logging.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
# @ RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
keypoint_id = result['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
# ids=[keypoint_id],
|
||||
filter=f"keypoint_id == {keypoint_id}",
|
||||
output_fields=['keypoint_vector', 'keypoint_site']
|
||||
)
|
||||
if len(res) == 0:
|
||||
# 没有结果 直接推理拿结果 并保存
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
||||
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
||||
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||
elif res[0]["keypoint_site"] != site:
|
||||
# 需要的类型和查询到的不一致,则更新类型为all
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
||||
except Exception as e:
|
||||
logging.info(f"search keypoint cache milvus error {e}")
|
||||
return False
|
||||
130
app/service/design/items/pipelines/loading.py
Normal file
130
app/service/design/items/pipelines/loading.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import cv2
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadImageFromFile(object):
|
||||
def __init__(self, path, color=None, print_dict=None):
|
||||
self.path = path
|
||||
self.color = color
|
||||
self.print_dict = print_dict
|
||||
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
def __call__(self, result):
|
||||
result['image'], result['pre_mask'] = self.read_image(self.path)
|
||||
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
result['keypoint'] = self.get_keypoint(result['name'])
|
||||
result['path'] = self.path
|
||||
result['img_shape'] = result['image'].shape
|
||||
result['ori_shape'] = result['image'].shape
|
||||
result['color'] = self.color if self.color is not None else None
|
||||
result['print_dict'] = self.print_dict
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_keypoint(name):
|
||||
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
|
||||
keypoint = 'shoulder'
|
||||
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
|
||||
keypoint = 'waistband'
|
||||
elif name == 'bag':
|
||||
keypoint = 'hand_point'
|
||||
elif name == 'shoes':
|
||||
keypoint = 'toe'
|
||||
elif name == 'hairstyle':
|
||||
keypoint = 'head_point'
|
||||
elif name == 'earring':
|
||||
keypoint = 'ear_point'
|
||||
else:
|
||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||
f"bag, shoes, hairstyle, earring.")
|
||||
return keypoint
|
||||
|
||||
@staticmethod
|
||||
def read_image(image_path):
|
||||
image_mask = None
|
||||
# file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data
|
||||
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
|
||||
|
||||
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
if image.shape[2] == 4: # 如果是四通道 mask
|
||||
image_mask = image[:, :, 3]
|
||||
image = image[:, :, :3]
|
||||
return image, image_mask
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadBodyImageFromFile(object):
|
||||
def __init__(self, body_path):
|
||||
self.body_path = body_path
|
||||
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
result["image_url"] = result['body_path'] = self.body_path
|
||||
result["name"] = "mannequin"
|
||||
# if not result['image_url'].lower().endswith(".png"):
|
||||
# bucket = self.body_path.split("/", 1)[0]
|
||||
# object_name = self.body_path.split("/", 1)[1]
|
||||
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
||||
# image = self.minioClient.get_object(bucket, object_name)
|
||||
# image = Image.open(io.BytesIO(image.data))
|
||||
# image = image.convert("RGBA")
|
||||
# data = image.getdata()
|
||||
# #
|
||||
# new_data = []
|
||||
# for item in data:
|
||||
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
|
||||
# new_data.append((255, 255, 255, 0))
|
||||
# else:
|
||||
# new_data.append(item)
|
||||
# image.putdata(new_data)
|
||||
# image_data = io.BytesIO()
|
||||
# image.save(image_data, format='PNG')
|
||||
# image_data.seek(0)
|
||||
# image_bytes = image_data.read()
|
||||
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
# self.body_path = image_path
|
||||
# result["image_url"] = result['body_path'] = self.body_path
|
||||
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
|
||||
# put_image_time = time.time()
|
||||
# result['body_image'] = Image.open(io.BytesIO(response.read()))
|
||||
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
|
||||
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
|
||||
return result
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ImageShow(object):
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
import matplotlib.pyplot as plt
|
||||
if isinstance(self.key, list):
|
||||
for key in self.key:
|
||||
plt.imshow(result[key])
|
||||
plt.title(key)
|
||||
plt.show()
|
||||
elif isinstance(self.key, str):
|
||||
img = self._resize_img(result[self.key])
|
||||
cv2.imshow(self.key, img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
raise TypeError(f'key should be string but got type {type(self.key)}.')
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _resize_img(img):
|
||||
shape = img.shape
|
||||
if shape[0] > 400 or shape[1] > 400:
|
||||
ratio = min(400 / shape[0], 400 / shape[1])
|
||||
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
|
||||
return img
|
||||
611
app/service/design/items/pipelines/painting.py
Normal file
611
app/service/design/items/pipelines/painting.py
Normal file
@@ -0,0 +1,611 @@
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Painting(object):
|
||||
def __init__(self, painting_flag=True):
|
||||
self.painting_flag = painting_flag
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
|
||||
dim_image_h, dim_image_w = result['image'].shape[0:2]
|
||||
if "gradient" in result.keys() and result['gradient'] != "":
|
||||
bucket_name = result['gradient'].split('/')[0]
|
||||
object_name = result['gradient'][result['gradient'].find('/') + 1:]
|
||||
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
pattern = self.get_pattern(result['color'])
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
result['alpha'] = 100 / 255.0
|
||||
else:
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = result['image'] * (closed_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_gradient(bucket_name, object_name):
|
||||
# image_data = minio_client.get_object(bucket_name, object_name)
|
||||
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
|
||||
|
||||
# 从数据流中读取图像
|
||||
# image_bytes = image_data.read()
|
||||
|
||||
# 将图像数据转换为numpy数组
|
||||
# image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
|
||||
|
||||
# 使用OpenCV解码图像数组
|
||||
# image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def crop_image(image, image_size_h, image_size_w):
|
||||
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
|
||||
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
return pattern
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class PrintPainting(object):
|
||||
def __init__(self, print_flag=True):
|
||||
self.print_flag = print_flag
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
single_print = result['print']['single']
|
||||
overall_print = result['print']['overall']
|
||||
element_print = result['print']['element']
|
||||
result['single_image'] = None
|
||||
result['print_image'] = None
|
||||
if overall_print['print_path_list']:
|
||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||
result['print_image'] = result['pattern_image']
|
||||
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
|
||||
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-result['print']['print_angle_list'][0], crop=True)
|
||||
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-result['print']['print_angle_list'][0], crop=True)
|
||||
|
||||
# resize 到sketch大小
|
||||
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
else:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
|
||||
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
|
||||
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
|
||||
|
||||
if single_print['print_path_list']:
|
||||
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(single_print['print_path_list'])):
|
||||
image, image_mode = self.read_image(single_print['print_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
|
||||
if element_print['element_path_list']:
|
||||
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(element_print['element_path_list'])):
|
||||
image, image_mode = self.read_image(element_print['element_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
print(1)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
# TODO element 丢失信息
|
||||
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
||||
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
|
||||
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
|
||||
|
||||
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
|
||||
img2gray = cv2.cvtColor(print_background, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
|
||||
|
||||
mask_inv = cv2.bitwise_not(mask_)
|
||||
|
||||
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_)
|
||||
|
||||
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_inv)
|
||||
|
||||
print_background = img1_bg + img2_fg
|
||||
|
||||
return print_background
|
||||
|
||||
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
|
||||
if print_trigger:
|
||||
print_ = self.get_print(print_dict)
|
||||
painting_dict['Trigger'] = not is_single
|
||||
painting_dict['location'] = print_['location']
|
||||
single_mask_inv_print = self.get_mask_inv(print_['image'])
|
||||
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||
if not is_single:
|
||||
self.random_seed = random.randint(0, 1000)
|
||||
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
||||
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
|
||||
return painting_dict
|
||||
|
||||
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
||||
tile = None
|
||||
if not trigger:
|
||||
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
if len(pattern.shape) == 2:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
|
||||
if len(pattern.shape) == 3:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
|
||||
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
|
||||
return tile
|
||||
|
||||
def get_mask_inv(self, print_):
|
||||
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
|
||||
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
mask_inv = cv2.inRange(print_tile, lower, upper)
|
||||
return mask_inv
|
||||
else:
|
||||
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
|
||||
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
|
||||
return mask_inv
|
||||
|
||||
@staticmethod
|
||||
def printpaint(result, painting_dict, print_=False):
|
||||
|
||||
if print_ and painting_dict['Trigger']:
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
||||
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
||||
else:
|
||||
print_mask = result['mask']
|
||||
img_fg = result['final_image']
|
||||
if print_ and not painting_dict['Trigger']:
|
||||
index_ = None
|
||||
try:
|
||||
index_ = len(painting_dict['location'])
|
||||
except:
|
||||
assert f'there must be parameter of location if choose IfSingle'
|
||||
|
||||
for i in range(index_):
|
||||
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
||||
|
||||
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
||||
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
||||
|
||||
change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
||||
# problem in change_mask
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
print_image = cv2.add(img_bg, img_fg)
|
||||
return print_image
|
||||
|
||||
@staticmethod
|
||||
def get_print(print_dict):
|
||||
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
|
||||
print_dict['scale'] = 0.3
|
||||
else:
|
||||
print_dict['scale'] = print_dict['print_scale_list'][0]
|
||||
|
||||
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
|
||||
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
|
||||
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL")
|
||||
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
||||
if image.mode == "RGBA":
|
||||
new_background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
new_background.paste(image, mask=image.split()[3])
|
||||
image = new_background
|
||||
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||
return print_dict
|
||||
|
||||
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
|
||||
print_w = print_shape[1]
|
||||
print_h = print_shape[0]
|
||||
|
||||
random.seed(self.random_seed)
|
||||
# logging.info(f'overall print location : {location}')
|
||||
# x_offset = random.randint(0, image.shape[0] - image_size_h)
|
||||
# y_offset = random.randint(0, image.shape[1] - image_size_w)
|
||||
|
||||
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||
x_offset = print_w - int(location[0][1] % print_w)
|
||||
y_offset = print_w - int(location[0][0] % print_h)
|
||||
|
||||
# y_offset = int(location[0][0])
|
||||
# x_offset = int(location[0][1])
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||
elif len(image.shape) == 3:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_low_high_lab(Lab_value, L=False):
|
||||
if L:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
else:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
return high, low
|
||||
|
||||
@staticmethod
|
||||
def img_rotate(image, angel, scale):
|
||||
"""顺时针旋转图像任意角度
|
||||
|
||||
Args:
|
||||
image (np.array): [原始图像]
|
||||
angel (float): [逆时针旋转的角度]
|
||||
|
||||
Returns:
|
||||
[array]: [旋转后的图像]
|
||||
"""
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
# if type(angel) is not int:
|
||||
# angel = 0
|
||||
M = cv2.getRotationMatrix2D(center, -angel, scale)
|
||||
# 调整旋转后的图像长宽
|
||||
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
|
||||
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
|
||||
M[0, 2] += (rotated_w - w) // 2
|
||||
M[1, 2] += (rotated_h - h) // 2
|
||||
# 旋转图像
|
||||
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
|
||||
|
||||
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
|
||||
# return rotated_img, (0, 0)
|
||||
|
||||
@staticmethod
|
||||
def rotate_crop_image(img, angle, crop):
|
||||
"""
|
||||
angle: 旋转的角度
|
||||
crop: 是否需要进行裁剪,布尔向量
|
||||
"""
|
||||
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
|
||||
w, h = img.shape[:2]
|
||||
# 旋转角度的周期是360°
|
||||
angle %= 360
|
||||
# 计算仿射变换矩阵
|
||||
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
||||
# 得到旋转后的图像
|
||||
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
|
||||
|
||||
# 如果需要去除黑边
|
||||
if crop:
|
||||
# 裁剪角度的等效周期是180°
|
||||
angle_crop = angle % 180
|
||||
if angle > 90:
|
||||
angle_crop = 180 - angle_crop
|
||||
# 转化角度为弧度
|
||||
theta = angle_crop * np.pi / 180
|
||||
# 计算高宽比
|
||||
hw_ratio = float(h) / float(w)
|
||||
# 计算裁剪边长系数的分子项
|
||||
tan_theta = np.tan(theta)
|
||||
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
|
||||
|
||||
# 计算分母中和高宽比相关的项
|
||||
r = hw_ratio if h > w else 1 / hw_ratio
|
||||
# 计算分母项
|
||||
denominator = r * tan_theta + 1
|
||||
# 最终的边长系数
|
||||
crop_mult = numerator / denominator
|
||||
|
||||
# 得到裁剪区域
|
||||
w_crop = int(crop_mult * w)
|
||||
h_crop = int(crop_mult * h)
|
||||
x0 = int((w - w_crop) / 2)
|
||||
y0 = int((h - h_crop) / 2)
|
||||
|
||||
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
|
||||
|
||||
return img_rotated
|
||||
|
||||
@staticmethod
|
||||
def read_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
image = Image.fromarray(image_rgb)
|
||||
image_mode = "RGBA"
|
||||
else:
|
||||
image_mode = "RGB"
|
||||
return image, image_mode
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop(img, target_width, target_height):
|
||||
# 获取原始图像的尺寸
|
||||
original_height, original_width = img.shape[:2]
|
||||
|
||||
# 计算目标尺寸的宽高比
|
||||
target_ratio = target_width / target_height
|
||||
|
||||
# 计算原始图像的宽高比
|
||||
original_ratio = original_width / original_height
|
||||
|
||||
# 调整尺寸
|
||||
if original_ratio > target_ratio:
|
||||
# 原始图像更宽,按高度resize,然后裁剪宽度
|
||||
new_height = target_height
|
||||
new_width = int(original_width * (target_height / original_height))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪宽度
|
||||
start_x = (new_width - target_width) // 2
|
||||
cropped_img = resized_img[:, start_x:start_x + target_width]
|
||||
else:
|
||||
# 原始图像更高,按宽度resize,然后裁剪高度
|
||||
new_width = target_width
|
||||
new_height = int(original_height * (target_width / original_width))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪高度
|
||||
start_y = (new_height - target_height) // 2
|
||||
cropped_img = resized_img[start_y:start_y + target_height, :]
|
||||
|
||||
return cropped_img
|
||||
56
app/service/design/items/pipelines/scale.py
Normal file
56
app/service/design/items/pipelines/scale.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Scaling(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
|
||||
# milvus_db_keypoint_cache
|
||||
distance_clo = math.sqrt(
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
|
||||
+
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2)
|
||||
|
||||
distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
|
||||
# distance_clo = math.sqrt(
|
||||
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2
|
||||
# +
|
||||
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2)
|
||||
#
|
||||
# distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
|
||||
if distance_clo == 0:
|
||||
result['scale'] = 1
|
||||
else:
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'toe':
|
||||
distance_bdy = math.sqrt(
|
||||
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
|
||||
+
|
||||
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
|
||||
)
|
||||
|
||||
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
|
||||
Edge = cv2.Canny(Blur, 10, 200)
|
||||
Edge = cv2.dilate(Edge, None)
|
||||
Edge = cv2.erode(Edge, None)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
|
||||
Max_contour = Contours[0]
|
||||
x, y, w, h = cv2.boundingRect(Max_contour)
|
||||
width = w
|
||||
distance_clo = width
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'hand_point':
|
||||
result['scale'] = result['scale_bag']
|
||||
elif result['keypoint'] == 'ear_point':
|
||||
result['scale'] = result['scale_earrings']
|
||||
return result
|
||||
14
app/service/design/items/pipelines/segmentation.py
Normal file
14
app/service/design/items/pipelines/segmentation.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.design_ensemble import get_seg_result
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Segmentation(object):
|
||||
def __init__(self, device='cpu', show=False, debug=None):
|
||||
self.show = show
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
def __call__(self, result):
|
||||
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
|
||||
return result
|
||||
81
app/service/design/items/pipelines/split.py
Normal file
81
app/service/design/items/pipelines/split.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.conversion_image import rgb_to_rgba
|
||||
from ...utils.upload_image import upload_png_mask
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Split(object):
|
||||
"""
|
||||
Split image into front and back layer according to the segmentation result
|
||||
"""
|
||||
|
||||
# KNet
|
||||
def __call__(self, result):
|
||||
try:
|
||||
if 'mask' not in result.keys():
|
||||
raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.')
|
||||
if 'seg_result' not in result.keys(): # 没过seg模型
|
||||
result['front_mask'] = result['mask'].copy()
|
||||
result['back_mask'] = np.zeros_like(result['mask'])
|
||||
else:
|
||||
temp_front = result['seg_result'] == 1
|
||||
result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = result['seg_result'] == 2
|
||||
result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8))
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
if len(result['front_mask'].shape) > 2:
|
||||
front_mask = result['front_mask'][0]
|
||||
else:
|
||||
front_mask = result['front_mask']
|
||||
|
||||
if len(result['back_mask'].shape) > 2:
|
||||
back_mask = result['back_mask'][0]
|
||||
else:
|
||||
back_mask = result['back_mask']
|
||||
|
||||
rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask'])
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"][1]))
|
||||
result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS)
|
||||
# result['front_mask_image'] = cv2.resize(front_mask, front_new_size)
|
||||
# result['front_image'] = result_front_image_pil
|
||||
front_mask = cv2.resize(front_mask, front_new_size)
|
||||
result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask)
|
||||
|
||||
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"][1]))
|
||||
result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS)
|
||||
# result['back_mask_image'] = cv2.resize(back_mask, back_new_size)
|
||||
# result['back_image'] = result_back_image_pil
|
||||
|
||||
back_mask = cv2.resize(back_mask, back_new_size)
|
||||
result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask)
|
||||
else:
|
||||
result['back_image'] = None
|
||||
result["back_image_url"] = None
|
||||
result["back_mask_url"] = None
|
||||
result['back_mask_image'] = None
|
||||
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba((result['pattern_image'].shape[0], result['pattern_image'].shape[1]), result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
_, result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
||||
121
app/service/design/items/shoes.py
Normal file
121
app/service/design/items/shoes.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
from ..utils.conversion_image import rgb_to_rgba
|
||||
from ..utils.upload_image import upload_png_mask
|
||||
from ...utils.generate_uuid import generate_uuid
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Shoes(Clothing):
|
||||
# TODO location of shoes has little mismatch
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Shoes, self).__init__(**kwargs)
|
||||
|
||||
def organize(self, layer):
|
||||
left_shoe_mask, right_shoe_mask = self.cut()
|
||||
|
||||
left_layer = dict(name=f'{type(self).__name__.lower()}_left',
|
||||
image=self.result['shoes_left'],
|
||||
image_url=self.result['left_image_url'],
|
||||
mask_url=self.result['left_mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=self.calculate_start_point(self.result['keypoint'],
|
||||
self.result['scale'],
|
||||
self.result['clothes_keypoint'],
|
||||
self.result['body_point'],
|
||||
'left'))
|
||||
layer.insert(left_layer)
|
||||
|
||||
right_layer = dict(name=f'{type(self).__name__.lower()}_right',
|
||||
image=self.result['shoes_right'],
|
||||
image_url=self.result['right_image_url'],
|
||||
mask_url=self.result['right_mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=self.calculate_start_point(self.result['keypoint'],
|
||||
self.result['scale'],
|
||||
self.result['clothes_keypoint'],
|
||||
self.result['body_point'],
|
||||
'right'))
|
||||
|
||||
layer.insert(right_layer)
|
||||
|
||||
def cut(self):
|
||||
"""
|
||||
Cut shoes mask into two pieces
|
||||
Returns:
|
||||
"""
|
||||
contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contours = sorted(contour, key=cv2.contourArea, reverse=True)
|
||||
|
||||
bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]]
|
||||
(contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False))
|
||||
|
||||
epsilon_left = 0.001 * cv2.arcLength(contours[0], True)
|
||||
|
||||
approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True)
|
||||
mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8)
|
||||
cv2.drawContours(mask_left, [approx_left], -1, 255, -1)
|
||||
item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0)
|
||||
|
||||
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left)
|
||||
result_image = np.zeros_like(rgba_image)
|
||||
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
|
||||
result_left_image_pil = Image.fromarray(result_image, 'RGBA')
|
||||
result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS)
|
||||
self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}")
|
||||
|
||||
epsilon_right = 0.001 * cv2.arcLength(contours[1], True)
|
||||
approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True)
|
||||
mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8)
|
||||
cv2.drawContours(mask_right, [approx_right], -1, 255, -1)
|
||||
item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0)
|
||||
|
||||
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right)
|
||||
result_image = np.zeros_like(rgba_image)
|
||||
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
|
||||
result_right_image_pil = Image.fromarray(result_image, 'RGBA')
|
||||
result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS)
|
||||
self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}")
|
||||
|
||||
return item_mask_left, item_mask_right
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location):
|
||||
"""
|
||||
left shoes align left
|
||||
right shoes align right
|
||||
Args:
|
||||
keypoint_type: string, "toe"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
location: string, indicates whether the start point belongs to right or left shoe
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
if location not in ['left', 'right']:
|
||||
raise KeyError(f'location value must be left or right but got {location}')
|
||||
side_indicator = f'{keypoint_type}_{location}'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
|
||||
body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
|
||||
return start_point
|
||||
46
app/service/design/items/top.py
Normal file
46
app/service/design/items/top.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Top(Clothing):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
if pipeline is None:
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Top, self).__init__(**kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Blouse(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Blouse, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Outwear(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Outwear, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Dress(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Dress, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
# Men's clothing
|
||||
@ITEMS.register_module()
|
||||
class Tops(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Tops, self).__init__(pipeline, **kwargs)
|
||||
28
app/service/design/model_process_service.py
Normal file
28
app/service/design/model_process_service.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import io
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
|
||||
def model_transpose(image_path):
|
||||
bucket = image_path.split("/", 1)[0]
|
||||
object_name = image_path.split("/", 1)[1]
|
||||
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
||||
image = oss_get_image(bucket=bucket, object_name=object_name, data_type="PIL")
|
||||
image = image.convert("RGBA")
|
||||
data = image.getdata()
|
||||
#
|
||||
new_data = []
|
||||
for item in data:
|
||||
if item[0] >= 256 and item[1] >= 256 and item[2] >= 256:
|
||||
new_data.append((255, 255, 255, 0))
|
||||
else:
|
||||
new_data.append(item)
|
||||
image.putdata(new_data)
|
||||
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
oss_upload_image(bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
|
||||
image_path = f"{bucket}/{new_object_name}"
|
||||
return image_path
|
||||
134
app/service/design/service.py
Normal file
134
app/service/design/service.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import concurrent.futures
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
from app.service.design.core.layer import Layer
|
||||
from app.service.design.items import build_item
|
||||
from app.service.design.utils.redis_utils import Redis
|
||||
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
|
||||
from app.service.utils.decorator import RunTime
|
||||
|
||||
|
||||
def process_item(item, layers):
|
||||
# logging.info("process running.........")
|
||||
item.process()
|
||||
item.organize(layers)
|
||||
if item.result['name'] == "mannequin":
|
||||
return item.result['body_image'].size
|
||||
|
||||
|
||||
def update_progress(process_id, total):
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
if progress and total != 1:
|
||||
if int(progress) <= 100:
|
||||
r.write(key=process_id, value=int(progress) + int(100 / total))
|
||||
else:
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
elif total == 1:
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
else:
|
||||
r.write(key=process_id, value=int(100 / total))
|
||||
return progress
|
||||
|
||||
|
||||
def final_progress(process_id):
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
|
||||
|
||||
@RunTime
|
||||
def generate(request_data):
|
||||
return_response = {}
|
||||
request_data = request_data.dict()
|
||||
assert "process_id" in request_data.keys(), "Need process_id parameters"
|
||||
|
||||
objects = request_data['objects']
|
||||
# insert_keypoint_cache(objects)
|
||||
process_id = request_data['process_id']
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# 提交每个对象的处理任务
|
||||
futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)}
|
||||
# 获取处理结果
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
obj = futures[future]
|
||||
|
||||
result = future.result()
|
||||
return_response[obj] = result
|
||||
final_progress(process_id)
|
||||
return return_response
|
||||
|
||||
|
||||
def process_object(cfg, process_id, total):
|
||||
basic_info = cfg.get('basic')
|
||||
items_response = {
|
||||
'layers': []
|
||||
}
|
||||
if cfg.get('basic')['single_overall'] == 'overall':
|
||||
basic_info['debug'] = False
|
||||
items = [build_item(x, default_args=basic_info) for x in cfg.get('items')]
|
||||
layers = Layer()
|
||||
body_size = None
|
||||
futures = []
|
||||
for item in items:
|
||||
futures = [process_item(item, layers)]
|
||||
for future in futures:
|
||||
if future is not None:
|
||||
body_size = future
|
||||
# 是否自定义排序
|
||||
if basic_info.get('layer_order', False):
|
||||
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
|
||||
else:
|
||||
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
|
||||
# 合成
|
||||
items_response['synthesis_url'] = synthesis(layers, body_size)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
'image_category': lay['name'],
|
||||
'position': lay['position'],
|
||||
'priority': lay.get("priority", None),
|
||||
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
|
||||
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
|
||||
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
||||
'mask_url': lay['mask_url'],
|
||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
|
||||
|
||||
# 'image': lay['image'],
|
||||
# 'mask_image': lay['mask_image'],
|
||||
})
|
||||
elif cfg.get('basic')['single_overall'] == 'single':
|
||||
assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters "
|
||||
basic_info['debug'] = False
|
||||
for item in cfg.get('items'):
|
||||
if item['type'] == cfg.get('basic')['switch_category']:
|
||||
item = build_item(item, default_args=cfg.get('basic'))
|
||||
item.process()
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item.result['name']}_front",
|
||||
'image_size': item.result['back_image'].size if item.result['back_image'] else None,
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item.result['front_image_url'],
|
||||
'mask_url': item.result['front_mask_url'],
|
||||
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else ""
|
||||
|
||||
})
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item.result['name']}_back",
|
||||
'image_size': item.result['front_image'].size if item.result['front_image'] else None,
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item.result['back_image_url'],
|
||||
'mask_url': item.result['back_mask_url'],
|
||||
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else ""
|
||||
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
|
||||
break
|
||||
update_progress(process_id, total)
|
||||
return items_response
|
||||
0
app/service/design/utils/__init__.py
Normal file
0
app/service/design/utils/__init__.py
Normal file
24
app/service/design/utils/conversion_image.py
Normal file
24
app/service/design/utils/conversion_image.py
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :conversion_image.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/21 10:40:29
|
||||
@detail :
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rgb_to_rgba(rgb_size, rgb_image, mask):
|
||||
alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
|
||||
# 创建四通道的结果图像
|
||||
rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
alpha_channel = np.where(mask > 0, 255, 0)
|
||||
# 更新RGBA图像的透明度通道
|
||||
rgba_image[:, :, 3] = alpha_channel
|
||||
return rgba_image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = open("")
|
||||
140
app/service/design/utils/design_ensemble.py
Normal file
140
app/service/design/utils/design_ensemble.py
Normal file
@@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :design_ensemble.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/16 19:36:21
|
||||
@detail :发起请求 获取推理结果
|
||||
"""
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
"""
|
||||
keypoint
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
def keypoint_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
img_scale = (256, 256)
|
||||
img, w_scale, h_scale = mmcv.imresize(img, img_scale, return_scale=True)
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, (w_scale, h_scale)
|
||||
|
||||
|
||||
# @ RunTime
|
||||
# 推理
|
||||
def get_keypoint_result(image, site):
|
||||
keypoint_result = None
|
||||
try:
|
||||
image, scale_factor = keypoint_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
transformed_img = image.astype(np.float32)
|
||||
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
|
||||
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output'))
|
||||
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
|
||||
except Exception as e:
|
||||
logging.warning(f"get_keypoint_result : {e}")
|
||||
return keypoint_result
|
||||
|
||||
|
||||
def keypoint_postprocess(output, scale_factor):
|
||||
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
|
||||
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
|
||||
segment_result = max_coords.numpy()
|
||||
scale_factor = [1 / x for x in scale_factor[::-1]]
|
||||
scale_matrix = np.diag(scale_factor)
|
||||
nan = np.isinf(scale_matrix)
|
||||
scale_matrix[nan] = 0
|
||||
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||||
|
||||
|
||||
"""
|
||||
seg
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
# KNet
|
||||
def seg_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale_w, img_scale_h = ori_shape
|
||||
if ori_shape[0] > 1024:
|
||||
img_scale_w = 1024
|
||||
if ori_shape[1] > 1024:
|
||||
img_scale_h = 1024
|
||||
scale_factor = []
|
||||
img, x, y = mmcv.imresize(img, (img_scale_w, img_scale_h), return_scale=True)
|
||||
scale_factor.append(x)
|
||||
scale_factor.append(y)
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, ori_shape
|
||||
|
||||
|
||||
# @ RunTime
|
||||
def get_seg_result(image_id, image):
|
||||
image, ori_shape = seg_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||||
transformed_img = image.astype(np.float32)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||
]
|
||||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||||
return seg_result
|
||||
|
||||
|
||||
# no cache
|
||||
def seg_postprocess(image_id, output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
|
||||
def key_point_show(image_path, key_point_result=None):
|
||||
img = cv2.imread(image_path)
|
||||
points_list = key_point_result
|
||||
point_size = 1
|
||||
point_color = (0, 0, 255) # BGR
|
||||
thickness = 4 # 可以为 0 、4、8
|
||||
for point in points_list:
|
||||
cv2.circle(img, point[::-1], point_size, point_color, thickness)
|
||||
cv2.imshow("0", img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = cv2.imread("./14162b58-f259-4833-98cb-89b9b496b251.jfif")
|
||||
a = get_keypoint_result(image, "up")
|
||||
new_list = []
|
||||
print(list)
|
||||
for i in a[0]:
|
||||
new_list.append((int(i[0]), int(i[1])))
|
||||
key_point_show("./14162b58-f259-4833-98cb-89b9b496b251.jfif", new_list)
|
||||
# a = get_seg_result(1, image)
|
||||
print(a)
|
||||
99
app/service/design/utils/redis_utils.py
Normal file
99
app/service/design/utils/redis_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
167
app/service/design/utils/synthesis_item.py
Normal file
167
app/service/design/utils/synthesis_item.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :synthesis_item.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/26 14:13:04
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
def positioning(all_mask_shape, mask_shape, offset):
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
mask_start = 0
|
||||
mask_end = 0
|
||||
if offset == 0:
|
||||
all_start = 0
|
||||
all_end = min(all_mask_shape, mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = min(all_mask_shape, mask_shape)
|
||||
elif offset > 0:
|
||||
all_start = min(offset, all_mask_shape)
|
||||
all_end = min(offset + mask_shape, all_mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
|
||||
elif offset < 0:
|
||||
if abs(offset) > mask_shape:
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
else:
|
||||
all_start = 0
|
||||
if mask_shape - abs(offset) > all_mask_shape:
|
||||
all_end = min(mask_shape - abs(offset), all_mask_shape)
|
||||
else:
|
||||
all_end = mask_shape - abs(offset)
|
||||
|
||||
if abs(offset) > mask_shape:
|
||||
mask_start = mask_shape
|
||||
mask_end = mask_shape
|
||||
else:
|
||||
mask_start = abs(offset)
|
||||
if mask_shape - abs(offset) >= all_mask_shape:
|
||||
mask_end = all_mask_shape + abs(offset)
|
||||
else:
|
||||
mask_end = mask_shape
|
||||
return all_start, all_end, mask_start, mask_end
|
||||
|
||||
|
||||
# @RunTime
|
||||
def synthesis(data, size):
|
||||
# 创建底图
|
||||
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
try:
|
||||
|
||||
all_mask_shape = (size[1], size[0])
|
||||
top_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
|
||||
bottom_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
|
||||
|
||||
top = True
|
||||
bottom = True
|
||||
i = len(data)
|
||||
while i:
|
||||
i -= 1
|
||||
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
||||
top = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
top_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front"]:
|
||||
bottom = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
bottom_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
elif bottom is False and top is False:
|
||||
break
|
||||
|
||||
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
|
||||
|
||||
for layer in data:
|
||||
if layer['image'] is not None:
|
||||
if layer['name'] != "body":
|
||||
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
test_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
|
||||
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||
mask_alpha = Image.fromarray(mask_data)
|
||||
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||
base_image.paste(cropped_image, (0, 0), cropped_image)
|
||||
else:
|
||||
base_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
|
||||
|
||||
result_image = base_image
|
||||
|
||||
with io.BytesIO() as output:
|
||||
result_image.save(output, format='PNG')
|
||||
data = output.getvalue()
|
||||
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
|
||||
# oss upload
|
||||
image_bytes = image_data.read()
|
||||
bucket_name = 'aida-results'
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"synthesis runtime exception : {e}")
|
||||
|
||||
|
||||
def synthesis_single(front_image, back_image):
|
||||
result_image = None
|
||||
if front_image:
|
||||
result_image = front_image
|
||||
if back_image:
|
||||
result_image.paste(back_image, (0, 0), back_image)
|
||||
|
||||
# with io.BytesIO() as output:
|
||||
# result_image.save(output, format='PNG')
|
||||
# data = output.getvalue()
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
# oss upload
|
||||
bucket_name = 'aida-results'
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
45
app/service/design/utils/upload_image.py
Normal file
45
app/service/design/utils/upload_image.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :upload_image.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/28 13:49:20
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
# @RunTime
|
||||
def upload_png_mask(front_image, object_name, mask=None):
|
||||
try:
|
||||
mask_url = None
|
||||
if mask is not None:
|
||||
mask_inverted = cv2.bitwise_not(mask)
|
||||
# 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明
|
||||
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
|
||||
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
|
||||
# image_bytes = io.BytesIO()
|
||||
# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
|
||||
# image_bytes.seek(0)
|
||||
# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}"
|
||||
# oss upload ####################
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
|
||||
|
||||
image_data = io.BytesIO()
|
||||
front_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
|
||||
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
|
||||
return front_image, image_url, mask_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
374
app/service/design_pre_processing/service.py
Normal file
374
app/service/design_pre_processing/service.py
Normal file
@@ -0,0 +1,374 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tritonclient.grpc as grpcclient
|
||||
from urllib3.exceptions import ResponseError
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.pre_processing import DesignPreProcessingModel
|
||||
from app.service.design.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
|
||||
class DesignPreprocessing:
|
||||
# def __init__(self):
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# @ RunTime
|
||||
def pipeline(self, image_list):
|
||||
sketches_list = self.read_image(image_list)
|
||||
logging.info("read image success")
|
||||
|
||||
bounding_box_sketches_list = self.bounding_box(sketches_list)
|
||||
logging.info("bounding box image success")
|
||||
|
||||
super_resolution_list = self.super_resolution(bounding_box_sketches_list)
|
||||
logging.info("super_resolution_list image success")
|
||||
|
||||
infer_sketches_list = self.infer_image(super_resolution_list)
|
||||
logging.info("infer image success")
|
||||
|
||||
result = self.composing_image(infer_sketches_list)
|
||||
logging.info("Replenish white edge image success")
|
||||
|
||||
for d in result:
|
||||
if 'image_obj' in d:
|
||||
del d['image_obj']
|
||||
if 'obj' in d:
|
||||
del d['obj']
|
||||
if 'keypoint_result' in d:
|
||||
del d['keypoint_result']
|
||||
return result
|
||||
|
||||
def read_image(self, image_list):
|
||||
for obj in image_list:
|
||||
# file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
|
||||
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
|
||||
image = oss_get_image(bucket=obj['image_url'].split("/", 1)[0], object_name=obj['image_url'].split("/", 1)[1], data_type="cv2")
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
elif image.shape[2] == 4: # 如果是四通道 mask
|
||||
image = image[:, :, :3]
|
||||
obj["image_obj"] = image
|
||||
return image_list
|
||||
|
||||
# @ RunTime
|
||||
def bounding_box(self, image_list):
|
||||
for item in image_list:
|
||||
image = item['image_obj']
|
||||
# 使用Canny边缘检测来检测物体的轮廓
|
||||
edges = cv2.Canny(image, 50, 150)
|
||||
# 查找轮廓
|
||||
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
# 初始化包围所有外接矩形的大矩形的坐标
|
||||
x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1
|
||||
# 遍历所有外接矩形,更新大矩形的坐标
|
||||
for contour in contours:
|
||||
x, y, w, h = cv2.boundingRect(contour)
|
||||
x_min = min(x_min, x)
|
||||
y_min = min(y_min, y)
|
||||
x_max = max(x_max, x + w)
|
||||
y_max = max(y_max, y + h)
|
||||
|
||||
if IF_DEBUG_SHOW:
|
||||
image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
|
||||
cv2.imshow("bounding_box image", image_with_big_rect)
|
||||
cv2.waitKey(0)
|
||||
|
||||
# 根据大矩形的坐标来裁剪原始图像
|
||||
if len(contours) > 0:
|
||||
cropped_image = image[y_min:y_max, x_min:x_max]
|
||||
item['obj'] = cropped_image # 新shape图像
|
||||
# 取消直接覆盖,新增size判断
|
||||
# try:
|
||||
# # 覆盖到minio
|
||||
# image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes()
|
||||
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
|
||||
# print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||
# except ResponseError as err:
|
||||
# print(f"Error: {err}")
|
||||
else:
|
||||
item['obj'] = image
|
||||
return image_list
|
||||
|
||||
def super_resolution(self, image_list):
|
||||
for item in image_list:
|
||||
# 判断 两边是否同时都小于512 因为此处做四倍超分
|
||||
if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512:
|
||||
# 如果任意一边小于256则超分
|
||||
if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256:
|
||||
# 超分
|
||||
img = item['obj'].astype(np.float32) / 255.
|
||||
sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
||||
inputs = [
|
||||
grpcclient.InferInput("input", sample.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(sample)
|
||||
triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
|
||||
result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
|
||||
result_image = result.as_numpy(f'output')[0]
|
||||
sr_output = torch.tensor(result_image)
|
||||
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
if output.ndim == 3:
|
||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||
output = (output * 255.0).round().astype(np.uint8)
|
||||
item['obj'] = output
|
||||
try:
|
||||
# 覆盖到minio
|
||||
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
|
||||
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
|
||||
bucket_name = item['image_url'].split("/", 1)[0]
|
||||
object_name = item['image_url'].split("/", 1)[1]
|
||||
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
|
||||
except ResponseError as err:
|
||||
print(f"Error: {err}")
|
||||
return image_list
|
||||
|
||||
# @ RunTime
|
||||
def infer_image(self, image_list):
|
||||
for sketch in image_list:
|
||||
# 小写
|
||||
image_category = sketch['image_category'].lower()
|
||||
# 判断上下装
|
||||
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# 推理得到keypoint
|
||||
sketch['keypoint_result'] = self.keypoint_cache(sketch)
|
||||
|
||||
if IF_DEBUG_SHOW:
|
||||
debug_show_image = sketch['obj'].copy()
|
||||
points_list = []
|
||||
point_size = 1
|
||||
point_color = (0, 0, 255) # BGR
|
||||
thickness = 4 # 可以为 0 、4、8
|
||||
for i in sketch['keypoint_result'].values():
|
||||
points_list.append((int(i[1]), int(i[0])))
|
||||
for point in points_list:
|
||||
cv2.circle(debug_show_image, point, point_size, point_color, thickness)
|
||||
cv2.imshow("", debug_show_image)
|
||||
cv2.waitKey(0)
|
||||
# # 关键点在上部则推理seg
|
||||
# if sketch["site"] == "up":
|
||||
# # 判断seg缓存是否存在,是否与当前图片shape一致
|
||||
# seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape)
|
||||
# if seg_result is False:
|
||||
# # 推理seg + 保存
|
||||
# seg_result = get_seg_result(sketch['image_id'], sketch['obj'])
|
||||
return image_list
|
||||
|
||||
# @ RunTime
|
||||
def composing_image(self, image_list):
|
||||
for image in image_list:
|
||||
''' 比例相同 整合上下装代码'''
|
||||
image_width = image['obj'].shape[1]
|
||||
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
|
||||
scale = 0.4
|
||||
if waist_width / scale >= image_width:
|
||||
add_width = int((waist_width / scale - image_width) / 2)
|
||||
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
||||
if IF_DEBUG_SHOW:
|
||||
cv2.imshow("composing_image", ret)
|
||||
cv2.waitKey(0)
|
||||
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
||||
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||
bucket_name = image['image_url'].split('/', 1)[0]
|
||||
object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.')
|
||||
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||
else:
|
||||
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
||||
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||
bucket_name = image['image_url'].split('/', 1)[0]
|
||||
object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.')
|
||||
oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||
|
||||
# if image['site'] == 'down':
|
||||
# image_width = image['obj'].shape[1]
|
||||
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
|
||||
# scale = 0.4
|
||||
# if waist_width / scale >= image_width:
|
||||
# add_width = int((waist_width / scale - image_width) / 2)
|
||||
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
||||
# if IF_DEBUG_SHOW:
|
||||
# cv2.imshow("composing_image", ret)
|
||||
# cv2.waitKey(0)
|
||||
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
||||
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||
# bucket_name = image['image_url'].split('/', 1)[0]
|
||||
# object_name = image['image_url'].split('/', 1)[1]
|
||||
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
# image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||
# else:
|
||||
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
||||
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||
# bucket_name = image['image_url'].split('/', 1)[0]
|
||||
# object_name = image['image_url'].split('/', 1)[1]
|
||||
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
# image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||
# else:
|
||||
# image_width = image['obj'].shape[1]
|
||||
# waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
|
||||
# scale = 0.4
|
||||
# if waist_width / scale >= image_width:
|
||||
# add_width = int((waist_width / scale - image_width) / 2)
|
||||
# ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
|
||||
# if IF_DEBUG_SHOW:
|
||||
# cv2.imshow("composing_image", ret)
|
||||
# cv2.waitKey(0)
|
||||
# image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
|
||||
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||
# bucket_name = image['image_url'].split('/', 1)[0]
|
||||
# object_name = image['image_url'].split('/', 1)[1]
|
||||
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
# image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||
# else:
|
||||
# image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
|
||||
# # image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
|
||||
# bucket_name = image['image_url'].split('/', 1)[0]
|
||||
# object_name = image['image_url'].split('/', 1)[1]
|
||||
# oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
# image['show_image_url'] = f"{bucket_name}/{object_name}"
|
||||
return image_list
|
||||
|
||||
@staticmethod
|
||||
def select_seg_result(image_id, image_obj):
|
||||
try:
|
||||
# 如果shape不匹配 返回false
|
||||
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
|
||||
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
|
||||
return result
|
||||
else:
|
||||
return False
|
||||
except FileNotFoundError as e:
|
||||
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def search_seg_result(image_id, ori_shape):
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
|
||||
# collection.load()
|
||||
# start_time = time.time()
|
||||
# res = collection.query(
|
||||
# expr=f"seg_id == {image_id}",
|
||||
# offset=0,
|
||||
# limit=10,
|
||||
# output_fields=["seg_cache"],
|
||||
# )
|
||||
# logging.info(f"search seg cache time : {time.time() - start_time}")
|
||||
|
||||
# if len(res):
|
||||
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
|
||||
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
|
||||
# array_2d_exact = array_2d_exact.squeeze().numpy()
|
||||
# return array_2d_exact
|
||||
# else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
|
||||
return False
|
||||
|
||||
def keypoint_cache(self, sketch):
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# collection.load()
|
||||
start_time = time.time()
|
||||
# res = collection.query(
|
||||
# expr=f"keypoint_id == {sketch['image_id']}",
|
||||
# offset=0,
|
||||
# limit=1,
|
||||
# output_fields=["keypoint_cache", "keypoint_site"],
|
||||
# )
|
||||
res = []
|
||||
logging.info(f"search keypoint time : {time.time() - start_time}")
|
||||
if len(res) == 0:
|
||||
# 没有结果 直接推理拿结果 并保存
|
||||
keypoint_infer_result = self.infer_keypoint_result(sketch)
|
||||
return self.save_keypoint_cache(sketch, keypoint_infer_result)
|
||||
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']:
|
||||
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||
elif res[0]["keypoint_site"] != sketch['site']:
|
||||
# 需要的类型和查询到的不一致,则更新类型为all
|
||||
keypoint_infer_result = self.infer_keypoint_result(sketch)
|
||||
return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector'])
|
||||
except Exception as e:
|
||||
logging.info(f"search keypoint cache milvus error {e}")
|
||||
return False
|
||||
|
||||
# @ RunTime
|
||||
def infer_keypoint_result(self, sketch):
|
||||
keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果
|
||||
return keypoint_infer_result
|
||||
|
||||
@staticmethod
|
||||
# @ RunTime
|
||||
def save_keypoint_cache(sketch, keypoint_infer_result):
|
||||
if sketch['site'] == "down":
|
||||
zeros = np.zeros(20, dtype=int)
|
||||
result = np.concatenate([zeros, keypoint_infer_result.flatten()])
|
||||
else:
|
||||
zeros = np.zeros(4, dtype=int)
|
||||
result = np.concatenate([keypoint_infer_result.flatten(), zeros])
|
||||
data = [
|
||||
[int(sketch['image_id'])],
|
||||
[sketch['site']],
|
||||
[result.tolist()]
|
||||
]
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
start_time = time.time()
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# mr = collection.insert(data)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logging.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
@staticmethod
|
||||
def update_keypoint_cache(sketch, infer_result, search_result):
|
||||
if sketch['site'] == "up":
|
||||
# 需要的是up 即推理出来的是up 那么查询的就是down
|
||||
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||
else:
|
||||
# 需要的是down 即推理出来的是down 那么查询的就是up
|
||||
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||
data = [
|
||||
[int(sketch['image_id'])],
|
||||
["all"],
|
||||
[result.tolist()]
|
||||
]
|
||||
try:
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
start_time = time.time()
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# mr = collection.upsert(data)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logging.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = {
|
||||
"sketches": [
|
||||
{
|
||||
"image_category": "dress",
|
||||
"image_id": "107903",
|
||||
"image_url": "aida-sys-image/images/female/dress/0628000000.jpg"
|
||||
}
|
||||
]
|
||||
}
|
||||
request_data = DesignPreProcessingModel(sketches=data["sketches"])
|
||||
server = DesignPreprocessing()
|
||||
data = server.pipeline(image_list=request_data.sketches)
|
||||
print(data)
|
||||
@@ -10,21 +10,19 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import minio
|
||||
import numpy as np
|
||||
import redis
|
||||
import tritonclient.grpc as grpcclient
|
||||
import numpy as np
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateImageModel
|
||||
from app.service.generate_image.utils.adjust_contrast import adjust_contrast
|
||||
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust, face_detect_pic
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_stain_png_sd
|
||||
from app.service.generate_image.utils.image_processing import remove_background, stain_detection, generate_category_recognition, autoLevels, luminance_adjust
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@@ -36,22 +34,23 @@ class GenerateImage:
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
if request_data.mode == "img2img":
|
||||
# cv2 读图片是BGR PIL读图片是RGB
|
||||
self.image = self.get_image(request_data.image_url)
|
||||
self.prompt = request_data.prompt
|
||||
else:
|
||||
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||
self.prompt = request_data.prompt
|
||||
|
||||
self.prompt = request_data.prompt
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.mode = request_data.mode
|
||||
self.batch_size = 1
|
||||
self.category = request_data.category
|
||||
if self.category == "sketch":
|
||||
self.prompt = f"{self.category},{self.prompt}"
|
||||
self.index = 0
|
||||
self.gender = request_data.gender
|
||||
self.generate_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': '', 'category': ''}
|
||||
@@ -63,10 +62,13 @@ class GenerateImage:
|
||||
# Read data from response.
|
||||
# read image use cv2
|
||||
try:
|
||||
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
|
||||
image_file = BytesIO(response.data)
|
||||
image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
||||
image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
# response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
|
||||
# image_file = BytesIO(response.data)
|
||||
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
|
||||
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
# image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
||||
|
||||
image_cv2 = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="cv2")
|
||||
image_rbg = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image_rbg, (1024, 1024))
|
||||
except minio.error.S3Error:
|
||||
@@ -104,7 +106,7 @@ class GenerateImage:
|
||||
image_result = not_smudge_image
|
||||
if is_smudge: # 无污点
|
||||
# image_result = adjust_contrast(image_result)
|
||||
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
|
||||
image_url = upload_png_sd(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||
# logger.info(f"upload image SUCCESS : {image_url}")
|
||||
self.generate_data['status'] = "SUCCESS"
|
||||
self.generate_data['message'] = "success"
|
||||
@@ -121,13 +123,6 @@ class GenerateImage:
|
||||
status_data = self.redis_client.get(self.tasks_id)
|
||||
return json.loads(status_data), status_data
|
||||
|
||||
def infer(self, inputs):
|
||||
return self.grpc_client.async_infer(
|
||||
model_name=GI_MODEL_NAME,
|
||||
inputs=inputs,
|
||||
callback=self.callback
|
||||
)
|
||||
|
||||
def get_result(self):
|
||||
try:
|
||||
prompts = [self.prompt] * self.batch_size
|
||||
@@ -147,7 +142,7 @@ class GenerateImage:
|
||||
input_mode.set_data_from_numpy(mode_obj)
|
||||
|
||||
inputs = [input_text, input_image, input_mode]
|
||||
ctx = self.infer(inputs)
|
||||
ctx = self.grpc_client.async_infer(model_name=GI_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||
time_out = 600
|
||||
generate_data = None
|
||||
while time_out > 0:
|
||||
@@ -187,9 +182,10 @@ if __name__ == '__main__':
|
||||
rd = GenerateImageModel(
|
||||
tasks_id="123-89",
|
||||
prompt='skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic',
|
||||
image_url="",
|
||||
image_url="aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
|
||||
mode='txt2img',
|
||||
category="test"
|
||||
category="test",
|
||||
gender="male"
|
||||
)
|
||||
server = GenerateImage(rd)
|
||||
print(server.get_result())
|
||||
187
app/service/generate_image/service_generate_product_image.py
Normal file
187
app/service/generate_image/service_generate_product_image.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import redis
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image, ImageOps
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateProductImageModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class GenerateProductImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
# self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "product_image"
|
||||
self.image_strength = request_data.image_strength
|
||||
self.batch_size = 1
|
||||
self.product_type = request_data.product_type
|
||||
self.prompt = request_data.prompt
|
||||
self.image, self.image_size = pre_processing_image(request_data.image_url)
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
|
||||
def callback(self, result, error):
|
||||
if error:
|
||||
self.gen_product_data['status'] = "FAILURE"
|
||||
self.gen_product_data['message'] = str(error)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
else:
|
||||
# pil图像转成numpy数组
|
||||
if self.product_type == "single":
|
||||
image = result.as_numpy("generated_cnet_image")
|
||||
else:
|
||||
image = result.as_numpy("generated_inpaint_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
|
||||
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||
self.gen_product_data['status'] = "SUCCESS"
|
||||
self.gen_product_data['message'] = "success"
|
||||
self.gen_product_data['image_url'] = str(image_url)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
|
||||
def read_tasks_status(self):
|
||||
status_data = self.redis_client.get(self.tasks_id)
|
||||
return json.loads(status_data), status_data
|
||||
|
||||
def get_result(self):
|
||||
try:
|
||||
prompts = [self.prompt] * self.batch_size
|
||||
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
||||
self.image = cv2.resize(self.image, (512, 768))
|
||||
images = [self.image.astype(np.uint8)] * self.batch_size
|
||||
|
||||
if self.product_type == "single":
|
||||
text_obj = np.array(prompts, dtype="object").reshape(-1, 1)
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape(-1, 1)
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape(1)
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
image_strength_obj = np.array(self.image_strength, dtype=np.float32).reshape((1))
|
||||
|
||||
# 假设 prompts、images 和 self.image_strength 已经定义
|
||||
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_image_strength = grpcclient.InferInput("image_strength", image_strength_obj.shape, np_to_triton_dtype(image_strength_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
inputs = [input_text, input_image, input_image_strength]
|
||||
input_image_strength.set_data_from_numpy(image_strength_obj)
|
||||
|
||||
if self.product_type == "single":
|
||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
|
||||
else:
|
||||
ctx = self.grpc_client.async_infer(model_name=GPI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
|
||||
|
||||
time_out = 600
|
||||
while time_out > 0:
|
||||
gen_product_data, _ = self.read_tasks_status()
|
||||
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
break
|
||||
elif gen_product_data['status'] == "SUCCESS":
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(0.1)
|
||||
gen_product_data, _ = self.read_tasks_status()
|
||||
return gen_product_data
|
||||
except Exception as e:
|
||||
self.gen_product_data['status'] = "FAILURE"
|
||||
self.gen_product_data['message'] = str(e)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||
logger.info(f" [x] Sent to: {GPI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||
gen_product_data = json.dumps(data)
|
||||
redis_client.set(tasks_id, gen_product_data)
|
||||
return data
|
||||
|
||||
|
||||
def pre_processing_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:], data_type="PIL")
|
||||
# 原始图片的尺寸
|
||||
width, height = image.size
|
||||
|
||||
# 计算长宽比为 3:2 的新尺寸
|
||||
desired_ratio = 2 / 3
|
||||
current_ratio = width / height
|
||||
|
||||
if current_ratio > desired_ratio:
|
||||
# 原始图片更宽,需要在上下添加 padding
|
||||
new_width = width
|
||||
new_height = int(width / desired_ratio)
|
||||
else:
|
||||
# 原始图片更高或者长宽比已经为 3:2
|
||||
new_height = height
|
||||
new_width = int(height * desired_ratio)
|
||||
|
||||
# 创建一个新的画布,大小为添加 padding 后的尺寸,并设置为白色背景
|
||||
pad_image = Image.new('RGBA', (new_width, new_height), (0, 0, 0, 0))
|
||||
|
||||
# 将原始图片粘贴到新的画布中心
|
||||
left = (new_width - width) // 2
|
||||
top = (new_height - height) // 2
|
||||
pad_image.paste(image, (left, top))
|
||||
|
||||
# 将画布 resize 成宽度 500,长度 750
|
||||
resized_image = pad_image.resize((500, 750))
|
||||
image_size = (512, 768)
|
||||
|
||||
if resized_image.mode in ('RGBA', 'LA') or (resized_image.mode == 'P' and 'transparency' in resized_image.info):
|
||||
# 创建白色背景
|
||||
background = Image.new("RGB", image_size, (255, 255, 255))
|
||||
# 将图片粘贴到白色背景上
|
||||
background.paste(resized_image, mask=resized_image.split()[3])
|
||||
image = np.array(background)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
return image, image_size
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = GenerateProductImageModel(
|
||||
tasks_id="123-89",
|
||||
# prompt="",
|
||||
image_strength=0.9,
|
||||
prompt=" the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||
image_url="aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||
product_type="overall"
|
||||
)
|
||||
server = GenerateProductImage(rd)
|
||||
print(server.get_result())
|
||||
159
app/service/generate_image/service_generate_relight_image.py
Normal file
159
app/service/generate_image/service_generate_relight_image.py
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import redis
|
||||
import tritonclient.grpc as grpcclient
|
||||
from PIL import Image
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.generate_image import GenerateRelightImageModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class GenerateRelightImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.category = "relight_image"
|
||||
self.batch_size = 1
|
||||
self.prompt = request_data.prompt
|
||||
self.seed = "1"
|
||||
self.product_type = request_data.product_type
|
||||
self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
|
||||
self.direction = request_data.direction
|
||||
self.image_url = request_data.image_url
|
||||
self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2")
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
|
||||
def callback(self, result, error):
|
||||
if error:
|
||||
self.gen_product_data['status'] = "FAILURE"
|
||||
self.gen_product_data['message'] = str(error)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
else:
|
||||
# pil图像转成numpy数组
|
||||
if self.product_type == 'single':
|
||||
image = result.as_numpy("generated_relight_image")
|
||||
else:
|
||||
image = result.as_numpy("generated_inpaint_image")
|
||||
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
|
||||
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||
self.gen_product_data['status'] = "SUCCESS"
|
||||
self.gen_product_data['message'] = "success"
|
||||
self.gen_product_data['image_url'] = str(image_url)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
|
||||
def read_tasks_status(self):
|
||||
status_data = self.redis_client.get(self.tasks_id)
|
||||
return json.loads(status_data), status_data
|
||||
|
||||
def get_result(self):
|
||||
try:
|
||||
prompts = [self.prompt] * self.batch_size
|
||||
image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
|
||||
image = cv2.resize(image, (512, 768))
|
||||
images = [image.astype(np.uint8)] * self.batch_size
|
||||
seeds = [self.seed] * self.batch_size
|
||||
nagetive_prompts = [self.negative_prompt] * self.batch_size
|
||||
directions = [self.direction] * self.batch_size
|
||||
|
||||
if self.product_type == 'single':
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((-1, 768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((-1, 1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((-1, 1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((-1, 1))
|
||||
else:
|
||||
text_obj = np.array(prompts, dtype="object").reshape((1))
|
||||
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
|
||||
na_text_obj = np.array(nagetive_prompts, dtype="object").reshape((1))
|
||||
seed_obj = np.array(seeds, dtype="object").reshape((1))
|
||||
direction_obj = np.array(directions, dtype="object").reshape((1))
|
||||
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
|
||||
input_natext = grpcclient.InferInput("negative_prompt", na_text_obj.shape, np_to_triton_dtype(na_text_obj.dtype))
|
||||
input_seed = grpcclient.InferInput("seed", seed_obj.shape, np_to_triton_dtype(seed_obj.dtype))
|
||||
input_direction = grpcclient.InferInput("direction", direction_obj.shape, np_to_triton_dtype(direction_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_natext.set_data_from_numpy(na_text_obj)
|
||||
input_seed.set_data_from_numpy(seed_obj)
|
||||
input_direction.set_data_from_numpy(direction_obj)
|
||||
|
||||
inputs = [input_text, input_natext, input_image, input_seed, input_direction]
|
||||
if self.product_type == 'single':
|
||||
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_SINGLE, inputs=inputs, callback=self.callback)
|
||||
else:
|
||||
ctx = self.grpc_client.async_infer(model_name=GRI_MODEL_NAME_OVERALL, inputs=inputs, callback=self.callback)
|
||||
|
||||
time_out = 600
|
||||
while time_out > 0:
|
||||
gen_product_data, _ = self.read_tasks_status()
|
||||
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
break
|
||||
elif gen_product_data['status'] == "SUCCESS":
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(0.1)
|
||||
gen_product_data, _ = self.read_tasks_status()
|
||||
return gen_product_data
|
||||
except Exception as e:
|
||||
self.gen_product_data['status'] = "FAILURE"
|
||||
self.gen_product_data['message'] = str(e)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GRI_RABBITMQ_QUEUES, body=str_gen_product_data)
|
||||
logger.info(f" [x] Sent to: {GRI_RABBITMQ_QUEUES} data:@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||
gen_product_data = json.dumps(data)
|
||||
redis_client.set(tasks_id, gen_product_data)
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = GenerateRelightImageModel(
|
||||
tasks_id="123-89",
|
||||
# prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
prompt="Colorful black",
|
||||
image_url='aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png',
|
||||
direction="Right Light",
|
||||
product_type="single"
|
||||
)
|
||||
server = GenerateRelightImage(rd)
|
||||
print(server.get_result())
|
||||
119
app/service/generate_image/service_generate_single_logo.py
Normal file
119
app/service/generate_image/service_generate_single_logo.py
Normal file
@@ -0,0 +1,119 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :service_att_recognition.py
|
||||
@Author :周成融
|
||||
@Date :2023/7/26 12:01:05
|
||||
@detail :
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import redis
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
import tritonclient.grpc as grpcclient
|
||||
from app.schemas.generate_image import GenerateSingleLogoImageModel
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class GenerateSingleLogoImage:
|
||||
def __init__(self, request_data):
|
||||
if DEBUG is False:
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
self.channel = self.connection.channel()
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
||||
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
self.batch_size = 1
|
||||
self.category = "single_logo"
|
||||
self.negative_prompts = "bad, ugly"
|
||||
self.seed = request_data.seed
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.prompt = request_data.prompt
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
|
||||
def read_tasks_status(self):
|
||||
status_data = self.redis_client.get(self.tasks_id)
|
||||
return json.loads(status_data), status_data
|
||||
|
||||
def callback(self, result, error):
|
||||
if error:
|
||||
self.gen_single_logo_data['status'] = "FAILURE"
|
||||
self.gen_single_logo_data['message'] = str(error)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
|
||||
else:
|
||||
image = result.as_numpy("generated_image")
|
||||
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", file_name=f"{self.tasks_id}.png")
|
||||
self.gen_single_logo_data['status'] = "SUCCESS"
|
||||
self.gen_single_logo_data['message'] = "success"
|
||||
self.gen_single_logo_data['image_url'] = str(image_url)
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
|
||||
|
||||
def get_result(self):
|
||||
try:
|
||||
# prompt
|
||||
prompts = [self.prompt] * self.batch_size
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
|
||||
text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1))
|
||||
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
|
||||
input_text_neg.set_data_from_numpy(text_obj_neg)
|
||||
|
||||
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
|
||||
input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype))
|
||||
input_seed.set_data_from_numpy(seed)
|
||||
inputs = [input_text, input_text_neg, input_seed]
|
||||
ctx = self.grpc_client.async_infer(model_name=GSL_MODEL_NAME, inputs=inputs, callback=self.callback)
|
||||
time_out = 600
|
||||
generate_data = None
|
||||
while time_out > 0:
|
||||
generate_data, _ = self.read_tasks_status()
|
||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
break
|
||||
elif generate_data['status'] == "SUCCESS":
|
||||
break
|
||||
time_out -= 1
|
||||
time.sleep(0.1)
|
||||
return generate_data
|
||||
except Exception as e:
|
||||
raise Exception(str(e))
|
||||
finally:
|
||||
dict_generate_data, str_generate_data = self.read_tasks_status()
|
||||
if DEBUG is False:
|
||||
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
|
||||
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
|
||||
|
||||
|
||||
def infer_cancel(tasks_id):
|
||||
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
|
||||
generate_data = json.dumps(data)
|
||||
redis_client.set(tasks_id, generate_data)
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
rd = GenerateSingleLogoImageModel(
|
||||
tasks_id="123-89",
|
||||
prompt='an apple',
|
||||
seed="2",
|
||||
)
|
||||
server = GenerateSingleLogoImage(rd)
|
||||
print(server.get_result())
|
||||
@@ -381,7 +381,7 @@ if __name__ == '__main__':
|
||||
remove_bg_img = remove_background(luminance)
|
||||
# cv2.imwrite("remove_bg_img.png", remove_bg_img)
|
||||
|
||||
print(1)
|
||||
# print(1)
|
||||
cv2.imshow("source", img)
|
||||
cv2.imshow("levels", equAuto)
|
||||
cv2.imshow("luminance", luminance)
|
||||
|
||||
@@ -10,26 +10,61 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
# import boto3
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
|
||||
def upload_png_sd(image, user_id, category, object_name):
|
||||
# s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||
|
||||
|
||||
# def upload_single_logo(image, user_id, category, object_name):
|
||||
# with io.BytesIO() as output:
|
||||
# image.save(output, format='PNG')
|
||||
# data = output.getvalue()
|
||||
# # 创建一个 S3 客户端
|
||||
# try:
|
||||
# key = f'{user_id}/{category}/{object_name}'
|
||||
# image_url = f"{AIDA_CLOTHING}/{key}"
|
||||
# s3.put_object(Bucket=GSL_MINIO_BUCKET, Key=key, Body=data, ContentType='image/png')
|
||||
# return image_url
|
||||
# except Exception as e:
|
||||
# print(f'上传到 S3 失败: {e}')
|
||||
|
||||
def upload_SDXL_image(image, user_id, category, file_name):
|
||||
try:
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
|
||||
# minio_req = minio_client.put_object(
|
||||
# GI_MINIO_BUCKET,
|
||||
# f'{user_id}/{category}/{file_name}',
|
||||
# io.BytesIO(image_bytes),
|
||||
# len(image_bytes),
|
||||
# content_type='image/jpeg'
|
||||
# )
|
||||
object_name = f'{user_id}/{category}/{file_name}'
|
||||
req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
|
||||
def upload_png_sd(image, user_id, category, file_name):
|
||||
try:
|
||||
_, img_byte_array = cv2.imencode('.jpg', image)
|
||||
minio_req = minio_client.put_object(
|
||||
GI_MINIO_BUCKET,
|
||||
f'{user_id}/{category}/{object_name}',
|
||||
io.BytesIO(img_byte_array),
|
||||
len(img_byte_array),
|
||||
content_type='image/jpeg'
|
||||
)
|
||||
image_url = f"aida-users/{minio_req.object_name}"
|
||||
object_name = f'{user_id}/{category}/{file_name}'
|
||||
req = oss_upload_image(bucket=GI_MINIO_BUCKET, object_name=object_name, image_bytes=img_byte_array)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
77
app/service/prompt_generation/chatgpt_for_translation.py
Normal file
77
app/service/prompt_generation/chatgpt_for_translation.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import logging
|
||||
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
|
||||
from app.core.config import OPENAI_MODEL, OPENAI_API_KEY
|
||||
|
||||
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
|
||||
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
||||
|
||||
|
||||
llm = ChatOpenAI(model_name=OPENAI_MODEL,
|
||||
openai_api_key=OPENAI_API_KEY,
|
||||
temperature=0)
|
||||
|
||||
|
||||
def translate_to_en(text):
|
||||
template = (
|
||||
"""You are a translation expert, proficient in various languages.
|
||||
And can translate various languages into English.
|
||||
Please translate to grammatically correct English regardless of the input language.
|
||||
If the input is already in English, or consists of letters or numbers such as "cat", "abc", or "1",
|
||||
output the input text exactly as it is without any modifications or additions.
|
||||
If there are grammatical errors, correct them and then output the sentence."""
|
||||
)
|
||||
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
|
||||
|
||||
# 待翻译文本由 Human 角色输入
|
||||
human_template = "User input : {text}"
|
||||
human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template)
|
||||
|
||||
# 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate
|
||||
chat_prompt_template = ChatPromptTemplate.from_messages(
|
||||
[system_message_prompt, human_message_prompt]
|
||||
)
|
||||
translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template)
|
||||
|
||||
result = translate_chain.invoke(text)
|
||||
|
||||
logging.info("translate result : " + result.get('text'))
|
||||
# print("translate result : " + result.get('text'))
|
||||
return result.get('text')
|
||||
|
||||
# template = (
|
||||
# """
|
||||
# Input sentence:
|
||||
# {translate}
|
||||
# 1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images,
|
||||
# ensuring all key nouns or adjectives related to the image are retained.
|
||||
# 2. Simplify complex sentence structures and clarify ambiguous expressions.
|
||||
# 3. Only Output the adjusted English sentence.
|
||||
#
|
||||
# Output :
|
||||
# """
|
||||
# )
|
||||
# # "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words."
|
||||
# # 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding.
|
||||
#
|
||||
# prompt_template = PromptTemplate(input_variables=["translate"], template=template)
|
||||
# prompt_chain = LLMChain(llm=llm, prompt=prompt_template)
|
||||
#
|
||||
# from langchain.chains import SimpleSequentialChain
|
||||
# overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True)
|
||||
#
|
||||
# response = overall_chain.run(text)
|
||||
# return response
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
text = translate_to_en("fire")
|
||||
print(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,17 +1,17 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import minio.error
|
||||
import redis
|
||||
import json
|
||||
|
||||
import cv2
|
||||
import minio.error
|
||||
import numpy as np
|
||||
import redis
|
||||
import torch
|
||||
import tritonclient.grpc as grpcclient
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.schemas.super_resolution import SuperResolutionModel
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@@ -24,7 +24,7 @@ class SuperResolution:
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.sr_image_url = data.sr_image_url
|
||||
self.sr_xn = data.sr_xn
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.redis_client.set(self.tasks_id, json.dumps({'status': 'PENDING', 'message': "pending", 'data': ''}))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
@@ -33,36 +33,37 @@ class SuperResolution:
|
||||
# @RunTime
|
||||
def read_image(self):
|
||||
try:
|
||||
image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
|
||||
img = oss_get_image(bucket=self.sr_image_url.split("/", 1)[0], object_name=self.sr_image_url.split("/", 1)[1], data_type="cv2")
|
||||
img = img.astype(np.float32) / 255. # 解码
|
||||
except minio.error.S3Error as e:
|
||||
sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
|
||||
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
||||
logger.info(f" [x] Sent {sr_data}")
|
||||
raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
|
||||
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
|
||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
|
||||
return img
|
||||
|
||||
# try:
|
||||
# image_data = self.minio_client.get_object(self.sr_image_url.split("/", 1)[0], self.sr_image_url.split("/", 1)[1])
|
||||
# except minio.error.S3Error as e:
|
||||
# sr_data = json.dumps({'tasks_id': self.tasks_id, 'status': 'ERROR', 'message': f'{e}'})
|
||||
# self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=sr_data)
|
||||
# logger.info(f" [x] Sent {sr_data}")
|
||||
# raise FileNotFoundError(f"Image '{self.sr_image_url.split('/', 1)[1]}' not found in bucket '{self.sr_image_url.split('/', 1)[0]}'")
|
||||
# img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
|
||||
# img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
|
||||
# return img
|
||||
|
||||
def read_tasks_status(self):
|
||||
status_data = json.loads(self.redis_client.get(self.tasks_id))
|
||||
logging.info(f"{self.tasks_id} ===> {status_data}")
|
||||
return status_data
|
||||
|
||||
# @RunTime
|
||||
def infer(self, inputs):
|
||||
return self.triton_client.async_infer(
|
||||
model_name=SR_MODEL_NAME,
|
||||
inputs=inputs,
|
||||
callback=self.callback
|
||||
)
|
||||
|
||||
# @RunTime
|
||||
def sr_result(self):
|
||||
sample = self.read_image()
|
||||
if self.sr_xn == 2:
|
||||
new_shape = (sample.shape[0] // self.sr_xn, sample.shape[1] // self.sr_xn)
|
||||
sample = cv2.resize(sample, new_shape)
|
||||
print(new_shape)
|
||||
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
||||
inputs = [
|
||||
@@ -72,13 +73,16 @@ class SuperResolution:
|
||||
# , binary_data=True
|
||||
)
|
||||
|
||||
ctx = self.infer(inputs)
|
||||
ctx = self.triton_client.async_infer(
|
||||
model_name=SR_MODEL_NAME,
|
||||
inputs=inputs,
|
||||
callback=self.callback
|
||||
)
|
||||
time_out = 60
|
||||
while time_out > 0:
|
||||
generate_data = self.read_tasks_status()
|
||||
if generate_data['status'] in ["REVOKED", "FAILURE"]:
|
||||
ctx.cancel()
|
||||
# noinspection PyTypeChecker
|
||||
self.channel.basic_publish(exchange='', routing_key=SR_RABBITMQ_QUEUES, body=json.dumps(generate_data))
|
||||
logger.info(f" [x] Sent {generate_data}")
|
||||
break
|
||||
@@ -88,28 +92,19 @@ class SuperResolution:
|
||||
time.sleep(1)
|
||||
return self.read_tasks_status()
|
||||
|
||||
# results = self.triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
|
||||
|
||||
# sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
||||
# output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
# if output.ndim == 3:
|
||||
# output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||
# output = (output * 255.0).round().astype(np.uint8)
|
||||
# output_url = self.upload_img_sr(output, generate_uuid())
|
||||
# return output_url
|
||||
|
||||
def upload_img_sr(self, image):
|
||||
try:
|
||||
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||
res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
|
||||
image_url = f"aida-users/{res.object_name}"
|
||||
# res = self.minio_client.put_object(f'{SR_MINIO_BUCKET}', f'{self.user_id}/sr/output/{self.tasks_id}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png')
|
||||
object_name = f'{self.user_id}/sr/output/{self.tasks_id}.jpg'
|
||||
oss_upload_image(bucket=SR_MINIO_BUCKET, object_name=object_name, image_bytes=image_bytes)
|
||||
image_url = f"{SR_MINIO_BUCKET}/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logger.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
def callback(self, result, error):
|
||||
if error:
|
||||
print(error)
|
||||
sr_info_data = json.dumps({'status': 'FAILURE', 'message': f"{error}", 'data': f"{error}"})
|
||||
self.redis_client.set(self.tasks_id, sr_info_data)
|
||||
else:
|
||||
@@ -135,6 +130,6 @@ def infer_cancel(tasks_id):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = SuperResolutionModel(sr_image_url="test/512_image/15.png", sr_xn=2, sr_tasks_id="123")
|
||||
request_data = SuperResolutionModel(sr_image_url="aida-users/83/print/b77bf4ca-6ca2-44a1-9040-505f359a974c-3-83.png", sr_xn=2, sr_tasks_id="12341556")
|
||||
service = SuperResolution(request_data)
|
||||
result_url = service.sr_result()
|
||||
|
||||
74
app/service/utils/oss_client.py
Normal file
74
app/service/utils/oss_client.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import io
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import boto3
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
# 获取图片
|
||||
def oss_get_image(bucket, object_name, data_type):
|
||||
# cv2 默认全通道读取
|
||||
image_object = None
|
||||
try:
|
||||
if OSS == "minio":
|
||||
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||
else:
|
||||
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||
image_data = oss_client.get_object(Bucket=bucket, Key=object_name)['Body']
|
||||
if data_type == "cv2":
|
||||
image_bytes = image_data.read()
|
||||
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
|
||||
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
|
||||
else:
|
||||
data_bytes = BytesIO(image_data.read())
|
||||
image_object = Image.open(data_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"{OSS} | 获取图片出现异常 ######: {e}")
|
||||
return image_object
|
||||
|
||||
|
||||
def oss_upload_image(bucket, object_name, image_bytes):
|
||||
req = None
|
||||
try:
|
||||
if OSS == "minio":
|
||||
oss_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||
else:
|
||||
oss_client = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
|
||||
req = oss_client.put_object(Bucket=bucket, Key=object_name, Body=io.BytesIO(image_bytes), ContentType='image/png')
|
||||
except Exception as e:
|
||||
logger.warning(f"{OSS} | 上传图片出现异常 ######: {e}")
|
||||
return req
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png"
|
||||
# url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg"
|
||||
# url = "aida-sys-image/images/female/outwear/0628000054.jpg"
|
||||
# url = "aida-users/89/product_image/string-89.png"
|
||||
# url = "test/845046c7-4f62-4f54-a4a9-c26d49c6969335b5b3a9-d335-4871-a46c-3cc3caf07da259629dfd1f1f555a2e2a9def7e719366.png"
|
||||
# url = 'aida-users/89/relight_image/123-89.png'
|
||||
# url = 'aida-users/89/relight_image/123-89.png'
|
||||
# url = 'aida-users/89/relight_image/123-89.png'
|
||||
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
||||
# url = "aida-users/89/single_logo/123-89.png"
|
||||
# url = "aida-users/89/product_image/string-89.png"
|
||||
url = "aida-results/result_c6520ce7-33a1-11ef-a8d3-b0dcefbff887.png"
|
||||
read_type = "PIL"
|
||||
if read_type == "cv2":
|
||||
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img = oss_get_image(bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||
img.show()
|
||||
Reference in New Issue
Block a user