feat(training): 添加缺失值填充处理器 NullFiller
新增 NullFiller 处理器,支持 zero/mean/median/value 填充策略, 支持全局统计量或按日期截面填充。在回归训练流程中添加 NullFiller。
This commit is contained in:
@@ -19,7 +19,7 @@ from src.training import (
|
|||||||
StockFilterConfig,
|
StockFilterConfig,
|
||||||
StockPoolManager,
|
StockPoolManager,
|
||||||
Trainer,
|
Trainer,
|
||||||
Winsorizer,
|
Winsorizer, NullFiller,
|
||||||
)
|
)
|
||||||
from src.training.config import TrainingConfig
|
from src.training.config import TrainingConfig
|
||||||
|
|
||||||
@@ -224,6 +224,7 @@ def train_regression_model():
|
|||||||
|
|
||||||
# 6. 创建数据处理器(从 PROCESSOR_CONFIGS 解析)
|
# 6. 创建数据处理器(从 PROCESSOR_CONFIGS 解析)
|
||||||
processors = [
|
processors = [
|
||||||
|
NullFiller(strategy="mean"),
|
||||||
Winsorizer(**PROCESSOR_CONFIGS[0]["params"]), # type: ignore[arg-type]
|
Winsorizer(**PROCESSOR_CONFIGS[0]["params"]), # type: ignore[arg-type]
|
||||||
StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), # type: ignore[call-arg]
|
StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), # type: ignore[call-arg]
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from src.training.components.selectors import (
|
|||||||
# 数据处理器
|
# 数据处理器
|
||||||
from src.training.components.processors import (
|
from src.training.components.processors import (
|
||||||
CrossSectionalStandardScaler,
|
CrossSectionalStandardScaler,
|
||||||
|
NullFiller,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
Winsorizer,
|
Winsorizer,
|
||||||
)
|
)
|
||||||
@@ -57,6 +58,7 @@ __all__ = [
|
|||||||
"StockFilterConfig",
|
"StockFilterConfig",
|
||||||
"MarketCapSelectorConfig",
|
"MarketCapSelectorConfig",
|
||||||
# 数据处理器
|
# 数据处理器
|
||||||
|
"NullFiller",
|
||||||
"StandardScaler",
|
"StandardScaler",
|
||||||
"CrossSectionalStandardScaler",
|
"CrossSectionalStandardScaler",
|
||||||
"Winsorizer",
|
"Winsorizer",
|
||||||
|
|||||||
@@ -5,11 +5,13 @@
|
|||||||
|
|
||||||
from src.training.components.processors.transforms import (
|
from src.training.components.processors.transforms import (
|
||||||
CrossSectionalStandardScaler,
|
CrossSectionalStandardScaler,
|
||||||
|
NullFiller,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
Winsorizer,
|
Winsorizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"NullFiller",
|
||||||
"StandardScaler",
|
"StandardScaler",
|
||||||
"CrossSectionalStandardScaler",
|
"CrossSectionalStandardScaler",
|
||||||
"Winsorizer",
|
"Winsorizer",
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""数据处理器实现
|
"""数据处理器实现
|
||||||
|
|
||||||
包含标准化、缩尾等数据处理器。
|
包含标准化、缩尾、缺失值填充等数据处理器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
@@ -11,6 +11,204 @@ from src.training.components.base import BaseProcessor
|
|||||||
from src.training.registry import register_processor
|
from src.training.registry import register_processor
|
||||||
|
|
||||||
|
|
||||||
|
@register_processor("null_filler")
|
||||||
|
class NullFiller(BaseProcessor):
|
||||||
|
"""缺失值填充处理器
|
||||||
|
|
||||||
|
支持多种填充策略:固定值、0、均值、中值。
|
||||||
|
可以全局填充或使用当天截面统计量填充。
|
||||||
|
|
||||||
|
填充策略:
|
||||||
|
- "zero": 填充0
|
||||||
|
- "mean": 填充均值(全局或当天截面)
|
||||||
|
- "median": 填充中值(全局或当天截面)
|
||||||
|
- "value": 填充指定数值
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
strategy: 填充策略,可选 "zero", "mean", "median", "value"
|
||||||
|
fill_value: 当 strategy="value" 时使用的填充值
|
||||||
|
by_date: 是否按日期独立计算统计量(仅对 mean/median 有效)
|
||||||
|
date_col: 日期列名
|
||||||
|
exclude_cols: 不参与填充的列名列表
|
||||||
|
stats_: 存储学习到的统计量(全局模式)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "null_filler"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
strategy: Literal["zero", "mean", "median", "value"] = "zero",
|
||||||
|
fill_value: Optional[float] = None,
|
||||||
|
by_date: bool = True,
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
exclude_cols: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""初始化缺失值填充处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy: 填充策略,默认 "zero"
|
||||||
|
- "zero": 填充0
|
||||||
|
- "mean": 填充均值
|
||||||
|
- "median": 填充中值
|
||||||
|
- "value": 填充指定数值(需配合 fill_value)
|
||||||
|
fill_value: 当 strategy="value" 时的填充值,默认为 None
|
||||||
|
by_date: 是否每天独立计算统计量,默认 False(全局统计量)
|
||||||
|
date_col: 日期列名,默认 "trade_date"
|
||||||
|
exclude_cols: 不参与填充的列名列表,默认为 ["ts_code", "trade_date"]
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 策略无效或 fill_value 未提供时
|
||||||
|
"""
|
||||||
|
if strategy not in ("zero", "mean", "median", "value"):
|
||||||
|
raise ValueError(
|
||||||
|
f"无效的填充策略: {strategy},必须是 'zero', 'mean', 'median', 'value' 之一"
|
||||||
|
)
|
||||||
|
|
||||||
|
if strategy == "value" and fill_value is None:
|
||||||
|
raise ValueError("当 strategy='value' 时,必须提供 fill_value")
|
||||||
|
|
||||||
|
self.strategy = strategy
|
||||||
|
self.fill_value = fill_value
|
||||||
|
self.by_date = by_date
|
||||||
|
self.date_col = date_col
|
||||||
|
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||||
|
self.stats_: dict = {}
|
||||||
|
|
||||||
|
def fit(self, X: pl.DataFrame) -> "NullFiller":
|
||||||
|
"""学习统计量(仅在全局模式下)
|
||||||
|
|
||||||
|
在全局模式下,计算每列的均值或中值作为填充值。
|
||||||
|
在截面模式下(by_date=True),不需要 fit,每天独立计算。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self
|
||||||
|
"""
|
||||||
|
if not self.by_date and self.strategy in ("mean", "median"):
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
for col in numeric_cols:
|
||||||
|
if self.strategy == "mean":
|
||||||
|
self.stats_[col] = X[col].mean()
|
||||||
|
else: # median
|
||||||
|
self.stats_[col] = X[col].median()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""填充缺失值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 待转换数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
填充后的数据
|
||||||
|
"""
|
||||||
|
if self.strategy == "zero":
|
||||||
|
return self._fill_with_zero(X)
|
||||||
|
elif self.strategy == "value":
|
||||||
|
return self._fill_with_value(X)
|
||||||
|
elif self.strategy in ("mean", "median"):
|
||||||
|
if self.by_date:
|
||||||
|
return self._fill_by_date(X)
|
||||||
|
else:
|
||||||
|
return self._fill_global(X)
|
||||||
|
else:
|
||||||
|
# 不应该到达这里,因为 __init__ 已经验证
|
||||||
|
raise ValueError(f"未知的填充策略: {self.strategy}")
|
||||||
|
|
||||||
|
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""使用0填充缺失值"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
expr = pl.col(col).fill_null(0).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""使用指定值填充缺失值"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
expr = pl.col(col).fill_null(self.fill_value).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
def _fill_global(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""使用全局统计量填充(训练集学到的统计量)"""
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in self.stats_:
|
||||||
|
fill_val = self.stats_[col]
|
||||||
|
expr = pl.col(col).fill_null(fill_val).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""使用每天截面统计量填充"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
# 计算每天的统计量
|
||||||
|
stat_exprs = []
|
||||||
|
for col in numeric_cols:
|
||||||
|
if self.strategy == "mean":
|
||||||
|
stat_exprs.append(
|
||||||
|
pl.col(col).mean().over(self.date_col).alias(f"{col}_stat")
|
||||||
|
)
|
||||||
|
else: # median
|
||||||
|
stat_exprs.append(
|
||||||
|
pl.col(col).median().over(self.date_col).alias(f"{col}_stat")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加统计量列
|
||||||
|
result = X.with_columns(stat_exprs)
|
||||||
|
|
||||||
|
# 使用统计量填充缺失值
|
||||||
|
fill_exprs = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
|
||||||
|
fill_exprs.append(expr)
|
||||||
|
else:
|
||||||
|
fill_exprs.append(pl.col(col))
|
||||||
|
|
||||||
|
result = result.select(fill_exprs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@register_processor("standard_scaler")
|
@register_processor("standard_scaler")
|
||||||
class StandardScaler(BaseProcessor):
|
class StandardScaler(BaseProcessor):
|
||||||
"""标准化处理器(全局标准化)
|
"""标准化处理器(全局标准化)
|
||||||
|
|||||||
Reference in New Issue
Block a user