Files
ProStock/tests/training/test_processors.py

301 lines
9.6 KiB
Python
Raw Normal View History

"""测试数据处理器
验证 StandardScalerCrossSectionalStandardScaler Winsorizer 功能
"""
import numpy as np
import polars as pl
import pytest
from src.training.components.processors import (
CrossSectionalStandardScaler,
StandardScaler,
Winsorizer,
)
class TestStandardScaler:
"""StandardScaler 测试类"""
def test_init_default(self):
"""测试默认初始化"""
scaler = StandardScaler()
assert scaler.exclude_cols == ["ts_code", "trade_date"]
assert scaler.mean_ == {}
assert scaler.std_ == {}
def test_init_custom_exclude(self):
"""测试自定义排除列"""
scaler = StandardScaler(exclude_cols=["id", "date"])
assert scaler.exclude_cols == ["id", "date"]
def test_fit_transform(self):
"""测试拟合和转换"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D"],
"trade_date": ["20240101"] * 4,
"value": [1.0, 2.0, 3.0, 4.0],
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# 验证学习到的统计量
assert scaler.mean_["value"] == 2.5
assert scaler.std_["value"] == pytest.approx(1.290, rel=1e-2)
# 验证转换结果
expected_std = (np.array([1.0, 2.0, 3.0, 4.0]) - 2.5) / 1.290
assert result["value"].to_list() == pytest.approx(
expected_std.tolist(), rel=1e-2
)
def test_transform_use_fitted_params(self):
"""测试转换使用拟合时的参数"""
train_data = pl.DataFrame(
{
"ts_code": ["A", "B", "C"],
"trade_date": ["20240101"] * 3,
"value": [1.0, 2.0, 3.0],
}
)
test_data = pl.DataFrame(
{
"ts_code": ["D"],
"trade_date": ["20240102"],
"value": [100.0], # 远离训练分布
}
)
scaler = StandardScaler()
scaler.fit(train_data)
# 使用训练集的均值(2.0)和标准差进行转换
result = scaler.transform(test_data)
expected = (100.0 - 2.0) / 1.0 # 均值2.0, 标准差1.0
assert result["value"][0] == pytest.approx(expected, rel=1e-2)
def test_exclude_non_numeric(self):
"""测试自动排除非数值列"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240102"],
"category": ["X", "Y"], # 字符串列
"value": [1.0, 2.0],
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# category 列应该原样保留
assert result["category"].to_list() == ["X", "Y"]
# value 列应该被标准化
assert "value" in scaler.mean_
def test_zero_std_handling(self):
"""测试处理标准差为0的情况"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240102"],
"constant": [5.0, 5.0], # 常数列
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# 标准差为0时结果应该为0避免除以0
assert result["constant"].to_list() == [0.0, 0.0]
class TestCrossSectionalStandardScaler:
"""CrossSectionalStandardScaler 测试类"""
def test_init_default(self):
"""测试默认初始化"""
scaler = CrossSectionalStandardScaler()
assert scaler.exclude_cols == ["ts_code", "trade_date"]
assert scaler.date_col == "trade_date"
def test_init_custom(self):
"""测试自定义参数"""
scaler = CrossSectionalStandardScaler(
exclude_cols=["id"],
date_col="date",
)
assert scaler.exclude_cols == ["id"]
assert scaler.date_col == "date"
def test_transform_no_fit_needed(self):
"""测试不需要 fit"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240101"],
"value": [1.0, 3.0],
}
)
scaler = CrossSectionalStandardScaler()
# 截面标准化不需要 fit
result = scaler.transform(data)
# 当天均值=2.0, 样本标准差=sqrt(2)≈1.414, z-score=[-0.707, 0.707]
assert result["value"].to_list() == pytest.approx([-0.707, 0.707], rel=1e-2)
def test_transform_by_date(self):
"""测试按日期分组标准化"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D"],
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
"value": [1.0, 3.0, 10.0, 30.0],
}
)
scaler = CrossSectionalStandardScaler()
result = scaler.transform(data)
# 2024-01-01: 均值=2.0, 样本std≈1.414 -> [-0.707, 0.707]
# 2024-01-02: 均值=20.0, 样本std≈14.14 -> [-0.707, 0.707]
values = result["value"].to_list()
assert values[0] == pytest.approx(-0.707, abs=1e-2)
assert values[1] == pytest.approx(0.707, abs=1e-2)
assert values[2] == pytest.approx(-0.707, abs=1e-2)
assert values[3] == pytest.approx(0.707, abs=1e-2)
def test_exclude_columns_preserved(self):
"""测试排除列保持原样"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240101"],
"value": [1.0, 3.0],
}
)
scaler = CrossSectionalStandardScaler()
result = scaler.transform(data)
assert result["ts_code"].to_list() == ["A", "B"]
assert result["trade_date"].to_list() == ["20240101", "20240101"]
class TestWinsorizer:
"""Winsorizer 测试类"""
def test_init_default(self):
"""测试默认初始化"""
winsorizer = Winsorizer()
assert winsorizer.lower == 0.01
assert winsorizer.upper == 0.99
assert winsorizer.by_date is False
assert winsorizer.date_col == "trade_date"
def test_init_custom(self):
"""测试自定义参数"""
winsorizer = Winsorizer(lower=0.05, upper=0.95, by_date=True, date_col="date")
assert winsorizer.lower == 0.05
assert winsorizer.upper == 0.95
assert winsorizer.by_date is True
assert winsorizer.date_col == "date"
def test_invalid_quantiles(self):
"""测试无效的分位数参数"""
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=0.5, upper=0.3)
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=-0.1, upper=0.5)
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=0.5, upper=1.5)
def test_global_winsorize(self):
"""测试全局缩尾"""
# 创建包含极端值的数据
values = list(range(1, 101)) # 1-100
values[0] = -1000 # 极端小值
values[-1] = 1000 # 极端大值
data = pl.DataFrame(
{
"ts_code": [f"A{i}" for i in range(100)],
"trade_date": ["20240101"] * 100,
"value": values,
}
)
winsorizer = Winsorizer(lower=0.01, upper=0.99)
result = winsorizer.fit_transform(data)
# 1%分位数=2, 99%分位数=99
# -1000 应该被截断为 2
# 1000 应该被截断为 99
result_values = result["value"].to_list()
assert result_values[0] == 2 # 原-1000被截断
assert result_values[-1] == 99 # 原1000被截断
assert result_values[1] == 2 # 原2保持不变
assert result_values[98] == 99 # 原99保持不变
def test_by_date_winsorize(self):
"""测试每日独立缩尾"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D", "E", "F"],
"trade_date": ["20240101"] * 3 + ["20240102"] * 3,
"value": [1.0, 50.0, 100.0, 200.0, 250.0, 300.0],
}
)
winsorizer = Winsorizer(lower=0.0, upper=0.5, by_date=True)
result = winsorizer.transform(data)
# 每天独立处理:
# 2024-01-01: [1, 50, 100], 50%分位数=50
# -> 截断为 [1, 50, 50]
# 2024-01-02: [200, 250, 300], 50%分位数=250
# -> 截断为 [200, 250, 250]
result_values = result["value"].to_list()
assert result_values[0] == 1.0
assert result_values[1] == 50.0
assert result_values[2] == 50.0 # 被截断
assert result_values[3] == 200.0
assert result_values[4] == 250.0
assert result_values[5] == 250.0 # 被截断
def test_global_transform_after_fit(self):
"""测试全局模式下,转换使用拟合时的边界"""
train_data = pl.DataFrame(
{
"ts_code": ["A", "B", "C"],
"trade_date": ["20240101"] * 3,
"value": [1.0, 50.0, 100.0],
}
)
test_data = pl.DataFrame(
{
"ts_code": ["D"],
"trade_date": ["20240102"],
"value": [200.0],
}
)
winsorizer = Winsorizer(lower=0.0, upper=1.0) # 0%和100%分位数
winsorizer.fit(train_data)
# 使用训练集的分位数边界 [1, 100]
result = winsorizer.transform(test_data)
assert result["value"][0] == 100.0 # 被截断为100
if __name__ == "__main__":
pytest.main([__file__, "-v"])