"""FactorEngine 端到端测试。 模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。 """ import pytest import polars as pl import numpy as np from datetime import datetime, timedelta from src.factors.engine import FactorEngine, DataSpec from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym from src.factors.dsl import Symbol, FunctionNode def create_mock_data( start_date: str = "20240101", end_date: str = "20240131", n_stocks: int = 5, ) -> pl.DataFrame: """创建模拟的日线数据。""" start = datetime.strptime(start_date, "%Y%m%d") end = datetime.strptime(end_date, "%Y%m%d") dates = [] current = start while current <= end: if current.weekday() < 5: # 周一到周五 dates.append(current.strftime("%Y%m%d")) current += timedelta(days=1) stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)] np.random.seed(42) rows = [] for date in dates: for stock in stocks: base_price = 10 + np.random.randn() * 5 close_val = base_price + np.random.randn() * 0.5 open_val = close_val + np.random.randn() * 0.2 high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3 low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3 vol = int(1000000 + np.random.exponential(500000)) amt = close_val * vol rows.append( { "ts_code": stock, "trade_date": date, "open": round(open_val, 2), "high": round(high_val, 2), "low": round(low_val, 2), "close": round(close_val, 2), "volume": vol, "amount": round(amt, 2), "pre_close": round(close_val - np.random.randn() * 0.3, 2), } ) return pl.DataFrame(rows) class TestFactorEngineEndToEnd: """FactorEngine 端到端测试类。""" @pytest.fixture def mock_data(self): """提供模拟数据的 fixture。""" return create_mock_data("20240101", "20240131", n_stocks=5) @pytest.fixture def engine(self, mock_data): """提供配置好的 FactorEngine fixture。""" data_source = {"pro_bar": mock_data} return FactorEngine(data_source=data_source, max_workers=2) def test_simple_symbol_expression(self, engine): """测试简单的符号表达式。""" engine.register("close_price", close) result = engine.compute("close_price", "20240115", "20240120") assert "close_price" in result.columns assert len(result) > 0 print("[PASS] 简单符号表达式测试") def test_arithmetic_expression(self, engine): """测试算术表达式。""" engine.register("returns", (close - open_sym) / open_sym) result = engine.compute("returns", "20240115", "20240120") assert "returns" in result.columns print("[PASS] 算术表达式测试") def test_cs_rank_factor(self, engine): """测试截面排名因子。""" engine.register("price_rank", cs_rank(close)) result = engine.compute("price_rank", "20240115", "20240120") assert "price_rank" in result.columns assert result["price_rank"].min() >= 0 assert result["price_rank"].max() <= 1 print("[PASS] 截面排名因子测试") class TestFullWorkflow: """完整工作流测试类。""" def test_full_workflow_demo(self): """演示完整的因子计算工作流。""" print("\n" + "=" * 60) print("FactorEngine Full Workflow Demo") print("=" * 60) # 1. 准备数据 print("\nStep 1: Prepare mock data...") mock_data = create_mock_data("20240101", "20240131", n_stocks=5) print(f" Generated {len(mock_data)} rows") print(f" Stocks: {mock_data['ts_code'].n_unique()}") # 2. 初始化引擎 print("\nStep 2: Initialize FactorEngine...") engine = FactorEngine(data_source={"pro_bar": mock_data}) print(" Engine initialized") # 3. 注册因子 - 使用简单因子避免回看窗口问题 print("\nStep 3: Register factors...") engine.register("returns", (close - open_sym) / open_sym) engine.register("price_rank", cs_rank(close)) print(" Registered: returns, price_rank") # 4. 执行计算 - 使用完整日期范围 print("\nStep 4: Compute factors...") result = engine.compute( ["returns", "price_rank"], "20240115", "20240120", ) print(f" Computed {len(result)} rows") # 5. 验证结果 print("\nStep 5: Verify results...") assert "returns" in result.columns assert "price_rank" in result.columns assert result["price_rank"].min() >= 0 assert result["price_rank"].max() <= 1 print(" All assertions passed") # 6. 展示样本 print("\nStep 6: Sample output...") sample = result.select( ["ts_code", "trade_date", "close", "returns", "price_rank"] ).head(3) print(sample.to_pandas().to_string(index=False)) print("\n" + "=" * 60) print("Workflow completed successfully!") print("=" * 60) if __name__ == "__main__": test = TestFullWorkflow() test.test_full_workflow_demo() pytest.main([__file__, "-v", "--tb=short"])