"""Tests for estimation module.""" import numpy as np import pandas as pd import pytest from causal_agent.analysis.estimation import estimate_ate def test_estimate_ate_basic(): np.random.seed(42) n = 300 z = np.random.normal(0, 1, n) t = (z + np.random.normal(0, 0.5, n) > 0).astype(int) y = 2 * t + 3 * z + np.random.normal(0, 0.5, n) df = pd.DataFrame({"T": t, "Y": y, "Z": z}) result = estimate_ate(df, "T", "Y", ["Z"], compute_ci=False) assert "ATE_Outcome_Regression" in result assert "ATE_IPW" in result assert abs(result["ATE_Outcome_Regression"] - 2.0) < 0.5 # IPW 估计方差较大,放宽阈值 assert abs(result["ATE_IPW"] - 2.0) < 1.5 assert "balance_check" in result assert "Z" in result["balance_check"] def test_estimate_ate_empty_adjustment(): np.random.seed(42) n = 200 t = np.random.binomial(1, 0.5, n) y = 1.5 * t + np.random.normal(0, 0.5, n) df = pd.DataFrame({"T": t, "Y": y}) result = estimate_ate(df, "T", "Y", [], compute_ci=False) assert abs(result["ATE_Outcome_Regression"] - 1.5) < 0.5 assert abs(result["ATE_IPW"] - 1.5) < 0.5