- 新增 api_daily_basic.py 封装 Tushare 每日指标接口 - 因子引擎移除 lookback_days,支持 daily_basic 表字段路由 - 将每日指标纳入自动同步流程 - 删除废弃的 training/main.py
161 lines
5.5 KiB
Python
161 lines
5.5 KiB
Python
"""FactorEngine 端到端测试。
|
|
|
|
模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。
|
|
"""
|
|
|
|
import pytest
|
|
import polars as pl
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
|
|
from src.factors.engine import FactorEngine, DataSpec
|
|
from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym
|
|
from src.factors.dsl import Symbol, FunctionNode
|
|
|
|
|
|
def create_mock_data(
|
|
start_date: str = "20240101",
|
|
end_date: str = "20240131",
|
|
n_stocks: int = 5,
|
|
) -> pl.DataFrame:
|
|
"""创建模拟的日线数据。"""
|
|
start = datetime.strptime(start_date, "%Y%m%d")
|
|
end = datetime.strptime(end_date, "%Y%m%d")
|
|
|
|
dates = []
|
|
current = start
|
|
while current <= end:
|
|
if current.weekday() < 5: # 周一到周五
|
|
dates.append(current.strftime("%Y%m%d"))
|
|
current += timedelta(days=1)
|
|
|
|
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
|
|
np.random.seed(42)
|
|
|
|
rows = []
|
|
for date in dates:
|
|
for stock in stocks:
|
|
base_price = 10 + np.random.randn() * 5
|
|
close_val = base_price + np.random.randn() * 0.5
|
|
open_val = close_val + np.random.randn() * 0.2
|
|
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
|
|
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
|
|
vol = int(1000000 + np.random.exponential(500000))
|
|
amt = close_val * vol
|
|
|
|
rows.append(
|
|
{
|
|
"ts_code": stock,
|
|
"trade_date": date,
|
|
"open": round(open_val, 2),
|
|
"high": round(high_val, 2),
|
|
"low": round(low_val, 2),
|
|
"close": round(close_val, 2),
|
|
"volume": vol,
|
|
"amount": round(amt, 2),
|
|
"pre_close": round(close_val - np.random.randn() * 0.3, 2),
|
|
}
|
|
)
|
|
|
|
return pl.DataFrame(rows)
|
|
|
|
|
|
class TestFactorEngineEndToEnd:
|
|
"""FactorEngine 端到端测试类。"""
|
|
|
|
@pytest.fixture
|
|
def mock_data(self):
|
|
"""提供模拟数据的 fixture。"""
|
|
return create_mock_data("20240101", "20240131", n_stocks=5)
|
|
|
|
@pytest.fixture
|
|
def engine(self, mock_data):
|
|
"""提供配置好的 FactorEngine fixture。"""
|
|
data_source = {"pro_bar": mock_data}
|
|
return FactorEngine(data_source=data_source, max_workers=2)
|
|
|
|
def test_simple_symbol_expression(self, engine):
|
|
"""测试简单的符号表达式。"""
|
|
engine.register("close_price", close)
|
|
result = engine.compute("close_price", "20240115", "20240120")
|
|
assert "close_price" in result.columns
|
|
assert len(result) > 0
|
|
print("[PASS] 简单符号表达式测试")
|
|
|
|
def test_arithmetic_expression(self, engine):
|
|
"""测试算术表达式。"""
|
|
engine.register("returns", (close - open_sym) / open_sym)
|
|
result = engine.compute("returns", "20240115", "20240120")
|
|
assert "returns" in result.columns
|
|
print("[PASS] 算术表达式测试")
|
|
|
|
def test_cs_rank_factor(self, engine):
|
|
"""测试截面排名因子。"""
|
|
engine.register("price_rank", cs_rank(close))
|
|
result = engine.compute("price_rank", "20240115", "20240120")
|
|
assert "price_rank" in result.columns
|
|
assert result["price_rank"].min() >= 0
|
|
assert result["price_rank"].max() <= 1
|
|
print("[PASS] 截面排名因子测试")
|
|
|
|
|
|
class TestFullWorkflow:
|
|
"""完整工作流测试类。"""
|
|
|
|
def test_full_workflow_demo(self):
|
|
"""演示完整的因子计算工作流。"""
|
|
print("\n" + "=" * 60)
|
|
print("FactorEngine Full Workflow Demo")
|
|
print("=" * 60)
|
|
|
|
# 1. 准备数据
|
|
print("\nStep 1: Prepare mock data...")
|
|
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
|
print(f" Generated {len(mock_data)} rows")
|
|
print(f" Stocks: {mock_data['ts_code'].n_unique()}")
|
|
|
|
# 2. 初始化引擎
|
|
print("\nStep 2: Initialize FactorEngine...")
|
|
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
|
print(" Engine initialized")
|
|
|
|
# 3. 注册因子 - 使用简单因子避免回看窗口问题
|
|
print("\nStep 3: Register factors...")
|
|
engine.register("returns", (close - open_sym) / open_sym)
|
|
engine.register("price_rank", cs_rank(close))
|
|
print(" Registered: returns, price_rank")
|
|
|
|
# 4. 执行计算 - 使用完整日期范围
|
|
print("\nStep 4: Compute factors...")
|
|
result = engine.compute(
|
|
["returns", "price_rank"],
|
|
"20240115",
|
|
"20240120",
|
|
)
|
|
print(f" Computed {len(result)} rows")
|
|
|
|
# 5. 验证结果
|
|
print("\nStep 5: Verify results...")
|
|
assert "returns" in result.columns
|
|
assert "price_rank" in result.columns
|
|
assert result["price_rank"].min() >= 0
|
|
assert result["price_rank"].max() <= 1
|
|
print(" All assertions passed")
|
|
|
|
# 6. 展示样本
|
|
print("\nStep 6: Sample output...")
|
|
sample = result.select(
|
|
["ts_code", "trade_date", "close", "returns", "price_rank"]
|
|
).head(3)
|
|
print(sample.to_pandas().to_string(index=False))
|
|
|
|
print("\n" + "=" * 60)
|
|
print("Workflow completed successfully!")
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test = TestFullWorkflow()
|
|
test.test_full_workflow_demo()
|
|
pytest.main([__file__, "-v", "--tb=short"])
|