2026-03-29 23:47:20 +08:00

134 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
变量解析模块
使用 LLM 识别处理变量、结果变量,并对每个变量进行时间层级解析。
"""
from typing import Any, Dict, List, Optional
from causal_agent.core.llm_client import LLMClient
class VariableParser:
"""LLM 变量解析器"""
SYSTEM_PROMPT = """你是一位专业的因果推断分析师。你的任务是分析给定的数据识别处理变量treatment、结果变量outcome并对每个变量进行时间层级解析。
请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。
JSON 输出规范:
{
"treatment": "处理变量名称",
"outcome": "结果变量名称",
"time_tiers": {
"变量名1": 整数层级,
"变量名2": 整数层级,
...
}
}
time_tiers 层级说明(整数,越小表示越早发生):
- -1: 非时间变量(如样本唯一标识符 id、index 等)
- 0: 人口学特征或不变的混杂因素(如 age、gender、region 等)
- 1: 基线测量(干预前测得,可能是混杂因素,如 baseline_score、pre_test 等)
- 2: 干预点/处理变量(如 treatment、intervention、policy 等)
- 3: 中介变量(干预后、结果前测得)
- 4: 随访结果/结果变量(如 outcome、post_test、score 等)
- 5+: 更晚的时间点(如有多次随访)
注意:
- 只输出上述 JSON 格式,不要包含其他字段
- treatment 和 outcome 必须是数据表格中真实存在的列名
- time_tiers 必须包含数据中的所有列名
- 不要使用 markdown 代码块标记(如 ```json
- 直接输出纯 JSON 字符串"""
def __init__(self, llm_client: Optional[LLMClient] = None):
self.llm_client = llm_client or LLMClient()
def parse(
self,
columns: List[str],
data_info: str,
custom_prompt: Optional[str] = None,
) -> Dict[str, Any]:
"""
调用 LLM 解析变量。
Args:
columns: 数据列名列表
data_info: 数据信息摘要字符串
custom_prompt: 可选的自定义用户提示词补充
Returns:
包含 treatment、outcome、time_tiers 的字典
"""
user_prompt = self._build_user_prompt(columns, data_info, custom_prompt)
result = self.llm_client.generate_response(self.SYSTEM_PROMPT, user_prompt)
if not result.get("success"):
raise RuntimeError(f"LLM 调用失败: {result.get('error')}")
try:
parsed = LLMClient.parse_json_response(result["content"])
except Exception as e:
raise ValueError(f"LLM 输出解析失败: {e}\n原始内容: {result.get('content')}")
self._validate(parsed, columns)
return parsed
def _build_user_prompt(
self,
columns: List[str],
data_info: str,
custom_prompt: Optional[str] = None,
) -> str:
"""构建用户提示词"""
parts = [
"请分析以下数据,并严格按照 JSON 格式输出分析结果:",
"",
data_info,
"",
"JSON 输出格式要求:",
'{',
' "treatment": "处理变量名称",',
' "outcome": "结果变量名称",',
' "time_tiers": {',
' "列名1": 层级整数,',
' "列名2": 层级整数,',
" ...",
" }",
"}",
"",
"要求:",
"1. treatment 和 outcome 必须与表格列名完全一致",
"2. time_tiers 必须覆盖所有列名",
"3. 根据列名含义和统计摘要推断每个变量的时间层级",
"4. 只输出 JSON不要包含其他任何内容",
"5. 不要使用 markdown 代码块标记",
]
if custom_prompt:
parts.insert(1, f"\n用户补充说明:{custom_prompt}\n")
return "\n".join(parts)
def _validate(self, parsed: Dict[str, Any], columns: List[str]):
"""验证 LLM 返回结果"""
treatment = parsed.get("treatment")
outcome = parsed.get("outcome")
tiers = parsed.get("time_tiers", {})
if treatment not in columns:
raise ValueError(f"LLM 返回的 treatment '{treatment}' 不在数据列中")
if outcome not in columns:
raise ValueError(f"LLM 返回的 outcome '{outcome}' 不在数据列中")
if not isinstance(tiers, dict):
raise ValueError("time_tiers 必须是字典类型")
missing = set(columns) - set(tiers.keys())
if missing:
raise ValueError(f"time_tiers 缺少以下列: {sorted(missing)}")
for col, tier in tiers.items():
if not isinstance(tier, int):
raise ValueError(f"time_tiers 中 '{col}' 的层级必须是整数")