aida agent (基础版)搭建完成
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
import httpx
|
||||
|
||||
|
||||
async def generate_image(
|
||||
bucket_name="fida-public-bucket",
|
||||
object_name=f"furniture/sketches/123456.png",
|
||||
prompt="Generate a modern minimalist dining chair made of light "
|
||||
"oak wood and white leather, with slim metal legs, photographed "
|
||||
"in a bright Scandinavian living room with natural sunlight, high detail, "
|
||||
"8k resolution.",
|
||||
):
|
||||
request_data = {
|
||||
"input_image_paths": [],
|
||||
"prompt": prompt,
|
||||
"bucket_name": bucket_name,
|
||||
"object_name": object_name,
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
resp = await client.post(
|
||||
f"http://20.1.1.33:14202/predict",
|
||||
json=request_data,
|
||||
)
|
||||
result = resp.json()
|
||||
image_url = result.get("output_path", None)
|
||||
return image_url
|
||||
@@ -0,0 +1,79 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from langchain.tools import tool
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
from uuid_utils import uuid7
|
||||
from app.core.config import settings
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
|
||||
# 模型配置
|
||||
GSL_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10041"
|
||||
GSL_MODEL_NAME = "stable_diffusion_xl_transparent"
|
||||
|
||||
# 线程池用于执行同步推理
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _generate_logo_sync(prompt: str) -> Image.Image:
|
||||
"""同步生成 Logo 的内部函数"""
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
|
||||
|
||||
# 准备输入
|
||||
prompts = [prompt]
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
|
||||
negative_prompts = "bad, ugly"
|
||||
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
|
||||
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
|
||||
input_text_neg.set_data_from_numpy(text_obj_neg)
|
||||
|
||||
seed_input = np.array(seed, dtype="object").reshape((-1, 1))
|
||||
input_seed = grpcclient.InferInput("seed", seed_input.shape, np_to_triton_dtype(seed_input.dtype))
|
||||
input_seed.set_data_from_numpy(seed_input)
|
||||
|
||||
inputs = [input_text, input_text_neg, input_seed]
|
||||
|
||||
# 同步推理
|
||||
result = grpc_client.infer(model_name=GSL_MODEL_NAME, inputs=inputs)
|
||||
image = result.as_numpy("generated_image")
|
||||
return Image.fromarray(np.squeeze(image.astype(np.uint8)))
|
||||
|
||||
|
||||
async def generate_logo(prompt: str) -> Image.Image:
|
||||
"""异步生成透明背景的 Logo 图片"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, _generate_logo_sync, prompt)
|
||||
|
||||
|
||||
class GenerateLogoToolInput(BaseModel):
|
||||
"""Input schema for the Generate Logo Tool."""
|
||||
|
||||
prompt: str = Field(description="Simple keyword for logo generation, e.g., 'cat', 'flower', 'dog'")
|
||||
user_id: str = Field(description="User ID for image storage", default="agent")
|
||||
|
||||
|
||||
@tool(args_schema=GenerateLogoToolInput)
|
||||
async def generate_logo_tool(prompt: str, user_id: str = "agent") -> str:
|
||||
"""Generate a transparent background logo image based on a simple keyword."""
|
||||
|
||||
image = await generate_logo(prompt=prompt)
|
||||
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.png"
|
||||
loop = asyncio.get_event_loop()
|
||||
image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name)
|
||||
return image_url
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = asyncio.run(generate_logo_tool.ainvoke({"prompt": "golden retriever"}))
|
||||
print(f"Logo saved to: {result}")
|
||||
@@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
from PIL import Image
|
||||
from uuid_utils import uuid7
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger()
|
||||
PEXELS_API_KEY = os.environ.get("PEXELS_API_KEY", "")
|
||||
PEXELS_BASE_URL = os.environ.get("PEXELS_BASE_URL", "")
|
||||
|
||||
# 线程池用于执行同步上传
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
async def search_photos(query: str, per_page: int = 4, user_id: str = "agent") -> list[dict]:
|
||||
"""从 Pexels 搜索图片并上传到 minio
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
per_page: 返回图片数量 (1-80)
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
图片信息列表,每项包含 image_url 和 minio_path
|
||||
"""
|
||||
# 搜索图片
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(
|
||||
f"{PEXELS_BASE_URL}/search",
|
||||
headers={"Authorization": PEXELS_API_KEY},
|
||||
params={
|
||||
"query": query,
|
||||
"per_page": per_page,
|
||||
"orientation": "square",
|
||||
"size": "medium",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Pexels API error: {response.status_code} - {response.text}")
|
||||
|
||||
data = response.json()
|
||||
photos = data.get("photos", [])
|
||||
|
||||
# 下载并上传到 minio
|
||||
results = []
|
||||
for photo in photos:
|
||||
try:
|
||||
# 下载图片(使用 large 尺寸)
|
||||
image_url = photo["src"]["original"]
|
||||
async with httpx.AsyncClient(timeout=60) as dl_client:
|
||||
dl_response = await dl_client.get(image_url)
|
||||
image = Image.open(io.BytesIO(dl_response.content))
|
||||
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.jpg"
|
||||
loop = asyncio.get_event_loop()
|
||||
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
|
||||
results.append({"image_url": image_url, "minio_path": minio_url})
|
||||
logger.info(f"[Explorer] 上传成功: {minio_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Explorer] 上传失败: {e}")
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import httpx
|
||||
|
||||
from PIL import Image
|
||||
from uuid_utils import uuid7
|
||||
from dotenv import load_dotenv
|
||||
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Unsplash API 配置
|
||||
UNSPLASH_ACCESS_KEY = os.environ.get("UNSPLASH_ACCESS_KEY", "")
|
||||
UNSPLASH_BASE_URL = os.environ.get("UNSPLASH_BASE_URL", "")
|
||||
logger = logging.getLogger()
|
||||
# 线程池用于执行同步上传
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
async def get_random_photos(query: str, count: int = 4, user_id: str = "agent") -> list[dict]:
|
||||
"""从 Unsplash 获取随机图片并上传到 minio
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
count: 返回图片数量 (1-30)
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
图片信息列表,每项包含 image_url 和 minio_path
|
||||
"""
|
||||
# 获取随机图片
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.get(
|
||||
f"{UNSPLASH_BASE_URL}/search/photos",
|
||||
headers={"Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"},
|
||||
params={
|
||||
"query": query,
|
||||
"per_page": count,
|
||||
"page": 1,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Unsplash API error: {response.status_code} - {response.text}")
|
||||
|
||||
data = response.json()
|
||||
# /search/photos 返回 {"results": [...], "total": ...}
|
||||
photos = data.get("results", [])
|
||||
|
||||
# 下载并上传到 minio
|
||||
results = []
|
||||
for photo in photos:
|
||||
try:
|
||||
# 下载图片
|
||||
image_url = photo["urls"]["raw"]
|
||||
async with httpx.AsyncClient(timeout=60) as dl_client:
|
||||
dl_response = await dl_client.get(image_url)
|
||||
image = Image.open(io.BytesIO(dl_response.content))
|
||||
|
||||
# 上传到 minio(使用线程池避免阻塞事件循环)
|
||||
file_name = f"{uuid7()}.jpg"
|
||||
loop = asyncio.get_event_loop()
|
||||
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
|
||||
results.append({"image_url": image_url, "minio_path": minio_url})
|
||||
logger.info(f"[Explorer] 上传成功: {minio_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Explorer] 上传失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
async def test():
|
||||
"""测试 Unsplash 搜索"""
|
||||
query = "summer dress fresh natural style"
|
||||
print(f"搜索关键词: {query}")
|
||||
print("=" * 50)
|
||||
|
||||
results = await get_random_photos(query, count=4, user_id="test")
|
||||
print(f"\n找到 {len(results)} 张图片:")
|
||||
for i, item in enumerate(results, 1):
|
||||
print(f" {i}. 原图: {item.get('image_url', '')}")
|
||||
print(f" Minio: {item.get('minio_path', '')}")
|
||||
|
||||
asyncio.run(test())
|
||||
Reference in New Issue
Block a user