131 lines
3.5 KiB
Python
131 lines
3.5 KiB
Python
|
|
"""测试新增的时间序列函数和智能分发逻辑。"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
import polars as pl
|
|||
|
|
import numpy as np
|
|||
|
|
from src.factors.dsl import Symbol, FunctionNode
|
|||
|
|
from src.factors.translator import PolarsTranslator
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_ts_sma_translate():
|
|||
|
|
"""测试 ts_sma 翻译正确。"""
|
|||
|
|
close = Symbol("close")
|
|||
|
|
expr = FunctionNode("ts_sma", close, 10, 5)
|
|||
|
|
translator = PolarsTranslator()
|
|||
|
|
result = translator.translate(expr)
|
|||
|
|
assert isinstance(result, pl.Expr)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_ts_wma_translate():
|
|||
|
|
"""测试 ts_wma 翻译正确。"""
|
|||
|
|
close = Symbol("close")
|
|||
|
|
expr = FunctionNode("ts_wma", close, 20)
|
|||
|
|
translator = PolarsTranslator()
|
|||
|
|
result = translator.translate(expr)
|
|||
|
|
assert isinstance(result, pl.Expr)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_ts_sumac_translate():
|
|||
|
|
"""测试 ts_sumac 翻译正确。"""
|
|||
|
|
close = Symbol("close")
|
|||
|
|
expr = FunctionNode("ts_sumac", close)
|
|||
|
|
translator = PolarsTranslator()
|
|||
|
|
result = translator.translate(expr)
|
|||
|
|
assert isinstance(result, pl.Expr)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_max_intelligent_dispatch():
|
|||
|
|
"""测试 max_ 智能分发: int -> ts_max,其他 -> element-wise max。"""
|
|||
|
|
from src.factors.api import max_, close
|
|||
|
|
|
|||
|
|
# 正整数 -> ts_max
|
|||
|
|
result = max_(close, 20)
|
|||
|
|
assert result.func_name == "ts_max"
|
|||
|
|
|
|||
|
|
# 零或负数 -> element-wise max
|
|||
|
|
result = max_(close, 0)
|
|||
|
|
assert result.func_name == "max"
|
|||
|
|
|
|||
|
|
result = max_(close, -1)
|
|||
|
|
assert result.func_name == "max"
|
|||
|
|
|
|||
|
|
# 浮点数 -> element-wise max
|
|||
|
|
result = max_(close, 10.5)
|
|||
|
|
assert result.func_name == "max"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_min_intelligent_dispatch():
|
|||
|
|
"""测试 min_ 智能分发: int -> ts_min,其他 -> element-wise min。"""
|
|||
|
|
from src.factors.api import min_, close
|
|||
|
|
|
|||
|
|
# 正整数 -> ts_min
|
|||
|
|
result = min_(close, 20)
|
|||
|
|
assert result.func_name == "ts_min"
|
|||
|
|
|
|||
|
|
# 零或负数 -> element-wise min
|
|||
|
|
result = min_(close, 0)
|
|||
|
|
assert result.func_name == "min"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_test_data() -> pl.DataFrame:
|
|||
|
|
"""创建测试数据。"""
|
|||
|
|
np.random.seed(42)
|
|||
|
|
n = 100
|
|||
|
|
return pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"ts_code": ["000001.SZ"] * n,
|
|||
|
|
"trade_date": list(range(20240101, 20240101 + n)),
|
|||
|
|
"close": np.random.randn(n).cumsum() + 100,
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_ts_sma_computation():
|
|||
|
|
"""测试 ts_sma 计算与原生 Polars 一致。"""
|
|||
|
|
df = create_test_data()
|
|||
|
|
translator = PolarsTranslator()
|
|||
|
|
|
|||
|
|
# 翻译因子
|
|||
|
|
close = Symbol("close")
|
|||
|
|
expr_node = FunctionNode("ts_sma", close, 10, 5)
|
|||
|
|
expr = translator.translate(expr_node)
|
|||
|
|
|
|||
|
|
# 使用翻译后的表达式计算
|
|||
|
|
result = df.select(["ts_code", "trade_date", "close", expr.alias("ts_sma_result")])
|
|||
|
|
|
|||
|
|
# 原生 Polars 计算
|
|||
|
|
native = df.with_columns(
|
|||
|
|
[pl.col("close").ewm_mean(alpha=5 / 10, adjust=False).alias("native_sma")]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 对比结果
|
|||
|
|
assert np.allclose(
|
|||
|
|
result["ts_sma_result"].to_numpy()[9:],
|
|||
|
|
native["native_sma"].to_numpy()[9:],
|
|||
|
|
equal_nan=True,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_ts_sumac_computation():
|
|||
|
|
"""测试 ts_sumac 计算与原生 Polars 一致。"""
|
|||
|
|
df = create_test_data()
|
|||
|
|
translator = PolarsTranslator()
|
|||
|
|
|
|||
|
|
close = Symbol("close")
|
|||
|
|
expr_node = FunctionNode("ts_sumac", close)
|
|||
|
|
expr = translator.translate(expr_node)
|
|||
|
|
|
|||
|
|
result = df.select(
|
|||
|
|
["ts_code", "trade_date", "close", expr.alias("ts_sumac_result")]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
native = df.with_columns([pl.col("close").cum_sum().alias("native_sumac")])
|
|||
|
|
|
|||
|
|
assert np.allclose(
|
|||
|
|
result["ts_sumac_result"].to_numpy(), native["native_sumac"].to_numpy()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
pytest.main([__file__, "-v"])
|