""" 通用数据加载器 支持 CSV、Excel、JSON Lines 等格式,同时提供数据信息摘要生成。 """ import os from typing import Optional import pandas as pd def _read_json_lines(file_path: str) -> pd.DataFrame: """读取 JSON Lines 文件""" return pd.read_json(file_path, lines=True) class DataLoader: """通用数据加载器""" SUPPORTED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".json", ".jsonl"} def __init__(self, default_read_rows: Optional[int] = None): self.default_read_rows = default_read_rows def load(self, file_path: str) -> pd.DataFrame: """加载数据文件为 DataFrame""" if not os.path.exists(file_path): raise FileNotFoundError(f"数据文件不存在:{file_path}") ext = os.path.splitext(file_path)[1].lower() if ext not in self.SUPPORTED_EXTENSIONS: raise ValueError( f"不支持的文件格式:{ext},仅支持 {self.SUPPORTED_EXTENSIONS}" ) if ext == ".csv": kwargs = {} if self.default_read_rows is not None: kwargs["nrows"] = self.default_read_rows return pd.read_csv(file_path, **kwargs) elif ext in {".xlsx", ".xls"}: return pd.read_excel(file_path) elif ext in {".json", ".jsonl"}: return _read_json_lines(file_path) raise ValueError(f"未实现的文件格式:{ext}") def get_data_info(df: pd.DataFrame, treatment_col: Optional[str] = None) -> str: """生成数据信息摘要字符串,用于 LLM prompt""" lines = [ f"**数据概览:**", f"- 样本数量:{len(df)}", f"- 变量:{', '.join(df.columns)}", f"", f"**统计摘要:**", f"{df.describe()}", ] if treatment_col and treatment_col in df.columns: lines.extend( [ f"", f"**处理变量分布:**", f"{df[treatment_col].value_counts().to_string()}", ] ) return "\n".join(lines)