feat(training): 添加缺失值填充处理器 NullFiller
新增 NullFiller 处理器,支持 zero/mean/median/value 填充策略, 支持全局统计量或按日期截面填充。在回归训练流程中添加 NullFiller。
This commit is contained in:
@@ -19,7 +19,7 @@ from src.training import (
|
||||
StockFilterConfig,
|
||||
StockPoolManager,
|
||||
Trainer,
|
||||
Winsorizer,
|
||||
Winsorizer, NullFiller,
|
||||
)
|
||||
from src.training.config import TrainingConfig
|
||||
|
||||
@@ -224,6 +224,7 @@ def train_regression_model():
|
||||
|
||||
# 6. 创建数据处理器(从 PROCESSOR_CONFIGS 解析)
|
||||
processors = [
|
||||
NullFiller(strategy="mean"),
|
||||
Winsorizer(**PROCESSOR_CONFIGS[0]["params"]), # type: ignore[arg-type]
|
||||
StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), # type: ignore[call-arg]
|
||||
]
|
||||
|
||||
@@ -26,6 +26,7 @@ from src.training.components.selectors import (
|
||||
# 数据处理器
|
||||
from src.training.components.processors import (
|
||||
CrossSectionalStandardScaler,
|
||||
NullFiller,
|
||||
StandardScaler,
|
||||
Winsorizer,
|
||||
)
|
||||
@@ -57,6 +58,7 @@ __all__ = [
|
||||
"StockFilterConfig",
|
||||
"MarketCapSelectorConfig",
|
||||
# 数据处理器
|
||||
"NullFiller",
|
||||
"StandardScaler",
|
||||
"CrossSectionalStandardScaler",
|
||||
"Winsorizer",
|
||||
|
||||
@@ -5,11 +5,13 @@
|
||||
|
||||
from src.training.components.processors.transforms import (
|
||||
CrossSectionalStandardScaler,
|
||||
NullFiller,
|
||||
StandardScaler,
|
||||
Winsorizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"NullFiller",
|
||||
"StandardScaler",
|
||||
"CrossSectionalStandardScaler",
|
||||
"Winsorizer",
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""数据处理器实现
|
||||
|
||||
包含标准化、缩尾等数据处理器。
|
||||
包含标准化、缩尾、缺失值填充等数据处理器。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import polars as pl
|
||||
|
||||
@@ -11,6 +11,204 @@ from src.training.components.base import BaseProcessor
|
||||
from src.training.registry import register_processor
|
||||
|
||||
|
||||
@register_processor("null_filler")
|
||||
class NullFiller(BaseProcessor):
|
||||
"""缺失值填充处理器
|
||||
|
||||
支持多种填充策略:固定值、0、均值、中值。
|
||||
可以全局填充或使用当天截面统计量填充。
|
||||
|
||||
填充策略:
|
||||
- "zero": 填充0
|
||||
- "mean": 填充均值(全局或当天截面)
|
||||
- "median": 填充中值(全局或当天截面)
|
||||
- "value": 填充指定数值
|
||||
|
||||
Attributes:
|
||||
strategy: 填充策略,可选 "zero", "mean", "median", "value"
|
||||
fill_value: 当 strategy="value" 时使用的填充值
|
||||
by_date: 是否按日期独立计算统计量(仅对 mean/median 有效)
|
||||
date_col: 日期列名
|
||||
exclude_cols: 不参与填充的列名列表
|
||||
stats_: 存储学习到的统计量(全局模式)
|
||||
"""
|
||||
|
||||
name = "null_filler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Literal["zero", "mean", "median", "value"] = "zero",
|
||||
fill_value: Optional[float] = None,
|
||||
by_date: bool = True,
|
||||
date_col: str = "trade_date",
|
||||
exclude_cols: Optional[List[str]] = None,
|
||||
):
|
||||
"""初始化缺失值填充处理器
|
||||
|
||||
Args:
|
||||
strategy: 填充策略,默认 "zero"
|
||||
- "zero": 填充0
|
||||
- "mean": 填充均值
|
||||
- "median": 填充中值
|
||||
- "value": 填充指定数值(需配合 fill_value)
|
||||
fill_value: 当 strategy="value" 时的填充值,默认为 None
|
||||
by_date: 是否每天独立计算统计量,默认 False(全局统计量)
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
exclude_cols: 不参与填充的列名列表,默认为 ["ts_code", "trade_date"]
|
||||
|
||||
Raises:
|
||||
ValueError: 策略无效或 fill_value 未提供时
|
||||
"""
|
||||
if strategy not in ("zero", "mean", "median", "value"):
|
||||
raise ValueError(
|
||||
f"无效的填充策略: {strategy},必须是 'zero', 'mean', 'median', 'value' 之一"
|
||||
)
|
||||
|
||||
if strategy == "value" and fill_value is None:
|
||||
raise ValueError("当 strategy='value' 时,必须提供 fill_value")
|
||||
|
||||
self.strategy = strategy
|
||||
self.fill_value = fill_value
|
||||
self.by_date = by_date
|
||||
self.date_col = date_col
|
||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||
self.stats_: dict = {}
|
||||
|
||||
def fit(self, X: pl.DataFrame) -> "NullFiller":
|
||||
"""学习统计量(仅在全局模式下)
|
||||
|
||||
在全局模式下,计算每列的均值或中值作为填充值。
|
||||
在截面模式下(by_date=True),不需要 fit,每天独立计算。
|
||||
|
||||
Args:
|
||||
X: 训练数据
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
if not self.by_date and self.strategy in ("mean", "median"):
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
for col in numeric_cols:
|
||||
if self.strategy == "mean":
|
||||
self.stats_[col] = X[col].mean()
|
||||
else: # median
|
||||
self.stats_[col] = X[col].median()
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""填充缺失值
|
||||
|
||||
Args:
|
||||
X: 待转换数据
|
||||
|
||||
Returns:
|
||||
填充后的数据
|
||||
"""
|
||||
if self.strategy == "zero":
|
||||
return self._fill_with_zero(X)
|
||||
elif self.strategy == "value":
|
||||
return self._fill_with_value(X)
|
||||
elif self.strategy in ("mean", "median"):
|
||||
if self.by_date:
|
||||
return self._fill_by_date(X)
|
||||
else:
|
||||
return self._fill_global(X)
|
||||
else:
|
||||
# 不应该到达这里,因为 __init__ 已经验证
|
||||
raise ValueError(f"未知的填充策略: {self.strategy}")
|
||||
|
||||
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用0填充缺失值"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
expr = pl.col(col).fill_null(0).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
|
||||
return X.select(expressions)
|
||||
|
||||
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用指定值填充缺失值"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
expr = pl.col(col).fill_null(self.fill_value).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
|
||||
return X.select(expressions)
|
||||
|
||||
def _fill_global(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用全局统计量填充(训练集学到的统计量)"""
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in self.stats_:
|
||||
fill_val = self.stats_[col]
|
||||
expr = pl.col(col).fill_null(fill_val).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
|
||||
return X.select(expressions)
|
||||
|
||||
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用每天截面统计量填充"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
# 计算每天的统计量
|
||||
stat_exprs = []
|
||||
for col in numeric_cols:
|
||||
if self.strategy == "mean":
|
||||
stat_exprs.append(
|
||||
pl.col(col).mean().over(self.date_col).alias(f"{col}_stat")
|
||||
)
|
||||
else: # median
|
||||
stat_exprs.append(
|
||||
pl.col(col).median().over(self.date_col).alias(f"{col}_stat")
|
||||
)
|
||||
|
||||
# 添加统计量列
|
||||
result = X.with_columns(stat_exprs)
|
||||
|
||||
# 使用统计量填充缺失值
|
||||
fill_exprs = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
|
||||
fill_exprs.append(expr)
|
||||
else:
|
||||
fill_exprs.append(pl.col(col))
|
||||
|
||||
result = result.select(fill_exprs)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@register_processor("standard_scaler")
|
||||
class StandardScaler(BaseProcessor):
|
||||
"""标准化处理器(全局标准化)
|
||||
|
||||
Reference in New Issue
Block a user