弃用langgrpah更换deepagent
This commit is contained in:
225
src/server/deep_agent/tools/structured_retrieval_tool.py
Normal file
225
src/server/deep_agent/tools/structured_retrieval_tool.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.documents import Document
|
||||
|
||||
# RAG
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
# =========================
|
||||
# 全局模型(单例)
|
||||
# =========================
|
||||
|
||||
_EMBEDDING_MODEL = HuggingFaceEmbeddings(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
|
||||
_RERANK_MODEL = CrossEncoder(
|
||||
"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
)
|
||||
|
||||
|
||||
class StructuredRetrievalInput(BaseModel):
|
||||
file_paths: List[str] = Field(..., description="List of local markdown file paths.")
|
||||
query: str = Field(..., description="Extraction query")
|
||||
source_url: Optional[str] = Field(None, description="Optional global source URL")
|
||||
|
||||
|
||||
@tool("structured_retrieval", args_schema=StructuredRetrievalInput)
|
||||
def structured_retrieval(
|
||||
file_paths: List[str],
|
||||
query: str,
|
||||
source_url: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Batch structured extraction from markdown files.
|
||||
- Performs vector search + re-ranking
|
||||
- Saves extracted structured data as JSON file to disk
|
||||
- Returns ONLY summary (status, count, file path)
|
||||
"""
|
||||
|
||||
# ── 1. 收集所有文件內容 ──────────────────────────────────────
|
||||
all_docs_pool: List[Document] = []
|
||||
|
||||
for path in file_paths:
|
||||
if not os.path.exists(path) or not path.endswith((".md", ".markdown")):
|
||||
continue
|
||||
|
||||
file_name = os.path.basename(path)
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
current_source = source_url or _extract_source_from_md(content) or "unknown"
|
||||
|
||||
sections = _split_markdown_by_headers(content)
|
||||
|
||||
for sec in sections:
|
||||
all_docs_pool.append(
|
||||
Document(
|
||||
page_content=sec,
|
||||
metadata={"source_url": current_source, "file_name": file_name}
|
||||
)
|
||||
)
|
||||
|
||||
if not all_docs_pool:
|
||||
return {"status": "no_documents_found", "items_count": 0, "json_path": None}
|
||||
|
||||
# ── 2. Vector search ────────────────────────────────────────────
|
||||
vector_store = FAISS.from_documents(all_docs_pool, _EMBEDDING_MODEL)
|
||||
retrieved = vector_store.similarity_search(query, k=200)
|
||||
|
||||
# ── 3. 提取結構化片段 ──────────────────────────────────────────
|
||||
structured_items = []
|
||||
|
||||
for doc in retrieved:
|
||||
text = doc.page_content.strip()
|
||||
if len(text) < 30:
|
||||
continue
|
||||
|
||||
images = list(set(re.findall(r"!\[.*?\]\((.*?)\)", text)))
|
||||
|
||||
structured_items.append(
|
||||
{
|
||||
"text": text,
|
||||
"images": images,
|
||||
"source_url": doc.metadata.get("source_url"),
|
||||
"file_name": doc.metadata.get("file_name")
|
||||
}
|
||||
)
|
||||
|
||||
# ── 4. Re-rank ──────────────────────────────────────────────────
|
||||
if structured_items:
|
||||
unique_items = {item["text"]: item for item in structured_items}.values()
|
||||
pairs = [[query, item["text"]] for item in unique_items]
|
||||
scores = _RERANK_MODEL.predict(pairs)
|
||||
|
||||
sorted_items = sorted(
|
||||
zip(scores, unique_items),
|
||||
key=lambda x: x[0],
|
||||
reverse=True
|
||||
)
|
||||
top_items = [item for _, item in sorted_items[:50]]
|
||||
else:
|
||||
top_items = []
|
||||
|
||||
# ── 5. 寫入 JSON 文件 ──────────────────────────────────────────
|
||||
if not top_items:
|
||||
return {"status": "no_relevant_content", "items_count": 0, "json_path": None}
|
||||
|
||||
# 產生有意義的檔名
|
||||
safe_query = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '_', query)[:40]
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
json_filename = f"extracted_{safe_query}_{timestamp}.json"
|
||||
|
||||
# 建議的儲存目錄(與 crawl4ai_batch 對齊)
|
||||
output_dir = os.path.join(os.path.dirname(file_paths[0]), "..", "extracted")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
json_path = os.path.join(output_dir, json_filename)
|
||||
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"query": query,
|
||||
"extracted_at": timestamp,
|
||||
"item_count": len(top_items),
|
||||
"items": top_items
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2
|
||||
)
|
||||
|
||||
# ── 6. 只回傳摘要 ──────────────────────────────────────────────
|
||||
return {
|
||||
"status": "success",
|
||||
"items_count": len(top_items),
|
||||
"json_path": json_path,
|
||||
"summary": f"已提取 {len(top_items)} 個高相關片段,儲存於 {json_path}"
|
||||
}
|
||||
|
||||
|
||||
def _extract_source_from_md(content: str) -> Optional[str]:
|
||||
match = re.search(r"<!--\s*Source:\s*(.*?)\s*-->", content)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
# =========================
|
||||
# Markdown Header Split
|
||||
# =========================
|
||||
|
||||
def _split_markdown_by_headers(
|
||||
content: str,
|
||||
max_chars: int = 2000,
|
||||
overlap: int = 150,
|
||||
):
|
||||
header_re = re.compile(
|
||||
r'^(#{1,6})\s+(.+?)\s*$',
|
||||
re.MULTILINE
|
||||
)
|
||||
|
||||
matches = list(header_re.finditer(content))
|
||||
|
||||
if not matches:
|
||||
return _chunk_text(content, max_chars, overlap)
|
||||
|
||||
sections = []
|
||||
|
||||
for i, m in enumerate(matches):
|
||||
start = m.start()
|
||||
end = (
|
||||
matches[i + 1].start()
|
||||
if i + 1 < len(matches)
|
||||
else len(content)
|
||||
)
|
||||
|
||||
block = content[start:end].strip()
|
||||
if block:
|
||||
sections.append(block)
|
||||
|
||||
final_sections = []
|
||||
|
||||
for s in sections:
|
||||
if len(s) > max_chars:
|
||||
final_sections.extend(
|
||||
_chunk_text(s, max_chars, overlap)
|
||||
)
|
||||
else:
|
||||
final_sections.append(s)
|
||||
|
||||
return final_sections
|
||||
|
||||
|
||||
def _chunk_text(
|
||||
text: str,
|
||||
max_chars: int = 2000,
|
||||
overlap: int = 150
|
||||
):
|
||||
text = text.strip()
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(text):
|
||||
end = min(len(text), start + max_chars)
|
||||
chunk = text[start:end].strip()
|
||||
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
if end == len(text):
|
||||
break
|
||||
|
||||
start = max(0, end - overlap)
|
||||
|
||||
return chunks
|
||||
Reference in New Issue
Block a user