Files
ProStock/src/models/pipeline.py
liaozhaorun 9f95be56a0 feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类
- 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露
- 内置 8 个数据处理器和 3 种时序划分策略
- 支持 LightGBM、CatBoost 模型
- PluginRegistry 装饰器注册,插件式架构
- 22 个单元测试
2026-02-23 01:37:34 +08:00

71 lines
2.3 KiB
Python

"""数据处理流水线
管理多个处理器的顺序执行,支持阶段感知处理。
"""
from typing import List, Dict
import polars as pl
from src.models.core import BaseProcessor, PipelineStage
class ProcessingPipeline:
"""数据处理流水线
按顺序执行多个处理器,自动处理阶段标记。
关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。
"""
def __init__(self, processors: List[BaseProcessor]):
"""初始化流水线
Args:
processors: 处理器列表(按执行顺序)
"""
self.processors = processors
self._fitted_processors: Dict[int, BaseProcessor] = {}
def fit_transform(
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN
) -> pl.DataFrame:
"""在训练数据上fit所有处理器并transform"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
result = processor.fit_transform(result)
self._fitted_processors[i] = processor
elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST:
processor.fit(result)
self._fitted_processors[i] = processor
return result
def transform(
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST
) -> pl.DataFrame:
"""在测试数据上应用已fit的处理器"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
if i in self._fitted_processors:
result = self._fitted_processors[i].transform(result)
else:
result = processor.transform(result)
return result
def save_processors(self, path: str) -> None:
"""保存所有已fit的处理器状态"""
import pickle
with open(path, "wb") as f:
pickle.dump(self._fitted_processors, f)
def load_processors(self, path: str) -> None:
"""加载处理器状态"""
import pickle
with open(path, "rb") as f:
self._fitted_processors = pickle.load(f)
__all__ = ["ProcessingPipeline"]