diff --git a/src/training/components/__init__.py b/src/training/components/__init__.py index 9b2ebe4..3106fed 100644 --- a/src/training/components/__init__.py +++ b/src/training/components/__init__.py @@ -15,10 +15,20 @@ from src.training.components.selectors import ( StockFilterConfig, ) +# 数据处理器 +from src.training.components.processors import ( + CrossSectionalStandardScaler, + StandardScaler, + Winsorizer, +) + __all__ = [ "BaseModel", "BaseProcessor", "DateSplitter", "StockFilterConfig", "MarketCapSelectorConfig", + "StandardScaler", + "CrossSectionalStandardScaler", + "Winsorizer", ] diff --git a/src/training/components/processors/__init__.py b/src/training/components/processors/__init__.py new file mode 100644 index 0000000..c790cf6 --- /dev/null +++ b/src/training/components/processors/__init__.py @@ -0,0 +1,16 @@ +"""数据处理器子模块 + +包含数据预处理、转换等处理器实现。 +""" + +from src.training.components.processors.transforms import ( + CrossSectionalStandardScaler, + StandardScaler, + Winsorizer, +) + +__all__ = [ + "StandardScaler", + "CrossSectionalStandardScaler", + "Winsorizer", +] diff --git a/src/training/components/processors/transforms.py b/src/training/components/processors/transforms.py new file mode 100644 index 0000000..363e66b --- /dev/null +++ b/src/training/components/processors/transforms.py @@ -0,0 +1,275 @@ +"""数据处理器实现 + +包含标准化、缩尾等数据处理器。 +""" + +from typing import List, Optional + +import polars as pl + +from src.training.components.base import BaseProcessor +from src.training.registry import register_processor + + +@register_processor("standard_scaler") +class StandardScaler(BaseProcessor): + """标准化处理器(全局标准化) + + 在整个训练集上学习均值和标准差, + 然后应用到训练集和测试集。 + + 适用于需要全局统计量的场景。 + + Attributes: + exclude_cols: 不参与标准化的列名列表 + mean_: 学习到的均值字典 {列名: 均值} + std_: 学习到的标准差字典 {列名: 标准差} + """ + + name = "standard_scaler" + + def __init__(self, exclude_cols: Optional[List[str]] = None): + """初始化标准化处理器 + + Args: + exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"] + """ + self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] + self.mean_: dict = {} + self.std_: dict = {} + + def fit(self, X: pl.DataFrame) -> "StandardScaler": + """计算均值和标准差(仅在训练集上) + + Args: + X: 训练数据 + + Returns: + self + """ + numeric_cols = [ + c + for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + for col in numeric_cols: + self.mean_[col] = X[col].mean() + self.std_[col] = X[col].std() + + return self + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """标准化(使用训练集学到的参数) + + Args: + X: 待转换数据 + + Returns: + 标准化后的数据 + """ + expressions = [] + for col in X.columns: + if col in self.mean_ and col in self.std_: + # 避免除以0 + std_val = self.std_[col] if self.std_[col] != 0 else 1.0 + expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + + return X.select(expressions) + + +@register_processor("cs_standard_scaler") +class CrossSectionalStandardScaler(BaseProcessor): + """截面标准化处理器 + + 每天独立进行标准化:对当天所有股票的某一因子进行标准化。 + + 特点: + - 不需要 fit,每天独立计算当天的均值和标准差 + - 适用于截面因子,消除市值等行业差异 + - 公式:z = (x - mean_today) / std_today + + Attributes: + exclude_cols: 不参与标准化的列名列表 + date_col: 日期列名 + """ + + name = "cs_standard_scaler" + + def __init__( + self, + exclude_cols: Optional[List[str]] = None, + date_col: str = "trade_date", + ): + """初始化截面标准化处理器 + + Args: + exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"] + date_col: 日期列名 + """ + self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] + self.date_col = date_col + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """截面标准化 + + 按日期分组,每天独立计算均值和标准差并标准化。 + 不需要 fit,因为每天使用当天的统计量。 + + Args: + X: 待转换数据 + + Returns: + 标准化后的数据 + """ + numeric_cols = [ + c + for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + # 构建表达式列表 + expressions = [] + for col in X.columns: + if col in numeric_cols: + # 截面标准化:每天独立计算均值和标准差 + # 避免除以0,当std为0时设为1 + expr = ( + (pl.col(col) - pl.col(col).mean().over(self.date_col)) + / (pl.col(col).std().over(self.date_col) + 1e-10) + ).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + + return X.select(expressions) + + +@register_processor("winsorizer") +class Winsorizer(BaseProcessor): + """缩尾处理器 + + 对每一列的极端值进行截断处理。 + 可以全局截断(在整个训练集上学习分位数), + 也可以截面截断(每天独立处理)。 + + Attributes: + lower: 下分位数(如0.01表示1%分位数) + upper: 上分位数(如0.99表示99%分位数) + by_date: True=每天独立缩尾, False=全局缩尾 + date_col: 日期列名 + bounds_: 存储分位数边界(全局模式) + """ + + name = "winsorizer" + + def __init__( + self, + lower: float = 0.01, + upper: float = 0.99, + by_date: bool = False, + date_col: str = "trade_date", + ): + """初始化缩尾处理器 + + Args: + lower: 下分位数,默认0.01 + upper: 上分位数,默认0.99 + by_date: 每天独立缩尾,默认False(全局缩尾) + date_col: 日期列名 + + Raises: + ValueError: 分位数参数无效 + """ + if not 0 <= lower < upper <= 1: + raise ValueError( + f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内" + ) + + self.lower = lower + self.upper = upper + self.by_date = by_date + self.date_col = date_col + self.bounds_: dict = {} + + def fit(self, X: pl.DataFrame) -> "Winsorizer": + """学习分位数边界(仅在全局模式下) + + Args: + X: 训练数据 + + Returns: + self + """ + if not self.by_date: + numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()] + for col in numeric_cols: + self.bounds_[col] = { + "lower": X[col].quantile(self.lower), + "upper": X[col].quantile(self.upper), + } + return self + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """缩尾处理 + + Args: + X: 待转换数据 + + Returns: + 缩尾处理后的数据 + """ + if self.by_date: + return self._transform_by_date(X) + else: + return self._transform_global(X) + + def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame: + """全局缩尾(使用训练集学到的边界)""" + expressions = [] + for col in X.columns: + if col in self.bounds_: + lower = self.bounds_[col]["lower"] + upper = self.bounds_[col]["upper"] + expr = pl.col(col).clip(lower, upper).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + return X.select(expressions) + + def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame: + """每日独立缩尾""" + numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()] + + # 先计算每天的分位数 + lower_exprs = [ + pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower") + for col in numeric_cols + ] + upper_exprs = [ + pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper") + for col in numeric_cols + ] + + # 添加分位数列 + result = X.with_columns(lower_exprs + upper_exprs) + + # 执行缩尾 + clip_exprs = [] + for col in X.columns: + if col in numeric_cols: + clipped = ( + pl.col(col) + .clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper")) + .alias(col) + ) + clip_exprs.append(clipped) + else: + clip_exprs.append(pl.col(col)) + + result = result.select(clip_exprs) + + return result diff --git a/tests/training/test_processors.py b/tests/training/test_processors.py new file mode 100644 index 0000000..4f7ce1d --- /dev/null +++ b/tests/training/test_processors.py @@ -0,0 +1,300 @@ +"""测试数据处理器 + +验证 StandardScaler、CrossSectionalStandardScaler 和 Winsorizer 功能。 +""" + +import numpy as np +import polars as pl +import pytest + +from src.training.components.processors import ( + CrossSectionalStandardScaler, + StandardScaler, + Winsorizer, +) + + +class TestStandardScaler: + """StandardScaler 测试类""" + + def test_init_default(self): + """测试默认初始化""" + scaler = StandardScaler() + assert scaler.exclude_cols == ["ts_code", "trade_date"] + assert scaler.mean_ == {} + assert scaler.std_ == {} + + def test_init_custom_exclude(self): + """测试自定义排除列""" + scaler = StandardScaler(exclude_cols=["id", "date"]) + assert scaler.exclude_cols == ["id", "date"] + + def test_fit_transform(self): + """测试拟合和转换""" + data = pl.DataFrame( + { + "ts_code": ["A", "B", "C", "D"], + "trade_date": ["20240101"] * 4, + "value": [1.0, 2.0, 3.0, 4.0], + } + ) + + scaler = StandardScaler() + result = scaler.fit_transform(data) + + # 验证学习到的统计量 + assert scaler.mean_["value"] == 2.5 + assert scaler.std_["value"] == pytest.approx(1.290, rel=1e-2) + + # 验证转换结果 + expected_std = (np.array([1.0, 2.0, 3.0, 4.0]) - 2.5) / 1.290 + assert result["value"].to_list() == pytest.approx( + expected_std.tolist(), rel=1e-2 + ) + + def test_transform_use_fitted_params(self): + """测试转换使用拟合时的参数""" + train_data = pl.DataFrame( + { + "ts_code": ["A", "B", "C"], + "trade_date": ["20240101"] * 3, + "value": [1.0, 2.0, 3.0], + } + ) + + test_data = pl.DataFrame( + { + "ts_code": ["D"], + "trade_date": ["20240102"], + "value": [100.0], # 远离训练分布 + } + ) + + scaler = StandardScaler() + scaler.fit(train_data) + + # 使用训练集的均值(2.0)和标准差进行转换 + result = scaler.transform(test_data) + expected = (100.0 - 2.0) / 1.0 # 均值2.0, 标准差1.0 + assert result["value"][0] == pytest.approx(expected, rel=1e-2) + + def test_exclude_non_numeric(self): + """测试自动排除非数值列""" + data = pl.DataFrame( + { + "ts_code": ["A", "B"], + "trade_date": ["20240101", "20240102"], + "category": ["X", "Y"], # 字符串列 + "value": [1.0, 2.0], + } + ) + + scaler = StandardScaler() + result = scaler.fit_transform(data) + + # category 列应该原样保留 + assert result["category"].to_list() == ["X", "Y"] + # value 列应该被标准化 + assert "value" in scaler.mean_ + + def test_zero_std_handling(self): + """测试处理标准差为0的情况""" + data = pl.DataFrame( + { + "ts_code": ["A", "B"], + "trade_date": ["20240101", "20240102"], + "constant": [5.0, 5.0], # 常数列 + } + ) + + scaler = StandardScaler() + result = scaler.fit_transform(data) + + # 标准差为0时,结果应该为0(避免除以0) + assert result["constant"].to_list() == [0.0, 0.0] + + +class TestCrossSectionalStandardScaler: + """CrossSectionalStandardScaler 测试类""" + + def test_init_default(self): + """测试默认初始化""" + scaler = CrossSectionalStandardScaler() + assert scaler.exclude_cols == ["ts_code", "trade_date"] + assert scaler.date_col == "trade_date" + + def test_init_custom(self): + """测试自定义参数""" + scaler = CrossSectionalStandardScaler( + exclude_cols=["id"], + date_col="date", + ) + assert scaler.exclude_cols == ["id"] + assert scaler.date_col == "date" + + def test_transform_no_fit_needed(self): + """测试不需要 fit""" + data = pl.DataFrame( + { + "ts_code": ["A", "B"], + "trade_date": ["20240101", "20240101"], + "value": [1.0, 3.0], + } + ) + + scaler = CrossSectionalStandardScaler() + # 截面标准化不需要 fit + result = scaler.transform(data) + + # 当天均值=2.0, 样本标准差=sqrt(2)≈1.414, z-score=[-0.707, 0.707] + assert result["value"].to_list() == pytest.approx([-0.707, 0.707], rel=1e-2) + + def test_transform_by_date(self): + """测试按日期分组标准化""" + data = pl.DataFrame( + { + "ts_code": ["A", "B", "C", "D"], + "trade_date": ["20240101", "20240101", "20240102", "20240102"], + "value": [1.0, 3.0, 10.0, 30.0], + } + ) + + scaler = CrossSectionalStandardScaler() + result = scaler.transform(data) + + # 2024-01-01: 均值=2.0, 样本std≈1.414 -> [-0.707, 0.707] + # 2024-01-02: 均值=20.0, 样本std≈14.14 -> [-0.707, 0.707] + values = result["value"].to_list() + assert values[0] == pytest.approx(-0.707, abs=1e-2) + assert values[1] == pytest.approx(0.707, abs=1e-2) + assert values[2] == pytest.approx(-0.707, abs=1e-2) + assert values[3] == pytest.approx(0.707, abs=1e-2) + + def test_exclude_columns_preserved(self): + """测试排除列保持原样""" + data = pl.DataFrame( + { + "ts_code": ["A", "B"], + "trade_date": ["20240101", "20240101"], + "value": [1.0, 3.0], + } + ) + + scaler = CrossSectionalStandardScaler() + result = scaler.transform(data) + + assert result["ts_code"].to_list() == ["A", "B"] + assert result["trade_date"].to_list() == ["20240101", "20240101"] + + +class TestWinsorizer: + """Winsorizer 测试类""" + + def test_init_default(self): + """测试默认初始化""" + winsorizer = Winsorizer() + assert winsorizer.lower == 0.01 + assert winsorizer.upper == 0.99 + assert winsorizer.by_date is False + assert winsorizer.date_col == "trade_date" + + def test_init_custom(self): + """测试自定义参数""" + winsorizer = Winsorizer(lower=0.05, upper=0.95, by_date=True, date_col="date") + assert winsorizer.lower == 0.05 + assert winsorizer.upper == 0.95 + assert winsorizer.by_date is True + assert winsorizer.date_col == "date" + + def test_invalid_quantiles(self): + """测试无效的分位数参数""" + with pytest.raises(ValueError, match="lower .* 必须小于 upper"): + Winsorizer(lower=0.5, upper=0.3) + + with pytest.raises(ValueError, match="lower .* 必须小于 upper"): + Winsorizer(lower=-0.1, upper=0.5) + + with pytest.raises(ValueError, match="lower .* 必须小于 upper"): + Winsorizer(lower=0.5, upper=1.5) + + def test_global_winsorize(self): + """测试全局缩尾""" + # 创建包含极端值的数据 + values = list(range(1, 101)) # 1-100 + values[0] = -1000 # 极端小值 + values[-1] = 1000 # 极端大值 + + data = pl.DataFrame( + { + "ts_code": [f"A{i}" for i in range(100)], + "trade_date": ["20240101"] * 100, + "value": values, + } + ) + + winsorizer = Winsorizer(lower=0.01, upper=0.99) + result = winsorizer.fit_transform(data) + + # 1%分位数=2, 99%分位数=99 + # -1000 应该被截断为 2 + # 1000 应该被截断为 99 + result_values = result["value"].to_list() + assert result_values[0] == 2 # 原-1000被截断 + assert result_values[-1] == 99 # 原1000被截断 + assert result_values[1] == 2 # 原2保持不变 + assert result_values[98] == 99 # 原99保持不变 + + def test_by_date_winsorize(self): + """测试每日独立缩尾""" + data = pl.DataFrame( + { + "ts_code": ["A", "B", "C", "D", "E", "F"], + "trade_date": ["20240101"] * 3 + ["20240102"] * 3, + "value": [1.0, 50.0, 100.0, 200.0, 250.0, 300.0], + } + ) + + winsorizer = Winsorizer(lower=0.0, upper=0.5, by_date=True) + result = winsorizer.transform(data) + + # 每天独立处理: + # 2024-01-01: [1, 50, 100], 50%分位数=50 + # -> 截断为 [1, 50, 50] + # 2024-01-02: [200, 250, 300], 50%分位数=250 + # -> 截断为 [200, 250, 250] + result_values = result["value"].to_list() + assert result_values[0] == 1.0 + assert result_values[1] == 50.0 + assert result_values[2] == 50.0 # 被截断 + assert result_values[3] == 200.0 + assert result_values[4] == 250.0 + assert result_values[5] == 250.0 # 被截断 + + def test_global_transform_after_fit(self): + """测试全局模式下,转换使用拟合时的边界""" + train_data = pl.DataFrame( + { + "ts_code": ["A", "B", "C"], + "trade_date": ["20240101"] * 3, + "value": [1.0, 50.0, 100.0], + } + ) + + test_data = pl.DataFrame( + { + "ts_code": ["D"], + "trade_date": ["20240102"], + "value": [200.0], + } + ) + + winsorizer = Winsorizer(lower=0.0, upper=1.0) # 0%和100%分位数 + winsorizer.fit(train_data) + + # 使用训练集的分位数边界 [1, 100] + result = winsorizer.transform(test_data) + assert result["value"][0] == 100.0 # 被截断为100 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])