更新图形生成工具,优化返回格式并添加新功能

This commit is contained in:
zcr
2026-06-15 17:10:04 +08:00
parent b14ccab723
commit 35e791b4e2
11 changed files with 31 additions and 21 deletions

View File

@@ -1,79 +0,0 @@
import asyncio
import concurrent.futures
import random
import numpy as np
import tritonclient.grpc as grpcclient
from langchain.tools import tool
from PIL import Image
from pydantic import BaseModel, Field
from tritonclient.utils import np_to_triton_dtype
from uuid_utils import uuid7
from app.core.config import settings
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
# 模型配置
GSL_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10041"
GSL_MODEL_NAME = "stable_diffusion_xl_transparent"
# 线程池用于执行同步推理
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
def _generate_logo_sync(prompt: str) -> Image.Image:
"""同步生成 Logo 的内部函数"""
seed = random.randint(0, 2**32 - 1)
grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
# 准备输入
prompts = [prompt]
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
negative_prompts = "bad, ugly"
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
input_text_neg.set_data_from_numpy(text_obj_neg)
seed_input = np.array(seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput("seed", seed_input.shape, np_to_triton_dtype(seed_input.dtype))
input_seed.set_data_from_numpy(seed_input)
inputs = [input_text, input_text_neg, input_seed]
# 同步推理
result = grpc_client.infer(model_name=GSL_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
return Image.fromarray(np.squeeze(image.astype(np.uint8)))
async def generate_logo(prompt: str) -> Image.Image:
"""异步生成透明背景的 Logo 图片"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, _generate_logo_sync, prompt)
class GenerateLogoToolInput(BaseModel):
"""Input schema for the Generate Logo Tool."""
prompt: str = Field(description="Simple keyword for logo generation, e.g., 'cat', 'flower', 'dog'")
user_id: str = Field(description="User ID for image storage", default="agent")
@tool(args_schema=GenerateLogoToolInput)
async def generate_logo_tool(prompt: str, user_id: str = "agent") -> str:
"""Generate a transparent background logo image based on a simple keyword."""
image = await generate_logo(prompt=prompt)
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.png"
loop = asyncio.get_event_loop()
image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name)
return image_url
if __name__ == "__main__":
result = asyncio.run(generate_logo_tool.ainvoke({"prompt": "golden retriever"}))
print(f"Logo saved to: {result}")