- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合 - 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等) - 新增 factors/compiler.py: 因子编译器 - 新增 factors/translator.py: DSL表达式翻译器 - 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据 - 新增 data/data_router.py: 数据路由功能 - 新增相关测试用例
326 lines
10 KiB
Python
326 lines
10 KiB
Python
"""测试 DSL 字符串自动提升(Promotion)功能。
|
||
|
||
验证以下功能:
|
||
1. 字符串自动转换为 Symbol
|
||
2. 算子函数支持字符串参数
|
||
3. 右位运算支持
|
||
"""
|
||
|
||
import pytest
|
||
from src.factors.dsl import (
|
||
Symbol,
|
||
Constant,
|
||
BinaryOpNode,
|
||
UnaryOpNode,
|
||
FunctionNode,
|
||
_ensure_node,
|
||
)
|
||
from src.factors.api import (
|
||
close,
|
||
open,
|
||
ts_mean,
|
||
ts_std,
|
||
ts_corr,
|
||
cs_rank,
|
||
cs_zscore,
|
||
log,
|
||
exp,
|
||
max_,
|
||
min_,
|
||
clip,
|
||
if_,
|
||
where,
|
||
)
|
||
|
||
|
||
class TestEnsureNode:
|
||
"""测试 _ensure_node 辅助函数。"""
|
||
|
||
def test_ensure_node_with_node(self):
|
||
"""Node 类型应该原样返回。"""
|
||
sym = Symbol("close")
|
||
result = _ensure_node(sym)
|
||
assert result is sym
|
||
|
||
def test_ensure_node_with_int(self):
|
||
"""整数应该转换为 Constant。"""
|
||
result = _ensure_node(100)
|
||
assert isinstance(result, Constant)
|
||
assert result.value == 100
|
||
|
||
def test_ensure_node_with_float(self):
|
||
"""浮点数应该转换为 Constant。"""
|
||
result = _ensure_node(3.14)
|
||
assert isinstance(result, Constant)
|
||
assert result.value == 3.14
|
||
|
||
def test_ensure_node_with_str(self):
|
||
"""字符串应该转换为 Symbol。"""
|
||
result = _ensure_node("close")
|
||
assert isinstance(result, Symbol)
|
||
assert result.name == "close"
|
||
|
||
def test_ensure_node_with_invalid_type(self):
|
||
"""无效类型应该抛出 TypeError。"""
|
||
with pytest.raises(TypeError):
|
||
_ensure_node([1, 2, 3])
|
||
|
||
|
||
class TestSymbolStringPromotion:
|
||
"""测试 Symbol 与字符串的运算。"""
|
||
|
||
def test_symbol_add_str(self):
|
||
"""Symbol + 字符串。"""
|
||
expr = close + "pe_ratio"
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "+"
|
||
assert isinstance(expr.left, Symbol)
|
||
assert expr.left.name == "close"
|
||
assert isinstance(expr.right, Symbol)
|
||
assert expr.right.name == "pe_ratio"
|
||
|
||
def test_symbol_sub_str(self):
|
||
"""Symbol - 字符串。"""
|
||
expr = close - "open"
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "-"
|
||
assert expr.right.name == "open"
|
||
|
||
def test_symbol_mul_str(self):
|
||
"""Symbol * 字符串。"""
|
||
expr = close * "volume"
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "*"
|
||
assert expr.right.name == "volume"
|
||
|
||
def test_symbol_div_str(self):
|
||
"""Symbol / 字符串。"""
|
||
expr = close / "pe_ratio"
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "/"
|
||
assert expr.right.name == "pe_ratio"
|
||
|
||
def test_symbol_pow_str(self):
|
||
"""Symbol ** 字符串。"""
|
||
expr = close ** "exponent"
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "**"
|
||
assert expr.right.name == "exponent"
|
||
|
||
|
||
class TestRightHandOperations:
|
||
"""测试右位运算。"""
|
||
|
||
def test_int_add_symbol(self):
|
||
"""整数 + Symbol。"""
|
||
expr = 100 + close
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "+"
|
||
assert isinstance(expr.left, Constant)
|
||
assert expr.left.value == 100
|
||
assert isinstance(expr.right, Symbol)
|
||
assert expr.right.name == "close"
|
||
|
||
def test_int_sub_symbol(self):
|
||
"""整数 - Symbol。"""
|
||
expr = 100 - close
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "-"
|
||
assert expr.left.value == 100
|
||
assert expr.right.name == "close"
|
||
|
||
def test_int_mul_symbol(self):
|
||
"""整数 * Symbol。"""
|
||
expr = 2 * close
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "*"
|
||
assert expr.left.value == 2
|
||
assert expr.right.name == "close"
|
||
|
||
def test_int_div_symbol(self):
|
||
"""整数 / Symbol。"""
|
||
expr = 100 / close
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "/"
|
||
assert expr.left.value == 100
|
||
assert expr.right.name == "close"
|
||
|
||
def test_int_div_str_not_supported(self):
|
||
"""Python 内置 int 不支持直接与 str 进行除法运算。
|
||
|
||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
||
所以 100 / "close" 会抛出 TypeError。正确的用法是 100 / Symbol("close") 或
|
||
使用已有的 Symbol 对象如 close。
|
||
"""
|
||
with pytest.raises(TypeError):
|
||
100 / "close"
|
||
def test_int_floordiv_symbol(self):
|
||
"""整数 // Symbol。"""
|
||
expr = 100 // close
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "//"
|
||
|
||
def test_int_mod_symbol(self):
|
||
"""整数 % Symbol。"""
|
||
expr = 100 % close
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "%"
|
||
|
||
def test_int_pow_symbol(self):
|
||
"""整数 ** Symbol。"""
|
||
expr = 2**close
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "**"
|
||
assert expr.left.value == 2
|
||
assert expr.right.name == "close"
|
||
|
||
|
||
class TestOperatorFunctionsWithStrings:
|
||
"""测试算子函数支持字符串参数。"""
|
||
|
||
def test_ts_mean_with_str(self):
|
||
"""ts_mean 支持字符串参数。"""
|
||
expr = ts_mean("close", 20)
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "ts_mean"
|
||
assert len(expr.args) == 2
|
||
assert isinstance(expr.args[0], Symbol)
|
||
assert expr.args[0].name == "close"
|
||
assert isinstance(expr.args[1], Constant)
|
||
assert expr.args[1].value == 20
|
||
|
||
def test_ts_std_with_str(self):
|
||
"""ts_std 支持字符串参数。"""
|
||
expr = ts_std("volume", 10)
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "ts_std"
|
||
assert expr.args[0].name == "volume"
|
||
|
||
def test_ts_corr_with_str(self):
|
||
"""ts_corr 支持字符串参数。"""
|
||
expr = ts_corr("close", "open", 20)
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "ts_corr"
|
||
assert expr.args[0].name == "close"
|
||
assert expr.args[1].name == "open"
|
||
|
||
def test_cs_rank_with_str(self):
|
||
"""cs_rank 支持字符串参数。"""
|
||
expr = cs_rank("pe_ratio")
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "cs_rank"
|
||
assert expr.args[0].name == "pe_ratio"
|
||
|
||
def test_cs_zscore_with_str(self):
|
||
"""cs_zscore 支持字符串参数。"""
|
||
expr = cs_zscore("market_cap")
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "cs_zscore"
|
||
assert expr.args[0].name == "market_cap"
|
||
|
||
def test_log_with_str(self):
|
||
"""log 支持字符串参数。"""
|
||
expr = log("close")
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "log"
|
||
assert expr.args[0].name == "close"
|
||
|
||
def test_max_with_str(self):
|
||
"""max_ 支持字符串参数。"""
|
||
expr = max_("close", "open")
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "max"
|
||
assert expr.args[0].name == "close"
|
||
assert expr.args[1].name == "open"
|
||
|
||
def test_max_with_str_and_number(self):
|
||
"""max_ 支持字符串和数值混合。"""
|
||
expr = max_("close", 100)
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.args[0].name == "close"
|
||
assert expr.args[1].value == 100
|
||
|
||
def test_clip_with_str(self):
|
||
"""clip 支持字符串参数。"""
|
||
expr = clip("pe_ratio", "lower_bound", "upper_bound")
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "clip"
|
||
assert expr.args[0].name == "pe_ratio"
|
||
assert expr.args[1].name == "lower_bound"
|
||
assert expr.args[2].name == "upper_bound"
|
||
|
||
def test_if_with_str(self):
|
||
"""if_ 支持字符串参数。"""
|
||
expr = if_("condition", "true_val", "false_val")
|
||
assert isinstance(expr, FunctionNode)
|
||
assert expr.func_name == "if"
|
||
assert expr.args[0].name == "condition"
|
||
assert expr.args[1].name == "true_val"
|
||
assert expr.args[2].name == "false_val"
|
||
|
||
|
||
class TestComplexExpressions:
|
||
"""测试复杂表达式。"""
|
||
|
||
def test_complex_expression_1(self):
|
||
"""复杂表达式:ts_mean("close", 5) / "pe_ratio"。"""
|
||
expr = ts_mean("close", 5) / "pe_ratio"
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "/"
|
||
assert isinstance(expr.left, FunctionNode)
|
||
assert expr.left.func_name == "ts_mean"
|
||
assert isinstance(expr.right, Symbol)
|
||
assert expr.right.name == "pe_ratio"
|
||
|
||
def test_complex_expression_2(self):
|
||
"""复杂表达式:100 / close * cs_rank("volume") 。
|
||
|
||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
||
所以需要使用已有的 Symbol 对象或先创建 Symbol。
|
||
"""
|
||
expr = 100 / close * cs_rank("volume")
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "*"
|
||
assert isinstance(expr.left, BinaryOpNode)
|
||
assert expr.left.op == "/"
|
||
assert isinstance(expr.right, FunctionNode)
|
||
assert expr.right.func_name == "cs_rank"
|
||
def test_complex_expression_3(self):
|
||
"""复杂表达式:ts_mean(close - "open", 20) / close。"""
|
||
expr = ts_mean(close - "open", 20) / close
|
||
assert isinstance(expr, BinaryOpNode)
|
||
assert expr.op == "/"
|
||
assert isinstance(expr.left, FunctionNode)
|
||
assert expr.left.func_name == "ts_mean"
|
||
# 检查 ts_mean 的第一个参数是 close - open
|
||
assert isinstance(expr.left.args[0], BinaryOpNode)
|
||
assert expr.left.args[0].op == "-"
|
||
|
||
|
||
class TestExpressionRepr:
|
||
"""测试表达式字符串表示。"""
|
||
|
||
def test_symbol_str_repr(self):
|
||
"""Symbol 的字符串表示。"""
|
||
expr = Symbol("close")
|
||
assert repr(expr) == "close"
|
||
|
||
def test_binary_op_repr(self):
|
||
"""二元运算的字符串表示。"""
|
||
expr = close + "open"
|
||
assert repr(expr) == "(close + open)"
|
||
|
||
def test_function_node_repr(self):
|
||
"""函数节点的字符串表示。"""
|
||
expr = ts_mean("close", 20)
|
||
assert repr(expr) == "ts_mean(close, 20)"
|
||
|
||
def test_complex_expr_repr(self):
|
||
"""复杂表达式的字符串表示。"""
|
||
expr = ts_mean("close", 5) / "pe_ratio"
|
||
assert repr(expr) == "(ts_mean(close, 5) / pe_ratio)"
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|