feat(factors): 新增 Phase 1-2 数学和统计因子函数
- 新增 atan, log1p 数学函数 - 新增 ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema 统计函数 - 新增 ts_atr, ts_rsi, ts_obv TA-Lib 技术指标函数 - 新增完整集成测试覆盖所有新函数
This commit is contained in:
@@ -190,6 +190,130 @@ def ts_cov(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNod
|
|||||||
return FunctionNode("ts_cov", x, y, window)
|
return FunctionNode("ts_cov", x, y, window)
|
||||||
|
|
||||||
|
|
||||||
|
def ts_var(x: Union[Node, str], window: int) -> FunctionNode:
|
||||||
|
"""时间序列方差。
|
||||||
|
|
||||||
|
计算给定因子在滚动窗口内的方差。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 输入因子表达式或字段名字符串
|
||||||
|
window: 滚动窗口大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("ts_var", x, window)
|
||||||
|
|
||||||
|
|
||||||
|
def ts_skew(x: Union[Node, str], window: int) -> FunctionNode:
|
||||||
|
"""时间序列偏度。
|
||||||
|
|
||||||
|
计算给定因子在滚动窗口内的偏度(三阶矩)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 输入因子表达式或字段名字符串
|
||||||
|
window: 滚动窗口大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("ts_skew", x, window)
|
||||||
|
|
||||||
|
|
||||||
|
def ts_kurt(x: Union[Node, str], window: int) -> FunctionNode:
|
||||||
|
"""时间序列峰度。
|
||||||
|
|
||||||
|
计算给定因子在滚动窗口内的峰度(四阶矩)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 输入因子表达式或字段名字符串
|
||||||
|
window: 滚动窗口大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("ts_kurt", x, window)
|
||||||
|
|
||||||
|
|
||||||
|
def ts_pct_change(x: Union[Node, str], periods: int) -> FunctionNode:
|
||||||
|
"""时间序列百分比变化。
|
||||||
|
|
||||||
|
计算给定因子与 N 个周期前的百分比变化:(x - x.shift(n)) / x.shift(n)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 输入因子表达式或字段名字符串
|
||||||
|
periods: 滞后期数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("ts_pct_change", x, periods)
|
||||||
|
|
||||||
|
|
||||||
|
def ts_ema(x: Union[Node, str], window: int) -> FunctionNode:
|
||||||
|
"""指数移动平均。
|
||||||
|
|
||||||
|
计算给定因子的指数移动平均值。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 输入因子表达式或字段名字符串
|
||||||
|
window: 指数移动平均的 span 参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("ts_ema", x, window)
|
||||||
|
|
||||||
|
|
||||||
|
def ts_atr(
|
||||||
|
high: Union[Node, str], low: Union[Node, str], close: Union[Node, str], window: int
|
||||||
|
) -> FunctionNode:
|
||||||
|
"""平均真实波幅 (Average True Range)。
|
||||||
|
|
||||||
|
计算给定窗口内的平均真实波幅,使用 TA-Lib 实现。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
high: 最高价表达式或字段名字符串
|
||||||
|
low: 最低价表达式或字段名字符串
|
||||||
|
close: 收盘价表达式或字段名字符串
|
||||||
|
window: 滚动窗口大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("ts_atr", high, low, close, window)
|
||||||
|
|
||||||
|
|
||||||
|
def ts_rsi(close: Union[Node, str], window: int) -> FunctionNode:
|
||||||
|
"""相对强弱指数 (Relative Strength Index)。
|
||||||
|
|
||||||
|
计算给定窗口内的 RSI 值,使用 TA-Lib 实现。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
close: 收盘价表达式或字段名字符串
|
||||||
|
window: 滚动窗口大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("ts_rsi", close, window)
|
||||||
|
|
||||||
|
|
||||||
|
def ts_obv(close: Union[Node, str], volume: Union[Node, str]) -> FunctionNode:
|
||||||
|
"""能量潮指标 (On Balance Volume)。
|
||||||
|
|
||||||
|
计算 OBV 值,使用 TA-Lib 实现。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
close: 收盘价表达式或字段名字符串
|
||||||
|
volume: 成交量表达式或字段名字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("ts_obv", close, volume)
|
||||||
|
|
||||||
|
|
||||||
def ts_rank(x: Union[Node, str], window: int) -> FunctionNode:
|
def ts_rank(x: Union[Node, str], window: int) -> FunctionNode:
|
||||||
"""时间序列排名。
|
"""时间序列排名。
|
||||||
|
|
||||||
@@ -429,6 +553,34 @@ def clip(
|
|||||||
return FunctionNode("clip", x, _ensure_node(lower), _ensure_node(upper))
|
return FunctionNode("clip", x, _ensure_node(lower), _ensure_node(upper))
|
||||||
|
|
||||||
|
|
||||||
|
def atan(x: Union[Node, str]) -> FunctionNode:
|
||||||
|
"""反正切函数。
|
||||||
|
|
||||||
|
计算输入值的反正切值(弧度)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 输入因子表达式或字段名字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("atan", x)
|
||||||
|
|
||||||
|
|
||||||
|
def log1p(x: Union[Node, str]) -> FunctionNode:
|
||||||
|
"""log(1+x) 函数。
|
||||||
|
|
||||||
|
计算 log(1+x),对 x 接近 0 的情况更精确。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: 输入因子表达式或字段名字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FunctionNode: 函数调用节点
|
||||||
|
"""
|
||||||
|
return FunctionNode("log1p", x)
|
||||||
|
|
||||||
|
|
||||||
# ==================== 条件函数 ====================
|
# ==================== 条件函数 ====================
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,18 @@
|
|||||||
|
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
|
# TA-Lib 可选依赖
|
||||||
|
try:
|
||||||
|
import talib
|
||||||
|
|
||||||
|
HAS_TALIB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TALIB = False
|
||||||
|
talib = None
|
||||||
|
|
||||||
from src.factors.decorators import cross_section, element_wise, time_series
|
from src.factors.decorators import cross_section, element_wise, time_series
|
||||||
from src.factors.dsl import (
|
from src.factors.dsl import (
|
||||||
BinaryOpNode,
|
BinaryOpNode,
|
||||||
@@ -53,6 +63,14 @@ class PolarsTranslator:
|
|||||||
self.register_handler("ts_delta", self._handle_ts_delta)
|
self.register_handler("ts_delta", self._handle_ts_delta)
|
||||||
self.register_handler("ts_corr", self._handle_ts_corr)
|
self.register_handler("ts_corr", self._handle_ts_corr)
|
||||||
self.register_handler("ts_cov", self._handle_ts_cov)
|
self.register_handler("ts_cov", self._handle_ts_cov)
|
||||||
|
self.register_handler("ts_var", self._handle_ts_var)
|
||||||
|
self.register_handler("ts_skew", self._handle_ts_skew)
|
||||||
|
self.register_handler("ts_kurt", self._handle_ts_kurt)
|
||||||
|
self.register_handler("ts_pct_change", self._handle_ts_pct_change)
|
||||||
|
self.register_handler("ts_ema", self._handle_ts_ema)
|
||||||
|
self.register_handler("ts_atr", self._handle_ts_atr)
|
||||||
|
self.register_handler("ts_rsi", self._handle_ts_rsi)
|
||||||
|
self.register_handler("ts_obv", self._handle_ts_obv)
|
||||||
|
|
||||||
# 截面因子处理器 (cs_*)
|
# 截面因子处理器 (cs_*)
|
||||||
self.register_handler("cs_rank", self._handle_cs_rank)
|
self.register_handler("cs_rank", self._handle_cs_rank)
|
||||||
@@ -66,6 +84,8 @@ class PolarsTranslator:
|
|||||||
self.register_handler("sign", self._handle_sign)
|
self.register_handler("sign", self._handle_sign)
|
||||||
self.register_handler("cos", self._handle_cos)
|
self.register_handler("cos", self._handle_cos)
|
||||||
self.register_handler("sin", self._handle_sin)
|
self.register_handler("sin", self._handle_sin)
|
||||||
|
self.register_handler("atan", self._handle_atan)
|
||||||
|
self.register_handler("log1p", self._handle_log1p)
|
||||||
|
|
||||||
def register_handler(
|
def register_handler(
|
||||||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||||||
@@ -295,6 +315,143 @@ class PolarsTranslator:
|
|||||||
window = self._extract_window(node.args[2])
|
window = self._extract_window(node.args[2])
|
||||||
return x.rolling_cov(y, window_size=window)
|
return x.rolling_cov(y, window_size=window)
|
||||||
|
|
||||||
|
@time_series
|
||||||
|
def _handle_ts_var(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 ts_var(close, window) -> rolling_var(window)。"""
|
||||||
|
if len(node.args) != 2:
|
||||||
|
raise ValueError("ts_var 需要 2 个参数: (expr, window)")
|
||||||
|
expr = self.translate(node.args[0])
|
||||||
|
window = self._extract_window(node.args[1])
|
||||||
|
return expr.rolling_var(window_size=window)
|
||||||
|
|
||||||
|
@time_series
|
||||||
|
def _handle_ts_skew(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 ts_skew(close, window) -> rolling_skew(window)。"""
|
||||||
|
if len(node.args) != 2:
|
||||||
|
raise ValueError("ts_skew 需要 2 个参数: (expr, window)")
|
||||||
|
expr = self.translate(node.args[0])
|
||||||
|
window = self._extract_window(node.args[1])
|
||||||
|
return expr.rolling_skew(window_size=window)
|
||||||
|
|
||||||
|
@time_series
|
||||||
|
def _handle_ts_kurt(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 ts_kurt(close, window) -> rolling_kurt(window)。"""
|
||||||
|
if len(node.args) != 2:
|
||||||
|
raise ValueError("ts_kurt 需要 2 个参数: (expr, window)")
|
||||||
|
expr = self.translate(node.args[0])
|
||||||
|
window = self._extract_window(node.args[1])
|
||||||
|
# 使用 rolling_map 计算峰度
|
||||||
|
return expr.rolling_map(
|
||||||
|
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
|
||||||
|
window_size=window,
|
||||||
|
)
|
||||||
|
|
||||||
|
@time_series
|
||||||
|
def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 ts_pct_change(x, n) -> (x - shift(n)) / shift(n)。"""
|
||||||
|
if len(node.args) != 2:
|
||||||
|
raise ValueError("ts_pct_change 需要 2 个参数: (expr, periods)")
|
||||||
|
expr = self.translate(node.args[0])
|
||||||
|
n = self._extract_window(node.args[1])
|
||||||
|
shifted = expr.shift(n)
|
||||||
|
return (expr - shifted) / shifted
|
||||||
|
|
||||||
|
@time_series
|
||||||
|
def _handle_ts_ema(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 ts_ema(x, window) -> ewm_mean(span=window)。"""
|
||||||
|
if len(node.args) != 2:
|
||||||
|
raise ValueError("ts_ema 需要 2 个参数: (expr, window)")
|
||||||
|
expr = self.translate(node.args[0])
|
||||||
|
window = self._extract_window(node.args[1])
|
||||||
|
return expr.ewm_mean(span=window)
|
||||||
|
|
||||||
|
@time_series
|
||||||
|
def _handle_ts_atr(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 ts_atr(high, low, close, window) -> 使用 TA-Lib 计算 ATR。
|
||||||
|
|
||||||
|
使用 map_batches 在每个分组上应用 TA-Lib ATR 函数。
|
||||||
|
@time_series 装饰器会自动添加 .over("ts_code")
|
||||||
|
"""
|
||||||
|
if not HAS_TALIB:
|
||||||
|
raise ImportError("ts_atr 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||||||
|
if len(node.args) != 4:
|
||||||
|
raise ValueError("ts_atr 需要 4 个参数: (high, low, close, window)")
|
||||||
|
|
||||||
|
high = self.translate(node.args[0])
|
||||||
|
low = self.translate(node.args[1])
|
||||||
|
close = self.translate(node.args[2])
|
||||||
|
window = self._extract_window(node.args[3])
|
||||||
|
|
||||||
|
# 使用 map_batches 应用 TA-Lib ATR 到整个分组
|
||||||
|
def calc_atr(struct_series: pl.Series) -> pl.Series:
|
||||||
|
"""计算 ATR 的辅助函数。"""
|
||||||
|
if len(struct_series) == 0:
|
||||||
|
return pl.Series([float("nan")] * len(struct_series))
|
||||||
|
|
||||||
|
# struct_series 包含 h, l, c 三个字段
|
||||||
|
h = np.array(struct_series.struct.field("h").to_list(), dtype=float)
|
||||||
|
l = np.array(struct_series.struct.field("l").to_list(), dtype=float)
|
||||||
|
c = np.array(struct_series.struct.field("c").to_list(), dtype=float)
|
||||||
|
result = talib.ATR(h, l, c, timeperiod=window)
|
||||||
|
return pl.Series(result)
|
||||||
|
|
||||||
|
return pl.struct(
|
||||||
|
[high.alias("h"), low.alias("l"), close.alias("c")]
|
||||||
|
).map_batches(calc_atr)
|
||||||
|
|
||||||
|
@time_series
|
||||||
|
def _handle_ts_rsi(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 ts_rsi(close, window) -> 使用 TA-Lib 计算 RSI。
|
||||||
|
|
||||||
|
使用 map_batches 在每个分组上应用 TA-Lib RSI 函数。
|
||||||
|
@time_series 装饰器会自动添加 .over("ts_code")
|
||||||
|
"""
|
||||||
|
if not HAS_TALIB:
|
||||||
|
raise ImportError("ts_rsi 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||||||
|
if len(node.args) != 2:
|
||||||
|
raise ValueError("ts_rsi 需要 2 个参数: (close, window)")
|
||||||
|
|
||||||
|
close = self.translate(node.args[0])
|
||||||
|
window = self._extract_window(node.args[1])
|
||||||
|
|
||||||
|
# 使用 map_batches 应用 TA-Lib RSI 到整个分组
|
||||||
|
def calc_rsi(series: pl.Series) -> pl.Series:
|
||||||
|
"""计算 RSI 的辅助函数。"""
|
||||||
|
values = np.array(series.to_list(), dtype=float)
|
||||||
|
result = talib.RSI(values, timeperiod=window)
|
||||||
|
return pl.Series(result)
|
||||||
|
|
||||||
|
return close.map_batches(calc_rsi)
|
||||||
|
|
||||||
|
@time_series
|
||||||
|
def _handle_ts_obv(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 ts_obv(close, volume) -> 使用 TA-Lib 计算 OBV。
|
||||||
|
|
||||||
|
使用 map_batches 在每个分组上应用 TA-Lib OBV 函数。
|
||||||
|
@time_series 装饰器会自动添加 .over("ts_code")
|
||||||
|
"""
|
||||||
|
if not HAS_TALIB:
|
||||||
|
raise ImportError("ts_obv 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||||||
|
if len(node.args) != 2:
|
||||||
|
raise ValueError("ts_obv 需要 2 个参数: (close, volume)")
|
||||||
|
|
||||||
|
close = self.translate(node.args[0])
|
||||||
|
volume = self.translate(node.args[1])
|
||||||
|
|
||||||
|
# 使用 map_batches 应用 TA-Lib OBV 到整个分组
|
||||||
|
def calc_obv(struct_series: pl.Series) -> pl.Series:
|
||||||
|
"""计算 OBV 的辅助函数。"""
|
||||||
|
if len(struct_series) == 0:
|
||||||
|
return pl.Series([float("nan")] * len(struct_series))
|
||||||
|
|
||||||
|
# struct_series 包含 c 和 v 两个字段
|
||||||
|
c = np.array(struct_series.struct.field("c").to_list(), dtype=float)
|
||||||
|
v = np.array(struct_series.struct.field("v").to_list(), dtype=float)
|
||||||
|
result = talib.OBV(c, v)
|
||||||
|
return pl.Series(result)
|
||||||
|
|
||||||
|
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv)
|
||||||
|
|
||||||
# ==================== 截面因子处理器 (cs_*) ====================
|
# ==================== 截面因子处理器 (cs_*) ====================
|
||||||
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
|
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
|
||||||
|
|
||||||
@@ -377,6 +534,22 @@ class PolarsTranslator:
|
|||||||
expr = self.translate(node.args[0])
|
expr = self.translate(node.args[0])
|
||||||
return expr.sin()
|
return expr.sin()
|
||||||
|
|
||||||
|
@element_wise
|
||||||
|
def _handle_atan(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 atan(expr) -> 反正切函数。"""
|
||||||
|
if len(node.args) != 1:
|
||||||
|
raise ValueError("atan 需要 1 个参数: (expr)")
|
||||||
|
expr = self.translate(node.args[0])
|
||||||
|
return expr.arctan()
|
||||||
|
|
||||||
|
@element_wise
|
||||||
|
def _handle_log1p(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
"""处理 log1p(expr) -> log(1+x) 函数。"""
|
||||||
|
if len(node.args) != 1:
|
||||||
|
raise ValueError("log1p 需要 1 个参数: (expr)")
|
||||||
|
expr = self.translate(node.args[0])
|
||||||
|
return expr.log1p()
|
||||||
|
|
||||||
# ==================== 辅助方法 ====================
|
# ==================== 辅助方法 ====================
|
||||||
|
|
||||||
def _extract_window(self, node: Node) -> int:
|
def _extract_window(self, node: Node) -> int:
|
||||||
|
|||||||
541
tests/test_phase1_2_factors.py
Normal file
541
tests/test_phase1_2_factors.py
Normal file
@@ -0,0 +1,541 @@
|
|||||||
|
"""Phase 1-2 因子函数集成测试。
|
||||||
|
|
||||||
|
测试所有新实现的函数,使用字符串因子表达式形式计算因子,
|
||||||
|
并与原始 Polars 计算结果进行对比。
|
||||||
|
|
||||||
|
测试范围:
|
||||||
|
1. 数学函数:atan, log1p
|
||||||
|
2. 统计函数:ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema
|
||||||
|
3. TA-Lib 函数:ts_atr, ts_rsi, ts_obv
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.factors import FormulaParser, FunctionRegistry
|
||||||
|
from src.factors.translator import PolarsTranslator, HAS_TALIB
|
||||||
|
from src.factors.engine import FactorEngine
|
||||||
|
from src.data.catalog import DatabaseCatalog
|
||||||
|
|
||||||
|
|
||||||
|
# ============== 测试数据准备 ==============
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_data() -> pl.DataFrame:
|
||||||
|
"""创建测试用的模拟数据。
|
||||||
|
|
||||||
|
创建一个包含多只股票、多个交易日的 DataFrame,
|
||||||
|
用于测试因子函数的计算。
|
||||||
|
"""
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
dates = pl.date_range(
|
||||||
|
start=pl.date(2024, 1, 1),
|
||||||
|
end=pl.date(2024, 1, 31),
|
||||||
|
interval="1d",
|
||||||
|
eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600001.SH"]
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for stock in stocks:
|
||||||
|
base_price = 100 + np.random.randn() * 10
|
||||||
|
for i, date in enumerate(dates):
|
||||||
|
price = base_price + np.random.randn() * 5 + i * 0.1
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
"ts_code": stock,
|
||||||
|
"trade_date": date,
|
||||||
|
"close": price,
|
||||||
|
"open": price * (1 + np.random.randn() * 0.01),
|
||||||
|
"high": price * (1 + abs(np.random.randn()) * 0.02),
|
||||||
|
"low": price * (1 - abs(np.random.randn()) * 0.02),
|
||||||
|
"vol": int(1000000 + np.random.randn() * 500000),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pl.DataFrame(data)
|
||||||
|
|
||||||
|
|
||||||
|
# ============== 数学函数测试 ==============
|
||||||
|
|
||||||
|
|
||||||
|
def test_atan_function():
|
||||||
|
"""测试 atan 函数:计算反正切值。"""
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 5,
|
||||||
|
"trade_date": pl.date_range(
|
||||||
|
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
|
||||||
|
),
|
||||||
|
"value": [0.0, 1.0, -1.0, 0.5, -0.5],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("atan(value)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 原始 Polars 计算
|
||||||
|
result_pl = df.with_columns(pl_result=pl.col("value").arctan()).to_pandas()[
|
||||||
|
"pl_result"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 对比结果
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, result_pl.values, decimal=10
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_log1p_function():
|
||||||
|
"""测试 log1p 函数:计算 log(1+x)。"""
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 5,
|
||||||
|
"trade_date": pl.date_range(
|
||||||
|
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
|
||||||
|
),
|
||||||
|
"value": [0.0, 0.1, -0.1, 1.0, -0.5],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("log1p(value)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 原始 Polars 计算
|
||||||
|
result_pl = df.with_columns(pl_result=pl.col("value").log1p()).to_pandas()[
|
||||||
|
"pl_result"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 对比结果
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, result_pl.values, decimal=10
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============== 统计函数测试 ==============
|
||||||
|
|
||||||
|
|
||||||
|
def test_ts_var_function():
|
||||||
|
"""测试 ts_var 函数:滚动方差。"""
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 10 + ["B"] * 10,
|
||||||
|
"trade_date": pl.date_range(
|
||||||
|
pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True
|
||||||
|
).append(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
|
||||||
|
),
|
||||||
|
"close": list(range(1, 11)) + list(range(10, 20)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("ts_var(close, 5)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = (
|
||||||
|
df.with_columns(dsl_result=polars_expr)
|
||||||
|
.to_pandas()
|
||||||
|
.groupby("ts_code")["dsl_result"]
|
||||||
|
.apply(list)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 原始 Polars 计算
|
||||||
|
result_pl = (
|
||||||
|
df.with_columns(
|
||||||
|
pl_result=pl.col("close").rolling_var(window_size=5).over("ts_code")
|
||||||
|
)
|
||||||
|
.to_pandas()
|
||||||
|
.groupby("ts_code")["pl_result"]
|
||||||
|
.apply(list)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对比结果
|
||||||
|
for stock in ["A", "B"]:
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl[stock], result_pl[stock], decimal=10
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ts_skew_function():
|
||||||
|
"""测试 ts_skew 函数:滚动偏度。"""
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
np.random.seed(42)
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 20 + ["B"] * 20,
|
||||||
|
"trade_date": list(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
|
||||||
|
)
|
||||||
|
* 2,
|
||||||
|
"close": np.random.randn(40),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("ts_skew(close, 10)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 原始 Polars 计算
|
||||||
|
result_pl = df.with_columns(
|
||||||
|
pl_result=pl.col("close").rolling_skew(window_size=10).over("ts_code")
|
||||||
|
).to_pandas()["pl_result"]
|
||||||
|
|
||||||
|
# 对比结果
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, result_pl.values, decimal=10
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ts_kurt_function():
|
||||||
|
"""测试 ts_kurt 函数:滚动峰度。"""
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
np.random.seed(42)
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 20 + ["B"] * 20,
|
||||||
|
"trade_date": list(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
|
||||||
|
)
|
||||||
|
* 2,
|
||||||
|
"close": np.random.randn(40),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("ts_kurt(close, 10)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 原始 Polars 计算
|
||||||
|
result_pl = df.with_columns(
|
||||||
|
pl_result=pl.col("close")
|
||||||
|
.rolling_map(
|
||||||
|
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
|
||||||
|
window_size=10,
|
||||||
|
)
|
||||||
|
.over("ts_code")
|
||||||
|
).to_pandas()["pl_result"]
|
||||||
|
|
||||||
|
# 对比结果
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, result_pl.values, decimal=10
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ts_pct_change_function():
|
||||||
|
"""测试 ts_pct_change 函数:百分比变化。"""
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 5 + ["B"] * 5,
|
||||||
|
"trade_date": list(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True)
|
||||||
|
)
|
||||||
|
* 2,
|
||||||
|
"close": [100, 105, 102, 108, 110, 50, 52, 48, 55, 60],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("ts_pct_change(close, 1)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 原始 Polars 计算
|
||||||
|
result_pl = df.with_columns(
|
||||||
|
pl_result=(pl.col("close") - pl.col("close").shift(1))
|
||||||
|
/ pl.col("close").shift(1).over("ts_code")
|
||||||
|
).to_pandas()["pl_result"]
|
||||||
|
|
||||||
|
# 对比结果
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, result_pl.values, decimal=10
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ts_ema_function():
|
||||||
|
"""测试 ts_ema 函数:指数移动平均。"""
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 10 + ["B"] * 10,
|
||||||
|
"trade_date": list(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
|
||||||
|
)
|
||||||
|
* 2,
|
||||||
|
"close": list(range(1, 11)) + list(range(10, 20)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("ts_ema(close, 5)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 原始 Polars 计算
|
||||||
|
result_pl = df.with_columns(
|
||||||
|
pl_result=pl.col("close").ewm_mean(span=5).over("ts_code")
|
||||||
|
).to_pandas()["pl_result"]
|
||||||
|
|
||||||
|
# 对比结果
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, result_pl.values, decimal=10
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============== TA-Lib 函数测试 ==============
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
|
||||||
|
def test_ts_atr_function():
|
||||||
|
"""测试 ts_atr 函数:平均真实波幅。"""
|
||||||
|
import talib
|
||||||
|
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
np.random.seed(42)
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 20 + ["B"] * 20,
|
||||||
|
"trade_date": list(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
|
||||||
|
)
|
||||||
|
* 2,
|
||||||
|
"high": 100 + np.random.randn(40) * 2,
|
||||||
|
"low": 98 + np.random.randn(40) * 2,
|
||||||
|
"close": 99 + np.random.randn(40) * 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("ts_atr(high, low, close, 14)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 使用 talib 手动计算(分组计算)
|
||||||
|
result_expected = []
|
||||||
|
for stock in ["A", "B"]:
|
||||||
|
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
|
||||||
|
atr = talib.ATR(
|
||||||
|
stock_df["high"].values,
|
||||||
|
stock_df["low"].values,
|
||||||
|
stock_df["close"].values,
|
||||||
|
timeperiod=14,
|
||||||
|
)
|
||||||
|
result_expected.extend(atr)
|
||||||
|
|
||||||
|
# 对比结果(允许小误差)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, np.array(result_expected), decimal=5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
|
||||||
|
def test_ts_rsi_function():
|
||||||
|
"""测试 ts_rsi 函数:相对强弱指数。"""
|
||||||
|
import talib
|
||||||
|
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
np.random.seed(42)
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 30 + ["B"] * 30,
|
||||||
|
"trade_date": list(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
|
||||||
|
)
|
||||||
|
* 2,
|
||||||
|
"close": 100 + np.cumsum(np.random.randn(60)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("ts_rsi(close, 14)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 使用 talib 手动计算(分组计算)
|
||||||
|
result_expected = []
|
||||||
|
for stock in ["A", "B"]:
|
||||||
|
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
|
||||||
|
rsi = talib.RSI(stock_df["close"].values, timeperiod=14)
|
||||||
|
result_expected.extend(rsi)
|
||||||
|
|
||||||
|
# 对比结果(允许小误差)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, np.array(result_expected), decimal=5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
|
||||||
|
def test_ts_obv_function():
|
||||||
|
"""测试 ts_obv 函数:能量潮指标。"""
|
||||||
|
import talib
|
||||||
|
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
np.random.seed(42)
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 20 + ["B"] * 20,
|
||||||
|
"trade_date": list(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
|
||||||
|
)
|
||||||
|
* 2,
|
||||||
|
"close": 100 + np.cumsum(np.random.randn(40)),
|
||||||
|
"vol": np.random.randint(100000, 1000000, 40).astype(float),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# DSL 计算
|
||||||
|
expr = parser.parse("ts_obv(close, vol)")
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
|
||||||
|
|
||||||
|
# 使用 talib 手动计算(分组计算)
|
||||||
|
result_expected = []
|
||||||
|
for stock in ["A", "B"]:
|
||||||
|
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
|
||||||
|
obv = talib.OBV(
|
||||||
|
stock_df["close"].values,
|
||||||
|
stock_df["vol"].values,
|
||||||
|
)
|
||||||
|
result_expected.extend(obv)
|
||||||
|
|
||||||
|
# 对比结果(允许小误差)
|
||||||
|
np.testing.assert_array_almost_equal(
|
||||||
|
result_dsl.values, np.array(result_expected), decimal=5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============== 综合测试 ==============
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_factor_expressions():
|
||||||
|
"""测试复杂因子表达式的计算。"""
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
np.random.seed(42)
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A"] * 30 + ["B"] * 30,
|
||||||
|
"trade_date": list(
|
||||||
|
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
|
||||||
|
)
|
||||||
|
* 2,
|
||||||
|
"close": 100 + np.cumsum(np.random.randn(60)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试 act_factor1: atan((ts_ema(close,5)/ts_delay(ts_ema(close,5),1)-1)*100) * 57.3 / 50
|
||||||
|
expr = parser.parse(
|
||||||
|
"atan((ts_ema(close, 5) / ts_delay(ts_ema(close, 5), 1) - 1) * 100) * 57.3 / 50"
|
||||||
|
)
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(expr)
|
||||||
|
result = df.with_columns(factor=polars_expr)
|
||||||
|
|
||||||
|
# 验证结果不为空
|
||||||
|
assert len(result) == 60
|
||||||
|
assert "factor" in result.columns
|
||||||
|
|
||||||
|
print("复杂因子表达式测试通过")
|
||||||
|
|
||||||
|
|
||||||
|
# ============== 主函数 ==============
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("运行 Phase 1-2 因子函数测试...")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 运行数学函数测试
|
||||||
|
print("\n[数学函数测试]")
|
||||||
|
test_atan_function()
|
||||||
|
print(" ✅ atan 测试通过")
|
||||||
|
|
||||||
|
test_log1p_function()
|
||||||
|
print(" ✅ log1p 测试通过")
|
||||||
|
|
||||||
|
# 运行统计函数测试
|
||||||
|
print("\n[统计函数测试]")
|
||||||
|
test_ts_var_function()
|
||||||
|
print(" ✅ ts_var 测试通过")
|
||||||
|
|
||||||
|
test_ts_skew_function()
|
||||||
|
print(" ✅ ts_skew 测试通过")
|
||||||
|
|
||||||
|
test_ts_kurt_function()
|
||||||
|
print(" ✅ ts_kurt 测试通过")
|
||||||
|
|
||||||
|
test_ts_pct_change_function()
|
||||||
|
print(" ✅ ts_pct_change 测试通过")
|
||||||
|
|
||||||
|
test_ts_ema_function()
|
||||||
|
print(" ✅ ts_ema 测试通过")
|
||||||
|
|
||||||
|
# 运行 TA-Lib 函数测试
|
||||||
|
print("\n[TA-Lib 函数测试]")
|
||||||
|
try:
|
||||||
|
import talib
|
||||||
|
|
||||||
|
HAS_TALIB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TALIB = False
|
||||||
|
print(" ⚠️ TA-Lib 未安装,跳过 TA-Lib 测试")
|
||||||
|
|
||||||
|
if HAS_TALIB:
|
||||||
|
test_ts_atr_function()
|
||||||
|
print(" ✅ ts_atr 测试通过")
|
||||||
|
|
||||||
|
test_ts_rsi_function()
|
||||||
|
print(" ✅ ts_rsi 测试通过")
|
||||||
|
|
||||||
|
test_ts_obv_function()
|
||||||
|
print(" ✅ ts_obv 测试通过")
|
||||||
|
|
||||||
|
# 运行综合测试
|
||||||
|
print("\n[综合测试]")
|
||||||
|
test_complex_factor_expressions()
|
||||||
|
print(" ✅ 复杂因子表达式测试通过")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("所有测试通过!")
|
||||||
Reference in New Issue
Block a user