feat: brand dna logo生成替换flux2klein ; fix:

This commit is contained in:
zcr
2026-03-23 11:21:50 +08:00
committed by zchen
parent e9ca1d301b
commit 674514ec11

View File

@@ -1,19 +1,10 @@
import logging import uuid
import httpx
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from minio import Minio from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.core.config import settings from app.core.config import settings
@@ -26,14 +17,9 @@ class GenerateBrandInfo:
# user info init # user info init
self.user_id = request_data.user_id self.user_id = request_data.user_id
self.category = "brand_logo" self.category = "brand_logo"
# generate logo init
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.batch_size = 1
self.mode = 'txt2img'
# llm generate brand info init # llm generate brand info init
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key=settings.QWEN_API_KEY)
self.response_schemas = [ self.response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."), ResponseSchema(name="brand_name", description="Brand name."),
@@ -63,38 +49,20 @@ class GenerateBrandInfo:
self.generate_logo_prompt = brand_data['brand_logo_prompt'] self.generate_logo_prompt = brand_data['brand_logo_prompt']
def generate_brand_logo(self): def generate_brand_logo(self):
prompts = [self.generate_logo_prompt] * self.batch_size request_item = {
modes = [self.mode] * self.batch_size "bucket_name": "aida-users",
images = [self.image.astype(np.float16)] * self.batch_size "object_name": f'{self.user_id}/{self.category}/{uuid.uuid4().hex}.png',
"prompt": self.generate_logo_prompt,
text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) "height": 1024,
mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) "width": 1024
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) }
with httpx.Client(timeout=120) as client:
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) resp = client.post(
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) json=request_item,
)
input_text.set_data_from_numpy(text_obj) result = resp.json()
input_image.set_data_from_numpy(image_obj) self.result_data['brand_logo'] = result.get("output_path", "")
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
logo_url = self.upload_logo_image(image_result, generate_uuid())
self.result_data['brand_logo'] = logo_url
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
if __name__ == '__main__': if __name__ == '__main__':