feat(training): 添加缺失值填充处理器 NullFiller

新增 NullFiller 处理器,支持 zero/mean/median/value 填充策略,
支持全局统计量或按日期截面填充。在回归训练流程中添加 NullFiller。
This commit is contained in:
2026-03-05 21:57:34 +08:00
parent aefe6d06cf
commit 7b935b0fa3
4 changed files with 206 additions and 3 deletions

View File

@@ -19,7 +19,7 @@ from src.training import (
StockFilterConfig,
StockPoolManager,
Trainer,
Winsorizer,
Winsorizer, NullFiller,
)
from src.training.config import TrainingConfig
@@ -224,6 +224,7 @@ def train_regression_model():
# 6. 创建数据处理器(从 PROCESSOR_CONFIGS 解析)
processors = [
NullFiller(strategy="mean"),
Winsorizer(**PROCESSOR_CONFIGS[0]["params"]), # type: ignore[arg-type]
StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), # type: ignore[call-arg]
]

View File

@@ -26,6 +26,7 @@ from src.training.components.selectors import (
# 数据处理器
from src.training.components.processors import (
CrossSectionalStandardScaler,
NullFiller,
StandardScaler,
Winsorizer,
)
@@ -57,6 +58,7 @@ __all__ = [
"StockFilterConfig",
"MarketCapSelectorConfig",
# 数据处理器
"NullFiller",
"StandardScaler",
"CrossSectionalStandardScaler",
"Winsorizer",

View File

@@ -5,11 +5,13 @@
from src.training.components.processors.transforms import (
CrossSectionalStandardScaler,
NullFiller,
StandardScaler,
Winsorizer,
)
__all__ = [
"NullFiller",
"StandardScaler",
"CrossSectionalStandardScaler",
"Winsorizer",

View File

@@ -1,9 +1,9 @@
"""数据处理器实现
包含标准化、缩尾等数据处理器。
包含标准化、缩尾、缺失值填充等数据处理器。
"""
from typing import List, Optional
from typing import List, Literal, Optional, Union
import polars as pl
@@ -11,6 +11,204 @@ from src.training.components.base import BaseProcessor
from src.training.registry import register_processor
@register_processor("null_filler")
class NullFiller(BaseProcessor):
"""缺失值填充处理器
支持多种填充策略固定值、0、均值、中值。
可以全局填充或使用当天截面统计量填充。
填充策略:
- "zero": 填充0
- "mean": 填充均值(全局或当天截面)
- "median": 填充中值(全局或当天截面)
- "value": 填充指定数值
Attributes:
strategy: 填充策略,可选 "zero", "mean", "median", "value"
fill_value: 当 strategy="value" 时使用的填充值
by_date: 是否按日期独立计算统计量(仅对 mean/median 有效)
date_col: 日期列名
exclude_cols: 不参与填充的列名列表
stats_: 存储学习到的统计量(全局模式)
"""
name = "null_filler"
def __init__(
self,
strategy: Literal["zero", "mean", "median", "value"] = "zero",
fill_value: Optional[float] = None,
by_date: bool = True,
date_col: str = "trade_date",
exclude_cols: Optional[List[str]] = None,
):
"""初始化缺失值填充处理器
Args:
strategy: 填充策略,默认 "zero"
- "zero": 填充0
- "mean": 填充均值
- "median": 填充中值
- "value": 填充指定数值(需配合 fill_value
fill_value: 当 strategy="value" 时的填充值,默认为 None
by_date: 是否每天独立计算统计量,默认 False全局统计量
date_col: 日期列名,默认 "trade_date"
exclude_cols: 不参与填充的列名列表,默认为 ["ts_code", "trade_date"]
Raises:
ValueError: 策略无效或 fill_value 未提供时
"""
if strategy not in ("zero", "mean", "median", "value"):
raise ValueError(
f"无效的填充策略: {strategy},必须是 'zero', 'mean', 'median', 'value' 之一"
)
if strategy == "value" and fill_value is None:
raise ValueError("当 strategy='value' 时,必须提供 fill_value")
self.strategy = strategy
self.fill_value = fill_value
self.by_date = by_date
self.date_col = date_col
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
self.stats_: dict = {}
def fit(self, X: pl.DataFrame) -> "NullFiller":
"""学习统计量(仅在全局模式下)
在全局模式下,计算每列的均值或中值作为填充值。
在截面模式下by_date=True不需要 fit每天独立计算。
Args:
X: 训练数据
Returns:
self
"""
if not self.by_date and self.strategy in ("mean", "median"):
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
for col in numeric_cols:
if self.strategy == "mean":
self.stats_[col] = X[col].mean()
else: # median
self.stats_[col] = X[col].median()
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""填充缺失值
Args:
X: 待转换数据
Returns:
填充后的数据
"""
if self.strategy == "zero":
return self._fill_with_zero(X)
elif self.strategy == "value":
return self._fill_with_value(X)
elif self.strategy in ("mean", "median"):
if self.by_date:
return self._fill_by_date(X)
else:
return self._fill_global(X)
else:
# 不应该到达这里,因为 __init__ 已经验证
raise ValueError(f"未知的填充策略: {self.strategy}")
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用0填充缺失值"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
expressions = []
for col in X.columns:
if col in numeric_cols:
expr = pl.col(col).fill_null(0).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用指定值填充缺失值"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
expressions = []
for col in X.columns:
if col in numeric_cols:
expr = pl.col(col).fill_null(self.fill_value).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
def _fill_global(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用全局统计量填充(训练集学到的统计量)"""
expressions = []
for col in X.columns:
if col in self.stats_:
fill_val = self.stats_[col]
expr = pl.col(col).fill_null(fill_val).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用每天截面统计量填充"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
# 计算每天的统计量
stat_exprs = []
for col in numeric_cols:
if self.strategy == "mean":
stat_exprs.append(
pl.col(col).mean().over(self.date_col).alias(f"{col}_stat")
)
else: # median
stat_exprs.append(
pl.col(col).median().over(self.date_col).alias(f"{col}_stat")
)
# 添加统计量列
result = X.with_columns(stat_exprs)
# 使用统计量填充缺失值
fill_exprs = []
for col in X.columns:
if col in numeric_cols:
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
fill_exprs.append(expr)
else:
fill_exprs.append(pl.col(col))
result = result.select(fill_exprs)
return result
@register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
"""标准化处理器(全局标准化)