2026-03-29 23:47:20 +08:00

135 lines
3.8 KiB
Python

"""
因果图构建与后门准则模块
功能:
1. 基于时间层级和统计相关性构建混合智能因果图
2. 自动发现后门路径
3. 基于后门准则寻找最小调整集
"""
from typing import Any, Dict, List, Optional, Tuple
import networkx as nx
import pandas as pd
from scipy.stats import spearmanr
def build_local_graph(
df: pd.DataFrame,
T: str,
Y: str,
candidates: List[Dict[str, Any]],
tiers: Dict[str, int],
corr_threshold: float = 0.1,
) -> nx.DiGraph:
"""
构建局部因果图。
Args:
df: 输入数据框
T: 处理变量
Y: 结果变量
candidates: 候选混杂变量列表
tiers: 时间层级字典
corr_threshold: 相关性阈值
Returns:
networkx.DiGraph 对象
"""
G = nx.DiGraph()
relevant_vars = [T, Y] + [c["var"] for c in candidates]
G.add_nodes_from(relevant_vars)
for u in relevant_vars:
for v in relevant_vars:
if u == v:
continue
tier_u = tiers.get(u, -1)
tier_v = tiers.get(v, -1)
if tier_u < tier_v:
corr, _ = spearmanr(df[u], df[v])
if abs(corr) > corr_threshold:
edge_type = "confounding" if v in [T, Y] else "temporal"
G.add_edge(u, v, type=edge_type, confidence=1.0)
G.add_edge(T, Y, type="hypothesized", confidence="research_question")
return G
def find_backdoor_paths(G: nx.DiGraph, T: str, Y: str) -> List[str]:
"""寻找从 T 到 Y 的所有后门路径"""
UG = G.to_undirected()
try:
paths = list(nx.all_simple_paths(UG, source=T, target=Y, cutoff=5))
except nx.NodeNotFound:
paths = []
backdoor_paths = []
for path in paths:
if len(path) < 2:
continue
second_node = path[1]
if G.has_edge(second_node, T):
formatted = _format_path(G, path)
backdoor_paths.append(formatted)
seen = set()
unique_paths = []
for p in backdoor_paths:
if p not in seen:
seen.add(p)
unique_paths.append(p)
return unique_paths
def _format_path(G: nx.DiGraph, path: List[str]) -> str:
"""将无向路径格式化为带方向箭头的字符串"""
parts = [path[0]]
for i in range(len(path) - 1):
u, v = path[i], path[i + 1]
if G.has_edge(u, v) and G.has_edge(v, u):
parts.append(f"<-> {v}")
elif G.has_edge(u, v):
parts.append(f"-> {v}")
elif G.has_edge(v, u):
parts.append(f"<- {v}")
else:
parts.append(f"-- {v}")
return " ".join(parts)
def find_adjustment_set(
G: nx.DiGraph, T: str, Y: str, backdoor_paths: Optional[List[str]] = None
) -> Tuple[List[str], str]:
"""基于后门准则寻找最小调整集"""
if backdoor_paths is None:
backdoor_paths = find_backdoor_paths(G, T, Y)
if not backdoor_paths:
return [], "未发现后门路径,无需额外调整变量即可识别因果效应。"
adjustment_candidates = set()
for path_str in backdoor_paths:
nodes = _parse_path_nodes(path_str)
if len(nodes) >= 2:
adjustment_candidates.add(nodes[1])
safe_adjustment_set = sorted(adjustment_candidates)
reasoning = (
f"发现 {len(backdoor_paths)} 条后门路径。"
f"通过控制变量 {safe_adjustment_set} 可阻断所有后门路径,满足后门准则。"
)
return safe_adjustment_set, reasoning
def _parse_path_nodes(path_str: str) -> List[str]:
"""从格式化路径字符串中提取节点列表"""
cleaned = (
path_str.replace("->", " ")
.replace("<-", " ")
.replace("<->", " ")
.replace("--", " ")
)
nodes = [n.strip() for n in cleaned.split() if n.strip()]
return nodes