主力合约回测
This commit is contained in:
@@ -10,5 +10,9 @@ print(ls)
|
||||
ls = api.query_quotes(ins_class="CONT", product_id="au")
|
||||
print(ls)
|
||||
|
||||
quote = api.get_quote("KQ.m@SHFE.rb")
|
||||
# 打印现在螺纹钢主连的标的合约
|
||||
print(quote.underlying_symbol)
|
||||
|
||||
# 关闭api,释放相应资源
|
||||
api.close()
|
||||
@@ -2,7 +2,13 @@ import traceback
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqsdk import TqApi, TqAuth, TqBacktest, TqSim # 确保导入所有需要的回测/模拟类
|
||||
from tqsdk import (
|
||||
TqApi,
|
||||
TqAuth,
|
||||
TqBacktest,
|
||||
TqSim,
|
||||
BacktestFinished,
|
||||
) # 确保导入所有需要的回测/模拟类
|
||||
import os
|
||||
import datetime
|
||||
from datetime import date # 导入 datetime.date
|
||||
@@ -13,18 +19,18 @@ from datetime import date # 导入 datetime.date
|
||||
TQ_USER_NAME = "emanresu" # 例如: "123456"
|
||||
TQ_PASSWORD = "dfgvfgdfgg" # 例如: "your_password"
|
||||
|
||||
BEIJING_TZ = 'Asia/Shanghai'
|
||||
BEIJING_TZ = "Asia/Shanghai"
|
||||
|
||||
|
||||
def collect_and_save_tqsdk_data_stream(
|
||||
symbol: str,
|
||||
freq: str,
|
||||
start_date_str: str,
|
||||
end_date_str: str,
|
||||
mode: str = "backtest", # 默认为回测模式,因为获取历史数据通常用于回测
|
||||
output_dir: str = "../data",
|
||||
tq_user: str = TQ_USER_NAME,
|
||||
tq_pwd: str = TQ_PASSWORD
|
||||
symbol: str,
|
||||
freq: str,
|
||||
start_date_str: str,
|
||||
end_date_str: str,
|
||||
mode: str = "backtest", # 默认为回测模式,因为获取历史数据通常用于回测
|
||||
output_dir: str = "../data",
|
||||
tq_user: str = TQ_USER_NAME,
|
||||
tq_pwd: str = TQ_PASSWORD,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
通过 TqSdk 在指定模式下(回测或模拟)运行,监听并收集指定品种、频率、日期范围的K线数据流,
|
||||
@@ -52,16 +58,20 @@ def collect_and_save_tqsdk_data_stream(
|
||||
collected_data = [] # 用于收集每一根完整K线的数据
|
||||
|
||||
try:
|
||||
start_dt_data_obj = datetime.datetime.strptime(start_date_str, '%Y-%m-%d')
|
||||
end_dt_data_obj = datetime.datetime.strptime(end_date_str, '%Y-%m-%d')
|
||||
start_dt_data_obj = datetime.datetime.strptime(start_date_str, "%Y-%m-%d")
|
||||
end_dt_data_obj = datetime.datetime.strptime(end_date_str, "%Y-%m-%d")
|
||||
|
||||
if mode == "backtest":
|
||||
backtest_start_date = start_dt_data_obj.date()
|
||||
backtest_end_date = end_dt_data_obj.date()
|
||||
print(f"初始化天勤回测API,回测日期范围:{backtest_start_date} 至 {backtest_end_date}")
|
||||
print(
|
||||
f"初始化天勤回测API,回测日期范围:{backtest_start_date} 至 {backtest_end_date}"
|
||||
)
|
||||
api = TqApi(
|
||||
backtest=TqBacktest(start_dt=backtest_start_date, end_dt=backtest_end_date),
|
||||
auth=TqAuth(tq_user, tq_pwd)
|
||||
backtest=TqBacktest(
|
||||
start_dt=backtest_start_date, end_dt=backtest_end_date
|
||||
),
|
||||
auth=TqAuth(tq_user, tq_pwd),
|
||||
)
|
||||
elif mode == "sim":
|
||||
print("初始化天勤模拟/实盘API")
|
||||
@@ -83,21 +93,31 @@ def collect_and_save_tqsdk_data_stream(
|
||||
elif freq == "month":
|
||||
duration_seconds = 30 * 24 * 60 * 60 # 大约一个月
|
||||
else:
|
||||
print(f"错误: 不支持的数据频率 '{freq}'。目前支持 '1min', '5min', 'day', 'week', 'month'。")
|
||||
print(
|
||||
f"错误: 不支持的数据频率 '{freq}'。目前支持 '1min', '5min', 'day', 'week', 'month'。"
|
||||
)
|
||||
print("注意:Tick数据量巨大,不建议用此方法直接收集,因为它会耗尽内存。")
|
||||
return None
|
||||
|
||||
# 获取K线序列,这里获取的是指定频率的K线,天勤会根据模式从历史或实时流中推送
|
||||
klines = api.get_kline_serial(symbol, duration_seconds)
|
||||
quote = api.get_quote(symbol=symbol)
|
||||
underlying_symbol = quote.underlying_symbol
|
||||
|
||||
print(f"开始在 '{mode}' 模式下收集 {symbol} 从 {start_date_str} 到 {end_date_str} 的 {freq} 数据...")
|
||||
print(
|
||||
f"开始在 '{mode}' 模式下收集 {symbol} 从 {start_date_str} 到 {end_date_str} 的 {freq} 数据..."
|
||||
)
|
||||
|
||||
last_kline_datetime = None # 用于跟踪上一根已完成K线的时间
|
||||
|
||||
while api.wait_update():
|
||||
if underlying_symbol is None:
|
||||
underlying_symbol = quote.underlying_symbol
|
||||
|
||||
# 检查是否有新的完整K线生成,或者当前K线是最后一次更新 (在回测结束时)
|
||||
# TqSdk会在K线完成时发送最后一次更新,或者在回测结束时强制更新
|
||||
if api.is_changing(quote, "underlying_symbol"):
|
||||
underlying_symbol = quote.underlying_symbol
|
||||
if api.is_changing(klines):
|
||||
# 只有当K线序列发生变化时才处理
|
||||
# 关注最新一根 K 线(即 klines.iloc[-1])
|
||||
@@ -107,32 +127,44 @@ def collect_and_save_tqsdk_data_stream(
|
||||
# 判断当前K线是否已经结束 (is_last=True) 并且与上一次保存的K线不同
|
||||
# 或者,在回测模式下,回测结束时,最后一根K线也会被视为“完成”
|
||||
# 判断条件:K线时间戳不是 None 且 大于上一次记录的 K线时间
|
||||
if not pd.isna(current_kline['datetime']) and (last_kline_datetime is None or (
|
||||
last_kline_datetime is not None and current_kline['datetime'] > last_kline_datetime)):
|
||||
if not pd.isna(current_kline["datetime"]) and (
|
||||
last_kline_datetime is None
|
||||
or (
|
||||
last_kline_datetime is not None
|
||||
and current_kline["datetime"] > last_kline_datetime
|
||||
)
|
||||
):
|
||||
# 将datetime (微秒) 转换为可读格式
|
||||
|
||||
# 检查K线的时间戳是否在我们要获取的日期范围内
|
||||
# 注意:get_kline_serial 会获取指定范围前后的一小段数据,我们需要过滤
|
||||
|
||||
kline_dt = pd.to_datetime(current_kline['datetime'], unit='ns', utc=True)
|
||||
kline_dt = kline_dt.tz_convert(BEIJING_TZ).strftime('%Y-%m-%d %H:%M:%S')
|
||||
kline_dt = pd.to_datetime(
|
||||
current_kline["datetime"], unit="ns", utc=True
|
||||
)
|
||||
kline_dt = kline_dt.tz_convert(BEIJING_TZ).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
kline_data_to_save = {
|
||||
'datetime': kline_dt,
|
||||
'open': current_kline['open'],
|
||||
'high': current_kline['high'],
|
||||
'low': current_kline['low'],
|
||||
'close': current_kline['close'],
|
||||
'volume': current_kline['volume'],
|
||||
'open_oi': current_kline['open_oi'],
|
||||
'close_oi': current_kline['close_oi']
|
||||
"datetime": kline_dt,
|
||||
"open": current_kline["open"],
|
||||
"high": current_kline["high"],
|
||||
"low": current_kline["low"],
|
||||
"close": current_kline["close"],
|
||||
"volume": current_kline["volume"],
|
||||
"open_oi": current_kline["open_oi"],
|
||||
"close_oi": current_kline["close_oi"],
|
||||
"underlying_symbol": underlying_symbol,
|
||||
}
|
||||
|
||||
collected_data.append(kline_data_to_save)
|
||||
last_kline_datetime = current_kline['datetime']
|
||||
last_kline_datetime = current_kline["datetime"]
|
||||
# print(f"收集到 K线: {kline_dt}, close: {current_kline['close']}") # 用于调试
|
||||
|
||||
# 在回测模式下,当回测结束时,api.wait_update() 会抛出异常,此时我们可以退出循环
|
||||
if api.is_changing(api.get_account()) or api.is_changing(api.get_position()):
|
||||
if api.is_changing(api.get_account()) or api.is_changing(
|
||||
api.get_position()
|
||||
):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
@@ -146,16 +178,19 @@ def collect_and_save_tqsdk_data_stream(
|
||||
# 无论如何,都尝试处理剩余数据并保存
|
||||
finally:
|
||||
if collected_data:
|
||||
df = pd.DataFrame(collected_data).set_index('datetime')
|
||||
df = pd.DataFrame(collected_data).set_index("datetime")
|
||||
df = df.sort_index() # 确保数据按时间排序
|
||||
|
||||
# 构造保存路径
|
||||
freq_folder = freq.replace("min", "m") if "min" in freq else freq
|
||||
if freq == "day": freq_folder = "daily"
|
||||
if freq == "week": freq_folder = "weekly"
|
||||
if freq == "month": freq_folder = "monthly"
|
||||
if freq == "day":
|
||||
freq_folder = "daily"
|
||||
if freq == "week":
|
||||
freq_folder = "weekly"
|
||||
if freq == "month":
|
||||
freq_folder = "monthly"
|
||||
|
||||
safe_symbol = symbol.replace('.', '_')
|
||||
safe_symbol = symbol.replace(".", "_")
|
||||
|
||||
save_folder = os.path.join(output_dir, safe_symbol)
|
||||
os.makedirs(save_folder, exist_ok=True)
|
||||
@@ -163,6 +198,7 @@ def collect_and_save_tqsdk_data_stream(
|
||||
file_name = f"{safe_symbol}_{freq}.csv"
|
||||
file_path = os.path.join(save_folder, file_name)
|
||||
|
||||
print(df.head())
|
||||
df.to_csv(file_path, index=True)
|
||||
print(f"数据已成功保存到: {file_path}, 共 {len(df)} 条记录。")
|
||||
|
||||
@@ -175,6 +211,7 @@ def collect_and_save_tqsdk_data_stream(
|
||||
api.close()
|
||||
return None
|
||||
|
||||
|
||||
# --- 示例用法 ---
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
@@ -187,16 +224,17 @@ if __name__ == "__main__":
|
||||
TQ_USER_NAME = "emanresu" # 例如: "123456"
|
||||
TQ_PASSWORD = "dfgvfgdfgg" # 例如: "your_password"
|
||||
|
||||
# 示例1: 在回测模式下获取沪深300指数主连的日线数据 (用于历史回测)
|
||||
# 这种方式适合获取相对较短或中等长度的历史K线数据。
|
||||
df_if_backtest_daily = collect_and_save_tqsdk_data_stream(
|
||||
symbol="KQ.i@SHFE.rb",
|
||||
freq="day",
|
||||
symbol="KQ.m@SHFE.rb",
|
||||
# symbol='SHFE.rb2510',
|
||||
# symbol='KQ.i@SHFE.bu',
|
||||
freq="min60",
|
||||
start_date_str="2023-01-01",
|
||||
end_date_str="2025-05-01",
|
||||
end_date_str="2025-06-22",
|
||||
mode="backtest", # 指定为回测模式
|
||||
tq_user=TQ_USER_NAME,
|
||||
tq_pwd=TQ_PASSWORD
|
||||
tq_pwd=TQ_PASSWORD,
|
||||
)
|
||||
if df_if_backtest_daily is not None:
|
||||
print(df_if_backtest_daily.tail())
|
||||
|
||||
859
main.ipynb
859
main.ipynb
File diff suppressed because one or more lines are too long
@@ -31,12 +31,13 @@ class ResultAnalyzer:
|
||||
Args:
|
||||
portfolio_snapshots (List[PortfolioSnapshot]): 回测引擎输出的投资组合快照列表。
|
||||
trade_history (List[Trade]): 回测引擎输出的交易历史记录列表。
|
||||
bars (List[Bar]): 回测引擎输出的所有K线数据列表 (可能包含多个合约的K线)。
|
||||
initial_capital (float): 初始资金。
|
||||
"""
|
||||
self.portfolio_snapshots = portfolio_snapshots
|
||||
self.trade_history = trade_history
|
||||
self.initial_capital = initial_capital
|
||||
self.bars = bars
|
||||
self.bars = bars # 接收所有K线数据
|
||||
self._metrics_cache: Optional[Dict[str, Any]] = None
|
||||
|
||||
print("\n--- 结果分析器初始化完成 ---")
|
||||
@@ -69,30 +70,41 @@ class ResultAnalyzer:
|
||||
print(f"{'夏普比率':<15}: {metrics['夏普比率']:.2f}")
|
||||
print(f"{'卡玛比率':<15}: {metrics['卡玛比率']:.2f}")
|
||||
print(f"{'总交易次数':<15}: {metrics['总交易次数']}")
|
||||
print(f"{'总实现盈亏':<15}: {metrics['总实现盈亏']:.2f}") # 新增
|
||||
print(f"{'交易成本':<15}: {metrics['交易成本']:.2f}")
|
||||
|
||||
# 新增交易相关详细指标,以适应更全面的交易分析需求
|
||||
print("\n--- 交易详情 ---")
|
||||
print(f"{'盈利交易次数':<15}: {metrics['盈利交易次数']}")
|
||||
print(f"{'亏损交易次数':<15}: {metrics['亏损交易次数']}")
|
||||
print(f"{'胜率':<15}: {metrics['胜率']:.2%}")
|
||||
print(f"{'盈亏比':<15}: {metrics['盈亏比']:.2f}")
|
||||
print(f"{'平均每次盈利':<15}: {metrics['平均每次盈利']:.2f}")
|
||||
print(f"{'平均每次亏损':<15}: {metrics['平均每次亏损']:.2f}")
|
||||
|
||||
|
||||
if self.trade_history:
|
||||
print("\n--- 部分交易明细 (最近5笔) ---")
|
||||
for trade in self.trade_history[-5:]:
|
||||
# 调整输出格式,显示实现盈亏
|
||||
pnl_display = f" | PnL: {trade.realized_pnl:.2f}" if trade.is_close_trade else ""
|
||||
print(
|
||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}"
|
||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Comm: {trade.commission:.2f}{pnl_display}"
|
||||
)
|
||||
else:
|
||||
print("\n没有交易记录。")
|
||||
|
||||
def plot_performance(self) -> None:
|
||||
"""
|
||||
绘制投资组合净值和回撤曲线。
|
||||
绘制投资组合净值和回撤曲线,以及所有合约的收盘价曲线。
|
||||
"""
|
||||
print("正在绘制绩效图表...")
|
||||
# plot_performance_chart(self.portfolio_snapshots, self.initial_capital, self.bars)
|
||||
plot_equity_and_drawdown_chart(
|
||||
self.portfolio_snapshots,
|
||||
self.initial_capital,
|
||||
title="Portfolio Equity and Drawdown Curve",
|
||||
title="Portfolio Equity and Drawdown Curve (All Contracts)" # 明确标题,表明是整体曲线
|
||||
)
|
||||
|
||||
# 绘制单独的收盘价曲线
|
||||
plot_close_price_chart(self.bars, title="Underlying Asset Close Price")
|
||||
|
||||
# 绘制所有处理过的K线收盘价曲线
|
||||
plot_close_price_chart(self.bars, title="Underlying Asset Close Price (Concatenated Bars)") # 明确标题
|
||||
print("图表绘制完成。")
|
||||
@@ -1,92 +1,104 @@
|
||||
# src/backtest_context.py
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any, Dict, TYPE_CHECKING
|
||||
|
||||
# 导入核心数据类
|
||||
from .core_data import Bar, Order, Trade
|
||||
|
||||
# 导入 DataManager 的相关类 (虽然这里不是类,但 BacktestEngine 会传入)
|
||||
# 为了避免循环导入,这里假设 BacktestEngine 会传入 DataManager 实例
|
||||
# 或者,如果 DataManager 是一个函数集合,那么 BacktestContext 需要直接访问这些函数
|
||||
# 在这里,我们假设 DataManager 是一个 OOP 类,或者有一个接口让 Context 可以获取历史Bar
|
||||
from .data_manager import DataManager # 假设 DataManager 是一个类
|
||||
|
||||
# 导入 ExecutionSimulator
|
||||
from .execution_simulator import ExecutionSimulator
|
||||
# 使用 TYPE_CHECKING 避免循环导入,只在类型检查时导入 BacktestEngine
|
||||
if TYPE_CHECKING:
|
||||
from .backtest_engine import BacktestEngine
|
||||
from .execution_simulator import ExecutionSimulator
|
||||
from .data_manager import DataManager
|
||||
from .core_data import Bar, Order # 确保导入 Order
|
||||
|
||||
class BacktestContext:
|
||||
"""
|
||||
回测上下文:作为策略与回测引擎和模拟器之间的桥梁。
|
||||
策略通过Context获取市场数据和发出交易指令。
|
||||
回测上下文,用于连接策略与数据管理器、模拟器。
|
||||
策略通过此上下文与回测引擎进行交互。
|
||||
"""
|
||||
def __init__(self, data_manager: DataManager, simulator: ExecutionSimulator):
|
||||
self._data_manager = data_manager # 引用DataManager实例
|
||||
self._simulator = simulator # 引用ExecutionSimulator实例
|
||||
self._current_bar: Optional[Bar] = None # 当前Bar,由引擎设置
|
||||
def __init__(self, data_manager: 'DataManager', simulator: 'ExecutionSimulator'):
|
||||
"""
|
||||
初始化回测上下文。
|
||||
|
||||
def set_current_bar(self, bar: Bar):
|
||||
"""由回测引擎调用,更新当前Bar。"""
|
||||
Args:
|
||||
data_manager (DataManager): 数据管理器实例。
|
||||
simulator (ExecutionSimulator): 交易模拟器实例。
|
||||
"""
|
||||
self._data_manager = data_manager
|
||||
self._simulator = simulator
|
||||
self._current_bar: Optional['Bar'] = None
|
||||
self._engine: Optional['BacktestEngine'] = None # 添加对引擎的引用
|
||||
|
||||
def set_current_bar(self, bar: 'Bar'):
|
||||
"""
|
||||
设置当前正在处理的 K 线数据。
|
||||
由 BacktestEngine 调用。
|
||||
"""
|
||||
self._current_bar = bar
|
||||
|
||||
def get_current_bar(self) -> Bar:
|
||||
"""获取当前正在处理的Bar对象。"""
|
||||
if self._current_bar is None:
|
||||
raise RuntimeError("当前Bar未设置。请确保在策略on_bar调用前已设置。")
|
||||
def get_current_bar(self) -> Optional['Bar']:
|
||||
"""
|
||||
获取当前正在处理的 K 线数据。
|
||||
策略可以通过此方法获取最新 K 线。
|
||||
"""
|
||||
return self._current_bar
|
||||
|
||||
def get_history_bars(self, num_bars: int) -> List[Bar]:
|
||||
def get_current_time(self) -> datetime:
|
||||
"""
|
||||
获取当前Bar之前的历史Bar列表。
|
||||
通过DataManager获取,确保不包含未来函数。
|
||||
获取当前模拟时间。
|
||||
"""
|
||||
# DataManager需要提供一个方法来获取不包含当前bar的历史数据
|
||||
# 这里假设DataManager已经实现了这个功能,通过内部索引管理
|
||||
# 注意:这里的实现需要和 DataManager 内部逻辑匹配,确保严格的未来函数防范
|
||||
if self._current_bar is None:
|
||||
raise RuntimeError("当前Bar未设置,无法获取历史Bar。")
|
||||
return self._data_manager.get_history_bars(self._current_bar.datetime, num_bars)
|
||||
|
||||
def send_order(self, order: Order) -> Optional[Trade]:
|
||||
"""
|
||||
策略通过此方法发出交易订单。
|
||||
"""
|
||||
if self._current_bar is None:
|
||||
raise RuntimeError("当前Bar未设置,无法发送订单。")
|
||||
# 将订单转发给模拟器执行
|
||||
trade = self._simulator.send_order(order, self._current_bar)
|
||||
# 可以在这里触发策略的on_trade回调(如果策略定义了)
|
||||
return trade
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
"""
|
||||
策略通过此方法发出交易订单。
|
||||
"""
|
||||
if self._current_bar is None:
|
||||
raise RuntimeError("当前Bar未设置,无法发送订单。")
|
||||
|
||||
return self._simulator.cancel_order(order_id)
|
||||
|
||||
return self._simulator.get_current_time()
|
||||
|
||||
def get_current_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
获取当前模拟器的持仓情况。
|
||||
获取当前所有持仓。
|
||||
"""
|
||||
return self._simulator.get_current_positions()
|
||||
|
||||
def get_current_cash(self) -> float:
|
||||
def get_pending_orders(self) -> Dict[str, 'Order']:
|
||||
"""
|
||||
获取当前模拟器的可用资金。
|
||||
获取当前所有待处理(未成交)订单。
|
||||
"""
|
||||
return self._simulator.get_pending_orders()
|
||||
|
||||
def get_account_cash(self) -> float:
|
||||
"""
|
||||
获取当前可用现金。
|
||||
"""
|
||||
return self._simulator.cash
|
||||
|
||||
def get_current_portfolio_value(self, current_bar: Bar) -> float:
|
||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||
"""
|
||||
获取当前的投资组合总价值(包括现金和持仓市值)。
|
||||
Args:
|
||||
current_bar (Bar): 当前的Bar数据,用于计算持仓市值。
|
||||
Returns:
|
||||
float: 当前的投资组合总价值。
|
||||
获取指定合约的平均持仓成本。
|
||||
"""
|
||||
# 调用底层模拟器的方法来获取投资组合价值
|
||||
return self._simulator.get_portfolio_value(current_bar)
|
||||
return self._simulator.get_average_position_price(symbol)
|
||||
|
||||
def send_order(self, order: 'Order') -> Optional['Order']:
|
||||
"""
|
||||
策略通过此方法发送订单到模拟器。
|
||||
"""
|
||||
return self._simulator.send_order_to_pending(order)
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
"""
|
||||
策略通过此方法取消指定ID的订单。
|
||||
"""
|
||||
return self._simulator.cancel_order(order_id)
|
||||
|
||||
def set_engine(self, engine: 'BacktestEngine'):
|
||||
"""
|
||||
设置对 BacktestEngine 实例的引用。
|
||||
由 BacktestEngine 在初始化时调用,用于允许 Context 访问 Engine 的状态。
|
||||
"""
|
||||
self._engine = engine
|
||||
|
||||
@property
|
||||
def is_rollover_bar(self) -> bool:
|
||||
"""
|
||||
属性:判断当前 K 线是否为换月 K 线(即新合约的第一根 K 线)。
|
||||
用于在换月时禁止策略开仓。
|
||||
"""
|
||||
if self._engine:
|
||||
return self._engine.is_rollover_bar
|
||||
# 如果没有设置引擎引用,默认不认为是换月 K 线
|
||||
# 这通常发生在测试 Context 本身时,或 Engine 初始化不完整的情况。
|
||||
return False
|
||||
@@ -1,6 +1,6 @@
|
||||
# src/backtest_engine.py
|
||||
|
||||
from typing import Type, Dict, Any, List
|
||||
from typing import Type, Dict, Any, List, Optional
|
||||
import pandas as pd
|
||||
|
||||
# 导入所有需要协调的模块
|
||||
@@ -8,22 +8,21 @@ from .core_data import Bar, Order, Trade, PortfolioSnapshot
|
||||
from .data_manager import DataManager
|
||||
from .execution_simulator import ExecutionSimulator
|
||||
from .backtest_context import BacktestContext
|
||||
from .strategies.base_strategy import Strategy # 导入策略基类
|
||||
|
||||
from .strategies.base_strategy import Strategy
|
||||
|
||||
class BacktestEngine:
|
||||
"""
|
||||
回测引擎:协调数据流、策略执行、订单模拟和结果记录。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_manager: DataManager,
|
||||
strategy_class: Type[Strategy],
|
||||
strategy_params: Dict[str, Any],
|
||||
current_segment_symbol: str,
|
||||
# current_segment_symbol: str, # 这个参数不再需要,因为 symbol 会动态更新
|
||||
initial_capital: float = 100000.0,
|
||||
slippage_rate: float = 0.0001,
|
||||
commission_rate: float = 0.0002):
|
||||
commission_rate: float = 0.0002,
|
||||
roll_over_mode: bool = False): # 新增换月模式参数
|
||||
"""
|
||||
初始化回测引擎。
|
||||
|
||||
@@ -34,6 +33,7 @@ class BacktestEngine:
|
||||
initial_capital (float): 初始交易资金。
|
||||
slippage_rate (float): 交易滑点率。
|
||||
commission_rate (float): 交易佣金率。
|
||||
roll_over_mode (bool): 是否启用主连合约换月模式。
|
||||
"""
|
||||
self.data_manager = data_manager
|
||||
self.initial_capital = initial_capital
|
||||
@@ -42,64 +42,110 @@ class BacktestEngine:
|
||||
slippage_rate=slippage_rate,
|
||||
commission_rate=commission_rate
|
||||
)
|
||||
# 传入引擎自身给 context,以便 context 可以获取引擎的状态(如 is_rollover_bar)
|
||||
self.context = BacktestContext(self.data_manager, self.simulator)
|
||||
self.current_segment_symbol = current_segment_symbol
|
||||
self.context.set_engine(self) # 建立 Context 到 Engine 的引用
|
||||
|
||||
# 实例化策略
|
||||
self.strategy = strategy_class(self.context, **strategy_params)
|
||||
# self.current_segment_symbol = current_segment_symbol # 此行移除或作为内部变量动态管理
|
||||
|
||||
self.portfolio_snapshots: List[PortfolioSnapshot] = [] # 存储每天的投资组合快照
|
||||
self.trade_history: List[Trade] = [] # 存储所有成交记录
|
||||
# 实例化策略。初始 symbol 会在 run_backtest 中根据第一根 Bar 动态设置。
|
||||
self.strategy = strategy_class(self.context, symbol="INITIAL_PLACEHOLDER_SYMBOL", **strategy_params)
|
||||
|
||||
self.portfolio_snapshots: List[PortfolioSnapshot] = []
|
||||
self.trade_history: List[Trade] = []
|
||||
self.all_bars: List[Bar] = []
|
||||
|
||||
# 历史Bar缓存,用于特征计算
|
||||
self._history_bars: List[Bar] = []
|
||||
self._max_history_bars: int = 200 # 例如,只保留最近200根Bar的历史数据,可根据策略需求调整
|
||||
self._history_bars: List[Bar] = [] # 引擎层面保留的历史 Bar,通常供策略在 on_bar 中使用
|
||||
self._max_history_bars: int = strategy_params.get('history_bars_limit', 200)
|
||||
|
||||
# 换月相关状态
|
||||
self.roll_over_mode = roll_over_mode # 是否启用换月模式
|
||||
self._last_processed_bar_symbol: Optional[str] = None # 记录上一根 K 线的 symbol
|
||||
self.is_rollover_bar: bool = False # 标记当前 K 线是否为换月 K 线(禁止开仓)
|
||||
|
||||
print("\n--- 回测引擎初始化完成 ---")
|
||||
print(f" 策略: {strategy_class.__name__}")
|
||||
print(f" 初始资金: {initial_capital:.2f}")
|
||||
print(f" 换月模式: {'启用' if roll_over_mode else '禁用'}")
|
||||
|
||||
def run_backtest(self):
|
||||
"""
|
||||
运行整个回测流程。
|
||||
运行整个回测流程,包含换月逻辑。
|
||||
"""
|
||||
print("\n--- 回测开始 ---")
|
||||
|
||||
# 调用策略的初始化方法
|
||||
self.strategy.on_init()
|
||||
|
||||
last_processed_bar: Optional[Bar] = None # 用于在换月时引用旧合约的最后一根 K 线
|
||||
|
||||
# 主回测循环
|
||||
while True:
|
||||
current_bar = self.data_manager.get_next_bar()
|
||||
if current_bar is None:
|
||||
break # 没有更多数据,回测结束
|
||||
|
||||
# 设置当前Bar到Context,供策略访问
|
||||
# --- 换月逻辑判断和处理 (在处理 current_bar 之前进行) ---
|
||||
# 1. 重置 is_rollover_bar 标记
|
||||
self.is_rollover_bar = False
|
||||
|
||||
# 2. 如果启用换月模式,并且检测到合约 symbol 变化
|
||||
if current_bar.symbol != self._last_processed_bar_symbol:
|
||||
print(self.roll_over_mode,
|
||||
self._last_processed_bar_symbol,
|
||||
current_bar.symbol, self._last_processed_bar_symbol)
|
||||
if self.roll_over_mode and \
|
||||
self._last_processed_bar_symbol is not None and \
|
||||
current_bar.symbol != self._last_processed_bar_symbol:
|
||||
|
||||
old_symbol = self._last_processed_bar_symbol
|
||||
new_symbol = current_bar.symbol
|
||||
|
||||
# 确认 last_processed_bar 确实是旧合约的最后一根 K 线
|
||||
if last_processed_bar and last_processed_bar.symbol == old_symbol:
|
||||
self.strategy.log(f"检测到换月!从 [{old_symbol}] 切换到 [{new_symbol}]。"
|
||||
f"在旧合约最后一根K线 ({last_processed_bar.datetime}) 执行强制平仓和取消操作。")
|
||||
|
||||
# A. 强制平仓旧合约的所有持仓
|
||||
self.simulator.force_close_all_positions_for_symbol(old_symbol, last_processed_bar)
|
||||
|
||||
# B. 取消旧合约的所有挂单
|
||||
self.simulator.cancel_all_pending_orders_for_symbol(old_symbol)
|
||||
|
||||
# C. 标记【当前这根 Bar (即新合约的第一根 K 线)】为换月 K 线
|
||||
# 此时 self.is_rollover_bar 变为 True,将通过 Context 传递给策略,
|
||||
# 策略在该 K 线周期内不能开仓。
|
||||
self.is_rollover_bar = True
|
||||
|
||||
# D. 通知策略换月事件,让策略有机会重置内部状态
|
||||
self.strategy.on_rollover(old_symbol, new_symbol)
|
||||
else:
|
||||
self.strategy.log(f"警告: 检测到换月从 {old_symbol} 到 {new_symbol},但 last_processed_bar 为空或与旧合约不符。"
|
||||
"强制平仓/取消操作可能未正确执行。")
|
||||
|
||||
# 3. 更新策略关注的当前合约 symbol
|
||||
self.strategy.symbol = current_bar.symbol
|
||||
|
||||
# 4. 更新 Context 和 Simulator 的当前 Bar 和时间
|
||||
self.context.set_current_bar(current_bar)
|
||||
self.simulator.update_time(current_time=current_bar.datetime)
|
||||
|
||||
# 更新历史Bar缓存
|
||||
# 5. 更新引擎内部的历史 Bar 缓存
|
||||
self._history_bars.append(current_bar)
|
||||
if len(self._history_bars) > self._max_history_bars:
|
||||
self._history_bars.pop(0) # 移除最旧的Bar
|
||||
self._history_bars.pop(0)
|
||||
|
||||
# 1. 计算特征 (使用纯函数)
|
||||
# 注意: extract_bar_features 接收的是完整的历史数据,不包含当前Bar
|
||||
# 但为了简单起见,这里传入的是包含当前bar在内的历史数据,但内部函数应确保不使用“未来”数据
|
||||
# 严格来说,应该传入 self._history_bars[:-1]
|
||||
# features = extract_bar_features(current_bar, self._history_bars[:-1]) # 传入当前Bar之前的所有历史Bar
|
||||
# 6. 处理待撮合订单 (在调用策略 on_bar 之前,确保订单在当前 K 线开盘价撮合)
|
||||
self.simulator.process_pending_orders(current_bar)
|
||||
|
||||
# 2. 调用策略的 on_bar 方法
|
||||
# 7. 调用策略的 on_bar 方法
|
||||
self.strategy.on_bar(current_bar)
|
||||
|
||||
# 3. 记录投资组合快照
|
||||
# 8. 记录投资组合快照
|
||||
current_portfolio_value = self.simulator.get_portfolio_value(current_bar)
|
||||
current_positions = self.simulator.get_current_positions()
|
||||
|
||||
# 创建 PortfolioSnapshot,记录当前Bar的收盘价
|
||||
price_at_snapshot = {
|
||||
current_bar.symbol if hasattr(current_bar, 'symbol') else "DEFAULT_SYMBOL": current_bar.close}
|
||||
price_at_snapshot = {current_bar.symbol: current_bar.close} # 使用当前 Bar 的收盘价记录快照
|
||||
|
||||
snapshot = PortfolioSnapshot(
|
||||
datetime=current_bar.datetime,
|
||||
@@ -111,49 +157,36 @@ class BacktestEngine:
|
||||
self.portfolio_snapshots.append(snapshot)
|
||||
self.all_bars.append(current_bar)
|
||||
|
||||
# 9. 更新 `_last_processed_bar_symbol` 和 `last_processed_bar` 为当前 Bar,为下一轮循环做准备
|
||||
self._last_processed_bar_symbol = current_bar.symbol
|
||||
last_processed_bar = current_bar
|
||||
|
||||
# 记录交易历史(从模拟器获取)
|
||||
# 简化处理:每次获取模拟器中的所有交易历史,并更新引擎的trade_history
|
||||
# 更好的做法是模拟器提供一个方法,返回自上次查询以来的新增交易
|
||||
# 这里为了不重复添加,可以在 trade_log 中只添加当前 Bar 生成的交易
|
||||
|
||||
# 在 on_bar 循环的末尾,获取本Bar周期内新产生的交易
|
||||
# 模拟器在每次send_order成功时会将trade添加到其trade_log
|
||||
# 这里可以做一个增量获取,或者简单地在循环结束后统一获取
|
||||
# 目前我们在执行模拟器中已经将成交记录在了 trade_log 中,所以这里不用重复记录,
|
||||
# 而是等到回测结束后再统一获取。
|
||||
# 不在此处记录 self.trade_history
|
||||
|
||||
print("\n--- 回测片段结束,检查并平仓所有持仓 ---")
|
||||
if last_processed_bar: # 确保至少有一根Bar被处理过
|
||||
positions_to_close = self.simulator.get_current_positions()
|
||||
for symbol_held, quantity in positions_to_close.items():
|
||||
if quantity != 0:
|
||||
print(f"[{last_processed_bar.datetime}] 回测结束平仓: 平仓 {symbol_held} ({quantity} 手) @ {last_processed_bar.close:.2f}。")
|
||||
direction = "CLOSE_LONG" if quantity > 0 else "CLOSE_SELL"
|
||||
volume = abs(quantity)
|
||||
|
||||
# 使用当前合约的最后一根Bar的价格进行平仓
|
||||
# 注意:这里假设平仓的symbol_held就是当前segment的symbol
|
||||
# 如果策略可能同时持有其他旧合约的仓位(多主力同时持有),这里需要更复杂的逻辑来获取正确的平仓价格
|
||||
# 但在主力合约切换场景下,通常只持有当前主力合约的仓位。
|
||||
rollover_order = Order(symbol=symbol_held, direction=direction, volume=volume, price_type="MARKET")
|
||||
self.simulator.send_order(rollover_order, current_bar=last_processed_bar)
|
||||
# --- 回测结束后的清理工作 ---
|
||||
print("\n--- 回测结束,检查并平仓所有剩余持仓 ---")
|
||||
if last_processed_bar: # 确保至少有一根 Bar 被处理过
|
||||
# 在回测结束时,强制平仓所有可能存在的剩余持仓
|
||||
# 遍历所有持仓,确保全部清算
|
||||
remaining_positions_symbols = list(self.simulator.get_current_positions().keys())
|
||||
for symbol_held in remaining_positions_symbols:
|
||||
if self.simulator.get_current_positions().get(symbol_held, 0) != 0:
|
||||
self.strategy.log(f"回测结束清理: 强制平仓合约 {symbol_held} 的剩余持仓。")
|
||||
# 使用 simulator 的 force_close_all_positions_for_symbol 方法进行清理
|
||||
self.simulator.force_close_all_positions_for_symbol(symbol_held, last_processed_bar)
|
||||
self.simulator.cancel_all_pending_orders_for_symbol(symbol_held)
|
||||
else:
|
||||
print("没有处理任何Bar,无需平仓。")
|
||||
print("没有处理任何 Bar,无需平仓。")
|
||||
|
||||
# 回测结束后,获取所有交易记录
|
||||
self.trade_history = self.simulator.get_trade_history()
|
||||
|
||||
print("--- 回测结束 ---")
|
||||
print(f"总计处理了 {len(self.portfolio_snapshots)} 根K线。")
|
||||
print(f"总计处理了 {len(self.all_bars)} 根K线。")
|
||||
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
|
||||
|
||||
final_portfolio_value = 0.0
|
||||
if last_processed_bar:
|
||||
final_portfolio_value = self.simulator.get_portfolio_value(last_processed_bar)
|
||||
else: # 如果数据为空,或者回测根本没跑,则净值为初始资金
|
||||
else:
|
||||
final_portfolio_value = self.initial_capital
|
||||
|
||||
total_return_percentage = ((final_portfolio_value - self.initial_capital) / self.initial_capital) * 100
|
||||
@@ -168,12 +201,9 @@ class BacktestEngine:
|
||||
return {
|
||||
"portfolio_snapshots": self.portfolio_snapshots,
|
||||
"trade_history": self.trade_history,
|
||||
"initial_capital": self.simulator.initial_capital, # 或 self.initial_capital
|
||||
"initial_capital": self.simulator.initial_capital,
|
||||
"all_bars": self.all_bars
|
||||
}
|
||||
|
||||
def get_simulator(self) -> ExecutionSimulator: # <--- 新增的方法
|
||||
"""
|
||||
返回引擎内部的 ExecutionSimulator 实例,以便外部可以访问和修改其状态。
|
||||
"""
|
||||
def get_simulator(self) -> ExecutionSimulator:
|
||||
return self.simulator
|
||||
@@ -28,7 +28,7 @@ def load_raw_data(file_path: str) -> pd.DataFrame:
|
||||
raise FileNotFoundError(f"数据文件未找到: {file_path}")
|
||||
|
||||
# 定义期望的列名,用于检查和选择
|
||||
expected_cols = ['datetime', 'open', 'high', 'low', 'close', 'volume', 'open_oi', 'close_oi']
|
||||
expected_cols = ['datetime', 'open', 'high', 'low', 'close', 'volume', 'open_oi', 'close_oi', 'underlying_symbol']
|
||||
|
||||
try:
|
||||
# 使用 pandas.read_csv 直接解析 datetime 列
|
||||
@@ -42,7 +42,7 @@ def load_raw_data(file_path: str) -> pd.DataFrame:
|
||||
# 检查所有必需的列是否存在
|
||||
missing_cols = [col for col in expected_cols[1:] if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise KeyError(f"CSV文件中缺少以下必需列: {', '.join(missing_cols)}")
|
||||
print(f"CSV文件中缺少以下列: {', '.join(missing_cols)}")
|
||||
|
||||
# 确保数据按时间排序 (这是回测的基础)
|
||||
df = df.sort_index()
|
||||
@@ -68,19 +68,33 @@ def df_to_bar_stream(df: pd.DataFrame, symbol: str) -> Iterator[Bar]:
|
||||
Bar: 逐个生成的 Bar 对象。
|
||||
"""
|
||||
print("开始将 DataFrame 转换为 Bar 对象流...")
|
||||
print(df)
|
||||
for index, row in df.iterrows():
|
||||
try:
|
||||
bar = Bar(
|
||||
symbol=symbol,
|
||||
datetime=index,
|
||||
open=row['open'],
|
||||
high=row['high'],
|
||||
low=row['low'],
|
||||
close=row['close'],
|
||||
volume=int(row['volume']), # 确保成交量为整数
|
||||
open_oi=int(row['open_oi']), # 确保持仓量为整数
|
||||
close_oi=int(row['close_oi'])
|
||||
)
|
||||
if 'underlying_symbol' in df.columns and row['underlying_symbol'] != '':
|
||||
bar = Bar(
|
||||
symbol=row['underlying_symbol'],
|
||||
datetime=index,
|
||||
open=row['open'],
|
||||
high=row['high'],
|
||||
low=row['low'],
|
||||
close=row['close'],
|
||||
volume=int(row['volume']), # 确保成交量为整数
|
||||
open_oi=int(row['open_oi']), # 确保持仓量为整数
|
||||
close_oi=int(row['close_oi'])
|
||||
)
|
||||
else:
|
||||
bar = Bar(
|
||||
symbol=symbol,
|
||||
datetime=index,
|
||||
open=row['open'],
|
||||
high=row['high'],
|
||||
low=row['low'],
|
||||
close=row['close'],
|
||||
volume=int(row['volume']), # 确保成交量为整数
|
||||
open_oi=int(row['open_oi']), # 确保持仓量为整数
|
||||
close_oi=int(row['close_oi'])
|
||||
)
|
||||
yield bar
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"警告: 无法为 {index} 时间创建Bar对象,跳过。错误: {e}")
|
||||
|
||||
@@ -1,42 +1,31 @@
|
||||
# src/execution_simulator.py (修改部分)
|
||||
# src/execution_simulator.py
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
from .core_data import Order, Trade, Bar, PortfolioSnapshot
|
||||
|
||||
|
||||
class ExecutionSimulator:
|
||||
"""
|
||||
模拟交易执行和管理账户资金、持仓。
|
||||
"""
|
||||
|
||||
def __init__(self, initial_capital: float,
|
||||
slippage_rate: float = 0.0001,
|
||||
commission_rate: float = 0.0002,
|
||||
initial_positions: Optional[Dict[str, int]] = None):
|
||||
"""
|
||||
Args:
|
||||
initial_capital (float): 初始资金。
|
||||
slippage_rate (float): 滑点率(相对于成交价格的百分比)。
|
||||
commission_rate (float): 佣金率(相对于成交金额的百分比)。
|
||||
initial_positions (Optional[Dict[str, int]]): 初始持仓,格式为 {symbol: quantity}。
|
||||
"""
|
||||
self.initial_capital = initial_capital
|
||||
self.cash = initial_capital
|
||||
self.positions: Dict[str, int] = initial_positions if initial_positions is not None else {}
|
||||
# 新增:跟踪持仓的平均成本 {symbol: average_cost}
|
||||
self.average_costs: Dict[str, float] = {}
|
||||
# 如果有初始持仓,需要设置初始成本(简化为0,或在外部配置)
|
||||
if initial_positions:
|
||||
for symbol, qty in initial_positions.items():
|
||||
# 初始持仓成本,如果需要精确,应该从外部传入
|
||||
self.average_costs[symbol] = 0.0 # 简化处理,初始持仓成本为0
|
||||
self.average_costs[symbol] = 0.0
|
||||
|
||||
self.slippage_rate = slippage_rate
|
||||
self.commission_rate = commission_rate
|
||||
self.trade_log: List[Trade] = [] # 存储所有成交记录
|
||||
self.pending_orders: Dict[str, Order] = {} # {order_id: Order_object}
|
||||
self._current_time = None
|
||||
self.trade_log: List[Trade] = []
|
||||
self.pending_orders: Dict[str, Order] = {}
|
||||
self._current_time: Optional[datetime] = None
|
||||
|
||||
print(
|
||||
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}")
|
||||
@@ -44,20 +33,11 @@ class ExecutionSimulator:
|
||||
print(f"初始持仓:{self.positions}")
|
||||
|
||||
def update_time(self, current_time: datetime):
|
||||
"""
|
||||
更新模拟器的当前时间。
|
||||
这个方法由 BacktestEngine 在遍历 K 线时调用。
|
||||
"""
|
||||
self._current_time = current_time
|
||||
|
||||
# --- 新增的公共方法 ---
|
||||
def get_current_time(self) -> datetime:
|
||||
"""
|
||||
获取模拟器的当前时间。
|
||||
"""
|
||||
if self._current_time is None:
|
||||
# 可以在这里抛出错误或者返回一个默认值,取决于你对未初始化时间的处理
|
||||
# 抛出错误可以帮助你发现问题,例如在模拟器时间未设置时就尝试获取
|
||||
# 改进:如果时间未设置,可以抛出错误,防止策略在 on_init 阶段意外调用
|
||||
# raise RuntimeError("Simulator time has not been set. Ensure update_time is called.")
|
||||
return None
|
||||
return self._current_time
|
||||
@@ -65,126 +45,100 @@ class ExecutionSimulator:
|
||||
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
|
||||
"""
|
||||
内部方法:根据订单类型和滑点计算实际成交价格。
|
||||
- 市价单通常以当前K线的开盘价成交(考虑滑点)。
|
||||
- 限价单判断是否触及限价,如果触及,以限价成交(考虑滑点)。
|
||||
撮合逻辑:所有订单(市价/限价)都以当前K线的 **开盘价 (open)** 为基准进行撮合。
|
||||
"""
|
||||
fill_price = -1.0 # 默认未成交
|
||||
|
||||
base_price = current_bar.open # 所有成交都以当前K线的开盘价为基准
|
||||
|
||||
if order.price_type == "MARKET":
|
||||
# 市价单通常以开盘价成交,或者根据你的策略需求选择收盘价
|
||||
# 这里我们仍然使用开盘价作为市价单的基准成交价
|
||||
base_price = current_bar.open
|
||||
|
||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 买入或平空,价格向上偏离
|
||||
# 市价单:直接以开盘价成交,考虑滑点
|
||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 买入/平空:向上偏离(多付)
|
||||
fill_price = base_price * (1 + self.slippage_rate)
|
||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 卖出或平多,价格向下偏离
|
||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 卖出/平多:向下偏离(少收)
|
||||
fill_price = base_price * (1 - self.slippage_rate)
|
||||
else: # 默认情况,理论上不应该到这里,因为方向应该明确
|
||||
fill_price = base_price # 不考虑滑点
|
||||
|
||||
# 市价单只要有价格就会成交,无需额外价格区间判断
|
||||
else:
|
||||
fill_price = base_price # 理论上不发生
|
||||
|
||||
elif order.price_type == "LIMIT" and order.limit_price is not None:
|
||||
limit_price = order.limit_price
|
||||
high = current_bar.high
|
||||
low = current_bar.low
|
||||
open_price = current_bar.open # 也可以在限价单成交时用开盘价作为实际成交价,或限价本身
|
||||
|
||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 限价买入或限价平空
|
||||
# 买入:如果K线最低价 <= 限价,则限价单可能成交
|
||||
if low <= limit_price:
|
||||
# 成交价通常是限价,但考虑滑点后可能略高(对买方不利)
|
||||
# 或者可以以 open/low 之间的一个价格成交,这里简化为限价+滑点
|
||||
fill_price_candidate = limit_price * (1 + self.slippage_rate)
|
||||
# 限价单:判断开盘价是否满足限价条件,如果满足,则以开盘价成交(考虑滑点)
|
||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 限价买入/平空
|
||||
# 买单只有当开盘价低于或等于限价时才可能成交
|
||||
# 即:我愿意出 limit_price 买,开盘价 open_price 更低或一样,当然买
|
||||
if base_price <= limit_price:
|
||||
fill_price = base_price * (1 + self.slippage_rate)
|
||||
# else: 未满足限价条件,不成交
|
||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 限价卖出/平多
|
||||
# 卖单只有当开盘价高于或等于限价时才可能成交
|
||||
# 即:我愿意出 limit_price 卖,开盘价 open_price 更高或一样,当然卖
|
||||
if base_price >= limit_price:
|
||||
fill_price = base_price * (1 - self.slippage_rate)
|
||||
# else: 未满足限价条件,不成交
|
||||
|
||||
# 确保成交价不会比当前K线的最低价还低(如果不是在最低点成交)
|
||||
# 也可以简单就返回 limit_price * (1 + self.slippage_rate)
|
||||
|
||||
# 如果开盘价低于或等于限价,那么通常会以开盘价成交(或者略差)
|
||||
# 否则,如果价格是从上方跌落到限价区,那么会在限价附近成交
|
||||
# 这里简化:如果K线触及限价,则以限价成交(考虑滑点)。
|
||||
# 更精细的模拟会考虑价格穿越顺序 (例如,是否开盘就跳过了限价)
|
||||
|
||||
# 如果开盘价已经低于或等于限价,那么就以开盘价成交(考虑滑点)
|
||||
if open_price <= limit_price:
|
||||
fill_price = open_price * (1 + self.slippage_rate)
|
||||
else: # 价格从高处跌落到限价
|
||||
fill_price = limit_price * (1 + self.slippage_rate)
|
||||
|
||||
# 确保成交价不会超过限价 (虽然加滑点可能会略超,但这是交易成本的一部分)
|
||||
# 这个检查是为了避免逻辑错误,理论上加滑点后应该可以接受比限价略高的价格
|
||||
# 如果你严格要求成交价不能高于限价,则需要移除滑点或者更复杂的逻辑
|
||||
# 这里我们接受加滑点后的价格
|
||||
if fill_price > high and fill_price > limit_price: # 如果计算出来的成交价高于K线最高价,则可能不合理
|
||||
fill_price = -1.0 # 理论上不该发生,除非滑点过大
|
||||
else:
|
||||
return -1.0 # 未触及限价
|
||||
|
||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 限价卖出或限价平多
|
||||
# 卖出:如果K线最高价 >= 限价,则限价单可能成交
|
||||
if high >= limit_price:
|
||||
# 成交价通常是限价,但考虑滑点后可能略低(对卖方不利)
|
||||
fill_price_candidate = limit_price * (1 - self.slippage_rate)
|
||||
|
||||
# 如果开盘价已经高于或等于限价,那么就以开盘价成交(考虑滑点)
|
||||
if open_price >= limit_price:
|
||||
fill_price = open_price * (1 - self.slippage_rate)
|
||||
else: # 价格从低处上涨到限价
|
||||
fill_price = limit_price * (1 - self.slippage_rate)
|
||||
|
||||
# 确保成交价不会低于限价
|
||||
if fill_price < low and fill_price < limit_price: # 如果计算出来的成交价低于K线最低价,则可能不合理
|
||||
fill_price = -1.0 # 理论上不该发生
|
||||
else:
|
||||
return -1.0 # 未触及限价
|
||||
|
||||
# 最后检查成交价是否有效
|
||||
# 最终检查成交价是否有效且合理(大于0)
|
||||
if fill_price <= 0:
|
||||
return -1.0 # 如果计算出来价格无效,返回未成交
|
||||
return -1.0 # 未成交或价格无效
|
||||
|
||||
return fill_price
|
||||
|
||||
def send_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
|
||||
def send_order_to_pending(self, order: Order) -> Optional[Order]:
|
||||
"""
|
||||
接收策略发出的订单,并模拟执行。
|
||||
如果订单未立即成交,则加入待处理订单列表。
|
||||
特殊处理:如果 order.direction 是 "CANCEL",则调用 cancel_order。
|
||||
将订单添加到待处理队列。由 BacktestEngine 或 Strategy 调用。
|
||||
此方法不进行撮合,撮合由 process_pending_orders 统一处理。
|
||||
"""
|
||||
if order.id in self.pending_orders:
|
||||
# print(f"订单 {order.id} 已经存在于待处理队列。")
|
||||
return None
|
||||
self.pending_orders[order.id] = order
|
||||
# print(f"订单 {order.id} 加入待处理队列。")
|
||||
return order
|
||||
|
||||
Args:
|
||||
order (Order): 待执行的订单对象。
|
||||
current_bar (Bar): 当前的Bar数据,用于确定成交价格。
|
||||
def process_pending_orders(self, current_bar: Bar):
|
||||
"""
|
||||
处理所有待撮合的订单。在每个K线数据到来时调用。
|
||||
"""
|
||||
# 复制一份待处理订单的键,防止在迭代时修改字典
|
||||
order_ids_to_process = list(self.pending_orders.keys())
|
||||
|
||||
Returns:
|
||||
Optional[Trade]: 如果订单成功执行则返回 Trade 对象,否则返回 None。
|
||||
for order_id in order_ids_to_process:
|
||||
if order_id not in self.pending_orders: # 订单可能已被取消
|
||||
continue
|
||||
|
||||
order = self.pending_orders[order_id]
|
||||
|
||||
# 只有当订单的symbol与当前bar的symbol一致时才尝试撮合
|
||||
# 这样确保了在换月后,旧合约的挂单不会被尝试撮合 (尽管换月时会强制取消)
|
||||
if order.symbol != current_bar.symbol:
|
||||
# 这种情况理论上应该被换月逻辑清理掉的旧合约挂单,
|
||||
# 如果因为某种原因漏掉了,这里直接跳过,避免异常。
|
||||
continue
|
||||
|
||||
# 尝试成交订单
|
||||
self._execute_single_order(order, current_bar)
|
||||
|
||||
def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
|
||||
"""
|
||||
内部方法:尝试执行单个订单,并处理资金和持仓变化。
|
||||
由 send_order 或 process_pending_orders 调用。
|
||||
"""
|
||||
# --- 处理撤单指令 ---
|
||||
if order.direction == "CANCEL":
|
||||
if order.direction == "CANCEL": # 策略主动发起撤单
|
||||
success = self.cancel_order(order.id)
|
||||
if success:
|
||||
# print(f"[{current_bar.datetime}] 模拟器: 收到并成功处理撤单指令 for Order ID: {order.id}")
|
||||
pass
|
||||
else:
|
||||
# print(f"[{current_bar.datetime}] 模拟器: 收到撤单指令 for Order ID: {order.id}, 但订单已成交或不存在。")
|
||||
pass
|
||||
return None # 撤单操作不返回Trade
|
||||
return None # 撤单操作不返回Trade
|
||||
|
||||
# --- 正常买卖订单处理 ---
|
||||
symbol = order.symbol
|
||||
volume = order.volume
|
||||
|
||||
# 尝试计算成交价格
|
||||
fill_price = self._calculate_fill_price(order, current_bar)
|
||||
|
||||
executed_trade: Optional[Trade] = None
|
||||
realized_pnl = 0.0 # 初始化实现盈亏
|
||||
is_open_trade = False
|
||||
is_close_trade = False
|
||||
|
||||
if fill_price <= 0: # 表示未成交或不满足限价条件
|
||||
if order.price_type == "LIMIT":
|
||||
self.pending_orders[order.id] = order
|
||||
# print(f'撮合失败,order id:{order.id},fill_price:{fill_price}')
|
||||
return None # 未成交,返回None
|
||||
if fill_price <= 0: # 未成交或不满足限价条件
|
||||
return None
|
||||
|
||||
# --- 以下是订单成功成交的逻辑 ---
|
||||
trade_value = volume * fill_price
|
||||
@@ -193,195 +147,215 @@ class ExecutionSimulator:
|
||||
current_position = self.positions.get(symbol, 0)
|
||||
current_average_cost = self.average_costs.get(symbol, 0.0)
|
||||
|
||||
actual_direction = order.direction
|
||||
if order.direction == "CLOSE_SHORT":
|
||||
actual_direction = "BUY"
|
||||
elif order.direction == "CLOSE_LONG":
|
||||
actual_direction = "SELL"
|
||||
realized_pnl = 0.0
|
||||
|
||||
is_close_order_intent = (order.direction == "CLOSE_LONG" or
|
||||
order.direction == "CLOSE_SHORT")
|
||||
# 根据 direction 判断开平仓意图
|
||||
# 如果 direction 是 CLOSE_LONG 或 CLOSE_SELL (平多), CLOSE_SHORT (平空) 则是平仓交易
|
||||
is_close_trade = order.direction in ["CLOSE_LONG", "CLOSE_SELL", "CLOSE_SHORT"]
|
||||
# 如果 direction 是 BUY 或 SELL 且不是平仓意图,则是开仓交易
|
||||
is_open_trade = (order.direction in ["BUY", "SELL"]) and (not is_close_trade)
|
||||
|
||||
if actual_direction == "BUY": # 处理买入 (开多 / 平空)
|
||||
# 开多仓或平空仓
|
||||
|
||||
# 区分实际的买卖方向
|
||||
actual_execution_direction = ""
|
||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT":
|
||||
actual_execution_direction = "BUY"
|
||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG" or order.direction == "CLOSE_SELL":
|
||||
actual_execution_direction = "SELL"
|
||||
else:
|
||||
print(f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。")
|
||||
if order.id in self.pending_orders: del self.pending_orders[order.id]
|
||||
return None
|
||||
|
||||
if actual_execution_direction == "BUY": # 处理实际的买入 (开多 / 平空)
|
||||
if current_position >= 0: # 当前持有多仓或无仓位 (开多)
|
||||
is_open_trade = not is_close_order_intent # 如果是平仓意图,则不是开仓交易
|
||||
# 更新平均成本 (加权平均)
|
||||
new_total_cost = (current_average_cost * current_position) + (fill_price * volume)
|
||||
new_total_volume = current_position + volume
|
||||
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
|
||||
self.positions[symbol] = new_total_volume
|
||||
else: # 当前持有空仓 (平空)
|
||||
is_close_trade = is_close_order_intent # 这是平仓交易
|
||||
# 计算平空盈亏
|
||||
pnl_per_share = current_average_cost - fill_price
|
||||
pnl_per_share = current_average_cost - fill_price # 空头平仓盈亏
|
||||
realized_pnl = pnl_per_share * volume
|
||||
|
||||
# 更新持仓和成本
|
||||
self.positions[symbol] += volume
|
||||
if self.positions[symbol] == 0: # 如果全部平仓
|
||||
if self.positions[symbol] == 0:
|
||||
del self.positions[symbol]
|
||||
if symbol in self.average_costs: del self.average_costs[symbol]
|
||||
elif self.positions[symbol] > 0 and current_position < 0: # 部分平空,且空头仓位被买平为多头仓位
|
||||
# 这是从空头转为多头的复杂情况。需要重新计算平均成本
|
||||
# 简单处理:将剩余的多头仓位成本设为当前价格
|
||||
self.average_costs[symbol] = fill_price
|
||||
elif self.positions[symbol] > 0 and current_position < 0: # 空转多
|
||||
self.average_costs[symbol] = fill_price # 新多头仓位成本以成交价为准
|
||||
|
||||
if self.cash >= trade_value + commission:
|
||||
self.cash -= (trade_value + commission)
|
||||
else:
|
||||
print(f"[{current_bar.datetime}] 资金不足,无法执行买入 {volume} {symbol}")
|
||||
if self.cash < trade_value + commission:
|
||||
print(f"[{current_bar.datetime}] 模拟器: 资金不足,无法执行买入 {volume} {symbol} @ {fill_price:.2f}")
|
||||
if order.id in self.pending_orders: del self.pending_orders[order.id]
|
||||
return None
|
||||
self.cash -= (trade_value + commission)
|
||||
|
||||
|
||||
elif actual_direction == "SELL": # 处理卖出 (开空 / 平多)
|
||||
# 开空仓或平多仓
|
||||
elif actual_execution_direction == "SELL": # 处理实际的卖出 (开空 / 平多)
|
||||
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
|
||||
is_open_trade = not is_close_order_intent # 如果是平仓意图,则不是开仓交易
|
||||
# 更新平均成本 (空头成本为负值)
|
||||
# 对于空头,平均成本通常是指你卖出开仓的平均价格
|
||||
# 这里需要根据你的空头成本计算方式来调整
|
||||
# 常见的做法是:总卖出价值 / 总卖出数量
|
||||
new_total_value = (current_average_cost * abs(current_position)) + (fill_price * volume)
|
||||
new_total_volume = abs(current_position) + volume
|
||||
self.average_costs[symbol] = new_total_value / new_total_volume if new_total_volume > 0 else 0.0
|
||||
self.positions[symbol] -= volume # 空头数量增加,持仓量变为负更多
|
||||
|
||||
self.positions[symbol] -= volume
|
||||
else: # 当前持有多仓 (平多)
|
||||
is_close_trade = is_close_order_intent # 这是平仓交易
|
||||
# 计算平多盈亏
|
||||
pnl_per_share = fill_price - current_average_cost
|
||||
pnl_per_share = fill_price - current_average_cost # 多头平仓盈亏
|
||||
realized_pnl = pnl_per_share * volume
|
||||
|
||||
# 更新持仓和成本
|
||||
self.positions[symbol] -= volume
|
||||
if self.positions[symbol] == 0: # 如果全部平仓
|
||||
if self.positions[symbol] == 0:
|
||||
del self.positions[symbol]
|
||||
if symbol in self.average_costs: del self.average_costs[symbol]
|
||||
elif self.positions[symbol] < 0 and current_position > 0: # 部分平多,且多头仓位被卖平为空头仓位
|
||||
# 从多头转为空头的复杂情况
|
||||
self.average_costs[symbol] = fill_price # 简单将剩余空头仓位成本设为当前价格
|
||||
elif self.positions[symbol] < 0 and current_position > 0: # 多转空
|
||||
self.average_costs[symbol] = fill_price # 新空头仓位成本以成交价为准
|
||||
|
||||
if self.cash >= commission:
|
||||
self.cash -= commission
|
||||
self.cash += trade_value
|
||||
else:
|
||||
print(f"[{current_bar.datetime}] 资金不足,无法执行卖出 {volume} {symbol}")
|
||||
if self.cash < commission: # 卖出交易,佣金先扣
|
||||
print(f"[{current_bar.datetime}] 模拟器: 资金不足(佣金),无法执行卖出 {volume} {symbol} @ {fill_price:.2f}")
|
||||
if order.id in self.pending_orders: del self.pending_orders[order.id]
|
||||
return None
|
||||
else: # 既不是 BUY 也不是 SELL,且也不是 CANCEL。这可能是未知的 direction
|
||||
print(
|
||||
f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。")
|
||||
return None
|
||||
self.cash -= commission
|
||||
self.cash += trade_value
|
||||
|
||||
# 创建 Trade 对象
|
||||
# 创建 Trade 对象时,direction 使用原始订单的 direction
|
||||
executed_trade = Trade(
|
||||
order_id=order.id, fill_time=current_bar.datetime, symbol=symbol,
|
||||
direction=order.direction, # 记录原始订单方向 (BUY/SELL/CLOSE_X)
|
||||
direction=order.direction, # 使用原始订单的 direction
|
||||
volume=volume, price=fill_price, commission=commission,
|
||||
cash_after_trade=self.cash, positions_after_trade=self.positions.copy(),
|
||||
realized_pnl=realized_pnl, # 填充实现盈亏
|
||||
realized_pnl=realized_pnl,
|
||||
is_open_trade=is_open_trade,
|
||||
is_close_trade=is_close_trade
|
||||
)
|
||||
self.trade_log.append(executed_trade)
|
||||
|
||||
# 如果订单成交,无论它是市价单还是限价单,都从待处理订单中移除
|
||||
# 订单成交,从待处理订单中移除
|
||||
if order.id in self.pending_orders:
|
||||
del self.pending_orders[order.id]
|
||||
|
||||
return executed_trade
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
"""
|
||||
尝试取消一个待处理订单。
|
||||
|
||||
Args:
|
||||
order_id (str): 要取消的订单ID。
|
||||
|
||||
Returns:
|
||||
bool: 如果成功取消则返回 True,否则返回 False(例如,订单不存在或已成交)。
|
||||
"""
|
||||
if order_id in self.pending_orders:
|
||||
# print(f"订单 {order_id} 已成功取消。")
|
||||
del self.pending_orders[order_id]
|
||||
return True
|
||||
# print(f"订单 {order_id} 不存在或已成交,无法取消。")
|
||||
return False
|
||||
|
||||
# --- 新增:强制平仓指定合约的所有持仓 ---
|
||||
def force_close_all_positions_for_symbol(self, symbol_to_close: str, closing_bar: Bar) -> List[Trade]:
|
||||
"""
|
||||
强制平仓指定合约的所有持仓。
|
||||
Args:
|
||||
symbol_to_close (str): 需要平仓的合约代码。
|
||||
closing_bar (Bar): 用于获取平仓价格的当前K线数据(通常是旧合约的最后一根K线)。
|
||||
Returns:
|
||||
List[Trade]: 因强制平仓而产生的交易记录。
|
||||
"""
|
||||
closed_trades: List[Trade] = []
|
||||
|
||||
# 仅处理指定symbol的持仓
|
||||
if symbol_to_close in self.positions and self.positions[symbol_to_close] != 0:
|
||||
volume_to_close = self.positions[symbol_to_close]
|
||||
|
||||
# 根据持仓方向决定平仓订单的方向
|
||||
direction = "SELL" if volume_to_close > 0 else "BUY" # 多头平仓是卖出,空头平仓是买入
|
||||
|
||||
# 构造一个市价平仓订单
|
||||
rollover_order = Order(
|
||||
id=f"FORCE_CLOSE_{symbol_to_close}_{closing_bar.datetime.strftime('%Y%m%d%H%M%S%f')}",
|
||||
symbol=symbol_to_close,
|
||||
direction=direction,
|
||||
volume=abs(volume_to_close),
|
||||
price_type="MARKET",
|
||||
limit_price=None,
|
||||
submitted_time=closing_bar.datetime,
|
||||
)
|
||||
|
||||
# 使用内部的执行逻辑进行撮合
|
||||
trade = self._execute_single_order(rollover_order, closing_bar)
|
||||
if trade:
|
||||
closed_trades.append(trade)
|
||||
else:
|
||||
print(f"[{closing_bar.datetime}] 警告: 强制平仓 {symbol_to_close} 失败!")
|
||||
|
||||
return closed_trades
|
||||
|
||||
# --- 新增:取消指定合约的所有挂单 ---
|
||||
def cancel_all_pending_orders_for_symbol(self, symbol_to_cancel: str) -> int:
|
||||
"""
|
||||
取消指定合约的所有待处理订单。
|
||||
"""
|
||||
cancelled_count = 0
|
||||
order_ids_to_cancel = [
|
||||
order_id for order_id, order in self.pending_orders.items()
|
||||
if order.symbol == symbol_to_cancel
|
||||
]
|
||||
for order_id in order_ids_to_cancel:
|
||||
if self.cancel_order(order_id): # 调用现有的 cancel_order 方法
|
||||
cancelled_count += 1
|
||||
return cancelled_count
|
||||
|
||||
def get_pending_orders(self) -> Dict[str, Order]:
|
||||
"""
|
||||
获取当前所有待处理订单的副本。
|
||||
"""
|
||||
return self.pending_orders.copy()
|
||||
|
||||
def get_portfolio_value(self, current_bar: Bar) -> float:
|
||||
"""
|
||||
计算当前的投资组合总价值(包括现金和持仓市值)。
|
||||
此方法需要兼容多合约持仓的场景。
|
||||
Args:
|
||||
current_bar (Bar): 当前的Bar数据,用于计算持仓市值。
|
||||
current_bar (Bar): 当前的Bar数据,用于计算**当前活跃合约**的持仓市值。
|
||||
注意:如果 simulator 中持有多个合约,这里需要更复杂的逻辑。
|
||||
目前假设主力合约回测时,simulator.positions 主要只包含当前主力合约。
|
||||
Returns:
|
||||
float: 当前的投资组合总价值。
|
||||
"""
|
||||
total_value = self.cash
|
||||
|
||||
# 在单品种场景下,我们假设 self.positions 最多只包含一个品种
|
||||
# 并且这个品种就是 current_bar.symbol 所代表的品种
|
||||
symbol_in_position = list(self.positions.keys())[0] if self.positions else None
|
||||
|
||||
if symbol_in_position and symbol_in_position == current_bar.symbol:
|
||||
quantity = self.positions[symbol_in_position]
|
||||
# 持仓市值 = 数量 * 当前市场价格 (current_bar.close)
|
||||
# 无论多头(quantity > 0)还是空头(quantity < 0),这个计算都是正确的
|
||||
total_value += quantity * current_bar.open
|
||||
|
||||
# 您也可以选择在这里打印调试信息
|
||||
# print(f" DEBUG Portfolio Value Calculation: Cash={self.cash:.2f}, "
|
||||
# f"Position for {symbol_in_position}: {quantity} @ {current_bar.close:.2f}, "
|
||||
# f"Position Value={quantity * current_bar.close:.2f}, Total Value={total_value:.2f}")
|
||||
|
||||
# 如果没有持仓,或者持仓品种与当前Bar品种不符 (理论上单品种不会发生)
|
||||
# 那么 total_value 依然是 self.cash
|
||||
# 遍历所有持仓,计算市值。
|
||||
# 注意:这里假设 current_bar 提供了当前活跃主力合约的价格。
|
||||
# 如果 self.positions 中包含其他非 current_bar.symbol 的旧合约,
|
||||
# 它们的市值将无法用 current_bar.open 来准确计算。
|
||||
# 在换月模式下,旧合约会被强制平仓,因此 simulator.positions 通常只包含一个合约。
|
||||
for symbol, quantity in self.positions.items():
|
||||
# 这里简单处理:如果持仓合约与 current_bar.symbol 相同,则使用 current_bar.open 计算。
|
||||
# 如果是其他合约,则需要外部提供其最新价格,但这超出了本函数当前的能力范围。
|
||||
# 考虑到换月模式,旧合约会被平仓,所以大部分时候这不会是问题。
|
||||
if symbol == current_bar.symbol:
|
||||
total_value += quantity * current_bar.open
|
||||
else:
|
||||
# 警告:如果这里出现,说明有未平仓的旧合约持仓,且没有其最新价格来计算市值。
|
||||
# 在严谨的主力连续回测中,这不应该发生,因为换月会强制平仓。
|
||||
print(f"[{current_bar.datetime}] 警告:持仓中存在非当前K线合约 {symbol},无法准确计算其市值。")
|
||||
# 可以选择将这部分持仓价值计为0,或者使用上一个已知价格(需要额外数据结构)
|
||||
# 这里我们假设它不影响总价值计算,因为换月时会处理掉
|
||||
pass
|
||||
|
||||
return total_value
|
||||
|
||||
def get_current_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
返回当前持仓字典的副本。
|
||||
"""
|
||||
return self.positions.copy()
|
||||
|
||||
def get_trade_history(self) -> List[Trade]:
|
||||
"""
|
||||
返回所有成交记录的副本。
|
||||
"""
|
||||
return self.trade_log.copy()
|
||||
|
||||
def reset(self, new_initial_capital: float = None, new_initial_positions: Dict[str, int] = None) -> None:
|
||||
"""
|
||||
重置模拟器状态到新的初始条件。
|
||||
可以在总回测开始时调用,或在合约切换时调整资金和持仓。
|
||||
此方法不用于换月时的平仓,它用于整个回测开始前的初始化。
|
||||
"""
|
||||
print("ExecutionSimulator: 重置状态。")
|
||||
self.cash = new_initial_capital if new_initial_capital is not None else self.initial_capital
|
||||
self.positions = new_initial_positions.copy() if new_initial_positions is not None else {}
|
||||
self.trade_history = []
|
||||
self.current_orders = {}
|
||||
self.average_costs = {}
|
||||
for symbol, qty in self.positions.items(): # 重置平均成本
|
||||
self.average_costs[symbol] = 0.0
|
||||
self.trade_log = []
|
||||
self.pending_orders = {} # 清空挂单
|
||||
self._current_time = None
|
||||
|
||||
def clear_trade_history(self) -> None:
|
||||
"""
|
||||
清空当前模拟器的交易历史。
|
||||
在每个合约片段结束时调用,以便我们只收集当前片段的交易记录。
|
||||
"""
|
||||
print("ExecutionSimulator: 清空交易历史。")
|
||||
self.trade_history = []
|
||||
# Removed clear_trade_history as trade_log is cleared in reset
|
||||
|
||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||
"""
|
||||
获取指定合约的平均持仓成本。
|
||||
如果无持仓或无该合约记录,返回 None。
|
||||
"""
|
||||
# 返回 average_costs 字典中对应 symbol 的值
|
||||
# 如果没有持仓或者没有记录,返回 None
|
||||
if symbol in self.positions and self.positions[symbol] != 0:
|
||||
return self.average_costs.get(symbol)
|
||||
return None
|
||||
File diff suppressed because one or more lines are too long
@@ -41,7 +41,7 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
|
||||
self.order_id_counter = 0
|
||||
|
||||
self._bar_history: deque[Bar] = deque(maxlen=7)
|
||||
self._bar_history: deque[Bar] = deque(maxlen=10)
|
||||
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
|
||||
|
||||
self.log(f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
|
||||
@@ -58,6 +58,7 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
next_bar_open (Optional[float]): 下一根K线的开盘价,此处策略未使用。
|
||||
"""
|
||||
current_datetime = bar.datetime # 获取当前K线时间
|
||||
self.symbol = bar.symbol
|
||||
|
||||
# --- 1. 撤销上一根K线未成交的订单 ---
|
||||
# 检查是否记录了上一笔订单ID,并且该订单仍然在待处理列表中
|
||||
@@ -87,7 +88,6 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
# --- 3. 平仓逻辑 (止损/止盈) ---
|
||||
# 只有当有持仓时才考虑平仓
|
||||
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
|
||||
|
||||
avg_entry_price = self.get_average_position_price(self.symbol)
|
||||
if avg_entry_price is not None:
|
||||
pnl_per_unit = bar.close - avg_entry_price # 当前浮动盈亏(以收盘价计算)
|
||||
@@ -140,7 +140,7 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
|
||||
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
|
||||
bar_1_ago = self._bar_history[-2]
|
||||
bar_7_ago = self._bar_history[0]
|
||||
bar_7_ago = self._bar_history[-8]
|
||||
|
||||
# 计算历史 K 线的 Range
|
||||
range_1_ago = bar_1_ago.high - bar_1_ago.low
|
||||
@@ -174,10 +174,21 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
new_order = self.send_order(order)
|
||||
# 记录下这个订单的ID,以便在下一根K线开始时进行撤销
|
||||
if new_order:
|
||||
self._last_order_id = new_order.order_id
|
||||
self._last_order_id = new_order.id
|
||||
self.log(f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}")
|
||||
else:
|
||||
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
|
||||
|
||||
# else:
|
||||
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(self._bar_history)}")
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
"""
|
||||
在合约换月时清空历史K线数据和上次订单ID,避免使用旧合约数据进行计算。
|
||||
"""
|
||||
super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志
|
||||
self._bar_history.clear() # 清空历史K线
|
||||
self._last_order_id = None # 清空上次订单ID,因为旧合约订单已取消
|
||||
|
||||
self.log(f"换月完成,清空历史K线数据和上次订单ID,准备新合约交易。")
|
||||
|
||||
|
||||
@@ -15,7 +15,13 @@ class Strategy(ABC):
|
||||
策略通过 context 对象与回测引擎和模拟器进行交互,并提供辅助方法。
|
||||
"""
|
||||
|
||||
def __init__(self, context: 'BacktestContext', symbol: str, enable_log: bool = True, **params: Any):
|
||||
def __init__(
|
||||
self,
|
||||
context: "BacktestContext",
|
||||
symbol: str,
|
||||
enable_log: bool = True,
|
||||
**params: Any,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
context (BacktestEngine): 回测引擎实例,作为策略的上下文,提供与模拟器等的交互接口。
|
||||
@@ -34,7 +40,7 @@ class Strategy(ABC):
|
||||
"""
|
||||
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
|
||||
|
||||
def on_trade(self, trade: 'Trade'):
|
||||
def on_trade(self, trade: "Trade"):
|
||||
"""
|
||||
当模拟器成功执行一笔交易时调用。
|
||||
可用于更新策略内部持仓状态或记录交易。
|
||||
@@ -46,7 +52,7 @@ class Strategy(ABC):
|
||||
pass # 默认不执行任何操作,具体策略可覆盖
|
||||
|
||||
@abstractmethod
|
||||
def on_bar(self, bar: 'Bar'):
|
||||
def on_bar(self, bar: "Bar"):
|
||||
"""
|
||||
每当新的K线数据到来时调用此方法。
|
||||
Args:
|
||||
@@ -57,11 +63,14 @@ class Strategy(ABC):
|
||||
|
||||
# --- 新增/修改的辅助方法 ---
|
||||
|
||||
def send_order(self, order: 'Order') -> Optional[Trade]:
|
||||
def send_order(self, order: "Order") -> Optional[Order]:
|
||||
"""
|
||||
发送订单的辅助方法。
|
||||
会在 BaseStrategy 内部构建 Order 对象,并通过 context 转发给模拟器。
|
||||
"""
|
||||
if self.context.is_rollover_bar:
|
||||
self.log(f"当前是换月K线,禁止开仓订单")
|
||||
return None
|
||||
return self.context.send_order(order)
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
@@ -73,44 +82,37 @@ class Strategy(ABC):
|
||||
return self.context.cancel_order(order_id)
|
||||
|
||||
def cancel_all_pending_orders(self) -> int:
|
||||
"""
|
||||
取消所有当前策略的未决订单。
|
||||
返回成功取消的订单数量。
|
||||
"""
|
||||
pending_orders = self.get_pending_orders() # 调用 BaseStrategy 自己的 get_pending_orders
|
||||
"""取消当前策略的未决订单,仅限于当前策略关注的Symbol。"""
|
||||
# 注意:在换月模式下,引擎会自动取消旧合约的挂单,这里是策略主动取消
|
||||
pending_orders = self.get_pending_orders()
|
||||
cancelled_count = 0
|
||||
orders_to_cancel = [order.id for order in pending_orders.values() if order.symbol == self.symbol]
|
||||
orders_to_cancel = [
|
||||
order.id for order in pending_orders.values() if order.symbol == self.symbol
|
||||
]
|
||||
for order_id in orders_to_cancel:
|
||||
if self.cancel_order(order_id): # 调用 BaseStrategy 自己的 cancel_order
|
||||
if self.cancel_order(order_id):
|
||||
cancelled_count += 1
|
||||
return cancelled_count
|
||||
|
||||
def get_current_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
获取当前持仓。
|
||||
通过 context 调用模拟器的 get_positions 方法。
|
||||
"""
|
||||
"""获取所有当前持仓 (可能包含多个合约)。"""
|
||||
return self.context._simulator.get_current_positions()
|
||||
|
||||
def get_pending_orders(self) -> Dict[str, 'Order']:
|
||||
"""
|
||||
获取当前所有待处理订单的副本。
|
||||
通过 context 调用模拟器的 get_pending_orders 方法。
|
||||
"""
|
||||
def get_pending_orders(self) -> Dict[str, "Order"]:
|
||||
"""获取所有当前待处理订单的副本 (可能包含多个合约)。"""
|
||||
return self.context._simulator.get_pending_orders()
|
||||
|
||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||
"""
|
||||
获取指定合约的平均持仓成本。
|
||||
通过 context 调用模拟器的 get_average_position_price 方法。
|
||||
"""
|
||||
"""获取指定合约的平均持仓成本。"""
|
||||
return self.context._simulator.get_average_position_price(symbol)
|
||||
|
||||
# 你可以根据需要在这里添加更多辅助方法,如获取账户净值等
|
||||
def get_account_cash(self) -> float:
|
||||
"""获取当前账户现金余额。"""
|
||||
return self.context._simulator.cash
|
||||
|
||||
def get_current_time(self) -> datetime:
|
||||
"""获取模拟器当前时间。"""
|
||||
return self.context._simulator.get_current_time()
|
||||
|
||||
def log(self, *args: Any, **kwargs: Any):
|
||||
"""
|
||||
@@ -121,7 +123,9 @@ class Strategy(ABC):
|
||||
if self.enable_log:
|
||||
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
|
||||
try:
|
||||
current_time_str = self.context._simulator.get_current_time().strftime('%Y-%m-%d %H:%M:%S')
|
||||
current_time_str = self.context._simulator.get_current_time().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
time_prefix = f"[{current_time_str}] "
|
||||
except AttributeError:
|
||||
# 如果获取不到时间(例如在策略初始化时,模拟器时间还未设置),则不加时间前缀
|
||||
@@ -129,8 +133,21 @@ class Strategy(ABC):
|
||||
|
||||
# 使用 f-string 结合 *args 来构建消息
|
||||
# print() 函数会将 *args 自动用空格分隔,这里我们模仿这个行为
|
||||
message = ' '.join(map(str, args))
|
||||
message = " ".join(map(str, args))
|
||||
|
||||
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print,
|
||||
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
|
||||
print(f"{time_prefix}策略 ({self.symbol}): {message}", **kwargs)
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
"""
|
||||
当回测的合约发生换月时调用此方法。
|
||||
子类可以重写此方法来执行换月相关的逻辑(例如,调整目标仓位,清空历史数据)。
|
||||
注意:在调用此方法前,引擎已强制平仓旧合约的所有仓位并取消所有挂单。
|
||||
Args:
|
||||
old_symbol (str): 旧的合约代码。
|
||||
new_symbol (str): 新的合约代码。
|
||||
"""
|
||||
self.log(f"合约换月事件: 从 {old_symbol} 切换到 {new_symbol}")
|
||||
# 默认实现可以为空,子类根据需要重写
|
||||
pass
|
||||
|
||||
@@ -57,7 +57,7 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
"""
|
||||
每接收到一根Bar时,执行策略逻辑。
|
||||
"""
|
||||
current_portfolio_value = self.context.get_current_portfolio_value(bar)
|
||||
current_portfolio_value = self.context.get_account_cash()
|
||||
# print(f"[{bar.datetime}] Strategy processing Bar. Current close price: {bar.close:.2f}. Current Portfolio Value: {current_portfolio_value:.2f}")
|
||||
|
||||
# 1. 撤销上一根K线未成交的订单
|
||||
@@ -108,16 +108,19 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
)
|
||||
|
||||
# 通过上下文发送订单
|
||||
trade = self.send_order(order)
|
||||
if trade:
|
||||
print(
|
||||
f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 @ {trade.price:.2f}(open:{bar.open}, close:{bar.close}) (订单ID: {order.id})")
|
||||
# 如果立即成交,_last_order_id 仍然保持 None
|
||||
else:
|
||||
# 如果未立即成交,将订单ID记录下来,以便下一根Bar撤销
|
||||
self._last_order_id = order.id
|
||||
print(
|
||||
f"[{bar.datetime}] 策略: 发送限价买单 {trade_volume} 股 @ {limit_price:.2f} (未成交,订单ID: {order.id} 已挂单)")
|
||||
# trade = self.send_order(order)
|
||||
# if trade:
|
||||
# print(
|
||||
# f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 (open:{bar.open}, close:{bar.close}) (订单ID: {order.id})")
|
||||
# # 如果立即成交,_last_order_id 仍然保持 None
|
||||
# else:
|
||||
# # 如果未立即成交,将订单ID记录下来,以便下一根Bar撤销
|
||||
# self._last_order_id = order.id
|
||||
# print(
|
||||
# f"[{bar.datetime}] 策略: 发送限价买单 {trade_volume} 股 @ {limit_price:.2f} (未成交,订单ID: {order.id} 已挂单)")
|
||||
order = self.send_order(order)
|
||||
if order:
|
||||
print(f"[{bar.datetime}]发送订单 {order.id}, direction {order.direction}")
|
||||
else:
|
||||
# print(f"[{bar.datetime}] 策略: 当前已有持仓或有未成交订单,不重复下单。")
|
||||
pass
|
||||
Reference in New Issue
Block a user