Files
ProStock/tests/test_factor_engine.py
liaozhaorun 53225b9443 feat(data): 添加每日指标接口并优化因子引擎
- 新增 api_daily_basic.py 封装 Tushare 每日指标接口
- 因子引擎移除 lookback_days,支持 daily_basic 表字段路由
- 将每日指标纳入自动同步流程
- 删除废弃的 training/main.py
2026-03-03 17:09:39 +08:00

161 lines
5.5 KiB
Python

"""FactorEngine 端到端测试。
模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine, DataSpec
from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym
from src.factors.dsl import Symbol, FunctionNode
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
amt = close_val * vol
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
"amount": round(amt, 2),
"pre_close": round(close_val - np.random.randn() * 0.3, 2),
}
)
return pl.DataFrame(rows)
class TestFactorEngineEndToEnd:
"""FactorEngine 端到端测试类。"""
@pytest.fixture
def mock_data(self):
"""提供模拟数据的 fixture。"""
return create_mock_data("20240101", "20240131", n_stocks=5)
@pytest.fixture
def engine(self, mock_data):
"""提供配置好的 FactorEngine fixture。"""
data_source = {"pro_bar": mock_data}
return FactorEngine(data_source=data_source, max_workers=2)
def test_simple_symbol_expression(self, engine):
"""测试简单的符号表达式。"""
engine.register("close_price", close)
result = engine.compute("close_price", "20240115", "20240120")
assert "close_price" in result.columns
assert len(result) > 0
print("[PASS] 简单符号表达式测试")
def test_arithmetic_expression(self, engine):
"""测试算术表达式。"""
engine.register("returns", (close - open_sym) / open_sym)
result = engine.compute("returns", "20240115", "20240120")
assert "returns" in result.columns
print("[PASS] 算术表达式测试")
def test_cs_rank_factor(self, engine):
"""测试截面排名因子。"""
engine.register("price_rank", cs_rank(close))
result = engine.compute("price_rank", "20240115", "20240120")
assert "price_rank" in result.columns
assert result["price_rank"].min() >= 0
assert result["price_rank"].max() <= 1
print("[PASS] 截面排名因子测试")
class TestFullWorkflow:
"""完整工作流测试类。"""
def test_full_workflow_demo(self):
"""演示完整的因子计算工作流。"""
print("\n" + "=" * 60)
print("FactorEngine Full Workflow Demo")
print("=" * 60)
# 1. 准备数据
print("\nStep 1: Prepare mock data...")
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f" Generated {len(mock_data)} rows")
print(f" Stocks: {mock_data['ts_code'].n_unique()}")
# 2. 初始化引擎
print("\nStep 2: Initialize FactorEngine...")
engine = FactorEngine(data_source={"pro_bar": mock_data})
print(" Engine initialized")
# 3. 注册因子 - 使用简单因子避免回看窗口问题
print("\nStep 3: Register factors...")
engine.register("returns", (close - open_sym) / open_sym)
engine.register("price_rank", cs_rank(close))
print(" Registered: returns, price_rank")
# 4. 执行计算 - 使用完整日期范围
print("\nStep 4: Compute factors...")
result = engine.compute(
["returns", "price_rank"],
"20240115",
"20240120",
)
print(f" Computed {len(result)} rows")
# 5. 验证结果
print("\nStep 5: Verify results...")
assert "returns" in result.columns
assert "price_rank" in result.columns
assert result["price_rank"].min() >= 0
assert result["price_rank"].max() <= 1
print(" All assertions passed")
# 6. 展示样本
print("\nStep 6: Sample output...")
sample = result.select(
["ts_code", "trade_date", "close", "returns", "price_rank"]
).head(3)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("Workflow completed successfully!")
print("=" * 60)
if __name__ == "__main__":
test = TestFullWorkflow()
test.test_full_workflow_demo()
pytest.main([__file__, "-v", "--tb=short"])