Files
FiDA_Python/src/server/deep_agent/tools/structured_retrieval_tool.py

226 lines
6.8 KiB
Python
Raw Normal View History

2026-03-11 21:45:46 +08:00
import os
import re
import json
from datetime import datetime
from typing import List, Dict, Optional
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from langchain_core.documents import Document
# RAG
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
# =========================
# 全局模型(单例)
# =========================
_EMBEDDING_MODEL = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
_RERANK_MODEL = CrossEncoder(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
)
class StructuredRetrievalInput(BaseModel):
file_paths: List[str] = Field(..., description="List of local markdown file paths.")
query: str = Field(..., description="Extraction query")
source_url: Optional[str] = Field(None, description="Optional global source URL")
@tool("structured_retrieval", args_schema=StructuredRetrievalInput)
def structured_retrieval(
file_paths: List[str],
query: str,
source_url: Optional[str] = None
) -> Dict:
"""
Batch structured extraction from markdown files.
- Performs vector search + re-ranking
- Saves extracted structured data as JSON file to disk
- Returns ONLY summary (status, count, file path)
"""
# ── 1. 收集所有文件內容 ──────────────────────────────────────
all_docs_pool: List[Document] = []
for path in file_paths:
if not os.path.exists(path) or not path.endswith((".md", ".markdown")):
continue
file_name = os.path.basename(path)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
current_source = source_url or _extract_source_from_md(content) or "unknown"
sections = _split_markdown_by_headers(content)
for sec in sections:
all_docs_pool.append(
Document(
page_content=sec,
metadata={"source_url": current_source, "file_name": file_name}
)
)
if not all_docs_pool:
return {"status": "no_documents_found", "items_count": 0, "json_path": None}
# ── 2. Vector search ────────────────────────────────────────────
vector_store = FAISS.from_documents(all_docs_pool, _EMBEDDING_MODEL)
retrieved = vector_store.similarity_search(query, k=200)
# ── 3. 提取結構化片段 ──────────────────────────────────────────
structured_items = []
for doc in retrieved:
text = doc.page_content.strip()
if len(text) < 30:
continue
images = list(set(re.findall(r"!\[.*?\]\((.*?)\)", text)))
structured_items.append(
{
"text": text,
"images": images,
"source_url": doc.metadata.get("source_url"),
"file_name": doc.metadata.get("file_name")
}
)
# ── 4. Re-rank ──────────────────────────────────────────────────
if structured_items:
unique_items = {item["text"]: item for item in structured_items}.values()
pairs = [[query, item["text"]] for item in unique_items]
scores = _RERANK_MODEL.predict(pairs)
sorted_items = sorted(
zip(scores, unique_items),
key=lambda x: x[0],
reverse=True
)
top_items = [item for _, item in sorted_items[:50]]
else:
top_items = []
# ── 5. 寫入 JSON 文件 ──────────────────────────────────────────
if not top_items:
return {"status": "no_relevant_content", "items_count": 0, "json_path": None}
# 產生有意義的檔名
safe_query = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '_', query)[:40]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
json_filename = f"extracted_{safe_query}_{timestamp}.json"
# 建議的儲存目錄(與 crawl4ai_batch 對齊)
output_dir = os.path.join(os.path.dirname(file_paths[0]), "..", "extracted")
os.makedirs(output_dir, exist_ok=True)
json_path = os.path.join(output_dir, json_filename)
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{
"query": query,
"extracted_at": timestamp,
"item_count": len(top_items),
"items": top_items
},
f,
ensure_ascii=False,
indent=2
)
# ── 6. 只回傳摘要 ──────────────────────────────────────────────
return {
"status": "success",
"items_count": len(top_items),
"json_path": json_path,
"summary": f"已提取 {len(top_items)} 個高相關片段,儲存於 {json_path}"
}
def _extract_source_from_md(content: str) -> Optional[str]:
match = re.search(r"<!--\s*Source:\s*(.*?)\s*-->", content)
return match.group(1).strip() if match else None
# =========================
# Markdown Header Split
# =========================
def _split_markdown_by_headers(
content: str,
max_chars: int = 2000,
overlap: int = 150,
):
header_re = re.compile(
r'^(#{1,6})\s+(.+?)\s*$',
re.MULTILINE
)
matches = list(header_re.finditer(content))
if not matches:
return _chunk_text(content, max_chars, overlap)
sections = []
for i, m in enumerate(matches):
start = m.start()
end = (
matches[i + 1].start()
if i + 1 < len(matches)
else len(content)
)
block = content[start:end].strip()
if block:
sections.append(block)
final_sections = []
for s in sections:
if len(s) > max_chars:
final_sections.extend(
_chunk_text(s, max_chars, overlap)
)
else:
final_sections.append(s)
return final_sections
def _chunk_text(
text: str,
max_chars: int = 2000,
overlap: int = 150
):
text = text.strip()
if len(text) <= max_chars:
return [text]
chunks = []
start = 0
while start < len(text):
end = min(len(text), start + max_chars)
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
if end == len(text):
break
start = max(0, end - overlap)
return chunks