feat(新功能): brand name slogan logo 生成服务
fix(修复bug): docs(文档变更): refactor(重构): test(增加测试):
This commit is contained in:
@@ -3,16 +3,17 @@ import logging
|
|||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
|
||||||
from app.schemas.brand_dna import BrandDnaModel
|
from app.schemas.brand_dna import BrandDnaModel, GenerateBrandModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.brand_dna.service import BrandDna
|
from app.service.brand_dna.service import BrandDna
|
||||||
|
from app.service.brand_dna.service_generate_brand_info import GenerateBrandInfo
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/seg_product")
|
@router.post("/seg_product")
|
||||||
def image2sketch(request_item: BrandDnaModel):
|
def seg_product(request_item: BrandDnaModel):
|
||||||
"""
|
"""
|
||||||
创建一个具有以下参数的请求体:
|
创建一个具有以下参数的请求体:
|
||||||
- **image_url**: 提取图片url
|
- **image_url**: 提取图片url
|
||||||
@@ -32,3 +33,27 @@ def image2sketch(request_item: BrandDnaModel):
|
|||||||
logger.warning(f"brand dna Run Exception @@@@@@:{e}")
|
logger.warning(f"brand dna Run Exception @@@@@@:{e}")
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
return ResponseModel(data=result_url)
|
return ResponseModel(data=result_url)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/GenerateBrand")
|
||||||
|
def GenerateBrand(request_data: GenerateBrandModel):
|
||||||
|
"""
|
||||||
|
通过prompt 生成 brand name ,brand slogan , brand logo。
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **prompt**:
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"prompt": "xiaomi",
|
||||||
|
"user_id": "89"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"GenerateBrand request item is : @@@@@@:{request_data}")
|
||||||
|
service = GenerateBrandInfo(request_data)
|
||||||
|
data = service.get_result()
|
||||||
|
logger.info(f"GenerateBrand response @@@@@@:{data}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"GenerateBrand Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=data)
|
||||||
|
|||||||
@@ -9,14 +9,14 @@ load_dotenv(os.path.join(BASE_DIR, '.env'))
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
|
PROJECT_NAME: str = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
|
||||||
SECRET_KEY = os.getenv('SECRET_KEY', '')
|
SECRET_KEY: str = os.getenv('SECRET_KEY', '')
|
||||||
API_PREFIX = ''
|
API_PREFIX: str = ''
|
||||||
BACKEND_CORS_ORIGINS = ['*']
|
BACKEND_CORS_ORIGINS: list[str] = ['*']
|
||||||
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '')
|
DATABASE_URL: str = os.getenv('SQL_DATABASE_URL', '')
|
||||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||||
SECURITY_ALGORITHM = 'HS256'
|
SECURITY_ALGORITHM: str = 'HS256'
|
||||||
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
|
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
|
||||||
|
|
||||||
|
|
||||||
OSS = "minio"
|
OSS = "minio"
|
||||||
@@ -32,7 +32,6 @@ else:
|
|||||||
SEG_CACHE_PATH = "/seg_cache/"
|
SEG_CACHE_PATH = "/seg_cache/"
|
||||||
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
||||||
|
|
||||||
|
|
||||||
# RABBITMQ_ENV = "" # 生产环境
|
# RABBITMQ_ENV = "" # 生产环境
|
||||||
RABBITMQ_ENV = "-dev" # 开发环境
|
RABBITMQ_ENV = "-dev" # 开发环境
|
||||||
# RABBITMQ_ENV = "-local" # 本地测试环境
|
# RABBITMQ_ENV = "-local" # 本地测试环境
|
||||||
@@ -146,7 +145,6 @@ GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
|||||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||||
|
|
||||||
|
|
||||||
# Pose Transform service config
|
# Pose Transform service config
|
||||||
|
|
||||||
PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}")
|
PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}")
|
||||||
|
|||||||
@@ -4,3 +4,8 @@ from pydantic import BaseModel
|
|||||||
class BrandDnaModel(BaseModel):
|
class BrandDnaModel(BaseModel):
|
||||||
image_url: str
|
image_url: str
|
||||||
is_brand_dna: bool
|
is_brand_dna: bool
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateBrandModel(BaseModel):
|
||||||
|
user_id: str
|
||||||
|
prompt: str
|
||||||
|
|||||||
104
app/service/brand_dna/service_generate_brand_info.py
Normal file
104
app/service/brand_dna/service_generate_brand_info.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import tritonclient.grpc as grpcclient
|
||||||
|
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||||
|
from langchain_community.chat_models import ChatTongyi
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
# from langchain_openai import ChatOpenAI
|
||||||
|
from minio import Minio
|
||||||
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
|
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
|
||||||
|
from app.schemas.brand_dna import GenerateBrandModel
|
||||||
|
from app.service.utils.generate_uuid import generate_uuid
|
||||||
|
from app.service.utils.new_oss_client import oss_upload_image
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateBrandInfo:
|
||||||
|
def __init__(self, request_data):
|
||||||
|
# minio client init
|
||||||
|
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
# user info init
|
||||||
|
self.user_id = request_data.user_id
|
||||||
|
self.category = "brand_logo"
|
||||||
|
# generate logo init
|
||||||
|
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||||
|
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||||
|
self.batch_size = 1
|
||||||
|
self.mode = 'txt2img'
|
||||||
|
|
||||||
|
# llm generate brand info init
|
||||||
|
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||||
|
|
||||||
|
self.response_schemas = [
|
||||||
|
ResponseSchema(name="brand_name", description="Brand name."),
|
||||||
|
ResponseSchema(name="brand_slogan", description="Brand slogan."),
|
||||||
|
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
|
||||||
|
]
|
||||||
|
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
|
||||||
|
self.format_instructions = self.output_parser.get_format_instructions()
|
||||||
|
self.prompt = PromptTemplate(
|
||||||
|
template="你是一个时装品牌的设计师。根据用户输入提取出brand name,brand slogan,brand logo 描述。如果没有以上内容,需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt,这个prompt用于生成模型.\n{format_instructions}\n{question}",
|
||||||
|
input_variables=["question"],
|
||||||
|
partial_variables={"format_instructions": self.format_instructions}
|
||||||
|
)
|
||||||
|
self._input = self.prompt.format_prompt(question=request_data.prompt)
|
||||||
|
|
||||||
|
self.result_data = {}
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
self.llm_generate_brand_info()
|
||||||
|
self.generate_brand_logo()
|
||||||
|
return self.result_data
|
||||||
|
|
||||||
|
def llm_generate_brand_info(self):
|
||||||
|
output = self.model(self._input.to_messages())
|
||||||
|
brand_data = self.output_parser.parse(output.content)
|
||||||
|
self.result_data = brand_data
|
||||||
|
self.generate_logo_prompt = brand_data['brand_logo_prompt']
|
||||||
|
|
||||||
|
def generate_brand_logo(self):
|
||||||
|
prompts = [self.generate_logo_prompt] * self.batch_size
|
||||||
|
modes = [self.mode] * self.batch_size
|
||||||
|
images = [self.image.astype(np.float16)] * self.batch_size
|
||||||
|
|
||||||
|
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||||
|
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
||||||
|
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
||||||
|
|
||||||
|
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||||
|
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
||||||
|
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
||||||
|
|
||||||
|
input_text.set_data_from_numpy(text_obj)
|
||||||
|
input_image.set_data_from_numpy(image_obj)
|
||||||
|
input_mode.set_data_from_numpy(mode_obj)
|
||||||
|
|
||||||
|
inputs = [input_text, input_image, input_mode]
|
||||||
|
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
|
||||||
|
image = result.as_numpy("generated_image")
|
||||||
|
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
||||||
|
logo_url = self.upload_logo_image(image_result, generate_uuid())
|
||||||
|
self.result_data['brand_logo'] = logo_url
|
||||||
|
|
||||||
|
def upload_logo_image(self, image, object_name):
|
||||||
|
try:
|
||||||
|
_, img_byte_array = cv2.imencode('.jpg', image)
|
||||||
|
object_name = f'{self.user_id}/{self.category}/{object_name}'
|
||||||
|
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
|
||||||
|
image_url = f"aida-users/{object_name}"
|
||||||
|
return image_url
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
request_data = GenerateBrandModel(
|
||||||
|
user_id="89",
|
||||||
|
prompt="xiaomi"
|
||||||
|
)
|
||||||
|
service = GenerateBrandInfo(request_data)
|
||||||
|
print(service.get_result())
|
||||||
32
app/service/brand_dna/test.py
Normal file
32
app/service/brand_dna/test.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from dotenv import load_dotenv
|
||||||
|
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
# 加载.env文件的环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# 创建一个大语言模型,model指定了大语言模型的种类
|
||||||
|
model = ChatOpenAI(model="qwen2.5-14b-instruct")
|
||||||
|
|
||||||
|
# 想要接收的响应模式
|
||||||
|
response_schemas = [
|
||||||
|
ResponseSchema(name="brand_name", description="Brand name."),
|
||||||
|
ResponseSchema(name="brand_slogan", description="Brand slogan."),
|
||||||
|
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
|
||||||
|
]
|
||||||
|
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
||||||
|
format_instructions = output_parser.get_format_instructions()
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
template="你是一个时装品牌的设计师。根据用户输入提取出brand name,brand slogan,brand logo 描述。如果没有以上内容,需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt,这个prompt用于生成模型.\n{format_instructions}\n{question}",
|
||||||
|
input_variables=["question"],
|
||||||
|
partial_variables={"format_instructions": format_instructions}
|
||||||
|
)
|
||||||
|
_input = prompt.format_prompt(question="brand name: cat home")
|
||||||
|
|
||||||
|
output = model(_input.to_messages())
|
||||||
|
brand_data = output_parser.parse(output.content)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_logo(bucket_name, object_name, prompt):
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user