"""数据处理流水线 管理多个处理器的顺序执行,支持阶段感知处理。 """ from typing import List, Dict import polars as pl from src.models.core import BaseProcessor, PipelineStage class ProcessingPipeline: """数据处理流水线 按顺序执行多个处理器,自动处理阶段标记。 关键特性:在测试阶段使用训练阶段学习到的参数,防止数据泄露。 """ def __init__(self, processors: List[BaseProcessor]): """初始化流水线 Args: processors: 处理器列表(按执行顺序) """ self.processors = processors self._fitted_processors: Dict[int, BaseProcessor] = {} def fit_transform( self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN ) -> pl.DataFrame: """在训练数据上fit所有处理器并transform""" result = data for i, processor in enumerate(self.processors): if processor.stage in [PipelineStage.ALL, stage]: result = processor.fit_transform(result) self._fitted_processors[i] = processor elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST: processor.fit(result) self._fitted_processors[i] = processor return result def transform( self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST ) -> pl.DataFrame: """在测试数据上应用已fit的处理器""" result = data for i, processor in enumerate(self.processors): if processor.stage in [PipelineStage.ALL, stage]: if i in self._fitted_processors: result = self._fitted_processors[i].transform(result) else: result = processor.transform(result) return result def save_processors(self, path: str) -> None: """保存所有已fit的处理器状态""" import pickle with open(path, "wb") as f: pickle.dump(self._fitted_processors, f) def load_processors(self, path: str) -> None: """加载处理器状态""" import pickle with open(path, "rb") as f: self._fitted_processors = pickle.load(f) __all__ = ["ProcessingPipeline"]