利用 LangGraph 搭建智能体:实现多源异构数据的实时清洗与知识图谱构建
前言
在大数据与人工智能深度融合的今天,知识图谱已成为企业智能化转型的核心基础设施之一。然而,知识图谱的构建并非一蹴而就——如何从多个异构数据源中高效提取、清洗、转换并存储结构化知识,始终是工程实践中的重大挑战。
传统的 ETL(Extract-Transform-Load)管道往往是静态的、串行的,难以应对数据格式多样、质量参差不齐、规模庞大等复杂场景。而随着大语言模型(LLM)与智能体(Agent)技术的成熟,我们获得了一种全新的解决范式:利用 LangGraph 构建多智能体协作系统,结合并行处理与实时清洗,将多源异构数据动态写入 Neo4j 知识图谱。
本文将深入讲解这一完整技术方案的设计思路、架构细节与代码实现,帮助读者在实际项目中落地这一先进的智能体工作流。
一、技术背景与核心概念
1.1 LangGraph 是什么?
LangGraph 是由 LangChain 团队推出的一个专为构建有状态、多步骤、多智能体工作流而设计的框架。与传统的 LangChain Chain 相比,LangGraph 引入了**图(Graph)**的概念——将整个工作流建模为有向图(DAG 或包含环的图),每个节点(Node)代表一个处理步骤或智能体,边(Edge)定义了数据流转与控制逻辑。
LangGraph 的核心优势包括:
状态管理:通过
StateGraph统一管理全局状态,各节点可读写共享状态条件路由:支持基于条件的动态分支,实现复杂的决策逻辑
并行执行:支持多个节点并行运行,大幅提升处理效率
循环与反馈:支持工作流中的循环,实现自我修正与迭代处理
人机协作(Human-in-the-Loop):可在关键节点暂停并等待人工确认
1.2 Neo4j 知识图谱存储
Neo4j 是目前最主流的图数据库,采用属性图(Property Graph)模型,以节点(Node)和关系(Relationship)为基本单元存储知识。其 Cypher 查询语言直观易用,且提供了 Python 驱动 neo4j 和高层封装 langchain-community 中的 Neo4jGraph,非常适合与 LangGraph 集成。
1.3 多源异构数据的挑战
在典型的企业场景中,知识图谱的数据来源可能包括:
面对如此复杂的数据生态,我们需要一套可扩展、可并行、可自适应的处理框架。
二、整体架构设计
2.1 系统架构概览
整个系统可以划分为以下几个核心层次:
┌─────────────────────────────────────────────────────────────┐
│ 数据源层(Data Sources) │
│ [PDF文件] [CSV/Excel] [MySQL DB] [MongoDB] [Web API] [HTML] │
└──────────────────────────┬──────────────────────────────────┘
│
┌──────────────────────────▼──────────────────────────────────┐
│ LangGraph 多智能体编排层 │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
│ │ 规划智能体 │ │ 分发智能体 │ │ 监督/汇总智能体 │ │
│ │ (Planner) │ │(Dispatcher) │ │ (Supervisor) │ │
│ └─────────────┘ └─────────────┘ └─────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ 并行工作智能体池 │ │
│ │ [PDF解析器] [CSV清洗器] [DB提取器] [文本NER] [API拉取] │ │
│ └──────────────────────────────────────────────────────┘ │
└──────────────────────────┬──────────────────────────────────┘
│
┌──────────────────────────▼──────────────────────────────────┐
│ 数据清洗与标准化层 │
│ [去重] [格式统一] [实体对齐] [关系验证] [质量评分] │
└──────────────────────────┬──────────────────────────────────┘
│
┌──────────────────────────▼──────────────────────────────────┐
│ 知识写入层(Neo4j) │
│ [节点创建] [关系创建] [属性更新] [索引维护] [事务管理] │
└─────────────────────────────────────────────────────────────┘
2.2 LangGraph 工作流图设计
核心的 LangGraph 工作流图包含以下节点:
intake_node(数据接收节点):接收所有数据源描述,初始化状态planner_node(规划节点):分析数据源类型,制定处理计划dispatcher_node(分发节点):将任务分配给对应的处理智能体并行处理节点组:
pdf_processor_nodecsv_processor_nodedb_extractor_nodetext_ner_nodeapi_fetcher_node
cleaner_node(清洗汇总节点):合并并深度清洗所有提取结果kg_builder_node(知识图谱构建节点):生成 Cypher 语句并写入 Neo4jvalidator_node(验证节点):验证写入结果,决定是否需要重试end_node(结束节点):生成处理报告
三、环境准备与依赖安装
3.1 安装必要依赖
# 核心框架
pip install langgraph langchain langchain-openai langchain-community
# 数据处理
pip install pandas openpyxl pymupdf python-docx beautifulsoup4 requests
# 数据库驱动
pip install neo4j pymysql pymongo
# NLP 工具
pip install spacy transformers
python -m spacy download zh_core_web_sm
# 其他工具
pip install pydantic python-dotenv tqdm
3.2 环境配置
# config.py
import os
from dotenv import load_dotenv
load_dotenv()
class Config:
# LLM 配置
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o")
# Neo4j 配置
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
# 数据源配置
MYSQL_URI = os.getenv("MYSQL_URI")
MONGO_URI = os.getenv("MONGO_URI")
# 处理配置
MAX_PARALLEL_WORKERS = int(os.getenv("MAX_PARALLEL_WORKERS", "5"))
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "100"))
QUALITY_THRESHOLD = float(os.getenv("QUALITY_THRESHOLD", "0.75"))
四、核心代码实现
4.1 定义全局状态结构
LangGraph 的精髓在于共享状态管理。我们首先定义一个完整的状态类:
# state.py
from typing import TypedDict, List, Dict, Any, Optional, Annotated
from operator import add
import operator
class DataSourceConfig(TypedDict):
"""单个数据源的配置描述"""
source_id: str
source_type: str # pdf, csv, mysql, mongodb, api, html
source_path: str # 文件路径、数据库连接串或 URL
schema_hint: Optional[str] # 数据模式提示
priority: int # 处理优先级
class ExtractedTriple(TypedDict):
"""提取出的知识三元组"""
subject: str
subject_type: str
predicate: str
object: str
object_type: str
confidence: float
source_id: str
raw_text: Optional[str]
class ProcessingTask(TypedDict):
"""单个处理任务"""
task_id: str
source_config: DataSourceConfig
status: str # pending, processing, completed, failed
result: Optional[List[ExtractedTriple]]
error: Optional[str]
class KGBuildState(TypedDict):
"""LangGraph 全局状态"""
# 输入
data_sources: List[DataSourceConfig]
domain_context: str # 知识图谱的领域上下文
# 规划阶段
processing_plan: Optional[Dict[str, Any]]
tasks: Annotated[List[ProcessingTask], add] # 支持并行追加
# 提取阶段
raw_extractions: Annotated[List[ExtractedTriple], add] # 并行追加
# 清洗阶段
cleaned_triples: List[ExtractedTriple]
rejected_triples: List[ExtractedTriple]
quality_report: Optional[Dict[str, Any]]
# 写入阶段
cypher_statements: List[str]
write_results: Annotated[List[Dict], add]
# 控制流
current_phase: str
retry_count: int
errors: Annotated[List[str], add]
# 最终报告
final_report: Optional[Dict[str, Any]]
4.2 规划智能体(Planner Node)
规划智能体负责分析数据源并制定最优处理策略:
# nodes/planner.py
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
import json
from state import KGBuildState
from config import Config
llm = ChatOpenAI(
model=Config.OPENAI_MODEL,
api_key=Config.OPENAI_API_KEY,
temperature=0
)
PLANNER_PROMPT = ChatPromptTemplate.from_messages([
("system", """你是一个专业的知识图谱构建规划师。
你的任务是分析给定的数据源列表,制定一个高效的并行处理计划。
请根据以下原则制定计划:
1. 识别每个数据源的类型和处理复杂度
2. 将可以并行处理的任务分组
3. 识别数据源之间的依赖关系
4. 为每个处理任务分配合适的智能体类型
5. 估算处理时间和资源需求
可用的智能体类型:
- pdf_processor: 处理 PDF 和 Word 文档
- csv_processor: 处理 CSV 和 Excel 文件
- db_extractor: 从关系型或文档型数据库提取数据
- text_ner: 对非结构化文本进行命名实体识别
- api_fetcher: 从 Web API 或 HTML 页面获取数据
输出格式必须是合法的 JSON,包含以下字段:
{{
"parallel_groups": [
{{
"group_id": "group_1",
"tasks": ["source_id_1", "source_id_2"],
"estimated_duration": 30,
"agent_assignments": {{"source_id_1": "pdf_processor"}}
}}
],
"dependencies": {{}},
"total_estimated_duration": 90,
"priority_order": [],
"special_instructions": {{}}
}}
"""),
("human", """
领域上下文: {domain_context}
数据源列表:
{data_sources}
请制定处理计划。
""")
])
def planner_node(state: KGBuildState) -> KGBuildState:
"""规划节点:分析数据源,制定处理计划"""
print("🧠 [规划智能体] 正在分析数据源并制定处理计划...")
data_sources_str = json.dumps(state["data_sources"], ensure_ascii=False, indent=2)
chain = PLANNER_PROMPT | llm | JsonOutputParser()
try:
plan = chain.invoke({
"domain_context": state["domain_context"],
"data_sources": data_sources_str
})
print(f"✅ [规划智能体] 规划完成,共 {len(plan.get('parallel_groups', []))} 个并行处理组")
return {
**state,
"processing_plan": plan,
"current_phase": "dispatching"
}
except Exception as e:
error_msg = f"规划阶段失败: {str(e)}"
print(f"❌ [规划智能体] {error_msg}")
return {
**state,
"errors": [error_msg],
"current_phase": "error"
}
4.3 各类数据处理智能体
4.3.1 PDF/文档处理器
# nodes/processors/pdf_processor.py
import fitz # PyMuPDF
from docx import Document
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from typing import List
import re
from state import ProcessingTask, ExtractedTriple
from config import Config
llm = ChatOpenAI(model=Config.OPENAI_MODEL, temperature=0)
EXTRACTION_PROMPT = ChatPromptTemplate.from_messages([
("system", """你是一个专业的知识抽取专家。
从给定的文本中提取知识三元组(Subject-Predicate-Object)。
要求:
1. 提取尽可能多的有效知识三元组
2. 主语和宾语应为具体的实体(人名、机构、概念、地点等)
3. 谓语应为有意义的关系(如"属于"、"创立于"、"位于"等)
4. 为每个三元组评估置信度(0-1之间)
5. 标注实体类型(Person/Organization/Location/Concept/Event/Product等)
输出 JSON 数组格式:
[
{{
"subject": "实体名称",
"subject_type": "实体类型",
"predicate": "关系",
"object": "实体名称",
"object_type": "实体类型",
"confidence": 0.95,
"raw_text": "原始文本片段"
}}
]
"""),
("human", "领域上下文: {domain_context}\n\n文本内容:\n{text}")
])
def extract_text_from_pdf(file_path: str) -> str:
"""从 PDF 文件提取文本"""
doc = fitz.open(file_path)
text_parts = []
for page in doc:
text_parts.append(page.get_text("text"))
return "\n".join(text_parts)
def extract_text_from_docx(file_path: str) -> str:
"""从 Word 文档提取文本"""
doc = Document(file_path)
return "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
def chunk_text(text: str, chunk_size: int = 2000, overlap: int = 200) -> List[str]:
"""将长文本切分为带重叠的块"""
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
# 尝试在句子边界切分
if end < len(text):
last_period = max(chunk.rfind('。'), chunk.rfind('.'), chunk.rfind('\n'))
if last_period > chunk_size * 0.7:
chunk = chunk[:last_period + 1]
end = start + last_period + 1
chunks.append(chunk)
start = end - overlap
return chunks
def process_pdf_task(task: ProcessingTask, domain_context: str) -> ProcessingTask:
"""处理单个 PDF/文档任务"""
source_config = task["source_config"]
source_path = source_config["source_path"]
source_id = source_config["source_id"]
print(f"📄 [PDF处理器] 开始处理: {source_path}")
try:
# 提取文本
if source_path.lower().endswith('.pdf'):
text = extract_text_from_pdf(source_path)
elif source_path.lower().endswith(('.docx', '.doc')):
text = extract_text_from_docx(source_path)
else:
with open(source_path, 'r', encoding='utf-8') as f:
text = f.read()
# 文本清洗
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[^\u4e00-\u9fff\w\s.,;:!?()()【】\-—]', '', text)
# 分块处理
chunks = chunk_text(text)
all_triples = []
chain = EXTRACTION_PROMPT | llm | JsonOutputParser()
for i, chunk in enumerate(chunks):
print(f" 📝 处理文本块 {i+1}/{len(chunks)}")
try:
triples = chain.invoke({
"domain_context": domain_context,
"text": chunk
})
for triple in triples:
triple["source_id"] = source_id
all_triples.append(triple)
except Exception as e:
print(f" ⚠️ 块 {i+1} 处理失败: {e}")
print(f"✅ [PDF处理器] {source_path} 提取完成,共 {len(all_triples)} 个三元组")
return {
**task,
"status": "completed",
"result": all_triples
}
except Exception as e:
error_msg = f"PDF处理失败 ({source_path}): {str(e)}"
print(f"❌ [PDF处理器] {error_msg}")
return {
**task,
"status": "failed",
"error": error_msg,
"result": []
}
4.3.2 CSV/Excel 处理器
# nodes/processors/csv_processor.py
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from typing import List, Dict
import json
import numpy as np
from state import ProcessingTask, ExtractedTriple
from config import Config
llm = ChatOpenAI(model=Config.OPENAI_MODEL, temperature=0)
SCHEMA_ANALYSIS_PROMPT = ChatPromptTemplate.from_messages([
("system", """你是数据结构分析专家。分析给定的 CSV/Excel 数据结构,
识别哪些列可以作为知识图谱中的实体或属性,以及列之间可能存在的关系。
输出 JSON 格式:
{{
"entity_columns": ["col1", "col2"],
"relation_patterns": [
{{"subject_col": "col1", "predicate": "关系名", "object_col": "col2"}}
],
"attribute_patterns": [
{{"entity_col": "col1", "attribute_col": "col3", "attribute_name": "属性名"}}
]
}}
"""),
("human", "领域上下文: {domain_context}\n\n数据样本(前5行):\n{sample_data}\n\n列名: {columns}")
])
def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""数据框清洗"""
# 删除完全重复的行
df = df.drop_duplicates()
# 处理缺失值
for col in df.columns:
if df[col].dtype == object:
df[col] = df[col].fillna('').astype(str).str.strip()
# 移除明显的噪声值
df[col] = df[col].replace(['N/A', 'null', 'NULL', 'None', 'nan', '#N/A'], '')
elif df[col].dtype in [np.float64, np.int64]:
df[col] = df[col].fillna(0)
# 删除全空行
df = df[~(df == '').all(axis=1)]
return df
def process_csv_task(task: ProcessingTask, domain_context: str) -> ProcessingTask:
"""处理 CSV/Excel 文件"""
source_config = task["source_config"]
source_path = source_config["source_path"]
source_id = source_config["source_id"]
print(f"📊 [CSV处理器] 开始处理: {source_path}")
try:
# 读取文件
if source_path.lower().endswith(('.xlsx', '.xls')):
df = pd.read_excel(source_path)
else:
# 尝试多种编码
for encoding in ['utf-8', 'gbk', 'gb2312', 'utf-8-sig']:
try:
df = pd.read_csv(source_path, encoding=encoding)
break
except UnicodeDecodeError:
continue
# 数据清洗
df = clean_dataframe(df)
print(f" 📋 数据维度: {df.shape[0]} 行 × {df.shape[1]} 列")
# 分析数据模式
sample_data = df.head(5).to_string()
columns = list(df.columns)
chain = SCHEMA_ANALYSIS_PROMPT | llm | JsonOutputParser()
schema_analysis = chain.invoke({
"domain_context": domain_context,
"sample_data": sample_data,
"columns": json.dumps(columns, ensure_ascii=False)
})
# 根据分析结果提取三元组
all_triples = []
# 处理关系模式
for pattern in schema_analysis.get("relation_patterns", []):
subj_col = pattern.get("subject_col")
pred = pattern.get("predicate")
obj_col = pattern.get("object_col")
if subj_col in df.columns and obj_col in df.columns:
for _, row in df.iterrows():
subj_val = str(row[subj_col]).strip()
obj_val = str(row[obj_col]).strip()
if subj_val and obj_val and subj_val != obj_val:
all_triples.append({
"subject": subj_val,
"subject_type": "Entity",
"predicate": pred,
"object": obj_val,
"object_type": "Entity",
"confidence": 0.9,
"source_id": source_id,
"raw_text": f"{subj_col}:{subj_val} -> {pred} -> {obj_col}:{obj_val}"
})
# 处理属性模式
for pattern in schema_analysis.get("attribute_patterns", []):
entity_col = pattern.get("entity_col")
attr_col = pattern.get("attribute_col")
attr_name = pattern.get("attribute_name", attr_col)
if entity_col in df.columns and attr_col in df.columns:
for _, row in df.iterrows():
entity_val = str(row[entity_col]).strip()
attr_val = str(row[attr_col]).strip()
if entity_val and attr_val:
all_triples.append({
"subject": entity_val,
"subject_type": "Entity",
"predicate": f"具有{attr_name}",
"object": attr_val,
"object_type": "Attribute",
"confidence": 0.85,
"source_id": source_id,
"raw_text": f"{entity_col}:{entity_val} 的 {attr_name} 为 {attr_val}"
})
print(f"✅ [CSV处理器] 提取完成,共 {len(all_triples)} 个三元组")
return {
**task,
"status": "completed",
"result": all_triples
}
except Exception as e:
error_msg = f"CSV处理失败 ({source_path}): {str(e)}"
print(f"❌ [CSV处理器] {error_msg}")
return {
**task,
"status": "failed",
"error": error_msg,
"result": []
}
4.3.3 数据库提取器
# nodes/processors/db_extractor.py
import pymysql
import pymongo
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from typing import List, Dict, Any
import json
from state import ProcessingTask
from config import Config
llm = ChatOpenAI(model=Config.OPENAI_MODEL, temperature=0)
SQL_GENERATION_PROMPT = ChatPromptTemplate.from_messages([
("system", """你是数据库查询专家。根据数据库模式信息,生成用于知识图谱构建的 SQL 查询。
查询应该能提取实体和关系信息。输出 JSON:
{{
"queries": [
{{
"description": "查询描述",
"sql": "SELECT ...",
"subject_col": "主体列名",
"predicate": "关系名称",
"object_col": "客体列名"
}}
]
}}
"""),
("human", "领域上下文: {domain_context}\n\n数据库表结构:\n{schema}\n\n目标: 提取知识三元组")
])
def get_mysql_schema(connection_str: str) -> str:
"""获取 MySQL 数据库结构"""
# 解析连接字符串
import re
pattern = r'mysql://(\w+):(\w+)@([\w.]+):(\d+)/(\w+)'
match = re.match(pattern, connection_str)
if not match:
raise ValueError(f"无效的 MySQL 连接字符串: {connection_str}")
user, password, host, port, database = match.groups()
conn = pymysql.connect(
host=host, port=int(port),
user=user, password=password,
database=database, charset='utf8mb4'
)
schema_info = []
try:
with conn.cursor() as cursor:
cursor.execute("SHOW TABLES")
tables = [row[0] for row in cursor.fetchall()]
for table in tables:
cursor.execute(f"DESCRIBE {table}")
columns = cursor.fetchall()
col_info = [f"{col[0]} ({col[1]})" for col in columns]
schema_info.append(f"表 {table}: {', '.join(col_info)}")
finally:
conn.close()
return "\n".join(schema_info)
def process_db_task(task: ProcessingTask, domain_context: str) -> ProcessingTask:
"""从数据库提取数据"""
source_config = task["source_config"]
source_path = source_config["source_path"] # 数据库连接字符串
source_id = source_config["source_id"]
print(f"🗄️ [DB提取器] 开始处理数据库: {source_id}")
try:
all_triples = []
if source_path.startswith("mysql://"):
schema = get_mysql_schema(source_path)
# 生成查询
chain = SQL_GENERATION_PROMPT | llm | JsonOutputParser()
query_plan = chain.invoke({
"domain_context": domain_context,
"schema": schema
})
# 执行查询并提取三元组
import re
match = re.match(r'mysql://(\w+):(\w+)@([\w.]+):(\d+)/(\w+)', source_path)
user, password, host, port, database = match.groups()
conn = pymysql.connect(
host=host, port=int(port),
user=user, password=password,
database=database, charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
try:
for query_info in query_plan.get("queries", []):
with conn.cursor() as cursor:
cursor.execute(query_info["sql"])
rows = cursor.fetchall()
for row in rows:
subj = str(row.get(query_info["subject_col"], "")).strip()
obj = str(row.get(query_info["object_col"], "")).strip()
if subj and obj:
all_triples.append({
"subject": subj,
"subject_type": "Entity",
"predicate": query_info["predicate"],
"object": obj,
"object_type": "Entity",
"confidence": 0.95,
"source_id": source_id,
"raw_text": json.dumps(dict(row), ensure_ascii=False)
})
finally:
conn.close()
elif source_path.startswith("mongodb://"):
client = pymongo.MongoClient(source_path)
db_name = source_config.get("schema_hint", "knowledge_db")
db = client[db_name]
for collection_name in db.list_collection_names():
collection = db[collection_name]
# 取样分析
sample_docs = list(collection.find().limit(Config.BATCH_SIZE))
for doc in sample_docs:
# 将文档中的字段对转换为三元组
doc_id = str(doc.get("_id", ""))
entity_name = doc.get("name") or doc.get("title") or doc_id
for key, value in doc.items():
if key not in ["_id"] and isinstance(value, str) and value:
all_triples.append({
"subject": str(entity_name),
"subject_type": collection_name,
"predicate": f"具有{key}",
"object": str(value),
"object_type": "Attribute",
"confidence": 0.88,
"source_id": source_id,
"raw_text": f"{collection_name}.{entity_name}.{key}={value}"
})
print(f"✅ [DB提取器] {source_id} 提取完成,共 {len(all_triples)} 个三元组")
return {
**task,
"status": "completed",
"result": all_triples
}
except Exception as e:
error_msg = f"数据库提取失败 ({source_id}): {str(e)}"
print(f"❌ [DB提取器] {error_msg}")
return {
**task,
"status": "failed",
"error": error_msg,
"result": []
}
4.4 并行处理分发节点
这是整个架构中最关键的节点之一,负责并行调度所有处理任务:
# nodes/dispatcher.py
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Callable
import uuid
from state import KGBuildState, ProcessingTask
from config import Config
from nodes.processors.pdf_processor import process_pdf_task
from nodes.processors.csv_processor import process_csv_task
from nodes.processors.db_extractor import process_db_task
from nodes.processors.api_fetcher import process_api_task
# 智能体类型到处理函数的映射
PROCESSOR_MAP: Dict[str, Callable] = {
"pdf_processor": process_pdf_task,
"csv_processor": process_csv_task,
"db_extractor": process_db_task,
"api_fetcher": process_api_task,
"text_ner": process_pdf_task, # 复用文本处理逻辑
}
def dispatcher_node(state: KGBuildState) -> KGBuildState:
"""
分发节点:根据规划结果,并行执行所有数据处理任务
"""
print("🚀 [分发智能体] 开始并行分发处理任务...")
plan = state.get("processing_plan", {})
data_sources = state["data_sources"]
domain_context = state["domain_context"]
# 构建数据源 ID 到配置的映射
source_map = {s["source_id"]: s for s in data_sources}
# 构建智能体分配映射
agent_assignments = {}
for group in plan.get("parallel_groups", []):
agent_assignments.update(group.get("agent_assignments", {}))
# 创建任务列表
tasks: List[ProcessingTask] = []
for source in data_sources:
source_id = source["source_id"]
agent_type = agent_assignments.get(source_id,
_infer_agent_type(source["source_type"]))
tasks.append({
"task_id": str(uuid.uuid4()),
"source_config": source,
"assigned_agent": agent_type,
"status": "pending",
"result": None,
"error": None
})
print(f" 📋 共 {len(tasks)} 个任务待处理")
# 并行执行所有任务
completed_tasks = []
all_raw_extractions = []
all_errors = []
with ThreadPoolExecutor(max_workers=Config.MAX_PARALLEL_WORKERS) as executor:
# 提交所有任务
future_to_task = {
executor.submit(
_execute_task,
task,
domain_context
): task
for task in tasks
}
# 收集结果
for future in as_completed(future_to_task):
original_task = future_to_task[future]
try:
completed_task = future.result(timeout=300) # 5分钟超时
completed_tasks.append(completed_task)
if completed_task["status"] == "completed" and completed_task["result"]:
all_raw_extractions.extend(completed_task["result"])
print(f" ✅ 任务 {completed_task['source_config']['source_id']} "
f"完成,提取 {len(completed_task['result'])} 个三元组")
else:
error_msg = completed_task.get("error", "未知错误")
all_errors.append(error_msg)
print(f" ❌ 任务 {completed_task['source_config']['source_id']} 失败: {error_msg}")
except Exception as e:
error_msg = f"任务执行超时或崩溃: {str(e)}"
all_errors.append(error_msg)
print(f" 💥 {error_msg}")
print(f"✅ [分发智能体] 并行处理完成!"
f"成功: {sum(1 for t in completed_tasks if t['status']=='completed')}/"
f"{len(tasks)},"
f"提取三元组总数: {len(all_raw_extractions)}")
return {
**state,
"tasks": completed_tasks,
"raw_extractions": all_raw_extractions,
"errors": all_errors,
"current_phase": "cleaning"
}
def _execute_task(task: ProcessingTask, domain_context: str) -> ProcessingTask:
"""执行单个任务"""
agent_type = task.get("assigned_agent", "pdf_processor")
processor_func = PROCESSOR_MAP.get(agent_type, process_pdf_task)
updated_task = {**task, "status": "processing"}
return processor_func(updated_task, domain_context)
def _infer_agent_type(source_type: str) -> str:
"""根据数据源类型推断智能体类型"""
type_map = {
"pdf": "pdf_processor",
"word": "pdf_processor",
"txt": "text_ner",
"csv": "csv_processor",
"excel": "csv_processor",
"mysql": "db_extractor",
"mongodb": "db_extractor",
"api": "api_fetcher",
"html": "api_fetcher",
}
return type_map.get(source_type.lower(), "text_ner")
4.5 数据清洗节点(实时清洗的核心)
# nodes/cleaner.py
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from typing import List, Dict, Set, Tuple
import hashlib
import re
from collections import defaultdict
from state import KGBuildState, ExtractedTriple
from config import Config
llm = ChatOpenAI(model=Config.OPENAI_MODEL, temperature=0)
ENTITY_NORMALIZATION_PROMPT = ChatPromptTemplate.from_messages([
("system", """你是实体标准化专家。给定一组相似的实体名称,判断它们是否指向同一实体,
并给出标准化名称。
输出 JSON:
{{
"clusters": [
{{
"canonical": "标准名称",
"variants": ["变体1", "变体2"],
"entity_type": "实体类型"
}}
]
}}
"""),
("human", "实体列表:\n{entities}\n\n领域: {domain}")
])
def compute_triple_hash(triple: ExtractedTriple) -> str:
"""计算三元组的哈希值用于去重"""
key = f"{triple['subject']}|{triple['predicate']}|{triple['object']}"
return hashlib.md5(key.encode('utf-8')).hexdigest()
def normalize_text(text: str) -> str:
"""文本标准化"""
# 去除多余空格
text = re.sub(r'\s+', ' ', text.strip())
# 统一全半角
text = text.replace('(', '(').replace(')', ')')
text = text.replace(',', ',').replace('。', '.')
# 去除首尾标点
text = text.strip('.,;:!?,。;:!?')
return text
def filter_noise_triples(triples: List[ExtractedTriple]) -> Tuple[List, List]:
"""过滤噪声三元组"""
valid = []
rejected = []
for triple in triples:
reasons = []
# 检查主语/宾语长度
if len(triple.get("subject", "")) < 1:
reasons.append("主语为空")
if len(triple.get("object", "")) < 1:
reasons.append("宾语为空")
if len(triple.get("predicate", "")) < 1:
reasons.append("谓语为空")
# 检查主语是否过长(可能是噪声文本)
if len(triple.get("subject", "")) > 100:
reasons.append("主语过长")
if len(triple.get("object", "")) > 200:
reasons.append("宾语过长")
# 检查置信度
if triple.get("confidence", 0) < Config.QUALITY_THRESHOLD:
reasons.append(f"置信度过低({triple.get('confidence', 0):.2f})")
# 检查主语和宾语是否相同
if triple.get("subject") == triple.get("object"):
reasons.append("主语和宾语相同")
# 检查是否包含乱码
subj = triple.get("subject", "")
if re.search(r'[\x00-\x08\x0b-\x0c\x0e-\x1f]', subj):
reasons.append("包含控制字符")
if reasons:
triple["reject_reasons"] = reasons
rejected.append(triple)
else:
valid.append(triple)
return valid, rejected
def deduplicate_triples(triples: List[ExtractedTriple]) -> List[ExtractedTriple]:
"""三元组去重"""
seen_hashes: Set[str] = set()
unique_triples = []
# 按置信度降序排序,优先保留高置信度三元组
sorted_triples = sorted(triples, key=lambda x: x.get("confidence", 0), reverse=True)
for triple in sorted_triples:
h = compute_triple_hash(triple)
if h not in seen_hashes:
seen_hashes.add(h)
unique_triples.append(triple)
return unique_triples
def normalize_entities_batch(triples: List[ExtractedTriple], domain: str) -> List[ExtractedTriple]:
"""批量实体标准化——通过 LLM 识别同义实体"""
# 收集所有独特实体
entities = set()
for triple in triples:
entities.add(f"{triple['subject']} ({triple['subject_type']})")
entities.add(f"{triple['object']} ({triple['object_type']})")
entity_list = list(entities)
# 批量处理(避免 prompt 过长)
batch_size = 50
entity_map = {} # 变体 -> 标准名
for i in range(0, len(entity_list), batch_size):
batch = entity_list[i:i+batch_size]
try:
chain = ENTITY_NORMALIZATION_PROMPT | llm | JsonOutputParser()
result = chain.invoke({
"entities": "\n".join(batch),
"domain": domain
})
for cluster in result.get("clusters", []):
canonical = cluster["canonical"]
for variant in cluster.get("variants", []):
entity_map[variant] = canonical
entity_map[canonical] = canonical
except Exception as e:
print(f" ⚠️ 实体标准化批次 {i//batch_size + 1} 失败: {e}")
# 应用标准化映射
normalized_triples = []
for triple in triples:
normalized = {**triple}
normalized["subject"] = entity_map.get(triple["subject"], triple["subject"])
normalized["object"] = entity_map.get(triple["object"], triple["object"])
normalized_triples.append(normalized)
return normalized_triples
def cleaner_node(state: KGBuildState) -> KGBuildState:
"""清洗节点:对所有提取的三元组进行深度清洗"""
print("🧹 [清洗智能体] 开始数据清洗...")
raw_triples = state["raw_extractions"]
domain_context = state["domain_context"]
print(f" 📊 原始三元组数量: {len(raw_triples)}")
# Step 1: 文本标准化
print(" 🔤 Step 1: 文本标准化...")
for triple in raw_triples:
triple["subject"] = normalize_text(triple.get("subject", ""))
triple["object"] = normalize_text(triple.get("object", ""))
triple["predicate"] = normalize_text(triple.get("predicate", ""))
# Step 2: 过滤噪声
print(" 🚿 Step 2: 过滤噪声三元组...")
valid_triples, rejected_triples = filter_noise_triples(raw_triples)
print(f" 有效: {len(valid_triples)},拒绝: {len(rejected_triples)}")
# Step 3: 去重
print(" 🔍 Step 3: 三元组去重...")
unique_triples = deduplicate_triples(valid_triples)
print(f" 去重后: {len(unique_triples)}(去除 {len(valid_triples)-len(unique_triples)} 个重复)")
# Step 4: 实体标准化(LLM 辅助)
print(" 🏷️ Step 4: LLM 实体标准化...")
normalized_triples = normalize_entities_batch(unique_triples, domain_context)
# Step 5: 再次去重(标准化后可能产生新的重复)
final_triples = deduplicate_triples(normalized_triples)
print(f" ✨ 最终清洗后三元组数量: {len(final_triples)}")
# 生成质量报告
quality_report = {
"raw_count": len(raw_triples),
"filtered_count": len(valid_triples),
"unique_count": len(unique_triples),
"final_count": len(final_triples),
"rejection_rate": f"{len(rejected_triples)/max(len(raw_triples),1)*100:.1f}%",
"dedup_rate": f"{(len(valid_triples)-len(unique_triples))/max(len(valid_triples),1)*100:.1f}%",
"avg_confidence": sum(t.get("confidence",0) for t in final_triples) / max(len(final_triples),1),
"source_distribution": _count_by_source(final_triples),
"type_distribution": _count_by_type(final_triples)
}
print(f"✅ [清洗智能体] 清洗完成!质量报告: {quality_report}")
return {
**state,
"cleaned_triples": final_triples,
"rejected_triples": rejected_triples,
"quality_report": quality_report,
"current_phase": "building"
}
def _count_by_source(triples: List[ExtractedTriple]) -> Dict:
counts = defaultdict(int)
for t in triples:
counts[t.get("source_id", "unknown")] += 1
return dict(counts)
def _count_by_type(triples: List[ExtractedTriple]) -> Dict:
counts = defaultdict(int)
for t in triples:
counts[t.get("subject_type", "Unknown")] += 1
return dict(counts)
4.6 知识图谱写入节点(Neo4j)
# nodes/kg_builder.py
from neo4j import GraphDatabase
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from typing import List, Dict, Any
import time
import json
from state import KGBuildState, ExtractedTriple
from config import Config
llm = ChatOpenAI(model=Config.OPENAI_MODEL, temperature=0)
CYPHER_GENERATION_PROMPT = ChatPromptTemplate.from_messages([
("system", """你是 Neo4j Cypher 查询专家。根据知识三元组,生成高效的 Cypher MERGE 语句。
规则:
1. 使用 MERGE 而非 CREATE,确保幂等性(避免重复创建)
2. 节点标签使用实体类型(如 :Person, :Organization 等)
3. 关系类型使用大写下划线格式(如 :BELONGS_TO, :LOCATED_IN)
4. 为节点和关系添加必要属性
5. 每条语句独立,用换行分隔
6. 添加 source 和 confidence 属性用于溯源
示例:
MERGE (s:Person {{name: "张三"}})
MERGE (o:Organization {{name: "ABC公司"}})
MERGE (s)-[:WORKS_FOR {{confidence: 0.95, source: "doc_001"}}]->(o)
输出纯文本 Cypher 语句,每个三元组一条 MERGE 语句。
"""),
("human", "三元组列表(JSON格式):\n{triples_json}")
])
class Neo4jWriter:
"""Neo4j 写入器,支持批量写入和事务管理"""
def __init__(self):
self.driver = GraphDatabase.driver(
Config.NEO4J_URI,
auth=(Config.NEO4J_USERNAME, Config.NEO4J_PASSWORD)
)
def close(self):
self.driver.close()
def create_indexes(self):
"""创建必要的索引以提升查询性能"""
index_queries = [
"CREATE INDEX IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX IF NOT EXISTS FOR (n:Person) ON (n.name)",
"CREATE INDEX IF NOT EXISTS FOR (n:Organization) ON (n.name)",
"CREATE INDEX IF NOT EXISTS FOR (n:Location) ON (n.name)",
"CREATE INDEX IF NOT EXISTS FOR (n:Concept) ON (n.name)",
"CREATE INDEX IF NOT EXISTS FOR (n:Product) ON (n.name)",
"CREATE INDEX IF NOT EXISTS FOR (n:Event) ON (n.name)",
]
with self.driver.session() as session:
for query in index_queries:
try:
session.run(query)
except Exception as e:
print(f" ⚠️ 索引创建警告: {e}")
print(" ✅ 索引创建完成")
def write_triple_batch(self, triples: List[ExtractedTriple],
batch_size: int = 50) -> Dict[str, Any]:
"""批量写入三元组到 Neo4j"""
total_written = 0
total_failed = 0
errors = []
# 使用参数化 Cypher 语句进行批量写入
cypher = """
UNWIND $triples AS triple
CALL {
WITH triple
MERGE (s {name: triple.subject})
ON CREATE SET s.name = triple.subject,
s.created_at = datetime(),
s.source = triple.source_id
SET s:`Entity`
CALL apoc.create.addLabels(s, [triple.subject_type]) YIELD node AS subj
MERGE (o {name: triple.object})
ON CREATE SET o.name = triple.object,
o.created_at = datetime(),
o.source = triple.source_id
SET o:`Entity`
CALL apoc.create.addLabels(o, [triple.object_type]) YIELD node AS obj
WITH subj, obj, triple
CALL apoc.merge.relationship(
subj,
triple.predicate,
{source: triple.source_id},
{confidence: triple.confidence, created_at: datetime()},
obj
) YIELD rel
RETURN rel
}
RETURN count(*) AS count
"""
# 回退到更基础的写入方式(不依赖 APOC)
cypher_basic = """
MERGE (s:Entity {name: $subject})
SET s.entity_type = $subject_type, s.updated_at = datetime()
MERGE (o:Entity {name: $object})
SET o.entity_type = $object_type, o.updated_at = datetime()
MERGE (s)-[r:RELATED_TO {predicate: $predicate}]->(o)
SET r.confidence = $confidence,
r.source = $source_id,
r.updated_at = datetime()
"""
# 分批写入
for i in range(0, len(triples), batch_size):
batch = triples[i:i+batch_size]
with self.driver.session() as session:
tx = session.begin_transaction()
try:
batch_written = 0
for triple in batch:
# 构建动态标签的 Cypher
subject_label = self._sanitize_label(triple.get("subject_type", "Entity"))
object_label = self._sanitize_label(triple.get("object_type", "Entity"))
rel_type = self._sanitize_rel_type(triple.get("predicate", "RELATED_TO"))
cypher_dynamic = f"""
MERGE (s:{subject_label} {{name: $subject}})
SET s.updated_at = datetime(), s.source = $source_id
MERGE (o:{object_label} {{name: $object}})
SET o.updated_at = datetime(), o.source = $source_id
MERGE (s)-[r:{rel_type}]->(o)
SET r.confidence = $confidence,
r.source = $source_id,
r.predicate_text = $predicate,
r.updated_at = datetime()
"""
tx.run(cypher_dynamic,
subject=triple.get("subject", ""),
object=triple.get("object", ""),
predicate=triple.get("predicate", ""),
confidence=float(triple.get("confidence", 0.5)),
source_id=triple.get("source_id", "unknown"))
batch_written += 1
tx.commit()
total_written += batch_written
print(f" 💾 批次 {i//batch_size + 1}: 写入 {batch_written} 个三元组")
except Exception as e:
tx.rollback()
error_msg = f"批次 {i//batch_size + 1} 写入失败: {str(e)}"
errors.append(error_msg)
total_failed += len(batch)
print(f" ❌ {error_msg}")
finally:
tx.close()
# 避免频繁写入造成数据库压力
time.sleep(0.1)
return {
"total_written": total_written,
"total_failed": total_failed,
"errors": errors
}
def _sanitize_label(self, label: str) -> str:
"""清理节点标签,确保符合 Neo4j 命名规范"""
import re
# 移除特殊字符,首字母大写
clean = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fff_]', '_', label)
if clean and clean[0].isdigit():
clean = "Type_" + clean
return clean if clean else "Entity"
def _sanitize_rel_type(self, rel_type: str) -> str:
"""清理关系类型,转换为大写下划线格式"""
import re
# 替换中文谓语为英文关系标识符
rel_map = {
"属于": "BELONGS_TO", "位于": "LOCATED_IN", "创建": "CREATED_BY",
"包含": "CONTAINS", "相关": "RELATED_TO", "合作": "COOPERATES_WITH",
"管理": "MANAGED_BY", "使用": "USES", "具有": "HAS_PROPERTY",
"生产": "PRODUCES", "参与": "PARTICIPATES_IN"
}
for cn, en in rel_map.items():
if cn in rel_type:
return en
# 清理并转换为大写下划线
clean = re.sub(r'[^a-zA-Z0-9\u4e00-\u9fff]', '_', rel_type.upper())
clean = re.sub(r'_+', '_', clean).strip('_')
return clean if clean else "RELATED_TO"
def get_graph_stats(self) -> Dict:
"""获取图谱统计信息"""
with self.driver.session() as session:
node_count = session.run("MATCH (n) RETURN count(n) AS count").single()["count"]
rel_count = session.run("MATCH ()-[r]->() RETURN count(r) AS count").single()["count"]
label_counts = {}
for record in session.run("CALL db.labels() YIELD label RETURN label"):
label = record["label"]
count = session.run(f"MATCH (n:{label}) RETURN count(n) AS count").single()["count"]
label_counts[label] = count
return {
"total_nodes": node_count,
"total_relationships": rel_count,
"node_labels": label_counts
}
def kg_builder_node(state: KGBuildState) -> KGBuildState:
"""知识图谱构建节点:将清洗后的三元组写入 Neo4j"""
print("🏗️ [KG构建智能体] 开始构建知识图谱...")
cleaned_triples = state["cleaned_triples"]
if not cleaned_triples:
print(" ⚠️ 没有可写入的三元组,跳过构建")
return {
**state,
"write_results": [{"status": "skipped", "reason": "no_data"}],
"current_phase": "validating"
}
writer = Neo4jWriter()
try:
# 创建索引
print(" 📑 创建图数据库索引...")
writer.create_indexes()
# 批量写入
print(f" 💾 开始写入 {len(cleaned_triples)} 个三元组...")
write_result = writer.write_triple_batch(
cleaned_triples,
batch_size=Config.BATCH_SIZE
)
# 获取图谱统计
stats = writer.get_graph_stats()
print(f"✅ [KG构建智能体] 写入完成!"
f"成功: {write_result['total_written']},"
f"失败: {write_result['total_failed']}")
print(f" 📊 图谱统计: {stats}")
return {
**state,
"write_results": [{
"total_written": write_result["total_written"],
"total_failed": write_result["total_failed"],
"errors": write_result["errors"],
"graph_stats": stats,
"status": "success" if write_result["total_failed"] == 0 else "partial"
}],
"current_phase": "validating"
}
except Exception as e:
error_msg = f"知识图谱构建失败: {str(e)}"
print(f"❌ [KG构建智能体] {error_msg}")
return {
**state,
"errors": [error_msg],
"write_results": [{"status": "failed", "error": error_msg}],
"current_phase": "error"
}
finally:
writer.close()
4.7 验证节点与路由逻辑
# nodes/validator.py
from state import KGBuildState
from config import Config
def validator_node(state: KGBuildState) -> KGBuildState:
"""验证节点:验证知识图谱构建质量,决定是否需要重试"""
print("🔍 [验证智能体] 正在验证构建结果...")
write_results = state.get("write_results", [])
quality_report = state.get("quality_report", {})
retry_count = state.get("retry_count", 0)
issues = []
# 检查写入成功率
if write_results:
latest_result = write_results[-1]
total_written = latest_result.get("total_written", 0)
total_failed = latest_result.get("total_failed", 0)
total = total_written + total_failed
if total > 0:
success_rate = total_written / total
if success_rate < 0.9: # 成功率低于 90%
issues.append(f"写入成功率过低: {success_rate:.1%}")
# 检查图谱规模
if write_results and "graph_stats" in (write_results[-1] if write_results else {}):
stats = write_results[-1]["graph_stats"]
if stats.get("total_nodes", 0) == 0:
issues.append("图谱中没有节点")
# 检查质量分数
avg_confidence = quality_report.get("avg_confidence", 0)
if avg_confidence < Config.QUALITY_THRESHOLD:
issues.append(f"平均置信度过低: {avg_confidence:.2f}")
if issues and retry_count < 2:
print(f" ⚠️ 发现问题,需要重试: {issues}")
return {
**state,
"retry_count": retry_count + 1,
"current_phase": "retry_needed",
"errors": issues
}
# 生成最终报告
final_report = {
"status": "completed" if not issues else "completed_with_warnings",
"warnings": issues,
"quality_report": quality_report,
"write_summary": write_results[-1] if write_results else {},
"total_retry_count": retry_count
}
print(f"✅ [验证智能体] 验证通过!状态: {final_report['status']}")
return {
**state,
"final_report": final_report,
"current_phase": "completed"
}
# 路由函数
def route_after_validation(state: KGBuildState) -> str:
"""验证后的路由逻辑"""
phase = state.get("current_phase", "")
if phase == "retry_needed":
return "retry"
elif phase == "completed":
return "end"
elif phase == "error":
return "end"
else:
return "end"
def route_after_dispatch(state: KGBuildState) -> str:
"""分发后的路由逻辑"""
errors = state.get("errors", [])
raw_extractions = state.get("raw_extractions", [])
if not raw_extractions and errors:
return "error"
return "clean"
4.8 组装完整的 LangGraph 工作流
# graph.py
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from state import KGBuildState
from nodes.planner import planner_node
from nodes.dispatcher import dispatcher_node
from nodes.cleaner import cleaner_node
from nodes.kg_builder import kg_builder_node
from nodes.validator import validator_node, route_after_validation, route_after_dispatch
def create_kg_builder_graph() -> StateGraph:
"""创建知识图谱构建的 LangGraph 工作流"""
# 初始化图
workflow = StateGraph(KGBuildState)
# 添加所有节点
workflow.add_node("planner", planner_node)
workflow.add_node("dispatcher", dispatcher_node)
workflow.add_node("cleaner", cleaner_node)
workflow.add_node("kg_builder", kg_builder_node)
workflow.add_node("validator", validator_node)
# 设置入口点
workflow.set_entry_point("planner")
# 添加标准边
workflow.add_edge("planner", "dispatcher")
workflow.add_edge("cleaner", "kg_builder")
workflow.add_edge("kg_builder", "validator")
# 添加条件边:分发后根据结果路由
workflow.add_conditional_edges(
"dispatcher",
route_after_dispatch,
{
"clean": "cleaner",
"error": END
}
)
# 添加条件边:验证后根据结果路由
workflow.add_conditional_edges(
"validator",
route_after_validation,
{
"retry": "cleaner", # 失败时重新清洗和写入
"end": END
}
)
# 编译图(附加内存检查点)
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)
return app
def visualize_graph(app) -> None:
"""可视化工作流图"""
try:
from IPython.display import Image, display
display(Image(app.get_graph().draw_mermaid_png()))
except Exception:
print(app.get_graph().draw_ascii())
# 主程序入口
def run_kg_builder(data_sources: list, domain_context: str, thread_id: str = "kg_build_001"):
"""运行知识图谱构建流程"""
app = create_kg_builder_graph()
# 初始状态
initial_state = {
"data_sources": data_sources,
"domain_context": domain_context,
"processing_plan": None,
"tasks": [],
"raw_extractions": [],
"cleaned_triples": [],
"rejected_triples": [],
"quality_report": None,
"cypher_statements": [],
"write_results": [],
"current_phase": "planning",
"retry_count": 0,
"errors": [],
"final_report": None
}
config = {"configurable": {"thread_id": thread_id}}
print("=" * 60)
print("🚀 知识图谱构建流程启动")
print("=" * 60)
# 流式执行,实时输出每个节点的处理进度
final_state = None
for event in app.stream(initial_state, config=config):
for node_name, node_output in event.items():
print(f"\n{'='*20} 节点: {node_name} {'='*20}")
if node_name == "validator" and node_output.get("final_report"):
print(f"📋 最终报告: {node_output['final_report']}")
final_state = node_output
print("\n" + "=" * 60)
print("✅ 知识图谱构建流程完成!")
print("=" * 60)
return final_state
五、端到端使用示例
# main.py
from graph import run_kg_builder
if __name__ == "__main__":
# 定义数据源
data_sources = [
{
"source_id": "annual_report_2023",
"source_type": "pdf",
"source_path": "./data/annual_report_2023.pdf",
"schema_hint": "企业年度报告,包含公司简介、业务板块、高管信息",
"priority": 1
},
{
"source_id": "employee_data",
"source_type": "csv",
"source_path": "./data/employees.csv",
"schema_hint": "员工信息表,包含姓名、部门、职位、上级关系",
"priority": 2
},
{
"source_id": "product_db",
"source_type": "mysql",
"source_path": "mysql://root:password@localhost:3306/products_db",
"schema_hint": "产品数据库,包含产品、分类、供应商关系",
"priority": 2
},
{
"source_id": "news_corpus",
"source_type": "txt",
"source_path": "./data/news_articles.txt",
"schema_hint": "新闻文章集合,涉及企业合作、投资、并购事件",
"priority": 3
},
{
"source_id": "partner_mongodb",
"source_type": "mongodb",
"source_path": "mongodb://localhost:27017",
"schema_hint": "合作伙伴数据库,包含合作公司信息和合作项目",
"priority": 2
}
]
# 领域上下文
domain_context = """
企业知识图谱,领域为科技制造业。
主要关注实体类型:公司(Company)、人员(Person)、产品(Product)、
技术(Technology)、项目(Project)、部门(Department)。
主要关系:隶属关系、合作关系、产品关系、技术依赖、人员归属。
"""
# 执行构建
result = run_kg_builder(
data_sources=data_sources,
domain_context=domain_context,
thread_id="enterprise_kg_v1"
)
# 输出最终报告
if result and result.get("final_report"):
import json
print("\n📊 最终构建报告:")
print(json.dumps(result["final_report"], ensure_ascii=False, indent=2))
六、进阶优化策略
6.1 流式实时处理与断点续传
利用 LangGraph 的 Checkpointer 机制,可以实现断点续传,避免长时间任务中途失败的损失:
# 使用 SQLite 持久化检查点
from langgraph.checkpoint.sqlite import SqliteSaver
import sqlite3
conn = sqlite3.connect("kg_build_checkpoint.db", check_same_thread=False)
memory = SqliteSaver(conn)
app = workflow.compile(checkpointer=memory)
# 如果任务中断,使用相同的 thread_id 恢复
config = {"configurable": {"thread_id": "enterprise_kg_v1"}}
# LangGraph 会自动从上次检查点恢复
6.2 动态并行度控制
根据系统资源动态调整并行度:
import psutil
def get_optimal_workers() -> int:
"""根据系统负载动态决定并行工作线程数"""
cpu_percent = psutil.cpu_percent(interval=1)
memory_percent = psutil.virtual_memory().percent
base_workers = Config.MAX_PARALLEL_WORKERS
if cpu_percent > 80 or memory_percent > 85:
return max(1, base_workers // 2)
elif cpu_percent < 30 and memory_percent < 50:
return min(base_workers * 2, 20)
else:
return base_workers
6.3 增量更新机制
对于已写入 Neo4j 的数据,支持增量更新而非全量重建:
def check_existing_data(writer: Neo4jWriter, source_id: str) -> Dict:
"""检查数据源是否已经处理过"""
with writer.driver.session() as session:
result = session.run("""
MATCH (n) WHERE n.source = $source_id
RETURN count(n) AS node_count,
max(n.updated_at) AS last_update
""", source_id=source_id)
record = result.single()
return {
"exists": record["node_count"] > 0,
"node_count": record["node_count"],
"last_update": record["last_update"]
}
6.4 质量评分体系
构建更精细的质量评分模型:
def compute_triple_quality_score(triple: ExtractedTriple) -> float:
"""综合评分:置信度 + 实体完整性 + 关系有效性"""
score = triple.get("confidence", 0.5)
# 实体长度合理性
subj_len = len(triple.get("subject", ""))
obj_len = len(triple.get("object", ""))
if 1 <= subj_len <= 50:
score += 0.05
if 1 <= obj_len <= 100:
score += 0.05
# 关系语义质量(有实质含义的谓语加分)
predicate = triple.get("predicate", "")
meaningful_predicates = ["属于", "位于", "创立", "合作", "管理", "生产", "使用"]
if any(p in predicate for p in meaningful_predicates):
score += 0.1
# 来源可信度
source_trust = {
"mysql": 0.1, "mongodb": 0.08,
"csv": 0.05, "pdf": 0.03, "api": 0.02
}
source_id = triple.get("source_id", "")
for source_type, bonus in source_trust.items():
if source_type in source_id.lower():
score += bonus
break
return min(score, 1.0)
七、监控与可观测性
7.1 构建 Prometheus 监控指标
# monitoring.py
from prometheus_client import Counter, Histogram, Gauge, start_http_server
import time
# 定义指标
triples_extracted = Counter('kg_triples_extracted_total',
'提取的三元组总数', ['source_type'])
triples_cleaned = Counter('kg_triples_cleaned_total',
'清洗后的三元组数', ['status'])
neo4j_write_duration = Histogram('kg_neo4j_write_duration_seconds',
'Neo4j 写入耗时')
active_processors = Gauge('kg_active_processors', '当前活跃的处理器数量')
class MetricsCollector:
def record_extraction(self, source_type: str, count: int):
triples_extracted.labels(source_type=source_type).inc(count)
def record_cleaning(self, accepted: int, rejected: int):
triples_cleaned.labels(status='accepted').inc(accepted)
triples_cleaned.labels(status='rejected').inc(rejected)
@neo4j_write_duration.time()
def timed_write(self, write_func, *args, **kwargs):
return write_func(*args, **kwargs)
7.2 LangSmith 追踪集成
import os
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "KG-Builder"
os.environ["LANGCHAIN_API_KEY"] = "your_langsmith_api_key"
# LangGraph 会自动将所有节点的执行情况上报到 LangSmith
# 可在 LangSmith 控制台查看完整执行链路、token 消耗、延迟等
八、生产部署建议
8.1 容器化部署
# docker-compose.yml
version: '3.8'
services:
neo4j:
image: neo4j:5.14
environment:
- NEO4J_AUTH=neo4j/your_password
- NEO4J_PLUGINS=["apoc"]
- NEO4J_dbms_memory_heap_max__size=4G
ports:
- "7474:7474"
- "7687:7687"
volumes:
- neo4j_data:/data
kg-builder:
build: .
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- NEO4J_URI=bolt://neo4j:7687
- MAX_PARALLEL_WORKERS=8
depends_on:
- neo4j
volumes:
- ./data:/app/data
volumes:
neo4j_data:
8.2 扩展性设计原则
在大规模生产环境中,建议遵循以下原则:
无状态处理器:每个处理智能体应设计为无状态,状态统一由 LangGraph 的 StateGraph 管理,便于水平扩展
消息队列解耦:对于超大规模数据,可在 Dispatcher 和 Processor 之间引入 Kafka 或 RabbitMQ,实现异步解耦
Neo4j 集群:生产环境建议使用 Neo4j Causal Cluster,支持读写分离和高可用
LLM 成本控制:对于重复性高的结构化数据(如 CSV、数据库),优先使用规则引擎而非 LLM,降低 API 调用成本
缓存机制:对已处理过的文档内容进行哈希缓存,避免重复处理
九、总结与展望
核心价值回顾
本文构建的基于 LangGraph 的多智能体知识图谱构建系统,具有以下核心价值:
未来演进方向
引入 RAG 增强实体消歧:将已有知识图谱作为上下文,辅助新数据的实体消歧
多模态数据支持:扩展对图片、视频字幕的知识抽取能力
在线学习与反馈:基于用户对图谱质量的反馈,动态调整清洗规则
知识图谱推理增强:结合 GraphRAG,让 LLM 基于图谱进行多跳推理
联邦学习:在保护数据隐私的前提下,从多个组织的数据中协作构建知识图谱
知识图谱的价值在于其动态性和完整性。LangGraph 为我们提供了一个强大的编排框架,让智能体能够灵活分工、并行协作、自我修正,从而将原本繁琐的多源异构数据处理流程转变为高效、可靠的自动化管道。随着大语言模型能力的持续提升,这一方案将展现出越来越强大的知识构建能力,为企业智能化奠定坚实的数据基础。
本文代码已在 Python 3.10+、LangGraph 0.2+、Neo4j 5.14 环境下验证。如有问题欢迎在评论区交流。