Files
ProStock/tests/training/test_processors.py
liaozhaorun 9ca1deae56 feat(training): 实现数据处理器
- 新增 StandardScaler:全局标准化,训练集学习参数,测试集复用
- 新增 CrossSectionalStandardScaler:截面标准化,每天独立计算
- 新增 Winsorizer:支持全局/截面两种缩尾模式
- 处理器统一遵循 fit/transform 接口,Trainer 可无差别调用
- 添加 17 个单元测试覆盖各种场景
2026-03-03 22:23:43 +08:00

301 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试数据处理器
验证 StandardScaler、CrossSectionalStandardScaler 和 Winsorizer 功能。
"""
import numpy as np
import polars as pl
import pytest
from src.training.components.processors import (
CrossSectionalStandardScaler,
StandardScaler,
Winsorizer,
)
class TestStandardScaler:
"""StandardScaler 测试类"""
def test_init_default(self):
"""测试默认初始化"""
scaler = StandardScaler()
assert scaler.exclude_cols == ["ts_code", "trade_date"]
assert scaler.mean_ == {}
assert scaler.std_ == {}
def test_init_custom_exclude(self):
"""测试自定义排除列"""
scaler = StandardScaler(exclude_cols=["id", "date"])
assert scaler.exclude_cols == ["id", "date"]
def test_fit_transform(self):
"""测试拟合和转换"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D"],
"trade_date": ["20240101"] * 4,
"value": [1.0, 2.0, 3.0, 4.0],
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# 验证学习到的统计量
assert scaler.mean_["value"] == 2.5
assert scaler.std_["value"] == pytest.approx(1.290, rel=1e-2)
# 验证转换结果
expected_std = (np.array([1.0, 2.0, 3.0, 4.0]) - 2.5) / 1.290
assert result["value"].to_list() == pytest.approx(
expected_std.tolist(), rel=1e-2
)
def test_transform_use_fitted_params(self):
"""测试转换使用拟合时的参数"""
train_data = pl.DataFrame(
{
"ts_code": ["A", "B", "C"],
"trade_date": ["20240101"] * 3,
"value": [1.0, 2.0, 3.0],
}
)
test_data = pl.DataFrame(
{
"ts_code": ["D"],
"trade_date": ["20240102"],
"value": [100.0], # 远离训练分布
}
)
scaler = StandardScaler()
scaler.fit(train_data)
# 使用训练集的均值(2.0)和标准差进行转换
result = scaler.transform(test_data)
expected = (100.0 - 2.0) / 1.0 # 均值2.0, 标准差1.0
assert result["value"][0] == pytest.approx(expected, rel=1e-2)
def test_exclude_non_numeric(self):
"""测试自动排除非数值列"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240102"],
"category": ["X", "Y"], # 字符串列
"value": [1.0, 2.0],
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# category 列应该原样保留
assert result["category"].to_list() == ["X", "Y"]
# value 列应该被标准化
assert "value" in scaler.mean_
def test_zero_std_handling(self):
"""测试处理标准差为0的情况"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240102"],
"constant": [5.0, 5.0], # 常数列
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# 标准差为0时结果应该为0避免除以0
assert result["constant"].to_list() == [0.0, 0.0]
class TestCrossSectionalStandardScaler:
"""CrossSectionalStandardScaler 测试类"""
def test_init_default(self):
"""测试默认初始化"""
scaler = CrossSectionalStandardScaler()
assert scaler.exclude_cols == ["ts_code", "trade_date"]
assert scaler.date_col == "trade_date"
def test_init_custom(self):
"""测试自定义参数"""
scaler = CrossSectionalStandardScaler(
exclude_cols=["id"],
date_col="date",
)
assert scaler.exclude_cols == ["id"]
assert scaler.date_col == "date"
def test_transform_no_fit_needed(self):
"""测试不需要 fit"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240101"],
"value": [1.0, 3.0],
}
)
scaler = CrossSectionalStandardScaler()
# 截面标准化不需要 fit
result = scaler.transform(data)
# 当天均值=2.0, 样本标准差=sqrt(2)≈1.414, z-score=[-0.707, 0.707]
assert result["value"].to_list() == pytest.approx([-0.707, 0.707], rel=1e-2)
def test_transform_by_date(self):
"""测试按日期分组标准化"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D"],
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
"value": [1.0, 3.0, 10.0, 30.0],
}
)
scaler = CrossSectionalStandardScaler()
result = scaler.transform(data)
# 2024-01-01: 均值=2.0, 样本std≈1.414 -> [-0.707, 0.707]
# 2024-01-02: 均值=20.0, 样本std≈14.14 -> [-0.707, 0.707]
values = result["value"].to_list()
assert values[0] == pytest.approx(-0.707, abs=1e-2)
assert values[1] == pytest.approx(0.707, abs=1e-2)
assert values[2] == pytest.approx(-0.707, abs=1e-2)
assert values[3] == pytest.approx(0.707, abs=1e-2)
def test_exclude_columns_preserved(self):
"""测试排除列保持原样"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240101"],
"value": [1.0, 3.0],
}
)
scaler = CrossSectionalStandardScaler()
result = scaler.transform(data)
assert result["ts_code"].to_list() == ["A", "B"]
assert result["trade_date"].to_list() == ["20240101", "20240101"]
class TestWinsorizer:
"""Winsorizer 测试类"""
def test_init_default(self):
"""测试默认初始化"""
winsorizer = Winsorizer()
assert winsorizer.lower == 0.01
assert winsorizer.upper == 0.99
assert winsorizer.by_date is False
assert winsorizer.date_col == "trade_date"
def test_init_custom(self):
"""测试自定义参数"""
winsorizer = Winsorizer(lower=0.05, upper=0.95, by_date=True, date_col="date")
assert winsorizer.lower == 0.05
assert winsorizer.upper == 0.95
assert winsorizer.by_date is True
assert winsorizer.date_col == "date"
def test_invalid_quantiles(self):
"""测试无效的分位数参数"""
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=0.5, upper=0.3)
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=-0.1, upper=0.5)
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=0.5, upper=1.5)
def test_global_winsorize(self):
"""测试全局缩尾"""
# 创建包含极端值的数据
values = list(range(1, 101)) # 1-100
values[0] = -1000 # 极端小值
values[-1] = 1000 # 极端大值
data = pl.DataFrame(
{
"ts_code": [f"A{i}" for i in range(100)],
"trade_date": ["20240101"] * 100,
"value": values,
}
)
winsorizer = Winsorizer(lower=0.01, upper=0.99)
result = winsorizer.fit_transform(data)
# 1%分位数=2, 99%分位数=99
# -1000 应该被截断为 2
# 1000 应该被截断为 99
result_values = result["value"].to_list()
assert result_values[0] == 2 # 原-1000被截断
assert result_values[-1] == 99 # 原1000被截断
assert result_values[1] == 2 # 原2保持不变
assert result_values[98] == 99 # 原99保持不变
def test_by_date_winsorize(self):
"""测试每日独立缩尾"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D", "E", "F"],
"trade_date": ["20240101"] * 3 + ["20240102"] * 3,
"value": [1.0, 50.0, 100.0, 200.0, 250.0, 300.0],
}
)
winsorizer = Winsorizer(lower=0.0, upper=0.5, by_date=True)
result = winsorizer.transform(data)
# 每天独立处理:
# 2024-01-01: [1, 50, 100], 50%分位数=50
# -> 截断为 [1, 50, 50]
# 2024-01-02: [200, 250, 300], 50%分位数=250
# -> 截断为 [200, 250, 250]
result_values = result["value"].to_list()
assert result_values[0] == 1.0
assert result_values[1] == 50.0
assert result_values[2] == 50.0 # 被截断
assert result_values[3] == 200.0
assert result_values[4] == 250.0
assert result_values[5] == 250.0 # 被截断
def test_global_transform_after_fit(self):
"""测试全局模式下,转换使用拟合时的边界"""
train_data = pl.DataFrame(
{
"ts_code": ["A", "B", "C"],
"trade_date": ["20240101"] * 3,
"value": [1.0, 50.0, 100.0],
}
)
test_data = pl.DataFrame(
{
"ts_code": ["D"],
"trade_date": ["20240102"],
"value": [200.0],
}
)
winsorizer = Winsorizer(lower=0.0, upper=1.0) # 0%和100%分位数
winsorizer.fit(train_data)
# 使用训练集的分位数边界 [1, 100]
result = winsorizer.transform(test_data)
assert result["value"][0] == 100.0 # 被截断为100
if __name__ == "__main__":
pytest.main([__file__, "-v"])