feat(factors): 添加时间序列函数及智能路由

- 新增 8 个国泰君安 191 兼容的时间序列函数:ts_sma, ts_wma, ts_decay_linear, ts_argmax, ts_argmin, ts_count, ts_prod, ts_sumac
- max_/min_ 函数智能路由:正整数参数自动调用 ts_max/ts_min 实现滚动窗口逻辑
This commit is contained in:
2026-03-15 13:05:55 +08:00
parent 0e9ea5d533
commit c6ebab0e58
3 changed files with 467 additions and 2 deletions

View File

@@ -510,26 +510,40 @@ def abs(x: Union[Node, str]) -> FunctionNode:
def max_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode:
"""逐元素最大值。
智能分发逻辑:
- 如果 y 是正整数 (y > 0),调用 ts_max(x, y) 滚动窗口最大值
- 否则,调用逐元素 max(x, y)
注意:避免 MAX(CLOSE - DELAY(CLOSE, 1), 0) 这类场景被错误路由到 ts_max
Args:
x: 第一个因子表达式或字段名字符串
y: 第二个因子表达式、字段名字符串或数值
y: 第二个因子表达式、字段名字符串或正整数(窗口大小)
Returns:
FunctionNode: 函数调用节点
"""
if isinstance(y, int) and y > 0:
return ts_max(x, y)
return FunctionNode("max", x, _ensure_node(y))
def min_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode:
"""逐元素最小值。
智能分发逻辑:
- 如果 y 是正整数 (y > 0),调用 ts_min(x, y) 滚动窗口最小值
- 否则,调用逐元素 min(x, y)
Args:
x: 第一个因子表达式或字段名字符串
y: 第二个因子表达式、字段名字符串或数值
y: 第二个因子表达式、字段名字符串或正整数(窗口大小)
Returns:
FunctionNode: 函数调用节点
"""
if isinstance(y, int) and y > 0:
return ts_min(x, y)
return FunctionNode("min", x, _ensure_node(y))
@@ -622,3 +636,129 @@ def where(
FunctionNode: 函数调用节点
"""
return if_(condition, true_val, false_val)
# ==================== 补充的时间序列函数 ====================
def ts_sma(x: Union[Node, str], n: int, m: int) -> FunctionNode:
"""国内平滑移动平均 (Simple Moving Average)。
使用 N*M 平滑,公式: SMA = (M*CLOSE + (N-M)*REF(CLOSE,1)) / N
对应国泰君安 191 因子的 SMA 函数。
Args:
x: 输入因子表达式或字段名字符串
n: 第一个参数,对应公式中的 N
m: 第二个参数,对应公式中的 M
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_sma", x, n, m)
def ts_wma(x: Union[Node, str], window: int) -> FunctionNode:
"""线性加权移动平均 (Weighted Moving Average)。
权重按线性递增,最近一天权重最大。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_wma", x, window)
def ts_decay_linear(x: Union[Node, str], window: int) -> FunctionNode:
"""线性衰减移动平均 (Decay Linear)。
与 WMA 等价,权重按线性递增(近期权重最大)。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_decay_linear", x, window)
def ts_argmax(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列最大值位置 (HIGHDAY)。
返回距离过去 window 期内最大值的交易日数。
例如:今天是最高点返回 0昨天是最高点返回 1。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_argmax", x, window)
def ts_argmin(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列最小值位置 (LOWDAY)。
返回距离过去 window 期内最小值的交易日数。
例如:今天是最低点返回 0昨天是最低点返回 1。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_argmin", x, window)
def ts_count(condition: Union[Node, str], window: int) -> FunctionNode:
"""窗口内条件为真的天数。
统计过去 window 内满足条件的交易日数量。
Args:
condition: 条件表达式或字段名字符串(布尔型)
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_count", condition, window)
def ts_prod(x: Union[Node, str], window: int) -> FunctionNode:
"""窗口期内的连乘积。
计算过去 window 期因子值的乘积。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_prod", x, window)
def ts_sumac(x: Union[Node, str]) -> FunctionNode:
"""累计求和。
从序列开始到当前位置的累计求和。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_sumac", x)

View File

@@ -71,8 +71,20 @@ class PolarsTranslator:
self.register_handler("ts_atr", self._handle_ts_atr)
self.register_handler("ts_rsi", self._handle_ts_rsi)
self.register_handler("ts_obv", self._handle_ts_obv)
# 补充的时间序列因子处理器
self.register_handler("ts_sma", self._handle_ts_sma)
self.register_handler("ts_wma", self._handle_ts_wma)
self.register_handler("ts_decay_linear", self._handle_ts_decay_linear)
self.register_handler("ts_argmax", self._handle_ts_argmax)
self.register_handler("ts_argmin", self._handle_ts_argmin)
self.register_handler("ts_count", self._handle_ts_count)
self.register_handler("ts_prod", self._handle_ts_prod)
self.register_handler("ts_sumac", self._handle_ts_sumac)
self.register_handler("max", self._handle_max)
self.register_handler("min", self._handle_min)
# 截面因子处理器 (cs_*)
self.register_handler("cs_rank", self._handle_cs_rank)
self.register_handler("cs_zscore", self._handle_cs_zscore)
self.register_handler("cs_neutral", self._handle_cs_neutral)
@@ -455,6 +467,189 @@ class PolarsTranslator:
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv)
# ==================== 补充的时间序列因子处理器 ====================
@time_series
def _handle_ts_sma(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_sma(x, n, m) -> ewm_mean(alpha=m/n, adjust=False)。
国泰君安 191 的 SMA 使用 N*M 平滑,等效于 alpha=m/n 的 EWM。
"""
if len(node.args) != 3:
raise ValueError("ts_sma 需要 3 个参数: (x, n, m)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
m = self._extract_window(node.args[2])
alpha = m / n if n != 0 else 0.5
return expr.ewm_mean(alpha=alpha, adjust=False)
@time_series
def _handle_ts_wma(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_wma(x, window) -> 使用 numpy.convolve 实现线性加权。
使用卷积实现高性能线性加权移动平均。
"""
if len(node.args) != 2:
raise ValueError("ts_wma 需要 2 个参数: (x, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
def wma_calc(s: pl.Series) -> pl.Series:
"""计算线性加权移动平均,使用 numpy 卷积优化。"""
values = s.to_numpy()
n = len(values)
if n == 0:
return pl.Series([float("nan")] * n)
# 线性递增权重: 1, 2, 3, ..., window
weights = np.arange(1, window + 1, dtype=float)
weights = weights / weights.sum() # 归一化
# 使用卷积计算加权平均
result = np.full(n, np.nan)
# valid 卷积结果需要至少 window 个有效数据点
conv_res = np.convolve(values, weights[::-1], mode="valid")
result[window - 1 :] = conv_res
return pl.Series(result)
return expr.map_batches(wma_calc)
@time_series
def _handle_ts_decay_linear(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_decay_linear(x, window) -> 复用 ts_wma 实现。
DECAYLINEAR 与 WMA 等价,都使用近期最高权重的线性权重。
直接复用 ts_wma 的实现。
"""
# 复用 ts_wma 的实现
return self._handle_ts_wma(node)
@time_series
def _handle_ts_argmax(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_argmax(x, window) -> 距离最大值的交易日数。
HIGHDAY 语义:返回距离过去 window 期内最大值的交易日数。
例如:今天是最高点返回 0昨天是最高点返回 1。
"""
if len(node.args) != 2:
raise ValueError("ts_argmax 需要 2 个参数: (x, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
def argmax_calc(s: pl.Series) -> pl.Series:
"""计算距离最大值的交易日数。"""
values = s.to_numpy()
n = len(values)
if n == 0:
return pl.Series([float("nan")] * n)
result = np.full(n, np.nan)
for i in range(window - 1, n):
window_slice = values[i - window + 1 : i + 1]
# 距离 = 当前索引 - (i - window + 1) - argmax_idx
# 其中 argmax_idx = np.argmax(window_slice)
argmax_idx = np.argmax(window_slice)
distance = window - 1 - argmax_idx
result[i] = distance
return pl.Series(result)
return expr.map_batches(argmax_calc)
@time_series
def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_argmin(x, window) -> 距离最小值的交易日数。
LOWDAY 语义:返回距离过去 window 期内最小值的交易日数。
例如:今天是最低点返回 0昨天是最低点返回 1。
"""
if len(node.args) != 2:
raise ValueError("ts_argmin 需要 2 个参数: (x, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
def argmin_calc(s: pl.Series) -> pl.Series:
"""计算距离最小值的交易日数。"""
values = s.to_numpy()
n = len(values)
if n == 0:
return pl.Series([float("nan")] * n)
result = np.full(n, np.nan)
for i in range(window - 1, n):
window_slice = values[i - window + 1 : i + 1]
argmin_idx = np.argmin(window_slice)
distance = window - 1 - argmin_idx
result[i] = distance
return pl.Series(result)
return expr.map_batches(argmin_calc)
@time_series
def _handle_ts_count(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_count(condition, window) -> 布尔expr.cast(Int32).rolling_sum(window)。"""
if len(node.args) != 2:
raise ValueError("ts_count 需要 2 个参数: (condition, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
# 布尔型转换为 Int32 后 rolling_sum
return expr.cast(pl.Int32).rolling_sum(window_size=window)
@time_series
def _handle_ts_prod(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_prod(x, window) -> 滚动连乘积。
使用 map_batches 实现窗口内的累积乘积。
"""
if len(node.args) != 2:
raise ValueError("ts_prod 需要 2 个参数: (x, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
def prod_calc(s: pl.Series) -> pl.Series:
"""计算窗口内的连乘积。"""
values = s.to_numpy()
n = len(values)
if n == 0:
return pl.Series([float("nan")] * n)
result = np.full(n, np.nan)
for i in range(window - 1, n):
window_slice = values[i - window + 1 : i + 1]
result[i] = np.prod(window_slice)
return pl.Series(result)
return expr.map_batches(prod_calc)
@time_series
def _handle_ts_sumac(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_sumac(x) -> cum_sum()。"""
if len(node.args) != 1:
raise ValueError("ts_sumac 需要 1 个参数: (x)")
expr = self.translate(node.args[0])
return expr.cum_sum()
@element_wise
def _handle_max(self, node: FunctionNode) -> pl.Expr:
"""处理 max(x, y) -> 逐元素最大值。"""
if len(node.args) != 2:
raise ValueError("max 需要 2 个参数: (x, y)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
return pl.when(x >= y).then(x).otherwise(y)
@element_wise
def _handle_min(self, node: FunctionNode) -> pl.Expr:
"""处理 min(x, y) -> 逐元素最小值。"""
if len(node.args) != 2:
raise ValueError("min 需要 2 个参数: (x, y)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
return pl.when(x <= y).then(x).otherwise(y)
# ==================== 截面因子处理器 (cs_*) ====================
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表

View File

@@ -0,0 +1,130 @@
"""测试新增的时间序列函数和智能分发逻辑。"""
import pytest
import polars as pl
import numpy as np
from src.factors.dsl import Symbol, FunctionNode
from src.factors.translator import PolarsTranslator
def test_ts_sma_translate():
"""测试 ts_sma 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_sma", close, 10, 5)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_ts_wma_translate():
"""测试 ts_wma 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_wma", close, 20)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_ts_sumac_translate():
"""测试 ts_sumac 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_sumac", close)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_max_intelligent_dispatch():
"""测试 max_ 智能分发: int -> ts_max其他 -> element-wise max。"""
from src.factors.api import max_, close
# 正整数 -> ts_max
result = max_(close, 20)
assert result.func_name == "ts_max"
# 零或负数 -> element-wise max
result = max_(close, 0)
assert result.func_name == "max"
result = max_(close, -1)
assert result.func_name == "max"
# 浮点数 -> element-wise max
result = max_(close, 10.5)
assert result.func_name == "max"
def test_min_intelligent_dispatch():
"""测试 min_ 智能分发: int -> ts_min其他 -> element-wise min。"""
from src.factors.api import min_, close
# 正整数 -> ts_min
result = min_(close, 20)
assert result.func_name == "ts_min"
# 零或负数 -> element-wise min
result = min_(close, 0)
assert result.func_name == "min"
def create_test_data() -> pl.DataFrame:
"""创建测试数据。"""
np.random.seed(42)
n = 100
return pl.DataFrame(
{
"ts_code": ["000001.SZ"] * n,
"trade_date": list(range(20240101, 20240101 + n)),
"close": np.random.randn(n).cumsum() + 100,
}
)
def test_ts_sma_computation():
"""测试 ts_sma 计算与原生 Polars 一致。"""
df = create_test_data()
translator = PolarsTranslator()
# 翻译因子
close = Symbol("close")
expr_node = FunctionNode("ts_sma", close, 10, 5)
expr = translator.translate(expr_node)
# 使用翻译后的表达式计算
result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")])
# 原生 Polars 计算
native = df.with_columns(
[pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")]
)
# 对比结果
assert np.allclose(
result["ts_sma_result"].to_numpy()[9:],
native["native_sma"].to_numpy()[9:],
equal_nan=True,
)
def test_ts_sumac_computation():
"""测试 ts_sumac 计算与原生 Polars 一致。"""
df = create_test_data()
translator = PolarsTranslator()
close = Symbol("close")
expr_node = FunctionNode("ts_sumac", close)
expr = translator.translate(expr_node)
result = df.select(
["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")]
)
native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")])
assert np.allclose(
result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy()
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])