1.新增视觉能力 2.新增对上次图片 或 上传图片 引用图片做编辑能力.
This commit is contained in:
@@ -10,16 +10,18 @@ from typing import AsyncGenerator
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from langchain_core.messages import SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
from langchain_core.messages import SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
|
||||||
|
|
||||||
from src.core.config import PROJECT_ROOT, settings
|
from src.core.config import PROJECT_ROOT, settings, MONGO_URI
|
||||||
from src.server.deep_agent.agents.main_agent import build_main_agent
|
from src.server.deep_agent.agents.main_agent import build_main_agent
|
||||||
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
|
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
|
||||||
from src.server.deep_agent.tools.generate_furniture_sketch import is_image_path_exist
|
from src.server.deep_agent.tools.generate_furniture_sketch import is_image_path_exist
|
||||||
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
|
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
|
||||||
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
|
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
|
||||||
|
from src.server.deep_agent.utils.mongodb_util import ThreadImageMinIOStore
|
||||||
from src.server.utils.new_oss_client import is_minio_file_exist, oss_upload_image_file, oss_get_image, get_presigned_url
|
from src.server.utils.new_oss_client import is_minio_file_exist, oss_upload_image_file, oss_get_image, get_presigned_url
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
image_store = ThreadImageMinIOStore(MONGO_URI, "agent_tool_generate_db")
|
||||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,7 +77,6 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
响应流包含三种类型的事件:会话开始、节点消息、会话结束
|
响应流包含三种类型的事件:会话开始、节点消息、会话结束
|
||||||
|
|
||||||
"""
|
"""
|
||||||
logger.info(f"chat request data: {request}")
|
|
||||||
if request.thread_id:
|
if request.thread_id:
|
||||||
need_title = False
|
need_title = False
|
||||||
else:
|
else:
|
||||||
@@ -88,7 +89,7 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
target_thread_id = str(uuid.uuid4())[:8] if is_branching else (source_thread_id or str(uuid.uuid4())[:8])
|
||||||
# 构建主agent
|
# 构建主agent
|
||||||
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{target_thread_id}")
|
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{target_thread_id}")
|
||||||
logger.info(f"target_thread_id : workspace_dir: {workspace_dir}")
|
logger.info(f"chat request data: {request} | target_thread_id : workspace_dir: {workspace_dir}")
|
||||||
main_agent = build_main_agent(request.use_report, workspace_dir)
|
main_agent = build_main_agent(request.use_report, workspace_dir)
|
||||||
# 2. 配置參數
|
# 2. 配置參數
|
||||||
temp = request.config_params.temperature if request.config_params else 0.7
|
temp = request.config_params.temperature if request.config_params else 0.7
|
||||||
@@ -132,39 +133,46 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
is_first = True
|
is_first = True
|
||||||
if request.image_url:
|
content = [{"type": "text", "text": request.message}]
|
||||||
bucket, object_name = request.image_url.split('/', 1)
|
files = {
|
||||||
|
"input_image": [],
|
||||||
|
"quote_image": "",
|
||||||
|
"current_image": ""
|
||||||
|
}
|
||||||
|
# 用户上传图片
|
||||||
|
if request.input_image_path:
|
||||||
|
for path in request.input_image_path:
|
||||||
|
bucket, object_name = path.split('/', 1)
|
||||||
|
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||||
|
files["input_image"].append(path)
|
||||||
|
|
||||||
|
# 用户引用图片
|
||||||
|
if len(request.quote_image_path):
|
||||||
|
bucket, object_name = request.quote_image_path.split('/', 1)
|
||||||
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
||||||
new_messages = {
|
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||||
"messages": [
|
files["quote_image"] = request.quote_image_path
|
||||||
{
|
|
||||||
"role": "user",
|
# 用户最近生成图片
|
||||||
"content": [
|
if image_store.get_image_path(target_thread_id):
|
||||||
{"type": "text", "text": request.message},
|
current_image_path = image_store.get_image_path(target_thread_id).get("current_image_path", False)
|
||||||
{"type": "image_url", "image_url": {"url": image_url}}
|
if current_image_path:
|
||||||
]
|
bucket, object_name = current_image_path.split('/', 1)
|
||||||
},
|
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
||||||
],
|
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||||
"files": {
|
|
||||||
"input_image": request.image_url,
|
final_messages = {
|
||||||
}
|
"messages": [
|
||||||
}
|
{
|
||||||
else:
|
"role": "user",
|
||||||
new_messages = {
|
"content": content
|
||||||
"messages": [
|
},
|
||||||
{
|
],
|
||||||
"role": "user",
|
"files": files
|
||||||
"content": [
|
}
|
||||||
{"type": "text", "text": request.message},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"files": {
|
|
||||||
"input_image": "",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
async for stream in main_agent.astream(
|
async for stream in main_agent.astream(
|
||||||
new_messages,
|
final_messages,
|
||||||
config=current_config,
|
config=current_config,
|
||||||
stream_mode=["updates", "messages", "custom"],
|
stream_mode=["updates", "messages", "custom"],
|
||||||
subgraphs=True
|
subgraphs=True
|
||||||
@@ -175,6 +183,8 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
is_first = False
|
is_first = False
|
||||||
_, mode, chunks = stream
|
_, mode, chunks = stream
|
||||||
if mode == "updates": # 只做记录 不做事件返回
|
if mode == "updates": # 只做记录 不做事件返回
|
||||||
|
logger.info(f"[updates] -- {chunks}")
|
||||||
|
|
||||||
update_model_messages = chunks.get("model", None)
|
update_model_messages = chunks.get("model", None)
|
||||||
update_tools_messages = chunks.get("tools", None)
|
update_tools_messages = chunks.get("tools", None)
|
||||||
payload_out = {
|
payload_out = {
|
||||||
@@ -194,7 +204,6 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
"node": model_name if model_name else "main",
|
"node": model_name if model_name else "main",
|
||||||
"tool_calls": model_token.tool_calls
|
"tool_calls": model_token.tool_calls
|
||||||
})
|
})
|
||||||
logger.info(f"[updates] {model_name} -- {model_content_blocks} -- {model_token.tool_calls}")
|
|
||||||
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(payload_out, ensure_ascii=False)}\n\n"
|
||||||
elif update_tools_messages:
|
elif update_tools_messages:
|
||||||
tools_messages = update_tools_messages.get("messages", [])
|
tools_messages = update_tools_messages.get("messages", [])
|
||||||
@@ -207,6 +216,8 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
logger.info(f"[updates] -- {chunks}")
|
logger.info(f"[updates] -- {chunks}")
|
||||||
|
|
||||||
elif mode == "messages":
|
elif mode == "messages":
|
||||||
|
# logger.info(f"[messages] -- {chunks}")
|
||||||
|
|
||||||
token, metadata = chunks
|
token, metadata = chunks
|
||||||
subagent_name = metadata.get('lc_agent_name', "main")
|
subagent_name = metadata.get('lc_agent_name', "main")
|
||||||
payload_out = {
|
payload_out = {
|
||||||
@@ -267,6 +278,8 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
elif mode == "custom":
|
elif mode == "custom":
|
||||||
|
logger.info(f"[custom] -- {chunks}")
|
||||||
|
|
||||||
payload_out = {
|
payload_out = {
|
||||||
"node": "research-agent",
|
"node": "research-agent",
|
||||||
"is_delta": False,
|
"is_delta": False,
|
||||||
@@ -289,6 +302,7 @@ async def chat_stream(request: DeepAgentChatRequest):
|
|||||||
# 获取标题
|
# 获取标题
|
||||||
if need_title:
|
if need_title:
|
||||||
title = await conversation_title(agent=main_agent, config=current_config)
|
title = await conversation_title(agent=main_agent, config=current_config)
|
||||||
|
logger.info(f"[need_title] {title}")
|
||||||
yield f"data: {json.dumps({'title': title}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'title': title}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'status': 'end'}, ensure_ascii=False)}\n\n"
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ class AgentConfig(BaseModel):
|
|||||||
|
|
||||||
class DeepAgentChatRequest(BaseModel):
|
class DeepAgentChatRequest(BaseModel):
|
||||||
message: str = Field(..., description="用户的输入指令")
|
message: str = Field(..., description="用户的输入指令")
|
||||||
image_url: Optional[str] = Field(None, description="图片地址") # ✅ 新增
|
quote_image_path: Optional[str] = Field(None, description="引用图片地址") # ✅ 新增
|
||||||
|
input_image_path: Optional[list[str]] = Field(None, description="上传图片地址集合") # ✅ 新增
|
||||||
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
thread_id: Optional[str] = Field(None, description="会话线程ID,不传则开启新会话")
|
||||||
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
checkpoint_id: Optional[str] = Field(None, description="回溯点的ID,用于从历史点开启新对话")
|
||||||
config_params: Optional[AgentConfig] = None
|
config_params: Optional[AgentConfig] = None
|
||||||
|
|||||||
@@ -90,10 +90,20 @@ def create_edit_furniture_tool(workspace_dir, width: int = 1024, height: int = 1
|
|||||||
logger.info(f"\n[系统日志] 正在调用 edit_furniture ...")
|
logger.info(f"\n[系统日志] 正在调用 edit_furniture ...")
|
||||||
thread_id = runtime.config.get("configurable").get("thread_id")
|
thread_id = runtime.config.get("configurable").get("thread_id")
|
||||||
try:
|
try:
|
||||||
current_image_path = image_store.get_image_path(thread_id).get("current_image_path", False)
|
current_image_path = None
|
||||||
input_image_path = runtime.state.get("files").get("input_image", False)
|
if image_store.get_image_path(thread_id):
|
||||||
if input_image_path or current_image_path:
|
current_image_path = image_store.get_image_path(thread_id).get("current_image_path", False)
|
||||||
input_path = [path for path in (input_image_path, current_image_path) if path]
|
user_input_image_paths = runtime.state.get("files").get("input_image", [])
|
||||||
|
user_quote_image_path = runtime.state.get("files").get("quote_image", "")
|
||||||
|
input_path = []
|
||||||
|
if len(user_input_image_paths) or current_image_path:
|
||||||
|
if len(user_input_image_paths):
|
||||||
|
for path in user_input_image_paths:
|
||||||
|
input_path.append(path)
|
||||||
|
if user_quote_image_path:
|
||||||
|
input_path.append(user_quote_image_path)
|
||||||
|
if not len(user_input_image_paths) and not user_quote_image_path:
|
||||||
|
input_path = [current_image_path]
|
||||||
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
object_name = f"furniture/sketches/{uuid.uuid4()}.png"
|
||||||
bucket_name = "fida-test" # 替换为你的 bucket 名称
|
bucket_name = "fida-test" # 替换为你的 bucket 名称
|
||||||
request_data = {
|
request_data = {
|
||||||
|
|||||||
Reference in New Issue
Block a user