84 lines
2.5 KiB
Python
84 lines
2.5 KiB
Python
"""
|
|
快速相关性筛查模块
|
|
|
|
使用 Pearson + Spearman 方法快速过滤与处理变量 (T) 和结果变量 (Y) 都相关的变量。
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from scipy.stats import pearsonr, spearmanr
|
|
from sklearn.feature_selection import mutual_info_regression
|
|
|
|
|
|
def local_screen(
|
|
df: pd.DataFrame,
|
|
T: str,
|
|
Y: str,
|
|
excluded: Optional[List[str]] = None,
|
|
corr_threshold: float = 0.1,
|
|
alpha: float = 0.05,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
快速相关性筛查:找出与 T 和 Y 都显著相关的变量。
|
|
|
|
Args:
|
|
df: 输入数据框
|
|
T: 处理变量名称
|
|
Y: 结果变量名称
|
|
excluded: 需要排除的变量列表
|
|
corr_threshold: 相关系数绝对值阈值
|
|
alpha: 显著性水平
|
|
|
|
Returns:
|
|
候选混杂变量列表
|
|
"""
|
|
if excluded is None:
|
|
excluded = []
|
|
|
|
candidates = []
|
|
cols = [c for c in df.columns if c not in [T, Y] + excluded]
|
|
|
|
for col in cols:
|
|
if df[col].isna().mean() > 0.5:
|
|
continue
|
|
if not pd.api.types.is_numeric_dtype(df[col]):
|
|
continue
|
|
|
|
p_t, pv_t_pearson = pearsonr(df[col], df[T])
|
|
p_y, pv_y_pearson = pearsonr(df[col], df[Y])
|
|
|
|
s_t, pv_t_spear = spearmanr(df[col], df[T])
|
|
s_y, pv_y_spear = spearmanr(df[col], df[Y])
|
|
|
|
cond_t = abs(s_t) > corr_threshold and pv_t_spear < alpha
|
|
cond_y = abs(s_y) > corr_threshold and pv_y_spear < alpha
|
|
|
|
if cond_t and cond_y:
|
|
mi_t = _compute_mi(df[[col]].values, df[T].values)
|
|
mi_y = _compute_mi(df[[col]].values, df[Y].values)
|
|
|
|
candidates.append(
|
|
{
|
|
"var": col,
|
|
"pearson_T": round(float(p_t), 4),
|
|
"pearson_Y": round(float(p_y), 4),
|
|
"spearman_T": round(float(s_t), 4),
|
|
"spearman_Y": round(float(s_y), 4),
|
|
"pvalue_T": round(float(pv_t_spear), 4),
|
|
"pvalue_Y": round(float(pv_y_spear), 4),
|
|
"mi_T": round(float(mi_t), 4),
|
|
"mi_Y": round(float(mi_y), 4),
|
|
}
|
|
)
|
|
|
|
candidates.sort(key=lambda x: abs(x["spearman_T"]) + abs(x["spearman_Y"]), reverse=True)
|
|
return candidates
|
|
|
|
|
|
def _compute_mi(X: np.ndarray, y: np.ndarray) -> float:
|
|
"""计算互信息"""
|
|
mi = mutual_info_regression(X, y, random_state=42)
|
|
return float(mi[0])
|