2026-03-29 23:47:20 +08:00

723 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
因果推断解析器模块 (Causal Parser)
该模块提供因果推断解析功能,使用启发式规则解析用户的自然语言问题,
识别因果推断中的关键要素(处理变量 T、结果变量 Y、协变量
"""
import json
from dataclasses import dataclass, asdict, field
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
# 支持直接运行和作为模块导入两种方式
if __name__ == "__main__":
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from file_query_tool import FileQueryTool, FileMetadata
else:
from .file_query_tool import FileQueryTool, FileMetadata
@dataclass
class CausalInferenceRequest:
"""因果推断请求数据类"""
user_question: str # 用户的自然语言问题
file_path: str # 数据文件路径
read_rows: int = 100 # 读取的行数
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return asdict(self)
@dataclass
class VariableInfo:
"""变量信息"""
name: str # 变量名
variable_type: str # 变量类型treatment/outcome/control
confidence: float # 置信度0-1
reasoning: str # 推理理由
data_type: Optional[str] = None # 数据类型
is_binary: Optional[bool] = None # 是否为二值变量
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return asdict(self)
@dataclass
class CausalInferenceResult:
"""因果推断结果数据类"""
original_question: str # 用户原始问题
file_path: str # 文件路径
treatment_variable: Optional[VariableInfo] = None # 处理变量 T
outcome_variable: Optional[VariableInfo] = None # 结果变量 Y
control_variables: List[VariableInfo] = field(default_factory=list) # 协变量列表
confidence_score: float = 0.0 # 整体置信度评分
reasoning: str = "" # 推理理由
alternative_matches: List[Dict[str, Any]] = field(default_factory=list) # 其他可能的匹配
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result = asdict(self)
if self.treatment_variable:
result["treatment_variable"] = self.treatment_variable.to_dict()
if self.outcome_variable:
result["outcome_variable"] = self.outcome_variable.to_dict()
result["control_variables"] = [v.to_dict() for v in self.control_variables]
return result
def to_json(self, indent: int = 2) -> str:
"""转换为 JSON 字符串"""
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
class CausalParser:
"""
因果推断解析器
使用 LLM 解析用户的自然语言问题,识别因果推断中的关键要素。
Attributes:
file_query_tool: 文件查询工具实例
llm_client: LLM 客户端(可选)
default_read_rows: 默认读取的行数
Example:
>>> parser = CausalParser()
>>> request = CausalInferenceRequest(
... user_question="教育对收入的影响是什么?",
... file_path="data/education_data.csv"
... )
>>> result = parser.parse(request)
>>> print(result.to_json())
"""
# 内置的提示词模板
DEFAULT_SYSTEM_PROMPT = """你是一位因果推断专家,负责从用户的自然语言问题中识别因果推断的关键要素。
任务:
1. 分析用户的问题,理解其因果推断意图
2. 根据提供的数据表结构,识别以下变量:
- 处理变量 (T/Treatment): 被操纵或观察的干预措施
- 结果变量 (Y/Outcome): 受处理影响的指标
- 协变量 (Control): 需要控制的混淆因素
输出要求:
- 只输出 JSON 格式,不要包含其他内容
- 对于每个变量,提供名称、类型、置信度和推理理由
- 如果无法确定某个变量,将其设为 null 或空列表
- 置信度评分范围0.0-1.0
变量类型说明:
- treatment: 处理变量,通常是实验中的干预措施或观察中的暴露因素
- outcome: 结果变量,是我们要研究的效应指标
- control: 协变量/控制变量,用于控制混淆因素
示例输出格式:
{{
"treatment_variable": {{
"name": "treatment",
"variable_type": "treatment",
"confidence": 0.95,
"reasoning": "该变量表示是否接受干预"
}},
"outcome_variable": {{
"name": "outcome",
"variable_type": "outcome",
"confidence": 0.90,
"reasoning": "该变量表示干预后的结果"
}},
"control_variables": [
{{
"name": "age",
"variable_type": "control",
"confidence": 0.85,
"reasoning": "年龄可能影响结果,需要控制"
}}
],
"confidence_score": 0.90,
"reasoning": "根据问题描述,这是一个典型的因果推断问题..."
}}"""
DEFAULT_USER_PROMPT_TEMPLATE = """用户问题:{question}
数据表结构:
{table_info}
请识别因果推断要素并输出 JSON 格式的结果。"""
def __init__(self, file_query_tool: Optional[FileQueryTool] = None,
default_read_rows: int = 100):
"""
初始化因果推断解析器
Args:
file_query_tool: 文件查询工具实例,如果为 None 则创建默认实例
default_read_rows: 默认读取的行数
"""
self.file_query_tool = file_query_tool or FileQueryTool()
self.default_read_rows = default_read_rows
def parse(self, request: CausalInferenceRequest) -> CausalInferenceResult:
"""
解析因果推断请求
Args:
request: 因果推断请求对象
Returns:
CausalInferenceResult: 解析结果
"""
# 获取文件信息
file_metadata = self.file_query_tool.query(
request.file_path,
request.read_rows
)
# 使用启发式方法进行解析
result = self._heuristic_parse(
request.user_question,
file_metadata
)
return result
def _build_table_info(self, metadata: FileMetadata) -> str:
"""
构建表头信息字符串
Args:
metadata: 文件元数据
Returns:
str: 格式化的表头信息
"""
lines = [
f"文件名:{metadata.file_name}",
f"总行数:{metadata.total_rows}",
f"总列数:{metadata.total_columns}",
"",
"列信息:"
]
if metadata.columns is None:
return "\n".join(lines)
for col in metadata.columns:
lines.append(f" - {col.name}: {col.dtype} (非空:{col.non_null_count}, 空值:{col.null_count}, 唯一值:{col.unique_count})")
if col.sample_values:
sample_str = ", ".join(str(v) for v in col.sample_values[:3])
lines.append(f" 样本值:{sample_str}")
return "\n".join(lines)
def _heuristic_parse(self, question: str, metadata: FileMetadata) -> CausalInferenceResult:
"""
使用启发式规则进行因果推断解析
Args:
question: 用户问题
metadata: 文件元数据
Returns:
CausalInferenceResult: 解析结果
"""
# 启发式规则识别变量
treatment_var = self._heuristic_identify_treatment(question, metadata)
outcome_var = self._heuristic_identify_outcome(question, metadata)
control_vars = self._heuristic_identify_controls(question, metadata, treatment_var, outcome_var)
# 计算置信度
confidence = self._calculate_confidence(treatment_var, outcome_var, control_vars)
# 构建推理理由
reasoning = self._build_reasoning(question, treatment_var, outcome_var, control_vars)
# 查找其他可能的匹配
alternative_matches = self._find_alternative_matches(metadata, treatment_var, outcome_var)
# 构建结果对象
result = CausalInferenceResult(
original_question=question,
file_path=metadata.file_path,
treatment_variable=treatment_var,
outcome_variable=outcome_var,
control_variables=control_vars,
confidence_score=confidence,
reasoning=reasoning,
alternative_matches=alternative_matches,
metadata={
"parsed_at": datetime.now().isoformat(),
"file_info": {
"file_name": metadata.file_name,
"total_rows": metadata.total_rows,
"total_columns": metadata.total_columns
}
}
)
return result
"""
模拟 LLM 解析(用于测试)
基于启发式规则进行解析,不依赖真实的 LLM。
Args:
question: 用户问题
metadata: 文件元数据
Returns:
str: 模拟的 JSON 响应
"""
# 启发式规则识别变量
treatment_var = self._heuristic_identify_treatment(question, metadata)
outcome_var = self._heuristic_identify_outcome(question, metadata)
control_vars = self._heuristic_identify_controls(question, metadata, treatment_var, outcome_var)
# 计算置信度
confidence = self._calculate_confidence(treatment_var, outcome_var, control_vars)
# 构建推理理由
reasoning = self._build_reasoning(question, treatment_var, outcome_var, control_vars)
# 构建结果
result = {
"treatment_variable": treatment_var.to_dict() if treatment_var else None,
"outcome_variable": outcome_var.to_dict() if outcome_var else None,
"control_variables": [v.to_dict() for v in control_vars],
"confidence_score": confidence,
"reasoning": reasoning,
"alternative_matches": self._find_alternative_matches(metadata, treatment_var, outcome_var)
}
return json.dumps(result, ensure_ascii=False)
def _heuristic_identify_treatment(self, question: str,
metadata: FileMetadata) -> Optional[VariableInfo]:
"""
启发式识别处理变量
Args:
question: 用户问题
metadata: 文件元数据
Returns:
Optional[VariableInfo]: 处理变量信息
"""
# 常见处理变量关键词
treatment_keywords = [
"处理", "干预", "治疗", "实验组", "对照组", "treatment", "treat",
"干预组", "实验", "exposure", "exposed", "assigned", "group"
]
# 检查列名是否包含处理变量关键词
if metadata.columns is None:
return None
for col in metadata.columns:
col_name_lower = col.name.lower()
for keyword in treatment_keywords:
if keyword in col_name_lower:
return VariableInfo(
name=col.name,
variable_type="treatment",
confidence=0.85,
reasoning=f"列名包含处理变量关键词:{keyword}",
data_type=col.dtype,
is_binary=col.unique_count == 2 if col.unique_count else None
)
# 检查是否有明显的二值变量
for col in metadata.columns:
if col.unique_count == 2 and col.non_null_count == metadata.total_rows:
return VariableInfo(
name=col.name,
variable_type="treatment",
confidence=0.70,
reasoning="发现二值变量,可能是处理变量",
data_type=col.dtype,
is_binary=True
)
return None
def _heuristic_identify_outcome(self, question: str,
metadata: FileMetadata) -> Optional[VariableInfo]:
"""
启发式识别结果变量
Args:
question: 用户问题
metadata: 文件元数据
Returns:
Optional[VariableInfo]: 结果变量信息
"""
# 常见结果变量关键词
outcome_keywords = [
"结果", "影响", "效应", "收入", "工资", "成绩", "分数", "outcome",
"result", "effect", "impact", "dependent", "y", "目标", "指标"
]
# 检查列名是否包含结果变量关键词
if metadata.columns is None:
return None
for col in metadata.columns:
col_name_lower = col.name.lower()
for keyword in outcome_keywords:
if keyword in col_name_lower:
return VariableInfo(
name=col.name,
variable_type="outcome",
confidence=0.85,
reasoning=f"列名包含结果变量关键词:{keyword}",
data_type=col.dtype,
is_binary=False
)
# 检查是否有数值型列(通常是结果变量)
for col in metadata.columns:
if col.dtype in ['float64', 'int64', 'float32', 'int32']:
return VariableInfo(
name=col.name,
variable_type="outcome",
confidence=0.65,
reasoning="发现数值型列,可能是结果变量",
data_type=col.dtype,
is_binary=False
)
return None
def _heuristic_identify_controls(self, question: str,
metadata: FileMetadata,
treatment: Optional[VariableInfo],
outcome: Optional[VariableInfo]) -> List[VariableInfo]:
"""
启发式识别协变量
Args:
question: 用户问题
metadata: 文件元数据
treatment: 处理变量
outcome: 结果变量
Returns:
List[VariableInfo]: 协变量列表
"""
controls = []
treated_names = {treatment.name} if treatment else set()
outcome_names = {outcome.name} if outcome else set()
# 常见协变量关键词
control_keywords = [
"年龄", "性别", "教育", "经验", "控制", "covariate", "control",
"confounder", "confounding", "特征", "变量", "demographic"
]
if metadata.columns is None:
return []
for col in metadata.columns:
# 跳过已识别的变量
if col.name in treated_names or col.name in outcome_names:
continue
col_name_lower = col.name.lower()
# 检查是否包含协变量关键词
is_control = False
reasoning = ""
for keyword in control_keywords:
if keyword in col_name_lower:
is_control = True
reasoning = f"列名包含协变量关键词:{keyword}"
break
# 如果没有关键词,但也不是处理或结果变量,可能是协变量
if not is_control:
is_control = True
reasoning = "可能是协变量,用于控制混淆因素"
controls.append(VariableInfo(
name=col.name,
variable_type="control",
confidence=0.60 if is_control else 0.40,
reasoning=reasoning,
data_type=col.dtype,
is_binary=col.unique_count == 2 if col.unique_count else None
))
return controls
def _calculate_confidence(self, treatment: Optional[VariableInfo],
outcome: Optional[VariableInfo],
controls: List[VariableInfo]) -> float:
"""
计算整体置信度
Args:
treatment: 处理变量
outcome: 结果变量
controls: 协变量列表
Returns:
float: 置信度评分 (0-1)
"""
if treatment is None or outcome is None:
return 0.3
# 基础置信度
base_confidence = (treatment.confidence + outcome.confidence) / 2
# 协变量数量影响
if len(controls) > 0:
avg_control_confidence = sum(c.confidence for c in controls) / len(controls)
base_confidence = (base_confidence + avg_control_confidence) / 2
return min(base_confidence, 1.0)
def _build_reasoning(self, question: str, treatment: Optional[VariableInfo],
outcome: Optional[VariableInfo],
controls: List[VariableInfo]) -> str:
"""
构建推理理由
Args:
question: 用户问题
treatment: 处理变量
outcome: 结果变量
controls: 协变量列表
Returns:
str: 推理理由
"""
reasons = []
if treatment:
reasons.append(f"处理变量 '{treatment.name}' 被识别为干预措施")
if outcome:
reasons.append(f"结果变量 '{outcome.name}' 被识别为效应指标")
if controls:
control_names = [c.name for c in controls[:3]]
if len(controls) > 3:
control_names.append(f"{len(controls) - 3} 个其他变量")
reasons.append(f"协变量 {', '.join(control_names)} 被识别为需要控制的混淆因素")
return "".join(reasons) if reasons else "未能充分识别因果推断要素"
def _find_alternative_matches(self, metadata: FileMetadata,
treatment: Optional[VariableInfo],
outcome: Optional[VariableInfo]) -> List[Dict[str, Any]]:
"""
查找其他可能的匹配
Args:
metadata: 文件元数据
treatment: 处理变量
outcome: 结果变量
Returns:
List[Dict[str, Any]]: 其他可能的匹配
"""
alternatives = []
if metadata.columns is None:
return []
# 查找其他可能的处理变量
if treatment:
for col in metadata.columns:
if col.name != treatment.name:
alternatives.append({
"variable_type": "treatment",
"name": col.name,
"confidence": 0.3,
"reasoning": "可能是替代的处理变量"
})
# 查找其他可能的结果变量
if outcome:
for col in metadata.columns:
if col.name != outcome.name:
alternatives.append({
"variable_type": "outcome",
"name": col.name,
"confidence": 0.3,
"reasoning": "可能是替代的结果变量"
})
return alternatives[:5] # 最多返回 5 个替代匹配
def _parse_llm_response(self, response: str, request: CausalInferenceRequest,
metadata: FileMetadata) -> CausalInferenceResult:
"""
解析 LLM 响应
Args:
response: LLM 响应
request: 原始请求
metadata: 文件元数据
Returns:
CausalInferenceResult: 解析结果
"""
try:
# 清理响应,提取 JSON 部分
response = response.strip()
if response.startswith("```json"):
response = response[7:]
if response.startswith("```"):
response = response[1:]
if response.endswith("```"):
response = response[:-3]
response = response.strip()
# 解析 JSON
data = json.loads(response)
# 构建结果对象
result = CausalInferenceResult(
original_question=request.user_question,
file_path=request.file_path,
treatment_variable=VariableInfo(**data.get("treatment_variable")) if data.get("treatment_variable") else None,
outcome_variable=VariableInfo(**data.get("outcome_variable")) if data.get("outcome_variable") else None,
control_variables=[VariableInfo(**v) for v in data.get("control_variables", [])],
confidence_score=data.get("confidence_score", 0.0),
reasoning=data.get("reasoning", ""),
alternative_matches=data.get("alternative_matches", []),
metadata={
"parsed_at": datetime.now().isoformat(),
"file_info": {
"file_name": metadata.file_name,
"total_rows": metadata.total_rows,
"total_columns": metadata.total_columns
}
}
)
return result
except json.JSONDecodeError as e:
# 如果 JSON 解析失败,返回错误结果
return CausalInferenceResult(
original_question=request.user_question,
file_path=request.file_path,
confidence_score=0.0,
reasoning=f"LLM 响应解析失败:{str(e)}",
metadata={
"parsed_at": datetime.now().isoformat(),
"error": str(e),
"raw_response": response[:500]
}
)
def parse_simple(self, question: str, file_path: str) -> CausalInferenceResult:
"""
简化的解析方法
Args:
question: 用户问题
file_path: 文件路径
Returns:
CausalInferenceResult: 解析结果
"""
request = CausalInferenceRequest(
user_question=question,
file_path=file_path
)
return self.parse(request)
if __name__ == "__main__":
"""
因果推断解析器模块的简单测试示例
"""
from file_query_tool import FileQueryTool
print("=" * 60)
print("CausalParser 简单测试示例")
print("=" * 60)
print()
# 创建解析器实例(不使用 LLM 客户端)
print("步骤 1: 创建 CausalParser 实例...")
parser = CausalParser()
print(" - CausalParser 已初始化")
print()
# 准备示例数据文件
print("步骤 2: 准备示例数据文件...")
sample_file = "data/output/test_simple_ate.csv"
print(f" - 文件路径:{sample_file}")
print()
# 创建解析请求
print("步骤 3: 创建解析请求...")
sample_question = "教育对收入的影响是什么?"
request = CausalInferenceRequest(
user_question=sample_question,
file_path=sample_file
)
print(f" - 问题:{sample_question}")
print()
# 执行解析
print("步骤 4: 执行解析...")
try:
result = parser.parse(request)
print(" - 解析成功!")
print()
# 输出解析结果
print("步骤 5: 解析结果:")
print("-" * 60)
print(f"原始问题:{result.original_question}")
print(f"文件路径:{result.file_path}")
print(f"整体置信度:{result.confidence_score:.2%}")
print()
if result.treatment_variable:
tv = result.treatment_variable
print(f"处理变量 (T):")
print(f" - 名称:{tv.name}")
print(f" - 类型:{tv.variable_type}")
print(f" - 置信度:{tv.confidence:.2%}")
print(f" - 推理:{tv.reasoning}")
print()
if result.outcome_variable:
ov = result.outcome_variable
print(f"结果变量 (Y):")
print(f" - 名称:{ov.name}")
print(f" - 类型:{ov.variable_type}")
print(f" - 置信度:{ov.confidence:.2%}")
print(f" - 推理:{ov.reasoning}")
print()
if result.control_variables:
print(f"协变量 (Control Variables):")
for cv in result.control_variables:
print(f" - {cv.name} (置信度:{cv.confidence:.2%})")
print(f" 推理:{cv.reasoning}")
print()
print(f"推理理由:{result.reasoning}")
print()
print("=" * 60)
print("完整 JSON 输出:")
print("=" * 60)
print(result.to_json())
except Exception as e:
print(f" - 解析失败:{str(e)}")
print()
print("请检查:")
print(" 1. 数据文件路径是否正确")