2026-03-29 23:47:20 +08:00

50 lines
1.3 KiB
Python

"""Tests for causal_graph module."""
import networkx as nx
import pandas as pd
import pytest
from causal_agent.analysis.causal_graph import (
build_local_graph,
find_adjustment_set,
find_backdoor_paths,
)
def test_build_local_graph():
df = pd.DataFrame({
"treatment": [0, 1, 0, 1],
"health": [1, 2, 1, 3],
"age": [20, 30, 40, 50],
})
candidates = [{"var": "age"}]
tiers = {"treatment": 2, "health": 4, "age": 0}
G = build_local_graph(df, "treatment", "health", candidates, tiers, corr_threshold=0.1)
assert isinstance(G, nx.DiGraph)
assert G.has_edge("treatment", "health")
assert G.has_edge("age", "treatment")
assert G.has_edge("age", "health")
def test_find_backdoor_paths():
G = nx.DiGraph()
G.add_edge("age", "treatment")
G.add_edge("age", "health")
G.add_edge("treatment", "health")
paths = find_backdoor_paths(G, "treatment", "health")
assert len(paths) == 1
assert "treatment <- age -> health" in paths
def test_find_adjustment_set():
G = nx.DiGraph()
G.add_edge("age", "treatment")
G.add_edge("age", "health")
G.add_edge("treatment", "health")
adjustment_set, reasoning = find_adjustment_set(G, "treatment", "health")
assert "age" in adjustment_set
assert "后门" in reasoning