201 lines
6.5 KiB
Python
201 lines
6.5 KiB
Python
"""
|
|
数据验证模块
|
|
用于检查数据列是否匹配 LLM 识别的处理变量和结果变量
|
|
"""
|
|
|
|
import pandas as pd
|
|
from typing import Dict, Any, Tuple, Optional, List
|
|
|
|
|
|
class ValidationResult:
|
|
"""验证结果模型"""
|
|
def __init__(
|
|
self,
|
|
is_valid: bool,
|
|
treatment: Optional[str] = None,
|
|
outcome: Optional[str] = None,
|
|
errors: Optional[List[str]] = None,
|
|
warnings: Optional[List[str]] = None
|
|
):
|
|
self.is_valid = is_valid
|
|
self.treatment = treatment
|
|
self.outcome = outcome
|
|
self.errors = errors or []
|
|
self.warnings = warnings or []
|
|
|
|
|
|
class DataValidator:
|
|
"""数据验证器"""
|
|
|
|
def __init__(self, data_path: str):
|
|
"""
|
|
初始化数据验证器
|
|
|
|
Args:
|
|
data_path: 数据文件路径
|
|
"""
|
|
self.data_path = data_path
|
|
self.data = None
|
|
|
|
def load_data(self) -> pd.DataFrame:
|
|
"""加载数据"""
|
|
if self.data is None:
|
|
self.data = pd.read_excel(self.data_path)
|
|
return self.data
|
|
|
|
def validate_columns(
|
|
self,
|
|
llm_result: Dict[str, Any]
|
|
) -> ValidationResult:
|
|
"""
|
|
验证数据列是否匹配 LLM 识别的变量
|
|
|
|
Args:
|
|
llm_result: LLM 输出的 JSON 结果,包含 treatment 和 outcome 字段
|
|
|
|
Returns:
|
|
验证结果
|
|
"""
|
|
errors = []
|
|
warnings = []
|
|
is_valid = True
|
|
|
|
# 加载数据
|
|
try:
|
|
self.load_data()
|
|
except Exception as e:
|
|
return ValidationResult(
|
|
is_valid=False,
|
|
treatment=None,
|
|
outcome=None,
|
|
errors=[f"加载数据失败:{str(e)}"],
|
|
warnings=[]
|
|
)
|
|
|
|
# 获取数据列名
|
|
assert self.data is not None, "数据未加载"
|
|
available_columns = list(self.data.columns)
|
|
|
|
# 获取 LLM 识别的变量
|
|
treatment = llm_result.get('treatment')
|
|
outcome = llm_result.get('outcome')
|
|
|
|
# 验证处理变量
|
|
if treatment is None:
|
|
errors.append("LLM 未识别处理变量 (treatment)")
|
|
is_valid = False
|
|
elif treatment not in available_columns:
|
|
errors.append(f"数据中不存在处理变量 '{treatment}'")
|
|
errors.append(f"可用列名:{available_columns}")
|
|
is_valid = False
|
|
else:
|
|
# 检查处理变量类型
|
|
treatment_col = self.data[treatment]
|
|
if treatment_col.nunique() > 2:
|
|
warnings.append(f"处理变量 '{treatment}' 有多个唯一值,可能不是二元变量")
|
|
elif treatment_col.dtype not in ['int64', 'float64', 'bool', 'object']:
|
|
warnings.append(f"处理变量 '{treatment}' 的数据类型可能不适合因果分析")
|
|
|
|
# 验证结果变量
|
|
if outcome is None:
|
|
errors.append("LLM 未识别结果变量 (outcome)")
|
|
is_valid = False
|
|
elif outcome not in available_columns:
|
|
errors.append(f"数据中不存在结果变量 '{outcome}'")
|
|
errors.append(f"可用列名:{available_columns}")
|
|
is_valid = False
|
|
else:
|
|
# 检查结果变量类型
|
|
outcome_col = self.data[outcome]
|
|
if outcome_col.dtype not in ['int64', 'float64']:
|
|
warnings.append(f"结果变量 '{outcome}' 的数据类型可能不适合因果分析")
|
|
|
|
# 检查样本量
|
|
if len(self.data) < 10:
|
|
warnings.append(f"样本量 ({len(self.data)}) 较小,可能影响分析结果")
|
|
|
|
# 检查缺失值
|
|
missing_treatment = self.data[treatment].isna().sum() if treatment and treatment in available_columns else 0
|
|
missing_outcome = self.data[outcome].isna().sum() if outcome and outcome in available_columns else 0
|
|
|
|
if missing_treatment > 0:
|
|
warnings.append(f"处理变量 '{treatment}' 有 {missing_treatment} 个缺失值")
|
|
if missing_outcome > 0:
|
|
warnings.append(f"结果变量 '{outcome}' 有 {missing_outcome} 个缺失值")
|
|
|
|
return ValidationResult(
|
|
is_valid=is_valid,
|
|
treatment=treatment,
|
|
outcome=outcome,
|
|
errors=errors,
|
|
warnings=warnings
|
|
)
|
|
|
|
def validate_and_raise(
|
|
self,
|
|
llm_result: Dict[str, Any]
|
|
) -> Tuple[Optional[str], Optional[str]]:
|
|
"""
|
|
验证并抛出异常(如果验证失败)
|
|
|
|
Args:
|
|
llm_result: LLM 输出的 JSON 结果
|
|
|
|
Returns:
|
|
(treatment, outcome) 元组,如果验证失败则抛出异常
|
|
|
|
Raises:
|
|
ValueError: 验证失败时抛出
|
|
"""
|
|
result = self.validate_columns(llm_result)
|
|
|
|
if not result.is_valid:
|
|
error_msg = "数据验证失败:\n"
|
|
for error in result.errors:
|
|
error_msg += f" - {error}\n"
|
|
raise ValueError(error_msg)
|
|
|
|
return result.treatment, result.outcome
|
|
|
|
|
|
# 便捷函数
|
|
def validate_data(data_path: str, llm_result: Dict[str, Any]) -> ValidationResult:
|
|
"""
|
|
便捷函数:验证数据
|
|
|
|
Args:
|
|
data_path: 数据文件路径
|
|
llm_result: LLM 输出的 JSON 结果
|
|
|
|
Returns:
|
|
验证结果
|
|
"""
|
|
validator = DataValidator(data_path)
|
|
return validator.validate_columns(llm_result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# 测试数据验证
|
|
print("测试数据验证模块")
|
|
print("=" * 50)
|
|
|
|
# 测试 1: 有效数据
|
|
print("\n测试 1: 有效数据")
|
|
validator = DataValidator("examples/medical/data.xlsx")
|
|
llm_result = {"treatment": "treatment", "outcome": "health"}
|
|
result = validator.validate_columns(llm_result)
|
|
|
|
print(f" 是否有效:{result.is_valid}")
|
|
print(f" 处理变量:{result.treatment}")
|
|
print(f" 结果变量:{result.outcome}")
|
|
print(f" 错误:{result.errors}")
|
|
print(f" 警告:{result.warnings}")
|
|
|
|
# 测试 2: 无效数据(变量名不存在)
|
|
print("\n测试 2: 无效数据(变量名不存在)")
|
|
llm_result_invalid = {"treatment": "invalid_col", "outcome": "health"}
|
|
result_invalid = validator.validate_columns(llm_result_invalid)
|
|
|
|
print(f" 是否有效:{result_invalid.is_valid}")
|
|
print(f" 错误:{result_invalid.errors}")
|