From 674514ec115366680b8fb6bc4004aa309802cd03 Mon Sep 17 00:00:00 2001 From: zcr Date: Mon, 23 Mar 2026 11:21:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20brand=20dna=20logo=E7=94=9F=E6=88=90?= =?UTF-8?q?=E6=9B=BF=E6=8D=A2flux2klein=20;=20fix:?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../brand_dna/service_generate_brand_info.py | 66 +++++-------------- 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/app/service/brand_dna/service_generate_brand_info.py b/app/service/brand_dna/service_generate_brand_info.py index fa8d2e5..345b855 100644 --- a/app/service/brand_dna/service_generate_brand_info.py +++ b/app/service/brand_dna/service_generate_brand_info.py @@ -1,19 +1,10 @@ -import logging - -import cv2 -import numpy as np -import tritonclient.grpc as grpcclient +import uuid +import httpx from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser from langchain_community.chat_models import ChatTongyi from langchain_core.prompts import PromptTemplate from minio import Minio -from tritonclient.utils import np_to_triton_dtype - -from app.core.config import GI_MODEL_URL, GI_MODEL_NAME from app.schemas.brand_dna import GenerateBrandModel -from app.service.utils.generate_uuid import generate_uuid -from app.service.utils.new_oss_client import oss_upload_image - from app.core.config import settings @@ -26,14 +17,9 @@ class GenerateBrandInfo: # user info init self.user_id = request_data.user_id self.category = "brand_logo" - # generate logo init - self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL) - self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8) - self.batch_size = 1 - self.mode = 'txt2img' # llm generate brand info init - self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab") + self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key=settings.QWEN_API_KEY) self.response_schemas = [ ResponseSchema(name="brand_name", description="Brand name."), @@ -63,38 +49,20 @@ class GenerateBrandInfo: self.generate_logo_prompt = brand_data['brand_logo_prompt'] def generate_brand_logo(self): - prompts = [self.generate_logo_prompt] * self.batch_size - modes = [self.mode] * self.batch_size - images = [self.image.astype(np.float16)] * self.batch_size - - text_obj = np.array(prompts, dtype="object").reshape((-1, 1)) - mode_obj = np.array(modes, dtype="object").reshape((-1, 1)) - image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3)) - - input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)) - input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype)) - input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype)) - - input_text.set_data_from_numpy(text_obj) - input_image.set_data_from_numpy(image_obj) - input_mode.set_data_from_numpy(mode_obj) - - inputs = [input_text, input_image, input_mode] - result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs) - image = result.as_numpy("generated_image") - image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) - logo_url = self.upload_logo_image(image_result, generate_uuid()) - self.result_data['brand_logo'] = logo_url - - def upload_logo_image(self, image, object_name): - try: - _, img_byte_array = cv2.imencode('.jpg', image) - object_name = f'{self.user_id}/{self.category}/{object_name}.jpg' - oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array) - image_url = f"aida-users/{object_name}" - return image_url - except Exception as e: - logging.warning(f"upload_png_mask runtime exception : {e}") + request_item = { + "bucket_name": "aida-users", + "object_name": f'{self.user_id}/{self.category}/{uuid.uuid4().hex}.png', + "prompt": self.generate_logo_prompt, + "height": 1024, + "width": 1024 + } + with httpx.Client(timeout=120) as client: + resp = client.post( + f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict", + json=request_item, + ) + result = resp.json() + self.result_data['brand_logo'] = result.get("output_path", "") if __name__ == '__main__':