# src/backtest_engine.py from typing import Type, Dict, Any, List import pandas as pd # 导入所有需要协调的模块 from .core_data import Bar, Order, Trade, PortfolioSnapshot from .data_manager import DataManager from .execution_simulator import ExecutionSimulator from .backtest_context import BacktestContext from .strategies.base_strategy import Strategy # 导入策略基类 class BacktestEngine: """ 回测引擎:协调数据流、策略执行、订单模拟和结果记录。 """ def __init__(self, data_manager: DataManager, strategy_class: Type[Strategy], strategy_params: Dict[str, Any], current_segment_symbol: str, initial_capital: float = 100000.0, slippage_rate: float = 0.0001, commission_rate: float = 0.0002): """ 初始化回测引擎。 Args: data_manager (DataManager): 已经初始化好的数据管理器实例。 strategy_class (Type[Strategy]): 策略类(而不是实例),引擎会负责实例化。 strategy_params (Dict[str, Any]): 传递给策略的参数字典。 initial_capital (float): 初始交易资金。 slippage_rate (float): 交易滑点率。 commission_rate (float): 交易佣金率。 """ self.data_manager = data_manager self.simulator = ExecutionSimulator( initial_capital=initial_capital, slippage_rate=slippage_rate, commission_rate=commission_rate ) self.context = BacktestContext(self.data_manager, self.simulator) self.current_segment_symbol = current_segment_symbol # 实例化策略 self.strategy = strategy_class(self.context, **strategy_params) self.portfolio_snapshots: List[PortfolioSnapshot] = [] # 存储每天的投资组合快照 self.trade_history: List[Trade] = [] # 存储所有成交记录 self.all_bars: List[Bar] = [] # 历史Bar缓存,用于特征计算 self._history_bars: List[Bar] = [] self._max_history_bars: int = 200 # 例如,只保留最近200根Bar的历史数据,可根据策略需求调整 print("\n--- 回测引擎初始化完成 ---") print(f" 策略: {strategy_class.__name__}") print(f" 初始资金: {initial_capital:.2f}") def run_backtest(self): """ 运行整个回测流程。 """ print("\n--- 回测开始 ---") # 调用策略的初始化方法 self.strategy.on_init() # 主回测循环 while True: current_bar = self.data_manager.get_next_bar() if current_bar is None: break # 没有更多数据,回测结束 # 设置当前Bar到Context,供策略访问 self.context.set_current_bar(current_bar) # 更新历史Bar缓存 self._history_bars.append(current_bar) if len(self._history_bars) > self._max_history_bars: self._history_bars.pop(0) # 移除最旧的Bar # 1. 计算特征 (使用纯函数) # 注意: extract_bar_features 接收的是完整的历史数据,不包含当前Bar # 但为了简单起见,这里传入的是包含当前bar在内的历史数据,但内部函数应确保不使用“未来”数据 # 严格来说,应该传入 self._history_bars[:-1] # features = extract_bar_features(current_bar, self._history_bars[:-1]) # 传入当前Bar之前的所有历史Bar # 2. 调用策略的 on_bar 方法 self.strategy.on_bar(current_bar) # 3. 记录投资组合快照 current_portfolio_value = self.simulator.get_portfolio_value(current_bar) current_positions = self.simulator.get_current_positions() # 创建 PortfolioSnapshot,记录当前Bar的收盘价 price_at_snapshot = { current_bar.symbol if hasattr(current_bar, 'symbol') else "DEFAULT_SYMBOL": current_bar.close} snapshot = PortfolioSnapshot( datetime=current_bar.datetime, total_value=current_portfolio_value, cash=self.simulator.cash, positions=current_positions, price_at_snapshot=price_at_snapshot ) self.portfolio_snapshots.append(snapshot) self.all_bars.append(current_bar) last_processed_bar = current_bar # 记录交易历史(从模拟器获取) # 简化处理:每次获取模拟器中的所有交易历史,并更新引擎的trade_history # 更好的做法是模拟器提供一个方法,返回自上次查询以来的新增交易 # 这里为了不重复添加,可以在 trade_log 中只添加当前 Bar 生成的交易 # 在 on_bar 循环的末尾,获取本Bar周期内新产生的交易 # 模拟器在每次send_order成功时会将trade添加到其trade_log # 这里可以做一个增量获取,或者简单地在循环结束后统一获取 # 目前我们在执行模拟器中已经将成交记录在了 trade_log 中,所以这里不用重复记录, # 而是等到回测结束后再统一获取。 # 不在此处记录 self.trade_history print("\n--- 回测片段结束,检查并平仓所有持仓 ---") if last_processed_bar: # 确保至少有一根Bar被处理过 positions_to_close = self.simulator.get_current_positions() for symbol_held, quantity in positions_to_close.items(): if quantity != 0: print(f"[{last_processed_bar.datetime}] 回测结束平仓: 平仓 {symbol_held} ({quantity} 手) @ {last_processed_bar.close:.2f}。") direction = "SELL" if quantity > 0 else "BUY" volume = abs(quantity) # 使用当前合约的最后一根Bar的价格进行平仓 # 注意:这里假设平仓的symbol_held就是当前segment的symbol # 如果策略可能同时持有其他旧合约的仓位(多主力同时持有),这里需要更复杂的逻辑来获取正确的平仓价格 # 但在主力合约切换场景下,通常只持有当前主力合约的仓位。 rollover_order = Order(symbol=symbol_held, direction=direction, volume=volume, price_type="MARKET") self.simulator.send_order(rollover_order, current_bar=last_processed_bar) else: print("没有处理任何Bar,无需平仓。") # 回测结束后,获取所有交易记录 self.trade_history = self.simulator.get_trade_history() print("--- 回测结束 ---") print(f"总计处理了 {len(self.portfolio_snapshots)} 根K线。") print(f"总计发生了 {len(self.trade_history)} 笔交易。") def get_backtest_results(self) -> Dict[str, Any]: """ 返回回测结果数据,供结果分析模块使用。 """ return { "portfolio_snapshots": self.portfolio_snapshots, "trade_history": self.trade_history, "initial_capital": self.simulator.initial_capital, # 或 self.initial_capital "all_bars": self.all_bars } def get_simulator(self) -> ExecutionSimulator: # <--- 新增的方法 """ 返回引擎内部的 ExecutionSimulator 实例,以便外部可以访问和修改其状态。 """ return self.simulator