feat(新功能): brand name slogan logo 生成服务

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
This commit is contained in:
zhouchengrong
2025-03-25 15:55:52 +08:00
parent b69aadbcfb
commit d029bdb944
5 changed files with 175 additions and 11 deletions

View File

@@ -0,0 +1,104 @@
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
# from langchain_openai import ChatOpenAI
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
class GenerateBrandInfo:
def __init__(self, request_data):
# minio client init
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# user info init
self.user_id = request_data.user_id
self.category = "brand_logo"
# generate logo init
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.batch_size = 1
self.mode = 'txt2img'
# llm generate brand info init
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
ResponseSchema(name="brand_slogan", description="Brand slogan."),
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
]
self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas)
self.format_instructions = self.output_parser.get_format_instructions()
self.prompt = PromptTemplate(
template="你是一个时装品牌的设计师。根据用户输入提取出brand namebrand slogan,brand logo 描述。如果没有以上内容需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt这个prompt用于生成模型.\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": self.format_instructions}
)
self._input = self.prompt.format_prompt(question=request_data.prompt)
self.result_data = {}
def get_result(self):
self.llm_generate_brand_info()
self.generate_brand_logo()
return self.result_data
def llm_generate_brand_info(self):
output = self.model(self._input.to_messages())
brand_data = self.output_parser.parse(output.content)
self.result_data = brand_data
self.generate_logo_prompt = brand_data['brand_logo_prompt']
def generate_brand_logo(self):
prompts = [self.generate_logo_prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
logo_url = self.upload_logo_image(image_result, generate_uuid())
self.result_data['brand_logo'] = logo_url
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}'
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
if __name__ == '__main__':
request_data = GenerateBrandModel(
user_id="89",
prompt="xiaomi"
)
service = GenerateBrandInfo(request_data)
print(service.get_result())

View File

@@ -0,0 +1,32 @@
from dotenv import load_dotenv
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
# 加载.env文件的环境变量
load_dotenv()
# 创建一个大语言模型model指定了大语言模型的种类
model = ChatOpenAI(model="qwen2.5-14b-instruct")
# 想要接收的响应模式
response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
ResponseSchema(name="brand_slogan", description="Brand slogan."),
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
template="你是一个时装品牌的设计师。根据用户输入提取出brand namebrand slogan,brand logo 描述。如果没有以上内容需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt这个prompt用于生成模型.\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": format_instructions}
)
_input = prompt.format_prompt(question="brand name: cat home")
output = model(_input.to_messages())
brand_data = output_parser.parse(output.content)
def generate_logo(bucket_name, object_name, prompt):
pass