2026-03-29 23:47:20 +08:00

60 lines
1.9 KiB
Python

"""
报告生成模块
将因果分析各阶段结果整合为统一的 JSON 报告。
"""
import json
from typing import Any, Dict, List, Optional
def generate_report(
query_interpretation: Dict[str, str],
causal_graph: Dict[str, Any],
identification: Dict[str, Any],
estimation: Dict[str, Any],
extra_warnings: Optional[List[Dict[str, str]]] = None,
) -> Dict[str, Any]:
"""生成标准化 JSON 报告"""
warnings = list(extra_warnings) if extra_warnings else []
warnings.extend(estimation.get("warnings", []))
warnings.append(
{
"type": "unobserved_confounding",
"message": "可能存在未观测混杂,建议进行敏感性分析。",
}
)
report = {
"query_interpretation": query_interpretation,
"causal_graph": causal_graph,
"identification": {
"strategy": identification.get("strategy", "Backdoor Adjustment"),
"adjustment_set": identification.get("adjustment_set", []),
"reasoning": identification.get("reasoning", ""),
},
"estimation": {
"ATE_Outcome_Regression": estimation.get("ATE_Outcome_Regression"),
"ATE_IPW": estimation.get("ATE_IPW"),
"95%_CI": estimation.get("95%_CI"),
"interpretation": estimation.get("interpretation", ""),
},
"diagnostics": {
"balance_check": estimation.get("balance_check", {}),
"overlap_assumption": estimation.get("overlap_assumption", "未知"),
"robustness": estimation.get("robustness", "未知"),
},
"warnings": warnings,
}
return report
def print_report(report: Dict[str, Any]) -> None:
"""美观地打印报告"""
print("\n" + "=" * 60)
print("因果推断分析报告")
print("=" * 60)
print(json.dumps(report, indent=2, ensure_ascii=False))
print("=" * 60)