2026-03-29 23:47:20 +08:00

69 lines
2.0 KiB
Python

"""
通用数据加载器
支持 CSV、Excel、JSON Lines 等格式,同时提供数据信息摘要生成。
"""
import os
from typing import Optional
import pandas as pd
def _read_json_lines(file_path: str) -> pd.DataFrame:
"""读取 JSON Lines 文件"""
return pd.read_json(file_path, lines=True)
class DataLoader:
"""通用数据加载器"""
SUPPORTED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".json", ".jsonl"}
def __init__(self, default_read_rows: Optional[int] = None):
self.default_read_rows = default_read_rows
def load(self, file_path: str) -> pd.DataFrame:
"""加载数据文件为 DataFrame"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"数据文件不存在:{file_path}")
ext = os.path.splitext(file_path)[1].lower()
if ext not in self.SUPPORTED_EXTENSIONS:
raise ValueError(
f"不支持的文件格式:{ext},仅支持 {self.SUPPORTED_EXTENSIONS}"
)
if ext == ".csv":
kwargs = {}
if self.default_read_rows is not None:
kwargs["nrows"] = self.default_read_rows
return pd.read_csv(file_path, **kwargs)
elif ext in {".xlsx", ".xls"}:
return pd.read_excel(file_path)
elif ext in {".json", ".jsonl"}:
return _read_json_lines(file_path)
raise ValueError(f"未实现的文件格式:{ext}")
def get_data_info(df: pd.DataFrame, treatment_col: Optional[str] = None) -> str:
"""生成数据信息摘要字符串,用于 LLM prompt"""
lines = [
f"**数据概览:**",
f"- 样本数量:{len(df)}",
f"- 变量:{', '.join(df.columns)}",
f"",
f"**统计摘要:**",
f"{df.describe()}",
]
if treatment_col and treatment_col in df.columns:
lines.extend(
[
f"",
f"**处理变量分布:**",
f"{df[treatment_col].value_counts().to_string()}",
]
)
return "\n".join(lines)