2026-03-29 23:47:20 +08:00

218 lines
6.8 KiB
Python

"""
输入处理模块测试脚本
该脚本测试文件查询工具和因果推断解析器的功能。
"""
import os
import sys
import json
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# 导入模块
from input_processor import FileQueryTool, CausalParser, CausalInferenceRequest
def test_file_query_tool():
"""测试文件查询工具"""
print("=" * 60)
print("测试文件查询工具")
print("=" * 60)
# 使用项目中的示例数据
test_files = [
"data/output/simple_ate_data.csv",
"data/output/simple_ate_data.xlsx"
]
tool = FileQueryTool(default_read_rows=100)
for file_path in test_files:
if not os.path.exists(file_path):
print(f"\n文件不存在,跳过:{file_path}")
continue
print(f"\n测试文件:{file_path}")
print("-" * 40)
try:
# 测试基本查询
metadata = tool.query(file_path)
print(f"文件名:{metadata.file_name}")
print(f"文件类型:{metadata.file_type}")
print(f"文件大小:{metadata.file_size} 字节")
print(f"总行数:{metadata.total_rows}")
print(f"总列数:{metadata.total_columns}")
print(f"实际读取行数:{metadata.read_rows}")
print("\n列信息:")
for col in metadata.columns:
print(f" - {col.name}: {col.dtype}")
print(f" 非空值:{col.non_null_count}, 空值:{col.null_count}")
print(f" 唯一值:{col.unique_count}")
if col.sample_values:
print(f" 样本值:{col.sample_values[:3]}")
print("\n样本数据:")
if metadata.sample_data:
for i, row in enumerate(metadata.sample_data.rows[:3]):
print(f"{i+1}: {row}")
# 测试获取列名
print("\n列名列表:")
col_names = tool.get_column_names(file_path)
print(f" {col_names}")
# 测试获取数据类型
if col_names:
first_col = col_names[0]
dtype = tool.get_column_dtype(file_path, first_col)
print(f"\n'{first_col}' 的数据类型:{dtype}")
# 测试获取文件信息字典
print("\n文件信息字典:")
file_info = tool.get_file_info(file_path)
print(json.dumps(file_info, ensure_ascii=False, indent=2))
except Exception as e:
print(f"错误:{e}")
print("\n" + "=" * 60)
print("文件查询工具测试完成")
print("=" * 60)
def test_causal_parser():
"""测试因果推断解析器"""
print("\n" + "=" * 60)
print("测试因果推断解析器")
print("=" * 60)
# 使用项目中的示例数据
test_file = "data/output/simple_ate_data.csv"
if not os.path.exists(test_file):
print(f"文件不存在,跳过测试:{test_file}")
return
parser = CausalParser()
# 测试问题列表
test_questions = [
"教育对收入的影响是什么?",
"处理变量对结果变量的影响如何?",
"分析 treatment 对 outcome 的因果效应",
"这个数据集中 treatment 和 outcome 有什么关系?",
"如何评估干预措施的效果?"
]
for question in test_questions:
print(f"\n用户问题:{question}")
print("-" * 40)
try:
# 使用简化的解析方法
result = parser.parse_simple(question, test_file)
# 输出结果
print(result.to_json())
# 结构化输出
print("\n结构化结果:")
if result.treatment_variable:
print(f" 处理变量 (T): {result.treatment_variable.name}")
print(f" 置信度:{result.treatment_variable.confidence:.2f}")
print(f" 推理:{result.treatment_variable.reasoning}")
if result.outcome_variable:
print(f" 结果变量 (Y): {result.outcome_variable.name}")
print(f" 置信度:{result.outcome_variable.confidence:.2f}")
print(f" 推理:{result.outcome_variable.reasoning}")
if result.control_variables:
print(f" 协变量 ({len(result.control_variables)} 个):")
for var in result.control_variables[:3]:
print(f" - {var.name} (置信度:{var.confidence:.2f})")
if len(result.control_variables) > 3:
print(f" ... 还有 {len(result.control_variables) - 3}")
print(f"\n 整体置信度:{result.confidence_score:.2f}")
print(f" 推理理由:{result.reasoning}")
except Exception as e:
print(f"错误:{e}")
print("\n" + "=" * 60)
print("因果推断解析器测试完成")
print("=" * 60)
def test_causal_parser_with_request():
"""测试使用 CausalInferenceRequest 对象的解析"""
print("\n" + "=" * 60)
print("测试使用 CausalInferenceRequest 对象")
print("=" * 60)
test_file = "data/output/simple_ate_data.csv"
if not os.path.exists(test_file):
print(f"文件不存在,跳过测试:{test_file}")
return
parser = CausalParser()
# 创建请求对象
request = CausalInferenceRequest(
user_question="分析 treatment 变量对 outcome 变量的影响",
file_path=test_file,
read_rows=50,
llm_provider="mock",
llm_model="mock-model"
)
print(f"请求信息:")
print(f" 问题:{request.user_question}")
print(f" 文件:{request.file_path}")
print(f" 读取行数:{request.read_rows}")
print(f" LLM 提供商:{request.llm_provider}")
print(f" LLM 模型:{request.llm_model}")
# 执行解析
result = parser.parse(request)
print("\n解析结果:")
print(result.to_json())
print("\n" + "=" * 60)
print("请求对象测试完成")
print("=" * 60)
def main():
"""主测试函数"""
print("\n" + "=" * 60)
print("输入处理模块测试")
print("=" * 60)
# 测试文件查询工具
test_file_query_tool()
# 测试因果推断解析器
test_causal_parser()
# 测试使用请求对象的解析
test_causal_parser_with_request()
print("\n" + "=" * 60)
print("所有测试完成!")
print("=" * 60)
if __name__ == "__main__":
main()