diff --git a/src/experiment/common.py b/src/experiment/common.py index 7826c28..b6fdc23 100644 --- a/src/experiment/common.py +++ b/src/experiment/common.py @@ -14,7 +14,7 @@ from src.factors import FactorEngine # ============================================================================= # 日期范围配置(正确的 train/val/test 三分法) # ============================================================================= -TRAIN_START = "20200101" +TRAIN_START = "20190101" TRAIN_END = "20231231" VAL_START = "20240101" VAL_END = "20241231" @@ -79,8 +79,8 @@ SELECTED_FACTORS = [ "GTJA_alpha002", "GTJA_alpha003", "GTJA_alpha004", - "GTJA_alpha005", - "GTJA_alpha006", + # "GTJA_alpha005", + # "GTJA_alpha006", "GTJA_alpha007", "GTJA_alpha008", "GTJA_alpha009", @@ -123,133 +123,133 @@ SELECTED_FACTORS = [ "GTJA_alpha048", "GTJA_alpha049", "GTJA_alpha050", - "GTJA_alpha051", - "GTJA_alpha052", - "GTJA_alpha053", - "GTJA_alpha054", - "GTJA_alpha056", - "GTJA_alpha057", - "GTJA_alpha058", - "GTJA_alpha059", - "GTJA_alpha060", - "GTJA_alpha061", - "GTJA_alpha062", - "GTJA_alpha063", - "GTJA_alpha064", - "GTJA_alpha065", - "GTJA_alpha066", - "GTJA_alpha067", - "GTJA_alpha068", - "GTJA_alpha070", - "GTJA_alpha071", - "GTJA_alpha072", - "GTJA_alpha073", - "GTJA_alpha074", - "GTJA_alpha076", - "GTJA_alpha077", - "GTJA_alpha078", - "GTJA_alpha079", - "GTJA_alpha080", - "GTJA_alpha081", - "GTJA_alpha082", - "GTJA_alpha083", - "GTJA_alpha084", - "GTJA_alpha085", - "GTJA_alpha086", - "GTJA_alpha087", - "GTJA_alpha088", - "GTJA_alpha089", - "GTJA_alpha090", - "GTJA_alpha091", - "GTJA_alpha092", - "GTJA_alpha093", - "GTJA_alpha094", - "GTJA_alpha095", - "GTJA_alpha096", - "GTJA_alpha097", - "GTJA_alpha098", - "GTJA_alpha099", - "GTJA_alpha100", - "GTJA_alpha101", - "GTJA_alpha102", - "GTJA_alpha103", - "GTJA_alpha104", - "GTJA_alpha105", - "GTJA_alpha106", - "GTJA_alpha107", - "GTJA_alpha108", - "GTJA_alpha109", - "GTJA_alpha110", - "GTJA_alpha111", - "GTJA_alpha112", - # "GTJA_alpha113", - "GTJA_alpha114", - "GTJA_alpha115", - "GTJA_alpha117", - "GTJA_alpha118", - "GTJA_alpha119", - "GTJA_alpha120", - # "GTJA_alpha121", - "GTJA_alpha122", - "GTJA_alpha123", - "GTJA_alpha124", - "GTJA_alpha125", - "GTJA_alpha126", - "GTJA_alpha127", - "GTJA_alpha128", - "GTJA_alpha129", - "GTJA_alpha130", - "GTJA_alpha131", - "GTJA_alpha132", - "GTJA_alpha133", - "GTJA_alpha134", - "GTJA_alpha135", - "GTJA_alpha136", - # "GTJA_alpha138", - "GTJA_alpha139", - # "GTJA_alpha140", - "GTJA_alpha141", - "GTJA_alpha142", - "GTJA_alpha145", - # "GTJA_alpha146", - "GTJA_alpha148", - "GTJA_alpha150", - "GTJA_alpha151", - "GTJA_alpha152", - "GTJA_alpha153", - "GTJA_alpha154", - "GTJA_alpha155", - "GTJA_alpha156", - "GTJA_alpha157", - "GTJA_alpha158", - "GTJA_alpha159", - "GTJA_alpha160", - "GTJA_alpha161", - "GTJA_alpha162", - "GTJA_alpha163", - "GTJA_alpha164", - # "GTJA_alpha165", - "GTJA_alpha166", - "GTJA_alpha167", - "GTJA_alpha168", - "GTJA_alpha169", - "GTJA_alpha170", - "GTJA_alpha171", - "GTJA_alpha173", - "GTJA_alpha174", - "GTJA_alpha175", - "GTJA_alpha176", - "GTJA_alpha177", - "GTJA_alpha178", - "GTJA_alpha179", - "GTJA_alpha180", - # "GTJA_alpha183", - "GTJA_alpha184", - "GTJA_alpha185", - "GTJA_alpha187", - "GTJA_alpha188", - "GTJA_alpha189", - "GTJA_alpha191", + # "GTJA_alpha051", + # "GTJA_alpha052", + # "GTJA_alpha053", + # "GTJA_alpha054", + # "GTJA_alpha056", + # "GTJA_alpha057", + # "GTJA_alpha058", + # "GTJA_alpha059", + # "GTJA_alpha060", + # "GTJA_alpha061", + # "GTJA_alpha062", + # "GTJA_alpha063", + # "GTJA_alpha064", + # "GTJA_alpha065", + # "GTJA_alpha066", + # "GTJA_alpha067", + # "GTJA_alpha068", + # "GTJA_alpha070", + # "GTJA_alpha071", + # "GTJA_alpha072", + # "GTJA_alpha073", + # "GTJA_alpha074", + # "GTJA_alpha076", + # "GTJA_alpha077", + # "GTJA_alpha078", + # "GTJA_alpha079", + # "GTJA_alpha080", + # "GTJA_alpha081", + # "GTJA_alpha082", + # "GTJA_alpha083", + # "GTJA_alpha084", + # "GTJA_alpha085", + # "GTJA_alpha086", + # "GTJA_alpha087", + # "GTJA_alpha088", + # "GTJA_alpha089", + # "GTJA_alpha090", + # "GTJA_alpha091", + # "GTJA_alpha092", + # "GTJA_alpha093", + # "GTJA_alpha094", + # "GTJA_alpha095", + # "GTJA_alpha096", + # "GTJA_alpha097", + # "GTJA_alpha098", + # "GTJA_alpha099", + # "GTJA_alpha100", + # "GTJA_alpha101", + # "GTJA_alpha102", + # "GTJA_alpha103", + # "GTJA_alpha104", + # "GTJA_alpha105", + # "GTJA_alpha106", + # "GTJA_alpha107", + # "GTJA_alpha108", + # "GTJA_alpha109", + # "GTJA_alpha110", + # "GTJA_alpha111", + # "GTJA_alpha112", + # # "GTJA_alpha113", + # "GTJA_alpha114", + # "GTJA_alpha115", + # "GTJA_alpha117", + # "GTJA_alpha118", + # "GTJA_alpha119", + # "GTJA_alpha120", + # # "GTJA_alpha121", + # "GTJA_alpha122", + # "GTJA_alpha123", + # "GTJA_alpha124", + # "GTJA_alpha125", + # "GTJA_alpha126", + # "GTJA_alpha127", + # "GTJA_alpha128", + # "GTJA_alpha129", + # "GTJA_alpha130", + # "GTJA_alpha131", + # "GTJA_alpha132", + # "GTJA_alpha133", + # "GTJA_alpha134", + # "GTJA_alpha135", + # "GTJA_alpha136", + # # "GTJA_alpha138", + # "GTJA_alpha139", + # # "GTJA_alpha140", + # "GTJA_alpha141", + # "GTJA_alpha142", + # "GTJA_alpha145", + # # "GTJA_alpha146", + # "GTJA_alpha148", + # "GTJA_alpha150", + # "GTJA_alpha151", + # "GTJA_alpha152", + # "GTJA_alpha153", + # "GTJA_alpha154", + # "GTJA_alpha155", + # "GTJA_alpha156", + # "GTJA_alpha157", + # "GTJA_alpha158", + # "GTJA_alpha159", + # "GTJA_alpha160", + # "GTJA_alpha161", + # # "GTJA_alpha162", + # "GTJA_alpha163", + # "GTJA_alpha164", + # # "GTJA_alpha165", + # "GTJA_alpha166", + # "GTJA_alpha167", + # "GTJA_alpha168", + # "GTJA_alpha169", + # "GTJA_alpha170", + # "GTJA_alpha171", + # "GTJA_alpha173", + # "GTJA_alpha174", + # "GTJA_alpha175", + # "GTJA_alpha176", + # "GTJA_alpha177", + # "GTJA_alpha178", + # "GTJA_alpha179", + # "GTJA_alpha180", + # # "GTJA_alpha183", + # "GTJA_alpha184", + # "GTJA_alpha185", + # "GTJA_alpha187", + # "GTJA_alpha188", + # "GTJA_alpha189", + # "GTJA_alpha191", "chip_dispersion_90", "chip_dispersion_70", "cost_skewness", @@ -488,6 +488,9 @@ MODEL_SAVE_DIR = "models" # 模型保存目录 # Top N 配置:每日推荐股票数量 TOP_N = 5 # 可调整为 10, 20 等 +# 训练数据跳过天数配置 +TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据,避免训练初期数据不足的问题 + def get_output_path(model_type: str, test_start: str, test_end: str) -> str: """生成输出文件路径。 diff --git a/src/experiment/data_quality_analyzer.py b/src/experiment/data_quality_analyzer.py new file mode 100644 index 0000000..34842aa --- /dev/null +++ b/src/experiment/data_quality_analyzer.py @@ -0,0 +1,514 @@ +"""数据质量分析模块 + +提供数据质量检查功能,包括: +- 缺失值统计 +- 零值统计 +- 按日期检查全空列 +""" + +from typing import Any, Dict, List, Optional +import polars as pl +import numpy as np + + +class DataQualityAnalyzer: + """数据质量分析器 + + 用于分析训练数据的质量问题,帮助识别数据异常。 + + Attributes: + feature_cols: 特征列名列表 + label_col: 标签列名 + verbose: 是否打印详细信息 + """ + + def __init__( + self, + feature_cols: Optional[List[str]] = None, + label_col: Optional[str] = None, + verbose: bool = True, + ): + """初始化数据质量分析器 + + Args: + feature_cols: 特征列名列表 + label_col: 标签列名 + verbose: 是否打印详细信息 + """ + self.feature_cols = feature_cols or [] + self.label_col = label_col + self.verbose = verbose + self.analysis_results: Dict[str, Any] = {} + + def set_columns(self, feature_cols: List[str], label_col: str) -> None: + """设置要分析的列 + + Args: + feature_cols: 特征列名列表 + label_col: 标签列名 + """ + self.feature_cols = feature_cols + self.label_col = label_col + + def analyze( + self, + data: Dict[str, Dict[str, Any]], + split_names: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """执行完整的数据质量分析 + + Args: + data: 数据字典,格式为 {"train": {...}, "val": {...}, "test": {...}} + split_names: 要分析的数据划分名称列表,默认为 ["train", "val", "test"] + + Returns: + 分析结果字典 + """ + if not split_names: + split_names = ["train", "val", "test"] + + if self.verbose: + print("\n" + "=" * 80) + print("数据质量分析报告") + print("=" * 80) + + self.analysis_results = {} + + for split_name in split_names: + if split_name not in data: + continue + + split_data = data[split_name] + raw_df = split_data.get("raw_data") + + if raw_df is None: + continue + + if self.verbose: + print(f"\n[{split_name.upper()} 数据集]") + print("-" * 40) + + split_results = self._analyze_split(raw_df, split_name) + self.analysis_results[split_name] = split_results + + if self.verbose: + print("\n" + "=" * 80) + + return self.analysis_results + + def _analyze_split( + self, + df: pl.DataFrame, + split_name: str, + ) -> Dict[str, Any]: + """分析单个数据集划分 + + Args: + df: 数据框 + split_name: 划分名称 + + Returns: + 分析结果字典 + """ + results = { + "total_rows": len(df), + "total_cols": len(df.columns), + "feature_cols": self.feature_cols, + "label_col": self.label_col, + "null_analysis": {}, + "zero_analysis": {}, + "all_null_by_date": {}, + } + + # 1. 分析特征列的缺失值 + null_stats = self._analyze_null_values(df, self.feature_cols) + results["null_analysis"] = null_stats + + if self.verbose: + self._print_null_analysis(null_stats) + + # 2. 分析特征列的零值 + zero_stats = self._analyze_zero_values(df, self.feature_cols) + results["zero_analysis"] = zero_stats + + if self.verbose: + self._print_zero_analysis(zero_stats) + + # 3. 检查是否存在某天某列全为空的情况 + all_null_by_date = self._check_all_null_by_date(df, self.feature_cols) + results["all_null_by_date"] = all_null_by_date + + if self.verbose: + self._print_all_null_by_date(all_null_by_date) + + # 4. 分析标签列 + if self.label_col and self.label_col in df.columns: + label_stats = self._analyze_label(df, self.label_col) + results["label_analysis"] = label_stats + + if self.verbose: + self._print_label_analysis(label_stats) + + return results + + def _analyze_null_values( + self, + df: pl.DataFrame, + cols: List[str], + ) -> Dict[str, Any]: + """分析缺失值 + + Args: + df: 数据框 + cols: 要分析的列名列表 + + Returns: + 缺失值统计字典 + """ + stats = { + "total_cells": len(df) * len(cols), + "null_counts": {}, + "null_percentages": {}, + "columns_with_null": [], + "total_null_cells": 0, + } + + for col in cols: + if col not in df.columns: + continue + + null_count = df[col].null_count() + if null_count > 0: + null_pct = null_count / len(df) * 100 + stats["null_counts"][col] = null_count + stats["null_percentages"][col] = null_pct + stats["columns_with_null"].append(col) + stats["total_null_cells"] += null_count + + return stats + + def _analyze_zero_values( + self, + df: pl.DataFrame, + cols: List[str], + ) -> Dict[str, Any]: + """分析零值 + + Args: + df: 数据框 + cols: 要分析的列名列表 + + Returns: + 零值统计字典 + """ + stats = { + "total_cells": len(df) * len(cols), + "zero_counts": {}, + "zero_percentages": {}, + "columns_with_zero": [], + "total_zero_cells": 0, + } + + for col in cols: + if col not in df.columns: + continue + + # 计算零值数量(排除空值) + non_null_series = df[col].drop_nulls() + if len(non_null_series) == 0: + continue + + zero_count = (non_null_series == 0).sum() + if zero_count > 0: + zero_pct = zero_count / len(df) * 100 + stats["zero_counts"][col] = int(zero_count) + stats["zero_percentages"][col] = zero_pct + stats["columns_with_zero"].append(col) + stats["total_zero_cells"] += int(zero_count) + + return stats + + def _check_all_null_by_date( + self, + df: pl.DataFrame, + cols: List[str], + ) -> Dict[str, Any]: + """检查是否存在某天某列全为空的情况 + + 使用 polars lazy frame 进行内存安全的高效计算。 + + Args: + df: 数据框 + cols: 要分析的列名列表 + + Returns: + 全空检查结果字典 + """ + results = { + "issues_found": False, + "issues": [], + } + + if "trade_date" not in df.columns: + return results + + # 过滤掉不在表中的列 + valid_cols = [c for c in cols if c in df.columns] + if not valid_cols: + return results + + # 使用 lazy frame 进行查询优化 + lf = df.lazy() + + # 核心步骤:只计算 null_count 和总行数 (聚合后数据量极小) + # 为每个列创建单独的 null_count 聚合表达式 + agg_exprs = [ + pl.col(col).null_count().alias(f"{col}_nulls") for col in valid_cols + ] + agg_exprs.append(pl.len().alias("total_rows")) + + agg_lf = lf.group_by("trade_date").agg(agg_exprs) + + # 收集结果 (此时 agg_df 行数通常只有几百到几千行) + agg_df = agg_lf.collect() + + # 在这个已经"脱水"的小表上进行逻辑检查 + issues = [] + for col in valid_cols: + null_col = f"{col}_nulls" + # 找出 null 数量等于总行数的日期 + bad_dates = agg_df.filter( + (pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0) + ).select(["trade_date", "total_rows"]) + + if not bad_dates.is_empty(): + for row in bad_dates.to_dicts(): + issues.append( + { + "date": row["trade_date"], + "column": col, + "total_rows": row["total_rows"], + } + ) + + if issues: + results["issues_found"] = True + results["issues"] = issues + + return results + + def _analyze_label( + self, + df: pl.DataFrame, + label_col: str, + ) -> Dict[str, Any]: + """分析标签列 + + Args: + df: 数据框 + label_col: 标签列名 + + Returns: + 标签分析字典 + """ + stats = { + "total_count": len(df), + "null_count": 0, + "null_percentage": 0.0, + "zero_count": 0, + "zero_percentage": 0.0, + "min": None, + "max": None, + "mean": None, + "std": None, + } + + if label_col not in df.columns: + return stats + + series = df[label_col] + + # 缺失值统计 + null_count = series.null_count() + stats["null_count"] = null_count + stats["null_percentage"] = null_count / len(df) * 100 if len(df) > 0 else 0 + + # 零值统计 + non_null_series = series.drop_nulls() + if len(non_null_series) > 0: + zero_count = (non_null_series == 0).sum() + stats["zero_count"] = int(zero_count) + stats["zero_percentage"] = zero_count / len(df) * 100 + + # 基本统计量 + stats["min"] = float(non_null_series.min()) + stats["max"] = float(non_null_series.max()) + stats["mean"] = float(non_null_series.mean()) + stats["std"] = float(non_null_series.std()) + + return stats + + def _print_null_analysis(self, stats: Dict[str, Any]) -> None: + """打印缺失值分析结果 + + Args: + stats: 缺失值统计字典 + """ + total_cells = stats["total_cells"] + total_null = stats["total_null_cells"] + null_cols = stats["columns_with_null"] + + print(f" 缺失值统计:") + print(f" 总单元格数: {total_cells:,}") + print( + f" 缺失单元格数: {total_null:,} ({total_null / total_cells * 100:.2f}%)" + ) + print(f" 有缺失值的列数: {len(null_cols)}/{len(self.feature_cols)}") + + if null_cols: + print(f" 缺失值最多的5个特征:") + sorted_cols = sorted( + stats["null_counts"].items(), + key=lambda x: x[1], + reverse=True, + )[:5] + for col, count in sorted_cols: + pct = stats["null_percentages"][col] + print(f" {col}: {count:,} ({pct:.2f}%)") + + def _print_zero_analysis(self, stats: Dict[str, Any]) -> None: + """打印零值分析结果 + + Args: + stats: 零值统计字典 + """ + total_cells = stats["total_cells"] + total_zero = stats["total_zero_cells"] + zero_cols = stats["columns_with_zero"] + + print(f" 零值统计:") + print(f" 总单元格数: {total_cells:,}") + print( + f" 零值单元格数: {total_zero:,} ({total_zero / total_cells * 100:.2f}%)" + ) + print(f" 有零值的列数: {len(zero_cols)}/{len(self.feature_cols)}") + + if zero_cols: + print(f" 零值最多的5个特征:") + sorted_cols = sorted( + stats["zero_counts"].items(), + key=lambda x: x[1], + reverse=True, + )[:5] + for col, count in sorted_cols: + pct = stats["zero_percentages"][col] + print(f" {col}: {count:,} ({pct:.2f}%)") + + def _print_all_null_by_date(self, results: Dict[str, Any]) -> None: + """打印按日期全空检查结果 + + Args: + results: 全空检查结果字典 + """ + issues = results["issues"] + + print(f" 按日期全空检查:") + if results["issues_found"]: + print(f" [警告] 发现 {len(issues)} 个问题:") + # 按日期分组显示 + by_date = {} + for issue in issues: + date = issue["date"] + if date not in by_date: + by_date[date] = [] + by_date[date].append(issue["column"]) + + for date in sorted(by_date.keys())[:5]: # 只显示前5个日期 + cols = by_date[date] + print(f" 日期 {date}: {len(cols)} 列全为空") + if len(cols) <= 3: + print(f" 列名: {', '.join(cols)}") + + if len(by_date) > 5: + print(f" ... 还有 {len(by_date) - 5} 个日期存在问题") + else: + print(f" [正常] 未发现某天某列全为空的情况") + + def _print_label_analysis(self, stats: Dict[str, Any]) -> None: + """打印标签分析结果 + + Args: + stats: 标签分析字典 + """ + print(f" 标签列统计 ({self.label_col}):") + print(f" 总数: {stats['total_count']:,}") + print(f" 缺失值: {stats['null_count']:,} ({stats['null_percentage']:.2f}%)") + print(f" 零值: {stats['zero_count']:,} ({stats['zero_percentage']:.2f}%)") + + if stats["mean"] is not None: + print(f" 最小值: {stats['min']:.6f}") + print(f" 最大值: {stats['max']:.6f}") + print(f" 均值: {stats['mean']:.6f}") + print(f" 标准差: {stats['std']:.6f}") + + def get_summary(self) -> str: + """获取分析结果摘要 + + Returns: + 摘要字符串 + """ + if not self.analysis_results: + return "尚未执行分析" + + lines = ["数据质量分析摘要", "=" * 40] + + for split_name, results in self.analysis_results.items(): + lines.append(f"\n[{split_name.upper()}]") + lines.append(f" 总行数: {results['total_rows']:,}") + + null_stats = results.get("null_analysis", {}) + if null_stats.get("columns_with_null"): + lines.append( + f" 缺失值: {null_stats['total_null_cells']:,} 个单元格, " + f"{len(null_stats['columns_with_null'])} 列受影响" + ) + + zero_stats = results.get("zero_analysis", {}) + if zero_stats.get("columns_with_zero"): + lines.append( + f" 零值: {zero_stats['total_zero_cells']:,} 个单元格, " + f"{len(zero_stats['columns_with_zero'])} 列受影响" + ) + + all_null = results.get("all_null_by_date", {}) + if all_null.get("issues_found"): + lines.append( + f" [警告] 发现 {len(all_null['issues'])} 个日期列全空问题" + ) + + return "\n".join(lines) + + +def analyze_data_quality( + data: Dict[str, Dict[str, Any]], + feature_cols: Optional[List[str]] = None, + label_col: Optional[str] = None, + verbose: bool = True, +) -> Dict[str, Any]: + """便捷函数:执行数据质量分析 + + Args: + data: 数据字典 + feature_cols: 特征列名列表 + label_col: 标签列名 + verbose: 是否打印详细信息 + + Returns: + 分析结果字典 + """ + analyzer = DataQualityAnalyzer( + feature_cols=feature_cols, + label_col=label_col, + verbose=verbose, + ) + return analyzer.analyze(data) diff --git a/src/experiment/learn_to_rank.py b/src/experiment/learn_to_rank.py index 0f51caa..5eca623 100644 --- a/src/experiment/learn_to_rank.py +++ b/src/experiment/learn_to_rank.py @@ -37,6 +37,7 @@ from src.experiment.common import ( get_model_save_path, save_model_with_factors, TOP_N, + TRAIN_SKIP_DAYS, ) # 训练类型标识 @@ -155,6 +156,7 @@ def main(): filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=TRAIN_SKIP_DAYS, ) # 4. 创建 RankTask diff --git a/src/experiment/regression.py b/src/experiment/regression.py index 0eee2c7..045debc 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -38,6 +38,7 @@ from src.experiment.common import ( get_model_save_path, save_model_with_factors, TOP_N, + TRAIN_SKIP_DAYS, ) # 训练类型标识 @@ -51,55 +52,55 @@ TRAINING_TYPE = "regression" # 排除的因子列表 EXCLUDED_FACTORS = [ - 'GTJA_alpha036', - 'GTJA_alpha032', - 'GTJA_alpha010', - 'GTJA_alpha005', - 'CP', - 'BP', - 'debt_to_equity', - 'current_ratio', - 'GTJA_alpha002', - 'GTJA_alpha027', - 'GTJA_alpha064', - 'GTJA_alpha062', - 'GTJA_alpha043', - 'GTJA_alpha044', - 'GTJA_alpha120', - 'GTJA_alpha117', - 'GTJA_alpha103', - 'GTJA_alpha104', - 'GTJA_alpha105', - 'GTJA_alpha073', - 'GTJA_alpha077', - 'GTJA_alpha085', - 'GTJA_alpha090', - 'GTJA_alpha087', - 'GTJA_alpha083', - 'GTJA_alpha092', - 'GTJA_alpha133', - 'GTJA_alpha131', - 'GTJA_alpha126', - 'GTJA_alpha124', - 'GTJA_alpha162', - 'GTJA_alpha164', - 'GTJA_alpha157', - 'GTJA_alpha177', - 'price_to_avg_cost', - 'cost_skewness', - 'GTJA_alpha191', - 'GTJA_alpha180', - 'history_position', - 'bottom_profit', - 'mean_median_dev', - 'smart_money_accumulation', -'GTJA_alpha013', -'GTJA_alpha099', -'GTJA_alpha107', -'GTJA_alpha119', -'GTJA_alpha141', -'GTJA_alpha130', -'GTJA_alpha173', + "GTJA_alpha036", + "GTJA_alpha032", + "GTJA_alpha010", + "GTJA_alpha005", + "CP", + "BP", + "debt_to_equity", + "current_ratio", + "GTJA_alpha002", + "GTJA_alpha027", + "GTJA_alpha064", + "GTJA_alpha062", + "GTJA_alpha043", + "GTJA_alpha044", + "GTJA_alpha120", + "GTJA_alpha117", + "GTJA_alpha103", + "GTJA_alpha104", + "GTJA_alpha105", + "GTJA_alpha073", + "GTJA_alpha077", + "GTJA_alpha085", + "GTJA_alpha090", + "GTJA_alpha087", + "GTJA_alpha083", + "GTJA_alpha092", + "GTJA_alpha133", + "GTJA_alpha131", + "GTJA_alpha126", + "GTJA_alpha124", + "GTJA_alpha162", + "GTJA_alpha164", + "GTJA_alpha157", + "GTJA_alpha177", + "price_to_avg_cost", + "cost_skewness", + "GTJA_alpha191", + "GTJA_alpha180", + "history_position", + "bottom_profit", + "mean_median_dev", + "smart_money_accumulation", + "GTJA_alpha013", + "GTJA_alpha099", + "GTJA_alpha107", + "GTJA_alpha119", + "GTJA_alpha141", + "GTJA_alpha130", + "GTJA_alpha173", ] # 模型参数配置 @@ -184,6 +185,7 @@ def main(): filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=TRAIN_SKIP_DAYS, ) # 4. 创建 RegressionTask diff --git a/src/experiment/tabm_regression.py b/src/experiment/tabm_regression.py new file mode 100644 index 0000000..16f04a6 --- /dev/null +++ b/src/experiment/tabm_regression.py @@ -0,0 +1,375 @@ +# %% md +# # TabM 回归训练流程 +# +# 使用 TabM (Tabular MLP with Ensembles) 模型进行回归训练。 +# TabM 通过内置集成机制(ensemble_size=32)实现高效的多模型集成。 +# %% md +# ## 1. 导入依赖 +# %% +import os + +from src.factors import FactorEngine +from src.training import ( + FactorManager, + DataPipeline, + TabMRegressionTask, + NullFiller, + Winsorizer, + StandardScaler, + CrossSectionalStandardScaler, +) +from src.training.core.trainer_v2 import Trainer +from src.training.components.filters import STFilter +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_NAME, + LABEL_FACTOR, + TRAIN_START, + TRAIN_END, + VAL_START, + VAL_END, + TEST_START, + TEST_END, + stock_pool_filter, + STOCK_FILTER_REQUIRED_COLUMNS, + OUTPUT_DIR, + SAVE_PREDICTIONS, + SAVE_MODEL, + get_model_save_path, + save_model_with_factors, + TOP_N, + TRAIN_SKIP_DAYS, +) +from src.experiment.data_quality_analyzer import DataQualityAnalyzer + +# 训练类型标识 +TRAINING_TYPE = "tabm_regression" + +# %% md +# ## 2. 训练特定配置 +# %% +# Label 配置(从 common.py 统一导入) +# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入 + +# 排除的因子列表(与 LightGBM 回归保持一致) +EXCLUDED_FACTORS = [ + # "GTJA_alpha001", + # "GTJA_alpha002", + # "GTJA_alpha003", + # "GTJA_alpha004", + # "GTJA_alpha005", + # "GTJA_alpha006", + # "GTJA_alpha007", + # "GTJA_alpha008", + # "GTJA_alpha009", + # "GTJA_alpha010", + # "GTJA_alpha011", + # "GTJA_alpha012", + # "GTJA_alpha013", + # "GTJA_alpha014", + # "GTJA_alpha015", + # "GTJA_alpha016", + # "GTJA_alpha017", + # "GTJA_alpha018", + # "GTJA_alpha019", + # "GTJA_alpha020", + # "GTJA_alpha022", + # "GTJA_alpha023", + # "GTJA_alpha024", + # "GTJA_alpha025", + # "GTJA_alpha026", + # "GTJA_alpha027", + # "GTJA_alpha028", + # "GTJA_alpha029", + # "GTJA_alpha031", + # "GTJA_alpha032", + # "GTJA_alpha033", + # "GTJA_alpha034", + # "GTJA_alpha035", + # "GTJA_alpha036", + # "GTJA_alpha037", + # # "GTJA_alpha038", + # "GTJA_alpha039", + # "GTJA_alpha040", + # "GTJA_alpha041", + # "GTJA_alpha042", + # "GTJA_alpha043", + # "GTJA_alpha044", + # "GTJA_alpha045", + # "GTJA_alpha046", + # "GTJA_alpha047", + # "GTJA_alpha048", + # "GTJA_alpha049", + # "GTJA_alpha050", + # "GTJA_alpha051", + # "GTJA_alpha052", + # "GTJA_alpha053", + # "GTJA_alpha054", + # "GTJA_alpha056", + # "GTJA_alpha057", + # "GTJA_alpha058", + # "GTJA_alpha059", + # "GTJA_alpha060", + # "GTJA_alpha061", + # "GTJA_alpha062", + # "GTJA_alpha063", + # "GTJA_alpha064", + # "GTJA_alpha065", + # "GTJA_alpha066", + # "GTJA_alpha067", + # "GTJA_alpha068", + # "GTJA_alpha070", + # "GTJA_alpha071", + # "GTJA_alpha072", + # "GTJA_alpha073", + # "GTJA_alpha074", + # "GTJA_alpha076", + # "GTJA_alpha077", + # "GTJA_alpha078", + # "GTJA_alpha079", + # "GTJA_alpha080", + # "GTJA_alpha081", + # "GTJA_alpha082", + # "GTJA_alpha083", + # "GTJA_alpha084", + # "GTJA_alpha085", + # "GTJA_alpha086", + # "GTJA_alpha087", + # "GTJA_alpha088", + # "GTJA_alpha089", + # "GTJA_alpha090", + # "GTJA_alpha091", + # "GTJA_alpha092", + # "GTJA_alpha093", + # "GTJA_alpha094", + # "GTJA_alpha095", + # "GTJA_alpha096", + # "GTJA_alpha097", + # "GTJA_alpha098", + # "GTJA_alpha099", + # "GTJA_alpha100", + # "GTJA_alpha101", + # "GTJA_alpha102", + # "GTJA_alpha103", + # "GTJA_alpha104", + # "GTJA_alpha105", + # "GTJA_alpha106", + # "GTJA_alpha107", + # "GTJA_alpha108", + # "GTJA_alpha109", + # "GTJA_alpha110", + # "GTJA_alpha111", + # "GTJA_alpha112", + # # "GTJA_alpha113", + # "GTJA_alpha114", + # "GTJA_alpha115", + # "GTJA_alpha117", + # "GTJA_alpha118", + # "GTJA_alpha119", + # "GTJA_alpha120", + # # "GTJA_alpha121", + # "GTJA_alpha122", + # "GTJA_alpha123", + # "GTJA_alpha124", + # "GTJA_alpha125", + # "GTJA_alpha126", + # "GTJA_alpha127", + # "GTJA_alpha128", + # "GTJA_alpha129", + # "GTJA_alpha130", + # "GTJA_alpha131", + # "GTJA_alpha132", + # "GTJA_alpha133", + # "GTJA_alpha134", + # "GTJA_alpha135", + # "GTJA_alpha136", + # # "GTJA_alpha138", + # "GTJA_alpha139", + # # "GTJA_alpha140", + # "GTJA_alpha141", + # "GTJA_alpha142", + # "GTJA_alpha145", + # # "GTJA_alpha146", + # "GTJA_alpha148", + # "GTJA_alpha150", + # "GTJA_alpha151", + # "GTJA_alpha152", + # "GTJA_alpha153", + # "GTJA_alpha154", + # "GTJA_alpha155", + # "GTJA_alpha156", + # "GTJA_alpha157", + # "GTJA_alpha158", + # "GTJA_alpha159", + # "GTJA_alpha160", + # "GTJA_alpha161", + # "GTJA_alpha162", + # "GTJA_alpha163", + # "GTJA_alpha164", + # # "GTJA_alpha165", + # "GTJA_alpha166", + # "GTJA_alpha167", + # "GTJA_alpha168", + # "GTJA_alpha169", + # "GTJA_alpha170", + # "GTJA_alpha171", + # "GTJA_alpha173", + # "GTJA_alpha174", + # "GTJA_alpha175", + # "GTJA_alpha176", + # "GTJA_alpha177", + # "GTJA_alpha178", + # "GTJA_alpha179", + # "GTJA_alpha180", + # # "GTJA_alpha183", + # "GTJA_alpha184", + # "GTJA_alpha185", + # "GTJA_alpha187", + # "GTJA_alpha188", + # "GTJA_alpha189", + # "GTJA_alpha191", + # "chip_dispersion_90", + # "chip_dispersion_70", + # "cost_skewness", + # "dispersion_change_20", + # "price_to_avg_cost", + # "price_to_median_cost", + # "mean_median_dev", + # "trap_pressure", + # "bottom_profit", + # "history_position", + # "winner_rate_surge_5", + # "winner_rate_cs_rank", + # "winner_rate_dev_20", + # "winner_rate_volatility", + # "smart_money_accumulation", + # "winner_vol_corr_20", + # "cost_base_momentum", + # "bottom_cost_stability", + # "pivot_reversion", + # "chip_transition", +] + +# TabM 模型参数配置(来自用户提供的示例代码) +MODEL_PARAMS = { + # ==================== MLP 结构 ==================== + "n_blocks": 3, # MLP 层数 + "d_block": 256, # 每层神经元数 + "dropout": 0.3, # Dropout 率 + # ==================== 集成机制 ==================== + "ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成) + # ==================== 训练参数 ==================== + "batch_size": 1024, # 批次大小 + "learning_rate": 1e-3, # 学习率 + "weight_decay": 1e-5, # 权重衰减 + "epochs": 100, # 训练轮数 + # ==================== 早停 ==================== + "early_stopping_patience": 30, # 早停耐心值 +} + +# 日期范围配置 +date_range = { + "train": (TRAIN_START, TRAIN_END), + "val": (VAL_START, VAL_END), + "test": (TEST_START, TEST_END), +} + +# 输出配置 +output_config = { + "output_dir": OUTPUT_DIR, + "output_filename": "tabm_regression_output.csv", + "save_predictions": SAVE_PREDICTIONS, + "save_model": SAVE_MODEL, + "model_save_path": get_model_save_path(TRAINING_TYPE), + "top_n": TOP_N, +} + + +def main(): + """主函数""" + print("\n" + "=" * 80) + print("TabM 回归模型训练") + print("=" * 80) + + # 1. 创建 FactorEngine + print("\n[1] 创建 FactorEngine") + engine = FactorEngine() + + # 2. 创建 FactorManager + print("\n[2] 创建 FactorManager") + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + # 3. 创建 DataPipeline + # 【关键】TabM 需要标准化输入,使用 StandardScaler + # 处理顺序:NullFiller -> Winsorizer -> StandardScaler + print("\n[3] 创建 DataPipeline") + pipeline = DataPipeline( + factor_manager=factor_manager, + processor_configs=[ + (NullFiller, {"strategy": "mean"}), + (Winsorizer, {"lower": 0.01, "upper": 0.99}), # 先缩尾处理厚尾分布 + (StandardScaler, {}), # TabM 需要标准化输入 + ], + label_processor_configs=[ + # 对 label 进行缩尾处理(去除极端收益率) + (Winsorizer, {"lower": 0.05, "upper": 0.95}), + ], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=TRAIN_SKIP_DAYS, + ) + + # 4. 创建 TabMRegressionTask + print("\n[4] 创建 TabMRegressionTask") + task = TabMRegressionTask( + model_params=MODEL_PARAMS, + label_name=LABEL_NAME, + ) + + # 5. 创建 Trainer + print("\n[5] 创建 Trainer") + trainer = Trainer( + data_pipeline=pipeline, + task=task, + output_config=output_config, + verbose=True, + ) + + # 6. 执行训练 + print("\n[6] 执行训练") + results = trainer.run(engine=engine, date_range=date_range) + + # 7. 绘制训练曲线 + print("\n[7] 绘制训练曲线") + task.plot_training_metrics( + output_path=os.path.join(OUTPUT_DIR, "tabm_training_curve.png") + ) + + # 8. 保存模型和因子信息(如果启用) + if SAVE_MODEL: + print("\n[8] 保存模型和因子信息") + save_model_with_factors( + model=task.get_model(), + model_path=output_config["model_save_path"], + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + fitted_processors=pipeline.get_fitted_processors(), + ) + + print("\n" + "=" * 80) + print("TabM 训练流程完成!") + print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_regression_output.csv')}") + print("=" * 80) + + return results + + +if __name__ == "__main__": + main() diff --git a/src/experiment/tabpfn_regression.py b/src/experiment/tabpfn_regression.py new file mode 100644 index 0000000..c647839 --- /dev/null +++ b/src/experiment/tabpfn_regression.py @@ -0,0 +1,425 @@ +# %% md +# # TabPFN 回归训练流程 +# +# 使用 TabPFN (Prior-Data Fitted Network) 进行回归预测。 +# TabPFN 通过上下文学习进行预测,无需传统梯度下降训练过程。 +# %% md +# ## 1. 导入依赖 +# %% +import os + +from src.factors import FactorEngine +from src.training import ( + FactorManager, + DataPipeline, + NullFiller, + Winsorizer, + StandardScaler, + CrossSectionalStandardScaler, +) +from src.training.core.trainer_v2 import Trainer +from src.training.components.filters import STFilter +from src.training.components.models import TabPFNModel +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_NAME, + LABEL_FACTOR, + TRAIN_START, + TRAIN_END, + VAL_START, + VAL_END, + TEST_START, + TEST_END, + stock_pool_filter, + STOCK_FILTER_REQUIRED_COLUMNS, + OUTPUT_DIR, + SAVE_PREDICTIONS, + SAVE_MODEL, + get_model_save_path, + save_model_with_factors, + TOP_N, +) + +# 训练类型标识 +TRAINING_TYPE = "tabpfn" + +# %% md +# ## 2. 训练特定配置 +# %% +# Label 配置(从 common.py 统一导入) +# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入 + +# 排除的因子列表(与 regression.py 保持一致) +EXCLUDED_FACTORS = [ + "GTJA_alpha001", + "GTJA_alpha002", + "GTJA_alpha003", + "GTJA_alpha004", + "GTJA_alpha005", + "GTJA_alpha006", + "GTJA_alpha007", + "GTJA_alpha008", + "GTJA_alpha009", + "GTJA_alpha010", + "GTJA_alpha011", + "GTJA_alpha012", + "GTJA_alpha013", + "GTJA_alpha014", + "GTJA_alpha015", + "GTJA_alpha016", + "GTJA_alpha017", + "GTJA_alpha018", + "GTJA_alpha019", + "GTJA_alpha020", + "GTJA_alpha022", + "GTJA_alpha023", + "GTJA_alpha024", + "GTJA_alpha025", + "GTJA_alpha026", + "GTJA_alpha027", + "GTJA_alpha028", + "GTJA_alpha029", + "GTJA_alpha031", + "GTJA_alpha032", + "GTJA_alpha033", + "GTJA_alpha034", + "GTJA_alpha035", + "GTJA_alpha036", + "GTJA_alpha037", + # "GTJA_alpha038", + "GTJA_alpha039", + "GTJA_alpha040", + "GTJA_alpha041", + "GTJA_alpha042", + "GTJA_alpha043", + "GTJA_alpha044", + "GTJA_alpha045", + "GTJA_alpha046", + "GTJA_alpha047", + "GTJA_alpha048", + "GTJA_alpha049", + "GTJA_alpha050", + "GTJA_alpha051", + "GTJA_alpha052", + "GTJA_alpha053", + "GTJA_alpha054", + "GTJA_alpha056", + "GTJA_alpha057", + "GTJA_alpha058", + "GTJA_alpha059", + "GTJA_alpha060", + "GTJA_alpha061", + "GTJA_alpha062", + "GTJA_alpha063", + "GTJA_alpha064", + "GTJA_alpha065", + "GTJA_alpha066", + "GTJA_alpha067", + "GTJA_alpha068", + "GTJA_alpha070", + "GTJA_alpha071", + "GTJA_alpha072", + "GTJA_alpha073", + "GTJA_alpha074", + "GTJA_alpha076", + "GTJA_alpha077", + "GTJA_alpha078", + "GTJA_alpha079", + "GTJA_alpha080", + "GTJA_alpha081", + "GTJA_alpha082", + "GTJA_alpha083", + "GTJA_alpha084", + "GTJA_alpha085", + "GTJA_alpha086", + "GTJA_alpha087", + "GTJA_alpha088", + "GTJA_alpha089", + "GTJA_alpha090", + "GTJA_alpha091", + "GTJA_alpha092", + "GTJA_alpha093", + "GTJA_alpha094", + "GTJA_alpha095", + "GTJA_alpha096", + "GTJA_alpha097", + "GTJA_alpha098", + "GTJA_alpha099", + "GTJA_alpha100", + "GTJA_alpha101", + "GTJA_alpha102", + "GTJA_alpha103", + "GTJA_alpha104", + "GTJA_alpha105", + "GTJA_alpha106", + "GTJA_alpha107", + "GTJA_alpha108", + "GTJA_alpha109", + "GTJA_alpha110", + "GTJA_alpha111", + "GTJA_alpha112", + # "GTJA_alpha113", + "GTJA_alpha114", + "GTJA_alpha115", + "GTJA_alpha117", + "GTJA_alpha118", + "GTJA_alpha119", + "GTJA_alpha120", + # "GTJA_alpha121", + "GTJA_alpha122", + "GTJA_alpha123", + "GTJA_alpha124", + "GTJA_alpha125", + "GTJA_alpha126", + "GTJA_alpha127", + "GTJA_alpha128", + "GTJA_alpha129", + "GTJA_alpha130", + "GTJA_alpha131", + "GTJA_alpha132", + "GTJA_alpha133", + "GTJA_alpha134", + "GTJA_alpha135", + "GTJA_alpha136", + # "GTJA_alpha138", + "GTJA_alpha139", + # "GTJA_alpha140", + "GTJA_alpha141", + "GTJA_alpha142", + "GTJA_alpha145", + # "GTJA_alpha146", + "GTJA_alpha148", + "GTJA_alpha150", + "GTJA_alpha151", + "GTJA_alpha152", + "GTJA_alpha153", + "GTJA_alpha154", + "GTJA_alpha155", + "GTJA_alpha156", + "GTJA_alpha157", + "GTJA_alpha158", + "GTJA_alpha159", + "GTJA_alpha160", + "GTJA_alpha161", + "GTJA_alpha162", + "GTJA_alpha163", + "GTJA_alpha164", + # "GTJA_alpha165", + "GTJA_alpha166", + "GTJA_alpha167", + "GTJA_alpha168", + "GTJA_alpha169", + "GTJA_alpha170", + "GTJA_alpha171", + "GTJA_alpha173", + "GTJA_alpha174", + "GTJA_alpha175", + "GTJA_alpha176", + "GTJA_alpha177", + "GTJA_alpha178", + "GTJA_alpha179", + "GTJA_alpha180", + # "GTJA_alpha183", + "GTJA_alpha184", + "GTJA_alpha185", + "GTJA_alpha187", + "GTJA_alpha188", + "GTJA_alpha189", + "GTJA_alpha191", + "chip_dispersion_90", + "chip_dispersion_70", + "cost_skewness", + "dispersion_change_20", + "price_to_avg_cost", + "price_to_median_cost", + "mean_median_dev", + "trap_pressure", + "bottom_profit", + "history_position", + "winner_rate_surge_5", + "winner_rate_cs_rank", + "winner_rate_dev_20", + "winner_rate_volatility", + "smart_money_accumulation", + "winner_vol_corr_20", + "cost_base_momentum", + "bottom_cost_stability", + "pivot_reversion", + "chip_transition", +] + +# 模型参数配置 +MODEL_PARAMS = { + # ==================== 设备配置 ==================== + "device": "cuda", # 计算设备: "cuda" 或 "cpu"(默认 cuda) + # ==================== 上下文限制 ==================== + "max_context_size": 100, # 16GB GPU 建议 1000-3000,32GB 可尝试 5000-8000 +} + +# 日期范围配置 +date_range = { + "train": (TRAIN_START, TRAIN_END), + "val": (VAL_START, VAL_END), + "test": (TEST_START, TEST_END), +} + +# 输出配置 +output_config = { + "output_dir": OUTPUT_DIR, + "output_filename": "tabpfn_output.csv", + "save_predictions": SAVE_PREDICTIONS, + "save_model": SAVE_MODEL, + "model_save_path": get_model_save_path(TRAINING_TYPE), + "top_n": TOP_N, +} + + +# %% md +# ## 3. 自定义 TabPFN 任务 +# %% +from src.training.tasks import RegressionTask + + +class TabPFNTask(RegressionTask): + """TabPFN 回归任务 + + 继承自 RegressionTask,但使用 TabPFNModel 作为模型。 + TabPFN 不需要传统的训练过程,而是通过上下文学习进行预测。 + """ + + def __init__(self, model_params: dict, label_name: str): + """初始化 TabPFN 任务 + + Args: + model_params: TabPFN 参数字典 + label_name: Label 列名称 + """ + # 不调用父类 __init__,直接初始化以避免创建 LightGBMModel + from src.training.tasks.base import BaseTask + + BaseTask.__init__(self, model_params, label_name) + self.evals_result: dict | None = None + self.model = TabPFNModel(params=model_params) + + def fit(self, train_data: dict, val_data: dict) -> None: + """训练 TabPFN 模型 + + TabPFN 通过将训练数据加载到模型上下文中进行"训练", + 不需要传统的梯度下降优化过程。 + + Args: + train_data: 训练数据 {"X": DataFrame, "y": Series} + val_data: 验证数据,用于评估但不参与训练 + """ + X_train = train_data["X"] + y_train = train_data["y"] + X_val = val_data.get("X") + y_val = val_data.get("y") + + # TabPFN 使用 eval_set 进行验证 + self.model.fit( + X_train, y_train, eval_set=(X_val, y_val) if X_val is not None else None + ) + + def get_model(self) -> TabPFNModel: + """获取训练好的模型实例""" + return self.model + + +# %% md +# ## 4. 主函数 +# %% +def main(): + """主函数""" + print("\n" + "=" * 80) + print("TabPFN 回归模型训练") + print("=" * 80) + print("\n[说明] TabPFN 使用上下文学习(In-Context Learning),") + print(" 训练过程实际是加载数据到模型上下文。") + print(" 如果训练数据超过上下文限制,会自动截取最近的数据。") + + # 1. 创建 FactorEngine + print("\n[1] 创建 FactorEngine") + engine = FactorEngine() + + # 2. 创建 FactorManager + print("\n[2] 创建 FactorManager") + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + # 3. 创建 DataPipeline + print("\n[3] 创建 DataPipeline") + pipeline = DataPipeline( + factor_manager=factor_manager, + processor_configs=[ + (NullFiller, {"strategy": "mean"}), + (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (StandardScaler, {}), + # (CrossSectionalStandardScaler, {}), + ], + label_processor_configs=[ + # 对 label 进行缩尾处理(去除极端收益率) + (Winsorizer, {"lower": 0.05, "upper": 0.95}), + # (StandardScaler, {}), + ], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + ) + + # 4. 创建 TabPFNTask + print("\n[4] 创建 TabPFNTask") + task = TabPFNTask( + model_params=MODEL_PARAMS, + label_name=LABEL_NAME, + ) + + # 5. 创建 Trainer + print("\n[5] 创建 Trainer") + trainer = Trainer( + data_pipeline=pipeline, + task=task, + output_config=output_config, + verbose=True, + ) + + # 6. 执行训练 + print("\n[6] 执行训练") + results = trainer.run(engine=engine, date_range=date_range) + + # 7. 保存模型和因子信息(如果启用) + if SAVE_MODEL: + print("\n[7] 保存模型和因子信息") + save_model_with_factors( + model=task.get_model(), + model_path=output_config["model_save_path"], + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + fitted_processors=pipeline.get_fitted_processors(), + ) + + # 8. 输出 TabPFN 特有指标 + print("\n" + "=" * 80) + print("TabPFN 训练完成!") + print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabpfn_output.csv')}") + + # 显示验证集评估结果(如果可用) + model = task.get_model() + best_score = model.get_best_score() + if best_score: + print("\n[验证集评估指标]") + for metric, value in best_score.get("valid_0", {}).items(): + print(f" - {metric}: {value:.6f}") + + print("=" * 80) + + return results + + +if __name__ == "__main__": + main() diff --git a/src/scripts/check_gpu.py b/src/scripts/check_gpu.py new file mode 100644 index 0000000..04b7ca2 --- /dev/null +++ b/src/scripts/check_gpu.py @@ -0,0 +1,15 @@ +import torch + +# 查看 PyTorch 版本(关键!) +print(f"PyTorch 版本: {torch.__version__}") +# CPU 版本会显示: 2.6.0+cpu +# GPU 版本会显示: 2.6.0+cu118 / 2.6.0+cu121 / 2.6.0+cu124 等 + +# 检查 CUDA 是否可用 +print(f"CUDA 可用: {torch.cuda.is_available()}") + +# 如果有 CUDA,查看版本 +if torch.cuda.is_available(): + print(f"CUDA 版本: {torch.version.cuda}") + print(f"GPU 数量: {torch.cuda.device_count()}") + print(f"GPU 名称: {torch.cuda.get_device_name(0)}") \ No newline at end of file diff --git a/src/training/__init__.py b/src/training/__init__.py index 0d04219..948ba0b 100644 --- a/src/training/__init__.py +++ b/src/training/__init__.py @@ -29,7 +29,7 @@ from src.training.components.processors import ( ) # 模型 -from src.training.components.models import LightGBMModel +from src.training.components.models import LightGBMModel, TabMModel # 数据过滤器 from src.training.components.filters import BaseFilter, STFilter @@ -50,7 +50,7 @@ from src.training.config import TrainingConfig from src.training.factor_manager import FactorManager from src.training.pipeline import DataPipeline from src.training.result_analyzer import ResultAnalyzer -from src.training.tasks import BaseTask, RegressionTask, RankTask +from src.training.tasks import BaseTask, RegressionTask, RankTask, TabMRegressionTask # 从 trainer_v2 导入新 Trainer(推荐) from src.training.core.trainer_v2 import Trainer as TrainerV2 @@ -79,6 +79,7 @@ __all__ = [ "STFilter", # 模型 "LightGBMModel", + "TabMModel", # 训练核心(旧版,已废弃) "StockPoolManager", "Trainer", @@ -93,5 +94,6 @@ __all__ = [ "BaseTask", "RegressionTask", "RankTask", + "TabMRegressionTask", "TrainerV2", # 新的 Trainer(推荐) ] diff --git a/src/training/components/models/__init__.py b/src/training/components/models/__init__.py index 921a4b0..648f4ec 100644 --- a/src/training/components/models/__init__.py +++ b/src/training/components/models/__init__.py @@ -5,5 +5,7 @@ from src.training.components.models.lightgbm import LightGBMModel from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel +from src.training.components.models.tabpfn_model import TabPFNModel +from src.training.components.models.tabm_model import TabMModel -__all__ = ["LightGBMModel", "LightGBMLambdaRankModel"] +__all__ = ["LightGBMModel", "LightGBMLambdaRankModel", "TabPFNModel", "TabMModel"] diff --git a/src/training/components/models/tabm_model.py b/src/training/components/models/tabm_model.py new file mode 100644 index 0000000..d0ee247 --- /dev/null +++ b/src/training/components/models/tabm_model.py @@ -0,0 +1,368 @@ +"""TabM模型实现 + +TabM (Tabular Multilayer Perceptron with Ensembles) +基于 rtdl_revisiting_models 的 TabM 模型,支持内置集成。 +""" + +from typing import Dict, Any, List, Optional +from pathlib import Path +import pickle + +import numpy as np +import polars as pl +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +from tabm import TabM + +from src.training.components.base import BaseModel +from src.training.registry import register_model + + +@register_model("tabm") +class TabMModel(BaseModel): + """TabM回归模型 + + 特点: + - 使用MLP架构 + - 内置集成机制(ensemble_size),显存开销远小于独立模型 + - 训练时所有集成成员独立优化,保持多样性 + - 预测时取集成成员均值获得稳定结果 + + Attributes: + name: 模型名称标识 + params: 模型参数字典 + model: TabM模型实例 + device: 计算设备(cuda/cpu) + training_history_: 训练历史记录 + feature_names_: 特征名称列表 + """ + + name = "tabm" + + def __init__(self, params: Dict[str, Any]): + """初始化TabM模型 + + Args: + params: 模型参数字典,包含: + - n_blocks: MLP层数 (默认: 3) + - d_block: 每层神经元数 (默认: 256) + - dropout: Dropout率 (默认: 0.1) + - ensemble_size: 集成大小 (默认: 32) + - batch_size: 批次大小 (默认: 1024) + - learning_rate: 学习率 (默认: 1e-3) + - weight_decay: 权重衰减 (默认: 1e-5) + - epochs: 训练轮数 (默认: 50) + - early_stopping_patience: 早停耐心值 (默认: 10) + """ + self.params = params + self.model = None + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.training_history_: Dict[str, List[float]] = { + "train_loss": [], + "val_loss": [], + } + self.feature_names_: Optional[List[str]] = None + + # 损失函数 + self.criterion = nn.MSELoss() + + def _make_loader( + self, X: np.ndarray, y: Optional[np.ndarray] = None, shuffle: bool = False + ) -> DataLoader: + """创建DataLoader + + Args: + X: 特征数组 [N, n_features] + y: 标签数组 [N] 或 None + shuffle: 是否打乱数据 + + Returns: + DataLoader实例 + """ + if y is not None: + dataset = TensorDataset(torch.from_numpy(X), torch.from_numpy(y)) + else: + dataset = TensorDataset(torch.from_numpy(X)) + + batch_size = self.params.get("batch_size", 1024) + return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) + + def _validate(self, val_loader: DataLoader) -> float: + """验证模型 + + Args: + val_loader: 验证数据加载器 + + Returns: + 平均验证损失 + """ + self.model.eval() + total_loss = 0.0 + n_batches = 0 + + with torch.no_grad(): + for batch in val_loader: + if len(batch) == 2: + bx, by = batch + bx, by = bx.to(self.device), by.to(self.device) + else: + bx = batch[0].to(self.device) + by = None + + # 预测时取集成成员均值 + outputs = self.model(bx) # [B, E, 1] + preds = outputs.mean(dim=1).squeeze(-1) # [B] + + if by is not None: + loss = self.criterion(preds, by).item() + total_loss += loss + n_batches += 1 + + return total_loss / max(n_batches, 1) + + def fit( + self, X: pl.DataFrame, y: pl.Series, eval_set: Optional[tuple] = None + ) -> "TabMModel": + """训练TabM模型 + + 训练策略: + 1. 对所有集成成员独立计算Loss,保持多样性 + 2. 验证和预测时取ensemble成员均值 + + Args: + X: 训练特征DataFrame + y: 训练标签Series + eval_set: 验证集元组 (X_val, y_val),可选 + + Returns: + self + """ + # 保存特征名称 + self.feature_names_ = X.columns + + # 【关键】数据类型强制转换为float32 + # PyTorch对float64支持较差,避免使用Polars/Numpy默认类型 + X_np = X.to_numpy().astype(np.float32) + y_np = y.to_numpy().astype(np.float32) + + # 创建DataLoader + train_loader = self._make_loader(X_np, y_np, shuffle=True) + val_loader = None + if eval_set is not None: + X_val, y_val = eval_set + X_val_np = X_val.to_numpy().astype(np.float32) + y_val_np = y_val.to_numpy().astype(np.float32) + val_loader = self._make_loader(X_val_np, y_val_np, shuffle=False) + + n_features = X_np.shape[1] + ensemble_size = self.params.get("ensemble_size", 32) + + # 初始化TabM模型,使用TabM.make()自动填充默认参数 + self.model = TabM.make( + n_num_features=n_features, + cat_cardinalities=[], + d_out=1, # 回归任务输出维度为1 + n_blocks=self.params.get("n_blocks", 3), + d_block=self.params.get("d_block", 256), + dropout=self.params.get("dropout", 0.1), + k=ensemble_size, # 集成大小 + ).to(self.device) + + # 优化器 + optimizer = optim.AdamW( + self.model.parameters(), + lr=self.params.get("learning_rate", 1e-3), + weight_decay=self.params.get("weight_decay", 1e-5), + ) + + # 训练参数 + epochs = self.params.get("epochs", 50) + early_stopping_patience = self.params.get("early_stopping_patience", 10) + + # 训练循环 + best_val_loss = float("inf") + patience_counter = 0 + + print(f"[TabM] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}") + + for epoch in range(epochs): + # 训练阶段 + self.model.train() + train_loss = 0.0 + n_train_batches = 0 + + for bx, by in train_loader: + bx, by = bx.to(self.device), by.to(self.device) + + optimizer.zero_grad() + + # 前向传播 + # outputs形状: [Batch, Ensemble, 1] + outputs = self.model(bx) + outputs_squeezed = outputs.squeeze(-1) # [B, E] + + # 【关键】针对所有集成成员计算Loss + # 不先取均值,让每个集成成员独立收敛,保持集成多样性 + by_expanded = by.unsqueeze(-1).expand(-1, ensemble_size) # [B, E] + loss = self.criterion(outputs_squeezed, by_expanded) + + loss.backward() + optimizer.step() + + train_loss += loss.item() + n_train_batches += 1 + + avg_train_loss = train_loss / max(n_train_batches, 1) + self.training_history_["train_loss"].append(avg_train_loss) + + # 验证阶段 + if val_loader is not None: + val_loss = self._validate(val_loader) + self.training_history_["val_loss"].append(val_loss) + + # 早停逻辑 + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + + if (epoch + 1) % 5 == 0 or epoch == 0: + print( + f"[TabM] Epoch {epoch + 1}/{epochs} | " + f"Train Loss: {avg_train_loss:.6f} | " + f"Val Loss: {val_loss:.6f}" + ) + + if patience_counter >= early_stopping_patience: + print(f"[TabM] 早停触发,epoch {epoch + 1}") + break + else: + if (epoch + 1) % 5 == 0 or epoch == 0: + print( + f"[TabM] Epoch {epoch + 1}/{epochs} | " + f"Train Loss: {avg_train_loss:.6f}" + ) + + print(f"[TabM] 训练完成") + return self + + def predict(self, X: pl.DataFrame) -> np.ndarray: + """生成预测 + + 预测时对ensemble_size个成员取均值,获得稳定结果。 + + Args: + X: 特征DataFrame + + Returns: + 预测结果数组 [N] + """ + if self.model is None: + raise RuntimeError("模型未训练,请先调用fit()") + + # 数据类型转换 + X_np = X.to_numpy().astype(np.float32) + loader = self._make_loader(X_np, shuffle=False) + + self.model.eval() + all_preds = [] + + with torch.no_grad(): + for batch in loader: + bx = batch[0].to(self.device) + # 预测时取集成成员均值 + outputs = self.model(bx) # [B, E, 1] + preds = outputs.mean(dim=1).squeeze(-1) # [B] + all_preds.append(preds.cpu().numpy()) + + return np.concatenate(all_preds) + + def feature_importance(self) -> Optional[pl.Series]: + """获取特征重要性 + + TabM没有内置特征重要性计算,返回None。 + + Returns: + None + """ + return None + + def save(self, path: str | Path) -> None: + """保存模型 + + 保存模型state_dict和元数据(params, feature_names, training_history) + + Args: + path: 保存路径 + """ + if self.model is None: + raise RuntimeError("模型未训练,无法保存") + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + # 保存模型权重 + model_path = path.with_suffix(".pt") + torch.save(self.model.state_dict(), model_path) + + # 保存元数据 + meta_path = path.with_suffix(".meta") + meta = { + "params": self.params, + "feature_names": self.feature_names_, + "training_history": self.training_history_, + "device": str(self.device), + } + with open(meta_path, "wb") as f: + pickle.dump(meta, f) + + print(f"[TabM] 模型保存到: {path}") + + @classmethod + def load(cls, path: str) -> "TabMModel": + """加载模型 + + Args: + path: 模型路径(不含扩展名) + + Returns: + 加载的TabMModel实例 + """ + path = Path(path) + + # 加载元数据 + meta_path = path.with_suffix(".meta") + with open(meta_path, "rb") as f: + meta = pickle.load(f) + + # 创建实例 + instance = cls(meta["params"]) + instance.feature_names_ = meta["feature_names"] + instance.training_history_ = meta["training_history"] + + # 重建模型结构 + if instance.feature_names_ is not None: + n_features = len(instance.feature_names_) + ensemble_size = instance.params.get("ensemble_size", 32) + + instance.model = TabM.make( + n_num_features=n_features, + cat_cardinalities=[], + d_out=1, # 回归任务输出维度为1 + n_blocks=instance.params.get("n_blocks", 3), + d_block=instance.params.get("d_block", 256), + dropout=instance.params.get("dropout", 0.1), + k=ensemble_size, # 集成大小 + ).to(instance.device) + + # 加载权重 + model_path = path.with_suffix(".pt") + instance.model.load_state_dict( + torch.load(model_path, map_location=instance.device) + ) + + print(f"[TabM] 模型从 {path} 加载完成") + return instance diff --git a/src/training/components/models/tabpfn_model.py b/src/training/components/models/tabpfn_model.py new file mode 100644 index 0000000..3415a47 --- /dev/null +++ b/src/training/components/models/tabpfn_model.py @@ -0,0 +1,296 @@ +"""TabPFN 模型实现 + +基于 TabPFN (Prior-Data Fitted Network) 的回归模型实现。 +TabPFN 利用预训练的 Transformer 网络,通过上下文学习(in-context learning) +进行快速的小样本/中样本回归预测,无需传统训练过程。 +""" + +import json +import os +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import pandas as pd +import polars as pl +from scipy.stats import spearmanr + +from src.training.components.base import BaseModel +from src.training.registry import register_model + +os.environ["HF_TOKEN"] = "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" + + +@register_model("tabpfn") +class TabPFNModel(BaseModel): + """TabPFN 回归模型 + + 使用 TabPFN 库实现基于 Prior-Data Fitted Network 的回归预测。 + 该模型通过上下文学习方式进行预测,无需传统梯度下降训练。 + 支持 GPU 加速和自动上下文截断处理。 + + Attributes: + name: 模型名称 "tabpfn" + params: TabPFN 参数字典 + model: TabPFNRegressor 实例 + feature_names_: 特征名称列表 + evals_result_: 训练评估结果 + best_score_: 最佳评估指标 + """ + + name = "tabpfn" + + # TabPFN 官方限制(最大样本数),可通过 ignore_pretraining_limits=True 扩展 + MAX_CONTEXT_SIZE = 10000 + + def __init__(self, params: Optional[dict] = None): + """初始化 TabPFN 模型 + + Args: + params: TabPFN 参数字典,支持以下参数: + - device: 计算设备,'cuda' 或 'cpu'(默认 'cpu') + - model_path: 本地模型权重文件路径(可选) + - N_ensemble: 集成数量,用于降低预测方差(默认 1) + - max_context_size: 最大上下文样本数(默认 50000) + + Examples: + >>> model = TabPFNModel(params={ + ... "device": "cuda", + ... "N_ensemble": 5, + ... }) + """ + self.params = dict(params) if params is not None else {} + self.model = None + self.feature_names_: Optional[list] = None + self.evals_result_: Optional[dict] = None + self.best_score_: Optional[dict] = None + + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + eval_set: Optional[tuple] = None, + ) -> "TabPFNModel": + """训练/加载 TabPFN 模型 + + TabPFN 采用上下文学习,"fit" 操作实际上是加载训练数据到模型上下文。 + 如果训练数据超过上下文限制,会自动截取最近的数据。 + + Args: + X: 特征矩阵 (Polars DataFrame) + y: 目标变量 (Polars Series) + eval_set: 验证集元组 (X_val, y_val),用于评估模型性能 + + Returns: + self (支持链式调用) + + Raises: + ImportError: 未安装 tabpfn + RuntimeError: 模型初始化或加载失败 + """ + from tabpfn import TabPFNRegressor + + self.feature_names_ = X.columns + + # 转换为 numpy 数组 + X_np = X.to_numpy() + y_np = y.to_numpy() + + # 处理上下文大小限制 + max_context = self.params.get("max_context_size", self.MAX_CONTEXT_SIZE) + if len(X_np) > max_context: + print( + f"[TabPFN] 训练数据 {len(X_np)} 超过上下文限制 {max_context},截取最近数据" + ) + X_np = X_np[-max_context:] + y_np = y_np[-max_context:] + + # 初始化模型 + # TabPFNRegressor 需要设置 ignore_pretraining_limits=True 以支持超过 10,000 样本 + device = self.params.get("device", "cuda") + ignore_limits = self.params.get("ignore_pretraining_limits", True) + self.model = TabPFNRegressor( + device=device, + ignore_pretraining_limits=ignore_limits, + n_estimators=1 + ) + + # 加载上下文(TabPFN 的 "fit" 是加载上下文) + print("[TabPFN] 加载训练数据到上下文...") + self.model.fit(X_np, y_np) + + # 评估验证集 + if eval_set is not None: + X_val, y_val = eval_set + val_preds = self.predict(X_val) + y_val_np = y_val.to_numpy() + + # 计算评估指标 + mse = np.mean((y_val_np - val_preds) ** 2) + rank_ic, p_value = spearmanr(val_preds, y_val_np) + + self.evals_result_ = { + "valid_0": { + "mse": [mse], + "rank_ic": [rank_ic], + } + } + self.best_score_ = { + "valid_0": { + "mse": mse, + "rank_ic": rank_ic, + "rank_ic_pvalue": p_value, + } + } + + print(f"[TabPFN] 验证集 MSE: {mse:.6f}, Rank IC: {rank_ic:.4f}") + + return self + + def predict(self, X: pl.DataFrame) -> np.ndarray: + """预测 + + Args: + X: 特征矩阵 (Polars DataFrame) + + Returns: + 预测结果 (numpy ndarray) + + Raises: + RuntimeError: 模型未初始化时调用 + """ + if self.model is None: + raise RuntimeError("模型尚未初始化,请先调用 fit()") + + X_np = X.to_numpy() + result = self.model.predict(X_np) + return np.asarray(result) + + def predict_with_uncertainty( + self, X: pl.DataFrame + ) -> tuple[np.ndarray, np.ndarray]: + """预测并返回不确定性估计 + + 利用 N_ensemble 预测的标准差作为不确定性估计。 + + Args: + X: 特征矩阵 (Polars DataFrame) + + Returns: + (predictions, uncertainties) 元组,均为 numpy ndarray + """ + if self.model is None: + raise RuntimeError("模型尚未初始化,请先调用 fit()") + + X_np = X.to_numpy() + predictions = self.model.predict(X_np) + + # 如果使用了 ensemble,可以通过多次预测计算标准差 + # 注意:这需要修改 TabPFNRegressor 的使用方式 + # 这里返回预测值的零不确定性作为默认行为 + uncertainties = np.zeros_like(predictions) + + return np.asarray(predictions), np.asarray(uncertainties) + + def get_evals_result(self) -> Optional[dict]: + """获取训练评估结果 + + Returns: + 评估结果字典,如果未进行评估返回 None + """ + return self.evals_result_ + + def get_best_score(self) -> Optional[dict]: + """获取最佳评分 + + Returns: + 最佳评分字典,如果未进行评估返回 None + """ + return self.best_score_ + + def evaluate(self, X: pl.DataFrame, y: pl.Series) -> dict[str, float]: + """评估模型性能 + + 计算回归任务常用指标:MSE 和 Rank IC。 + + Args: + X: 特征矩阵 + y: 真实目标值 + + Returns: + 评估指标字典,包含 mse 和 rank_ic + """ + preds = self.predict(X) + y_np = y.to_numpy() + + mse = float(np.mean((y_np - preds) ** 2)) + rank_ic_result = spearmanr(preds, y_np) + rank_ic = float(rank_ic_result.correlation) + p_value = float(rank_ic_result.pvalue) + + return { + "mse": mse, + "rank_ic": rank_ic, + "rank_ic_pvalue": p_value, + } + + def feature_importance(self) -> None: + """TabPFN 不支持传统特征重要性 + + TabPFN 是基于 Transformer 的上下文学习模型, + 不提供类似决策树的特征重要性指标。 + + Returns: + None + """ + return None + + def save(self, path: str) -> None: + """保存模型元数据和配置 + + TabPFN 模型本身不支持序列化保存,因此只保存: + - 模型参数配置 + - 特征名称列表 + - 上下文数据摘要(样本数、特征数) + + 注意:实际使用时需要重新 fit 来加载上下文。 + + Args: + path: 保存路径 + """ + save_data = { + "model_type": self.name, + "params": self.params, + "feature_names": self.feature_names_, + "evals_result": self.evals_result_, + "best_score": self.best_score_, + } + + # 保存为 JSON + Path(path).parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(save_data, f, indent=2, ensure_ascii=False) + + @classmethod + def load(cls, path: str) -> "TabPFNModel": + """加载模型配置 + + 注意:TabPFN 模型需要重新 fit 才能使用, + 此方法仅恢复模型参数配置。 + + Args: + path: 配置文件路径 + + Returns: + 配置恢复的 TabPFNModel 实例(未 fit) + """ + with open(path, "r", encoding="utf-8") as f: + save_data = json.load(f) + + instance = cls(params=save_data.get("params", {})) + instance.feature_names_ = save_data.get("feature_names") + instance.evals_result_ = save_data.get("evals_result") + instance.best_score_ = save_data.get("best_score") + + print(f"[TabPFN] 已加载模型配置,需要调用 fit() 重新加载上下文") + return instance diff --git a/src/training/components/processors/transforms.py b/src/training/components/processors/transforms.py index 501cf4b..5554c57 100644 --- a/src/training/components/processors/transforms.py +++ b/src/training/components/processors/transforms.py @@ -3,6 +3,7 @@ 包含标准化、缩尾、缺失值填充等数据处理器。 """ +import math from typing import List, Literal, Optional, Union import polars as pl @@ -88,7 +89,9 @@ class NullFiller(BaseProcessor): """ if not self.by_date and self.strategy in ("mean", "median"): for col in self.feature_cols: - if col in X.columns and X[col].dtype.is_numeric(): + if col in X.columns and ( + X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean + ): if self.strategy == "mean": self.stats_[col] = X[col].mean() or 0.0 else: # median @@ -119,11 +122,14 @@ class NullFiller(BaseProcessor): raise ValueError(f"未知的填充策略: {self.strategy}") def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame: - """使用0填充缺失值""" + """使用0填充缺失值(同时处理 NaN 和 null)""" expressions = [] for col in X.columns: - if col in self.feature_cols and X[col].dtype.is_numeric(): - expr = pl.col(col).fill_null(0).alias(col) + if col in self.feature_cols and ( + X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean + ): + # 先 fill_nan 再 fill_null,确保两种缺失值都被处理 + expr = pl.col(col).fill_nan(0).fill_null(0).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) @@ -131,11 +137,19 @@ class NullFiller(BaseProcessor): return X.select(expressions) def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame: - """使用指定值填充缺失值""" + """使用指定值填充缺失值(同时处理 NaN 和 null)""" expressions = [] for col in X.columns: - if col in self.feature_cols and X[col].dtype.is_numeric(): - expr = pl.col(col).fill_null(self.fill_value).alias(col) + if col in self.feature_cols and ( + X[col].dtype.is_numeric() or X[col].dtype == pl.Boolean + ): + # 先 fill_nan 再 fill_null + expr = ( + pl.col(col) + .fill_nan(self.fill_value) + .fill_null(self.fill_value) + .alias(col) + ) expressions.append(expr) else: expressions.append(pl.col(col)) @@ -143,12 +157,13 @@ class NullFiller(BaseProcessor): return X.select(expressions) def _fill_global(self, X: pl.DataFrame) -> pl.DataFrame: - """使用全局统计量填充(训练集学到的统计量)""" + """使用全局统计量填充(训练集学到的统计量,同时处理 NaN 和 null)""" expressions = [] for col in X.columns: if col in self.stats_: fill_val = self.stats_[col] - expr = pl.col(col).fill_null(fill_val).alias(col) + # 先 fill_nan 再 fill_null + expr = pl.col(col).fill_nan(fill_val).fill_null(fill_val).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) @@ -156,8 +171,9 @@ class NullFiller(BaseProcessor): return X.select(expressions) def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame: - """使用每天截面统计量填充""" - # 确定需要处理的数值列 + """使用每天截面统计量填充(同时处理 NaN 和 null)""" + # 确定需要处理的列(仅 numeric 类型,排除 boolean) + # 注意:boolean 类型没有 NaN 概念,fill_nan 会报错 target_cols = [ col for col in self.feature_cols @@ -180,10 +196,20 @@ class NullFiller(BaseProcessor): result = X.with_columns(stat_exprs) # 使用统计量填充缺失值 + # 注意:如果某天某列全为null,统计量也会为null,所以需要链式填充 + # 同时处理 NaN 和 null fill_exprs = [] for col in X.columns: if col in target_cols: - expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col) + # 先用当天统计量填充 NaN 和 null,如果统计量也是null则用0填充 + expr = ( + pl.col(col) + .fill_nan(pl.col(f"{col}_stat")) + .fill_null(pl.col(f"{col}_stat")) + .fill_nan(0) # 如果统计量是 NaN,再用 0 填充 + .fill_null(0) # 如果统计量是 null,再用 0 填充 + .alias(col) + ) fill_exprs.append(expr) else: fill_exprs.append(pl.col(col)) @@ -230,17 +256,40 @@ class StandardScaler(BaseProcessor): self """ for col in self.feature_cols: + # 仅处理数值类型,排除布尔类型(标准化布尔类型语义不明确) if col in X.columns and X[col].dtype.is_numeric(): col_mean = X[col].mean() col_std = X[col].std() - if col_mean is not None and col_std is not None: + # 关键修复:检查是否为 None 且不是 NaN + # 注意:使用 try-except 处理类型转换,避免 LSP 类型检查错误 + try: + mean_is_valid = ( + col_mean is not None + and isinstance(col_mean, (int, float)) + and not math.isnan(col_mean) + ) + std_is_valid = ( + col_std is not None + and isinstance(col_std, (int, float)) + and not math.isnan(col_std) + ) + except (TypeError, ValueError): + mean_is_valid = False + std_is_valid = False + + if mean_is_valid and std_is_valid: self.mean_[col] = col_mean self.std_[col] = col_std + else: + # 如果统计量无效,使用默认值(mean=0, std=1) + # 防止 transform 时产生更多 NaN + self.mean_[col] = 0.0 + self.std_[col] = 1.0 return self def transform(self, X: pl.DataFrame) -> pl.DataFrame: - """标准化(使用训练集学到的参数) + """标准化(使用训练集学到的参数,增加 NaN 保护) Args: X: 待转换数据 @@ -253,7 +302,18 @@ class StandardScaler(BaseProcessor): if col in self.mean_ and col in self.std_: # 避免除以0 std_val = self.std_[col] if self.std_[col] != 0 else 1.0 - expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col) + # 关键修复:添加 fill_nan(0) 保险,防止计算产生 NaN + expr = ( + ((pl.col(col) - self.mean_[col]) / std_val) + .fill_nan(0) + .fill_null(0) + .alias(col) + ) + expressions.append(expr) + elif col in self.feature_cols: + # 对于应该被处理但未学习到统计量的列 + # 统一转换为float并同时处理 NaN 和 null + expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) @@ -308,13 +368,24 @@ class CrossSectionalStandardScaler(BaseProcessor): # 构建表达式列表 expressions = [] for col in X.columns: + # 仅处理数值类型,排除布尔类型(标准化布尔类型语义不明确) if col in self.feature_cols and X[col].dtype.is_numeric(): # 截面标准化:每天独立计算均值和标准差 # 避免除以0,当std为0时设为1 + # 关键修复:先 fill_nan 再 fill_null,防止计算产生的 NaN expr = ( - (pl.col(col) - pl.col(col).mean().over(self.date_col)) - / (pl.col(col).std().over(self.date_col) + 1e-10) - ).alias(col) + ( + (pl.col(col) - pl.col(col).mean().over(self.date_col)) + / (pl.col(col).std().over(self.date_col) + 1e-10) + ) + .fill_nan(0) + .fill_null(0) + .alias(col) + ) + expressions.append(expr) + elif col in self.feature_cols: + # 对于应该被处理但类型不匹配的列,转换为float并同时处理 NaN 和 null + expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) @@ -384,6 +455,7 @@ class Winsorizer(BaseProcessor): """ if not self.by_date: for col in self.feature_cols: + # 仅处理数值类型,排除布尔类型(quantile 不支持布尔类型) if col in X.columns and X[col].dtype.is_numeric(): self.bounds_[col] = { "lower": X[col].quantile(self.lower), @@ -414,13 +486,19 @@ class Winsorizer(BaseProcessor): upper = self.bounds_[col]["upper"] expr = pl.col(col).clip(lower, upper).alias(col) expressions.append(expr) + elif col in self.feature_cols: + # 对于应该被处理但未学习到边界的列(如全为NaN、布尔列等) + # 统一转换为float并填充0 + expr = pl.col(col).cast(pl.Float64).fill_null(0).alias(col) + expressions.append(expr) else: expressions.append(pl.col(col)) return X.select(expressions) def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame: """每日独立缩尾""" - # 确定需要处理的数值列 + # 确定需要处理的列(仅 numeric 类型,排除 boolean) + # 注意:quantile 操作不支持布尔类型 target_cols = [ col for col in self.feature_cols @@ -444,9 +522,11 @@ class Winsorizer(BaseProcessor): clip_exprs = [] for col in X.columns: if col in target_cols: + # 先用当天分位数缩尾,如果分位数是null(该日全为NaN)则填充0 clipped = ( pl.col(col) .clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper")) + .fill_null(0) .alias(col) ) clip_exprs.append(clipped) diff --git a/src/training/core/trainer_v2.py b/src/training/core/trainer_v2.py index 63679ac..2aa5d2d 100644 --- a/src/training/core/trainer_v2.py +++ b/src/training/core/trainer_v2.py @@ -95,6 +95,27 @@ class Trainer: verbose=self.verbose, ) + # Step 1.5: 数据质量分析 + if self.verbose: + print("\n[Step 1.5/7] 数据质量分析...") + + try: + from src.experiment.data_quality_analyzer import DataQualityAnalyzer + + # 获取特征列名(从训练集) + feature_cols = data["train"].get("feature_cols", []) + label_name = self.task.label_name + + analyzer = DataQualityAnalyzer( + feature_cols=feature_cols, + label_col=label_name, + verbose=self.verbose, + ) + analyzer.analyze(data) + except Exception as e: + if self.verbose: + print(f" [警告] 数据质量分析失败: {e}") + # Step 2: 处理标签 if self.verbose: print("\n[Step 2/7] 处理标签...") diff --git a/src/training/pipeline.py b/src/training/pipeline.py index c376987..cb23715 100644 --- a/src/training/pipeline.py +++ b/src/training/pipeline.py @@ -43,6 +43,7 @@ class DataPipeline: label_processor_configs: Optional[ List[Tuple[Type[BaseProcessor], Dict[str, Any]]] ] = None, + train_skip_days: int = 252, ): """初始化数据流水线 @@ -55,6 +56,7 @@ class DataPipeline: stock_pool_required_columns: 股票池筛选所需的额外列 label_processor_configs: Label 数据处理器配置列表,格式与 processor_configs 相同 例如:[(Winsorizer, {"lower": 0.01, "upper": 0.99})] 用于对 label 进行缩尾处理 + train_skip_days: 训练数据跳过前n天,用于避免训练初期数据不足的问题,默认252天 """ self.factor_manager = factor_manager self.processor_configs = processor_configs or [] @@ -64,6 +66,7 @@ class DataPipeline: self.fitted_processors: List[BaseProcessor] = [] self.label_processor_configs = label_processor_configs or [] self.fitted_label_processors: List[BaseProcessor] = [] + self.train_skip_days = train_skip_days def prepare_data( self, @@ -220,6 +223,8 @@ class DataPipeline: ) -> Dict[str, Dict[str, Any]]: """划分数据集 + 对于训练集,会根据 train_skip_days 参数跳过前n个交易日的数据。 + Args: data: 完整数据 date_range: 日期范围字典 @@ -236,6 +241,33 @@ class DataPipeline: mask = (data["trade_date"] >= start) & (data["trade_date"] <= end) split_df = data.filter(mask) + # 对训练集跳过前n天数据 + if split_name == "train" and self.train_skip_days > 0: + original_count = len(split_df) + # 获取唯一的交易日列表并按日期排序 + unique_dates = split_df["trade_date"].unique().sort() + if len(unique_dates) > self.train_skip_days: + # 跳过前n个交易日 + start_date = unique_dates[self.train_skip_days] + split_df = split_df.filter(pl.col("trade_date") >= start_date) + skipped_count = original_count - len(split_df) + if verbose: + print( + f" {split_name}: {len(split_df)} 条记录" + f" (跳过前{self.train_skip_days}天,减少{skipped_count}条)" + ) + else: + if verbose: + print( + f" [警告] 训练数据交易日数量({len(unique_dates)})" + f"少于跳过天数({self.train_skip_days}),未进行过滤" + ) + if verbose: + print(f" {split_name}: {len(split_df)} 条记录") + else: + if verbose: + print(f" {split_name}: {len(split_df)} 条记录") + result[split_name] = { "X": split_df.select(feature_cols), "y": split_df[label_name], @@ -243,9 +275,6 @@ class DataPipeline: "feature_cols": feature_cols, } - if verbose: - print(f" {split_name}: {len(split_df)} 条记录") - return result def _preprocess( @@ -345,6 +374,30 @@ class DataPipeline: split_data[split_name]["X"] = split_df.select(feature_cols) split_data[split_name]["y"] = split_df[label_name] + # 删除标签为 NaN 的行 + for split_name in ["train", "val", "test"]: + if split_name in split_data: + y_series = split_data[split_name]["y"] + y_nan_count = y_series.null_count() + + if y_nan_count > 0: + if verbose: + print(f" 删除 {split_name} 集中 {y_nan_count} 个标签为NaN的行") + + # 创建有效标签的mask + valid_mask = y_series.is_not_null() + + # 过滤所有相关数据 + split_data[split_name]["raw_data"] = split_data[split_name][ + "raw_data" + ].filter(valid_mask) + split_data[split_name]["X"] = split_data[split_name]["X"].filter( + valid_mask + ) + split_data[split_name]["y"] = split_data[split_name]["y"].filter( + valid_mask + ) + return split_data def get_fitted_processors(self) -> List[BaseProcessor]: diff --git a/src/training/tasks/__init__.py b/src/training/tasks/__init__.py index 09b7442..16f3d5c 100644 --- a/src/training/tasks/__init__.py +++ b/src/training/tasks/__init__.py @@ -6,9 +6,11 @@ from src.training.tasks.base import BaseTask from src.training.tasks.regression_task import RegressionTask from src.training.tasks.rank_task import RankTask +from src.training.tasks.tabm_regression_task import TabMRegressionTask __all__ = [ "BaseTask", "RegressionTask", "RankTask", + "TabMRegressionTask", ] diff --git a/src/training/tasks/tabm_regression_task.py b/src/training/tasks/tabm_regression_task.py new file mode 100644 index 0000000..63f42dc --- /dev/null +++ b/src/training/tasks/tabm_regression_task.py @@ -0,0 +1,165 @@ +"""TabM回归任务 + +TabM模型的回归训练任务实现。 +""" + +from typing import Dict, Any, Optional +from pathlib import Path + +import numpy as np +import polars as pl +import matplotlib.pyplot as plt + +from src.training.tasks.base import BaseTask +from src.training.components.models.tabm_model import TabMModel + +# Type alias for model type +TabMModelType = TabMModel + + +class TabMRegressionTask(BaseTask): + """TabM回归任务 + + 使用TabM模型进行回归训练,支持: + - 内置集成训练(ensemble_size) + - 早停机制 + - 训练曲线绘制 + + Attributes: + model_params: 模型参数字典 + label_name: 目标列名称 + model: TabMModel实例 + """ + + def __init__( + self, + model_params: Dict[str, Any], + label_name: str = "future_return_5", + ): + """初始化TabM回归任务 + + Args: + model_params: TabM模型参数,包含: + - n_blocks: MLP层数 + - d_block: 每层神经元数 + - dropout: Dropout率 + - ensemble_size: 集成大小 + - batch_size: 批次大小 + - learning_rate: 学习率 + - weight_decay: 权重衰减 + - epochs: 训练轮数 + - early_stopping_patience: 早停耐心值 + label_name: 目标列名称 + """ + super().__init__(model_params, label_name) + self.model_params = model_params + self.label_name = label_name + self.model = None # type: Optional[TabMModelType] + + def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]: + """准备标签 + + 回归任务不需要转换标签,直接返回原始数据。 + + Args: + data: 数据字典 + + Returns: + 未修改的数据字典 + """ + # 回归任务:标签已经是连续值,无需转换 + return data + + def fit( + self, + train_data: Dict[str, Any], + val_data: Dict[str, Any], + ) -> None: + """训练TabM模型 + + Args: + train_data: 训练数据字典,包含: + - X: 特征DataFrame + - y: 标签Series + val_data: 验证数据字典,包含: + - X: 特征DataFrame + - y: 标签Series + """ + print("\n[TabMRegressionTask] 开始训练...") + + # 创建模型实例 + self.model = TabMModel(self.model_params) + + # 提取训练数据 + X_train = train_data["X"] + y_train = train_data["y"] + X_val = val_data["X"] + y_val = val_data["y"] + + # 训练模型 + self.model.fit(X=X_train, y=y_train, eval_set=(X_val, y_val)) + + print("[TabMRegressionTask] 训练完成") + + def predict(self, test_data: Dict[str, Any]) -> np.ndarray: + """生成预测 + + Args: + test_data: 测试数据字典,包含: + - X: 特征DataFrame + + Returns: + 预测结果数组 + """ + if self.model is None: + raise RuntimeError("模型未训练,请先调用fit()") + + X_test = test_data["X"] + return self.model.predict(X_test) + + def get_model(self) -> Any: + """获取训练好的模型 + + Returns: + TabMModel实例或None + """ + return self.model + + def plot_training_metrics(self, output_path: Optional[str] = None) -> None: + """绘制训练指标 + + 绘制训练和验证损失曲线。 + + Args: + output_path: 图表保存路径,None则显示图表 + """ + if self.model is None or not self.model.training_history_["train_loss"]: + print("[TabMRegressionTask] 无训练历史可绘制") + return + + history = self.model.training_history_ + + fig, ax = plt.subplots(figsize=(10, 6)) + + epochs = range(1, len(history["train_loss"]) + 1) + ax.plot(epochs, history["train_loss"], "b-", label="Train Loss", linewidth=2) + + if history["val_loss"]: + ax.plot(epochs, history["val_loss"], "r-", label="Val Loss", linewidth=2) + + ax.set_xlabel("Epoch", fontsize=12) + ax.set_ylabel("Loss (MSE)", fontsize=12) + ax.set_title("TabM Training History", fontsize=14) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + + if output_path: + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + plt.savefig(output_path, dpi=150, bbox_inches="tight") + print(f"[TabMRegressionTask] 训练曲线保存到: {output_path}") + else: + plt.show() + + plt.close() diff --git a/tests/check_gtja.py b/tests/check_gtja.py new file mode 100644 index 0000000..b4aeef6 --- /dev/null +++ b/tests/check_gtja.py @@ -0,0 +1,85 @@ +"""检查 GTJA_alpha 因子""" + +import polars as pl + +from src.factors import FactorEngine +from src.training import FactorManager +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_FACTOR, +) + +EXCLUDED_FACTORS = [ + "GTJA_alpha001", + "GTJA_alpha002", + "GTJA_alpha003", + "GTJA_alpha004", + "GTJA_alpha005", + "GTJA_alpha006", + "GTJA_alpha007", + "GTJA_alpha008", + "GTJA_alpha009", + "GTJA_alpha010", + "GTJA_alpha011", + "GTJA_alpha012", + "GTJA_alpha013", + "GTJA_alpha014", + "GTJA_alpha015", +] + + +def main(): + print("=" * 80) + print("检查 GTJA_alpha 因子") + print("=" * 80) + + engine = FactorEngine() + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + # 注册因子 + feature_cols = factor_manager.register_to_engine(engine, verbose=False) + + # 找出 GTJA_alpha 因子 + gtja_factors = [f for f in feature_cols if f.startswith("GTJA_alpha")] + print(f"\nGTJA_alpha 因子数量: {len(gtja_factors)}") + print(f"前10个: {gtja_factors[:10]}") + + # 计算一个小的日期范围 + print("\n计算因子数据...") + data = engine.compute( + factor_names=gtja_factors[:10] + ["close"], # 只计算前10个 GTJA_alpha + close + start_date="20200101", + end_date="20200110", + ) + + print(f"\n数据形状: {data.shape}") + print(f"列: {data.columns}") + + # 检查每个 GTJA_alpha 因子的 NaN 情况 + print("\nGTJA_alpha 因子 NaN 统计:") + for col in gtja_factors[:10]: + if col in data.columns: + nan_count = data[col].null_count() + total = len(data) + print(f" {col}: {nan_count}/{total} ({nan_count / total * 100:.1f}%)") + else: + print(f" {col}: 列不存在!") + + # 检查 close 列作为对比 + print( + f"\n close: {data['close'].null_count()}/{len(data)} ({data['close'].null_count() / len(data) * 100:.1f}%)" + ) + + # 查看实际数据 + print("\n实际数据预览:") + print(data.select(["trade_date", "ts_code"] + gtja_factors[:3]).head(10)) + + +if __name__ == "__main__": + main() diff --git a/tests/diagnose_nan.py b/tests/diagnose_nan.py new file mode 100644 index 0000000..c44f92d --- /dev/null +++ b/tests/diagnose_nan.py @@ -0,0 +1,208 @@ +"""诊断 NaN 来源""" + +import numpy as np +import polars as pl + +from src.factors import FactorEngine +from src.training import ( + FactorManager, + DataPipeline, + NullFiller, + Winsorizer, + StandardScaler, +) +from src.training.components.filters import STFilter +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_NAME, + LABEL_FACTOR, + stock_pool_filter, + STOCK_FILTER_REQUIRED_COLUMNS, +) + +# 只使用少量因子加速测试 +EXCLUDED_FACTORS = [ + "GTJA_alpha001", + "GTJA_alpha002", + "GTJA_alpha003", + "GTJA_alpha004", + "GTJA_alpha005", + "GTJA_alpha006", + "GTJA_alpha007", + "GTJA_alpha008", + "GTJA_alpha009", + "GTJA_alpha010", + "GTJA_alpha011", + "GTJA_alpha012", + "GTJA_alpha013", + "GTJA_alpha014", + "GTJA_alpha015", +] + +TEST_DATE_RANGE = { + "train": ("20200101", "20200331"), # 缩小范围加速测试 + "val": ("20200401", "20200430"), + "test": ("20200501", "20200531"), +} + + +def main(): + print("=" * 80) + print("NaN 来源诊断") + print("=" * 80) + + engine = FactorEngine() + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + # Step 1: 注册因子并计算原始数据 + print("\n[Step 1] 注册因子并计算原始数据...") + feature_cols = factor_manager.register_to_engine(engine, verbose=False) + print(f" 特征数: {len(feature_cols)}") + + all_start = min( + TEST_DATE_RANGE["train"][0], + TEST_DATE_RANGE["val"][0], + TEST_DATE_RANGE["test"][0], + ) + all_end = max( + TEST_DATE_RANGE["train"][1], + TEST_DATE_RANGE["val"][1], + TEST_DATE_RANGE["test"][1], + ) + + raw_data = engine.compute( + factor_names=feature_cols + [LABEL_NAME], + start_date=all_start, + end_date=all_end, + ) + print(f" 原始数据形状: {raw_data.shape}") + + # 检查原始数据中的 NaN + print("\n[Step 2] 原始数据 NaN 统计...") + nan_counts = {} + for col in feature_cols[:20]: # 只检查前20个特征 + nan_count = raw_data[col].null_count() + if nan_count > 0: + nan_counts[col] = nan_count + + print(f" 含 NaN 的特征数 (前20个): {len(nan_counts)}") + for col, count in list(nan_counts.items())[:10]: + pct = count / len(raw_data) * 100 + print(f" {col}: {count} ({pct:.1f}%)") + + # Step 3: 应用过滤器 + print("\n[Step 3] 应用过滤器...") + st_filter = STFilter(data_router=engine.router) + filtered_data = st_filter.filter(raw_data) + print(f" 过滤后数据形状: {filtered_data.shape}") + + # 检查过滤后的 NaN + nan_after_filter = sum(filtered_data[col].null_count() for col in feature_cols[:20]) + print(f" 前20个特征总 NaN 数: {nan_after_filter}") + + # Step 4: 应用股票池筛选 + print("\n[Step 4] 应用股票池筛选...") + from src.training.core.stock_pool_manager import StockPoolManager + + pool_manager = StockPoolManager( + filter_func=stock_pool_filter, + required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + data_router=engine.router, + ) + pool_data = pool_manager.filter_and_select_daily(filtered_data) + print(f" 筛选后数据形状: {pool_data.shape}") + + # 检查筛选后的 NaN + nan_after_pool = sum(pool_data[col].null_count() for col in feature_cols[:20]) + print(f" 前20个特征总 NaN 数: {nan_after_pool}") + + # Step 5: 划分数据 + print("\n[Step 5] 划分训练集...") + train_mask = (pool_data["trade_date"] >= TEST_DATE_RANGE["train"][0]) & ( + pool_data["trade_date"] <= TEST_DATE_RANGE["train"][1] + ) + train_df = pool_data.filter(train_mask) + print(f" 训练集形状: {train_df.shape}") + + # 检查训练集的 NaN + nan_train_before = sum(train_df[col].null_count() for col in feature_cols[:20]) + print(f" 前20个特征总 NaN 数: {nan_train_before}") + + # Step 6: 依次应用 processors 并检查每一步的 NaN + print("\n[Step 6] 依次应用 processors...") + + # 6.1 NullFiller + print("\n [6.1] NullFiller (by_date=True, strategy=mean)...") + null_filler = NullFiller(feature_cols=feature_cols, strategy="mean", by_date=True) + after_null = null_filler.fit_transform(train_df) + nan_after_null = sum(after_null[col].null_count() for col in feature_cols[:20]) + print(f" 处理后前20个特征总 NaN 数: {nan_after_null}") + + # 检查具体哪些列还有 NaN + if nan_after_null > 0: + print(" 仍有 NaN 的列:") + for col in feature_cols[:20]: + count = after_null[col].null_count() + if count > 0: + print(f" {col}: {count}") + + # 6.2 Winsorizer + print("\n [6.2] Winsorizer (by_date=False)...") + winsorizer = Winsorizer( + feature_cols=feature_cols, lower=0.01, upper=0.99, by_date=False + ) + after_winsor = winsorizer.fit_transform(after_null) + nan_after_winsor = sum(after_winsor[col].null_count() for col in feature_cols[:20]) + print(f" 处理后前20个特征总 NaN 数: {nan_after_winsor}") + + # 6.3 StandardScaler + print("\n [6.3] StandardScaler...") + scaler = StandardScaler(feature_cols=feature_cols) + after_scaler = scaler.fit_transform(after_winsor) + nan_after_scaler = sum(after_scaler[col].null_count() for col in feature_cols[:20]) + print(f" 处理后前20个特征总 NaN 数: {nan_after_scaler}") + + # 检查具体哪些列还有 NaN + if nan_after_scaler > 0: + print(" 仍有 NaN 的列:") + for col in feature_cols[:20]: + count = after_scaler[col].null_count() + if count > 0: + # 检查这列在训练时的统计量 + has_mean = col in scaler.mean_ + has_std = col in scaler.std_ + mean_val = scaler.mean_.get(col, "N/A") + std_val = scaler.std_.get(col, "N/A") + print(f" {col}: {count}, mean={mean_val}, std={std_val}") + + # Step 7: 提取 X 并检查 + print("\n[Step 7] 提取特征矩阵 X...") + X = after_scaler.select(feature_cols) + X_np = X.to_numpy() + print(f" X 形状: {X_np.shape}") + print(f" X 中 NaN 总数: {np.isnan(X_np).sum()}") + + # 检查哪些特征列有 NaN + nan_by_col = [] + for i, col in enumerate(feature_cols): + col_nan = np.isnan(X_np[:, i]).sum() + if col_nan > 0: + nan_by_col.append((col, col_nan)) + + print(f" 含 NaN 的特征列数: {len(nan_by_col)}") + for col, count in nan_by_col[:10]: + print(f" {col}: {count}") + + print("\n" + "=" * 80) + print("诊断完成") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/tests/test_diagnose_nan.py b/tests/test_diagnose_nan.py new file mode 100644 index 0000000..4fe99fc --- /dev/null +++ b/tests/test_diagnose_nan.py @@ -0,0 +1,278 @@ +"""诊断 NaN 来源 - pytest 版本""" + +import numpy as np +import polars as pl +import pytest + +from src.factors import FactorEngine +from src.training import ( + FactorManager, + NullFiller, + Winsorizer, + StandardScaler, +) +from src.training.components.filters import STFilter +from src.training.core.stock_pool_manager import StockPoolManager +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_NAME, + LABEL_FACTOR, + stock_pool_filter, + STOCK_FILTER_REQUIRED_COLUMNS, +) + +# 只使用少量因子加速测试 +EXCLUDED_FACTORS = [ + "GTJA_alpha001", + "GTJA_alpha002", + "GTJA_alpha003", + "GTJA_alpha004", + "GTJA_alpha005", + "GTJA_alpha006", + "GTJA_alpha007", + "GTJA_alpha008", + "GTJA_alpha009", + "GTJA_alpha010", + "GTJA_alpha011", + "GTJA_alpha012", + "GTJA_alpha013", + "GTJA_alpha014", + "GTJA_alpha015", +] + +TEST_DATE_RANGE = { + "train": ("20200101", "20200331"), # 缩小范围加速测试 + "val": ("20200401", "20200430"), + "test": ("20200501", "20200531"), +} + + +def test_diagnose_nan_source(): + """诊断 NaN 来源""" + print("\n" + "=" * 80) + print("NaN 来源诊断") + print("=" * 80) + + engine = FactorEngine() + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + # Step 1: 注册因子并计算原始数据 + print("\n[Step 1] 注册因子并计算原始数据...") + feature_cols = factor_manager.register_to_engine(engine, verbose=False) + print(f" 特征数: {len(feature_cols)}") + + all_start = min( + TEST_DATE_RANGE["train"][0], + TEST_DATE_RANGE["val"][0], + TEST_DATE_RANGE["test"][0], + ) + all_end = max( + TEST_DATE_RANGE["train"][1], + TEST_DATE_RANGE["val"][1], + TEST_DATE_RANGE["test"][1], + ) + + raw_data = engine.compute( + factor_names=feature_cols + [LABEL_NAME], + start_date=all_start, + end_date=all_end, + ) + print(f" 原始数据形状: {raw_data.shape}") + + # 检查原始数据中的 NaN + print("\n[Step 2] 原始数据 NaN 统计...") + nan_counts = {} + for col in feature_cols[:20]: # 只检查前20个特征 + nan_count = raw_data[col].null_count() + if nan_count > 0: + nan_counts[col] = nan_count + + print(f" 含 NaN 的特征数 (前20个): {len(nan_counts)}") + for col, count in list(nan_counts.items())[:10]: + pct = count / len(raw_data) * 100 + print(f" {col}: {count} ({pct:.1f}%)") + + # Step 3: 应用过滤器 + print("\n[Step 3] 应用过滤器...") + st_filter = STFilter(data_router=engine.router) + filtered_data = st_filter.filter(raw_data) + print(f" 过滤后数据形状: {filtered_data.shape}") + + # 检查过滤后的 NaN + nan_after_filter = sum(filtered_data[col].null_count() for col in feature_cols[:20]) + print(f" 前20个特征总 NaN 数: {nan_after_filter}") + + # Step 4: 应用股票池筛选 + print("\n[Step 4] 应用股票池筛选...") + pool_manager = StockPoolManager( + filter_func=stock_pool_filter, + required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + data_router=engine.router, + ) + pool_data = pool_manager.filter_and_select_daily(filtered_data) + print(f" 筛选后数据形状: {pool_data.shape}") + + # 检查筛选后的 NaN + nan_after_pool = sum(pool_data[col].null_count() for col in feature_cols[:20]) + print(f" 前20个特征总 NaN 数: {nan_after_pool}") + + # Step 5: 划分数据 + print("\n[Step 5] 划分训练集...") + train_mask = (pool_data["trade_date"] >= TEST_DATE_RANGE["train"][0]) & ( + pool_data["trade_date"] <= TEST_DATE_RANGE["train"][1] + ) + train_df = pool_data.filter(train_mask) + print(f" 训练集形状: {train_df.shape}") + + # 检查训练集的 NaN + nan_train_before = sum(train_df[col].null_count() for col in feature_cols[:20]) + print(f" 前20个特征总 NaN 数: {nan_train_before}") + + # Step 6: 依次应用 processors 并检查每一步的 NaN + print("\n[Step 6] 依次应用 processors...") + + # 6.1 NullFiller + print("\n [6.1] NullFiller (by_date=True, strategy=mean)...") + null_filler = NullFiller(feature_cols=feature_cols, strategy="mean", by_date=True) + after_null = null_filler.fit_transform(train_df) + nan_after_null = sum(after_null[col].null_count() for col in feature_cols[:20]) + print(f" 处理后前20个特征总 NaN 数: {nan_after_null}") + + # 检查具体哪些列还有 NaN + if nan_after_null > 0: + print(" 仍有 NaN 的列:") + for col in feature_cols[:20]: + count = after_null[col].null_count() + if count > 0: + print(f" {col}: {count}") + + # 6.2 Winsorizer + print("\n [6.2] Winsorizer (by_date=False)...") + winsorizer = Winsorizer( + feature_cols=feature_cols, lower=0.01, upper=0.99, by_date=False + ) + after_winsor = winsorizer.fit_transform(after_null) + nan_after_winsor = sum(after_winsor[col].null_count() for col in feature_cols[:20]) + print(f" 处理后前20个特征总 NaN 数: {nan_after_winsor}") + + # 6.3 StandardScaler + print("\n [6.3] StandardScaler...") + scaler = StandardScaler(feature_cols=feature_cols) + after_scaler = scaler.fit_transform(after_winsor) + nan_after_scaler = sum(after_scaler[col].null_count() for col in feature_cols[:20]) + print(f" 处理后前20个特征总 NaN 数: {nan_after_scaler}") + + # 检查具体哪些列还有 NaN + if nan_after_scaler > 0: + print(" 仍有 NaN 的列:") + for col in feature_cols[:20]: + count = after_scaler[col].null_count() + if count > 0: + # 检查这列在训练时的统计量 + has_mean = col in scaler.mean_ + has_std = col in scaler.std_ + mean_val = scaler.mean_.get(col, "N/A") + std_val = scaler.std_.get(col, "N/A") + print(f" {col}: {count}, mean={has_mean}, std={has_std}") + + # Step 6.4: 检查 StandardScaler 之后、select 之前的所有列 + print("\n [6.4] 检查 StandardScaler 后的所有列...") + all_nan_counts = {} + for col in feature_cols: + count = after_scaler[col].null_count() + if count > 0: + all_nan_counts[col] = count + print(f" 所有特征列中含 NaN 的列数: {len(all_nan_counts)}") + + # 检查这些列是否在 feature_cols 的前20个中 + nan_cols_in_first_20 = [c for c in all_nan_counts.keys() if c in feature_cols[:20]] + nan_cols_not_in_first_20 = [ + c for c in all_nan_counts.keys() if c not in feature_cols[:20] + ] + print(f" 在前20个中的: {len(nan_cols_in_first_20)}") + print(f" 不在前20个中的: {len(nan_cols_not_in_first_20)}") + if nan_cols_not_in_first_20: + print(f" 例如: {nan_cols_not_in_first_20[:10]}") + + # 检查 StandardScaler 是否学到了这些列的统计量 + print("\n [6.5] 检查 StandardScaler 学到的统计量...") + missing_stats_cols = [c for c in all_nan_counts.keys() if c not in scaler.mean_] + print(f" 未学到 mean 的列数: {len(missing_stats_cols)}") + if missing_stats_cols: + print(f" 例如: {missing_stats_cols[:10]}") + # 检查这些列的数据类型 + for col in missing_stats_cols[:3]: + dtype = after_scaler[col].dtype + print(f" {col}: dtype={dtype}") + + # Step 7: 提取 X 并检查 + print("\n[Step 7] 提取特征矩阵 X...") + X = after_scaler.select(feature_cols) + + # 关键检查:对比 after_scaler 和 X 中的列 + print("\n [7.1] 对比 after_scaler 和 X 中的列...") + for col in feature_cols[:20]: + null_in_raw = after_scaler[col].null_count() + null_in_x = X[col].null_count() + if null_in_raw != null_in_x: + print(f" {col}: after_scaler={null_in_raw}, X={null_in_x}") + + X_np = X.to_numpy() + print(f" X 形状: {X_np.shape}") + print(f" X 中 NaN 总数: {np.isnan(X_np).sum()}") + + # 检查哪些特征列有 NaN + nan_by_col = [] + for i, col in enumerate(feature_cols): + col_nan = np.isnan(X_np[:, i]).sum() + if col_nan > 0: + nan_by_col.append((col, col_nan)) + + print(f" 含 NaN 的特征列数: {len(nan_by_col)}") + for col, count in nan_by_col[:10]: + print(f" {col}: {count}") + + # 检查这些列在 after_scaler 中的数据类型 + print("\n [Step 8] 检查含 NaN 列的数据类型...") + for col, count in nan_by_col[:5]: + dtype = after_scaler[col].dtype + null_count = after_scaler[col].null_count() + print(f" {col}: dtype={dtype}, null_count={null_count}") + + # 检查这些列是否是布尔类型 + boolean_cols = [ + col for col in feature_cols if after_scaler[col].dtype == pl.Boolean + ] + print(f"\n Boolean 类型的特征列数: {len(boolean_cols)}") + print(f" 例如: {boolean_cols[:10]}") + + # 检查这些布尔列是否有 null + boolean_with_null = [ + col for col in boolean_cols if after_scaler[col].null_count() > 0 + ] + print(f"\n 含 null 的 Boolean 列数: {len(boolean_with_null)}") + + # Step 9: 检查是否有不在 feature_cols 中的列有 NaN + print("\n [Step 9] 检查非特征列的 NaN...") + non_feature_cols = [c for c in after_scaler.columns if c not in feature_cols] + non_feature_nan = {} + for col in non_feature_cols[:10]: + count = after_scaler[col].null_count() + if count > 0: + non_feature_nan[col] = count + print(f" 非特征列中含 NaN 的列数: {len(non_feature_nan)}") + for col, count in list(non_feature_nan.items())[:5]: + print(f" {col}: {count}") + + print("\n" + "=" * 80) + print("诊断完成") + print("=" * 80) + + # 断言用于pytest + assert True diff --git a/tests/test_nan_step_by_step.py b/tests/test_nan_step_by_step.py new file mode 100644 index 0000000..5677454 --- /dev/null +++ b/tests/test_nan_step_by_step.py @@ -0,0 +1,492 @@ +"""NaN 问题逐步诊断测试 - 精确定位问题环节 + +此测试会逐步检查 DataPipeline 的每个处理步骤,精确定位 NaN 产生的位置。 +""" + +import numpy as np +import polars as pl +import pytest + +from src.factors import FactorEngine +from src.training import ( + FactorManager, + DataPipeline, + NullFiller, + Winsorizer, + StandardScaler, +) +from src.training.components.filters import STFilter +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_NAME, + LABEL_FACTOR, + stock_pool_filter, + STOCK_FILTER_REQUIRED_COLUMNS, +) + +# 测试配置 +EXCLUDED_FACTORS = [f"GTJA_alpha{i:03d}" for i in range(1, 50)] # 排除前50个加速测试 +TEST_DATE_RANGE = { + "train": ("20200101", "20201231"), # 一整年数据 + "val": ("20210101", "20210331"), + "test": ("20210401", "20210630"), +} + + +class TestNaNStepByStep: + """逐步诊断 NaN 问题的测试类""" + + @pytest.fixture(scope="class") + def base_data(self): + """准备基础数据(未经过任何处理)""" + print("\n" + "=" * 80) + print("[Fixture] 准备基础数据...") + + engine = FactorEngine() + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + # 注册因子 + feature_cols = factor_manager.register_to_engine(engine, verbose=False) + print(f" 特征数: {len(feature_cols)}") + + # 计算完整日期范围 + all_start = min( + TEST_DATE_RANGE["train"][0], + TEST_DATE_RANGE["val"][0], + TEST_DATE_RANGE["test"][0], + ) + all_end = max( + TEST_DATE_RANGE["train"][1], + TEST_DATE_RANGE["val"][1], + TEST_DATE_RANGE["test"][1], + ) + + # 计算因子 + raw_data = engine.compute( + factor_names=feature_cols + [LABEL_NAME], + start_date=all_start, + end_date=all_end, + ) + print(f" 原始数据形状: {raw_data.shape}") + + return { + "engine": engine, + "factor_manager": factor_manager, + "feature_cols": feature_cols, + "raw_data": raw_data, + } + + def test_step_0_raw_data(self, base_data): + """步骤0: 检查原始数据中的 NaN""" + print("\n" + "=" * 80) + print("[步骤0] 检查原始数据中的 NaN") + + raw_data = base_data["raw_data"] + feature_cols = base_data["feature_cols"] + + nan_stats = self._check_nan_in_df(raw_data, feature_cols, "原始数据") + + # 记录有 NaN 的列 + print(f" 含 NaN 的特征列数: {len(nan_stats['cols_with_nan'])}") + if nan_stats["cols_with_nan"]: + print(f" 示例: {nan_stats['cols_with_nan'][:5]}") + + return nan_stats + + def test_step_1_after_st_filter(self, base_data): + """步骤1: 检查 STFilter 后的 NaN""" + print("\n" + "=" * 80) + print("[步骤1] 检查 STFilter 后的 NaN") + + raw_data = base_data["raw_data"] + feature_cols = base_data["feature_cols"] + engine = base_data["engine"] + + st_filter = STFilter(data_router=engine.router) + filtered_data = st_filter.filter(raw_data) + + print(f" 过滤后数据形状: {filtered_data.shape}") + print(f" 删除记录数: {len(raw_data) - len(filtered_data)}") + + nan_stats = self._check_nan_in_df(filtered_data, feature_cols, "STFilter后") + + # 对比步骤0,看是否有新增 NaN + step0_nan = self.test_step_0_raw_data(base_data) + if nan_stats["total_nan"] != step0_nan["total_nan"]: + print( + f" [警告] NaN 数量变化: {step0_nan['total_nan']} -> {nan_stats['total_nan']}" + ) + + return nan_stats + + def test_step_2_after_stock_pool(self, base_data): + """步骤2: 检查股票池筛选后的 NaN""" + print("\n" + "=" * 80) + print("[步骤2] 检查股票池筛选后的 NaN") + + raw_data = base_data["raw_data"] + feature_cols = base_data["feature_cols"] + engine = base_data["engine"] + + # 先应用 STFilter + st_filter = STFilter(data_router=engine.router) + filtered_data = st_filter.filter(raw_data) + + # 再应用股票池筛选 + from src.training.core.stock_pool_manager import StockPoolManager + + pool_manager = StockPoolManager( + filter_func=stock_pool_filter, + required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + data_router=engine.router, + ) + pool_data = pool_manager.filter_and_select_daily(filtered_data) + + print(f" 筛选后数据形状: {pool_data.shape}") + print(f" 删除记录数: {len(filtered_data) - len(pool_data)}") + + nan_stats = self._check_nan_in_df(pool_data, feature_cols, "股票池筛选后") + return nan_stats + + def test_step_3_train_split_without_skip(self, base_data): + """步骤3: 检查训练集划分后的 NaN(不跳过天数)""" + print("\n" + "=" * 80) + print("[步骤3] 检查训练集划分后的 NaN(不跳过天数)") + + raw_data = base_data["raw_data"] + feature_cols = base_data["feature_cols"] + engine = base_data["engine"] + + # 应用过滤器 + st_filter = STFilter(data_router=engine.router) + filtered_data = st_filter.filter(raw_data) + + from src.training.core.stock_pool_manager import StockPoolManager + + pool_manager = StockPoolManager( + filter_func=stock_pool_filter, + required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + data_router=engine.router, + ) + pool_data = pool_manager.filter_and_select_daily(filtered_data) + + # 划分训练集 + train_start, train_end = TEST_DATE_RANGE["train"] + train_mask = (pool_data["trade_date"] >= train_start) & ( + pool_data["trade_date"] <= train_end + ) + train_df = pool_data.filter(train_mask) + + print(f" 训练集形状: {train_df.shape}") + + # 统计交易日数量 + unique_dates = train_df["trade_date"].unique().sort() + print(f" 训练集交易日数量: {len(unique_dates)}") + print(f" 日期范围: {unique_dates[0]} ~ {unique_dates[-1]}") + + nan_stats = self._check_nan_in_df(train_df, feature_cols, "训练集(不跳过)") + + # 返回训练集供后续测试使用 + return { + "nan_stats": nan_stats, + "train_df": train_df, + "unique_dates": unique_dates, + } + + def test_step_4_train_split_with_skip(self, base_data): + """步骤4: 检查训练集划分后的 NaN(跳过前252天)""" + print("\n" + "=" * 80) + print("[步骤4] 检查训练集划分后的 NaN(跳过前252天)") + + step3_result = self.test_step_3_train_split_without_skip(base_data) + train_df = step3_result["train_df"] + unique_dates = step3_result["unique_dates"] + feature_cols = base_data["feature_cols"] + + # 跳过前252天 + skip_days = 252 + if len(unique_dates) > skip_days: + start_date = unique_dates[skip_days] + train_df_skipped = train_df.filter(pl.col("trade_date") >= start_date) + print(f" 跳过前{skip_days}天后,从 {start_date} 开始") + print(f" 跳过后训练集形状: {train_df_skipped.shape}") + print(f" 跳过记录数: {len(train_df) - len(train_df_skipped)}") + else: + train_df_skipped = train_df + print( + f" [警告] 训练集交易日数({len(unique_dates)})少于跳过天数({skip_days}),未跳过" + ) + + nan_stats = self._check_nan_in_df( + train_df_skipped, feature_cols, "训练集(跳过252天)" + ) + + return { + "nan_stats": nan_stats, + "train_df": train_df_skipped, + } + + def test_step_5_after_null_filler(self, base_data): + """步骤5: 检查 NullFiller 后的 NaN""" + print("\n" + "=" * 80) + print("[步骤5] 检查 NullFiller 后的 NaN") + + step4_result = self.test_step_4_train_split_with_skip(base_data) + train_df = step4_result["train_df"] + feature_cols = base_data["feature_cols"] + + print(f" 处理前数据形状: {train_df.shape}") + + # 应用 NullFiller + null_filler = NullFiller( + feature_cols=feature_cols, strategy="mean", by_date=True + ) + after_null = null_filler.fit_transform(train_df) + + print(f" 处理后数据形状: {after_null.shape}") + + nan_stats = self._check_nan_in_df(after_null, feature_cols, "NullFiller后") + + # 检查哪些列还有 NaN + if nan_stats["cols_with_nan"]: + print( + f" [错误] NullFiller 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:" + ) + for col in nan_stats["cols_with_nan"][:10]: + count = after_null[col].null_count() + dtype = after_null[col].dtype + print(f" {col}: {count} 个 NaN, dtype={dtype}") + + return { + "nan_stats": nan_stats, + "after_null": after_null, + } + + def test_step_6_after_winsorizer(self, base_data): + """步骤6: 检查 Winsorizer 后的 NaN""" + print("\n" + "=" * 80) + print("[步骤6] 检查 Winsorizer 后的 NaN") + + step5_result = self.test_step_5_after_null_filler(base_data) + after_null = step5_result["after_null"] + feature_cols = base_data["feature_cols"] + + # 应用 Winsorizer + winsorizer = Winsorizer( + feature_cols=feature_cols, lower=0.01, upper=0.99, by_date=False + ) + after_winsor = winsorizer.fit_transform(after_null) + + nan_stats = self._check_nan_in_df(after_winsor, feature_cols, "Winsorizer后") + + # 检查哪些列还有 NaN + if nan_stats["cols_with_nan"]: + print( + f" [错误] Winsorizer 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:" + ) + for col in nan_stats["cols_with_nan"][:10]: + count = after_winsor[col].null_count() + dtype = after_winsor[col].dtype + print(f" {col}: {count} 个 NaN, dtype={dtype}") + + return { + "nan_stats": nan_stats, + "after_winsor": after_winsor, + } + + def test_step_7_after_standard_scaler(self, base_data): + """步骤7: 检查 StandardScaler 后的 NaN""" + print("\n" + "=" * 80) + print("[步骤7] 检查 StandardScaler 后的 NaN") + + step6_result = self.test_step_6_after_winsorizer(base_data) + after_winsor = step6_result["after_winsor"] + feature_cols = base_data["feature_cols"] + + # 在应用 StandardScaler 之前,检查那些后来出问题的列 + print("\n [预检查] StandardScaler 前,检查关键列...") + problematic_cols = [ + "GTJA_alpha062", + "GTJA_alpha073", + "GTJA_alpha085", + "GTJA_alpha087", + "GTJA_alpha092", + "GTJA_alpha103", + "GTJA_alpha104", + "GTJA_alpha117", + "GTJA_alpha124", + "GTJA_alpha131", + ] + for col in problematic_cols: + if col in after_winsor.columns: + null_count = after_winsor[col].null_count() + dtype = after_winsor[col].dtype + min_val = after_winsor[col].min() + max_val = after_winsor[col].max() + print( + f" {col}: null={null_count}, dtype={dtype}, min={min_val}, max={max_val}" + ) + + # 应用 StandardScaler + scaler = StandardScaler(feature_cols=feature_cols) + after_scaler = scaler.fit_transform(after_winsor) + + # 检查 StandardScaler 学到的统计量 + print("\n [统计量检查] StandardScaler 学到的统计量...") + for col in problematic_cols: + if col in scaler.mean_: + print(f" {col}: mean={scaler.mean_[col]}, std={scaler.std_[col]}") + else: + print(f" {col}: [未学到统计量]") + + nan_stats = self._check_nan_in_df( + after_scaler, feature_cols, "StandardScaler后" + ) + + # 检查哪些列还有 NaN + if nan_stats["cols_with_nan"]: + print( + f" [错误] StandardScaler 后仍有 {len(nan_stats['cols_with_nan'])} 列含 NaN:" + ) + for col in nan_stats["cols_with_nan"][:10]: + count = after_scaler[col].null_count() + dtype = after_scaler[col].dtype + print(f" {col}: {count} 个 NaN, dtype={dtype}") + + # 检查这列是否学到了统计量 + if col in scaler.mean_: + print( + f" mean={scaler.mean_[col]:.4f}, std={scaler.std_[col]:.4f}" + ) + else: + print(f" [警告] 未学到统计量!") + + return { + "nan_stats": nan_stats, + "after_scaler": after_scaler, + "scaler": scaler, + } + + def test_step_8_extract_X(self, base_data): + """步骤8: 检查提取 X 后的 NaN(转换为 numpy)""" + print("\n" + "=" * 80) + print("[步骤8] 检查提取 X 后的 NaN") + + step7_result = self.test_step_7_after_standard_scaler(base_data) + after_scaler = step7_result["after_scaler"] + feature_cols = base_data["feature_cols"] + + # 提取 X + X_df = after_scaler.select(feature_cols) + print(f" X DataFrame 形状: {X_df.shape}") + + # 对比 DataFrame 和 select 后的 null 数量 + print("\n [对比] DataFrame vs select 后的 null 数量:") + mismatched = [] + for col in feature_cols[:20]: # 只检查前20个 + null_in_df = after_scaler[col].null_count() + null_in_x = X_df[col].null_count() + if null_in_df != null_in_x: + mismatched.append((col, null_in_df, null_in_x)) + + if mismatched: + print(f" [警告] 发现 {len(mismatched)} 列不匹配:") + for col, df_null, x_null in mismatched[:10]: + print(f" {col}: DataFrame={df_null}, X={x_null}") + else: + print(f" [通过] 所有列的 null 数量一致") + + # 转换为 numpy + X_np = X_df.to_numpy() + print(f"\n X numpy 形状: {X_np.shape}") + + nan_count = np.isnan(X_np).sum() + print(f" X 中 NaN 总数: {nan_count}") + + if nan_count > 0: + # 找出哪些列有 NaN + nan_by_col = [] + for i, col in enumerate(feature_cols): + col_nan = np.isnan(X_np[:, i]).sum() + if col_nan > 0: + nan_by_col.append((col, col_nan)) + + print(f"\n [错误] 含 NaN 的特征列数: {len(nan_by_col)}") + for col, count in nan_by_col[:10]: + # 检查原始 DataFrame 中的情况 + df_null = after_scaler[col].null_count() + dtype = after_scaler[col].dtype + + # 检查是否有 Infinity + inf_count_pos = (after_scaler[col] == float("inf")).sum() + inf_count_neg = (after_scaler[col] == float("-inf")).sum() + + # 检查 min/max + col_min = after_scaler[col].min() + col_max = after_scaler[col].max() + + print( + f" {col}: numpy中{count}个NaN, DataFrame中{df_null}个null, dtype={dtype}" + ) + print(f" min={col_min}, max={col_max}") + print(f" +inf={inf_count_pos}, -inf={inf_count_neg}") + + # 如果有 inf,显示一些样本值 + if inf_count_pos > 0 or inf_count_neg > 0: + sample_vals = after_scaler[col].drop_nulls().tail(5).to_list() + print(f" 样本值: {sample_vals}") + + # 断言失败,显示详细信息 + assert False, f"X 中含 {nan_count} 个 NaN,涉及 {len(nan_by_col)} 个特征列" + else: + print("\n [通过] X 中无 NaN!") + + def _check_nan_in_df( + self, df: pl.DataFrame, feature_cols: list, step_name: str + ) -> dict: + """检查 DataFrame 中的 NaN 统计信息 + + Returns: + dict: { + 'total_nan': 总NaN数, + 'cols_with_nan': 含NaN的列名列表, + 'nan_by_col': {列名: NaN数} 的字典 + } + """ + nan_by_col = {} + total_nan = 0 + + for col in feature_cols: + null_count = df[col].null_count() + if null_count > 0: + nan_by_col[col] = null_count + total_nan += null_count + + cols_with_nan = list(nan_by_col.keys()) + + print(f" {step_name}:") + print(f" 总记录数: {len(df)}") + print(f" 特征列数: {len(feature_cols)}") + print(f" 总NaN数: {total_nan}") + print(f" 含NaN的列数: {len(cols_with_nan)}") + + if cols_with_nan and len(cols_with_nan) <= 5: + print(f" 含NaN的列: {cols_with_nan}") + elif cols_with_nan: + print(f" 含NaN的列(前5): {cols_with_nan[:5]}...") + + return { + "total_nan": total_nan, + "cols_with_nan": cols_with_nan, + "nan_by_col": nan_by_col, + } + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_tabm_integration.py b/tests/test_tabm_integration.py new file mode 100644 index 0000000..294904a --- /dev/null +++ b/tests/test_tabm_integration.py @@ -0,0 +1,310 @@ +"""TabM 集成测试 + +测试 TabMModel 和 TabMRegressionTask 的完整训练流程。 +""" + +import os +import sys +from pathlib import Path + +import numpy as np +import polars as pl +import pytest +import torch + +# 确保 src 在路径中 +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.training.components.models import TabMModel +from src.training.tasks import TabMRegressionTask + + +# ========================================== +# 测试数据准备 +# ========================================== + + +def create_sample_data(n_samples: int = 1000, n_features: int = 20, seed: int = 42): + """创建样本数据用于测试 + + Args: + n_samples: 样本数量 + n_features: 特征数量 + seed: 随机种子 + + Returns: + (train_X, train_y, val_X, val_y, test_X, test_y) + """ + np.random.seed(seed) + torch.manual_seed(seed) + + # 创建特征矩阵 + X_train = pl.DataFrame( + np.random.randn(n_samples, n_features).astype(np.float32), + schema=[f"feature_{i}" for i in range(n_features)], + ) + y_train = pl.Series("target", np.random.randn(n_samples).astype(np.float32)) + + X_val = pl.DataFrame( + np.random.randn(n_samples // 2, n_features).astype(np.float32), + schema=[f"feature_{i}" for i in range(n_features)], + ) + y_val = pl.Series("target", np.random.randn(n_samples // 2).astype(np.float32)) + + X_test = pl.DataFrame( + np.random.randn(n_samples // 2, n_features).astype(np.float32), + schema=[f"feature_{i}" for i in range(n_features)], + ) + y_test = pl.Series("target", np.random.randn(n_samples // 2).astype(np.float32)) + + return X_train, y_train, X_val, y_val, X_test, y_test + + +# ========================================== +# TabMModel 测试 +# ========================================== + + +class TestTabMModel: + """TabMModel 单元测试""" + + def test_initialization(self): + """测试模型初始化""" + params = { + "n_blocks": 2, + "d_block": 128, + "ensemble_size": 8, # 小规模集成用于测试 + "batch_size": 64, + "epochs": 2, + } + + model = TabMModel(params) + + assert model.name == "tabm" + assert model.params == params + assert model.device.type in ["cuda", "cpu"] + assert model.model is None # 未训练时为 None + + def test_fit_and_predict(self): + """测试训练和预测""" + # 创建小规模数据 + X_train, y_train, X_val, y_val, X_test, _ = create_sample_data( + n_samples=200, n_features=10, seed=42 + ) + + params = { + "n_blocks": 1, + "d_block": 64, + "ensemble_size": 4, + "batch_size": 32, + "epochs": 2, + "early_stopping_patience": 10, + } + + model = TabMModel(params) + + # 训练 + model.fit(X_train, y_train, eval_set=(X_val, y_val)) + + # 验证模型已训练 + assert model.model is not None + assert len(model.training_history_["train_loss"]) > 0 + + # 预测 + predictions = model.predict(X_test) + + # 验证预测结果 + assert isinstance(predictions, np.ndarray) + assert len(predictions) == len(X_test) + assert predictions.shape == (len(X_test),) + + def test_save_and_load(self, tmp_path): + """测试模型保存和加载""" + # 创建数据 + X_train, y_train, X_val, y_val, _, _ = create_sample_data( + n_samples=200, n_features=10, seed=42 + ) + + params = { + "n_blocks": 1, + "d_block": 64, + "ensemble_size": 4, + "batch_size": 32, + "epochs": 2, + } + + # 训练模型 + model = TabMModel(params) + model.fit(X_train, y_train, eval_set=(X_val, y_val)) + + # 保存 + save_path = tmp_path / "test_tabm_model" + model.save(str(save_path)) + + # 加载 + loaded_model = TabMModel.load(str(save_path)) + + # 验证加载的模型 + assert loaded_model.params == params + assert loaded_model.feature_names_ == model.feature_names_ + assert loaded_model.model is not None + + # 预测结果应该一致 + pred1 = model.predict(X_val) + pred2 = loaded_model.predict(X_val) + + np.testing.assert_allclose(pred1, pred2, rtol=1e-5) + + +# ========================================== +# TabMRegressionTask 测试 +# ========================================== + + +class TestTabMRegressionTask: + """TabMRegressionTask 单元测试""" + + def test_initialization(self): + """测试任务初始化""" + params = { + "n_blocks": 2, + "d_block": 128, + "ensemble_size": 8, + "batch_size": 64, + "epochs": 2, + } + + task = TabMRegressionTask(model_params=params, label_name="target") + + assert task.model_params == params + assert task.label_name == "target" + assert task.model is None + + def test_prepare_labels(self): + """测试标签准备(回归任务不做转换)""" + params = { + "ensemble_size": 4, + "epochs": 2, + } + + task = TabMRegressionTask(model_params=params, label_name="target") + + # 创建测试数据 + data = { + "train": { + "X": pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), + "y": pl.Series("target", [0.1, 0.2, 0.3]), + } + } + + result = task.prepare_labels(data) + + # 回归任务不做转换,数据应该保持不变 + assert result == data + + def test_fit_train_and_predict(self): + """测试完整训练和预测流程""" + # 创建数据 + X_train, y_train, X_val, y_val, X_test, y_test = create_sample_data( + n_samples=300, n_features=10, seed=42 + ) + + params = { + "n_blocks": 1, + "d_block": 64, + "ensemble_size": 4, + "batch_size": 32, + "epochs": 3, + } + + task = TabMRegressionTask(model_params=params, label_name="target") + + # 准备数据格式 + train_data = {"X": X_train, "y": y_train} + val_data = {"X": X_val, "y": y_val} + + # 训练 + task.fit(train_data, val_data) + + # 验证模型已训练 + assert task.get_model() is not None + + # 预测 + predictions = task.predict({"X": X_test}) + + # 验证预测结果 + assert len(predictions) == len(X_test) + + +# ========================================== +# 集成测试 +# ========================================== + + +class TestTabMIntegration: + """TabM 集成测试""" + + def test_full_workflow(self): + """测试完整工作流程""" + # 创建数据 + X_train, y_train, X_val, y_val, X_test, y_test = create_sample_data( + n_samples=500, n_features=15, seed=42 + ) + + params = { + "n_blocks": 2, + "d_block": 128, + "ensemble_size": 8, + "batch_size": 64, + "epochs": 5, + } + + # 1. 创建 Task + task = TabMRegressionTask(model_params=params, label_name="target") + + # 2. 准备数据 + train_data = {"X": X_train, "y": y_train} + val_data = {"X": X_val, "y": y_val} + + # 3. 训练 + task.fit(train_data, val_data) + + # 4. 验证训练历史 + model = task.get_model() + assert len(model.training_history_["train_loss"]) > 0 + assert len(model.training_history_["val_loss"]) > 0 + + # 5. 预测 + predictions = task.predict({"X": X_test}) + + # 6. 验证预测质量 + # 简单验证:预测值不应全为常数 + assert np.std(predictions) > 1e-6, "预测值全为常数,可能是模型未正常训练" + + # 验证预测值与真实值存在一定相关性 + correlation = np.corrcoef(predictions, y_test.to_numpy())[0, 1] + # 注意:随机数据的相关性可能很低,这是正常的 + print(f"预测与真实值相关系数: {correlation:.4f}") + + def test_gpu_availability(self): + """测试 GPU 可用性""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + params = { + "ensemble_size": 2, + "epochs": 1, + } + + model = TabMModel(params) + + assert model.device == device + expected_type = "cuda" if torch.cuda.is_available() else "cpu" + assert model.device.type == expected_type + + +# ========================================== +# 运行测试 +# ========================================== + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/test_tabm_nan_debug.py b/tests/test_tabm_nan_debug.py new file mode 100644 index 0000000..4bc18d3 --- /dev/null +++ b/tests/test_tabm_nan_debug.py @@ -0,0 +1,519 @@ +"""TabM NaN 问题诊断测试 + train_skip_days 功能验证 + +诊断 loss 为 nan 的根因: +1. 标签中是否有 NaN 或极端值 +2. 标准化后是否有 NaN +3. 是否有方差为0或接近0的列 +4. train_skip_days 功能是否正常工作 +""" + +import numpy as np +import polars as pl +import pytest + +from src.factors import FactorEngine +from src.training import ( + FactorManager, + DataPipeline, + TabMRegressionTask, + NullFiller, + Winsorizer, + StandardScaler, +) +from src.training.components.filters import STFilter +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_NAME, + LABEL_FACTOR, + stock_pool_filter, + STOCK_FILTER_REQUIRED_COLUMNS, +) + + +# TabM 模型参数(简化用于快速测试) +MODEL_PARAMS = { + "n_blocks": 2, + "d_block": 128, + "dropout": 0.1, + "ensemble_size": 4, # 简化 + "batch_size": 256, + "learning_rate": 1e-4, # 降低学习率 + "weight_decay": 1e-5, + "epochs": 3, + "early_stopping_patience": 5, +} + +# 测试日期范围 +TEST_DATE_RANGE = { + "train": ("20200101", "20200630"), + "val": ("20200701", "20200731"), + "test": ("20200801", "20200831"), +} + +# 小范围测试用的因子排除列表 +EXCLUDED_FACTORS = [ + "GTJA_alpha001", + "GTJA_alpha002", + "GTJA_alpha003", + "GTJA_alpha004", + "GTJA_alpha005", + "GTJA_alpha006", + "GTJA_alpha007", + "GTJA_alpha008", + "GTJA_alpha009", + "GTJA_alpha010", +] + + +class TestTabMNanDebug: + """TabM NaN 问题诊断测试类(使用 DataPipeline)""" + + @pytest.fixture(scope="class") + def engine_and_factor_manager(self): + """准备 FactorEngine 和 FactorManager""" + engine = FactorEngine() + + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + return { + "engine": engine, + "factor_manager": factor_manager, + } + + @pytest.fixture(scope="class") + def pipeline_data(self, engine_and_factor_manager): + """使用 DataPipeline 准备数据,并过滤掉全为NaN的列""" + engine = engine_and_factor_manager["engine"] + factor_manager = engine_and_factor_manager["factor_manager"] + + # 创建 DataPipeline(使用 train_skip_days=0 进行基础测试) + pipeline = DataPipeline( + factor_manager=factor_manager, + processor_configs=[ + (NullFiller, {"strategy": "mean"}), + (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (StandardScaler, {}), + ], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=0, # 先不跳过天数进行测试 + ) + + # 准备数据 + data = pipeline.prepare_data( + engine=engine, + date_range=TEST_DATE_RANGE, + label_name=LABEL_NAME, + verbose=True, + ) + + # 获取特征列 + feature_cols = factor_manager.register_to_engine(engine, verbose=False) + + # 过滤掉全为NaN的列(这些因子计算失败,没有修复意义) + print("\n[DataPipeline] 检查并过滤全为NaN的特征列...") + for split_name in ["train", "val", "test"]: + X_df = data[split_name]["X"] + y_series = data[split_name]["y"] + raw_data = data[split_name]["raw_data"] + + # 删除标签为NaN的行 + y_nan_count = y_series.null_count() + if y_nan_count > 0: + print(f" {split_name}: 发现 {y_nan_count} 个标签为NaN的行,将被删除") + # 创建有效标签的mask + valid_mask = y_series.is_not_null() + # 过滤所有相关数据 + X_df = X_df.filter(valid_mask) + y_series = y_series.filter(valid_mask) + raw_data = raw_data.filter(valid_mask) + # 更新数据 + data[split_name]["X"] = X_df + data[split_name]["y"] = y_series + data[split_name]["raw_data"] = raw_data + + # 检查每列的NaN数量 + nan_counts = {col: X_df[col].null_count() for col in X_df.columns} + total_rows = len(X_df) + + # 找出全为NaN的列 + all_nan_cols = [ + col for col, count in nan_counts.items() if count == total_rows + ] + + if all_nan_cols: + print( + f" {split_name}: 发现 {len(all_nan_cols)} 个全为NaN的列,将被删除" + ) + print( + f" 列名: {all_nan_cols[:5]}{'...' if len(all_nan_cols) > 5 else ''}" + ) + + # 更新feature_cols + feature_cols = [c for c in feature_cols if c not in all_nan_cols] + + # 从X中删除这些列 + X_df = X_df.select(feature_cols) + + # 从raw_data中也删除这些列(保留原始特征列以外的列如trade_date, ts_code等) + raw_cols_to_keep = [ + c for c in raw_data.columns if c not in all_nan_cols + ] + raw_data = raw_data.select(raw_cols_to_keep) + + # 更新数据 + data[split_name]["X"] = X_df + data[split_name]["raw_data"] = raw_data + data[split_name]["feature_cols"] = feature_cols + + # 验证没有NaN了 + X_np = X_df.to_numpy() + nan_count = np.isnan(X_np).sum() + assert nan_count == 0, f"{split_name} 中仍有 {nan_count} 个NaN" + + print(f" 过滤后特征数: {len(feature_cols)}") + print(" [通过] 所有特征列均无NaN") + + return { + "pipeline": pipeline, + "data": data, + "feature_cols": feature_cols, + "engine": engine, + } + + def test_pipeline_data_structure(self, pipeline_data): + """诊断0: DataPipeline 返回的数据结构检查""" + data = pipeline_data["data"] + feature_cols = pipeline_data["feature_cols"] + + print("\n[诊断0] DataPipeline 数据结构检查:") + + # 检查数据结构 + assert "train" in data, "缺少 train 数据" + assert "val" in data, "缺少 val 数据" + assert "test" in data, "缺少 test 数据" + + for split_name in ["train", "val", "test"]: + split_data = data[split_name] + assert "X" in split_data, f"{split_name} 缺少 X" + assert "y" in split_data, f"{split_name} 缺少 y" + assert "raw_data" in split_data, f"{split_name} 缺少 raw_data" + assert "feature_cols" in split_data, f"{split_name} 缺少 feature_cols" + + X = split_data["X"] + y = split_data["y"] + + print(f" {split_name}:") + print(f" X 形状: {X.shape}") + print(f" y 形状: {len(y)}") + print(f" 特征数: {len(split_data['feature_cols'])}") + + # 检查维度一致性 + assert X.shape[0] == len(y), f"{split_name} X 和 y 行数不匹配" + assert X.shape[1] == len(feature_cols), f"{split_name} 特征数不匹配" + + print(" [通过] 数据结构正确") + + def test_label_quality_with_pipeline(self, pipeline_data): + """诊断1: 使用 DataPipeline 后的标签数据质量检查""" + data = pipeline_data["data"] + + # 检查所有数据集的标签质量 + for split_name in ["train", "val", "test"]: + y = data[split_name]["y"] + y_np = y.to_numpy() + + print(f"\n[诊断1-{split_name}] 标签数据质量:") + print(f" 总数: {len(y)}") + print(f" NaN数量: {y.null_count()}") + print(f" 均值: {y.mean():.6f}") + print(f" 标准差: {y.std():.6f}") + print(f" 最小值: {y.min():.6f}") + print(f" 最大值: {y.max():.6f}") + + inf_count = np.isinf(y_np).sum() + nan_count = np.isnan(y_np).sum() + print(f" inf数量: {inf_count}") + print(f" nan数量: {nan_count}") + + assert inf_count == 0, f"{split_name} 标签含 inf: {inf_count}" + assert nan_count == 0, f"{split_name} 标签含 nan: {nan_count}" + + def test_processed_data_quality_with_pipeline(self, pipeline_data): + """诊断2: 使用 DataPipeline 处理后的特征数据质量""" + data = pipeline_data["data"] + feature_cols = pipeline_data["feature_cols"] + + for split_name in ["train", "val", "test"]: + X = data[split_name]["X"] + y = data[split_name]["y"] + + X_np = X.to_numpy().astype(np.float32) + y_np = y.to_numpy().astype(np.float32) + + print(f"\n[诊断2-{split_name}] 处理后数据质量:") + print(f" X 形状: {X_np.shape}, dtype: {X_np.dtype}") + print(f" y 形状: {y_np.shape}, dtype: {y_np.dtype}") + print(f" X中NaN: {np.isnan(X_np).sum()}") + print(f" X中Inf: {np.isinf(X_np).sum()}") + print(f" y中NaN: {np.isnan(y_np).sum()}") + print(f" y中Inf: {np.isinf(y_np).sum()}") + + assert np.isnan(X_np).sum() == 0, f"{split_name} X含NaN" + assert np.isnan(y_np).sum() == 0, f"{split_name} y含NaN" + assert np.isinf(X_np).sum() == 0, f"{split_name} X含Inf" + assert np.isinf(y_np).sum() == 0, f"{split_name} y含Inf" + + def test_train_skip_days_functionality(self, engine_and_factor_manager): + """诊断3: train_skip_days 功能验证""" + engine = engine_and_factor_manager["engine"] + factor_manager = engine_and_factor_manager["factor_manager"] + + print("\n[诊断3] train_skip_days 功能验证:") + + # 创建一个跳过50天数据的 pipeline + skip_days = 50 + pipeline_with_skip = DataPipeline( + factor_manager=factor_manager, + processor_configs=[ + (NullFiller, {"strategy": "mean"}), + (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (StandardScaler, {}), + ], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=skip_days, + ) + + # 准备数据 + data_with_skip = pipeline_with_skip.prepare_data( + engine=engine, + date_range=TEST_DATE_RANGE, + label_name=LABEL_NAME, + verbose=True, + ) + + # 获取原始数据(不跳过天数)用于对比 + pipeline_no_skip = DataPipeline( + factor_manager=factor_manager, + processor_configs=[ + (NullFiller, {"strategy": "mean"}), + (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (StandardScaler, {}), + ], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=0, + ) + + data_no_skip = pipeline_no_skip.prepare_data( + engine=engine, + date_range=TEST_DATE_RANGE, + label_name=LABEL_NAME, + verbose=False, + ) + + # 验证训练数据减少了 + train_with_skip = data_with_skip["train"]["raw_data"] + train_no_skip = data_no_skip["train"]["raw_data"] + + print(f"\n 对比结果:") + print(f" 不跳过时训练数据: {len(train_no_skip)} 条") + print(f" 跳过{skip_days}天后: {len(train_with_skip)} 条") + print(f" 减少: {len(train_no_skip) - len(train_with_skip)} 条") + + # 验证验证集和测试集不受影响 + assert len(data_with_skip["val"]["raw_data"]) == len( + data_no_skip["val"]["raw_data"] + ), "val 数据不应该受 train_skip_days 影响" + assert len(data_with_skip["test"]["raw_data"]) == len( + data_no_skip["test"]["raw_data"] + ), "test 数据不应该受 train_skip_days 影响" + + # 验证日期确实被跳过了 + if len(train_no_skip) > 0: + dates_no_skip = sorted(train_no_skip["trade_date"].unique()) + dates_with_skip = sorted(train_with_skip["trade_date"].unique()) + + print(f"\n 日期对比:") + print( + f" 不跳过 - 最早日期: {dates_no_skip[0]}, 共 {len(dates_no_skip)} 个交易日" + ) + print( + f" 跳过 - 最早日期: {dates_with_skip[0]}, 共 {len(dates_with_skip)} 个交易日" + ) + + # 验证跳过的数据确实从更晚的日期开始 + if len(dates_no_skip) > skip_days: + expected_start_date = dates_no_skip[skip_days] + assert dates_with_skip[0] == expected_start_date, ( + f"预期从 {expected_start_date} 开始,实际从 {dates_with_skip[0]} 开始" + ) + print(f" [通过] 正确跳过前 {skip_days} 个交易日") + + def test_training_with_pipeline(self, pipeline_data): + """诊断4: 使用 DataPipeline 处理后的数据进行训练测试""" + import torch + from torch.utils.data import DataLoader, TensorDataset + from tabm import TabM + + data = pipeline_data["data"] + feature_cols = pipeline_data["feature_cols"] + + # 获取训练数据 + X_train_df = data["train"]["X"] + y_train_series = data["train"]["y"] + + # 删除标签为NaN的行 + valid_mask = y_train_series.is_not_null() + y_nan_count = y_train_series.null_count() + if y_nan_count > 0: + print(f" 发现 {y_nan_count} 个标签为NaN的行,将被删除") + X_train_df = X_train_df.filter(valid_mask) + y_train_series = y_train_series.filter(valid_mask) + + X_train = X_train_df.to_numpy().astype(np.float32) + y_train = y_train_series.to_numpy().astype(np.float32) + + # 只取前1000条加速测试 + X_train = X_train[:1000] + y_train = y_train[:1000] + + print(f"\n[诊断4] DataPipeline 数据训练测试:") + print(f" X 形状: {X_train.shape}") + print(f" y 形状: {y_train.shape}") + print(f" X中NaN: {np.isnan(X_train).sum()}, Inf: {np.isinf(X_train).sum()}") + print(f" y中NaN: {np.isnan(y_train).sum()}, Inf: {np.isinf(y_train).sum()}") + + # 创建 TabM 模型 + n_features = X_train.shape[1] + model = TabM.make( + n_num_features=n_features, + cat_cardinalities=[], + d_out=1, + n_blocks=2, + d_block=128, + dropout=0.1, + k=4, + ) + + # 训练一个 epoch + dataset = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)) + loader = DataLoader(dataset, batch_size=256, shuffle=True) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + criterion = torch.nn.MSELoss() + + model.train() + losses = [] + + for batch_idx, (bx, by) in enumerate(loader): + optimizer.zero_grad() + outputs = model(bx) # [B, E, 1] + outputs_squeezed = outputs.squeeze(-1) # [B, E] + by_expanded = by.unsqueeze(-1).expand(-1, 4) # [B, E] + loss = criterion(outputs_squeezed, by_expanded) + loss.backward() + optimizer.step() + + losses.append(loss.item()) + print(f" 批次 {batch_idx + 1} loss: {loss.item():.6f}") + + assert not torch.isnan(loss), f"批次 {batch_idx + 1} loss 为 nan!" + + if batch_idx >= 2: # 只测试前3个批次 + break + + print(f" [通过] 所有批次 loss 正常,无 NaN") + + +class TestTrainSkipDaysEdgeCases: + """train_skip_days 边界情况测试""" + + @pytest.fixture(scope="class") + def engine_and_factor_manager(self): + """准备 FactorEngine 和 FactorManager""" + engine = FactorEngine() + + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + return { + "engine": engine, + "factor_manager": factor_manager, + } + + def test_skip_more_than_available_days(self, engine_and_factor_manager): + """测试跳过天数超过可用天数的情况""" + engine = engine_and_factor_manager["engine"] + factor_manager = engine_and_factor_manager["factor_manager"] + + print("\n[边界测试] 跳过天数超过可用天数:") + + # 使用一个很大的跳过天数 + pipeline = DataPipeline( + factor_manager=factor_manager, + processor_configs=[(NullFiller, {"strategy": "mean"})], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=1000, # 超过测试期间的交易日数 + ) + + # 准备数据(应该能正常运行并发出警告) + data = pipeline.prepare_data( + engine=engine, + date_range=TEST_DATE_RANGE, + label_name=LABEL_NAME, + verbose=True, + ) + + # 即使跳过天数很多,也应该有数据返回 + # 如果交易日数少于跳过天数,应该保留所有数据(只发出警告) + train_data = data["train"]["raw_data"] + print(f" 训练数据量: {len(train_data)} 条") + print(f" [通过] 程序未崩溃") + + def test_skip_zero_days(self, engine_and_factor_manager): + """测试跳过0天(即不跳过)""" + engine = engine_and_factor_manager["engine"] + factor_manager = engine_and_factor_manager["factor_manager"] + + print("\n[边界测试] 跳过0天:") + + pipeline = DataPipeline( + factor_manager=factor_manager, + processor_configs=[(NullFiller, {"strategy": "mean"})], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=0, + ) + + data = pipeline.prepare_data( + engine=engine, + date_range=TEST_DATE_RANGE, + label_name=LABEL_NAME, + verbose=False, + ) + + train_data = data["train"]["raw_data"] + print(f" 训练数据量: {len(train_data)} 条") + print(f" [通过] skip=0 时数据正常") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])