diff --git a/app/api/api_brand_dna.py b/app/api/api_brand_dna.py index a250ead..2133ee9 100644 --- a/app/api/api_brand_dna.py +++ b/app/api/api_brand_dna.py @@ -3,16 +3,17 @@ import logging from fastapi import APIRouter, HTTPException -from app.schemas.brand_dna import BrandDnaModel +from app.schemas.brand_dna import BrandDnaModel, GenerateBrandModel from app.schemas.response_template import ResponseModel from app.service.brand_dna.service import BrandDna +from app.service.brand_dna.service_generate_brand_info import GenerateBrandInfo router = APIRouter() logger = logging.getLogger() @router.post("/seg_product") -def image2sketch(request_item: BrandDnaModel): +def seg_product(request_item: BrandDnaModel): """ 创建一个具有以下参数的请求体: - **image_url**: 提取图片url @@ -32,3 +33,27 @@ def image2sketch(request_item: BrandDnaModel): logger.warning(f"brand dna Run Exception @@@@@@:{e}") raise HTTPException(status_code=404, detail=str(e)) return ResponseModel(data=result_url) + + +@router.post("/GenerateBrand") +def GenerateBrand(request_data: GenerateBrandModel): + """ + 通过prompt 生成 brand name ,brand slogan , brand logo。 + 创建一个具有以下参数的请求体: + - **prompt**: + + 示例参数: + { + "prompt": "xiaomi", + "user_id": "89" + } + """ + try: + logger.info(f"GenerateBrand request item is : @@@@@@:{request_data}") + service = GenerateBrandInfo(request_data) + data = service.get_result() + logger.info(f"GenerateBrand response @@@@@@:{data}") + except Exception as e: + logger.warning(f"GenerateBrand Run Exception @@@@@@:{e}") + raise HTTPException(status_code=404, detail=str(e)) + return ResponseModel(data=data) diff --git a/app/core/config.py b/app/core/config.py index 662d7e2..6ac56e3 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -9,14 +9,14 @@ load_dotenv(os.path.join(BASE_DIR, '.env')) class Settings(BaseSettings): - PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE') - SECRET_KEY = os.getenv('SECRET_KEY', '') - API_PREFIX = '' - BACKEND_CORS_ORIGINS = ['*'] - DATABASE_URL = os.getenv('SQL_DATABASE_URL', '') + PROJECT_NAME: str = os.getenv('PROJECT_NAME', 'FASTAPI BASE') + SECRET_KEY: str = os.getenv('SECRET_KEY', '') + API_PREFIX: str = '' + BACKEND_CORS_ORIGINS: list[str] = ['*'] + DATABASE_URL: str = os.getenv('SQL_DATABASE_URL', '') ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days - SECURITY_ALGORITHM = 'HS256' - LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') + SECURITY_ALGORITHM: str = 'HS256' + LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py') OSS = "minio" @@ -32,7 +32,6 @@ else: SEG_CACHE_PATH = "/seg_cache/" RECOMMEND_PATH_PREFIX = "app/service/recommend/" - # RABBITMQ_ENV = "" # 生产环境 RABBITMQ_ENV = "-dev" # 开发环境 # RABBITMQ_ENV = "-local" # 本地测试环境 @@ -146,7 +145,6 @@ GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble' GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight' GRI_MODEL_URL = '10.1.1.240:10051' - # Pose Transform service config PS_RABBITMQ_QUEUES = os.getenv("PS_RABBITMQ_QUEUES", f"PoseTransform{RABBITMQ_ENV}") diff --git a/app/schemas/brand_dna.py b/app/schemas/brand_dna.py index c5ae2ab..9796195 100644 --- a/app/schemas/brand_dna.py +++ b/app/schemas/brand_dna.py @@ -4,3 +4,8 @@ from pydantic import BaseModel class BrandDnaModel(BaseModel): image_url: str is_brand_dna: bool + + +class GenerateBrandModel(BaseModel): + user_id: str + prompt: str diff --git a/app/service/brand_dna/service_generate_brand_info.py b/app/service/brand_dna/service_generate_brand_info.py new file mode 100644 index 0000000..73c1294 --- /dev/null +++ b/app/service/brand_dna/service_generate_brand_info.py @@ -0,0 +1,104 @@ +import logging + +import cv2 +import numpy as np +import tritonclient.grpc as grpcclient +from langchain.output_parsers import ResponseSchema, StructuredOutputParser +from langchain_community.chat_models import ChatTongyi +from langchain_core.prompts import PromptTemplate +# from langchain_openai import ChatOpenAI +from minio import Minio +from tritonclient.utils import np_to_triton_dtype + +from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME +from app.schemas.brand_dna import GenerateBrandModel +from app.service.utils.generate_uuid import generate_uuid +from app.service.utils.new_oss_client import oss_upload_image + + +class GenerateBrandInfo: + def __init__(self, request_data): + # minio client init + self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE) + + # user info init + self.user_id = request_data.user_id + self.category = "brand_logo" + # generate logo init + self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) + self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) + self.batch_size = 1 + self.mode = 'txt2img' + + # llm generate brand info init + self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") + + self.response_schemas = [ + ResponseSchema(name="brand_name", description="Brand name."), + ResponseSchema(name="brand_slogan", description="Brand slogan."), + ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.") + ] + self.output_parser = StructuredOutputParser.from_response_schemas(self.response_schemas) + self.format_instructions = self.output_parser.get_format_instructions() + self.prompt = PromptTemplate( + template="你是一个时装品牌的设计师。根据用户输入提取出brand name,brand slogan,brand logo 描述。如果没有以上内容,需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt,这个prompt用于生成模型.\n{format_instructions}\n{question}", + input_variables=["question"], + partial_variables={"format_instructions": self.format_instructions} + ) + self._input = self.prompt.format_prompt(question=request_data.prompt) + + self.result_data = {} + + def get_result(self): + self.llm_generate_brand_info() + self.generate_brand_logo() + return self.result_data + + def llm_generate_brand_info(self): + output = self.model(self._input.to_messages()) + brand_data = self.output_parser.parse(output.content) + self.result_data = brand_data + self.generate_logo_prompt = brand_data['brand_logo_prompt'] + + def generate_brand_logo(self): + prompts = [self.generate_logo_prompt] * self.batch_size + modes = [self.mode] * self.batch_size + images = [self.image.astype(np.float16)] * self.batch_size + + text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) + mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) + image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) + + input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) + input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) + input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) + + input_text.set_data_from_numpy(text_obj) + input_image.set_data_from_numpy(image_obj) + input_mode.set_data_from_numpy(mode_obj) + + inputs = [input_text, input_image, input_mode] + result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs) + image = result.as_numpy("generated_image") + image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) + logo_url = self.upload_logo_image(image_result, generate_uuid()) + self.result_data['brand_logo'] = logo_url + + def upload_logo_image(self, image, object_name): + try: + _, img_byte_array = cv2.imencode('.jpg', image) + object_name = f'{self.user_id}/{self.category}/{object_name}' + req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array) + image_url = f"aida-users/{object_name}" + return image_url + except Exception as e: + logging.warning(f"upload_png_mask runtime exception : {e}") + + +if __name__ == '__main__': + request_data = GenerateBrandModel( + user_id="89", + prompt="xiaomi" + ) + service = GenerateBrandInfo(request_data) + print(service.get_result()) diff --git a/app/service/brand_dna/test.py b/app/service/brand_dna/test.py new file mode 100644 index 0000000..966f76e --- /dev/null +++ b/app/service/brand_dna/test.py @@ -0,0 +1,32 @@ +from dotenv import load_dotenv +from langchain.output_parsers import StructuredOutputParser, ResponseSchema +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI + +# 加载.env文件的环境变量 +load_dotenv() + +# 创建一个大语言模型,model指定了大语言模型的种类 +model = ChatOpenAI(model="qwen2.5-14b-instruct") + +# 想要接收的响应模式 +response_schemas = [ + ResponseSchema(name="brand_name", description="Brand name."), + ResponseSchema(name="brand_slogan", description="Brand slogan."), + ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.") +] +output_parser = StructuredOutputParser.from_response_schemas(response_schemas) +format_instructions = output_parser.get_format_instructions() +prompt = PromptTemplate( + template="你是一个时装品牌的设计师。根据用户输入提取出brand name,brand slogan,brand logo 描述。如果没有以上内容,需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt,这个prompt用于生成模型.\n{format_instructions}\n{question}", + input_variables=["question"], + partial_variables={"format_instructions": format_instructions} +) +_input = prompt.format_prompt(question="brand name: cat home") + +output = model(_input.to_messages()) +brand_data = output_parser.parse(output.content) + + +def generate_logo(bucket_name, object_name, prompt): + pass