feat(factors): 添加时间序列函数及智能路由
- 新增 8 个国泰君安 191 兼容的时间序列函数:ts_sma, ts_wma, ts_decay_linear, ts_argmax, ts_argmin, ts_count, ts_prod, ts_sumac - max_/min_ 函数智能路由:正整数参数自动调用 ts_max/ts_min 实现滚动窗口逻辑
This commit is contained in:
130
tests/test_new_ts_functions.py
Normal file
130
tests/test_new_ts_functions.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""测试新增的时间序列函数和智能分发逻辑。"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
from src.factors.translator import PolarsTranslator
|
||||
|
||||
|
||||
def test_ts_sma_translate():
|
||||
"""测试 ts_sma 翻译正确。"""
|
||||
close = Symbol("close")
|
||||
expr = FunctionNode("ts_sma", close, 10, 5)
|
||||
translator = PolarsTranslator()
|
||||
result = translator.translate(expr)
|
||||
assert isinstance(result, pl.Expr)
|
||||
|
||||
|
||||
def test_ts_wma_translate():
|
||||
"""测试 ts_wma 翻译正确。"""
|
||||
close = Symbol("close")
|
||||
expr = FunctionNode("ts_wma", close, 20)
|
||||
translator = PolarsTranslator()
|
||||
result = translator.translate(expr)
|
||||
assert isinstance(result, pl.Expr)
|
||||
|
||||
|
||||
def test_ts_sumac_translate():
|
||||
"""测试 ts_sumac 翻译正确。"""
|
||||
close = Symbol("close")
|
||||
expr = FunctionNode("ts_sumac", close)
|
||||
translator = PolarsTranslator()
|
||||
result = translator.translate(expr)
|
||||
assert isinstance(result, pl.Expr)
|
||||
|
||||
|
||||
def test_max_intelligent_dispatch():
|
||||
"""测试 max_ 智能分发: int -> ts_max,其他 -> element-wise max。"""
|
||||
from src.factors.api import max_, close
|
||||
|
||||
# 正整数 -> ts_max
|
||||
result = max_(close, 20)
|
||||
assert result.func_name == "ts_max"
|
||||
|
||||
# 零或负数 -> element-wise max
|
||||
result = max_(close, 0)
|
||||
assert result.func_name == "max"
|
||||
|
||||
result = max_(close, -1)
|
||||
assert result.func_name == "max"
|
||||
|
||||
# 浮点数 -> element-wise max
|
||||
result = max_(close, 10.5)
|
||||
assert result.func_name == "max"
|
||||
|
||||
|
||||
def test_min_intelligent_dispatch():
|
||||
"""测试 min_ 智能分发: int -> ts_min,其他 -> element-wise min。"""
|
||||
from src.factors.api import min_, close
|
||||
|
||||
# 正整数 -> ts_min
|
||||
result = min_(close, 20)
|
||||
assert result.func_name == "ts_min"
|
||||
|
||||
# 零或负数 -> element-wise min
|
||||
result = min_(close, 0)
|
||||
assert result.func_name == "min"
|
||||
|
||||
|
||||
def create_test_data() -> pl.DataFrame:
|
||||
"""创建测试数据。"""
|
||||
np.random.seed(42)
|
||||
n = 100
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"] * n,
|
||||
"trade_date": list(range(20240101, 20240101 + n)),
|
||||
"close": np.random.randn(n).cumsum() + 100,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_ts_sma_computation():
|
||||
"""测试 ts_sma 计算与原生 Polars 一致。"""
|
||||
df = create_test_data()
|
||||
translator = PolarsTranslator()
|
||||
|
||||
# 翻译因子
|
||||
close = Symbol("close")
|
||||
expr_node = FunctionNode("ts_sma", close, 10, 5)
|
||||
expr = translator.translate(expr_node)
|
||||
|
||||
# 使用翻译后的表达式计算
|
||||
result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")])
|
||||
|
||||
# 原生 Polars 计算
|
||||
native = df.with_columns(
|
||||
[pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")]
|
||||
)
|
||||
|
||||
# 对比结果
|
||||
assert np.allclose(
|
||||
result["ts_sma_result"].to_numpy()[9:],
|
||||
native["native_sma"].to_numpy()[9:],
|
||||
equal_nan=True,
|
||||
)
|
||||
|
||||
|
||||
def test_ts_sumac_computation():
|
||||
"""测试 ts_sumac 计算与原生 Polars 一致。"""
|
||||
df = create_test_data()
|
||||
translator = PolarsTranslator()
|
||||
|
||||
close = Symbol("close")
|
||||
expr_node = FunctionNode("ts_sumac", close)
|
||||
expr = translator.translate(expr_node)
|
||||
|
||||
result = df.select(
|
||||
["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")]
|
||||
)
|
||||
|
||||
native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")])
|
||||
|
||||
assert np.allclose(
|
||||
result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user