2026-03-29 23:47:20 +08:00

204 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
效应估计模块
提供 Outcome Regression (OR) 和 Inverse Probability Weighting (IPW) 两种估计方法。
"""
from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression, LogisticRegression
def _compute_smd(df: pd.DataFrame, var: str, T: str) -> float:
"""计算标准化均值差SMD"""
treated = df[df[T] == 1][var]
control = df[df[T] == 0][var]
pooled_std = np.sqrt((treated.var() + control.var()) / 2)
if pooled_std == 0:
return 0.0
return float((treated.mean() - control.mean()) / pooled_std)
def _bootstrap_ci(
df: pd.DataFrame,
T: str,
Y: str,
adjustment_set: List[str],
estimator: str,
n_bootstrap: int = 500,
alpha: float = 0.05,
) -> List[float]:
"""Bootstrap 置信区间"""
n = len(df)
estimates = []
for _ in range(n_bootstrap):
sample = df.sample(n=n, replace=True, random_state=None)
try:
if estimator == "OR":
ate = _estimate_or(sample, T, Y, adjustment_set)
elif estimator == "IPW":
ate = _estimate_ipw(sample, T, Y, adjustment_set)
else:
ate = np.nan
estimates.append(ate)
except Exception:
estimates.append(np.nan)
estimates = np.array(estimates)
estimates = estimates[~np.isnan(estimates)]
if len(estimates) == 0:
return [np.nan, np.nan]
lower = float(np.percentile(estimates, 100 * alpha / 2))
upper = float(np.percentile(estimates, 100 * (1 - alpha / 2)))
return [round(lower, 4), round(upper, 4)]
def _estimate_or(df: pd.DataFrame, T: str, Y: str, adjustment_set: List[str]) -> float:
"""Outcome Regression 估计 ATE"""
X_cols = adjustment_set + [T]
X = df[X_cols].values
y = df[Y].values
model = LinearRegression().fit(X, y)
df_t1 = df.copy()
df_t1[T] = 1
df_t0 = df.copy()
df_t0[T] = 0
ate = (
model.predict(df_t1[X_cols].values).mean()
- model.predict(df_t0[X_cols].values).mean()
)
return float(ate)
def _estimate_ipw(df: pd.DataFrame, T: str, Y: str, adjustment_set: List[str]) -> float:
"""IPW 估计 ATE"""
t = df[T].values
if len(adjustment_set) == 0:
ps = np.full_like(t, fill_value=t.mean(), dtype=float)
else:
X = df[adjustment_set].values
ps_model = LogisticRegression(max_iter=1000, solver="lbfgs")
ps_model.fit(X, t)
ps = ps_model.predict_proba(X)[:, 1]
ps_clipped = np.clip(ps, 0.05, 0.95)
weights_t = t / ps_clipped
weights_c = (1 - t) / (1 - ps_clipped)
ate = np.mean(weights_t * df[Y].values) - np.mean(weights_c * df[Y].values)
return float(ate)
def estimate_ate(
df: pd.DataFrame,
T: str,
Y: str,
adjustment_set: List[str],
compute_ci: bool = True,
n_bootstrap: int = 500,
alpha: float = 0.05,
) -> Dict[str, Any]:
"""
估计平均处理效应ATE
Args:
df: 输入数据框
T: 处理变量
Y: 结果变量
adjustment_set: 调整变量列表
compute_ci: 是否计算 Bootstrap CI
n_bootstrap: Bootstrap 次数
alpha: 显著性水平
Returns:
包含 ATE_OR、ATE_IPW、诊断信息的字典
"""
warnings_list: List[Dict[str, str]] = []
ate_or = _estimate_or(df, T, Y, adjustment_set)
t = df[T].values
if len(adjustment_set) == 0:
ps = np.full_like(t, fill_value=t.mean(), dtype=float)
else:
X = df[adjustment_set].values
ps_model = LogisticRegression(max_iter=1000, solver="lbfgs")
ps_model.fit(X, t)
ps = ps_model.predict_proba(X)[:, 1]
overlap_ok = True
if np.any(ps < 0.05) or np.any(ps > 0.95):
overlap_ok = False
warnings_list.append(
{
"type": "positivity_violation",
"message": "倾向得分存在极端值(<0.05 或 >0.95),存在重叠假设违反风险。",
}
)
ate_ipw = _estimate_ipw(df, T, Y, adjustment_set)
diff = abs(ate_or - ate_ipw)
robustness = "稳健"
if ate_or != 0 and diff / abs(ate_or) > 0.1:
robustness = "OR 与 IPW 估计差异 >10%,可能存在模型误设,建议进一步检查"
elif ate_ipw != 0 and diff / abs(ate_ipw) > 0.1:
robustness = "OR 与 IPW 估计差异 >10%,可能存在模型误设,建议进一步检查"
balance_check: Dict[str, Dict[str, float]] = {}
for var in adjustment_set:
smd_before = _compute_smd(df, var, T)
ps_clipped = np.clip(ps, 0.05, 0.95)
weights = np.where(df[T].values == 1, 1.0 / ps_clipped, 1.0 / (1 - ps_clipped))
treated_idx = df[T].values == 1
control_idx = df[T].values == 0
weighted_mean_t = np.average(
df.loc[treated_idx, var].values, weights=weights[treated_idx]
)
weighted_mean_c = np.average(
df.loc[control_idx, var].values, weights=weights[control_idx]
)
pooled_std = np.sqrt(
(df.loc[treated_idx, var].var() + df.loc[control_idx, var].var()) / 2
)
smd_after = (
float((weighted_mean_t - weighted_mean_c) / pooled_std)
if pooled_std > 0
else 0.0
)
balance_check[var] = {
"before": round(smd_before, 4),
"after": round(smd_after, 4),
}
ci_or: List[float] = [np.nan, np.nan]
ci_ipw: List[float] = [np.nan, np.nan]
if compute_ci:
ci_or = _bootstrap_ci(df, T, Y, adjustment_set, "OR", n_bootstrap, alpha)
ci_ipw = _bootstrap_ci(df, T, Y, adjustment_set, "IPW", n_bootstrap, alpha)
ate_report = round((ate_or + ate_ipw) / 2, 4)
ci_report = ci_or if not np.isnan(ci_or[0]) else ci_ipw
interpretation = (
f"在控制 {adjustment_set} 后,接受处理使 {Y} "
f"平均变化 {ate_report:.4f} 95%CI: {ci_report[0]:.2f}-{ci_report[1]:.2f})。"
)
return {
"ATE_Outcome_Regression": round(ate_or, 4),
"ATE_IPW": round(ate_ipw, 4),
"ATE_reported": ate_report,
"95%_CI": ci_report,
"interpretation": interpretation,
"balance_check": balance_check,
"overlap_assumption": "满足" if overlap_ok else "存在风险",
"robustness": robustness,
"warnings": warnings_list,
}