From 9ca1deae56c3a79aa11c28289c20a3f4d7bcf1d8 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Tue, 3 Mar 2026 22:23:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(training):=20=E5=AE=9E=E7=8E=B0=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=A4=84=E7=90=86=E5=99=A8=20-=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=20StandardScaler=EF=BC=9A=E5=85=A8=E5=B1=80=E6=A0=87=E5=87=86?= =?UTF-8?q?=E5=8C=96=EF=BC=8C=E8=AE=AD=E7=BB=83=E9=9B=86=E5=AD=A6=E4=B9=A0?= =?UTF-8?q?=E5=8F=82=E6=95=B0=EF=BC=8C=E6=B5=8B=E8=AF=95=E9=9B=86=E5=A4=8D?= =?UTF-8?q?=E7=94=A8=20-=20=E6=96=B0=E5=A2=9E=20CrossSectionalStandardScal?= =?UTF-8?q?er=EF=BC=9A=E6=88=AA=E9=9D=A2=E6=A0=87=E5=87=86=E5=8C=96?= =?UTF-8?q?=EF=BC=8C=E6=AF=8F=E5=A4=A9=E7=8B=AC=E7=AB=8B=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=20-=20=E6=96=B0=E5=A2=9E=20Winsorizer=EF=BC=9A=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=85=A8=E5=B1=80/=E6=88=AA=E9=9D=A2=E4=B8=A4?= =?UTF-8?q?=E7=A7=8D=E7=BC=A9=E5=B0=BE=E6=A8=A1=E5=BC=8F=20-=20=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=99=A8=E7=BB=9F=E4=B8=80=E9=81=B5=E5=BE=AA=20fit/tr?= =?UTF-8?q?ansform=20=E6=8E=A5=E5=8F=A3=EF=BC=8CTrainer=20=E5=8F=AF?= =?UTF-8?q?=E6=97=A0=E5=B7=AE=E5=88=AB=E8=B0=83=E7=94=A8=20-=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=2017=20=E4=B8=AA=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E8=A6=86=E7=9B=96=E5=90=84=E7=A7=8D=E5=9C=BA=E6=99=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/training/components/__init__.py | 10 + .../components/processors/__init__.py | 16 + .../components/processors/transforms.py | 275 ++++++++++++++++ tests/training/test_processors.py | 300 ++++++++++++++++++ 4 files changed, 601 insertions(+) create mode 100644 src/training/components/processors/__init__.py create mode 100644 src/training/components/processors/transforms.py create mode 100644 tests/training/test_processors.py diff --git a/src/training/components/__init__.py b/src/training/components/__init__.py index 9b2ebe4..3106fed 100644 --- a/src/training/components/__init__.py +++ b/src/training/components/__init__.py @@ -15,10 +15,20 @@ from src.training.components.selectors import ( StockFilterConfig, ) +# 数据处理器 +from src.training.components.processors import ( + CrossSectionalStandardScaler, + StandardScaler, + Winsorizer, +) + __all__ = [ "BaseModel", "BaseProcessor", "DateSplitter", "StockFilterConfig", "MarketCapSelectorConfig", + "StandardScaler", + "CrossSectionalStandardScaler", + "Winsorizer", ] diff --git a/src/training/components/processors/__init__.py b/src/training/components/processors/__init__.py new file mode 100644 index 0000000..c790cf6 --- /dev/null +++ b/src/training/components/processors/__init__.py @@ -0,0 +1,16 @@ +"""数据处理器子模块 + +包含数据预处理、转换等处理器实现。 +""" + +from src.training.components.processors.transforms import ( + CrossSectionalStandardScaler, + StandardScaler, + Winsorizer, +) + +__all__ = [ + "StandardScaler", + "CrossSectionalStandardScaler", + "Winsorizer", +] diff --git a/src/training/components/processors/transforms.py b/src/training/components/processors/transforms.py new file mode 100644 index 0000000..363e66b --- /dev/null +++ b/src/training/components/processors/transforms.py @@ -0,0 +1,275 @@ +"""数据处理器实现 + +包含标准化、缩尾等数据处理器。 +""" + +from typing import List, Optional + +import polars as pl + +from src.training.components.base import BaseProcessor +from src.training.registry import register_processor + + +@register_processor("standard_scaler") +class StandardScaler(BaseProcessor): + """标准化处理器(全局标准化) + + 在整个训练集上学习均值和标准差, + 然后应用到训练集和测试集。 + + 适用于需要全局统计量的场景。 + + Attributes: + exclude_cols: 不参与标准化的列名列表 + mean_: 学习到的均值字典 {列名: 均值} + std_: 学习到的标准差字典 {列名: 标准差} + """ + + name = "standard_scaler" + + def __init__(self, exclude_cols: Optional[List[str]] = None): + """初始化标准化处理器 + + Args: + exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"] + """ + self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] + self.mean_: dict = {} + self.std_: dict = {} + + def fit(self, X: pl.DataFrame) -> "StandardScaler": + """计算均值和标准差(仅在训练集上) + + Args: + X: 训练数据 + + Returns: + self + """ + numeric_cols = [ + c + for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + for col in numeric_cols: + self.mean_[col] = X[col].mean() + self.std_[col] = X[col].std() + + return self + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """标准化(使用训练集学到的参数) + + Args: + X: 待转换数据 + + Returns: + 标准化后的数据 + """ + expressions = [] + for col in X.columns: + if col in self.mean_ and col in self.std_: + # 避免除以0 + std_val = self.std_[col] if self.std_[col] != 0 else 1.0 + expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + + return X.select(expressions) + + +@register_processor("cs_standard_scaler") +class CrossSectionalStandardScaler(BaseProcessor): + """截面标准化处理器 + + 每天独立进行标准化:对当天所有股票的某一因子进行标准化。 + + 特点: + - 不需要 fit,每天独立计算当天的均值和标准差 + - 适用于截面因子,消除市值等行业差异 + - 公式:z = (x - mean_today) / std_today + + Attributes: + exclude_cols: 不参与标准化的列名列表 + date_col: 日期列名 + """ + + name = "cs_standard_scaler" + + def __init__( + self, + exclude_cols: Optional[List[str]] = None, + date_col: str = "trade_date", + ): + """初始化截面标准化处理器 + + Args: + exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"] + date_col: 日期列名 + """ + self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] + self.date_col = date_col + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """截面标准化 + + 按日期分组,每天独立计算均值和标准差并标准化。 + 不需要 fit,因为每天使用当天的统计量。 + + Args: + X: 待转换数据 + + Returns: + 标准化后的数据 + """ + numeric_cols = [ + c + for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + # 构建表达式列表 + expressions = [] + for col in X.columns: + if col in numeric_cols: + # 截面标准化:每天独立计算均值和标准差 + # 避免除以0,当std为0时设为1 + expr = ( + (pl.col(col) - pl.col(col).mean().over(self.date_col)) + / (pl.col(col).std().over(self.date_col) + 1e-10) + ).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + + return X.select(expressions) + + +@register_processor("winsorizer") +class Winsorizer(BaseProcessor): + """缩尾处理器 + + 对每一列的极端值进行截断处理。 + 可以全局截断(在整个训练集上学习分位数), + 也可以截面截断(每天独立处理)。 + + Attributes: + lower: 下分位数(如0.01表示1%分位数) + upper: 上分位数(如0.99表示99%分位数) + by_date: True=每天独立缩尾, False=全局缩尾 + date_col: 日期列名 + bounds_: 存储分位数边界(全局模式) + """ + + name = "winsorizer" + + def __init__( + self, + lower: float = 0.01, + upper: float = 0.99, + by_date: bool = False, + date_col: str = "trade_date", + ): + """初始化缩尾处理器 + + Args: + lower: 下分位数,默认0.01 + upper: 上分位数,默认0.99 + by_date: 每天独立缩尾,默认False(全局缩尾) + date_col: 日期列名 + + Raises: + ValueError: 分位数参数无效 + """ + if not 0 <= lower < upper <= 1: + raise ValueError( + f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内" + ) + + self.lower = lower + self.upper = upper + self.by_date = by_date + self.date_col = date_col + self.bounds_: dict = {} + + def fit(self, X: pl.DataFrame) -> "Winsorizer": + """学习分位数边界(仅在全局模式下) + + Args: + X: 训练数据 + + Returns: + self + """ + if not self.by_date: + numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()] + for col in numeric_cols: + self.bounds_[col] = { + "lower": X[col].quantile(self.lower), + "upper": X[col].quantile(self.upper), + } + return self + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """缩尾处理 + + Args: + X: 待转换数据 + + Returns: + 缩尾处理后的数据 + """ + if self.by_date: + return self._transform_by_date(X) + else: + return self._transform_global(X) + + def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame: + """全局缩尾(使用训练集学到的边界)""" + expressions = [] + for col in X.columns: + if col in self.bounds_: + lower = self.bounds_[col]["lower"] + upper = self.bounds_[col]["upper"] + expr = pl.col(col).clip(lower, upper).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + return X.select(expressions) + + def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame: + """每日独立缩尾""" + numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()] + + # 先计算每天的分位数 + lower_exprs = [ + pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower") + for col in numeric_cols + ] + upper_exprs = [ + pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper") + for col in numeric_cols + ] + + # 添加分位数列 + result = X.with_columns(lower_exprs + upper_exprs) + + # 执行缩尾 + clip_exprs = [] + for col in X.columns: + if col in numeric_cols: + clipped = ( + pl.col(col) + .clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper")) + .alias(col) + ) + clip_exprs.append(clipped) + else: + clip_exprs.append(pl.col(col)) + + result = result.select(clip_exprs) + + return result diff --git a/tests/training/test_processors.py b/tests/training/test_processors.py new file mode 100644 index 0000000..4f7ce1d --- /dev/null +++ b/tests/training/test_processors.py @@ -0,0 +1,300 @@ +"""测试数据处理器 + +验证 StandardScaler、CrossSectionalStandardScaler 和 Winsorizer 功能。 +""" + +import numpy as np +import polars as pl +import pytest + +from src.training.components.processors import ( + CrossSectionalStandardScaler, + StandardScaler, + Winsorizer, +) + + +class TestStandardScaler: + """StandardScaler 测试类""" + + def test_init_default(self): + """测试默认初始化""" + scaler = StandardScaler() + assert scaler.exclude_cols == ["ts_code", "trade_date"] + assert scaler.mean_ == {} + assert scaler.std_ == {} + + def test_init_custom_exclude(self): + """测试自定义排除列""" + scaler = StandardScaler(exclude_cols=["id", "date"]) + assert scaler.exclude_cols == ["id", "date"] + + def test_fit_transform(self): + """测试拟合和转换""" + data = pl.DataFrame( + { + "ts_code": ["A", "B", "C", "D"], + "trade_date": ["20240101"] * 4, + "value": [1.0, 2.0, 3.0, 4.0], + } + ) + + scaler = StandardScaler() + result = scaler.fit_transform(data) + + # 验证学习到的统计量 + assert scaler.mean_["value"] == 2.5 + assert scaler.std_["value"] == pytest.approx(1.290, rel=1e-2) + + # 验证转换结果 + expected_std = (np.array([1.0, 2.0, 3.0, 4.0]) - 2.5) / 1.290 + assert result["value"].to_list() == pytest.approx( + expected_std.tolist(), rel=1e-2 + ) + + def test_transform_use_fitted_params(self): + """测试转换使用拟合时的参数""" + train_data = pl.DataFrame( + { + "ts_code": ["A", "B", "C"], + "trade_date": ["20240101"] * 3, + "value": [1.0, 2.0, 3.0], + } + ) + + test_data = pl.DataFrame( + { + "ts_code": ["D"], + "trade_date": ["20240102"], + "value": [100.0], # 远离训练分布 + } + ) + + scaler = StandardScaler() + scaler.fit(train_data) + + # 使用训练集的均值(2.0)和标准差进行转换 + result = scaler.transform(test_data) + expected = (100.0 - 2.0) / 1.0 # 均值2.0, 标准差1.0 + assert result["value"][0] == pytest.approx(expected, rel=1e-2) + + def test_exclude_non_numeric(self): + """测试自动排除非数值列""" + data = pl.DataFrame( + { + "ts_code": ["A", "B"], + "trade_date": ["20240101", "20240102"], + "category": ["X", "Y"], # 字符串列 + "value": [1.0, 2.0], + } + ) + + scaler = StandardScaler() + result = scaler.fit_transform(data) + + # category 列应该原样保留 + assert result["category"].to_list() == ["X", "Y"] + # value 列应该被标准化 + assert "value" in scaler.mean_ + + def test_zero_std_handling(self): + """测试处理标准差为0的情况""" + data = pl.DataFrame( + { + "ts_code": ["A", "B"], + "trade_date": ["20240101", "20240102"], + "constant": [5.0, 5.0], # 常数列 + } + ) + + scaler = StandardScaler() + result = scaler.fit_transform(data) + + # 标准差为0时,结果应该为0(避免除以0) + assert result["constant"].to_list() == [0.0, 0.0] + + +class TestCrossSectionalStandardScaler: + """CrossSectionalStandardScaler 测试类""" + + def test_init_default(self): + """测试默认初始化""" + scaler = CrossSectionalStandardScaler() + assert scaler.exclude_cols == ["ts_code", "trade_date"] + assert scaler.date_col == "trade_date" + + def test_init_custom(self): + """测试自定义参数""" + scaler = CrossSectionalStandardScaler( + exclude_cols=["id"], + date_col="date", + ) + assert scaler.exclude_cols == ["id"] + assert scaler.date_col == "date" + + def test_transform_no_fit_needed(self): + """测试不需要 fit""" + data = pl.DataFrame( + { + "ts_code": ["A", "B"], + "trade_date": ["20240101", "20240101"], + "value": [1.0, 3.0], + } + ) + + scaler = CrossSectionalStandardScaler() + # 截面标准化不需要 fit + result = scaler.transform(data) + + # 当天均值=2.0, 样本标准差=sqrt(2)≈1.414, z-score=[-0.707, 0.707] + assert result["value"].to_list() == pytest.approx([-0.707, 0.707], rel=1e-2) + + def test_transform_by_date(self): + """测试按日期分组标准化""" + data = pl.DataFrame( + { + "ts_code": ["A", "B", "C", "D"], + "trade_date": ["20240101", "20240101", "20240102", "20240102"], + "value": [1.0, 3.0, 10.0, 30.0], + } + ) + + scaler = CrossSectionalStandardScaler() + result = scaler.transform(data) + + # 2024-01-01: 均值=2.0, 样本std≈1.414 -> [-0.707, 0.707] + # 2024-01-02: 均值=20.0, 样本std≈14.14 -> [-0.707, 0.707] + values = result["value"].to_list() + assert values[0] == pytest.approx(-0.707, abs=1e-2) + assert values[1] == pytest.approx(0.707, abs=1e-2) + assert values[2] == pytest.approx(-0.707, abs=1e-2) + assert values[3] == pytest.approx(0.707, abs=1e-2) + + def test_exclude_columns_preserved(self): + """测试排除列保持原样""" + data = pl.DataFrame( + { + "ts_code": ["A", "B"], + "trade_date": ["20240101", "20240101"], + "value": [1.0, 3.0], + } + ) + + scaler = CrossSectionalStandardScaler() + result = scaler.transform(data) + + assert result["ts_code"].to_list() == ["A", "B"] + assert result["trade_date"].to_list() == ["20240101", "20240101"] + + +class TestWinsorizer: + """Winsorizer 测试类""" + + def test_init_default(self): + """测试默认初始化""" + winsorizer = Winsorizer() + assert winsorizer.lower == 0.01 + assert winsorizer.upper == 0.99 + assert winsorizer.by_date is False + assert winsorizer.date_col == "trade_date" + + def test_init_custom(self): + """测试自定义参数""" + winsorizer = Winsorizer(lower=0.05, upper=0.95, by_date=True, date_col="date") + assert winsorizer.lower == 0.05 + assert winsorizer.upper == 0.95 + assert winsorizer.by_date is True + assert winsorizer.date_col == "date" + + def test_invalid_quantiles(self): + """测试无效的分位数参数""" + with pytest.raises(ValueError, match="lower .* 必须小于 upper"): + Winsorizer(lower=0.5, upper=0.3) + + with pytest.raises(ValueError, match="lower .* 必须小于 upper"): + Winsorizer(lower=-0.1, upper=0.5) + + with pytest.raises(ValueError, match="lower .* 必须小于 upper"): + Winsorizer(lower=0.5, upper=1.5) + + def test_global_winsorize(self): + """测试全局缩尾""" + # 创建包含极端值的数据 + values = list(range(1, 101)) # 1-100 + values[0] = -1000 # 极端小值 + values[-1] = 1000 # 极端大值 + + data = pl.DataFrame( + { + "ts_code": [f"A{i}" for i in range(100)], + "trade_date": ["20240101"] * 100, + "value": values, + } + ) + + winsorizer = Winsorizer(lower=0.01, upper=0.99) + result = winsorizer.fit_transform(data) + + # 1%分位数=2, 99%分位数=99 + # -1000 应该被截断为 2 + # 1000 应该被截断为 99 + result_values = result["value"].to_list() + assert result_values[0] == 2 # 原-1000被截断 + assert result_values[-1] == 99 # 原1000被截断 + assert result_values[1] == 2 # 原2保持不变 + assert result_values[98] == 99 # 原99保持不变 + + def test_by_date_winsorize(self): + """测试每日独立缩尾""" + data = pl.DataFrame( + { + "ts_code": ["A", "B", "C", "D", "E", "F"], + "trade_date": ["20240101"] * 3 + ["20240102"] * 3, + "value": [1.0, 50.0, 100.0, 200.0, 250.0, 300.0], + } + ) + + winsorizer = Winsorizer(lower=0.0, upper=0.5, by_date=True) + result = winsorizer.transform(data) + + # 每天独立处理: + # 2024-01-01: [1, 50, 100], 50%分位数=50 + # -> 截断为 [1, 50, 50] + # 2024-01-02: [200, 250, 300], 50%分位数=250 + # -> 截断为 [200, 250, 250] + result_values = result["value"].to_list() + assert result_values[0] == 1.0 + assert result_values[1] == 50.0 + assert result_values[2] == 50.0 # 被截断 + assert result_values[3] == 200.0 + assert result_values[4] == 250.0 + assert result_values[5] == 250.0 # 被截断 + + def test_global_transform_after_fit(self): + """测试全局模式下,转换使用拟合时的边界""" + train_data = pl.DataFrame( + { + "ts_code": ["A", "B", "C"], + "trade_date": ["20240101"] * 3, + "value": [1.0, 50.0, 100.0], + } + ) + + test_data = pl.DataFrame( + { + "ts_code": ["D"], + "trade_date": ["20240102"], + "value": [200.0], + } + ) + + winsorizer = Winsorizer(lower=0.0, upper=1.0) # 0%和100%分位数 + winsorizer.fit(train_data) + + # 使用训练集的分位数边界 [1, 100] + result = winsorizer.transform(test_data) + assert result["value"][0] == 100.0 # 被截断为100 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])