134 lines
4.8 KiB
Python
134 lines
4.8 KiB
Python
"""
|
||
变量解析模块
|
||
|
||
使用 LLM 识别处理变量、结果变量,并对每个变量进行时间层级解析。
|
||
"""
|
||
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from causal_agent.core.llm_client import LLMClient
|
||
|
||
|
||
class VariableParser:
|
||
"""LLM 变量解析器"""
|
||
|
||
SYSTEM_PROMPT = """你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别处理变量(treatment)、结果变量(outcome),并对每个变量进行时间层级解析。
|
||
|
||
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
|
||
|
||
JSON 输出规范:
|
||
{
|
||
"treatment": "处理变量名称",
|
||
"outcome": "结果变量名称",
|
||
"time_tiers": {
|
||
"变量名1": 整数层级,
|
||
"变量名2": 整数层级,
|
||
...
|
||
}
|
||
}
|
||
|
||
time_tiers 层级说明(整数,越小表示越早发生):
|
||
- -1: 非时间变量(如样本唯一标识符 id、index 等)
|
||
- 0: 人口学特征或不变的混杂因素(如 age、gender、region 等)
|
||
- 1: 基线测量(干预前测得,可能是混杂因素,如 baseline_score、pre_test 等)
|
||
- 2: 干预点/处理变量(如 treatment、intervention、policy 等)
|
||
- 3: 中介变量(干预后、结果前测得)
|
||
- 4: 随访结果/结果变量(如 outcome、post_test、score 等)
|
||
- 5+: 更晚的时间点(如有多次随访)
|
||
|
||
注意:
|
||
- 只输出上述 JSON 格式,不要包含其他字段
|
||
- treatment 和 outcome 必须是数据表格中真实存在的列名
|
||
- time_tiers 必须包含数据中的所有列名
|
||
- 不要使用 markdown 代码块标记(如 ```json)
|
||
- 直接输出纯 JSON 字符串"""
|
||
|
||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||
self.llm_client = llm_client or LLMClient()
|
||
|
||
def parse(
|
||
self,
|
||
columns: List[str],
|
||
data_info: str,
|
||
custom_prompt: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
调用 LLM 解析变量。
|
||
|
||
Args:
|
||
columns: 数据列名列表
|
||
data_info: 数据信息摘要字符串
|
||
custom_prompt: 可选的自定义用户提示词补充
|
||
|
||
Returns:
|
||
包含 treatment、outcome、time_tiers 的字典
|
||
"""
|
||
user_prompt = self._build_user_prompt(columns, data_info, custom_prompt)
|
||
result = self.llm_client.generate_response(self.SYSTEM_PROMPT, user_prompt)
|
||
|
||
if not result.get("success"):
|
||
raise RuntimeError(f"LLM 调用失败: {result.get('error')}")
|
||
|
||
try:
|
||
parsed = LLMClient.parse_json_response(result["content"])
|
||
except Exception as e:
|
||
raise ValueError(f"LLM 输出解析失败: {e}\n原始内容: {result.get('content')}")
|
||
|
||
self._validate(parsed, columns)
|
||
return parsed
|
||
|
||
def _build_user_prompt(
|
||
self,
|
||
columns: List[str],
|
||
data_info: str,
|
||
custom_prompt: Optional[str] = None,
|
||
) -> str:
|
||
"""构建用户提示词"""
|
||
parts = [
|
||
"请分析以下数据,并严格按照 JSON 格式输出分析结果:",
|
||
"",
|
||
data_info,
|
||
"",
|
||
"JSON 输出格式要求:",
|
||
'{',
|
||
' "treatment": "处理变量名称",',
|
||
' "outcome": "结果变量名称",',
|
||
' "time_tiers": {',
|
||
' "列名1": 层级整数,',
|
||
' "列名2": 层级整数,',
|
||
" ...",
|
||
" }",
|
||
"}",
|
||
"",
|
||
"要求:",
|
||
"1. treatment 和 outcome 必须与表格列名完全一致",
|
||
"2. time_tiers 必须覆盖所有列名",
|
||
"3. 根据列名含义和统计摘要推断每个变量的时间层级",
|
||
"4. 只输出 JSON,不要包含其他任何内容",
|
||
"5. 不要使用 markdown 代码块标记",
|
||
]
|
||
if custom_prompt:
|
||
parts.insert(1, f"\n用户补充说明:{custom_prompt}\n")
|
||
return "\n".join(parts)
|
||
|
||
def _validate(self, parsed: Dict[str, Any], columns: List[str]):
|
||
"""验证 LLM 返回结果"""
|
||
treatment = parsed.get("treatment")
|
||
outcome = parsed.get("outcome")
|
||
tiers = parsed.get("time_tiers", {})
|
||
|
||
if treatment not in columns:
|
||
raise ValueError(f"LLM 返回的 treatment '{treatment}' 不在数据列中")
|
||
if outcome not in columns:
|
||
raise ValueError(f"LLM 返回的 outcome '{outcome}' 不在数据列中")
|
||
if not isinstance(tiers, dict):
|
||
raise ValueError("time_tiers 必须是字典类型")
|
||
|
||
missing = set(columns) - set(tiers.keys())
|
||
if missing:
|
||
raise ValueError(f"time_tiers 缺少以下列: {sorted(missing)}")
|
||
|
||
for col, tier in tiers.items():
|
||
if not isinstance(tier, int):
|
||
raise ValueError(f"time_tiers 中 '{col}' 的层级必须是整数")
|