2026-03-29 23:47:20 +08:00

357 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 因果分析测试入口
使用自然语言输入让 LLM 测试分析因果变量,并将所有参数记录到 log.md 文档中
"""
import os
import json
from datetime import datetime
from typing import Dict, Any, Optional
import pandas as pd
from llm_client import LLMClient
class CausalAnalysisLogger:
"""因果分析日志记录器"""
def __init__(self, log_path: str = "examples/medical/log.md"):
self.log_path = log_path
self.log_entries = []
self._init_log_file()
self._load_existing_entries()
def _init_log_file(self):
"""初始化日志文件"""
header = """# 因果分析日志
## 日志说明
本文档记录 LLM 进行因果分析时的所有输入参数和输出结果。
## 分析记录
"""
if not os.path.exists(self.log_path):
with open(self.log_path, 'w', encoding='utf-8') as f:
f.write(header)
def _load_existing_entries(self):
"""从日志文件中加载已存在的分析记录,以确定下一个分析 ID"""
if not os.path.exists(self.log_path):
return
analysis_count = 0
with open(self.log_path, 'r', encoding='utf-8') as f:
content = f.read()
# 统计已存在的分析记录数量
import re
matches = re.findall(r'### 分析 #(\d+)', content)
if matches:
analysis_count = max(int(m) for m in matches)
self.analysis_count = analysis_count
def _append_to_log(self, content: str):
"""追加内容到日志文件"""
with open(self.log_path, 'a', encoding='utf-8') as f:
f.write(content)
def log_analysis(self, system_prompt: str, user_prompt: str,
model_output: str, parameters: Dict[str, Any]):
"""记录一次完整的分析过程"""
# 增加分析计数
self.analysis_count += 1
analysis_id = f"{self.analysis_count:03d}"
timestamp = datetime.now().isoformat()
entry = f"""
---
### 分析 #{analysis_id}
**时间**: {timestamp}
#### 系统提示词
```
{system_prompt}
```
#### 用户提示词
```
{user_prompt}
```
#### LLM 输出
```
{model_output}
```
#### 调用参数
```json
{json.dumps(parameters, indent=2, ensure_ascii=False)}
```
---
"""
self._append_to_log(entry)
self.log_entries.append({
'id': analysis_id,
'timestamp': timestamp,
'system_prompt': system_prompt,
'user_prompt': user_prompt,
'model_output': model_output,
'parameters': parameters
})
print(f"分析 #{analysis_id} 已记录到 {self.log_path}")
class CausalAnalysisAgent:
"""因果分析代理"""
def __init__(self, data_path: str = "examples/medical/data.xlsx", log_path: str = "examples/medical/log.md"):
self.data_path = data_path
self.logger = CausalAnalysisLogger(log_path)
self.data = None
def load_data(self) -> pd.DataFrame:
"""加载数据"""
if not os.path.exists(self.data_path):
raise FileNotFoundError(f"数据文件不存在:{self.data_path}")
self.data = pd.read_excel(self.data_path)
print(f"成功加载数据:{self.data_path}")
print(f"数据形状:{self.data.shape}")
print(f"列名:{list(self.data.columns)}")
print(f"\n数据预览:")
print(self.data.head())
return self.data
def _generate_system_prompt(self) -> str:
"""生成系统提示词"""
return """你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别因果变量并评估因果关系。
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
JSON 输出规范:
{
"treatment": "处理变量名称",
"outcome": "结果变量名称"
}
注意:
- 只输出上述 JSON 格式,不要包含其他字段
- 处理变量和结果变量名称必须与数据表格的列名完全一致
- 不要使用 markdown 代码块标记(如 ```json
- 直接输出纯 JSON 字符串"""
def _generate_user_prompt(self, data_info: str) -> str:
"""生成用户提示词"""
return f"""请分析以下医疗数据,并严格按照 JSON 格式输出分析结果:
{data_info}
JSON 输出格式要求:
{{
"treatment": "处理变量名称",
"outcome": "结果变量名称"
}}
要求:
1. 处理变量和结果变量名称必须与表格列名完全一致
2. 只输出 JSON不要包含其他任何内容
3. 不要使用 markdown 代码块标记"""
def _call_llm(self, system_prompt: str, user_prompt: str) -> Dict[str, Any]:
"""
调用 LLM API 获取响应
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
Returns:
包含响应内容的字典
"""
# 初始化 LLM 客户端(从环境变量读取配置)
import os
llm_client = LLMClient(
base_url=os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1"),
model=os.getenv("LLM_MODEL", "qwen3.5-35b"),
temperature=float(os.getenv("LLM_TEMPERATURE", "0.3")),
max_tokens=int(os.getenv("LLM_MAX_TOKENS", "2048"))
)
# 调用 LLM
result = llm_client.generate_response(system_prompt, user_prompt)
return result
def analyze(self, custom_prompt: Optional[str] = None) -> Dict[str, Any]:
"""
执行因果分析
Args:
custom_prompt: 自定义用户提示词,如果为 None 则使用默认提示词
Returns:
分析结果字典
"""
# 加载数据
if self.data is None:
self.load_data()
# 类型断言:确保 self.data 不是 None
assert self.data is not None, "数据未加载"
# 生成数据信息
data_info = f"""
**数据概览:**
- 样本数量:{len(self.data)}
- 变量:{', '.join(self.data.columns)}
**统计摘要:**
{self.data.describe()}
**处理变量分布:**
{self.data['treatment'].value_counts().to_string()}
"""
# 生成提示词
system_prompt = self._generate_system_prompt()
user_prompt = custom_prompt if custom_prompt else self._generate_user_prompt(data_info)
# 计算统计值
treatment_mean = self.data[self.data['treatment'] == 1]['health'].mean() if self.data['treatment'].sum() > 0 else 0
control_mean = self.data[self.data['treatment'] == 0]['health'].mean() if self.data['treatment'].sum() < len(self.data) else 0
ate = treatment_mean - control_mean if self.data['treatment'].sum() > 0 and self.data['treatment'].sum() < len(self.data) else 0
# 调用 LLM 获取真实响应
llm_result = self._call_llm(system_prompt, user_prompt)
# 打印 LLM 调用结果用于调试
print(f"\nLLM 调用结果:{llm_result}")
if llm_result.get("success"):
model_output = llm_result.get("content", "LLM 返回内容为空")
else:
# 如果 LLM 调用失败,使用模拟响应
model_output = f"""## 因果分析结果LLM 调用失败,使用模拟响应)
### 1. 变量识别
- **处理变量 (Treatment)**: treatment (是否吃药)
- 类型:二元变量 (0/1)
- 含义1 表示吃药0 表示不吃药
- **结果变量 (Outcome)**: health (健康状态)
- 类型:连续变量 (0-1 浮点数)
- 含义:健康状态评分,越高越好
### 2. 因果效应估计
根据数据描述性统计:
- 吃药组 (treatment=1) 的平均健康状态:{treatment_mean:.4f}
- 未吃药组 (treatment=0) 的平均健康状态:{control_mean:.4f}
- **平均处理效应 (ATE)**: {ate:.4f}
### 3. 分析方法
使用简单的组间比较方法估计因果效应:
- 方法:均值差异 (Mean Difference)
- 假设:无混杂因素或混杂因素已控制
### 4. 结论
吃药对健康状态有正向的因果效应,估计效应大小为 {ate:.4f}
这意味着吃药可以使健康状态平均提高 {ate * 100:.1f}%。
### 5. 建议
- 建议进行更严格的因果推断分析(如倾向得分匹配、工具变量法等)
- 考虑控制可能的混杂因素(如年龄、基础健康状况等)
---
**LLM 调用错误**: {llm_result.get('error', 'Unknown error')}"""
# 准备参数
parameters = {
'data_path': self.data_path,
'sample_size': len(self.data),
'variables': list(self.data.columns),
'treatment_variable': 'treatment',
'outcome_variable': 'health'
}
# 记录到日志(包含 LLM 调用信息)
import os
llm_params = {
'base_url': os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1"),
'model': os.getenv("LLM_MODEL", "qwen3.5-35b"),
'temperature': float(os.getenv("LLM_TEMPERATURE", "0.3")),
'max_tokens': int(os.getenv("LLM_MAX_TOKENS", "2048"))
}
parameters['llm_params'] = llm_params
parameters['data_path'] = self.data_path
parameters['log_path'] = self.logger.log_path
self.logger.log_analysis(system_prompt, user_prompt, model_output, parameters)
# 打印结果
print(f"\n{'='*60}")
print(f"分析 #{self.logger.analysis_count:03d} 结果:")
print(f"{'='*60}")
print(model_output)
return {
'id': f"{self.logger.analysis_count:03d}",
'system_prompt': system_prompt,
'user_prompt': user_prompt,
'model_output': model_output,
'parameters': parameters
}
def interactive_analyze(self):
"""交互式分析模式"""
print("=" * 60)
print("LLM 因果分析测试系统")
print("=" * 60)
# 加载数据
self.load_data()
while True:
print("\n" + "-" * 40)
print("请输入分析提示词(输入 'quit' 退出,'default' 使用默认):")
user_input = input("> ").strip()
if user_input.lower() == 'quit':
print("感谢使用,再见!")
break
elif user_input.lower() == 'default':
user_input = None
try:
result = self.analyze(custom_prompt=user_input)
print(f"\n分析已记录到:{self.logger.log_path}")
except Exception as e:
print(f"分析出错:{e}")
def main():
"""主函数"""
# 检查数据文件是否存在,不存在则生成
data_path = "examples/medical/data.xlsx"
if not os.path.exists(data_path):
print("数据文件不存在,正在生成...")
from data_generator import generate_medical_data
generate_medical_data(n_samples=500, output_path=data_path)
# 创建分析代理并执行分析
agent = CausalAnalysisAgent(data_path=data_path, log_path="examples/medical/log.md")
# 执行分析
result = agent.analyze()
print("\n" + "=" * 60)
print("测试完成!")
print(f"日志已保存到:{agent.logger.log_path}")
print("=" * 60)
if __name__ == "__main__":
main()