218 lines
6.8 KiB
Python
218 lines
6.8 KiB
Python
"""
|
|
输入处理模块测试脚本
|
|
|
|
该脚本测试文件查询工具和因果推断解析器的功能。
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# 添加项目根目录到 Python 路径
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
# 导入模块
|
|
from input_processor import FileQueryTool, CausalParser, CausalInferenceRequest
|
|
|
|
|
|
def test_file_query_tool():
|
|
"""测试文件查询工具"""
|
|
print("=" * 60)
|
|
print("测试文件查询工具")
|
|
print("=" * 60)
|
|
|
|
# 使用项目中的示例数据
|
|
test_files = [
|
|
"data/output/simple_ate_data.csv",
|
|
"data/output/simple_ate_data.xlsx"
|
|
]
|
|
|
|
tool = FileQueryTool(default_read_rows=100)
|
|
|
|
for file_path in test_files:
|
|
if not os.path.exists(file_path):
|
|
print(f"\n文件不存在,跳过:{file_path}")
|
|
continue
|
|
|
|
print(f"\n测试文件:{file_path}")
|
|
print("-" * 40)
|
|
|
|
try:
|
|
# 测试基本查询
|
|
metadata = tool.query(file_path)
|
|
|
|
print(f"文件名:{metadata.file_name}")
|
|
print(f"文件类型:{metadata.file_type}")
|
|
print(f"文件大小:{metadata.file_size} 字节")
|
|
print(f"总行数:{metadata.total_rows}")
|
|
print(f"总列数:{metadata.total_columns}")
|
|
print(f"实际读取行数:{metadata.read_rows}")
|
|
|
|
print("\n列信息:")
|
|
for col in metadata.columns:
|
|
print(f" - {col.name}: {col.dtype}")
|
|
print(f" 非空值:{col.non_null_count}, 空值:{col.null_count}")
|
|
print(f" 唯一值:{col.unique_count}")
|
|
if col.sample_values:
|
|
print(f" 样本值:{col.sample_values[:3]}")
|
|
|
|
print("\n样本数据:")
|
|
if metadata.sample_data:
|
|
for i, row in enumerate(metadata.sample_data.rows[:3]):
|
|
print(f" 行 {i+1}: {row}")
|
|
|
|
# 测试获取列名
|
|
print("\n列名列表:")
|
|
col_names = tool.get_column_names(file_path)
|
|
print(f" {col_names}")
|
|
|
|
# 测试获取数据类型
|
|
if col_names:
|
|
first_col = col_names[0]
|
|
dtype = tool.get_column_dtype(file_path, first_col)
|
|
print(f"\n列 '{first_col}' 的数据类型:{dtype}")
|
|
|
|
# 测试获取文件信息字典
|
|
print("\n文件信息字典:")
|
|
file_info = tool.get_file_info(file_path)
|
|
print(json.dumps(file_info, ensure_ascii=False, indent=2))
|
|
|
|
except Exception as e:
|
|
print(f"错误:{e}")
|
|
|
|
print("\n" + "=" * 60)
|
|
print("文件查询工具测试完成")
|
|
print("=" * 60)
|
|
|
|
|
|
def test_causal_parser():
|
|
"""测试因果推断解析器"""
|
|
print("\n" + "=" * 60)
|
|
print("测试因果推断解析器")
|
|
print("=" * 60)
|
|
|
|
# 使用项目中的示例数据
|
|
test_file = "data/output/simple_ate_data.csv"
|
|
|
|
if not os.path.exists(test_file):
|
|
print(f"文件不存在,跳过测试:{test_file}")
|
|
return
|
|
|
|
parser = CausalParser()
|
|
|
|
# 测试问题列表
|
|
test_questions = [
|
|
"教育对收入的影响是什么?",
|
|
"处理变量对结果变量的影响如何?",
|
|
"分析 treatment 对 outcome 的因果效应",
|
|
"这个数据集中 treatment 和 outcome 有什么关系?",
|
|
"如何评估干预措施的效果?"
|
|
]
|
|
|
|
for question in test_questions:
|
|
print(f"\n用户问题:{question}")
|
|
print("-" * 40)
|
|
|
|
try:
|
|
# 使用简化的解析方法
|
|
result = parser.parse_simple(question, test_file)
|
|
|
|
# 输出结果
|
|
print(result.to_json())
|
|
|
|
# 结构化输出
|
|
print("\n结构化结果:")
|
|
if result.treatment_variable:
|
|
print(f" 处理变量 (T): {result.treatment_variable.name}")
|
|
print(f" 置信度:{result.treatment_variable.confidence:.2f}")
|
|
print(f" 推理:{result.treatment_variable.reasoning}")
|
|
|
|
if result.outcome_variable:
|
|
print(f" 结果变量 (Y): {result.outcome_variable.name}")
|
|
print(f" 置信度:{result.outcome_variable.confidence:.2f}")
|
|
print(f" 推理:{result.outcome_variable.reasoning}")
|
|
|
|
if result.control_variables:
|
|
print(f" 协变量 ({len(result.control_variables)} 个):")
|
|
for var in result.control_variables[:3]:
|
|
print(f" - {var.name} (置信度:{var.confidence:.2f})")
|
|
if len(result.control_variables) > 3:
|
|
print(f" ... 还有 {len(result.control_variables) - 3} 个")
|
|
|
|
print(f"\n 整体置信度:{result.confidence_score:.2f}")
|
|
print(f" 推理理由:{result.reasoning}")
|
|
|
|
except Exception as e:
|
|
print(f"错误:{e}")
|
|
|
|
print("\n" + "=" * 60)
|
|
print("因果推断解析器测试完成")
|
|
print("=" * 60)
|
|
|
|
|
|
def test_causal_parser_with_request():
|
|
"""测试使用 CausalInferenceRequest 对象的解析"""
|
|
print("\n" + "=" * 60)
|
|
print("测试使用 CausalInferenceRequest 对象")
|
|
print("=" * 60)
|
|
|
|
test_file = "data/output/simple_ate_data.csv"
|
|
|
|
if not os.path.exists(test_file):
|
|
print(f"文件不存在,跳过测试:{test_file}")
|
|
return
|
|
|
|
parser = CausalParser()
|
|
|
|
# 创建请求对象
|
|
request = CausalInferenceRequest(
|
|
user_question="分析 treatment 变量对 outcome 变量的影响",
|
|
file_path=test_file,
|
|
read_rows=50,
|
|
llm_provider="mock",
|
|
llm_model="mock-model"
|
|
)
|
|
|
|
print(f"请求信息:")
|
|
print(f" 问题:{request.user_question}")
|
|
print(f" 文件:{request.file_path}")
|
|
print(f" 读取行数:{request.read_rows}")
|
|
print(f" LLM 提供商:{request.llm_provider}")
|
|
print(f" LLM 模型:{request.llm_model}")
|
|
|
|
# 执行解析
|
|
result = parser.parse(request)
|
|
|
|
print("\n解析结果:")
|
|
print(result.to_json())
|
|
|
|
print("\n" + "=" * 60)
|
|
print("请求对象测试完成")
|
|
print("=" * 60)
|
|
|
|
|
|
def main():
|
|
"""主测试函数"""
|
|
print("\n" + "=" * 60)
|
|
print("输入处理模块测试")
|
|
print("=" * 60)
|
|
|
|
# 测试文件查询工具
|
|
test_file_query_tool()
|
|
|
|
# 测试因果推断解析器
|
|
test_causal_parser()
|
|
|
|
# 测试使用请求对象的解析
|
|
test_causal_parser_with_request()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("所有测试完成!")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|