feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类 - 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露 - 内置 8 个数据处理器和 3 种时序划分策略 - 支持 LightGBM、CatBoost 模型 - PluginRegistry 装饰器注册,插件式架构 - 22 个单元测试
This commit is contained in:
86
src/models/__init__.py
Normal file
86
src/models/__init__.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""ProStock 模型训练框架
|
||||
|
||||
组件化、低耦合、插件式的机器学习训练框架。
|
||||
|
||||
示例:
|
||||
>>> from src.models import (
|
||||
... PluginRegistry, ProcessingPipeline,
|
||||
... PipelineStage, BaseProcessor
|
||||
... )
|
||||
|
||||
>>> # 获取注册的处理器
|
||||
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
|
||||
>>> scaler = scaler_class()
|
||||
|
||||
>>> # 创建处理流水线
|
||||
>>> pipeline = ProcessingPipeline([
|
||||
... PluginRegistry.get_processor("dropna")(),
|
||||
... PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
|
||||
... PluginRegistry.get_processor("standard_scaler")(),
|
||||
... ])
|
||||
"""
|
||||
|
||||
# 导入核心抽象类和划分策略
|
||||
from src.models.core import (
|
||||
PipelineStage,
|
||||
TaskType,
|
||||
BaseProcessor,
|
||||
BaseModel,
|
||||
BaseSplitter,
|
||||
BaseMetric,
|
||||
TimeSeriesSplit,
|
||||
WalkForwardSplit,
|
||||
ExpandingWindowSplit,
|
||||
)
|
||||
|
||||
# 导入注册中心
|
||||
from src.models.registry import PluginRegistry
|
||||
|
||||
# 导入处理流水线
|
||||
from src.models.pipeline import ProcessingPipeline
|
||||
|
||||
# 导入并注册内置处理器
|
||||
from src.models.processors.processors import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
MinMaxScaler,
|
||||
RankTransformer,
|
||||
Neutralizer,
|
||||
)
|
||||
|
||||
# 导入并注册内置模型
|
||||
from src.models.models.models import (
|
||||
LightGBMModel,
|
||||
CatBoostModel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 核心抽象
|
||||
"PipelineStage",
|
||||
"TaskType",
|
||||
"BaseProcessor",
|
||||
"BaseModel",
|
||||
"BaseSplitter",
|
||||
"BaseMetric",
|
||||
# 划分策略
|
||||
"TimeSeriesSplit",
|
||||
"WalkForwardSplit",
|
||||
"ExpandingWindowSplit",
|
||||
# 注册中心
|
||||
"PluginRegistry",
|
||||
# 处理流水线
|
||||
"ProcessingPipeline",
|
||||
# 处理器
|
||||
"DropNAProcessor",
|
||||
"FillNAProcessor",
|
||||
"Winsorizer",
|
||||
"StandardScaler",
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
# 模型
|
||||
"LightGBMModel",
|
||||
"CatBoostModel",
|
||||
]
|
||||
30
src/models/core/__init__.py
Normal file
30
src/models/core/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""核心模块导出"""
|
||||
|
||||
from src.models.core.base import (
|
||||
PipelineStage,
|
||||
TaskType,
|
||||
BaseProcessor,
|
||||
BaseModel,
|
||||
BaseSplitter,
|
||||
BaseMetric,
|
||||
)
|
||||
|
||||
from src.models.core.splitter import (
|
||||
TimeSeriesSplit,
|
||||
WalkForwardSplit,
|
||||
ExpandingWindowSplit,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 基础抽象
|
||||
"PipelineStage",
|
||||
"TaskType",
|
||||
"BaseProcessor",
|
||||
"BaseModel",
|
||||
"BaseSplitter",
|
||||
"BaseMetric",
|
||||
# 划分策略
|
||||
"TimeSeriesSplit",
|
||||
"WalkForwardSplit",
|
||||
"ExpandingWindowSplit",
|
||||
]
|
||||
351
src/models/core/base.py
Normal file
351
src/models/core/base.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""模型训练框架核心抽象类
|
||||
|
||||
提供处理器、模型、划分策略和评估指标的基类定义。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Literal
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
# 任务类型
|
||||
TaskType = Literal["classification", "regression", "ranking"]
|
||||
|
||||
|
||||
class PipelineStage(Enum):
|
||||
"""流水线阶段标记
|
||||
|
||||
用于标记处理器在哪些阶段生效,防止数据泄露。
|
||||
|
||||
Attributes:
|
||||
ALL: 适用于所有阶段(训练、测试、验证)
|
||||
TRAIN: 仅训练阶段
|
||||
TEST: 仅测试阶段
|
||||
VALIDATION: 仅验证阶段
|
||||
"""
|
||||
|
||||
ALL = auto()
|
||||
TRAIN = auto()
|
||||
TEST = auto()
|
||||
VALIDATION = auto()
|
||||
|
||||
|
||||
class BaseProcessor(ABC):
|
||||
"""数据处理器基类
|
||||
|
||||
所有数据处理器必须继承此类。关键特性是通过 stage 属性控制处理器在哪些阶段生效。
|
||||
|
||||
阶段标记规则:
|
||||
- ALL: 训练和测试阶段都使用相同的参数
|
||||
- TRAIN: 只在训练阶段计算参数(如分位数、均值等),测试阶段使用训练阶段学到的参数
|
||||
- TEST: 只在测试阶段执行
|
||||
"""
|
||||
|
||||
# 子类必须定义适用阶段
|
||||
stage: PipelineStage = PipelineStage.ALL
|
||||
|
||||
def __init__(self, columns: Optional[List[str]] = None, **params):
|
||||
"""初始化处理器
|
||||
|
||||
Args:
|
||||
columns: 要处理的列,None表示所有数值列
|
||||
**params: 处理器特定参数
|
||||
"""
|
||||
self.columns = columns
|
||||
self.params = params
|
||||
self._is_fitted = False
|
||||
self._fitted_params: Dict[str, Any] = {}
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, data: pl.DataFrame) -> "BaseProcessor":
|
||||
"""在训练数据上学习参数
|
||||
|
||||
此方法只在训练阶段调用一次。学习到的参数存储在 self._fitted_params 中。
|
||||
|
||||
Args:
|
||||
data: 训练数据
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""转换数据
|
||||
|
||||
在训练和测试阶段都会被调用。使用 fit() 阶段学习到的参数进行转换。
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
pass
|
||||
|
||||
def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""先fit再transform的便捷方法
|
||||
|
||||
Args:
|
||||
data: 训练数据
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
return self.fit(data).transform(data)
|
||||
|
||||
def get_fitted_params(self) -> Dict[str, Any]:
|
||||
"""获取学习到的参数(用于保存/加载)
|
||||
|
||||
Returns:
|
||||
学习到的参数字典
|
||||
"""
|
||||
return self._fitted_params.copy()
|
||||
|
||||
def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor":
|
||||
"""设置学习到的参数(用于从checkpoint恢复)
|
||||
|
||||
Args:
|
||||
params: 参数字典
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
self._fitted_params = params.copy()
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""机器学习模型基类
|
||||
|
||||
统一接口支持多种模型(LightGBM, CatBoost, XGBoost等)
|
||||
和多种任务类型(分类、回归、排序)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: TaskType,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""初始化模型
|
||||
|
||||
Args:
|
||||
task_type: 任务类型 - "classification", "regression", "ranking"
|
||||
params: 模型特定参数
|
||||
name: 模型名称(用于日志和报告)
|
||||
"""
|
||||
self.task_type = task_type
|
||||
self.params = params or {}
|
||||
self.name = name or self.__class__.__name__
|
||||
self._model: Any = None
|
||||
self._is_fitted = False
|
||||
|
||||
@abstractmethod
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
X_val: Optional[pl.DataFrame] = None,
|
||||
y_val: Optional[pl.Series] = None,
|
||||
**fit_params,
|
||||
) -> "BaseModel":
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X: 特征数据
|
||||
y: 目标变量
|
||||
X_val: 验证集特征(可选)
|
||||
y_val: 验证集目标(可选)
|
||||
**fit_params: 额外的fit参数
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测
|
||||
|
||||
Args:
|
||||
X: 特征数据
|
||||
|
||||
Returns:
|
||||
预测结果数组
|
||||
- classification: 类别标签或概率
|
||||
- regression: 连续值
|
||||
- ranking: 排序分数
|
||||
"""
|
||||
pass
|
||||
|
||||
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测概率(仅分类任务)
|
||||
|
||||
Args:
|
||||
X: 特征数据
|
||||
|
||||
Returns:
|
||||
类别概率数组 [n_samples, n_classes]
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 非分类任务时抛出
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"predict_proba only available for classification tasks"
|
||||
)
|
||||
|
||||
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||
"""获取特征重要性(如果模型支持)
|
||||
|
||||
Returns:
|
||||
DataFrame[feature, importance] 或 None
|
||||
"""
|
||||
return None
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""保存模型到文件
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
"""
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "BaseModel":
|
||||
"""从文件加载模型
|
||||
|
||||
Args:
|
||||
path: 模型文件路径
|
||||
|
||||
Returns:
|
||||
加载的模型实例
|
||||
"""
|
||||
import pickle
|
||||
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
class BaseSplitter(ABC):
|
||||
"""数据划分策略基类
|
||||
|
||||
针对时间序列数据的特殊划分策略,防止未来泄露。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def split(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||
"""生成训练/测试索引
|
||||
|
||||
Args:
|
||||
data: 完整数据集
|
||||
date_col: 日期列名
|
||||
|
||||
Yields:
|
||||
(train_indices, test_indices) 元组
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_split_dates(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""获取划分日期范围
|
||||
|
||||
Args:
|
||||
data: 完整数据集
|
||||
date_col: 日期列名
|
||||
|
||||
Returns:
|
||||
[(train_start, train_end, test_start, test_end), ...]
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseMetric(ABC):
|
||||
"""评估指标基类
|
||||
|
||||
所有评估指标必须继承此类。支持单次计算和累积计算两种模式。
|
||||
"""
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
"""初始化指标
|
||||
|
||||
Args:
|
||||
name: 指标名称
|
||||
"""
|
||||
self.name = name or self.__class__.__name__
|
||||
self._values: List[float] = []
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""计算指标值
|
||||
|
||||
Args:
|
||||
y_true: 真实值
|
||||
y_pred: 预测值
|
||||
|
||||
Returns:
|
||||
指标值
|
||||
"""
|
||||
pass
|
||||
|
||||
def update(self, y_true: np.ndarray, y_pred: np.ndarray) -> "BaseMetric":
|
||||
"""更新累积值
|
||||
|
||||
Args:
|
||||
y_true: 真实值
|
||||
y_pred: 预测值
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
self._values.append(self.compute(y_true, y_pred))
|
||||
return self
|
||||
|
||||
def get_mean(self) -> float:
|
||||
"""获取累积值的均值
|
||||
|
||||
Returns:
|
||||
均值
|
||||
"""
|
||||
if not self._values:
|
||||
return 0.0
|
||||
return float(np.mean(self._values))
|
||||
|
||||
def get_std(self) -> float:
|
||||
"""获取累积值的标准差
|
||||
|
||||
Returns:
|
||||
标准差
|
||||
"""
|
||||
if not self._values:
|
||||
return 0.0
|
||||
return float(np.std(self._values))
|
||||
|
||||
def reset(self) -> "BaseMetric":
|
||||
"""重置累积值
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
self._values = []
|
||||
return self
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PipelineStage",
|
||||
"TaskType",
|
||||
"BaseProcessor",
|
||||
"BaseModel",
|
||||
"BaseSplitter",
|
||||
"BaseMetric",
|
||||
]
|
||||
222
src/models/core/splitter.py
Normal file
222
src/models/core/splitter.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""时间序列数据划分策略
|
||||
|
||||
提供针对金融时间序列的特殊划分策略,防止未来泄露。
|
||||
"""
|
||||
|
||||
from typing import Iterator, List, Tuple
|
||||
import polars as pl
|
||||
|
||||
from src.models.core.base import BaseSplitter
|
||||
|
||||
|
||||
class TimeSeriesSplit(BaseSplitter):
|
||||
"""时间序列划分 - 确保训练数据在测试数据之前
|
||||
|
||||
按照时间顺序进行K折划分,每折的训练数据都在测试数据之前。
|
||||
通过 gap 参数防止训练集和测试集之间的数据泄露。
|
||||
|
||||
Args:
|
||||
n_splits: 划分折数
|
||||
gap: 训练集和测试集之间的间隔天数(防止泄露)
|
||||
min_train_size: 最小训练集大小(天数)
|
||||
"""
|
||||
|
||||
def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252):
|
||||
self.n_splits = n_splits
|
||||
self.gap = gap
|
||||
self.min_train_size = min_train_size
|
||||
|
||||
def split(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||
"""生成训练/测试索引"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
test_size = (n_dates - self.min_train_size) // self.n_splits
|
||||
|
||||
for i in range(self.n_splits):
|
||||
train_end_idx = self.min_train_size + i * test_size
|
||||
test_start_idx = train_end_idx + self.gap
|
||||
test_end_idx = test_start_idx + test_size
|
||||
|
||||
if test_end_idx > n_dates:
|
||||
break
|
||||
|
||||
train_dates = dates[:train_end_idx]
|
||||
test_dates = dates[test_start_idx:test_end_idx]
|
||||
|
||||
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||
|
||||
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||
|
||||
yield train_idx, test_idx
|
||||
|
||||
def get_split_dates(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""获取划分日期范围"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
test_size = (n_dates - self.min_train_size) // self.n_splits
|
||||
|
||||
result = []
|
||||
for i in range(self.n_splits):
|
||||
train_end_idx = self.min_train_size + i * test_size
|
||||
test_start_idx = train_end_idx + self.gap
|
||||
test_end_idx = test_start_idx + test_size
|
||||
|
||||
if test_end_idx > n_dates:
|
||||
break
|
||||
|
||||
result.append(
|
||||
(
|
||||
str(dates[0]),
|
||||
str(dates[train_end_idx - 1]),
|
||||
str(dates[test_start_idx]),
|
||||
str(dates[test_end_idx - 1]),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class WalkForwardSplit(BaseSplitter):
|
||||
"""滚动前向验证 - 训练集逐步扩展
|
||||
|
||||
Args:
|
||||
train_window: 训练集窗口大小(天数)
|
||||
test_window: 测试集窗口大小(天数)
|
||||
gap: 训练集和测试集之间的间隔天数
|
||||
"""
|
||||
|
||||
def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5):
|
||||
self.train_window = train_window
|
||||
self.test_window = test_window
|
||||
self.gap = gap
|
||||
|
||||
def split(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||
"""生成训练/测试索引"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
start_idx = self.train_window
|
||||
while start_idx + self.gap + self.test_window <= n_dates:
|
||||
train_start = start_idx - self.train_window
|
||||
train_end = start_idx
|
||||
test_start = start_idx + self.gap
|
||||
test_end = test_start + self.test_window
|
||||
|
||||
train_dates = dates[train_start:train_end]
|
||||
test_dates = dates[test_start:test_end]
|
||||
|
||||
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||
|
||||
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||
|
||||
yield train_idx, test_idx
|
||||
start_idx += self.test_window
|
||||
|
||||
def get_split_dates(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""获取划分日期范围"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
result = []
|
||||
start_idx = self.train_window
|
||||
while start_idx + self.gap + self.test_window <= n_dates:
|
||||
train_start = start_idx - self.train_window
|
||||
train_end = start_idx
|
||||
test_start = start_idx + self.gap
|
||||
test_end = test_start + self.test_window
|
||||
|
||||
result.append(
|
||||
(
|
||||
str(dates[train_start]),
|
||||
str(dates[train_end - 1]),
|
||||
str(dates[test_start]),
|
||||
str(dates[test_end - 1]),
|
||||
)
|
||||
)
|
||||
start_idx += self.test_window
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ExpandingWindowSplit(BaseSplitter):
|
||||
"""扩展窗口划分 - 训练集不断扩大
|
||||
|
||||
Args:
|
||||
initial_train_size: 初始训练集大小(天数)
|
||||
test_window: 测试集窗口大小(天数)
|
||||
gap: 训练集和测试集之间的间隔天数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, initial_train_size: int = 252, test_window: int = 21, gap: int = 5
|
||||
):
|
||||
self.initial_train_size = initial_train_size
|
||||
self.test_window = test_window
|
||||
self.gap = gap
|
||||
|
||||
def split(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> Iterator[Tuple[List[int], List[int]]]:
|
||||
"""生成训练/测试索引"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
train_end_idx = self.initial_train_size
|
||||
while train_end_idx + self.gap + self.test_window <= n_dates:
|
||||
train_dates = dates[:train_end_idx]
|
||||
test_start = train_end_idx + self.gap
|
||||
test_end = test_start + self.test_window
|
||||
test_dates = dates[test_start:test_end]
|
||||
|
||||
train_mask = data[date_col].is_in(train_dates.to_list())
|
||||
test_mask = data[date_col].is_in(test_dates.to_list())
|
||||
|
||||
train_idx = data.with_row_index().filter(train_mask)["index"].to_list()
|
||||
test_idx = data.with_row_index().filter(test_mask)["index"].to_list()
|
||||
|
||||
yield train_idx, test_idx
|
||||
train_end_idx += self.test_window
|
||||
|
||||
def get_split_dates(
|
||||
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""获取划分日期范围"""
|
||||
dates = data[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
|
||||
result = []
|
||||
train_end_idx = self.initial_train_size
|
||||
while train_end_idx + self.gap + self.test_window <= n_dates:
|
||||
test_start = train_end_idx + self.gap
|
||||
test_end = test_start + self.test_window
|
||||
|
||||
result.append(
|
||||
(
|
||||
str(dates[0]),
|
||||
str(dates[train_end_idx - 1]),
|
||||
str(dates[test_start]),
|
||||
str(dates[test_end - 1]),
|
||||
)
|
||||
)
|
||||
train_end_idx += self.test_window
|
||||
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TimeSeriesSplit",
|
||||
"WalkForwardSplit",
|
||||
"ExpandingWindowSplit",
|
||||
]
|
||||
11
src/models/models/__init__.py
Normal file
11
src/models/models/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""模型模块"""
|
||||
|
||||
from src.models.models.models import (
|
||||
LightGBMModel,
|
||||
CatBoostModel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LightGBMModel",
|
||||
"CatBoostModel",
|
||||
]
|
||||
210
src/models/models/models.py
Normal file
210
src/models/models/models.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""内置机器学习模型
|
||||
|
||||
提供 LightGBM、CatBoost 等模型的统一接口包装器。
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
from src.models.core import BaseModel, TaskType
|
||||
from src.models.registry import PluginRegistry
|
||||
|
||||
|
||||
@PluginRegistry.register_model("lightgbm")
|
||||
class LightGBMModel(BaseModel):
|
||||
"""LightGBM 模型包装器
|
||||
|
||||
支持分类、回归、排序三种任务类型。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: TaskType,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(task_type, params, name)
|
||||
self._model = None
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
X_val: Optional[pl.DataFrame] = None,
|
||||
y_val: Optional[pl.Series] = None,
|
||||
**fit_params,
|
||||
) -> "LightGBMModel":
|
||||
"""训练模型"""
|
||||
try:
|
||||
import lightgbm as lgb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"lightgbm is required. Install with: uv pip install lightgbm"
|
||||
)
|
||||
|
||||
X_arr = X.to_numpy()
|
||||
y_arr = y.to_numpy()
|
||||
|
||||
train_data = lgb.Dataset(X_arr, label=y_arr)
|
||||
valid_sets = [train_data]
|
||||
valid_names = ["train"]
|
||||
|
||||
if X_val is not None and y_val is not None:
|
||||
valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
|
||||
valid_sets.append(valid_data)
|
||||
valid_names.append("valid")
|
||||
|
||||
default_params = {
|
||||
"objective": self._get_objective(),
|
||||
"metric": self._get_metric(),
|
||||
"boosting_type": "gbdt",
|
||||
"num_leaves": 31,
|
||||
"learning_rate": 0.05,
|
||||
"feature_fraction": 0.9,
|
||||
"bagging_fraction": 0.8,
|
||||
"bagging_freq": 5,
|
||||
"verbose": -1,
|
||||
}
|
||||
default_params.update(self.params)
|
||||
|
||||
callbacks = []
|
||||
if len(valid_sets) > 1:
|
||||
callbacks.append(lgb.early_stopping(stopping_rounds=10, verbose=False))
|
||||
|
||||
self._model = lgb.train(
|
||||
default_params,
|
||||
train_data,
|
||||
num_boost_round=fit_params.get("num_boost_round", 100),
|
||||
valid_sets=valid_sets,
|
||||
valid_names=valid_names,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测"""
|
||||
if not self._is_fitted:
|
||||
raise RuntimeError("Model not fitted yet")
|
||||
return self._model.predict(X.to_numpy())
|
||||
|
||||
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测概率(仅分类任务)"""
|
||||
if self.task_type != "classification":
|
||||
raise ValueError("predict_proba only for classification")
|
||||
probs = self.predict(X)
|
||||
if len(probs.shape) == 1:
|
||||
return np.vstack([1 - probs, probs]).T
|
||||
return probs
|
||||
|
||||
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||
"""获取特征重要性"""
|
||||
if self._model is None:
|
||||
return None
|
||||
importance = self._model.feature_importance(importance_type="gain")
|
||||
feature_names = getattr(
|
||||
self._model,
|
||||
"feature_name",
|
||||
lambda: [f"feature_{i}" for i in range(len(importance))],
|
||||
)()
|
||||
return pl.DataFrame({"feature": feature_names, "importance": importance}).sort(
|
||||
"importance", descending=True
|
||||
)
|
||||
|
||||
def _get_objective(self) -> str:
|
||||
objectives = {
|
||||
"classification": "binary",
|
||||
"regression": "regression",
|
||||
"ranking": "lambdarank",
|
||||
}
|
||||
return objectives.get(self.task_type, "regression")
|
||||
|
||||
def _get_metric(self) -> str:
|
||||
metrics = {"classification": "auc", "regression": "rmse", "ranking": "ndcg"}
|
||||
return metrics.get(self.task_type, "rmse")
|
||||
|
||||
|
||||
@PluginRegistry.register_model("catboost")
|
||||
class CatBoostModel(BaseModel):
|
||||
"""CatBoost 模型包装器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_type: TaskType,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(task_type, params, name)
|
||||
self._model = None
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
X_val: Optional[pl.DataFrame] = None,
|
||||
y_val: Optional[pl.Series] = None,
|
||||
**fit_params,
|
||||
) -> "CatBoostModel":
|
||||
"""训练模型"""
|
||||
try:
|
||||
from catboost import CatBoostClassifier, CatBoostRegressor
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"catboost is required. Install with: uv pip install catboost"
|
||||
)
|
||||
|
||||
if self.task_type == "classification":
|
||||
model_class = CatBoostClassifier
|
||||
default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
|
||||
elif self.task_type == "regression":
|
||||
model_class = CatBoostRegressor
|
||||
default_params = {"loss_function": "RMSE"}
|
||||
else:
|
||||
model_class = CatBoostRegressor
|
||||
default_params = {"loss_function": "QueryRMSE"}
|
||||
|
||||
default_params.update(self.params)
|
||||
default_params["verbose"] = False
|
||||
|
||||
self._model = model_class(**default_params)
|
||||
|
||||
eval_set = None
|
||||
if X_val is not None and y_val is not None:
|
||||
eval_set = (X_val.to_pandas(), y_val.to_pandas())
|
||||
|
||||
self._model.fit(
|
||||
X.to_pandas(),
|
||||
y.to_pandas(),
|
||||
eval_set=eval_set,
|
||||
early_stopping_rounds=fit_params.get("early_stopping_rounds", 10),
|
||||
verbose=False,
|
||||
)
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测"""
|
||||
if not self._is_fitted:
|
||||
raise RuntimeError("Model not fitted yet")
|
||||
return self._model.predict(X.to_pandas())
|
||||
|
||||
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测概率"""
|
||||
if self.task_type != "classification":
|
||||
raise ValueError("predict_proba only for classification")
|
||||
return self._model.predict_proba(X.to_pandas())
|
||||
|
||||
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
||||
"""获取特征重要性"""
|
||||
if self._model is None:
|
||||
return None
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"feature": self._model.feature_names_,
|
||||
"importance": self._model.feature_importances_,
|
||||
}
|
||||
).sort("importance", descending=True)
|
||||
|
||||
|
||||
__all__ = ["LightGBMModel", "CatBoostModel"]
|
||||
70
src/models/pipeline.py
Normal file
70
src/models/pipeline.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""数据处理流水线
|
||||
|
||||
管理多个处理器的顺序执行,支持阶段感知处理。
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
import polars as pl
|
||||
|
||||
from src.models.core import BaseProcessor, PipelineStage
|
||||
|
||||
|
||||
class ProcessingPipeline:
|
||||
"""数据处理流水线
|
||||
|
||||
按顺序执行多个处理器,自动处理阶段标记。
|
||||
关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。
|
||||
"""
|
||||
|
||||
def __init__(self, processors: List[BaseProcessor]):
|
||||
"""初始化流水线
|
||||
|
||||
Args:
|
||||
processors: 处理器列表(按执行顺序)
|
||||
"""
|
||||
self.processors = processors
|
||||
self._fitted_processors: Dict[int, BaseProcessor] = {}
|
||||
|
||||
def fit_transform(
|
||||
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN
|
||||
) -> pl.DataFrame:
|
||||
"""在训练数据上fit所有处理器并transform"""
|
||||
result = data
|
||||
for i, processor in enumerate(self.processors):
|
||||
if processor.stage in [PipelineStage.ALL, stage]:
|
||||
result = processor.fit_transform(result)
|
||||
self._fitted_processors[i] = processor
|
||||
elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST:
|
||||
processor.fit(result)
|
||||
self._fitted_processors[i] = processor
|
||||
return result
|
||||
|
||||
def transform(
|
||||
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST
|
||||
) -> pl.DataFrame:
|
||||
"""在测试数据上应用已fit的处理器"""
|
||||
result = data
|
||||
for i, processor in enumerate(self.processors):
|
||||
if processor.stage in [PipelineStage.ALL, stage]:
|
||||
if i in self._fitted_processors:
|
||||
result = self._fitted_processors[i].transform(result)
|
||||
else:
|
||||
result = processor.transform(result)
|
||||
return result
|
||||
|
||||
def save_processors(self, path: str) -> None:
|
||||
"""保存所有已fit的处理器状态"""
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self._fitted_processors, f)
|
||||
|
||||
def load_processors(self, path: str) -> None:
|
||||
"""加载处理器状态"""
|
||||
import pickle
|
||||
|
||||
with open(path, "rb") as f:
|
||||
self._fitted_processors = pickle.load(f)
|
||||
|
||||
|
||||
__all__ = ["ProcessingPipeline"]
|
||||
21
src/models/processors/__init__.py
Normal file
21
src/models/processors/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""处理器模块"""
|
||||
|
||||
from src.models.processors.processors import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
MinMaxScaler,
|
||||
RankTransformer,
|
||||
Neutralizer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DropNAProcessor",
|
||||
"FillNAProcessor",
|
||||
"Winsorizer",
|
||||
"StandardScaler",
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
]
|
||||
238
src/models/processors/processors.py
Normal file
238
src/models/processors/processors.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""内置数据处理器
|
||||
|
||||
提供常用的数据预处理和转换处理器。
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
from src.models.core import BaseProcessor, PipelineStage
|
||||
from src.models.registry import PluginRegistry
|
||||
|
||||
# 数值类型列表
|
||||
FLOAT_TYPES = [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64]
|
||||
|
||||
|
||||
def _get_numeric_columns(
|
||||
data: pl.DataFrame, columns: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""获取数值列"""
|
||||
if columns is not None:
|
||||
return columns
|
||||
return [c for c in data.columns if data[c].dtype in FLOAT_TYPES]
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("dropna")
|
||||
class DropNAProcessor(BaseProcessor):
|
||||
"""缺失值删除处理器"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "DropNAProcessor":
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
cols = self.columns or data.columns
|
||||
return data.drop_nulls(subset=cols)
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("fillna")
|
||||
class FillNAProcessor(BaseProcessor):
|
||||
"""缺失值填充处理器(只在训练阶段计算填充值)"""
|
||||
|
||||
stage = PipelineStage.TRAIN
|
||||
|
||||
def __init__(self, columns: Optional[List[str]] = None, method: str = "median"):
|
||||
super().__init__(columns)
|
||||
if method not in ["median", "mean", "zero"]:
|
||||
raise ValueError(f"Unknown fill method: {method}")
|
||||
self.method = method
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "FillNAProcessor":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
fill_values = {}
|
||||
|
||||
for col in cols:
|
||||
if self.method == "median":
|
||||
fill_values[col] = data[col].median()
|
||||
elif self.method == "mean":
|
||||
fill_values[col] = data[col].mean()
|
||||
elif self.method == "zero":
|
||||
fill_values[col] = 0.0
|
||||
|
||||
self._fitted_params = {"fill_values": fill_values, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
for col, val in self._fitted_params.get("fill_values", {}).items():
|
||||
if col in result.columns:
|
||||
result = result.with_columns(pl.col(col).fill_null(val).alias(col))
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("winsorizer")
|
||||
class Winsorizer(BaseProcessor):
|
||||
"""缩尾处理器 - 防止极端值影响(只在训练阶段计算分位数)"""
|
||||
|
||||
stage = PipelineStage.TRAIN
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
columns: Optional[List[str]] = None,
|
||||
lower: float = 0.01,
|
||||
upper: float = 0.99,
|
||||
):
|
||||
super().__init__(columns)
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "Winsorizer":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
bounds = {}
|
||||
|
||||
for col in cols:
|
||||
bounds[col] = {
|
||||
"lower": data[col].quantile(self.lower),
|
||||
"upper": data[col].quantile(self.upper),
|
||||
}
|
||||
|
||||
self._fitted_params = {"bounds": bounds, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
for col, bounds in self._fitted_params.get("bounds", {}).items():
|
||||
if col in result.columns:
|
||||
result = result.with_columns(
|
||||
pl.col(col).clip(bounds["lower"], bounds["upper"]).alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("standard_scaler")
|
||||
class StandardScaler(BaseProcessor):
|
||||
"""标准化处理器 - Z-score标准化"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "StandardScaler":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
stats = {}
|
||||
|
||||
for col in cols:
|
||||
stats[col] = {"mean": data[col].mean(), "std": data[col].std()}
|
||||
|
||||
self._fitted_params = {"stats": stats, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
for col, stats in self._fitted_params.get("stats", {}).items():
|
||||
if col in result.columns and stats["std"] is not None and stats["std"] > 0:
|
||||
result = result.with_columns(
|
||||
((pl.col(col) - stats["mean"]) / stats["std"]).alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("minmax_scaler")
|
||||
class MinMaxScaler(BaseProcessor):
|
||||
"""归一化处理器 - 缩放到[0, 1]范围"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "MinMaxScaler":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
stats = {}
|
||||
|
||||
for col in cols:
|
||||
stats[col] = {"min": data[col].min(), "max": data[col].max()}
|
||||
|
||||
self._fitted_params = {"stats": stats, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
for col, stats in self._fitted_params.get("stats", {}).items():
|
||||
if col in result.columns:
|
||||
range_val = stats["max"] - stats["min"]
|
||||
if range_val is not None and range_val > 0:
|
||||
result = result.with_columns(
|
||||
((pl.col(col) - stats["min"]) / range_val).alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("rank_transformer")
|
||||
class RankTransformer(BaseProcessor):
|
||||
"""排名转换处理器 - 转换为截面排名"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "RankTransformer":
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
cols = self.columns or _get_numeric_columns(data)
|
||||
|
||||
for col in cols:
|
||||
if col in result.columns:
|
||||
result = result.with_columns(
|
||||
pl.col(col).rank().over("trade_date").alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("neutralizer")
|
||||
class Neutralizer(BaseProcessor):
|
||||
"""中性化处理器 - 行业/市值中性化"""
|
||||
|
||||
stage = PipelineStage.ALL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
columns: Optional[List[str]] = None,
|
||||
group_col: str = "industry",
|
||||
exclude_cols: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(columns)
|
||||
self.group_col = group_col
|
||||
self.exclude_cols = exclude_cols or []
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "Neutralizer":
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
result = data
|
||||
cols = self.columns or _get_numeric_columns(data)
|
||||
|
||||
for col in cols:
|
||||
if col in result.columns and col not in self.exclude_cols:
|
||||
result = result.with_columns(
|
||||
(
|
||||
pl.col(col)
|
||||
- pl.col(col).mean().over(["trade_date", self.group_col])
|
||||
).alias(col)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DropNAProcessor",
|
||||
"FillNAProcessor",
|
||||
"Winsorizer",
|
||||
"StandardScaler",
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
]
|
||||
297
src/models/registry.py
Normal file
297
src/models/registry.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""插件注册中心
|
||||
|
||||
提供装饰器方式注册处理器、模型、划分策略等组件。
|
||||
实现真正的插件式架构 - 新功能只需注册即可使用。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_processor("standard_scaler")
|
||||
... class StandardScaler(BaseProcessor):
|
||||
... pass
|
||||
|
||||
>>> # 使用
|
||||
>>> scaler = PluginRegistry.get_processor("standard_scaler")()
|
||||
"""
|
||||
|
||||
from typing import Type, Dict, List, TypeVar, Optional
|
||||
from functools import wraps
|
||||
from weakref import WeakValueDictionary
|
||||
import contextlib
|
||||
|
||||
from src.models.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PluginRegistry:
|
||||
"""插件注册中心
|
||||
|
||||
管理所有组件的注册和获取。使用装饰器方式注册新组件。
|
||||
|
||||
Attributes:
|
||||
_processors: 已注册的处理器字典
|
||||
_models: 已注册的模型字典
|
||||
_splitters: 已注册的划分策略字典
|
||||
_metrics: 已注册的评估指标字典
|
||||
"""
|
||||
|
||||
_processors: Dict[str, Type[BaseProcessor]] = {}
|
||||
_models: Dict[str, Type[BaseModel]] = {}
|
||||
_splitters: Dict[str, Type[BaseSplitter]] = {}
|
||||
_metrics: Dict[str, Type[BaseMetric]] = {}
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def temp_registry(cls):
|
||||
"""临时注册上下文管理器
|
||||
|
||||
在上下文管理器内部注册的组件会在退出时自动清理,
|
||||
避免测试之间的状态污染。
|
||||
|
||||
示例:
|
||||
>>> with PluginRegistry.temp_registry():
|
||||
... @PluginRegistry.register_processor("temp_processor")
|
||||
... class TempProcessor(BaseProcessor):
|
||||
... pass
|
||||
... # 在此处可以使用 temp_processor
|
||||
... # 退出后自动清理
|
||||
"""
|
||||
original_state = {
|
||||
"_processors": cls._processors.copy(),
|
||||
"_models": cls._models.copy(),
|
||||
"_splitters": cls._splitters.copy(),
|
||||
"_metrics": cls._metrics.copy(),
|
||||
}
|
||||
try:
|
||||
yield cls
|
||||
finally:
|
||||
cls._processors = original_state["_processors"]
|
||||
cls._models = original_state["_models"]
|
||||
cls._splitters = original_state["_splitters"]
|
||||
cls._metrics = original_state["_metrics"]
|
||||
|
||||
@classmethod
|
||||
def register_processor(cls, name: Optional[str] = None):
|
||||
"""注册处理器装饰器
|
||||
|
||||
用于装饰器方式注册数据处理器。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_processor("standard_scaler")
|
||||
... class StandardScaler(BaseProcessor):
|
||||
... pass
|
||||
|
||||
>>> # 获取并使用
|
||||
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
|
||||
>>> scaler = scaler_class()
|
||||
|
||||
Args:
|
||||
name: 注册名称,默认为类名
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]:
|
||||
key = name or processor_class.__name__
|
||||
cls._processors[key] = processor_class
|
||||
processor_class._registry_name = key
|
||||
return processor_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def register_model(cls, name: Optional[str] = None):
|
||||
"""注册模型装饰器
|
||||
|
||||
用于装饰器方式注册机器学习模型。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_model("lightgbm")
|
||||
... class LightGBMModel(BaseModel):
|
||||
... pass
|
||||
|
||||
Args:
|
||||
name: 注册名称,默认为类名
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]:
|
||||
key = name or model_class.__name__
|
||||
cls._models[key] = model_class
|
||||
model_class._registry_name = key
|
||||
return model_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def register_splitter(cls, name: Optional[str] = None):
|
||||
"""注册划分策略装饰器
|
||||
|
||||
用于装饰器方式注册数据划分策略。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_splitter("time_series")
|
||||
... class TimeSeriesSplit(BaseSplitter):
|
||||
... pass
|
||||
|
||||
Args:
|
||||
name: 注册名称,默认为类名
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]:
|
||||
key = name or splitter_class.__name__
|
||||
cls._splitters[key] = splitter_class
|
||||
splitter_class._registry_name = key
|
||||
return splitter_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def register_metric(cls, name: Optional[str] = None):
|
||||
"""注册评估指标装饰器
|
||||
|
||||
用于装饰器方式注册评估指标。
|
||||
|
||||
示例:
|
||||
>>> @PluginRegistry.register_metric("ic")
|
||||
... class ICMetric(BaseMetric):
|
||||
... pass
|
||||
|
||||
Args:
|
||||
name: 注册名称,默认为类名
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
"""
|
||||
|
||||
def decorator(metric_class: Type[BaseMetric]) -> Type[BaseMetric]:
|
||||
key = name or metric_class.__name__
|
||||
cls._metrics[key] = metric_class
|
||||
metric_class._registry_name = key
|
||||
return metric_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_processor(cls, name: str) -> Type[BaseProcessor]:
|
||||
"""获取处理器类
|
||||
|
||||
Args:
|
||||
name: 处理器注册名称
|
||||
|
||||
Returns:
|
||||
处理器类
|
||||
|
||||
Raises:
|
||||
KeyError: 处理器不存在时抛出
|
||||
"""
|
||||
if name not in cls._processors:
|
||||
available = list(cls._processors.keys())
|
||||
raise KeyError(f"Processor '{name}' not found. Available: {available}")
|
||||
return cls._processors[name]
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, name: str) -> Type[BaseModel]:
|
||||
"""获取模型类
|
||||
|
||||
Args:
|
||||
name: 模型注册名称
|
||||
|
||||
Returns:
|
||||
模型类
|
||||
|
||||
Raises:
|
||||
KeyError: 模型不存在时抛出
|
||||
"""
|
||||
if name not in cls._models:
|
||||
available = list(cls._models.keys())
|
||||
raise KeyError(f"Model '{name}' not found. Available: {available}")
|
||||
return cls._models[name]
|
||||
|
||||
@classmethod
|
||||
def get_splitter(cls, name: str) -> Type[BaseSplitter]:
|
||||
"""获取划分策略类
|
||||
|
||||
Args:
|
||||
name: 划分策略注册名称
|
||||
|
||||
Returns:
|
||||
划分策略类
|
||||
|
||||
Raises:
|
||||
KeyError: 划分策略不存在时抛出
|
||||
"""
|
||||
if name not in cls._splitters:
|
||||
available = list(cls._splitters.keys())
|
||||
raise KeyError(f"Splitter '{name}' not found. Available: {available}")
|
||||
return cls._splitters[name]
|
||||
|
||||
@classmethod
|
||||
def get_metric(cls, name: str) -> Type[BaseMetric]:
|
||||
"""获取评估指标类
|
||||
|
||||
Args:
|
||||
name: 评估指标注册名称
|
||||
|
||||
Returns:
|
||||
评估指标类
|
||||
|
||||
Raises:
|
||||
KeyError: 评估指标不存在时抛出
|
||||
"""
|
||||
if name not in cls._metrics:
|
||||
available = list(cls._metrics.keys())
|
||||
raise KeyError(f"Metric '{name}' not found. Available: {available}")
|
||||
return cls._metrics[name]
|
||||
|
||||
@classmethod
|
||||
def list_processors(cls) -> List[str]:
|
||||
"""列出所有可用处理器
|
||||
|
||||
Returns:
|
||||
处理器名称列表
|
||||
"""
|
||||
return list(cls._processors.keys())
|
||||
|
||||
@classmethod
|
||||
def list_models(cls) -> List[str]:
|
||||
"""列出所有可用模型
|
||||
|
||||
Returns:
|
||||
模型名称列表
|
||||
"""
|
||||
return list(cls._models.keys())
|
||||
|
||||
@classmethod
|
||||
def list_splitters(cls) -> List[str]:
|
||||
"""列出所有可用划分策略
|
||||
|
||||
Returns:
|
||||
划分策略名称列表
|
||||
"""
|
||||
return list(cls._splitters.keys())
|
||||
|
||||
@classmethod
|
||||
def list_metrics(cls) -> List[str]:
|
||||
"""列出所有可用评估指标
|
||||
|
||||
Returns:
|
||||
评估指标名称列表
|
||||
"""
|
||||
return list(cls._metrics.keys())
|
||||
|
||||
@classmethod
|
||||
def clear_all(cls) -> None:
|
||||
"""清除所有注册(主要用于测试)"""
|
||||
cls._processors.clear()
|
||||
cls._models.clear()
|
||||
cls._splitters.clear()
|
||||
cls._metrics.clear()
|
||||
|
||||
|
||||
__all__ = ["PluginRegistry"]
|
||||
Reference in New Issue
Block a user