feat: brand dna logo生成替换flux2klein ; fix:

This commit is contained in:
zcr
2026-03-23 11:21:50 +08:00
committed by zchen
parent e9ca1d301b
commit 674514ec11

View File

@@ -1,19 +1,10 @@
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
import uuid
import httpx
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.core.config import settings
@@ -26,14 +17,9 @@ class GenerateBrandInfo:
# user info init
self.user_id = request_data.user_id
self.category = "brand_logo"
# generate logo init
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.batch_size = 1
self.mode = 'txt2img'
# llm generate brand info init
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key=settings.QWEN_API_KEY)
self.response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
@@ -63,38 +49,20 @@ class GenerateBrandInfo:
self.generate_logo_prompt = brand_data['brand_logo_prompt']
def generate_brand_logo(self):
prompts = [self.generate_logo_prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
logo_url = self.upload_logo_image(image_result, generate_uuid())
self.result_data['brand_logo'] = logo_url
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
request_item = {
"bucket_name": "aida-users",
"object_name": f'{self.user_id}/{self.category}/{uuid.uuid4().hex}.png',
"prompt": self.generate_logo_prompt,
"height": 1024,
"width": 1024
}
with httpx.Client(timeout=120) as client:
resp = client.post(
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
json=request_item,
)
result = resp.json()
self.result_data['brand_logo'] = result.get("output_path", "")
if __name__ == '__main__':