Files
ProStock/tests/test_factor_engine.py
liaozhaorun b461a4940d refactor(factor): 完成因子框架 DSL 化重构
- 重构 FactorEngine 实现完整的 DSL 表达式执行链路
- 新增 DataRouter 数据路由器,支持内存模式和核心宽表组装
- 新增 ExecutionPlanner 执行计划生成器,整合编译器和翻译器
- 新增 ComputeEngine 计算引擎,支持并行运算
- 完善 factors/__init__.py 公开 API 导出
- 新增 test_factor_engine.py 引擎单元测试
- 移除旧引擎实现和废弃的 DSL promotion 测试
- 更新 AGENTS.md 添加 v2.2 架构变更历史和 Factors 框架设计说明
2026-03-01 15:03:56 +08:00

161 lines
5.5 KiB
Python

"""FactorEngine 端到端测试。
模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine, DataSpec
from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym
from src.factors.dsl import Symbol, FunctionNode
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
amt = close_val * vol
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
"amount": round(amt, 2),
"pre_close": round(close_val - np.random.randn() * 0.3, 2),
}
)
return pl.DataFrame(rows)
class TestFactorEngineEndToEnd:
"""FactorEngine 端到端测试类。"""
@pytest.fixture
def mock_data(self):
"""提供模拟数据的 fixture。"""
return create_mock_data("20240101", "20240131", n_stocks=5)
@pytest.fixture
def engine(self, mock_data):
"""提供配置好的 FactorEngine fixture。"""
data_source = {"daily": mock_data}
return FactorEngine(data_source=data_source, max_workers=2)
def test_simple_symbol_expression(self, engine):
"""测试简单的符号表达式。"""
engine.register("close_price", close)
result = engine.compute("close_price", "20240115", "20240120")
assert "close_price" in result.columns
assert len(result) > 0
print("[PASS] 简单符号表达式测试")
def test_arithmetic_expression(self, engine):
"""测试算术表达式。"""
engine.register("returns", (close - open_sym) / open_sym)
result = engine.compute("returns", "20240115", "20240120")
assert "returns" in result.columns
print("[PASS] 算术表达式测试")
def test_cs_rank_factor(self, engine):
"""测试截面排名因子。"""
engine.register("price_rank", cs_rank(close))
result = engine.compute("price_rank", "20240115", "20240120")
assert "price_rank" in result.columns
assert result["price_rank"].min() >= 0
assert result["price_rank"].max() <= 1
print("[PASS] 截面排名因子测试")
class TestFullWorkflow:
"""完整工作流测试类。"""
def test_full_workflow_demo(self):
"""演示完整的因子计算工作流。"""
print("\n" + "=" * 60)
print("FactorEngine Full Workflow Demo")
print("=" * 60)
# 1. 准备数据
print("\nStep 1: Prepare mock data...")
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f" Generated {len(mock_data)} rows")
print(f" Stocks: {mock_data['ts_code'].n_unique()}")
# 2. 初始化引擎
print("\nStep 2: Initialize FactorEngine...")
engine = FactorEngine(data_source={"daily": mock_data})
print(" Engine initialized")
# 3. 注册因子 - 使用简单因子避免回看窗口问题
print("\nStep 3: Register factors...")
engine.register("returns", (close - open_sym) / open_sym)
engine.register("price_rank", cs_rank(close))
print(" Registered: returns, price_rank")
# 4. 执行计算 - 使用完整日期范围
print("\nStep 4: Compute factors...")
result = engine.compute(
["returns", "price_rank"],
"20240115",
"20240120",
)
print(f" Computed {len(result)} rows")
# 5. 验证结果
print("\nStep 5: Verify results...")
assert "returns" in result.columns
assert "price_rank" in result.columns
assert result["price_rank"].min() >= 0
assert result["price_rank"].max() <= 1
print(" All assertions passed")
# 6. 展示样本
print("\nStep 6: Sample output...")
sample = result.select(
["ts_code", "trade_date", "close", "returns", "price_rank"]
).head(3)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("Workflow completed successfully!")
print("=" * 60)
if __name__ == "__main__":
test = TestFullWorkflow()
test.test_full_workflow_demo()
pytest.main([__file__, "-v", "--tb=short"])