aida agent (基础版)搭建完成

This commit is contained in:
zcr
2026-06-15 14:48:17 +08:00
parent b602c47fc9
commit dbbaa7503c
25 changed files with 1953 additions and 717 deletions

View File

@@ -0,0 +1,27 @@
import httpx
async def generate_image(
bucket_name="fida-public-bucket",
object_name=f"furniture/sketches/123456.png",
prompt="Generate a modern minimalist dining chair made of light "
"oak wood and white leather, with slim metal legs, photographed "
"in a bright Scandinavian living room with natural sunlight, high detail, "
"8k resolution.",
):
request_data = {
"input_image_paths": [],
"prompt": prompt,
"bucket_name": bucket_name,
"object_name": object_name,
"width": 1024,
"height": 1024,
}
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"http://20.1.1.33:14202/predict",
json=request_data,
)
result = resp.json()
image_url = result.get("output_path", None)
return image_url

View File

@@ -0,0 +1,79 @@
import asyncio
import concurrent.futures
import random
import numpy as np
import tritonclient.grpc as grpcclient
from langchain.tools import tool
from PIL import Image
from pydantic import BaseModel, Field
from tritonclient.utils import np_to_triton_dtype
from uuid_utils import uuid7
from app.core.config import settings
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
# 模型配置
GSL_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10041"
GSL_MODEL_NAME = "stable_diffusion_xl_transparent"
# 线程池用于执行同步推理
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
def _generate_logo_sync(prompt: str) -> Image.Image:
"""同步生成 Logo 的内部函数"""
seed = random.randint(0, 2**32 - 1)
grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
# 准备输入
prompts = [prompt]
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
negative_prompts = "bad, ugly"
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
input_text_neg.set_data_from_numpy(text_obj_neg)
seed_input = np.array(seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput("seed", seed_input.shape, np_to_triton_dtype(seed_input.dtype))
input_seed.set_data_from_numpy(seed_input)
inputs = [input_text, input_text_neg, input_seed]
# 同步推理
result = grpc_client.infer(model_name=GSL_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
return Image.fromarray(np.squeeze(image.astype(np.uint8)))
async def generate_logo(prompt: str) -> Image.Image:
"""异步生成透明背景的 Logo 图片"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, _generate_logo_sync, prompt)
class GenerateLogoToolInput(BaseModel):
"""Input schema for the Generate Logo Tool."""
prompt: str = Field(description="Simple keyword for logo generation, e.g., 'cat', 'flower', 'dog'")
user_id: str = Field(description="User ID for image storage", default="agent")
@tool(args_schema=GenerateLogoToolInput)
async def generate_logo_tool(prompt: str, user_id: str = "agent") -> str:
"""Generate a transparent background logo image based on a simple keyword."""
image = await generate_logo(prompt=prompt)
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.png"
loop = asyncio.get_event_loop()
image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name)
return image_url
if __name__ == "__main__":
result = asyncio.run(generate_logo_tool.ainvoke({"prompt": "golden retriever"}))
print(f"Logo saved to: {result}")

View File

@@ -0,0 +1,72 @@
import asyncio
import concurrent.futures
import io
import logging
import os
import httpx
from dotenv import load_dotenv
from PIL import Image
from uuid_utils import uuid7
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
load_dotenv()
logger = logging.getLogger()
PEXELS_API_KEY = os.environ.get("PEXELS_API_KEY", "")
PEXELS_BASE_URL = os.environ.get("PEXELS_BASE_URL", "")
# 线程池用于执行同步上传
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
async def search_photos(query: str, per_page: int = 4, user_id: str = "agent") -> list[dict]:
"""从 Pexels 搜索图片并上传到 minio
Args:
query: 搜索关键词
per_page: 返回图片数量 (1-80)
user_id: 用户 ID
Returns:
图片信息列表,每项包含 image_url 和 minio_path
"""
# 搜索图片
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(
f"{PEXELS_BASE_URL}/search",
headers={"Authorization": PEXELS_API_KEY},
params={
"query": query,
"per_page": per_page,
"orientation": "square",
"size": "medium",
},
)
if response.status_code != 200:
raise Exception(f"Pexels API error: {response.status_code} - {response.text}")
data = response.json()
photos = data.get("photos", [])
# 下载并上传到 minio
results = []
for photo in photos:
try:
# 下载图片(使用 large 尺寸)
image_url = photo["src"]["original"]
async with httpx.AsyncClient(timeout=60) as dl_client:
dl_response = await dl_client.get(image_url)
image = Image.open(io.BytesIO(dl_response.content))
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.jpg"
loop = asyncio.get_event_loop()
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
results.append({"image_url": image_url, "minio_path": minio_url})
logger.info(f"[Explorer] 上传成功: {minio_url}")
except Exception as e:
logger.error(f"[Explorer] 上传失败: {e}")
return results

View File

@@ -0,0 +1,90 @@
import asyncio
import concurrent.futures
import io
import logging
import os
import httpx
from PIL import Image
from uuid_utils import uuid7
from dotenv import load_dotenv
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
load_dotenv()
# Unsplash API 配置
UNSPLASH_ACCESS_KEY = os.environ.get("UNSPLASH_ACCESS_KEY", "")
UNSPLASH_BASE_URL = os.environ.get("UNSPLASH_BASE_URL", "")
logger = logging.getLogger()
# 线程池用于执行同步上传
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
async def get_random_photos(query: str, count: int = 4, user_id: str = "agent") -> list[dict]:
"""从 Unsplash 获取随机图片并上传到 minio
Args:
query: 搜索关键词
count: 返回图片数量 (1-30)
user_id: 用户 ID
Returns:
图片信息列表,每项包含 image_url 和 minio_path
"""
# 获取随机图片
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(
f"{UNSPLASH_BASE_URL}/search/photos",
headers={"Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"},
params={
"query": query,
"per_page": count,
"page": 1,
},
)
if response.status_code != 200:
raise Exception(f"Unsplash API error: {response.status_code} - {response.text}")
data = response.json()
# /search/photos 返回 {"results": [...], "total": ...}
photos = data.get("results", [])
# 下载并上传到 minio
results = []
for photo in photos:
try:
# 下载图片
image_url = photo["urls"]["raw"]
async with httpx.AsyncClient(timeout=60) as dl_client:
dl_response = await dl_client.get(image_url)
image = Image.open(io.BytesIO(dl_response.content))
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.jpg"
loop = asyncio.get_event_loop()
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
results.append({"image_url": image_url, "minio_path": minio_url})
logger.info(f"[Explorer] 上传成功: {minio_url}")
except Exception as e:
logger.error(f"[Explorer] 上传失败: {e}")
return results
if __name__ == "__main__":
import asyncio
async def test():
"""测试 Unsplash 搜索"""
query = "summer dress fresh natural style"
print(f"搜索关键词: {query}")
print("=" * 50)
results = await get_random_photos(query, count=4, user_id="test")
print(f"\n找到 {len(results)} 张图片:")
for i, item in enumerate(results, 1):
print(f" {i}. 原图: {item.get('image_url', '')}")
print(f" Minio: {item.get('minio_path', '')}")
asyncio.run(test())