"""测试 DSL 字符串自动提升(Promotion)功能。 验证以下功能: 1. 字符串自动转换为 Symbol 2. 算子函数支持字符串参数 3. 右位运算支持 """ import pytest from src.factors.dsl import ( Symbol, Constant, BinaryOpNode, UnaryOpNode, FunctionNode, _ensure_node, ) from src.factors.api import ( close, open, ts_mean, ts_std, ts_corr, cs_rank, cs_zscore, log, exp, max_, min_, clip, if_, where, ) class TestEnsureNode: """测试 _ensure_node 辅助函数。""" def test_ensure_node_with_node(self): """Node 类型应该原样返回。""" sym = Symbol("close") result = _ensure_node(sym) assert result is sym def test_ensure_node_with_int(self): """整数应该转换为 Constant。""" result = _ensure_node(100) assert isinstance(result, Constant) assert result.value == 100 def test_ensure_node_with_float(self): """浮点数应该转换为 Constant。""" result = _ensure_node(3.14) assert isinstance(result, Constant) assert result.value == 3.14 def test_ensure_node_with_str(self): """字符串应该转换为 Symbol。""" result = _ensure_node("close") assert isinstance(result, Symbol) assert result.name == "close" def test_ensure_node_with_invalid_type(self): """无效类型应该抛出 TypeError。""" with pytest.raises(TypeError): _ensure_node([1, 2, 3]) class TestSymbolStringPromotion: """测试 Symbol 与字符串的运算。""" def test_symbol_add_str(self): """Symbol + 字符串。""" expr = close + "pe_ratio" assert isinstance(expr, BinaryOpNode) assert expr.op == "+" assert isinstance(expr.left, Symbol) assert expr.left.name == "close" assert isinstance(expr.right, Symbol) assert expr.right.name == "pe_ratio" def test_symbol_sub_str(self): """Symbol - 字符串。""" expr = close - "open" assert isinstance(expr, BinaryOpNode) assert expr.op == "-" assert expr.right.name == "open" def test_symbol_mul_str(self): """Symbol * 字符串。""" expr = close * "volume" assert isinstance(expr, BinaryOpNode) assert expr.op == "*" assert expr.right.name == "volume" def test_symbol_div_str(self): """Symbol / 字符串。""" expr = close / "pe_ratio" assert isinstance(expr, BinaryOpNode) assert expr.op == "/" assert expr.right.name == "pe_ratio" def test_symbol_pow_str(self): """Symbol ** 字符串。""" expr = close ** "exponent" assert isinstance(expr, BinaryOpNode) assert expr.op == "**" assert expr.right.name == "exponent" class TestRightHandOperations: """测试右位运算。""" def test_int_add_symbol(self): """整数 + Symbol。""" expr = 100 + close assert isinstance(expr, BinaryOpNode) assert expr.op == "+" assert isinstance(expr.left, Constant) assert expr.left.value == 100 assert isinstance(expr.right, Symbol) assert expr.right.name == "close" def test_int_sub_symbol(self): """整数 - Symbol。""" expr = 100 - close assert isinstance(expr, BinaryOpNode) assert expr.op == "-" assert expr.left.value == 100 assert expr.right.name == "close" def test_int_mul_symbol(self): """整数 * Symbol。""" expr = 2 * close assert isinstance(expr, BinaryOpNode) assert expr.op == "*" assert expr.left.value == 2 assert expr.right.name == "close" def test_int_div_symbol(self): """整数 / Symbol。""" expr = 100 / close assert isinstance(expr, BinaryOpNode) assert expr.op == "/" assert expr.left.value == 100 assert expr.right.name == "close" def test_int_div_str_not_supported(self): """Python 内置 int 不支持直接与 str 进行除法运算。 注意:Python 内置的 int 类型不支持直接与 str 进行除法运算, 所以 100 / "close" 会抛出 TypeError。正确的用法是 100 / Symbol("close") 或 使用已有的 Symbol 对象如 close。 """ with pytest.raises(TypeError): 100 / "close" def test_int_floordiv_symbol(self): """整数 // Symbol。""" expr = 100 // close assert isinstance(expr, BinaryOpNode) assert expr.op == "//" def test_int_mod_symbol(self): """整数 % Symbol。""" expr = 100 % close assert isinstance(expr, BinaryOpNode) assert expr.op == "%" def test_int_pow_symbol(self): """整数 ** Symbol。""" expr = 2**close assert isinstance(expr, BinaryOpNode) assert expr.op == "**" assert expr.left.value == 2 assert expr.right.name == "close" class TestOperatorFunctionsWithStrings: """测试算子函数支持字符串参数。""" def test_ts_mean_with_str(self): """ts_mean 支持字符串参数。""" expr = ts_mean("close", 20) assert isinstance(expr, FunctionNode) assert expr.func_name == "ts_mean" assert len(expr.args) == 2 assert isinstance(expr.args[0], Symbol) assert expr.args[0].name == "close" assert isinstance(expr.args[1], Constant) assert expr.args[1].value == 20 def test_ts_std_with_str(self): """ts_std 支持字符串参数。""" expr = ts_std("volume", 10) assert isinstance(expr, FunctionNode) assert expr.func_name == "ts_std" assert expr.args[0].name == "volume" def test_ts_corr_with_str(self): """ts_corr 支持字符串参数。""" expr = ts_corr("close", "open", 20) assert isinstance(expr, FunctionNode) assert expr.func_name == "ts_corr" assert expr.args[0].name == "close" assert expr.args[1].name == "open" def test_cs_rank_with_str(self): """cs_rank 支持字符串参数。""" expr = cs_rank("pe_ratio") assert isinstance(expr, FunctionNode) assert expr.func_name == "cs_rank" assert expr.args[0].name == "pe_ratio" def test_cs_zscore_with_str(self): """cs_zscore 支持字符串参数。""" expr = cs_zscore("market_cap") assert isinstance(expr, FunctionNode) assert expr.func_name == "cs_zscore" assert expr.args[0].name == "market_cap" def test_log_with_str(self): """log 支持字符串参数。""" expr = log("close") assert isinstance(expr, FunctionNode) assert expr.func_name == "log" assert expr.args[0].name == "close" def test_max_with_str(self): """max_ 支持字符串参数。""" expr = max_("close", "open") assert isinstance(expr, FunctionNode) assert expr.func_name == "max" assert expr.args[0].name == "close" assert expr.args[1].name == "open" def test_max_with_str_and_number(self): """max_ 支持字符串和数值混合。""" expr = max_("close", 100) assert isinstance(expr, FunctionNode) assert expr.args[0].name == "close" assert expr.args[1].value == 100 def test_clip_with_str(self): """clip 支持字符串参数。""" expr = clip("pe_ratio", "lower_bound", "upper_bound") assert isinstance(expr, FunctionNode) assert expr.func_name == "clip" assert expr.args[0].name == "pe_ratio" assert expr.args[1].name == "lower_bound" assert expr.args[2].name == "upper_bound" def test_if_with_str(self): """if_ 支持字符串参数。""" expr = if_("condition", "true_val", "false_val") assert isinstance(expr, FunctionNode) assert expr.func_name == "if" assert expr.args[0].name == "condition" assert expr.args[1].name == "true_val" assert expr.args[2].name == "false_val" class TestComplexExpressions: """测试复杂表达式。""" def test_complex_expression_1(self): """复杂表达式:ts_mean("close", 5) / "pe_ratio"。""" expr = ts_mean("close", 5) / "pe_ratio" assert isinstance(expr, BinaryOpNode) assert expr.op == "/" assert isinstance(expr.left, FunctionNode) assert expr.left.func_name == "ts_mean" assert isinstance(expr.right, Symbol) assert expr.right.name == "pe_ratio" def test_complex_expression_2(self): """复杂表达式:100 / close * cs_rank("volume") 。 注意:Python 内置的 int 类型不支持直接与 str 进行除法运算, 所以需要使用已有的 Symbol 对象或先创建 Symbol。 """ expr = 100 / close * cs_rank("volume") assert isinstance(expr, BinaryOpNode) assert expr.op == "*" assert isinstance(expr.left, BinaryOpNode) assert expr.left.op == "/" assert isinstance(expr.right, FunctionNode) assert expr.right.func_name == "cs_rank" def test_complex_expression_3(self): """复杂表达式:ts_mean(close - "open", 20) / close。""" expr = ts_mean(close - "open", 20) / close assert isinstance(expr, BinaryOpNode) assert expr.op == "/" assert isinstance(expr.left, FunctionNode) assert expr.left.func_name == "ts_mean" # 检查 ts_mean 的第一个参数是 close - open assert isinstance(expr.left.args[0], BinaryOpNode) assert expr.left.args[0].op == "-" class TestExpressionRepr: """测试表达式字符串表示。""" def test_symbol_str_repr(self): """Symbol 的字符串表示。""" expr = Symbol("close") assert repr(expr) == "close" def test_binary_op_repr(self): """二元运算的字符串表示。""" expr = close + "open" assert repr(expr) == "(close + open)" def test_function_node_repr(self): """函数节点的字符串表示。""" expr = ts_mean("close", 20) assert repr(expr) == "ts_mean(close, 20)" def test_complex_expr_repr(self): """复杂表达式的字符串表示。""" expr = ts_mean("close", 5) / "pe_ratio" assert repr(expr) == "(ts_mean(close, 5) / pe_ratio)" if __name__ == "__main__": pytest.main([__file__, "-v"])