32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
"""Tests for screening module."""
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from causal_agent.analysis.screening import local_screen
|
|
|
|
|
|
def test_local_screen_finds_confounders():
|
|
np.random.seed(42)
|
|
n = 200
|
|
z = np.random.normal(0, 1, n)
|
|
t = (z + np.random.normal(0, 0.5, n) > 0).astype(int)
|
|
y = 2 * t + 3 * z + np.random.normal(0, 0.5, n)
|
|
noise = np.random.normal(0, 1, n)
|
|
|
|
df = pd.DataFrame({"T": t, "Y": y, "Z": z, "noise": noise})
|
|
candidates = local_screen(df, T="T", Y="Y", excluded=[], corr_threshold=0.1, alpha=0.05)
|
|
|
|
var_names = [c["var"] for c in candidates]
|
|
assert "Z" in var_names
|
|
assert "noise" not in var_names
|
|
|
|
|
|
def test_local_screen_excludes_specified_columns():
|
|
df = pd.DataFrame({"T": [0, 1], "Y": [1, 2], "X": [0, 1], "id": [1, 2]})
|
|
candidates = local_screen(df, T="T", Y="Y", excluded=["id"])
|
|
assert all(c["var"] != "id" for c in candidates)
|
|
assert all(c["var"] != "T" for c in candidates)
|
|
assert all(c["var"] != "Y" for c in candidates)
|