357 lines
11 KiB
Python
357 lines
11 KiB
Python
"""
|
||
LLM 因果分析测试入口
|
||
使用自然语言输入让 LLM 测试分析因果变量,并将所有参数记录到 log.md 文档中
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
from datetime import datetime
|
||
from typing import Dict, Any, Optional
|
||
import pandas as pd
|
||
from llm_client import LLMClient
|
||
|
||
|
||
class CausalAnalysisLogger:
|
||
"""因果分析日志记录器"""
|
||
|
||
def __init__(self, log_path: str = "examples/medical/log.md"):
|
||
self.log_path = log_path
|
||
self.log_entries = []
|
||
self._init_log_file()
|
||
self._load_existing_entries()
|
||
|
||
def _init_log_file(self):
|
||
"""初始化日志文件"""
|
||
header = """# 因果分析日志
|
||
|
||
## 日志说明
|
||
本文档记录 LLM 进行因果分析时的所有输入参数和输出结果。
|
||
|
||
## 分析记录
|
||
"""
|
||
if not os.path.exists(self.log_path):
|
||
with open(self.log_path, 'w', encoding='utf-8') as f:
|
||
f.write(header)
|
||
|
||
def _load_existing_entries(self):
|
||
"""从日志文件中加载已存在的分析记录,以确定下一个分析 ID"""
|
||
if not os.path.exists(self.log_path):
|
||
return
|
||
|
||
analysis_count = 0
|
||
with open(self.log_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
# 统计已存在的分析记录数量
|
||
import re
|
||
matches = re.findall(r'### 分析 #(\d+)', content)
|
||
if matches:
|
||
analysis_count = max(int(m) for m in matches)
|
||
|
||
self.analysis_count = analysis_count
|
||
|
||
def _append_to_log(self, content: str):
|
||
"""追加内容到日志文件"""
|
||
with open(self.log_path, 'a', encoding='utf-8') as f:
|
||
f.write(content)
|
||
|
||
def log_analysis(self, system_prompt: str, user_prompt: str,
|
||
model_output: str, parameters: Dict[str, Any]):
|
||
"""记录一次完整的分析过程"""
|
||
# 增加分析计数
|
||
self.analysis_count += 1
|
||
analysis_id = f"{self.analysis_count:03d}"
|
||
timestamp = datetime.now().isoformat()
|
||
|
||
entry = f"""
|
||
---
|
||
|
||
### 分析 #{analysis_id}
|
||
**时间**: {timestamp}
|
||
|
||
#### 系统提示词
|
||
```
|
||
{system_prompt}
|
||
```
|
||
|
||
#### 用户提示词
|
||
```
|
||
{user_prompt}
|
||
```
|
||
|
||
#### LLM 输出
|
||
```
|
||
{model_output}
|
||
```
|
||
|
||
#### 调用参数
|
||
```json
|
||
{json.dumps(parameters, indent=2, ensure_ascii=False)}
|
||
```
|
||
|
||
---
|
||
|
||
"""
|
||
self._append_to_log(entry)
|
||
self.log_entries.append({
|
||
'id': analysis_id,
|
||
'timestamp': timestamp,
|
||
'system_prompt': system_prompt,
|
||
'user_prompt': user_prompt,
|
||
'model_output': model_output,
|
||
'parameters': parameters
|
||
})
|
||
|
||
print(f"分析 #{analysis_id} 已记录到 {self.log_path}")
|
||
|
||
|
||
class CausalAnalysisAgent:
|
||
"""因果分析代理"""
|
||
|
||
def __init__(self, data_path: str = "examples/medical/data.xlsx", log_path: str = "examples/medical/log.md"):
|
||
self.data_path = data_path
|
||
self.logger = CausalAnalysisLogger(log_path)
|
||
self.data = None
|
||
|
||
def load_data(self) -> pd.DataFrame:
|
||
"""加载数据"""
|
||
if not os.path.exists(self.data_path):
|
||
raise FileNotFoundError(f"数据文件不存在:{self.data_path}")
|
||
|
||
self.data = pd.read_excel(self.data_path)
|
||
print(f"成功加载数据:{self.data_path}")
|
||
print(f"数据形状:{self.data.shape}")
|
||
print(f"列名:{list(self.data.columns)}")
|
||
print(f"\n数据预览:")
|
||
print(self.data.head())
|
||
return self.data
|
||
|
||
def _generate_system_prompt(self) -> str:
|
||
"""生成系统提示词"""
|
||
return """你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别因果变量并评估因果关系。
|
||
|
||
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
|
||
|
||
JSON 输出规范:
|
||
{
|
||
"treatment": "处理变量名称",
|
||
"outcome": "结果变量名称"
|
||
}
|
||
|
||
注意:
|
||
- 只输出上述 JSON 格式,不要包含其他字段
|
||
- 处理变量和结果变量名称必须与数据表格的列名完全一致
|
||
- 不要使用 markdown 代码块标记(如 ```json)
|
||
- 直接输出纯 JSON 字符串"""
|
||
|
||
def _generate_user_prompt(self, data_info: str) -> str:
|
||
"""生成用户提示词"""
|
||
return f"""请分析以下医疗数据,并严格按照 JSON 格式输出分析结果:
|
||
|
||
{data_info}
|
||
|
||
JSON 输出格式要求:
|
||
{{
|
||
"treatment": "处理变量名称",
|
||
"outcome": "结果变量名称"
|
||
}}
|
||
|
||
要求:
|
||
1. 处理变量和结果变量名称必须与表格列名完全一致
|
||
2. 只输出 JSON,不要包含其他任何内容
|
||
3. 不要使用 markdown 代码块标记"""
|
||
|
||
def _call_llm(self, system_prompt: str, user_prompt: str) -> Dict[str, Any]:
|
||
"""
|
||
调用 LLM API 获取响应
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
user_prompt: 用户提示词
|
||
|
||
Returns:
|
||
包含响应内容的字典
|
||
"""
|
||
# 初始化 LLM 客户端(从环境变量读取配置)
|
||
import os
|
||
llm_client = LLMClient(
|
||
base_url=os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1"),
|
||
model=os.getenv("LLM_MODEL", "qwen3.5-35b"),
|
||
temperature=float(os.getenv("LLM_TEMPERATURE", "0.3")),
|
||
max_tokens=int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
||
)
|
||
|
||
# 调用 LLM
|
||
result = llm_client.generate_response(system_prompt, user_prompt)
|
||
|
||
return result
|
||
|
||
def analyze(self, custom_prompt: Optional[str] = None) -> Dict[str, Any]:
|
||
"""
|
||
执行因果分析
|
||
|
||
Args:
|
||
custom_prompt: 自定义用户提示词,如果为 None 则使用默认提示词
|
||
|
||
Returns:
|
||
分析结果字典
|
||
"""
|
||
# 加载数据
|
||
if self.data is None:
|
||
self.load_data()
|
||
|
||
# 类型断言:确保 self.data 不是 None
|
||
assert self.data is not None, "数据未加载"
|
||
|
||
# 生成数据信息
|
||
data_info = f"""
|
||
**数据概览:**
|
||
- 样本数量:{len(self.data)}
|
||
- 变量:{', '.join(self.data.columns)}
|
||
|
||
**统计摘要:**
|
||
{self.data.describe()}
|
||
|
||
**处理变量分布:**
|
||
{self.data['treatment'].value_counts().to_string()}
|
||
"""
|
||
|
||
# 生成提示词
|
||
system_prompt = self._generate_system_prompt()
|
||
user_prompt = custom_prompt if custom_prompt else self._generate_user_prompt(data_info)
|
||
|
||
# 计算统计值
|
||
treatment_mean = self.data[self.data['treatment'] == 1]['health'].mean() if self.data['treatment'].sum() > 0 else 0
|
||
control_mean = self.data[self.data['treatment'] == 0]['health'].mean() if self.data['treatment'].sum() < len(self.data) else 0
|
||
ate = treatment_mean - control_mean if self.data['treatment'].sum() > 0 and self.data['treatment'].sum() < len(self.data) else 0
|
||
|
||
# 调用 LLM 获取真实响应
|
||
llm_result = self._call_llm(system_prompt, user_prompt)
|
||
|
||
# 打印 LLM 调用结果用于调试
|
||
print(f"\nLLM 调用结果:{llm_result}")
|
||
|
||
if llm_result.get("success"):
|
||
model_output = llm_result.get("content", "LLM 返回内容为空")
|
||
else:
|
||
# 如果 LLM 调用失败,使用模拟响应
|
||
model_output = f"""## 因果分析结果(LLM 调用失败,使用模拟响应)
|
||
|
||
### 1. 变量识别
|
||
- **处理变量 (Treatment)**: treatment (是否吃药)
|
||
- 类型:二元变量 (0/1)
|
||
- 含义:1 表示吃药,0 表示不吃药
|
||
|
||
- **结果变量 (Outcome)**: health (健康状态)
|
||
- 类型:连续变量 (0-1 浮点数)
|
||
- 含义:健康状态评分,越高越好
|
||
|
||
### 2. 因果效应估计
|
||
根据数据描述性统计:
|
||
- 吃药组 (treatment=1) 的平均健康状态:{treatment_mean:.4f}
|
||
- 未吃药组 (treatment=0) 的平均健康状态:{control_mean:.4f}
|
||
- **平均处理效应 (ATE)**: {ate:.4f}
|
||
|
||
### 3. 分析方法
|
||
使用简单的组间比较方法估计因果效应:
|
||
- 方法:均值差异 (Mean Difference)
|
||
- 假设:无混杂因素或混杂因素已控制
|
||
|
||
### 4. 结论
|
||
吃药对健康状态有正向的因果效应,估计效应大小为 {ate:.4f}。
|
||
这意味着吃药可以使健康状态平均提高 {ate * 100:.1f}%。
|
||
|
||
### 5. 建议
|
||
- 建议进行更严格的因果推断分析(如倾向得分匹配、工具变量法等)
|
||
- 考虑控制可能的混杂因素(如年龄、基础健康状况等)
|
||
|
||
---
|
||
**LLM 调用错误**: {llm_result.get('error', 'Unknown error')}"""
|
||
|
||
# 准备参数
|
||
parameters = {
|
||
'data_path': self.data_path,
|
||
'sample_size': len(self.data),
|
||
'variables': list(self.data.columns),
|
||
'treatment_variable': 'treatment',
|
||
'outcome_variable': 'health'
|
||
}
|
||
|
||
# 记录到日志(包含 LLM 调用信息)
|
||
import os
|
||
llm_params = {
|
||
'base_url': os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1"),
|
||
'model': os.getenv("LLM_MODEL", "qwen3.5-35b"),
|
||
'temperature': float(os.getenv("LLM_TEMPERATURE", "0.3")),
|
||
'max_tokens': int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
||
}
|
||
parameters['llm_params'] = llm_params
|
||
parameters['data_path'] = self.data_path
|
||
parameters['log_path'] = self.logger.log_path
|
||
|
||
self.logger.log_analysis(system_prompt, user_prompt, model_output, parameters)
|
||
|
||
# 打印结果
|
||
print(f"\n{'='*60}")
|
||
print(f"分析 #{self.logger.analysis_count:03d} 结果:")
|
||
print(f"{'='*60}")
|
||
print(model_output)
|
||
|
||
return {
|
||
'id': f"{self.logger.analysis_count:03d}",
|
||
'system_prompt': system_prompt,
|
||
'user_prompt': user_prompt,
|
||
'model_output': model_output,
|
||
'parameters': parameters
|
||
}
|
||
|
||
def interactive_analyze(self):
|
||
"""交互式分析模式"""
|
||
print("=" * 60)
|
||
print("LLM 因果分析测试系统")
|
||
print("=" * 60)
|
||
|
||
# 加载数据
|
||
self.load_data()
|
||
|
||
while True:
|
||
print("\n" + "-" * 40)
|
||
print("请输入分析提示词(输入 'quit' 退出,'default' 使用默认):")
|
||
user_input = input("> ").strip()
|
||
|
||
if user_input.lower() == 'quit':
|
||
print("感谢使用,再见!")
|
||
break
|
||
elif user_input.lower() == 'default':
|
||
user_input = None
|
||
|
||
try:
|
||
result = self.analyze(custom_prompt=user_input)
|
||
print(f"\n分析已记录到:{self.logger.log_path}")
|
||
except Exception as e:
|
||
print(f"分析出错:{e}")
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
# 检查数据文件是否存在,不存在则生成
|
||
data_path = "examples/medical/data.xlsx"
|
||
if not os.path.exists(data_path):
|
||
print("数据文件不存在,正在生成...")
|
||
from data_generator import generate_medical_data
|
||
generate_medical_data(n_samples=500, output_path=data_path)
|
||
|
||
# 创建分析代理并执行分析
|
||
agent = CausalAnalysisAgent(data_path=data_path, log_path="examples/medical/log.md")
|
||
|
||
# 执行分析
|
||
result = agent.analyze()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("测试完成!")
|
||
print(f"日志已保存到:{agent.logger.log_path}")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|