This commit is contained in:
Cloyir 2026-03-29 23:47:20 +08:00
commit 16da68c038
38 changed files with 7796 additions and 0 deletions

19
.env.example Normal file
View File

@ -0,0 +1,19 @@
# Causal Inference Agent - LLM 配置示例
# 复制此文件为 .env 并修改为你自己的配置
# LLM API 配置(必填)
LLM_BASE_URL=https://xxxx/v1
LLM_MODEL=qwen3.5-35b
# 可选配置
LLM_TEMPERATURE=0.3
LLM_MAX_TOKENS=2048
# LLM_API_KEY=your-api-key-here
# 统计筛查配置(可选)
# CORR_THRESHOLD=0.1
# ALPHA=0.05
# BOOTSTRAP_ITERATIONS=500
# 日志配置(可选)
# LOG_PATH=causal_analysis.log.md

159
.gitignore vendored Normal file
View File

@ -0,0 +1,159 @@
causal_analysis.log.md
report.json
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv/
env/
env.bak/
venv/
ENV/
env.sh
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# OS
.DS_Store
Thumbs.db
# Logs
*.log
logs/
# Data (optional - adjust based on your needs)
data/output/
# 保留根目录的 data.xlsx
*.xlsx
!/data.xlsx
*.csv
# Temporary files
*.tmp
*.temp
*.bak

297
README.md Normal file
View File

@ -0,0 +1,297 @@
# Causal Inference Agent
一个领域无关的通用因果推断 Agent提供从数据加载、变量识别、因果图构建、混杂识别到效应估计的完整 Pipeline。
## 功能特性
- **LLM 智能变量识别**:自动识别处理变量 (T)、结果变量 (Y),并对所有变量进行时间层级解析
- **快速相关性筛查**:使用 Pearson + Spearman + 互信息筛选潜在混杂因素
- **混合智能因果图**:基于时间层级和统计相关性构建因果图
- **后门准则自动识别**:发现后门路径并推导最小调整集
- **双重效应估计**:支持 Outcome Regression (OR) 和 Inverse Probability Weighting (IPW)
- **标准化 JSON 报告**:包含因果图、识别策略、估计结果、诊断信息
## 快速开始
### 环境要求
- Python >= 3.11
- uv >= 0.4.0
### 安装 uv
```bash
# Windows (PowerShell)
irm https://astral.sh/uv/install.ps1 | iex
# macOS/Linux
curl -LsSf https://astral.sh/uv/install.sh | sh
```
### 项目设置
```bash
# 进入项目目录
cd CausalInferenceAgent_V2
# 创建虚拟环境并安装依赖
uv venv
uv pip install -e .
```
## 使用方式
### 1. 命令行工具 (CLI)
```bash
# 基础用法
uv run python -m causal_agent --data data.xlsx --output report.json --prompt "分析吃药的效果"
# 自定义 LLM 配置
uv run python -m causal_agent --data data.xlsx --model qwen3.5-35b --base-url http://... --output report.json
# 查看帮助
uv run python -m causal_agent --help
```
### 2. Python API
```python
from causal_agent.agent import CausalInferenceAgent
from causal_agent.core.config import AgentConfig
# 使用默认配置
agent = CausalInferenceAgent()
result = agent.analyze("data.xlsx")
# 查看报告
print(result["report"])
```
### 3. 运行示例
```bash
uv run python -m causal_agent --data data.xlsx --output report.json --prompt "分析吃药的效果"
```
## 项目结构
```
CausalInferenceAgent_V2/
├── pyproject.toml # 项目配置
├── README.md # 本文件
├── uv.lock # uv 锁定文件
├── causal_agent/ # 主包
│ ├── __init__.py
│ ├── agent.py # CausalInferenceAgent
│ ├── cli.py # 命令行入口
│ ├── logger.py # 日志记录器
│ ├── core/ # 核心模块
│ │ ├── config.py # AgentConfig 配置
│ │ ├── llm_client.py # LLM 客户端
│ │ └── data_loader.py # 数据加载器
│ └── analysis/ # 分析模块
│ ├── variable_parser.py # LLM 变量识别
│ ├── screening.py # 相关性筛查
│ ├── causal_graph.py # 因果图构建
│ ├── estimation.py # 效应估计
│ └── reporting.py # 报告生成
├── input_processor/ # 输入处理(原有)
├── examples/
│ └── medical_v2/ # 医疗数据示例
├── tests/
│ └── test_causal_agent/ # 单元测试
└── data/ # 测试数据
```
## 完整 Pipeline 说明
Agent 执行以下 6 个步骤:
### Step 1: LLM 变量识别
调用 LLM 识别:
- `treatment`: 处理变量
- `outcome`: 结果变量
- `time_tiers`: 所有变量的时间层级(-1~4+
**time_tiers 层级说明:**
- `-1`: 非时间变量(如 id、index
- `0`: 人口学特征(如 age、gender
- `1`: 基线测量(如 baseline_score
- `2`: 干预点/处理变量
- `3`: 中介变量
- `4`: 随访结果/结果变量
- `5+`: 更晚时间点
### Step 2: 快速相关性筛查
使用 Pearson + Spearman 相关系数 + 互信息 (MI)
- 筛选与 T 和 Y 都显著相关的变量(|r| > 0.1, p < 0.05
- 作为候选混杂因素
### Step 3: 因果图构建
基于时间层级构建混合智能因果图:
- 时间层级低的变量指向层级高的变量
- 统计相关性低于阈值的边被过滤
- 边类型:`temporal``confounding``hypothesized`
### Step 4: 混杂识别(后门准则)
- 发现所有以 `T <-` 开始的后门路径
- 自动推导阻断所有后门路径的最小调整集
### Step 5: 效应估计
提供两种估计方法:
**Outcome Regression (OR)**
- 线性回归 + G-computation
- 直接估计 E[Y|do(T=1)] - E[Y|do(T=0)]
**Inverse Probability Weighting (IPW)**
- Logistic 回归估计倾向得分
- 截断稳定权重避免极端值
同时提供:
- Bootstrap 95% 置信区间
- 标准化均值差SMD平衡检验
- 重叠假设Positivity检查
- OR/IPW 稳健性判断
### Step 6: 生成报告
标准化 JSON 报告包含:
```json
{
"query_interpretation": {
"treatment": "treatment",
"outcome": "health",
"estimand": "ATE"
},
"causal_graph": {
"nodes": ["age", "base_health", "treatment", "health"],
"edges": [...],
"backdoor_paths": ["treatment <- base_health -> health"]
},
"identification": {
"strategy": "Backdoor Adjustment",
"adjustment_set": ["age", "base_health"],
"reasoning": "..."
},
"estimation": {
"ATE_Outcome_Regression": 0.2444,
"ATE_IPW": 0.2425,
"95%_CI": [0.2351, 0.2544],
"interpretation": "..."
},
"diagnostics": {
"balance_check": {...},
"overlap_assumption": "满足",
"robustness": "稳健"
},
"warnings": [...]
}
```
## 配置选项
### 方法 1使用 `.env` 文件(推荐)
项目已包含 `.env.example` 文件,复制为 `.env` 并修改:
```bash
cp .env.example .env
```
编辑 `.env` 文件修改 LLM URL
```bash
# 只需修改这一行即可切换 LLM 服务器
LLM_BASE_URL=https://your-new-llm-server.com/v1
LLM_MODEL=your-model-name
```
**优点**
- 切换 URL 只需改一个文件
- `.env` 不会被提交到 git已在 .gitignore 中)
- 所有代码自动读取,无需重启
### 方法 2环境变量
```bash
# Windows PowerShell
$env:LLM_BASE_URL="https://your-new-llm-server.com/v1"
$env:LLM_MODEL="qwen3.5-35b"
# macOS/Linux
export LLM_BASE_URL="https://your-new-llm-server.com/v1"
export LLM_MODEL="qwen3.5-35b"
export LLM_TEMPERATURE="0.3"
export LLM_MAX_TOKENS="2048"
export LLM_API_KEY="your-api-key"
# 统计筛查配置
export CORR_THRESHOLD="0.1"
export ALPHA="0.05"
export BOOTSTRAP_ITERATIONS="500"
# 路径配置
export LOG_PATH="causal_analysis.log.md"
```
### 方法 3Python 代码中配置
```python
from causal_agent.core.config import AgentConfig
config = AgentConfig(
llm_model="gpt-4",
llm_temperature=0.5,
corr_threshold=0.15,
log_path="my_analysis.log"
)
agent = CausalInferenceAgent(config)
```
## 开发工具
```bash
# 代码格式化
uv run black .
# 代码检查
uv run ruff check .
# 类型检查
uv run mypy .
# 运行测试
uv run pytest tests/test_causal_agent/ -v
# 运行测试并生成覆盖率报告
uv run pytest tests/ --cov=. --cov-report=html
```
## 许可证
MIT License
## 贡献
欢迎贡献代码!请遵循以下步骤:
1. Fork 本仓库
2. 创建特性分支 (`git checkout -b feature/AmazingFeature`)
3. 提交更改 (`git commit -m 'Add some AmazingFeature'`)
4. 推送到分支 (`git push origin feature/AmazingFeature`)
5. 开启 Pull Request

12
causal_agent/__init__.py Normal file
View File

@ -0,0 +1,12 @@
"""
Causal Inference Agent
一个领域无关的通用因果推断 Agent提供从数据加载变量识别
因果图构建混杂识别到效应估计的完整 pipeline
"""
from .agent import CausalInferenceAgent
from .core.config import AgentConfig
__version__ = "0.2.0"
__all__ = ["CausalInferenceAgent", "AgentConfig"]

6
causal_agent/__main__.py Normal file
View File

@ -0,0 +1,6 @@
"""Allow running causal_agent as a module: python -m causal_agent"""
from causal_agent.cli import main
if __name__ == "__main__":
main()

186
causal_agent/agent.py Normal file
View File

@ -0,0 +1,186 @@
"""
Causal Inference Agent Orchestrator
整合完整 pipeline 的核心代理类
"""
from typing import Any, Dict, Optional
import pandas as pd
from causal_agent.analysis.causal_graph import (
build_local_graph,
find_adjustment_set,
find_backdoor_paths,
)
from causal_agent.analysis.estimation import estimate_ate
from causal_agent.analysis.reporting import generate_report, print_report
from causal_agent.analysis.screening import local_screen
from causal_agent.analysis.variable_parser import VariableParser
from causal_agent.core.config import AgentConfig
from causal_agent.core.data_loader import DataLoader, get_data_info
from causal_agent.core.llm_client import LLMClient
from causal_agent.logger import CausalAnalysisLogger
class CausalInferenceAgent:
"""通用因果推断 Agent"""
def __init__(self, config: Optional[AgentConfig] = None):
self.config = config or AgentConfig.from_env()
self.data_loader = DataLoader()
self.llm_client = LLMClient(
base_url=self.config.llm_base_url,
model=self.config.llm_model,
temperature=self.config.llm_temperature,
max_tokens=self.config.llm_max_tokens,
api_key=self.config.llm_api_key,
timeout=self.config.llm_timeout,
max_retries=self.config.llm_max_retries,
)
self.variable_parser = VariableParser(self.llm_client)
self.logger = CausalAnalysisLogger(self.config.log_path)
def analyze(
self, data_path: str, custom_prompt: Optional[str] = None
) -> Dict[str, Any]:
"""
执行完整的因果分析 pipeline
Args:
data_path: 数据文件路径
custom_prompt: 可选的自定义提示词补充
Returns:
包含完整报告和中间结果的字典
"""
df = self.data_loader.load(data_path)
columns = list(df.columns)
data_info = get_data_info(df)
print(f"成功加载数据:{data_path}")
print(f"数据形状:{df.shape}")
print(f"列名:{columns}")
# Step 1: LLM 变量识别
print("\n[Step 1] LLM 变量识别...")
parsed = self.variable_parser.parse(columns, data_info, custom_prompt)
T = parsed["treatment"]
Y = parsed["outcome"]
tiers = parsed["time_tiers"]
print(f"处理变量:{T},结果变量:{Y}")
print(f"时间层级:{tiers}")
# Step 2: 快速相关性筛查
print("\n[Step 2] 快速相关性筛查...")
candidates = local_screen(
df,
T=T,
Y=Y,
excluded=[c for c in columns if tiers.get(c, 0) < 0], # 排除 id 类
corr_threshold=self.config.corr_threshold,
alpha=self.config.alpha,
)
print(f"候选混杂变量:{[c['var'] for c in candidates]}")
# Step 3: 因果图构建
print("\n[Step 3] 因果图构建...")
G = build_local_graph(
df, T, Y, candidates, tiers, corr_threshold=self.config.corr_threshold
)
edges = [
{"from": u, "to": v, "type": data.get("type", "unknown")}
for u, v, data in G.edges(data=True)
]
nodes = list(G.nodes())
print(f"图节点:{nodes}")
print(f"图边:{edges}")
# Step 4: 混杂识别
print("\n[Step 4] 后门路径识别...")
backdoor_paths = find_backdoor_paths(G, T, Y)
print(f"后门路径:{backdoor_paths}")
adjustment_set, reasoning = find_adjustment_set(G, T, Y, backdoor_paths)
print(f"调整集:{adjustment_set}")
print(f"调整理由:{reasoning}")
# Step 5: 效应估计
print("\n[Step 5] 效应估计...")
estimation_result = estimate_ate(
df,
T,
Y,
adjustment_set,
compute_ci=True,
n_bootstrap=self.config.bootstrap_iterations,
alpha=self.config.bootstrap_alpha,
)
print(f"ATE (OR): {estimation_result['ATE_Outcome_Regression']}")
print(f"ATE (IPW): {estimation_result['ATE_IPW']}")
print(f"95% CI: {estimation_result['95%_CI']}")
# Step 6: 生成报告
print("\n[Step 6] 生成报告...")
causal_graph_info = {
"nodes": nodes,
"edges": edges,
"backdoor_paths": backdoor_paths,
}
identification_info = {
"strategy": "Backdoor Adjustment",
"adjustment_set": adjustment_set,
"reasoning": reasoning,
}
query_interpretation = {
"treatment": T,
"outcome": Y,
"estimand": "ATE",
}
report = generate_report(
query_interpretation=query_interpretation,
causal_graph=causal_graph_info,
identification=identification_info,
estimation=estimation_result,
)
print_report(report)
# 记录日志
parameters = {
"data_path": data_path,
"sample_size": len(df),
"variables": columns,
"treatment_variable": T,
"outcome_variable": Y,
"time_tiers": tiers,
"llm_params": {
"base_url": self.config.llm_base_url,
"model": self.config.llm_model,
"temperature": self.config.llm_temperature,
"max_tokens": self.config.llm_max_tokens,
},
"candidates": candidates,
"causal_graph": causal_graph_info,
"identification": identification_info,
"estimation": {
k: v
for k, v in estimation_result.items()
if k not in ("warnings", "balance_check")
},
"log_path": self.logger.log_path,
}
self.logger.log_analysis(
VariableParser.SYSTEM_PROMPT,
self.variable_parser._build_user_prompt(columns, data_info, custom_prompt),
str(parsed),
parameters,
report,
)
return {
"id": f"{self.logger.analysis_count:03d}",
"report": report,
"parameters": parameters,
}

View File

@ -0,0 +1,18 @@
"""Analysis modules for Causal Inference Agent."""
from .screening import local_screen
from .causal_graph import build_local_graph, find_backdoor_paths, find_adjustment_set
from .estimation import estimate_ate
from .reporting import generate_report, print_report
from .variable_parser import VariableParser
__all__ = [
"local_screen",
"build_local_graph",
"find_backdoor_paths",
"find_adjustment_set",
"estimate_ate",
"generate_report",
"print_report",
"VariableParser",
]

View File

@ -0,0 +1,134 @@
"""
因果图构建与后门准则模块
功能
1. 基于时间层级和统计相关性构建混合智能因果图
2. 自动发现后门路径
3. 基于后门准则寻找最小调整集
"""
from typing import Any, Dict, List, Optional, Tuple
import networkx as nx
import pandas as pd
from scipy.stats import spearmanr
def build_local_graph(
df: pd.DataFrame,
T: str,
Y: str,
candidates: List[Dict[str, Any]],
tiers: Dict[str, int],
corr_threshold: float = 0.1,
) -> nx.DiGraph:
"""
构建局部因果图
Args:
df: 输入数据框
T: 处理变量
Y: 结果变量
candidates: 候选混杂变量列表
tiers: 时间层级字典
corr_threshold: 相关性阈值
Returns:
networkx.DiGraph 对象
"""
G = nx.DiGraph()
relevant_vars = [T, Y] + [c["var"] for c in candidates]
G.add_nodes_from(relevant_vars)
for u in relevant_vars:
for v in relevant_vars:
if u == v:
continue
tier_u = tiers.get(u, -1)
tier_v = tiers.get(v, -1)
if tier_u < tier_v:
corr, _ = spearmanr(df[u], df[v])
if abs(corr) > corr_threshold:
edge_type = "confounding" if v in [T, Y] else "temporal"
G.add_edge(u, v, type=edge_type, confidence=1.0)
G.add_edge(T, Y, type="hypothesized", confidence="research_question")
return G
def find_backdoor_paths(G: nx.DiGraph, T: str, Y: str) -> List[str]:
"""寻找从 T 到 Y 的所有后门路径"""
UG = G.to_undirected()
try:
paths = list(nx.all_simple_paths(UG, source=T, target=Y, cutoff=5))
except nx.NodeNotFound:
paths = []
backdoor_paths = []
for path in paths:
if len(path) < 2:
continue
second_node = path[1]
if G.has_edge(second_node, T):
formatted = _format_path(G, path)
backdoor_paths.append(formatted)
seen = set()
unique_paths = []
for p in backdoor_paths:
if p not in seen:
seen.add(p)
unique_paths.append(p)
return unique_paths
def _format_path(G: nx.DiGraph, path: List[str]) -> str:
"""将无向路径格式化为带方向箭头的字符串"""
parts = [path[0]]
for i in range(len(path) - 1):
u, v = path[i], path[i + 1]
if G.has_edge(u, v) and G.has_edge(v, u):
parts.append(f"<-> {v}")
elif G.has_edge(u, v):
parts.append(f"-> {v}")
elif G.has_edge(v, u):
parts.append(f"<- {v}")
else:
parts.append(f"-- {v}")
return " ".join(parts)
def find_adjustment_set(
G: nx.DiGraph, T: str, Y: str, backdoor_paths: Optional[List[str]] = None
) -> Tuple[List[str], str]:
"""基于后门准则寻找最小调整集"""
if backdoor_paths is None:
backdoor_paths = find_backdoor_paths(G, T, Y)
if not backdoor_paths:
return [], "未发现后门路径,无需额外调整变量即可识别因果效应。"
adjustment_candidates = set()
for path_str in backdoor_paths:
nodes = _parse_path_nodes(path_str)
if len(nodes) >= 2:
adjustment_candidates.add(nodes[1])
safe_adjustment_set = sorted(adjustment_candidates)
reasoning = (
f"发现 {len(backdoor_paths)} 条后门路径。"
f"通过控制变量 {safe_adjustment_set} 可阻断所有后门路径,满足后门准则。"
)
return safe_adjustment_set, reasoning
def _parse_path_nodes(path_str: str) -> List[str]:
"""从格式化路径字符串中提取节点列表"""
cleaned = (
path_str.replace("->", " ")
.replace("<-", " ")
.replace("<->", " ")
.replace("--", " ")
)
nodes = [n.strip() for n in cleaned.split() if n.strip()]
return nodes

View File

@ -0,0 +1,203 @@
"""
效应估计模块
提供 Outcome Regression (OR) Inverse Probability Weighting (IPW) 两种估计方法
"""
from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression, LogisticRegression
def _compute_smd(df: pd.DataFrame, var: str, T: str) -> float:
"""计算标准化均值差SMD"""
treated = df[df[T] == 1][var]
control = df[df[T] == 0][var]
pooled_std = np.sqrt((treated.var() + control.var()) / 2)
if pooled_std == 0:
return 0.0
return float((treated.mean() - control.mean()) / pooled_std)
def _bootstrap_ci(
df: pd.DataFrame,
T: str,
Y: str,
adjustment_set: List[str],
estimator: str,
n_bootstrap: int = 500,
alpha: float = 0.05,
) -> List[float]:
"""Bootstrap 置信区间"""
n = len(df)
estimates = []
for _ in range(n_bootstrap):
sample = df.sample(n=n, replace=True, random_state=None)
try:
if estimator == "OR":
ate = _estimate_or(sample, T, Y, adjustment_set)
elif estimator == "IPW":
ate = _estimate_ipw(sample, T, Y, adjustment_set)
else:
ate = np.nan
estimates.append(ate)
except Exception:
estimates.append(np.nan)
estimates = np.array(estimates)
estimates = estimates[~np.isnan(estimates)]
if len(estimates) == 0:
return [np.nan, np.nan]
lower = float(np.percentile(estimates, 100 * alpha / 2))
upper = float(np.percentile(estimates, 100 * (1 - alpha / 2)))
return [round(lower, 4), round(upper, 4)]
def _estimate_or(df: pd.DataFrame, T: str, Y: str, adjustment_set: List[str]) -> float:
"""Outcome Regression 估计 ATE"""
X_cols = adjustment_set + [T]
X = df[X_cols].values
y = df[Y].values
model = LinearRegression().fit(X, y)
df_t1 = df.copy()
df_t1[T] = 1
df_t0 = df.copy()
df_t0[T] = 0
ate = (
model.predict(df_t1[X_cols].values).mean()
- model.predict(df_t0[X_cols].values).mean()
)
return float(ate)
def _estimate_ipw(df: pd.DataFrame, T: str, Y: str, adjustment_set: List[str]) -> float:
"""IPW 估计 ATE"""
t = df[T].values
if len(adjustment_set) == 0:
ps = np.full_like(t, fill_value=t.mean(), dtype=float)
else:
X = df[adjustment_set].values
ps_model = LogisticRegression(max_iter=1000, solver="lbfgs")
ps_model.fit(X, t)
ps = ps_model.predict_proba(X)[:, 1]
ps_clipped = np.clip(ps, 0.05, 0.95)
weights_t = t / ps_clipped
weights_c = (1 - t) / (1 - ps_clipped)
ate = np.mean(weights_t * df[Y].values) - np.mean(weights_c * df[Y].values)
return float(ate)
def estimate_ate(
df: pd.DataFrame,
T: str,
Y: str,
adjustment_set: List[str],
compute_ci: bool = True,
n_bootstrap: int = 500,
alpha: float = 0.05,
) -> Dict[str, Any]:
"""
估计平均处理效应ATE
Args:
df: 输入数据框
T: 处理变量
Y: 结果变量
adjustment_set: 调整变量列表
compute_ci: 是否计算 Bootstrap CI
n_bootstrap: Bootstrap 次数
alpha: 显著性水平
Returns:
包含 ATE_ORATE_IPW诊断信息的字典
"""
warnings_list: List[Dict[str, str]] = []
ate_or = _estimate_or(df, T, Y, adjustment_set)
t = df[T].values
if len(adjustment_set) == 0:
ps = np.full_like(t, fill_value=t.mean(), dtype=float)
else:
X = df[adjustment_set].values
ps_model = LogisticRegression(max_iter=1000, solver="lbfgs")
ps_model.fit(X, t)
ps = ps_model.predict_proba(X)[:, 1]
overlap_ok = True
if np.any(ps < 0.05) or np.any(ps > 0.95):
overlap_ok = False
warnings_list.append(
{
"type": "positivity_violation",
"message": "倾向得分存在极端值(<0.05 或 >0.95),存在重叠假设违反风险。",
}
)
ate_ipw = _estimate_ipw(df, T, Y, adjustment_set)
diff = abs(ate_or - ate_ipw)
robustness = "稳健"
if ate_or != 0 and diff / abs(ate_or) > 0.1:
robustness = "OR 与 IPW 估计差异 >10%,可能存在模型误设,建议进一步检查"
elif ate_ipw != 0 and diff / abs(ate_ipw) > 0.1:
robustness = "OR 与 IPW 估计差异 >10%,可能存在模型误设,建议进一步检查"
balance_check: Dict[str, Dict[str, float]] = {}
for var in adjustment_set:
smd_before = _compute_smd(df, var, T)
ps_clipped = np.clip(ps, 0.05, 0.95)
weights = np.where(df[T].values == 1, 1.0 / ps_clipped, 1.0 / (1 - ps_clipped))
treated_idx = df[T].values == 1
control_idx = df[T].values == 0
weighted_mean_t = np.average(
df.loc[treated_idx, var].values, weights=weights[treated_idx]
)
weighted_mean_c = np.average(
df.loc[control_idx, var].values, weights=weights[control_idx]
)
pooled_std = np.sqrt(
(df.loc[treated_idx, var].var() + df.loc[control_idx, var].var()) / 2
)
smd_after = (
float((weighted_mean_t - weighted_mean_c) / pooled_std)
if pooled_std > 0
else 0.0
)
balance_check[var] = {
"before": round(smd_before, 4),
"after": round(smd_after, 4),
}
ci_or: List[float] = [np.nan, np.nan]
ci_ipw: List[float] = [np.nan, np.nan]
if compute_ci:
ci_or = _bootstrap_ci(df, T, Y, adjustment_set, "OR", n_bootstrap, alpha)
ci_ipw = _bootstrap_ci(df, T, Y, adjustment_set, "IPW", n_bootstrap, alpha)
ate_report = round((ate_or + ate_ipw) / 2, 4)
ci_report = ci_or if not np.isnan(ci_or[0]) else ci_ipw
interpretation = (
f"在控制 {adjustment_set} 后,接受处理使 {Y} "
f"平均变化 {ate_report:.4f} 95%CI: {ci_report[0]:.2f}-{ci_report[1]:.2f})。"
)
return {
"ATE_Outcome_Regression": round(ate_or, 4),
"ATE_IPW": round(ate_ipw, 4),
"ATE_reported": ate_report,
"95%_CI": ci_report,
"interpretation": interpretation,
"balance_check": balance_check,
"overlap_assumption": "满足" if overlap_ok else "存在风险",
"robustness": robustness,
"warnings": warnings_list,
}

View File

@ -0,0 +1,59 @@
"""
报告生成模块
将因果分析各阶段结果整合为统一的 JSON 报告
"""
import json
from typing import Any, Dict, List, Optional
def generate_report(
query_interpretation: Dict[str, str],
causal_graph: Dict[str, Any],
identification: Dict[str, Any],
estimation: Dict[str, Any],
extra_warnings: Optional[List[Dict[str, str]]] = None,
) -> Dict[str, Any]:
"""生成标准化 JSON 报告"""
warnings = list(extra_warnings) if extra_warnings else []
warnings.extend(estimation.get("warnings", []))
warnings.append(
{
"type": "unobserved_confounding",
"message": "可能存在未观测混杂,建议进行敏感性分析。",
}
)
report = {
"query_interpretation": query_interpretation,
"causal_graph": causal_graph,
"identification": {
"strategy": identification.get("strategy", "Backdoor Adjustment"),
"adjustment_set": identification.get("adjustment_set", []),
"reasoning": identification.get("reasoning", ""),
},
"estimation": {
"ATE_Outcome_Regression": estimation.get("ATE_Outcome_Regression"),
"ATE_IPW": estimation.get("ATE_IPW"),
"95%_CI": estimation.get("95%_CI"),
"interpretation": estimation.get("interpretation", ""),
},
"diagnostics": {
"balance_check": estimation.get("balance_check", {}),
"overlap_assumption": estimation.get("overlap_assumption", "未知"),
"robustness": estimation.get("robustness", "未知"),
},
"warnings": warnings,
}
return report
def print_report(report: Dict[str, Any]) -> None:
"""美观地打印报告"""
print("\n" + "=" * 60)
print("因果推断分析报告")
print("=" * 60)
print(json.dumps(report, indent=2, ensure_ascii=False))
print("=" * 60)

View File

@ -0,0 +1,83 @@
"""
快速相关性筛查模块
使用 Pearson + Spearman 方法快速过滤与处理变量 (T) 和结果变量 (Y) 都相关的变量
"""
from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from sklearn.feature_selection import mutual_info_regression
def local_screen(
df: pd.DataFrame,
T: str,
Y: str,
excluded: Optional[List[str]] = None,
corr_threshold: float = 0.1,
alpha: float = 0.05,
) -> List[Dict[str, Any]]:
"""
快速相关性筛查找出与 T Y 都显著相关的变量
Args:
df: 输入数据框
T: 处理变量名称
Y: 结果变量名称
excluded: 需要排除的变量列表
corr_threshold: 相关系数绝对值阈值
alpha: 显著性水平
Returns:
候选混杂变量列表
"""
if excluded is None:
excluded = []
candidates = []
cols = [c for c in df.columns if c not in [T, Y] + excluded]
for col in cols:
if df[col].isna().mean() > 0.5:
continue
if not pd.api.types.is_numeric_dtype(df[col]):
continue
p_t, pv_t_pearson = pearsonr(df[col], df[T])
p_y, pv_y_pearson = pearsonr(df[col], df[Y])
s_t, pv_t_spear = spearmanr(df[col], df[T])
s_y, pv_y_spear = spearmanr(df[col], df[Y])
cond_t = abs(s_t) > corr_threshold and pv_t_spear < alpha
cond_y = abs(s_y) > corr_threshold and pv_y_spear < alpha
if cond_t and cond_y:
mi_t = _compute_mi(df[[col]].values, df[T].values)
mi_y = _compute_mi(df[[col]].values, df[Y].values)
candidates.append(
{
"var": col,
"pearson_T": round(float(p_t), 4),
"pearson_Y": round(float(p_y), 4),
"spearman_T": round(float(s_t), 4),
"spearman_Y": round(float(s_y), 4),
"pvalue_T": round(float(pv_t_spear), 4),
"pvalue_Y": round(float(pv_y_spear), 4),
"mi_T": round(float(mi_t), 4),
"mi_Y": round(float(mi_y), 4),
}
)
candidates.sort(key=lambda x: abs(x["spearman_T"]) + abs(x["spearman_Y"]), reverse=True)
return candidates
def _compute_mi(X: np.ndarray, y: np.ndarray) -> float:
"""计算互信息"""
mi = mutual_info_regression(X, y, random_state=42)
return float(mi[0])

View File

@ -0,0 +1,133 @@
"""
变量解析模块
使用 LLM 识别处理变量结果变量并对每个变量进行时间层级解析
"""
from typing import Any, Dict, List, Optional
from causal_agent.core.llm_client import LLMClient
class VariableParser:
"""LLM 变量解析器"""
SYSTEM_PROMPT = """你是一位专业的因果推断分析师。你的任务是分析给定的数据识别处理变量treatment、结果变量outcome并对每个变量进行时间层级解析。
请以 JSON 格式输出分析结果不要包含任何额外的解释或思考过程
JSON 输出规范
{
"treatment": "处理变量名称",
"outcome": "结果变量名称",
"time_tiers": {
"变量名1": 整数层级,
"变量名2": 整数层级,
...
}
}
time_tiers 层级说明整数越小表示越早发生
- -1: 非时间变量如样本唯一标识符 idindex
- 0: 人口学特征或不变的混杂因素 agegenderregion
- 1: 基线测量干预前测得可能是混杂因素 baseline_scorepre_test
- 2: 干预点/处理变量 treatmentinterventionpolicy
- 3: 中介变量干预后结果前测得
- 4: 随访结果/结果变量 outcomepost_testscore
- 5+: 更晚的时间点如有多次随访
注意
- 只输出上述 JSON 格式不要包含其他字段
- treatment outcome 必须是数据表格中真实存在的列名
- time_tiers 必须包含数据中的所有列名
- 不要使用 markdown 代码块标记 ```json
- 直接输出纯 JSON 字符串"""
def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm_client = llm_client or LLMClient()
def parse(
self,
columns: List[str],
data_info: str,
custom_prompt: Optional[str] = None,
) -> Dict[str, Any]:
"""
调用 LLM 解析变量
Args:
columns: 数据列名列表
data_info: 数据信息摘要字符串
custom_prompt: 可选的自定义用户提示词补充
Returns:
包含 treatmentoutcometime_tiers 的字典
"""
user_prompt = self._build_user_prompt(columns, data_info, custom_prompt)
result = self.llm_client.generate_response(self.SYSTEM_PROMPT, user_prompt)
if not result.get("success"):
raise RuntimeError(f"LLM 调用失败: {result.get('error')}")
try:
parsed = LLMClient.parse_json_response(result["content"])
except Exception as e:
raise ValueError(f"LLM 输出解析失败: {e}\n原始内容: {result.get('content')}")
self._validate(parsed, columns)
return parsed
def _build_user_prompt(
self,
columns: List[str],
data_info: str,
custom_prompt: Optional[str] = None,
) -> str:
"""构建用户提示词"""
parts = [
"请分析以下数据,并严格按照 JSON 格式输出分析结果:",
"",
data_info,
"",
"JSON 输出格式要求:",
'{',
' "treatment": "处理变量名称",',
' "outcome": "结果变量名称",',
' "time_tiers": {',
' "列名1": 层级整数,',
' "列名2": 层级整数,',
" ...",
" }",
"}",
"",
"要求:",
"1. treatment 和 outcome 必须与表格列名完全一致",
"2. time_tiers 必须覆盖所有列名",
"3. 根据列名含义和统计摘要推断每个变量的时间层级",
"4. 只输出 JSON不要包含其他任何内容",
"5. 不要使用 markdown 代码块标记",
]
if custom_prompt:
parts.insert(1, f"\n用户补充说明:{custom_prompt}\n")
return "\n".join(parts)
def _validate(self, parsed: Dict[str, Any], columns: List[str]):
"""验证 LLM 返回结果"""
treatment = parsed.get("treatment")
outcome = parsed.get("outcome")
tiers = parsed.get("time_tiers", {})
if treatment not in columns:
raise ValueError(f"LLM 返回的 treatment '{treatment}' 不在数据列中")
if outcome not in columns:
raise ValueError(f"LLM 返回的 outcome '{outcome}' 不在数据列中")
if not isinstance(tiers, dict):
raise ValueError("time_tiers 必须是字典类型")
missing = set(columns) - set(tiers.keys())
if missing:
raise ValueError(f"time_tiers 缺少以下列: {sorted(missing)}")
for col, tier in tiers.items():
if not isinstance(tier, int):
raise ValueError(f"time_tiers 中 '{col}' 的层级必须是整数")

67
causal_agent/cli.py Normal file
View File

@ -0,0 +1,67 @@
"""
Causal Inference Agent CLI
命令行入口支持直接对数据文件执行因果推断分析
"""
import argparse
import json
import os
import sys
from causal_agent.agent import CausalInferenceAgent
from causal_agent.core.config import AgentConfig
def main():
parser = argparse.ArgumentParser(
description="通用因果推断 Agent",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python -m causal_agent --data data.csv --output report.json
python -m causal_agent --data data.xlsx --model qwen3.5-35b
python -m causal_agent --data data.csv --prompt "分析教育干预对成绩的影响"
""",
)
parser.add_argument("--data", "-d", required=True, help="数据文件路径")
parser.add_argument("--output", "-o", help="报告输出 JSON 文件路径")
parser.add_argument("--prompt", "-p", help="自定义分析提示词")
parser.add_argument("--base-url", help="LLM API Base URL")
parser.add_argument("--model", help="LLM 模型名称")
parser.add_argument("--temperature", type=float, help="LLM 温度参数")
parser.add_argument("--log-path", help="日志文件路径")
parser.add_argument("--corr-threshold", type=float, help="相关性筛查阈值")
args = parser.parse_args()
if not os.path.exists(args.data):
print(f"错误:数据文件不存在:{args.data}", file=sys.stderr)
sys.exit(1)
# 构建配置
config = AgentConfig.from_env()
if args.base_url:
config.llm_base_url = args.base_url
if args.model:
config.llm_model = args.model
if args.temperature is not None:
config.llm_temperature = args.temperature
if args.log_path:
config.log_path = args.log_path
if args.corr_threshold is not None:
config.corr_threshold = args.corr_threshold
agent = CausalInferenceAgent(config)
result = agent.analyze(args.data, custom_prompt=args.prompt)
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(result["report"], f, ensure_ascii=False, indent=2)
print(f"\n报告已保存到:{args.output}")
print("\n分析完成!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,7 @@
"""Core modules for Causal Inference Agent."""
from .config import AgentConfig
from .llm_client import LLMClient
from .data_loader import DataLoader, get_data_info
__all__ = ["AgentConfig", "LLMClient", "DataLoader", "get_data_info"]

View File

@ -0,0 +1,81 @@
"""
Agent 配置模块
支持从环境变量构造函数参数加载配置
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
def _load_env_file():
"""加载项目根目录的 .env 文件"""
# 尝试多个可能的位置
possible_paths = [
Path.cwd() / ".env",
Path(__file__).parent.parent.parent / ".env",
Path(__file__).parent.parent.parent.parent / ".env",
]
for env_path in possible_paths:
if env_path.exists():
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
# 只在环境变量不存在时才设置
if key not in os.environ:
os.environ[key] = value
break
# 模块加载时自动读取 .env 文件
_load_env_file()
@dataclass
class AgentConfig:
"""Causal Inference Agent 配置"""
# LLM 配置(默认值仅在环境变量和 .env 都不存在时使用)
llm_base_url: str = "https://glm47flash.cloyir.com/v1"
llm_model: str = "qwen3.5-35b"
llm_temperature: float = 0.3
llm_max_tokens: int = 2048
llm_api_key: Optional[str] = None
llm_timeout: int = 120
llm_max_retries: int = 3
# 统计筛查配置
corr_threshold: float = 0.1
alpha: float = 0.05
bootstrap_iterations: int = 500
bootstrap_alpha: float = 0.05
# 路径配置
log_path: str = "causal_analysis.log.md"
output_dir: str = "."
@classmethod
def from_env(cls) -> "AgentConfig":
"""从环境变量加载配置"""
return cls(
llm_base_url=os.getenv("LLM_BASE_URL", cls.llm_base_url),
llm_model=os.getenv("LLM_MODEL", cls.llm_model),
llm_temperature=float(os.getenv("LLM_TEMPERATURE", str(cls.llm_temperature))),
llm_max_tokens=int(os.getenv("LLM_MAX_TOKENS", str(cls.llm_max_tokens))),
llm_api_key=os.getenv("LLM_API_KEY", cls.llm_api_key),
llm_timeout=int(os.getenv("LLM_TIMEOUT", str(cls.llm_timeout))),
llm_max_retries=int(os.getenv("LLM_MAX_RETRIES", str(cls.llm_max_retries))),
corr_threshold=float(os.getenv("CORR_THRESHOLD", str(cls.corr_threshold))),
alpha=float(os.getenv("ALPHA", str(cls.alpha))),
bootstrap_iterations=int(
os.getenv("BOOTSTRAP_ITERATIONS", str(cls.bootstrap_iterations))
),
bootstrap_alpha=float(os.getenv("BOOTSTRAP_ALPHA", str(cls.bootstrap_alpha))),
log_path=os.getenv("LOG_PATH", cls.log_path),
output_dir=os.getenv("OUTPUT_DIR", cls.output_dir),
)

View File

@ -0,0 +1,68 @@
"""
通用数据加载器
支持 CSVExcelJSON Lines 等格式同时提供数据信息摘要生成
"""
import os
from typing import Optional
import pandas as pd
def _read_json_lines(file_path: str) -> pd.DataFrame:
"""读取 JSON Lines 文件"""
return pd.read_json(file_path, lines=True)
class DataLoader:
"""通用数据加载器"""
SUPPORTED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".json", ".jsonl"}
def __init__(self, default_read_rows: Optional[int] = None):
self.default_read_rows = default_read_rows
def load(self, file_path: str) -> pd.DataFrame:
"""加载数据文件为 DataFrame"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"数据文件不存在:{file_path}")
ext = os.path.splitext(file_path)[1].lower()
if ext not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"不支持的文件格式:{ext},仅支持 {self.SUPPORTED_EXTENSIONS}"
)
if ext == ".csv":
kwargs = {}
if self.default_read_rows is not None:
kwargs["nrows"] = self.default_read_rows
return pd.read_csv(file_path, **kwargs)
elif ext in {".xlsx", ".xls"}:
return pd.read_excel(file_path)
elif ext in {".json", ".jsonl"}:
return _read_json_lines(file_path)
raise ValueError(f"未实现的文件格式:{ext}")
def get_data_info(df: pd.DataFrame, treatment_col: Optional[str] = None) -> str:
"""生成数据信息摘要字符串,用于 LLM prompt"""
lines = [
f"**数据概览:**",
f"- 样本数量:{len(df)}",
f"- 变量:{', '.join(df.columns)}",
f"",
f"**统计摘要:**",
f"{df.describe()}",
]
if treatment_col and treatment_col in df.columns:
lines.extend(
[
f"",
f"**处理变量分布:**",
f"{df[treatment_col].value_counts().to_string()}",
]
)
return "\n".join(lines)

View File

@ -0,0 +1,129 @@
"""
LLM 调用模块
用于调用外部 LLM API 进行文本生成支持重试和 JSON 解析辅助
"""
import json
import re
import time
from typing import Dict, Any, List, Optional
import requests
class LLMClient:
"""LLM API 客户端"""
def __init__(
self,
base_url: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
timeout: int = 120,
max_retries: int = 3,
):
self.base_url = base_url.rstrip("/")
self.model = model or os.getenv("LLM_MODEL", "qwen3.5-35b")
self.temperature = temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.3"))
self.max_tokens = max_tokens if max_tokens is not None else int(os.getenv("LLM_MAX_TOKENS", "2048"))
self.timeout = timeout
self.max_retries = max_retries
self.api_key = api_key
self._headers = {"Content-Type": "application/json"}
if api_key:
self._headers["Authorization"] = f"Bearer {api_key}"
def chat_completion(
self,
messages: List[Dict[str, str]],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""发送聊天请求并获取响应,支持重试"""
url = f"{self.base_url}/chat/completions"
payload = {
"model": self.model,
"messages": messages,
"temperature": temperature or self.temperature,
"max_tokens": max_tokens or self.max_tokens,
"stream": False,
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
}
last_error = None
for attempt in range(self.max_retries):
try:
response = requests.post(
url, headers=self._headers, json=payload, timeout=self.timeout
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
last_error = e
if attempt < self.max_retries - 1:
time.sleep(1 * (attempt + 1))
return {
"error": str(last_error),
"status_code": (
getattr(last_error.response, "status_code", None)
if hasattr(last_error, "response")
else None
),
}
def generate_response(
self,
system_prompt: str,
user_prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""生成响应(简化接口)"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
result = self.chat_completion(messages, temperature, max_tokens)
if "error" in result:
return {"success": False, "error": result["error"], "content": None}
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
message = choice.get("message", {})
content = message.get("content")
if content is None:
content = message.get("reasoning", "")
if content is None:
return {
"success": False,
"error": "No content found in response",
"content": None,
}
return {
"success": True,
"error": None,
"content": content,
"usage": result.get("usage", {}),
"model": result.get("model", self.model),
}
return {
"success": False,
"error": f"Unexpected response format: {result}",
"content": None,
}
@staticmethod
def parse_json_response(content: str) -> Dict[str, Any]:
"""清理并解析 LLM 返回的 JSON 字符串"""
cleaned = content.strip()
if cleaned.startswith("```"):
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
cleaned = re.sub(r"\s*```$", "", cleaned)
return json.loads(cleaned)

115
causal_agent/logger.py Normal file
View File

@ -0,0 +1,115 @@
"""
因果分析日志记录器
"""
import json
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
class CausalAnalysisLogger:
"""因果分析日志记录器"""
def __init__(self, log_path: str = "causal_analysis.log.md"):
self.log_path = log_path
self.log_entries: List[Dict[str, Any]] = []
self._init_log_file()
self._load_existing_entries()
def _init_log_file(self):
"""初始化日志文件"""
header = """# 因果分析日志
## 日志说明
本文档记录因果推断分析的所有输入参数和输出结果
## 分析记录
"""
if not os.path.exists(self.log_path):
with open(self.log_path, "w", encoding="utf-8") as f:
f.write(header)
def _load_existing_entries(self):
"""加载已存在的分析记录数量"""
if not os.path.exists(self.log_path):
self.analysis_count = 0
return
analysis_count = 0
with open(self.log_path, "r", encoding="utf-8") as f:
content = f.read()
matches = re.findall(r"### 分析 #(\d+)", content)
if matches:
analysis_count = max(int(m) for m in matches)
self.analysis_count = analysis_count
def _append_to_log(self, content: str):
"""追加内容到日志文件"""
with open(self.log_path, "a", encoding="utf-8") as f:
f.write(content)
def log_analysis(
self,
system_prompt: str,
user_prompt: str,
model_output: str,
parameters: Dict[str, Any],
report: Optional[Dict[str, Any]] = None,
):
"""记录一次完整的分析过程"""
self.analysis_count += 1
analysis_id = f"{self.analysis_count:03d}"
timestamp = datetime.now().isoformat()
report_json = ""
if report is not None:
report_json = (
f"\n#### 分析报告\n"
f"```json\n{json.dumps(report, indent=2, ensure_ascii=False)}\n```\n"
)
entry = f"""
---
### 分析 #{analysis_id}
**时间**: {timestamp}
#### 系统提示词
```
{system_prompt}
```
#### 用户提示词
```
{user_prompt}
```
#### LLM 输出
```
{model_output}
```
{report_json}
#### 调用参数
```json
{json.dumps(parameters, indent=2, ensure_ascii=False)}
```
---
"""
self._append_to_log(entry)
self.log_entries.append(
{
"id": analysis_id,
"timestamp": timestamp,
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"model_output": model_output,
"parameters": parameters,
"report": report,
}
)
print(f"分析 #{analysis_id} 已记录到 {self.log_path}")

BIN
data.xlsx Normal file

Binary file not shown.

807
data/simulator.py Normal file
View File

@ -0,0 +1,807 @@
"""
模拟数据生成模块
本模块提供用于测试因果推断的模拟数据生成功能支持多种数据生成场景
- 简单 ATE 场景
- 协变量场景
- 交互效应场景
- CATE 场景
作者CausalInferenceAgent
版本1.0.0
"""
import numpy as np
import pandas as pd
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Union
from pathlib import Path
class BaseDataSimulator(ABC):
"""
数据模拟器基类
定义所有数据模拟器的通用接口
"""
def __init__(
self,
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
random_state: Optional[int] = None
):
"""
初始化数据模拟器
参数:
n_samples: 样本数量默认为 1000
treatment_rate: 处理比例0-1 之间默认为 0.5
noise_level: 噪声水平默认为 1.0
effect_size: 效应大小默认为 1.0
random_state: 随机种子用于可重复性
"""
self.n_samples = n_samples
self.treatment_rate = treatment_rate
self.noise_level = noise_level
self.effect_size = effect_size
self.random_state = random_state
# 设置随机种子
if random_state is not None:
np.random.seed(random_state)
@abstractmethod
def generate(self) -> pd.DataFrame:
"""
生成模拟数据
返回:
包含模拟数据的 DataFrame
"""
pass
def save_to_csv(self, filepath: Union[str, Path]) -> None:
"""
将生成的数据保存为 CSV 文件
参数:
filepath: 文件路径
"""
df = self.generate()
df.to_csv(filepath, index=False, encoding='utf-8')
def save_to_excel(self, filepath: Union[str, Path]) -> None:
"""
将生成的数据保存为 Excel 文件
参数:
filepath: 文件路径
"""
df = self.generate()
df.to_excel(filepath, index=False)
@abstractmethod
def get_description(self) -> str:
"""
获取数据生成场景的描述
返回:
描述字符串
"""
pass
class SimpleATESimulator(BaseDataSimulator):
"""
简单 ATE平均处理效应场景模拟器
生成只有一个处理变量 T 和一个结果变量 Y 的简单数据
模型Y = α + τ*T + ε
其中 τ 是平均处理效应
"""
def __init__(
self,
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
intercept: float = 0.0,
random_state: Optional[int] = None
):
"""
初始化简单 ATE 模拟器
参数:
n_samples: 样本数量
treatment_rate: 处理比例
noise_level: 噪声水平
effect_size: 平均处理效应大小
intercept: 截距项
random_state: 随机种子
"""
super().__init__(
n_samples=n_samples,
treatment_rate=treatment_rate,
noise_level=noise_level,
effect_size=effect_size,
random_state=random_state
)
self.intercept = intercept
def generate(self) -> pd.DataFrame:
"""
生成简单 ATE 场景的模拟数据
返回:
包含 'T'处理变量 'Y'结果变量 DataFrame
"""
# 生成处理变量(伯努利分布)
T = np.random.binomial(1, self.treatment_rate, self.n_samples)
# 生成噪声
noise = np.random.normal(0, self.noise_level, self.n_samples)
# 生成结果变量Y = intercept + effect_size * T + noise
Y = self.intercept + self.effect_size * T + noise
df = pd.DataFrame({
'T': T,
'Y': Y
})
return df
def get_description(self) -> str:
"""获取场景描述"""
return (
f"简单 ATE 场景:\n"
f"- 样本数量:{self.n_samples}\n"
f"- 处理比例:{self.treatment_rate:.2%}\n"
f"- 噪声水平:{self.noise_level}\n"
f"- 平均处理效应:{self.effect_size}\n"
f"- 截距项:{self.intercept}\n"
f"模型Y = {self.intercept} + {self.effect_size}*T + ε"
)
class CovariateSimulator(BaseDataSimulator):
"""
协变量场景模拟器
生成包含协变量 X 的数据用于控制混淆因素
模型Y = α + β*X + τ*T + ε
"""
def __init__(
self,
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
n_covariates: int = 3,
covariate_distribution: str = 'normal',
covariate_params: Optional[Dict[str, Any]] = None,
random_state: Optional[int] = None
):
"""
初始化协变量场景模拟器
参数:
n_samples: 样本数量
treatment_rate: 处理比例
noise_level: 噪声水平
effect_size: 处理效应大小
n_covariates: 协变量数量
covariate_distribution: 协变量分布类型 ('normal', 'uniform', 'beta')
covariate_params: 协变量分布参数
random_state: 随机种子
"""
super().__init__(
n_samples=n_samples,
treatment_rate=treatment_rate,
noise_level=noise_level,
effect_size=effect_size,
random_state=random_state
)
self.n_covariates = n_covariates
self.covariate_distribution = covariate_distribution
self.covariate_params = covariate_params or {}
def _generate_covariates(self) -> np.ndarray:
"""
生成协变量数据
返回:
形状为 (n_samples, n_covariates) 的协变量数组
"""
if self.covariate_distribution == 'normal':
mean = self.covariate_params.get('mean', 0)
std = self.covariate_params.get('std', 1)
X = np.random.normal(mean, std, (self.n_samples, self.n_covariates))
elif self.covariate_distribution == 'uniform':
low = self.covariate_params.get('low', -1)
high = self.covariate_params.get('high', 1)
X = np.random.uniform(low, high, (self.n_samples, self.n_covariates))
elif self.covariate_distribution == 'beta':
a = self.covariate_params.get('a', 2)
b = self.covariate_params.get('b', 2)
X = np.random.beta(a, b, (self.n_samples, self.n_covariates))
else:
raise ValueError(f"未知的协变量分布类型:{self.covariate_distribution}")
return X
def generate(self) -> pd.DataFrame:
"""
生成协变量场景的模拟数据
返回:
包含协变量 X处理变量 T 和结果变量 Y DataFrame
"""
# 生成协变量
X = self._generate_covariates()
# 生成处理变量
T = np.random.binomial(1, self.treatment_rate, self.n_samples)
# 生成噪声
noise = np.random.normal(0, self.noise_level, self.n_samples)
# 生成协变量系数
beta = np.random.uniform(-1, 1, self.n_covariates)
# 生成结果变量Y = β*X + τ*T + ε
Y = np.dot(X, beta) + self.effect_size * T + noise
# 创建 DataFrame
df_dict = {}
for i in range(self.n_covariates):
df_dict[f'X{i}'] = X[:, i]
df_dict['T'] = T
df_dict['Y'] = Y
df = pd.DataFrame(df_dict)
return df
def get_description(self) -> str:
"""获取场景描述"""
return (
f"协变量场景:\n"
f"- 样本数量:{self.n_samples}\n"
f"- 处理比例:{self.treatment_rate:.2%}\n"
f"- 噪声水平:{self.noise_level}\n"
f"- 处理效应:{self.effect_size}\n"
f"- 协变量数量:{self.n_covariates}\n"
f"- 协变量分布:{self.covariate_distribution}\n"
f"模型Y = β*X + {self.effect_size}*T + ε"
)
class InteractionEffectSimulator(BaseDataSimulator):
"""
交互效应场景模拟器
生成处理效应随协变量变化的数据
模型Y = α + β*X + τ*T + γ*X*T + ε
其中 γ 是交互效应系数
"""
def __init__(
self,
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
interaction_strength: float = 0.5,
n_covariates: int = 2,
random_state: Optional[int] = None
):
"""
初始化交互效应场景模拟器
参数:
n_samples: 样本数量
treatment_rate: 处理比例
noise_level: 噪声水平
effect_size: 基础处理效应
interaction_strength: 交互效应强度
n_covariates: 协变量数量
random_state: 随机种子
"""
super().__init__(
n_samples=n_samples,
treatment_rate=treatment_rate,
noise_level=noise_level,
effect_size=effect_size,
random_state=random_state
)
self.interaction_strength = interaction_strength
self.n_covariates = n_covariates
def generate(self) -> pd.DataFrame:
"""
生成交互效应场景的模拟数据
返回:
包含协变量 X处理变量 T 和结果变量 Y DataFrame
"""
# 生成协变量
X = np.random.normal(0, 1, (self.n_samples, self.n_covariates))
# 生成处理变量
T = np.random.binomial(1, self.treatment_rate, self.n_samples)
# 生成噪声
noise = np.random.normal(0, self.noise_level, self.n_samples)
# 生成协变量系数
beta = np.random.uniform(-0.5, 0.5, self.n_covariates)
# 生成交互效应系数
gamma = np.random.uniform(-0.5, 0.5, self.n_covariates)
# 生成结果变量Y = β*X + τ*T + γ*X*T + ε
Y = np.dot(X, beta) + self.effect_size * T
for i in range(self.n_covariates):
Y += gamma[i] * X[:, i] * T
Y += noise
# 创建 DataFrame
df_dict = {}
for i in range(self.n_covariates):
df_dict[f'X{i}'] = X[:, i]
df_dict['T'] = T
df_dict['Y'] = Y
df = pd.DataFrame(df_dict)
return df
def get_description(self) -> str:
"""获取场景描述"""
return (
f"交互效应场景:\n"
f"- 样本数量:{self.n_samples}\n"
f"- 处理比例:{self.treatment_rate:.2%}\n"
f"- 噪声水平:{self.noise_level}\n"
f"- 基础处理效应:{self.effect_size}\n"
f"- 交互效应强度:{self.interaction_strength}\n"
f"- 协变量数量:{self.n_covariates}\n"
f"模型Y = β*X + {self.effect_size}*T + γ*X*T + ε"
)
class CATESimulator(BaseDataSimulator):
"""
条件平均处理效应CATE场景模拟器
生成处理效应依赖于个体特征的数据
模型Y = α + β*X + τ(X)*T + ε
其中 τ(X) 是条件处理效应函数
"""
def __init__(
self,
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
n_features: int = 5,
cate_function: str = 'linear',
random_state: Optional[int] = None
):
"""
初始化 CATE 场景模拟器
参数:
n_samples: 样本数量
treatment_rate: 处理比例
noise_level: 噪声水平
effect_size: 平均处理效应基准
n_features: 特征数量
cate_function: CATE 函数类型 ('linear', 'nonlinear', 'threshold')
random_state: 随机种子
"""
super().__init__(
n_samples=n_samples,
treatment_rate=treatment_rate,
noise_level=noise_level,
effect_size=effect_size,
random_state=random_state
)
self.n_features = n_features
self.cate_function = cate_function
def _compute_cate(self, X: np.ndarray) -> np.ndarray:
"""
计算条件处理效应
参数:
X: 特征矩阵形状为 (n_samples, n_features)
返回:
条件处理效应数组
"""
if self.cate_function == 'linear':
# 线性 CATE: τ(X) = τ0 + β*X
beta_cate = np.random.uniform(-0.5, 0.5, self.n_features)
cate = self.effect_size + np.dot(X, beta_cate)
elif self.cate_function == 'nonlinear':
# 非线性 CATE: τ(X) = τ0 + sin(β*X)
beta_cate = np.random.uniform(-1, 1, self.n_features)
linear_comb = np.dot(X, beta_cate)
cate = self.effect_size + np.sin(linear_comb)
elif self.cate_function == 'threshold':
# 阈值 CATE: τ(X) = τ0 + β*X * I(X > threshold)
beta_cate = np.random.uniform(-0.5, 0.5, self.n_features)
threshold = np.random.uniform(-1, 1)
linear_comb = np.dot(X, beta_cate)
cate = self.effect_size + linear_comb * (linear_comb > threshold).astype(float)
else:
raise ValueError(f"未知的 CATE 函数类型:{self.cate_function}")
return cate
def generate(self) -> pd.DataFrame:
"""
生成 CATE 场景的模拟数据
返回:
包含特征 X处理变量 T 和结果变量 Y DataFrame
"""
# 生成特征
X = np.random.normal(0, 1, (self.n_samples, self.n_features))
# 生成处理变量
T = np.random.binomial(1, self.treatment_rate, self.n_samples)
# 生成噪声
noise = np.random.normal(0, self.noise_level, self.n_samples)
# 计算 CATE
cate = self._compute_cate(X)
# 生成协变量系数(用于结果模型)
beta_y = np.random.uniform(-0.5, 0.5, self.n_features)
# 生成结果变量Y = β*X + τ(X)*T + ε
Y = np.dot(X, beta_y) + cate * T + noise
# 创建 DataFrame
df_dict = {}
for i in range(self.n_features):
df_dict[f'X{i}'] = X[:, i]
df_dict['T'] = T
df_dict['Y'] = Y
# 添加 CATE 列(用于验证)
df_dict['CATE'] = cate
df = pd.DataFrame(df_dict)
return df
def get_description(self) -> str:
"""获取场景描述"""
return (
f"CATE 场景:\n"
f"- 样本数量:{self.n_samples}\n"
f"- 处理比例:{self.treatment_rate:.2%}\n"
f"- 噪声水平:{self.noise_level}\n"
f"- 平均处理效应基准:{self.effect_size}\n"
f"- 特征数量:{self.n_features}\n"
f"- CATE 函数类型:{self.cate_function}\n"
f"模型Y = β*X + τ(X)*T + ε"
)
class DataSimulatorFactory:
"""
数据模拟器工厂类
用于创建不同类型的模拟器实例
"""
_simulators = {
'simple_ate': SimpleATESimulator,
'covariate': CovariateSimulator,
'interaction': InteractionEffectSimulator,
'cate': CATESimulator
}
@classmethod
def create(cls, simulator_type: str, **kwargs) -> BaseDataSimulator:
"""
创建指定类型的模拟器
参数:
simulator_type: 模拟器类型 ('simple_ate', 'covariate', 'interaction', 'cate')
**kwargs: 模拟器初始化参数
返回:
模拟器实例
Raises:
ValueError: 未知的模拟器类型
"""
if simulator_type not in cls._simulators:
available_types = ', '.join(cls._simulators.keys())
raise ValueError(
f"未知的模拟器类型:{simulator_type}"
f"可用类型:{available_types}"
)
return cls._simulators[simulator_type](**kwargs)
@classmethod
def list_available_types(cls) -> list:
"""
获取所有可用的模拟器类型
返回:
模拟器类型列表
"""
return list(cls._simulators.keys())
# 便捷函数
def generate_simple_ate_data(
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
intercept: float = 0.0,
random_state: Optional[int] = None
) -> pd.DataFrame:
"""
生成简单 ATE 场景的模拟数据便捷函数
参数:
n_samples: 样本数量
treatment_rate: 处理比例
noise_level: 噪声水平
effect_size: 平均处理效应
intercept: 截距项
random_state: 随机种子
返回:
包含 'T' 'Y' DataFrame
"""
simulator = SimpleATESimulator(
n_samples=n_samples,
treatment_rate=treatment_rate,
noise_level=noise_level,
effect_size=effect_size,
intercept=intercept,
random_state=random_state
)
return simulator.generate()
def generate_covariate_data(
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
n_covariates: int = 3,
covariate_distribution: str = 'normal',
random_state: Optional[int] = None
) -> pd.DataFrame:
"""
生成协变量场景的模拟数据便捷函数
参数:
n_samples: 样本数量
treatment_rate: 处理比例
noise_level: 噪声水平
effect_size: 处理效应
n_covariates: 协变量数量
covariate_distribution: 协变量分布类型
random_state: 随机种子
返回:
包含协变量'T' 'Y' DataFrame
"""
simulator = CovariateSimulator(
n_samples=n_samples,
treatment_rate=treatment_rate,
noise_level=noise_level,
effect_size=effect_size,
n_covariates=n_covariates,
covariate_distribution=covariate_distribution,
random_state=random_state
)
return simulator.generate()
def generate_interaction_data(
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
interaction_strength: float = 0.5,
n_covariates: int = 2,
random_state: Optional[int] = None
) -> pd.DataFrame:
"""
生成交互效应场景的模拟数据便捷函数
参数:
n_samples: 样本数量
treatment_rate: 处理比例
noise_level: 噪声水平
effect_size: 基础处理效应
interaction_strength: 交互效应强度
n_covariates: 协变量数量
random_state: 随机种子
返回:
包含协变量'T' 'Y' DataFrame
"""
simulator = InteractionEffectSimulator(
n_samples=n_samples,
treatment_rate=treatment_rate,
noise_level=noise_level,
effect_size=effect_size,
interaction_strength=interaction_strength,
n_covariates=n_covariates,
random_state=random_state
)
return simulator.generate()
def generate_cate_data(
n_samples: int = 1000,
treatment_rate: float = 0.5,
noise_level: float = 1.0,
effect_size: float = 1.0,
n_features: int = 5,
cate_function: str = 'linear',
random_state: Optional[int] = None
) -> pd.DataFrame:
"""
生成 CATE 场景的模拟数据便捷函数
参数:
n_samples: 样本数量
treatment_rate: 处理比例
noise_level: 噪声水平
effect_size: 平均处理效应基准
n_features: 特征数量
cate_function: CATE 函数类型
random_state: 随机种子
返回:
包含特征'T''Y' 'CATE' DataFrame
"""
simulator = CATESimulator(
n_samples=n_samples,
treatment_rate=treatment_rate,
noise_level=noise_level,
effect_size=effect_size,
n_features=n_features,
cate_function=cate_function,
random_state=random_state
)
return simulator.generate()
# 示例用法
if __name__ == '__main__':
print("=" * 60)
print("模拟数据生成模块示例")
print("=" * 60)
# 1. 简单 ATE 场景
print("\n1. 简单 ATE 场景")
print("-" * 40)
simple_simulator = SimpleATESimulator(
n_samples=500,
treatment_rate=0.5,
noise_level=1.0,
effect_size=2.0,
intercept=5.0,
random_state=42
)
print(simple_simulator.get_description())
simple_df = simple_simulator.generate()
print(f"\n数据形状:{simple_df.shape}")
print(f"前 5 行:\n{simple_df.head()}")
# 2. 协变量场景
print("\n\n2. 协变量场景")
print("-" * 40)
covariate_simulator = CovariateSimulator(
n_samples=500,
treatment_rate=0.4,
noise_level=1.5,
effect_size=1.5,
n_covariates=3,
covariate_distribution='normal',
random_state=42
)
print(covariate_simulator.get_description())
covariate_df = covariate_simulator.generate()
print(f"\n数据形状:{covariate_df.shape}")
print(f"前 5 行:\n{covariate_df.head()}")
# 3. 交互效应场景
print("\n\n3. 交互效应场景")
print("-" * 40)
interaction_simulator = InteractionEffectSimulator(
n_samples=500,
treatment_rate=0.5,
noise_level=1.0,
effect_size=1.0,
interaction_strength=0.5,
n_covariates=2,
random_state=42
)
print(interaction_simulator.get_description())
interaction_df = interaction_simulator.generate()
print(f"\n数据形状:{interaction_df.shape}")
print(f"前 5 行:\n{interaction_df.head()}")
# 4. CATE 场景
print("\n\n4. CATE 场景")
print("-" * 40)
cate_simulator = CATESimulator(
n_samples=500,
treatment_rate=0.5,
noise_level=1.0,
effect_size=1.5,
n_features=4,
cate_function='linear',
random_state=42
)
print(cate_simulator.get_description())
cate_df = cate_simulator.generate()
print(f"\n数据形状:{cate_df.shape}")
print(f"前 5 行:\n{cate_df.head()}")
# 5. 使用工厂创建模拟器
print("\n\n5. 使用工厂创建模拟器")
print("-" * 40)
print(f"可用模拟器类型:{DataSimulatorFactory.list_available_types()}")
factory_simulator = DataSimulatorFactory.create(
'simple_ate',
n_samples=300,
treatment_rate=0.6,
effect_size=2.5,
random_state=42
)
print(factory_simulator.get_description())
# 6. 保存数据
print("\n\n6. 保存数据示例")
print("-" * 40)
output_dir = Path('data/output')
output_dir.mkdir(exist_ok=True)
simple_simulator.save_to_csv(output_dir / 'simple_ate_data.csv')
print(f"已保存:{output_dir / 'simple_ate_data.csv'}")
simple_simulator.save_to_excel(output_dir / 'simple_ate_data.xlsx')
print(f"已保存:{output_dir / 'simple_ate_data.xlsx'}")
# 7. 使用便捷函数
print("\n\n7. 使用便捷函数")
print("-" * 40)
simple_df = generate_simple_ate_data(n_samples=200, effect_size=3.0, random_state=123)
print(f"便捷函数生成数据形状:{simple_df.shape}")
print(f"前 5 行:\n{simple_df.head()}")
print("\n" + "=" * 60)
print("示例完成!")
print("=" * 60)

View File

@ -0,0 +1,78 @@
"""
医疗数据生成器
用于模拟构建医疗测试数据
- id: 唯一标识符
- treatment: 是否吃药0 1
- health: 病人健康状态0~1 浮点数越高越好
"""
import pandas as pd
import numpy as np
import os
def generate_medical_data(n_samples: int = 500, output_path: str = "examples/medical/data.xlsx") -> str:
"""
生成医疗测试数据并保存到 Excel 文件
Args:
n_samples: 样本数量默认 500
output_path: 输出文件路径
Returns:
生成的文件路径
"""
# 设置随机种子以确保可重复性
np.random.seed(42)
# 生成唯一 ID
ids = list(range(1, n_samples + 1))
# 生成是否吃药0 或 1假设 40% 的人吃药
treatment = np.random.binomial(1, 0.4, n_samples)
# 生成健康状态0~1 浮点数)
# 健康状态受是否吃药影响,吃药的人健康状态平均更高
# 基础健康状态 + 吃药的额外影响 + 随机噪声
base_health = np.random.beta(2, 2, n_samples) # 基础健康分布
treatment_effect = treatment * 0.2 # 吃药带来 0.2 的健康提升
noise = np.random.normal(0, 0.1, n_samples) # 随机噪声
health = np.clip(base_health + treatment_effect + noise, 0, 1)
# 创建 DataFrame
df = pd.DataFrame({
'id': ids,
'treatment': treatment,
'health': np.round(health, 4)
})
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# 确保输出目录存在
output_dir = os.path.dirname(output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# 保存到 Excel
df.to_excel(output_path, index=False)
print(f"成功生成 {n_samples} 条医疗数据,已保存到:{output_path}")
print(f"文件已创建:{output_path}")
print(f"数据预览:")
print(df.head(10))
print(f"\n统计信息:")
print(f"吃药人数:{treatment.sum()} ({treatment.mean()*100:.1f}%)")
print(f"健康状态均值:{health.mean():.4f}")
print(f"吃药组健康均值:{health[treatment==1].mean():.4f}")
print(f"未吃药组健康均值:{health[treatment==0].mean():.4f}")
return output_path
if __name__ == "__main__":
# 默认生成 500 条数据到 data.xlsx
output_file = generate_medical_data(n_samples=500, output_path="data.xlsx")
print(f"\n文件已创建:{output_file}")

View File

@ -0,0 +1,200 @@
"""
数据验证模块
用于检查数据列是否匹配 LLM 识别的处理变量和结果变量
"""
import pandas as pd
from typing import Dict, Any, Tuple, Optional, List
class ValidationResult:
"""验证结果模型"""
def __init__(
self,
is_valid: bool,
treatment: Optional[str] = None,
outcome: Optional[str] = None,
errors: Optional[List[str]] = None,
warnings: Optional[List[str]] = None
):
self.is_valid = is_valid
self.treatment = treatment
self.outcome = outcome
self.errors = errors or []
self.warnings = warnings or []
class DataValidator:
"""数据验证器"""
def __init__(self, data_path: str):
"""
初始化数据验证器
Args:
data_path: 数据文件路径
"""
self.data_path = data_path
self.data = None
def load_data(self) -> pd.DataFrame:
"""加载数据"""
if self.data is None:
self.data = pd.read_excel(self.data_path)
return self.data
def validate_columns(
self,
llm_result: Dict[str, Any]
) -> ValidationResult:
"""
验证数据列是否匹配 LLM 识别的变量
Args:
llm_result: LLM 输出的 JSON 结果包含 treatment outcome 字段
Returns:
验证结果
"""
errors = []
warnings = []
is_valid = True
# 加载数据
try:
self.load_data()
except Exception as e:
return ValidationResult(
is_valid=False,
treatment=None,
outcome=None,
errors=[f"加载数据失败:{str(e)}"],
warnings=[]
)
# 获取数据列名
assert self.data is not None, "数据未加载"
available_columns = list(self.data.columns)
# 获取 LLM 识别的变量
treatment = llm_result.get('treatment')
outcome = llm_result.get('outcome')
# 验证处理变量
if treatment is None:
errors.append("LLM 未识别处理变量 (treatment)")
is_valid = False
elif treatment not in available_columns:
errors.append(f"数据中不存在处理变量 '{treatment}'")
errors.append(f"可用列名:{available_columns}")
is_valid = False
else:
# 检查处理变量类型
treatment_col = self.data[treatment]
if treatment_col.nunique() > 2:
warnings.append(f"处理变量 '{treatment}' 有多个唯一值,可能不是二元变量")
elif treatment_col.dtype not in ['int64', 'float64', 'bool', 'object']:
warnings.append(f"处理变量 '{treatment}' 的数据类型可能不适合因果分析")
# 验证结果变量
if outcome is None:
errors.append("LLM 未识别结果变量 (outcome)")
is_valid = False
elif outcome not in available_columns:
errors.append(f"数据中不存在结果变量 '{outcome}'")
errors.append(f"可用列名:{available_columns}")
is_valid = False
else:
# 检查结果变量类型
outcome_col = self.data[outcome]
if outcome_col.dtype not in ['int64', 'float64']:
warnings.append(f"结果变量 '{outcome}' 的数据类型可能不适合因果分析")
# 检查样本量
if len(self.data) < 10:
warnings.append(f"样本量 ({len(self.data)}) 较小,可能影响分析结果")
# 检查缺失值
missing_treatment = self.data[treatment].isna().sum() if treatment and treatment in available_columns else 0
missing_outcome = self.data[outcome].isna().sum() if outcome and outcome in available_columns else 0
if missing_treatment > 0:
warnings.append(f"处理变量 '{treatment}'{missing_treatment} 个缺失值")
if missing_outcome > 0:
warnings.append(f"结果变量 '{outcome}'{missing_outcome} 个缺失值")
return ValidationResult(
is_valid=is_valid,
treatment=treatment,
outcome=outcome,
errors=errors,
warnings=warnings
)
def validate_and_raise(
self,
llm_result: Dict[str, Any]
) -> Tuple[Optional[str], Optional[str]]:
"""
验证并抛出异常如果验证失败
Args:
llm_result: LLM 输出的 JSON 结果
Returns:
(treatment, outcome) 元组如果验证失败则抛出异常
Raises:
ValueError: 验证失败时抛出
"""
result = self.validate_columns(llm_result)
if not result.is_valid:
error_msg = "数据验证失败:\n"
for error in result.errors:
error_msg += f" - {error}\n"
raise ValueError(error_msg)
return result.treatment, result.outcome
# 便捷函数
def validate_data(data_path: str, llm_result: Dict[str, Any]) -> ValidationResult:
"""
便捷函数验证数据
Args:
data_path: 数据文件路径
llm_result: LLM 输出的 JSON 结果
Returns:
验证结果
"""
validator = DataValidator(data_path)
return validator.validate_columns(llm_result)
if __name__ == "__main__":
# 测试数据验证
print("测试数据验证模块")
print("=" * 50)
# 测试 1: 有效数据
print("\n测试 1: 有效数据")
validator = DataValidator("examples/medical/data.xlsx")
llm_result = {"treatment": "treatment", "outcome": "health"}
result = validator.validate_columns(llm_result)
print(f" 是否有效:{result.is_valid}")
print(f" 处理变量:{result.treatment}")
print(f" 结果变量:{result.outcome}")
print(f" 错误:{result.errors}")
print(f" 警告:{result.warnings}")
# 测试 2: 无效数据(变量名不存在)
print("\n测试 2: 无效数据(变量名不存在)")
llm_result_invalid = {"treatment": "invalid_col", "outcome": "health"}
result_invalid = validator.validate_columns(llm_result_invalid)
print(f" 是否有效:{result_invalid.is_valid}")
print(f" 错误:{result_invalid.errors}")

View File

@ -0,0 +1,166 @@
"""
LLM 调用模块
用于调用外部 LLM API 进行文本生成
"""
import os
import json
from typing import Dict, Any, Optional, List
import requests
class LLMClient:
"""LLM API 客户端"""
def __init__(
self,
base_url: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None
):
"""
初始化 LLM 客户端
Args:
base_url: API 基础 URL默认从环境变量 LLM_BASE_URL 读取
model: 模型名称默认从环境变量 LLM_MODEL 读取
temperature: 温度参数默认从环境变量 LLM_TEMPERATURE 读取
max_tokens: 最大生成长度默认从环境变量 LLM_MAX_TOKENS 读取
api_key: API 密钥可选从环境变量 LLM_API_KEY 读取
"""
import os
self.base_url = base_url or os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1")
self.model = model or os.getenv("LLM_MODEL", "qwen3.5-35b")
self.temperature = temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.3"))
self.max_tokens = max_tokens if max_tokens is not None else int(os.getenv("LLM_MAX_TOKENS", "2048"))
self.api_key = api_key or os.getenv("LLM_API_KEY", "")
self._headers = {
"Content-Type": "application/json"
}
if self.api_key:
self._headers["Authorization"] = f"Bearer {self.api_key}"
def chat_completion(
self,
messages: List[Dict[str, str]],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""
发送聊天请求并获取响应
Args:
messages: 消息列表格式为 [{"role": "system/user/assistant", "content": "..."}]
temperature: 温度参数覆盖默认值
max_tokens: 最大生成长度覆盖默认值
Returns:
API 响应字典
"""
url = f"{self.base_url}/chat/completions"
payload = {
"model": self.model,
"messages": messages,
"temperature": temperature or self.temperature,
"max_tokens": max_tokens or self.max_tokens,
"stream": False,
"extra_body": {
"chat_template_kwargs": {
"enable_thinking": False
}
}
}
try:
response = requests.post(
url,
headers=self._headers,
json=payload,
timeout=120
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
return {
"error": str(e),
"status_code": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
}
def generate_response(
self,
system_prompt: str,
user_prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""
生成响应简化接口
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
max_tokens: 最大生成长度
Returns:
包含响应内容的字典
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
result = self.chat_completion(messages, temperature, max_tokens)
if "error" in result:
return {
"success": False,
"error": result["error"],
"content": None
}
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
message = choice.get("message", {})
# 尝试从不同位置获取 content
content = message.get("content")
# 如果 content 是 None检查是否有 reasoning 或其他字段
if content is None:
# 有些模型可能将内容放在 reasoning 字段中
reasoning = message.get("reasoning", "")
if reasoning:
content = reasoning
# 如果仍然没有内容,返回错误
if content is None:
return {
"success": False,
"error": "No content found in response",
"content": None
}
return {
"success": True,
"error": None,
"content": content,
"usage": result.get("usage", {}),
"model": result.get("model", self.model)
}
return {
"success": False,
"error": f"Unexpected response format: {result}",
"content": None
}
# 全局 LLM 客户端实例(可通过环境变量配置)
def get_llm_client() -> LLMClient:
"""获取 LLM 客户端实例(配置优先从环境变量读取)"""
return LLMClient() # 所有默认值已从环境变量读取

200
examples/medical/log.md Normal file
View File

@ -0,0 +1,200 @@
# 因果分析日志
## 日志说明
本文档记录 LLM 进行因果分析时的所有输入参数和输出结果。
## 分析记录
---
### 分析 #001
**时间**: 2026-03-26T23:19:00.211723
#### 系统提示词
```
你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别因果变量并评估因果关系。
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
JSON 输出规范:
{
"treatment": "处理变量名称",
"outcome": "结果变量名称"
}
注意:
- 只输出上述 JSON 格式,不要包含其他字段
- 处理变量和结果变量名称必须与数据表格的列名完全一致
- 不要使用 markdown 代码块标记(如 ```json
- 直接输出纯 JSON 字符串
```
#### 用户提示词
```
请分析以下医疗数据,并严格按照 JSON 格式输出分析结果:
**数据概览:**
- 样本数量500
- 变量id, treatment, health
**统计摘要:**
id treatment health
count 500.000000 500.000000 500.000000
mean 250.500000 0.414000 0.586082
std 144.481833 0.493042 0.236948
min 1.000000 0.000000 0.000000
25% 125.750000 0.000000 0.421975
50% 250.500000 0.000000 0.583000
75% 375.250000 1.000000 0.769625
max 500.000000 1.000000 1.000000
**处理变量分布:**
treatment
0 293
1 207
JSON 输出格式要求:
{
"treatment": "处理变量名称",
"outcome": "结果变量名称"
}
要求:
1. 处理变量和结果变量名称必须与表格列名完全一致
2. 只输出 JSON不要包含其他任何内容
3. 不要使用 markdown 代码块标记
```
#### LLM 输出
```
{
"treatment": "treatment",
"outcome": "health"
}
```
#### 调用参数
```json
{
"data_path": "examples/medical/data.xlsx",
"sample_size": 500,
"variables": [
"id",
"treatment",
"health"
],
"treatment_variable": "treatment",
"outcome_variable": "health",
"llm_params": {
"base_url": "http://10.106.123.247:8000/v1",
"model": "qwen3.5-35b",
"temperature": 0.3,
"max_tokens": 2048
},
"log_path": "examples/medical/log.md"
}
```
---
---
### 分析 #002
**时间**: 2026-03-26T23:19:53.065927
#### 系统提示词
```
你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别因果变量并评估因果关系。
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
JSON 输出规范:
{
"treatment": "处理变量名称",
"outcome": "结果变量名称"
}
注意:
- 只输出上述 JSON 格式,不要包含其他字段
- 处理变量和结果变量名称必须与数据表格的列名完全一致
- 不要使用 markdown 代码块标记(如 ```json
- 直接输出纯 JSON 字符串
```
#### 用户提示词
```
请分析以下医疗数据,并严格按照 JSON 格式输出分析结果:
**数据概览:**
- 样本数量500
- 变量id, treatment, health
**统计摘要:**
id treatment health
count 500.000000 500.000000 500.000000
mean 250.500000 0.414000 0.586082
std 144.481833 0.493042 0.236948
min 1.000000 0.000000 0.000000
25% 125.750000 0.000000 0.421975
50% 250.500000 0.000000 0.583000
75% 375.250000 1.000000 0.769625
max 500.000000 1.000000 1.000000
**处理变量分布:**
treatment
0 293
1 207
JSON 输出格式要求:
{
"treatment": "处理变量名称",
"outcome": "结果变量名称"
}
要求:
1. 处理变量和结果变量名称必须与表格列名完全一致
2. 只输出 JSON不要包含其他任何内容
3. 不要使用 markdown 代码块标记
```
#### LLM 输出
```
{
"treatment": "treatment",
"outcome": "health"
}
```
#### 调用参数
```json
{
"data_path": "examples/medical/data.xlsx",
"sample_size": 500,
"variables": [
"id",
"treatment",
"health"
],
"treatment_variable": "treatment",
"outcome_variable": "health",
"llm_params": {
"base_url": "http://10.106.123.247:8000/v1",
"model": "qwen3.5-35b",
"temperature": 0.3,
"max_tokens": 2048
},
"log_path": "examples/medical/log.md"
}
```
---

356
examples/medical/start.py Normal file
View File

@ -0,0 +1,356 @@
"""
LLM 因果分析测试入口
使用自然语言输入让 LLM 测试分析因果变量并将所有参数记录到 log.md 文档中
"""
import os
import json
from datetime import datetime
from typing import Dict, Any, Optional
import pandas as pd
from llm_client import LLMClient
class CausalAnalysisLogger:
"""因果分析日志记录器"""
def __init__(self, log_path: str = "examples/medical/log.md"):
self.log_path = log_path
self.log_entries = []
self._init_log_file()
self._load_existing_entries()
def _init_log_file(self):
"""初始化日志文件"""
header = """# 因果分析日志
## 日志说明
本文档记录 LLM 进行因果分析时的所有输入参数和输出结果
## 分析记录
"""
if not os.path.exists(self.log_path):
with open(self.log_path, 'w', encoding='utf-8') as f:
f.write(header)
def _load_existing_entries(self):
"""从日志文件中加载已存在的分析记录,以确定下一个分析 ID"""
if not os.path.exists(self.log_path):
return
analysis_count = 0
with open(self.log_path, 'r', encoding='utf-8') as f:
content = f.read()
# 统计已存在的分析记录数量
import re
matches = re.findall(r'### 分析 #(\d+)', content)
if matches:
analysis_count = max(int(m) for m in matches)
self.analysis_count = analysis_count
def _append_to_log(self, content: str):
"""追加内容到日志文件"""
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(content)
def log_analysis(self, system_prompt: str, user_prompt: str,
model_output: str, parameters: Dict[str, Any]):
"""记录一次完整的分析过程"""
# 增加分析计数
self.analysis_count += 1
analysis_id = f"{self.analysis_count:03d}"
timestamp = datetime.now().isoformat()
entry = f"""
---
### 分析 #{analysis_id}
**时间**: {timestamp}
#### 系统提示词
```
{system_prompt}
```
#### 用户提示词
```
{user_prompt}
```
#### LLM 输出
```
{model_output}
```
#### 调用参数
```json
{json.dumps(parameters, indent=2, ensure_ascii=False)}
```
---
"""
self._append_to_log(entry)
self.log_entries.append({
'id': analysis_id,
'timestamp': timestamp,
'system_prompt': system_prompt,
'user_prompt': user_prompt,
'model_output': model_output,
'parameters': parameters
})
print(f"分析 #{analysis_id} 已记录到 {self.log_path}")
class CausalAnalysisAgent:
"""因果分析代理"""
def __init__(self, data_path: str = "examples/medical/data.xlsx", log_path: str = "examples/medical/log.md"):
self.data_path = data_path
self.logger = CausalAnalysisLogger(log_path)
self.data = None
def load_data(self) -> pd.DataFrame:
"""加载数据"""
if not os.path.exists(self.data_path):
raise FileNotFoundError(f"数据文件不存在:{self.data_path}")
self.data = pd.read_excel(self.data_path)
print(f"成功加载数据:{self.data_path}")
print(f"数据形状:{self.data.shape}")
print(f"列名:{list(self.data.columns)}")
print(f"\n数据预览:")
print(self.data.head())
return self.data
def _generate_system_prompt(self) -> str:
"""生成系统提示词"""
return """你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别因果变量并评估因果关系。
请以 JSON 格式输出分析结果不要包含任何额外的解释或思考过程
JSON 输出规范
{
"treatment": "处理变量名称",
"outcome": "结果变量名称"
}
注意
- 只输出上述 JSON 格式不要包含其他字段
- 处理变量和结果变量名称必须与数据表格的列名完全一致
- 不要使用 markdown 代码块标记 ```json
- 直接输出纯 JSON 字符串"""
def _generate_user_prompt(self, data_info: str) -> str:
"""生成用户提示词"""
return f"""请分析以下医疗数据,并严格按照 JSON 格式输出分析结果:
{data_info}
JSON 输出格式要求
{{
"treatment": "处理变量名称",
"outcome": "结果变量名称"
}}
要求
1. 处理变量和结果变量名称必须与表格列名完全一致
2. 只输出 JSON不要包含其他任何内容
3. 不要使用 markdown 代码块标记"""
def _call_llm(self, system_prompt: str, user_prompt: str) -> Dict[str, Any]:
"""
调用 LLM API 获取响应
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
Returns:
包含响应内容的字典
"""
# 初始化 LLM 客户端(从环境变量读取配置)
import os
llm_client = LLMClient(
base_url=os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1"),
model=os.getenv("LLM_MODEL", "qwen3.5-35b"),
temperature=float(os.getenv("LLM_TEMPERATURE", "0.3")),
max_tokens=int(os.getenv("LLM_MAX_TOKENS", "2048"))
)
# 调用 LLM
result = llm_client.generate_response(system_prompt, user_prompt)
return result
def analyze(self, custom_prompt: Optional[str] = None) -> Dict[str, Any]:
"""
执行因果分析
Args:
custom_prompt: 自定义用户提示词如果为 None 则使用默认提示词
Returns:
分析结果字典
"""
# 加载数据
if self.data is None:
self.load_data()
# 类型断言:确保 self.data 不是 None
assert self.data is not None, "数据未加载"
# 生成数据信息
data_info = f"""
**数据概览:**
- 样本数量{len(self.data)}
- 变量{', '.join(self.data.columns)}
**统计摘要:**
{self.data.describe()}
**处理变量分布:**
{self.data['treatment'].value_counts().to_string()}
"""
# 生成提示词
system_prompt = self._generate_system_prompt()
user_prompt = custom_prompt if custom_prompt else self._generate_user_prompt(data_info)
# 计算统计值
treatment_mean = self.data[self.data['treatment'] == 1]['health'].mean() if self.data['treatment'].sum() > 0 else 0
control_mean = self.data[self.data['treatment'] == 0]['health'].mean() if self.data['treatment'].sum() < len(self.data) else 0
ate = treatment_mean - control_mean if self.data['treatment'].sum() > 0 and self.data['treatment'].sum() < len(self.data) else 0
# 调用 LLM 获取真实响应
llm_result = self._call_llm(system_prompt, user_prompt)
# 打印 LLM 调用结果用于调试
print(f"\nLLM 调用结果:{llm_result}")
if llm_result.get("success"):
model_output = llm_result.get("content", "LLM 返回内容为空")
else:
# 如果 LLM 调用失败,使用模拟响应
model_output = f"""## 因果分析结果LLM 调用失败,使用模拟响应)
### 1. 变量识别
- **处理变量 (Treatment)**: treatment (是否吃药)
- 类型二元变量 (0/1)
- 含义1 表示吃药0 表示不吃药
- **结果变量 (Outcome)**: health (健康状态)
- 类型连续变量 (0-1 浮点数)
- 含义健康状态评分越高越好
### 2. 因果效应估计
根据数据描述性统计
- 吃药组 (treatment=1) 的平均健康状态{treatment_mean:.4f}
- 未吃药组 (treatment=0) 的平均健康状态{control_mean:.4f}
- **平均处理效应 (ATE)**: {ate:.4f}
### 3. 分析方法
使用简单的组间比较方法估计因果效应
- 方法均值差异 (Mean Difference)
- 假设无混杂因素或混杂因素已控制
### 4. 结论
吃药对健康状态有正向的因果效应估计效应大小为 {ate:.4f}
这意味着吃药可以使健康状态平均提高 {ate * 100:.1f}%
### 5. 建议
- 建议进行更严格的因果推断分析如倾向得分匹配工具变量法等
- 考虑控制可能的混杂因素如年龄基础健康状况等
---
**LLM 调用错误**: {llm_result.get('error', 'Unknown error')}"""
# 准备参数
parameters = {
'data_path': self.data_path,
'sample_size': len(self.data),
'variables': list(self.data.columns),
'treatment_variable': 'treatment',
'outcome_variable': 'health'
}
# 记录到日志(包含 LLM 调用信息)
import os
llm_params = {
'base_url': os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1"),
'model': os.getenv("LLM_MODEL", "qwen3.5-35b"),
'temperature': float(os.getenv("LLM_TEMPERATURE", "0.3")),
'max_tokens': int(os.getenv("LLM_MAX_TOKENS", "2048"))
}
parameters['llm_params'] = llm_params
parameters['data_path'] = self.data_path
parameters['log_path'] = self.logger.log_path
self.logger.log_analysis(system_prompt, user_prompt, model_output, parameters)
# 打印结果
print(f"\n{'='*60}")
print(f"分析 #{self.logger.analysis_count:03d} 结果:")
print(f"{'='*60}")
print(model_output)
return {
'id': f"{self.logger.analysis_count:03d}",
'system_prompt': system_prompt,
'user_prompt': user_prompt,
'model_output': model_output,
'parameters': parameters
}
def interactive_analyze(self):
"""交互式分析模式"""
print("=" * 60)
print("LLM 因果分析测试系统")
print("=" * 60)
# 加载数据
self.load_data()
while True:
print("\n" + "-" * 40)
print("请输入分析提示词(输入 'quit' 退出,'default' 使用默认):")
user_input = input("> ").strip()
if user_input.lower() == 'quit':
print("感谢使用,再见!")
break
elif user_input.lower() == 'default':
user_input = None
try:
result = self.analyze(custom_prompt=user_input)
print(f"\n分析已记录到:{self.logger.log_path}")
except Exception as e:
print(f"分析出错:{e}")
def main():
"""主函数"""
# 检查数据文件是否存在,不存在则生成
data_path = "examples/medical/data.xlsx"
if not os.path.exists(data_path):
print("数据文件不存在,正在生成...")
from data_generator import generate_medical_data
generate_medical_data(n_samples=500, output_path=data_path)
# 创建分析代理并执行分析
agent = CausalAnalysisAgent(data_path=data_path, log_path="examples/medical/log.md")
# 执行分析
result = agent.analyze()
print("\n" + "=" * 60)
print("测试完成!")
print(f"日志已保存到:{agent.logger.log_path}")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,174 @@
"""
医疗测试数据生成器
生成符合因果推断测试要求的医疗数据包含以下特征
- id: 唯一标识符
- treatment: 是否吃药0 1吃药比例 2/3
- health: 病人的健康状态0~1 浮点数越高越好
- base_health: 病人吃药后的健康状态如果没吃药就和 health 栏一致
- age: 年龄18~70 年龄越大健康状态越低药效越显著
数据生成逻辑
- 基础健康状态服从 Beta 分布
- 年龄对健康有负面影响年龄越大健康越低
- 年龄对药效有调节作用年龄越大吃药带来的健康提升越显著
"""
import numpy as np
import pandas as pd
from typing import Tuple
def generate_medical_data(
n_samples: int = 500,
treatment_ratio: float = 2/3,
seed: int = 42
) -> pd.DataFrame:
"""
生成医疗测试数据
参数:
n_samples: 样本数量默认 500
treatment_ratio: 吃药比例默认 2/3
seed: 随机种子用于可重复性
返回:
包含生成数据的 DataFrame
"""
# 设置随机种子以确保可重复性
np.random.seed(seed)
# 1. 生成唯一标识符
ids = list(range(1, n_samples + 1))
# 2. 生成年龄18~70 岁,均匀分布)
ages = np.random.randint(18, 71, size=n_samples)
# 3. 生成基础健康状态(服从 Beta 分布)
# Beta(2, 2) 分布在 [0,1] 区间,均值 0.5,适合模拟健康状态
base_health_raw = np.random.beta(2, 2, size=n_samples)
# 4. 计算年龄对健康的影响
# 将年龄归一化到 [0,1],然后乘以一个系数
age_normalized = (ages - 18) / (70 - 18) # 归一化到 [0,1]
age_health_effect = 0.3 * age_normalized # 年龄对健康的影响系数
# 5. 计算基础健康状态(考虑年龄影响)
base_health = base_health_raw * (1 - age_health_effect)
# 确保在 [0,1] 范围内
base_health = np.clip(base_health, 0, 1)
# 6. 生成 treatment是否吃药
# 治疗分配受年龄和基线健康影响,产生混杂:
# 年龄越大、基线健康越差的患者,越可能接受治疗
logit_p = 1.5 * age_normalized - 2.0 * base_health + 0.8
prob_treatment = 1 / (1 + np.exp(-logit_p))
treatments = (np.random.uniform(0, 1, size=n_samples) < prob_treatment).astype(int)
# 7. 计算药效(年龄越大,药效越显著)
# 药效系数0.1(年轻)到 0.4(年老)
treatment_effect = 0.1 + 0.3 * age_normalized
# 8. 计算最终健康状态
# 如果吃药health = base_health + treatment_effect
# 如果没吃药health = base_health
health = base_health.copy()
health[treatments == 1] = np.clip(
base_health[treatments == 1] + treatment_effect[treatments == 1],
0, 1
)
# 9. 计算 base_health 列
# 如果没吃药base_health = health
# 如果吃药base_health = health - treatment_effect即不吃药时的健康状态
base_health_final = health.copy()
base_health_final[treatments == 1] = np.clip(
health[treatments == 1] - treatment_effect[treatments == 1],
0, 1
)
# 10. 创建 DataFrame
df = pd.DataFrame({
'id': ids,
'treatment': treatments,
'health': np.round(health, 4),
'base_health': np.round(base_health_final, 4),
'age': ages
})
return df
def print_data_preview(df: pd.DataFrame) -> None:
"""打印数据预览"""
print("=" * 60)
print("数据预览(前 10 行)")
print("=" * 60)
print(df.head(10).to_string(index=False))
print()
def print_statistics(df: pd.DataFrame) -> None:
"""打印统计信息"""
print("=" * 60)
print("统计信息")
print("=" * 60)
# 总体统计
print("\n【总体统计】")
print(df.describe().round(4))
# Treatment 分布
print("\n【Treatment 分布】")
treatment_dist = df['treatment'].value_counts()
print(f"不吃药 (0): {treatment_dist.get(0, 0)} 人 ({100*treatment_dist.get(0, 0)/len(df):.1f}%)")
print(f"吃药 (1): {treatment_dist.get(1, 0)} 人 ({100*treatment_dist.get(1, 0)/len(df):.1f}%)")
# 按 treatment 分组的统计
print("\n【按 Treatment 分组的健康状态统计】")
grouped = df.groupby('treatment')['health'].agg(['mean', 'std', 'min', 'max'])
print(grouped.round(4))
# 年龄与药效的关系
print("\n【年龄与药效关系分析】")
treated_df = df[df['treatment'] == 1].copy()
treated_df['age_normalized'] = (treated_df['age'] - 18) / (70 - 18)
treated_df['health_gain'] = treated_df['health'] - treated_df['base_health']
# 按年龄四分位数分组
treated_df['age_quartile'] = pd.qcut(treated_df['age'], q=4, labels=['Q1( youngest)', 'Q2', 'Q3', 'Q4(oldest)'])
quartile_analysis = treated_df.groupby('age_quartile')['health_gain'].agg(['mean', 'std', 'count'])
print(quartile_analysis.round(4))
# 相关性分析
print("\n【相关性分析】")
print(f"年龄与健康的相关系数:{df['age'].corr(df['health']):.4f}")
print(f"年龄与药效的相关系数:{treated_df['age_normalized'].corr(treated_df['health_gain']):.4f}")
def main():
"""主函数"""
print("医疗测试数据生成器")
print("-" * 40)
# 生成数据
print("正在生成数据...")
df = generate_medical_data(n_samples=500, treatment_ratio=2/3, seed=42)
# 打印预览
print_data_preview(df)
# 打印统计信息
print_statistics(df)
# 保存为 Excel使用绝对路径
import os
output_file = os.path.join(os.path.dirname(__file__), "data.xlsx")
df.to_excel(output_file, index=False)
print(f"\n数据已保存到:{output_file}")
return df
if __name__ == "__main__":
main()

1196
examples/medical_v2/log.md Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,46 @@
"""
医疗数据因果推断分析示例v2
本脚本演示如何使用 causal_agent 主包对医疗数据进行完整的因果推断分析
"""
import os
import sys
# 将项目根目录加入路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from causal_agent.agent import CausalInferenceAgent
from causal_agent.core.config import AgentConfig
def main():
data_path = "examples/medical_v2/data.xlsx"
# 如果数据不存在,自动生成示例数据
if not os.path.exists(data_path):
print("数据文件不存在,正在生成示例医疗数据...")
from examples.medical_v2.data_generator import generate_medical_data
df = generate_medical_data(n_samples=500, treatment_ratio=2 / 3, seed=42)
df.to_excel(data_path, index=False)
print(f"数据已保存到:{data_path}")
# 配置 Agent可选自定义 LLM 参数、日志路径等)
config = AgentConfig.from_env()
config.log_path = "examples/medical_v2/log.md"
# 创建 Agent 并执行分析
agent = CausalInferenceAgent(config)
result = agent.analyze(data_path)
print("\n" + "=" * 60)
print("示例运行完成!")
print(f"分析 ID{result['id']}")
print(f"日志已保存到:{agent.logger.log_path}")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,23 @@
"""
输入处理模块 (Input Processor Module)
该模块提供文件查询和因果推断解析功能用于处理用户自然语言问题并识别因果推断要素
主要组件:
- FileQueryTool: 文件查询工具支持 CSV Excel 文件
- CausalParser: 因果推断解析器使用 LLM 解析用户问题
"""
from .file_query_tool import FileQueryTool, FileMetadata, ColumnInfo, SampleData
from .causal_parser import CausalParser, CausalInferenceRequest, CausalInferenceResult
__version__ = "1.0.0"
__all__ = [
"FileQueryTool",
"FileMetadata",
"ColumnInfo",
"SampleData",
"CausalParser",
"CausalInferenceRequest",
"CausalInferenceResult",
]

View File

@ -0,0 +1,723 @@
"""
因果推断解析器模块 (Causal Parser)
该模块提供因果推断解析功能使用启发式规则解析用户的自然语言问题
识别因果推断中的关键要素处理变量 T结果变量 Y协变量
"""
import json
from dataclasses import dataclass, asdict, field
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
# 支持直接运行和作为模块导入两种方式
if __name__ == "__main__":
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from file_query_tool import FileQueryTool, FileMetadata
else:
from .file_query_tool import FileQueryTool, FileMetadata
@dataclass
class CausalInferenceRequest:
"""因果推断请求数据类"""
user_question: str # 用户的自然语言问题
file_path: str # 数据文件路径
read_rows: int = 100 # 读取的行数
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return asdict(self)
@dataclass
class VariableInfo:
"""变量信息"""
name: str # 变量名
variable_type: str # 变量类型treatment/outcome/control
confidence: float # 置信度0-1
reasoning: str # 推理理由
data_type: Optional[str] = None # 数据类型
is_binary: Optional[bool] = None # 是否为二值变量
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return asdict(self)
@dataclass
class CausalInferenceResult:
"""因果推断结果数据类"""
original_question: str # 用户原始问题
file_path: str # 文件路径
treatment_variable: Optional[VariableInfo] = None # 处理变量 T
outcome_variable: Optional[VariableInfo] = None # 结果变量 Y
control_variables: List[VariableInfo] = field(default_factory=list) # 协变量列表
confidence_score: float = 0.0 # 整体置信度评分
reasoning: str = "" # 推理理由
alternative_matches: List[Dict[str, Any]] = field(default_factory=list) # 其他可能的匹配
metadata: Dict[str, Any] = field(default_factory=dict) # 元数据
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result = asdict(self)
if self.treatment_variable:
result["treatment_variable"] = self.treatment_variable.to_dict()
if self.outcome_variable:
result["outcome_variable"] = self.outcome_variable.to_dict()
result["control_variables"] = [v.to_dict() for v in self.control_variables]
return result
def to_json(self, indent: int = 2) -> str:
"""转换为 JSON 字符串"""
return json.dumps(self.to_dict(), ensure_ascii=False, indent=indent)
class CausalParser:
"""
因果推断解析器
使用 LLM 解析用户的自然语言问题识别因果推断中的关键要素
Attributes:
file_query_tool: 文件查询工具实例
llm_client: LLM 客户端可选
default_read_rows: 默认读取的行数
Example:
>>> parser = CausalParser()
>>> request = CausalInferenceRequest(
... user_question="教育对收入的影响是什么?",
... file_path="data/education_data.csv"
... )
>>> result = parser.parse(request)
>>> print(result.to_json())
"""
# 内置的提示词模板
DEFAULT_SYSTEM_PROMPT = """你是一位因果推断专家,负责从用户的自然语言问题中识别因果推断的关键要素。
任务
1. 分析用户的问题理解其因果推断意图
2. 根据提供的数据表结构识别以下变量
- 处理变量 (T/Treatment): 被操纵或观察的干预措施
- 结果变量 (Y/Outcome): 受处理影响的指标
- 协变量 (Control): 需要控制的混淆因素
输出要求
- 只输出 JSON 格式不要包含其他内容
- 对于每个变量提供名称类型置信度和推理理由
- 如果无法确定某个变量将其设为 null 或空列表
- 置信度评分范围0.0-1.0
变量类型说明
- treatment: 处理变量通常是实验中的干预措施或观察中的暴露因素
- outcome: 结果变量是我们要研究的效应指标
- control: 协变量/控制变量用于控制混淆因素
示例输出格式
{{
"treatment_variable": {{
"name": "treatment",
"variable_type": "treatment",
"confidence": 0.95,
"reasoning": "该变量表示是否接受干预"
}},
"outcome_variable": {{
"name": "outcome",
"variable_type": "outcome",
"confidence": 0.90,
"reasoning": "该变量表示干预后的结果"
}},
"control_variables": [
{{
"name": "age",
"variable_type": "control",
"confidence": 0.85,
"reasoning": "年龄可能影响结果,需要控制"
}}
],
"confidence_score": 0.90,
"reasoning": "根据问题描述,这是一个典型的因果推断问题..."
}}"""
DEFAULT_USER_PROMPT_TEMPLATE = """用户问题:{question}
数据表结构
{table_info}
请识别因果推断要素并输出 JSON 格式的结果"""
def __init__(self, file_query_tool: Optional[FileQueryTool] = None,
default_read_rows: int = 100):
"""
初始化因果推断解析器
Args:
file_query_tool: 文件查询工具实例如果为 None 则创建默认实例
default_read_rows: 默认读取的行数
"""
self.file_query_tool = file_query_tool or FileQueryTool()
self.default_read_rows = default_read_rows
def parse(self, request: CausalInferenceRequest) -> CausalInferenceResult:
"""
解析因果推断请求
Args:
request: 因果推断请求对象
Returns:
CausalInferenceResult: 解析结果
"""
# 获取文件信息
file_metadata = self.file_query_tool.query(
request.file_path,
request.read_rows
)
# 使用启发式方法进行解析
result = self._heuristic_parse(
request.user_question,
file_metadata
)
return result
def _build_table_info(self, metadata: FileMetadata) -> str:
"""
构建表头信息字符串
Args:
metadata: 文件元数据
Returns:
str: 格式化的表头信息
"""
lines = [
f"文件名:{metadata.file_name}",
f"总行数:{metadata.total_rows}",
f"总列数:{metadata.total_columns}",
"",
"列信息:"
]
if metadata.columns is None:
return "\n".join(lines)
for col in metadata.columns:
lines.append(f" - {col.name}: {col.dtype} (非空:{col.non_null_count}, 空值:{col.null_count}, 唯一值:{col.unique_count})")
if col.sample_values:
sample_str = ", ".join(str(v) for v in col.sample_values[:3])
lines.append(f" 样本值:{sample_str}")
return "\n".join(lines)
def _heuristic_parse(self, question: str, metadata: FileMetadata) -> CausalInferenceResult:
"""
使用启发式规则进行因果推断解析
Args:
question: 用户问题
metadata: 文件元数据
Returns:
CausalInferenceResult: 解析结果
"""
# 启发式规则识别变量
treatment_var = self._heuristic_identify_treatment(question, metadata)
outcome_var = self._heuristic_identify_outcome(question, metadata)
control_vars = self._heuristic_identify_controls(question, metadata, treatment_var, outcome_var)
# 计算置信度
confidence = self._calculate_confidence(treatment_var, outcome_var, control_vars)
# 构建推理理由
reasoning = self._build_reasoning(question, treatment_var, outcome_var, control_vars)
# 查找其他可能的匹配
alternative_matches = self._find_alternative_matches(metadata, treatment_var, outcome_var)
# 构建结果对象
result = CausalInferenceResult(
original_question=question,
file_path=metadata.file_path,
treatment_variable=treatment_var,
outcome_variable=outcome_var,
control_variables=control_vars,
confidence_score=confidence,
reasoning=reasoning,
alternative_matches=alternative_matches,
metadata={
"parsed_at": datetime.now().isoformat(),
"file_info": {
"file_name": metadata.file_name,
"total_rows": metadata.total_rows,
"total_columns": metadata.total_columns
}
}
)
return result
"""
模拟 LLM 解析用于测试
基于启发式规则进行解析不依赖真实的 LLM
Args:
question: 用户问题
metadata: 文件元数据
Returns:
str: 模拟的 JSON 响应
"""
# 启发式规则识别变量
treatment_var = self._heuristic_identify_treatment(question, metadata)
outcome_var = self._heuristic_identify_outcome(question, metadata)
control_vars = self._heuristic_identify_controls(question, metadata, treatment_var, outcome_var)
# 计算置信度
confidence = self._calculate_confidence(treatment_var, outcome_var, control_vars)
# 构建推理理由
reasoning = self._build_reasoning(question, treatment_var, outcome_var, control_vars)
# 构建结果
result = {
"treatment_variable": treatment_var.to_dict() if treatment_var else None,
"outcome_variable": outcome_var.to_dict() if outcome_var else None,
"control_variables": [v.to_dict() for v in control_vars],
"confidence_score": confidence,
"reasoning": reasoning,
"alternative_matches": self._find_alternative_matches(metadata, treatment_var, outcome_var)
}
return json.dumps(result, ensure_ascii=False)
def _heuristic_identify_treatment(self, question: str,
metadata: FileMetadata) -> Optional[VariableInfo]:
"""
启发式识别处理变量
Args:
question: 用户问题
metadata: 文件元数据
Returns:
Optional[VariableInfo]: 处理变量信息
"""
# 常见处理变量关键词
treatment_keywords = [
"处理", "干预", "治疗", "实验组", "对照组", "treatment", "treat",
"干预组", "实验", "exposure", "exposed", "assigned", "group"
]
# 检查列名是否包含处理变量关键词
if metadata.columns is None:
return None
for col in metadata.columns:
col_name_lower = col.name.lower()
for keyword in treatment_keywords:
if keyword in col_name_lower:
return VariableInfo(
name=col.name,
variable_type="treatment",
confidence=0.85,
reasoning=f"列名包含处理变量关键词:{keyword}",
data_type=col.dtype,
is_binary=col.unique_count == 2 if col.unique_count else None
)
# 检查是否有明显的二值变量
for col in metadata.columns:
if col.unique_count == 2 and col.non_null_count == metadata.total_rows:
return VariableInfo(
name=col.name,
variable_type="treatment",
confidence=0.70,
reasoning="发现二值变量,可能是处理变量",
data_type=col.dtype,
is_binary=True
)
return None
def _heuristic_identify_outcome(self, question: str,
metadata: FileMetadata) -> Optional[VariableInfo]:
"""
启发式识别结果变量
Args:
question: 用户问题
metadata: 文件元数据
Returns:
Optional[VariableInfo]: 结果变量信息
"""
# 常见结果变量关键词
outcome_keywords = [
"结果", "影响", "效应", "收入", "工资", "成绩", "分数", "outcome",
"result", "effect", "impact", "dependent", "y", "目标", "指标"
]
# 检查列名是否包含结果变量关键词
if metadata.columns is None:
return None
for col in metadata.columns:
col_name_lower = col.name.lower()
for keyword in outcome_keywords:
if keyword in col_name_lower:
return VariableInfo(
name=col.name,
variable_type="outcome",
confidence=0.85,
reasoning=f"列名包含结果变量关键词:{keyword}",
data_type=col.dtype,
is_binary=False
)
# 检查是否有数值型列(通常是结果变量)
for col in metadata.columns:
if col.dtype in ['float64', 'int64', 'float32', 'int32']:
return VariableInfo(
name=col.name,
variable_type="outcome",
confidence=0.65,
reasoning="发现数值型列,可能是结果变量",
data_type=col.dtype,
is_binary=False
)
return None
def _heuristic_identify_controls(self, question: str,
metadata: FileMetadata,
treatment: Optional[VariableInfo],
outcome: Optional[VariableInfo]) -> List[VariableInfo]:
"""
启发式识别协变量
Args:
question: 用户问题
metadata: 文件元数据
treatment: 处理变量
outcome: 结果变量
Returns:
List[VariableInfo]: 协变量列表
"""
controls = []
treated_names = {treatment.name} if treatment else set()
outcome_names = {outcome.name} if outcome else set()
# 常见协变量关键词
control_keywords = [
"年龄", "性别", "教育", "经验", "控制", "covariate", "control",
"confounder", "confounding", "特征", "变量", "demographic"
]
if metadata.columns is None:
return []
for col in metadata.columns:
# 跳过已识别的变量
if col.name in treated_names or col.name in outcome_names:
continue
col_name_lower = col.name.lower()
# 检查是否包含协变量关键词
is_control = False
reasoning = ""
for keyword in control_keywords:
if keyword in col_name_lower:
is_control = True
reasoning = f"列名包含协变量关键词:{keyword}"
break
# 如果没有关键词,但也不是处理或结果变量,可能是协变量
if not is_control:
is_control = True
reasoning = "可能是协变量,用于控制混淆因素"
controls.append(VariableInfo(
name=col.name,
variable_type="control",
confidence=0.60 if is_control else 0.40,
reasoning=reasoning,
data_type=col.dtype,
is_binary=col.unique_count == 2 if col.unique_count else None
))
return controls
def _calculate_confidence(self, treatment: Optional[VariableInfo],
outcome: Optional[VariableInfo],
controls: List[VariableInfo]) -> float:
"""
计算整体置信度
Args:
treatment: 处理变量
outcome: 结果变量
controls: 协变量列表
Returns:
float: 置信度评分 (0-1)
"""
if treatment is None or outcome is None:
return 0.3
# 基础置信度
base_confidence = (treatment.confidence + outcome.confidence) / 2
# 协变量数量影响
if len(controls) > 0:
avg_control_confidence = sum(c.confidence for c in controls) / len(controls)
base_confidence = (base_confidence + avg_control_confidence) / 2
return min(base_confidence, 1.0)
def _build_reasoning(self, question: str, treatment: Optional[VariableInfo],
outcome: Optional[VariableInfo],
controls: List[VariableInfo]) -> str:
"""
构建推理理由
Args:
question: 用户问题
treatment: 处理变量
outcome: 结果变量
controls: 协变量列表
Returns:
str: 推理理由
"""
reasons = []
if treatment:
reasons.append(f"处理变量 '{treatment.name}' 被识别为干预措施")
if outcome:
reasons.append(f"结果变量 '{outcome.name}' 被识别为效应指标")
if controls:
control_names = [c.name for c in controls[:3]]
if len(controls) > 3:
control_names.append(f"{len(controls) - 3} 个其他变量")
reasons.append(f"协变量 {', '.join(control_names)} 被识别为需要控制的混淆因素")
return "".join(reasons) if reasons else "未能充分识别因果推断要素"
def _find_alternative_matches(self, metadata: FileMetadata,
treatment: Optional[VariableInfo],
outcome: Optional[VariableInfo]) -> List[Dict[str, Any]]:
"""
查找其他可能的匹配
Args:
metadata: 文件元数据
treatment: 处理变量
outcome: 结果变量
Returns:
List[Dict[str, Any]]: 其他可能的匹配
"""
alternatives = []
if metadata.columns is None:
return []
# 查找其他可能的处理变量
if treatment:
for col in metadata.columns:
if col.name != treatment.name:
alternatives.append({
"variable_type": "treatment",
"name": col.name,
"confidence": 0.3,
"reasoning": "可能是替代的处理变量"
})
# 查找其他可能的结果变量
if outcome:
for col in metadata.columns:
if col.name != outcome.name:
alternatives.append({
"variable_type": "outcome",
"name": col.name,
"confidence": 0.3,
"reasoning": "可能是替代的结果变量"
})
return alternatives[:5] # 最多返回 5 个替代匹配
def _parse_llm_response(self, response: str, request: CausalInferenceRequest,
metadata: FileMetadata) -> CausalInferenceResult:
"""
解析 LLM 响应
Args:
response: LLM 响应
request: 原始请求
metadata: 文件元数据
Returns:
CausalInferenceResult: 解析结果
"""
try:
# 清理响应,提取 JSON 部分
response = response.strip()
if response.startswith("```json"):
response = response[7:]
if response.startswith("```"):
response = response[1:]
if response.endswith("```"):
response = response[:-3]
response = response.strip()
# 解析 JSON
data = json.loads(response)
# 构建结果对象
result = CausalInferenceResult(
original_question=request.user_question,
file_path=request.file_path,
treatment_variable=VariableInfo(**data.get("treatment_variable")) if data.get("treatment_variable") else None,
outcome_variable=VariableInfo(**data.get("outcome_variable")) if data.get("outcome_variable") else None,
control_variables=[VariableInfo(**v) for v in data.get("control_variables", [])],
confidence_score=data.get("confidence_score", 0.0),
reasoning=data.get("reasoning", ""),
alternative_matches=data.get("alternative_matches", []),
metadata={
"parsed_at": datetime.now().isoformat(),
"file_info": {
"file_name": metadata.file_name,
"total_rows": metadata.total_rows,
"total_columns": metadata.total_columns
}
}
)
return result
except json.JSONDecodeError as e:
# 如果 JSON 解析失败,返回错误结果
return CausalInferenceResult(
original_question=request.user_question,
file_path=request.file_path,
confidence_score=0.0,
reasoning=f"LLM 响应解析失败:{str(e)}",
metadata={
"parsed_at": datetime.now().isoformat(),
"error": str(e),
"raw_response": response[:500]
}
)
def parse_simple(self, question: str, file_path: str) -> CausalInferenceResult:
"""
简化的解析方法
Args:
question: 用户问题
file_path: 文件路径
Returns:
CausalInferenceResult: 解析结果
"""
request = CausalInferenceRequest(
user_question=question,
file_path=file_path
)
return self.parse(request)
if __name__ == "__main__":
"""
因果推断解析器模块的简单测试示例
"""
from file_query_tool import FileQueryTool
print("=" * 60)
print("CausalParser 简单测试示例")
print("=" * 60)
print()
# 创建解析器实例(不使用 LLM 客户端)
print("步骤 1: 创建 CausalParser 实例...")
parser = CausalParser()
print(" - CausalParser 已初始化")
print()
# 准备示例数据文件
print("步骤 2: 准备示例数据文件...")
sample_file = "data/output/test_simple_ate.csv"
print(f" - 文件路径:{sample_file}")
print()
# 创建解析请求
print("步骤 3: 创建解析请求...")
sample_question = "教育对收入的影响是什么?"
request = CausalInferenceRequest(
user_question=sample_question,
file_path=sample_file
)
print(f" - 问题:{sample_question}")
print()
# 执行解析
print("步骤 4: 执行解析...")
try:
result = parser.parse(request)
print(" - 解析成功!")
print()
# 输出解析结果
print("步骤 5: 解析结果:")
print("-" * 60)
print(f"原始问题:{result.original_question}")
print(f"文件路径:{result.file_path}")
print(f"整体置信度:{result.confidence_score:.2%}")
print()
if result.treatment_variable:
tv = result.treatment_variable
print(f"处理变量 (T):")
print(f" - 名称:{tv.name}")
print(f" - 类型:{tv.variable_type}")
print(f" - 置信度:{tv.confidence:.2%}")
print(f" - 推理:{tv.reasoning}")
print()
if result.outcome_variable:
ov = result.outcome_variable
print(f"结果变量 (Y):")
print(f" - 名称:{ov.name}")
print(f" - 类型:{ov.variable_type}")
print(f" - 置信度:{ov.confidence:.2%}")
print(f" - 推理:{ov.reasoning}")
print()
if result.control_variables:
print(f"协变量 (Control Variables):")
for cv in result.control_variables:
print(f" - {cv.name} (置信度:{cv.confidence:.2%})")
print(f" 推理:{cv.reasoning}")
print()
print(f"推理理由:{result.reasoning}")
print()
print("=" * 60)
print("完整 JSON 输出:")
print("=" * 60)
print(result.to_json())
except Exception as e:
print(f" - 解析失败:{str(e)}")
print()
print("请检查:")
print(" 1. 数据文件路径是否正确")

View File

@ -0,0 +1,293 @@
"""
文件查询工具模块 (File Query Tool)
该模块提供文件查询功能支持读取 CSV Excel 文件的表头信息
统计信息和样本数据而不需要读取完整数据
"""
import os
import pandas as pd
from dataclasses import dataclass, asdict
from typing import List, Optional, Dict, Any
from pathlib import Path
@dataclass
class ColumnInfo:
"""列信息数据类"""
name: str # 列名
dtype: str # 数据类型
non_null_count: int # 非空值数量
null_count: int # 空值数量
unique_count: Optional[int] = None # 唯一值数量(可选)
sample_values: Optional[List[Any]] = None # 样本值(可选)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return asdict(self)
@dataclass
class SampleData:
"""样本数据数据类"""
rows: List[Dict[str, Any]] # 样本行数据
row_count: int # 样本行数
columns: List[str] # 列名列表
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"rows": self.rows,
"row_count": self.row_count,
"columns": self.columns
}
@dataclass
class FileMetadata:
"""文件元数据"""
file_path: str # 文件路径
file_name: str # 文件名
file_size: int # 文件大小(字节)
file_type: str # 文件类型csv/excel
extension: str # 文件扩展名
total_rows: Optional[int] = None # 总行数
total_columns: Optional[int] = None # 总列数
columns: Optional[List[ColumnInfo]] = None # 列信息列表
sample_data: Optional[SampleData] = None # 样本数据
read_rows: int = 100 # 实际读取的行数
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result = asdict(self)
if self.columns:
result["columns"] = [col.to_dict() for col in self.columns]
if self.sample_data:
result["sample_data"] = self.sample_data.to_dict()
return result
class FileQueryTool:
"""
文件查询工具
用于读取 CSV Excel 文件的表头信息统计信息和样本数据
支持只读取前 N 避免加载完整数据
Attributes:
default_read_rows: 默认读取的行数100
Example:
>>> tool = FileQueryTool()
>>> metadata = tool.query("data.csv")
>>> print(metadata.columns)
>>> print(metadata.sample_data)
"""
def __init__(self, default_read_rows: int = 100):
"""
初始化文件查询工具
Args:
default_read_rows: 默认读取的行数用于获取表头和样本数据
"""
self.default_read_rows = default_read_rows
def query(self, file_path: str, read_rows: Optional[int] = None) -> FileMetadata:
"""
查询文件信息
读取文件的前 N 提取表头信息统计信息和样本数据
Args:
file_path: 文件路径
read_rows: 要读取的行数如果为 None 则使用默认值
Returns:
FileMetadata: 文件元数据对象
Raises:
FileNotFoundError: 文件不存在
ValueError: 不支持的文件格式
Exception: 其他读取错误
"""
if read_rows is None:
read_rows = self.default_read_rows
# 验证文件存在
if not os.path.exists(file_path):
raise FileNotFoundError(f"文件不存在:{file_path}")
# 获取文件信息
file_path = os.path.abspath(file_path)
file_name = os.path.basename(file_path)
file_size = os.path.getsize(file_path)
extension = os.path.splitext(file_name)[1].lower()
# 确定文件类型
if extension == ".csv":
file_type = "csv"
df = self._read_csv(file_path, read_rows)
elif extension in [".xlsx", ".xls"]:
file_type = "excel"
df = self._read_excel(file_path, read_rows)
else:
raise ValueError(f"不支持的文件格式:{extension},仅支持 CSV 和 Excel 文件")
# 提取列信息
columns = self._extract_column_info(df)
# 提取样本数据
sample_data = self._extract_sample_data(df, read_rows)
# 创建元数据对象
metadata = FileMetadata(
file_path=file_path,
file_name=file_name,
file_size=file_size,
file_type=file_type,
extension=extension,
total_rows=len(df),
total_columns=len(df.columns),
columns=columns,
sample_data=sample_data,
read_rows=min(len(df), read_rows)
)
return metadata
def _read_csv(self, file_path: str, read_rows: int) -> pd.DataFrame:
"""
读取 CSV 文件
Args:
file_path: CSV 文件路径
read_rows: 要读取的行数
Returns:
pd.DataFrame: 读取的数据框
"""
df = pd.read_csv(file_path, nrows=read_rows)
return df
def _read_excel(self, file_path: str, read_rows: int) -> pd.DataFrame:
"""
读取 Excel 文件
Args:
file_path: Excel 文件路径
read_rows: 要读取的行数
Returns:
pd.DataFrame: 读取的数据框
"""
df = pd.read_excel(file_path, nrows=read_rows)
return df
def _extract_column_info(self, df: pd.DataFrame) -> List[ColumnInfo]:
"""
提取列信息
Args:
df: 数据框
Returns:
List[ColumnInfo]: 列信息列表
"""
columns = []
for col in df.columns:
# 获取数据类型
dtype = str(df[col].dtype)
# 统计非空值和空值
non_null_count = df[col].notna().sum()
null_count = df[col].isna().sum()
# 计算唯一值数量(仅对非空值)
unique_count = df[col].dropna().nunique()
# 获取样本值(最多 5 个)
sample_values = df[col].dropna().head(5).tolist()
column_info = ColumnInfo(
name=col,
dtype=dtype,
non_null_count=int(non_null_count),
null_count=int(null_count),
unique_count=int(unique_count),
sample_values=sample_values
)
columns.append(column_info)
return columns
def _extract_sample_data(self, df: pd.DataFrame, read_rows: int) -> SampleData:
"""
提取样本数据
Args:
df: 数据框
read_rows: 实际读取的行数
Returns:
SampleData: 样本数据对象
"""
# 获取实际行数(可能小于 read_rows
actual_rows = min(len(df), read_rows)
# 转换为字典列表
rows = df.head(actual_rows).to_dict(orient='records')
return SampleData(
rows=rows,
row_count=actual_rows,
columns=df.columns.tolist()
)
def get_column_names(self, file_path: str, read_rows: Optional[int] = None) -> List[str]:
"""
快速获取列名列表
Args:
file_path: 文件路径
read_rows: 要读取的行数
Returns:
List[str]: 列名列表
"""
metadata = self.query(file_path, read_rows)
return [col.name for col in metadata.columns]
def get_column_dtype(self, file_path: str, column_name: str,
read_rows: Optional[int] = None) -> str:
"""
获取指定列的数据类型
Args:
file_path: 文件路径
column_name: 列名
read_rows: 要读取的行数
Returns:
str: 数据类型
"""
metadata = self.query(file_path, read_rows)
for col in metadata.columns:
if col.name == column_name:
return col.dtype
raise ValueError(f"'{column_name}' 不存在")
def get_file_info(self, file_path: str, read_rows: Optional[int] = None) -> Dict[str, Any]:
"""
获取文件信息的字典格式
Args:
file_path: 文件路径
read_rows: 要读取的行数
Returns:
Dict[str, Any]: 文件信息字典
"""
metadata = self.query(file_path, read_rows)
return metadata.to_dict()

View File

@ -0,0 +1,217 @@
"""
输入处理模块测试脚本
该脚本测试文件查询工具和因果推断解析器的功能
"""
import os
import sys
import json
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# 导入模块
from input_processor import FileQueryTool, CausalParser, CausalInferenceRequest
def test_file_query_tool():
"""测试文件查询工具"""
print("=" * 60)
print("测试文件查询工具")
print("=" * 60)
# 使用项目中的示例数据
test_files = [
"data/output/simple_ate_data.csv",
"data/output/simple_ate_data.xlsx"
]
tool = FileQueryTool(default_read_rows=100)
for file_path in test_files:
if not os.path.exists(file_path):
print(f"\n文件不存在,跳过:{file_path}")
continue
print(f"\n测试文件:{file_path}")
print("-" * 40)
try:
# 测试基本查询
metadata = tool.query(file_path)
print(f"文件名:{metadata.file_name}")
print(f"文件类型:{metadata.file_type}")
print(f"文件大小:{metadata.file_size} 字节")
print(f"总行数:{metadata.total_rows}")
print(f"总列数:{metadata.total_columns}")
print(f"实际读取行数:{metadata.read_rows}")
print("\n列信息:")
for col in metadata.columns:
print(f" - {col.name}: {col.dtype}")
print(f" 非空值:{col.non_null_count}, 空值:{col.null_count}")
print(f" 唯一值:{col.unique_count}")
if col.sample_values:
print(f" 样本值:{col.sample_values[:3]}")
print("\n样本数据:")
if metadata.sample_data:
for i, row in enumerate(metadata.sample_data.rows[:3]):
print(f"{i+1}: {row}")
# 测试获取列名
print("\n列名列表:")
col_names = tool.get_column_names(file_path)
print(f" {col_names}")
# 测试获取数据类型
if col_names:
first_col = col_names[0]
dtype = tool.get_column_dtype(file_path, first_col)
print(f"\n'{first_col}' 的数据类型:{dtype}")
# 测试获取文件信息字典
print("\n文件信息字典:")
file_info = tool.get_file_info(file_path)
print(json.dumps(file_info, ensure_ascii=False, indent=2))
except Exception as e:
print(f"错误:{e}")
print("\n" + "=" * 60)
print("文件查询工具测试完成")
print("=" * 60)
def test_causal_parser():
"""测试因果推断解析器"""
print("\n" + "=" * 60)
print("测试因果推断解析器")
print("=" * 60)
# 使用项目中的示例数据
test_file = "data/output/simple_ate_data.csv"
if not os.path.exists(test_file):
print(f"文件不存在,跳过测试:{test_file}")
return
parser = CausalParser()
# 测试问题列表
test_questions = [
"教育对收入的影响是什么?",
"处理变量对结果变量的影响如何?",
"分析 treatment 对 outcome 的因果效应",
"这个数据集中 treatment 和 outcome 有什么关系?",
"如何评估干预措施的效果?"
]
for question in test_questions:
print(f"\n用户问题:{question}")
print("-" * 40)
try:
# 使用简化的解析方法
result = parser.parse_simple(question, test_file)
# 输出结果
print(result.to_json())
# 结构化输出
print("\n结构化结果:")
if result.treatment_variable:
print(f" 处理变量 (T): {result.treatment_variable.name}")
print(f" 置信度:{result.treatment_variable.confidence:.2f}")
print(f" 推理:{result.treatment_variable.reasoning}")
if result.outcome_variable:
print(f" 结果变量 (Y): {result.outcome_variable.name}")
print(f" 置信度:{result.outcome_variable.confidence:.2f}")
print(f" 推理:{result.outcome_variable.reasoning}")
if result.control_variables:
print(f" 协变量 ({len(result.control_variables)} 个):")
for var in result.control_variables[:3]:
print(f" - {var.name} (置信度:{var.confidence:.2f})")
if len(result.control_variables) > 3:
print(f" ... 还有 {len(result.control_variables) - 3}")
print(f"\n 整体置信度:{result.confidence_score:.2f}")
print(f" 推理理由:{result.reasoning}")
except Exception as e:
print(f"错误:{e}")
print("\n" + "=" * 60)
print("因果推断解析器测试完成")
print("=" * 60)
def test_causal_parser_with_request():
"""测试使用 CausalInferenceRequest 对象的解析"""
print("\n" + "=" * 60)
print("测试使用 CausalInferenceRequest 对象")
print("=" * 60)
test_file = "data/output/simple_ate_data.csv"
if not os.path.exists(test_file):
print(f"文件不存在,跳过测试:{test_file}")
return
parser = CausalParser()
# 创建请求对象
request = CausalInferenceRequest(
user_question="分析 treatment 变量对 outcome 变量的影响",
file_path=test_file,
read_rows=50,
llm_provider="mock",
llm_model="mock-model"
)
print(f"请求信息:")
print(f" 问题:{request.user_question}")
print(f" 文件:{request.file_path}")
print(f" 读取行数:{request.read_rows}")
print(f" LLM 提供商:{request.llm_provider}")
print(f" LLM 模型:{request.llm_model}")
# 执行解析
result = parser.parse(request)
print("\n解析结果:")
print(result.to_json())
print("\n" + "=" * 60)
print("请求对象测试完成")
print("=" * 60)
def main():
"""主测试函数"""
print("\n" + "=" * 60)
print("输入处理模块测试")
print("=" * 60)
# 测试文件查询工具
test_file_query_tool()
# 测试因果推断解析器
test_causal_parser()
# 测试使用请求对象的解析
test_causal_parser_with_request()
print("\n" + "=" * 60)
print("所有测试完成!")
print("=" * 60)
if __name__ == "__main__":
main()

192
pyproject.toml Normal file
View File

@ -0,0 +1,192 @@
[project]
# 项目名称
name = "causal-inference-agent"
# 项目版本(遵循语义化版本控制)
version = "0.1.0"
# 项目描述
description = "Causal Inference Agent System"
# 项目作者
authors = [
{ name = "Causal Inference Team", email = "team@example.com" }
]
# 项目许可证(可选)
license = "MIT"
# 项目 readme 文件
readme = "README.md"
# 项目关键词
keywords = ["causal-inference", "agent", "data-science", "statistics"]
# 项目分类
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Mathematics",
]
# Python 版本要求
# 最低支持 Python 3.11
requires-python = ">=3.11"
# 项目运行时依赖
dependencies = [
"pandas>=2.0.0", # 数据处理和分析
"openpyxl>=3.1.0", # Excel 文件读写支持
"numpy>=1.24.0", # 数值计算基础库
"scipy>=1.10.0", # 科学计算和统计功能
"pydantic>=2.0.0", # 数据验证和设置管理(用于 JSON 输出)
"requests>=2.31.0", # HTTP 请求库(用于 LLM API 调用)
"networkx>=3.0", # 图论与因果图构建
"scikit-learn>=1.3.0", # 机器学习与因果效应估计
]
# 可选依赖(额外功能)
[project.optional-dependencies]
# 开发依赖
dev = [
"pytest>=7.4.0", # 测试框架
"pytest-cov>=4.1.0", # 代码覆盖率
"black>=23.0.0", # 代码格式化
"ruff>=0.1.0", # 代码 linting
"mypy>=1.5.0", # 类型检查
"pre-commit>=3.4.0", # Git 预提交钩子
]
# 文档依赖
docs = [
"sphinx>=7.0.0", # 文档生成工具
"sphinx-rtd-theme>=1.3.0", # Sphinx 主题
"myst-parser>=2.0.0", # Markdown 解析器
]
# 测试依赖
test = [
"pytest>=7.4.0",
"pytest-cov>=4.1.0",
"pytest-xdist>=3.3.0", # 并行测试
]
# 完整依赖(包含所有可选依赖)
all = [
"causal-inference-agent[dev,docs,test]",
]
# 包配置
[tool.setuptools.packages.find]
include = ["causal_agent*", "input_processor*", "data*"]
[project.scripts]
# 命令行入口点
causal-agent = "causal_agent.cli:main"
[project.urls]
# 项目相关链接
Homepage = "https://github.com/example/causal-inference-agent"
Repository = "https://github.com/example/causal-inference-agent"
"Bug Tracker" = "https://github.com/example/causal-inference-agent/issues"
"Documentation" = "https://causal-inference-agent.readthedocs.io"
"Changelog" = "https://github.com/example/causal-inference-agent/blob/main/CHANGELOG.md"
# ==================== uv 配置 ====================
# uv 索引配置
[tool.uv]
# 使用 PyPI 作为默认索引
index-url = "https://pypi.org/simple"
# 额外索引(可选)
extra-index-url = []
# 依赖解析策略
resolution = "lowest-direct"
# ==================== 开发工具配置 ====================
# Black 代码格式化配置
[tool.black]
line-length = 88
target-version = ['py39', 'py310', 'py311', 'py312']
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
'''
# Ruff 代码 linting 配置
[tool.ruff]
line-length = 88
target-version = "py39"
# 启用的 lint 规则
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # Pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
]
ignore = [
"E501", # line too long (handled by black)
"B008", # do not perform function calls in argument defaults
]
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"] # unused imports in __init__.py
# MyPy 类型检查配置
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
ignore_missing_imports = true
# Pytest 测试配置
[tool.pytest.ini_options]
testpaths = ["tests", "input_processor"]
python_files = ["test_*.py"]
python_functions = ["test_*"]
addopts = [
"-v",
"--strict-markers",
"--strict-config",
]
markers = [
"slow: marks tests as slow",
"integration: marks tests as integration tests",
]
# 代码覆盖率配置
[tool.coverage.run]
source = ["."]
omit = [
"*/tests/*",
"*/__init__.py",
"*/conftest.py",
]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise AssertionError",
"raise NotImplementedError",
"if __name__ == .__main__.:",
]
# 预提交钩子配置
[tool.pre-commit]
# 默认钩子
default_stages = ["commit", "push", "manual"]

View File

View File

@ -0,0 +1,49 @@
"""Tests for causal_graph module."""
import networkx as nx
import pandas as pd
import pytest
from causal_agent.analysis.causal_graph import (
build_local_graph,
find_adjustment_set,
find_backdoor_paths,
)
def test_build_local_graph():
df = pd.DataFrame({
"treatment": [0, 1, 0, 1],
"health": [1, 2, 1, 3],
"age": [20, 30, 40, 50],
})
candidates = [{"var": "age"}]
tiers = {"treatment": 2, "health": 4, "age": 0}
G = build_local_graph(df, "treatment", "health", candidates, tiers, corr_threshold=0.1)
assert isinstance(G, nx.DiGraph)
assert G.has_edge("treatment", "health")
assert G.has_edge("age", "treatment")
assert G.has_edge("age", "health")
def test_find_backdoor_paths():
G = nx.DiGraph()
G.add_edge("age", "treatment")
G.add_edge("age", "health")
G.add_edge("treatment", "health")
paths = find_backdoor_paths(G, "treatment", "health")
assert len(paths) == 1
assert "treatment <- age -> health" in paths
def test_find_adjustment_set():
G = nx.DiGraph()
G.add_edge("age", "treatment")
G.add_edge("age", "health")
G.add_edge("treatment", "health")
adjustment_set, reasoning = find_adjustment_set(G, "treatment", "health")
assert "age" in adjustment_set
assert "后门" in reasoning

View File

@ -0,0 +1,39 @@
"""Tests for estimation module."""
import numpy as np
import pandas as pd
import pytest
from causal_agent.analysis.estimation import estimate_ate
def test_estimate_ate_basic():
np.random.seed(42)
n = 300
z = np.random.normal(0, 1, n)
t = (z + np.random.normal(0, 0.5, n) > 0).astype(int)
y = 2 * t + 3 * z + np.random.normal(0, 0.5, n)
df = pd.DataFrame({"T": t, "Y": y, "Z": z})
result = estimate_ate(df, "T", "Y", ["Z"], compute_ci=False)
assert "ATE_Outcome_Regression" in result
assert "ATE_IPW" in result
assert abs(result["ATE_Outcome_Regression"] - 2.0) < 0.5
# IPW 估计方差较大,放宽阈值
assert abs(result["ATE_IPW"] - 2.0) < 1.5
assert "balance_check" in result
assert "Z" in result["balance_check"]
def test_estimate_ate_empty_adjustment():
np.random.seed(42)
n = 200
t = np.random.binomial(1, 0.5, n)
y = 1.5 * t + np.random.normal(0, 0.5, n)
df = pd.DataFrame({"T": t, "Y": y})
result = estimate_ate(df, "T", "Y", [], compute_ci=False)
assert abs(result["ATE_Outcome_Regression"] - 1.5) < 0.5
assert abs(result["ATE_IPW"] - 1.5) < 0.5

View File

@ -0,0 +1,31 @@
"""Tests for screening module."""
import numpy as np
import pandas as pd
import pytest
from causal_agent.analysis.screening import local_screen
def test_local_screen_finds_confounders():
np.random.seed(42)
n = 200
z = np.random.normal(0, 1, n)
t = (z + np.random.normal(0, 0.5, n) > 0).astype(int)
y = 2 * t + 3 * z + np.random.normal(0, 0.5, n)
noise = np.random.normal(0, 1, n)
df = pd.DataFrame({"T": t, "Y": y, "Z": z, "noise": noise})
candidates = local_screen(df, T="T", Y="Y", excluded=[], corr_threshold=0.1, alpha=0.05)
var_names = [c["var"] for c in candidates]
assert "Z" in var_names
assert "noise" not in var_names
def test_local_screen_excludes_specified_columns():
df = pd.DataFrame({"T": [0, 1], "Y": [1, 2], "X": [0, 1], "id": [1, 2]})
candidates = local_screen(df, T="T", Y="Y", excluded=["id"])
assert all(c["var"] != "id" for c in candidates)
assert all(c["var"] != "T" for c in candidates)
assert all(c["var"] != "Y" for c in candidates)

1230
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff