298 lines
7.2 KiB
Markdown
298 lines
7.2 KiB
Markdown
# Causal Inference Agent
|
||
|
||
一个领域无关的通用因果推断 Agent,提供从数据加载、变量识别、因果图构建、混杂识别到效应估计的完整 Pipeline。
|
||
|
||
## 功能特性
|
||
|
||
- **LLM 智能变量识别**:自动识别处理变量 (T)、结果变量 (Y),并对所有变量进行时间层级解析
|
||
- **快速相关性筛查**:使用 Pearson + Spearman + 互信息筛选潜在混杂因素
|
||
- **混合智能因果图**:基于时间层级和统计相关性构建因果图
|
||
- **后门准则自动识别**:发现后门路径并推导最小调整集
|
||
- **双重效应估计**:支持 Outcome Regression (OR) 和 Inverse Probability Weighting (IPW)
|
||
- **标准化 JSON 报告**:包含因果图、识别策略、估计结果、诊断信息
|
||
|
||
## 快速开始
|
||
|
||
### 环境要求
|
||
|
||
- Python >= 3.11
|
||
- uv >= 0.4.0
|
||
|
||
### 安装 uv
|
||
|
||
```bash
|
||
# Windows (PowerShell)
|
||
irm https://astral.sh/uv/install.ps1 | iex
|
||
|
||
# macOS/Linux
|
||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||
```
|
||
|
||
### 项目设置
|
||
|
||
```bash
|
||
# 进入项目目录
|
||
cd CausalInferenceAgent_V2
|
||
|
||
# 创建虚拟环境并安装依赖
|
||
uv venv
|
||
uv pip install -e .
|
||
```
|
||
|
||
## 使用方式
|
||
|
||
### 1. 命令行工具 (CLI)
|
||
|
||
```bash
|
||
# 基础用法
|
||
uv run python -m causal_agent --data data.xlsx --output report.json --prompt "分析吃药的效果"
|
||
|
||
# 自定义 LLM 配置
|
||
uv run python -m causal_agent --data data.xlsx --model qwen3.5-35b --base-url http://... --output report.json
|
||
|
||
# 查看帮助
|
||
uv run python -m causal_agent --help
|
||
```
|
||
|
||
### 2. Python API
|
||
|
||
```python
|
||
from causal_agent.agent import CausalInferenceAgent
|
||
from causal_agent.core.config import AgentConfig
|
||
|
||
# 使用默认配置
|
||
agent = CausalInferenceAgent()
|
||
result = agent.analyze("data.xlsx")
|
||
|
||
# 查看报告
|
||
print(result["report"])
|
||
```
|
||
|
||
### 3. 运行示例
|
||
|
||
```bash
|
||
uv run python -m causal_agent --data data.xlsx --output report.json --prompt "分析吃药的效果"
|
||
```
|
||
|
||
## 项目结构
|
||
|
||
```
|
||
CausalInferenceAgent_V2/
|
||
├── pyproject.toml # 项目配置
|
||
├── README.md # 本文件
|
||
├── uv.lock # uv 锁定文件
|
||
├── causal_agent/ # 主包
|
||
│ ├── __init__.py
|
||
│ ├── agent.py # CausalInferenceAgent
|
||
│ ├── cli.py # 命令行入口
|
||
│ ├── logger.py # 日志记录器
|
||
│ ├── core/ # 核心模块
|
||
│ │ ├── config.py # AgentConfig 配置
|
||
│ │ ├── llm_client.py # LLM 客户端
|
||
│ │ └── data_loader.py # 数据加载器
|
||
│ └── analysis/ # 分析模块
|
||
│ ├── variable_parser.py # LLM 变量识别
|
||
│ ├── screening.py # 相关性筛查
|
||
│ ├── causal_graph.py # 因果图构建
|
||
│ ├── estimation.py # 效应估计
|
||
│ └── reporting.py # 报告生成
|
||
├── input_processor/ # 输入处理(原有)
|
||
├── examples/
|
||
│ └── medical_v2/ # 医疗数据示例
|
||
├── tests/
|
||
│ └── test_causal_agent/ # 单元测试
|
||
└── data/ # 测试数据
|
||
```
|
||
|
||
## 完整 Pipeline 说明
|
||
|
||
Agent 执行以下 6 个步骤:
|
||
|
||
### Step 1: LLM 变量识别
|
||
|
||
调用 LLM 识别:
|
||
|
||
- `treatment`: 处理变量
|
||
- `outcome`: 结果变量
|
||
- `time_tiers`: 所有变量的时间层级(-1~4+)
|
||
|
||
**time_tiers 层级说明:**
|
||
|
||
- `-1`: 非时间变量(如 id、index)
|
||
- `0`: 人口学特征(如 age、gender)
|
||
- `1`: 基线测量(如 baseline_score)
|
||
- `2`: 干预点/处理变量
|
||
- `3`: 中介变量
|
||
- `4`: 随访结果/结果变量
|
||
- `5+`: 更晚时间点
|
||
|
||
### Step 2: 快速相关性筛查
|
||
|
||
使用 Pearson + Spearman 相关系数 + 互信息 (MI):
|
||
|
||
- 筛选与 T 和 Y 都显著相关的变量(|r| > 0.1, p < 0.05)
|
||
- 作为候选混杂因素
|
||
|
||
### Step 3: 因果图构建
|
||
|
||
基于时间层级构建混合智能因果图:
|
||
|
||
- 时间层级低的变量指向层级高的变量
|
||
- 统计相关性低于阈值的边被过滤
|
||
- 边类型:`temporal`、`confounding`、`hypothesized`
|
||
|
||
### Step 4: 混杂识别(后门准则)
|
||
|
||
- 发现所有以 `T <-` 开始的后门路径
|
||
- 自动推导阻断所有后门路径的最小调整集
|
||
|
||
### Step 5: 效应估计
|
||
|
||
提供两种估计方法:
|
||
|
||
**Outcome Regression (OR)**
|
||
|
||
- 线性回归 + G-computation
|
||
- 直接估计 E[Y|do(T=1)] - E[Y|do(T=0)]
|
||
|
||
**Inverse Probability Weighting (IPW)**
|
||
|
||
- Logistic 回归估计倾向得分
|
||
- 截断稳定权重避免极端值
|
||
|
||
同时提供:
|
||
|
||
- Bootstrap 95% 置信区间
|
||
- 标准化均值差(SMD)平衡检验
|
||
- 重叠假设(Positivity)检查
|
||
- OR/IPW 稳健性判断
|
||
|
||
### Step 6: 生成报告
|
||
|
||
标准化 JSON 报告包含:
|
||
|
||
```json
|
||
{
|
||
"query_interpretation": {
|
||
"treatment": "treatment",
|
||
"outcome": "health",
|
||
"estimand": "ATE"
|
||
},
|
||
"causal_graph": {
|
||
"nodes": ["age", "base_health", "treatment", "health"],
|
||
"edges": [...],
|
||
"backdoor_paths": ["treatment <- base_health -> health"]
|
||
},
|
||
"identification": {
|
||
"strategy": "Backdoor Adjustment",
|
||
"adjustment_set": ["age", "base_health"],
|
||
"reasoning": "..."
|
||
},
|
||
"estimation": {
|
||
"ATE_Outcome_Regression": 0.2444,
|
||
"ATE_IPW": 0.2425,
|
||
"95%_CI": [0.2351, 0.2544],
|
||
"interpretation": "..."
|
||
},
|
||
"diagnostics": {
|
||
"balance_check": {...},
|
||
"overlap_assumption": "满足",
|
||
"robustness": "稳健"
|
||
},
|
||
"warnings": [...]
|
||
}
|
||
```
|
||
|
||
## 配置选项
|
||
|
||
### 方法 1:使用 `.env` 文件(推荐)
|
||
|
||
项目已包含 `.env.example` 文件,复制为 `.env` 并修改:
|
||
|
||
```bash
|
||
cp .env.example .env
|
||
```
|
||
|
||
编辑 `.env` 文件修改 LLM URL:
|
||
|
||
```bash
|
||
# 只需修改这一行即可切换 LLM 服务器
|
||
LLM_BASE_URL=https://your-new-llm-server.com/v1
|
||
LLM_MODEL=your-model-name
|
||
```
|
||
|
||
**优点**:
|
||
- 切换 URL 只需改一个文件
|
||
- `.env` 不会被提交到 git(已在 .gitignore 中)
|
||
- 所有代码自动读取,无需重启
|
||
|
||
### 方法 2:环境变量
|
||
|
||
```bash
|
||
# Windows PowerShell
|
||
$env:LLM_BASE_URL="https://your-new-llm-server.com/v1"
|
||
$env:LLM_MODEL="qwen3.5-35b"
|
||
|
||
# macOS/Linux
|
||
export LLM_BASE_URL="https://your-new-llm-server.com/v1"
|
||
export LLM_MODEL="qwen3.5-35b"
|
||
export LLM_TEMPERATURE="0.3"
|
||
export LLM_MAX_TOKENS="2048"
|
||
export LLM_API_KEY="your-api-key"
|
||
|
||
# 统计筛查配置
|
||
export CORR_THRESHOLD="0.1"
|
||
export ALPHA="0.05"
|
||
export BOOTSTRAP_ITERATIONS="500"
|
||
|
||
# 路径配置
|
||
export LOG_PATH="causal_analysis.log.md"
|
||
```
|
||
|
||
### 方法 3:Python 代码中配置
|
||
|
||
```python
|
||
from causal_agent.core.config import AgentConfig
|
||
|
||
config = AgentConfig(
|
||
llm_model="gpt-4",
|
||
llm_temperature=0.5,
|
||
corr_threshold=0.15,
|
||
log_path="my_analysis.log"
|
||
)
|
||
agent = CausalInferenceAgent(config)
|
||
```
|
||
|
||
## 开发工具
|
||
|
||
```bash
|
||
# 代码格式化
|
||
uv run black .
|
||
|
||
# 代码检查
|
||
uv run ruff check .
|
||
|
||
# 类型检查
|
||
uv run mypy .
|
||
|
||
# 运行测试
|
||
uv run pytest tests/test_causal_agent/ -v
|
||
|
||
# 运行测试并生成覆盖率报告
|
||
uv run pytest tests/ --cov=. --cov-report=html
|
||
```
|
||
|
||
## 许可证
|
||
|
||
MIT License
|
||
|
||
## 贡献
|
||
|
||
欢迎贡献代码!请遵循以下步骤:
|
||
|
||
1. Fork 本仓库
|
||
2. 创建特性分支 (`git checkout -b feature/AmazingFeature`)
|
||
3. 提交更改 (`git commit -m 'Add some AmazingFeature'`)
|
||
4. 推送到分支 (`git push origin feature/AmazingFeature`)
|
||
5. 开启 Pull Request
|