2026-03-29 23:47:20 +08:00

187 lines
6.2 KiB
Python

"""
Causal Inference Agent Orchestrator
整合完整 pipeline 的核心代理类。
"""
from typing import Any, Dict, Optional
import pandas as pd
from causal_agent.analysis.causal_graph import (
build_local_graph,
find_adjustment_set,
find_backdoor_paths,
)
from causal_agent.analysis.estimation import estimate_ate
from causal_agent.analysis.reporting import generate_report, print_report
from causal_agent.analysis.screening import local_screen
from causal_agent.analysis.variable_parser import VariableParser
from causal_agent.core.config import AgentConfig
from causal_agent.core.data_loader import DataLoader, get_data_info
from causal_agent.core.llm_client import LLMClient
from causal_agent.logger import CausalAnalysisLogger
class CausalInferenceAgent:
"""通用因果推断 Agent"""
def __init__(self, config: Optional[AgentConfig] = None):
self.config = config or AgentConfig.from_env()
self.data_loader = DataLoader()
self.llm_client = LLMClient(
base_url=self.config.llm_base_url,
model=self.config.llm_model,
temperature=self.config.llm_temperature,
max_tokens=self.config.llm_max_tokens,
api_key=self.config.llm_api_key,
timeout=self.config.llm_timeout,
max_retries=self.config.llm_max_retries,
)
self.variable_parser = VariableParser(self.llm_client)
self.logger = CausalAnalysisLogger(self.config.log_path)
def analyze(
self, data_path: str, custom_prompt: Optional[str] = None
) -> Dict[str, Any]:
"""
执行完整的因果分析 pipeline。
Args:
data_path: 数据文件路径
custom_prompt: 可选的自定义提示词补充
Returns:
包含完整报告和中间结果的字典
"""
df = self.data_loader.load(data_path)
columns = list(df.columns)
data_info = get_data_info(df)
print(f"成功加载数据:{data_path}")
print(f"数据形状:{df.shape}")
print(f"列名:{columns}")
# Step 1: LLM 变量识别
print("\n[Step 1] LLM 变量识别...")
parsed = self.variable_parser.parse(columns, data_info, custom_prompt)
T = parsed["treatment"]
Y = parsed["outcome"]
tiers = parsed["time_tiers"]
print(f"处理变量:{T},结果变量:{Y}")
print(f"时间层级:{tiers}")
# Step 2: 快速相关性筛查
print("\n[Step 2] 快速相关性筛查...")
candidates = local_screen(
df,
T=T,
Y=Y,
excluded=[c for c in columns if tiers.get(c, 0) < 0], # 排除 id 类
corr_threshold=self.config.corr_threshold,
alpha=self.config.alpha,
)
print(f"候选混杂变量:{[c['var'] for c in candidates]}")
# Step 3: 因果图构建
print("\n[Step 3] 因果图构建...")
G = build_local_graph(
df, T, Y, candidates, tiers, corr_threshold=self.config.corr_threshold
)
edges = [
{"from": u, "to": v, "type": data.get("type", "unknown")}
for u, v, data in G.edges(data=True)
]
nodes = list(G.nodes())
print(f"图节点:{nodes}")
print(f"图边:{edges}")
# Step 4: 混杂识别
print("\n[Step 4] 后门路径识别...")
backdoor_paths = find_backdoor_paths(G, T, Y)
print(f"后门路径:{backdoor_paths}")
adjustment_set, reasoning = find_adjustment_set(G, T, Y, backdoor_paths)
print(f"调整集:{adjustment_set}")
print(f"调整理由:{reasoning}")
# Step 5: 效应估计
print("\n[Step 5] 效应估计...")
estimation_result = estimate_ate(
df,
T,
Y,
adjustment_set,
compute_ci=True,
n_bootstrap=self.config.bootstrap_iterations,
alpha=self.config.bootstrap_alpha,
)
print(f"ATE (OR): {estimation_result['ATE_Outcome_Regression']}")
print(f"ATE (IPW): {estimation_result['ATE_IPW']}")
print(f"95% CI: {estimation_result['95%_CI']}")
# Step 6: 生成报告
print("\n[Step 6] 生成报告...")
causal_graph_info = {
"nodes": nodes,
"edges": edges,
"backdoor_paths": backdoor_paths,
}
identification_info = {
"strategy": "Backdoor Adjustment",
"adjustment_set": adjustment_set,
"reasoning": reasoning,
}
query_interpretation = {
"treatment": T,
"outcome": Y,
"estimand": "ATE",
}
report = generate_report(
query_interpretation=query_interpretation,
causal_graph=causal_graph_info,
identification=identification_info,
estimation=estimation_result,
)
print_report(report)
# 记录日志
parameters = {
"data_path": data_path,
"sample_size": len(df),
"variables": columns,
"treatment_variable": T,
"outcome_variable": Y,
"time_tiers": tiers,
"llm_params": {
"base_url": self.config.llm_base_url,
"model": self.config.llm_model,
"temperature": self.config.llm_temperature,
"max_tokens": self.config.llm_max_tokens,
},
"candidates": candidates,
"causal_graph": causal_graph_info,
"identification": identification_info,
"estimation": {
k: v
for k, v in estimation_result.items()
if k not in ("warnings", "balance_check")
},
"log_path": self.logger.log_path,
}
self.logger.log_analysis(
VariableParser.SYSTEM_PROMPT,
self.variable_parser._build_user_prompt(columns, data_info, custom_prompt),
str(parsed),
parameters,
report,
)
return {
"id": f"{self.logger.analysis_count:03d}",
"report": report,
"parameters": parameters,
}