feat(factors): 新增 Phase 1-2 数学和统计因子函数
- 新增 atan, log1p 数学函数 - 新增 ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema 统计函数 - 新增 ts_atr, ts_rsi, ts_obv TA-Lib 技术指标函数 - 新增完整集成测试覆盖所有新函数
This commit is contained in:
@@ -6,8 +6,18 @@
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
# TA-Lib 可选依赖
|
||||
try:
|
||||
import talib
|
||||
|
||||
HAS_TALIB = True
|
||||
except ImportError:
|
||||
HAS_TALIB = False
|
||||
talib = None
|
||||
|
||||
from src.factors.decorators import cross_section, element_wise, time_series
|
||||
from src.factors.dsl import (
|
||||
BinaryOpNode,
|
||||
@@ -53,6 +63,14 @@ class PolarsTranslator:
|
||||
self.register_handler("ts_delta", self._handle_ts_delta)
|
||||
self.register_handler("ts_corr", self._handle_ts_corr)
|
||||
self.register_handler("ts_cov", self._handle_ts_cov)
|
||||
self.register_handler("ts_var", self._handle_ts_var)
|
||||
self.register_handler("ts_skew", self._handle_ts_skew)
|
||||
self.register_handler("ts_kurt", self._handle_ts_kurt)
|
||||
self.register_handler("ts_pct_change", self._handle_ts_pct_change)
|
||||
self.register_handler("ts_ema", self._handle_ts_ema)
|
||||
self.register_handler("ts_atr", self._handle_ts_atr)
|
||||
self.register_handler("ts_rsi", self._handle_ts_rsi)
|
||||
self.register_handler("ts_obv", self._handle_ts_obv)
|
||||
|
||||
# 截面因子处理器 (cs_*)
|
||||
self.register_handler("cs_rank", self._handle_cs_rank)
|
||||
@@ -66,6 +84,8 @@ class PolarsTranslator:
|
||||
self.register_handler("sign", self._handle_sign)
|
||||
self.register_handler("cos", self._handle_cos)
|
||||
self.register_handler("sin", self._handle_sin)
|
||||
self.register_handler("atan", self._handle_atan)
|
||||
self.register_handler("log1p", self._handle_log1p)
|
||||
|
||||
def register_handler(
|
||||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||||
@@ -295,6 +315,143 @@ class PolarsTranslator:
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_cov(y, window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_var(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_var(close, window) -> rolling_var(window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_var 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_var(window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_skew(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_skew(close, window) -> rolling_skew(window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_skew 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_skew(window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_kurt(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_kurt(close, window) -> rolling_kurt(window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_kurt 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
# 使用 rolling_map 计算峰度
|
||||
return expr.rolling_map(
|
||||
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
|
||||
window_size=window,
|
||||
)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_pct_change(x, n) -> (x - shift(n)) / shift(n)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_pct_change 需要 2 个参数: (expr, periods)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
shifted = expr.shift(n)
|
||||
return (expr - shifted) / shifted
|
||||
|
||||
@time_series
|
||||
def _handle_ts_ema(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_ema(x, window) -> ewm_mean(span=window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_ema 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.ewm_mean(span=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_atr(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_atr(high, low, close, window) -> 使用 TA-Lib 计算 ATR。
|
||||
|
||||
使用 map_batches 在每个分组上应用 TA-Lib ATR 函数。
|
||||
@time_series 装饰器会自动添加 .over("ts_code")
|
||||
"""
|
||||
if not HAS_TALIB:
|
||||
raise ImportError("ts_atr 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||||
if len(node.args) != 4:
|
||||
raise ValueError("ts_atr 需要 4 个参数: (high, low, close, window)")
|
||||
|
||||
high = self.translate(node.args[0])
|
||||
low = self.translate(node.args[1])
|
||||
close = self.translate(node.args[2])
|
||||
window = self._extract_window(node.args[3])
|
||||
|
||||
# 使用 map_batches 应用 TA-Lib ATR 到整个分组
|
||||
def calc_atr(struct_series: pl.Series) -> pl.Series:
|
||||
"""计算 ATR 的辅助函数。"""
|
||||
if len(struct_series) == 0:
|
||||
return pl.Series([float("nan")] * len(struct_series))
|
||||
|
||||
# struct_series 包含 h, l, c 三个字段
|
||||
h = np.array(struct_series.struct.field("h").to_list(), dtype=float)
|
||||
l = np.array(struct_series.struct.field("l").to_list(), dtype=float)
|
||||
c = np.array(struct_series.struct.field("c").to_list(), dtype=float)
|
||||
result = talib.ATR(h, l, c, timeperiod=window)
|
||||
return pl.Series(result)
|
||||
|
||||
return pl.struct(
|
||||
[high.alias("h"), low.alias("l"), close.alias("c")]
|
||||
).map_batches(calc_atr)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_rsi(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_rsi(close, window) -> 使用 TA-Lib 计算 RSI。
|
||||
|
||||
使用 map_batches 在每个分组上应用 TA-Lib RSI 函数。
|
||||
@time_series 装饰器会自动添加 .over("ts_code")
|
||||
"""
|
||||
if not HAS_TALIB:
|
||||
raise ImportError("ts_rsi 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_rsi 需要 2 个参数: (close, window)")
|
||||
|
||||
close = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
|
||||
# 使用 map_batches 应用 TA-Lib RSI 到整个分组
|
||||
def calc_rsi(series: pl.Series) -> pl.Series:
|
||||
"""计算 RSI 的辅助函数。"""
|
||||
values = np.array(series.to_list(), dtype=float)
|
||||
result = talib.RSI(values, timeperiod=window)
|
||||
return pl.Series(result)
|
||||
|
||||
return close.map_batches(calc_rsi)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_obv(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_obv(close, volume) -> 使用 TA-Lib 计算 OBV。
|
||||
|
||||
使用 map_batches 在每个分组上应用 TA-Lib OBV 函数。
|
||||
@time_series 装饰器会自动添加 .over("ts_code")
|
||||
"""
|
||||
if not HAS_TALIB:
|
||||
raise ImportError("ts_obv 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_obv 需要 2 个参数: (close, volume)")
|
||||
|
||||
close = self.translate(node.args[0])
|
||||
volume = self.translate(node.args[1])
|
||||
|
||||
# 使用 map_batches 应用 TA-Lib OBV 到整个分组
|
||||
def calc_obv(struct_series: pl.Series) -> pl.Series:
|
||||
"""计算 OBV 的辅助函数。"""
|
||||
if len(struct_series) == 0:
|
||||
return pl.Series([float("nan")] * len(struct_series))
|
||||
|
||||
# struct_series 包含 c 和 v 两个字段
|
||||
c = np.array(struct_series.struct.field("c").to_list(), dtype=float)
|
||||
v = np.array(struct_series.struct.field("v").to_list(), dtype=float)
|
||||
result = talib.OBV(c, v)
|
||||
return pl.Series(result)
|
||||
|
||||
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv)
|
||||
|
||||
# ==================== 截面因子处理器 (cs_*) ====================
|
||||
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
|
||||
|
||||
@@ -377,6 +534,22 @@ class PolarsTranslator:
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.sin()
|
||||
|
||||
@element_wise
|
||||
def _handle_atan(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 atan(expr) -> 反正切函数。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("atan 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.arctan()
|
||||
|
||||
@element_wise
|
||||
def _handle_log1p(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 log1p(expr) -> log(1+x) 函数。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("log1p 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.log1p()
|
||||
|
||||
# ==================== 辅助方法 ====================
|
||||
|
||||
def _extract_window(self, node: Node) -> int:
|
||||
|
||||
Reference in New Issue
Block a user