135 lines
3.8 KiB
Python
135 lines
3.8 KiB
Python
"""
|
|
因果图构建与后门准则模块
|
|
|
|
功能:
|
|
1. 基于时间层级和统计相关性构建混合智能因果图
|
|
2. 自动发现后门路径
|
|
3. 基于后门准则寻找最小调整集
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import networkx as nx
|
|
import pandas as pd
|
|
from scipy.stats import spearmanr
|
|
|
|
|
|
def build_local_graph(
|
|
df: pd.DataFrame,
|
|
T: str,
|
|
Y: str,
|
|
candidates: List[Dict[str, Any]],
|
|
tiers: Dict[str, int],
|
|
corr_threshold: float = 0.1,
|
|
) -> nx.DiGraph:
|
|
"""
|
|
构建局部因果图。
|
|
|
|
Args:
|
|
df: 输入数据框
|
|
T: 处理变量
|
|
Y: 结果变量
|
|
candidates: 候选混杂变量列表
|
|
tiers: 时间层级字典
|
|
corr_threshold: 相关性阈值
|
|
|
|
Returns:
|
|
networkx.DiGraph 对象
|
|
"""
|
|
G = nx.DiGraph()
|
|
relevant_vars = [T, Y] + [c["var"] for c in candidates]
|
|
G.add_nodes_from(relevant_vars)
|
|
|
|
for u in relevant_vars:
|
|
for v in relevant_vars:
|
|
if u == v:
|
|
continue
|
|
tier_u = tiers.get(u, -1)
|
|
tier_v = tiers.get(v, -1)
|
|
if tier_u < tier_v:
|
|
corr, _ = spearmanr(df[u], df[v])
|
|
if abs(corr) > corr_threshold:
|
|
edge_type = "confounding" if v in [T, Y] else "temporal"
|
|
G.add_edge(u, v, type=edge_type, confidence=1.0)
|
|
|
|
G.add_edge(T, Y, type="hypothesized", confidence="research_question")
|
|
return G
|
|
|
|
|
|
def find_backdoor_paths(G: nx.DiGraph, T: str, Y: str) -> List[str]:
|
|
"""寻找从 T 到 Y 的所有后门路径"""
|
|
UG = G.to_undirected()
|
|
try:
|
|
paths = list(nx.all_simple_paths(UG, source=T, target=Y, cutoff=5))
|
|
except nx.NodeNotFound:
|
|
paths = []
|
|
|
|
backdoor_paths = []
|
|
for path in paths:
|
|
if len(path) < 2:
|
|
continue
|
|
second_node = path[1]
|
|
if G.has_edge(second_node, T):
|
|
formatted = _format_path(G, path)
|
|
backdoor_paths.append(formatted)
|
|
|
|
seen = set()
|
|
unique_paths = []
|
|
for p in backdoor_paths:
|
|
if p not in seen:
|
|
seen.add(p)
|
|
unique_paths.append(p)
|
|
return unique_paths
|
|
|
|
|
|
def _format_path(G: nx.DiGraph, path: List[str]) -> str:
|
|
"""将无向路径格式化为带方向箭头的字符串"""
|
|
parts = [path[0]]
|
|
for i in range(len(path) - 1):
|
|
u, v = path[i], path[i + 1]
|
|
if G.has_edge(u, v) and G.has_edge(v, u):
|
|
parts.append(f"<-> {v}")
|
|
elif G.has_edge(u, v):
|
|
parts.append(f"-> {v}")
|
|
elif G.has_edge(v, u):
|
|
parts.append(f"<- {v}")
|
|
else:
|
|
parts.append(f"-- {v}")
|
|
return " ".join(parts)
|
|
|
|
|
|
def find_adjustment_set(
|
|
G: nx.DiGraph, T: str, Y: str, backdoor_paths: Optional[List[str]] = None
|
|
) -> Tuple[List[str], str]:
|
|
"""基于后门准则寻找最小调整集"""
|
|
if backdoor_paths is None:
|
|
backdoor_paths = find_backdoor_paths(G, T, Y)
|
|
|
|
if not backdoor_paths:
|
|
return [], "未发现后门路径,无需额外调整变量即可识别因果效应。"
|
|
|
|
adjustment_candidates = set()
|
|
for path_str in backdoor_paths:
|
|
nodes = _parse_path_nodes(path_str)
|
|
if len(nodes) >= 2:
|
|
adjustment_candidates.add(nodes[1])
|
|
|
|
safe_adjustment_set = sorted(adjustment_candidates)
|
|
reasoning = (
|
|
f"发现 {len(backdoor_paths)} 条后门路径。"
|
|
f"通过控制变量 {safe_adjustment_set} 可阻断所有后门路径,满足后门准则。"
|
|
)
|
|
return safe_adjustment_set, reasoning
|
|
|
|
|
|
def _parse_path_nodes(path_str: str) -> List[str]:
|
|
"""从格式化路径字符串中提取节点列表"""
|
|
cleaned = (
|
|
path_str.replace("->", " ")
|
|
.replace("<-", " ")
|
|
.replace("<->", " ")
|
|
.replace("--", " ")
|
|
)
|
|
nodes = [n.strip() for n in cleaned.split() if n.strip()]
|
|
return nodes
|