294 lines
8.8 KiB
Python
294 lines
8.8 KiB
Python
"""
|
||
文件查询工具模块 (File Query Tool)
|
||
|
||
该模块提供文件查询功能,支持读取 CSV 和 Excel 文件的表头信息、
|
||
统计信息和样本数据,而不需要读取完整数据。
|
||
"""
|
||
|
||
import os
|
||
import pandas as pd
|
||
from dataclasses import dataclass, asdict
|
||
from typing import List, Optional, Dict, Any
|
||
from pathlib import Path
|
||
|
||
|
||
@dataclass
|
||
class ColumnInfo:
|
||
"""列信息数据类"""
|
||
name: str # 列名
|
||
dtype: str # 数据类型
|
||
non_null_count: int # 非空值数量
|
||
null_count: int # 空值数量
|
||
unique_count: Optional[int] = None # 唯一值数量(可选)
|
||
sample_values: Optional[List[Any]] = None # 样本值(可选)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
return asdict(self)
|
||
|
||
|
||
@dataclass
|
||
class SampleData:
|
||
"""样本数据数据类"""
|
||
rows: List[Dict[str, Any]] # 样本行数据
|
||
row_count: int # 样本行数
|
||
columns: List[str] # 列名列表
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
return {
|
||
"rows": self.rows,
|
||
"row_count": self.row_count,
|
||
"columns": self.columns
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class FileMetadata:
|
||
"""文件元数据"""
|
||
file_path: str # 文件路径
|
||
file_name: str # 文件名
|
||
file_size: int # 文件大小(字节)
|
||
file_type: str # 文件类型(csv/excel)
|
||
extension: str # 文件扩展名
|
||
total_rows: Optional[int] = None # 总行数
|
||
total_columns: Optional[int] = None # 总列数
|
||
columns: Optional[List[ColumnInfo]] = None # 列信息列表
|
||
sample_data: Optional[SampleData] = None # 样本数据
|
||
read_rows: int = 100 # 实际读取的行数
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
result = asdict(self)
|
||
if self.columns:
|
||
result["columns"] = [col.to_dict() for col in self.columns]
|
||
if self.sample_data:
|
||
result["sample_data"] = self.sample_data.to_dict()
|
||
return result
|
||
|
||
|
||
class FileQueryTool:
|
||
"""
|
||
文件查询工具
|
||
|
||
用于读取 CSV 和 Excel 文件的表头信息、统计信息和样本数据。
|
||
支持只读取前 N 行,避免加载完整数据。
|
||
|
||
Attributes:
|
||
default_read_rows: 默认读取的行数(100 行)
|
||
|
||
Example:
|
||
>>> tool = FileQueryTool()
|
||
>>> metadata = tool.query("data.csv")
|
||
>>> print(metadata.columns)
|
||
>>> print(metadata.sample_data)
|
||
"""
|
||
|
||
def __init__(self, default_read_rows: int = 100):
|
||
"""
|
||
初始化文件查询工具
|
||
|
||
Args:
|
||
default_read_rows: 默认读取的行数,用于获取表头和样本数据
|
||
"""
|
||
self.default_read_rows = default_read_rows
|
||
|
||
def query(self, file_path: str, read_rows: Optional[int] = None) -> FileMetadata:
|
||
"""
|
||
查询文件信息
|
||
|
||
读取文件的前 N 行,提取表头信息、统计信息和样本数据。
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
read_rows: 要读取的行数,如果为 None 则使用默认值
|
||
|
||
Returns:
|
||
FileMetadata: 文件元数据对象
|
||
|
||
Raises:
|
||
FileNotFoundError: 文件不存在
|
||
ValueError: 不支持的文件格式
|
||
Exception: 其他读取错误
|
||
"""
|
||
if read_rows is None:
|
||
read_rows = self.default_read_rows
|
||
|
||
# 验证文件存在
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"文件不存在:{file_path}")
|
||
|
||
# 获取文件信息
|
||
file_path = os.path.abspath(file_path)
|
||
file_name = os.path.basename(file_path)
|
||
file_size = os.path.getsize(file_path)
|
||
extension = os.path.splitext(file_name)[1].lower()
|
||
|
||
# 确定文件类型
|
||
if extension == ".csv":
|
||
file_type = "csv"
|
||
df = self._read_csv(file_path, read_rows)
|
||
elif extension in [".xlsx", ".xls"]:
|
||
file_type = "excel"
|
||
df = self._read_excel(file_path, read_rows)
|
||
else:
|
||
raise ValueError(f"不支持的文件格式:{extension},仅支持 CSV 和 Excel 文件")
|
||
|
||
# 提取列信息
|
||
columns = self._extract_column_info(df)
|
||
|
||
# 提取样本数据
|
||
sample_data = self._extract_sample_data(df, read_rows)
|
||
|
||
# 创建元数据对象
|
||
metadata = FileMetadata(
|
||
file_path=file_path,
|
||
file_name=file_name,
|
||
file_size=file_size,
|
||
file_type=file_type,
|
||
extension=extension,
|
||
total_rows=len(df),
|
||
total_columns=len(df.columns),
|
||
columns=columns,
|
||
sample_data=sample_data,
|
||
read_rows=min(len(df), read_rows)
|
||
)
|
||
|
||
return metadata
|
||
|
||
def _read_csv(self, file_path: str, read_rows: int) -> pd.DataFrame:
|
||
"""
|
||
读取 CSV 文件
|
||
|
||
Args:
|
||
file_path: CSV 文件路径
|
||
read_rows: 要读取的行数
|
||
|
||
Returns:
|
||
pd.DataFrame: 读取的数据框
|
||
"""
|
||
df = pd.read_csv(file_path, nrows=read_rows)
|
||
return df
|
||
|
||
def _read_excel(self, file_path: str, read_rows: int) -> pd.DataFrame:
|
||
"""
|
||
读取 Excel 文件
|
||
|
||
Args:
|
||
file_path: Excel 文件路径
|
||
read_rows: 要读取的行数
|
||
|
||
Returns:
|
||
pd.DataFrame: 读取的数据框
|
||
"""
|
||
df = pd.read_excel(file_path, nrows=read_rows)
|
||
return df
|
||
|
||
def _extract_column_info(self, df: pd.DataFrame) -> List[ColumnInfo]:
|
||
"""
|
||
提取列信息
|
||
|
||
Args:
|
||
df: 数据框
|
||
|
||
Returns:
|
||
List[ColumnInfo]: 列信息列表
|
||
"""
|
||
columns = []
|
||
|
||
for col in df.columns:
|
||
# 获取数据类型
|
||
dtype = str(df[col].dtype)
|
||
|
||
# 统计非空值和空值
|
||
non_null_count = df[col].notna().sum()
|
||
null_count = df[col].isna().sum()
|
||
|
||
# 计算唯一值数量(仅对非空值)
|
||
unique_count = df[col].dropna().nunique()
|
||
|
||
# 获取样本值(最多 5 个)
|
||
sample_values = df[col].dropna().head(5).tolist()
|
||
|
||
column_info = ColumnInfo(
|
||
name=col,
|
||
dtype=dtype,
|
||
non_null_count=int(non_null_count),
|
||
null_count=int(null_count),
|
||
unique_count=int(unique_count),
|
||
sample_values=sample_values
|
||
)
|
||
columns.append(column_info)
|
||
|
||
return columns
|
||
|
||
def _extract_sample_data(self, df: pd.DataFrame, read_rows: int) -> SampleData:
|
||
"""
|
||
提取样本数据
|
||
|
||
Args:
|
||
df: 数据框
|
||
read_rows: 实际读取的行数
|
||
|
||
Returns:
|
||
SampleData: 样本数据对象
|
||
"""
|
||
# 获取实际行数(可能小于 read_rows)
|
||
actual_rows = min(len(df), read_rows)
|
||
|
||
# 转换为字典列表
|
||
rows = df.head(actual_rows).to_dict(orient='records')
|
||
|
||
return SampleData(
|
||
rows=rows,
|
||
row_count=actual_rows,
|
||
columns=df.columns.tolist()
|
||
)
|
||
|
||
def get_column_names(self, file_path: str, read_rows: Optional[int] = None) -> List[str]:
|
||
"""
|
||
快速获取列名列表
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
read_rows: 要读取的行数
|
||
|
||
Returns:
|
||
List[str]: 列名列表
|
||
"""
|
||
metadata = self.query(file_path, read_rows)
|
||
return [col.name for col in metadata.columns]
|
||
|
||
def get_column_dtype(self, file_path: str, column_name: str,
|
||
read_rows: Optional[int] = None) -> str:
|
||
"""
|
||
获取指定列的数据类型
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
column_name: 列名
|
||
read_rows: 要读取的行数
|
||
|
||
Returns:
|
||
str: 数据类型
|
||
"""
|
||
metadata = self.query(file_path, read_rows)
|
||
for col in metadata.columns:
|
||
if col.name == column_name:
|
||
return col.dtype
|
||
raise ValueError(f"列 '{column_name}' 不存在")
|
||
|
||
def get_file_info(self, file_path: str, read_rows: Optional[int] = None) -> Dict[str, Any]:
|
||
"""
|
||
获取文件信息的字典格式
|
||
|
||
Args:
|
||
file_path: 文件路径
|
||
read_rows: 要读取的行数
|
||
|
||
Returns:
|
||
Dict[str, Any]: 文件信息字典
|
||
"""
|
||
metadata = self.query(file_path, read_rows)
|
||
return metadata.to_dict()
|