feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类 - 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露 - 内置 8 个数据处理器和 3 种时序划分策略 - 支持 LightGBM、CatBoost 模型 - PluginRegistry 装饰器注册,插件式架构 - 22 个单元测试
This commit is contained in:
70
src/models/pipeline.py
Normal file
70
src/models/pipeline.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""数据处理流水线
|
||||
|
||||
管理多个处理器的顺序执行,支持阶段感知处理。
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
import polars as pl
|
||||
|
||||
from src.models.core import BaseProcessor, PipelineStage
|
||||
|
||||
|
||||
class ProcessingPipeline:
|
||||
"""数据处理流水线
|
||||
|
||||
按顺序执行多个处理器,自动处理阶段标记。
|
||||
关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。
|
||||
"""
|
||||
|
||||
def __init__(self, processors: List[BaseProcessor]):
|
||||
"""初始化流水线
|
||||
|
||||
Args:
|
||||
processors: 处理器列表(按执行顺序)
|
||||
"""
|
||||
self.processors = processors
|
||||
self._fitted_processors: Dict[int, BaseProcessor] = {}
|
||||
|
||||
def fit_transform(
|
||||
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN
|
||||
) -> pl.DataFrame:
|
||||
"""在训练数据上fit所有处理器并transform"""
|
||||
result = data
|
||||
for i, processor in enumerate(self.processors):
|
||||
if processor.stage in [PipelineStage.ALL, stage]:
|
||||
result = processor.fit_transform(result)
|
||||
self._fitted_processors[i] = processor
|
||||
elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST:
|
||||
processor.fit(result)
|
||||
self._fitted_processors[i] = processor
|
||||
return result
|
||||
|
||||
def transform(
|
||||
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST
|
||||
) -> pl.DataFrame:
|
||||
"""在测试数据上应用已fit的处理器"""
|
||||
result = data
|
||||
for i, processor in enumerate(self.processors):
|
||||
if processor.stage in [PipelineStage.ALL, stage]:
|
||||
if i in self._fitted_processors:
|
||||
result = self._fitted_processors[i].transform(result)
|
||||
else:
|
||||
result = processor.transform(result)
|
||||
return result
|
||||
|
||||
def save_processors(self, path: str) -> None:
|
||||
"""保存所有已fit的处理器状态"""
|
||||
import pickle
|
||||
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self._fitted_processors, f)
|
||||
|
||||
def load_processors(self, path: str) -> None:
|
||||
"""加载处理器状态"""
|
||||
import pickle
|
||||
|
||||
with open(path, "rb") as f:
|
||||
self._fitted_processors = pickle.load(f)
|
||||
|
||||
|
||||
__all__ = ["ProcessingPipeline"]
|
||||
Reference in New Issue
Block a user