1.新增视觉能力 2.新增对上次图片 或 上传图片 引用图片做编辑能力.
This commit is contained in:
@@ -10,16 +10,18 @@ from typing import AsyncGenerator
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.messages import SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
||||
|
||||
from src.core.config import PROJECT_ROOT, settings
|
||||
from src.core.config import PROJECT_ROOT, settings, MONGO_URI
|
||||
from src.server.deep_agent.agents.main_agent import build_main_agent
|
||||
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
|
||||
from src.server.deep_agent.tools.generate_furniture_sketch import is_image_path_exist
|
||||
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
|
||||
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
|
||||
from src.server.deep_agent.utils.mongodb_util import ThreadImageMinIOStore
|
||||
from src.server.utils.new_oss_client import is_minio_file_exist, oss_upload_image_file, oss_get_image, get_presigned_url
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||
logger = logging.getLogger(__name__)
|
||||
image_store = ThreadImageMinIOStore(MONGO_URI, "agent_tool_generate_db")
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
@@ -75,7 +77,6 @@ async def chat_stream(request: DeepAgentChatRequest):
|
||||
响应流包含三种类型的事件:会话开始、节点消息、会话结束
|
||||
|
||||
"""
|
||||
logger.info(f"chat request data: {request}")
|
||||
if request.thread_id:
|
||||
need_title = False
|
||||
else:
|
||||
@@ -88,7 +89,7 @@ async def chat_stream(request: DeepAgentChatRequest):
|
||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||
# 构建主agent
|
||||
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{target_thread_id}")
|
||||
logger.info(f"target_thread_id : workspace_dir: {workspace_dir}")
|
||||
logger.info(f"chat request data: {request} | target_thread_id : workspace_dir: {workspace_dir}")
|
||||
main_agent = build_main_agent(request.use_report, workspace_dir)
|
||||
# 2. 配置參數
|
||||
temp = request.config_params.temperature if request.config_params else 0.7
|
||||
@@ -132,39 +133,46 @@ async def chat_stream(request: DeepAgentChatRequest):
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
is_first = True
|
||||
if request.image_url:
|
||||
bucket, object_name = request.image_url.split('/', 1)
|
||||
content = [{"type": "text", "text": request.message}]
|
||||
files = {
|
||||
"input_image": [],
|
||||
"quote_image": "",
|
||||
"current_image": ""
|
||||
}
|
||||
# 用户上传图片
|
||||
if request.input_image_path:
|
||||
for path in request.input_image_path:
|
||||
bucket, object_name = path.split('/', 1)
|
||||
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
files["input_image"].append(path)
|
||||
|
||||
# 用户引用图片
|
||||
if len(request.quote_image_path):
|
||||
bucket, object_name = request.quote_image_path.split('/', 1)
|
||||
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
||||
new_messages = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": request.message},
|
||||
{"type": "image_url", "image_url": {"url": image_url}}
|
||||
]
|
||||
},
|
||||
],
|
||||
"files": {
|
||||
"input_image": request.image_url,
|
||||
}
|
||||
}
|
||||
else:
|
||||
new_messages = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": request.message},
|
||||
]
|
||||
},
|
||||
],
|
||||
"files": {
|
||||
"input_image": "",
|
||||
}
|
||||
}
|
||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
files["quote_image"] = request.quote_image_path
|
||||
|
||||
# 用户最近生成图片
|
||||
if image_store.get_image_path(target_thread_id):
|
||||
current_image_path = image_store.get_image_path(target_thread_id).get("current_image_path", False)
|
||||
if current_image_path:
|
||||
bucket, object_name = current_image_path.split('/', 1)
|
||||
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
||||
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
|
||||
final_messages = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
},
|
||||
],
|
||||
"files": files
|
||||
}
|
||||
async for stream in main_agent.astream(
|
||||
new_messages,
|
||||
final_messages,
|
||||
config=current_config,
|
||||
stream_mode=["updates", "messages", "custom"],
|
||||
subgraphs=True
|
||||
@@ -175,6 +183,8 @@ async def chat_stream(request: DeepAgentChatRequest):
|
||||
is_first = False
|
||||
_, mode, chunks = stream
|
||||
if mode == "updates": # 只做记录 不做事件返回
|
||||
logger.info(f"[updates] -- {chunks}")
|
||||
|
||||
update_model_messages = chunks.get("model", None)
|
||||
update_tools_messages = chunks.get("tools", None)
|
||||
payload_out = {
|
||||
@@ -194,7 +204,6 @@ async def chat_stream(request: DeepAgentChatRequest):
|
||||
"node": model_name if model_name else "main",
|
||||
"tool_calls": model_token.tool_calls
|
||||
})
|
||||
logger.info(f"[updates] {model_name} -- {model_content_blocks} -- {model_token.tool_calls}")
|
||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||
elif update_tools_messages:
|
||||
tools_messages = update_tools_messages.get("messages", [])
|
||||
@@ -207,6 +216,8 @@ async def chat_stream(request: DeepAgentChatRequest):
|
||||
logger.info(f"[updates] -- {chunks}")
|
||||
|
||||
elif mode == "messages":
|
||||
# logger.info(f"[messages] -- {chunks}")
|
||||
|
||||
token, metadata = chunks
|
||||
subagent_name = metadata.get('lc_agent_name', "main")
|
||||
payload_out = {
|
||||
@@ -267,6 +278,8 @@ async def chat_stream(request: DeepAgentChatRequest):
|
||||
continue
|
||||
|
||||
elif mode == "custom":
|
||||
logger.info(f"[custom] -- {chunks}")
|
||||
|
||||
payload_out = {
|
||||
"node": "research-agent",
|
||||
"is_delta": False,
|
||||
@@ -289,6 +302,7 @@ async def chat_stream(request: DeepAgentChatRequest):
|
||||
# 获取标题
|
||||
if need_title:
|
||||
title = await conversation_title(agent=main_agent, config=current_config)
|
||||
logger.info(f"[need_title] {title}")
|
||||
yield f"data: {json.dumps({'title': title}, ensure_ascii=False)}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
@@ -11,7 +11,8 @@ class AgentConfig(BaseModel):
|
||||
|
||||
class DeepAgentChatRequest(BaseModel):
|
||||
message: str = Field(..., description="用户的输入指令")
|
||||
image_url: Optional[str] = Field(None, description="图片地址") # ✅ 新增
|
||||
quote_image_path: Optional[str] = Field(None, description="引用图片地址") # ✅ 新增
|
||||
input_image_path: Optional[list[str]] = Field(None, description="上传图片地址集合") # ✅ 新增
|
||||
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
||||
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||
config_params: Optional[AgentConfig] = None
|
||||
|
||||
@@ -90,10 +90,20 @@ def create_edit_furniture_tool(workspace_dir, width: int = 1024, height: int = 1
|
||||
logger.info(f"\n[系统日志] 正在调用 edit_furniture ...")
|
||||
thread_id = runtime.config.get("configurable").get("thread_id")
|
||||
try:
|
||||
current_image_path = image_store.get_image_path(thread_id).get("current_image_path", False)
|
||||
input_image_path = runtime.state.get("files").get("input_image", False)
|
||||
if input_image_path or current_image_path:
|
||||
input_path = [path for path in (input_image_path, current_image_path) if path]
|
||||
current_image_path = None
|
||||
if image_store.get_image_path(thread_id):
|
||||
current_image_path = image_store.get_image_path(thread_id).get("current_image_path", False)
|
||||
user_input_image_paths = runtime.state.get("files").get("input_image", [])
|
||||
user_quote_image_path = runtime.state.get("files").get("quote_image", "")
|
||||
input_path = []
|
||||
if len(user_input_image_paths) or current_image_path:
|
||||
if len(user_input_image_paths):
|
||||
for path in user_input_image_paths:
|
||||
input_path.append(path)
|
||||
if user_quote_image_path:
|
||||
input_path.append(user_quote_image_path)
|
||||
if not len(user_input_image_paths) and not user_quote_image_path:
|
||||
input_path = [current_image_path]
|
||||
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||
bucket_name = "fida-test" # 替换为你的 bucket 名称
|
||||
request_data = {
|
||||
|
||||
Reference in New Issue
Block a user