diff --git a/src/factors/api.py b/src/factors/api.py index 229754e..99bb229 100644 --- a/src/factors/api.py +++ b/src/factors/api.py @@ -510,26 +510,40 @@ def abs(x: Union[Node, str]) -> FunctionNode: def max_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode: """逐元素最大值。 + 智能分发逻辑: + - 如果 y 是正整数 (y > 0),调用 ts_max(x, y) 滚动窗口最大值 + - 否则,调用逐元素 max(x, y) + + 注意:避免 MAX(CLOSE - DELAY(CLOSE, 1), 0) 这类场景被错误路由到 ts_max + Args: x: 第一个因子表达式或字段名字符串 - y: 第二个因子表达式、字段名字符串或数值 + y: 第二个因子表达式、字段名字符串或正整数(窗口大小) Returns: FunctionNode: 函数调用节点 """ + if isinstance(y, int) and y > 0: + return ts_max(x, y) return FunctionNode("max", x, _ensure_node(y)) def min_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode: """逐元素最小值。 + 智能分发逻辑: + - 如果 y 是正整数 (y > 0),调用 ts_min(x, y) 滚动窗口最小值 + - 否则,调用逐元素 min(x, y) + Args: x: 第一个因子表达式或字段名字符串 - y: 第二个因子表达式、字段名字符串或数值 + y: 第二个因子表达式、字段名字符串或正整数(窗口大小) Returns: FunctionNode: 函数调用节点 """ + if isinstance(y, int) and y > 0: + return ts_min(x, y) return FunctionNode("min", x, _ensure_node(y)) @@ -622,3 +636,129 @@ def where( FunctionNode: 函数调用节点 """ return if_(condition, true_val, false_val) + + +# ==================== 补充的时间序列函数 ==================== + + +def ts_sma(x: Union[Node, str], n: int, m: int) -> FunctionNode: + """国内平滑移动平均 (Simple Moving Average)。 + + 使用 N*M 平滑,公式: SMA = (M*CLOSE + (N-M)*REF(CLOSE,1)) / N + 对应国泰君安 191 因子的 SMA 函数。 + + Args: + x: 输入因子表达式或字段名字符串 + n: 第一个参数,对应公式中的 N + m: 第二个参数,对应公式中的 M + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_sma", x, n, m) + + +def ts_wma(x: Union[Node, str], window: int) -> FunctionNode: + """线性加权移动平均 (Weighted Moving Average)。 + + 权重按线性递增,最近一天权重最大。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_wma", x, window) + + +def ts_decay_linear(x: Union[Node, str], window: int) -> FunctionNode: + """线性衰减移动平均 (Decay Linear)。 + + 与 WMA 等价,权重按线性递增(近期权重最大)。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_decay_linear", x, window) + + +def ts_argmax(x: Union[Node, str], window: int) -> FunctionNode: + """时间序列最大值位置 (HIGHDAY)。 + + 返回距离过去 window 期内最大值的交易日数。 + 例如:今天是最高点返回 0,昨天是最高点返回 1。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_argmax", x, window) + + +def ts_argmin(x: Union[Node, str], window: int) -> FunctionNode: + """时间序列最小值位置 (LOWDAY)。 + + 返回距离过去 window 期内最小值的交易日数。 + 例如:今天是最低点返回 0,昨天是最低点返回 1。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_argmin", x, window) + + +def ts_count(condition: Union[Node, str], window: int) -> FunctionNode: + """窗口内条件为真的天数。 + + 统计过去 window 内满足条件的交易日数量。 + + Args: + condition: 条件表达式或字段名字符串(布尔型) + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_count", condition, window) + + +def ts_prod(x: Union[Node, str], window: int) -> FunctionNode: + """窗口期内的连乘积。 + + 计算过去 window 期因子值的乘积。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_prod", x, window) + + +def ts_sumac(x: Union[Node, str]) -> FunctionNode: + """累计求和。 + + 从序列开始到当前位置的累计求和。 + + Args: + x: 输入因子表达式或字段名字符串 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_sumac", x) diff --git a/src/factors/translator.py b/src/factors/translator.py index 64686f8..1270f5b 100644 --- a/src/factors/translator.py +++ b/src/factors/translator.py @@ -71,8 +71,20 @@ class PolarsTranslator: self.register_handler("ts_atr", self._handle_ts_atr) self.register_handler("ts_rsi", self._handle_ts_rsi) self.register_handler("ts_obv", self._handle_ts_obv) + # 补充的时间序列因子处理器 + self.register_handler("ts_sma", self._handle_ts_sma) + self.register_handler("ts_wma", self._handle_ts_wma) + self.register_handler("ts_decay_linear", self._handle_ts_decay_linear) + self.register_handler("ts_argmax", self._handle_ts_argmax) + self.register_handler("ts_argmin", self._handle_ts_argmin) + self.register_handler("ts_count", self._handle_ts_count) + self.register_handler("ts_prod", self._handle_ts_prod) + self.register_handler("ts_sumac", self._handle_ts_sumac) + self.register_handler("max", self._handle_max) + self.register_handler("min", self._handle_min) # 截面因子处理器 (cs_*) + self.register_handler("cs_rank", self._handle_cs_rank) self.register_handler("cs_zscore", self._handle_cs_zscore) self.register_handler("cs_neutral", self._handle_cs_neutral) @@ -455,6 +467,189 @@ class PolarsTranslator: return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv) + # ==================== 补充的时间序列因子处理器 ==================== + + @time_series + def _handle_ts_sma(self, node: FunctionNode) -> pl.Expr: + """处理 ts_sma(x, n, m) -> ewm_mean(alpha=m/n, adjust=False)。 + + 国泰君安 191 的 SMA 使用 N*M 平滑,等效于 alpha=m/n 的 EWM。 + """ + if len(node.args) != 3: + raise ValueError("ts_sma 需要 3 个参数: (x, n, m)") + expr = self.translate(node.args[0]) + n = self._extract_window(node.args[1]) + m = self._extract_window(node.args[2]) + alpha = m / n if n != 0 else 0.5 + return expr.ewm_mean(alpha=alpha, adjust=False) + + @time_series + def _handle_ts_wma(self, node: FunctionNode) -> pl.Expr: + """处理 ts_wma(x, window) -> 使用 numpy.convolve 实现线性加权。 + + 使用卷积实现高性能线性加权移动平均。 + """ + if len(node.args) != 2: + raise ValueError("ts_wma 需要 2 个参数: (x, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + + def wma_calc(s: pl.Series) -> pl.Series: + """计算线性加权移动平均,使用 numpy 卷积优化。""" + values = s.to_numpy() + n = len(values) + if n == 0: + return pl.Series([float("nan")] * n) + + # 线性递增权重: 1, 2, 3, ..., window + weights = np.arange(1, window + 1, dtype=float) + weights = weights / weights.sum() # 归一化 + + # 使用卷积计算加权平均 + result = np.full(n, np.nan) + # valid 卷积结果需要至少 window 个有效数据点 + conv_res = np.convolve(values, weights[::-1], mode="valid") + result[window - 1 :] = conv_res + + return pl.Series(result) + + return expr.map_batches(wma_calc) + + @time_series + def _handle_ts_decay_linear(self, node: FunctionNode) -> pl.Expr: + """处理 ts_decay_linear(x, window) -> 复用 ts_wma 实现。 + + DECAYLINEAR 与 WMA 等价,都使用近期最高权重的线性权重。 + 直接复用 ts_wma 的实现。 + """ + # 复用 ts_wma 的实现 + return self._handle_ts_wma(node) + + @time_series + def _handle_ts_argmax(self, node: FunctionNode) -> pl.Expr: + """处理 ts_argmax(x, window) -> 距离最大值的交易日数。 + + HIGHDAY 语义:返回距离过去 window 期内最大值的交易日数。 + 例如:今天是最高点返回 0,昨天是最高点返回 1。 + """ + if len(node.args) != 2: + raise ValueError("ts_argmax 需要 2 个参数: (x, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + + def argmax_calc(s: pl.Series) -> pl.Series: + """计算距离最大值的交易日数。""" + values = s.to_numpy() + n = len(values) + if n == 0: + return pl.Series([float("nan")] * n) + + result = np.full(n, np.nan) + for i in range(window - 1, n): + window_slice = values[i - window + 1 : i + 1] + # 距离 = 当前索引 - (i - window + 1) - argmax_idx + # 其中 argmax_idx = np.argmax(window_slice) + argmax_idx = np.argmax(window_slice) + distance = window - 1 - argmax_idx + result[i] = distance + + return pl.Series(result) + + return expr.map_batches(argmax_calc) + + @time_series + def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr: + """处理 ts_argmin(x, window) -> 距离最小值的交易日数。 + + LOWDAY 语义:返回距离过去 window 期内最小值的交易日数。 + 例如:今天是最低点返回 0,昨天是最低点返回 1。 + """ + if len(node.args) != 2: + raise ValueError("ts_argmin 需要 2 个参数: (x, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + + def argmin_calc(s: pl.Series) -> pl.Series: + """计算距离最小值的交易日数。""" + values = s.to_numpy() + n = len(values) + if n == 0: + return pl.Series([float("nan")] * n) + + result = np.full(n, np.nan) + for i in range(window - 1, n): + window_slice = values[i - window + 1 : i + 1] + argmin_idx = np.argmin(window_slice) + distance = window - 1 - argmin_idx + result[i] = distance + + return pl.Series(result) + + return expr.map_batches(argmin_calc) + + @time_series + def _handle_ts_count(self, node: FunctionNode) -> pl.Expr: + """处理 ts_count(condition, window) -> 布尔expr.cast(Int32).rolling_sum(window)。""" + if len(node.args) != 2: + raise ValueError("ts_count 需要 2 个参数: (condition, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + # 布尔型转换为 Int32 后 rolling_sum + return expr.cast(pl.Int32).rolling_sum(window_size=window) + + @time_series + def _handle_ts_prod(self, node: FunctionNode) -> pl.Expr: + """处理 ts_prod(x, window) -> 滚动连乘积。 + + 使用 map_batches 实现窗口内的累积乘积。 + """ + if len(node.args) != 2: + raise ValueError("ts_prod 需要 2 个参数: (x, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + + def prod_calc(s: pl.Series) -> pl.Series: + """计算窗口内的连乘积。""" + values = s.to_numpy() + n = len(values) + if n == 0: + return pl.Series([float("nan")] * n) + + result = np.full(n, np.nan) + for i in range(window - 1, n): + window_slice = values[i - window + 1 : i + 1] + result[i] = np.prod(window_slice) + + return pl.Series(result) + + return expr.map_batches(prod_calc) + + @time_series + def _handle_ts_sumac(self, node: FunctionNode) -> pl.Expr: + """处理 ts_sumac(x) -> cum_sum()。""" + if len(node.args) != 1: + raise ValueError("ts_sumac 需要 1 个参数: (x)") + expr = self.translate(node.args[0]) + return expr.cum_sum() + + @element_wise + def _handle_max(self, node: FunctionNode) -> pl.Expr: + """处理 max(x, y) -> 逐元素最大值。""" + if len(node.args) != 2: + raise ValueError("max 需要 2 个参数: (x, y)") + x = self.translate(node.args[0]) + y = self.translate(node.args[1]) + return pl.when(x >= y).then(x).otherwise(y) + + @element_wise + def _handle_min(self, node: FunctionNode) -> pl.Expr: + """处理 min(x, y) -> 逐元素最小值。""" + if len(node.args) != 2: + raise ValueError("min 需要 2 个参数: (x, y)") + x = self.translate(node.args[0]) + y = self.translate(node.args[1]) + return pl.when(x <= y).then(x).otherwise(y) + # ==================== 截面因子处理器 (cs_*) ==================== # 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表 diff --git a/tests/test_new_ts_functions.py b/tests/test_new_ts_functions.py new file mode 100644 index 0000000..2c9429b --- /dev/null +++ b/tests/test_new_ts_functions.py @@ -0,0 +1,130 @@ +"""测试新增的时间序列函数和智能分发逻辑。""" + +import pytest +import polars as pl +import numpy as np +from src.factors.dsl import Symbol, FunctionNode +from src.factors.translator import PolarsTranslator + + +def test_ts_sma_translate(): + """测试 ts_sma 翻译正确。""" + close = Symbol("close") + expr = FunctionNode("ts_sma", close, 10, 5) + translator = PolarsTranslator() + result = translator.translate(expr) + assert isinstance(result, pl.Expr) + + +def test_ts_wma_translate(): + """测试 ts_wma 翻译正确。""" + close = Symbol("close") + expr = FunctionNode("ts_wma", close, 20) + translator = PolarsTranslator() + result = translator.translate(expr) + assert isinstance(result, pl.Expr) + + +def test_ts_sumac_translate(): + """测试 ts_sumac 翻译正确。""" + close = Symbol("close") + expr = FunctionNode("ts_sumac", close) + translator = PolarsTranslator() + result = translator.translate(expr) + assert isinstance(result, pl.Expr) + + +def test_max_intelligent_dispatch(): + """测试 max_ 智能分发: int -> ts_max,其他 -> element-wise max。""" + from src.factors.api import max_, close + + # 正整数 -> ts_max + result = max_(close, 20) + assert result.func_name == "ts_max" + + # 零或负数 -> element-wise max + result = max_(close, 0) + assert result.func_name == "max" + + result = max_(close, -1) + assert result.func_name == "max" + + # 浮点数 -> element-wise max + result = max_(close, 10.5) + assert result.func_name == "max" + + +def test_min_intelligent_dispatch(): + """测试 min_ 智能分发: int -> ts_min,其他 -> element-wise min。""" + from src.factors.api import min_, close + + # 正整数 -> ts_min + result = min_(close, 20) + assert result.func_name == "ts_min" + + # 零或负数 -> element-wise min + result = min_(close, 0) + assert result.func_name == "min" + + +def create_test_data() -> pl.DataFrame: + """创建测试数据。""" + np.random.seed(42) + n = 100 + return pl.DataFrame( + { + "ts_code": ["000001.SZ"] * n, + "trade_date": list(range(20240101, 20240101 + n)), + "close": np.random.randn(n).cumsum() + 100, + } + ) + + +def test_ts_sma_computation(): + """测试 ts_sma 计算与原生 Polars 一致。""" + df = create_test_data() + translator = PolarsTranslator() + + # 翻译因子 + close = Symbol("close") + expr_node = FunctionNode("ts_sma", close, 10, 5) + expr = translator.translate(expr_node) + + # 使用翻译后的表达式计算 + result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")]) + + # 原生 Polars 计算 + native = df.with_columns( + [pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")] + ) + + # 对比结果 + assert np.allclose( + result["ts_sma_result"].to_numpy()[9:], + native["native_sma"].to_numpy()[9:], + equal_nan=True, + ) + + +def test_ts_sumac_computation(): + """测试 ts_sumac 计算与原生 Polars 一致。""" + df = create_test_data() + translator = PolarsTranslator() + + close = Symbol("close") + expr_node = FunctionNode("ts_sumac", close) + expr = translator.translate(expr_node) + + result = df.select( + ["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")] + ) + + native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")]) + + assert np.allclose( + result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy() + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])