Files
NewQuant/src/backtest_engine.py
2025-07-15 22:45:51 +08:00

276 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# src/backtest_engine.py
from datetime import datetime
from typing import Type, Dict, Any, List, Optional
import numpy as np
import pandas as pd
from src.indicators.base_indicators import Indicator
# 导入所有需要协调的模块
from .core_data import Bar, Order, Trade, PortfolioSnapshot
from .data_manager import DataManager
from .execution_simulator import ExecutionSimulator
from .backtest_context import BacktestContext
from .strategies.base_strategy import Strategy
class BacktestEngine:
"""
回测引擎:协调数据流、策略执行、订单模拟和结果记录。
"""
def __init__(self,
data_manager: DataManager,
strategy_class: Type[Strategy],
strategy_params: Dict[str, Any],
# current_segment_symbol: str, # 这个参数不再需要,因为 symbol 会动态更新
initial_capital: float = 100000.0,
slippage_rate: float = 0.0001,
commission_rate: float = 0.0002,
roll_over_mode: bool = False,
start_time: Optional[datetime] = None, # 新增开始时间
end_time: Optional[datetime] = None, # 新增结束时间
indicators: List[Indicator] = [],
): # 新增换月模式参数
"""
初始化回测引擎。
Args:
data_manager (DataManager): 已经初始化好的数据管理器实例。
strategy_class (Type[Strategy]): 策略类(而不是实例),引擎会负责实例化。
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
initial_capital (float): 初始交易资金。
slippage_rate (float): 交易滑点率。
commission_rate (float): 交易佣金率。
roll_over_mode (bool): 是否启用主连合约换月模式。
"""
self.data_manager = data_manager
self.initial_capital = initial_capital
self.simulator = ExecutionSimulator(
initial_capital=initial_capital,
slippage_rate=slippage_rate,
commission_rate=commission_rate
)
# 传入引擎自身给 context以便 context 可以获取引擎的状态(如 is_rollover_bar
self.context = BacktestContext(self.data_manager, self.simulator)
self.context.set_engine(self) # 建立 Context 到 Engine 的引用
# self.current_segment_symbol = current_segment_symbol # 此行移除或作为内部变量动态管理
# 实例化策略。初始 symbol 会在 run_backtest 中根据第一根 Bar 动态设置。
self.strategy = strategy_class(self.context, symbol="INITIAL_PLACEHOLDER_SYMBOL", **strategy_params)
self.indicators = indicators
self.portfolio_snapshots: List[PortfolioSnapshot] = []
self.trade_history: List[Trade] = []
self.all_bars: List[Bar] = []
self.close_list: List[float] = []
self.open_list: List[float] = []
self.high_list: List[float] = []
self.low_list: List[float] = []
self.volume_list: List[float] = []
self._history_bars: List[Bar] = [] # 引擎层面保留的历史 Bar通常供策略在 on_bar 中使用
self._max_history_bars: int = strategy_params.get('history_bars_limit', 200)
# 换月相关状态
self.roll_over_mode = roll_over_mode # 是否启用换月模式
self._last_processed_bar_symbol: Optional[str] = None # 记录上一根 K 线的 symbol
self.is_rollover_bar: bool = False # 标记当前 K 线是否为换月 K 线(禁止开仓)
# 新增时间过滤属性
self.start_time = start_time
self.end_time = end_time
print("\n--- 回测引擎初始化完成 ---")
print(f" 策略: {strategy_class.__name__}")
print(f" 初始资金: {initial_capital:.2f}")
print(f" 换月模式: {'启用' if roll_over_mode else '禁用'}")
def run_backtest(self):
"""
运行整个回测流程,包含换月逻辑。
"""
print("\n--- 回测开始 ---")
# 调用策略的初始化方法
self.strategy.on_init()
self.strategy.trading = True
last_processed_bar: Optional[Bar] = None # 用于在换月时引用旧合约的最后一根 K 线
# 主回测循环
while True:
current_bar = self.data_manager.get_next_bar()
if current_bar is None:
break # 没有更多数据,回测结束
if self.start_time and current_bar.datetime < self.start_time:
continue
# 如果设置了结束时间且当前K线在结束时间之后则终止回测
if self.end_time and current_bar.datetime >= self.end_time:
print(f"到达结束时间 {self.end_time},回测终止。")
break
# --- 换月逻辑判断和处理 (在处理 current_bar 之前进行) ---
# 1. 重置 is_rollover_bar 标记
self.is_rollover_bar = False
# 4. 更新 Context 和 Simulator 的当前 Bar 和时间
self.context.set_current_bar(current_bar)
self.simulator.update_time(current_time=current_bar.datetime)
# 2. 如果启用换月模式,并且检测到合约 symbol 变化
if self.roll_over_mode and \
self._last_processed_bar_symbol is not None and \
current_bar.symbol != self._last_processed_bar_symbol:
old_symbol = self._last_processed_bar_symbol
new_symbol = current_bar.symbol
# 确认 last_processed_bar 确实是旧合约的最后一根 K 线
if last_processed_bar and last_processed_bar.symbol == old_symbol:
self.strategy.log(f"检测到换月!从 [{old_symbol}] 切换到 [{new_symbol}]。"
f"在旧合约最后一根K线 ({last_processed_bar.datetime}) 执行强制平仓和取消操作。")
# A. 强制平仓旧合约的所有持仓
self.simulator.force_close_all_positions_for_symbol(old_symbol, last_processed_bar)
# B. 取消旧合约的所有挂单
self.simulator.cancel_all_pending_orders_for_symbol(old_symbol)
# C. 标记【当前这根 Bar (即新合约的第一根 K 线)】为换月 K 线
# 此时 self.is_rollover_bar 变为 True将通过 Context 传递给策略,
# 策略在该 K 线周期内不能开仓。
self.is_rollover_bar = True
# D. 通知策略换月事件,让策略有机会重置内部状态
self.strategy.on_rollover(old_symbol, new_symbol)
else:
self.strategy.log(f"警告: 检测到换月从 {old_symbol}{new_symbol},但 last_processed_bar 为空或与旧合约不符。"
"强制平仓/取消操作可能未正确执行。")
# 3. 更新策略关注的当前合约 symbol
self.strategy.symbol = current_bar.symbol
self.strategy.on_open_bar(current_bar.open, current_bar.symbol)
current_indicator_dict = {}
close_array = np.array(self.close_list)
open_array = np.array(self.open_list)
high_array = np.array(self.high_list)
low_array = np.array(self.low_list)
volume_array = np.array(self.volume_list)
for indicator in self.indicators:
current_indicator_dict[indicator.get_name()] = indicator.get_latest_value(
close_array,
open_array,
high_array,
low_array,
volume_array
)
self.simulator.process_pending_orders(current_bar, current_indicator_dict)
self.all_bars.append(current_bar)
self.close_list.append(current_bar.close)
self.open_list.append(current_bar.open)
self.high_list.append(current_bar.high)
self.low_list.append(current_bar.low)
self.volume_list.append(current_bar.volume)
# 7. 调用策略的 on_bar 方法
# self.strategy.on_bar(current_bar)
self.strategy.on_close_bar(current_bar)
self.simulator.process_pending_orders(current_bar, current_indicator_dict)
# 8. 记录投资组合快照
current_portfolio_value = self.simulator.get_portfolio_value(current_bar)
current_positions = self.simulator.get_current_positions()
price_at_snapshot = {current_bar.symbol: current_bar.close} # 使用当前 Bar 的收盘价记录快照
snapshot = PortfolioSnapshot(
datetime=current_bar.datetime,
total_value=current_portfolio_value,
cash=self.simulator.cash,
positions=current_positions,
price_at_snapshot=price_at_snapshot
)
self.portfolio_snapshots.append(snapshot)
# 9. 更新 `_last_processed_bar_symbol` 和 `last_processed_bar` 为当前 Bar为下一轮循环做准备
self._last_processed_bar_symbol = current_bar.symbol
last_processed_bar = current_bar
# --- 回测结束后的清理工作 ---
print("\n--- 回测结束,检查并平仓所有剩余持仓 ---")
if last_processed_bar: # 确保至少有一根 Bar 被处理过
# 在回测结束时,强制平仓所有可能存在的剩余持仓
# 遍历所有持仓,确保全部清算
remaining_positions_symbols = list(self.simulator.get_current_positions().keys())
for symbol_held in remaining_positions_symbols:
if self.simulator.get_current_positions().get(symbol_held, 0) != 0:
self.strategy.log(f"回测结束清理: 强制平仓合约 {symbol_held} 的剩余持仓。")
# 使用 simulator 的 force_close_all_positions_for_symbol 方法进行清理
self.simulator.force_close_all_positions_for_symbol(symbol_held, last_processed_bar)
self.simulator.cancel_all_pending_orders_for_symbol(symbol_held)
else:
print("没有处理任何 Bar无需平仓。")
# 回测结束后,获取所有交易记录
self.trade_history = self.simulator.get_trade_history()
print("--- 回测结束 ---")
print(f"总计处理了 {len(self.all_bars)} 根K线。")
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
final_portfolio_value = 0.0
if last_processed_bar:
final_portfolio_value = self.simulator.get_portfolio_value(last_processed_bar)
else:
final_portfolio_value = self.initial_capital
total_return_percentage = ((final_portfolio_value - self.initial_capital) / self.initial_capital) * 100
print(f"最终总净值: {final_portfolio_value:.2f}")
print(f"总收益率: {total_return_percentage:.2f}%")
def get_backtest_results(self) -> Dict[str, Any]:
"""
返回回测结果数据,供结果分析模块使用。
"""
return {
"portfolio_snapshots": self.portfolio_snapshots,
"trade_history": self.trade_history,
"initial_capital": self.simulator.initial_capital,
"all_bars": self.all_bars
}
def get_simulator(self) -> ExecutionSimulator:
return self.simulator
def get_bar_history(self):
return self.all_bars
def get_price_history(self, key: str):
if key == 'close':
return self.close_list
elif key == 'open':
return self.open_list
elif key == 'high':
return self.high_list
elif key == 'low':
return self.low_list
elif key == 'volume':
return self.volume_list
return None