Files
ProStock/tests/test_factor_engine.py
liaozhaorun 181994f063 perf(factors/engine): 重构计算引擎使用 Polars 原生并行
- 移除 Python 多进程/多线程池,消除 DataFrame 序列化开销
- 采用 BFS 分层执行策略,每层表达式通过单次 with_columns 提交
- 利用 Polars Rust 引擎实现零拷贝并行计算
- 添加死锁检测机制处理依赖环
2026-03-14 01:24:52 +08:00

161 lines
5.5 KiB
Python

"""FactorEngine 端到端测试。
模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine, DataSpec
from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym
from src.factors.dsl import Symbol, FunctionNode
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
amt = close_val * vol
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
"amount": round(amt, 2),
"pre_close": round(close_val - np.random.randn() * 0.3, 2),
}
)
return pl.DataFrame(rows)
class TestFactorEngineEndToEnd:
"""FactorEngine 端到端测试类。"""
@pytest.fixture
def mock_data(self):
"""提供模拟数据的 fixture。"""
return create_mock_data("20240101", "20240131", n_stocks=5)
@pytest.fixture
def engine(self, mock_data):
"""提供配置好的 FactorEngine fixture。"""
data_source = {"pro_bar": mock_data}
return FactorEngine(data_source=data_source)
def test_simple_symbol_expression(self, engine):
"""测试简单的符号表达式。"""
engine.register("close_price", close)
result = engine.compute("close_price", "20240115", "20240120")
assert "close_price" in result.columns
assert len(result) > 0
print("[PASS] 简单符号表达式测试")
def test_arithmetic_expression(self, engine):
"""测试算术表达式。"""
engine.register("returns", (close - open_sym) / open_sym)
result = engine.compute("returns", "20240115", "20240120")
assert "returns" in result.columns
print("[PASS] 算术表达式测试")
def test_cs_rank_factor(self, engine):
"""测试截面排名因子。"""
engine.register("price_rank", cs_rank(close))
result = engine.compute("price_rank", "20240115", "20240120")
assert "price_rank" in result.columns
assert result["price_rank"].min() >= 0
assert result["price_rank"].max() <= 1
print("[PASS] 截面排名因子测试")
class TestFullWorkflow:
"""完整工作流测试类。"""
def test_full_workflow_demo(self):
"""演示完整的因子计算工作流。"""
print("\n" + "=" * 60)
print("FactorEngine Full Workflow Demo")
print("=" * 60)
# 1. 准备数据
print("\nStep 1: Prepare mock data...")
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f" Generated {len(mock_data)} rows")
print(f" Stocks: {mock_data['ts_code'].n_unique()}")
# 2. 初始化引擎
print("\nStep 2: Initialize FactorEngine...")
engine = FactorEngine(data_source={"pro_bar": mock_data})
print(" Engine initialized")
# 3. 注册因子 - 使用简单因子避免回看窗口问题
print("\nStep 3: Register factors...")
engine.register("returns", (close - open_sym) / open_sym)
engine.register("price_rank", cs_rank(close))
print(" Registered: returns, price_rank")
# 4. 执行计算 - 使用完整日期范围
print("\nStep 4: Compute factors...")
result = engine.compute(
["returns", "price_rank"],
"20240115",
"20240120",
)
print(f" Computed {len(result)} rows")
# 5. 验证结果
print("\nStep 5: Verify results...")
assert "returns" in result.columns
assert "price_rank" in result.columns
assert result["price_rank"].min() >= 0
assert result["price_rank"].max() <= 1
print(" All assertions passed")
# 6. 展示样本
print("\nStep 6: Sample output...")
sample = result.select(
["ts_code", "trade_date", "close", "returns", "price_rank"]
).head(3)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("Workflow completed successfully!")
print("=" * 60)
if __name__ == "__main__":
test = TestFullWorkflow()
test.test_full_workflow_demo()
pytest.main([__file__, "-v", "--tb=short"])