204 lines
6.3 KiB
Python
204 lines
6.3 KiB
Python
"""
|
||
效应估计模块
|
||
|
||
提供 Outcome Regression (OR) 和 Inverse Probability Weighting (IPW) 两种估计方法。
|
||
"""
|
||
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from sklearn.linear_model import LinearRegression, LogisticRegression
|
||
|
||
|
||
def _compute_smd(df: pd.DataFrame, var: str, T: str) -> float:
|
||
"""计算标准化均值差(SMD)"""
|
||
treated = df[df[T] == 1][var]
|
||
control = df[df[T] == 0][var]
|
||
pooled_std = np.sqrt((treated.var() + control.var()) / 2)
|
||
if pooled_std == 0:
|
||
return 0.0
|
||
return float((treated.mean() - control.mean()) / pooled_std)
|
||
|
||
|
||
def _bootstrap_ci(
|
||
df: pd.DataFrame,
|
||
T: str,
|
||
Y: str,
|
||
adjustment_set: List[str],
|
||
estimator: str,
|
||
n_bootstrap: int = 500,
|
||
alpha: float = 0.05,
|
||
) -> List[float]:
|
||
"""Bootstrap 置信区间"""
|
||
n = len(df)
|
||
estimates = []
|
||
for _ in range(n_bootstrap):
|
||
sample = df.sample(n=n, replace=True, random_state=None)
|
||
try:
|
||
if estimator == "OR":
|
||
ate = _estimate_or(sample, T, Y, adjustment_set)
|
||
elif estimator == "IPW":
|
||
ate = _estimate_ipw(sample, T, Y, adjustment_set)
|
||
else:
|
||
ate = np.nan
|
||
estimates.append(ate)
|
||
except Exception:
|
||
estimates.append(np.nan)
|
||
|
||
estimates = np.array(estimates)
|
||
estimates = estimates[~np.isnan(estimates)]
|
||
if len(estimates) == 0:
|
||
return [np.nan, np.nan]
|
||
|
||
lower = float(np.percentile(estimates, 100 * alpha / 2))
|
||
upper = float(np.percentile(estimates, 100 * (1 - alpha / 2)))
|
||
return [round(lower, 4), round(upper, 4)]
|
||
|
||
|
||
def _estimate_or(df: pd.DataFrame, T: str, Y: str, adjustment_set: List[str]) -> float:
|
||
"""Outcome Regression 估计 ATE"""
|
||
X_cols = adjustment_set + [T]
|
||
X = df[X_cols].values
|
||
y = df[Y].values
|
||
model = LinearRegression().fit(X, y)
|
||
|
||
df_t1 = df.copy()
|
||
df_t1[T] = 1
|
||
df_t0 = df.copy()
|
||
df_t0[T] = 0
|
||
|
||
ate = (
|
||
model.predict(df_t1[X_cols].values).mean()
|
||
- model.predict(df_t0[X_cols].values).mean()
|
||
)
|
||
return float(ate)
|
||
|
||
|
||
def _estimate_ipw(df: pd.DataFrame, T: str, Y: str, adjustment_set: List[str]) -> float:
|
||
"""IPW 估计 ATE"""
|
||
t = df[T].values
|
||
if len(adjustment_set) == 0:
|
||
ps = np.full_like(t, fill_value=t.mean(), dtype=float)
|
||
else:
|
||
X = df[adjustment_set].values
|
||
ps_model = LogisticRegression(max_iter=1000, solver="lbfgs")
|
||
ps_model.fit(X, t)
|
||
ps = ps_model.predict_proba(X)[:, 1]
|
||
|
||
ps_clipped = np.clip(ps, 0.05, 0.95)
|
||
weights_t = t / ps_clipped
|
||
weights_c = (1 - t) / (1 - ps_clipped)
|
||
|
||
ate = np.mean(weights_t * df[Y].values) - np.mean(weights_c * df[Y].values)
|
||
return float(ate)
|
||
|
||
|
||
def estimate_ate(
|
||
df: pd.DataFrame,
|
||
T: str,
|
||
Y: str,
|
||
adjustment_set: List[str],
|
||
compute_ci: bool = True,
|
||
n_bootstrap: int = 500,
|
||
alpha: float = 0.05,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
估计平均处理效应(ATE)。
|
||
|
||
Args:
|
||
df: 输入数据框
|
||
T: 处理变量
|
||
Y: 结果变量
|
||
adjustment_set: 调整变量列表
|
||
compute_ci: 是否计算 Bootstrap CI
|
||
n_bootstrap: Bootstrap 次数
|
||
alpha: 显著性水平
|
||
|
||
Returns:
|
||
包含 ATE_OR、ATE_IPW、诊断信息的字典
|
||
"""
|
||
warnings_list: List[Dict[str, str]] = []
|
||
|
||
ate_or = _estimate_or(df, T, Y, adjustment_set)
|
||
|
||
t = df[T].values
|
||
if len(adjustment_set) == 0:
|
||
ps = np.full_like(t, fill_value=t.mean(), dtype=float)
|
||
else:
|
||
X = df[adjustment_set].values
|
||
ps_model = LogisticRegression(max_iter=1000, solver="lbfgs")
|
||
ps_model.fit(X, t)
|
||
ps = ps_model.predict_proba(X)[:, 1]
|
||
|
||
overlap_ok = True
|
||
if np.any(ps < 0.05) or np.any(ps > 0.95):
|
||
overlap_ok = False
|
||
warnings_list.append(
|
||
{
|
||
"type": "positivity_violation",
|
||
"message": "倾向得分存在极端值(<0.05 或 >0.95),存在重叠假设违反风险。",
|
||
}
|
||
)
|
||
|
||
ate_ipw = _estimate_ipw(df, T, Y, adjustment_set)
|
||
|
||
diff = abs(ate_or - ate_ipw)
|
||
robustness = "稳健"
|
||
if ate_or != 0 and diff / abs(ate_or) > 0.1:
|
||
robustness = "OR 与 IPW 估计差异 >10%,可能存在模型误设,建议进一步检查"
|
||
elif ate_ipw != 0 and diff / abs(ate_ipw) > 0.1:
|
||
robustness = "OR 与 IPW 估计差异 >10%,可能存在模型误设,建议进一步检查"
|
||
|
||
balance_check: Dict[str, Dict[str, float]] = {}
|
||
for var in adjustment_set:
|
||
smd_before = _compute_smd(df, var, T)
|
||
ps_clipped = np.clip(ps, 0.05, 0.95)
|
||
weights = np.where(df[T].values == 1, 1.0 / ps_clipped, 1.0 / (1 - ps_clipped))
|
||
treated_idx = df[T].values == 1
|
||
control_idx = df[T].values == 0
|
||
weighted_mean_t = np.average(
|
||
df.loc[treated_idx, var].values, weights=weights[treated_idx]
|
||
)
|
||
weighted_mean_c = np.average(
|
||
df.loc[control_idx, var].values, weights=weights[control_idx]
|
||
)
|
||
pooled_std = np.sqrt(
|
||
(df.loc[treated_idx, var].var() + df.loc[control_idx, var].var()) / 2
|
||
)
|
||
smd_after = (
|
||
float((weighted_mean_t - weighted_mean_c) / pooled_std)
|
||
if pooled_std > 0
|
||
else 0.0
|
||
)
|
||
balance_check[var] = {
|
||
"before": round(smd_before, 4),
|
||
"after": round(smd_after, 4),
|
||
}
|
||
|
||
ci_or: List[float] = [np.nan, np.nan]
|
||
ci_ipw: List[float] = [np.nan, np.nan]
|
||
if compute_ci:
|
||
ci_or = _bootstrap_ci(df, T, Y, adjustment_set, "OR", n_bootstrap, alpha)
|
||
ci_ipw = _bootstrap_ci(df, T, Y, adjustment_set, "IPW", n_bootstrap, alpha)
|
||
|
||
ate_report = round((ate_or + ate_ipw) / 2, 4)
|
||
ci_report = ci_or if not np.isnan(ci_or[0]) else ci_ipw
|
||
|
||
interpretation = (
|
||
f"在控制 {adjustment_set} 后,接受处理使 {Y} "
|
||
f"平均变化 {ate_report:.4f} (95%CI: {ci_report[0]:.2f}-{ci_report[1]:.2f})。"
|
||
)
|
||
|
||
return {
|
||
"ATE_Outcome_Regression": round(ate_or, 4),
|
||
"ATE_IPW": round(ate_ipw, 4),
|
||
"ATE_reported": ate_report,
|
||
"95%_CI": ci_report,
|
||
"interpretation": interpretation,
|
||
"balance_check": balance_check,
|
||
"overlap_assumption": "满足" if overlap_ok else "存在风险",
|
||
"robustness": robustness,
|
||
"warnings": warnings_list,
|
||
}
|