187 lines
6.2 KiB
Python
187 lines
6.2 KiB
Python
"""
|
|
Causal Inference Agent Orchestrator
|
|
|
|
整合完整 pipeline 的核心代理类。
|
|
"""
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
import pandas as pd
|
|
|
|
from causal_agent.analysis.causal_graph import (
|
|
build_local_graph,
|
|
find_adjustment_set,
|
|
find_backdoor_paths,
|
|
)
|
|
from causal_agent.analysis.estimation import estimate_ate
|
|
from causal_agent.analysis.reporting import generate_report, print_report
|
|
from causal_agent.analysis.screening import local_screen
|
|
from causal_agent.analysis.variable_parser import VariableParser
|
|
from causal_agent.core.config import AgentConfig
|
|
from causal_agent.core.data_loader import DataLoader, get_data_info
|
|
from causal_agent.core.llm_client import LLMClient
|
|
from causal_agent.logger import CausalAnalysisLogger
|
|
|
|
|
|
class CausalInferenceAgent:
|
|
"""通用因果推断 Agent"""
|
|
|
|
def __init__(self, config: Optional[AgentConfig] = None):
|
|
self.config = config or AgentConfig.from_env()
|
|
self.data_loader = DataLoader()
|
|
self.llm_client = LLMClient(
|
|
base_url=self.config.llm_base_url,
|
|
model=self.config.llm_model,
|
|
temperature=self.config.llm_temperature,
|
|
max_tokens=self.config.llm_max_tokens,
|
|
api_key=self.config.llm_api_key,
|
|
timeout=self.config.llm_timeout,
|
|
max_retries=self.config.llm_max_retries,
|
|
)
|
|
self.variable_parser = VariableParser(self.llm_client)
|
|
self.logger = CausalAnalysisLogger(self.config.log_path)
|
|
|
|
def analyze(
|
|
self, data_path: str, custom_prompt: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
执行完整的因果分析 pipeline。
|
|
|
|
Args:
|
|
data_path: 数据文件路径
|
|
custom_prompt: 可选的自定义提示词补充
|
|
|
|
Returns:
|
|
包含完整报告和中间结果的字典
|
|
"""
|
|
df = self.data_loader.load(data_path)
|
|
columns = list(df.columns)
|
|
data_info = get_data_info(df)
|
|
|
|
print(f"成功加载数据:{data_path}")
|
|
print(f"数据形状:{df.shape}")
|
|
print(f"列名:{columns}")
|
|
|
|
# Step 1: LLM 变量识别
|
|
print("\n[Step 1] LLM 变量识别...")
|
|
parsed = self.variable_parser.parse(columns, data_info, custom_prompt)
|
|
T = parsed["treatment"]
|
|
Y = parsed["outcome"]
|
|
tiers = parsed["time_tiers"]
|
|
print(f"处理变量:{T},结果变量:{Y}")
|
|
print(f"时间层级:{tiers}")
|
|
|
|
# Step 2: 快速相关性筛查
|
|
print("\n[Step 2] 快速相关性筛查...")
|
|
candidates = local_screen(
|
|
df,
|
|
T=T,
|
|
Y=Y,
|
|
excluded=[c for c in columns if tiers.get(c, 0) < 0], # 排除 id 类
|
|
corr_threshold=self.config.corr_threshold,
|
|
alpha=self.config.alpha,
|
|
)
|
|
print(f"候选混杂变量:{[c['var'] for c in candidates]}")
|
|
|
|
# Step 3: 因果图构建
|
|
print("\n[Step 3] 因果图构建...")
|
|
G = build_local_graph(
|
|
df, T, Y, candidates, tiers, corr_threshold=self.config.corr_threshold
|
|
)
|
|
edges = [
|
|
{"from": u, "to": v, "type": data.get("type", "unknown")}
|
|
for u, v, data in G.edges(data=True)
|
|
]
|
|
nodes = list(G.nodes())
|
|
print(f"图节点:{nodes}")
|
|
print(f"图边:{edges}")
|
|
|
|
# Step 4: 混杂识别
|
|
print("\n[Step 4] 后门路径识别...")
|
|
backdoor_paths = find_backdoor_paths(G, T, Y)
|
|
print(f"后门路径:{backdoor_paths}")
|
|
|
|
adjustment_set, reasoning = find_adjustment_set(G, T, Y, backdoor_paths)
|
|
print(f"调整集:{adjustment_set}")
|
|
print(f"调整理由:{reasoning}")
|
|
|
|
# Step 5: 效应估计
|
|
print("\n[Step 5] 效应估计...")
|
|
estimation_result = estimate_ate(
|
|
df,
|
|
T,
|
|
Y,
|
|
adjustment_set,
|
|
compute_ci=True,
|
|
n_bootstrap=self.config.bootstrap_iterations,
|
|
alpha=self.config.bootstrap_alpha,
|
|
)
|
|
print(f"ATE (OR): {estimation_result['ATE_Outcome_Regression']}")
|
|
print(f"ATE (IPW): {estimation_result['ATE_IPW']}")
|
|
print(f"95% CI: {estimation_result['95%_CI']}")
|
|
|
|
# Step 6: 生成报告
|
|
print("\n[Step 6] 生成报告...")
|
|
causal_graph_info = {
|
|
"nodes": nodes,
|
|
"edges": edges,
|
|
"backdoor_paths": backdoor_paths,
|
|
}
|
|
identification_info = {
|
|
"strategy": "Backdoor Adjustment",
|
|
"adjustment_set": adjustment_set,
|
|
"reasoning": reasoning,
|
|
}
|
|
query_interpretation = {
|
|
"treatment": T,
|
|
"outcome": Y,
|
|
"estimand": "ATE",
|
|
}
|
|
|
|
report = generate_report(
|
|
query_interpretation=query_interpretation,
|
|
causal_graph=causal_graph_info,
|
|
identification=identification_info,
|
|
estimation=estimation_result,
|
|
)
|
|
print_report(report)
|
|
|
|
# 记录日志
|
|
parameters = {
|
|
"data_path": data_path,
|
|
"sample_size": len(df),
|
|
"variables": columns,
|
|
"treatment_variable": T,
|
|
"outcome_variable": Y,
|
|
"time_tiers": tiers,
|
|
"llm_params": {
|
|
"base_url": self.config.llm_base_url,
|
|
"model": self.config.llm_model,
|
|
"temperature": self.config.llm_temperature,
|
|
"max_tokens": self.config.llm_max_tokens,
|
|
},
|
|
"candidates": candidates,
|
|
"causal_graph": causal_graph_info,
|
|
"identification": identification_info,
|
|
"estimation": {
|
|
k: v
|
|
for k, v in estimation_result.items()
|
|
if k not in ("warnings", "balance_check")
|
|
},
|
|
"log_path": self.logger.log_path,
|
|
}
|
|
|
|
self.logger.log_analysis(
|
|
VariableParser.SYSTEM_PROMPT,
|
|
self.variable_parser._build_user_prompt(columns, data_info, custom_prompt),
|
|
str(parsed),
|
|
parameters,
|
|
report,
|
|
)
|
|
|
|
return {
|
|
"id": f"{self.logger.analysis_count:03d}",
|
|
"report": report,
|
|
"parameters": parameters,
|
|
}
|