"""测试数据处理器 验证 StandardScaler、CrossSectionalStandardScaler 和 Winsorizer 功能。 """ import numpy as np import polars as pl import pytest from src.training.components.processors import ( CrossSectionalStandardScaler, StandardScaler, Winsorizer, ) class TestStandardScaler: """StandardScaler 测试类""" def test_init_default(self): """测试默认初始化""" scaler = StandardScaler() assert scaler.exclude_cols == ["ts_code", "trade_date"] assert scaler.mean_ == {} assert scaler.std_ == {} def test_init_custom_exclude(self): """测试自定义排除列""" scaler = StandardScaler(exclude_cols=["id", "date"]) assert scaler.exclude_cols == ["id", "date"] def test_fit_transform(self): """测试拟合和转换""" data = pl.DataFrame( { "ts_code": ["A", "B", "C", "D"], "trade_date": ["20240101"] * 4, "value": [1.0, 2.0, 3.0, 4.0], } ) scaler = StandardScaler() result = scaler.fit_transform(data) # 验证学习到的统计量 assert scaler.mean_["value"] == 2.5 assert scaler.std_["value"] == pytest.approx(1.290, rel=1e-2) # 验证转换结果 expected_std = (np.array([1.0, 2.0, 3.0, 4.0]) - 2.5) / 1.290 assert result["value"].to_list() == pytest.approx( expected_std.tolist(), rel=1e-2 ) def test_transform_use_fitted_params(self): """测试转换使用拟合时的参数""" train_data = pl.DataFrame( { "ts_code": ["A", "B", "C"], "trade_date": ["20240101"] * 3, "value": [1.0, 2.0, 3.0], } ) test_data = pl.DataFrame( { "ts_code": ["D"], "trade_date": ["20240102"], "value": [100.0], # 远离训练分布 } ) scaler = StandardScaler() scaler.fit(train_data) # 使用训练集的均值(2.0)和标准差进行转换 result = scaler.transform(test_data) expected = (100.0 - 2.0) / 1.0 # 均值2.0, 标准差1.0 assert result["value"][0] == pytest.approx(expected, rel=1e-2) def test_exclude_non_numeric(self): """测试自动排除非数值列""" data = pl.DataFrame( { "ts_code": ["A", "B"], "trade_date": ["20240101", "20240102"], "category": ["X", "Y"], # 字符串列 "value": [1.0, 2.0], } ) scaler = StandardScaler() result = scaler.fit_transform(data) # category 列应该原样保留 assert result["category"].to_list() == ["X", "Y"] # value 列应该被标准化 assert "value" in scaler.mean_ def test_zero_std_handling(self): """测试处理标准差为0的情况""" data = pl.DataFrame( { "ts_code": ["A", "B"], "trade_date": ["20240101", "20240102"], "constant": [5.0, 5.0], # 常数列 } ) scaler = StandardScaler() result = scaler.fit_transform(data) # 标准差为0时,结果应该为0(避免除以0) assert result["constant"].to_list() == [0.0, 0.0] class TestCrossSectionalStandardScaler: """CrossSectionalStandardScaler 测试类""" def test_init_default(self): """测试默认初始化""" scaler = CrossSectionalStandardScaler() assert scaler.exclude_cols == ["ts_code", "trade_date"] assert scaler.date_col == "trade_date" def test_init_custom(self): """测试自定义参数""" scaler = CrossSectionalStandardScaler( exclude_cols=["id"], date_col="date", ) assert scaler.exclude_cols == ["id"] assert scaler.date_col == "date" def test_transform_no_fit_needed(self): """测试不需要 fit""" data = pl.DataFrame( { "ts_code": ["A", "B"], "trade_date": ["20240101", "20240101"], "value": [1.0, 3.0], } ) scaler = CrossSectionalStandardScaler() # 截面标准化不需要 fit result = scaler.transform(data) # 当天均值=2.0, 样本标准差=sqrt(2)≈1.414, z-score=[-0.707, 0.707] assert result["value"].to_list() == pytest.approx([-0.707, 0.707], rel=1e-2) def test_transform_by_date(self): """测试按日期分组标准化""" data = pl.DataFrame( { "ts_code": ["A", "B", "C", "D"], "trade_date": ["20240101", "20240101", "20240102", "20240102"], "value": [1.0, 3.0, 10.0, 30.0], } ) scaler = CrossSectionalStandardScaler() result = scaler.transform(data) # 2024-01-01: 均值=2.0, 样本std≈1.414 -> [-0.707, 0.707] # 2024-01-02: 均值=20.0, 样本std≈14.14 -> [-0.707, 0.707] values = result["value"].to_list() assert values[0] == pytest.approx(-0.707, abs=1e-2) assert values[1] == pytest.approx(0.707, abs=1e-2) assert values[2] == pytest.approx(-0.707, abs=1e-2) assert values[3] == pytest.approx(0.707, abs=1e-2) def test_exclude_columns_preserved(self): """测试排除列保持原样""" data = pl.DataFrame( { "ts_code": ["A", "B"], "trade_date": ["20240101", "20240101"], "value": [1.0, 3.0], } ) scaler = CrossSectionalStandardScaler() result = scaler.transform(data) assert result["ts_code"].to_list() == ["A", "B"] assert result["trade_date"].to_list() == ["20240101", "20240101"] class TestWinsorizer: """Winsorizer 测试类""" def test_init_default(self): """测试默认初始化""" winsorizer = Winsorizer() assert winsorizer.lower == 0.01 assert winsorizer.upper == 0.99 assert winsorizer.by_date is False assert winsorizer.date_col == "trade_date" def test_init_custom(self): """测试自定义参数""" winsorizer = Winsorizer(lower=0.05, upper=0.95, by_date=True, date_col="date") assert winsorizer.lower == 0.05 assert winsorizer.upper == 0.95 assert winsorizer.by_date is True assert winsorizer.date_col == "date" def test_invalid_quantiles(self): """测试无效的分位数参数""" with pytest.raises(ValueError, match="lower .* 必须小于 upper"): Winsorizer(lower=0.5, upper=0.3) with pytest.raises(ValueError, match="lower .* 必须小于 upper"): Winsorizer(lower=-0.1, upper=0.5) with pytest.raises(ValueError, match="lower .* 必须小于 upper"): Winsorizer(lower=0.5, upper=1.5) def test_global_winsorize(self): """测试全局缩尾""" # 创建包含极端值的数据 values = list(range(1, 101)) # 1-100 values[0] = -1000 # 极端小值 values[-1] = 1000 # 极端大值 data = pl.DataFrame( { "ts_code": [f"A{i}" for i in range(100)], "trade_date": ["20240101"] * 100, "value": values, } ) winsorizer = Winsorizer(lower=0.01, upper=0.99) result = winsorizer.fit_transform(data) # 1%分位数=2, 99%分位数=99 # -1000 应该被截断为 2 # 1000 应该被截断为 99 result_values = result["value"].to_list() assert result_values[0] == 2 # 原-1000被截断 assert result_values[-1] == 99 # 原1000被截断 assert result_values[1] == 2 # 原2保持不变 assert result_values[98] == 99 # 原99保持不变 def test_by_date_winsorize(self): """测试每日独立缩尾""" data = pl.DataFrame( { "ts_code": ["A", "B", "C", "D", "E", "F"], "trade_date": ["20240101"] * 3 + ["20240102"] * 3, "value": [1.0, 50.0, 100.0, 200.0, 250.0, 300.0], } ) winsorizer = Winsorizer(lower=0.0, upper=0.5, by_date=True) result = winsorizer.transform(data) # 每天独立处理: # 2024-01-01: [1, 50, 100], 50%分位数=50 # -> 截断为 [1, 50, 50] # 2024-01-02: [200, 250, 300], 50%分位数=250 # -> 截断为 [200, 250, 250] result_values = result["value"].to_list() assert result_values[0] == 1.0 assert result_values[1] == 50.0 assert result_values[2] == 50.0 # 被截断 assert result_values[3] == 200.0 assert result_values[4] == 250.0 assert result_values[5] == 250.0 # 被截断 def test_global_transform_after_fit(self): """测试全局模式下,转换使用拟合时的边界""" train_data = pl.DataFrame( { "ts_code": ["A", "B", "C"], "trade_date": ["20240101"] * 3, "value": [1.0, 50.0, 100.0], } ) test_data = pl.DataFrame( { "ts_code": ["D"], "trade_date": ["20240102"], "value": [200.0], } ) winsorizer = Winsorizer(lower=0.0, upper=1.0) # 0%和100%分位数 winsorizer.fit(train_data) # 使用训练集的分位数边界 [1, 100] result = winsorizer.transform(test_data) assert result["value"][0] == 100.0 # 被截断为100 if __name__ == "__main__": pytest.main([__file__, "-v"])