Files
FiDA_Python/src/server/deep_agent/tools/structured_retrieval_tool.py

234 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import re
import json
from datetime import datetime
from typing import List, Dict, Optional
from pydantic import BaseModel, Field
from langchain_core.tools import tool
from langchain_core.documents import Document
# RAG
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
# =========================
# 全局模型(单例)
# =========================
_EMBEDDING_MODEL = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
_RERANK_MODEL = CrossEncoder(
"cross-encoder/ms-marco-MiniLM-L-6-v2"
)
class StructuredRetrievalInput(BaseModel):
file_paths: List[str] = Field(..., description="List of local markdown file paths.")
query: str = Field(..., description="Extraction query")
source_url: Optional[str] = Field(None, description="Optional global source URL")
def _extract_source_from_md(content: str) -> Optional[str]:
match = re.search(r"<!--\s*Source:\s*(.*?)\s*-->", content)
return match.group(1).strip() if match else None
# =========================
# Markdown Header Split
# =========================
def _split_markdown_by_headers(
content: str,
max_chars: int = 2000,
overlap: int = 150,
):
header_re = re.compile(
r'^(#{1,6})\s+(.+?)\s*$',
re.MULTILINE
)
matches = list(header_re.finditer(content))
if not matches:
return _chunk_text(content, max_chars, overlap)
sections = []
for i, m in enumerate(matches):
start = m.start()
end = (
matches[i + 1].start()
if i + 1 < len(matches)
else len(content)
)
block = content[start:end].strip()
if block:
sections.append(block)
final_sections = []
for s in sections:
if len(s) > max_chars:
final_sections.extend(
_chunk_text(s, max_chars, overlap)
)
else:
final_sections.append(s)
return final_sections
def _chunk_text(
text: str,
max_chars: int = 2000,
overlap: int = 150
):
text = text.strip()
if len(text) <= max_chars:
return [text]
chunks = []
start = 0
while start < len(text):
end = min(len(text), start + max_chars)
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
if end == len(text):
break
start = max(0, end - overlap)
return chunks
def create_structured_retrieval_tool(workspace_dir):
@tool("structured_retrieval", args_schema=StructuredRetrievalInput)
def structured_retrieval(
file_paths: List[str],
query: str,
source_url: Optional[str] = None
) -> Dict:
"""
Batch structured extraction from markdown files.
- Performs vector search + re-ranking
- Saves extracted structured data as JSON file to disk
- Returns ONLY summary (status, count, file path)
"""
# ── 1. 收集所有文件內容 ──────────────────────────────────────
all_docs_pool: List[Document] = []
for path in file_paths:
if not os.path.exists(path) or not path.endswith((".md", ".markdown")):
continue
file_name = os.path.basename(path)
with open(path, "r", encoding="utf-8") as f:
content = f.read()
current_source = source_url or _extract_source_from_md(content) or "unknown"
sections = _split_markdown_by_headers(content)
for sec in sections:
all_docs_pool.append(
Document(
page_content=sec,
metadata={"source_url": current_source, "file_name": file_name}
)
)
if not all_docs_pool:
return {"status": "no_documents_found", "items_count": 0, "json_path": None}
# ── 2. Vector search ────────────────────────────────────────────
vector_store = FAISS.from_documents(all_docs_pool, _EMBEDDING_MODEL)
retrieved = vector_store.similarity_search(query, k=200)
# ── 3. 提取結構化片段 ──────────────────────────────────────────
structured_items = []
for doc in retrieved:
text = doc.page_content.strip()
if len(text) < 30:
continue
images = list(set(re.findall(r"!\[.*?\]\((.*?)\)", text)))
structured_items.append(
{
"text": text,
"images": images,
"source_url": doc.metadata.get("source_url"),
"file_name": doc.metadata.get("file_name")
}
)
# ── 4. Re-rank ──────────────────────────────────────────────────
if structured_items:
unique_items = {item["text"]: item for item in structured_items}.values()
pairs = [[query, item["text"]] for item in unique_items]
scores = _RERANK_MODEL.predict(pairs)
sorted_items = sorted(
zip(scores, unique_items),
key=lambda x: x[0],
reverse=True
)
top_items = [item for _, item in sorted_items[:50]]
else:
top_items = []
# ── 5. 寫入 JSON 文件 ──────────────────────────────────────────
if not top_items:
return {"status": "no_relevant_content", "items_count": 0, "json_path": None}
# 產生有意義的檔名
safe_query = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '_', query)[:40]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
json_filename = f"extracted_{safe_query}_{timestamp}.json"
# 建議的儲存目錄(與 crawl4ai_batch 對齊)
output_dir = os.path.join(workspace_dir, "extracted")
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
if not os.path.exists(output_dir):
# 2. 不存在则创建makedirs 支持创建多级目录mkdir 只能创建单级)
os.makedirs(output_dir, exist_ok=True)
json_path = os.path.join(output_dir, json_filename)
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{
"query": query,
"extracted_at": timestamp,
"item_count": len(top_items),
"items": top_items
},
f,
ensure_ascii=False,
indent=2
)
# ── 6. 只回傳摘要 ──────────────────────────────────────────────
return {
"status": "success",
"items_count": len(top_items),
"json_path": json_path,
"summary": f"已提取 {len(top_items)} 個高相關片段,儲存於 {json_path}"
}
return structured_retrieval