276 lines
12 KiB
Python
276 lines
12 KiB
Python
# src/backtest_engine.py
|
||
from datetime import datetime
|
||
from typing import Type, Dict, Any, List, Optional
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
from src.indicators.base_indicators import Indicator
|
||
|
||
# 导入所有需要协调的模块
|
||
from .core_data import Bar, Order, Trade, PortfolioSnapshot
|
||
from .data_manager import DataManager
|
||
from .execution_simulator import ExecutionSimulator
|
||
from .backtest_context import BacktestContext
|
||
from .strategies.base_strategy import Strategy
|
||
|
||
class BacktestEngine:
|
||
"""
|
||
回测引擎:协调数据流、策略执行、订单模拟和结果记录。
|
||
"""
|
||
def __init__(self,
|
||
data_manager: DataManager,
|
||
strategy_class: Type[Strategy],
|
||
strategy_params: Dict[str, Any],
|
||
# current_segment_symbol: str, # 这个参数不再需要,因为 symbol 会动态更新
|
||
initial_capital: float = 100000.0,
|
||
slippage_rate: float = 0.0001,
|
||
commission_rate: float = 0.0002,
|
||
roll_over_mode: bool = False,
|
||
start_time: Optional[datetime] = None, # 新增开始时间
|
||
end_time: Optional[datetime] = None, # 新增结束时间
|
||
indicators: List[Indicator] = [],
|
||
): # 新增换月模式参数
|
||
"""
|
||
初始化回测引擎。
|
||
|
||
Args:
|
||
data_manager (DataManager): 已经初始化好的数据管理器实例。
|
||
strategy_class (Type[Strategy]): 策略类(而不是实例),引擎会负责实例化。
|
||
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
|
||
initial_capital (float): 初始交易资金。
|
||
slippage_rate (float): 交易滑点率。
|
||
commission_rate (float): 交易佣金率。
|
||
roll_over_mode (bool): 是否启用主连合约换月模式。
|
||
"""
|
||
self.data_manager = data_manager
|
||
self.initial_capital = initial_capital
|
||
self.simulator = ExecutionSimulator(
|
||
initial_capital=initial_capital,
|
||
slippage_rate=slippage_rate,
|
||
commission_rate=commission_rate
|
||
)
|
||
# 传入引擎自身给 context,以便 context 可以获取引擎的状态(如 is_rollover_bar)
|
||
self.context = BacktestContext(self.data_manager, self.simulator)
|
||
self.context.set_engine(self) # 建立 Context 到 Engine 的引用
|
||
|
||
# self.current_segment_symbol = current_segment_symbol # 此行移除或作为内部变量动态管理
|
||
|
||
# 实例化策略。初始 symbol 会在 run_backtest 中根据第一根 Bar 动态设置。
|
||
self.strategy = strategy_class(self.context, symbol="INITIAL_PLACEHOLDER_SYMBOL", **strategy_params)
|
||
|
||
self.indicators = indicators
|
||
|
||
self.portfolio_snapshots: List[PortfolioSnapshot] = []
|
||
self.trade_history: List[Trade] = []
|
||
self.all_bars: List[Bar] = []
|
||
|
||
self.close_list: List[float] = []
|
||
self.open_list: List[float] = []
|
||
self.high_list: List[float] = []
|
||
self.low_list: List[float] = []
|
||
self.volume_list: List[float] = []
|
||
|
||
self._history_bars: List[Bar] = [] # 引擎层面保留的历史 Bar,通常供策略在 on_bar 中使用
|
||
self._max_history_bars: int = strategy_params.get('history_bars_limit', 200)
|
||
|
||
# 换月相关状态
|
||
self.roll_over_mode = roll_over_mode # 是否启用换月模式
|
||
self._last_processed_bar_symbol: Optional[str] = None # 记录上一根 K 线的 symbol
|
||
self.is_rollover_bar: bool = False # 标记当前 K 线是否为换月 K 线(禁止开仓)
|
||
|
||
# 新增时间过滤属性
|
||
self.start_time = start_time
|
||
self.end_time = end_time
|
||
|
||
print("\n--- 回测引擎初始化完成 ---")
|
||
print(f" 策略: {strategy_class.__name__}")
|
||
print(f" 初始资金: {initial_capital:.2f}")
|
||
print(f" 换月模式: {'启用' if roll_over_mode else '禁用'}")
|
||
|
||
def run_backtest(self):
|
||
"""
|
||
运行整个回测流程,包含换月逻辑。
|
||
"""
|
||
print("\n--- 回测开始 ---")
|
||
|
||
# 调用策略的初始化方法
|
||
self.strategy.on_init()
|
||
|
||
self.strategy.trading = True
|
||
|
||
last_processed_bar: Optional[Bar] = None # 用于在换月时引用旧合约的最后一根 K 线
|
||
|
||
# 主回测循环
|
||
while True:
|
||
current_bar = self.data_manager.get_next_bar()
|
||
|
||
if current_bar is None:
|
||
break # 没有更多数据,回测结束
|
||
|
||
if self.start_time and current_bar.datetime < self.start_time:
|
||
continue
|
||
|
||
# 如果设置了结束时间,且当前K线在结束时间之后,则终止回测
|
||
if self.end_time and current_bar.datetime >= self.end_time:
|
||
print(f"到达结束时间 {self.end_time},回测终止。")
|
||
break
|
||
|
||
# --- 换月逻辑判断和处理 (在处理 current_bar 之前进行) ---
|
||
# 1. 重置 is_rollover_bar 标记
|
||
self.is_rollover_bar = False
|
||
|
||
# 4. 更新 Context 和 Simulator 的当前 Bar 和时间
|
||
self.context.set_current_bar(current_bar)
|
||
self.simulator.update_time(current_time=current_bar.datetime)
|
||
|
||
# 2. 如果启用换月模式,并且检测到合约 symbol 变化
|
||
if self.roll_over_mode and \
|
||
self._last_processed_bar_symbol is not None and \
|
||
current_bar.symbol != self._last_processed_bar_symbol:
|
||
|
||
old_symbol = self._last_processed_bar_symbol
|
||
new_symbol = current_bar.symbol
|
||
|
||
# 确认 last_processed_bar 确实是旧合约的最后一根 K 线
|
||
if last_processed_bar and last_processed_bar.symbol == old_symbol:
|
||
self.strategy.log(f"检测到换月!从 [{old_symbol}] 切换到 [{new_symbol}]。"
|
||
f"在旧合约最后一根K线 ({last_processed_bar.datetime}) 执行强制平仓和取消操作。")
|
||
|
||
# A. 强制平仓旧合约的所有持仓
|
||
self.simulator.force_close_all_positions_for_symbol(old_symbol, last_processed_bar)
|
||
|
||
# B. 取消旧合约的所有挂单
|
||
self.simulator.cancel_all_pending_orders_for_symbol(old_symbol)
|
||
|
||
# C. 标记【当前这根 Bar (即新合约的第一根 K 线)】为换月 K 线
|
||
# 此时 self.is_rollover_bar 变为 True,将通过 Context 传递给策略,
|
||
# 策略在该 K 线周期内不能开仓。
|
||
self.is_rollover_bar = True
|
||
|
||
# D. 通知策略换月事件,让策略有机会重置内部状态
|
||
self.strategy.on_rollover(old_symbol, new_symbol)
|
||
else:
|
||
self.strategy.log(f"警告: 检测到换月从 {old_symbol} 到 {new_symbol},但 last_processed_bar 为空或与旧合约不符。"
|
||
"强制平仓/取消操作可能未正确执行。")
|
||
|
||
# 3. 更新策略关注的当前合约 symbol
|
||
self.strategy.symbol = current_bar.symbol
|
||
|
||
self.strategy.on_open_bar(current_bar.open, current_bar.symbol)
|
||
|
||
current_indicator_dict = {}
|
||
close_array = np.array(self.close_list)
|
||
open_array = np.array(self.open_list)
|
||
high_array = np.array(self.high_list)
|
||
low_array = np.array(self.low_list)
|
||
volume_array = np.array(self.volume_list)
|
||
|
||
for indicator in self.indicators:
|
||
current_indicator_dict[indicator.get_name()] = indicator.get_latest_value(
|
||
close_array,
|
||
open_array,
|
||
high_array,
|
||
low_array,
|
||
volume_array
|
||
)
|
||
self.simulator.process_pending_orders(current_bar, current_indicator_dict)
|
||
|
||
self.all_bars.append(current_bar)
|
||
self.close_list.append(current_bar.close)
|
||
self.open_list.append(current_bar.open)
|
||
self.high_list.append(current_bar.high)
|
||
self.low_list.append(current_bar.low)
|
||
self.volume_list.append(current_bar.volume)
|
||
|
||
|
||
# 7. 调用策略的 on_bar 方法
|
||
# self.strategy.on_bar(current_bar)
|
||
|
||
self.strategy.on_close_bar(current_bar)
|
||
self.simulator.process_pending_orders(current_bar, current_indicator_dict)
|
||
|
||
|
||
# 8. 记录投资组合快照
|
||
current_portfolio_value = self.simulator.get_portfolio_value(current_bar)
|
||
current_positions = self.simulator.get_current_positions()
|
||
|
||
price_at_snapshot = {current_bar.symbol: current_bar.close} # 使用当前 Bar 的收盘价记录快照
|
||
|
||
snapshot = PortfolioSnapshot(
|
||
datetime=current_bar.datetime,
|
||
total_value=current_portfolio_value,
|
||
cash=self.simulator.cash,
|
||
positions=current_positions,
|
||
price_at_snapshot=price_at_snapshot
|
||
)
|
||
self.portfolio_snapshots.append(snapshot)
|
||
|
||
# 9. 更新 `_last_processed_bar_symbol` 和 `last_processed_bar` 为当前 Bar,为下一轮循环做准备
|
||
self._last_processed_bar_symbol = current_bar.symbol
|
||
last_processed_bar = current_bar
|
||
|
||
# --- 回测结束后的清理工作 ---
|
||
print("\n--- 回测结束,检查并平仓所有剩余持仓 ---")
|
||
if last_processed_bar: # 确保至少有一根 Bar 被处理过
|
||
# 在回测结束时,强制平仓所有可能存在的剩余持仓
|
||
# 遍历所有持仓,确保全部清算
|
||
remaining_positions_symbols = list(self.simulator.get_current_positions().keys())
|
||
for symbol_held in remaining_positions_symbols:
|
||
if self.simulator.get_current_positions().get(symbol_held, 0) != 0:
|
||
self.strategy.log(f"回测结束清理: 强制平仓合约 {symbol_held} 的剩余持仓。")
|
||
# 使用 simulator 的 force_close_all_positions_for_symbol 方法进行清理
|
||
self.simulator.force_close_all_positions_for_symbol(symbol_held, last_processed_bar)
|
||
self.simulator.cancel_all_pending_orders_for_symbol(symbol_held)
|
||
else:
|
||
print("没有处理任何 Bar,无需平仓。")
|
||
|
||
# 回测结束后,获取所有交易记录
|
||
self.trade_history = self.simulator.get_trade_history()
|
||
|
||
print("--- 回测结束 ---")
|
||
print(f"总计处理了 {len(self.all_bars)} 根K线。")
|
||
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
|
||
|
||
final_portfolio_value = 0.0
|
||
if last_processed_bar:
|
||
final_portfolio_value = self.simulator.get_portfolio_value(last_processed_bar)
|
||
else:
|
||
final_portfolio_value = self.initial_capital
|
||
|
||
total_return_percentage = ((final_portfolio_value - self.initial_capital) / self.initial_capital) * 100
|
||
|
||
print(f"最终总净值: {final_portfolio_value:.2f}")
|
||
print(f"总收益率: {total_return_percentage:.2f}%")
|
||
|
||
def get_backtest_results(self) -> Dict[str, Any]:
|
||
"""
|
||
返回回测结果数据,供结果分析模块使用。
|
||
"""
|
||
return {
|
||
"portfolio_snapshots": self.portfolio_snapshots,
|
||
"trade_history": self.trade_history,
|
||
"initial_capital": self.simulator.initial_capital,
|
||
"all_bars": self.all_bars
|
||
}
|
||
|
||
def get_simulator(self) -> ExecutionSimulator:
|
||
return self.simulator
|
||
|
||
def get_bar_history(self):
|
||
return self.all_bars
|
||
|
||
|
||
def get_price_history(self, key: str):
|
||
if key == 'close':
|
||
return self.close_list
|
||
elif key == 'open':
|
||
return self.open_list
|
||
elif key == 'high':
|
||
return self.high_list
|
||
elif key == 'low':
|
||
return self.low_list
|
||
elif key == 'volume':
|
||
return self.volume_list
|
||
return None
|
||
|