""" Causal Inference Agent Orchestrator 整合完整 pipeline 的核心代理类。 """ from typing import Any, Dict, Optional import pandas as pd from causal_agent.analysis.causal_graph import ( build_local_graph, find_adjustment_set, find_backdoor_paths, ) from causal_agent.analysis.estimation import estimate_ate from causal_agent.analysis.reporting import generate_report, print_report from causal_agent.analysis.screening import local_screen from causal_agent.analysis.variable_parser import VariableParser from causal_agent.core.config import AgentConfig from causal_agent.core.data_loader import DataLoader, get_data_info from causal_agent.core.llm_client import LLMClient from causal_agent.logger import CausalAnalysisLogger class CausalInferenceAgent: """通用因果推断 Agent""" def __init__(self, config: Optional[AgentConfig] = None): self.config = config or AgentConfig.from_env() self.data_loader = DataLoader() self.llm_client = LLMClient( base_url=self.config.llm_base_url, model=self.config.llm_model, temperature=self.config.llm_temperature, max_tokens=self.config.llm_max_tokens, api_key=self.config.llm_api_key, timeout=self.config.llm_timeout, max_retries=self.config.llm_max_retries, ) self.variable_parser = VariableParser(self.llm_client) self.logger = CausalAnalysisLogger(self.config.log_path) def analyze( self, data_path: str, custom_prompt: Optional[str] = None ) -> Dict[str, Any]: """ 执行完整的因果分析 pipeline。 Args: data_path: 数据文件路径 custom_prompt: 可选的自定义提示词补充 Returns: 包含完整报告和中间结果的字典 """ df = self.data_loader.load(data_path) columns = list(df.columns) data_info = get_data_info(df) print(f"成功加载数据:{data_path}") print(f"数据形状:{df.shape}") print(f"列名:{columns}") # Step 1: LLM 变量识别 print("\n[Step 1] LLM 变量识别...") parsed = self.variable_parser.parse(columns, data_info, custom_prompt) T = parsed["treatment"] Y = parsed["outcome"] tiers = parsed["time_tiers"] print(f"处理变量:{T},结果变量:{Y}") print(f"时间层级:{tiers}") # Step 2: 快速相关性筛查 print("\n[Step 2] 快速相关性筛查...") candidates = local_screen( df, T=T, Y=Y, excluded=[c for c in columns if tiers.get(c, 0) < 0], # 排除 id 类 corr_threshold=self.config.corr_threshold, alpha=self.config.alpha, ) print(f"候选混杂变量:{[c['var'] for c in candidates]}") # Step 3: 因果图构建 print("\n[Step 3] 因果图构建...") G = build_local_graph( df, T, Y, candidates, tiers, corr_threshold=self.config.corr_threshold ) edges = [ {"from": u, "to": v, "type": data.get("type", "unknown")} for u, v, data in G.edges(data=True) ] nodes = list(G.nodes()) print(f"图节点:{nodes}") print(f"图边:{edges}") # Step 4: 混杂识别 print("\n[Step 4] 后门路径识别...") backdoor_paths = find_backdoor_paths(G, T, Y) print(f"后门路径:{backdoor_paths}") adjustment_set, reasoning = find_adjustment_set(G, T, Y, backdoor_paths) print(f"调整集:{adjustment_set}") print(f"调整理由:{reasoning}") # Step 5: 效应估计 print("\n[Step 5] 效应估计...") estimation_result = estimate_ate( df, T, Y, adjustment_set, compute_ci=True, n_bootstrap=self.config.bootstrap_iterations, alpha=self.config.bootstrap_alpha, ) print(f"ATE (OR): {estimation_result['ATE_Outcome_Regression']}") print(f"ATE (IPW): {estimation_result['ATE_IPW']}") print(f"95% CI: {estimation_result['95%_CI']}") # Step 6: 生成报告 print("\n[Step 6] 生成报告...") causal_graph_info = { "nodes": nodes, "edges": edges, "backdoor_paths": backdoor_paths, } identification_info = { "strategy": "Backdoor Adjustment", "adjustment_set": adjustment_set, "reasoning": reasoning, } query_interpretation = { "treatment": T, "outcome": Y, "estimand": "ATE", } report = generate_report( query_interpretation=query_interpretation, causal_graph=causal_graph_info, identification=identification_info, estimation=estimation_result, ) print_report(report) # 记录日志 parameters = { "data_path": data_path, "sample_size": len(df), "variables": columns, "treatment_variable": T, "outcome_variable": Y, "time_tiers": tiers, "llm_params": { "base_url": self.config.llm_base_url, "model": self.config.llm_model, "temperature": self.config.llm_temperature, "max_tokens": self.config.llm_max_tokens, }, "candidates": candidates, "causal_graph": causal_graph_info, "identification": identification_info, "estimation": { k: v for k, v in estimation_result.items() if k not in ("warnings", "balance_check") }, "log_path": self.logger.log_path, } self.logger.log_analysis( VariableParser.SYSTEM_PROMPT, self.variable_parser._build_user_prompt(columns, data_info, custom_prompt), str(parsed), parameters, report, ) return { "id": f"{self.logger.analysis_count:03d}", "report": report, "parameters": parameters, }