feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类 - 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露 - 内置 8 个数据处理器和 3 种时序划分策略 - 支持 LightGBM、CatBoost 模型 - PluginRegistry 装饰器注册,插件式架构 - 22 个单元测试
This commit is contained in:
86
src/models/__init__.py
Normal file
86
src/models/__init__.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""ProStock 模型训练框架
|
||||
|
||||
组件化、低耦合、插件式的机器学习训练框架。
|
||||
|
||||
示例:
|
||||
>>> from src.models import (
|
||||
... PluginRegistry, ProcessingPipeline,
|
||||
... PipelineStage, BaseProcessor
|
||||
... )
|
||||
|
||||
>>> # 获取注册的处理器
|
||||
>>> scaler_class = PluginRegistry.get_processor("standard_scaler")
|
||||
>>> scaler = scaler_class()
|
||||
|
||||
>>> # 创建处理流水线
|
||||
>>> pipeline = ProcessingPipeline([
|
||||
... PluginRegistry.get_processor("dropna")(),
|
||||
... PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
|
||||
... PluginRegistry.get_processor("standard_scaler")(),
|
||||
... ])
|
||||
"""
|
||||
|
||||
# 导入核心抽象类和划分策略
|
||||
from src.models.core import (
|
||||
PipelineStage,
|
||||
TaskType,
|
||||
BaseProcessor,
|
||||
BaseModel,
|
||||
BaseSplitter,
|
||||
BaseMetric,
|
||||
TimeSeriesSplit,
|
||||
WalkForwardSplit,
|
||||
ExpandingWindowSplit,
|
||||
)
|
||||
|
||||
# 导入注册中心
|
||||
from src.models.registry import PluginRegistry
|
||||
|
||||
# 导入处理流水线
|
||||
from src.models.pipeline import ProcessingPipeline
|
||||
|
||||
# 导入并注册内置处理器
|
||||
from src.models.processors.processors import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
Winsorizer,
|
||||
StandardScaler,
|
||||
MinMaxScaler,
|
||||
RankTransformer,
|
||||
Neutralizer,
|
||||
)
|
||||
|
||||
# 导入并注册内置模型
|
||||
from src.models.models.models import (
|
||||
LightGBMModel,
|
||||
CatBoostModel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 核心抽象
|
||||
"PipelineStage",
|
||||
"TaskType",
|
||||
"BaseProcessor",
|
||||
"BaseModel",
|
||||
"BaseSplitter",
|
||||
"BaseMetric",
|
||||
# 划分策略
|
||||
"TimeSeriesSplit",
|
||||
"WalkForwardSplit",
|
||||
"ExpandingWindowSplit",
|
||||
# 注册中心
|
||||
"PluginRegistry",
|
||||
# 处理流水线
|
||||
"ProcessingPipeline",
|
||||
# 处理器
|
||||
"DropNAProcessor",
|
||||
"FillNAProcessor",
|
||||
"Winsorizer",
|
||||
"StandardScaler",
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
# 模型
|
||||
"LightGBMModel",
|
||||
"CatBoostModel",
|
||||
]
|
||||
Reference in New Issue
Block a user