50 lines
1.3 KiB
Python
50 lines
1.3 KiB
Python
"""Tests for causal_graph module."""
|
|
|
|
import networkx as nx
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from causal_agent.analysis.causal_graph import (
|
|
build_local_graph,
|
|
find_adjustment_set,
|
|
find_backdoor_paths,
|
|
)
|
|
|
|
|
|
def test_build_local_graph():
|
|
df = pd.DataFrame({
|
|
"treatment": [0, 1, 0, 1],
|
|
"health": [1, 2, 1, 3],
|
|
"age": [20, 30, 40, 50],
|
|
})
|
|
candidates = [{"var": "age"}]
|
|
tiers = {"treatment": 2, "health": 4, "age": 0}
|
|
G = build_local_graph(df, "treatment", "health", candidates, tiers, corr_threshold=0.1)
|
|
|
|
assert isinstance(G, nx.DiGraph)
|
|
assert G.has_edge("treatment", "health")
|
|
assert G.has_edge("age", "treatment")
|
|
assert G.has_edge("age", "health")
|
|
|
|
|
|
def test_find_backdoor_paths():
|
|
G = nx.DiGraph()
|
|
G.add_edge("age", "treatment")
|
|
G.add_edge("age", "health")
|
|
G.add_edge("treatment", "health")
|
|
|
|
paths = find_backdoor_paths(G, "treatment", "health")
|
|
assert len(paths) == 1
|
|
assert "treatment <- age -> health" in paths
|
|
|
|
|
|
def test_find_adjustment_set():
|
|
G = nx.DiGraph()
|
|
G.add_edge("age", "treatment")
|
|
G.add_edge("age", "health")
|
|
G.add_edge("treatment", "health")
|
|
|
|
adjustment_set, reasoning = find_adjustment_set(G, "treatment", "health")
|
|
assert "age" in adjustment_set
|
|
assert "后门" in reasoning
|