import os import re import json from datetime import datetime from typing import List, Dict, Optional from pydantic import BaseModel, Field from langchain_core.tools import tool from langchain_core.documents import Document # RAG from langchain_community.vectorstores import FAISS from langchain_huggingface import HuggingFaceEmbeddings from sentence_transformers import CrossEncoder # ========================= # 全局模型(单例) # ========================= _EMBEDDING_MODEL = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2" ) _RERANK_MODEL = CrossEncoder( "cross-encoder/ms-marco-MiniLM-L-6-v2" ) class StructuredRetrievalInput(BaseModel): file_paths: List[str] = Field(..., description="List of local markdown file paths.") query: str = Field(..., description="Extraction query") source_url: Optional[str] = Field(None, description="Optional global source URL") def _extract_source_from_md(content: str) -> Optional[str]: match = re.search(r"", content) return match.group(1).strip() if match else None # ========================= # Markdown Header Split # ========================= def _split_markdown_by_headers( content: str, max_chars: int = 2000, overlap: int = 150, ): header_re = re.compile( r'^(#{1,6})\s+(.+?)\s*$', re.MULTILINE ) matches = list(header_re.finditer(content)) if not matches: return _chunk_text(content, max_chars, overlap) sections = [] for i, m in enumerate(matches): start = m.start() end = ( matches[i + 1].start() if i + 1 < len(matches) else len(content) ) block = content[start:end].strip() if block: sections.append(block) final_sections = [] for s in sections: if len(s) > max_chars: final_sections.extend( _chunk_text(s, max_chars, overlap) ) else: final_sections.append(s) return final_sections def _chunk_text( text: str, max_chars: int = 2000, overlap: int = 150 ): text = text.strip() if len(text) <= max_chars: return [text] chunks = [] start = 0 while start < len(text): end = min(len(text), start + max_chars) chunk = text[start:end].strip() if chunk: chunks.append(chunk) if end == len(text): break start = max(0, end - overlap) return chunks def create_structured_retrieval_tool(workspace_dir): @tool("structured_retrieval", args_schema=StructuredRetrievalInput) def structured_retrieval( file_paths: List[str], query: str, source_url: Optional[str] = None ) -> Dict: """ Batch structured extraction from markdown files. - Performs vector search + re-ranking - Saves extracted structured data as JSON file to disk - Returns ONLY summary (status, count, file path) """ # ── 1. 收集所有文件內容 ────────────────────────────────────── all_docs_pool: List[Document] = [] for path in file_paths: if not os.path.exists(path) or not path.endswith((".md", ".markdown")): continue file_name = os.path.basename(path) with open(path, "r", encoding="utf-8") as f: content = f.read() current_source = source_url or _extract_source_from_md(content) or "unknown" sections = _split_markdown_by_headers(content) for sec in sections: all_docs_pool.append( Document( page_content=sec, metadata={"source_url": current_source, "file_name": file_name} ) ) if not all_docs_pool: return {"status": "no_documents_found", "items_count": 0, "json_path": None} # ── 2. Vector search ──────────────────────────────────────────── vector_store = FAISS.from_documents(all_docs_pool, _EMBEDDING_MODEL) retrieved = vector_store.similarity_search(query, k=200) # ── 3. 提取結構化片段 ────────────────────────────────────────── structured_items = [] for doc in retrieved: text = doc.page_content.strip() if len(text) < 30: continue images = list(set(re.findall(r"!\[.*?\]\((.*?)\)", text))) structured_items.append( { "text": text, "images": images, "source_url": doc.metadata.get("source_url"), "file_name": doc.metadata.get("file_name") } ) # ── 4. Re-rank ────────────────────────────────────────────────── if structured_items: unique_items = {item["text"]: item for item in structured_items}.values() pairs = [[query, item["text"]] for item in unique_items] scores = _RERANK_MODEL.predict(pairs) sorted_items = sorted( zip(scores, unique_items), key=lambda x: x[0], reverse=True ) top_items = [item for _, item in sorted_items[:50]] else: top_items = [] # ── 5. 寫入 JSON 文件 ────────────────────────────────────────── if not top_items: return {"status": "no_relevant_content", "items_count": 0, "json_path": None} # 產生有意義的檔名 safe_query = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fa5]', '_', query)[:40] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") json_filename = f"extracted_{safe_query}_{timestamp}.json" # 建議的儲存目錄(與 crawl4ai_batch 對齊) output_dir = os.path.join(workspace_dir, "extracted") if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) if not os.path.exists(output_dir): # 2. 不存在则创建(makedirs 支持创建多级目录,mkdir 只能创建单级) os.makedirs(output_dir, exist_ok=True) json_path = os.path.join(output_dir, json_filename) with open(json_path, "w", encoding="utf-8") as f: json.dump( { "query": query, "extracted_at": timestamp, "item_count": len(top_items), "items": top_items }, f, ensure_ascii=False, indent=2 ) # ── 6. 只回傳摘要 ────────────────────────────────────────────── return { "status": "success", "items_count": len(top_items), "json_path": json_path, "summary": f"已提取 {len(top_items)} 個高相關片段,儲存於 {json_path}" } return structured_retrieval