Merge branch 'local' into develop

This commit is contained in:
zhouchengrong
2024-06-17 10:46:09 +08:00
72 changed files with 5667 additions and 45 deletions

View File

@@ -1,8 +1,9 @@
import json
import logging
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from app.schemas.attribute_retrieve import *
from app.schemas.response_template import ResponseModel
from app.service.attribute.config import const
from app.service.attribute.service_att_recognition import AttributeRecognition
from app.service.attribute.service_category_recognition import CategoryRecognition
@@ -12,34 +13,28 @@ logger = logging.getLogger()
# 属性识别
@router.post("/attribute_recognition")
@router.post("/attribute_recognition", response_model=ResponseModel)
def attribute_recognition(request_item: list[AttributeRecognitionModel]):
try:
logger.info(f"attribute_recognition request item is : @@@@@@:{request_item}")
service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result()
code = 200
message = "access"
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
code = 400
message = e
data = e
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
return {"code": code, "message": message, "data": data}
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
# 类别识别
@router.post("/category_recognition")
def category_recognition(request_item: list[CategoryRecognitionModel]):
try:
logger.info(f"category_recognition request item is : @@@@@@:{request_item}")
service = CategoryRecognition(request_data=request_item)
data = service.get_result()
code = 200
message = "access"
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
code = 400
message = e
data = e
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
return {"code": code, "message": message, "data": data}
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

23
app/api/api_chat_robot.py Normal file
View File

@@ -0,0 +1,23 @@
import json
import logging
import time
from fastapi import APIRouter, HTTPException
from app.schemas.chat_robot import ChatRobotModel
from app.schemas.response_template import ResponseModel
from app.service.chat_robot.script.main import chat
router = APIRouter()
logger = logging.getLogger()
@router.post("/chat_robot")
def chat_robot(request_data: ChatRobotModel):
try:
logger.info(f"chat_robot request item is : @@@@@@:{request_data}")
data = chat(post_data=request_data)
logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

24
app/api/api_design.py Normal file
View File

@@ -0,0 +1,24 @@
import json
import logging
import time
from fastapi import APIRouter, HTTPException
from app.schemas.design import DesignModel
from app.schemas.response_template import ResponseModel
from app.service.design.service import generate
router = APIRouter()
logger = logging.getLogger()
@router.post("/design")
def design(request_data: DesignModel):
try:
logger.info(f"design request item is : @@@@@@:{request_data}")
data = generate(request_data=request_data)
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -0,0 +1,22 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.pre_processing import DesignPreProcessingModel
from app.schemas.response_template import ResponseModel
from app.service.design_pre_processing.service import DesignPreprocessing
router = APIRouter()
logger = logging.getLogger()
@router.post("/design_pre_processing")
def design_pre_processing(request_data: DesignPreProcessingModel):
try:
logger.info(f"design_pre_processing request item is : @@@@@@:{request_data}")
server = DesignPreprocessing()
data = server.pipeline(image_list=request_data.sketches)
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -1,28 +1,119 @@
import json
import logging
from fastapi import APIRouter, BackgroundTasks
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.service import GenerateImage, infer_cancel
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel
from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_generate_image import GenerateImage, infer_cancel as generate_image_infer_cancel
from app.service.generate_image.service_generate_product_image import GenerateProductImage, infer_cancel as generate_product_image_cancel
from app.service.generate_image.service_generate_relight_image import GenerateRelightImage, infer_cancel as generate_relight_image_cancel
from app.service.generate_image.service_generate_single_logo import GenerateSingleLogoImage, infer_cancel as generate_single_logo_cancel
router = APIRouter()
logger = logging.getLogger()
'''generate image'''
@router.post("/generate_image")
def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks):
try:
logger.info(f"request data ### : {request_item}")
logger.info(f"generate_image request item is : @@@@@@:{request_item}")
service = GenerateImage(request_item)
background_tasks.add_task(service.get_result)
code = 200
message = "access"
except Exception as e:
code = 400
message = e
logger.warning(e)
return {"code": code, "message": message}
logger.warning(f"generate_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_cancel/{tasks_id}>")
def generate_image(tasks_id):
result = infer_cancel(tasks_id)
return {"code": 200, "message": result['message'], "data": result['data']}
def generate_image(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_image_infer_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''single logo'''
@router.post("/generate_single_logo")
def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_tasks: BackgroundTasks):
try:
logger.info(f"generate_single_logo request item is : @@@@@@:{request_item}")
service = GenerateSingleLogoImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_single_logo Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_single_logo_cancel/{tasks_id}>")
def generate_single_logo_image(tasks_id: str):
try:
logger.info(f"generate_single_logo_cancel request item is : @@@@@@:{tasks_id}")
data = generate_single_logo_cancel(tasks_id)
logger.info(f"generate_single_logo_cancel response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"generate_single_logo_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''product image'''
@router.post("/generate_product_image")
def generate_product_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks):
try:
logger.info(f"generate_product_image request item is : @@@@@@:{request_item}")
service = GenerateProductImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_product_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_product_image_cancel_cancel/{tasks_id}>")
def generate_product_image(tasks_id: str):
try:
logger.info(f"generate_product_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_product_image_cancel(tasks_id)
logger.info(f"generate_product_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"generate_product_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
'''relight image'''
@router.post("/generate_relight_image")
def generate_relight_image(request_item: GenerateProductImageModel, background_tasks: BackgroundTasks):
try:
logger.info(f"generate_relight_image request item is : @@@@@@:{request_item}")
service = GenerateRelightImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_relight_image_cancel_cancel/{tasks_id}>")
def generate_relight_image(tasks_id: str):
try:
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_relight_image_cancel(tasks_id)
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])

View File

@@ -0,0 +1,24 @@
import json
import logging
import time
from fastapi import APIRouter, HTTPException
from app.schemas.prompt_generation import PromptGenerationImageModel
from app.schemas.response_template import ResponseModel
from app.service.prompt_generation.chatgpt_for_translation import translate_to_en
router = APIRouter()
logger = logging.getLogger()
@router.post("/translateToEN")
def prompt_generation(request_data: PromptGenerationImageModel):
try:
logger.info(f"prompt_generation request item is : @@@@@@:{request_data}")
data = translate_to_en(request_data.text)
logger.info(f"prompt_generation response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"prompt_generation Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)

View File

@@ -4,6 +4,11 @@ from app.api import api_test
from app.api import api_super_resolution
from app.api import api_generate_image
from app.api import api_attribute_retrieve
from app.api import api_design
from app.api import api_chat_robot
from app.api import api_prompt_generation
from app.api import api_design_pre_processing
router = APIRouter()
@@ -11,3 +16,7 @@ router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")
router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api")
router.include_router(api_design.router, tags=['design'], prefix="/api")
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")

View File

@@ -1,8 +1,9 @@
import json
import logging
from fastapi import APIRouter, BackgroundTasks
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.response_template import ResponseModel
from app.schemas.super_resolution import SuperResolutionModel
from app.service.super_resolution.service import SuperResolution, infer_cancel
@@ -13,18 +14,22 @@ logger = logging.getLogger()
@router.post("/super_resolution")
def super_resolution(request_item: SuperResolutionModel, background_tasks: BackgroundTasks):
try:
logger.info(f"super_resolution request item is : @@@@@@:{request_item}")
service = SuperResolution(request_item)
background_tasks.add_task(service.sr_result)
code = 200
message = "access"
except Exception as e:
code = 400
message = e
logger.warning(e)
return {"code": code, "message": message}
logger.warning(f"super_resolution Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/sr_cancel/{tasks_id}>")
def super_resolution(tasks_id):
result = infer_cancel(tasks_id)
return {"code": 200, "message": result['message'], "data": result['data']}
def super_resolution(tasks_id: str):
try:
logger.info(f"sr_cancel request item is : @@@@@@:{tasks_id}")
data = infer_cancel(tasks_id)
logger.info(f"sr_cancel response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"sr_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])

View File

@@ -1,13 +1,20 @@
import logging
from fastapi import APIRouter
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES
from fastapi import FastAPI, HTTPException
from app.schemas.response_template import ResponseModel
logger = logging.getLogger()
router = APIRouter()
@router.get("")
def test():
@router.get("{id}")
def test(id: int):
logger.info(SR_RABBITMQ_QUEUES)
logger.info("test")
return {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES}
data = {"SR_RABBITMQ_QUEUES message": SR_RABBITMQ_QUEUES, "GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES}
if id == 1:
raise HTTPException(status_code=404, detail="Item not found")
return ResponseModel(data=data)

View File

@@ -41,6 +41,11 @@ MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# S3 配置
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
S3_AWS_SECRET_ACCESS_KEY = "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"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
@@ -61,6 +66,32 @@ MILVUS_PORT = "19530"
MILVUS_TABLE_KEYPOINT = "keypoint_cache"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con_python' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# attribute service config
ATT_TRITON_URL = "10.1.1.240:10000"
@@ -77,6 +108,25 @@ GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = os.getenv("GI_RABBITMQ_QUEUES", f"GenerateImage{RABBITMQ_ENV}")
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = os.getenv("SLOGAN_RABBITMQ_QUEUES", f"Slogan{RABBITMQ_ENV}")
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.240:10051'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = os.getenv("GEN_SINGLE_LOGO_RABBITMQ_QUEUES", f"GenSingleLogo{RABBITMQ_ENV}")
# Generate Single Logo service config
GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
GPI_MODEL_NAME = 'diffusion_ensemble_all'
GPI_MODEL_URL = '10.1.1.240:10061'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"Relight{RABBITMQ_ENV}")
GRI_MODEL_NAME = 'stable_diffusion_1_5'
GRI_MODEL_URL = '10.1.1.150:8001'
# SEG service config
SEG_MODEL_URL = '10.1.1.240:10000'
SEGMENTATION = {
@@ -87,9 +137,13 @@ SEGMENTATION = {
}
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:9000'
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# DESIGN 预处理
IF_DEBUG_SHOW = False
# 优先级
PRIORITY_DICT = {
@@ -117,4 +171,3 @@ PRIORITY_DICT = {
'bag_back': -98,
'earring_back': -99,
}

View File

@@ -1,14 +1,18 @@
import logging.config
from http.client import HTTPException
from fastapi.responses import JSONResponse
from fastapi import FastAPI, HTTPException, Request
import uvicorn
from fastapi import FastAPI
from app.api.api_route import router
from app.core.config import settings
from app.schemas.response_template import ResponseModel
from logging_env import LOGGER_CONFIG_DICT
logging.config.dictConfig(LOGGER_CONFIG_DICT)
logging.getLogger("pika").setLevel(logging.WARNING)
from starlette.middleware.cors import CORSMiddleware
@@ -35,5 +39,15 @@ def get_application() -> FastAPI:
app = get_application()
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict()
)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class ChatRobotModel(BaseModel):
gender: str
message: str
session_id: str
user_id: int

50
app/schemas/design.py Normal file
View File

@@ -0,0 +1,50 @@
from pydantic import BaseModel
# class BodyPointModel(BaseModel):
# waistband_right: list[int]
# hand_point_right: list[int]
# waistband_left: list[int]
# hand_point_left: list[int]
# shoulder_left: list[int]
# shoulder_right: list[int]
#
#
# class BasicModel(BaseModel):
# body_point: BodyPointModel
# layer_order: bool
# scale_bag: float
# scale_earrings: float
# self_template: bool
# single_overall: str
# switch_category: str
# body_path: str
#
#
# class PrintModel(BaseModel):
# if_single: bool
# print_path_list: list[str]
#
#
# class ItemModel(BaseModel):
# color: str
# image_id: str
# offset: list[int]
# path: str
# print: PrintModel
# resize_scale: float
# type: str
#
#
# class CollocationModel(BaseModel):
# basic: BasicModel
# item: list[ItemModel]
#
#
# class DesignModel(BaseModel):
# object: list[CollocationModel]
# process_id: str
class DesignModel(BaseModel):
objects: list[dict]
process_id: str

View File

@@ -8,3 +8,21 @@ class GenerateImageModel(BaseModel):
mode: str
category: str
gender: str
class GenerateSingleLogoImageModel(BaseModel):
tasks_id: str
prompt: str
seed: str
class GenerateProductImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str
class GenerateRelightImageModel(BaseModel):
tasks_id: str
prompt: str
image_url: str

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class DesignPreProcessingModel(BaseModel):
sketches: list[dict]

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class PromptGenerationImageModel(BaseModel):
text: str

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
from typing import Any, Optional
class ResponseModel(BaseModel):
code: int = 200
msg: str = "OK!"
data: Optional[Any] = None

View File

@@ -0,0 +1,7 @@
from .agent_executor import CustomAgentExecutor
from .conversational_functions_agent import ConversationalFunctionsAgent
__all__ = [
"CustomAgentExecutor",
"ConversationalFunctionsAgent"
]

View File

@@ -0,0 +1,132 @@
import inspect
import json
import logging
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain.agents import AgentExecutor
from langchain.callbacks.manager import Callbacks, CallbackManager
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, RunInfo
from langchain_core.agents import AgentAction, AgentFinish
class CustomAgentExecutor(AgentExecutor):
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
session_key: str = "",
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs, session_key)
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
try:
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
logging.exception(e)
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs, session_key
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
session_key: str = ""
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None and outputs['need_record']:
self.memory.save_context(inputs, outputs, session_key)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(self, inputs: Union[Dict[str, Any], Any], session_key: str = "") -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs, session_key)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str]
) -> Optional[AgentFinish]:
"""Check if the tool is a returning tool."""
agent_action, observation = next_step_output
name_to_tool_map = {tool.name: tool for tool in self.tools}
return_value_key = "output"
if len(self.agent.return_values) > 0:
return_value_key = self.agent.return_values[0]
try:
observation_list = json.loads(observation)
if agent_action.tool == "sql_db_query" and isinstance(observation_list,
list) and observation_list.__len__() != 0:
return AgentFinish(
{return_value_key: observation},
"",
)
except:
pass
# Invalid tools won't be in the map, so we return False.
if agent_action.tool in name_to_tool_map:
if name_to_tool_map[agent_action.tool].return_direct:
return AgentFinish(
{return_value_key: observation},
"",
)
return None

View File

@@ -0,0 +1,198 @@
import json
import re
from json import JSONDecodeError
from typing import List, Tuple, Any, Union
from dataclasses import dataclass
from langchain.callbacks.manager import Callbacks
from langchain.agents import (
OpenAIFunctionsAgent,
)
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
OutputParserException
)
from langchain.schema.messages import (
AIMessage,
FunctionMessage
)
from langchain.tools import BaseTool, StructuredTool
# from langchain.tools.convert_to_openai import FunctionDescription
from langchain.utils.openai_functions import FunctionDescription
@dataclass
class _FunctionsAgentAction(AgentAction):
"""Add message_log to AgentAction class for the _FunctionAgentAction
"""
message_log: List[BaseMessage]
def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
):
"""Override init to support instantiation by position for backward compat."""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
def _convert_agent_action_to_messages(
agent_action: AgentAction, observation: str
) -> List[BaseMessage]:
"""Convert an agents action to a message.
This code is used to reconstruct the original AI message from the agents action.
Args:
agent_action: Agent action to convert.
Returns:
AIMessage that corresponds to the original tools invocation.
"""
if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [
_create_function_message(agent_action, observation)
]
else:
return [AIMessage(content=agent_action.log)]
def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
"""Convert agents action and observation into a function message.
Args:
agent_action: the tools invocation request from the agents
observation: the result of the tools invocation
Returns:
FunctionMessage that corresponds to the original tools invocation
"""
if not isinstance(observation, str):
try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
def _format_intermediate_steps(
intermediate_steps: List[Tuple[AgentAction, str]],
) -> List[BaseMessage]:
"""Format intermediate steps.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
Returns:
list of messages to send to the LLM for the next prediction
"""
messages = []
for intermediate_step in intermediate_steps:
agent_action, observation = intermediate_step
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
return messages
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tools into the OpenAI function API."""
parameters = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": tool.param_description if hasattr(tool, 'param_description') else "",
},
},
"required": ["query"],
}
return {
"name": tool.name,
"description": tool.description,
"parameters": parameters,
}
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
if not isinstance(message, AIMessage):
raise TypeError(f"Expected an AI message but got {type(message)}")
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_call = message.additional_kwargs["function_call"]
function_name = function_call["name"]
try:
_tool_input = json.loads(function_call["arguments"])
except JSONDecodeError:
raise OutputParserException(
f"Could not parse tools input: {function_call} because"
f"the `arguments` is not valid JSON."
)
if "query" in _tool_input:
tool_input = _tool_input["query"]
else:
tool_input = _tool_input
return _FunctionsAgentAction(
tool=function_name,
tool_input=tool_input,
log=f"\nInvoking: `{function_name}` with `{tool_input}`\n",
message_log=[message]
)
# pattern = r'\((.*?)\)'
# matches = re.findall(pattern, message.content)
# result = []
#
# for match in matches:
# result.append(match)
#
# if result:
# output = result
# else:
# output = message.content
return AgentFinish(return_values={"output": message.content}, log=message.content)
class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
@property
def functions(self) -> List[dict]:
return [dict(_format_tool_to_openai_function(t)) for t in self.tools]
def plan(self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Decide how agents should move after receiving an input. The difference between
OpenAIFunctionsAgent lies in the '_parse_ai_message' function. We add an OutputParser
into it.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
**kwargs: Including user's input string
Returns:
Action specifying what tools to use.
"""
agent_scratchpad: List[BaseMessage] = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages: List[BaseMessage] = prompt.to_messages()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision

View File

@@ -0,0 +1,6 @@
from .openai_token_record_callback import OpenAITokenRecordCallbackHandler
__all__ = [
'OpenAITokenRecordCallbackHandler'
]

View File

@@ -0,0 +1,46 @@
"""Callback Handler that add on_chain_end function to record Token usage."""
from typing import Any, Dict
from langchain.callbacks import OpenAICallbackHandler
from langchain.schema import LLMResult
from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
need_record: bool = True
response_type: str = "string"
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
return None
self.successful_requests += 1
if "token_usage" not in response.llm_output:
return None
if "function_call" in response.generations[0][0].message.additional_kwargs:
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
self.need_record = False
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
self.response_type = "image"
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
model_name = standardize_model_name(response.llm_output.get("model_name", ""))
if model_name in MODEL_COST_PER_1K_TOKENS:
completion_cost = get_openai_token_cost_for_model(
model_name, completion_tokens, is_completion=True
)
prompt_cost = get_openai_token_cost_for_model(model_name, prompt_tokens)
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
"""Write token usage to redis."""
outputs["total_tokens"] = self.total_tokens
outputs["total_cost"] = self.total_cost
outputs["prompt_tokens"] = self.prompt_tokens
outputs["completion_tokens"] = self.completion_tokens
outputs["need_record"] = self.need_record
outputs["response_type"] = self.response_type

View File

@@ -0,0 +1,79 @@
from typing import Optional, List
import json
from sqlalchemy import text
# from langchain import SQLDatabase
from langchain.utilities import SQLDatabase
class CustomDatabase(SQLDatabase):
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
# def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
connection = self._engine.connect()
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
# raise ValueError(f"table_names {missing_tables} not found in database")
return f"Table {','.join(missing_tables)} can not be found in the database"
all_table_names = table_names
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
]
tables = []
for table in meta_tables:
table_name = table.name
column_names = table.columns.keys()
table_info = f"Table: {table_name}\nColumns: \nID, \nimg_name\n"
for column_name in column_names:
if column_name not in ["ID", "img_name"]:
query = text(f"SELECT DISTINCT {column_name} FROM {table_name}")
result = connection.execute(query)
enum_values: List[str] = [row[0] for row in result.fetchall()]
column_info = f"{column_name}: {', '.join(enum_values)}\n"
table_info += column_info
# table_info = f"Table: {table_name}\n"
#
# if self._sample_rows_in_table_info:
# table_info += f"{self._get_sample_rows(table)}\n"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def run(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.begin() as connection:
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
f"ALTER SESSION SET search_path='{self._schema}'"
)
elif self.dialect == "bigquery":
connection.exec_driver_sql(f"SET @@dataset_id='{self._schema}'")
else:
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.execute(text(command))
if cursor.rowcount:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
# Convert columns values to string to avoid issues with sqlalchmey
# trunacating text
if isinstance(result, list):
return json.dumps([r[0] for r in result])
return json.dumps([result[0]])
return ""

View File

@@ -0,0 +1,114 @@
import logging
from loguru import logger
from langchain.agents import Tool
from langchain.utilities import SerpAPIWrapper
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.chat_models import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.callbacks import FileCallbackHandler
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.callbacks import OpenAITokenRecordCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX
from app.service.chat_robot.script.tools import (QuerySQLDataBaseTool, InfoSQLDatabaseTool, QuerySQLCheckerTool, ListSQLDatabaseTool)
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
from app.service.chat_robot.script.tools.tutorial_tool import CustomTutorialTool
from app.core.config import *
import os
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
# log callbacks
logfile = "logs/chat_debug.log"
logger.add(logfile, colorize=True, enqueue=True)
log_handler = FileCallbackHandler(logfile)
# Initiate our LLM 'gpt-3.5-turbo'
llm = ChatOpenAI(temperature=0.1,
openai_api_key=OPENAI_API_KEY,
# callbacks=[OpenAICallbackHandler()]
)
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
tools = [
Tool(
name="internet_search",
description="Can be used to perform Internet searches",
func=search.run
),
QuerySQLDataBaseTool(db=db, return_direct=False),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
Tool(
name="tutorial_tool",
description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string",
func=CustomTutorialTool(),
return_direct=True
)
]
messages = [
SystemMessage(content=FASHION_CHAT_BOT_PREFIX),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template(
"{input} "
"Question from a {gender}."
),
AIMessage(content=TOOLS_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate(input_variables=["input", "gender", "agent_scratchpad", "history"], messages=messages)
agent = ConversationalFunctionsAgent(
llm=llm,
tools=tools,
prompt=prompt
)
memory = UserConversationBufferWindowMemory.from_redis(
return_messages=True, k=2, input_key='input', output_key='output'
)
agent_executor = CustomAgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
verbose=True,
memory=memory,
)
def chat(post_data):
user_id = post_data.user_id
session_id = post_data.session_id
input_message = post_data.message
gender = post_data.gender
final_outputs = agent_executor(
{"input": input_message, "gender": gender},
callbacks=[OpenAITokenRecordCallbackHandler(), log_handler],
session_key=f"buffer:{user_id}:{session_id}",
)
api_response = {
'user_id': user_id,
'session_id': session_id,
# 'message_id': message_id,
# 'create_time': created_time,
'input': final_outputs['input'],
# 'conversion': messages,
'output': final_outputs['output'],
# 'gpt_response_time': gpt_response_time,
'total_tokens': final_outputs['total_tokens'],
'total_cost': final_outputs['total_cost'],
'prompt_tokens': final_outputs['prompt_tokens'],
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs['response_type']
}
logging.info(api_response)
return api_response

View File

@@ -0,0 +1,3 @@
from .user_buffer_window import UserConversationBufferWindowMemory
__all__ = ['UserConversationBufferWindowMemory']

View File

@@ -0,0 +1,93 @@
import logging
from typing import Any, Dict, List, Tuple
import json
import redis
from redis import Redis
from langchain.memory import RedisChatMessageHistory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
from langchain.schema.messages import _message_to_dict, messages_from_dict
from langchain.memory.utils import get_prompt_input_key
from app.core.config import *
class UserConversationBufferWindowMemory(BaseChatMemory):
"""Buffer for storing conversation memory."""
redis_client: Redis
human_prefix: str = "Human"
ai_prefix: str = "AI"
memory_key: str = "history" #: :meta private:
k: int = 5
@classmethod
def from_redis(
cls,
host: str = REDIS_HOST,
port: int = REDIS_PORT,
db: int = 3,
**kwargs
):
redis_client = Redis(host=host, port=port, db=db)
try:
response = redis_client.ping()
if response:
print("Connect to redis server successfully.")
logging.info("Connect to redis server successfully.")
else:
print("Fail to connect to redis server")
logging.info("Fail to connect to redis server")
except redis.RedisError as e:
logging.info(f"Error occurs when connecting to redis server: {str(e)}")
return cls(redis_client=redis_client, **kwargs)
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
:meta private:
"""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any], key: str = "") -> Dict[str, str]:
"""Return history buffer."""
_items: Any = self.redis_client.lrange(key, 0, self.k * 2) if self.k > 0 else []
items = [json.loads(m.decode("utf-8")) for m in _items[::-1]]
buffer = messages_from_dict(items)
if not self.return_messages:
buffer = get_buffer_string(
buffer,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
return {self.memory_key: buffer}
def _get_input_output(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> Tuple[str, str]:
if self.input_key is None:
prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
else:
prompt_input_key = self.input_key
if self.output_key is None:
if len(outputs) != 1:
raise ValueError(f"One output key expected, got {outputs.keys()}")
output_key = list(outputs.keys())[0]
else:
output_key = self.output_key
return inputs[prompt_input_key], outputs[output_key]
def add_message(self, key: str, message: BaseMessage) -> None:
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.add_message(key, HumanMessage(content=input_str))
self.add_message(key, AIMessage(content=output_str))
# def clear(self, key) -> None:
# """Clear memory contents."""
# self.redis_client.delete(key)

View File

@@ -0,0 +1,52 @@
FASHION_CHAT_BOT_PREFIX = """
You are a helpful assistant for fashion designers. You can chat with the users or answer their query as much as you can.
The most crucial aspect is to accurately determine whether the user's inquiry requires a internet search or querying the database.
Remember your answer should be very precise and the final output answer should not exceed 20 words.
You may encounter the following types of questions:
1) If the query related to clothing retrieval, you are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct mysql query to run, always fetching random data from tables.
Unless the user specifies a specific number of examples they wish to obtain,always limit your query to at most 4 results.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
2) If the query related to current events, you should use internet_search to seek help from the internet.
3) If the query is just casual conversation, engage in the conversation as a fashion designer assistant.
Be careful to use the tools, since you are actually a chat bot. Tools can only be used when essential.
"""
TOOL_SELECT_SUFFIX = """
Prior to proceeding, it is essential to carefully assess the question and select the appropriate tools or approach accordingly.
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
The use of online resources should be limited to inquiry pertaining to current subjects.
"""
SQL_FUNCTIONS_SUFFIX = """
For database-related questions, use SQL tools to identify relevant tables and query their schemas.
"""
INTERNET_SEARCH_SUFFIX = """
If the question should be answered using internet search tools, I should seek help from the internet.
"""
ANSWER_FORMAT_SUFFIX = """
My final answer are limited to 20 words and be as much precise as possible.
"""
TOOLS_FUNCTIONS_SUFFIX = (
"If the input involves clothing queries,"
"I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables."
"All SQL statements must use 'ORDER BY RAND()', for example:"
"Example Input 1: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' ORDER BY RAND() LIMIT 2'"
"If the input does not involve clothing queries, "
"I should engage in conversation as an assistant or search from internet with internet_search tool."
"If the database query returns no results, please respond directly with: 'Apologies, I couldn't find any images that match your description. Could you please give me more details about the clothing you're searching for?'"
"Upon mentioning words related to 'tutorial' in the input, I should use tutorial_tool "
)
TUTORIAL_TOOL_RETURN = "Commencing the systematic tutorial guide now."

View File

@@ -0,0 +1,10 @@
from .sql_tools import (
QuerySQLDataBaseTool,
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool
)
__all__ = [
"QuerySQLCheckerTool", "InfoSQLDatabaseTool", "ListSQLDatabaseTool", "QuerySQLDataBaseTool"
]

View File

@@ -0,0 +1,183 @@
# flake8: noqa
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
# from langchain.sql_database import SQLDatabase
from langchain.utilities import SQLDatabase
from langchain.tools.base import BaseTool
from langchain.tools.sql_database.prompt import QUERY_CHECKER
class BaseSQLDatabaseTool(BaseModel):
"""Base tools for interacting with a SQL database."""
db: SQLDatabase = Field(exclude=True)
param_description: str = ""
# Override BaseTool.Config to appease mypy
# See https://github.com/pydantic/pydantic/issues/4173
class Config(BaseTool.Config):
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
extra = Extra.forbid
class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for querying a SQL database."""
name = "sql_db_query"
# description = """
# Before use this tool, another tool named sql_db_schema must be used first to find the schema of interested tables.
# This tool is designed exclusively for generating SELECT queries to retrieve clothing's img_name randomly from a MySQL database.
# You should always use order by rand() to randomly select data.
# If the query is not correct, an error message will be returned.
# If an error is returned, rewrite the query, check the query, and try again.
# Always limit your query to at most 4 results.
# Never query for all the columns from a specific table, only ask for the relevant columns given the question.
# You MUST double check your query before executing it.
# DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
# """
description: str = (
"The input of this tool is a detailed and correct SQL select query statement, "
"and the output is the result of the database, and it can only return up to 4 results."
"If the query is not correct, an error message will be returned."
"If an error is returned, rewrite the query, check the query, and try again."
"If you encounter an issue with Unknown column 'xxxx' in 'field list' or Table 'attribute_retrieval.xxxx' doesn't exist,"
"use sql_db_schema to query the correct table fields."
"Example Input: 'SELECT img_name FROM skirt WHERE opening_type = 'Button' ORDER BY RAND() "
"LIMIT 1'"
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
"order by rand() LIMIT 2'"
)
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Execute the query, return the results or an error message."""
result = self.db.run_no_throw(query)
return result
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("QuerySqlDbTool does not support async")
class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting metadata about a SQL database."""
name = "sql_db_schema"
# description = """
# The database contains information of lots of fashion items, such as item name, their fashion attributes.
# There are five tables covering five fashion categories: top, pants, dress, skirt, and outwear.
# Find the most relevant tables with the query, and output the schema of these tables.
# """
description: str = (
"Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables."
"There are eight tables covering eight fashion categories: female_top, female_pants, female_dress,"
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
"Example Input: 'female_outwear, male_top'"
)
def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
return self.db.get_table_info_no_throw(table_names.split(", "))
async def _arun(
self,
table_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("SchemaSqlDbTool does not support async")
class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""
name = "sql_db_list_tables"
description = "Input is an empty string, output is a comma separated list of tables in the database."
def _run(
self,
tool_input: str = "",
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for a specific table."""
return ", ".join(self.db.get_usable_table_names())
async def _arun(
self,
tool_input: str = "",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("ListTablesSqlDbTool does not support async")
class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
template: str = QUERY_CHECKER
llm: BaseLanguageModel
llm_chain: LLMChain = Field(init=False)
name = "sql_db_query_checker"
description = (
"Use this tools to double check if your query is correct before executing it."
"Always use this tools before executing a query with sql_db_query!"
)
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER,
input_variables=["query", "dialect"]
),
)
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
# if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
)
return values
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the LLM to check the query."""
return self.llm_chain.predict(query=query, dialect=self.db.dialect)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
return await self.llm_chain.apredict(query=query, dialect=self.db.dialect)

View File

@@ -0,0 +1,19 @@
from typing import Any
from langchain.tools.base import BaseTool
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
# 处理系统引导教程相关的输入
class CustomTutorialTool(BaseTool):
name = "tutorial_tool"
description = ("Utilize this tool to retrieve specific statements related to user guidance tutorials."
"Input is an empty string")
def _run(self, tool_input, **kwargs: Any) -> str:
return TUTORIAL_TOOL_RETURN
async def _arun(self, tool_input, **kwargs: Any) -> str:
raise NotImplementedError("CustomTutorialTool does not support async")

View File

@@ -0,0 +1 @@
from .logger import Logger

View File

@@ -0,0 +1,26 @@
import logging
from logging import handlers
class Logger(object):
level_relations = {
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'crit': logging.CRITICAL
}
def __init__(self, filename, level='info', when='D', backCount=3,
fmt='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'):
self.logger = logging.getLogger(filename)
format_str = logging.Formatter(fmt) # set log format
self.logger.setLevel(self.level_relations.get(level)) # set log level
sh = logging.StreamHandler() # output to terminal
sh.setFormatter(format_str) # set format for terminal log
th = handlers.TimedRotatingFileHandler(filename=filename, when=when, backupCount=backCount,
encoding='utf-8') # log into file
th.setFormatter(format_str) # set format for file log
self.logger.addHandler(sh) # output to terminal
self.logger.addHandler(th) # output to file

View File

View File

@@ -0,0 +1,116 @@
import logging
import numpy as np
import cv2
from matplotlib import pyplot as plt
from PIL import Image
def show(img, win_name="temp"):
cv2.imshow(win_name, img)
cv2.waitKey(0)
def crop(img):
mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2)
img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340]
return img_roi
class Layer(object):
def __init__(self):
self._layer = []
@property
def layer(self):
return self._layer
def insert(self, layer_instance):
if layer_instance['name'] == 'body':
self._body = layer_instance
self._layer.append(layer_instance)
def sort(self, priority):
self._layer.sort(key=lambda x: priority[x['name']])
# def merge(self, cfg):
# """
# opencv shape order (height, width, channel)
# image coordinate system:
# |------------->x (width)
# |
# |
# |
# y (height)
# Returns:
#
#
# """
# base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0))
# for layer in self._layer:
# y, x = layer['position']
# base_image.paste(layer['image'], (x, y), layer['image'])
# # base_image.show()
#
# for x in self._layer:
# if np.all(x['mask'] == 0):
# continue
# # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask)
# roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x)
# temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2)
# tmp1 = (roi * (temp_bg / 255)).astype(np.uint8)
# temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2)
# tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8)
#
# roi[:] = cv2.add(tmp1, tmp2)
# # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA),
# # win_name=x.get('name'))
# # crop image and get the central part
# if cfg.get('basic')['self_template'] == False:
# dst_roi = crop(dst)
# else:
# dst_roi = dst
# return dst_roi, signal
#
# @staticmethod
# def get_roi(dst, image):
# signal = False
# dst_y, dst_x = dst.shape[:2]
# roi_height, roi_width = image['mask'].shape
# roi_y0, roi_x0 = image['position']
#
# if roi_y0 < 0:
# roi_yin = 0
# mask_yin = -roi_y0
# signal = True
# else:
# roi_yin = roi_y0
# mask_yin = 0
# if roi_y0 + roi_height > dst_y:
# roi_yout = dst_y
# mask_yout = dst_y - roi_y0
# signal = True
# else:
# roi_yout = roi_height + roi_y0
# mask_yout = roi_height
# # x part
# if roi_x0 < 0:
# roi_xin = 0
# mask_xin = -roi_x0
# signal = True
# else:
# roi_xin = roi_x0
# mask_xin = 0
# if roi_x0 + roi_width > dst_x:
# roi_xout = dst_x
# mask_xout = dst_x - roi_x0
# signal = True
# else:
# roi_xout = roi_width + roi_x0
# mask_xout = roi_width
#
# roi = dst[roi_yin: roi_yout, roi_xin: roi_xout]
# roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout]
# roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout]
# return roi, roi_mask, roi_image, signal

View File

@@ -0,0 +1,45 @@
class Priority(object):
"""Item layer priority levels.
"""
def __init__(self, item_list):
self._priority = dict(
earring_front=99,
bag_front=98,
hairstyle_front=97,
outwear_front=20,
bottoms_front=19,
dress_front=18,
blouse_front=17,
skirt_front=16,
trousers_front=15,
tops_front=14,
shoes_right=1,
shoes_left=1,
body=0,
tops_back=-14,
trousers_back=-15,
skirt_back=-16,
blouse_back=-17,
dress_back=-18,
bottoms_back=-19,
outwear_back=-20,
hairstyle_back=-97,
bag_back=-98,
earring_back=-99,
)
self.clothing_start_num = 10
if not isinstance(item_list, list):
raise ValueError('item_list must be a list!')
for cate in item_list:
cate = cate.lower()
if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
raise ValueError(f'Item type error. Cannot recognize {cate}')
for i, cate in enumerate(item_list):
cate = cate.lower()
self._priority[f'{cate}_front'] = self.clothing_start_num - i
self._priority[f'{cate}_back'] = -(self.clothing_start_num - i)
@property
def priority(self):
return self._priority

View File

@@ -0,0 +1,101 @@
{
"objects": [
{
"basic": {
"body_point": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 67315,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/0628000325.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Trousers"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 92912,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0825001943.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Blouse"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 91430,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/outwear/0825000856.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Outwear"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
],
"process_id": "7296013643475027"
}

View File

@@ -0,0 +1,684 @@
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 67315,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/0628000325.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Trousers"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 92912,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0825001943.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Blouse"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 91430,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/outwear/0825000856.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Outwear"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
,
{
"basic": {
"body_point_test": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 92913,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/dress/826000033.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Dress"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
,
{
"basic": {
"body_point_test": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 92914,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/skirt/0902001788.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Skirt"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 92915,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0902003817.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Blouse"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
,
{
"basic": {
"body_point_test": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 92916,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/skirt/skirt_p4_838.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Skirt"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 84210,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0916000703.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Blouse"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
,
{
"basic": {
"body_point_test": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 62041,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/outwear/0902000232.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Outwear"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 67039,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0902002591.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Blouse"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 78016,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/trousers_p4_302.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Trousers"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
,
{
"basic": {
"body_point_test": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 92917,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/0902001403.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Trousers"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 92306,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0902001766.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Blouse"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
,
{
"basic": {
"body_point_test": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 86564,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0916000038.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Blouse"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 92918,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/0628001561.jpeg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Trousers"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 92919,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/outwear/outwear_p3186.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Outwear"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
,
{
"basic": {
"body_point_test": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"color": "151 78 78",
"icon": "none",
"image_id": 67009,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/blouse/0902002051.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Blouse"
},
{
"color": "151 78 78",
"icon": "none",
"image_id": 85028,
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/skirt/903000142.jpg",
"print": {
"IfSingle": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Skirt"
},
{
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"resize_scale": 1.0,
"type": "Body"
}
]
}
],
"process_id": "7296013643475027"
}

View File

@@ -0,0 +1,69 @@
{
"basic": {
"body_point": {
"waistband_right": [
1081,
1318
],
"hand_point_right": [
1200,
1857
],
"waistband_left": [
639,
1315
],
"hand_point_left": [
493,
1808
],
"shoulder_left": [
556,
582
],
"shoulder_right": [
1130,
576
]
},
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "single",
"switch_category": "Trousers",
"body_path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png"
},
"item": [
{
"color": "151 78 78",
"image_id": "67315",
"offset": [
1,
1
],
"path": "aida-sys-image/images/female/trousers/0628000325.jpg",
"print": {
"if_single": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Trousers"
},
{
"color": "151 78 78",
"path": "aida-users/89/models/female/5d39394e-9809-43c2-80b8-4e96497b1974.png",
"image_id": 69331,
"offset": [
1,
1
],
"print": {
"if_single": false,
"print_path_list": []
},
"resize_scale": 1.0,
"type": "Body"
}
]
}

View File

@@ -0,0 +1,16 @@
from .builder import ITEMS, build_item
from .clothing import Clothing # 4.0 sec
from .body import Body
from .top import Top, Blouse, Outwear, Dress
from .bottom import Bottom, Trousers, Skirt
from .shoes import Shoes
from .bag import Bag
from .accessories import Hairstyle, Earring
__all__ = [
'ITEMS', 'build_item',
'Clothing', 'Body',
'Top', 'Blouse', 'Outwear', 'Dress',
'Bottom', 'Trousers', 'Skirt',
'Shoes', 'Bag', 'Hairstyle', 'Earring'
]

View File

@@ -0,0 +1,59 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Hairstyle(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Hairstyle, self).__init__(**kwargs)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
"""
align up
Args:
keypoint_type: string, "head_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
side_indicator = f'{keypoint_type}_up'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
# logging.info(clothes_point[side_indicator])
start_point = (
int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)),
int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale))
)
return start_point
@ITEMS.register_module()
class Earring(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Earring, self).__init__(**kwargs)

View File

@@ -0,0 +1,44 @@
from .builder import ITEMS
from .clothing import Clothing
import random
@ITEMS.register_module()
class Bag(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Bag, self).__init__(**kwargs)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
"""
align left
Args:
keypoint_type: string, "hand_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (y', x')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
location = random.choice(seq=['left', 'right'])
if location == 'left':
side_indicator = f'{keypoint_type}_left'
else:
side_indicator = f'{keypoint_type}_right'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale),
body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale))
return start_point

View File

@@ -0,0 +1,35 @@
import cv2
from .builder import ITEMS
from .pipelines import Compose
@ITEMS.register_module()
class Body(object):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']),
# dict(type='ImageShow', key=['body_image', "body_mask"])
]
self.pipeline = Compose(pipeline)
self.result = dict()
def process(self):
self.pipeline(self.result)
pass
def organize(self, layer):
body_layer = dict(priority=0,
name=type(self).__name__.lower(),
image=self.result['body_image'],
image_url=self.result['image_url'],
mask_image=None,
mask_url=None,
sacle=1,
# mask=self.result['body_mask'],
position=(0, 0))
layer.insert(body_layer)
@staticmethod
def show(img):
cv2.imshow('', img)
cv2.waitKey(0)

View File

@@ -0,0 +1,38 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Bottom(Clothing):
def __init__(self, pipeline, **kwargs):
if pipeline is None:
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']),
]
kwargs.update(pipeline=pipeline)
super(Bottom, self).__init__(**kwargs)
@ITEMS.register_module()
class Trousers(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Trousers, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Skirt(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Skirt, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Bottoms(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Bottoms, self).__init__(pipeline, **kwargs)

View File

@@ -0,0 +1,9 @@
from mmcv.utils import Registry, build_from_cfg
ITEMS = Registry('item')
PIPELINES = Registry('pipeline')
def build_item(cfg, default_args=None):
item = build_from_cfg(cfg, ITEMS, default_args)
return item

View File

@@ -0,0 +1,96 @@
import cv2
from app.core.config import PRIORITY_DICT
from .builder import ITEMS
from .pipelines import Compose
@ITEMS.register_module()
class Clothing(object):
def __init__(self, pipeline, **kwargs):
self.pipeline = Compose(pipeline)
self.result = dict(name=type(self).__name__.lower(), **kwargs)
def process(self):
self.pipeline(self.result)
def apply_scale(self, img):
scale = self.result['scale']
height, width = img.shape[0: 2]
if len(img.shape) > 2:
height, width = img.shape[0: 2]
scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
return scaled_img
def organize(self, layer):
start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"])
front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None),
name=f'{type(self).__name__.lower()}_front',
image=self.result["front_image"],
# mask_image=self.result['front_mask_image'],
image_url=self.result['front_image_url'],
mask_url=self.result['front_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else ""
)
layer.insert(front_layer)
back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None),
name=f'{type(self).__name__.lower()}_back',
image=self.result["back_image"],
# mask_image=self.result['back_mask_image'],
image_url=self.result['back_image_url'],
mask_url=self.result['back_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else ""
)
layer.insert(back_layer)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
"""
Align left
Args:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale + offset
y' = x_body - x1 * scale + offset
"""
side_indicator = f'{keypoint_type}_left'
# if keypoint_type == "ear_point":
# start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
# body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
# else:
# start_point = (
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
# )
# milvus_DB_keypoint_cache:
start_point = (
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
)
# start_point = (
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
# )
return start_point

View File

@@ -0,0 +1,19 @@
from .compose import Compose
from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow
from .keypoints import KeypointDetection
from .segmentation import Segmentation
from .painting import Painting, PrintPainting
from .scale import Scaling
from .contour_detection import ContourDetection
from .split import Split
__all__ = [
'Compose',
'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow',
'KeypointDetection',
'Segmentation',
'Painting', 'PrintPainting',
'Scaling',
'ContourDetection',
'split',
]

View File

@@ -0,0 +1,36 @@
import collections
from mmcv.utils import build_from_cfg
from ..builder import PIPELINES
@PIPELINES.register_module()
class Compose(object):
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, data):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data

View File

@@ -0,0 +1,59 @@
import logging
from ..builder import PIPELINES
import cv2
import numpy as np
@PIPELINES.register_module()
class ContourDetection(object):
def __init__(self):
# logging.info("ContourDetection run ")
pass
#@ RunTime
def __call__(self, result):
# shoe diff
if result['name'] == 'shoes':
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
for i in range(2):
Max_contour = Contour[i]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
else:
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
# TODO 修复部分图片出现透明的情况 下版本上线
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
# Mask = cv2.bitwise_not(Mask)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
return result
@staticmethod
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour

View File

@@ -0,0 +1,148 @@
import logging
import time
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from ..builder import PIPELINES
from ...utils.design_ensemble import get_keypoint_result
@PIPELINES.register_module()
class KeypointDetection(object):
"""
path here: abstract path
"""
def __init__(self):
self.client = MilvusClient(
uri="http://10.1.1.240:19530",
token="root:Milvus",
db_name=MILVUS_ALIAS
)
def __del__(self):
# start_time = time.time()
self.client.close()
# print(f"client close time : {time.time() - start_time}")
# @ RunTime
def __call__(self, result):
# logging.info("KeypointDetection run ")
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
result['clothes_keypoint'] = keypoint_cache
return result
@staticmethod
def infer_keypoint_result(result):
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
start_time = time.time()
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
# logging.info(f"infer keypoint time : {time.time() - start_time}")
return keypoint_infer_result, site
@staticmethod
# @ RunTime
def save_keypoint_cache(keypoint_id, cache, site, KEYPOINT_RESULT_TABLE_FIELD_SET=None):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([cache.flatten(), zeros])
# 取消向量保存 直接拿结果
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": site,
"keypoint_vector": result.tolist()
}
]
client = MilvusClient(
uri="http://10.1.1.240:19530",
token="root:Milvus",
db_name=MILVUS_ALIAS
)
try:
start_time = time.time()
res = client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data,
)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
finally:
client.close()
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
if site == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
client = MilvusClient(
uri="http://10.1.1.240:19530",
token="root:Milvus",
db_name=MILVUS_ALIAS
)
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
def keypoint_cache(self, result, site):
try:
keypoint_id = result['image_id']
res = self.client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False

View File

@@ -0,0 +1,143 @@
import io
import logging
import time
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import *
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile(object):
def __init__(self, path, color=None, print_dict=None):
self.path = path
self.color = color
self.print_dict = print_dict
self.minio_client = Minio(
f"{MINIO_URL}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
def __call__(self, result):
result['image'], result['pre_mask'] = self.read_image(self.path)
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['keypoint'] = self.get_keypoint(result['name'])
result['path'] = self.path
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
result['color'] = self.color if self.color is not None else None
result['print_dict'] = self.print_dict
return result
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
keypoint = 'waistband'
elif name == 'bag':
keypoint = 'hand_point'
elif name == 'shoes':
keypoint = 'toe'
elif name == 'hairstyle':
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")
return keypoint
def read_image(self, image_path):
image_mask = None
file = self.minio_client.get_object(image_path.split("/", 1)[0], image_path.split("/", 1)[1]).data
image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
return image, image_mask
@PIPELINES.register_module()
class LoadBodyImageFromFile(object):
def __init__(self, body_path):
self.body_path = body_path
self.minioClient = Minio(
f"{MINIO_URL}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
# @ RunTime
def __call__(self, result):
result["image_url"] = result['body_path'] = self.body_path
result["name"] = "mannequin"
if not result['image_url'].lower().endswith(".png"):
logging.info(1)
bucket = self.body_path.split("/", 1)[0]
object_name = self.body_path.split("/", 1)[1]
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
image = self.minioClient.get_object(bucket, object_name)
image = Image.open(io.BytesIO(image.data))
image = image.convert("RGBA")
data = image.getdata()
#
new_data = []
for item in data:
if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
new_data.append((255, 255, 255, 0))
else:
new_data.append(item)
image.putdata(new_data)
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
self.body_path = image_path
result["image_url"] = result['body_path'] = self.body_path
response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
# put_image_time = time.time()
result['body_image'] = Image.open(io.BytesIO(response.read()))
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
return result
@PIPELINES.register_module()
class ImageShow(object):
def __init__(self, key):
self.key = key
# @ RunTime
def __call__(self, result):
import matplotlib.pyplot as plt
if isinstance(self.key, list):
for key in self.key:
plt.imshow(result[key])
plt.title(key)
plt.show()
elif isinstance(self.key, str):
img = self._resize_img(result[self.key])
cv2.imshow(self.key, img)
cv2.waitKey(0)
else:
raise TypeError(f'key should be string but got type {type(self.key)}.')
return result
@staticmethod
def _resize_img(img):
shape = img.shape
if shape[0] > 400 or shape[1] > 400:
ratio = min(400 / shape[0], 400 / shape[1])
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
return img

View File

@@ -0,0 +1,602 @@
import random
from io import BytesIO
import boto3
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import *
from ..builder import PIPELINES
minio_client = Minio(
MINIO_URL,
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
@PIPELINES.register_module()
class Painting(object):
def __init__(self, painting_flag=True):
self.painting_flag = painting_flag
# @ RunTime
def __call__(self, result):
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
dim_image_h, dim_image_w = result['image'].shape[0:2]
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
result['alpha'] = 100 / 255.0
else:
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
get_image_fir = result['image'] * (closed_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
return result
@staticmethod
def get_gradient(bucket_name, object_name):
image_data = minio_client.get_object(bucket_name, object_name)
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
# 从数据流中读取图像
image_bytes = image_data.read()
# 将图像数据转换为numpy数组
image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
# 使用OpenCV解码图像数组
image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
return image
@staticmethod
def crop_image(image, image_size_h, image_size_w):
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
return pattern
@staticmethod
def gradient(image, angle_degrees, start_color, end_color):
height, width = image.shape[0], image.shape[1]
# 创建一个空白的图像
gradient_image = np.zeros((height, width, 3), dtype=np.uint8)
# 将角度限制在 0 到 360 度之间
angle_degrees = np.clip(angle_degrees, 0, 360)
# 将角度转换为弧度
angle_radians = np.radians(angle_degrees)
# 计算渐变的方向
dx = np.cos(angle_radians)
dy = np.sin(angle_radians)
# 创建网格
x_grid, y_grid = np.meshgrid(np.arange(width), np.arange(height))
# 计算每个像素在渐变方向上的位置
distance_along_gradient = (x_grid * dx + y_grid * dy) / np.sqrt(dx ** 2 + dy ** 2)
# 计算渐变的权重
weight = np.clip(distance_along_gradient / max(width, height), 0, 1)
# 计算渐变的颜色
gradient_image[:, :, 0] = (1 - weight) * start_color[0] + weight * end_color[0]
gradient_image[:, :, 1] = (1 - weight) * start_color[1] + weight * end_color[1]
gradient_image[:, :, 2] = (1 - weight) * start_color[2] + weight * end_color[2]
return gradient_image
@PIPELINES.register_module()
class PrintPainting(object):
def __init__(self, print_flag=True):
self.print_flag = print_flag
# @ RunTime
def __call__(self, result):
if "location" not in result['print'].keys():
result['print']["location"] = [[0, 0]]
elif result['print']["location"] == [] or result['print']["location"] is None:
result['print']["location"] = [[0, 0]]
if result['print']['IfSingle']:
if len(result['print']['print_path_list']) == 0:
raise ValueError('When there is no printing, ifsingle must be false')
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
# print_background = np.full((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), 255, dtype=np.uint8)
for i in range(len(result['print']['print_path_list'])):
image, image_mode = self.read_image(result['print']['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * result['print']['print_scale_list'][i]), int(image.height * result['print']['print_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(result['print']['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(result['print']['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['print']['location'][i][0]), int(result['print']['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, result['print']['print_angle_list'][i], result['print']['print_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(result['print']['location'][i][0] - rotated_new_size[0]), int(result['print']['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
else:
painting_dict = {}
painting_dict['dim_image_h'], painting_dict['dim_image_w'] = result['pattern_image'].shape[0:2]
# no print
if len(result['print_dict']['print_path_list']) == 0 or not self.print_flag:
result['print_image'] = result['pattern_image']
# print
else:
painting_dict = self.painting_collection(painting_dict, result, print_trigger=True)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['final_image'] = result['print_image']
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if "element" in result.keys():
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(result['element']['element_path_list'])):
image, image_mode = self.read_image(result['element']['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * result['element']['element_scale_list'][i]), int(image.height * result['element']['element_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(result['element']['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(result['element']['element_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(result['element']['location'][i][0]), int(result['element']['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
print(1)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, result['element']['element_angle_list'][i], result['element']['element_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(result['element']['location'][i][0] - rotated_new_size[0]), int(result['element']['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
img2gray = cv2.cvtColor(print_background, cv2.COLOR_BGR2GRAY)
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask_)
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_)
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_inv)
print_background = img1_bg + img2_fg
return print_background
def painting_collection(self, painting_dict, result, print_trigger=False):
if print_trigger:
print_ = self.get_print(result['print_dict'])
painting_dict['Trigger'] = not print_['IfSingle']
painting_dict['location'] = print_['location'] if 'location' in print_.keys() else None
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not print_['IfSingle']:
self.random_seed = random.randint(0, 1000)
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
tile = None
if not trigger:
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
else:
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
if len(pattern.shape) == 2:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
if len(pattern.shape) == 3:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
return tile
def get_mask_inv(self, print_):
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
mask_inv = cv2.inRange(print_tile, lower, upper)
return mask_inv
else:
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
return mask_inv
@staticmethod
def printpaint(result, painting_dict, print_=False):
if print_ and painting_dict['Trigger']:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else:
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:
assert f'there must be parameter of location if choose IfSingle'
for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
print_image = cv2.add(img_bg, img_fg)
return print_image
@staticmethod
def get_print(print_dict):
if not 'print_scale_list' in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
if not 'IfSingle' in print_dict.keys():
print_dict['IfSingle'] = False
data = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1])
# data = s3.get_object(Bucket=print_dict['print_path_list'][0].split("/", 1)[0], Key=print_dict['print_path_list'][0].split("/", 1)[1])['Body']
data_bytes = BytesIO(data.read())
image = Image.open(data_bytes)
image_mode = image.mode
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image_mode == "RGBA":
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
# file = minio_client.get_object(print_dict['print_path_list'][0].split("/", 1)[0], print_dict['print_path_list'][0].split("/", 1)[1]).data
# print_dict['image'] = cv2.imdecode(np.fromstring(file, np.uint8), 1)
# image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
# return image
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
random.seed(self.random_seed)
# logging.info(f'overall print location : {location}')
# x_offset = random.randint(0, image.shape[0] - image_size_h)
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
x_offset = print_w - int(location[0][1] % print_w)
y_offset = print_w - int(location[0][0] % print_h)
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_low_high_lab(Lab_value, L=False):
if L:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
else:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
return high, low
@staticmethod
def img_rotate(image, angel, scale):
"""顺时针旋转图像任意角度
Args:
image (np.array): [原始图像]
angel (float): [逆时针旋转的角度]
Returns:
[array]: [旋转后的图像]
"""
h, w = image.shape[:2]
center = (w // 2, h // 2)
# if type(angel) is not int:
# angel = 0
M = cv2.getRotationMatrix2D(center, -angel, scale)
# 调整旋转后的图像长宽
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
M[0, 2] += (rotated_w - w) // 2
M[1, 2] += (rotated_h - h) // 2
# 旋转图像
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
# return rotated_img, (0, 0)
@staticmethod
def read_image(image_url):
data = minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
# data = s3.get_object(Bucket=image_url.split("/", 1)[0], Key=image_url.split("/", 1)[1])['Body']
data_bytes = BytesIO(data.read())
image = Image.open(data_bytes)
image_mode = image.mode
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image_mode == "RGBA":
# new_background = Image.new('RGB', image.size, (255, 255, 255))
# new_background.paste(image, mask=image.split()[3])
# image = new_background
return image, image_mode
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
return image, image_mode
# @staticmethod
# def read_image(image_url):
# response = requests.get(image_url)
# image_data = np.frombuffer(response.content, np.uint8)
#
# # 解码图像
# image = cv2.imdecode(image_data, 3)
# return image

View File

@@ -0,0 +1,54 @@
from ..builder import PIPELINES
import math
import cv2
@PIPELINES.register_module()
class Scaling(object):
def __init__(self):
pass
# @ RunTime
def __call__(self, result):
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
# milvus_db_keypoint_cache
distance_clo = math.sqrt(
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
+
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2)
distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
# distance_clo = math.sqrt(
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2
# +
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2)
#
# distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'toe':
distance_bdy = math.sqrt(
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
+
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
)
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
Edge = cv2.Canny(Blur, 10, 200)
Edge = cv2.dilate(Edge, None)
Edge = cv2.erode(Edge, None)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
Max_contour = Contours[0]
x, y, w, h = cv2.boundingRect(Max_contour)
width = w
distance_clo = width
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'hand_point':
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
return result

View File

@@ -0,0 +1,14 @@
from ..builder import PIPELINES
from ...utils.design_ensemble import get_seg_result
@PIPELINES.register_module()
class Segmentation(object):
def __init__(self, device='cpu', show=False, debug=None):
self.show = show
self.device = device
self.debug = debug
def __call__(self, result):
result['seg_result'] = get_seg_result(result["image_id"], result['image'])
return result

View File

@@ -0,0 +1,77 @@
import logging
import cv2
import numpy as np
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.service.utils.generate_uuid import generate_uuid
from ..builder import PIPELINES
from PIL import Image
from ...utils.conversion_image import rgb_to_rgba
from ...utils.upload_image import upload_png_mask
@PIPELINES.register_module()
class Split(object):
"""
Split image into front and back layer according to the segmentation result
"""
# KNet
def __call__(self, result):
try:
if 'mask' not in result.keys():
raise KeyError(f'Cannot find mask in result dict, please check ContourDetection is included in process pipelines.')
if 'seg_result' not in result.keys(): # 没过seg模型
result['front_mask'] = result['mask'].copy()
result['back_mask'] = np.zeros_like(result['mask'])
else:
temp_front = result['seg_result'] == 1
result['front_mask'] = (result['mask'] * (temp_front + 0).astype(np.uint8))
temp_back = result['seg_result'] == 2
result['back_mask'] = (result['mask'] * (temp_back + 0).astype(np.uint8))
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
if len(result['front_mask'].shape) > 2:
front_mask = result['front_mask'][0]
else:
front_mask = result['front_mask']
if len(result['back_mask'].shape) > 2:
back_mask = result['back_mask'][0]
else:
back_mask = result['back_mask']
rgba_image = rgb_to_rgba((result['final_image'].shape[0], result['final_image'].shape[1]), result['final_image'], result['mask'])
result_front_image = np.zeros_like(rgba_image)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
front_new_size = (int(result_front_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_front_image_pil.height * result["scale"] * result["resize_scale"][1]))
result_front_image_pil = result_front_image_pil.resize(front_new_size, Image.LANCZOS)
# TODO 多线程外部上传图片到minio
# result['front_mask_image'] = cv2.resize(front_mask, front_new_size)
# result['front_image'] = result_front_image_pil
front_mask = cv2.resize(front_mask, front_new_size)
result['front_image'], result["front_image_url"], result["front_mask_url"] = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=front_mask)
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
back_new_size = (int(result_back_image_pil.width * result["scale"] * result["resize_scale"][0]), int(result_back_image_pil.height * result["scale"] * result["resize_scale"][1]))
result_back_image_pil = result_back_image_pil.resize(back_new_size, Image.LANCZOS)
# TODO 多线程外部上传图片到minio
# result['back_mask_image'] = cv2.resize(back_mask, back_new_size)
# result['back_image'] = result_back_image_pil
back_mask = cv2.resize(back_mask, back_new_size)
result['back_image'], result["back_image_url"], result["back_mask_url"] = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=back_mask)
else:
result['back_image'] = None
result["back_image_url"] = None
result["back_mask_url"] = None
result['back_mask_image'] = None
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -0,0 +1,126 @@
import io
import logging
import time
import cv2
import numpy as np
from .builder import ITEMS
from .clothing import Clothing
from PIL import Image
from ..utils.conversion_image import rgb_to_rgba
from ..utils.upload_image import upload_png_mask
from ...utils.generate_uuid import generate_uuid
@ITEMS.register_module()
class Shoes(Clothing):
# TODO location of shoes has little mismatch
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Shoes, self).__init__(**kwargs)
def organize(self, layer):
left_shoe_mask, right_shoe_mask = self.cut()
left_layer = dict(name=f'{type(self).__name__.lower()}_left',
image=self.result['shoes_left'],
image_url=self.result['left_image_url'],
mask_url=self.result['left_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=self.calculate_start_point(self.result['keypoint'],
self.result['scale'],
self.result['clothes_keypoint'],
self.result['body_point'],
'left'))
layer.insert(left_layer)
right_layer = dict(name=f'{type(self).__name__.lower()}_right',
image=self.result['shoes_right'],
image_url=self.result['right_image_url'],
mask_url=self.result['right_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=self.calculate_start_point(self.result['keypoint'],
self.result['scale'],
self.result['clothes_keypoint'],
self.result['body_point'],
'right'))
layer.insert(right_layer)
def cut(self):
"""
Cut shoes mask into two pieces
Returns:
"""
contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = sorted(contour, key=cv2.contourArea, reverse=True)
bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]]
(contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False))
epsilon_left = 0.001 * cv2.arcLength(contours[0], True)
approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True)
mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8)
cv2.drawContours(mask_left, [approx_left], -1, 255, -1)
item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0)
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left)
result_image = np.zeros_like(rgba_image)
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
result_left_image_pil = Image.fromarray(result_image, 'RGBA')
result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS)
self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}")
epsilon_right = 0.001 * cv2.arcLength(contours[1], True)
approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True)
mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8)
cv2.drawContours(mask_right, [approx_right], -1, 255, -1)
item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0)
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right)
result_image = np.zeros_like(rgba_image)
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
result_right_image_pil = Image.fromarray(result_image, 'RGBA')
result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS)
self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}")
return item_mask_left, item_mask_right
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location):
"""
left shoes align left
right shoes align right
Args:
keypoint_type: string, "toe"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
location: string, indicates whether the start point belongs to right or left shoe
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
if location not in ['left', 'right']:
raise KeyError(f'location value must be left or right but got {location}')
side_indicator = f'{keypoint_type}_{location}'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
return start_point

View File

@@ -0,0 +1,46 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Top(Clothing):
def __init__(self, pipeline, **kwargs):
if pipeline is None:
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Segmentation', device='cpu', show=False, debug=kwargs['debug']),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
dict(type='Scaling'),
dict(type='Split'),
]
kwargs.update(pipeline=pipeline)
super(Top, self).__init__(**kwargs)
@ITEMS.register_module()
class Blouse(Top):
def __init__(self, pipeline=None, **kwargs):
super(Blouse, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Outwear(Top):
def __init__(self, pipeline=None, **kwargs):
super(Outwear, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Dress(Top):
def __init__(self, pipeline=None, **kwargs):
super(Dress, self).__init__(pipeline, **kwargs)
# Men's clothing
@ITEMS.register_module()
class Tops(Top):
def __init__(self, pipeline=None, **kwargs):
super(Tops, self).__init__(pipeline, **kwargs)

View File

@@ -0,0 +1,130 @@
from app.core.config import PRIORITY_DICT
from app.service.design.core.layer import Layer
from app.service.design.items import build_item
from app.service.design.utils.redis_utils import Redis
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
import concurrent.futures
def process_item(item, layers):
# logging.info("process running.........")
item.process()
item.organize(layers)
if item.result['name'] == "mannequin":
return item.result['body_image'].size
def update_progress(process_id, total):
r = Redis()
progress = r.read(key=process_id)
if progress and total != 1:
if int(progress) <= 100:
r.write(key=process_id, value=int(progress) + int(100 / total))
else:
r.write(key=process_id, value=100)
return progress
elif total == 1:
r.write(key=process_id, value=100)
return progress
else:
r.write(key=process_id, value=int(100 / total))
return progress
def final_progress(process_id):
r = Redis()
progress = r.read(key=process_id)
r.write(key=process_id, value=100)
return progress
def generate(request_data):
return_response = {}
request_data = request_data.dict()
assert "process_id" in request_data.keys(), "Need process_id parameters"
objects = request_data['objects']
# insert_keypoint_cache(objects)
process_id = request_data['process_id']
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交每个对象的处理任务
futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)}
# 获取处理结果
for future in concurrent.futures.as_completed(futures):
obj = futures[future]
result = future.result()
return_response[obj] = result
final_progress(process_id)
return return_response
def process_object(cfg, process_id, total):
basic_info = cfg.get('basic')
items_response = {
'layers': []
}
if cfg.get('basic')['single_overall'] == 'overall':
basic_info['debug'] = False
items = [build_item(x, default_args=basic_info) for x in cfg.get('items')]
layers = Layer()
body_size = None
futures = []
for item in items:
futures = [process_item(item, layers)]
for future in futures:
if future is not None:
body_size = future
# 是否自定义排序
if basic_info.get('layer_order', False):
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
else:
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
# 合成
items_response['synthesis_url'] = synthesis(layers, body_size)
for lay in layers:
items_response['layers'].append({
'image_category': lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
# 'image': lay['image'],
# 'mask_image': lay['mask_image'],
})
elif cfg.get('basic')['single_overall'] == 'single':
assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters "
basic_info['debug'] = False
for item in cfg.get('items'):
if item['type'] == cfg.get('basic')['switch_category']:
item = build_item(item, default_args=cfg.get('basic'))
item.process()
items_response['layers'].append({
'image_category': f"{item.result['name']}_front",
'image_size': item.result['back_image'].size if item.result['back_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['front_image_url'],
'mask_url': item.result['front_mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else ""
})
items_response['layers'].append({
'image_category': f"{item.result['name']}_back",
'image_size': item.result['front_image'].size if item.result['front_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['back_image_url'],
'mask_url': item.result['back_mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else ""
})
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
break
update_progress(process_id, total)
return items_response

View File

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File conversion_image.py
@Author :周成融
@Date 2023/8/21 10:40:29
@detail
"""
import numpy as np
def rgb_to_rgba(rgb_size, rgb_image, mask):
alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
# 创建四通道的结果图像
rgba_image = np.dstack((rgb_image, alpha_channel))
alpha_channel = np.where(mask > 0, 255, 0)
# 更新RGBA图像的透明度通道
rgba_image[:, :, 3] = alpha_channel
return rgba_image
if __name__ == '__main__':
image = open("")

View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File design_ensemble.py
@Author :周成融
@Date 2023/8/16 19:36:21
@detail :发起请求 获取推理结果
"""
import logging
import cv2
import mmcv
import numpy as np
import tritonclient.http as httpclient
import torch
import torch.nn.functional as F
from app.core.config import *
"""
keypoint
预处理 推理 后处理
"""
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
img, w_scale, h_scale = mmcv.imresize(img, img_scale, return_scale=True)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
# @ RunTime
# 推理
def get_keypoint_result(image, site):
keypoint_result = None
try:
image, scale_factor = keypoint_preprocess(image)
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
inference_output = torch.from_numpy(results.as_numpy(f'output'))
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
except Exception as e:
logging.warning(f"get_keypoint_result : {e}")
return keypoint_result
def keypoint_postprocess(output, scale_factor):
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
segment_result = max_coords.numpy()
scale_factor = [1 / x for x in scale_factor[::-1]]
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
"""
seg
预处理 推理 后处理
"""
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
scale_factor = []
img, x, y = mmcv.imresize(img, (img_scale_w, img_scale_h), return_scale=True)
scale_factor.append(x)
scale_factor.append(y)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# @ RunTime
def get_seg_result(image_id, image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
def key_point_show(image_path, key_point_result=None):
img = cv2.imread(image_path)
points_list = key_point_result
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for point in points_list:
cv2.circle(img, point[::-1], point_size, point_color, thickness)
cv2.imshow("0", img)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread("./14162b58-f259-4833-98cb-89b9b496b251.jfif")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("./14162b58-f259-4833-98cb-89b9b496b251.jfif", new_list)
# a = get_seg_result(1, image)
print(a)

View File

@@ -0,0 +1,99 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -0,0 +1,173 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File synthesis_item.py
@Author :周成融
@Date 2023/8/26 14:13:04
@detail
"""
import io
import logging
import time
import boto3
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import *
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
minio_client = Minio(
MINIO_URL,
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
s3 = boto3.client(
's3',
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY,
region_name=S3_REGION_NAME
)
def positioning(all_mask_shape, mask_shape, offset):
all_start = 0
all_end = 0
mask_start = 0
mask_end = 0
if offset == 0:
all_start = 0
all_end = min(all_mask_shape, mask_shape)
mask_start = 0
mask_end = min(all_mask_shape, mask_shape)
elif offset > 0:
all_start = min(offset, all_mask_shape)
all_end = min(offset + mask_shape, all_mask_shape)
mask_start = 0
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
elif offset < 0:
if abs(offset) > mask_shape:
all_start = 0
all_end = 0
else:
all_start = 0
if mask_shape - abs(offset) > all_mask_shape:
all_end = min(mask_shape - abs(offset), all_mask_shape)
else:
all_end = mask_shape - abs(offset)
if abs(offset) > mask_shape:
mask_start = mask_shape
mask_end = mask_shape
else:
mask_start = abs(offset)
if mask_shape - abs(offset) >= all_mask_shape:
mask_end = all_mask_shape + abs(offset)
else:
mask_end = mask_shape
return all_start, all_end, mask_start, mask_end
@RunTime
def synthesis(data, size):
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
top_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
bottom_outer_mask = np.zeros(all_mask_shape, dtype=np.uint8)
top = True
bottom = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
top_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front"]:
bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
bottom_outer_mask[all_y_start:all_y_end, all_x_start:all_x_end] = data[i]['mask'][mask_y_start:mask_y_end, mask_x_start:mask_x_end]
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(cropped_image, (0, 0), cropped_image)
else:
base_image.paste(layer['image'], (layer['position'][1], layer['position'][0]), layer['image'])
result_image = base_image
with io.BytesIO() as output:
result_image.save(output, format='PNG')
data = output.getvalue()
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
result_image = front_image
if back_image:
result_image.paste(back_image, (0, 0), back_image)
# with io.BytesIO() as output:
# result_image.save(output, format='PNG')
# data = output.getvalue()
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"

View File

@@ -0,0 +1,155 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File upload_image.py
@Author :周成融
@Date 2023/8/28 13:49:20
@detail
"""
import io
import logging
import time
import boto3
import cv2
from minio import Minio
from app.core.config import *
from app.service.utils.decorator import RunTime
minio_client = Minio(
f"{MINIO_URL}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
"""S3 上传"""
s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
@RunTime
def upload_png_mask(front_image, object_name, mask=None):
mask_url = None
if mask is not None:
# 反转掩模
mask_inverted = cv2.bitwise_not(mask)
# 将掩模转换为 RGBA 格式
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
# 将图像数据保存到内存中的 BytesIO 对象中
image_bytes = io.BytesIO()
image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
image_bytes.seek(0)
try:
key = f"mask/mask_{object_name}.png"
mask_url = f"{AIDA_CLOTHING}/{key}"
s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png')
except Exception as e:
print(f'上传到 S3 失败: {e}')
with io.BytesIO() as output:
front_image.save(output, format='PNG')
data = output.getvalue()
# 创建一个 S3 客户端
try:
key = f"image/image_{object_name}.png"
image_url = f"{AIDA_CLOTHING}/{key}"
s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png')
return front_image, image_url, mask_url
except Exception as e:
print(f'上传到 S3 失败: {e}')
@RunTime
def upload_layer_image(image, object_name):
with io.BytesIO() as output:
image.save(output, format='PNG')
data = output.getvalue()
# 创建一个 S3 客户端
try:
key = f"image/image_{object_name}.png"
image_url = f"{AIDA_CLOTHING}/{key}"
s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=data, ContentType='image/png')
return image_url
except Exception as e:
print(f'上传到 S3 失败: {e}')
@RunTime
def upload_mask_image(mask, object_name):
# 反转掩模
mask_inverted = cv2.bitwise_not(mask)
# 将掩模转换为 RGBA 格式
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
# 将图像数据保存到内存中的 BytesIO 对象中
image_bytes = io.BytesIO()
image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
image_bytes.seek(0)
try:
key = f"mask/mask_{object_name}.png"
mask_url = f"{AIDA_CLOTHING}/{key}"
s3.put_object(Bucket=AIDA_CLOTHING, Key=key, Body=image_bytes, ContentType='image/png')
return mask_url
except Exception as e:
print(f'上传到 S3 失败: {e}')
"""minio 上传"""
# @RunTime
# def upload_png_mask(front_image, object_name, mask=None):
# start_time = time.time()
# try:
# mask_url = None
# if mask is not None:
# mask_inverted = cv2.bitwise_not(mask)
# # 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
# image_bytes = io.BytesIO()
# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
#
# image_bytes.seek(0)
# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}"
#
# image_data = io.BytesIO()
# front_image.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# # print(f"upload_png_mask {object_name} = {time.time() - start_time}")
# return front_image, image_url, mask_url
# except Exception as e:
# logging.warning(f"upload_png_mask runtime exception : {e}")
#
#
# @RunTime
# def upload_layer_image(image, object_name):
# try:
# image_data = io.BytesIO()
# image.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# image_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'image/image_{object_name}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# return image_url
# except Exception as e:
# logging.warning(f"upload_png_mask runtime exception : {e}")
#
#
# @RunTime
# def upload_mask_image(mask, object_name):
# try:
# mask_inverted = cv2.bitwise_not(mask)
# # 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
# rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
# rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
# image_bytes = io.BytesIO()
# image_bytes.write(cv2.imencode('.png', rgba_image)[1].tobytes())
#
# image_bytes.seek(0)
# mask_url = f"{AIDA_CLOTHING}/{minio_client.put_object('aida-clothing', f'mask/mask_{object_name}.png', image_bytes, len(image_bytes.getvalue()), content_type='image/png').object_name}"
# return mask_url
# except Exception as e:
# logging.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -0,0 +1,320 @@
import logging
import time
import cv2
import numpy as np
import torch
from minio import Minio
from pymilvus import connections, Collection
from urllib3.exceptions import ResponseError
import torch.nn.functional as F
import tritonclient.grpc as grpcclient
import io
from app.core.config import *
from app.service.design.utils.design_ensemble import get_keypoint_result
class DesignPreprocessing:
def __init__(self):
self.minio_client = Minio(
MINIO_URL,
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
# @ RunTime
def pipeline(self, image_list):
sketches_list = self.read_image(image_list)
logging.info("read image success")
bounding_box_sketches_list = self.bounding_box(sketches_list)
logging.info("bounding box image success")
super_resolution_list = self.super_resolution(bounding_box_sketches_list)
logging.info("super_resolution_list image success")
infer_sketches_list = self.infer_image(super_resolution_list)
logging.info("infer image success")
result = self.composing_image(infer_sketches_list)
logging.info("Replenish white edge image success")
for d in result:
if 'image_obj' in d:
del d['image_obj']
if 'obj' in d:
del d['obj']
if 'keypoint_result' in d:
del d['keypoint_result']
return result
def read_image(self, image_list):
for obj in image_list:
file = self.minio_client.get_object(obj['image_url'].split("/", 1)[0], obj['image_url'].split("/", 1)[1]).data
image = cv2.imdecode(np.frombuffer(file, np.uint8), 1)
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # 如果是四通道 mask
image = image[:, :, :3]
obj["image_obj"] = image
return image_list
# @ RunTime
def bounding_box(self, image_list):
for item in image_list:
image = item['image_obj']
# 使用Canny边缘检测来检测物体的轮廓
edges = cv2.Canny(image, 50, 150)
# 查找轮廓
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 初始化包围所有外接矩形的大矩形的坐标
x_min, y_min, x_max, y_max = float('inf'), float('inf'), -1, -1
# 遍历所有外接矩形,更新大矩形的坐标
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
x_min = min(x_min, x)
y_min = min(y_min, y)
x_max = max(x_max, x + w)
y_max = max(y_max, y + h)
if IF_DEBUG_SHOW:
image_with_big_rect = cv2.rectangle(image.copy(), (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
cv2.imshow("bounding_box image", image_with_big_rect)
cv2.waitKey(0)
# 根据大矩形的坐标来裁剪原始图像
if len(contours) > 0:
cropped_image = image[y_min:y_max, x_min:x_max]
item['obj'] = cropped_image # 新shape图像
# 取消直接覆盖新增size判断
# try:
# # 覆盖到minio
# image_bytes = cv2.imencode(".jpg", cropped_image)[1].tobytes()
# self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
# print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
# except ResponseError as err:
# print(f"Error: {err}")
else:
item['obj'] = image
return image_list
def super_resolution(self, image_list):
for item in image_list:
# 判断 两边是否同时都小于512 因为此处做四倍超分
if item['obj'].shape[0] <= 512 and item['obj'].shape[1] <= 512:
# 如果任意一边小于256则超分
if item['obj'].shape[0] <= 256 or item['obj'].shape[1] <= 256:
# 超分
img = item['obj'].astype(np.float32) / 255.
sample = np.transpose(img if img.shape[2] == 1 else img[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
grpcclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample)
triton_client = grpcclient.InferenceServerClient(url=SR_TRITON_URL)
result = triton_client.infer(model_name=SR_MODEL_NAME, inputs=inputs)
result_image = result.as_numpy(f'output')[0]
sr_output = torch.tensor(result_image)
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
item['obj'] = output
try:
# 覆盖到minio
image_bytes = cv2.imencode(".jpg", item['obj'])[1].tobytes()
self.minio_client.put_object(item['image_url'].split("/", 1)[0], item['image_url'].split("/", 1)[1], io.BytesIO(image_bytes), len(image_bytes), content_type="image/jpeg", )
print(f"Object '{item['image_url'].split('/', 1)[1]}' overwritten successfully.")
except ResponseError as err:
print(f"Error: {err}")
return image_list
# @ RunTime
def infer_image(self, image_list):
for sketch in image_list:
# 小写
image_category = sketch['image_category'].lower()
# 判断上下装
sketch['site'] = 'up' if image_category in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 推理得到keypoint
sketch['keypoint_result'] = self.keypoint_cache(sketch)
if IF_DEBUG_SHOW:
debug_show_image = sketch['obj'].copy()
points_list = []
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for i in sketch['keypoint_result'].values():
points_list.append((int(i[1]), int(i[0])))
for point in points_list:
cv2.circle(debug_show_image, point, point_size, point_color, thickness)
cv2.imshow("", debug_show_image)
cv2.waitKey(0)
# # 关键点在上部则推理seg
# if sketch["site"] == "up":
# # 判断seg缓存是否存在,是否与当前图片shape一致
# seg_result = self.search_seg_result(sketch["image_id"], sketch["obj"].shape)
# if seg_result is False:
# # 推理seg + 保存
# seg_result = get_seg_result(sketch['image_id'], sketch['obj'])
return image_list
# @ RunTime
def composing_image(self, image_list):
for image in image_list:
if image['site'] == 'down':
image_width = image['obj'].shape[1]
waist_width = image['keypoint_result']['waistband_right'][1] - image['keypoint_result']['waistband_left'][1]
scale = 0.4
if waist_width / scale >= image['obj'].shape[1]:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
if IF_DEBUG_SHOW:
cv2.imshow("composing_image", ret)
cv2.waitKey(0)
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
else:
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
else:
scale = 0.4
image_width = image['obj'].shape[1]
waist_width = image['keypoint_result']['armpit_right'][1] - image['keypoint_result']['armpit_left'][1]
if waist_width / scale >= image_width:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
if IF_DEBUG_SHOW:
cv2.imshow("composing_image", ret)
cv2.waitKey(0)
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
else:
image_bytes = cv2.imencode(".jpg", image['obj'])[1].tobytes()
image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
return image_list
@staticmethod
def select_seg_result(image_id, image_obj):
try:
# 如果shape不匹配 返回false
result = np.load(f"seg_result/{image_id}.npy").astype(np.int64)
if result.shape[1] == image_obj.shape[0] and result.shape[2] == image_obj.shape[1]:
return result
else:
return False
except FileNotFoundError as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
@staticmethod
def search_seg_result(image_id, ori_shape):
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_SEG) # Get an existing collection.
# collection.load()
# start_time = time.time()
# res = collection.query(
# expr=f"seg_id == {image_id}",
# offset=0,
# limit=10,
# output_fields=["seg_cache"],
# )
# logging.info(f"search seg cache time {time.time() - start_time}")
# if len(res):
# vector = np.reshape(res[0]['seg_cache'] + res[1]['seg_cache'], (224, 224))
# array_2d_exact = F.interpolate(torch.tensor(vector).unsqueeze(0).unsqueeze(0), size=ori_shape, mode='bilinear', align_corners=False)
# array_2d_exact = array_2d_exact.squeeze().numpy()
# return array_2d_exact
# else:
return False
except Exception as e:
logging.warning(f"{image_id} Image segmentation results cache file does not exist : {e}")
return False
def keypoint_cache(self, sketch):
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# collection.load()
start_time = time.time()
# res = collection.query(
# expr=f"keypoint_id == {sketch['image_id']}",
# offset=0,
# limit=1,
# output_fields=["keypoint_cache", "keypoint_site"],
# )
res = []
logging.info(f"search keypoint time : {time.time() - start_time}")
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch)
return self.save_keypoint_cache(sketch, keypoint_infer_result)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == sketch['site']:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != sketch['site']:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result = self.infer_keypoint_result(sketch)
return self.update_keypoint_cache(sketch, keypoint_infer_result, res[0]['keypoint_vector'])
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False
# @ RunTime
def infer_keypoint_result(self, sketch):
keypoint_infer_result = get_keypoint_result(sketch["obj"], sketch['site']) # 推理结果
return keypoint_infer_result
@staticmethod
# @ RunTime
def save_keypoint_cache(sketch, keypoint_infer_result):
if sketch['site'] == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, keypoint_infer_result.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([keypoint_infer_result.flatten(), zeros])
data = [
[int(sketch['image_id'])],
[sketch['site']],
[result.tolist()]
]
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.insert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(sketch, infer_result, search_result):
if sketch['site'] == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
[int(sketch['image_id'])],
["all"],
[result.tolist()]
]
try:
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))

View File

@@ -0,0 +1,175 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
import time
import cv2
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
logger = logging.getLogger()
class GenerateProductImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GPI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "product_image"
self.batch_size = 1
self.prompt = request_data.prompt
# TODO aida design 结果图背景改为白色
self.image, self.image_size = self.get_image(request_data.image_url)
# TODO image 填充并resize成512*768
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_bytes = io.BytesIO(response.read())
# 转换为PIL图像对象
image = Image.open(image_bytes)
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
# self.gen_product_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GPI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
prompts = [self.prompt] * self.batch_size
self.image = cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB)
self.image = cv2.resize(self.image, (512, 768))
images = [self.image.astype(np.uint8)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape(1)
image_obj = np.array(images, dtype=np.uint8).reshape((768, 512, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, "UINT8")
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
inputs = [input_text, input_image]
ctx = self.infer(inputs)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
# logger.info(gen_product_data)
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, gen_product_data)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
if __name__ == '__main__':
rd = GenerateImageModel(
tasks_id="123-89",
prompt="best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
image_url="aida-results/result_067f2f7e-21ba-11ef-8cf5-0242ac170002.png",
)
server = GenerateProductImage(rd)
print(server.get_result())

View File

@@ -0,0 +1,202 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import io
import json
import logging
import time
import cv2
import redis
import tritonclient.grpc as grpcclient
import numpy as np
from PIL import Image, ImageOps
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.schemas.generate_image import GenerateRelightImageModel
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
logger = logging.getLogger()
class GenerateRelightImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GRI_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.category = "relight_image"
self.batch_size = 1
self.prompt = request_data.prompt
self.seed = "12345"
# TODO aida design 结果图背景改为白色
# self.image, self.image_size = self.get_image(request_data.image_url)
self.image = request_data.image_url
# TODO image 填充并resize成512*768
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_product_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
self.redis_client.expire(self.tasks_id, 600)
def get_image(self, image_url):
response = self.minio_client.get_object(image_url.split('/')[0], image_url[image_url.find('/') + 1:])
image_bytes = io.BytesIO(response.read())
# 转换为PIL图像对象
image = Image.open(image_bytes)
target_height = 768
target_width = 512
aspect_ratio = image.width / image.height
new_width = int(target_height * aspect_ratio)
resized_image = image.resize((new_width, target_height))
left = (target_width - resized_image.width) // 2
top = (target_height - resized_image.height) // 2
right = target_width - resized_image.width - left
bottom = target_height - resized_image.height - top
image = ImageOps.expand(resized_image, (left, top, right, bottom), fill="white")
image_size = image.size
if image.mode in ('RGBA', 'LA') or (image.mode == 'P' and 'transparency' in image.info):
# 创建白色背景
background = Image.new("RGB", image.size, (255, 255, 255))
# 将图片粘贴到白色背景上
background.paste(image, mask=image.split()[3])
image = np.array(background)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image_file = BytesIO(response.data)
# image_array = np.asarray(bytearray(image_file.read()), dtype=np.uint8)
# image_cv2 = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# image = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
# image = cv2.resize(image_rbg, (1024, 1024))
return image, image_size
def callback(self, result, error):
if error:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(error)
# self.gen_product_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
else:
# pil图像转成numpy数组
image = result.as_numpy("generated_inpaint_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8))).resize(self.image_size)
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
# logger.info(f"upload image SUCCESS {image_url}")
self.gen_product_data['status'] = "SUCCESS"
self.gen_product_data['message'] = "success"
self.gen_product_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GRI_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def get_result(self):
try:
direction = "Right Light"
negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality'
self.prompt = 'beautiful woman, detailed face, sunshine, outdoor, warm atmosphere'
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput(
"prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)
)
input_text.set_data_from_numpy(text_obj)
negative_prompts = [negative_prompt] * self.batch_size
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput(
"negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype)
)
input_text_neg.set_data_from_numpy(text_obj_neg)
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput(
"seed", seed.shape, np_to_triton_dtype(seed.dtype)
)
input_seed.set_data_from_numpy(seed)
input_images = [self.image] * self.batch_size
text_obj_images = np.array(input_images, dtype="object").reshape((-1, 1))
input_input_images = grpcclient.InferInput(
"input_image", text_obj_images.shape, np_to_triton_dtype(text_obj_images.dtype)
)
input_input_images.set_data_from_numpy(text_obj_images)
directions = [direction] * self.batch_size
text_obj_directions = np.array(directions, dtype="object").reshape((-1, 1))
input_directions = grpcclient.InferInput(
"direction", text_obj_directions.shape, np_to_triton_dtype(text_obj_directions.dtype)
)
input_directions.set_data_from_numpy(text_obj_directions)
output_img = grpcclient.InferRequestedOutput("generated_image")
request_start = time.time()
inputs = [input_text, input_text_neg, input_input_images, input_seed, input_directions]
ctx = self.infer(inputs)
time_out = 600
while time_out > 0:
gen_product_data, _ = self.read_tasks_status()
# logger.info(gen_product_data)
if gen_product_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif gen_product_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, gen_product_data)
gen_product_data, _ = self.read_tasks_status()
return gen_product_data
except Exception as e:
self.gen_product_data['status'] = "FAILURE"
self.gen_product_data['message'] = str(e)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_product_data))
raise Exception(str(e))
finally:
dict_gen_product_data, str_gen_product_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GPI_RABBITMQ_QUEUES, body=str_gen_product_data)
# self.channel.basic_publish(exchange='', routing_key=GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES, body=str_gen_product_data)
logger.info(f" [x] Sent to {GPI_RABBITMQ_QUEUES} data@@@@ {json.dumps(dict_gen_product_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
gen_product_data = json.dumps(data)
redis_client.set(tasks_id, gen_product_data)
return data
if __name__ == '__main__':
rd = GenerateRelightImageModel(
tasks_id="123-89",
prompt="beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
image_url="/workspace/i3.png",
)
server = GenerateRelightImage(rd)
print(server.get_result())

View File

@@ -0,0 +1,137 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File service_att_recognition.py
@Author :周成融
@Date 2023/7/26 12:01:05
@detail
"""
import json
import logging
import time
import cv2
import numpy as np
import redis
from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
import tritonclient.grpc as grpcclient
from app.schemas.generate_image import GenerateSingleLogoImageModel
from app.service.generate_image.utils.upload_sd_image import upload_png_sd, upload_SDXL_image
logger = logging.getLogger()
class GenerateSingleLogoImage:
def __init__(self, request_data):
if DEBUG is False:
self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
self.channel = self.connection.channel()
# self.connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
# self.channel = self.connection.channel()
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
self.batch_size = 1
self.category = "single_logo"
self.negative_prompts = "bad, ugly"
self.seed = request_data.seed
self.tasks_id = request_data.tasks_id
self.prompt = request_data.prompt
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.gen_single_logo_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
self.redis_client.expire(self.tasks_id, 600)
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def infer(self, inputs):
return self.grpc_client.async_infer(
model_name=GSL_MODEL_NAME,
inputs=inputs,
callback=self.callback
)
def callback(self, result, error):
if error:
self.gen_single_logo_data['status'] = "FAILURE"
self.gen_single_logo_data['message'] = str(error)
# self.generate_data['data'] = str(error)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
else:
image = result.as_numpy("generated_image")
image_result = Image.fromarray(np.squeeze(image.astype(np.uint8)))
image_url = upload_SDXL_image(image_result, user_id=self.user_id, category=f"{self.category}", object_name=f"{self.tasks_id}.png")
self.gen_single_logo_data['status'] = "SUCCESS"
self.gen_single_logo_data['message'] = "success"
self.gen_single_logo_data['image_url'] = str(image_url)
self.redis_client.set(self.tasks_id, json.dumps(self.gen_single_logo_data))
def get_result(self):
try:
# prompt
prompts = [self.prompt] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
# negative_prompts
text_obj_neg = np.array(self.negative_prompts, dtype="object").reshape((-1, 1))
# print('text obj neg: ', text_obj_neg)
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
input_text_neg.set_data_from_numpy(text_obj_neg)
# seed
seed = np.array(self.seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput("seed", seed.shape, np_to_triton_dtype(seed.dtype))
input_seed.set_data_from_numpy(seed)
inputs = [input_text, input_text_neg, input_seed]
ctx = self.infer(inputs)
time_out = 600
generate_data = None
while time_out > 0:
generate_data, _ = self.read_tasks_status()
# logger.info(generate_data)
if generate_data['status'] in ["REVOKED", "FAILURE"]:
ctx.cancel()
break
elif generate_data['status'] == "SUCCESS":
break
time_out -= 1
time.sleep(0.1)
# logger.info(time_out, generate_data)
return generate_data
except Exception as e:
raise Exception(str(e))
finally:
dict_generate_data, str_generate_data = self.read_tasks_status()
if DEBUG is False:
self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
# self.channel.basic_publish(exchange='', routing_key=GI_RABBITMQ_QUEUES, body=str_generate_data)
logger.info(f" [x] Sent {json.dumps(dict_generate_data, indent=4)}")
def infer_cancel(tasks_id):
redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
data = {'tasks_id': tasks_id, 'status': 'REVOKED', 'message': "revoked", 'data': 'revoked'}
generate_data = json.dumps(data)
redis_client.set(tasks_id, generate_data)
return data
if __name__ == '__main__':
rd = GenerateSingleLogoImageModel(
tasks_id="123-89",
prompt='an apple',
seed="2",
)
server = GenerateSingleLogoImage(rd)
print(server.get_result())

View File

@@ -382,7 +382,7 @@ if __name__ == '__main__':
remove_bg_img = remove_background(luminance)
# cv2.imwrite("remove_bg_img.png", remove_bg_img)
print(1)
# print(1)
cv2.imshow("source", img)
cv2.imshow("levels", equAuto)
cv2.imshow("luminance", luminance)

View File

@@ -10,6 +10,7 @@
import io
import logging
import boto3
import cv2
from PIL import Image
from minio import Minio
@@ -17,6 +18,39 @@ from minio import Minio
from app.core.config import *
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
s3 = boto3.client('s3', aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, region_name=S3_REGION_NAME)
# def upload_single_logo(image, user_id, category, object_name):
# with io.BytesIO() as output:
# image.save(output, format='PNG')
# data = output.getvalue()
# # 创建一个 S3 客户端
# try:
# key = f'{user_id}/{category}/{object_name}'
# image_url = f"{AIDA_CLOTHING}/{key}"
# s3.put_object(Bucket=GSL_MINIO_BUCKET, Key=key, Body=data, ContentType='image/png')
# return image_url
# except Exception as e:
# print(f'上传到 S3 失败: {e}')
def upload_SDXL_image(image, user_id, category, object_name):
try:
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
minio_req = minio_client.put_object(
GI_MINIO_BUCKET,
f'{user_id}/{category}/{object_name}',
io.BytesIO(image_bytes),
len(image_bytes),
content_type='image/jpeg'
)
image_url = f"aida-users/{minio_req.object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
def upload_png_sd(image, user_id, category, object_name):

View File

@@ -0,0 +1,70 @@
import os
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, \
PromptTemplate
from app.core.config import OPENAI_MODEL, OPENAI_API_KEY
# os.environ["http_proxy"] = "http://127.0.0.1:7890"
# os.environ["https_proxy"] = "http://127.0.0.1:7890"
llm = ChatOpenAI(model_name=OPENAI_MODEL,
openai_api_key=OPENAI_API_KEY,
temperature=0)
def translate_to_en(text):
template = (
"""You are a translation expert, proficient in various languages.
And can translate various languages into English.
Please translate to grammatically correct English regardless of the input language.
If the input is in English or numbers, check for grammatical errors. If there are no errors, output the input directly.
If there are grammatical errors, correct them and then output the sentence."""
)
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
# 待翻译文本由 Human 角色输入
human_template = "User input : {text}"
human_message_prompt = HumanMessagePromptTemplate.from_template(input_variables=["text"], template=human_template)
# 使用 System 和 Human 角色的提示模板构造 ChatPromptTemplate
chat_prompt_template = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
translate_chain = LLMChain(llm=llm, prompt=chat_prompt_template)
template = (
"""
Input sentence:
{translate}
1. Based on the input,adjust the input sentence to make it more suitable for prompts for generating images,
ensuring all key nouns or adjectives related to the image are retained.
2. Simplify complex sentence structures and clarify ambiguous expressions.
3. Only Output the adjusted English sentence.
Output :
"""
)
# "Based on the input sentence, extract key adjectives and nouns.Only Output extracted key words."
# 1. Check if the input sentence contains any grammatical errors. If there are errors, please correct them before proceeding.
prompt_template = PromptTemplate(input_variables=["translate"], template=template)
prompt_chain = LLMChain(llm=llm, prompt=prompt_template)
from langchain.chains import SimpleSequentialChain
overall_chain = SimpleSequentialChain(chains=[translate_chain, prompt_chain], verbose=True)
response = overall_chain.run(text)
return response
def main():
"""Main function"""
translate_to_en("生成一件运动风格的夹克,带有拉链和口袋,适合休闲穿着")
if __name__ == "__main__":
main()

Binary file not shown.