- 添加 TabMModel、TabPFNModel 深度学习模型实现 - 新增 DataQualityAnalyzer 进行训练前数据质量诊断 - 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性 - 支持 train_skip_days 参数跳过训练初期数据不足期 - Pipeline 自动清理标签为 NaN 的样本
515 lines
16 KiB
Python
515 lines
16 KiB
Python
"""数据质量分析模块
|
|
|
|
提供数据质量检查功能,包括:
|
|
- 缺失值统计
|
|
- 零值统计
|
|
- 按日期检查全空列
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
import polars as pl
|
|
import numpy as np
|
|
|
|
|
|
class DataQualityAnalyzer:
|
|
"""数据质量分析器
|
|
|
|
用于分析训练数据的质量问题,帮助识别数据异常。
|
|
|
|
Attributes:
|
|
feature_cols: 特征列名列表
|
|
label_col: 标签列名
|
|
verbose: 是否打印详细信息
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
feature_cols: Optional[List[str]] = None,
|
|
label_col: Optional[str] = None,
|
|
verbose: bool = True,
|
|
):
|
|
"""初始化数据质量分析器
|
|
|
|
Args:
|
|
feature_cols: 特征列名列表
|
|
label_col: 标签列名
|
|
verbose: 是否打印详细信息
|
|
"""
|
|
self.feature_cols = feature_cols or []
|
|
self.label_col = label_col
|
|
self.verbose = verbose
|
|
self.analysis_results: Dict[str, Any] = {}
|
|
|
|
def set_columns(self, feature_cols: List[str], label_col: str) -> None:
|
|
"""设置要分析的列
|
|
|
|
Args:
|
|
feature_cols: 特征列名列表
|
|
label_col: 标签列名
|
|
"""
|
|
self.feature_cols = feature_cols
|
|
self.label_col = label_col
|
|
|
|
def analyze(
|
|
self,
|
|
data: Dict[str, Dict[str, Any]],
|
|
split_names: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""执行完整的数据质量分析
|
|
|
|
Args:
|
|
data: 数据字典,格式为 {"train": {...}, "val": {...}, "test": {...}}
|
|
split_names: 要分析的数据划分名称列表,默认为 ["train", "val", "test"]
|
|
|
|
Returns:
|
|
分析结果字典
|
|
"""
|
|
if not split_names:
|
|
split_names = ["train", "val", "test"]
|
|
|
|
if self.verbose:
|
|
print("\n" + "=" * 80)
|
|
print("数据质量分析报告")
|
|
print("=" * 80)
|
|
|
|
self.analysis_results = {}
|
|
|
|
for split_name in split_names:
|
|
if split_name not in data:
|
|
continue
|
|
|
|
split_data = data[split_name]
|
|
raw_df = split_data.get("raw_data")
|
|
|
|
if raw_df is None:
|
|
continue
|
|
|
|
if self.verbose:
|
|
print(f"\n[{split_name.upper()} 数据集]")
|
|
print("-" * 40)
|
|
|
|
split_results = self._analyze_split(raw_df, split_name)
|
|
self.analysis_results[split_name] = split_results
|
|
|
|
if self.verbose:
|
|
print("\n" + "=" * 80)
|
|
|
|
return self.analysis_results
|
|
|
|
def _analyze_split(
|
|
self,
|
|
df: pl.DataFrame,
|
|
split_name: str,
|
|
) -> Dict[str, Any]:
|
|
"""分析单个数据集划分
|
|
|
|
Args:
|
|
df: 数据框
|
|
split_name: 划分名称
|
|
|
|
Returns:
|
|
分析结果字典
|
|
"""
|
|
results = {
|
|
"total_rows": len(df),
|
|
"total_cols": len(df.columns),
|
|
"feature_cols": self.feature_cols,
|
|
"label_col": self.label_col,
|
|
"null_analysis": {},
|
|
"zero_analysis": {},
|
|
"all_null_by_date": {},
|
|
}
|
|
|
|
# 1. 分析特征列的缺失值
|
|
null_stats = self._analyze_null_values(df, self.feature_cols)
|
|
results["null_analysis"] = null_stats
|
|
|
|
if self.verbose:
|
|
self._print_null_analysis(null_stats)
|
|
|
|
# 2. 分析特征列的零值
|
|
zero_stats = self._analyze_zero_values(df, self.feature_cols)
|
|
results["zero_analysis"] = zero_stats
|
|
|
|
if self.verbose:
|
|
self._print_zero_analysis(zero_stats)
|
|
|
|
# 3. 检查是否存在某天某列全为空的情况
|
|
all_null_by_date = self._check_all_null_by_date(df, self.feature_cols)
|
|
results["all_null_by_date"] = all_null_by_date
|
|
|
|
if self.verbose:
|
|
self._print_all_null_by_date(all_null_by_date)
|
|
|
|
# 4. 分析标签列
|
|
if self.label_col and self.label_col in df.columns:
|
|
label_stats = self._analyze_label(df, self.label_col)
|
|
results["label_analysis"] = label_stats
|
|
|
|
if self.verbose:
|
|
self._print_label_analysis(label_stats)
|
|
|
|
return results
|
|
|
|
def _analyze_null_values(
|
|
self,
|
|
df: pl.DataFrame,
|
|
cols: List[str],
|
|
) -> Dict[str, Any]:
|
|
"""分析缺失值
|
|
|
|
Args:
|
|
df: 数据框
|
|
cols: 要分析的列名列表
|
|
|
|
Returns:
|
|
缺失值统计字典
|
|
"""
|
|
stats = {
|
|
"total_cells": len(df) * len(cols),
|
|
"null_counts": {},
|
|
"null_percentages": {},
|
|
"columns_with_null": [],
|
|
"total_null_cells": 0,
|
|
}
|
|
|
|
for col in cols:
|
|
if col not in df.columns:
|
|
continue
|
|
|
|
null_count = df[col].null_count()
|
|
if null_count > 0:
|
|
null_pct = null_count / len(df) * 100
|
|
stats["null_counts"][col] = null_count
|
|
stats["null_percentages"][col] = null_pct
|
|
stats["columns_with_null"].append(col)
|
|
stats["total_null_cells"] += null_count
|
|
|
|
return stats
|
|
|
|
def _analyze_zero_values(
|
|
self,
|
|
df: pl.DataFrame,
|
|
cols: List[str],
|
|
) -> Dict[str, Any]:
|
|
"""分析零值
|
|
|
|
Args:
|
|
df: 数据框
|
|
cols: 要分析的列名列表
|
|
|
|
Returns:
|
|
零值统计字典
|
|
"""
|
|
stats = {
|
|
"total_cells": len(df) * len(cols),
|
|
"zero_counts": {},
|
|
"zero_percentages": {},
|
|
"columns_with_zero": [],
|
|
"total_zero_cells": 0,
|
|
}
|
|
|
|
for col in cols:
|
|
if col not in df.columns:
|
|
continue
|
|
|
|
# 计算零值数量(排除空值)
|
|
non_null_series = df[col].drop_nulls()
|
|
if len(non_null_series) == 0:
|
|
continue
|
|
|
|
zero_count = (non_null_series == 0).sum()
|
|
if zero_count > 0:
|
|
zero_pct = zero_count / len(df) * 100
|
|
stats["zero_counts"][col] = int(zero_count)
|
|
stats["zero_percentages"][col] = zero_pct
|
|
stats["columns_with_zero"].append(col)
|
|
stats["total_zero_cells"] += int(zero_count)
|
|
|
|
return stats
|
|
|
|
def _check_all_null_by_date(
|
|
self,
|
|
df: pl.DataFrame,
|
|
cols: List[str],
|
|
) -> Dict[str, Any]:
|
|
"""检查是否存在某天某列全为空的情况
|
|
|
|
使用 polars lazy frame 进行内存安全的高效计算。
|
|
|
|
Args:
|
|
df: 数据框
|
|
cols: 要分析的列名列表
|
|
|
|
Returns:
|
|
全空检查结果字典
|
|
"""
|
|
results = {
|
|
"issues_found": False,
|
|
"issues": [],
|
|
}
|
|
|
|
if "trade_date" not in df.columns:
|
|
return results
|
|
|
|
# 过滤掉不在表中的列
|
|
valid_cols = [c for c in cols if c in df.columns]
|
|
if not valid_cols:
|
|
return results
|
|
|
|
# 使用 lazy frame 进行查询优化
|
|
lf = df.lazy()
|
|
|
|
# 核心步骤:只计算 null_count 和总行数 (聚合后数据量极小)
|
|
# 为每个列创建单独的 null_count 聚合表达式
|
|
agg_exprs = [
|
|
pl.col(col).null_count().alias(f"{col}_nulls") for col in valid_cols
|
|
]
|
|
agg_exprs.append(pl.len().alias("total_rows"))
|
|
|
|
agg_lf = lf.group_by("trade_date").agg(agg_exprs)
|
|
|
|
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
|
|
agg_df = agg_lf.collect()
|
|
|
|
# 在这个已经"脱水"的小表上进行逻辑检查
|
|
issues = []
|
|
for col in valid_cols:
|
|
null_col = f"{col}_nulls"
|
|
# 找出 null 数量等于总行数的日期
|
|
bad_dates = agg_df.filter(
|
|
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
|
|
).select(["trade_date", "total_rows"])
|
|
|
|
if not bad_dates.is_empty():
|
|
for row in bad_dates.to_dicts():
|
|
issues.append(
|
|
{
|
|
"date": row["trade_date"],
|
|
"column": col,
|
|
"total_rows": row["total_rows"],
|
|
}
|
|
)
|
|
|
|
if issues:
|
|
results["issues_found"] = True
|
|
results["issues"] = issues
|
|
|
|
return results
|
|
|
|
def _analyze_label(
|
|
self,
|
|
df: pl.DataFrame,
|
|
label_col: str,
|
|
) -> Dict[str, Any]:
|
|
"""分析标签列
|
|
|
|
Args:
|
|
df: 数据框
|
|
label_col: 标签列名
|
|
|
|
Returns:
|
|
标签分析字典
|
|
"""
|
|
stats = {
|
|
"total_count": len(df),
|
|
"null_count": 0,
|
|
"null_percentage": 0.0,
|
|
"zero_count": 0,
|
|
"zero_percentage": 0.0,
|
|
"min": None,
|
|
"max": None,
|
|
"mean": None,
|
|
"std": None,
|
|
}
|
|
|
|
if label_col not in df.columns:
|
|
return stats
|
|
|
|
series = df[label_col]
|
|
|
|
# 缺失值统计
|
|
null_count = series.null_count()
|
|
stats["null_count"] = null_count
|
|
stats["null_percentage"] = null_count / len(df) * 100 if len(df) > 0 else 0
|
|
|
|
# 零值统计
|
|
non_null_series = series.drop_nulls()
|
|
if len(non_null_series) > 0:
|
|
zero_count = (non_null_series == 0).sum()
|
|
stats["zero_count"] = int(zero_count)
|
|
stats["zero_percentage"] = zero_count / len(df) * 100
|
|
|
|
# 基本统计量
|
|
stats["min"] = float(non_null_series.min())
|
|
stats["max"] = float(non_null_series.max())
|
|
stats["mean"] = float(non_null_series.mean())
|
|
stats["std"] = float(non_null_series.std())
|
|
|
|
return stats
|
|
|
|
def _print_null_analysis(self, stats: Dict[str, Any]) -> None:
|
|
"""打印缺失值分析结果
|
|
|
|
Args:
|
|
stats: 缺失值统计字典
|
|
"""
|
|
total_cells = stats["total_cells"]
|
|
total_null = stats["total_null_cells"]
|
|
null_cols = stats["columns_with_null"]
|
|
|
|
print(f" 缺失值统计:")
|
|
print(f" 总单元格数: {total_cells:,}")
|
|
print(
|
|
f" 缺失单元格数: {total_null:,} ({total_null / total_cells * 100:.2f}%)"
|
|
)
|
|
print(f" 有缺失值的列数: {len(null_cols)}/{len(self.feature_cols)}")
|
|
|
|
if null_cols:
|
|
print(f" 缺失值最多的5个特征:")
|
|
sorted_cols = sorted(
|
|
stats["null_counts"].items(),
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)[:5]
|
|
for col, count in sorted_cols:
|
|
pct = stats["null_percentages"][col]
|
|
print(f" {col}: {count:,} ({pct:.2f}%)")
|
|
|
|
def _print_zero_analysis(self, stats: Dict[str, Any]) -> None:
|
|
"""打印零值分析结果
|
|
|
|
Args:
|
|
stats: 零值统计字典
|
|
"""
|
|
total_cells = stats["total_cells"]
|
|
total_zero = stats["total_zero_cells"]
|
|
zero_cols = stats["columns_with_zero"]
|
|
|
|
print(f" 零值统计:")
|
|
print(f" 总单元格数: {total_cells:,}")
|
|
print(
|
|
f" 零值单元格数: {total_zero:,} ({total_zero / total_cells * 100:.2f}%)"
|
|
)
|
|
print(f" 有零值的列数: {len(zero_cols)}/{len(self.feature_cols)}")
|
|
|
|
if zero_cols:
|
|
print(f" 零值最多的5个特征:")
|
|
sorted_cols = sorted(
|
|
stats["zero_counts"].items(),
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)[:5]
|
|
for col, count in sorted_cols:
|
|
pct = stats["zero_percentages"][col]
|
|
print(f" {col}: {count:,} ({pct:.2f}%)")
|
|
|
|
def _print_all_null_by_date(self, results: Dict[str, Any]) -> None:
|
|
"""打印按日期全空检查结果
|
|
|
|
Args:
|
|
results: 全空检查结果字典
|
|
"""
|
|
issues = results["issues"]
|
|
|
|
print(f" 按日期全空检查:")
|
|
if results["issues_found"]:
|
|
print(f" [警告] 发现 {len(issues)} 个问题:")
|
|
# 按日期分组显示
|
|
by_date = {}
|
|
for issue in issues:
|
|
date = issue["date"]
|
|
if date not in by_date:
|
|
by_date[date] = []
|
|
by_date[date].append(issue["column"])
|
|
|
|
for date in sorted(by_date.keys())[:5]: # 只显示前5个日期
|
|
cols = by_date[date]
|
|
print(f" 日期 {date}: {len(cols)} 列全为空")
|
|
if len(cols) <= 3:
|
|
print(f" 列名: {', '.join(cols)}")
|
|
|
|
if len(by_date) > 5:
|
|
print(f" ... 还有 {len(by_date) - 5} 个日期存在问题")
|
|
else:
|
|
print(f" [正常] 未发现某天某列全为空的情况")
|
|
|
|
def _print_label_analysis(self, stats: Dict[str, Any]) -> None:
|
|
"""打印标签分析结果
|
|
|
|
Args:
|
|
stats: 标签分析字典
|
|
"""
|
|
print(f" 标签列统计 ({self.label_col}):")
|
|
print(f" 总数: {stats['total_count']:,}")
|
|
print(f" 缺失值: {stats['null_count']:,} ({stats['null_percentage']:.2f}%)")
|
|
print(f" 零值: {stats['zero_count']:,} ({stats['zero_percentage']:.2f}%)")
|
|
|
|
if stats["mean"] is not None:
|
|
print(f" 最小值: {stats['min']:.6f}")
|
|
print(f" 最大值: {stats['max']:.6f}")
|
|
print(f" 均值: {stats['mean']:.6f}")
|
|
print(f" 标准差: {stats['std']:.6f}")
|
|
|
|
def get_summary(self) -> str:
|
|
"""获取分析结果摘要
|
|
|
|
Returns:
|
|
摘要字符串
|
|
"""
|
|
if not self.analysis_results:
|
|
return "尚未执行分析"
|
|
|
|
lines = ["数据质量分析摘要", "=" * 40]
|
|
|
|
for split_name, results in self.analysis_results.items():
|
|
lines.append(f"\n[{split_name.upper()}]")
|
|
lines.append(f" 总行数: {results['total_rows']:,}")
|
|
|
|
null_stats = results.get("null_analysis", {})
|
|
if null_stats.get("columns_with_null"):
|
|
lines.append(
|
|
f" 缺失值: {null_stats['total_null_cells']:,} 个单元格, "
|
|
f"{len(null_stats['columns_with_null'])} 列受影响"
|
|
)
|
|
|
|
zero_stats = results.get("zero_analysis", {})
|
|
if zero_stats.get("columns_with_zero"):
|
|
lines.append(
|
|
f" 零值: {zero_stats['total_zero_cells']:,} 个单元格, "
|
|
f"{len(zero_stats['columns_with_zero'])} 列受影响"
|
|
)
|
|
|
|
all_null = results.get("all_null_by_date", {})
|
|
if all_null.get("issues_found"):
|
|
lines.append(
|
|
f" [警告] 发现 {len(all_null['issues'])} 个日期列全空问题"
|
|
)
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def analyze_data_quality(
|
|
data: Dict[str, Dict[str, Any]],
|
|
feature_cols: Optional[List[str]] = None,
|
|
label_col: Optional[str] = None,
|
|
verbose: bool = True,
|
|
) -> Dict[str, Any]:
|
|
"""便捷函数:执行数据质量分析
|
|
|
|
Args:
|
|
data: 数据字典
|
|
feature_cols: 特征列名列表
|
|
label_col: 标签列名
|
|
verbose: 是否打印详细信息
|
|
|
|
Returns:
|
|
分析结果字典
|
|
"""
|
|
analyzer = DataQualityAnalyzer(
|
|
feature_cols=feature_cols,
|
|
label_col=label_col,
|
|
verbose=verbose,
|
|
)
|
|
return analyzer.analyze(data)
|