1.新增视觉能力 2.新增对上次图片 或 上传图片 引用图片做编辑能力.
This commit is contained in:
@@ -1,23 +1,26 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from minio import Minio
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
from typing import AsyncGenerator
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from langchain_core.messages import SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
||||||
|
|
||||||
from src.core.config import PROJECT_ROOT
|
from src.core.config import PROJECT_ROOT, settings
|
||||||
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
|
||||||
|
|
||||||
from src.server.deep_agent.agents.main_agent import build_main_agent
|
from src.server.deep_agent.agents.main_agent import build_main_agent
|
||||||
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
|
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
|
||||||
|
from src.server.deep_agent.tools.generate_furniture_sketch import is_image_path_exist
|
||||||
|
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
|
||||||
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
|
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
|
||||||
|
from src.server.utils.new_oss_client import is_minio_file_exist, oss_upload_image_file, oss_get_image, get_presigned_url
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/deep_agent_stream")
|
@router.post("/deep_agent_stream")
|
||||||
@@ -83,12 +86,10 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
# 1. 確定目標 thread_id
|
# 1. 確定目標 thread_id
|
||||||
is_branching = source_thread_id and checkpoint_id
|
is_branching = source_thread_id and checkpoint_id
|
||||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||||
|
|
||||||
# 构建主agent
|
# 构建主agent
|
||||||
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{target_thread_id}")
|
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{target_thread_id}")
|
||||||
print(f"target_thread_id : workspace_dir: {workspace_dir}")
|
logger.info(f"target_thread_id : workspace_dir: {workspace_dir}")
|
||||||
main_agent = build_main_agent(request.use_report, workspace_dir)
|
main_agent = build_main_agent(request.use_report, workspace_dir)
|
||||||
|
|
||||||
# 2. 配置參數
|
# 2. 配置參數
|
||||||
temp = request.config_params.temperature if request.config_params else 0.7
|
temp = request.config_params.temperature if request.config_params else 0.7
|
||||||
|
|
||||||
@@ -131,14 +132,39 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
is_first = True
|
is_first = True
|
||||||
new_messages = initial_messages[:] if not source_thread_id else []
|
if request.image_url:
|
||||||
new_messages.append(HumanMessage(content=request.message))
|
bucket, object_name = request.image_url.split('/', 1)
|
||||||
|
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
||||||
input_data = {
|
new_messages = {
|
||||||
"messages": new_messages,
|
"messages": [
|
||||||
}
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": request.message},
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"files": {
|
||||||
|
"input_image": request.image_url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
new_messages = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": request.message},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"files": {
|
||||||
|
"input_image": "",
|
||||||
|
}
|
||||||
|
}
|
||||||
async for stream in main_agent.astream(
|
async for stream in main_agent.astream(
|
||||||
input_data,
|
new_messages,
|
||||||
config=current_config,
|
config=current_config,
|
||||||
stream_mode=["updates", "messages", "custom"],
|
stream_mode=["updates", "messages", "custom"],
|
||||||
subgraphs=True
|
subgraphs=True
|
||||||
@@ -149,12 +175,10 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
is_first = False
|
is_first = False
|
||||||
_, mode, chunks = stream
|
_, mode, chunks = stream
|
||||||
if mode == "updates": # 只做记录 不做事件返回
|
if mode == "updates": # 只做记录 不做事件返回
|
||||||
print(f"[updates] {chunks}")
|
|
||||||
update_model_messages = chunks.get("model", None)
|
update_model_messages = chunks.get("model", None)
|
||||||
update_tools_messages = chunks.get("tools", None)
|
update_tools_messages = chunks.get("tools", None)
|
||||||
payload_out = {
|
payload_out = {
|
||||||
"node": "",
|
"node": "",
|
||||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
|
||||||
"is_delta": False,
|
"is_delta": False,
|
||||||
"content": "",
|
"content": "",
|
||||||
"type": "updates"
|
"type": "updates"
|
||||||
@@ -187,7 +211,6 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
subagent_name = metadata.get('lc_agent_name', "main")
|
subagent_name = metadata.get('lc_agent_name', "main")
|
||||||
payload_out = {
|
payload_out = {
|
||||||
"node": subagent_name,
|
"node": subagent_name,
|
||||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
|
||||||
"is_delta": False,
|
"is_delta": False,
|
||||||
"content": "",
|
"content": "",
|
||||||
"type": ""
|
"type": ""
|
||||||
@@ -220,7 +243,6 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
payload_out.update({
|
payload_out.update({
|
||||||
"type": "tool_call",
|
"type": "tool_call",
|
||||||
"is_delta": True,
|
"is_delta": True,
|
||||||
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
|
|
||||||
})
|
})
|
||||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
elif isinstance(token, ToolMessageChunk): # 工具返回
|
elif isinstance(token, ToolMessageChunk): # 工具返回
|
||||||
@@ -247,7 +269,6 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
elif mode == "custom":
|
elif mode == "custom":
|
||||||
payload_out = {
|
payload_out = {
|
||||||
"node": "research-agent",
|
"node": "research-agent",
|
||||||
# "checkpoint_id": current_cp_id or "unknown", TODO 替换为checkpoint_idns
|
|
||||||
"is_delta": False,
|
"is_delta": False,
|
||||||
"content": "",
|
"content": "",
|
||||||
"type": ""
|
"type": ""
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class AgentConfig(BaseModel):
|
|||||||
|
|
||||||
class DeepAgentChatRequest(BaseModel):
|
class DeepAgentChatRequest(BaseModel):
|
||||||
message: str = Field(..., description="用户的输入指令")
|
message: str = Field(..., description="用户的输入指令")
|
||||||
# image_url: Optional[str] = Field(None, description="图片地址") # ✅ 新增
|
image_url: Optional[str] = Field(None, description="图片地址") # ✅ 新增
|
||||||
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
||||||
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||||
config_params: Optional[AgentConfig] = None
|
config_params: Optional[AgentConfig] = None
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from src.server.deep_agent.init_llm import vision_llm
|
|
||||||
from src.server.deep_agent.tools.vision_analyze_tool import vision_analyze_tool
|
|
||||||
|
|
||||||
vision_subagent = {
|
|
||||||
"name": "vision_subagent",
|
|
||||||
"description": "分析用户上传的图片,提取家具、风格、颜色、材质等信息",
|
|
||||||
"system_prompt": """
|
|
||||||
你是一个专业的视觉分析助手(家具设计方向)。
|
|
||||||
|
|
||||||
你的任务:
|
|
||||||
1. 理解用户提供的图片(路径或URL)
|
|
||||||
2. 分析家具内容
|
|
||||||
3. 输出结构化JSON(不要解释)
|
|
||||||
|
|
||||||
格式:
|
|
||||||
{
|
|
||||||
"objects": [],
|
|
||||||
"style": "",
|
|
||||||
"color": [],
|
|
||||||
"material": [],
|
|
||||||
"room_type": "",
|
|
||||||
"description": ""
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
"tools": [], # ❗这里不用tool,直接用多模态模型
|
|
||||||
"model": vision_llm,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def vision_execute(state):
|
|
||||||
image = state.get("image")
|
|
||||||
|
|
||||||
if image is None:
|
|
||||||
return {
|
|
||||||
"error": "NO_IMAGE"
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt = """
|
|
||||||
你是一个家具视觉分析模型。
|
|
||||||
|
|
||||||
任务:分析图片并输出JSON:
|
|
||||||
|
|
||||||
{
|
|
||||||
"objects": [],
|
|
||||||
"style": "",
|
|
||||||
"color": [],
|
|
||||||
"material": [],
|
|
||||||
"room_type": "",
|
|
||||||
"description": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
规则:
|
|
||||||
- 只基于图像内容
|
|
||||||
- 不允许编造
|
|
||||||
- objects 最多5个
|
|
||||||
- color 最多3个
|
|
||||||
- 只输出JSON
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = vision_llm.generate(
|
|
||||||
image=image, # ⭐ 关键:真正喂图
|
|
||||||
prompt=prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
return safe_parse_json(result)
|
|
||||||
|
|
||||||
|
|
||||||
def safe_parse_json(text):
|
|
||||||
try:
|
|
||||||
return json.loads(text)
|
|
||||||
except:
|
|
||||||
return {
|
|
||||||
"error": "INVALID_JSON",
|
|
||||||
"raw": text
|
|
||||||
}
|
|
||||||
@@ -60,20 +60,95 @@ def build_system_prompt(use_report):
|
|||||||
|
|
||||||
def build_painter_prompt():
|
def build_painter_prompt():
|
||||||
prompt = """
|
prompt = """
|
||||||
你是 painter_subagent,专门生成或编辑 sketch 图。
|
你是 painter_subagent,专门负责「生成」或「编辑」 sketch 图像的工具调度助手。
|
||||||
1. 每次开始决策前,先调用工具 read_file("/current_sketch_path.txt") 获取当前路径。
|
|
||||||
- 如果文件不存在或返回空 → 当前没有历史图,使用 generate_sketch。
|
|
||||||
- 如果有路径 → 检查用户意图是否为「修改/编辑/改成/调整/优化/把...变成」,如果是则必须使用 edit_sketch,并传入 image_path = 读取到的路径。
|
|
||||||
2. 生成或编辑完成后,**必须立即**调用 write_file("/current_sketch_path.txt", content=本次生成的图片完整路径) 来更新状态。
|
|
||||||
3. 【对用户隐藏路径】:
|
|
||||||
- 永远不要在最终回复给用户的任何消息中出现路径、/tmp/、/current_sketch_path.txt 等字符串!
|
|
||||||
- 回复格式只能是:
|
|
||||||
"图片已成功生成!"
|
|
||||||
或
|
|
||||||
"已按你的要求把狗改成猫,图片更新完成!"
|
|
||||||
- 如果前端支持图片展示,你可以直接返回图片(但不要带路径文字)。
|
|
||||||
|
|
||||||
现在开始严格遵守以上规则。
|
你的唯一任务是:根据用户意图,严格选择正确的工具(generate_furniture 或 edit_furniture),并构造对应参数。
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
【一、工具选择规则(最高优先级)】
|
||||||
|
|
||||||
|
你必须先判断用户意图属于以下哪一类:
|
||||||
|
|
||||||
|
### ✅ 1. 编辑类(必须使用 edit_furniture)
|
||||||
|
当用户输入包含以下语义时:
|
||||||
|
- 修改 / 改成 / 换成 / 调整 / 优化 / 变成 / 改颜色 / 改样式
|
||||||
|
- 或任何“基于已有图片做改变”的表达
|
||||||
|
|
||||||
|
👉 必须使用:
|
||||||
|
edit_furniture
|
||||||
|
|
||||||
|
👉 严格要求:
|
||||||
|
- 不允许调用 generate_furniture
|
||||||
|
- 不允许重新生成整张图
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### ✅ 2. 生成类(使用 generate_furniture)
|
||||||
|
仅当用户明确表达:
|
||||||
|
- 生成 / 创建 / 设计 / 画一个 / 给我一个
|
||||||
|
|
||||||
|
👉 才允许使用:
|
||||||
|
generate_furniture
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### ❗默认规则(非常重要)
|
||||||
|
如果用户输入不明确(例如:“改成绿色”):
|
||||||
|
|
||||||
|
👉 一律视为【编辑类】
|
||||||
|
👉 使用 edit_furniture
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
【二、关于图片来源(关键规则)】
|
||||||
|
|
||||||
|
- 当前系统已经提供了一张“当前图片”(不需要你生成 image_url)
|
||||||
|
- ❗禁止你自行编造 image_url
|
||||||
|
- ❗禁止你猜测 image_url
|
||||||
|
- edit_furniture 会自动从上下文获取图片
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
【三、参数构造规则】
|
||||||
|
|
||||||
|
调用 edit_furniture 时:
|
||||||
|
|
||||||
|
- 只需要提供:
|
||||||
|
{
|
||||||
|
"prompt": "<英文图像编辑描述>"
|
||||||
|
}
|
||||||
|
|
||||||
|
- prompt 要求:
|
||||||
|
- 清晰描述修改内容
|
||||||
|
- 保留原结构(除非用户明确要求改变)
|
||||||
|
- 示例:
|
||||||
|
"Change the sofa to green color while keeping the original lines and structure."
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
【四、禁止行为(强约束)】
|
||||||
|
|
||||||
|
你绝对不能:
|
||||||
|
|
||||||
|
- ❌ 在编辑场景调用 generate_furniture
|
||||||
|
- ❌ 编造 image_url
|
||||||
|
- ❌ 忽略“修改类”意图
|
||||||
|
- ❌ 因为信息少就拒绝调用工具
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
【五、用户回复规则(必须遵守)】
|
||||||
|
|
||||||
|
你对用户的最终回复只能是以下格式之一:
|
||||||
|
|
||||||
|
- "图片已成功生成!"
|
||||||
|
- "已按你的要求完成修改,图片已更新!"
|
||||||
|
|
||||||
|
❗禁止输出:
|
||||||
|
- 路径
|
||||||
|
- URL
|
||||||
|
- 工具参数
|
||||||
|
- 解释过程
|
||||||
|
|
||||||
|
--------------------------------
|
||||||
|
|
||||||
|
现在开始工作。
|
||||||
"""
|
"""
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|||||||
@@ -1,36 +1,37 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from google.oauth2 import service_account
|
import logging
|
||||||
from langchain_core.tools import tool
|
|
||||||
from google import genai
|
|
||||||
from google.genai.types import GenerateContentConfig, Modality
|
|
||||||
from langgraph.prebuilt import ToolRuntime
|
|
||||||
|
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
from google import genai
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
from langgraph.prebuilt import ToolRuntime
|
||||||
|
|
||||||
from src.core.config import settings
|
from src.core.config import settings, MONGO_URI
|
||||||
from src.server.utils.new_oss_client import oss_upload_image, oss_get_image, is_minio_file_exist, oss_upload_image_file
|
from src.server.deep_agent.utils.mongodb_util import ThreadImageMinIOStore
|
||||||
|
|
||||||
|
# from google.genai.types import GenerateContentConfig, Modality
|
||||||
|
# from langgraph.config import get_stream_writer
|
||||||
|
# from src.server.utils.new_oss_client import oss_upload_image, oss_get_image, is_minio_file_exist, oss_upload_image_file
|
||||||
|
|
||||||
|
# 初始化全局凭证和客户端
|
||||||
|
# creds = service_account.Credentials.from_service_account_file(
|
||||||
|
# settings.GOOGLE_GENAI_USE_VERTEXAI,
|
||||||
|
# scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
|
# )
|
||||||
|
# client = genai.Client(
|
||||||
|
# credentials=creds,
|
||||||
|
# project=settings.GOOGLE_CLOUD_PROJECT,
|
||||||
|
# location=settings.GOOGLE_CLOUD_LOCATION,
|
||||||
|
# vertexai=True
|
||||||
|
# )
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
# 初始化全局凭证和客户端
|
|
||||||
creds = service_account.Credentials.from_service_account_file(
|
|
||||||
settings.GOOGLE_GENAI_USE_VERTEXAI,
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
|
|
||||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
client = genai.Client(
|
image_store = ThreadImageMinIOStore(MONGO_URI, "agent_tool_generate_db")
|
||||||
credentials=creds,
|
|
||||||
project=settings.GOOGLE_CLOUD_PROJECT,
|
|
||||||
location=settings.GOOGLE_CLOUD_LOCATION,
|
|
||||||
vertexai=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_image_path_exist(image_path):
|
def is_image_path_exist(image_path):
|
||||||
@@ -47,6 +48,7 @@ def create_generate_furniture_tool(workspace_dir, width: int = 1024, height: int
|
|||||||
使用 Gemini 图像生成模型根据详细的英文提示词生成家具设计草图。
|
使用 Gemini 图像生成模型根据详细的英文提示词生成家具设计草图。
|
||||||
"""
|
"""
|
||||||
logger.info(f"\n[系统日志] 正在调用 generate_furniture ...")
|
logger.info(f"\n[系统日志] 正在调用 generate_furniture ...")
|
||||||
|
thread_id = runtime.config.get("configurable").get("thread_id")
|
||||||
try:
|
try:
|
||||||
# 1. 生成图像 - local flux2-klein
|
# 1. 生成图像 - local flux2-klein
|
||||||
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||||
@@ -67,15 +69,7 @@ def create_generate_furniture_tool(workspace_dir, width: int = 1024, height: int
|
|||||||
image_url = result.get("output_path", None)
|
image_url = result.get("output_path", None)
|
||||||
|
|
||||||
if image_url:
|
if image_url:
|
||||||
filename = os.path.join(workspace_dir, image_url)
|
image_store.save_image_path(thread_id=thread_id, object_path=image_url, metadata={"prompt": prompt, "generated_at": str(datetime.now())})
|
||||||
# 2. 创建本地目录(确保目录存在)
|
|
||||||
local_dir = os.path.dirname(filename)
|
|
||||||
if not os.path.exists(local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
|
|
||||||
img = oss_get_image(oss_client=minio_client, bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:])
|
|
||||||
img.save(filename)
|
|
||||||
|
|
||||||
return image_url
|
return image_url
|
||||||
else:
|
else:
|
||||||
return f"Image generation failed."
|
return f"Image generation failed."
|
||||||
@@ -89,65 +83,42 @@ def create_generate_furniture_tool(workspace_dir, width: int = 1024, height: int
|
|||||||
|
|
||||||
def create_edit_furniture_tool(workspace_dir, width: int = 1024, height: int = 1024):
|
def create_edit_furniture_tool(workspace_dir, width: int = 1024, height: int = 1024):
|
||||||
@tool
|
@tool
|
||||||
async def edit_furniture(prompt: str, input_image_path) -> str:
|
async def edit_furniture(prompt: str, runtime: ToolRuntime) -> str:
|
||||||
"""
|
"""
|
||||||
使用图像生成模型根据详细的英文提示词编辑家具设计草图。
|
使用图像生成模型根据详细的英文提示词编辑家具设计草图。
|
||||||
"""
|
"""
|
||||||
logger.info(f"\n[系统日志] 正在调用 edit_furniture ...")
|
logger.info(f"\n[系统日志] 正在调用 edit_furniture ...")
|
||||||
|
thread_id = runtime.config.get("configurable").get("thread_id")
|
||||||
try:
|
try:
|
||||||
# 0. 编辑前先检查工作环境和minio上是否存在该图像
|
current_image_path = image_store.get_image_path(thread_id).get("current_image_path", False)
|
||||||
input_image_path = input_image_path.lstrip('/')
|
input_image_path = runtime.state.get("files").get("input_image", False)
|
||||||
filename = os.path.join(workspace_dir, input_image_path)
|
if input_image_path or current_image_path:
|
||||||
local_exist = is_image_path_exist(filename)
|
input_path = [path for path in (input_image_path, current_image_path) if path]
|
||||||
minio_exist = is_minio_file_exist(minio_client=minio_client, bucket_name=input_image_path.split('/')[0], object_name=input_image_path.split('/')[0])
|
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||||
|
bucket_name = "fida-test" # 替换为你的 bucket 名称
|
||||||
|
request_data = {
|
||||||
|
"input_image_paths": input_path,
|
||||||
|
"prompt": prompt,
|
||||||
|
"bucket_name": bucket_name,
|
||||||
|
"object_name": object_name,
|
||||||
|
"width": width,
|
||||||
|
"height": height
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient(timeout=120) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
|
||||||
|
json=request_data,
|
||||||
|
)
|
||||||
|
result = resp.json()
|
||||||
|
image_url = result.get("output_path", None)
|
||||||
|
|
||||||
if not local_exist and not minio_exist:
|
if image_url:
|
||||||
# 两个地方都不存在 直接报错
|
image_store.save_image_path(thread_id=thread_id, object_path=image_url, metadata={"prompt": prompt, "generated_at": str(datetime.now())})
|
||||||
return f"Image generation failed."
|
return image_url
|
||||||
elif local_exist and not minio_exist:
|
else:
|
||||||
# 把本地的上传到minio
|
return f"Image generation failed."
|
||||||
oss_upload_image_file(oss_client=minio_client, bucket=input_image_path.split('/')[0], object_name=input_image_path.split('/')[0], file_path=filename)
|
|
||||||
elif not local_exist and minio_exist:
|
|
||||||
# minio的下载到本地
|
|
||||||
img = oss_get_image(oss_client=minio_client, bucket=input_image_path.split('/')[0], object_name=input_image_path.split('/')[0], )
|
|
||||||
img.save(filename)
|
|
||||||
elif minio_exist and local_exist:
|
|
||||||
# 两个地方都存在 直接跳过
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 1. 生成图像 - local flux2-klein
|
|
||||||
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
|
||||||
bucket_name = "fida-test" # 替换为你的 bucket 名称
|
|
||||||
request_data = {
|
|
||||||
"input_image_paths": [input_image_path],
|
|
||||||
"prompt": prompt,
|
|
||||||
"bucket_name": bucket_name,
|
|
||||||
"object_name": object_name,
|
|
||||||
"width": width,
|
|
||||||
"height": height
|
|
||||||
}
|
|
||||||
async with httpx.AsyncClient(timeout=120) as client:
|
|
||||||
resp = await client.post(
|
|
||||||
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
|
|
||||||
json=request_data,
|
|
||||||
)
|
|
||||||
result = resp.json()
|
|
||||||
image_url = result.get("output_path", None)
|
|
||||||
|
|
||||||
if image_url:
|
|
||||||
filename = os.path.join(workspace_dir, image_url)
|
|
||||||
# 2. 创建本地目录(确保目录存在)
|
|
||||||
local_dir = os.path.dirname(filename)
|
|
||||||
if not os.path.exists(local_dir):
|
|
||||||
os.makedirs(local_dir, exist_ok=True)
|
|
||||||
|
|
||||||
img = oss_get_image(oss_client=minio_client, bucket=image_url.split('/')[0], object_name=image_url[image_url.find('/') + 1:])
|
|
||||||
img.save(filename)
|
|
||||||
return image_url
|
|
||||||
else:
|
else:
|
||||||
return f"Image generation failed."
|
return f"The picture to be edited does not exist."
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"edit_furniture error :{e}")
|
logger.warning(f"edit_furniture error :{e}")
|
||||||
return "edit_furniture error"
|
return "edit_furniture error"
|
||||||
|
|||||||
144
src/server/deep_agent/utils/mongodb_util.py
Normal file
144
src/server/deep_agent/utils/mongodb_util.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
from pymongo.errors import PyMongoError
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from src.core.config import MONGO_URI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadImageMinIOStore:
|
||||||
|
"""
|
||||||
|
根據 thread_id 存取/更新 current_image 的 MinIO 物件路徑(不存 binary)
|
||||||
|
|
||||||
|
儲存格式範例:
|
||||||
|
{
|
||||||
|
"thread_id": "thread_abc123",
|
||||||
|
"current_image_path": "images/2025/03/thread_abc123_latest.png",
|
||||||
|
"updated_at": ISODate,
|
||||||
|
"metadata": {"format": "png", "desc": "生成的貓圖", "size_bytes": 512345}
|
||||||
|
}
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
store = ThreadImageMinIOStore("mongodb://localhost:27017/", "deepagents_db")
|
||||||
|
store.save_image_path("thread_abc123", "images/cat/001.png", "https://minio.example.com/bucket/images/cat/001.png")
|
||||||
|
path_info = store.get_image_path("thread_abc123")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mongo_uri: str,
|
||||||
|
db_name: str = "deepagents_db",
|
||||||
|
collection_name: str = "agent_image_paths",
|
||||||
|
connect_timeout_ms: int = 5000,
|
||||||
|
server_selection_timeout_ms: int = 5000,
|
||||||
|
):
|
||||||
|
self.client = MongoClient(
|
||||||
|
mongo_uri,
|
||||||
|
connectTimeoutMS=connect_timeout_ms,
|
||||||
|
serverSelectionTimeoutMS=server_selection_timeout_ms,
|
||||||
|
retryWrites=True,
|
||||||
|
retryReads=True,
|
||||||
|
)
|
||||||
|
self.db = self.client[db_name]
|
||||||
|
self.collection: Collection = self.db[collection_name]
|
||||||
|
|
||||||
|
# 建立唯一索引
|
||||||
|
self.collection.create_index("thread_id", unique=True)
|
||||||
|
|
||||||
|
def save_image_path(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
object_path: str, # MinIO 中的相對路徑,例如 "test/123.png" 或 "images/20250320/abc.png"
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
保存或更新某個 thread 的 current_image MinIO 路徑
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: 對話執行緒 ID
|
||||||
|
object_path: MinIO bucket 內的物件路徑 (不含 bucket 名稱)
|
||||||
|
metadata: 額外資訊,例如 {"prompt": "...", "format": "png", "width": 1024}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功
|
||||||
|
"""
|
||||||
|
document = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"current_image_path": object_path,
|
||||||
|
"updated_at": datetime.now(),
|
||||||
|
"metadata": metadata or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.collection.update_one(
|
||||||
|
{"thread_id": thread_id},
|
||||||
|
{"$set": document},
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
action = "updated" if result.modified_count > 0 else "inserted"
|
||||||
|
logger.info(f"Image path for thread {thread_id} {action}: {object_path}")
|
||||||
|
return True
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"Failed to save image path for thread {thread_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_image_path(self, thread_id: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
取得某 thread 的 current_image MinIO 資訊
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"current_image_path": str,
|
||||||
|
"updated_at": datetime,
|
||||||
|
"metadata": dict
|
||||||
|
} 或 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
doc = self.collection.find_one({"thread_id": thread_id})
|
||||||
|
if not doc:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"current_image_path": doc.get("current_image_path"),
|
||||||
|
"updated_at": doc.get("updated_at"),
|
||||||
|
"metadata": doc.get("metadata", {})
|
||||||
|
}
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"Failed to get image path for thread {thread_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_object_path_only(self, thread_id: str) -> Optional[str]:
|
||||||
|
"""只取 MinIO 相對路徑,方便直接給 MinIO client 使用"""
|
||||||
|
info = self.get_image_path(thread_id)
|
||||||
|
return info["current_image_path"] if info else None
|
||||||
|
|
||||||
|
def delete_image_path(self, thread_id: str) -> bool:
|
||||||
|
"""刪除某 thread 的記錄(不影響 MinIO 實際檔案)"""
|
||||||
|
try:
|
||||||
|
result = self.collection.delete_one({"thread_id": thread_id})
|
||||||
|
if result.deleted_count > 0:
|
||||||
|
logger.info(f"Image path record for thread {thread_id} deleted")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except PyMongoError as e:
|
||||||
|
logger.error(f"Failed to delete image path for thread {thread_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.client.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
image_store = ThreadImageMinIOStore(MONGO_URI, "agent_tool_generate_db")
|
||||||
|
success = image_store.save_image_path(
|
||||||
|
thread_id="121233",
|
||||||
|
object_path="test/123.png",
|
||||||
|
metadata={"prompt": "prompt", "generated_at": str(datetime.now())})
|
||||||
|
print(success)
|
||||||
|
info = image_store.get_image_path("121233")
|
||||||
|
print(info)
|
||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import urllib3
|
import urllib3
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from minio import Minio
|
from minio import Minio, S3Error
|
||||||
|
|
||||||
from src.core.config import settings
|
from src.core.config import settings
|
||||||
|
|
||||||
@@ -72,10 +72,10 @@ def oss_upload_image_file(oss_client, bucket, object_name, file_path):
|
|||||||
|
|
||||||
def get_presigned_url(oss_client, bucket, object_name):
|
def get_presigned_url(oss_client, bucket, object_name):
|
||||||
try:
|
try:
|
||||||
presigned_url = oss_client.presigned_get_object(
|
presigned_url = oss_client.get_presigned_url(
|
||||||
|
"GET",
|
||||||
bucket_name=bucket,
|
bucket_name=bucket,
|
||||||
object_name=object_name,
|
object_name=object_name,
|
||||||
expires=3600
|
|
||||||
)
|
)
|
||||||
return presigned_url
|
return presigned_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -83,18 +83,41 @@ def get_presigned_url(oss_client, bucket, object_name):
|
|||||||
return "object not found"
|
return "object not found"
|
||||||
|
|
||||||
|
|
||||||
def is_minio_file_exist(minio_client: Minio, bucket_name: str, object_name: str) -> bool:
|
def is_minio_file_exist(oss_client: Minio, bucket_name: str, object_name: str) -> bool:
|
||||||
try:
|
try:
|
||||||
# 核心判断:检查MinIO中指定bucket+object是否存在
|
# 核心判断:检查MinIO中指定bucket+object是否存在
|
||||||
minio_client.stat_object(bucket_name, object_name)
|
oss_client.stat_object(bucket_name, object_name)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def load_minio_file_to_state(oss_client, bucket: str, object_name: str, display_filename: str = None):
|
||||||
|
try:
|
||||||
|
# 下載 object 成 bytes
|
||||||
|
response = oss_client.get_object(
|
||||||
|
bucket_name=bucket,
|
||||||
|
object_name=object_name,
|
||||||
|
)
|
||||||
|
data_bytes = response.read()
|
||||||
|
response.close()
|
||||||
|
response.release_conn()
|
||||||
|
|
||||||
|
# 決定在 agent 裡顯示的檔名(可覆寫,避免洩漏真實 object name)
|
||||||
|
filename = display_filename or object_name.split("/")[-1]
|
||||||
|
|
||||||
|
# 回傳適合塞進 state["files"] 的格式
|
||||||
|
return {filename: data_bytes}
|
||||||
|
|
||||||
|
except S3Error as err:
|
||||||
|
raise ValueError(f"MinIO 下載失敗: {err}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
url = 'fida-test/furniture/sketches/9356e3d8-d56e-4478-adde-61b29119979b.png'
|
url = 'fida-test/furniture/sketches/1b82b2db-8019-4796-b2cc-11fb24c7799d.png'
|
||||||
read_type = "2"
|
read_type = "2"
|
||||||
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:])
|
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:])
|
||||||
img.show()
|
img.show()
|
||||||
img.save("result.png")
|
img.save("result.png")
|
||||||
|
# get_presigned_url(oss_client=minio_client, bucket="fida-test", object_name="furniture/sketches/07bf4cfe-4502-4821-b78f-7727bf409498.png")
|
||||||
|
#
|
||||||
Reference in New Issue
Block a user