feat(models): 实现机器学习模型训练框架

- 添加核心抽象:Processor、Model、Splitter、Metric 基类
- 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露
- 内置 8 个数据处理器和 3 种时序划分策略
- 支持 LightGBM、CatBoost 模型
- PluginRegistry 装饰器注册,插件式架构
- 22 个单元测试
This commit is contained in:
2026-02-23 01:37:34 +08:00
parent e58b39970c
commit 9f95be56a0
16 changed files with 3774 additions and 865 deletions

86
src/models/__init__.py Normal file
View File

@@ -0,0 +1,86 @@
"""ProStock 模型训练框架
组件化、低耦合、插件式的机器学习训练框架。
示例:
>>> from src.models import (
... PluginRegistry, ProcessingPipeline,
... PipelineStage, BaseProcessor
... )
>>> # 获取注册的处理器
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
>>> scaler = scaler_class()
>>> # 创建处理流水线
>>> pipeline = ProcessingPipeline([
... PluginRegistry.get_processor("dropna")(),
... PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
... PluginRegistry.get_processor("standard_scaler")(),
... ])
"""
# 导入核心抽象类和划分策略
from src.models.core import (
PipelineStage,
TaskType,
BaseProcessor,
BaseModel,
BaseSplitter,
BaseMetric,
TimeSeriesSplit,
WalkForwardSplit,
ExpandingWindowSplit,
)
# 导入注册中心
from src.models.registry import PluginRegistry
# 导入处理流水线
from src.models.pipeline import ProcessingPipeline
# 导入并注册内置处理器
from src.models.processors.processors import (
DropNAProcessor,
FillNAProcessor,
Winsorizer,
StandardScaler,
MinMaxScaler,
RankTransformer,
Neutralizer,
)
# 导入并注册内置模型
from src.models.models.models import (
LightGBMModel,
CatBoostModel,
)
__all__ = [
# 核心抽象
"PipelineStage",
"TaskType",
"BaseProcessor",
"BaseModel",
"BaseSplitter",
"BaseMetric",
# 划分策略
"TimeSeriesSplit",
"WalkForwardSplit",
"ExpandingWindowSplit",
# 注册中心
"PluginRegistry",
# 处理流水线
"ProcessingPipeline",
# 处理器
"DropNAProcessor",
"FillNAProcessor",
"Winsorizer",
"StandardScaler",
"MinMaxScaler",
"RankTransformer",
"Neutralizer",
# 模型
"LightGBMModel",
"CatBoostModel",
]

View File

@@ -0,0 +1,30 @@
"""核心模块导出"""
from src.models.core.base import (
PipelineStage,
TaskType,
BaseProcessor,
BaseModel,
BaseSplitter,
BaseMetric,
)
from src.models.core.splitter import (
TimeSeriesSplit,
WalkForwardSplit,
ExpandingWindowSplit,
)
__all__ = [
# 基础抽象
"PipelineStage",
"TaskType",
"BaseProcessor",
"BaseModel",
"BaseSplitter",
"BaseMetric",
# 划分策略
"TimeSeriesSplit",
"WalkForwardSplit",
"ExpandingWindowSplit",
]

351
src/models/core/base.py Normal file
View File

@@ -0,0 +1,351 @@
"""模型训练框架核心抽象类
提供处理器、模型、划分策略和评估指标的基类定义。
"""
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Any, Dict, Iterator, List, Optional, Tuple, Literal
import polars as pl
import numpy as np
# 任务类型
TaskType = Literal["classification", "regression", "ranking"]
class PipelineStage(Enum):
"""流水线阶段标记
用于标记处理器在哪些阶段生效,防止数据泄露。
Attributes:
ALL: 适用于所有阶段(训练、测试、验证)
TRAIN: 仅训练阶段
TEST: 仅测试阶段
VALIDATION: 仅验证阶段
"""
ALL = auto()
TRAIN = auto()
TEST = auto()
VALIDATION = auto()
class BaseProcessor(ABC):
"""数据处理器基类
所有数据处理器必须继承此类。关键特性是通过 stage 属性控制处理器在哪些阶段生效。
阶段标记规则:
- ALL: 训练和测试阶段都使用相同的参数
- TRAIN: 只在训练阶段计算参数(如分位数、均值等),测试阶段使用训练阶段学到的参数
- TEST: 只在测试阶段执行
"""
# 子类必须定义适用阶段
stage: PipelineStage = PipelineStage.ALL
def __init__(self, columns: Optional[List[str]] = None, **params):
"""初始化处理器
Args:
columns: 要处理的列None表示所有数值列
**params: 处理器特定参数
"""
self.columns = columns
self.params = params
self._is_fitted = False
self._fitted_params: Dict[str, Any] = {}
@abstractmethod
def fit(self, data: pl.DataFrame) -> "BaseProcessor":
"""在训练数据上学习参数
此方法只在训练阶段调用一次。学习到的参数存储在 self._fitted_params 中。
Args:
data: 训练数据
Returns:
self (支持链式调用)
"""
pass
@abstractmethod
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""转换数据
在训练和测试阶段都会被调用。使用 fit() 阶段学习到的参数进行转换。
Args:
data: 输入数据
Returns:
转换后的数据
"""
pass
def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""先fit再transform的便捷方法
Args:
data: 训练数据
Returns:
转换后的数据
"""
return self.fit(data).transform(data)
def get_fitted_params(self) -> Dict[str, Any]:
"""获取学习到的参数(用于保存/加载)
Returns:
学习到的参数字典
"""
return self._fitted_params.copy()
def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor":
"""设置学习到的参数用于从checkpoint恢复
Args:
params: 参数字典
Returns:
self (支持链式调用)
"""
self._fitted_params = params.copy()
self._is_fitted = True
return self
class BaseModel(ABC):
"""机器学习模型基类
统一接口支持多种模型LightGBM, CatBoost, XGBoost等
和多种任务类型(分类、回归、排序)。
"""
def __init__(
self,
task_type: TaskType,
params: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
):
"""初始化模型
Args:
task_type: 任务类型 - "classification", "regression", "ranking"
params: 模型特定参数
name: 模型名称(用于日志和报告)
"""
self.task_type = task_type
self.params = params or {}
self.name = name or self.__class__.__name__
self._model: Any = None
self._is_fitted = False
@abstractmethod
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params,
) -> "BaseModel":
"""训练模型
Args:
X: 特征数据
y: 目标变量
X_val: 验证集特征(可选)
y_val: 验证集目标(可选)
**fit_params: 额外的fit参数
Returns:
self (支持链式调用)
"""
pass
@abstractmethod
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测
Args:
X: 特征数据
Returns:
预测结果数组
- classification: 类别标签或概率
- regression: 连续值
- ranking: 排序分数
"""
pass
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
"""预测概率(仅分类任务)
Args:
X: 特征数据
Returns:
类别概率数组 [n_samples, n_classes]
Raises:
NotImplementedError: 非分类任务时抛出
"""
raise NotImplementedError(
"predict_proba only available for classification tasks"
)
def get_feature_importance(self) -> Optional[pl.DataFrame]:
"""获取特征重要性(如果模型支持)
Returns:
DataFrame[feature, importance] 或 None
"""
return None
def save(self, path: str) -> None:
"""保存模型到文件
Args:
path: 保存路径
"""
import pickle
with open(path, "wb") as f:
pickle.dump(self, f)
@classmethod
def load(cls, path: str) -> "BaseModel":
"""从文件加载模型
Args:
path: 模型文件路径
Returns:
加载的模型实例
"""
import pickle
with open(path, "rb") as f:
return pickle.load(f)
class BaseSplitter(ABC):
"""数据划分策略基类
针对时间序列数据的特殊划分策略,防止未来泄露。
"""
@abstractmethod
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引
Args:
data: 完整数据集
date_col: 日期列名
Yields:
(train_indices, test_indices) 元组
"""
pass
@abstractmethod
def get_split_dates(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围
Args:
data: 完整数据集
date_col: 日期列名
Returns:
[(train_start, train_end, test_start, test_end), ...]
"""
pass
class BaseMetric(ABC):
"""评估指标基类
所有评估指标必须继承此类。支持单次计算和累积计算两种模式。
"""
def __init__(self, name: Optional[str] = None):
"""初始化指标
Args:
name: 指标名称
"""
self.name = name or self.__class__.__name__
self._values: List[float] = []
@abstractmethod
def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""计算指标值
Args:
y_true: 真实值
y_pred: 预测值
Returns:
指标值
"""
pass
def update(self, y_true: np.ndarray, y_pred: np.ndarray) -> "BaseMetric":
"""更新累积值
Args:
y_true: 真实值
y_pred: 预测值
Returns:
self (支持链式调用)
"""
self._values.append(self.compute(y_true, y_pred))
return self
def get_mean(self) -> float:
"""获取累积值的均值
Returns:
均值
"""
if not self._values:
return 0.0
return float(np.mean(self._values))
def get_std(self) -> float:
"""获取累积值的标准差
Returns:
标准差
"""
if not self._values:
return 0.0
return float(np.std(self._values))
def reset(self) -> "BaseMetric":
"""重置累积值
Returns:
self (支持链式调用)
"""
self._values = []
return self
__all__ = [
"PipelineStage",
"TaskType",
"BaseProcessor",
"BaseModel",
"BaseSplitter",
"BaseMetric",
]

222
src/models/core/splitter.py Normal file
View File

@@ -0,0 +1,222 @@
"""时间序列数据划分策略
提供针对金融时间序列的特殊划分策略,防止未来泄露。
"""
from typing import Iterator, List, Tuple
import polars as pl
from src.models.core.base import BaseSplitter
class TimeSeriesSplit(BaseSplitter):
"""时间序列划分 - 确保训练数据在测试数据之前
按照时间顺序进行K折划分每折的训练数据都在测试数据之前。
通过 gap 参数防止训练集和测试集之间的数据泄露。
Args:
n_splits: 划分折数
gap: 训练集和测试集之间的间隔天数(防止泄露)
min_train_size: 最小训练集大小(天数)
"""
def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252):
self.n_splits = n_splits
self.gap = gap
self.min_train_size = min_train_size
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
test_size = (n_dates - self.min_train_size) // self.n_splits
for i in range(self.n_splits):
train_end_idx = self.min_train_size + i * test_size
test_start_idx = train_end_idx + self.gap
test_end_idx = test_start_idx + test_size
if test_end_idx > n_dates:
break
train_dates = dates[:train_end_idx]
test_dates = dates[test_start_idx:test_end_idx]
train_mask = data[date_col].is_in(train_dates.to_list())
test_mask = data[date_col].is_in(test_dates.to_list())
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
yield train_idx, test_idx
def get_split_dates(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
test_size = (n_dates - self.min_train_size) // self.n_splits
result = []
for i in range(self.n_splits):
train_end_idx = self.min_train_size + i * test_size
test_start_idx = train_end_idx + self.gap
test_end_idx = test_start_idx + test_size
if test_end_idx > n_dates:
break
result.append(
(
str(dates[0]),
str(dates[train_end_idx - 1]),
str(dates[test_start_idx]),
str(dates[test_end_idx - 1]),
)
)
return result
class WalkForwardSplit(BaseSplitter):
"""滚动前向验证 - 训练集逐步扩展
Args:
train_window: 训练集窗口大小(天数)
test_window: 测试集窗口大小(天数)
gap: 训练集和测试集之间的间隔天数
"""
def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5):
self.train_window = train_window
self.test_window = test_window
self.gap = gap
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
start_idx = self.train_window
while start_idx + self.gap + self.test_window <= n_dates:
train_start = start_idx - self.train_window
train_end = start_idx
test_start = start_idx + self.gap
test_end = test_start + self.test_window
train_dates = dates[train_start:train_end]
test_dates = dates[test_start:test_end]
train_mask = data[date_col].is_in(train_dates.to_list())
test_mask = data[date_col].is_in(test_dates.to_list())
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
yield train_idx, test_idx
start_idx += self.test_window
def get_split_dates(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
result = []
start_idx = self.train_window
while start_idx + self.gap + self.test_window <= n_dates:
train_start = start_idx - self.train_window
train_end = start_idx
test_start = start_idx + self.gap
test_end = test_start + self.test_window
result.append(
(
str(dates[train_start]),
str(dates[train_end - 1]),
str(dates[test_start]),
str(dates[test_end - 1]),
)
)
start_idx += self.test_window
return result
class ExpandingWindowSplit(BaseSplitter):
"""扩展窗口划分 - 训练集不断扩大
Args:
initial_train_size: 初始训练集大小(天数)
test_window: 测试集窗口大小(天数)
gap: 训练集和测试集之间的间隔天数
"""
def __init__(
self, initial_train_size: int = 252, test_window: int = 21, gap: int = 5
):
self.initial_train_size = initial_train_size
self.test_window = test_window
self.gap = gap
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Iterator[Tuple[List[int], List[int]]]:
"""生成训练/测试索引"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
train_end_idx = self.initial_train_size
while train_end_idx + self.gap + self.test_window <= n_dates:
train_dates = dates[:train_end_idx]
test_start = train_end_idx + self.gap
test_end = test_start + self.test_window
test_dates = dates[test_start:test_end]
train_mask = data[date_col].is_in(train_dates.to_list())
test_mask = data[date_col].is_in(test_dates.to_list())
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
yield train_idx, test_idx
train_end_idx += self.test_window
def get_split_dates(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> List[Tuple[str, str, str, str]]:
"""获取划分日期范围"""
dates = data[date_col].unique().sort()
n_dates = len(dates)
result = []
train_end_idx = self.initial_train_size
while train_end_idx + self.gap + self.test_window <= n_dates:
test_start = train_end_idx + self.gap
test_end = test_start + self.test_window
result.append(
(
str(dates[0]),
str(dates[train_end_idx - 1]),
str(dates[test_start]),
str(dates[test_end - 1]),
)
)
train_end_idx += self.test_window
return result
__all__ = [
"TimeSeriesSplit",
"WalkForwardSplit",
"ExpandingWindowSplit",
]

View File

@@ -0,0 +1,11 @@
"""模型模块"""
from src.models.models.models import (
LightGBMModel,
CatBoostModel,
)
__all__ = [
"LightGBMModel",
"CatBoostModel",
]

210
src/models/models/models.py Normal file
View File

@@ -0,0 +1,210 @@
"""内置机器学习模型
提供 LightGBM、CatBoost 等模型的统一接口包装器。
"""
from typing import Optional, Dict, Any
import polars as pl
import numpy as np
from src.models.core import BaseModel, TaskType
from src.models.registry import PluginRegistry
@PluginRegistry.register_model("lightgbm")
class LightGBMModel(BaseModel):
"""LightGBM 模型包装器
支持分类、回归、排序三种任务类型。
"""
def __init__(
self,
task_type: TaskType,
params: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
):
super().__init__(task_type, params, name)
self._model = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params,
) -> "LightGBMModel":
"""训练模型"""
try:
import lightgbm as lgb
except ImportError:
raise ImportError(
"lightgbm is required. Install with: uv pip install lightgbm"
)
X_arr = X.to_numpy()
y_arr = y.to_numpy()
train_data = lgb.Dataset(X_arr, label=y_arr)
valid_sets = [train_data]
valid_names = ["train"]
if X_val is not None and y_val is not None:
valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
valid_sets.append(valid_data)
valid_names.append("valid")
default_params = {
"objective": self._get_objective(),
"metric": self._get_metric(),
"boosting_type": "gbdt",
"num_leaves": 31,
"learning_rate": 0.05,
"feature_fraction": 0.9,
"bagging_fraction": 0.8,
"bagging_freq": 5,
"verbose": -1,
}
default_params.update(self.params)
callbacks = []
if len(valid_sets) > 1:
callbacks.append(lgb.early_stopping(stopping_rounds=10, verbose=False))
self._model = lgb.train(
default_params,
train_data,
num_boost_round=fit_params.get("num_boost_round", 100),
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=callbacks,
)
self._is_fitted = True
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测"""
if not self._is_fitted:
raise RuntimeError("Model not fitted yet")
return self._model.predict(X.to_numpy())
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
"""预测概率(仅分类任务)"""
if self.task_type != "classification":
raise ValueError("predict_proba only for classification")
probs = self.predict(X)
if len(probs.shape) == 1:
return np.vstack([1 - probs, probs]).T
return probs
def get_feature_importance(self) -> Optional[pl.DataFrame]:
"""获取特征重要性"""
if self._model is None:
return None
importance = self._model.feature_importance(importance_type="gain")
feature_names = getattr(
self._model,
"feature_name",
lambda: [f"feature_{i}" for i in range(len(importance))],
)()
return pl.DataFrame({"feature": feature_names, "importance": importance}).sort(
"importance", descending=True
)
def _get_objective(self) -> str:
objectives = {
"classification": "binary",
"regression": "regression",
"ranking": "lambdarank",
}
return objectives.get(self.task_type, "regression")
def _get_metric(self) -> str:
metrics = {"classification": "auc", "regression": "rmse", "ranking": "ndcg"}
return metrics.get(self.task_type, "rmse")
@PluginRegistry.register_model("catboost")
class CatBoostModel(BaseModel):
"""CatBoost 模型包装器"""
def __init__(
self,
task_type: TaskType,
params: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
):
super().__init__(task_type, params, name)
self._model = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
X_val: Optional[pl.DataFrame] = None,
y_val: Optional[pl.Series] = None,
**fit_params,
) -> "CatBoostModel":
"""训练模型"""
try:
from catboost import CatBoostClassifier, CatBoostRegressor
except ImportError:
raise ImportError(
"catboost is required. Install with: uv pip install catboost"
)
if self.task_type == "classification":
model_class = CatBoostClassifier
default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
elif self.task_type == "regression":
model_class = CatBoostRegressor
default_params = {"loss_function": "RMSE"}
else:
model_class = CatBoostRegressor
default_params = {"loss_function": "QueryRMSE"}
default_params.update(self.params)
default_params["verbose"] = False
self._model = model_class(**default_params)
eval_set = None
if X_val is not None and y_val is not None:
eval_set = (X_val.to_pandas(), y_val.to_pandas())
self._model.fit(
X.to_pandas(),
y.to_pandas(),
eval_set=eval_set,
early_stopping_rounds=fit_params.get("early_stopping_rounds", 10),
verbose=False,
)
self._is_fitted = True
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测"""
if not self._is_fitted:
raise RuntimeError("Model not fitted yet")
return self._model.predict(X.to_pandas())
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
"""预测概率"""
if self.task_type != "classification":
raise ValueError("predict_proba only for classification")
return self._model.predict_proba(X.to_pandas())
def get_feature_importance(self) -> Optional[pl.DataFrame]:
"""获取特征重要性"""
if self._model is None:
return None
return pl.DataFrame(
{
"feature": self._model.feature_names_,
"importance": self._model.feature_importances_,
}
).sort("importance", descending=True)
__all__ = ["LightGBMModel", "CatBoostModel"]

70
src/models/pipeline.py Normal file
View File

@@ -0,0 +1,70 @@
"""数据处理流水线
管理多个处理器的顺序执行,支持阶段感知处理。
"""
from typing import List, Dict
import polars as pl
from src.models.core import BaseProcessor, PipelineStage
class ProcessingPipeline:
"""数据处理流水线
按顺序执行多个处理器,自动处理阶段标记。
关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。
"""
def __init__(self, processors: List[BaseProcessor]):
"""初始化流水线
Args:
processors: 处理器列表(按执行顺序)
"""
self.processors = processors
self._fitted_processors: Dict[int, BaseProcessor] = {}
def fit_transform(
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN
) -> pl.DataFrame:
"""在训练数据上fit所有处理器并transform"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
result = processor.fit_transform(result)
self._fitted_processors[i] = processor
elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST:
processor.fit(result)
self._fitted_processors[i] = processor
return result
def transform(
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST
) -> pl.DataFrame:
"""在测试数据上应用已fit的处理器"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
if i in self._fitted_processors:
result = self._fitted_processors[i].transform(result)
else:
result = processor.transform(result)
return result
def save_processors(self, path: str) -> None:
"""保存所有已fit的处理器状态"""
import pickle
with open(path, "wb") as f:
pickle.dump(self._fitted_processors, f)
def load_processors(self, path: str) -> None:
"""加载处理器状态"""
import pickle
with open(path, "rb") as f:
self._fitted_processors = pickle.load(f)
__all__ = ["ProcessingPipeline"]

View File

@@ -0,0 +1,21 @@
"""处理器模块"""
from src.models.processors.processors import (
DropNAProcessor,
FillNAProcessor,
Winsorizer,
StandardScaler,
MinMaxScaler,
RankTransformer,
Neutralizer,
)
__all__ = [
"DropNAProcessor",
"FillNAProcessor",
"Winsorizer",
"StandardScaler",
"MinMaxScaler",
"RankTransformer",
"Neutralizer",
]

View File

@@ -0,0 +1,238 @@
"""内置数据处理器
提供常用的数据预处理和转换处理器。
"""
from typing import List, Optional, Dict, Any
import polars as pl
import numpy as np
from src.models.core import BaseProcessor, PipelineStage
from src.models.registry import PluginRegistry
# 数值类型列表
FLOAT_TYPES = [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64]
def _get_numeric_columns(
data: pl.DataFrame, columns: Optional[List[str]] = None
) -> List[str]:
"""获取数值列"""
if columns is not None:
return columns
return [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
@PluginRegistry.register_processor("dropna")
class DropNAProcessor(BaseProcessor):
"""缺失值删除处理器"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "DropNAProcessor":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
cols = self.columns or data.columns
return data.drop_nulls(subset=cols)
@PluginRegistry.register_processor("fillna")
class FillNAProcessor(BaseProcessor):
"""缺失值填充处理器(只在训练阶段计算填充值)"""
stage = PipelineStage.TRAIN
def __init__(self, columns: Optional[List[str]] = None, method: str = "median"):
super().__init__(columns)
if method not in ["median", "mean", "zero"]:
raise ValueError(f"Unknown fill method: {method}")
self.method = method
def fit(self, data: pl.DataFrame) -> "FillNAProcessor":
cols = _get_numeric_columns(data, self.columns)
fill_values = {}
for col in cols:
if self.method == "median":
fill_values[col] = data[col].median()
elif self.method == "mean":
fill_values[col] = data[col].mean()
elif self.method == "zero":
fill_values[col] = 0.0
self._fitted_params = {"fill_values": fill_values, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, val in self._fitted_params.get("fill_values", {}).items():
if col in result.columns:
result = result.with_columns(pl.col(col).fill_null(val).alias(col))
return result
@PluginRegistry.register_processor("winsorizer")
class Winsorizer(BaseProcessor):
"""缩尾处理器 - 防止极端值影响(只在训练阶段计算分位数)"""
stage = PipelineStage.TRAIN
def __init__(
self,
columns: Optional[List[str]] = None,
lower: float = 0.01,
upper: float = 0.99,
):
super().__init__(columns)
self.lower = lower
self.upper = upper
def fit(self, data: pl.DataFrame) -> "Winsorizer":
cols = _get_numeric_columns(data, self.columns)
bounds = {}
for col in cols:
bounds[col] = {
"lower": data[col].quantile(self.lower),
"upper": data[col].quantile(self.upper),
}
self._fitted_params = {"bounds": bounds, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, bounds in self._fitted_params.get("bounds", {}).items():
if col in result.columns:
result = result.with_columns(
pl.col(col).clip(bounds["lower"], bounds["upper"]).alias(col)
)
return result
@PluginRegistry.register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
"""标准化处理器 - Z-score标准化"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "StandardScaler":
cols = _get_numeric_columns(data, self.columns)
stats = {}
for col in cols:
stats[col] = {"mean": data[col].mean(), "std": data[col].std()}
self._fitted_params = {"stats": stats, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, stats in self._fitted_params.get("stats", {}).items():
if col in result.columns and stats["std"] is not None and stats["std"] > 0:
result = result.with_columns(
((pl.col(col) - stats["mean"]) / stats["std"]).alias(col)
)
return result
@PluginRegistry.register_processor("minmax_scaler")
class MinMaxScaler(BaseProcessor):
"""归一化处理器 - 缩放到[0, 1]范围"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "MinMaxScaler":
cols = _get_numeric_columns(data, self.columns)
stats = {}
for col in cols:
stats[col] = {"min": data[col].min(), "max": data[col].max()}
self._fitted_params = {"stats": stats, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
for col, stats in self._fitted_params.get("stats", {}).items():
if col in result.columns:
range_val = stats["max"] - stats["min"]
if range_val is not None and range_val > 0:
result = result.with_columns(
((pl.col(col) - stats["min"]) / range_val).alias(col)
)
return result
@PluginRegistry.register_processor("rank_transformer")
class RankTransformer(BaseProcessor):
"""排名转换处理器 - 转换为截面排名"""
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "RankTransformer":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
cols = self.columns or _get_numeric_columns(data)
for col in cols:
if col in result.columns:
result = result.with_columns(
pl.col(col).rank().over("trade_date").alias(col)
)
return result
@PluginRegistry.register_processor("neutralizer")
class Neutralizer(BaseProcessor):
"""中性化处理器 - 行业/市值中性化"""
stage = PipelineStage.ALL
def __init__(
self,
columns: Optional[List[str]] = None,
group_col: str = "industry",
exclude_cols: Optional[List[str]] = None,
):
super().__init__(columns)
self.group_col = group_col
self.exclude_cols = exclude_cols or []
def fit(self, data: pl.DataFrame) -> "Neutralizer":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data
cols = self.columns or _get_numeric_columns(data)
for col in cols:
if col in result.columns and col not in self.exclude_cols:
result = result.with_columns(
(
pl.col(col)
- pl.col(col).mean().over(["trade_date", self.group_col])
).alias(col)
)
return result
__all__ = [
"DropNAProcessor",
"FillNAProcessor",
"Winsorizer",
"StandardScaler",
"MinMaxScaler",
"RankTransformer",
"Neutralizer",
]

297
src/models/registry.py Normal file
View File

@@ -0,0 +1,297 @@
"""插件注册中心
提供装饰器方式注册处理器、模型、划分策略等组件。
实现真正的插件式架构 - 新功能只需注册即可使用。
示例:
>>> @PluginRegistry.register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>> # 使用
>>> scaler = PluginRegistry.get_processor("standard_scaler")()
"""
from typing import Type, Dict, List, TypeVar, Optional
from functools import wraps
from weakref import WeakValueDictionary
import contextlib
from src.models.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
T = TypeVar("T")
class PluginRegistry:
"""插件注册中心
管理所有组件的注册和获取。使用装饰器方式注册新组件。
Attributes:
_processors: 已注册的处理器字典
_models: 已注册的模型字典
_splitters: 已注册的划分策略字典
_metrics: 已注册的评估指标字典
"""
_processors: Dict[str, Type[BaseProcessor]] = {}
_models: Dict[str, Type[BaseModel]] = {}
_splitters: Dict[str, Type[BaseSplitter]] = {}
_metrics: Dict[str, Type[BaseMetric]] = {}
@classmethod
@contextlib.contextmanager
def temp_registry(cls):
"""临时注册上下文管理器
在上下文管理器内部注册的组件会在退出时自动清理,
避免测试之间的状态污染。
示例:
>>> with PluginRegistry.temp_registry():
... @PluginRegistry.register_processor("temp_processor")
... class TempProcessor(BaseProcessor):
... pass
... # 在此处可以使用 temp_processor
... # 退出后自动清理
"""
original_state = {
"_processors": cls._processors.copy(),
"_models": cls._models.copy(),
"_splitters": cls._splitters.copy(),
"_metrics": cls._metrics.copy(),
}
try:
yield cls
finally:
cls._processors = original_state["_processors"]
cls._models = original_state["_models"]
cls._splitters = original_state["_splitters"]
cls._metrics = original_state["_metrics"]
@classmethod
def register_processor(cls, name: Optional[str] = None):
"""注册处理器装饰器
用于装饰器方式注册数据处理器。
示例:
>>> @PluginRegistry.register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>> # 获取并使用
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
>>> scaler = scaler_class()
Args:
name: 注册名称,默认为类名
Returns:
装饰器函数
"""
def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]:
key = name or processor_class.__name__
cls._processors[key] = processor_class
processor_class._registry_name = key
return processor_class
return decorator
@classmethod
def register_model(cls, name: Optional[str] = None):
"""注册模型装饰器
用于装饰器方式注册机器学习模型。
示例:
>>> @PluginRegistry.register_model("lightgbm")
... class LightGBMModel(BaseModel):
... pass
Args:
name: 注册名称,默认为类名
Returns:
装饰器函数
"""
def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]:
key = name or model_class.__name__
cls._models[key] = model_class
model_class._registry_name = key
return model_class
return decorator
@classmethod
def register_splitter(cls, name: Optional[str] = None):
"""注册划分策略装饰器
用于装饰器方式注册数据划分策略。
示例:
>>> @PluginRegistry.register_splitter("time_series")
... class TimeSeriesSplit(BaseSplitter):
... pass
Args:
name: 注册名称,默认为类名
Returns:
装饰器函数
"""
def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]:
key = name or splitter_class.__name__
cls._splitters[key] = splitter_class
splitter_class._registry_name = key
return splitter_class
return decorator
@classmethod
def register_metric(cls, name: Optional[str] = None):
"""注册评估指标装饰器
用于装饰器方式注册评估指标。
示例:
>>> @PluginRegistry.register_metric("ic")
... class ICMetric(BaseMetric):
... pass
Args:
name: 注册名称,默认为类名
Returns:
装饰器函数
"""
def decorator(metric_class: Type[BaseMetric]) -> Type[BaseMetric]:
key = name or metric_class.__name__
cls._metrics[key] = metric_class
metric_class._registry_name = key
return metric_class
return decorator
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""获取处理器类
Args:
name: 处理器注册名称
Returns:
处理器类
Raises:
KeyError: 处理器不存在时抛出
"""
if name not in cls._processors:
available = list(cls._processors.keys())
raise KeyError(f"Processor '{name}' not found. Available: {available}")
return cls._processors[name]
@classmethod
def get_model(cls, name: str) -> Type[BaseModel]:
"""获取模型类
Args:
name: 模型注册名称
Returns:
模型类
Raises:
KeyError: 模型不存在时抛出
"""
if name not in cls._models:
available = list(cls._models.keys())
raise KeyError(f"Model '{name}' not found. Available: {available}")
return cls._models[name]
@classmethod
def get_splitter(cls, name: str) -> Type[BaseSplitter]:
"""获取划分策略类
Args:
name: 划分策略注册名称
Returns:
划分策略类
Raises:
KeyError: 划分策略不存在时抛出
"""
if name not in cls._splitters:
available = list(cls._splitters.keys())
raise KeyError(f"Splitter '{name}' not found. Available: {available}")
return cls._splitters[name]
@classmethod
def get_metric(cls, name: str) -> Type[BaseMetric]:
"""获取评估指标类
Args:
name: 评估指标注册名称
Returns:
评估指标类
Raises:
KeyError: 评估指标不存在时抛出
"""
if name not in cls._metrics:
available = list(cls._metrics.keys())
raise KeyError(f"Metric '{name}' not found. Available: {available}")
return cls._metrics[name]
@classmethod
def list_processors(cls) -> List[str]:
"""列出所有可用处理器
Returns:
处理器名称列表
"""
return list(cls._processors.keys())
@classmethod
def list_models(cls) -> List[str]:
"""列出所有可用模型
Returns:
模型名称列表
"""
return list(cls._models.keys())
@classmethod
def list_splitters(cls) -> List[str]:
"""列出所有可用划分策略
Returns:
划分策略名称列表
"""
return list(cls._splitters.keys())
@classmethod
def list_metrics(cls) -> List[str]:
"""列出所有可用评估指标
Returns:
评估指标名称列表
"""
return list(cls._metrics.keys())
@classmethod
def clear_all(cls) -> None:
"""清除所有注册(主要用于测试)"""
cls._processors.clear()
cls._models.clear()
cls._splitters.clear()
cls._metrics.clear()
__all__ = ["PluginRegistry"]