Files
ProStock/src/pipeline/pipeline.py

71 lines
2.3 KiB
Python
Raw Normal View History

"""数据处理流水线
管理多个处理器的顺序执行支持阶段感知处理
"""
from typing import List, Dict
import polars as pl
from src.pipeline.core import BaseProcessor, PipelineStage
class ProcessingPipeline:
"""数据处理流水线
按顺序执行多个处理器自动处理阶段标记
关键特性在测试阶段使用训练阶段学习到的参数防止数据泄露
"""
def __init__(self, processors: List[BaseProcessor]):
"""初始化流水线
Args:
processors: 处理器列表按执行顺序
"""
self.processors = processors
self._fitted_processors: Dict[int, BaseProcessor] = {}
def fit_transform(
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TRAIN
) -> pl.DataFrame:
"""在训练数据上fit所有处理器并transform"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
result = processor.fit_transform(result)
self._fitted_processors[i] = processor
elif stage == PipelineStage.TRAIN and processor.stage == PipelineStage.TEST:
processor.fit(result)
self._fitted_processors[i] = processor
return result
def transform(
self, data: pl.DataFrame, stage: PipelineStage = PipelineStage.TEST
) -> pl.DataFrame:
"""在测试数据上应用已fit的处理器"""
result = data
for i, processor in enumerate(self.processors):
if processor.stage in [PipelineStage.ALL, stage]:
if i in self._fitted_processors:
result = self._fitted_processors[i].transform(result)
else:
result = processor.transform(result)
return result
def save_processors(self, path: str) -> None:
"""保存所有已fit的处理器状态"""
import pickle
with open(path, "wb") as f:
pickle.dump(self._fitted_processors, f)
def load_processors(self, path: str) -> None:
"""加载处理器状态"""
import pickle
with open(path, "rb") as f:
self._fitted_processors = pickle.load(f)
__all__ = ["ProcessingPipeline"]