feat(factors): 添加时间序列函数及智能路由

- 新增 8 个国泰君安 191 兼容的时间序列函数:ts_sma, ts_wma, ts_decay_linear, ts_argmax, ts_argmin, ts_count, ts_prod, ts_sumac
- max_/min_ 函数智能路由:正整数参数自动调用 ts_max/ts_min 实现滚动窗口逻辑
This commit is contained in:
2026-03-15 13:05:55 +08:00
parent 0e9ea5d533
commit c6ebab0e58
3 changed files with 467 additions and 2 deletions

View File

@@ -0,0 +1,130 @@
"""测试新增的时间序列函数和智能分发逻辑。"""
import pytest
import polars as pl
import numpy as np
from src.factors.dsl import Symbol, FunctionNode
from src.factors.translator import PolarsTranslator
def test_ts_sma_translate():
"""测试 ts_sma 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_sma", close, 10, 5)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_ts_wma_translate():
"""测试 ts_wma 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_wma", close, 20)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_ts_sumac_translate():
"""测试 ts_sumac 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_sumac", close)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_max_intelligent_dispatch():
"""测试 max_ 智能分发: int -> ts_max其他 -> element-wise max。"""
from src.factors.api import max_, close
# 正整数 -> ts_max
result = max_(close, 20)
assert result.func_name == "ts_max"
# 零或负数 -> element-wise max
result = max_(close, 0)
assert result.func_name == "max"
result = max_(close, -1)
assert result.func_name == "max"
# 浮点数 -> element-wise max
result = max_(close, 10.5)
assert result.func_name == "max"
def test_min_intelligent_dispatch():
"""测试 min_ 智能分发: int -> ts_min其他 -> element-wise min。"""
from src.factors.api import min_, close
# 正整数 -> ts_min
result = min_(close, 20)
assert result.func_name == "ts_min"
# 零或负数 -> element-wise min
result = min_(close, 0)
assert result.func_name == "min"
def create_test_data() -> pl.DataFrame:
"""创建测试数据。"""
np.random.seed(42)
n = 100
return pl.DataFrame(
{
"ts_code": ["000001.SZ"] * n,
"trade_date": list(range(20240101, 20240101 + n)),
"close": np.random.randn(n).cumsum() + 100,
}
)
def test_ts_sma_computation():
"""测试 ts_sma 计算与原生 Polars 一致。"""
df = create_test_data()
translator = PolarsTranslator()
# 翻译因子
close = Symbol("close")
expr_node = FunctionNode("ts_sma", close, 10, 5)
expr = translator.translate(expr_node)
# 使用翻译后的表达式计算
result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")])
# 原生 Polars 计算
native = df.with_columns(
[pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")]
)
# 对比结果
assert np.allclose(
result["ts_sma_result"].to_numpy()[9:],
native["native_sma"].to_numpy()[9:],
equal_nan=True,
)
def test_ts_sumac_computation():
"""测试 ts_sumac 计算与原生 Polars 一致。"""
df = create_test_data()
translator = PolarsTranslator()
close = Symbol("close")
expr_node = FunctionNode("ts_sumac", close)
expr = translator.translate(expr_node)
result = df.select(
["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")]
)
native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")])
assert np.allclose(
result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy()
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])