2026-03-29 23:47:20 +08:00

201 lines
6.5 KiB
Python

"""
数据验证模块
用于检查数据列是否匹配 LLM 识别的处理变量和结果变量
"""
import pandas as pd
from typing import Dict, Any, Tuple, Optional, List
class ValidationResult:
"""验证结果模型"""
def __init__(
self,
is_valid: bool,
treatment: Optional[str] = None,
outcome: Optional[str] = None,
errors: Optional[List[str]] = None,
warnings: Optional[List[str]] = None
):
self.is_valid = is_valid
self.treatment = treatment
self.outcome = outcome
self.errors = errors or []
self.warnings = warnings or []
class DataValidator:
"""数据验证器"""
def __init__(self, data_path: str):
"""
初始化数据验证器
Args:
data_path: 数据文件路径
"""
self.data_path = data_path
self.data = None
def load_data(self) -> pd.DataFrame:
"""加载数据"""
if self.data is None:
self.data = pd.read_excel(self.data_path)
return self.data
def validate_columns(
self,
llm_result: Dict[str, Any]
) -> ValidationResult:
"""
验证数据列是否匹配 LLM 识别的变量
Args:
llm_result: LLM 输出的 JSON 结果,包含 treatment 和 outcome 字段
Returns:
验证结果
"""
errors = []
warnings = []
is_valid = True
# 加载数据
try:
self.load_data()
except Exception as e:
return ValidationResult(
is_valid=False,
treatment=None,
outcome=None,
errors=[f"加载数据失败:{str(e)}"],
warnings=[]
)
# 获取数据列名
assert self.data is not None, "数据未加载"
available_columns = list(self.data.columns)
# 获取 LLM 识别的变量
treatment = llm_result.get('treatment')
outcome = llm_result.get('outcome')
# 验证处理变量
if treatment is None:
errors.append("LLM 未识别处理变量 (treatment)")
is_valid = False
elif treatment not in available_columns:
errors.append(f"数据中不存在处理变量 '{treatment}'")
errors.append(f"可用列名:{available_columns}")
is_valid = False
else:
# 检查处理变量类型
treatment_col = self.data[treatment]
if treatment_col.nunique() > 2:
warnings.append(f"处理变量 '{treatment}' 有多个唯一值,可能不是二元变量")
elif treatment_col.dtype not in ['int64', 'float64', 'bool', 'object']:
warnings.append(f"处理变量 '{treatment}' 的数据类型可能不适合因果分析")
# 验证结果变量
if outcome is None:
errors.append("LLM 未识别结果变量 (outcome)")
is_valid = False
elif outcome not in available_columns:
errors.append(f"数据中不存在结果变量 '{outcome}'")
errors.append(f"可用列名:{available_columns}")
is_valid = False
else:
# 检查结果变量类型
outcome_col = self.data[outcome]
if outcome_col.dtype not in ['int64', 'float64']:
warnings.append(f"结果变量 '{outcome}' 的数据类型可能不适合因果分析")
# 检查样本量
if len(self.data) < 10:
warnings.append(f"样本量 ({len(self.data)}) 较小,可能影响分析结果")
# 检查缺失值
missing_treatment = self.data[treatment].isna().sum() if treatment and treatment in available_columns else 0
missing_outcome = self.data[outcome].isna().sum() if outcome and outcome in available_columns else 0
if missing_treatment > 0:
warnings.append(f"处理变量 '{treatment}'{missing_treatment} 个缺失值")
if missing_outcome > 0:
warnings.append(f"结果变量 '{outcome}'{missing_outcome} 个缺失值")
return ValidationResult(
is_valid=is_valid,
treatment=treatment,
outcome=outcome,
errors=errors,
warnings=warnings
)
def validate_and_raise(
self,
llm_result: Dict[str, Any]
) -> Tuple[Optional[str], Optional[str]]:
"""
验证并抛出异常(如果验证失败)
Args:
llm_result: LLM 输出的 JSON 结果
Returns:
(treatment, outcome) 元组,如果验证失败则抛出异常
Raises:
ValueError: 验证失败时抛出
"""
result = self.validate_columns(llm_result)
if not result.is_valid:
error_msg = "数据验证失败:\n"
for error in result.errors:
error_msg += f" - {error}\n"
raise ValueError(error_msg)
return result.treatment, result.outcome
# 便捷函数
def validate_data(data_path: str, llm_result: Dict[str, Any]) -> ValidationResult:
"""
便捷函数:验证数据
Args:
data_path: 数据文件路径
llm_result: LLM 输出的 JSON 结果
Returns:
验证结果
"""
validator = DataValidator(data_path)
return validator.validate_columns(llm_result)
if __name__ == "__main__":
# 测试数据验证
print("测试数据验证模块")
print("=" * 50)
# 测试 1: 有效数据
print("\n测试 1: 有效数据")
validator = DataValidator("examples/medical/data.xlsx")
llm_result = {"treatment": "treatment", "outcome": "health"}
result = validator.validate_columns(llm_result)
print(f" 是否有效:{result.is_valid}")
print(f" 处理变量:{result.treatment}")
print(f" 结果变量:{result.outcome}")
print(f" 错误:{result.errors}")
print(f" 警告:{result.warnings}")
# 测试 2: 无效数据(变量名不存在)
print("\n测试 2: 无效数据(变量名不存在)")
llm_result_invalid = {"treatment": "invalid_col", "outcome": "health"}
result_invalid = validator.validate_columns(llm_result_invalid)
print(f" 是否有效:{result_invalid.is_valid}")
print(f" 错误:{result_invalid.errors}")