init
This commit is contained in:
commit
16da68c038
19
.env.example
Normal file
19
.env.example
Normal file
@ -0,0 +1,19 @@
|
||||
# Causal Inference Agent - LLM 配置示例
|
||||
# 复制此文件为 .env 并修改为你自己的配置
|
||||
|
||||
# LLM API 配置(必填)
|
||||
LLM_BASE_URL=https://xxxx/v1
|
||||
LLM_MODEL=qwen3.5-35b
|
||||
|
||||
# 可选配置
|
||||
LLM_TEMPERATURE=0.3
|
||||
LLM_MAX_TOKENS=2048
|
||||
# LLM_API_KEY=your-api-key-here
|
||||
|
||||
# 统计筛查配置(可选)
|
||||
# CORR_THRESHOLD=0.1
|
||||
# ALPHA=0.05
|
||||
# BOOTSTRAP_ITERATIONS=500
|
||||
|
||||
# 日志配置(可选)
|
||||
# LOG_PATH=causal_analysis.log.md
|
||||
159
.gitignore
vendored
Normal file
159
.gitignore
vendored
Normal file
@ -0,0 +1,159 @@
|
||||
causal_analysis.log.md
|
||||
report.json
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv/
|
||||
env/
|
||||
env.bak/
|
||||
venv/
|
||||
ENV/
|
||||
env.sh
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Data (optional - adjust based on your needs)
|
||||
data/output/
|
||||
# 保留根目录的 data.xlsx
|
||||
*.xlsx
|
||||
!/data.xlsx
|
||||
*.csv
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
*.bak
|
||||
297
README.md
Normal file
297
README.md
Normal file
@ -0,0 +1,297 @@
|
||||
# Causal Inference Agent
|
||||
|
||||
一个领域无关的通用因果推断 Agent,提供从数据加载、变量识别、因果图构建、混杂识别到效应估计的完整 Pipeline。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **LLM 智能变量识别**:自动识别处理变量 (T)、结果变量 (Y),并对所有变量进行时间层级解析
|
||||
- **快速相关性筛查**:使用 Pearson + Spearman + 互信息筛选潜在混杂因素
|
||||
- **混合智能因果图**:基于时间层级和统计相关性构建因果图
|
||||
- **后门准则自动识别**:发现后门路径并推导最小调整集
|
||||
- **双重效应估计**:支持 Outcome Regression (OR) 和 Inverse Probability Weighting (IPW)
|
||||
- **标准化 JSON 报告**:包含因果图、识别策略、估计结果、诊断信息
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 环境要求
|
||||
|
||||
- Python >= 3.11
|
||||
- uv >= 0.4.0
|
||||
|
||||
### 安装 uv
|
||||
|
||||
```bash
|
||||
# Windows (PowerShell)
|
||||
irm https://astral.sh/uv/install.ps1 | iex
|
||||
|
||||
# macOS/Linux
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
```
|
||||
|
||||
### 项目设置
|
||||
|
||||
```bash
|
||||
# 进入项目目录
|
||||
cd CausalInferenceAgent_V2
|
||||
|
||||
# 创建虚拟环境并安装依赖
|
||||
uv venv
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 1. 命令行工具 (CLI)
|
||||
|
||||
```bash
|
||||
# 基础用法
|
||||
uv run python -m causal_agent --data data.xlsx --output report.json --prompt "分析吃药的效果"
|
||||
|
||||
# 自定义 LLM 配置
|
||||
uv run python -m causal_agent --data data.xlsx --model qwen3.5-35b --base-url http://... --output report.json
|
||||
|
||||
# 查看帮助
|
||||
uv run python -m causal_agent --help
|
||||
```
|
||||
|
||||
### 2. Python API
|
||||
|
||||
```python
|
||||
from causal_agent.agent import CausalInferenceAgent
|
||||
from causal_agent.core.config import AgentConfig
|
||||
|
||||
# 使用默认配置
|
||||
agent = CausalInferenceAgent()
|
||||
result = agent.analyze("data.xlsx")
|
||||
|
||||
# 查看报告
|
||||
print(result["report"])
|
||||
```
|
||||
|
||||
### 3. 运行示例
|
||||
|
||||
```bash
|
||||
uv run python -m causal_agent --data data.xlsx --output report.json --prompt "分析吃药的效果"
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
CausalInferenceAgent_V2/
|
||||
├── pyproject.toml # 项目配置
|
||||
├── README.md # 本文件
|
||||
├── uv.lock # uv 锁定文件
|
||||
├── causal_agent/ # 主包
|
||||
│ ├── __init__.py
|
||||
│ ├── agent.py # CausalInferenceAgent
|
||||
│ ├── cli.py # 命令行入口
|
||||
│ ├── logger.py # 日志记录器
|
||||
│ ├── core/ # 核心模块
|
||||
│ │ ├── config.py # AgentConfig 配置
|
||||
│ │ ├── llm_client.py # LLM 客户端
|
||||
│ │ └── data_loader.py # 数据加载器
|
||||
│ └── analysis/ # 分析模块
|
||||
│ ├── variable_parser.py # LLM 变量识别
|
||||
│ ├── screening.py # 相关性筛查
|
||||
│ ├── causal_graph.py # 因果图构建
|
||||
│ ├── estimation.py # 效应估计
|
||||
│ └── reporting.py # 报告生成
|
||||
├── input_processor/ # 输入处理(原有)
|
||||
├── examples/
|
||||
│ └── medical_v2/ # 医疗数据示例
|
||||
├── tests/
|
||||
│ └── test_causal_agent/ # 单元测试
|
||||
└── data/ # 测试数据
|
||||
```
|
||||
|
||||
## 完整 Pipeline 说明
|
||||
|
||||
Agent 执行以下 6 个步骤:
|
||||
|
||||
### Step 1: LLM 变量识别
|
||||
|
||||
调用 LLM 识别:
|
||||
|
||||
- `treatment`: 处理变量
|
||||
- `outcome`: 结果变量
|
||||
- `time_tiers`: 所有变量的时间层级(-1~4+)
|
||||
|
||||
**time_tiers 层级说明:**
|
||||
|
||||
- `-1`: 非时间变量(如 id、index)
|
||||
- `0`: 人口学特征(如 age、gender)
|
||||
- `1`: 基线测量(如 baseline_score)
|
||||
- `2`: 干预点/处理变量
|
||||
- `3`: 中介变量
|
||||
- `4`: 随访结果/结果变量
|
||||
- `5+`: 更晚时间点
|
||||
|
||||
### Step 2: 快速相关性筛查
|
||||
|
||||
使用 Pearson + Spearman 相关系数 + 互信息 (MI):
|
||||
|
||||
- 筛选与 T 和 Y 都显著相关的变量(|r| > 0.1, p < 0.05)
|
||||
- 作为候选混杂因素
|
||||
|
||||
### Step 3: 因果图构建
|
||||
|
||||
基于时间层级构建混合智能因果图:
|
||||
|
||||
- 时间层级低的变量指向层级高的变量
|
||||
- 统计相关性低于阈值的边被过滤
|
||||
- 边类型:`temporal`、`confounding`、`hypothesized`
|
||||
|
||||
### Step 4: 混杂识别(后门准则)
|
||||
|
||||
- 发现所有以 `T <-` 开始的后门路径
|
||||
- 自动推导阻断所有后门路径的最小调整集
|
||||
|
||||
### Step 5: 效应估计
|
||||
|
||||
提供两种估计方法:
|
||||
|
||||
**Outcome Regression (OR)**
|
||||
|
||||
- 线性回归 + G-computation
|
||||
- 直接估计 E[Y|do(T=1)] - E[Y|do(T=0)]
|
||||
|
||||
**Inverse Probability Weighting (IPW)**
|
||||
|
||||
- Logistic 回归估计倾向得分
|
||||
- 截断稳定权重避免极端值
|
||||
|
||||
同时提供:
|
||||
|
||||
- Bootstrap 95% 置信区间
|
||||
- 标准化均值差(SMD)平衡检验
|
||||
- 重叠假设(Positivity)检查
|
||||
- OR/IPW 稳健性判断
|
||||
|
||||
### Step 6: 生成报告
|
||||
|
||||
标准化 JSON 报告包含:
|
||||
|
||||
```json
|
||||
{
|
||||
"query_interpretation": {
|
||||
"treatment": "treatment",
|
||||
"outcome": "health",
|
||||
"estimand": "ATE"
|
||||
},
|
||||
"causal_graph": {
|
||||
"nodes": ["age", "base_health", "treatment", "health"],
|
||||
"edges": [...],
|
||||
"backdoor_paths": ["treatment <- base_health -> health"]
|
||||
},
|
||||
"identification": {
|
||||
"strategy": "Backdoor Adjustment",
|
||||
"adjustment_set": ["age", "base_health"],
|
||||
"reasoning": "..."
|
||||
},
|
||||
"estimation": {
|
||||
"ATE_Outcome_Regression": 0.2444,
|
||||
"ATE_IPW": 0.2425,
|
||||
"95%_CI": [0.2351, 0.2544],
|
||||
"interpretation": "..."
|
||||
},
|
||||
"diagnostics": {
|
||||
"balance_check": {...},
|
||||
"overlap_assumption": "满足",
|
||||
"robustness": "稳健"
|
||||
},
|
||||
"warnings": [...]
|
||||
}
|
||||
```
|
||||
|
||||
## 配置选项
|
||||
|
||||
### 方法 1:使用 `.env` 文件(推荐)
|
||||
|
||||
项目已包含 `.env.example` 文件,复制为 `.env` 并修改:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
编辑 `.env` 文件修改 LLM URL:
|
||||
|
||||
```bash
|
||||
# 只需修改这一行即可切换 LLM 服务器
|
||||
LLM_BASE_URL=https://your-new-llm-server.com/v1
|
||||
LLM_MODEL=your-model-name
|
||||
```
|
||||
|
||||
**优点**:
|
||||
- 切换 URL 只需改一个文件
|
||||
- `.env` 不会被提交到 git(已在 .gitignore 中)
|
||||
- 所有代码自动读取,无需重启
|
||||
|
||||
### 方法 2:环境变量
|
||||
|
||||
```bash
|
||||
# Windows PowerShell
|
||||
$env:LLM_BASE_URL="https://your-new-llm-server.com/v1"
|
||||
$env:LLM_MODEL="qwen3.5-35b"
|
||||
|
||||
# macOS/Linux
|
||||
export LLM_BASE_URL="https://your-new-llm-server.com/v1"
|
||||
export LLM_MODEL="qwen3.5-35b"
|
||||
export LLM_TEMPERATURE="0.3"
|
||||
export LLM_MAX_TOKENS="2048"
|
||||
export LLM_API_KEY="your-api-key"
|
||||
|
||||
# 统计筛查配置
|
||||
export CORR_THRESHOLD="0.1"
|
||||
export ALPHA="0.05"
|
||||
export BOOTSTRAP_ITERATIONS="500"
|
||||
|
||||
# 路径配置
|
||||
export LOG_PATH="causal_analysis.log.md"
|
||||
```
|
||||
|
||||
### 方法 3:Python 代码中配置
|
||||
|
||||
```python
|
||||
from causal_agent.core.config import AgentConfig
|
||||
|
||||
config = AgentConfig(
|
||||
llm_model="gpt-4",
|
||||
llm_temperature=0.5,
|
||||
corr_threshold=0.15,
|
||||
log_path="my_analysis.log"
|
||||
)
|
||||
agent = CausalInferenceAgent(config)
|
||||
```
|
||||
|
||||
## 开发工具
|
||||
|
||||
```bash
|
||||
# 代码格式化
|
||||
uv run black .
|
||||
|
||||
# 代码检查
|
||||
uv run ruff check .
|
||||
|
||||
# 类型检查
|
||||
uv run mypy .
|
||||
|
||||
# 运行测试
|
||||
uv run pytest tests/test_causal_agent/ -v
|
||||
|
||||
# 运行测试并生成覆盖率报告
|
||||
uv run pytest tests/ --cov=. --cov-report=html
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
## 贡献
|
||||
|
||||
欢迎贡献代码!请遵循以下步骤:
|
||||
|
||||
1. Fork 本仓库
|
||||
2. 创建特性分支 (`git checkout -b feature/AmazingFeature`)
|
||||
3. 提交更改 (`git commit -m 'Add some AmazingFeature'`)
|
||||
4. 推送到分支 (`git push origin feature/AmazingFeature`)
|
||||
5. 开启 Pull Request
|
||||
12
causal_agent/__init__.py
Normal file
12
causal_agent/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""
|
||||
Causal Inference Agent
|
||||
|
||||
一个领域无关的通用因果推断 Agent,提供从数据加载、变量识别、
|
||||
因果图构建、混杂识别到效应估计的完整 pipeline。
|
||||
"""
|
||||
|
||||
from .agent import CausalInferenceAgent
|
||||
from .core.config import AgentConfig
|
||||
|
||||
__version__ = "0.2.0"
|
||||
__all__ = ["CausalInferenceAgent", "AgentConfig"]
|
||||
6
causal_agent/__main__.py
Normal file
6
causal_agent/__main__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Allow running causal_agent as a module: python -m causal_agent"""
|
||||
|
||||
from causal_agent.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
186
causal_agent/agent.py
Normal file
186
causal_agent/agent.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""
|
||||
Causal Inference Agent Orchestrator
|
||||
|
||||
整合完整 pipeline 的核心代理类。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from causal_agent.analysis.causal_graph import (
|
||||
build_local_graph,
|
||||
find_adjustment_set,
|
||||
find_backdoor_paths,
|
||||
)
|
||||
from causal_agent.analysis.estimation import estimate_ate
|
||||
from causal_agent.analysis.reporting import generate_report, print_report
|
||||
from causal_agent.analysis.screening import local_screen
|
||||
from causal_agent.analysis.variable_parser import VariableParser
|
||||
from causal_agent.core.config import AgentConfig
|
||||
from causal_agent.core.data_loader import DataLoader, get_data_info
|
||||
from causal_agent.core.llm_client import LLMClient
|
||||
from causal_agent.logger import CausalAnalysisLogger
|
||||
|
||||
|
||||
class CausalInferenceAgent:
|
||||
"""通用因果推断 Agent"""
|
||||
|
||||
def __init__(self, config: Optional[AgentConfig] = None):
|
||||
self.config = config or AgentConfig.from_env()
|
||||
self.data_loader = DataLoader()
|
||||
self.llm_client = LLMClient(
|
||||
base_url=self.config.llm_base_url,
|
||||
model=self.config.llm_model,
|
||||
temperature=self.config.llm_temperature,
|
||||
max_tokens=self.config.llm_max_tokens,
|
||||
api_key=self.config.llm_api_key,
|
||||
timeout=self.config.llm_timeout,
|
||||
max_retries=self.config.llm_max_retries,
|
||||
)
|
||||
self.variable_parser = VariableParser(self.llm_client)
|
||||
self.logger = CausalAnalysisLogger(self.config.log_path)
|
||||
|
||||
def analyze(
|
||||
self, data_path: str, custom_prompt: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行完整的因果分析 pipeline。
|
||||
|
||||
Args:
|
||||
data_path: 数据文件路径
|
||||
custom_prompt: 可选的自定义提示词补充
|
||||
|
||||
Returns:
|
||||
包含完整报告和中间结果的字典
|
||||
"""
|
||||
df = self.data_loader.load(data_path)
|
||||
columns = list(df.columns)
|
||||
data_info = get_data_info(df)
|
||||
|
||||
print(f"成功加载数据:{data_path}")
|
||||
print(f"数据形状:{df.shape}")
|
||||
print(f"列名:{columns}")
|
||||
|
||||
# Step 1: LLM 变量识别
|
||||
print("\n[Step 1] LLM 变量识别...")
|
||||
parsed = self.variable_parser.parse(columns, data_info, custom_prompt)
|
||||
T = parsed["treatment"]
|
||||
Y = parsed["outcome"]
|
||||
tiers = parsed["time_tiers"]
|
||||
print(f"处理变量:{T},结果变量:{Y}")
|
||||
print(f"时间层级:{tiers}")
|
||||
|
||||
# Step 2: 快速相关性筛查
|
||||
print("\n[Step 2] 快速相关性筛查...")
|
||||
candidates = local_screen(
|
||||
df,
|
||||
T=T,
|
||||
Y=Y,
|
||||
excluded=[c for c in columns if tiers.get(c, 0) < 0], # 排除 id 类
|
||||
corr_threshold=self.config.corr_threshold,
|
||||
alpha=self.config.alpha,
|
||||
)
|
||||
print(f"候选混杂变量:{[c['var'] for c in candidates]}")
|
||||
|
||||
# Step 3: 因果图构建
|
||||
print("\n[Step 3] 因果图构建...")
|
||||
G = build_local_graph(
|
||||
df, T, Y, candidates, tiers, corr_threshold=self.config.corr_threshold
|
||||
)
|
||||
edges = [
|
||||
{"from": u, "to": v, "type": data.get("type", "unknown")}
|
||||
for u, v, data in G.edges(data=True)
|
||||
]
|
||||
nodes = list(G.nodes())
|
||||
print(f"图节点:{nodes}")
|
||||
print(f"图边:{edges}")
|
||||
|
||||
# Step 4: 混杂识别
|
||||
print("\n[Step 4] 后门路径识别...")
|
||||
backdoor_paths = find_backdoor_paths(G, T, Y)
|
||||
print(f"后门路径:{backdoor_paths}")
|
||||
|
||||
adjustment_set, reasoning = find_adjustment_set(G, T, Y, backdoor_paths)
|
||||
print(f"调整集:{adjustment_set}")
|
||||
print(f"调整理由:{reasoning}")
|
||||
|
||||
# Step 5: 效应估计
|
||||
print("\n[Step 5] 效应估计...")
|
||||
estimation_result = estimate_ate(
|
||||
df,
|
||||
T,
|
||||
Y,
|
||||
adjustment_set,
|
||||
compute_ci=True,
|
||||
n_bootstrap=self.config.bootstrap_iterations,
|
||||
alpha=self.config.bootstrap_alpha,
|
||||
)
|
||||
print(f"ATE (OR): {estimation_result['ATE_Outcome_Regression']}")
|
||||
print(f"ATE (IPW): {estimation_result['ATE_IPW']}")
|
||||
print(f"95% CI: {estimation_result['95%_CI']}")
|
||||
|
||||
# Step 6: 生成报告
|
||||
print("\n[Step 6] 生成报告...")
|
||||
causal_graph_info = {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"backdoor_paths": backdoor_paths,
|
||||
}
|
||||
identification_info = {
|
||||
"strategy": "Backdoor Adjustment",
|
||||
"adjustment_set": adjustment_set,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
query_interpretation = {
|
||||
"treatment": T,
|
||||
"outcome": Y,
|
||||
"estimand": "ATE",
|
||||
}
|
||||
|
||||
report = generate_report(
|
||||
query_interpretation=query_interpretation,
|
||||
causal_graph=causal_graph_info,
|
||||
identification=identification_info,
|
||||
estimation=estimation_result,
|
||||
)
|
||||
print_report(report)
|
||||
|
||||
# 记录日志
|
||||
parameters = {
|
||||
"data_path": data_path,
|
||||
"sample_size": len(df),
|
||||
"variables": columns,
|
||||
"treatment_variable": T,
|
||||
"outcome_variable": Y,
|
||||
"time_tiers": tiers,
|
||||
"llm_params": {
|
||||
"base_url": self.config.llm_base_url,
|
||||
"model": self.config.llm_model,
|
||||
"temperature": self.config.llm_temperature,
|
||||
"max_tokens": self.config.llm_max_tokens,
|
||||
},
|
||||
"candidates": candidates,
|
||||
"causal_graph": causal_graph_info,
|
||||
"identification": identification_info,
|
||||
"estimation": {
|
||||
k: v
|
||||
for k, v in estimation_result.items()
|
||||
if k not in ("warnings", "balance_check")
|
||||
},
|
||||
"log_path": self.logger.log_path,
|
||||
}
|
||||
|
||||
self.logger.log_analysis(
|
||||
VariableParser.SYSTEM_PROMPT,
|
||||
self.variable_parser._build_user_prompt(columns, data_info, custom_prompt),
|
||||
str(parsed),
|
||||
parameters,
|
||||
report,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": f"{self.logger.analysis_count:03d}",
|
||||
"report": report,
|
||||
"parameters": parameters,
|
||||
}
|
||||
18
causal_agent/analysis/__init__.py
Normal file
18
causal_agent/analysis/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""Analysis modules for Causal Inference Agent."""
|
||||
|
||||
from .screening import local_screen
|
||||
from .causal_graph import build_local_graph, find_backdoor_paths, find_adjustment_set
|
||||
from .estimation import estimate_ate
|
||||
from .reporting import generate_report, print_report
|
||||
from .variable_parser import VariableParser
|
||||
|
||||
__all__ = [
|
||||
"local_screen",
|
||||
"build_local_graph",
|
||||
"find_backdoor_paths",
|
||||
"find_adjustment_set",
|
||||
"estimate_ate",
|
||||
"generate_report",
|
||||
"print_report",
|
||||
"VariableParser",
|
||||
]
|
||||
134
causal_agent/analysis/causal_graph.py
Normal file
134
causal_agent/analysis/causal_graph.py
Normal file
@ -0,0 +1,134 @@
|
||||
"""
|
||||
因果图构建与后门准则模块
|
||||
|
||||
功能:
|
||||
1. 基于时间层级和统计相关性构建混合智能因果图
|
||||
2. 自动发现后门路径
|
||||
3. 基于后门准则寻找最小调整集
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
from scipy.stats import spearmanr
|
||||
|
||||
|
||||
def build_local_graph(
|
||||
df: pd.DataFrame,
|
||||
T: str,
|
||||
Y: str,
|
||||
candidates: List[Dict[str, Any]],
|
||||
tiers: Dict[str, int],
|
||||
corr_threshold: float = 0.1,
|
||||
) -> nx.DiGraph:
|
||||
"""
|
||||
构建局部因果图。
|
||||
|
||||
Args:
|
||||
df: 输入数据框
|
||||
T: 处理变量
|
||||
Y: 结果变量
|
||||
candidates: 候选混杂变量列表
|
||||
tiers: 时间层级字典
|
||||
corr_threshold: 相关性阈值
|
||||
|
||||
Returns:
|
||||
networkx.DiGraph 对象
|
||||
"""
|
||||
G = nx.DiGraph()
|
||||
relevant_vars = [T, Y] + [c["var"] for c in candidates]
|
||||
G.add_nodes_from(relevant_vars)
|
||||
|
||||
for u in relevant_vars:
|
||||
for v in relevant_vars:
|
||||
if u == v:
|
||||
continue
|
||||
tier_u = tiers.get(u, -1)
|
||||
tier_v = tiers.get(v, -1)
|
||||
if tier_u < tier_v:
|
||||
corr, _ = spearmanr(df[u], df[v])
|
||||
if abs(corr) > corr_threshold:
|
||||
edge_type = "confounding" if v in [T, Y] else "temporal"
|
||||
G.add_edge(u, v, type=edge_type, confidence=1.0)
|
||||
|
||||
G.add_edge(T, Y, type="hypothesized", confidence="research_question")
|
||||
return G
|
||||
|
||||
|
||||
def find_backdoor_paths(G: nx.DiGraph, T: str, Y: str) -> List[str]:
|
||||
"""寻找从 T 到 Y 的所有后门路径"""
|
||||
UG = G.to_undirected()
|
||||
try:
|
||||
paths = list(nx.all_simple_paths(UG, source=T, target=Y, cutoff=5))
|
||||
except nx.NodeNotFound:
|
||||
paths = []
|
||||
|
||||
backdoor_paths = []
|
||||
for path in paths:
|
||||
if len(path) < 2:
|
||||
continue
|
||||
second_node = path[1]
|
||||
if G.has_edge(second_node, T):
|
||||
formatted = _format_path(G, path)
|
||||
backdoor_paths.append(formatted)
|
||||
|
||||
seen = set()
|
||||
unique_paths = []
|
||||
for p in backdoor_paths:
|
||||
if p not in seen:
|
||||
seen.add(p)
|
||||
unique_paths.append(p)
|
||||
return unique_paths
|
||||
|
||||
|
||||
def _format_path(G: nx.DiGraph, path: List[str]) -> str:
|
||||
"""将无向路径格式化为带方向箭头的字符串"""
|
||||
parts = [path[0]]
|
||||
for i in range(len(path) - 1):
|
||||
u, v = path[i], path[i + 1]
|
||||
if G.has_edge(u, v) and G.has_edge(v, u):
|
||||
parts.append(f"<-> {v}")
|
||||
elif G.has_edge(u, v):
|
||||
parts.append(f"-> {v}")
|
||||
elif G.has_edge(v, u):
|
||||
parts.append(f"<- {v}")
|
||||
else:
|
||||
parts.append(f"-- {v}")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def find_adjustment_set(
|
||||
G: nx.DiGraph, T: str, Y: str, backdoor_paths: Optional[List[str]] = None
|
||||
) -> Tuple[List[str], str]:
|
||||
"""基于后门准则寻找最小调整集"""
|
||||
if backdoor_paths is None:
|
||||
backdoor_paths = find_backdoor_paths(G, T, Y)
|
||||
|
||||
if not backdoor_paths:
|
||||
return [], "未发现后门路径,无需额外调整变量即可识别因果效应。"
|
||||
|
||||
adjustment_candidates = set()
|
||||
for path_str in backdoor_paths:
|
||||
nodes = _parse_path_nodes(path_str)
|
||||
if len(nodes) >= 2:
|
||||
adjustment_candidates.add(nodes[1])
|
||||
|
||||
safe_adjustment_set = sorted(adjustment_candidates)
|
||||
reasoning = (
|
||||
f"发现 {len(backdoor_paths)} 条后门路径。"
|
||||
f"通过控制变量 {safe_adjustment_set} 可阻断所有后门路径,满足后门准则。"
|
||||
)
|
||||
return safe_adjustment_set, reasoning
|
||||
|
||||
|
||||
def _parse_path_nodes(path_str: str) -> List[str]:
|
||||
"""从格式化路径字符串中提取节点列表"""
|
||||
cleaned = (
|
||||
path_str.replace("->", " ")
|
||||
.replace("<-", " ")
|
||||
.replace("<->", " ")
|
||||
.replace("--", " ")
|
||||
)
|
||||
nodes = [n.strip() for n in cleaned.split() if n.strip()]
|
||||
return nodes
|
||||
203
causal_agent/analysis/estimation.py
Normal file
203
causal_agent/analysis/estimation.py
Normal file
@ -0,0 +1,203 @@
|
||||
"""
|
||||
效应估计模块
|
||||
|
||||
提供 Outcome Regression (OR) 和 Inverse Probability Weighting (IPW) 两种估计方法。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.linear_model import LinearRegression, LogisticRegression
|
||||
|
||||
|
||||
def _compute_smd(df: pd.DataFrame, var: str, T: str) -> float:
|
||||
"""计算标准化均值差(SMD)"""
|
||||
treated = df[df[T] == 1][var]
|
||||
control = df[df[T] == 0][var]
|
||||
pooled_std = np.sqrt((treated.var() + control.var()) / 2)
|
||||
if pooled_std == 0:
|
||||
return 0.0
|
||||
return float((treated.mean() - control.mean()) / pooled_std)
|
||||
|
||||
|
||||
def _bootstrap_ci(
|
||||
df: pd.DataFrame,
|
||||
T: str,
|
||||
Y: str,
|
||||
adjustment_set: List[str],
|
||||
estimator: str,
|
||||
n_bootstrap: int = 500,
|
||||
alpha: float = 0.05,
|
||||
) -> List[float]:
|
||||
"""Bootstrap 置信区间"""
|
||||
n = len(df)
|
||||
estimates = []
|
||||
for _ in range(n_bootstrap):
|
||||
sample = df.sample(n=n, replace=True, random_state=None)
|
||||
try:
|
||||
if estimator == "OR":
|
||||
ate = _estimate_or(sample, T, Y, adjustment_set)
|
||||
elif estimator == "IPW":
|
||||
ate = _estimate_ipw(sample, T, Y, adjustment_set)
|
||||
else:
|
||||
ate = np.nan
|
||||
estimates.append(ate)
|
||||
except Exception:
|
||||
estimates.append(np.nan)
|
||||
|
||||
estimates = np.array(estimates)
|
||||
estimates = estimates[~np.isnan(estimates)]
|
||||
if len(estimates) == 0:
|
||||
return [np.nan, np.nan]
|
||||
|
||||
lower = float(np.percentile(estimates, 100 * alpha / 2))
|
||||
upper = float(np.percentile(estimates, 100 * (1 - alpha / 2)))
|
||||
return [round(lower, 4), round(upper, 4)]
|
||||
|
||||
|
||||
def _estimate_or(df: pd.DataFrame, T: str, Y: str, adjustment_set: List[str]) -> float:
|
||||
"""Outcome Regression 估计 ATE"""
|
||||
X_cols = adjustment_set + [T]
|
||||
X = df[X_cols].values
|
||||
y = df[Y].values
|
||||
model = LinearRegression().fit(X, y)
|
||||
|
||||
df_t1 = df.copy()
|
||||
df_t1[T] = 1
|
||||
df_t0 = df.copy()
|
||||
df_t0[T] = 0
|
||||
|
||||
ate = (
|
||||
model.predict(df_t1[X_cols].values).mean()
|
||||
- model.predict(df_t0[X_cols].values).mean()
|
||||
)
|
||||
return float(ate)
|
||||
|
||||
|
||||
def _estimate_ipw(df: pd.DataFrame, T: str, Y: str, adjustment_set: List[str]) -> float:
|
||||
"""IPW 估计 ATE"""
|
||||
t = df[T].values
|
||||
if len(adjustment_set) == 0:
|
||||
ps = np.full_like(t, fill_value=t.mean(), dtype=float)
|
||||
else:
|
||||
X = df[adjustment_set].values
|
||||
ps_model = LogisticRegression(max_iter=1000, solver="lbfgs")
|
||||
ps_model.fit(X, t)
|
||||
ps = ps_model.predict_proba(X)[:, 1]
|
||||
|
||||
ps_clipped = np.clip(ps, 0.05, 0.95)
|
||||
weights_t = t / ps_clipped
|
||||
weights_c = (1 - t) / (1 - ps_clipped)
|
||||
|
||||
ate = np.mean(weights_t * df[Y].values) - np.mean(weights_c * df[Y].values)
|
||||
return float(ate)
|
||||
|
||||
|
||||
def estimate_ate(
|
||||
df: pd.DataFrame,
|
||||
T: str,
|
||||
Y: str,
|
||||
adjustment_set: List[str],
|
||||
compute_ci: bool = True,
|
||||
n_bootstrap: int = 500,
|
||||
alpha: float = 0.05,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
估计平均处理效应(ATE)。
|
||||
|
||||
Args:
|
||||
df: 输入数据框
|
||||
T: 处理变量
|
||||
Y: 结果变量
|
||||
adjustment_set: 调整变量列表
|
||||
compute_ci: 是否计算 Bootstrap CI
|
||||
n_bootstrap: Bootstrap 次数
|
||||
alpha: 显著性水平
|
||||
|
||||
Returns:
|
||||
包含 ATE_OR、ATE_IPW、诊断信息的字典
|
||||
"""
|
||||
warnings_list: List[Dict[str, str]] = []
|
||||
|
||||
ate_or = _estimate_or(df, T, Y, adjustment_set)
|
||||
|
||||
t = df[T].values
|
||||
if len(adjustment_set) == 0:
|
||||
ps = np.full_like(t, fill_value=t.mean(), dtype=float)
|
||||
else:
|
||||
X = df[adjustment_set].values
|
||||
ps_model = LogisticRegression(max_iter=1000, solver="lbfgs")
|
||||
ps_model.fit(X, t)
|
||||
ps = ps_model.predict_proba(X)[:, 1]
|
||||
|
||||
overlap_ok = True
|
||||
if np.any(ps < 0.05) or np.any(ps > 0.95):
|
||||
overlap_ok = False
|
||||
warnings_list.append(
|
||||
{
|
||||
"type": "positivity_violation",
|
||||
"message": "倾向得分存在极端值(<0.05 或 >0.95),存在重叠假设违反风险。",
|
||||
}
|
||||
)
|
||||
|
||||
ate_ipw = _estimate_ipw(df, T, Y, adjustment_set)
|
||||
|
||||
diff = abs(ate_or - ate_ipw)
|
||||
robustness = "稳健"
|
||||
if ate_or != 0 and diff / abs(ate_or) > 0.1:
|
||||
robustness = "OR 与 IPW 估计差异 >10%,可能存在模型误设,建议进一步检查"
|
||||
elif ate_ipw != 0 and diff / abs(ate_ipw) > 0.1:
|
||||
robustness = "OR 与 IPW 估计差异 >10%,可能存在模型误设,建议进一步检查"
|
||||
|
||||
balance_check: Dict[str, Dict[str, float]] = {}
|
||||
for var in adjustment_set:
|
||||
smd_before = _compute_smd(df, var, T)
|
||||
ps_clipped = np.clip(ps, 0.05, 0.95)
|
||||
weights = np.where(df[T].values == 1, 1.0 / ps_clipped, 1.0 / (1 - ps_clipped))
|
||||
treated_idx = df[T].values == 1
|
||||
control_idx = df[T].values == 0
|
||||
weighted_mean_t = np.average(
|
||||
df.loc[treated_idx, var].values, weights=weights[treated_idx]
|
||||
)
|
||||
weighted_mean_c = np.average(
|
||||
df.loc[control_idx, var].values, weights=weights[control_idx]
|
||||
)
|
||||
pooled_std = np.sqrt(
|
||||
(df.loc[treated_idx, var].var() + df.loc[control_idx, var].var()) / 2
|
||||
)
|
||||
smd_after = (
|
||||
float((weighted_mean_t - weighted_mean_c) / pooled_std)
|
||||
if pooled_std > 0
|
||||
else 0.0
|
||||
)
|
||||
balance_check[var] = {
|
||||
"before": round(smd_before, 4),
|
||||
"after": round(smd_after, 4),
|
||||
}
|
||||
|
||||
ci_or: List[float] = [np.nan, np.nan]
|
||||
ci_ipw: List[float] = [np.nan, np.nan]
|
||||
if compute_ci:
|
||||
ci_or = _bootstrap_ci(df, T, Y, adjustment_set, "OR", n_bootstrap, alpha)
|
||||
ci_ipw = _bootstrap_ci(df, T, Y, adjustment_set, "IPW", n_bootstrap, alpha)
|
||||
|
||||
ate_report = round((ate_or + ate_ipw) / 2, 4)
|
||||
ci_report = ci_or if not np.isnan(ci_or[0]) else ci_ipw
|
||||
|
||||
interpretation = (
|
||||
f"在控制 {adjustment_set} 后,接受处理使 {Y} "
|
||||
f"平均变化 {ate_report:.4f} (95%CI: {ci_report[0]:.2f}-{ci_report[1]:.2f})。"
|
||||
)
|
||||
|
||||
return {
|
||||
"ATE_Outcome_Regression": round(ate_or, 4),
|
||||
"ATE_IPW": round(ate_ipw, 4),
|
||||
"ATE_reported": ate_report,
|
||||
"95%_CI": ci_report,
|
||||
"interpretation": interpretation,
|
||||
"balance_check": balance_check,
|
||||
"overlap_assumption": "满足" if overlap_ok else "存在风险",
|
||||
"robustness": robustness,
|
||||
"warnings": warnings_list,
|
||||
}
|
||||
59
causal_agent/analysis/reporting.py
Normal file
59
causal_agent/analysis/reporting.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""
|
||||
报告生成模块
|
||||
|
||||
将因果分析各阶段结果整合为统一的 JSON 报告。
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
def generate_report(
|
||||
query_interpretation: Dict[str, str],
|
||||
causal_graph: Dict[str, Any],
|
||||
identification: Dict[str, Any],
|
||||
estimation: Dict[str, Any],
|
||||
extra_warnings: Optional[List[Dict[str, str]]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成标准化 JSON 报告"""
|
||||
warnings = list(extra_warnings) if extra_warnings else []
|
||||
warnings.extend(estimation.get("warnings", []))
|
||||
|
||||
warnings.append(
|
||||
{
|
||||
"type": "unobserved_confounding",
|
||||
"message": "可能存在未观测混杂,建议进行敏感性分析。",
|
||||
}
|
||||
)
|
||||
|
||||
report = {
|
||||
"query_interpretation": query_interpretation,
|
||||
"causal_graph": causal_graph,
|
||||
"identification": {
|
||||
"strategy": identification.get("strategy", "Backdoor Adjustment"),
|
||||
"adjustment_set": identification.get("adjustment_set", []),
|
||||
"reasoning": identification.get("reasoning", ""),
|
||||
},
|
||||
"estimation": {
|
||||
"ATE_Outcome_Regression": estimation.get("ATE_Outcome_Regression"),
|
||||
"ATE_IPW": estimation.get("ATE_IPW"),
|
||||
"95%_CI": estimation.get("95%_CI"),
|
||||
"interpretation": estimation.get("interpretation", ""),
|
||||
},
|
||||
"diagnostics": {
|
||||
"balance_check": estimation.get("balance_check", {}),
|
||||
"overlap_assumption": estimation.get("overlap_assumption", "未知"),
|
||||
"robustness": estimation.get("robustness", "未知"),
|
||||
},
|
||||
"warnings": warnings,
|
||||
}
|
||||
return report
|
||||
|
||||
|
||||
def print_report(report: Dict[str, Any]) -> None:
|
||||
"""美观地打印报告"""
|
||||
print("\n" + "=" * 60)
|
||||
print("因果推断分析报告")
|
||||
print("=" * 60)
|
||||
print(json.dumps(report, indent=2, ensure_ascii=False))
|
||||
print("=" * 60)
|
||||
83
causal_agent/analysis/screening.py
Normal file
83
causal_agent/analysis/screening.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""
|
||||
快速相关性筛查模块
|
||||
|
||||
使用 Pearson + Spearman 方法快速过滤与处理变量 (T) 和结果变量 (Y) 都相关的变量。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.feature_selection import mutual_info_regression
|
||||
|
||||
|
||||
def local_screen(
|
||||
df: pd.DataFrame,
|
||||
T: str,
|
||||
Y: str,
|
||||
excluded: Optional[List[str]] = None,
|
||||
corr_threshold: float = 0.1,
|
||||
alpha: float = 0.05,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
快速相关性筛查:找出与 T 和 Y 都显著相关的变量。
|
||||
|
||||
Args:
|
||||
df: 输入数据框
|
||||
T: 处理变量名称
|
||||
Y: 结果变量名称
|
||||
excluded: 需要排除的变量列表
|
||||
corr_threshold: 相关系数绝对值阈值
|
||||
alpha: 显著性水平
|
||||
|
||||
Returns:
|
||||
候选混杂变量列表
|
||||
"""
|
||||
if excluded is None:
|
||||
excluded = []
|
||||
|
||||
candidates = []
|
||||
cols = [c for c in df.columns if c not in [T, Y] + excluded]
|
||||
|
||||
for col in cols:
|
||||
if df[col].isna().mean() > 0.5:
|
||||
continue
|
||||
if not pd.api.types.is_numeric_dtype(df[col]):
|
||||
continue
|
||||
|
||||
p_t, pv_t_pearson = pearsonr(df[col], df[T])
|
||||
p_y, pv_y_pearson = pearsonr(df[col], df[Y])
|
||||
|
||||
s_t, pv_t_spear = spearmanr(df[col], df[T])
|
||||
s_y, pv_y_spear = spearmanr(df[col], df[Y])
|
||||
|
||||
cond_t = abs(s_t) > corr_threshold and pv_t_spear < alpha
|
||||
cond_y = abs(s_y) > corr_threshold and pv_y_spear < alpha
|
||||
|
||||
if cond_t and cond_y:
|
||||
mi_t = _compute_mi(df[[col]].values, df[T].values)
|
||||
mi_y = _compute_mi(df[[col]].values, df[Y].values)
|
||||
|
||||
candidates.append(
|
||||
{
|
||||
"var": col,
|
||||
"pearson_T": round(float(p_t), 4),
|
||||
"pearson_Y": round(float(p_y), 4),
|
||||
"spearman_T": round(float(s_t), 4),
|
||||
"spearman_Y": round(float(s_y), 4),
|
||||
"pvalue_T": round(float(pv_t_spear), 4),
|
||||
"pvalue_Y": round(float(pv_y_spear), 4),
|
||||
"mi_T": round(float(mi_t), 4),
|
||||
"mi_Y": round(float(mi_y), 4),
|
||||
}
|
||||
)
|
||||
|
||||
candidates.sort(key=lambda x: abs(x["spearman_T"]) + abs(x["spearman_Y"]), reverse=True)
|
||||
return candidates
|
||||
|
||||
|
||||
def _compute_mi(X: np.ndarray, y: np.ndarray) -> float:
|
||||
"""计算互信息"""
|
||||
mi = mutual_info_regression(X, y, random_state=42)
|
||||
return float(mi[0])
|
||||
133
causal_agent/analysis/variable_parser.py
Normal file
133
causal_agent/analysis/variable_parser.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""
|
||||
变量解析模块
|
||||
|
||||
使用 LLM 识别处理变量、结果变量,并对每个变量进行时间层级解析。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from causal_agent.core.llm_client import LLMClient
|
||||
|
||||
|
||||
class VariableParser:
|
||||
"""LLM 变量解析器"""
|
||||
|
||||
SYSTEM_PROMPT = """你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别处理变量(treatment)、结果变量(outcome),并对每个变量进行时间层级解析。
|
||||
|
||||
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
|
||||
|
||||
JSON 输出规范:
|
||||
{
|
||||
"treatment": "处理变量名称",
|
||||
"outcome": "结果变量名称",
|
||||
"time_tiers": {
|
||||
"变量名1": 整数层级,
|
||||
"变量名2": 整数层级,
|
||||
...
|
||||
}
|
||||
}
|
||||
|
||||
time_tiers 层级说明(整数,越小表示越早发生):
|
||||
- -1: 非时间变量(如样本唯一标识符 id、index 等)
|
||||
- 0: 人口学特征或不变的混杂因素(如 age、gender、region 等)
|
||||
- 1: 基线测量(干预前测得,可能是混杂因素,如 baseline_score、pre_test 等)
|
||||
- 2: 干预点/处理变量(如 treatment、intervention、policy 等)
|
||||
- 3: 中介变量(干预后、结果前测得)
|
||||
- 4: 随访结果/结果变量(如 outcome、post_test、score 等)
|
||||
- 5+: 更晚的时间点(如有多次随访)
|
||||
|
||||
注意:
|
||||
- 只输出上述 JSON 格式,不要包含其他字段
|
||||
- treatment 和 outcome 必须是数据表格中真实存在的列名
|
||||
- time_tiers 必须包含数据中的所有列名
|
||||
- 不要使用 markdown 代码块标记(如 ```json)
|
||||
- 直接输出纯 JSON 字符串"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self.llm_client = llm_client or LLMClient()
|
||||
|
||||
def parse(
|
||||
self,
|
||||
columns: List[str],
|
||||
data_info: str,
|
||||
custom_prompt: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用 LLM 解析变量。
|
||||
|
||||
Args:
|
||||
columns: 数据列名列表
|
||||
data_info: 数据信息摘要字符串
|
||||
custom_prompt: 可选的自定义用户提示词补充
|
||||
|
||||
Returns:
|
||||
包含 treatment、outcome、time_tiers 的字典
|
||||
"""
|
||||
user_prompt = self._build_user_prompt(columns, data_info, custom_prompt)
|
||||
result = self.llm_client.generate_response(self.SYSTEM_PROMPT, user_prompt)
|
||||
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(f"LLM 调用失败: {result.get('error')}")
|
||||
|
||||
try:
|
||||
parsed = LLMClient.parse_json_response(result["content"])
|
||||
except Exception as e:
|
||||
raise ValueError(f"LLM 输出解析失败: {e}\n原始内容: {result.get('content')}")
|
||||
|
||||
self._validate(parsed, columns)
|
||||
return parsed
|
||||
|
||||
def _build_user_prompt(
|
||||
self,
|
||||
columns: List[str],
|
||||
data_info: str,
|
||||
custom_prompt: Optional[str] = None,
|
||||
) -> str:
|
||||
"""构建用户提示词"""
|
||||
parts = [
|
||||
"请分析以下数据,并严格按照 JSON 格式输出分析结果:",
|
||||
"",
|
||||
data_info,
|
||||
"",
|
||||
"JSON 输出格式要求:",
|
||||
'{',
|
||||
' "treatment": "处理变量名称",',
|
||||
' "outcome": "结果变量名称",',
|
||||
' "time_tiers": {',
|
||||
' "列名1": 层级整数,',
|
||||
' "列名2": 层级整数,',
|
||||
" ...",
|
||||
" }",
|
||||
"}",
|
||||
"",
|
||||
"要求:",
|
||||
"1. treatment 和 outcome 必须与表格列名完全一致",
|
||||
"2. time_tiers 必须覆盖所有列名",
|
||||
"3. 根据列名含义和统计摘要推断每个变量的时间层级",
|
||||
"4. 只输出 JSON,不要包含其他任何内容",
|
||||
"5. 不要使用 markdown 代码块标记",
|
||||
]
|
||||
if custom_prompt:
|
||||
parts.insert(1, f"\n用户补充说明:{custom_prompt}\n")
|
||||
return "\n".join(parts)
|
||||
|
||||
def _validate(self, parsed: Dict[str, Any], columns: List[str]):
|
||||
"""验证 LLM 返回结果"""
|
||||
treatment = parsed.get("treatment")
|
||||
outcome = parsed.get("outcome")
|
||||
tiers = parsed.get("time_tiers", {})
|
||||
|
||||
if treatment not in columns:
|
||||
raise ValueError(f"LLM 返回的 treatment '{treatment}' 不在数据列中")
|
||||
if outcome not in columns:
|
||||
raise ValueError(f"LLM 返回的 outcome '{outcome}' 不在数据列中")
|
||||
if not isinstance(tiers, dict):
|
||||
raise ValueError("time_tiers 必须是字典类型")
|
||||
|
||||
missing = set(columns) - set(tiers.keys())
|
||||
if missing:
|
||||
raise ValueError(f"time_tiers 缺少以下列: {sorted(missing)}")
|
||||
|
||||
for col, tier in tiers.items():
|
||||
if not isinstance(tier, int):
|
||||
raise ValueError(f"time_tiers 中 '{col}' 的层级必须是整数")
|
||||
67
causal_agent/cli.py
Normal file
67
causal_agent/cli.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Causal Inference Agent CLI
|
||||
|
||||
命令行入口,支持直接对数据文件执行因果推断分析。
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from causal_agent.agent import CausalInferenceAgent
|
||||
from causal_agent.core.config import AgentConfig
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="通用因果推断 Agent",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
python -m causal_agent --data data.csv --output report.json
|
||||
python -m causal_agent --data data.xlsx --model qwen3.5-35b
|
||||
python -m causal_agent --data data.csv --prompt "分析教育干预对成绩的影响"
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--data", "-d", required=True, help="数据文件路径")
|
||||
parser.add_argument("--output", "-o", help="报告输出 JSON 文件路径")
|
||||
parser.add_argument("--prompt", "-p", help="自定义分析提示词")
|
||||
parser.add_argument("--base-url", help="LLM API Base URL")
|
||||
parser.add_argument("--model", help="LLM 模型名称")
|
||||
parser.add_argument("--temperature", type=float, help="LLM 温度参数")
|
||||
parser.add_argument("--log-path", help="日志文件路径")
|
||||
parser.add_argument("--corr-threshold", type=float, help="相关性筛查阈值")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.data):
|
||||
print(f"错误:数据文件不存在:{args.data}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# 构建配置
|
||||
config = AgentConfig.from_env()
|
||||
if args.base_url:
|
||||
config.llm_base_url = args.base_url
|
||||
if args.model:
|
||||
config.llm_model = args.model
|
||||
if args.temperature is not None:
|
||||
config.llm_temperature = args.temperature
|
||||
if args.log_path:
|
||||
config.log_path = args.log_path
|
||||
if args.corr_threshold is not None:
|
||||
config.corr_threshold = args.corr_threshold
|
||||
|
||||
agent = CausalInferenceAgent(config)
|
||||
result = agent.analyze(args.data, custom_prompt=args.prompt)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(result["report"], f, ensure_ascii=False, indent=2)
|
||||
print(f"\n报告已保存到:{args.output}")
|
||||
|
||||
print("\n分析完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
7
causal_agent/core/__init__.py
Normal file
7
causal_agent/core/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
"""Core modules for Causal Inference Agent."""
|
||||
|
||||
from .config import AgentConfig
|
||||
from .llm_client import LLMClient
|
||||
from .data_loader import DataLoader, get_data_info
|
||||
|
||||
__all__ = ["AgentConfig", "LLMClient", "DataLoader", "get_data_info"]
|
||||
81
causal_agent/core/config.py
Normal file
81
causal_agent/core/config.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""
|
||||
Agent 配置模块
|
||||
|
||||
支持从环境变量、构造函数参数加载配置。
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _load_env_file():
|
||||
"""加载项目根目录的 .env 文件"""
|
||||
# 尝试多个可能的位置
|
||||
possible_paths = [
|
||||
Path.cwd() / ".env",
|
||||
Path(__file__).parent.parent.parent / ".env",
|
||||
Path(__file__).parent.parent.parent.parent / ".env",
|
||||
]
|
||||
|
||||
for env_path in possible_paths:
|
||||
if env_path.exists():
|
||||
with open(env_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
# 只在环境变量不存在时才设置
|
||||
if key not in os.environ:
|
||||
os.environ[key] = value
|
||||
break
|
||||
|
||||
|
||||
# 模块加载时自动读取 .env 文件
|
||||
_load_env_file()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Causal Inference Agent 配置"""
|
||||
|
||||
# LLM 配置(默认值仅在环境变量和 .env 都不存在时使用)
|
||||
llm_base_url: str = "https://glm47flash.cloyir.com/v1"
|
||||
llm_model: str = "qwen3.5-35b"
|
||||
llm_temperature: float = 0.3
|
||||
llm_max_tokens: int = 2048
|
||||
llm_api_key: Optional[str] = None
|
||||
llm_timeout: int = 120
|
||||
llm_max_retries: int = 3
|
||||
|
||||
# 统计筛查配置
|
||||
corr_threshold: float = 0.1
|
||||
alpha: float = 0.05
|
||||
bootstrap_iterations: int = 500
|
||||
bootstrap_alpha: float = 0.05
|
||||
|
||||
# 路径配置
|
||||
log_path: str = "causal_analysis.log.md"
|
||||
output_dir: str = "."
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "AgentConfig":
|
||||
"""从环境变量加载配置"""
|
||||
return cls(
|
||||
llm_base_url=os.getenv("LLM_BASE_URL", cls.llm_base_url),
|
||||
llm_model=os.getenv("LLM_MODEL", cls.llm_model),
|
||||
llm_temperature=float(os.getenv("LLM_TEMPERATURE", str(cls.llm_temperature))),
|
||||
llm_max_tokens=int(os.getenv("LLM_MAX_TOKENS", str(cls.llm_max_tokens))),
|
||||
llm_api_key=os.getenv("LLM_API_KEY", cls.llm_api_key),
|
||||
llm_timeout=int(os.getenv("LLM_TIMEOUT", str(cls.llm_timeout))),
|
||||
llm_max_retries=int(os.getenv("LLM_MAX_RETRIES", str(cls.llm_max_retries))),
|
||||
corr_threshold=float(os.getenv("CORR_THRESHOLD", str(cls.corr_threshold))),
|
||||
alpha=float(os.getenv("ALPHA", str(cls.alpha))),
|
||||
bootstrap_iterations=int(
|
||||
os.getenv("BOOTSTRAP_ITERATIONS", str(cls.bootstrap_iterations))
|
||||
),
|
||||
bootstrap_alpha=float(os.getenv("BOOTSTRAP_ALPHA", str(cls.bootstrap_alpha))),
|
||||
log_path=os.getenv("LOG_PATH", cls.log_path),
|
||||
output_dir=os.getenv("OUTPUT_DIR", cls.output_dir),
|
||||
)
|
||||
68
causal_agent/core/data_loader.py
Normal file
68
causal_agent/core/data_loader.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""
|
||||
通用数据加载器
|
||||
|
||||
支持 CSV、Excel、JSON Lines 等格式,同时提供数据信息摘要生成。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def _read_json_lines(file_path: str) -> pd.DataFrame:
|
||||
"""读取 JSON Lines 文件"""
|
||||
return pd.read_json(file_path, lines=True)
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""通用数据加载器"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".json", ".jsonl"}
|
||||
|
||||
def __init__(self, default_read_rows: Optional[int] = None):
|
||||
self.default_read_rows = default_read_rows
|
||||
|
||||
def load(self, file_path: str) -> pd.DataFrame:
|
||||
"""加载数据文件为 DataFrame"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"数据文件不存在:{file_path}")
|
||||
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext not in self.SUPPORTED_EXTENSIONS:
|
||||
raise ValueError(
|
||||
f"不支持的文件格式:{ext},仅支持 {self.SUPPORTED_EXTENSIONS}"
|
||||
)
|
||||
|
||||
if ext == ".csv":
|
||||
kwargs = {}
|
||||
if self.default_read_rows is not None:
|
||||
kwargs["nrows"] = self.default_read_rows
|
||||
return pd.read_csv(file_path, **kwargs)
|
||||
elif ext in {".xlsx", ".xls"}:
|
||||
return pd.read_excel(file_path)
|
||||
elif ext in {".json", ".jsonl"}:
|
||||
return _read_json_lines(file_path)
|
||||
|
||||
raise ValueError(f"未实现的文件格式:{ext}")
|
||||
|
||||
|
||||
def get_data_info(df: pd.DataFrame, treatment_col: Optional[str] = None) -> str:
|
||||
"""生成数据信息摘要字符串,用于 LLM prompt"""
|
||||
lines = [
|
||||
f"**数据概览:**",
|
||||
f"- 样本数量:{len(df)}",
|
||||
f"- 变量:{', '.join(df.columns)}",
|
||||
f"",
|
||||
f"**统计摘要:**",
|
||||
f"{df.describe()}",
|
||||
]
|
||||
if treatment_col and treatment_col in df.columns:
|
||||
lines.extend(
|
||||
[
|
||||
f"",
|
||||
f"**处理变量分布:**",
|
||||
f"{df[treatment_col].value_counts().to_string()}",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines)
|
||||
129
causal_agent/core/llm_client.py
Normal file
129
causal_agent/core/llm_client.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""
|
||||
LLM 调用模块
|
||||
|
||||
用于调用外部 LLM API 进行文本生成,支持重试和 JSON 解析辅助。
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""LLM API 客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 120,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model or os.getenv("LLM_MODEL", "qwen3.5-35b")
|
||||
self.temperature = temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.3"))
|
||||
self.max_tokens = max_tokens if max_tokens is not None else int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.api_key = api_key
|
||||
|
||||
self._headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
self._headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""发送聊天请求并获取响应,支持重试"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature or self.temperature,
|
||||
"max_tokens": max_tokens or self.max_tokens,
|
||||
"stream": False,
|
||||
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
|
||||
}
|
||||
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
url, headers=self._headers, json=payload, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
last_error = e
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(1 * (attempt + 1))
|
||||
|
||||
return {
|
||||
"error": str(last_error),
|
||||
"status_code": (
|
||||
getattr(last_error.response, "status_code", None)
|
||||
if hasattr(last_error, "response")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成响应(简化接口)"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
result = self.chat_completion(messages, temperature, max_tokens)
|
||||
|
||||
if "error" in result:
|
||||
return {"success": False, "error": result["error"], "content": None}
|
||||
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
choice = result["choices"][0]
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content")
|
||||
if content is None:
|
||||
content = message.get("reasoning", "")
|
||||
if content is None:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No content found in response",
|
||||
"content": None,
|
||||
}
|
||||
return {
|
||||
"success": True,
|
||||
"error": None,
|
||||
"content": content,
|
||||
"usage": result.get("usage", {}),
|
||||
"model": result.get("model", self.model),
|
||||
}
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unexpected response format: {result}",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse_json_response(content: str) -> Dict[str, Any]:
|
||||
"""清理并解析 LLM 返回的 JSON 字符串"""
|
||||
cleaned = content.strip()
|
||||
if cleaned.startswith("```"):
|
||||
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
|
||||
cleaned = re.sub(r"\s*```$", "", cleaned)
|
||||
return json.loads(cleaned)
|
||||
115
causal_agent/logger.py
Normal file
115
causal_agent/logger.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""
|
||||
因果分析日志记录器
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class CausalAnalysisLogger:
|
||||
"""因果分析日志记录器"""
|
||||
|
||||
def __init__(self, log_path: str = "causal_analysis.log.md"):
|
||||
self.log_path = log_path
|
||||
self.log_entries: List[Dict[str, Any]] = []
|
||||
self._init_log_file()
|
||||
self._load_existing_entries()
|
||||
|
||||
def _init_log_file(self):
|
||||
"""初始化日志文件"""
|
||||
header = """# 因果分析日志
|
||||
|
||||
## 日志说明
|
||||
本文档记录因果推断分析的所有输入参数和输出结果。
|
||||
|
||||
## 分析记录
|
||||
"""
|
||||
if not os.path.exists(self.log_path):
|
||||
with open(self.log_path, "w", encoding="utf-8") as f:
|
||||
f.write(header)
|
||||
|
||||
def _load_existing_entries(self):
|
||||
"""加载已存在的分析记录数量"""
|
||||
if not os.path.exists(self.log_path):
|
||||
self.analysis_count = 0
|
||||
return
|
||||
|
||||
analysis_count = 0
|
||||
with open(self.log_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
matches = re.findall(r"### 分析 #(\d+)", content)
|
||||
if matches:
|
||||
analysis_count = max(int(m) for m in matches)
|
||||
self.analysis_count = analysis_count
|
||||
|
||||
def _append_to_log(self, content: str):
|
||||
"""追加内容到日志文件"""
|
||||
with open(self.log_path, "a", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
def log_analysis(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
model_output: str,
|
||||
parameters: Dict[str, Any],
|
||||
report: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""记录一次完整的分析过程"""
|
||||
self.analysis_count += 1
|
||||
analysis_id = f"{self.analysis_count:03d}"
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
report_json = ""
|
||||
if report is not None:
|
||||
report_json = (
|
||||
f"\n#### 分析报告\n"
|
||||
f"```json\n{json.dumps(report, indent=2, ensure_ascii=False)}\n```\n"
|
||||
)
|
||||
|
||||
entry = f"""
|
||||
---
|
||||
|
||||
### 分析 #{analysis_id}
|
||||
**时间**: {timestamp}
|
||||
|
||||
#### 系统提示词
|
||||
```
|
||||
{system_prompt}
|
||||
```
|
||||
|
||||
#### 用户提示词
|
||||
```
|
||||
{user_prompt}
|
||||
```
|
||||
|
||||
#### LLM 输出
|
||||
```
|
||||
{model_output}
|
||||
```
|
||||
{report_json}
|
||||
#### 调用参数
|
||||
```json
|
||||
{json.dumps(parameters, indent=2, ensure_ascii=False)}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
self._append_to_log(entry)
|
||||
self.log_entries.append(
|
||||
{
|
||||
"id": analysis_id,
|
||||
"timestamp": timestamp,
|
||||
"system_prompt": system_prompt,
|
||||
"user_prompt": user_prompt,
|
||||
"model_output": model_output,
|
||||
"parameters": parameters,
|
||||
"report": report,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"分析 #{analysis_id} 已记录到 {self.log_path}")
|
||||
807
data/simulator.py
Normal file
807
data/simulator.py
Normal file
@ -0,0 +1,807 @@
|
||||
"""
|
||||
模拟数据生成模块
|
||||
|
||||
本模块提供用于测试因果推断的模拟数据生成功能,支持多种数据生成场景:
|
||||
- 简单 ATE 场景
|
||||
- 协变量场景
|
||||
- 交互效应场景
|
||||
- CATE 场景
|
||||
|
||||
作者:CausalInferenceAgent
|
||||
版本:1.0.0
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class BaseDataSimulator(ABC):
|
||||
"""
|
||||
数据模拟器基类
|
||||
|
||||
定义所有数据模拟器的通用接口。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
random_state: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化数据模拟器
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量,默认为 1000
|
||||
treatment_rate: 处理比例(0-1 之间),默认为 0.5
|
||||
noise_level: 噪声水平,默认为 1.0
|
||||
effect_size: 效应大小,默认为 1.0
|
||||
random_state: 随机种子,用于可重复性
|
||||
"""
|
||||
self.n_samples = n_samples
|
||||
self.treatment_rate = treatment_rate
|
||||
self.noise_level = noise_level
|
||||
self.effect_size = effect_size
|
||||
self.random_state = random_state
|
||||
|
||||
# 设置随机种子
|
||||
if random_state is not None:
|
||||
np.random.seed(random_state)
|
||||
|
||||
@abstractmethod
|
||||
def generate(self) -> pd.DataFrame:
|
||||
"""
|
||||
生成模拟数据
|
||||
|
||||
返回:
|
||||
包含模拟数据的 DataFrame
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_to_csv(self, filepath: Union[str, Path]) -> None:
|
||||
"""
|
||||
将生成的数据保存为 CSV 文件
|
||||
|
||||
参数:
|
||||
filepath: 文件路径
|
||||
"""
|
||||
df = self.generate()
|
||||
df.to_csv(filepath, index=False, encoding='utf-8')
|
||||
|
||||
def save_to_excel(self, filepath: Union[str, Path]) -> None:
|
||||
"""
|
||||
将生成的数据保存为 Excel 文件
|
||||
|
||||
参数:
|
||||
filepath: 文件路径
|
||||
"""
|
||||
df = self.generate()
|
||||
df.to_excel(filepath, index=False)
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str:
|
||||
"""
|
||||
获取数据生成场景的描述
|
||||
|
||||
返回:
|
||||
描述字符串
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class SimpleATESimulator(BaseDataSimulator):
|
||||
"""
|
||||
简单 ATE(平均处理效应)场景模拟器
|
||||
|
||||
生成只有一个处理变量 T 和一个结果变量 Y 的简单数据。
|
||||
模型:Y = α + τ*T + ε
|
||||
其中 τ 是平均处理效应。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
intercept: float = 0.0,
|
||||
random_state: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化简单 ATE 模拟器
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量
|
||||
treatment_rate: 处理比例
|
||||
noise_level: 噪声水平
|
||||
effect_size: 平均处理效应大小
|
||||
intercept: 截距项
|
||||
random_state: 随机种子
|
||||
"""
|
||||
super().__init__(
|
||||
n_samples=n_samples,
|
||||
treatment_rate=treatment_rate,
|
||||
noise_level=noise_level,
|
||||
effect_size=effect_size,
|
||||
random_state=random_state
|
||||
)
|
||||
self.intercept = intercept
|
||||
|
||||
def generate(self) -> pd.DataFrame:
|
||||
"""
|
||||
生成简单 ATE 场景的模拟数据
|
||||
|
||||
返回:
|
||||
包含 'T'(处理变量)和 'Y'(结果变量)的 DataFrame
|
||||
"""
|
||||
# 生成处理变量(伯努利分布)
|
||||
T = np.random.binomial(1, self.treatment_rate, self.n_samples)
|
||||
|
||||
# 生成噪声
|
||||
noise = np.random.normal(0, self.noise_level, self.n_samples)
|
||||
|
||||
# 生成结果变量:Y = intercept + effect_size * T + noise
|
||||
Y = self.intercept + self.effect_size * T + noise
|
||||
|
||||
df = pd.DataFrame({
|
||||
'T': T,
|
||||
'Y': Y
|
||||
})
|
||||
|
||||
return df
|
||||
|
||||
def get_description(self) -> str:
|
||||
"""获取场景描述"""
|
||||
return (
|
||||
f"简单 ATE 场景:\n"
|
||||
f"- 样本数量:{self.n_samples}\n"
|
||||
f"- 处理比例:{self.treatment_rate:.2%}\n"
|
||||
f"- 噪声水平:{self.noise_level}\n"
|
||||
f"- 平均处理效应:{self.effect_size}\n"
|
||||
f"- 截距项:{self.intercept}\n"
|
||||
f"模型:Y = {self.intercept} + {self.effect_size}*T + ε"
|
||||
)
|
||||
|
||||
|
||||
class CovariateSimulator(BaseDataSimulator):
|
||||
"""
|
||||
协变量场景模拟器
|
||||
|
||||
生成包含协变量 X 的数据,用于控制混淆因素。
|
||||
模型:Y = α + β*X + τ*T + ε
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
n_covariates: int = 3,
|
||||
covariate_distribution: str = 'normal',
|
||||
covariate_params: Optional[Dict[str, Any]] = None,
|
||||
random_state: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化协变量场景模拟器
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量
|
||||
treatment_rate: 处理比例
|
||||
noise_level: 噪声水平
|
||||
effect_size: 处理效应大小
|
||||
n_covariates: 协变量数量
|
||||
covariate_distribution: 协变量分布类型 ('normal', 'uniform', 'beta')
|
||||
covariate_params: 协变量分布参数
|
||||
random_state: 随机种子
|
||||
"""
|
||||
super().__init__(
|
||||
n_samples=n_samples,
|
||||
treatment_rate=treatment_rate,
|
||||
noise_level=noise_level,
|
||||
effect_size=effect_size,
|
||||
random_state=random_state
|
||||
)
|
||||
self.n_covariates = n_covariates
|
||||
self.covariate_distribution = covariate_distribution
|
||||
self.covariate_params = covariate_params or {}
|
||||
|
||||
def _generate_covariates(self) -> np.ndarray:
|
||||
"""
|
||||
生成协变量数据
|
||||
|
||||
返回:
|
||||
形状为 (n_samples, n_covariates) 的协变量数组
|
||||
"""
|
||||
if self.covariate_distribution == 'normal':
|
||||
mean = self.covariate_params.get('mean', 0)
|
||||
std = self.covariate_params.get('std', 1)
|
||||
X = np.random.normal(mean, std, (self.n_samples, self.n_covariates))
|
||||
|
||||
elif self.covariate_distribution == 'uniform':
|
||||
low = self.covariate_params.get('low', -1)
|
||||
high = self.covariate_params.get('high', 1)
|
||||
X = np.random.uniform(low, high, (self.n_samples, self.n_covariates))
|
||||
|
||||
elif self.covariate_distribution == 'beta':
|
||||
a = self.covariate_params.get('a', 2)
|
||||
b = self.covariate_params.get('b', 2)
|
||||
X = np.random.beta(a, b, (self.n_samples, self.n_covariates))
|
||||
|
||||
else:
|
||||
raise ValueError(f"未知的协变量分布类型:{self.covariate_distribution}")
|
||||
|
||||
return X
|
||||
|
||||
def generate(self) -> pd.DataFrame:
|
||||
"""
|
||||
生成协变量场景的模拟数据
|
||||
|
||||
返回:
|
||||
包含协变量 X、处理变量 T 和结果变量 Y 的 DataFrame
|
||||
"""
|
||||
# 生成协变量
|
||||
X = self._generate_covariates()
|
||||
|
||||
# 生成处理变量
|
||||
T = np.random.binomial(1, self.treatment_rate, self.n_samples)
|
||||
|
||||
# 生成噪声
|
||||
noise = np.random.normal(0, self.noise_level, self.n_samples)
|
||||
|
||||
# 生成协变量系数
|
||||
beta = np.random.uniform(-1, 1, self.n_covariates)
|
||||
|
||||
# 生成结果变量:Y = β*X + τ*T + ε
|
||||
Y = np.dot(X, beta) + self.effect_size * T + noise
|
||||
|
||||
# 创建 DataFrame
|
||||
df_dict = {}
|
||||
for i in range(self.n_covariates):
|
||||
df_dict[f'X{i}'] = X[:, i]
|
||||
df_dict['T'] = T
|
||||
df_dict['Y'] = Y
|
||||
|
||||
df = pd.DataFrame(df_dict)
|
||||
|
||||
return df
|
||||
|
||||
def get_description(self) -> str:
|
||||
"""获取场景描述"""
|
||||
return (
|
||||
f"协变量场景:\n"
|
||||
f"- 样本数量:{self.n_samples}\n"
|
||||
f"- 处理比例:{self.treatment_rate:.2%}\n"
|
||||
f"- 噪声水平:{self.noise_level}\n"
|
||||
f"- 处理效应:{self.effect_size}\n"
|
||||
f"- 协变量数量:{self.n_covariates}\n"
|
||||
f"- 协变量分布:{self.covariate_distribution}\n"
|
||||
f"模型:Y = β*X + {self.effect_size}*T + ε"
|
||||
)
|
||||
|
||||
|
||||
class InteractionEffectSimulator(BaseDataSimulator):
|
||||
"""
|
||||
交互效应场景模拟器
|
||||
|
||||
生成处理效应随协变量变化的数据。
|
||||
模型:Y = α + β*X + τ*T + γ*X*T + ε
|
||||
其中 γ 是交互效应系数。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
interaction_strength: float = 0.5,
|
||||
n_covariates: int = 2,
|
||||
random_state: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化交互效应场景模拟器
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量
|
||||
treatment_rate: 处理比例
|
||||
noise_level: 噪声水平
|
||||
effect_size: 基础处理效应
|
||||
interaction_strength: 交互效应强度
|
||||
n_covariates: 协变量数量
|
||||
random_state: 随机种子
|
||||
"""
|
||||
super().__init__(
|
||||
n_samples=n_samples,
|
||||
treatment_rate=treatment_rate,
|
||||
noise_level=noise_level,
|
||||
effect_size=effect_size,
|
||||
random_state=random_state
|
||||
)
|
||||
self.interaction_strength = interaction_strength
|
||||
self.n_covariates = n_covariates
|
||||
|
||||
def generate(self) -> pd.DataFrame:
|
||||
"""
|
||||
生成交互效应场景的模拟数据
|
||||
|
||||
返回:
|
||||
包含协变量 X、处理变量 T 和结果变量 Y 的 DataFrame
|
||||
"""
|
||||
# 生成协变量
|
||||
X = np.random.normal(0, 1, (self.n_samples, self.n_covariates))
|
||||
|
||||
# 生成处理变量
|
||||
T = np.random.binomial(1, self.treatment_rate, self.n_samples)
|
||||
|
||||
# 生成噪声
|
||||
noise = np.random.normal(0, self.noise_level, self.n_samples)
|
||||
|
||||
# 生成协变量系数
|
||||
beta = np.random.uniform(-0.5, 0.5, self.n_covariates)
|
||||
|
||||
# 生成交互效应系数
|
||||
gamma = np.random.uniform(-0.5, 0.5, self.n_covariates)
|
||||
|
||||
# 生成结果变量:Y = β*X + τ*T + γ*X*T + ε
|
||||
Y = np.dot(X, beta) + self.effect_size * T
|
||||
for i in range(self.n_covariates):
|
||||
Y += gamma[i] * X[:, i] * T
|
||||
Y += noise
|
||||
|
||||
# 创建 DataFrame
|
||||
df_dict = {}
|
||||
for i in range(self.n_covariates):
|
||||
df_dict[f'X{i}'] = X[:, i]
|
||||
df_dict['T'] = T
|
||||
df_dict['Y'] = Y
|
||||
|
||||
df = pd.DataFrame(df_dict)
|
||||
|
||||
return df
|
||||
|
||||
def get_description(self) -> str:
|
||||
"""获取场景描述"""
|
||||
return (
|
||||
f"交互效应场景:\n"
|
||||
f"- 样本数量:{self.n_samples}\n"
|
||||
f"- 处理比例:{self.treatment_rate:.2%}\n"
|
||||
f"- 噪声水平:{self.noise_level}\n"
|
||||
f"- 基础处理效应:{self.effect_size}\n"
|
||||
f"- 交互效应强度:{self.interaction_strength}\n"
|
||||
f"- 协变量数量:{self.n_covariates}\n"
|
||||
f"模型:Y = β*X + {self.effect_size}*T + γ*X*T + ε"
|
||||
)
|
||||
|
||||
|
||||
class CATESimulator(BaseDataSimulator):
|
||||
"""
|
||||
条件平均处理效应(CATE)场景模拟器
|
||||
|
||||
生成处理效应依赖于个体特征的数据。
|
||||
模型:Y = α + β*X + τ(X)*T + ε
|
||||
其中 τ(X) 是条件处理效应函数。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
n_features: int = 5,
|
||||
cate_function: str = 'linear',
|
||||
random_state: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化 CATE 场景模拟器
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量
|
||||
treatment_rate: 处理比例
|
||||
noise_level: 噪声水平
|
||||
effect_size: 平均处理效应基准
|
||||
n_features: 特征数量
|
||||
cate_function: CATE 函数类型 ('linear', 'nonlinear', 'threshold')
|
||||
random_state: 随机种子
|
||||
"""
|
||||
super().__init__(
|
||||
n_samples=n_samples,
|
||||
treatment_rate=treatment_rate,
|
||||
noise_level=noise_level,
|
||||
effect_size=effect_size,
|
||||
random_state=random_state
|
||||
)
|
||||
self.n_features = n_features
|
||||
self.cate_function = cate_function
|
||||
|
||||
def _compute_cate(self, X: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
计算条件处理效应
|
||||
|
||||
参数:
|
||||
X: 特征矩阵,形状为 (n_samples, n_features)
|
||||
|
||||
返回:
|
||||
条件处理效应数组
|
||||
"""
|
||||
if self.cate_function == 'linear':
|
||||
# 线性 CATE: τ(X) = τ0 + β*X
|
||||
beta_cate = np.random.uniform(-0.5, 0.5, self.n_features)
|
||||
cate = self.effect_size + np.dot(X, beta_cate)
|
||||
|
||||
elif self.cate_function == 'nonlinear':
|
||||
# 非线性 CATE: τ(X) = τ0 + sin(β*X)
|
||||
beta_cate = np.random.uniform(-1, 1, self.n_features)
|
||||
linear_comb = np.dot(X, beta_cate)
|
||||
cate = self.effect_size + np.sin(linear_comb)
|
||||
|
||||
elif self.cate_function == 'threshold':
|
||||
# 阈值 CATE: τ(X) = τ0 + β*X * I(X > threshold)
|
||||
beta_cate = np.random.uniform(-0.5, 0.5, self.n_features)
|
||||
threshold = np.random.uniform(-1, 1)
|
||||
linear_comb = np.dot(X, beta_cate)
|
||||
cate = self.effect_size + linear_comb * (linear_comb > threshold).astype(float)
|
||||
|
||||
else:
|
||||
raise ValueError(f"未知的 CATE 函数类型:{self.cate_function}")
|
||||
|
||||
return cate
|
||||
|
||||
def generate(self) -> pd.DataFrame:
|
||||
"""
|
||||
生成 CATE 场景的模拟数据
|
||||
|
||||
返回:
|
||||
包含特征 X、处理变量 T 和结果变量 Y 的 DataFrame
|
||||
"""
|
||||
# 生成特征
|
||||
X = np.random.normal(0, 1, (self.n_samples, self.n_features))
|
||||
|
||||
# 生成处理变量
|
||||
T = np.random.binomial(1, self.treatment_rate, self.n_samples)
|
||||
|
||||
# 生成噪声
|
||||
noise = np.random.normal(0, self.noise_level, self.n_samples)
|
||||
|
||||
# 计算 CATE
|
||||
cate = self._compute_cate(X)
|
||||
|
||||
# 生成协变量系数(用于结果模型)
|
||||
beta_y = np.random.uniform(-0.5, 0.5, self.n_features)
|
||||
|
||||
# 生成结果变量:Y = β*X + τ(X)*T + ε
|
||||
Y = np.dot(X, beta_y) + cate * T + noise
|
||||
|
||||
# 创建 DataFrame
|
||||
df_dict = {}
|
||||
for i in range(self.n_features):
|
||||
df_dict[f'X{i}'] = X[:, i]
|
||||
df_dict['T'] = T
|
||||
df_dict['Y'] = Y
|
||||
|
||||
# 添加 CATE 列(用于验证)
|
||||
df_dict['CATE'] = cate
|
||||
|
||||
df = pd.DataFrame(df_dict)
|
||||
|
||||
return df
|
||||
|
||||
def get_description(self) -> str:
|
||||
"""获取场景描述"""
|
||||
return (
|
||||
f"CATE 场景:\n"
|
||||
f"- 样本数量:{self.n_samples}\n"
|
||||
f"- 处理比例:{self.treatment_rate:.2%}\n"
|
||||
f"- 噪声水平:{self.noise_level}\n"
|
||||
f"- 平均处理效应基准:{self.effect_size}\n"
|
||||
f"- 特征数量:{self.n_features}\n"
|
||||
f"- CATE 函数类型:{self.cate_function}\n"
|
||||
f"模型:Y = β*X + τ(X)*T + ε"
|
||||
)
|
||||
|
||||
|
||||
class DataSimulatorFactory:
|
||||
"""
|
||||
数据模拟器工厂类
|
||||
|
||||
用于创建不同类型的模拟器实例。
|
||||
"""
|
||||
|
||||
_simulators = {
|
||||
'simple_ate': SimpleATESimulator,
|
||||
'covariate': CovariateSimulator,
|
||||
'interaction': InteractionEffectSimulator,
|
||||
'cate': CATESimulator
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create(cls, simulator_type: str, **kwargs) -> BaseDataSimulator:
|
||||
"""
|
||||
创建指定类型的模拟器
|
||||
|
||||
参数:
|
||||
simulator_type: 模拟器类型 ('simple_ate', 'covariate', 'interaction', 'cate')
|
||||
**kwargs: 模拟器初始化参数
|
||||
|
||||
返回:
|
||||
模拟器实例
|
||||
|
||||
Raises:
|
||||
ValueError: 未知的模拟器类型
|
||||
"""
|
||||
if simulator_type not in cls._simulators:
|
||||
available_types = ', '.join(cls._simulators.keys())
|
||||
raise ValueError(
|
||||
f"未知的模拟器类型:{simulator_type}。"
|
||||
f"可用类型:{available_types}"
|
||||
)
|
||||
|
||||
return cls._simulators[simulator_type](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def list_available_types(cls) -> list:
|
||||
"""
|
||||
获取所有可用的模拟器类型
|
||||
|
||||
返回:
|
||||
模拟器类型列表
|
||||
"""
|
||||
return list(cls._simulators.keys())
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def generate_simple_ate_data(
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
intercept: float = 0.0,
|
||||
random_state: Optional[int] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
生成简单 ATE 场景的模拟数据(便捷函数)
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量
|
||||
treatment_rate: 处理比例
|
||||
noise_level: 噪声水平
|
||||
effect_size: 平均处理效应
|
||||
intercept: 截距项
|
||||
random_state: 随机种子
|
||||
|
||||
返回:
|
||||
包含 'T' 和 'Y' 的 DataFrame
|
||||
"""
|
||||
simulator = SimpleATESimulator(
|
||||
n_samples=n_samples,
|
||||
treatment_rate=treatment_rate,
|
||||
noise_level=noise_level,
|
||||
effect_size=effect_size,
|
||||
intercept=intercept,
|
||||
random_state=random_state
|
||||
)
|
||||
return simulator.generate()
|
||||
|
||||
|
||||
def generate_covariate_data(
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
n_covariates: int = 3,
|
||||
covariate_distribution: str = 'normal',
|
||||
random_state: Optional[int] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
生成协变量场景的模拟数据(便捷函数)
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量
|
||||
treatment_rate: 处理比例
|
||||
noise_level: 噪声水平
|
||||
effect_size: 处理效应
|
||||
n_covariates: 协变量数量
|
||||
covariate_distribution: 协变量分布类型
|
||||
random_state: 随机种子
|
||||
|
||||
返回:
|
||||
包含协变量、'T' 和 'Y' 的 DataFrame
|
||||
"""
|
||||
simulator = CovariateSimulator(
|
||||
n_samples=n_samples,
|
||||
treatment_rate=treatment_rate,
|
||||
noise_level=noise_level,
|
||||
effect_size=effect_size,
|
||||
n_covariates=n_covariates,
|
||||
covariate_distribution=covariate_distribution,
|
||||
random_state=random_state
|
||||
)
|
||||
return simulator.generate()
|
||||
|
||||
|
||||
def generate_interaction_data(
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
interaction_strength: float = 0.5,
|
||||
n_covariates: int = 2,
|
||||
random_state: Optional[int] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
生成交互效应场景的模拟数据(便捷函数)
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量
|
||||
treatment_rate: 处理比例
|
||||
noise_level: 噪声水平
|
||||
effect_size: 基础处理效应
|
||||
interaction_strength: 交互效应强度
|
||||
n_covariates: 协变量数量
|
||||
random_state: 随机种子
|
||||
|
||||
返回:
|
||||
包含协变量、'T' 和 'Y' 的 DataFrame
|
||||
"""
|
||||
simulator = InteractionEffectSimulator(
|
||||
n_samples=n_samples,
|
||||
treatment_rate=treatment_rate,
|
||||
noise_level=noise_level,
|
||||
effect_size=effect_size,
|
||||
interaction_strength=interaction_strength,
|
||||
n_covariates=n_covariates,
|
||||
random_state=random_state
|
||||
)
|
||||
return simulator.generate()
|
||||
|
||||
|
||||
def generate_cate_data(
|
||||
n_samples: int = 1000,
|
||||
treatment_rate: float = 0.5,
|
||||
noise_level: float = 1.0,
|
||||
effect_size: float = 1.0,
|
||||
n_features: int = 5,
|
||||
cate_function: str = 'linear',
|
||||
random_state: Optional[int] = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
生成 CATE 场景的模拟数据(便捷函数)
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量
|
||||
treatment_rate: 处理比例
|
||||
noise_level: 噪声水平
|
||||
effect_size: 平均处理效应基准
|
||||
n_features: 特征数量
|
||||
cate_function: CATE 函数类型
|
||||
random_state: 随机种子
|
||||
|
||||
返回:
|
||||
包含特征、'T'、'Y' 和 'CATE' 的 DataFrame
|
||||
"""
|
||||
simulator = CATESimulator(
|
||||
n_samples=n_samples,
|
||||
treatment_rate=treatment_rate,
|
||||
noise_level=noise_level,
|
||||
effect_size=effect_size,
|
||||
n_features=n_features,
|
||||
cate_function=cate_function,
|
||||
random_state=random_state
|
||||
)
|
||||
return simulator.generate()
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == '__main__':
|
||||
print("=" * 60)
|
||||
print("模拟数据生成模块示例")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 简单 ATE 场景
|
||||
print("\n1. 简单 ATE 场景")
|
||||
print("-" * 40)
|
||||
simple_simulator = SimpleATESimulator(
|
||||
n_samples=500,
|
||||
treatment_rate=0.5,
|
||||
noise_level=1.0,
|
||||
effect_size=2.0,
|
||||
intercept=5.0,
|
||||
random_state=42
|
||||
)
|
||||
print(simple_simulator.get_description())
|
||||
simple_df = simple_simulator.generate()
|
||||
print(f"\n数据形状:{simple_df.shape}")
|
||||
print(f"前 5 行:\n{simple_df.head()}")
|
||||
|
||||
# 2. 协变量场景
|
||||
print("\n\n2. 协变量场景")
|
||||
print("-" * 40)
|
||||
covariate_simulator = CovariateSimulator(
|
||||
n_samples=500,
|
||||
treatment_rate=0.4,
|
||||
noise_level=1.5,
|
||||
effect_size=1.5,
|
||||
n_covariates=3,
|
||||
covariate_distribution='normal',
|
||||
random_state=42
|
||||
)
|
||||
print(covariate_simulator.get_description())
|
||||
covariate_df = covariate_simulator.generate()
|
||||
print(f"\n数据形状:{covariate_df.shape}")
|
||||
print(f"前 5 行:\n{covariate_df.head()}")
|
||||
|
||||
# 3. 交互效应场景
|
||||
print("\n\n3. 交互效应场景")
|
||||
print("-" * 40)
|
||||
interaction_simulator = InteractionEffectSimulator(
|
||||
n_samples=500,
|
||||
treatment_rate=0.5,
|
||||
noise_level=1.0,
|
||||
effect_size=1.0,
|
||||
interaction_strength=0.5,
|
||||
n_covariates=2,
|
||||
random_state=42
|
||||
)
|
||||
print(interaction_simulator.get_description())
|
||||
interaction_df = interaction_simulator.generate()
|
||||
print(f"\n数据形状:{interaction_df.shape}")
|
||||
print(f"前 5 行:\n{interaction_df.head()}")
|
||||
|
||||
# 4. CATE 场景
|
||||
print("\n\n4. CATE 场景")
|
||||
print("-" * 40)
|
||||
cate_simulator = CATESimulator(
|
||||
n_samples=500,
|
||||
treatment_rate=0.5,
|
||||
noise_level=1.0,
|
||||
effect_size=1.5,
|
||||
n_features=4,
|
||||
cate_function='linear',
|
||||
random_state=42
|
||||
)
|
||||
print(cate_simulator.get_description())
|
||||
cate_df = cate_simulator.generate()
|
||||
print(f"\n数据形状:{cate_df.shape}")
|
||||
print(f"前 5 行:\n{cate_df.head()}")
|
||||
|
||||
# 5. 使用工厂创建模拟器
|
||||
print("\n\n5. 使用工厂创建模拟器")
|
||||
print("-" * 40)
|
||||
print(f"可用模拟器类型:{DataSimulatorFactory.list_available_types()}")
|
||||
factory_simulator = DataSimulatorFactory.create(
|
||||
'simple_ate',
|
||||
n_samples=300,
|
||||
treatment_rate=0.6,
|
||||
effect_size=2.5,
|
||||
random_state=42
|
||||
)
|
||||
print(factory_simulator.get_description())
|
||||
|
||||
# 6. 保存数据
|
||||
print("\n\n6. 保存数据示例")
|
||||
print("-" * 40)
|
||||
output_dir = Path('data/output')
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
simple_simulator.save_to_csv(output_dir / 'simple_ate_data.csv')
|
||||
print(f"已保存:{output_dir / 'simple_ate_data.csv'}")
|
||||
|
||||
simple_simulator.save_to_excel(output_dir / 'simple_ate_data.xlsx')
|
||||
print(f"已保存:{output_dir / 'simple_ate_data.xlsx'}")
|
||||
|
||||
# 7. 使用便捷函数
|
||||
print("\n\n7. 使用便捷函数")
|
||||
print("-" * 40)
|
||||
simple_df = generate_simple_ate_data(n_samples=200, effect_size=3.0, random_state=123)
|
||||
print(f"便捷函数生成数据形状:{simple_df.shape}")
|
||||
print(f"前 5 行:\n{simple_df.head()}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("示例完成!")
|
||||
print("=" * 60)
|
||||
78
examples/medical/data_generator.py
Normal file
78
examples/medical/data_generator.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""
|
||||
医疗数据生成器
|
||||
用于模拟构建医疗测试数据:
|
||||
- id: 唯一标识符
|
||||
- treatment: 是否吃药(0 或 1)
|
||||
- health: 病人健康状态(0~1 浮点数,越高越好)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
def generate_medical_data(n_samples: int = 500, output_path: str = "examples/medical/data.xlsx") -> str:
|
||||
"""
|
||||
生成医疗测试数据并保存到 Excel 文件
|
||||
|
||||
Args:
|
||||
n_samples: 样本数量,默认 500
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
生成的文件路径
|
||||
"""
|
||||
# 设置随机种子以确保可重复性
|
||||
np.random.seed(42)
|
||||
|
||||
# 生成唯一 ID
|
||||
ids = list(range(1, n_samples + 1))
|
||||
|
||||
# 生成是否吃药(0 或 1),假设 40% 的人吃药
|
||||
treatment = np.random.binomial(1, 0.4, n_samples)
|
||||
|
||||
# 生成健康状态(0~1 浮点数)
|
||||
# 健康状态受是否吃药影响,吃药的人健康状态平均更高
|
||||
# 基础健康状态 + 吃药的额外影响 + 随机噪声
|
||||
base_health = np.random.beta(2, 2, n_samples) # 基础健康分布
|
||||
treatment_effect = treatment * 0.2 # 吃药带来 0.2 的健康提升
|
||||
noise = np.random.normal(0, 0.1, n_samples) # 随机噪声
|
||||
health = np.clip(base_health + treatment_effect + noise, 0, 1)
|
||||
|
||||
# 创建 DataFrame
|
||||
df = pd.DataFrame({
|
||||
'id': ids,
|
||||
'treatment': treatment,
|
||||
'health': np.round(health, 4)
|
||||
})
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 保存到 Excel
|
||||
df.to_excel(output_path, index=False)
|
||||
|
||||
print(f"成功生成 {n_samples} 条医疗数据,已保存到:{output_path}")
|
||||
print(f"文件已创建:{output_path}")
|
||||
print(f"数据预览:")
|
||||
print(df.head(10))
|
||||
print(f"\n统计信息:")
|
||||
print(f"吃药人数:{treatment.sum()} ({treatment.mean()*100:.1f}%)")
|
||||
print(f"健康状态均值:{health.mean():.4f}")
|
||||
print(f"吃药组健康均值:{health[treatment==1].mean():.4f}")
|
||||
print(f"未吃药组健康均值:{health[treatment==0].mean():.4f}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 默认生成 500 条数据到 data.xlsx
|
||||
output_file = generate_medical_data(n_samples=500, output_path="data.xlsx")
|
||||
print(f"\n文件已创建:{output_file}")
|
||||
200
examples/medical/data_validator.py
Normal file
200
examples/medical/data_validator.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""
|
||||
数据验证模块
|
||||
用于检查数据列是否匹配 LLM 识别的处理变量和结果变量
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Dict, Any, Tuple, Optional, List
|
||||
|
||||
|
||||
class ValidationResult:
|
||||
"""验证结果模型"""
|
||||
def __init__(
|
||||
self,
|
||||
is_valid: bool,
|
||||
treatment: Optional[str] = None,
|
||||
outcome: Optional[str] = None,
|
||||
errors: Optional[List[str]] = None,
|
||||
warnings: Optional[List[str]] = None
|
||||
):
|
||||
self.is_valid = is_valid
|
||||
self.treatment = treatment
|
||||
self.outcome = outcome
|
||||
self.errors = errors or []
|
||||
self.warnings = warnings or []
|
||||
|
||||
|
||||
class DataValidator:
|
||||
"""数据验证器"""
|
||||
|
||||
def __init__(self, data_path: str):
|
||||
"""
|
||||
初始化数据验证器
|
||||
|
||||
Args:
|
||||
data_path: 数据文件路径
|
||||
"""
|
||||
self.data_path = data_path
|
||||
self.data = None
|
||||
|
||||
def load_data(self) -> pd.DataFrame:
|
||||
"""加载数据"""
|
||||
if self.data is None:
|
||||
self.data = pd.read_excel(self.data_path)
|
||||
return self.data
|
||||
|
||||
def validate_columns(
|
||||
self,
|
||||
llm_result: Dict[str, Any]
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
验证数据列是否匹配 LLM 识别的变量
|
||||
|
||||
Args:
|
||||
llm_result: LLM 输出的 JSON 结果,包含 treatment 和 outcome 字段
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
"""
|
||||
errors = []
|
||||
warnings = []
|
||||
is_valid = True
|
||||
|
||||
# 加载数据
|
||||
try:
|
||||
self.load_data()
|
||||
except Exception as e:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
treatment=None,
|
||||
outcome=None,
|
||||
errors=[f"加载数据失败:{str(e)}"],
|
||||
warnings=[]
|
||||
)
|
||||
|
||||
# 获取数据列名
|
||||
assert self.data is not None, "数据未加载"
|
||||
available_columns = list(self.data.columns)
|
||||
|
||||
# 获取 LLM 识别的变量
|
||||
treatment = llm_result.get('treatment')
|
||||
outcome = llm_result.get('outcome')
|
||||
|
||||
# 验证处理变量
|
||||
if treatment is None:
|
||||
errors.append("LLM 未识别处理变量 (treatment)")
|
||||
is_valid = False
|
||||
elif treatment not in available_columns:
|
||||
errors.append(f"数据中不存在处理变量 '{treatment}'")
|
||||
errors.append(f"可用列名:{available_columns}")
|
||||
is_valid = False
|
||||
else:
|
||||
# 检查处理变量类型
|
||||
treatment_col = self.data[treatment]
|
||||
if treatment_col.nunique() > 2:
|
||||
warnings.append(f"处理变量 '{treatment}' 有多个唯一值,可能不是二元变量")
|
||||
elif treatment_col.dtype not in ['int64', 'float64', 'bool', 'object']:
|
||||
warnings.append(f"处理变量 '{treatment}' 的数据类型可能不适合因果分析")
|
||||
|
||||
# 验证结果变量
|
||||
if outcome is None:
|
||||
errors.append("LLM 未识别结果变量 (outcome)")
|
||||
is_valid = False
|
||||
elif outcome not in available_columns:
|
||||
errors.append(f"数据中不存在结果变量 '{outcome}'")
|
||||
errors.append(f"可用列名:{available_columns}")
|
||||
is_valid = False
|
||||
else:
|
||||
# 检查结果变量类型
|
||||
outcome_col = self.data[outcome]
|
||||
if outcome_col.dtype not in ['int64', 'float64']:
|
||||
warnings.append(f"结果变量 '{outcome}' 的数据类型可能不适合因果分析")
|
||||
|
||||
# 检查样本量
|
||||
if len(self.data) < 10:
|
||||
warnings.append(f"样本量 ({len(self.data)}) 较小,可能影响分析结果")
|
||||
|
||||
# 检查缺失值
|
||||
missing_treatment = self.data[treatment].isna().sum() if treatment and treatment in available_columns else 0
|
||||
missing_outcome = self.data[outcome].isna().sum() if outcome and outcome in available_columns else 0
|
||||
|
||||
if missing_treatment > 0:
|
||||
warnings.append(f"处理变量 '{treatment}' 有 {missing_treatment} 个缺失值")
|
||||
if missing_outcome > 0:
|
||||
warnings.append(f"结果变量 '{outcome}' 有 {missing_outcome} 个缺失值")
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=is_valid,
|
||||
treatment=treatment,
|
||||
outcome=outcome,
|
||||
errors=errors,
|
||||
warnings=warnings
|
||||
)
|
||||
|
||||
def validate_and_raise(
|
||||
self,
|
||||
llm_result: Dict[str, Any]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
验证并抛出异常(如果验证失败)
|
||||
|
||||
Args:
|
||||
llm_result: LLM 输出的 JSON 结果
|
||||
|
||||
Returns:
|
||||
(treatment, outcome) 元组,如果验证失败则抛出异常
|
||||
|
||||
Raises:
|
||||
ValueError: 验证失败时抛出
|
||||
"""
|
||||
result = self.validate_columns(llm_result)
|
||||
|
||||
if not result.is_valid:
|
||||
error_msg = "数据验证失败:\n"
|
||||
for error in result.errors:
|
||||
error_msg += f" - {error}\n"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return result.treatment, result.outcome
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def validate_data(data_path: str, llm_result: Dict[str, Any]) -> ValidationResult:
|
||||
"""
|
||||
便捷函数:验证数据
|
||||
|
||||
Args:
|
||||
data_path: 数据文件路径
|
||||
llm_result: LLM 输出的 JSON 结果
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
"""
|
||||
validator = DataValidator(data_path)
|
||||
return validator.validate_columns(llm_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试数据验证
|
||||
print("测试数据验证模块")
|
||||
print("=" * 50)
|
||||
|
||||
# 测试 1: 有效数据
|
||||
print("\n测试 1: 有效数据")
|
||||
validator = DataValidator("examples/medical/data.xlsx")
|
||||
llm_result = {"treatment": "treatment", "outcome": "health"}
|
||||
result = validator.validate_columns(llm_result)
|
||||
|
||||
print(f" 是否有效:{result.is_valid}")
|
||||
print(f" 处理变量:{result.treatment}")
|
||||
print(f" 结果变量:{result.outcome}")
|
||||
print(f" 错误:{result.errors}")
|
||||
print(f" 警告:{result.warnings}")
|
||||
|
||||
# 测试 2: 无效数据(变量名不存在)
|
||||
print("\n测试 2: 无效数据(变量名不存在)")
|
||||
llm_result_invalid = {"treatment": "invalid_col", "outcome": "health"}
|
||||
result_invalid = validator.validate_columns(llm_result_invalid)
|
||||
|
||||
print(f" 是否有效:{result_invalid.is_valid}")
|
||||
print(f" 错误:{result_invalid.errors}")
|
||||
166
examples/medical/llm_client.py
Normal file
166
examples/medical/llm_client.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""
|
||||
LLM 调用模块
|
||||
用于调用外部 LLM API 进行文本生成
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
import requests
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""LLM API 客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化 LLM 客户端
|
||||
|
||||
Args:
|
||||
base_url: API 基础 URL(默认从环境变量 LLM_BASE_URL 读取)
|
||||
model: 模型名称(默认从环境变量 LLM_MODEL 读取)
|
||||
temperature: 温度参数(默认从环境变量 LLM_TEMPERATURE 读取)
|
||||
max_tokens: 最大生成长度(默认从环境变量 LLM_MAX_TOKENS 读取)
|
||||
api_key: API 密钥(可选,从环境变量 LLM_API_KEY 读取)
|
||||
"""
|
||||
import os
|
||||
self.base_url = base_url or os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1")
|
||||
self.model = model or os.getenv("LLM_MODEL", "qwen3.5-35b")
|
||||
self.temperature = temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.3"))
|
||||
self.max_tokens = max_tokens if max_tokens is not None else int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
||||
self.api_key = api_key or os.getenv("LLM_API_KEY", "")
|
||||
|
||||
self._headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
if self.api_key:
|
||||
self._headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送聊天请求并获取响应
|
||||
|
||||
Args:
|
||||
messages: 消息列表,格式为 [{"role": "system/user/assistant", "content": "..."}]
|
||||
temperature: 温度参数(覆盖默认值)
|
||||
max_tokens: 最大生成长度(覆盖默认值)
|
||||
|
||||
Returns:
|
||||
API 响应字典
|
||||
"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature or self.temperature,
|
||||
"max_tokens": max_tokens or self.max_tokens,
|
||||
"stream": False,
|
||||
"extra_body": {
|
||||
"chat_template_kwargs": {
|
||||
"enable_thinking": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._headers,
|
||||
json=payload,
|
||||
timeout=120
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {
|
||||
"error": str(e),
|
||||
"status_code": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
|
||||
}
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成响应(简化接口)
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大生成长度
|
||||
|
||||
Returns:
|
||||
包含响应内容的字典
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
result = self.chat_completion(messages, temperature, max_tokens)
|
||||
|
||||
if "error" in result:
|
||||
return {
|
||||
"success": False,
|
||||
"error": result["error"],
|
||||
"content": None
|
||||
}
|
||||
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
choice = result["choices"][0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# 尝试从不同位置获取 content
|
||||
content = message.get("content")
|
||||
|
||||
# 如果 content 是 None,检查是否有 reasoning 或其他字段
|
||||
if content is None:
|
||||
# 有些模型可能将内容放在 reasoning 字段中
|
||||
reasoning = message.get("reasoning", "")
|
||||
if reasoning:
|
||||
content = reasoning
|
||||
|
||||
# 如果仍然没有内容,返回错误
|
||||
if content is None:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No content found in response",
|
||||
"content": None
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"error": None,
|
||||
"content": content,
|
||||
"usage": result.get("usage", {}),
|
||||
"model": result.get("model", self.model)
|
||||
}
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unexpected response format: {result}",
|
||||
"content": None
|
||||
}
|
||||
|
||||
|
||||
# 全局 LLM 客户端实例(可通过环境变量配置)
|
||||
def get_llm_client() -> LLMClient:
|
||||
"""获取 LLM 客户端实例(配置优先从环境变量读取)"""
|
||||
return LLMClient() # 所有默认值已从环境变量读取
|
||||
200
examples/medical/log.md
Normal file
200
examples/medical/log.md
Normal file
@ -0,0 +1,200 @@
|
||||
# 因果分析日志
|
||||
|
||||
## 日志说明
|
||||
本文档记录 LLM 进行因果分析时的所有输入参数和输出结果。
|
||||
|
||||
## 分析记录
|
||||
|
||||
---
|
||||
|
||||
### 分析 #001
|
||||
**时间**: 2026-03-26T23:19:00.211723
|
||||
|
||||
#### 系统提示词
|
||||
```
|
||||
你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别因果变量并评估因果关系。
|
||||
|
||||
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
|
||||
|
||||
JSON 输出规范:
|
||||
{
|
||||
"treatment": "处理变量名称",
|
||||
"outcome": "结果变量名称"
|
||||
}
|
||||
|
||||
注意:
|
||||
- 只输出上述 JSON 格式,不要包含其他字段
|
||||
- 处理变量和结果变量名称必须与数据表格的列名完全一致
|
||||
- 不要使用 markdown 代码块标记(如 ```json)
|
||||
- 直接输出纯 JSON 字符串
|
||||
```
|
||||
|
||||
#### 用户提示词
|
||||
```
|
||||
请分析以下医疗数据,并严格按照 JSON 格式输出分析结果:
|
||||
|
||||
|
||||
**数据概览:**
|
||||
- 样本数量:500
|
||||
- 变量:id, treatment, health
|
||||
|
||||
**统计摘要:**
|
||||
id treatment health
|
||||
count 500.000000 500.000000 500.000000
|
||||
mean 250.500000 0.414000 0.586082
|
||||
std 144.481833 0.493042 0.236948
|
||||
min 1.000000 0.000000 0.000000
|
||||
25% 125.750000 0.000000 0.421975
|
||||
50% 250.500000 0.000000 0.583000
|
||||
75% 375.250000 1.000000 0.769625
|
||||
max 500.000000 1.000000 1.000000
|
||||
|
||||
**处理变量分布:**
|
||||
treatment
|
||||
0 293
|
||||
1 207
|
||||
|
||||
|
||||
JSON 输出格式要求:
|
||||
{
|
||||
"treatment": "处理变量名称",
|
||||
"outcome": "结果变量名称"
|
||||
}
|
||||
|
||||
要求:
|
||||
1. 处理变量和结果变量名称必须与表格列名完全一致
|
||||
2. 只输出 JSON,不要包含其他任何内容
|
||||
3. 不要使用 markdown 代码块标记
|
||||
```
|
||||
|
||||
#### LLM 输出
|
||||
```
|
||||
|
||||
|
||||
{
|
||||
"treatment": "treatment",
|
||||
"outcome": "health"
|
||||
}
|
||||
```
|
||||
|
||||
#### 调用参数
|
||||
```json
|
||||
{
|
||||
"data_path": "examples/medical/data.xlsx",
|
||||
"sample_size": 500,
|
||||
"variables": [
|
||||
"id",
|
||||
"treatment",
|
||||
"health"
|
||||
],
|
||||
"treatment_variable": "treatment",
|
||||
"outcome_variable": "health",
|
||||
"llm_params": {
|
||||
"base_url": "http://10.106.123.247:8000/v1",
|
||||
"model": "qwen3.5-35b",
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 2048
|
||||
},
|
||||
"log_path": "examples/medical/log.md"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 分析 #002
|
||||
**时间**: 2026-03-26T23:19:53.065927
|
||||
|
||||
#### 系统提示词
|
||||
```
|
||||
你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别因果变量并评估因果关系。
|
||||
|
||||
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
|
||||
|
||||
JSON 输出规范:
|
||||
{
|
||||
"treatment": "处理变量名称",
|
||||
"outcome": "结果变量名称"
|
||||
}
|
||||
|
||||
注意:
|
||||
- 只输出上述 JSON 格式,不要包含其他字段
|
||||
- 处理变量和结果变量名称必须与数据表格的列名完全一致
|
||||
- 不要使用 markdown 代码块标记(如 ```json)
|
||||
- 直接输出纯 JSON 字符串
|
||||
```
|
||||
|
||||
#### 用户提示词
|
||||
```
|
||||
请分析以下医疗数据,并严格按照 JSON 格式输出分析结果:
|
||||
|
||||
|
||||
**数据概览:**
|
||||
- 样本数量:500
|
||||
- 变量:id, treatment, health
|
||||
|
||||
**统计摘要:**
|
||||
id treatment health
|
||||
count 500.000000 500.000000 500.000000
|
||||
mean 250.500000 0.414000 0.586082
|
||||
std 144.481833 0.493042 0.236948
|
||||
min 1.000000 0.000000 0.000000
|
||||
25% 125.750000 0.000000 0.421975
|
||||
50% 250.500000 0.000000 0.583000
|
||||
75% 375.250000 1.000000 0.769625
|
||||
max 500.000000 1.000000 1.000000
|
||||
|
||||
**处理变量分布:**
|
||||
treatment
|
||||
0 293
|
||||
1 207
|
||||
|
||||
|
||||
JSON 输出格式要求:
|
||||
{
|
||||
"treatment": "处理变量名称",
|
||||
"outcome": "结果变量名称"
|
||||
}
|
||||
|
||||
要求:
|
||||
1. 处理变量和结果变量名称必须与表格列名完全一致
|
||||
2. 只输出 JSON,不要包含其他任何内容
|
||||
3. 不要使用 markdown 代码块标记
|
||||
```
|
||||
|
||||
#### LLM 输出
|
||||
```
|
||||
|
||||
|
||||
{
|
||||
"treatment": "treatment",
|
||||
"outcome": "health"
|
||||
}
|
||||
```
|
||||
|
||||
#### 调用参数
|
||||
```json
|
||||
{
|
||||
"data_path": "examples/medical/data.xlsx",
|
||||
"sample_size": 500,
|
||||
"variables": [
|
||||
"id",
|
||||
"treatment",
|
||||
"health"
|
||||
],
|
||||
"treatment_variable": "treatment",
|
||||
"outcome_variable": "health",
|
||||
"llm_params": {
|
||||
"base_url": "http://10.106.123.247:8000/v1",
|
||||
"model": "qwen3.5-35b",
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 2048
|
||||
},
|
||||
"log_path": "examples/medical/log.md"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
356
examples/medical/start.py
Normal file
356
examples/medical/start.py
Normal file
@ -0,0 +1,356 @@
|
||||
"""
|
||||
LLM 因果分析测试入口
|
||||
使用自然语言输入让 LLM 测试分析因果变量,并将所有参数记录到 log.md 文档中
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
import pandas as pd
|
||||
from llm_client import LLMClient
|
||||
|
||||
|
||||
class CausalAnalysisLogger:
|
||||
"""因果分析日志记录器"""
|
||||
|
||||
def __init__(self, log_path: str = "examples/medical/log.md"):
|
||||
self.log_path = log_path
|
||||
self.log_entries = []
|
||||
self._init_log_file()
|
||||
self._load_existing_entries()
|
||||
|
||||
def _init_log_file(self):
|
||||
"""初始化日志文件"""
|
||||
header = """# 因果分析日志
|
||||
|
||||
## 日志说明
|
||||
本文档记录 LLM 进行因果分析时的所有输入参数和输出结果。
|
||||
|
||||
## 分析记录
|
||||
"""
|
||||
if not os.path.exists(self.log_path):
|
||||
with open(self.log_path, 'w', encoding='utf-8') as f:
|
||||
f.write(header)
|
||||
|
||||
def _load_existing_entries(self):
|
||||
"""从日志文件中加载已存在的分析记录,以确定下一个分析 ID"""
|
||||
if not os.path.exists(self.log_path):
|
||||
return
|
||||
|
||||
analysis_count = 0
|
||||
with open(self.log_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
# 统计已存在的分析记录数量
|
||||
import re
|
||||
matches = re.findall(r'### 分析 #(\d+)', content)
|
||||
if matches:
|
||||
analysis_count = max(int(m) for m in matches)
|
||||
|
||||
self.analysis_count = analysis_count
|
||||
|
||||
def _append_to_log(self, content: str):
|
||||
"""追加内容到日志文件"""
|
||||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
def log_analysis(self, system_prompt: str, user_prompt: str,
|
||||
model_output: str, parameters: Dict[str, Any]):
|
||||
"""记录一次完整的分析过程"""
|
||||
# 增加分析计数
|
||||
self.analysis_count += 1
|
||||
analysis_id = f"{self.analysis_count:03d}"
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
entry = f"""
|
||||
---
|
||||
|
||||
### 分析 #{analysis_id}
|
||||
**时间**: {timestamp}
|
||||
|
||||
#### 系统提示词
|
||||
```
|
||||
{system_prompt}
|
||||
```
|
||||
|
||||
#### 用户提示词
|
||||
```
|
||||
{user_prompt}
|
||||
```
|
||||
|
||||
#### LLM 输出
|
||||
```
|
||||
{model_output}
|
||||
```
|
||||
|
||||
#### 调用参数
|
||||
```json
|
||||
{json.dumps(parameters, indent=2, ensure_ascii=False)}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
self._append_to_log(entry)
|
||||
self.log_entries.append({
|
||||
'id': analysis_id,
|
||||
'timestamp': timestamp,
|
||||
'system_prompt': system_prompt,
|
||||
'user_prompt': user_prompt,
|
||||
'model_output': model_output,
|
||||
'parameters': parameters
|
||||
})
|
||||
|
||||
print(f"分析 #{analysis_id} 已记录到 {self.log_path}")
|
||||
|
||||
|
||||
class CausalAnalysisAgent:
|
||||
"""因果分析代理"""
|
||||
|
||||
def __init__(self, data_path: str = "examples/medical/data.xlsx", log_path: str = "examples/medical/log.md"):
|
||||
self.data_path = data_path
|
||||
self.logger = CausalAnalysisLogger(log_path)
|
||||
self.data = None
|
||||
|
||||
def load_data(self) -> pd.DataFrame:
|
||||
"""加载数据"""
|
||||
if not os.path.exists(self.data_path):
|
||||
raise FileNotFoundError(f"数据文件不存在:{self.data_path}")
|
||||
|
||||
self.data = pd.read_excel(self.data_path)
|
||||
print(f"成功加载数据:{self.data_path}")
|
||||
print(f"数据形状:{self.data.shape}")
|
||||
print(f"列名:{list(self.data.columns)}")
|
||||
print(f"\n数据预览:")
|
||||
print(self.data.head())
|
||||
return self.data
|
||||
|
||||
def _generate_system_prompt(self) -> str:
|
||||
"""生成系统提示词"""
|
||||
return """你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别因果变量并评估因果关系。
|
||||
|
||||
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
|
||||
|
||||
JSON 输出规范:
|
||||
{
|
||||
"treatment": "处理变量名称",
|
||||
"outcome": "结果变量名称"
|
||||
}
|
||||
|
||||
注意:
|
||||
- 只输出上述 JSON 格式,不要包含其他字段
|
||||
- 处理变量和结果变量名称必须与数据表格的列名完全一致
|
||||
- 不要使用 markdown 代码块标记(如 ```json)
|
||||
- 直接输出纯 JSON 字符串"""
|
||||
|
||||
def _generate_user_prompt(self, data_info: str) -> str:
|
||||
"""生成用户提示词"""
|
||||
return f"""请分析以下医疗数据,并严格按照 JSON 格式输出分析结果:
|
||||
|
||||
{data_info}
|
||||
|
||||
JSON 输出格式要求:
|
||||
{{
|
||||
"treatment": "处理变量名称",
|
||||
"outcome": "结果变量名称"
|
||||
}}
|
||||
|
||||
要求:
|
||||
1. 处理变量和结果变量名称必须与表格列名完全一致
|
||||
2. 只输出 JSON,不要包含其他任何内容
|
||||
3. 不要使用 markdown 代码块标记"""
|
||||
|
||||
def _call_llm(self, system_prompt: str, user_prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
调用 LLM API 获取响应
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
|
||||
Returns:
|
||||
包含响应内容的字典
|
||||
"""
|
||||
# 初始化 LLM 客户端(从环境变量读取配置)
|
||||
import os
|
||||
llm_client = LLMClient(
|
||||
base_url=os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1"),
|
||||
model=os.getenv("LLM_MODEL", "qwen3.5-35b"),
|
||||
temperature=float(os.getenv("LLM_TEMPERATURE", "0.3")),
|
||||
max_tokens=int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
||||
)
|
||||
|
||||
# 调用 LLM
|
||||
result = llm_client.generate_response(system_prompt, user_prompt)
|
||||
|
||||
return result
|
||||
|
||||
def analyze(self, custom_prompt: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
执行因果分析
|
||||
|
||||
Args:
|
||||
custom_prompt: 自定义用户提示词,如果为 None 则使用默认提示词
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
# 加载数据
|
||||
if self.data is None:
|
||||
self.load_data()
|
||||
|
||||
# 类型断言:确保 self.data 不是 None
|
||||
assert self.data is not None, "数据未加载"
|
||||
|
||||
# 生成数据信息
|
||||
data_info = f"""
|
||||
**数据概览:**
|
||||
- 样本数量:{len(self.data)}
|
||||
- 变量:{', '.join(self.data.columns)}
|
||||
|
||||
**统计摘要:**
|
||||
{self.data.describe()}
|
||||
|
||||
**处理变量分布:**
|
||||
{self.data['treatment'].value_counts().to_string()}
|
||||
"""
|
||||
|
||||
# 生成提示词
|
||||
system_prompt = self._generate_system_prompt()
|
||||
user_prompt = custom_prompt if custom_prompt else self._generate_user_prompt(data_info)
|
||||
|
||||
# 计算统计值
|
||||
treatment_mean = self.data[self.data['treatment'] == 1]['health'].mean() if self.data['treatment'].sum() > 0 else 0
|
||||
control_mean = self.data[self.data['treatment'] == 0]['health'].mean() if self.data['treatment'].sum() < len(self.data) else 0
|
||||
ate = treatment_mean - control_mean if self.data['treatment'].sum() > 0 and self.data['treatment'].sum() < len(self.data) else 0
|
||||
|
||||
# 调用 LLM 获取真实响应
|
||||
llm_result = self._call_llm(system_prompt, user_prompt)
|
||||
|
||||
# 打印 LLM 调用结果用于调试
|
||||
print(f"\nLLM 调用结果:{llm_result}")
|
||||
|
||||
if llm_result.get("success"):
|
||||
model_output = llm_result.get("content", "LLM 返回内容为空")
|
||||
else:
|
||||
# 如果 LLM 调用失败,使用模拟响应
|
||||
model_output = f"""## 因果分析结果(LLM 调用失败,使用模拟响应)
|
||||
|
||||
### 1. 变量识别
|
||||
- **处理变量 (Treatment)**: treatment (是否吃药)
|
||||
- 类型:二元变量 (0/1)
|
||||
- 含义:1 表示吃药,0 表示不吃药
|
||||
|
||||
- **结果变量 (Outcome)**: health (健康状态)
|
||||
- 类型:连续变量 (0-1 浮点数)
|
||||
- 含义:健康状态评分,越高越好
|
||||
|
||||
### 2. 因果效应估计
|
||||
根据数据描述性统计:
|
||||
- 吃药组 (treatment=1) 的平均健康状态:{treatment_mean:.4f}
|
||||
- 未吃药组 (treatment=0) 的平均健康状态:{control_mean:.4f}
|
||||
- **平均处理效应 (ATE)**: {ate:.4f}
|
||||
|
||||
### 3. 分析方法
|
||||
使用简单的组间比较方法估计因果效应:
|
||||
- 方法:均值差异 (Mean Difference)
|
||||
- 假设:无混杂因素或混杂因素已控制
|
||||
|
||||
### 4. 结论
|
||||
吃药对健康状态有正向的因果效应,估计效应大小为 {ate:.4f}。
|
||||
这意味着吃药可以使健康状态平均提高 {ate * 100:.1f}%。
|
||||
|
||||
### 5. 建议
|
||||
- 建议进行更严格的因果推断分析(如倾向得分匹配、工具变量法等)
|
||||
- 考虑控制可能的混杂因素(如年龄、基础健康状况等)
|
||||
|
||||
---
|
||||
**LLM 调用错误**: {llm_result.get('error', 'Unknown error')}"""
|
||||
|
||||
# 准备参数
|
||||
parameters = {
|
||||
'data_path': self.data_path,
|
||||
'sample_size': len(self.data),
|
||||
'variables': list(self.data.columns),
|
||||
'treatment_variable': 'treatment',
|
||||
'outcome_variable': 'health'
|
||||
}
|
||||
|
||||
# 记录到日志(包含 LLM 调用信息)
|
||||
import os
|
||||
llm_params = {
|
||||
'base_url': os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1"),
|
||||
'model': os.getenv("LLM_MODEL", "qwen3.5-35b"),
|
||||
'temperature': float(os.getenv("LLM_TEMPERATURE", "0.3")),
|
||||
'max_tokens': int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
||||
}
|
||||
parameters['llm_params'] = llm_params
|
||||
parameters['data_path'] = self.data_path
|
||||
parameters['log_path'] = self.logger.log_path
|
||||
|
||||
self.logger.log_analysis(system_prompt, user_prompt, model_output, parameters)
|
||||
|
||||
# 打印结果
|
||||
print(f"\n{'='*60}")
|
||||
print(f"分析 #{self.logger.analysis_count:03d} 结果:")
|
||||
print(f"{'='*60}")
|
||||
print(model_output)
|
||||
|
||||
return {
|
||||
'id': f"{self.logger.analysis_count:03d}",
|
||||
'system_prompt': system_prompt,
|
||||
'user_prompt': user_prompt,
|
||||
'model_output': model_output,
|
||||
'parameters': parameters
|
||||
}
|
||||
|
||||
def interactive_analyze(self):
|
||||
"""交互式分析模式"""
|
||||
print("=" * 60)
|
||||
print("LLM 因果分析测试系统")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载数据
|
||||
self.load_data()
|
||||
|
||||
while True:
|
||||
print("\n" + "-" * 40)
|
||||
print("请输入分析提示词(输入 'quit' 退出,'default' 使用默认):")
|
||||
user_input = input("> ").strip()
|
||||
|
||||
if user_input.lower() == 'quit':
|
||||
print("感谢使用,再见!")
|
||||
break
|
||||
elif user_input.lower() == 'default':
|
||||
user_input = None
|
||||
|
||||
try:
|
||||
result = self.analyze(custom_prompt=user_input)
|
||||
print(f"\n分析已记录到:{self.logger.log_path}")
|
||||
except Exception as e:
|
||||
print(f"分析出错:{e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 检查数据文件是否存在,不存在则生成
|
||||
data_path = "examples/medical/data.xlsx"
|
||||
if not os.path.exists(data_path):
|
||||
print("数据文件不存在,正在生成...")
|
||||
from data_generator import generate_medical_data
|
||||
generate_medical_data(n_samples=500, output_path=data_path)
|
||||
|
||||
# 创建分析代理并执行分析
|
||||
agent = CausalAnalysisAgent(data_path=data_path, log_path="examples/medical/log.md")
|
||||
|
||||
# 执行分析
|
||||
result = agent.analyze()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成!")
|
||||
print(f"日志已保存到:{agent.logger.log_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
174
examples/medical_v2/data_generator.py
Normal file
174
examples/medical_v2/data_generator.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""
|
||||
医疗测试数据生成器
|
||||
|
||||
生成符合因果推断测试要求的医疗数据,包含以下特征:
|
||||
- id: 唯一标识符
|
||||
- treatment: 是否吃药(0 或 1),吃药比例 2/3
|
||||
- health: 病人的健康状态(0~1 浮点数,越高越好)
|
||||
- base_health: 病人吃药后的健康状态,如果没吃药就和 health 栏一致
|
||||
- age: 年龄(18~70 岁),年龄越大,健康状态越低,药效越显著
|
||||
|
||||
数据生成逻辑:
|
||||
- 基础健康状态服从 Beta 分布
|
||||
- 年龄对健康有负面影响(年龄越大,健康越低)
|
||||
- 年龄对药效有调节作用(年龄越大,吃药带来的健康提升越显著)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def generate_medical_data(
|
||||
n_samples: int = 500,
|
||||
treatment_ratio: float = 2/3,
|
||||
seed: int = 42
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
生成医疗测试数据
|
||||
|
||||
参数:
|
||||
n_samples: 样本数量,默认 500
|
||||
treatment_ratio: 吃药比例,默认 2/3
|
||||
seed: 随机种子,用于可重复性
|
||||
|
||||
返回:
|
||||
包含生成数据的 DataFrame
|
||||
"""
|
||||
# 设置随机种子以确保可重复性
|
||||
np.random.seed(seed)
|
||||
|
||||
# 1. 生成唯一标识符
|
||||
ids = list(range(1, n_samples + 1))
|
||||
|
||||
# 2. 生成年龄(18~70 岁,均匀分布)
|
||||
ages = np.random.randint(18, 71, size=n_samples)
|
||||
|
||||
# 3. 生成基础健康状态(服从 Beta 分布)
|
||||
# Beta(2, 2) 分布在 [0,1] 区间,均值 0.5,适合模拟健康状态
|
||||
base_health_raw = np.random.beta(2, 2, size=n_samples)
|
||||
|
||||
# 4. 计算年龄对健康的影响
|
||||
# 将年龄归一化到 [0,1],然后乘以一个系数
|
||||
age_normalized = (ages - 18) / (70 - 18) # 归一化到 [0,1]
|
||||
age_health_effect = 0.3 * age_normalized # 年龄对健康的影响系数
|
||||
|
||||
# 5. 计算基础健康状态(考虑年龄影响)
|
||||
base_health = base_health_raw * (1 - age_health_effect)
|
||||
# 确保在 [0,1] 范围内
|
||||
base_health = np.clip(base_health, 0, 1)
|
||||
|
||||
# 6. 生成 treatment(是否吃药)
|
||||
# 治疗分配受年龄和基线健康影响,产生混杂:
|
||||
# 年龄越大、基线健康越差的患者,越可能接受治疗
|
||||
logit_p = 1.5 * age_normalized - 2.0 * base_health + 0.8
|
||||
prob_treatment = 1 / (1 + np.exp(-logit_p))
|
||||
treatments = (np.random.uniform(0, 1, size=n_samples) < prob_treatment).astype(int)
|
||||
|
||||
# 7. 计算药效(年龄越大,药效越显著)
|
||||
# 药效系数:0.1(年轻)到 0.4(年老)
|
||||
treatment_effect = 0.1 + 0.3 * age_normalized
|
||||
|
||||
# 8. 计算最终健康状态
|
||||
# 如果吃药,health = base_health + treatment_effect
|
||||
# 如果没吃药,health = base_health
|
||||
health = base_health.copy()
|
||||
health[treatments == 1] = np.clip(
|
||||
base_health[treatments == 1] + treatment_effect[treatments == 1],
|
||||
0, 1
|
||||
)
|
||||
|
||||
# 9. 计算 base_health 列
|
||||
# 如果没吃药,base_health = health
|
||||
# 如果吃药,base_health = health - treatment_effect(即不吃药时的健康状态)
|
||||
base_health_final = health.copy()
|
||||
base_health_final[treatments == 1] = np.clip(
|
||||
health[treatments == 1] - treatment_effect[treatments == 1],
|
||||
0, 1
|
||||
)
|
||||
|
||||
# 10. 创建 DataFrame
|
||||
df = pd.DataFrame({
|
||||
'id': ids,
|
||||
'treatment': treatments,
|
||||
'health': np.round(health, 4),
|
||||
'base_health': np.round(base_health_final, 4),
|
||||
'age': ages
|
||||
})
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def print_data_preview(df: pd.DataFrame) -> None:
|
||||
"""打印数据预览"""
|
||||
print("=" * 60)
|
||||
print("数据预览(前 10 行)")
|
||||
print("=" * 60)
|
||||
print(df.head(10).to_string(index=False))
|
||||
print()
|
||||
|
||||
|
||||
def print_statistics(df: pd.DataFrame) -> None:
|
||||
"""打印统计信息"""
|
||||
print("=" * 60)
|
||||
print("统计信息")
|
||||
print("=" * 60)
|
||||
|
||||
# 总体统计
|
||||
print("\n【总体统计】")
|
||||
print(df.describe().round(4))
|
||||
|
||||
# Treatment 分布
|
||||
print("\n【Treatment 分布】")
|
||||
treatment_dist = df['treatment'].value_counts()
|
||||
print(f"不吃药 (0): {treatment_dist.get(0, 0)} 人 ({100*treatment_dist.get(0, 0)/len(df):.1f}%)")
|
||||
print(f"吃药 (1): {treatment_dist.get(1, 0)} 人 ({100*treatment_dist.get(1, 0)/len(df):.1f}%)")
|
||||
|
||||
# 按 treatment 分组的统计
|
||||
print("\n【按 Treatment 分组的健康状态统计】")
|
||||
grouped = df.groupby('treatment')['health'].agg(['mean', 'std', 'min', 'max'])
|
||||
print(grouped.round(4))
|
||||
|
||||
# 年龄与药效的关系
|
||||
print("\n【年龄与药效关系分析】")
|
||||
treated_df = df[df['treatment'] == 1].copy()
|
||||
treated_df['age_normalized'] = (treated_df['age'] - 18) / (70 - 18)
|
||||
treated_df['health_gain'] = treated_df['health'] - treated_df['base_health']
|
||||
|
||||
# 按年龄四分位数分组
|
||||
treated_df['age_quartile'] = pd.qcut(treated_df['age'], q=4, labels=['Q1( youngest)', 'Q2', 'Q3', 'Q4(oldest)'])
|
||||
quartile_analysis = treated_df.groupby('age_quartile')['health_gain'].agg(['mean', 'std', 'count'])
|
||||
print(quartile_analysis.round(4))
|
||||
|
||||
# 相关性分析
|
||||
print("\n【相关性分析】")
|
||||
print(f"年龄与健康的相关系数:{df['age'].corr(df['health']):.4f}")
|
||||
print(f"年龄与药效的相关系数:{treated_df['age_normalized'].corr(treated_df['health_gain']):.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("医疗测试数据生成器")
|
||||
print("-" * 40)
|
||||
|
||||
# 生成数据
|
||||
print("正在生成数据...")
|
||||
df = generate_medical_data(n_samples=500, treatment_ratio=2/3, seed=42)
|
||||
|
||||
# 打印预览
|
||||
print_data_preview(df)
|
||||
|
||||
# 打印统计信息
|
||||
print_statistics(df)
|
||||
|
||||
# 保存为 Excel(使用绝对路径)
|
||||
import os
|
||||
output_file = os.path.join(os.path.dirname(__file__), "data.xlsx")
|
||||
df.to_excel(output_file, index=False)
|
||||
print(f"\n数据已保存到:{output_file}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1196
examples/medical_v2/log.md
Normal file
1196
examples/medical_v2/log.md
Normal file
File diff suppressed because it is too large
Load Diff
46
examples/medical_v2/start.py
Normal file
46
examples/medical_v2/start.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""
|
||||
医疗数据因果推断分析示例(v2)
|
||||
|
||||
本脚本演示如何使用 causal_agent 主包对医疗数据进行完整的因果推断分析。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 将项目根目录加入路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from causal_agent.agent import CausalInferenceAgent
|
||||
from causal_agent.core.config import AgentConfig
|
||||
|
||||
|
||||
def main():
|
||||
data_path = "examples/medical_v2/data.xlsx"
|
||||
|
||||
# 如果数据不存在,自动生成示例数据
|
||||
if not os.path.exists(data_path):
|
||||
print("数据文件不存在,正在生成示例医疗数据...")
|
||||
from examples.medical_v2.data_generator import generate_medical_data
|
||||
|
||||
df = generate_medical_data(n_samples=500, treatment_ratio=2 / 3, seed=42)
|
||||
df.to_excel(data_path, index=False)
|
||||
print(f"数据已保存到:{data_path}")
|
||||
|
||||
# 配置 Agent(可选:自定义 LLM 参数、日志路径等)
|
||||
config = AgentConfig.from_env()
|
||||
config.log_path = "examples/medical_v2/log.md"
|
||||
|
||||
# 创建 Agent 并执行分析
|
||||
agent = CausalInferenceAgent(config)
|
||||
result = agent.analyze(data_path)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("示例运行完成!")
|
||||
print(f"分析 ID:{result['id']}")
|
||||
print(f"日志已保存到:{agent.logger.log_path}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
23
input_processor/__init__.py
Normal file
23
input_processor/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""
|
||||
输入处理模块 (Input Processor Module)
|
||||
|
||||
该模块提供文件查询和因果推断解析功能,用于处理用户自然语言问题并识别因果推断要素。
|
||||
|
||||
主要组件:
|
||||
- FileQueryTool: 文件查询工具,支持 CSV 和 Excel 文件
|
||||
- CausalParser: 因果推断解析器,使用 LLM 解析用户问题
|
||||
"""
|
||||
|
||||
from .file_query_tool import FileQueryTool, FileMetadata, ColumnInfo, SampleData
|
||||
from .causal_parser import CausalParser, CausalInferenceRequest, CausalInferenceResult
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__all__ = [
|
||||
"FileQueryTool",
|
||||
"FileMetadata",
|
||||
"ColumnInfo",
|
||||
"SampleData",
|
||||
"CausalParser",
|
||||
"CausalInferenceRequest",
|
||||
"CausalInferenceResult",
|
||||
]
|
||||
723
input_processor/causal_parser.py
Normal file
723
input_processor/causal_parser.py
Normal file
@ -0,0 +1,723 @@
|
||||
"""
|
||||
因果推断解析器模块 (Causal Parser)
|
||||
|
||||
该模块提供因果推断解析功能,使用启发式规则解析用户的自然语言问题,
|
||||
识别因果推断中的关键要素(处理变量 T、结果变量 Y、协变量)。
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
|
||||
# 支持直接运行和作为模块导入两种方式
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from file_query_tool import FileQueryTool, FileMetadata
|
||||
else:
|
||||
from .file_query_tool import FileQueryTool, FileMetadata
|
||||
|
||||
@dataclass
|
||||
class CausalInferenceRequest:
|
||||
"""因果推断请求数据类"""
|
||||
user_question: str # 用户的自然语言问题
|
||||
file_path: str # 数据文件路径
|
||||
read_rows: int = 100 # 读取的行数
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariableInfo:
|
||||
"""变量信息"""
|
||||
name: str # 变量名
|
||||
variable_type: str # 变量类型(treatment/outcome/control)
|
||||
confidence: float # 置信度(0-1)
|
||||
reasoning: str # 推理理由
|
||||
data_type: Optional[str] = None # 数据类型
|
||||
is_binary: Optional[bool] = None # 是否为二值变量
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalInferenceResult:
|
||||
"""因果推断结果数据类"""
|
||||
original_question: str # 用户原始问题
|
||||
file_path: str # 文件路径
|
||||
treatment_variable: Optional[VariableInfo] = None # 处理变量 T
|
||||
outcome_variable: Optional[VariableInfo] = None # 结果变量 Y
|
||||
control_variables: List[VariableInfo] = field(default_factory=list) # 协变量列表
|
||||
confidence_score: float = 0.0 # 整体置信度评分
|
||||
reasoning: str = "" # 推理理由
|
||||
alternative_matches: List[Dict[str, Any]] = field(default_factory=list) # 其他可能的匹配
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result = asdict(self)
|
||||
if self.treatment_variable:
|
||||
result["treatment_variable"] = self.treatment_variable.to_dict()
|
||||
if self.outcome_variable:
|
||||
result["outcome_variable"] = self.outcome_variable.to_dict()
|
||||
result["control_variables"] = [v.to_dict() for v in self.control_variables]
|
||||
return result
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
"""转换为 JSON 字符串"""
|
||||
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
|
||||
|
||||
|
||||
class CausalParser:
|
||||
"""
|
||||
因果推断解析器
|
||||
|
||||
使用 LLM 解析用户的自然语言问题,识别因果推断中的关键要素。
|
||||
|
||||
Attributes:
|
||||
file_query_tool: 文件查询工具实例
|
||||
llm_client: LLM 客户端(可选)
|
||||
default_read_rows: 默认读取的行数
|
||||
|
||||
Example:
|
||||
>>> parser = CausalParser()
|
||||
>>> request = CausalInferenceRequest(
|
||||
... user_question="教育对收入的影响是什么?",
|
||||
... file_path="data/education_data.csv"
|
||||
... )
|
||||
>>> result = parser.parse(request)
|
||||
>>> print(result.to_json())
|
||||
"""
|
||||
|
||||
# 内置的提示词模板
|
||||
DEFAULT_SYSTEM_PROMPT = """你是一位因果推断专家,负责从用户的自然语言问题中识别因果推断的关键要素。
|
||||
|
||||
任务:
|
||||
1. 分析用户的问题,理解其因果推断意图
|
||||
2. 根据提供的数据表结构,识别以下变量:
|
||||
- 处理变量 (T/Treatment): 被操纵或观察的干预措施
|
||||
- 结果变量 (Y/Outcome): 受处理影响的指标
|
||||
- 协变量 (Control): 需要控制的混淆因素
|
||||
|
||||
输出要求:
|
||||
- 只输出 JSON 格式,不要包含其他内容
|
||||
- 对于每个变量,提供名称、类型、置信度和推理理由
|
||||
- 如果无法确定某个变量,将其设为 null 或空列表
|
||||
- 置信度评分范围:0.0-1.0
|
||||
|
||||
变量类型说明:
|
||||
- treatment: 处理变量,通常是实验中的干预措施或观察中的暴露因素
|
||||
- outcome: 结果变量,是我们要研究的效应指标
|
||||
- control: 协变量/控制变量,用于控制混淆因素
|
||||
|
||||
示例输出格式:
|
||||
{{
|
||||
"treatment_variable": {{
|
||||
"name": "treatment",
|
||||
"variable_type": "treatment",
|
||||
"confidence": 0.95,
|
||||
"reasoning": "该变量表示是否接受干预"
|
||||
}},
|
||||
"outcome_variable": {{
|
||||
"name": "outcome",
|
||||
"variable_type": "outcome",
|
||||
"confidence": 0.90,
|
||||
"reasoning": "该变量表示干预后的结果"
|
||||
}},
|
||||
"control_variables": [
|
||||
{{
|
||||
"name": "age",
|
||||
"variable_type": "control",
|
||||
"confidence": 0.85,
|
||||
"reasoning": "年龄可能影响结果,需要控制"
|
||||
}}
|
||||
],
|
||||
"confidence_score": 0.90,
|
||||
"reasoning": "根据问题描述,这是一个典型的因果推断问题..."
|
||||
}}"""
|
||||
|
||||
DEFAULT_USER_PROMPT_TEMPLATE = """用户问题:{question}
|
||||
|
||||
数据表结构:
|
||||
{table_info}
|
||||
|
||||
请识别因果推断要素并输出 JSON 格式的结果。"""
|
||||
|
||||
def __init__(self, file_query_tool: Optional[FileQueryTool] = None,
|
||||
default_read_rows: int = 100):
|
||||
"""
|
||||
初始化因果推断解析器
|
||||
|
||||
Args:
|
||||
file_query_tool: 文件查询工具实例,如果为 None 则创建默认实例
|
||||
default_read_rows: 默认读取的行数
|
||||
"""
|
||||
self.file_query_tool = file_query_tool or FileQueryTool()
|
||||
self.default_read_rows = default_read_rows
|
||||
|
||||
def parse(self, request: CausalInferenceRequest) -> CausalInferenceResult:
|
||||
"""
|
||||
解析因果推断请求
|
||||
|
||||
Args:
|
||||
request: 因果推断请求对象
|
||||
|
||||
Returns:
|
||||
CausalInferenceResult: 解析结果
|
||||
"""
|
||||
# 获取文件信息
|
||||
file_metadata = self.file_query_tool.query(
|
||||
request.file_path,
|
||||
request.read_rows
|
||||
)
|
||||
|
||||
# 使用启发式方法进行解析
|
||||
result = self._heuristic_parse(
|
||||
request.user_question,
|
||||
file_metadata
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _build_table_info(self, metadata: FileMetadata) -> str:
|
||||
"""
|
||||
构建表头信息字符串
|
||||
|
||||
Args:
|
||||
metadata: 文件元数据
|
||||
|
||||
Returns:
|
||||
str: 格式化的表头信息
|
||||
"""
|
||||
lines = [
|
||||
f"文件名:{metadata.file_name}",
|
||||
f"总行数:{metadata.total_rows}",
|
||||
f"总列数:{metadata.total_columns}",
|
||||
"",
|
||||
"列信息:"
|
||||
]
|
||||
|
||||
if metadata.columns is None:
|
||||
return "\n".join(lines)
|
||||
|
||||
for col in metadata.columns:
|
||||
lines.append(f" - {col.name}: {col.dtype} (非空:{col.non_null_count}, 空值:{col.null_count}, 唯一值:{col.unique_count})")
|
||||
if col.sample_values:
|
||||
sample_str = ", ".join(str(v) for v in col.sample_values[:3])
|
||||
lines.append(f" 样本值:{sample_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _heuristic_parse(self, question: str, metadata: FileMetadata) -> CausalInferenceResult:
|
||||
"""
|
||||
使用启发式规则进行因果推断解析
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
metadata: 文件元数据
|
||||
|
||||
Returns:
|
||||
CausalInferenceResult: 解析结果
|
||||
"""
|
||||
# 启发式规则识别变量
|
||||
treatment_var = self._heuristic_identify_treatment(question, metadata)
|
||||
outcome_var = self._heuristic_identify_outcome(question, metadata)
|
||||
control_vars = self._heuristic_identify_controls(question, metadata, treatment_var, outcome_var)
|
||||
|
||||
# 计算置信度
|
||||
confidence = self._calculate_confidence(treatment_var, outcome_var, control_vars)
|
||||
|
||||
# 构建推理理由
|
||||
reasoning = self._build_reasoning(question, treatment_var, outcome_var, control_vars)
|
||||
|
||||
# 查找其他可能的匹配
|
||||
alternative_matches = self._find_alternative_matches(metadata, treatment_var, outcome_var)
|
||||
|
||||
# 构建结果对象
|
||||
result = CausalInferenceResult(
|
||||
original_question=question,
|
||||
file_path=metadata.file_path,
|
||||
treatment_variable=treatment_var,
|
||||
outcome_variable=outcome_var,
|
||||
control_variables=control_vars,
|
||||
confidence_score=confidence,
|
||||
reasoning=reasoning,
|
||||
alternative_matches=alternative_matches,
|
||||
metadata={
|
||||
"parsed_at": datetime.now().isoformat(),
|
||||
"file_info": {
|
||||
"file_name": metadata.file_name,
|
||||
"total_rows": metadata.total_rows,
|
||||
"total_columns": metadata.total_columns
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
"""
|
||||
模拟 LLM 解析(用于测试)
|
||||
|
||||
基于启发式规则进行解析,不依赖真实的 LLM。
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
metadata: 文件元数据
|
||||
|
||||
Returns:
|
||||
str: 模拟的 JSON 响应
|
||||
"""
|
||||
# 启发式规则识别变量
|
||||
treatment_var = self._heuristic_identify_treatment(question, metadata)
|
||||
outcome_var = self._heuristic_identify_outcome(question, metadata)
|
||||
control_vars = self._heuristic_identify_controls(question, metadata, treatment_var, outcome_var)
|
||||
|
||||
# 计算置信度
|
||||
confidence = self._calculate_confidence(treatment_var, outcome_var, control_vars)
|
||||
|
||||
# 构建推理理由
|
||||
reasoning = self._build_reasoning(question, treatment_var, outcome_var, control_vars)
|
||||
|
||||
# 构建结果
|
||||
result = {
|
||||
"treatment_variable": treatment_var.to_dict() if treatment_var else None,
|
||||
"outcome_variable": outcome_var.to_dict() if outcome_var else None,
|
||||
"control_variables": [v.to_dict() for v in control_vars],
|
||||
"confidence_score": confidence,
|
||||
"reasoning": reasoning,
|
||||
"alternative_matches": self._find_alternative_matches(metadata, treatment_var, outcome_var)
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
def _heuristic_identify_treatment(self, question: str,
|
||||
metadata: FileMetadata) -> Optional[VariableInfo]:
|
||||
"""
|
||||
启发式识别处理变量
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
metadata: 文件元数据
|
||||
|
||||
Returns:
|
||||
Optional[VariableInfo]: 处理变量信息
|
||||
"""
|
||||
# 常见处理变量关键词
|
||||
treatment_keywords = [
|
||||
"处理", "干预", "治疗", "实验组", "对照组", "treatment", "treat",
|
||||
"干预组", "实验", "exposure", "exposed", "assigned", "group"
|
||||
]
|
||||
|
||||
# 检查列名是否包含处理变量关键词
|
||||
if metadata.columns is None:
|
||||
return None
|
||||
|
||||
for col in metadata.columns:
|
||||
col_name_lower = col.name.lower()
|
||||
for keyword in treatment_keywords:
|
||||
if keyword in col_name_lower:
|
||||
return VariableInfo(
|
||||
name=col.name,
|
||||
variable_type="treatment",
|
||||
confidence=0.85,
|
||||
reasoning=f"列名包含处理变量关键词:{keyword}",
|
||||
data_type=col.dtype,
|
||||
is_binary=col.unique_count == 2 if col.unique_count else None
|
||||
)
|
||||
|
||||
# 检查是否有明显的二值变量
|
||||
for col in metadata.columns:
|
||||
if col.unique_count == 2 and col.non_null_count == metadata.total_rows:
|
||||
return VariableInfo(
|
||||
name=col.name,
|
||||
variable_type="treatment",
|
||||
confidence=0.70,
|
||||
reasoning="发现二值变量,可能是处理变量",
|
||||
data_type=col.dtype,
|
||||
is_binary=True
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _heuristic_identify_outcome(self, question: str,
|
||||
metadata: FileMetadata) -> Optional[VariableInfo]:
|
||||
"""
|
||||
启发式识别结果变量
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
metadata: 文件元数据
|
||||
|
||||
Returns:
|
||||
Optional[VariableInfo]: 结果变量信息
|
||||
"""
|
||||
# 常见结果变量关键词
|
||||
outcome_keywords = [
|
||||
"结果", "影响", "效应", "收入", "工资", "成绩", "分数", "outcome",
|
||||
"result", "effect", "impact", "dependent", "y", "目标", "指标"
|
||||
]
|
||||
|
||||
# 检查列名是否包含结果变量关键词
|
||||
if metadata.columns is None:
|
||||
return None
|
||||
|
||||
for col in metadata.columns:
|
||||
col_name_lower = col.name.lower()
|
||||
for keyword in outcome_keywords:
|
||||
if keyword in col_name_lower:
|
||||
return VariableInfo(
|
||||
name=col.name,
|
||||
variable_type="outcome",
|
||||
confidence=0.85,
|
||||
reasoning=f"列名包含结果变量关键词:{keyword}",
|
||||
data_type=col.dtype,
|
||||
is_binary=False
|
||||
)
|
||||
|
||||
# 检查是否有数值型列(通常是结果变量)
|
||||
for col in metadata.columns:
|
||||
if col.dtype in ['float64', 'int64', 'float32', 'int32']:
|
||||
return VariableInfo(
|
||||
name=col.name,
|
||||
variable_type="outcome",
|
||||
confidence=0.65,
|
||||
reasoning="发现数值型列,可能是结果变量",
|
||||
data_type=col.dtype,
|
||||
is_binary=False
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _heuristic_identify_controls(self, question: str,
|
||||
metadata: FileMetadata,
|
||||
treatment: Optional[VariableInfo],
|
||||
outcome: Optional[VariableInfo]) -> List[VariableInfo]:
|
||||
"""
|
||||
启发式识别协变量
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
metadata: 文件元数据
|
||||
treatment: 处理变量
|
||||
outcome: 结果变量
|
||||
|
||||
Returns:
|
||||
List[VariableInfo]: 协变量列表
|
||||
"""
|
||||
controls = []
|
||||
treated_names = {treatment.name} if treatment else set()
|
||||
outcome_names = {outcome.name} if outcome else set()
|
||||
|
||||
# 常见协变量关键词
|
||||
control_keywords = [
|
||||
"年龄", "性别", "教育", "经验", "控制", "covariate", "control",
|
||||
"confounder", "confounding", "特征", "变量", "demographic"
|
||||
]
|
||||
|
||||
if metadata.columns is None:
|
||||
return []
|
||||
|
||||
for col in metadata.columns:
|
||||
# 跳过已识别的变量
|
||||
if col.name in treated_names or col.name in outcome_names:
|
||||
continue
|
||||
|
||||
col_name_lower = col.name.lower()
|
||||
|
||||
# 检查是否包含协变量关键词
|
||||
is_control = False
|
||||
reasoning = ""
|
||||
for keyword in control_keywords:
|
||||
if keyword in col_name_lower:
|
||||
is_control = True
|
||||
reasoning = f"列名包含协变量关键词:{keyword}"
|
||||
break
|
||||
|
||||
# 如果没有关键词,但也不是处理或结果变量,可能是协变量
|
||||
if not is_control:
|
||||
is_control = True
|
||||
reasoning = "可能是协变量,用于控制混淆因素"
|
||||
|
||||
controls.append(VariableInfo(
|
||||
name=col.name,
|
||||
variable_type="control",
|
||||
confidence=0.60 if is_control else 0.40,
|
||||
reasoning=reasoning,
|
||||
data_type=col.dtype,
|
||||
is_binary=col.unique_count == 2 if col.unique_count else None
|
||||
))
|
||||
|
||||
return controls
|
||||
|
||||
def _calculate_confidence(self, treatment: Optional[VariableInfo],
|
||||
outcome: Optional[VariableInfo],
|
||||
controls: List[VariableInfo]) -> float:
|
||||
"""
|
||||
计算整体置信度
|
||||
|
||||
Args:
|
||||
treatment: 处理变量
|
||||
outcome: 结果变量
|
||||
controls: 协变量列表
|
||||
|
||||
Returns:
|
||||
float: 置信度评分 (0-1)
|
||||
"""
|
||||
if treatment is None or outcome is None:
|
||||
return 0.3
|
||||
|
||||
# 基础置信度
|
||||
base_confidence = (treatment.confidence + outcome.confidence) / 2
|
||||
|
||||
# 协变量数量影响
|
||||
if len(controls) > 0:
|
||||
avg_control_confidence = sum(c.confidence for c in controls) / len(controls)
|
||||
base_confidence = (base_confidence + avg_control_confidence) / 2
|
||||
|
||||
return min(base_confidence, 1.0)
|
||||
|
||||
def _build_reasoning(self, question: str, treatment: Optional[VariableInfo],
|
||||
outcome: Optional[VariableInfo],
|
||||
controls: List[VariableInfo]) -> str:
|
||||
"""
|
||||
构建推理理由
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
treatment: 处理变量
|
||||
outcome: 结果变量
|
||||
controls: 协变量列表
|
||||
|
||||
Returns:
|
||||
str: 推理理由
|
||||
"""
|
||||
reasons = []
|
||||
|
||||
if treatment:
|
||||
reasons.append(f"处理变量 '{treatment.name}' 被识别为干预措施")
|
||||
|
||||
if outcome:
|
||||
reasons.append(f"结果变量 '{outcome.name}' 被识别为效应指标")
|
||||
|
||||
if controls:
|
||||
control_names = [c.name for c in controls[:3]]
|
||||
if len(controls) > 3:
|
||||
control_names.append(f"等 {len(controls) - 3} 个其他变量")
|
||||
reasons.append(f"协变量 {', '.join(control_names)} 被识别为需要控制的混淆因素")
|
||||
|
||||
return ";".join(reasons) if reasons else "未能充分识别因果推断要素"
|
||||
|
||||
def _find_alternative_matches(self, metadata: FileMetadata,
|
||||
treatment: Optional[VariableInfo],
|
||||
outcome: Optional[VariableInfo]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
查找其他可能的匹配
|
||||
|
||||
Args:
|
||||
metadata: 文件元数据
|
||||
treatment: 处理变量
|
||||
outcome: 结果变量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 其他可能的匹配
|
||||
"""
|
||||
alternatives = []
|
||||
|
||||
if metadata.columns is None:
|
||||
return []
|
||||
|
||||
# 查找其他可能的处理变量
|
||||
if treatment:
|
||||
for col in metadata.columns:
|
||||
if col.name != treatment.name:
|
||||
alternatives.append({
|
||||
"variable_type": "treatment",
|
||||
"name": col.name,
|
||||
"confidence": 0.3,
|
||||
"reasoning": "可能是替代的处理变量"
|
||||
})
|
||||
|
||||
# 查找其他可能的结果变量
|
||||
if outcome:
|
||||
for col in metadata.columns:
|
||||
if col.name != outcome.name:
|
||||
alternatives.append({
|
||||
"variable_type": "outcome",
|
||||
"name": col.name,
|
||||
"confidence": 0.3,
|
||||
"reasoning": "可能是替代的结果变量"
|
||||
})
|
||||
|
||||
return alternatives[:5] # 最多返回 5 个替代匹配
|
||||
|
||||
def _parse_llm_response(self, response: str, request: CausalInferenceRequest,
|
||||
metadata: FileMetadata) -> CausalInferenceResult:
|
||||
"""
|
||||
解析 LLM 响应
|
||||
|
||||
Args:
|
||||
response: LLM 响应
|
||||
request: 原始请求
|
||||
metadata: 文件元数据
|
||||
|
||||
Returns:
|
||||
CausalInferenceResult: 解析结果
|
||||
"""
|
||||
try:
|
||||
# 清理响应,提取 JSON 部分
|
||||
response = response.strip()
|
||||
if response.startswith("```json"):
|
||||
response = response[7:]
|
||||
if response.startswith("```"):
|
||||
response = response[1:]
|
||||
if response.endswith("```"):
|
||||
response = response[:-3]
|
||||
response = response.strip()
|
||||
|
||||
# 解析 JSON
|
||||
data = json.loads(response)
|
||||
|
||||
# 构建结果对象
|
||||
result = CausalInferenceResult(
|
||||
original_question=request.user_question,
|
||||
file_path=request.file_path,
|
||||
treatment_variable=VariableInfo(**data.get("treatment_variable")) if data.get("treatment_variable") else None,
|
||||
outcome_variable=VariableInfo(**data.get("outcome_variable")) if data.get("outcome_variable") else None,
|
||||
control_variables=[VariableInfo(**v) for v in data.get("control_variables", [])],
|
||||
confidence_score=data.get("confidence_score", 0.0),
|
||||
reasoning=data.get("reasoning", ""),
|
||||
alternative_matches=data.get("alternative_matches", []),
|
||||
metadata={
|
||||
"parsed_at": datetime.now().isoformat(),
|
||||
"file_info": {
|
||||
"file_name": metadata.file_name,
|
||||
"total_rows": metadata.total_rows,
|
||||
"total_columns": metadata.total_columns
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# 如果 JSON 解析失败,返回错误结果
|
||||
return CausalInferenceResult(
|
||||
original_question=request.user_question,
|
||||
file_path=request.file_path,
|
||||
confidence_score=0.0,
|
||||
reasoning=f"LLM 响应解析失败:{str(e)}",
|
||||
metadata={
|
||||
"parsed_at": datetime.now().isoformat(),
|
||||
"error": str(e),
|
||||
"raw_response": response[:500]
|
||||
}
|
||||
)
|
||||
|
||||
def parse_simple(self, question: str, file_path: str) -> CausalInferenceResult:
|
||||
"""
|
||||
简化的解析方法
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
CausalInferenceResult: 解析结果
|
||||
"""
|
||||
request = CausalInferenceRequest(
|
||||
user_question=question,
|
||||
file_path=file_path
|
||||
)
|
||||
return self.parse(request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
因果推断解析器模块的简单测试示例
|
||||
"""
|
||||
from file_query_tool import FileQueryTool
|
||||
|
||||
print("=" * 60)
|
||||
print("CausalParser 简单测试示例")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# 创建解析器实例(不使用 LLM 客户端)
|
||||
print("步骤 1: 创建 CausalParser 实例...")
|
||||
parser = CausalParser()
|
||||
print(" - CausalParser 已初始化")
|
||||
print()
|
||||
|
||||
# 准备示例数据文件
|
||||
print("步骤 2: 准备示例数据文件...")
|
||||
sample_file = "data/output/test_simple_ate.csv"
|
||||
print(f" - 文件路径:{sample_file}")
|
||||
print()
|
||||
|
||||
# 创建解析请求
|
||||
print("步骤 3: 创建解析请求...")
|
||||
sample_question = "教育对收入的影响是什么?"
|
||||
request = CausalInferenceRequest(
|
||||
user_question=sample_question,
|
||||
file_path=sample_file
|
||||
)
|
||||
print(f" - 问题:{sample_question}")
|
||||
print()
|
||||
|
||||
# 执行解析
|
||||
print("步骤 4: 执行解析...")
|
||||
try:
|
||||
result = parser.parse(request)
|
||||
print(" - 解析成功!")
|
||||
print()
|
||||
|
||||
# 输出解析结果
|
||||
print("步骤 5: 解析结果:")
|
||||
print("-" * 60)
|
||||
print(f"原始问题:{result.original_question}")
|
||||
print(f"文件路径:{result.file_path}")
|
||||
print(f"整体置信度:{result.confidence_score:.2%}")
|
||||
print()
|
||||
|
||||
if result.treatment_variable:
|
||||
tv = result.treatment_variable
|
||||
print(f"处理变量 (T):")
|
||||
print(f" - 名称:{tv.name}")
|
||||
print(f" - 类型:{tv.variable_type}")
|
||||
print(f" - 置信度:{tv.confidence:.2%}")
|
||||
print(f" - 推理:{tv.reasoning}")
|
||||
print()
|
||||
|
||||
if result.outcome_variable:
|
||||
ov = result.outcome_variable
|
||||
print(f"结果变量 (Y):")
|
||||
print(f" - 名称:{ov.name}")
|
||||
print(f" - 类型:{ov.variable_type}")
|
||||
print(f" - 置信度:{ov.confidence:.2%}")
|
||||
print(f" - 推理:{ov.reasoning}")
|
||||
print()
|
||||
|
||||
if result.control_variables:
|
||||
print(f"协变量 (Control Variables):")
|
||||
for cv in result.control_variables:
|
||||
print(f" - {cv.name} (置信度:{cv.confidence:.2%})")
|
||||
print(f" 推理:{cv.reasoning}")
|
||||
print()
|
||||
|
||||
print(f"推理理由:{result.reasoning}")
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("完整 JSON 输出:")
|
||||
print("=" * 60)
|
||||
print(result.to_json())
|
||||
|
||||
except Exception as e:
|
||||
print(f" - 解析失败:{str(e)}")
|
||||
print()
|
||||
print("请检查:")
|
||||
print(" 1. 数据文件路径是否正确")
|
||||
293
input_processor/file_query_tool.py
Normal file
293
input_processor/file_query_tool.py
Normal file
@ -0,0 +1,293 @@
|
||||
"""
|
||||
文件查询工具模块 (File Query Tool)
|
||||
|
||||
该模块提供文件查询功能,支持读取 CSV 和 Excel 文件的表头信息、
|
||||
统计信息和样本数据,而不需要读取完整数据。
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColumnInfo:
|
||||
"""列信息数据类"""
|
||||
name: str # 列名
|
||||
dtype: str # 数据类型
|
||||
non_null_count: int # 非空值数量
|
||||
null_count: int # 空值数量
|
||||
unique_count: Optional[int] = None # 唯一值数量(可选)
|
||||
sample_values: Optional[List[Any]] = None # 样本值(可选)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleData:
|
||||
"""样本数据数据类"""
|
||||
rows: List[Dict[str, Any]] # 样本行数据
|
||||
row_count: int # 样本行数
|
||||
columns: List[str] # 列名列表
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"rows": self.rows,
|
||||
"row_count": self.row_count,
|
||||
"columns": self.columns
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileMetadata:
|
||||
"""文件元数据"""
|
||||
file_path: str # 文件路径
|
||||
file_name: str # 文件名
|
||||
file_size: int # 文件大小(字节)
|
||||
file_type: str # 文件类型(csv/excel)
|
||||
extension: str # 文件扩展名
|
||||
total_rows: Optional[int] = None # 总行数
|
||||
total_columns: Optional[int] = None # 总列数
|
||||
columns: Optional[List[ColumnInfo]] = None # 列信息列表
|
||||
sample_data: Optional[SampleData] = None # 样本数据
|
||||
read_rows: int = 100 # 实际读取的行数
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result = asdict(self)
|
||||
if self.columns:
|
||||
result["columns"] = [col.to_dict() for col in self.columns]
|
||||
if self.sample_data:
|
||||
result["sample_data"] = self.sample_data.to_dict()
|
||||
return result
|
||||
|
||||
|
||||
class FileQueryTool:
|
||||
"""
|
||||
文件查询工具
|
||||
|
||||
用于读取 CSV 和 Excel 文件的表头信息、统计信息和样本数据。
|
||||
支持只读取前 N 行,避免加载完整数据。
|
||||
|
||||
Attributes:
|
||||
default_read_rows: 默认读取的行数(100 行)
|
||||
|
||||
Example:
|
||||
>>> tool = FileQueryTool()
|
||||
>>> metadata = tool.query("data.csv")
|
||||
>>> print(metadata.columns)
|
||||
>>> print(metadata.sample_data)
|
||||
"""
|
||||
|
||||
def __init__(self, default_read_rows: int = 100):
|
||||
"""
|
||||
初始化文件查询工具
|
||||
|
||||
Args:
|
||||
default_read_rows: 默认读取的行数,用于获取表头和样本数据
|
||||
"""
|
||||
self.default_read_rows = default_read_rows
|
||||
|
||||
def query(self, file_path: str, read_rows: Optional[int] = None) -> FileMetadata:
|
||||
"""
|
||||
查询文件信息
|
||||
|
||||
读取文件的前 N 行,提取表头信息、统计信息和样本数据。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
read_rows: 要读取的行数,如果为 None 则使用默认值
|
||||
|
||||
Returns:
|
||||
FileMetadata: 文件元数据对象
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 不支持的文件格式
|
||||
Exception: 其他读取错误
|
||||
"""
|
||||
if read_rows is None:
|
||||
read_rows = self.default_read_rows
|
||||
|
||||
# 验证文件存在
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"文件不存在:{file_path}")
|
||||
|
||||
# 获取文件信息
|
||||
file_path = os.path.abspath(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
extension = os.path.splitext(file_name)[1].lower()
|
||||
|
||||
# 确定文件类型
|
||||
if extension == ".csv":
|
||||
file_type = "csv"
|
||||
df = self._read_csv(file_path, read_rows)
|
||||
elif extension in [".xlsx", ".xls"]:
|
||||
file_type = "excel"
|
||||
df = self._read_excel(file_path, read_rows)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式:{extension},仅支持 CSV 和 Excel 文件")
|
||||
|
||||
# 提取列信息
|
||||
columns = self._extract_column_info(df)
|
||||
|
||||
# 提取样本数据
|
||||
sample_data = self._extract_sample_data(df, read_rows)
|
||||
|
||||
# 创建元数据对象
|
||||
metadata = FileMetadata(
|
||||
file_path=file_path,
|
||||
file_name=file_name,
|
||||
file_size=file_size,
|
||||
file_type=file_type,
|
||||
extension=extension,
|
||||
total_rows=len(df),
|
||||
total_columns=len(df.columns),
|
||||
columns=columns,
|
||||
sample_data=sample_data,
|
||||
read_rows=min(len(df), read_rows)
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
def _read_csv(self, file_path: str, read_rows: int) -> pd.DataFrame:
|
||||
"""
|
||||
读取 CSV 文件
|
||||
|
||||
Args:
|
||||
file_path: CSV 文件路径
|
||||
read_rows: 要读取的行数
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 读取的数据框
|
||||
"""
|
||||
df = pd.read_csv(file_path, nrows=read_rows)
|
||||
return df
|
||||
|
||||
def _read_excel(self, file_path: str, read_rows: int) -> pd.DataFrame:
|
||||
"""
|
||||
读取 Excel 文件
|
||||
|
||||
Args:
|
||||
file_path: Excel 文件路径
|
||||
read_rows: 要读取的行数
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 读取的数据框
|
||||
"""
|
||||
df = pd.read_excel(file_path, nrows=read_rows)
|
||||
return df
|
||||
|
||||
def _extract_column_info(self, df: pd.DataFrame) -> List[ColumnInfo]:
|
||||
"""
|
||||
提取列信息
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
|
||||
Returns:
|
||||
List[ColumnInfo]: 列信息列表
|
||||
"""
|
||||
columns = []
|
||||
|
||||
for col in df.columns:
|
||||
# 获取数据类型
|
||||
dtype = str(df[col].dtype)
|
||||
|
||||
# 统计非空值和空值
|
||||
non_null_count = df[col].notna().sum()
|
||||
null_count = df[col].isna().sum()
|
||||
|
||||
# 计算唯一值数量(仅对非空值)
|
||||
unique_count = df[col].dropna().nunique()
|
||||
|
||||
# 获取样本值(最多 5 个)
|
||||
sample_values = df[col].dropna().head(5).tolist()
|
||||
|
||||
column_info = ColumnInfo(
|
||||
name=col,
|
||||
dtype=dtype,
|
||||
non_null_count=int(non_null_count),
|
||||
null_count=int(null_count),
|
||||
unique_count=int(unique_count),
|
||||
sample_values=sample_values
|
||||
)
|
||||
columns.append(column_info)
|
||||
|
||||
return columns
|
||||
|
||||
def _extract_sample_data(self, df: pd.DataFrame, read_rows: int) -> SampleData:
|
||||
"""
|
||||
提取样本数据
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
read_rows: 实际读取的行数
|
||||
|
||||
Returns:
|
||||
SampleData: 样本数据对象
|
||||
"""
|
||||
# 获取实际行数(可能小于 read_rows)
|
||||
actual_rows = min(len(df), read_rows)
|
||||
|
||||
# 转换为字典列表
|
||||
rows = df.head(actual_rows).to_dict(orient='records')
|
||||
|
||||
return SampleData(
|
||||
rows=rows,
|
||||
row_count=actual_rows,
|
||||
columns=df.columns.tolist()
|
||||
)
|
||||
|
||||
def get_column_names(self, file_path: str, read_rows: Optional[int] = None) -> List[str]:
|
||||
"""
|
||||
快速获取列名列表
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
read_rows: 要读取的行数
|
||||
|
||||
Returns:
|
||||
List[str]: 列名列表
|
||||
"""
|
||||
metadata = self.query(file_path, read_rows)
|
||||
return [col.name for col in metadata.columns]
|
||||
|
||||
def get_column_dtype(self, file_path: str, column_name: str,
|
||||
read_rows: Optional[int] = None) -> str:
|
||||
"""
|
||||
获取指定列的数据类型
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
column_name: 列名
|
||||
read_rows: 要读取的行数
|
||||
|
||||
Returns:
|
||||
str: 数据类型
|
||||
"""
|
||||
metadata = self.query(file_path, read_rows)
|
||||
for col in metadata.columns:
|
||||
if col.name == column_name:
|
||||
return col.dtype
|
||||
raise ValueError(f"列 '{column_name}' 不存在")
|
||||
|
||||
def get_file_info(self, file_path: str, read_rows: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
获取文件信息的字典格式
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
read_rows: 要读取的行数
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 文件信息字典
|
||||
"""
|
||||
metadata = self.query(file_path, read_rows)
|
||||
return metadata.to_dict()
|
||||
217
input_processor/test_input_processor.py
Normal file
217
input_processor/test_input_processor.py
Normal file
@ -0,0 +1,217 @@
|
||||
"""
|
||||
输入处理模块测试脚本
|
||||
|
||||
该脚本测试文件查询工具和因果推断解析器的功能。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# 导入模块
|
||||
from input_processor import FileQueryTool, CausalParser, CausalInferenceRequest
|
||||
|
||||
|
||||
def test_file_query_tool():
|
||||
"""测试文件查询工具"""
|
||||
print("=" * 60)
|
||||
print("测试文件查询工具")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用项目中的示例数据
|
||||
test_files = [
|
||||
"data/output/simple_ate_data.csv",
|
||||
"data/output/simple_ate_data.xlsx"
|
||||
]
|
||||
|
||||
tool = FileQueryTool(default_read_rows=100)
|
||||
|
||||
for file_path in test_files:
|
||||
if not os.path.exists(file_path):
|
||||
print(f"\n文件不存在,跳过:{file_path}")
|
||||
continue
|
||||
|
||||
print(f"\n测试文件:{file_path}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
# 测试基本查询
|
||||
metadata = tool.query(file_path)
|
||||
|
||||
print(f"文件名:{metadata.file_name}")
|
||||
print(f"文件类型:{metadata.file_type}")
|
||||
print(f"文件大小:{metadata.file_size} 字节")
|
||||
print(f"总行数:{metadata.total_rows}")
|
||||
print(f"总列数:{metadata.total_columns}")
|
||||
print(f"实际读取行数:{metadata.read_rows}")
|
||||
|
||||
print("\n列信息:")
|
||||
for col in metadata.columns:
|
||||
print(f" - {col.name}: {col.dtype}")
|
||||
print(f" 非空值:{col.non_null_count}, 空值:{col.null_count}")
|
||||
print(f" 唯一值:{col.unique_count}")
|
||||
if col.sample_values:
|
||||
print(f" 样本值:{col.sample_values[:3]}")
|
||||
|
||||
print("\n样本数据:")
|
||||
if metadata.sample_data:
|
||||
for i, row in enumerate(metadata.sample_data.rows[:3]):
|
||||
print(f" 行 {i+1}: {row}")
|
||||
|
||||
# 测试获取列名
|
||||
print("\n列名列表:")
|
||||
col_names = tool.get_column_names(file_path)
|
||||
print(f" {col_names}")
|
||||
|
||||
# 测试获取数据类型
|
||||
if col_names:
|
||||
first_col = col_names[0]
|
||||
dtype = tool.get_column_dtype(file_path, first_col)
|
||||
print(f"\n列 '{first_col}' 的数据类型:{dtype}")
|
||||
|
||||
# 测试获取文件信息字典
|
||||
print("\n文件信息字典:")
|
||||
file_info = tool.get_file_info(file_path)
|
||||
print(json.dumps(file_info, ensure_ascii=False, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误:{e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("文件查询工具测试完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def test_causal_parser():
|
||||
"""测试因果推断解析器"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试因果推断解析器")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用项目中的示例数据
|
||||
test_file = "data/output/simple_ate_data.csv"
|
||||
|
||||
if not os.path.exists(test_file):
|
||||
print(f"文件不存在,跳过测试:{test_file}")
|
||||
return
|
||||
|
||||
parser = CausalParser()
|
||||
|
||||
# 测试问题列表
|
||||
test_questions = [
|
||||
"教育对收入的影响是什么?",
|
||||
"处理变量对结果变量的影响如何?",
|
||||
"分析 treatment 对 outcome 的因果效应",
|
||||
"这个数据集中 treatment 和 outcome 有什么关系?",
|
||||
"如何评估干预措施的效果?"
|
||||
]
|
||||
|
||||
for question in test_questions:
|
||||
print(f"\n用户问题:{question}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
# 使用简化的解析方法
|
||||
result = parser.parse_simple(question, test_file)
|
||||
|
||||
# 输出结果
|
||||
print(result.to_json())
|
||||
|
||||
# 结构化输出
|
||||
print("\n结构化结果:")
|
||||
if result.treatment_variable:
|
||||
print(f" 处理变量 (T): {result.treatment_variable.name}")
|
||||
print(f" 置信度:{result.treatment_variable.confidence:.2f}")
|
||||
print(f" 推理:{result.treatment_variable.reasoning}")
|
||||
|
||||
if result.outcome_variable:
|
||||
print(f" 结果变量 (Y): {result.outcome_variable.name}")
|
||||
print(f" 置信度:{result.outcome_variable.confidence:.2f}")
|
||||
print(f" 推理:{result.outcome_variable.reasoning}")
|
||||
|
||||
if result.control_variables:
|
||||
print(f" 协变量 ({len(result.control_variables)} 个):")
|
||||
for var in result.control_variables[:3]:
|
||||
print(f" - {var.name} (置信度:{var.confidence:.2f})")
|
||||
if len(result.control_variables) > 3:
|
||||
print(f" ... 还有 {len(result.control_variables) - 3} 个")
|
||||
|
||||
print(f"\n 整体置信度:{result.confidence_score:.2f}")
|
||||
print(f" 推理理由:{result.reasoning}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误:{e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("因果推断解析器测试完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def test_causal_parser_with_request():
|
||||
"""测试使用 CausalInferenceRequest 对象的解析"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试使用 CausalInferenceRequest 对象")
|
||||
print("=" * 60)
|
||||
|
||||
test_file = "data/output/simple_ate_data.csv"
|
||||
|
||||
if not os.path.exists(test_file):
|
||||
print(f"文件不存在,跳过测试:{test_file}")
|
||||
return
|
||||
|
||||
parser = CausalParser()
|
||||
|
||||
# 创建请求对象
|
||||
request = CausalInferenceRequest(
|
||||
user_question="分析 treatment 变量对 outcome 变量的影响",
|
||||
file_path=test_file,
|
||||
read_rows=50,
|
||||
llm_provider="mock",
|
||||
llm_model="mock-model"
|
||||
)
|
||||
|
||||
print(f"请求信息:")
|
||||
print(f" 问题:{request.user_question}")
|
||||
print(f" 文件:{request.file_path}")
|
||||
print(f" 读取行数:{request.read_rows}")
|
||||
print(f" LLM 提供商:{request.llm_provider}")
|
||||
print(f" LLM 模型:{request.llm_model}")
|
||||
|
||||
# 执行解析
|
||||
result = parser.parse(request)
|
||||
|
||||
print("\n解析结果:")
|
||||
print(result.to_json())
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("请求对象测试完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "=" * 60)
|
||||
print("输入处理模块测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试文件查询工具
|
||||
test_file_query_tool()
|
||||
|
||||
# 测试因果推断解析器
|
||||
test_causal_parser()
|
||||
|
||||
# 测试使用请求对象的解析
|
||||
test_causal_parser_with_request()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
192
pyproject.toml
Normal file
192
pyproject.toml
Normal file
@ -0,0 +1,192 @@
|
||||
[project]
|
||||
# 项目名称
|
||||
name = "causal-inference-agent"
|
||||
# 项目版本(遵循语义化版本控制)
|
||||
version = "0.1.0"
|
||||
# 项目描述
|
||||
description = "Causal Inference Agent System"
|
||||
# 项目作者
|
||||
authors = [
|
||||
{ name = "Causal Inference Team", email = "team@example.com" }
|
||||
]
|
||||
# 项目许可证(可选)
|
||||
license = "MIT"
|
||||
# 项目 readme 文件
|
||||
readme = "README.md"
|
||||
# 项目关键词
|
||||
keywords = ["causal-inference", "agent", "data-science", "statistics"]
|
||||
# 项目分类
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Scientific/Engineering :: Mathematics",
|
||||
]
|
||||
|
||||
# Python 版本要求
|
||||
# 最低支持 Python 3.11
|
||||
requires-python = ">=3.11"
|
||||
|
||||
# 项目运行时依赖
|
||||
dependencies = [
|
||||
"pandas>=2.0.0", # 数据处理和分析
|
||||
"openpyxl>=3.1.0", # Excel 文件读写支持
|
||||
"numpy>=1.24.0", # 数值计算基础库
|
||||
"scipy>=1.10.0", # 科学计算和统计功能
|
||||
"pydantic>=2.0.0", # 数据验证和设置管理(用于 JSON 输出)
|
||||
"requests>=2.31.0", # HTTP 请求库(用于 LLM API 调用)
|
||||
"networkx>=3.0", # 图论与因果图构建
|
||||
"scikit-learn>=1.3.0", # 机器学习与因果效应估计
|
||||
]
|
||||
|
||||
# 可选依赖(额外功能)
|
||||
[project.optional-dependencies]
|
||||
# 开发依赖
|
||||
dev = [
|
||||
"pytest>=7.4.0", # 测试框架
|
||||
"pytest-cov>=4.1.0", # 代码覆盖率
|
||||
"black>=23.0.0", # 代码格式化
|
||||
"ruff>=0.1.0", # 代码 linting
|
||||
"mypy>=1.5.0", # 类型检查
|
||||
"pre-commit>=3.4.0", # Git 预提交钩子
|
||||
]
|
||||
|
||||
# 文档依赖
|
||||
docs = [
|
||||
"sphinx>=7.0.0", # 文档生成工具
|
||||
"sphinx-rtd-theme>=1.3.0", # Sphinx 主题
|
||||
"myst-parser>=2.0.0", # Markdown 解析器
|
||||
]
|
||||
|
||||
# 测试依赖
|
||||
test = [
|
||||
"pytest>=7.4.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"pytest-xdist>=3.3.0", # 并行测试
|
||||
]
|
||||
|
||||
# 完整依赖(包含所有可选依赖)
|
||||
all = [
|
||||
"causal-inference-agent[dev,docs,test]",
|
||||
]
|
||||
|
||||
# 包配置
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["causal_agent*", "input_processor*", "data*"]
|
||||
|
||||
[project.scripts]
|
||||
# 命令行入口点
|
||||
causal-agent = "causal_agent.cli:main"
|
||||
|
||||
[project.urls]
|
||||
# 项目相关链接
|
||||
Homepage = "https://github.com/example/causal-inference-agent"
|
||||
Repository = "https://github.com/example/causal-inference-agent"
|
||||
"Bug Tracker" = "https://github.com/example/causal-inference-agent/issues"
|
||||
"Documentation" = "https://causal-inference-agent.readthedocs.io"
|
||||
"Changelog" = "https://github.com/example/causal-inference-agent/blob/main/CHANGELOG.md"
|
||||
|
||||
# ==================== uv 配置 ====================
|
||||
|
||||
# uv 索引配置
|
||||
[tool.uv]
|
||||
# 使用 PyPI 作为默认索引
|
||||
index-url = "https://pypi.org/simple"
|
||||
# 额外索引(可选)
|
||||
extra-index-url = []
|
||||
# 依赖解析策略
|
||||
resolution = "lowest-direct"
|
||||
|
||||
# ==================== 开发工具配置 ====================
|
||||
|
||||
# Black 代码格式化配置
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py39', 'py310', 'py311', 'py312']
|
||||
include = '\.pyi?$'
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| _build
|
||||
| buck-out
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
'''
|
||||
|
||||
# Ruff 代码 linting 配置
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
target-version = "py39"
|
||||
# 启用的 lint 规则
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # Pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by black)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"] # unused imports in __init__.py
|
||||
|
||||
# MyPy 类型检查配置
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
|
||||
# Pytest 测试配置
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests", "input_processor"]
|
||||
python_files = ["test_*.py"]
|
||||
python_functions = ["test_*"]
|
||||
addopts = [
|
||||
"-v",
|
||||
"--strict-markers",
|
||||
"--strict-config",
|
||||
]
|
||||
markers = [
|
||||
"slow: marks tests as slow",
|
||||
"integration: marks tests as integration tests",
|
||||
]
|
||||
|
||||
# 代码覆盖率配置
|
||||
[tool.coverage.run]
|
||||
source = ["."]
|
||||
omit = [
|
||||
"*/tests/*",
|
||||
"*/__init__.py",
|
||||
"*/conftest.py",
|
||||
]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise AssertionError",
|
||||
"raise NotImplementedError",
|
||||
"if __name__ == .__main__.:",
|
||||
]
|
||||
|
||||
# 预提交钩子配置
|
||||
[tool.pre-commit]
|
||||
# 默认钩子
|
||||
default_stages = ["commit", "push", "manual"]
|
||||
0
tests/test_causal_agent/__init__.py
Normal file
0
tests/test_causal_agent/__init__.py
Normal file
49
tests/test_causal_agent/test_causal_graph.py
Normal file
49
tests/test_causal_agent/test_causal_graph.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Tests for causal_graph module."""
|
||||
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from causal_agent.analysis.causal_graph import (
|
||||
build_local_graph,
|
||||
find_adjustment_set,
|
||||
find_backdoor_paths,
|
||||
)
|
||||
|
||||
|
||||
def test_build_local_graph():
|
||||
df = pd.DataFrame({
|
||||
"treatment": [0, 1, 0, 1],
|
||||
"health": [1, 2, 1, 3],
|
||||
"age": [20, 30, 40, 50],
|
||||
})
|
||||
candidates = [{"var": "age"}]
|
||||
tiers = {"treatment": 2, "health": 4, "age": 0}
|
||||
G = build_local_graph(df, "treatment", "health", candidates, tiers, corr_threshold=0.1)
|
||||
|
||||
assert isinstance(G, nx.DiGraph)
|
||||
assert G.has_edge("treatment", "health")
|
||||
assert G.has_edge("age", "treatment")
|
||||
assert G.has_edge("age", "health")
|
||||
|
||||
|
||||
def test_find_backdoor_paths():
|
||||
G = nx.DiGraph()
|
||||
G.add_edge("age", "treatment")
|
||||
G.add_edge("age", "health")
|
||||
G.add_edge("treatment", "health")
|
||||
|
||||
paths = find_backdoor_paths(G, "treatment", "health")
|
||||
assert len(paths) == 1
|
||||
assert "treatment <- age -> health" in paths
|
||||
|
||||
|
||||
def test_find_adjustment_set():
|
||||
G = nx.DiGraph()
|
||||
G.add_edge("age", "treatment")
|
||||
G.add_edge("age", "health")
|
||||
G.add_edge("treatment", "health")
|
||||
|
||||
adjustment_set, reasoning = find_adjustment_set(G, "treatment", "health")
|
||||
assert "age" in adjustment_set
|
||||
assert "后门" in reasoning
|
||||
39
tests/test_causal_agent/test_estimation.py
Normal file
39
tests/test_causal_agent/test_estimation.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Tests for estimation module."""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from causal_agent.analysis.estimation import estimate_ate
|
||||
|
||||
|
||||
def test_estimate_ate_basic():
|
||||
np.random.seed(42)
|
||||
n = 300
|
||||
z = np.random.normal(0, 1, n)
|
||||
t = (z + np.random.normal(0, 0.5, n) > 0).astype(int)
|
||||
y = 2 * t + 3 * z + np.random.normal(0, 0.5, n)
|
||||
|
||||
df = pd.DataFrame({"T": t, "Y": y, "Z": z})
|
||||
result = estimate_ate(df, "T", "Y", ["Z"], compute_ci=False)
|
||||
|
||||
assert "ATE_Outcome_Regression" in result
|
||||
assert "ATE_IPW" in result
|
||||
assert abs(result["ATE_Outcome_Regression"] - 2.0) < 0.5
|
||||
# IPW 估计方差较大,放宽阈值
|
||||
assert abs(result["ATE_IPW"] - 2.0) < 1.5
|
||||
assert "balance_check" in result
|
||||
assert "Z" in result["balance_check"]
|
||||
|
||||
|
||||
def test_estimate_ate_empty_adjustment():
|
||||
np.random.seed(42)
|
||||
n = 200
|
||||
t = np.random.binomial(1, 0.5, n)
|
||||
y = 1.5 * t + np.random.normal(0, 0.5, n)
|
||||
|
||||
df = pd.DataFrame({"T": t, "Y": y})
|
||||
result = estimate_ate(df, "T", "Y", [], compute_ci=False)
|
||||
|
||||
assert abs(result["ATE_Outcome_Regression"] - 1.5) < 0.5
|
||||
assert abs(result["ATE_IPW"] - 1.5) < 0.5
|
||||
31
tests/test_causal_agent/test_screening.py
Normal file
31
tests/test_causal_agent/test_screening.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Tests for screening module."""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from causal_agent.analysis.screening import local_screen
|
||||
|
||||
|
||||
def test_local_screen_finds_confounders():
|
||||
np.random.seed(42)
|
||||
n = 200
|
||||
z = np.random.normal(0, 1, n)
|
||||
t = (z + np.random.normal(0, 0.5, n) > 0).astype(int)
|
||||
y = 2 * t + 3 * z + np.random.normal(0, 0.5, n)
|
||||
noise = np.random.normal(0, 1, n)
|
||||
|
||||
df = pd.DataFrame({"T": t, "Y": y, "Z": z, "noise": noise})
|
||||
candidates = local_screen(df, T="T", Y="Y", excluded=[], corr_threshold=0.1, alpha=0.05)
|
||||
|
||||
var_names = [c["var"] for c in candidates]
|
||||
assert "Z" in var_names
|
||||
assert "noise" not in var_names
|
||||
|
||||
|
||||
def test_local_screen_excludes_specified_columns():
|
||||
df = pd.DataFrame({"T": [0, 1], "Y": [1, 2], "X": [0, 1], "id": [1, 2]})
|
||||
candidates = local_screen(df, T="T", Y="Y", excluded=["id"])
|
||||
assert all(c["var"] != "id" for c in candidates)
|
||||
assert all(c["var"] != "T" for c in candidates)
|
||||
assert all(c["var"] != "Y" for c in candidates)
|
||||
Loading…
x
Reference in New Issue
Block a user