40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
"""Tests for estimation module."""
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from causal_agent.analysis.estimation import estimate_ate
|
|
|
|
|
|
def test_estimate_ate_basic():
|
|
np.random.seed(42)
|
|
n = 300
|
|
z = np.random.normal(0, 1, n)
|
|
t = (z + np.random.normal(0, 0.5, n) > 0).astype(int)
|
|
y = 2 * t + 3 * z + np.random.normal(0, 0.5, n)
|
|
|
|
df = pd.DataFrame({"T": t, "Y": y, "Z": z})
|
|
result = estimate_ate(df, "T", "Y", ["Z"], compute_ci=False)
|
|
|
|
assert "ATE_Outcome_Regression" in result
|
|
assert "ATE_IPW" in result
|
|
assert abs(result["ATE_Outcome_Regression"] - 2.0) < 0.5
|
|
# IPW 估计方差较大,放宽阈值
|
|
assert abs(result["ATE_IPW"] - 2.0) < 1.5
|
|
assert "balance_check" in result
|
|
assert "Z" in result["balance_check"]
|
|
|
|
|
|
def test_estimate_ate_empty_adjustment():
|
|
np.random.seed(42)
|
|
n = 200
|
|
t = np.random.binomial(1, 0.5, n)
|
|
y = 1.5 * t + np.random.normal(0, 0.5, n)
|
|
|
|
df = pd.DataFrame({"T": t, "Y": y})
|
|
result = estimate_ate(df, "T", "Y", [], compute_ci=False)
|
|
|
|
assert abs(result["ATE_Outcome_Regression"] - 1.5) < 0.5
|
|
assert abs(result["ATE_IPW"] - 1.5) < 0.5
|