feat(training): 实现数据处理器
- 新增 StandardScaler:全局标准化,训练集学习参数,测试集复用 - 新增 CrossSectionalStandardScaler:截面标准化,每天独立计算 - 新增 Winsorizer:支持全局/截面两种缩尾模式 - 处理器统一遵循 fit/transform 接口,Trainer 可无差别调用 - 添加 17 个单元测试覆盖各种场景
This commit is contained in:
@@ -15,10 +15,20 @@ from src.training.components.selectors import (
|
||||
StockFilterConfig,
|
||||
)
|
||||
|
||||
# 数据处理器
|
||||
from src.training.components.processors import (
|
||||
CrossSectionalStandardScaler,
|
||||
StandardScaler,
|
||||
Winsorizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"BaseProcessor",
|
||||
"DateSplitter",
|
||||
"StockFilterConfig",
|
||||
"MarketCapSelectorConfig",
|
||||
"StandardScaler",
|
||||
"CrossSectionalStandardScaler",
|
||||
"Winsorizer",
|
||||
]
|
||||
|
||||
16
src/training/components/processors/__init__.py
Normal file
16
src/training/components/processors/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""数据处理器子模块
|
||||
|
||||
包含数据预处理、转换等处理器实现。
|
||||
"""
|
||||
|
||||
from src.training.components.processors.transforms import (
|
||||
CrossSectionalStandardScaler,
|
||||
StandardScaler,
|
||||
Winsorizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"StandardScaler",
|
||||
"CrossSectionalStandardScaler",
|
||||
"Winsorizer",
|
||||
]
|
||||
275
src/training/components/processors/transforms.py
Normal file
275
src/training/components/processors/transforms.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""数据处理器实现
|
||||
|
||||
包含标准化、缩尾等数据处理器。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.training.components.base import BaseProcessor
|
||||
from src.training.registry import register_processor
|
||||
|
||||
|
||||
@register_processor("standard_scaler")
|
||||
class StandardScaler(BaseProcessor):
|
||||
"""标准化处理器(全局标准化)
|
||||
|
||||
在整个训练集上学习均值和标准差,
|
||||
然后应用到训练集和测试集。
|
||||
|
||||
适用于需要全局统计量的场景。
|
||||
|
||||
Attributes:
|
||||
exclude_cols: 不参与标准化的列名列表
|
||||
mean_: 学习到的均值字典 {列名: 均值}
|
||||
std_: 学习到的标准差字典 {列名: 标准差}
|
||||
"""
|
||||
|
||||
name = "standard_scaler"
|
||||
|
||||
def __init__(self, exclude_cols: Optional[List[str]] = None):
|
||||
"""初始化标准化处理器
|
||||
|
||||
Args:
|
||||
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
||||
"""
|
||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||
self.mean_: dict = {}
|
||||
self.std_: dict = {}
|
||||
|
||||
def fit(self, X: pl.DataFrame) -> "StandardScaler":
|
||||
"""计算均值和标准差(仅在训练集上)
|
||||
|
||||
Args:
|
||||
X: 训练数据
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
for col in numeric_cols:
|
||||
self.mean_[col] = X[col].mean()
|
||||
self.std_[col] = X[col].std()
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""标准化(使用训练集学到的参数)
|
||||
|
||||
Args:
|
||||
X: 待转换数据
|
||||
|
||||
Returns:
|
||||
标准化后的数据
|
||||
"""
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in self.mean_ and col in self.std_:
|
||||
# 避免除以0
|
||||
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
|
||||
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
|
||||
return X.select(expressions)
|
||||
|
||||
|
||||
@register_processor("cs_standard_scaler")
|
||||
class CrossSectionalStandardScaler(BaseProcessor):
|
||||
"""截面标准化处理器
|
||||
|
||||
每天独立进行标准化:对当天所有股票的某一因子进行标准化。
|
||||
|
||||
特点:
|
||||
- 不需要 fit,每天独立计算当天的均值和标准差
|
||||
- 适用于截面因子,消除市值等行业差异
|
||||
- 公式:z = (x - mean_today) / std_today
|
||||
|
||||
Attributes:
|
||||
exclude_cols: 不参与标准化的列名列表
|
||||
date_col: 日期列名
|
||||
"""
|
||||
|
||||
name = "cs_standard_scaler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exclude_cols: Optional[List[str]] = None,
|
||||
date_col: str = "trade_date",
|
||||
):
|
||||
"""初始化截面标准化处理器
|
||||
|
||||
Args:
|
||||
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
||||
date_col: 日期列名
|
||||
"""
|
||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||
self.date_col = date_col
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""截面标准化
|
||||
|
||||
按日期分组,每天独立计算均值和标准差并标准化。
|
||||
不需要 fit,因为每天使用当天的统计量。
|
||||
|
||||
Args:
|
||||
X: 待转换数据
|
||||
|
||||
Returns:
|
||||
标准化后的数据
|
||||
"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
# 构建表达式列表
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
# 截面标准化:每天独立计算均值和标准差
|
||||
# 避免除以0,当std为0时设为1
|
||||
expr = (
|
||||
(pl.col(col) - pl.col(col).mean().over(self.date_col))
|
||||
/ (pl.col(col).std().over(self.date_col) + 1e-10)
|
||||
).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
|
||||
return X.select(expressions)
|
||||
|
||||
|
||||
@register_processor("winsorizer")
|
||||
class Winsorizer(BaseProcessor):
|
||||
"""缩尾处理器
|
||||
|
||||
对每一列的极端值进行截断处理。
|
||||
可以全局截断(在整个训练集上学习分位数),
|
||||
也可以截面截断(每天独立处理)。
|
||||
|
||||
Attributes:
|
||||
lower: 下分位数(如0.01表示1%分位数)
|
||||
upper: 上分位数(如0.99表示99%分位数)
|
||||
by_date: True=每天独立缩尾, False=全局缩尾
|
||||
date_col: 日期列名
|
||||
bounds_: 存储分位数边界(全局模式)
|
||||
"""
|
||||
|
||||
name = "winsorizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lower: float = 0.01,
|
||||
upper: float = 0.99,
|
||||
by_date: bool = False,
|
||||
date_col: str = "trade_date",
|
||||
):
|
||||
"""初始化缩尾处理器
|
||||
|
||||
Args:
|
||||
lower: 下分位数,默认0.01
|
||||
upper: 上分位数,默认0.99
|
||||
by_date: 每天独立缩尾,默认False(全局缩尾)
|
||||
date_col: 日期列名
|
||||
|
||||
Raises:
|
||||
ValueError: 分位数参数无效
|
||||
"""
|
||||
if not 0 <= lower < upper <= 1:
|
||||
raise ValueError(
|
||||
f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内"
|
||||
)
|
||||
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.by_date = by_date
|
||||
self.date_col = date_col
|
||||
self.bounds_: dict = {}
|
||||
|
||||
def fit(self, X: pl.DataFrame) -> "Winsorizer":
|
||||
"""学习分位数边界(仅在全局模式下)
|
||||
|
||||
Args:
|
||||
X: 训练数据
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
if not self.by_date:
|
||||
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
||||
for col in numeric_cols:
|
||||
self.bounds_[col] = {
|
||||
"lower": X[col].quantile(self.lower),
|
||||
"upper": X[col].quantile(self.upper),
|
||||
}
|
||||
return self
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""缩尾处理
|
||||
|
||||
Args:
|
||||
X: 待转换数据
|
||||
|
||||
Returns:
|
||||
缩尾处理后的数据
|
||||
"""
|
||||
if self.by_date:
|
||||
return self._transform_by_date(X)
|
||||
else:
|
||||
return self._transform_global(X)
|
||||
|
||||
def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""全局缩尾(使用训练集学到的边界)"""
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in self.bounds_:
|
||||
lower = self.bounds_[col]["lower"]
|
||||
upper = self.bounds_[col]["upper"]
|
||||
expr = pl.col(col).clip(lower, upper).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
return X.select(expressions)
|
||||
|
||||
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""每日独立缩尾"""
|
||||
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
||||
|
||||
# 先计算每天的分位数
|
||||
lower_exprs = [
|
||||
pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower")
|
||||
for col in numeric_cols
|
||||
]
|
||||
upper_exprs = [
|
||||
pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper")
|
||||
for col in numeric_cols
|
||||
]
|
||||
|
||||
# 添加分位数列
|
||||
result = X.with_columns(lower_exprs + upper_exprs)
|
||||
|
||||
# 执行缩尾
|
||||
clip_exprs = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
clipped = (
|
||||
pl.col(col)
|
||||
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
||||
.alias(col)
|
||||
)
|
||||
clip_exprs.append(clipped)
|
||||
else:
|
||||
clip_exprs.append(pl.col(col))
|
||||
|
||||
result = result.select(clip_exprs)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user