""" 变量解析模块 使用 LLM 识别处理变量、结果变量,并对每个变量进行时间层级解析。 """ from typing import Any, Dict, List, Optional from causal_agent.core.llm_client import LLMClient class VariableParser: """LLM 变量解析器""" SYSTEM_PROMPT = """你是一位专业的因果推断分析师。你的任务是分析给定的数据,识别处理变量(treatment)、结果变量(outcome),并对每个变量进行时间层级解析。 请以 JSON 格式输出分析结果,不要包含任何额外的解释或思考过程。 JSON 输出规范: { "treatment": "处理变量名称", "outcome": "结果变量名称", "time_tiers": { "变量名1": 整数层级, "变量名2": 整数层级, ... } } time_tiers 层级说明(整数,越小表示越早发生): - -1: 非时间变量(如样本唯一标识符 id、index 等) - 0: 人口学特征或不变的混杂因素(如 age、gender、region 等) - 1: 基线测量(干预前测得,可能是混杂因素,如 baseline_score、pre_test 等) - 2: 干预点/处理变量(如 treatment、intervention、policy 等) - 3: 中介变量(干预后、结果前测得) - 4: 随访结果/结果变量(如 outcome、post_test、score 等) - 5+: 更晚的时间点(如有多次随访) 注意: - 只输出上述 JSON 格式,不要包含其他字段 - treatment 和 outcome 必须是数据表格中真实存在的列名 - time_tiers 必须包含数据中的所有列名 - 不要使用 markdown 代码块标记(如 ```json) - 直接输出纯 JSON 字符串""" def __init__(self, llm_client: Optional[LLMClient] = None): self.llm_client = llm_client or LLMClient() def parse( self, columns: List[str], data_info: str, custom_prompt: Optional[str] = None, ) -> Dict[str, Any]: """ 调用 LLM 解析变量。 Args: columns: 数据列名列表 data_info: 数据信息摘要字符串 custom_prompt: 可选的自定义用户提示词补充 Returns: 包含 treatment、outcome、time_tiers 的字典 """ user_prompt = self._build_user_prompt(columns, data_info, custom_prompt) result = self.llm_client.generate_response(self.SYSTEM_PROMPT, user_prompt) if not result.get("success"): raise RuntimeError(f"LLM 调用失败: {result.get('error')}") try: parsed = LLMClient.parse_json_response(result["content"]) except Exception as e: raise ValueError(f"LLM 输出解析失败: {e}\n原始内容: {result.get('content')}") self._validate(parsed, columns) return parsed def _build_user_prompt( self, columns: List[str], data_info: str, custom_prompt: Optional[str] = None, ) -> str: """构建用户提示词""" parts = [ "请分析以下数据,并严格按照 JSON 格式输出分析结果:", "", data_info, "", "JSON 输出格式要求:", '{', ' "treatment": "处理变量名称",', ' "outcome": "结果变量名称",', ' "time_tiers": {', ' "列名1": 层级整数,', ' "列名2": 层级整数,', " ...", " }", "}", "", "要求:", "1. treatment 和 outcome 必须与表格列名完全一致", "2. time_tiers 必须覆盖所有列名", "3. 根据列名含义和统计摘要推断每个变量的时间层级", "4. 只输出 JSON,不要包含其他任何内容", "5. 不要使用 markdown 代码块标记", ] if custom_prompt: parts.insert(1, f"\n用户补充说明:{custom_prompt}\n") return "\n".join(parts) def _validate(self, parsed: Dict[str, Any], columns: List[str]): """验证 LLM 返回结果""" treatment = parsed.get("treatment") outcome = parsed.get("outcome") tiers = parsed.get("time_tiers", {}) if treatment not in columns: raise ValueError(f"LLM 返回的 treatment '{treatment}' 不在数据列中") if outcome not in columns: raise ValueError(f"LLM 返回的 outcome '{outcome}' 不在数据列中") if not isinstance(tiers, dict): raise ValueError("time_tiers 必须是字典类型") missing = set(columns) - set(tiers.keys()) if missing: raise ValueError(f"time_tiers 缺少以下列: {sorted(missing)}") for col, tier in tiers.items(): if not isinstance(tier, int): raise ValueError(f"time_tiers 中 '{col}' 的层级必须是整数")