"""Tests for causal_graph module.""" import networkx as nx import pandas as pd import pytest from causal_agent.analysis.causal_graph import ( build_local_graph, find_adjustment_set, find_backdoor_paths, ) def test_build_local_graph(): df = pd.DataFrame({ "treatment": [0, 1, 0, 1], "health": [1, 2, 1, 3], "age": [20, 30, 40, 50], }) candidates = [{"var": "age"}] tiers = {"treatment": 2, "health": 4, "age": 0} G = build_local_graph(df, "treatment", "health", candidates, tiers, corr_threshold=0.1) assert isinstance(G, nx.DiGraph) assert G.has_edge("treatment", "health") assert G.has_edge("age", "treatment") assert G.has_edge("age", "health") def test_find_backdoor_paths(): G = nx.DiGraph() G.add_edge("age", "treatment") G.add_edge("age", "health") G.add_edge("treatment", "health") paths = find_backdoor_paths(G, "treatment", "health") assert len(paths) == 1 assert "treatment <- age -> health" in paths def test_find_adjustment_set(): G = nx.DiGraph() G.add_edge("age", "treatment") G.add_edge("age", "health") G.add_edge("treatment", "health") adjustment_set, reasoning = find_adjustment_set(G, "treatment", "health") assert "age" in adjustment_set assert "后门" in reasoning