Files
ProStock/tests/factors/test_dsl_promotion.py
liaozhaorun 0698b9d919 feat: 添加DSL因子表达式系统和Pro Bar API封装
- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合
- 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等)
- 新增 factors/compiler.py: 因子编译器
- 新增 factors/translator.py: DSL表达式翻译器
- 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据
- 新增 data/data_router.py: 数据路由功能
- 新增相关测试用例
2026-02-27 22:43:45 +08:00

326 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试 DSL 字符串自动提升Promotion功能。
验证以下功能:
1. 字符串自动转换为 Symbol
2. 算子函数支持字符串参数
3. 右位运算支持
"""
import pytest
from src.factors.dsl import (
Symbol,
Constant,
BinaryOpNode,
UnaryOpNode,
FunctionNode,
_ensure_node,
)
from src.factors.api import (
close,
open,
ts_mean,
ts_std,
ts_corr,
cs_rank,
cs_zscore,
log,
exp,
max_,
min_,
clip,
if_,
where,
)
class TestEnsureNode:
"""测试 _ensure_node 辅助函数。"""
def test_ensure_node_with_node(self):
"""Node 类型应该原样返回。"""
sym = Symbol("close")
result = _ensure_node(sym)
assert result is sym
def test_ensure_node_with_int(self):
"""整数应该转换为 Constant。"""
result = _ensure_node(100)
assert isinstance(result, Constant)
assert result.value == 100
def test_ensure_node_with_float(self):
"""浮点数应该转换为 Constant。"""
result = _ensure_node(3.14)
assert isinstance(result, Constant)
assert result.value == 3.14
def test_ensure_node_with_str(self):
"""字符串应该转换为 Symbol。"""
result = _ensure_node("close")
assert isinstance(result, Symbol)
assert result.name == "close"
def test_ensure_node_with_invalid_type(self):
"""无效类型应该抛出 TypeError。"""
with pytest.raises(TypeError):
_ensure_node([1, 2, 3])
class TestSymbolStringPromotion:
"""测试 Symbol 与字符串的运算。"""
def test_symbol_add_str(self):
"""Symbol + 字符串。"""
expr = close + "pe_ratio"
assert isinstance(expr, BinaryOpNode)
assert expr.op == "+"
assert isinstance(expr.left, Symbol)
assert expr.left.name == "close"
assert isinstance(expr.right, Symbol)
assert expr.right.name == "pe_ratio"
def test_symbol_sub_str(self):
"""Symbol - 字符串。"""
expr = close - "open"
assert isinstance(expr, BinaryOpNode)
assert expr.op == "-"
assert expr.right.name == "open"
def test_symbol_mul_str(self):
"""Symbol * 字符串。"""
expr = close * "volume"
assert isinstance(expr, BinaryOpNode)
assert expr.op == "*"
assert expr.right.name == "volume"
def test_symbol_div_str(self):
"""Symbol / 字符串。"""
expr = close / "pe_ratio"
assert isinstance(expr, BinaryOpNode)
assert expr.op == "/"
assert expr.right.name == "pe_ratio"
def test_symbol_pow_str(self):
"""Symbol ** 字符串。"""
expr = close ** "exponent"
assert isinstance(expr, BinaryOpNode)
assert expr.op == "**"
assert expr.right.name == "exponent"
class TestRightHandOperations:
"""测试右位运算。"""
def test_int_add_symbol(self):
"""整数 + Symbol。"""
expr = 100 + close
assert isinstance(expr, BinaryOpNode)
assert expr.op == "+"
assert isinstance(expr.left, Constant)
assert expr.left.value == 100
assert isinstance(expr.right, Symbol)
assert expr.right.name == "close"
def test_int_sub_symbol(self):
"""整数 - Symbol。"""
expr = 100 - close
assert isinstance(expr, BinaryOpNode)
assert expr.op == "-"
assert expr.left.value == 100
assert expr.right.name == "close"
def test_int_mul_symbol(self):
"""整数 * Symbol。"""
expr = 2 * close
assert isinstance(expr, BinaryOpNode)
assert expr.op == "*"
assert expr.left.value == 2
assert expr.right.name == "close"
def test_int_div_symbol(self):
"""整数 / Symbol。"""
expr = 100 / close
assert isinstance(expr, BinaryOpNode)
assert expr.op == "/"
assert expr.left.value == 100
assert expr.right.name == "close"
def test_int_div_str_not_supported(self):
"""Python 内置 int 不支持直接与 str 进行除法运算。
注意Python 内置的 int 类型不支持直接与 str 进行除法运算,
所以 100 / "close" 会抛出 TypeError。正确的用法是 100 / Symbol("close") 或
使用已有的 Symbol 对象如 close。
"""
with pytest.raises(TypeError):
100 / "close"
def test_int_floordiv_symbol(self):
"""整数 // Symbol。"""
expr = 100 // close
assert isinstance(expr, BinaryOpNode)
assert expr.op == "//"
def test_int_mod_symbol(self):
"""整数 % Symbol。"""
expr = 100 % close
assert isinstance(expr, BinaryOpNode)
assert expr.op == "%"
def test_int_pow_symbol(self):
"""整数 ** Symbol。"""
expr = 2**close
assert isinstance(expr, BinaryOpNode)
assert expr.op == "**"
assert expr.left.value == 2
assert expr.right.name == "close"
class TestOperatorFunctionsWithStrings:
"""测试算子函数支持字符串参数。"""
def test_ts_mean_with_str(self):
"""ts_mean 支持字符串参数。"""
expr = ts_mean("close", 20)
assert isinstance(expr, FunctionNode)
assert expr.func_name == "ts_mean"
assert len(expr.args) == 2
assert isinstance(expr.args[0], Symbol)
assert expr.args[0].name == "close"
assert isinstance(expr.args[1], Constant)
assert expr.args[1].value == 20
def test_ts_std_with_str(self):
"""ts_std 支持字符串参数。"""
expr = ts_std("volume", 10)
assert isinstance(expr, FunctionNode)
assert expr.func_name == "ts_std"
assert expr.args[0].name == "volume"
def test_ts_corr_with_str(self):
"""ts_corr 支持字符串参数。"""
expr = ts_corr("close", "open", 20)
assert isinstance(expr, FunctionNode)
assert expr.func_name == "ts_corr"
assert expr.args[0].name == "close"
assert expr.args[1].name == "open"
def test_cs_rank_with_str(self):
"""cs_rank 支持字符串参数。"""
expr = cs_rank("pe_ratio")
assert isinstance(expr, FunctionNode)
assert expr.func_name == "cs_rank"
assert expr.args[0].name == "pe_ratio"
def test_cs_zscore_with_str(self):
"""cs_zscore 支持字符串参数。"""
expr = cs_zscore("market_cap")
assert isinstance(expr, FunctionNode)
assert expr.func_name == "cs_zscore"
assert expr.args[0].name == "market_cap"
def test_log_with_str(self):
"""log 支持字符串参数。"""
expr = log("close")
assert isinstance(expr, FunctionNode)
assert expr.func_name == "log"
assert expr.args[0].name == "close"
def test_max_with_str(self):
"""max_ 支持字符串参数。"""
expr = max_("close", "open")
assert isinstance(expr, FunctionNode)
assert expr.func_name == "max"
assert expr.args[0].name == "close"
assert expr.args[1].name == "open"
def test_max_with_str_and_number(self):
"""max_ 支持字符串和数值混合。"""
expr = max_("close", 100)
assert isinstance(expr, FunctionNode)
assert expr.args[0].name == "close"
assert expr.args[1].value == 100
def test_clip_with_str(self):
"""clip 支持字符串参数。"""
expr = clip("pe_ratio", "lower_bound", "upper_bound")
assert isinstance(expr, FunctionNode)
assert expr.func_name == "clip"
assert expr.args[0].name == "pe_ratio"
assert expr.args[1].name == "lower_bound"
assert expr.args[2].name == "upper_bound"
def test_if_with_str(self):
"""if_ 支持字符串参数。"""
expr = if_("condition", "true_val", "false_val")
assert isinstance(expr, FunctionNode)
assert expr.func_name == "if"
assert expr.args[0].name == "condition"
assert expr.args[1].name == "true_val"
assert expr.args[2].name == "false_val"
class TestComplexExpressions:
"""测试复杂表达式。"""
def test_complex_expression_1(self):
"""复杂表达式ts_mean("close", 5) / "pe_ratio""""
expr = ts_mean("close", 5) / "pe_ratio"
assert isinstance(expr, BinaryOpNode)
assert expr.op == "/"
assert isinstance(expr.left, FunctionNode)
assert expr.left.func_name == "ts_mean"
assert isinstance(expr.right, Symbol)
assert expr.right.name == "pe_ratio"
def test_complex_expression_2(self):
"""复杂表达式100 / close * cs_rank("volume") 。
注意Python 内置的 int 类型不支持直接与 str 进行除法运算,
所以需要使用已有的 Symbol 对象或先创建 Symbol。
"""
expr = 100 / close * cs_rank("volume")
assert isinstance(expr, BinaryOpNode)
assert expr.op == "*"
assert isinstance(expr.left, BinaryOpNode)
assert expr.left.op == "/"
assert isinstance(expr.right, FunctionNode)
assert expr.right.func_name == "cs_rank"
def test_complex_expression_3(self):
"""复杂表达式ts_mean(close - "open", 20) / close。"""
expr = ts_mean(close - "open", 20) / close
assert isinstance(expr, BinaryOpNode)
assert expr.op == "/"
assert isinstance(expr.left, FunctionNode)
assert expr.left.func_name == "ts_mean"
# 检查 ts_mean 的第一个参数是 close - open
assert isinstance(expr.left.args[0], BinaryOpNode)
assert expr.left.args[0].op == "-"
class TestExpressionRepr:
"""测试表达式字符串表示。"""
def test_symbol_str_repr(self):
"""Symbol 的字符串表示。"""
expr = Symbol("close")
assert repr(expr) == "close"
def test_binary_op_repr(self):
"""二元运算的字符串表示。"""
expr = close + "open"
assert repr(expr) == "(close + open)"
def test_function_node_repr(self):
"""函数节点的字符串表示。"""
expr = ts_mean("close", 20)
assert repr(expr) == "ts_mean(close, 20)"
def test_complex_expr_repr(self):
"""复杂表达式的字符串表示。"""
expr = ts_mean("close", 5) / "pe_ratio"
assert repr(expr) == "(ts_mean(close, 5) / pe_ratio)"
if __name__ == "__main__":
pytest.main([__file__, "-v"])