""" 输入处理模块测试脚本 该脚本测试文件查询工具和因果推断解析器的功能。 """ import os import sys import json from pathlib import Path # 添加项目根目录到 Python 路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) # 导入模块 from input_processor import FileQueryTool, CausalParser, CausalInferenceRequest def test_file_query_tool(): """测试文件查询工具""" print("=" * 60) print("测试文件查询工具") print("=" * 60) # 使用项目中的示例数据 test_files = [ "data/output/simple_ate_data.csv", "data/output/simple_ate_data.xlsx" ] tool = FileQueryTool(default_read_rows=100) for file_path in test_files: if not os.path.exists(file_path): print(f"\n文件不存在,跳过:{file_path}") continue print(f"\n测试文件:{file_path}") print("-" * 40) try: # 测试基本查询 metadata = tool.query(file_path) print(f"文件名:{metadata.file_name}") print(f"文件类型:{metadata.file_type}") print(f"文件大小:{metadata.file_size} 字节") print(f"总行数:{metadata.total_rows}") print(f"总列数:{metadata.total_columns}") print(f"实际读取行数:{metadata.read_rows}") print("\n列信息:") for col in metadata.columns: print(f" - {col.name}: {col.dtype}") print(f" 非空值:{col.non_null_count}, 空值:{col.null_count}") print(f" 唯一值:{col.unique_count}") if col.sample_values: print(f" 样本值:{col.sample_values[:3]}") print("\n样本数据:") if metadata.sample_data: for i, row in enumerate(metadata.sample_data.rows[:3]): print(f" 行 {i+1}: {row}") # 测试获取列名 print("\n列名列表:") col_names = tool.get_column_names(file_path) print(f" {col_names}") # 测试获取数据类型 if col_names: first_col = col_names[0] dtype = tool.get_column_dtype(file_path, first_col) print(f"\n列 '{first_col}' 的数据类型:{dtype}") # 测试获取文件信息字典 print("\n文件信息字典:") file_info = tool.get_file_info(file_path) print(json.dumps(file_info, ensure_ascii=False, indent=2)) except Exception as e: print(f"错误:{e}") print("\n" + "=" * 60) print("文件查询工具测试完成") print("=" * 60) def test_causal_parser(): """测试因果推断解析器""" print("\n" + "=" * 60) print("测试因果推断解析器") print("=" * 60) # 使用项目中的示例数据 test_file = "data/output/simple_ate_data.csv" if not os.path.exists(test_file): print(f"文件不存在,跳过测试:{test_file}") return parser = CausalParser() # 测试问题列表 test_questions = [ "教育对收入的影响是什么?", "处理变量对结果变量的影响如何?", "分析 treatment 对 outcome 的因果效应", "这个数据集中 treatment 和 outcome 有什么关系?", "如何评估干预措施的效果?" ] for question in test_questions: print(f"\n用户问题:{question}") print("-" * 40) try: # 使用简化的解析方法 result = parser.parse_simple(question, test_file) # 输出结果 print(result.to_json()) # 结构化输出 print("\n结构化结果:") if result.treatment_variable: print(f" 处理变量 (T): {result.treatment_variable.name}") print(f" 置信度:{result.treatment_variable.confidence:.2f}") print(f" 推理:{result.treatment_variable.reasoning}") if result.outcome_variable: print(f" 结果变量 (Y): {result.outcome_variable.name}") print(f" 置信度:{result.outcome_variable.confidence:.2f}") print(f" 推理:{result.outcome_variable.reasoning}") if result.control_variables: print(f" 协变量 ({len(result.control_variables)} 个):") for var in result.control_variables[:3]: print(f" - {var.name} (置信度:{var.confidence:.2f})") if len(result.control_variables) > 3: print(f" ... 还有 {len(result.control_variables) - 3} 个") print(f"\n 整体置信度:{result.confidence_score:.2f}") print(f" 推理理由:{result.reasoning}") except Exception as e: print(f"错误:{e}") print("\n" + "=" * 60) print("因果推断解析器测试完成") print("=" * 60) def test_causal_parser_with_request(): """测试使用 CausalInferenceRequest 对象的解析""" print("\n" + "=" * 60) print("测试使用 CausalInferenceRequest 对象") print("=" * 60) test_file = "data/output/simple_ate_data.csv" if not os.path.exists(test_file): print(f"文件不存在,跳过测试:{test_file}") return parser = CausalParser() # 创建请求对象 request = CausalInferenceRequest( user_question="分析 treatment 变量对 outcome 变量的影响", file_path=test_file, read_rows=50, llm_provider="mock", llm_model="mock-model" ) print(f"请求信息:") print(f" 问题:{request.user_question}") print(f" 文件:{request.file_path}") print(f" 读取行数:{request.read_rows}") print(f" LLM 提供商:{request.llm_provider}") print(f" LLM 模型:{request.llm_model}") # 执行解析 result = parser.parse(request) print("\n解析结果:") print(result.to_json()) print("\n" + "=" * 60) print("请求对象测试完成") print("=" * 60) def main(): """主测试函数""" print("\n" + "=" * 60) print("输入处理模块测试") print("=" * 60) # 测试文件查询工具 test_file_query_tool() # 测试因果推断解析器 test_causal_parser() # 测试使用请求对象的解析 test_causal_parser_with_request() print("\n" + "=" * 60) print("所有测试完成!") print("=" * 60) if __name__ == "__main__": main()