60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
"""
|
|
报告生成模块
|
|
|
|
将因果分析各阶段结果整合为统一的 JSON 报告。
|
|
"""
|
|
|
|
import json
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
def generate_report(
|
|
query_interpretation: Dict[str, str],
|
|
causal_graph: Dict[str, Any],
|
|
identification: Dict[str, Any],
|
|
estimation: Dict[str, Any],
|
|
extra_warnings: Optional[List[Dict[str, str]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""生成标准化 JSON 报告"""
|
|
warnings = list(extra_warnings) if extra_warnings else []
|
|
warnings.extend(estimation.get("warnings", []))
|
|
|
|
warnings.append(
|
|
{
|
|
"type": "unobserved_confounding",
|
|
"message": "可能存在未观测混杂,建议进行敏感性分析。",
|
|
}
|
|
)
|
|
|
|
report = {
|
|
"query_interpretation": query_interpretation,
|
|
"causal_graph": causal_graph,
|
|
"identification": {
|
|
"strategy": identification.get("strategy", "Backdoor Adjustment"),
|
|
"adjustment_set": identification.get("adjustment_set", []),
|
|
"reasoning": identification.get("reasoning", ""),
|
|
},
|
|
"estimation": {
|
|
"ATE_Outcome_Regression": estimation.get("ATE_Outcome_Regression"),
|
|
"ATE_IPW": estimation.get("ATE_IPW"),
|
|
"95%_CI": estimation.get("95%_CI"),
|
|
"interpretation": estimation.get("interpretation", ""),
|
|
},
|
|
"diagnostics": {
|
|
"balance_check": estimation.get("balance_check", {}),
|
|
"overlap_assumption": estimation.get("overlap_assumption", "未知"),
|
|
"robustness": estimation.get("robustness", "未知"),
|
|
},
|
|
"warnings": warnings,
|
|
}
|
|
return report
|
|
|
|
|
|
def print_report(report: Dict[str, Any]) -> None:
|
|
"""美观地打印报告"""
|
|
print("\n" + "=" * 60)
|
|
print("因果推断分析报告")
|
|
print("=" * 60)
|
|
print(json.dumps(report, indent=2, ensure_ascii=False))
|
|
print("=" * 60)
|