refactor: 代码审查修复 - 日期过滤、性能优化、数据泄露防护
- 修复 data_loader.py 财务数据日期过滤,支持按范围加载 - 优化 MADClipper 使用窗口函数替代 join,提升性能 - 修复训练日期边界问题,添加1天间隔避免数据泄露 - 新增 .gitignore 规则忽略训练输出目录
This commit is contained in:
@@ -8,6 +8,7 @@ from src.pipeline.processors.processors import (
|
||||
MinMaxScaler,
|
||||
RankTransformer,
|
||||
Neutralizer,
|
||||
MADClipper,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -18,4 +19,5 @@ __all__ = [
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
"MADClipper",
|
||||
]
|
||||
|
||||
@@ -3,9 +3,8 @@
|
||||
提供常用的数据预处理和转换处理器。
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
from src.pipeline.core import BaseProcessor, PipelineStage
|
||||
from src.pipeline.registry import PluginRegistry
|
||||
@@ -227,6 +226,64 @@ class Neutralizer(BaseProcessor):
|
||||
return result
|
||||
|
||||
|
||||
@PluginRegistry.register_processor("mad_clipper")
|
||||
class MADClipper(BaseProcessor):
|
||||
"""MAD去极值处理器 - 基于每日截面的中位数绝对偏差去除极值
|
||||
|
||||
使用3倍MAD作为阈值,比标准差方法更稳健,对异常值不敏感。
|
||||
阈值范围: [median - n*MAD, median + n*MAD]
|
||||
"""
|
||||
|
||||
stage = PipelineStage.TRAIN
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
columns: Optional[List[str]] = None,
|
||||
n_mad: float = 3.0,
|
||||
):
|
||||
super().__init__(columns)
|
||||
self.n_mad = n_mad
|
||||
|
||||
def fit(self, data: pl.DataFrame) -> "MADClipper":
|
||||
cols = _get_numeric_columns(data, self.columns)
|
||||
bounds = {}
|
||||
|
||||
for col in cols:
|
||||
# 按日期分组计算每个截面的 median 和 MAD
|
||||
daily_stats = data.group_by("trade_date").agg(
|
||||
pl.col(col).median().alias("median"),
|
||||
(pl.col(col) - pl.col(col).median()).abs().median().alias("mad"),
|
||||
)
|
||||
bounds[col] = daily_stats
|
||||
|
||||
self._fitted_params = {"bounds": bounds, "columns": cols}
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用窗口函数进行MAD去极值,避免join操作提升性能"""
|
||||
result = data
|
||||
bounds = self._fitted_params.get("bounds", {})
|
||||
|
||||
for col in bounds.keys():
|
||||
if col not in result.columns:
|
||||
continue
|
||||
|
||||
# 使用窗口函数直接计算每个截面的median和MAD,避免join
|
||||
# 1. 计算每个日期截面的median
|
||||
median = pl.col(col).median().over("trade_date")
|
||||
# 2. 计算每个日期截面的MAD
|
||||
mad = (pl.col(col) - median).abs().median().over("trade_date")
|
||||
|
||||
# 3. 计算上下界并clip
|
||||
lower = median - self.n_mad * mad
|
||||
upper = median + self.n_mad * mad
|
||||
|
||||
result = result.with_columns(pl.col(col).clip(lower, upper).alias(col))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DropNAProcessor",
|
||||
"FillNAProcessor",
|
||||
@@ -235,4 +292,5 @@ __all__ = [
|
||||
"MinMaxScaler",
|
||||
"RankTransformer",
|
||||
"Neutralizer",
|
||||
"MADClipper",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user