138 lines
6.1 KiB
Python
138 lines
6.1 KiB
Python
|
|
# src/backtest_engine.py
|
|||
|
|
|
|||
|
|
from typing import Type, Dict, Any, List
|
|||
|
|
import pandas as pd
|
|||
|
|
|
|||
|
|
# 导入所有需要协调的模块
|
|||
|
|
from .core_data import Bar, Order, Trade, PortfolioSnapshot
|
|||
|
|
from .data_manager import DataManager
|
|||
|
|
from .execution_simulator import ExecutionSimulator
|
|||
|
|
from .backtest_context import BacktestContext
|
|||
|
|
from .strategies.base_strategy import Strategy # 导入策略基类
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BacktestEngine:
|
|||
|
|
"""
|
|||
|
|
回测引擎:协调数据流、策略执行、订单模拟和结果记录。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self,
|
|||
|
|
data_manager: DataManager,
|
|||
|
|
strategy_class: Type[Strategy],
|
|||
|
|
strategy_params: Dict[str, Any],
|
|||
|
|
initial_capital: float = 100000.0,
|
|||
|
|
slippage_rate: float = 0.0001,
|
|||
|
|
commission_rate: float = 0.0002):
|
|||
|
|
"""
|
|||
|
|
初始化回测引擎。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_manager (DataManager): 已经初始化好的数据管理器实例。
|
|||
|
|
strategy_class (Type[Strategy]): 策略类(而不是实例),引擎会负责实例化。
|
|||
|
|
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
|
|||
|
|
initial_capital (float): 初始交易资金。
|
|||
|
|
slippage_rate (float): 交易滑点率。
|
|||
|
|
commission_rate (float): 交易佣金率。
|
|||
|
|
"""
|
|||
|
|
self.data_manager = data_manager
|
|||
|
|
self.simulator = ExecutionSimulator(
|
|||
|
|
initial_capital=initial_capital,
|
|||
|
|
slippage_rate=slippage_rate,
|
|||
|
|
commission_rate=commission_rate
|
|||
|
|
)
|
|||
|
|
self.context = BacktestContext(self.data_manager, self.simulator)
|
|||
|
|
|
|||
|
|
# 实例化策略
|
|||
|
|
self.strategy = strategy_class(self.context, **strategy_params)
|
|||
|
|
|
|||
|
|
self.portfolio_snapshots: List[PortfolioSnapshot] = [] # 存储每天的投资组合快照
|
|||
|
|
self.trade_history: List[Trade] = [] # 存储所有成交记录
|
|||
|
|
self.all_bars: List[Bar] = []
|
|||
|
|
|
|||
|
|
# 历史Bar缓存,用于特征计算
|
|||
|
|
self._history_bars: List[Bar] = []
|
|||
|
|
self._max_history_bars: int = 200 # 例如,只保留最近200根Bar的历史数据,可根据策略需求调整
|
|||
|
|
|
|||
|
|
print("\n--- 回测引擎初始化完成 ---")
|
|||
|
|
print(f" 策略: {strategy_class.__name__}")
|
|||
|
|
print(f" 初始资金: {initial_capital:.2f}")
|
|||
|
|
|
|||
|
|
def run_backtest(self):
|
|||
|
|
"""
|
|||
|
|
运行整个回测流程。
|
|||
|
|
"""
|
|||
|
|
print("\n--- 回测开始 ---")
|
|||
|
|
|
|||
|
|
# 调用策略的初始化方法
|
|||
|
|
self.strategy.on_init()
|
|||
|
|
|
|||
|
|
# 主回测循环
|
|||
|
|
while True:
|
|||
|
|
current_bar = self.data_manager.get_next_bar()
|
|||
|
|
if current_bar is None:
|
|||
|
|
break # 没有更多数据,回测结束
|
|||
|
|
|
|||
|
|
# 设置当前Bar到Context,供策略访问
|
|||
|
|
self.context.set_current_bar(current_bar)
|
|||
|
|
|
|||
|
|
# 更新历史Bar缓存
|
|||
|
|
self._history_bars.append(current_bar)
|
|||
|
|
if len(self._history_bars) > self._max_history_bars:
|
|||
|
|
self._history_bars.pop(0) # 移除最旧的Bar
|
|||
|
|
|
|||
|
|
# 1. 计算特征 (使用纯函数)
|
|||
|
|
# 注意: extract_bar_features 接收的是完整的历史数据,不包含当前Bar
|
|||
|
|
# 但为了简单起见,这里传入的是包含当前bar在内的历史数据,但内部函数应确保不使用“未来”数据
|
|||
|
|
# 严格来说,应该传入 self._history_bars[:-1]
|
|||
|
|
# features = extract_bar_features(current_bar, self._history_bars[:-1]) # 传入当前Bar之前的所有历史Bar
|
|||
|
|
|
|||
|
|
# 2. 调用策略的 on_bar 方法
|
|||
|
|
self.strategy.on_bar(current_bar)
|
|||
|
|
|
|||
|
|
# 3. 记录投资组合快照
|
|||
|
|
current_portfolio_value = self.simulator.get_portfolio_value(current_bar)
|
|||
|
|
current_positions = self.simulator.get_current_positions()
|
|||
|
|
|
|||
|
|
# 创建 PortfolioSnapshot,记录当前Bar的收盘价
|
|||
|
|
price_at_snapshot = {
|
|||
|
|
current_bar.symbol if hasattr(current_bar, 'symbol') else "DEFAULT_SYMBOL": current_bar.close}
|
|||
|
|
|
|||
|
|
snapshot = PortfolioSnapshot(
|
|||
|
|
datetime=current_bar.datetime,
|
|||
|
|
total_value=current_portfolio_value,
|
|||
|
|
cash=self.simulator.cash,
|
|||
|
|
positions=current_positions,
|
|||
|
|
price_at_snapshot=price_at_snapshot
|
|||
|
|
)
|
|||
|
|
self.portfolio_snapshots.append(snapshot)
|
|||
|
|
self.all_bars.append(current_bar)
|
|||
|
|
|
|||
|
|
# 记录交易历史(从模拟器获取)
|
|||
|
|
# 简化处理:每次获取模拟器中的所有交易历史,并更新引擎的trade_history
|
|||
|
|
# 更好的做法是模拟器提供一个方法,返回自上次查询以来的新增交易
|
|||
|
|
# 这里为了不重复添加,可以在 trade_log 中只添加当前 Bar 生成的交易
|
|||
|
|
|
|||
|
|
# 在 on_bar 循环的末尾,获取本Bar周期内新产生的交易
|
|||
|
|
# 模拟器在每次send_order成功时会将trade添加到其trade_log
|
|||
|
|
# 这里可以做一个增量获取,或者简单地在循环结束后统一获取
|
|||
|
|
# 目前我们在执行模拟器中已经将成交记录在了 trade_log 中,所以这里不用重复记录,
|
|||
|
|
# 而是等到回测结束后再统一获取。
|
|||
|
|
pass # 不在此处记录 self.trade_history
|
|||
|
|
|
|||
|
|
# 回测结束后,获取所有交易记录
|
|||
|
|
self.trade_history = self.simulator.get_trade_history()
|
|||
|
|
|
|||
|
|
print("--- 回测结束 ---")
|
|||
|
|
print(f"总计处理了 {len(self.portfolio_snapshots)} 根K线。")
|
|||
|
|
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
|
|||
|
|
|
|||
|
|
def get_backtest_results(self) -> Dict[str, Any]:
|
|||
|
|
"""
|
|||
|
|
返回回测结果数据,供结果分析模块使用。
|
|||
|
|
"""
|
|||
|
|
return {
|
|||
|
|
"portfolio_snapshots": self.portfolio_snapshots,
|
|||
|
|
"trade_history": self.trade_history,
|
|||
|
|
"initial_capital": self.simulator.initial_capital, # 或 self.initial_capital
|
|||
|
|
"all_bars": self.all_bars
|
|||
|
|
}
|