refactor(factor): 完成因子框架 DSL 化重构
- 重构 FactorEngine 实现完整的 DSL 表达式执行链路 - 新增 DataRouter 数据路由器,支持内存模式和核心宽表组装 - 新增 ExecutionPlanner 执行计划生成器,整合编译器和翻译器 - 新增 ComputeEngine 计算引擎,支持并行运算 - 完善 factors/__init__.py 公开 API 导出 - 新增 test_factor_engine.py 引擎单元测试 - 移除旧引擎实现和废弃的 DSL promotion 测试 - 更新 AGENTS.md 添加 v2.2 架构变更历史和 Factors 框架设计说明
This commit is contained in:
@@ -1,325 +0,0 @@
|
||||
"""测试 DSL 字符串自动提升(Promotion)功能。
|
||||
|
||||
验证以下功能:
|
||||
1. 字符串自动转换为 Symbol
|
||||
2. 算子函数支持字符串参数
|
||||
3. 右位运算支持
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from src.factors.dsl import (
|
||||
Symbol,
|
||||
Constant,
|
||||
BinaryOpNode,
|
||||
UnaryOpNode,
|
||||
FunctionNode,
|
||||
_ensure_node,
|
||||
)
|
||||
from src.factors.api import (
|
||||
close,
|
||||
open,
|
||||
ts_mean,
|
||||
ts_std,
|
||||
ts_corr,
|
||||
cs_rank,
|
||||
cs_zscore,
|
||||
log,
|
||||
exp,
|
||||
max_,
|
||||
min_,
|
||||
clip,
|
||||
if_,
|
||||
where,
|
||||
)
|
||||
|
||||
|
||||
class TestEnsureNode:
|
||||
"""测试 _ensure_node 辅助函数。"""
|
||||
|
||||
def test_ensure_node_with_node(self):
|
||||
"""Node 类型应该原样返回。"""
|
||||
sym = Symbol("close")
|
||||
result = _ensure_node(sym)
|
||||
assert result is sym
|
||||
|
||||
def test_ensure_node_with_int(self):
|
||||
"""整数应该转换为 Constant。"""
|
||||
result = _ensure_node(100)
|
||||
assert isinstance(result, Constant)
|
||||
assert result.value == 100
|
||||
|
||||
def test_ensure_node_with_float(self):
|
||||
"""浮点数应该转换为 Constant。"""
|
||||
result = _ensure_node(3.14)
|
||||
assert isinstance(result, Constant)
|
||||
assert result.value == 3.14
|
||||
|
||||
def test_ensure_node_with_str(self):
|
||||
"""字符串应该转换为 Symbol。"""
|
||||
result = _ensure_node("close")
|
||||
assert isinstance(result, Symbol)
|
||||
assert result.name == "close"
|
||||
|
||||
def test_ensure_node_with_invalid_type(self):
|
||||
"""无效类型应该抛出 TypeError。"""
|
||||
with pytest.raises(TypeError):
|
||||
_ensure_node([1, 2, 3])
|
||||
|
||||
|
||||
class TestSymbolStringPromotion:
|
||||
"""测试 Symbol 与字符串的运算。"""
|
||||
|
||||
def test_symbol_add_str(self):
|
||||
"""Symbol + 字符串。"""
|
||||
expr = close + "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "+"
|
||||
assert isinstance(expr.left, Symbol)
|
||||
assert expr.left.name == "close"
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_symbol_sub_str(self):
|
||||
"""Symbol - 字符串。"""
|
||||
expr = close - "open"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "-"
|
||||
assert expr.right.name == "open"
|
||||
|
||||
def test_symbol_mul_str(self):
|
||||
"""Symbol * 字符串。"""
|
||||
expr = close * "volume"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert expr.right.name == "volume"
|
||||
|
||||
def test_symbol_div_str(self):
|
||||
"""Symbol / 字符串。"""
|
||||
expr = close / "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_symbol_pow_str(self):
|
||||
"""Symbol ** 字符串。"""
|
||||
expr = close ** "exponent"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "**"
|
||||
assert expr.right.name == "exponent"
|
||||
|
||||
|
||||
class TestRightHandOperations:
|
||||
"""测试右位运算。"""
|
||||
|
||||
def test_int_add_symbol(self):
|
||||
"""整数 + Symbol。"""
|
||||
expr = 100 + close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "+"
|
||||
assert isinstance(expr.left, Constant)
|
||||
assert expr.left.value == 100
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_sub_symbol(self):
|
||||
"""整数 - Symbol。"""
|
||||
expr = 100 - close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "-"
|
||||
assert expr.left.value == 100
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_mul_symbol(self):
|
||||
"""整数 * Symbol。"""
|
||||
expr = 2 * close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert expr.left.value == 2
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_div_symbol(self):
|
||||
"""整数 / Symbol。"""
|
||||
expr = 100 / close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert expr.left.value == 100
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_div_str_not_supported(self):
|
||||
"""Python 内置 int 不支持直接与 str 进行除法运算。
|
||||
|
||||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
||||
所以 100 / "close" 会抛出 TypeError。正确的用法是 100 / Symbol("close") 或
|
||||
使用已有的 Symbol 对象如 close。
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
100 / "close"
|
||||
def test_int_floordiv_symbol(self):
|
||||
"""整数 // Symbol。"""
|
||||
expr = 100 // close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "//"
|
||||
|
||||
def test_int_mod_symbol(self):
|
||||
"""整数 % Symbol。"""
|
||||
expr = 100 % close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "%"
|
||||
|
||||
def test_int_pow_symbol(self):
|
||||
"""整数 ** Symbol。"""
|
||||
expr = 2**close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "**"
|
||||
assert expr.left.value == 2
|
||||
assert expr.right.name == "close"
|
||||
|
||||
|
||||
class TestOperatorFunctionsWithStrings:
|
||||
"""测试算子函数支持字符串参数。"""
|
||||
|
||||
def test_ts_mean_with_str(self):
|
||||
"""ts_mean 支持字符串参数。"""
|
||||
expr = ts_mean("close", 20)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_mean"
|
||||
assert len(expr.args) == 2
|
||||
assert isinstance(expr.args[0], Symbol)
|
||||
assert expr.args[0].name == "close"
|
||||
assert isinstance(expr.args[1], Constant)
|
||||
assert expr.args[1].value == 20
|
||||
|
||||
def test_ts_std_with_str(self):
|
||||
"""ts_std 支持字符串参数。"""
|
||||
expr = ts_std("volume", 10)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_std"
|
||||
assert expr.args[0].name == "volume"
|
||||
|
||||
def test_ts_corr_with_str(self):
|
||||
"""ts_corr 支持字符串参数。"""
|
||||
expr = ts_corr("close", "open", 20)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_corr"
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].name == "open"
|
||||
|
||||
def test_cs_rank_with_str(self):
|
||||
"""cs_rank 支持字符串参数。"""
|
||||
expr = cs_rank("pe_ratio")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "cs_rank"
|
||||
assert expr.args[0].name == "pe_ratio"
|
||||
|
||||
def test_cs_zscore_with_str(self):
|
||||
"""cs_zscore 支持字符串参数。"""
|
||||
expr = cs_zscore("market_cap")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "cs_zscore"
|
||||
assert expr.args[0].name == "market_cap"
|
||||
|
||||
def test_log_with_str(self):
|
||||
"""log 支持字符串参数。"""
|
||||
expr = log("close")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "log"
|
||||
assert expr.args[0].name == "close"
|
||||
|
||||
def test_max_with_str(self):
|
||||
"""max_ 支持字符串参数。"""
|
||||
expr = max_("close", "open")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "max"
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].name == "open"
|
||||
|
||||
def test_max_with_str_and_number(self):
|
||||
"""max_ 支持字符串和数值混合。"""
|
||||
expr = max_("close", 100)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].value == 100
|
||||
|
||||
def test_clip_with_str(self):
|
||||
"""clip 支持字符串参数。"""
|
||||
expr = clip("pe_ratio", "lower_bound", "upper_bound")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "clip"
|
||||
assert expr.args[0].name == "pe_ratio"
|
||||
assert expr.args[1].name == "lower_bound"
|
||||
assert expr.args[2].name == "upper_bound"
|
||||
|
||||
def test_if_with_str(self):
|
||||
"""if_ 支持字符串参数。"""
|
||||
expr = if_("condition", "true_val", "false_val")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "if"
|
||||
assert expr.args[0].name == "condition"
|
||||
assert expr.args[1].name == "true_val"
|
||||
assert expr.args[2].name == "false_val"
|
||||
|
||||
|
||||
class TestComplexExpressions:
|
||||
"""测试复杂表达式。"""
|
||||
|
||||
def test_complex_expression_1(self):
|
||||
"""复杂表达式:ts_mean("close", 5) / "pe_ratio"。"""
|
||||
expr = ts_mean("close", 5) / "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert isinstance(expr.left, FunctionNode)
|
||||
assert expr.left.func_name == "ts_mean"
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_complex_expression_2(self):
|
||||
"""复杂表达式:100 / close * cs_rank("volume") 。
|
||||
|
||||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
||||
所以需要使用已有的 Symbol 对象或先创建 Symbol。
|
||||
"""
|
||||
expr = 100 / close * cs_rank("volume")
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert isinstance(expr.left, BinaryOpNode)
|
||||
assert expr.left.op == "/"
|
||||
assert isinstance(expr.right, FunctionNode)
|
||||
assert expr.right.func_name == "cs_rank"
|
||||
def test_complex_expression_3(self):
|
||||
"""复杂表达式:ts_mean(close - "open", 20) / close。"""
|
||||
expr = ts_mean(close - "open", 20) / close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert isinstance(expr.left, FunctionNode)
|
||||
assert expr.left.func_name == "ts_mean"
|
||||
# 检查 ts_mean 的第一个参数是 close - open
|
||||
assert isinstance(expr.left.args[0], BinaryOpNode)
|
||||
assert expr.left.args[0].op == "-"
|
||||
|
||||
|
||||
class TestExpressionRepr:
|
||||
"""测试表达式字符串表示。"""
|
||||
|
||||
def test_symbol_str_repr(self):
|
||||
"""Symbol 的字符串表示。"""
|
||||
expr = Symbol("close")
|
||||
assert repr(expr) == "close"
|
||||
|
||||
def test_binary_op_repr(self):
|
||||
"""二元运算的字符串表示。"""
|
||||
expr = close + "open"
|
||||
assert repr(expr) == "(close + open)"
|
||||
|
||||
def test_function_node_repr(self):
|
||||
"""函数节点的字符串表示。"""
|
||||
expr = ts_mean("close", 20)
|
||||
assert repr(expr) == "ts_mean(close, 20)"
|
||||
|
||||
def test_complex_expr_repr(self):
|
||||
"""复杂表达式的字符串表示。"""
|
||||
expr = ts_mean("close", 5) / "pe_ratio"
|
||||
assert repr(expr) == "(ts_mean(close, 5) / pe_ratio)"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
160
tests/test_factor_engine.py
Normal file
160
tests/test_factor_engine.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""FactorEngine 端到端测试。
|
||||
|
||||
模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.factors.engine import FactorEngine, DataSpec
|
||||
from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
|
||||
|
||||
def create_mock_data(
|
||||
start_date: str = "20240101",
|
||||
end_date: str = "20240131",
|
||||
n_stocks: int = 5,
|
||||
) -> pl.DataFrame:
|
||||
"""创建模拟的日线数据。"""
|
||||
start = datetime.strptime(start_date, "%Y%m%d")
|
||||
end = datetime.strptime(end_date, "%Y%m%d")
|
||||
|
||||
dates = []
|
||||
current = start
|
||||
while current <= end:
|
||||
if current.weekday() < 5: # 周一到周五
|
||||
dates.append(current.strftime("%Y%m%d"))
|
||||
current += timedelta(days=1)
|
||||
|
||||
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
|
||||
np.random.seed(42)
|
||||
|
||||
rows = []
|
||||
for date in dates:
|
||||
for stock in stocks:
|
||||
base_price = 10 + np.random.randn() * 5
|
||||
close_val = base_price + np.random.randn() * 0.5
|
||||
open_val = close_val + np.random.randn() * 0.2
|
||||
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
|
||||
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
|
||||
vol = int(1000000 + np.random.exponential(500000))
|
||||
amt = close_val * vol
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"ts_code": stock,
|
||||
"trade_date": date,
|
||||
"open": round(open_val, 2),
|
||||
"high": round(high_val, 2),
|
||||
"low": round(low_val, 2),
|
||||
"close": round(close_val, 2),
|
||||
"volume": vol,
|
||||
"amount": round(amt, 2),
|
||||
"pre_close": round(close_val - np.random.randn() * 0.3, 2),
|
||||
}
|
||||
)
|
||||
|
||||
return pl.DataFrame(rows)
|
||||
|
||||
|
||||
class TestFactorEngineEndToEnd:
|
||||
"""FactorEngine 端到端测试类。"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data(self):
|
||||
"""提供模拟数据的 fixture。"""
|
||||
return create_mock_data("20240101", "20240131", n_stocks=5)
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, mock_data):
|
||||
"""提供配置好的 FactorEngine fixture。"""
|
||||
data_source = {"daily": mock_data}
|
||||
return FactorEngine(data_source=data_source, max_workers=2)
|
||||
|
||||
def test_simple_symbol_expression(self, engine):
|
||||
"""测试简单的符号表达式。"""
|
||||
engine.register("close_price", close)
|
||||
result = engine.compute("close_price", "20240115", "20240120")
|
||||
assert "close_price" in result.columns
|
||||
assert len(result) > 0
|
||||
print("[PASS] 简单符号表达式测试")
|
||||
|
||||
def test_arithmetic_expression(self, engine):
|
||||
"""测试算术表达式。"""
|
||||
engine.register("returns", (close - open_sym) / open_sym)
|
||||
result = engine.compute("returns", "20240115", "20240120")
|
||||
assert "returns" in result.columns
|
||||
print("[PASS] 算术表达式测试")
|
||||
|
||||
def test_cs_rank_factor(self, engine):
|
||||
"""测试截面排名因子。"""
|
||||
engine.register("price_rank", cs_rank(close))
|
||||
result = engine.compute("price_rank", "20240115", "20240120")
|
||||
assert "price_rank" in result.columns
|
||||
assert result["price_rank"].min() >= 0
|
||||
assert result["price_rank"].max() <= 1
|
||||
print("[PASS] 截面排名因子测试")
|
||||
|
||||
|
||||
class TestFullWorkflow:
|
||||
"""完整工作流测试类。"""
|
||||
|
||||
def test_full_workflow_demo(self):
|
||||
"""演示完整的因子计算工作流。"""
|
||||
print("\n" + "=" * 60)
|
||||
print("FactorEngine Full Workflow Demo")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 准备数据
|
||||
print("\nStep 1: Prepare mock data...")
|
||||
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||
print(f" Generated {len(mock_data)} rows")
|
||||
print(f" Stocks: {mock_data['ts_code'].n_unique()}")
|
||||
|
||||
# 2. 初始化引擎
|
||||
print("\nStep 2: Initialize FactorEngine...")
|
||||
engine = FactorEngine(data_source={"daily": mock_data})
|
||||
print(" Engine initialized")
|
||||
|
||||
# 3. 注册因子 - 使用简单因子避免回看窗口问题
|
||||
print("\nStep 3: Register factors...")
|
||||
engine.register("returns", (close - open_sym) / open_sym)
|
||||
engine.register("price_rank", cs_rank(close))
|
||||
print(" Registered: returns, price_rank")
|
||||
|
||||
# 4. 执行计算 - 使用完整日期范围
|
||||
print("\nStep 4: Compute factors...")
|
||||
result = engine.compute(
|
||||
["returns", "price_rank"],
|
||||
"20240115",
|
||||
"20240120",
|
||||
)
|
||||
print(f" Computed {len(result)} rows")
|
||||
|
||||
# 5. 验证结果
|
||||
print("\nStep 5: Verify results...")
|
||||
assert "returns" in result.columns
|
||||
assert "price_rank" in result.columns
|
||||
assert result["price_rank"].min() >= 0
|
||||
assert result["price_rank"].max() <= 1
|
||||
print(" All assertions passed")
|
||||
|
||||
# 6. 展示样本
|
||||
print("\nStep 6: Sample output...")
|
||||
sample = result.select(
|
||||
["ts_code", "trade_date", "close", "returns", "price_rank"]
|
||||
).head(3)
|
||||
print(sample.to_pandas().to_string(index=False))
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Workflow completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test = TestFullWorkflow()
|
||||
test.test_full_workflow_demo()
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user