refactor: 代码审查修复 - 日期过滤、性能优化、数据泄露防护

- 修复 data_loader.py 财务数据日期过滤,支持按范围加载
- 优化 MADClipper 使用窗口函数替代 join,提升性能
- 修复训练日期边界问题,添加1天间隔避免数据泄露
- 新增 .gitignore 规则忽略训练输出目录
This commit is contained in:
2026-02-25 21:11:19 +08:00
parent 593ec99466
commit a9e4746239
24 changed files with 3597 additions and 56 deletions

View File

@@ -8,6 +8,7 @@ from src.pipeline.processors.processors import (
MinMaxScaler,
RankTransformer,
Neutralizer,
MADClipper,
)
__all__ = [
@@ -18,4 +19,5 @@ __all__ = [
"MinMaxScaler",
"RankTransformer",
"Neutralizer",
"MADClipper",
]

View File

@@ -3,9 +3,8 @@
提供常用的数据预处理和转换处理器。
"""
from typing import List, Optional, Dict, Any
from typing import List, Optional
import polars as pl
import numpy as np
from src.pipeline.core import BaseProcessor, PipelineStage
from src.pipeline.registry import PluginRegistry
@@ -227,6 +226,64 @@ class Neutralizer(BaseProcessor):
return result
@PluginRegistry.register_processor("mad_clipper")
class MADClipper(BaseProcessor):
"""MAD去极值处理器 - 基于每日截面的中位数绝对偏差去除极值
使用3倍MAD作为阈值比标准差方法更稳健对异常值不敏感。
阈值范围: [median - n*MAD, median + n*MAD]
"""
stage = PipelineStage.TRAIN
def __init__(
self,
columns: Optional[List[str]] = None,
n_mad: float = 3.0,
):
super().__init__(columns)
self.n_mad = n_mad
def fit(self, data: pl.DataFrame) -> "MADClipper":
cols = _get_numeric_columns(data, self.columns)
bounds = {}
for col in cols:
# 按日期分组计算每个截面的 median 和 MAD
daily_stats = data.group_by("trade_date").agg(
pl.col(col).median().alias("median"),
(pl.col(col) - pl.col(col).median()).abs().median().alias("mad"),
)
bounds[col] = daily_stats
self._fitted_params = {"bounds": bounds, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""使用窗口函数进行MAD去极值避免join操作提升性能"""
result = data
bounds = self._fitted_params.get("bounds", {})
for col in bounds.keys():
if col not in result.columns:
continue
# 使用窗口函数直接计算每个截面的median和MAD避免join
# 1. 计算每个日期截面的median
median = pl.col(col).median().over("trade_date")
# 2. 计算每个日期截面的MAD
mad = (pl.col(col) - median).abs().median().over("trade_date")
# 3. 计算上下界并clip
lower = median - self.n_mad * mad
upper = median + self.n_mad * mad
result = result.with_columns(pl.col(col).clip(lower, upper).alias(col))
return result
__all__ = [
"DropNAProcessor",
"FillNAProcessor",
@@ -235,4 +292,5 @@ __all__ = [
"MinMaxScaler",
"RankTransformer",
"Neutralizer",
"MADClipper",
]