更新图形生成工具,优化返回格式并添加新功能
This commit is contained in:
@@ -1,79 +0,0 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from langchain.tools import tool
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
from uuid_utils import uuid7
|
||||
from app.core.config import settings
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
|
||||
# 模型配置
|
||||
GSL_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10041"
|
||||
GSL_MODEL_NAME = "stable_diffusion_xl_transparent"
|
||||
|
||||
# 线程池用于执行同步推理
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _generate_logo_sync(prompt: str) -> Image.Image:
|
||||
"""同步生成 Logo 的内部函数"""
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
||||
|
||||
# 准备输入
|
||||
prompts = [prompt]
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
|
||||
negative_prompts = "bad, ugly"
|
||||
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
|
||||
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
|
||||
input_text_neg.set_data_from_numpy(text_obj_neg)
|
||||
|
||||
seed_input = np.array(seed, dtype="object").reshape((-1, 1))
|
||||
input_seed = grpcclient.InferInput("seed", seed_input.shape, np_to_triton_dtype(seed_input.dtype))
|
||||
input_seed.set_data_from_numpy(seed_input)
|
||||
|
||||
inputs = [input_text, input_text_neg, input_seed]
|
||||
|
||||
# 同步推理
|
||||
result = grpc_client.infer(model_name=GSL_MODEL_NAME, inputs=inputs)
|
||||
image = result.as_numpy("generated_image")
|
||||
return Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
|
||||
|
||||
async def generate_logo(prompt: str) -> Image.Image:
|
||||
"""异步生成透明背景的 Logo 图片"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, _generate_logo_sync, prompt)
|
||||
|
||||
|
||||
class GenerateLogoToolInput(BaseModel):
|
||||
"""Input schema for the Generate Logo Tool."""
|
||||
|
||||
prompt: str = Field(description="Simple keyword for logo generation, e.g., 'cat', 'flower', 'dog'")
|
||||
user_id: str = Field(description="User ID for image storage", default="agent")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateLogoToolInput)
|
||||
async def generate_logo_tool(prompt: str, user_id: str = "agent") -> str:
|
||||
"""Generate a transparent background logo image based on a simple keyword."""
|
||||
|
||||
image = await generate_logo(prompt=prompt)
|
||||
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.png"
|
||||
loop = asyncio.get_event_loop()
|
||||
image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name)
|
||||
return image_url
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = asyncio.run(generate_logo_tool.ainvoke({"prompt": "golden retriever"}))
|
||||
print(f"Logo saved to: {result}")
|
||||
Reference in New Issue
Block a user