2026-03-29 23:47:20 +08:00

40 lines
1.1 KiB
Python

"""Tests for estimation module."""
import numpy as np
import pandas as pd
import pytest
from causal_agent.analysis.estimation import estimate_ate
def test_estimate_ate_basic():
np.random.seed(42)
n = 300
z = np.random.normal(0, 1, n)
t = (z + np.random.normal(0, 0.5, n) > 0).astype(int)
y = 2 * t + 3 * z + np.random.normal(0, 0.5, n)
df = pd.DataFrame({"T": t, "Y": y, "Z": z})
result = estimate_ate(df, "T", "Y", ["Z"], compute_ci=False)
assert "ATE_Outcome_Regression" in result
assert "ATE_IPW" in result
assert abs(result["ATE_Outcome_Regression"] - 2.0) < 0.5
# IPW 估计方差较大,放宽阈值
assert abs(result["ATE_IPW"] - 2.0) < 1.5
assert "balance_check" in result
assert "Z" in result["balance_check"]
def test_estimate_ate_empty_adjustment():
np.random.seed(42)
n = 200
t = np.random.binomial(1, 0.5, n)
y = 1.5 * t + np.random.normal(0, 0.5, n)
df = pd.DataFrame({"T": t, "Y": y})
result = estimate_ate(df, "T", "Y", [], compute_ci=False)
assert abs(result["ATE_Outcome_Regression"] - 1.5) < 0.5
assert abs(result["ATE_IPW"] - 1.5) < 0.5