Files
ProStock/tests/test_new_ts_functions.py
liaozhaorun c6ebab0e58 feat(factors): 添加时间序列函数及智能路由
- 新增 8 个国泰君安 191 兼容的时间序列函数:ts_sma, ts_wma, ts_decay_linear, ts_argmax, ts_argmin, ts_count, ts_prod, ts_sumac
- max_/min_ 函数智能路由:正整数参数自动调用 ts_max/ts_min 实现滚动窗口逻辑
2026-03-15 13:05:55 +08:00

131 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试新增的时间序列函数和智能分发逻辑。"""
import pytest
import polars as pl
import numpy as np
from src.factors.dsl import Symbol, FunctionNode
from src.factors.translator import PolarsTranslator
def test_ts_sma_translate():
"""测试 ts_sma 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_sma", close, 10, 5)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_ts_wma_translate():
"""测试 ts_wma 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_wma", close, 20)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_ts_sumac_translate():
"""测试 ts_sumac 翻译正确。"""
close = Symbol("close")
expr = FunctionNode("ts_sumac", close)
translator = PolarsTranslator()
result = translator.translate(expr)
assert isinstance(result, pl.Expr)
def test_max_intelligent_dispatch():
"""测试 max_ 智能分发: int -> ts_max其他 -> element-wise max。"""
from src.factors.api import max_, close
# 正整数 -> ts_max
result = max_(close, 20)
assert result.func_name == "ts_max"
# 零或负数 -> element-wise max
result = max_(close, 0)
assert result.func_name == "max"
result = max_(close, -1)
assert result.func_name == "max"
# 浮点数 -> element-wise max
result = max_(close, 10.5)
assert result.func_name == "max"
def test_min_intelligent_dispatch():
"""测试 min_ 智能分发: int -> ts_min其他 -> element-wise min。"""
from src.factors.api import min_, close
# 正整数 -> ts_min
result = min_(close, 20)
assert result.func_name == "ts_min"
# 零或负数 -> element-wise min
result = min_(close, 0)
assert result.func_name == "min"
def create_test_data() -> pl.DataFrame:
"""创建测试数据。"""
np.random.seed(42)
n = 100
return pl.DataFrame(
{
"ts_code": ["000001.SZ"] * n,
"trade_date": list(range(20240101, 20240101 + n)),
"close": np.random.randn(n).cumsum() + 100,
}
)
def test_ts_sma_computation():
"""测试 ts_sma 计算与原生 Polars 一致。"""
df = create_test_data()
translator = PolarsTranslator()
# 翻译因子
close = Symbol("close")
expr_node = FunctionNode("ts_sma", close, 10, 5)
expr = translator.translate(expr_node)
# 使用翻译后的表达式计算
result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")])
# 原生 Polars 计算
native = df.with_columns(
[pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")]
)
# 对比结果
assert np.allclose(
result["ts_sma_result"].to_numpy()[9:],
native["native_sma"].to_numpy()[9:],
equal_nan=True,
)
def test_ts_sumac_computation():
"""测试 ts_sumac 计算与原生 Polars 一致。"""
df = create_test_data()
translator = PolarsTranslator()
close = Symbol("close")
expr_node = FunctionNode("ts_sumac", close)
expr = translator.translate(expr_node)
result = df.select(
["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")]
)
native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")])
assert np.allclose(
result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy()
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])