""" 因果推断解析器模块 (Causal Parser) 该模块提供因果推断解析功能,使用启发式规则解析用户的自然语言问题, 识别因果推断中的关键要素(处理变量 T、结果变量 Y、协变量)。 """ import json from dataclasses import dataclass, asdict, field from typing import List, Optional, Dict, Any, Union from datetime import datetime # 支持直接运行和作为模块导入两种方式 if __name__ == "__main__": import sys import os sys.path.insert(0, os.path.dirname(__file__)) from file_query_tool import FileQueryTool, FileMetadata else: from .file_query_tool import FileQueryTool, FileMetadata @dataclass class CausalInferenceRequest: """因果推断请求数据类""" user_question: str # 用户的自然语言问题 file_path: str # 数据文件路径 read_rows: int = 100 # 读取的行数 def to_dict(self) -> Dict[str, Any]: """转换为字典""" return asdict(self) @dataclass class VariableInfo: """变量信息""" name: str # 变量名 variable_type: str # 变量类型(treatment/outcome/control) confidence: float # 置信度(0-1) reasoning: str # 推理理由 data_type: Optional[str] = None # 数据类型 is_binary: Optional[bool] = None # 是否为二值变量 def to_dict(self) -> Dict[str, Any]: """转换为字典""" return asdict(self) @dataclass class CausalInferenceResult: """因果推断结果数据类""" original_question: str # 用户原始问题 file_path: str # 文件路径 treatment_variable: Optional[VariableInfo] = None # 处理变量 T outcome_variable: Optional[VariableInfo] = None # 结果变量 Y control_variables: List[VariableInfo] = field(default_factory=list) # 协变量列表 confidence_score: float = 0.0 # 整体置信度评分 reasoning: str = "" # 推理理由 alternative_matches: List[Dict[str, Any]] = field(default_factory=list) # 其他可能的匹配 metadata: Dict[str, Any] = field(default_factory=dict) # 元数据 def to_dict(self) -> Dict[str, Any]: """转换为字典""" result = asdict(self) if self.treatment_variable: result["treatment_variable"] = self.treatment_variable.to_dict() if self.outcome_variable: result["outcome_variable"] = self.outcome_variable.to_dict() result["control_variables"] = [v.to_dict() for v in self.control_variables] return result def to_json(self, indent: int = 2) -> str: """转换为 JSON 字符串""" return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent) class CausalParser: """ 因果推断解析器 使用 LLM 解析用户的自然语言问题,识别因果推断中的关键要素。 Attributes: file_query_tool: 文件查询工具实例 llm_client: LLM 客户端(可选) default_read_rows: 默认读取的行数 Example: >>> parser = CausalParser() >>> request = CausalInferenceRequest( ... user_question="教育对收入的影响是什么?", ... file_path="data/education_data.csv" ... ) >>> result = parser.parse(request) >>> print(result.to_json()) """ # 内置的提示词模板 DEFAULT_SYSTEM_PROMPT = """你是一位因果推断专家,负责从用户的自然语言问题中识别因果推断的关键要素。 任务: 1. 分析用户的问题,理解其因果推断意图 2. 根据提供的数据表结构,识别以下变量: - 处理变量 (T/Treatment): 被操纵或观察的干预措施 - 结果变量 (Y/Outcome): 受处理影响的指标 - 协变量 (Control): 需要控制的混淆因素 输出要求: - 只输出 JSON 格式,不要包含其他内容 - 对于每个变量,提供名称、类型、置信度和推理理由 - 如果无法确定某个变量,将其设为 null 或空列表 - 置信度评分范围:0.0-1.0 变量类型说明: - treatment: 处理变量,通常是实验中的干预措施或观察中的暴露因素 - outcome: 结果变量,是我们要研究的效应指标 - control: 协变量/控制变量,用于控制混淆因素 示例输出格式: {{ "treatment_variable": {{ "name": "treatment", "variable_type": "treatment", "confidence": 0.95, "reasoning": "该变量表示是否接受干预" }}, "outcome_variable": {{ "name": "outcome", "variable_type": "outcome", "confidence": 0.90, "reasoning": "该变量表示干预后的结果" }}, "control_variables": [ {{ "name": "age", "variable_type": "control", "confidence": 0.85, "reasoning": "年龄可能影响结果,需要控制" }} ], "confidence_score": 0.90, "reasoning": "根据问题描述,这是一个典型的因果推断问题..." }}""" DEFAULT_USER_PROMPT_TEMPLATE = """用户问题:{question} 数据表结构: {table_info} 请识别因果推断要素并输出 JSON 格式的结果。""" def __init__(self, file_query_tool: Optional[FileQueryTool] = None, default_read_rows: int = 100): """ 初始化因果推断解析器 Args: file_query_tool: 文件查询工具实例,如果为 None 则创建默认实例 default_read_rows: 默认读取的行数 """ self.file_query_tool = file_query_tool or FileQueryTool() self.default_read_rows = default_read_rows def parse(self, request: CausalInferenceRequest) -> CausalInferenceResult: """ 解析因果推断请求 Args: request: 因果推断请求对象 Returns: CausalInferenceResult: 解析结果 """ # 获取文件信息 file_metadata = self.file_query_tool.query( request.file_path, request.read_rows ) # 使用启发式方法进行解析 result = self._heuristic_parse( request.user_question, file_metadata ) return result def _build_table_info(self, metadata: FileMetadata) -> str: """ 构建表头信息字符串 Args: metadata: 文件元数据 Returns: str: 格式化的表头信息 """ lines = [ f"文件名:{metadata.file_name}", f"总行数:{metadata.total_rows}", f"总列数:{metadata.total_columns}", "", "列信息:" ] if metadata.columns is None: return "\n".join(lines) for col in metadata.columns: lines.append(f" - {col.name}: {col.dtype} (非空:{col.non_null_count}, 空值:{col.null_count}, 唯一值:{col.unique_count})") if col.sample_values: sample_str = ", ".join(str(v) for v in col.sample_values[:3]) lines.append(f" 样本值:{sample_str}") return "\n".join(lines) def _heuristic_parse(self, question: str, metadata: FileMetadata) -> CausalInferenceResult: """ 使用启发式规则进行因果推断解析 Args: question: 用户问题 metadata: 文件元数据 Returns: CausalInferenceResult: 解析结果 """ # 启发式规则识别变量 treatment_var = self._heuristic_identify_treatment(question, metadata) outcome_var = self._heuristic_identify_outcome(question, metadata) control_vars = self._heuristic_identify_controls(question, metadata, treatment_var, outcome_var) # 计算置信度 confidence = self._calculate_confidence(treatment_var, outcome_var, control_vars) # 构建推理理由 reasoning = self._build_reasoning(question, treatment_var, outcome_var, control_vars) # 查找其他可能的匹配 alternative_matches = self._find_alternative_matches(metadata, treatment_var, outcome_var) # 构建结果对象 result = CausalInferenceResult( original_question=question, file_path=metadata.file_path, treatment_variable=treatment_var, outcome_variable=outcome_var, control_variables=control_vars, confidence_score=confidence, reasoning=reasoning, alternative_matches=alternative_matches, metadata={ "parsed_at": datetime.now().isoformat(), "file_info": { "file_name": metadata.file_name, "total_rows": metadata.total_rows, "total_columns": metadata.total_columns } } ) return result """ 模拟 LLM 解析(用于测试) 基于启发式规则进行解析,不依赖真实的 LLM。 Args: question: 用户问题 metadata: 文件元数据 Returns: str: 模拟的 JSON 响应 """ # 启发式规则识别变量 treatment_var = self._heuristic_identify_treatment(question, metadata) outcome_var = self._heuristic_identify_outcome(question, metadata) control_vars = self._heuristic_identify_controls(question, metadata, treatment_var, outcome_var) # 计算置信度 confidence = self._calculate_confidence(treatment_var, outcome_var, control_vars) # 构建推理理由 reasoning = self._build_reasoning(question, treatment_var, outcome_var, control_vars) # 构建结果 result = { "treatment_variable": treatment_var.to_dict() if treatment_var else None, "outcome_variable": outcome_var.to_dict() if outcome_var else None, "control_variables": [v.to_dict() for v in control_vars], "confidence_score": confidence, "reasoning": reasoning, "alternative_matches": self._find_alternative_matches(metadata, treatment_var, outcome_var) } return json.dumps(result, ensure_ascii=False) def _heuristic_identify_treatment(self, question: str, metadata: FileMetadata) -> Optional[VariableInfo]: """ 启发式识别处理变量 Args: question: 用户问题 metadata: 文件元数据 Returns: Optional[VariableInfo]: 处理变量信息 """ # 常见处理变量关键词 treatment_keywords = [ "处理", "干预", "治疗", "实验组", "对照组", "treatment", "treat", "干预组", "实验", "exposure", "exposed", "assigned", "group" ] # 检查列名是否包含处理变量关键词 if metadata.columns is None: return None for col in metadata.columns: col_name_lower = col.name.lower() for keyword in treatment_keywords: if keyword in col_name_lower: return VariableInfo( name=col.name, variable_type="treatment", confidence=0.85, reasoning=f"列名包含处理变量关键词:{keyword}", data_type=col.dtype, is_binary=col.unique_count == 2 if col.unique_count else None ) # 检查是否有明显的二值变量 for col in metadata.columns: if col.unique_count == 2 and col.non_null_count == metadata.total_rows: return VariableInfo( name=col.name, variable_type="treatment", confidence=0.70, reasoning="发现二值变量,可能是处理变量", data_type=col.dtype, is_binary=True ) return None def _heuristic_identify_outcome(self, question: str, metadata: FileMetadata) -> Optional[VariableInfo]: """ 启发式识别结果变量 Args: question: 用户问题 metadata: 文件元数据 Returns: Optional[VariableInfo]: 结果变量信息 """ # 常见结果变量关键词 outcome_keywords = [ "结果", "影响", "效应", "收入", "工资", "成绩", "分数", "outcome", "result", "effect", "impact", "dependent", "y", "目标", "指标" ] # 检查列名是否包含结果变量关键词 if metadata.columns is None: return None for col in metadata.columns: col_name_lower = col.name.lower() for keyword in outcome_keywords: if keyword in col_name_lower: return VariableInfo( name=col.name, variable_type="outcome", confidence=0.85, reasoning=f"列名包含结果变量关键词:{keyword}", data_type=col.dtype, is_binary=False ) # 检查是否有数值型列(通常是结果变量) for col in metadata.columns: if col.dtype in ['float64', 'int64', 'float32', 'int32']: return VariableInfo( name=col.name, variable_type="outcome", confidence=0.65, reasoning="发现数值型列,可能是结果变量", data_type=col.dtype, is_binary=False ) return None def _heuristic_identify_controls(self, question: str, metadata: FileMetadata, treatment: Optional[VariableInfo], outcome: Optional[VariableInfo]) -> List[VariableInfo]: """ 启发式识别协变量 Args: question: 用户问题 metadata: 文件元数据 treatment: 处理变量 outcome: 结果变量 Returns: List[VariableInfo]: 协变量列表 """ controls = [] treated_names = {treatment.name} if treatment else set() outcome_names = {outcome.name} if outcome else set() # 常见协变量关键词 control_keywords = [ "年龄", "性别", "教育", "经验", "控制", "covariate", "control", "confounder", "confounding", "特征", "变量", "demographic" ] if metadata.columns is None: return [] for col in metadata.columns: # 跳过已识别的变量 if col.name in treated_names or col.name in outcome_names: continue col_name_lower = col.name.lower() # 检查是否包含协变量关键词 is_control = False reasoning = "" for keyword in control_keywords: if keyword in col_name_lower: is_control = True reasoning = f"列名包含协变量关键词:{keyword}" break # 如果没有关键词,但也不是处理或结果变量,可能是协变量 if not is_control: is_control = True reasoning = "可能是协变量,用于控制混淆因素" controls.append(VariableInfo( name=col.name, variable_type="control", confidence=0.60 if is_control else 0.40, reasoning=reasoning, data_type=col.dtype, is_binary=col.unique_count == 2 if col.unique_count else None )) return controls def _calculate_confidence(self, treatment: Optional[VariableInfo], outcome: Optional[VariableInfo], controls: List[VariableInfo]) -> float: """ 计算整体置信度 Args: treatment: 处理变量 outcome: 结果变量 controls: 协变量列表 Returns: float: 置信度评分 (0-1) """ if treatment is None or outcome is None: return 0.3 # 基础置信度 base_confidence = (treatment.confidence + outcome.confidence) / 2 # 协变量数量影响 if len(controls) > 0: avg_control_confidence = sum(c.confidence for c in controls) / len(controls) base_confidence = (base_confidence + avg_control_confidence) / 2 return min(base_confidence, 1.0) def _build_reasoning(self, question: str, treatment: Optional[VariableInfo], outcome: Optional[VariableInfo], controls: List[VariableInfo]) -> str: """ 构建推理理由 Args: question: 用户问题 treatment: 处理变量 outcome: 结果变量 controls: 协变量列表 Returns: str: 推理理由 """ reasons = [] if treatment: reasons.append(f"处理变量 '{treatment.name}' 被识别为干预措施") if outcome: reasons.append(f"结果变量 '{outcome.name}' 被识别为效应指标") if controls: control_names = [c.name for c in controls[:3]] if len(controls) > 3: control_names.append(f"等 {len(controls) - 3} 个其他变量") reasons.append(f"协变量 {', '.join(control_names)} 被识别为需要控制的混淆因素") return ";".join(reasons) if reasons else "未能充分识别因果推断要素" def _find_alternative_matches(self, metadata: FileMetadata, treatment: Optional[VariableInfo], outcome: Optional[VariableInfo]) -> List[Dict[str, Any]]: """ 查找其他可能的匹配 Args: metadata: 文件元数据 treatment: 处理变量 outcome: 结果变量 Returns: List[Dict[str, Any]]: 其他可能的匹配 """ alternatives = [] if metadata.columns is None: return [] # 查找其他可能的处理变量 if treatment: for col in metadata.columns: if col.name != treatment.name: alternatives.append({ "variable_type": "treatment", "name": col.name, "confidence": 0.3, "reasoning": "可能是替代的处理变量" }) # 查找其他可能的结果变量 if outcome: for col in metadata.columns: if col.name != outcome.name: alternatives.append({ "variable_type": "outcome", "name": col.name, "confidence": 0.3, "reasoning": "可能是替代的结果变量" }) return alternatives[:5] # 最多返回 5 个替代匹配 def _parse_llm_response(self, response: str, request: CausalInferenceRequest, metadata: FileMetadata) -> CausalInferenceResult: """ 解析 LLM 响应 Args: response: LLM 响应 request: 原始请求 metadata: 文件元数据 Returns: CausalInferenceResult: 解析结果 """ try: # 清理响应,提取 JSON 部分 response = response.strip() if response.startswith("```json"): response = response[7:] if response.startswith("```"): response = response[1:] if response.endswith("```"): response = response[:-3] response = response.strip() # 解析 JSON data = json.loads(response) # 构建结果对象 result = CausalInferenceResult( original_question=request.user_question, file_path=request.file_path, treatment_variable=VariableInfo(**data.get("treatment_variable")) if data.get("treatment_variable") else None, outcome_variable=VariableInfo(**data.get("outcome_variable")) if data.get("outcome_variable") else None, control_variables=[VariableInfo(**v) for v in data.get("control_variables", [])], confidence_score=data.get("confidence_score", 0.0), reasoning=data.get("reasoning", ""), alternative_matches=data.get("alternative_matches", []), metadata={ "parsed_at": datetime.now().isoformat(), "file_info": { "file_name": metadata.file_name, "total_rows": metadata.total_rows, "total_columns": metadata.total_columns } } ) return result except json.JSONDecodeError as e: # 如果 JSON 解析失败,返回错误结果 return CausalInferenceResult( original_question=request.user_question, file_path=request.file_path, confidence_score=0.0, reasoning=f"LLM 响应解析失败:{str(e)}", metadata={ "parsed_at": datetime.now().isoformat(), "error": str(e), "raw_response": response[:500] } ) def parse_simple(self, question: str, file_path: str) -> CausalInferenceResult: """ 简化的解析方法 Args: question: 用户问题 file_path: 文件路径 Returns: CausalInferenceResult: 解析结果 """ request = CausalInferenceRequest( user_question=question, file_path=file_path ) return self.parse(request) if __name__ == "__main__": """ 因果推断解析器模块的简单测试示例 """ from file_query_tool import FileQueryTool print("=" * 60) print("CausalParser 简单测试示例") print("=" * 60) print() # 创建解析器实例(不使用 LLM 客户端) print("步骤 1: 创建 CausalParser 实例...") parser = CausalParser() print(" - CausalParser 已初始化") print() # 准备示例数据文件 print("步骤 2: 准备示例数据文件...") sample_file = "data/output/test_simple_ate.csv" print(f" - 文件路径:{sample_file}") print() # 创建解析请求 print("步骤 3: 创建解析请求...") sample_question = "教育对收入的影响是什么?" request = CausalInferenceRequest( user_question=sample_question, file_path=sample_file ) print(f" - 问题:{sample_question}") print() # 执行解析 print("步骤 4: 执行解析...") try: result = parser.parse(request) print(" - 解析成功!") print() # 输出解析结果 print("步骤 5: 解析结果:") print("-" * 60) print(f"原始问题:{result.original_question}") print(f"文件路径:{result.file_path}") print(f"整体置信度:{result.confidence_score:.2%}") print() if result.treatment_variable: tv = result.treatment_variable print(f"处理变量 (T):") print(f" - 名称:{tv.name}") print(f" - 类型:{tv.variable_type}") print(f" - 置信度:{tv.confidence:.2%}") print(f" - 推理:{tv.reasoning}") print() if result.outcome_variable: ov = result.outcome_variable print(f"结果变量 (Y):") print(f" - 名称:{ov.name}") print(f" - 类型:{ov.variable_type}") print(f" - 置信度:{ov.confidence:.2%}") print(f" - 推理:{ov.reasoning}") print() if result.control_variables: print(f"协变量 (Control Variables):") for cv in result.control_variables: print(f" - {cv.name} (置信度:{cv.confidence:.2%})") print(f" 推理:{cv.reasoning}") print() print(f"推理理由:{result.reasoning}") print() print("=" * 60) print("完整 JSON 输出:") print("=" * 60) print(result.to_json()) except Exception as e: print(f" - 解析失败:{str(e)}") print() print("请检查:") print(" 1. 数据文件路径是否正确")