diff --git a/src/experiment/regression.py b/src/experiment/regression.py index e2e7c50..f004ff0 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -19,7 +19,7 @@ from src.training import ( StockFilterConfig, StockPoolManager, Trainer, - Winsorizer, + Winsorizer, NullFiller, ) from src.training.config import TrainingConfig @@ -224,6 +224,7 @@ def train_regression_model(): # 6. 创建数据处理器(从 PROCESSOR_CONFIGS 解析) processors = [ + NullFiller(strategy="mean"), Winsorizer(**PROCESSOR_CONFIGS[0]["params"]), # type: ignore[arg-type] StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), # type: ignore[call-arg] ] diff --git a/src/training/__init__.py b/src/training/__init__.py index e885e2c..87eb845 100644 --- a/src/training/__init__.py +++ b/src/training/__init__.py @@ -26,6 +26,7 @@ from src.training.components.selectors import ( # 数据处理器 from src.training.components.processors import ( CrossSectionalStandardScaler, + NullFiller, StandardScaler, Winsorizer, ) @@ -57,6 +58,7 @@ __all__ = [ "StockFilterConfig", "MarketCapSelectorConfig", # 数据处理器 + "NullFiller", "StandardScaler", "CrossSectionalStandardScaler", "Winsorizer", diff --git a/src/training/components/processors/__init__.py b/src/training/components/processors/__init__.py index c790cf6..637dd16 100644 --- a/src/training/components/processors/__init__.py +++ b/src/training/components/processors/__init__.py @@ -5,11 +5,13 @@ from src.training.components.processors.transforms import ( CrossSectionalStandardScaler, + NullFiller, StandardScaler, Winsorizer, ) __all__ = [ + "NullFiller", "StandardScaler", "CrossSectionalStandardScaler", "Winsorizer", diff --git a/src/training/components/processors/transforms.py b/src/training/components/processors/transforms.py index 363e66b..22a53dd 100644 --- a/src/training/components/processors/transforms.py +++ b/src/training/components/processors/transforms.py @@ -1,9 +1,9 @@ """数据处理器实现 -包含标准化、缩尾等数据处理器。 +包含标准化、缩尾、缺失值填充等数据处理器。 """ -from typing import List, Optional +from typing import List, Literal, Optional, Union import polars as pl @@ -11,6 +11,204 @@ from src.training.components.base import BaseProcessor from src.training.registry import register_processor +@register_processor("null_filler") +class NullFiller(BaseProcessor): + """缺失值填充处理器 + + 支持多种填充策略:固定值、0、均值、中值。 + 可以全局填充或使用当天截面统计量填充。 + + 填充策略: + - "zero": 填充0 + - "mean": 填充均值(全局或当天截面) + - "median": 填充中值(全局或当天截面) + - "value": 填充指定数值 + + Attributes: + strategy: 填充策略,可选 "zero", "mean", "median", "value" + fill_value: 当 strategy="value" 时使用的填充值 + by_date: 是否按日期独立计算统计量(仅对 mean/median 有效) + date_col: 日期列名 + exclude_cols: 不参与填充的列名列表 + stats_: 存储学习到的统计量(全局模式) + """ + + name = "null_filler" + + def __init__( + self, + strategy: Literal["zero", "mean", "median", "value"] = "zero", + fill_value: Optional[float] = None, + by_date: bool = True, + date_col: str = "trade_date", + exclude_cols: Optional[List[str]] = None, + ): + """初始化缺失值填充处理器 + + Args: + strategy: 填充策略,默认 "zero" + - "zero": 填充0 + - "mean": 填充均值 + - "median": 填充中值 + - "value": 填充指定数值(需配合 fill_value) + fill_value: 当 strategy="value" 时的填充值,默认为 None + by_date: 是否每天独立计算统计量,默认 False(全局统计量) + date_col: 日期列名,默认 "trade_date" + exclude_cols: 不参与填充的列名列表,默认为 ["ts_code", "trade_date"] + + Raises: + ValueError: 策略无效或 fill_value 未提供时 + """ + if strategy not in ("zero", "mean", "median", "value"): + raise ValueError( + f"无效的填充策略: {strategy},必须是 'zero', 'mean', 'median', 'value' 之一" + ) + + if strategy == "value" and fill_value is None: + raise ValueError("当 strategy='value' 时,必须提供 fill_value") + + self.strategy = strategy + self.fill_value = fill_value + self.by_date = by_date + self.date_col = date_col + self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] + self.stats_: dict = {} + + def fit(self, X: pl.DataFrame) -> "NullFiller": + """学习统计量(仅在全局模式下) + + 在全局模式下,计算每列的均值或中值作为填充值。 + 在截面模式下(by_date=True),不需要 fit,每天独立计算。 + + Args: + X: 训练数据 + + Returns: + self + """ + if not self.by_date and self.strategy in ("mean", "median"): + numeric_cols = [ + c + for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + for col in numeric_cols: + if self.strategy == "mean": + self.stats_[col] = X[col].mean() + else: # median + self.stats_[col] = X[col].median() + + return self + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """填充缺失值 + + Args: + X: 待转换数据 + + Returns: + 填充后的数据 + """ + if self.strategy == "zero": + return self._fill_with_zero(X) + elif self.strategy == "value": + return self._fill_with_value(X) + elif self.strategy in ("mean", "median"): + if self.by_date: + return self._fill_by_date(X) + else: + return self._fill_global(X) + else: + # 不应该到达这里,因为 __init__ 已经验证 + raise ValueError(f"未知的填充策略: {self.strategy}") + + def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame: + """使用0填充缺失值""" + numeric_cols = [ + c + for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + expressions = [] + for col in X.columns: + if col in numeric_cols: + expr = pl.col(col).fill_null(0).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + + return X.select(expressions) + + def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame: + """使用指定值填充缺失值""" + numeric_cols = [ + c + for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + expressions = [] + for col in X.columns: + if col in numeric_cols: + expr = pl.col(col).fill_null(self.fill_value).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + + return X.select(expressions) + + def _fill_global(self, X: pl.DataFrame) -> pl.DataFrame: + """使用全局统计量填充(训练集学到的统计量)""" + expressions = [] + for col in X.columns: + if col in self.stats_: + fill_val = self.stats_[col] + expr = pl.col(col).fill_null(fill_val).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + + return X.select(expressions) + + def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame: + """使用每天截面统计量填充""" + numeric_cols = [ + c + for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + # 计算每天的统计量 + stat_exprs = [] + for col in numeric_cols: + if self.strategy == "mean": + stat_exprs.append( + pl.col(col).mean().over(self.date_col).alias(f"{col}_stat") + ) + else: # median + stat_exprs.append( + pl.col(col).median().over(self.date_col).alias(f"{col}_stat") + ) + + # 添加统计量列 + result = X.with_columns(stat_exprs) + + # 使用统计量填充缺失值 + fill_exprs = [] + for col in X.columns: + if col in numeric_cols: + expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col) + fill_exprs.append(expr) + else: + fill_exprs.append(pl.col(col)) + + result = result.select(fill_exprs) + + return result + + @register_processor("standard_scaler") class StandardScaler(BaseProcessor): """标准化处理器(全局标准化)