""" 报告生成模块 将因果分析各阶段结果整合为统一的 JSON 报告。 """ import json from typing import Any, Dict, List, Optional def generate_report( query_interpretation: Dict[str, str], causal_graph: Dict[str, Any], identification: Dict[str, Any], estimation: Dict[str, Any], extra_warnings: Optional[List[Dict[str, str]]] = None, ) -> Dict[str, Any]: """生成标准化 JSON 报告""" warnings = list(extra_warnings) if extra_warnings else [] warnings.extend(estimation.get("warnings", [])) warnings.append( { "type": "unobserved_confounding", "message": "可能存在未观测混杂,建议进行敏感性分析。", } ) report = { "query_interpretation": query_interpretation, "causal_graph": causal_graph, "identification": { "strategy": identification.get("strategy", "Backdoor Adjustment"), "adjustment_set": identification.get("adjustment_set", []), "reasoning": identification.get("reasoning", ""), }, "estimation": { "ATE_Outcome_Regression": estimation.get("ATE_Outcome_Regression"), "ATE_IPW": estimation.get("ATE_IPW"), "95%_CI": estimation.get("95%_CI"), "interpretation": estimation.get("interpretation", ""), }, "diagnostics": { "balance_check": estimation.get("balance_check", {}), "overlap_assumption": estimation.get("overlap_assumption", "未知"), "robustness": estimation.get("robustness", "未知"), }, "warnings": warnings, } return report def print_report(report: Dict[str, Any]) -> None: """美观地打印报告""" print("\n" + "=" * 60) print("因果推断分析报告") print("=" * 60) print(json.dumps(report, indent=2, ensure_ascii=False)) print("=" * 60)