69 lines
2.0 KiB
Python
69 lines
2.0 KiB
Python
"""
|
|
通用数据加载器
|
|
|
|
支持 CSV、Excel、JSON Lines 等格式,同时提供数据信息摘要生成。
|
|
"""
|
|
|
|
import os
|
|
from typing import Optional
|
|
|
|
import pandas as pd
|
|
|
|
|
|
def _read_json_lines(file_path: str) -> pd.DataFrame:
|
|
"""读取 JSON Lines 文件"""
|
|
return pd.read_json(file_path, lines=True)
|
|
|
|
|
|
class DataLoader:
|
|
"""通用数据加载器"""
|
|
|
|
SUPPORTED_EXTENSIONS = {".csv", ".xlsx", ".xls", ".json", ".jsonl"}
|
|
|
|
def __init__(self, default_read_rows: Optional[int] = None):
|
|
self.default_read_rows = default_read_rows
|
|
|
|
def load(self, file_path: str) -> pd.DataFrame:
|
|
"""加载数据文件为 DataFrame"""
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"数据文件不存在:{file_path}")
|
|
|
|
ext = os.path.splitext(file_path)[1].lower()
|
|
if ext not in self.SUPPORTED_EXTENSIONS:
|
|
raise ValueError(
|
|
f"不支持的文件格式:{ext},仅支持 {self.SUPPORTED_EXTENSIONS}"
|
|
)
|
|
|
|
if ext == ".csv":
|
|
kwargs = {}
|
|
if self.default_read_rows is not None:
|
|
kwargs["nrows"] = self.default_read_rows
|
|
return pd.read_csv(file_path, **kwargs)
|
|
elif ext in {".xlsx", ".xls"}:
|
|
return pd.read_excel(file_path)
|
|
elif ext in {".json", ".jsonl"}:
|
|
return _read_json_lines(file_path)
|
|
|
|
raise ValueError(f"未实现的文件格式:{ext}")
|
|
|
|
|
|
def get_data_info(df: pd.DataFrame, treatment_col: Optional[str] = None) -> str:
|
|
"""生成数据信息摘要字符串,用于 LLM prompt"""
|
|
lines = [
|
|
f"**数据概览:**",
|
|
f"- 样本数量:{len(df)}",
|
|
f"- 变量:{', '.join(df.columns)}",
|
|
f"",
|
|
f"**统计摘要:**",
|
|
f"{df.describe()}",
|
|
]
|
|
if treatment_col and treatment_col in df.columns:
|
|
lines.extend(
|
|
[
|
|
f"",
|
|
f"**处理变量分布:**",
|
|
f"{df[treatment_col].value_counts().to_string()}",
|
|
]
|
|
)
|
|
return "\n".join(lines)
|