feat(training): 实现数据处理器
- 新增 StandardScaler:全局标准化,训练集学习参数,测试集复用 - 新增 CrossSectionalStandardScaler:截面标准化,每天独立计算 - 新增 Winsorizer:支持全局/截面两种缩尾模式 - 处理器统一遵循 fit/transform 接口,Trainer 可无差别调用 - 添加 17 个单元测试覆盖各种场景
This commit is contained in:
@@ -15,10 +15,20 @@ from src.training.components.selectors import (
|
|||||||
StockFilterConfig,
|
StockFilterConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 数据处理器
|
||||||
|
from src.training.components.processors import (
|
||||||
|
CrossSectionalStandardScaler,
|
||||||
|
StandardScaler,
|
||||||
|
Winsorizer,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
"BaseProcessor",
|
"BaseProcessor",
|
||||||
"DateSplitter",
|
"DateSplitter",
|
||||||
"StockFilterConfig",
|
"StockFilterConfig",
|
||||||
"MarketCapSelectorConfig",
|
"MarketCapSelectorConfig",
|
||||||
|
"StandardScaler",
|
||||||
|
"CrossSectionalStandardScaler",
|
||||||
|
"Winsorizer",
|
||||||
]
|
]
|
||||||
|
|||||||
16
src/training/components/processors/__init__.py
Normal file
16
src/training/components/processors/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""数据处理器子模块
|
||||||
|
|
||||||
|
包含数据预处理、转换等处理器实现。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.training.components.processors.transforms import (
|
||||||
|
CrossSectionalStandardScaler,
|
||||||
|
StandardScaler,
|
||||||
|
Winsorizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"StandardScaler",
|
||||||
|
"CrossSectionalStandardScaler",
|
||||||
|
"Winsorizer",
|
||||||
|
]
|
||||||
275
src/training/components/processors/transforms.py
Normal file
275
src/training/components/processors/transforms.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
"""数据处理器实现
|
||||||
|
|
||||||
|
包含标准化、缩尾等数据处理器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.training.components.base import BaseProcessor
|
||||||
|
from src.training.registry import register_processor
|
||||||
|
|
||||||
|
|
||||||
|
@register_processor("standard_scaler")
|
||||||
|
class StandardScaler(BaseProcessor):
|
||||||
|
"""标准化处理器(全局标准化)
|
||||||
|
|
||||||
|
在整个训练集上学习均值和标准差,
|
||||||
|
然后应用到训练集和测试集。
|
||||||
|
|
||||||
|
适用于需要全局统计量的场景。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
exclude_cols: 不参与标准化的列名列表
|
||||||
|
mean_: 学习到的均值字典 {列名: 均值}
|
||||||
|
std_: 学习到的标准差字典 {列名: 标准差}
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "standard_scaler"
|
||||||
|
|
||||||
|
def __init__(self, exclude_cols: Optional[List[str]] = None):
|
||||||
|
"""初始化标准化处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
||||||
|
"""
|
||||||
|
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||||
|
self.mean_: dict = {}
|
||||||
|
self.std_: dict = {}
|
||||||
|
|
||||||
|
def fit(self, X: pl.DataFrame) -> "StandardScaler":
|
||||||
|
"""计算均值和标准差(仅在训练集上)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self
|
||||||
|
"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
for col in numeric_cols:
|
||||||
|
self.mean_[col] = X[col].mean()
|
||||||
|
self.std_[col] = X[col].std()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""标准化(使用训练集学到的参数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 待转换数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化后的数据
|
||||||
|
"""
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in self.mean_ and col in self.std_:
|
||||||
|
# 避免除以0
|
||||||
|
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
|
||||||
|
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
|
||||||
|
@register_processor("cs_standard_scaler")
|
||||||
|
class CrossSectionalStandardScaler(BaseProcessor):
|
||||||
|
"""截面标准化处理器
|
||||||
|
|
||||||
|
每天独立进行标准化:对当天所有股票的某一因子进行标准化。
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 不需要 fit,每天独立计算当天的均值和标准差
|
||||||
|
- 适用于截面因子,消除市值等行业差异
|
||||||
|
- 公式:z = (x - mean_today) / std_today
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
exclude_cols: 不参与标准化的列名列表
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "cs_standard_scaler"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
exclude_cols: Optional[List[str]] = None,
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
):
|
||||||
|
"""初始化截面标准化处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||||
|
self.date_col = date_col
|
||||||
|
|
||||||
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""截面标准化
|
||||||
|
|
||||||
|
按日期分组,每天独立计算均值和标准差并标准化。
|
||||||
|
不需要 fit,因为每天使用当天的统计量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 待转换数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化后的数据
|
||||||
|
"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
# 构建表达式列表
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
# 截面标准化:每天独立计算均值和标准差
|
||||||
|
# 避免除以0,当std为0时设为1
|
||||||
|
expr = (
|
||||||
|
(pl.col(col) - pl.col(col).mean().over(self.date_col))
|
||||||
|
/ (pl.col(col).std().over(self.date_col) + 1e-10)
|
||||||
|
).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
|
||||||
|
@register_processor("winsorizer")
|
||||||
|
class Winsorizer(BaseProcessor):
|
||||||
|
"""缩尾处理器
|
||||||
|
|
||||||
|
对每一列的极端值进行截断处理。
|
||||||
|
可以全局截断(在整个训练集上学习分位数),
|
||||||
|
也可以截面截断(每天独立处理)。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lower: 下分位数(如0.01表示1%分位数)
|
||||||
|
upper: 上分位数(如0.99表示99%分位数)
|
||||||
|
by_date: True=每天独立缩尾, False=全局缩尾
|
||||||
|
date_col: 日期列名
|
||||||
|
bounds_: 存储分位数边界(全局模式)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "winsorizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lower: float = 0.01,
|
||||||
|
upper: float = 0.99,
|
||||||
|
by_date: bool = False,
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
):
|
||||||
|
"""初始化缩尾处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lower: 下分位数,默认0.01
|
||||||
|
upper: 上分位数,默认0.99
|
||||||
|
by_date: 每天独立缩尾,默认False(全局缩尾)
|
||||||
|
date_col: 日期列名
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 分位数参数无效
|
||||||
|
"""
|
||||||
|
if not 0 <= lower < upper <= 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.by_date = by_date
|
||||||
|
self.date_col = date_col
|
||||||
|
self.bounds_: dict = {}
|
||||||
|
|
||||||
|
def fit(self, X: pl.DataFrame) -> "Winsorizer":
|
||||||
|
"""学习分位数边界(仅在全局模式下)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self
|
||||||
|
"""
|
||||||
|
if not self.by_date:
|
||||||
|
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
||||||
|
for col in numeric_cols:
|
||||||
|
self.bounds_[col] = {
|
||||||
|
"lower": X[col].quantile(self.lower),
|
||||||
|
"upper": X[col].quantile(self.upper),
|
||||||
|
}
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""缩尾处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 待转换数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缩尾处理后的数据
|
||||||
|
"""
|
||||||
|
if self.by_date:
|
||||||
|
return self._transform_by_date(X)
|
||||||
|
else:
|
||||||
|
return self._transform_global(X)
|
||||||
|
|
||||||
|
def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""全局缩尾(使用训练集学到的边界)"""
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in self.bounds_:
|
||||||
|
lower = self.bounds_[col]["lower"]
|
||||||
|
upper = self.bounds_[col]["upper"]
|
||||||
|
expr = pl.col(col).clip(lower, upper).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""每日独立缩尾"""
|
||||||
|
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
||||||
|
|
||||||
|
# 先计算每天的分位数
|
||||||
|
lower_exprs = [
|
||||||
|
pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower")
|
||||||
|
for col in numeric_cols
|
||||||
|
]
|
||||||
|
upper_exprs = [
|
||||||
|
pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper")
|
||||||
|
for col in numeric_cols
|
||||||
|
]
|
||||||
|
|
||||||
|
# 添加分位数列
|
||||||
|
result = X.with_columns(lower_exprs + upper_exprs)
|
||||||
|
|
||||||
|
# 执行缩尾
|
||||||
|
clip_exprs = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
clipped = (
|
||||||
|
pl.col(col)
|
||||||
|
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
||||||
|
.alias(col)
|
||||||
|
)
|
||||||
|
clip_exprs.append(clipped)
|
||||||
|
else:
|
||||||
|
clip_exprs.append(pl.col(col))
|
||||||
|
|
||||||
|
result = result.select(clip_exprs)
|
||||||
|
|
||||||
|
return result
|
||||||
300
tests/training/test_processors.py
Normal file
300
tests/training/test_processors.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""测试数据处理器
|
||||||
|
|
||||||
|
验证 StandardScaler、CrossSectionalStandardScaler 和 Winsorizer 功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.training.components.processors import (
|
||||||
|
CrossSectionalStandardScaler,
|
||||||
|
StandardScaler,
|
||||||
|
Winsorizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStandardScaler:
|
||||||
|
"""StandardScaler 测试类"""
|
||||||
|
|
||||||
|
def test_init_default(self):
|
||||||
|
"""测试默认初始化"""
|
||||||
|
scaler = StandardScaler()
|
||||||
|
assert scaler.exclude_cols == ["ts_code", "trade_date"]
|
||||||
|
assert scaler.mean_ == {}
|
||||||
|
assert scaler.std_ == {}
|
||||||
|
|
||||||
|
def test_init_custom_exclude(self):
|
||||||
|
"""测试自定义排除列"""
|
||||||
|
scaler = StandardScaler(exclude_cols=["id", "date"])
|
||||||
|
assert scaler.exclude_cols == ["id", "date"]
|
||||||
|
|
||||||
|
def test_fit_transform(self):
|
||||||
|
"""测试拟合和转换"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C", "D"],
|
||||||
|
"trade_date": ["20240101"] * 4,
|
||||||
|
"value": [1.0, 2.0, 3.0, 4.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
result = scaler.fit_transform(data)
|
||||||
|
|
||||||
|
# 验证学习到的统计量
|
||||||
|
assert scaler.mean_["value"] == 2.5
|
||||||
|
assert scaler.std_["value"] == pytest.approx(1.290, rel=1e-2)
|
||||||
|
|
||||||
|
# 验证转换结果
|
||||||
|
expected_std = (np.array([1.0, 2.0, 3.0, 4.0]) - 2.5) / 1.290
|
||||||
|
assert result["value"].to_list() == pytest.approx(
|
||||||
|
expected_std.tolist(), rel=1e-2
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_transform_use_fitted_params(self):
|
||||||
|
"""测试转换使用拟合时的参数"""
|
||||||
|
train_data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"value": [1.0, 2.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["D"],
|
||||||
|
"trade_date": ["20240102"],
|
||||||
|
"value": [100.0], # 远离训练分布
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaler.fit(train_data)
|
||||||
|
|
||||||
|
# 使用训练集的均值(2.0)和标准差进行转换
|
||||||
|
result = scaler.transform(test_data)
|
||||||
|
expected = (100.0 - 2.0) / 1.0 # 均值2.0, 标准差1.0
|
||||||
|
assert result["value"][0] == pytest.approx(expected, rel=1e-2)
|
||||||
|
|
||||||
|
def test_exclude_non_numeric(self):
|
||||||
|
"""测试自动排除非数值列"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B"],
|
||||||
|
"trade_date": ["20240101", "20240102"],
|
||||||
|
"category": ["X", "Y"], # 字符串列
|
||||||
|
"value": [1.0, 2.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
result = scaler.fit_transform(data)
|
||||||
|
|
||||||
|
# category 列应该原样保留
|
||||||
|
assert result["category"].to_list() == ["X", "Y"]
|
||||||
|
# value 列应该被标准化
|
||||||
|
assert "value" in scaler.mean_
|
||||||
|
|
||||||
|
def test_zero_std_handling(self):
|
||||||
|
"""测试处理标准差为0的情况"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B"],
|
||||||
|
"trade_date": ["20240101", "20240102"],
|
||||||
|
"constant": [5.0, 5.0], # 常数列
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
result = scaler.fit_transform(data)
|
||||||
|
|
||||||
|
# 标准差为0时,结果应该为0(避免除以0)
|
||||||
|
assert result["constant"].to_list() == [0.0, 0.0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossSectionalStandardScaler:
|
||||||
|
"""CrossSectionalStandardScaler 测试类"""
|
||||||
|
|
||||||
|
def test_init_default(self):
|
||||||
|
"""测试默认初始化"""
|
||||||
|
scaler = CrossSectionalStandardScaler()
|
||||||
|
assert scaler.exclude_cols == ["ts_code", "trade_date"]
|
||||||
|
assert scaler.date_col == "trade_date"
|
||||||
|
|
||||||
|
def test_init_custom(self):
|
||||||
|
"""测试自定义参数"""
|
||||||
|
scaler = CrossSectionalStandardScaler(
|
||||||
|
exclude_cols=["id"],
|
||||||
|
date_col="date",
|
||||||
|
)
|
||||||
|
assert scaler.exclude_cols == ["id"]
|
||||||
|
assert scaler.date_col == "date"
|
||||||
|
|
||||||
|
def test_transform_no_fit_needed(self):
|
||||||
|
"""测试不需要 fit"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B"],
|
||||||
|
"trade_date": ["20240101", "20240101"],
|
||||||
|
"value": [1.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = CrossSectionalStandardScaler()
|
||||||
|
# 截面标准化不需要 fit
|
||||||
|
result = scaler.transform(data)
|
||||||
|
|
||||||
|
# 当天均值=2.0, 样本标准差=sqrt(2)≈1.414, z-score=[-0.707, 0.707]
|
||||||
|
assert result["value"].to_list() == pytest.approx([-0.707, 0.707], rel=1e-2)
|
||||||
|
|
||||||
|
def test_transform_by_date(self):
|
||||||
|
"""测试按日期分组标准化"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C", "D"],
|
||||||
|
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
|
||||||
|
"value": [1.0, 3.0, 10.0, 30.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = CrossSectionalStandardScaler()
|
||||||
|
result = scaler.transform(data)
|
||||||
|
|
||||||
|
# 2024-01-01: 均值=2.0, 样本std≈1.414 -> [-0.707, 0.707]
|
||||||
|
# 2024-01-02: 均值=20.0, 样本std≈14.14 -> [-0.707, 0.707]
|
||||||
|
values = result["value"].to_list()
|
||||||
|
assert values[0] == pytest.approx(-0.707, abs=1e-2)
|
||||||
|
assert values[1] == pytest.approx(0.707, abs=1e-2)
|
||||||
|
assert values[2] == pytest.approx(-0.707, abs=1e-2)
|
||||||
|
assert values[3] == pytest.approx(0.707, abs=1e-2)
|
||||||
|
|
||||||
|
def test_exclude_columns_preserved(self):
|
||||||
|
"""测试排除列保持原样"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B"],
|
||||||
|
"trade_date": ["20240101", "20240101"],
|
||||||
|
"value": [1.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = CrossSectionalStandardScaler()
|
||||||
|
result = scaler.transform(data)
|
||||||
|
|
||||||
|
assert result["ts_code"].to_list() == ["A", "B"]
|
||||||
|
assert result["trade_date"].to_list() == ["20240101", "20240101"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestWinsorizer:
|
||||||
|
"""Winsorizer 测试类"""
|
||||||
|
|
||||||
|
def test_init_default(self):
|
||||||
|
"""测试默认初始化"""
|
||||||
|
winsorizer = Winsorizer()
|
||||||
|
assert winsorizer.lower == 0.01
|
||||||
|
assert winsorizer.upper == 0.99
|
||||||
|
assert winsorizer.by_date is False
|
||||||
|
assert winsorizer.date_col == "trade_date"
|
||||||
|
|
||||||
|
def test_init_custom(self):
|
||||||
|
"""测试自定义参数"""
|
||||||
|
winsorizer = Winsorizer(lower=0.05, upper=0.95, by_date=True, date_col="date")
|
||||||
|
assert winsorizer.lower == 0.05
|
||||||
|
assert winsorizer.upper == 0.95
|
||||||
|
assert winsorizer.by_date is True
|
||||||
|
assert winsorizer.date_col == "date"
|
||||||
|
|
||||||
|
def test_invalid_quantiles(self):
|
||||||
|
"""测试无效的分位数参数"""
|
||||||
|
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
|
||||||
|
Winsorizer(lower=0.5, upper=0.3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
|
||||||
|
Winsorizer(lower=-0.1, upper=0.5)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
|
||||||
|
Winsorizer(lower=0.5, upper=1.5)
|
||||||
|
|
||||||
|
def test_global_winsorize(self):
|
||||||
|
"""测试全局缩尾"""
|
||||||
|
# 创建包含极端值的数据
|
||||||
|
values = list(range(1, 101)) # 1-100
|
||||||
|
values[0] = -1000 # 极端小值
|
||||||
|
values[-1] = 1000 # 极端大值
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": [f"A{i}" for i in range(100)],
|
||||||
|
"trade_date": ["20240101"] * 100,
|
||||||
|
"value": values,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
winsorizer = Winsorizer(lower=0.01, upper=0.99)
|
||||||
|
result = winsorizer.fit_transform(data)
|
||||||
|
|
||||||
|
# 1%分位数=2, 99%分位数=99
|
||||||
|
# -1000 应该被截断为 2
|
||||||
|
# 1000 应该被截断为 99
|
||||||
|
result_values = result["value"].to_list()
|
||||||
|
assert result_values[0] == 2 # 原-1000被截断
|
||||||
|
assert result_values[-1] == 99 # 原1000被截断
|
||||||
|
assert result_values[1] == 2 # 原2保持不变
|
||||||
|
assert result_values[98] == 99 # 原99保持不变
|
||||||
|
|
||||||
|
def test_by_date_winsorize(self):
|
||||||
|
"""测试每日独立缩尾"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C", "D", "E", "F"],
|
||||||
|
"trade_date": ["20240101"] * 3 + ["20240102"] * 3,
|
||||||
|
"value": [1.0, 50.0, 100.0, 200.0, 250.0, 300.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
winsorizer = Winsorizer(lower=0.0, upper=0.5, by_date=True)
|
||||||
|
result = winsorizer.transform(data)
|
||||||
|
|
||||||
|
# 每天独立处理:
|
||||||
|
# 2024-01-01: [1, 50, 100], 50%分位数=50
|
||||||
|
# -> 截断为 [1, 50, 50]
|
||||||
|
# 2024-01-02: [200, 250, 300], 50%分位数=250
|
||||||
|
# -> 截断为 [200, 250, 250]
|
||||||
|
result_values = result["value"].to_list()
|
||||||
|
assert result_values[0] == 1.0
|
||||||
|
assert result_values[1] == 50.0
|
||||||
|
assert result_values[2] == 50.0 # 被截断
|
||||||
|
assert result_values[3] == 200.0
|
||||||
|
assert result_values[4] == 250.0
|
||||||
|
assert result_values[5] == 250.0 # 被截断
|
||||||
|
|
||||||
|
def test_global_transform_after_fit(self):
|
||||||
|
"""测试全局模式下,转换使用拟合时的边界"""
|
||||||
|
train_data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"value": [1.0, 50.0, 100.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["D"],
|
||||||
|
"trade_date": ["20240102"],
|
||||||
|
"value": [200.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
winsorizer = Winsorizer(lower=0.0, upper=1.0) # 0%和100%分位数
|
||||||
|
winsorizer.fit(train_data)
|
||||||
|
|
||||||
|
# 使用训练集的分位数边界 [1, 100]
|
||||||
|
result = winsorizer.transform(test_data)
|
||||||
|
assert result["value"][0] == 100.0 # 被截断为100
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user