新增图片上下文存储

This commit is contained in:
zcr
2026-03-30 15:12:56 +08:00
parent 1579c8d0f5
commit e3cf22edae
4 changed files with 121 additions and 248 deletions

View File

@@ -10,18 +10,15 @@ from typing import AsyncGenerator
from fastapi.responses import StreamingResponse
from langchain_core.messages import SystemMessage, AIMessageChunk, ToolMessage, AIMessage, ToolMessageChunk
from src.core.config import PROJECT_ROOT, settings, MONGO_URI
from src.core.config import PROJECT_ROOT, settings
from src.server.deep_agent.agents.main_agent import build_main_agent
from src.server.deep_agent.tools.conversation_title_tool import conversation_title
from src.server.deep_agent.tools.generate_furniture_sketch import is_image_path_exist
from src.schemas.deep_agent_chat import DeepAgentChatRequest, HistoryResponse, HistoryItem
from src.server.deep_agent.tools.extract_suggested_questions import generate_suggested_questions
from src.server.deep_agent.utils.mongodb_util import ThreadImageMinIOStore
from src.server.utils.new_oss_client import is_minio_file_exist, oss_upload_image_file, oss_get_image, get_presigned_url
from src.server.utils.new_oss_client import get_presigned_url
router = APIRouter(prefix="/chat", tags=["Furniture Design Chat"])
logger = logging.getLogger(__name__)
image_store = ThreadImageMinIOStore(MONGO_URI, "agent_tool_generate_db")
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@@ -108,6 +105,7 @@ async def chat_stream(request: DeepAgentChatRequest):
workspace_dir = os.path.join(PROJECT_ROOT, f"agent_workspace/{target_thread_id}")
logger.info(f"chat request data: {request} | target_thread_id : workspace_dir: {workspace_dir}")
main_agent = build_main_agent(request.use_report, workspace_dir, request.enable_thinking)
# 2. 配置參數
temp = request.config_params.temperature if request.config_params else 0.7
@@ -142,11 +140,14 @@ async def chat_stream(request: DeepAgentChatRequest):
"checkpoint_id": checkpoint_id
}
}
last_checkpoint_id = await get_branch_checkpoint_id(main_agent, source_config)
older_state = await main_agent.aget_state(source_config)
combined_values = older_state.values.copy()
if initial_messages:
combined_values["messages"] = list(combined_values.get("messages", [])) + initial_messages
await main_agent.aupdate_state(current_config, combined_values)
else:
last_checkpoint_id = await get_checkpoint_id(main_agent, current_config)
async def event_generator() -> AsyncGenerator[str, None]:
is_first = True
@@ -171,15 +172,6 @@ async def chat_stream(request: DeepAgentChatRequest):
content.append({"type": "image_url", "image_url": {"url": image_url}})
files["quote_image"] = request.quote_image_path
# 用户最近生成图片
if image_store.get_image_path(target_thread_id):
current_image_path = image_store.get_image_path(target_thread_id).get("current_image_path", False)
if current_image_path:
bucket, object_name = current_image_path.split('/', 1)
image_url = get_presigned_url(oss_client=minio_client, bucket=bucket, object_name=object_name)
if image_url is not None:
content.append({"type": "image_url", "image_url": {"url": image_url}})
final_messages = {
"messages": [
{
@@ -197,6 +189,17 @@ async def chat_stream(request: DeepAgentChatRequest):
):
if is_first:
checkpoint_id = main_agent.get_state(current_config).config.get("configurable").get("checkpoint_id")
if not checkpoint_id:
print("123")
main_agent.store.put(
("image_history",),
"checkpoint_id",
{
"current_checkpoint_id": checkpoint_id,
"last_checkpoint_id": last_checkpoint_id,
}
)
logger.info(f"*******************{checkpoint_id}**********************************")
yield f"data: {json.dumps({'thread_id': target_thread_id, 'is_branch': is_branching, 'status': 'start', "checkpoint_id": checkpoint_id}, ensure_ascii=False)}\n\n"
is_first = False
_, mode, chunks = stream
@@ -257,7 +260,7 @@ async def chat_stream(request: DeepAgentChatRequest):
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
print(f"[reasoning] {reasoning}*************************************************************************************")
logger.info(f"[reasoning] {reasoning}*************************************************************************************")
elif text:
if len(text) == 1:
payload_out.update({
@@ -267,7 +270,7 @@ async def chat_stream(request: DeepAgentChatRequest):
# "tool_call_chunk": token.tool_call_chunks[0] if token.tool_call_chunks else None
})
else:
print(f"[text] {text}*************************************************************************************")
logger.info(f"[text] {text}*************************************************************************************")
else:
payload_out.update({
"type": "tool_call",
@@ -395,7 +398,29 @@ async def get_chat_history(thread_id: str):
))
return HistoryResponse(thread_id=thread_id, history=history_data)
# try:
# except Exception as e:
# raise HTTPException(status_code=404, detail=f"History not found: {str(e)}")
async def get_checkpoint_id(main_agent, current_config):
# 🔥 最优:边遍历边找,找到第一个就返回,不浪费内存
async for item in main_agent.aget_state_history(config=current_config):
if item.next == ("__start__",):
# 找到直接处理并返回
# if item.parent_config:
# return item.parent_config.get('configurable', {}).get('checkpoint_id')
return item.config.get('configurable', {}).get('checkpoint_id')
# 没找到
return None
async def get_branch_checkpoint_id(main_agent, current_config):
# 🔥 最优:边遍历边找,找到第一个就返回,不浪费内存
async for item in main_agent.aget_state_history(config=current_config):
current_id = current_config.get('configurable', {}).get('checkpoint_id')
if item.next == ("__start__",) and item.config.get('configurable', {}).get('checkpoint_id') != current_id:
if item.parent_config:
if item.parent_config.get('configurable', {}).get('checkpoint_id') != current_id:
return item.config.get('configurable', {}).get('checkpoint_id')
else:
return item.config.get('configurable', {}).get('checkpoint_id')
# 没找到
return None