同步本地回测与tqsdk回测

This commit is contained in:
2025-06-29 12:03:43 +08:00
parent 4521939b95
commit 70c3b8186a
20 changed files with 29892 additions and 23140 deletions

View File

@@ -1,18 +1,19 @@
from tqsdk import TqApi, TqAuth from datetime import date
from tqsdk import TqApi, TqAuth, TqBacktest, TargetPosTask
api = TqApi(auth=TqAuth("emanresu", "dfgvfgdfgg")) '''
如果当前价格大于5分钟K线的MA15则开多仓
如果小于则平仓
回测从 2018-05-01 到 2018-10-01
'''
# 在创建 api 实例时传入 TqBacktest 就会进入回测模式
api = TqApi(backtest=TqBacktest(start_dt=date(2018, 5, 1), end_dt=date(2018, 10, 1)), auth=TqAuth("emanresu", "dfgvfgdfgg"))
# 获得 m1901 5分钟K线的引用
klines = api.get_kline_serial("DCE.m1901", 60 * 60, data_length=15)
# 创建 m1901 的目标持仓 task该 task 负责调整 m1901 的仓位到指定的目标仓位
target_pos = TargetPosTask(api, "DCE.m1901")
# au 品种指数合约 while True:
ls = api.query_quotes(ins_class="INDEX", product_id="au") api.wait_update()
print(ls) if api.is_changing(klines) and len(klines) > 2:
target_pos.set_target_volume(5)
# au 品种主连合约
ls = api.query_quotes(ins_class="CONT", product_id="au")
print(ls)
quote = api.get_quote("KQ.m@SHFE.rb")
# 打印现在螺纹钢主连的标的合约
print(quote.underlying_symbol)
# 关闭api,释放相应资源
api.close()

View File

@@ -109,6 +109,7 @@ def collect_and_save_tqsdk_data_stream(
) )
last_kline_datetime = None # 用于跟踪上一根已完成K线的时间 last_kline_datetime = None # 用于跟踪上一根已完成K线的时间
swap_month_dt = None
while api.wait_update(): while api.wait_update():
if underlying_symbol is None: if underlying_symbol is None:
@@ -117,7 +118,11 @@ def collect_and_save_tqsdk_data_stream(
# 检查是否有新的完整K线生成或者当前K线是最后一次更新 (在回测结束时) # 检查是否有新的完整K线生成或者当前K线是最后一次更新 (在回测结束时)
# TqSdk会在K线完成时发送最后一次更新或者在回测结束时强制更新 # TqSdk会在K线完成时发送最后一次更新或者在回测结束时强制更新
if api.is_changing(quote, "underlying_symbol"): if api.is_changing(quote, "underlying_symbol"):
underlying_symbol = quote.underlying_symbol swap_month_dt = pd.to_datetime(
quote.datetime, unit="ns", utc=True
)
if api.is_changing(klines): if api.is_changing(klines):
# 只有当K线序列发生变化时才处理 # 只有当K线序列发生变化时才处理
# 关注最新一根 K 线(即 klines.iloc[-1] # 关注最新一根 K 线(即 klines.iloc[-1]
@@ -141,10 +146,15 @@ def collect_and_save_tqsdk_data_stream(
kline_dt = pd.to_datetime( kline_dt = pd.to_datetime(
current_kline["datetime"], unit="ns", utc=True current_kline["datetime"], unit="ns", utc=True
) ).tz_convert(BEIJING_TZ)
kline_dt = kline_dt.tz_convert(BEIJING_TZ).strftime(
if swap_month_dt is not None and kline_dt.hour == swap_month_dt.hour:
underlying_symbol = quote.underlying_symbol
kline_dt = kline_dt.strftime(
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
) )
kline_data_to_save = { kline_data_to_save = {
"datetime": kline_dt, "datetime": kline_dt,
"open": current_kline["open"], "open": current_kline["open"],

File diff suppressed because one or more lines are too long

13759
main.ipynb

File diff suppressed because one or more lines are too long

View File

@@ -9,7 +9,7 @@ from ..core_data import PortfolioSnapshot, Trade, Bar
def calculate_metrics( def calculate_metrics(
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
纯函数:根据投资组合快照和交易历史计算关键绩效指标。 纯函数:根据投资组合快照和交易历史计算关键绩效指标。
@@ -124,27 +124,30 @@ def calculate_metrics(
"亏损交易次数": losing_count, "亏损交易次数": losing_count,
"平均每次盈利": avg_profit_per_trade, "平均每次盈利": avg_profit_per_trade,
"平均每次亏损": avg_loss_per_trade, # 这个值是负数 "平均每次亏损": avg_loss_per_trade, # 这个值是负数
"InitialCapital": initial_capital, "initial_capital": initial_capital,
"FinalCapital": final_value, "final_capital": final_value,
"TotalReturn": total_return, "total_return": total_return,
"AnnualizedReturn": annualized_return, "annualized_return": annualized_return,
"MaxDrawdown": max_drawdown, "max_drawdown": max_drawdown,
"SharpeRatio": sharpe_ratio, "sharpe_ratio": sharpe_ratio,
"CalmarRatio": calmar_ratio, "calmar_ratio": calmar_ratio,
"TotalTrades": len(trades), # All buy and sell trades "total_trades": len(trades), # All buy and sell trades
"TransactionCosts": total_commissions, "transaction_costs": total_commissions,
"TotalRealizedPNL": total_realized_pnl, # New "total_realized_pnl": total_realized_pnl, # New
"WinRate": win_rate, "win_rate": win_rate,
"ProfitLossRatio": profit_loss_ratio, "profit_loss_ratio": profit_loss_ratio,
"WinningTradesCount": winning_count, "winning_trades_count": winning_count,
"LosingTradesCount": losing_count, "losing_trades_count": losing_count,
"AvgProfitPerTrade": avg_profit_per_trade, "avg_profit_per_trade": avg_profit_per_trade,
"AvgLossPerTrade": avg_loss_per_trade, # This value is negative "avg_loss_per_trade": avg_loss_per_trade, # This value is negative
} }
def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_capital: float, def plot_equity_and_drawdown_chart(
title: str = "Portfolio Equity and Drawdown Curve") -> None: snapshots: List[PortfolioSnapshot],
initial_capital: float,
title: str = "Portfolio Equity and Drawdown Curve",
) -> None:
""" """
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced. Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
@@ -157,35 +160,45 @@ def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_c
print("No portfolio snapshots available to plot equity and drawdown.") print("No portfolio snapshots available to plot equity and drawdown.")
return return
df_equity = pd.DataFrame([ df_equity = pd.DataFrame(
{'datetime': s.datetime, 'total_value': s.total_value} [{"datetime": s.datetime, "total_value": s.total_value} for s in snapshots]
for s in snapshots )
])
equity_curve = df_equity['total_value'] / initial_capital equity_curve = df_equity["total_value"] / initial_capital
rolling_max = equity_curve.cummax() rolling_max = equity_curve.cummax()
drawdown = (rolling_max - equity_curve) / rolling_max drawdown = (rolling_max - equity_curve) / rolling_max
plt.style.use('seaborn-v0_8-darkgrid') plt.style.use("seaborn-v0_8-darkgrid")
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]}) fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(14, 10), sharex=True, gridspec_kw={"height_ratios": [3, 1]}
)
x_axis_indices = np.arange(len(df_equity)) x_axis_indices = np.arange(len(df_equity))
# Equity Curve Plot # Equity Curve Plot
ax1.plot(x_axis_indices, equity_curve, label='Equity Curve', color='blue', linewidth=1.5) ax1.plot(
ax1.set_ylabel('Equity', fontsize=12) x_axis_indices, equity_curve, label="Equity Curve", color="blue", linewidth=1.5
ax1.legend(loc='upper left') )
ax1.set_ylabel("Equity", fontsize=12)
ax1.legend(loc="upper left")
ax1.grid(True) ax1.grid(True)
ax1.set_title(title, fontsize=16) ax1.set_title(title, fontsize=16)
# Drawdown Curve Plot # Drawdown Curve Plot
ax2.fill_between(x_axis_indices, 0, drawdown, color='red', alpha=0.3) ax2.fill_between(x_axis_indices, 0, drawdown, color="red", alpha=0.3)
ax2.plot(x_axis_indices, drawdown, color='red', linewidth=1.0, linestyle='--', label='Drawdown') ax2.plot(
ax2.set_ylabel('Drawdown Rate', fontsize=12) x_axis_indices,
ax2.set_xlabel('Data Point Index (Date Labels Below)', fontsize=12) drawdown,
ax2.set_title('Portfolio Drawdown Curve', fontsize=14) color="red",
ax2.legend(loc='upper left') linewidth=1.0,
linestyle="--",
label="Drawdown",
)
ax2.set_ylabel("Drawdown Rate", fontsize=12)
ax2.set_xlabel("Data Point Index (Date Labels Below)", fontsize=12)
ax2.set_title("Portfolio Drawdown Curve", fontsize=14)
ax2.legend(loc="upper left")
ax2.grid(True) ax2.grid(True)
ax2.set_ylim(0, max(drawdown.max() * 1.1, 0.05)) ax2.set_ylim(0, max(drawdown.max() * 1.1, 0.05))
@@ -193,9 +206,12 @@ def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_c
num_ticks = 10 num_ticks = 10
if len(df_equity) > 0: if len(df_equity) > 0:
tick_positions = np.linspace(0, len(df_equity) - 1, num_ticks, dtype=int) tick_positions = np.linspace(0, len(df_equity) - 1, num_ticks, dtype=int)
tick_labels = [df_equity['datetime'].iloc[i].strftime('%Y-%m-%d %H:%M') for i in tick_positions] tick_labels = [
df_equity["datetime"].iloc[i].strftime("%Y-%m-%d %H:%M")
for i in tick_positions
]
ax1.set_xticks(tick_positions) ax1.set_xticks(tick_positions)
ax1.set_xticklabels(tick_labels, rotation=45, ha='right') ax1.set_xticklabels(tick_labels, rotation=45, ha="right")
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()
@@ -213,30 +229,38 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
print("No bar data available to plot close price.") print("No bar data available to plot close price.")
return return
df_prices = pd.DataFrame([ df_prices = pd.DataFrame(
{'datetime': b.datetime, 'close_price': b.close} [{"datetime": b.datetime, "close_price": b.close} for b in bars]
for b in bars )
])
plt.style.use('seaborn-v0_8-darkgrid') plt.style.use("seaborn-v0_8-darkgrid")
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
x_axis_indices = np.arange(len(df_prices)) x_axis_indices = np.arange(len(df_prices))
ax.plot(x_axis_indices, df_prices['close_price'], label='Close Price', color='orange', linewidth=1.5) ax.plot(
ax.set_ylabel('Price', fontsize=12) x_axis_indices,
ax.set_xlabel('Data Point Index (Date Labels Below)', fontsize=12) df_prices["close_price"],
label="Close Price",
color="orange",
linewidth=1.5,
)
ax.set_ylabel("Price", fontsize=12)
ax.set_xlabel("Data Point Index (Date Labels Below)", fontsize=12)
ax.set_title(title, fontsize=16) ax.set_title(title, fontsize=16)
ax.legend(loc='upper left') ax.legend(loc="upper left")
ax.grid(True) ax.grid(True)
# Set X-axis ticks to show actual dates at intervals # Set X-axis ticks to show actual dates at intervals
num_ticks = 10 num_ticks = 10
if len(df_prices) > 0: if len(df_prices) > 0:
tick_positions = np.linspace(0, len(df_prices) - 1, num_ticks, dtype=int) tick_positions = np.linspace(0, len(df_prices) - 1, num_ticks, dtype=int)
tick_labels = [df_prices['datetime'].iloc[i].strftime('%Y-%m-%d %H:%M') for i in tick_positions] tick_labels = [
df_prices["datetime"].iloc[i].strftime("%Y-%m-%d %H:%M")
for i in tick_positions
]
ax.set_xticks(tick_positions) ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_labels, rotation=45, ha='right') ax.set_xticklabels(tick_labels, rotation=45, ha="right")
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()
@@ -244,7 +268,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
# 辅助函数:计算单笔交易的盈亏 # 辅助函数:计算单笔交易的盈亏
def calculate_trade_pnl( def calculate_trade_pnl(
trade: Trade, entry_price: float, exit_price: float, direction: str trade: Trade, entry_price: float, exit_price: float, direction: str
) -> float: ) -> float:
if direction == "LONG": if direction == "LONG":
pnl = (exit_price - entry_price) * trade.volume pnl = (exit_price - entry_price) * trade.volume

View File

@@ -65,6 +65,7 @@ class GridSearchAnalyzer:
y_idx = y_values.index(item[self.param2_name]) y_idx = y_values.index(item[self.param2_name])
heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric] heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric]
print([x_values[0], x_values[-1], y_values[0], y_values[-1]])
fig, ax = plt.subplots(figsize=(10, 8)) fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower', im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower',
extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]], extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]],

View File

@@ -59,6 +59,18 @@ class ResultAnalyzer:
""" """
生成并打印详细的回测报告。 生成并打印详细的回测报告。
""" """
if self.trade_history:
print("\n--- 交易明细 ---")
for trade in self.trade_history:
# 调整输出格式,显示实现盈亏
pnl_display = f" | PnL: {trade.realized_pnl:.2f}" if trade.is_close_trade else ""
print(
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Comm: {trade.commission:.2f}{pnl_display}"
)
else:
print("\n没有交易记录。")
metrics = self.calculate_all_metrics() metrics = self.calculate_all_metrics()
print("\n--- 回测绩效报告 ---") print("\n--- 回测绩效报告 ---")
@@ -83,17 +95,6 @@ class ResultAnalyzer:
print(f"{'平均每次亏损':<15}: {metrics['平均每次亏损']:.2f}") print(f"{'平均每次亏损':<15}: {metrics['平均每次亏损']:.2f}")
if self.trade_history:
print("\n--- 部分交易明细 (最近5笔) ---")
for trade in self.trade_history[-5:]:
# 调整输出格式,显示实现盈亏
pnl_display = f" | PnL: {trade.realized_pnl:.2f}" if trade.is_close_trade else ""
print(
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Comm: {trade.commission:.2f}{pnl_display}"
)
else:
print("\n没有交易记录。")
def plot_performance(self) -> None: def plot_performance(self) -> None:
""" """
绘制投资组合净值和回撤曲线,以及所有合约的收盘价曲线。 绘制投资组合净值和回撤曲线,以及所有合约的收盘价曲线。

View File

@@ -91,6 +91,9 @@ class BacktestContext:
""" """
self._engine = engine self._engine = engine
def get_bar_history(self):
return self._engine.get_bar_history()
@property @property
def is_rollover_bar(self) -> bool: def is_rollover_bar(self) -> bool:
""" """

View File

@@ -89,9 +89,12 @@ class BacktestEngine:
# 主回测循环 # 主回测循环
while True: while True:
current_bar = self.data_manager.get_next_bar() current_bar = self.data_manager.get_next_bar()
if current_bar is None: if current_bar is None:
break # 没有更多数据,回测结束 break # 没有更多数据,回测结束
self.all_bars.append(current_bar)
if self.start_time and current_bar.datetime < self.start_time: if self.start_time and current_bar.datetime < self.start_time:
continue continue
@@ -104,11 +107,11 @@ class BacktestEngine:
# 1. 重置 is_rollover_bar 标记 # 1. 重置 is_rollover_bar 标记
self.is_rollover_bar = False self.is_rollover_bar = False
# 4. 更新 Context 和 Simulator 的当前 Bar 和时间
self.context.set_current_bar(current_bar)
self.simulator.update_time(current_time=current_bar.datetime)
# 2. 如果启用换月模式,并且检测到合约 symbol 变化 # 2. 如果启用换月模式,并且检测到合约 symbol 变化
if current_bar.symbol != self._last_processed_bar_symbol:
print(self.roll_over_mode,
self._last_processed_bar_symbol,
current_bar.symbol, self._last_processed_bar_symbol)
if self.roll_over_mode and \ if self.roll_over_mode and \
self._last_processed_bar_symbol is not None and \ self._last_processed_bar_symbol is not None and \
current_bar.symbol != self._last_processed_bar_symbol: current_bar.symbol != self._last_processed_bar_symbol:
@@ -141,20 +144,22 @@ class BacktestEngine:
# 3. 更新策略关注的当前合约 symbol # 3. 更新策略关注的当前合约 symbol
self.strategy.symbol = current_bar.symbol self.strategy.symbol = current_bar.symbol
# 4. 更新 Context 和 Simulator 的当前 Bar 和时间
self.context.set_current_bar(current_bar)
self.simulator.update_time(current_time=current_bar.datetime)
# 5. 更新引擎内部的历史 Bar 缓存 # 5. 更新引擎内部的历史 Bar 缓存
self._history_bars.append(current_bar) self._history_bars.append(current_bar)
if len(self._history_bars) > self._max_history_bars: if len(self._history_bars) > self._max_history_bars:
self._history_bars.pop(0) self._history_bars.pop(0)
# 6. 处理待撮合订单 (在调用策略 on_bar 之前,确保订单在当前 K 线开盘价撮合) # 6. 处理待撮合订单 (在调用策略 on_bar 之前,确保订单在当前 K 线开盘价撮合)
self.simulator.process_pending_orders(current_bar) # self.simulator.process_pending_orders(current_bar)
self.strategy.on_open_bar(current_bar)
# 7. 调用策略的 on_bar 方法 # 7. 调用策略的 on_bar 方法
self.strategy.on_bar(current_bar) # self.strategy.on_bar(current_bar)
self.simulator.process_pending_orders(current_bar)
self.strategy.on_close_bar(current_bar)
self.simulator.process_pending_orders(current_bar)
# 8. 记录投资组合快照 # 8. 记录投资组合快照
current_portfolio_value = self.simulator.get_portfolio_value(current_bar) current_portfolio_value = self.simulator.get_portfolio_value(current_bar)
@@ -170,7 +175,6 @@ class BacktestEngine:
price_at_snapshot=price_at_snapshot price_at_snapshot=price_at_snapshot
) )
self.portfolio_snapshots.append(snapshot) self.portfolio_snapshots.append(snapshot)
self.all_bars.append(current_bar)
# 9. 更新 `_last_processed_bar_symbol` 和 `last_processed_bar` 为当前 Bar为下一轮循环做准备 # 9. 更新 `_last_processed_bar_symbol` 和 `last_processed_bar` 为当前 Bar为下一轮循环做准备
self._last_processed_bar_symbol = current_bar.symbol self._last_processed_bar_symbol = current_bar.symbol
@@ -222,3 +226,7 @@ class BacktestEngine:
def get_simulator(self) -> ExecutionSimulator: def get_simulator(self) -> ExecutionSimulator:
return self.simulator return self.simulator
def get_bar_history(self):
return self.all_bars

View File

@@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional
import uuid # 用于生成唯一订单ID import uuid # 用于生成唯一订单ID
@dataclass(frozen=True) # frozen=True 使实例变为不可变 @dataclass() # 使实例变为不可变
class Bar: class Bar:
""" """
K线数据对象包含期货或股票的 OHLCV 和持仓量信息。 K线数据对象包含期货或股票的 OHLCV 和持仓量信息。
@@ -25,11 +25,11 @@ class Bar:
""" """
数据验证(可选):确保持仓量为非负整数。 数据验证(可选):确保持仓量为非负整数。
""" """
if not isinstance(self.volume, int) or self.volume < 0: if self.volume < 0:
raise ValueError(f"Volume must be a non-negative integer, got {self.volume}") raise ValueError(f"Volume must be a non-negative integer, got {self.volume}")
if not isinstance(self.open_oi, int) or self.open_oi < 0: if self.open_oi < 0:
raise ValueError(f"Open interest must be a non-negative integer, got {self.open_oi}") raise ValueError(f"Open interest must be a non-negative integer, got {self.open_oi}")
if not isinstance(self.close_oi, int) or self.close_oi < 0: if self.close_oi < 0:
raise ValueError(f"Close interest must be a non-negative integer, got {self.close_oi}") raise ValueError(f"Close interest must be a non-negative integer, got {self.close_oi}")
# 验证价格是否合理 # 验证价格是否合理
@@ -41,7 +41,7 @@ class Bar:
pass pass
@dataclass(frozen=True) @dataclass()
class Order: class Order:
""" """
代表一个待执行的交易指令。 代表一个待执行的交易指令。
@@ -53,6 +53,7 @@ class Order:
price_type: str = "MARKET" # "MARKET", "LIMIT" (简易版默认市价) price_type: str = "MARKET" # "MARKET", "LIMIT" (简易版默认市价)
limit_price: Optional[float] = None # 限价单价格 limit_price: Optional[float] = None # 限价单价格
submitted_time: pd.Timestamp = field(default_factory=pd.Timestamp.now) # 订单提交时间 submitted_time: pd.Timestamp = field(default_factory=pd.Timestamp.now) # 订单提交时间
offset: str = "OPEN"
def __post_init__(self): def __post_init__(self):
if self.direction not in ["BUY", "SELL", "CLOSE_LONG", "CLOSE_SHORT", "CANCEL"]: if self.direction not in ["BUY", "SELL", "CLOSE_LONG", "CLOSE_SHORT", "CANCEL"]:
@@ -63,7 +64,7 @@ class Order:
raise ValueError(f"Order volume must be a positive integer, got {self.volume}") raise ValueError(f"Order volume must be a positive integer, got {self.volume}")
@dataclass(frozen=True) @dataclass()
class Trade: class Trade:
""" """
代表一个已完成的成交记录。 代表一个已完成的成交记录。
@@ -83,7 +84,7 @@ class Trade:
is_close_trade: bool = False # <--- 新增字段:是否是平仓交易 (用于计算盈亏) is_close_trade: bool = False # <--- 新增字段:是否是平仓交易 (用于计算盈亏)
@dataclass(frozen=True) @dataclass()
class PortfolioSnapshot: class PortfolioSnapshot:
""" """
在特定时间点记录投资组合的快照。 在特定时间点记录投资组合的快照。

View File

@@ -1,5 +1,6 @@
# src/data_manager.py (修改并添加 get_history_bars 方法) # src/data_manager.py (修改并添加 get_history_bars 方法)
from datetime import datetime
import pandas as pd import pandas as pd
from typing import Iterator, List, Dict, Any, Optional from typing import Iterator, List, Dict, Any, Optional
import os import os
@@ -15,10 +16,16 @@ class DataManager: # DataManager 现在是一个类,以便维护内部索引
并提供获取历史Bar的能力。 并提供获取历史Bar的能力。
""" """
def __init__(self, file_path: str, symbol: str, tz='Asia/Shanghai'): def __init__(self, file_path: str, symbol: str, tz='Asia/Shanghai', start_time: Optional[datetime] = None, end_time: Optional[datetime] = None):
self.file_path = file_path self.file_path = file_path
self.tz = tz self.tz = tz
self.raw_df = load_raw_data(self.file_path) # 调用函数式加载数据 self.raw_df = load_raw_data(self.file_path) # 调用函数式加载数据
self.start_time = start_time
self.end_time = end_time
if start_time is not None and end_time is not None:
self.raw_df = self.raw_df[(self.raw_df.index >= start_time) & (self.raw_df.index <= end_time)]
# self.bars = list(df_to_bar_stream(self.raw_df)) # 一次性转换所有Bar方便历史数据查找 # self.bars = list(df_to_bar_stream(self.raw_df)) # 一次性转换所有Bar方便历史数据查找
# 优化使用内部迭代器和缓存避免一次性生成所有Bar但可以按需提供历史 # 优化使用内部迭代器和缓存避免一次性生成所有Bar但可以按需提供历史
self._bar_generator = df_to_bar_stream(self.raw_df, symbol) self._bar_generator = df_to_bar_stream(self.raw_df, symbol)
@@ -26,6 +33,8 @@ class DataManager: # DataManager 现在是一个类,以便维护内部索引
self._last_bar_index = -1 # 用于跟踪当前的Bar在原始df中的索引 self._last_bar_index = -1 # 用于跟踪当前的Bar在原始df中的索引
self.symbol = symbol self.symbol = symbol
def get_next_bar(self) -> Optional[Bar]: def get_next_bar(self) -> Optional[Bar]:
""" """
按顺序返回下一根 Bar 对象。 按顺序返回下一根 Bar 对象。

View File

@@ -43,6 +43,7 @@ def load_raw_data(file_path: str) -> pd.DataFrame:
missing_cols = [col for col in expected_cols[1:] if col not in df.columns] missing_cols = [col for col in expected_cols[1:] if col not in df.columns]
if missing_cols: if missing_cols:
print(f"CSV文件中缺少以下列: {', '.join(missing_cols)}") print(f"CSV文件中缺少以下列: {', '.join(missing_cols)}")
expected_cols = [col for col in expected_cols if col in df.columns]
# 确保数据按时间排序 (这是回测的基础) # 确保数据按时间排序 (这是回测的基础)
df = df.sort_index() df = df.sort_index()
@@ -51,7 +52,7 @@ def load_raw_data(file_path: str) -> pd.DataFrame:
print(f"数据范围从 {df.index.min()}{df.index.max()}") print(f"数据范围从 {df.index.min()}{df.index.max()}")
print(f"总计 {len(df)} 条记录。") print(f"总计 {len(df)} 条记录。")
return df[expected_cols[1:]] # 返回包含核心数据的DataFrame return df[expected_cols] # 返回包含核心数据的DataFrame
except Exception as e: except Exception as e:
print(f"加载数据时发生错误: {e}") print(f"加载数据时发生错误: {e}")
raise raise
@@ -68,7 +69,6 @@ def df_to_bar_stream(df: pd.DataFrame, symbol: str) -> Iterator[Bar]:
Bar: 逐个生成的 Bar 对象。 Bar: 逐个生成的 Bar 对象。
""" """
print("开始将 DataFrame 转换为 Bar 对象流...") print("开始将 DataFrame 转换为 Bar 对象流...")
print(df)
for index, row in df.iterrows(): for index, row in df.iterrows():
try: try:
if 'underlying_symbol' in df.columns and row['underlying_symbol'] != '': if 'underlying_symbol' in df.columns and row['underlying_symbol'] != '':

View File

@@ -11,17 +11,31 @@ class ExecutionSimulator:
模拟交易执行和管理账户资金、持仓。 模拟交易执行和管理账户资金、持仓。
""" """
def __init__(self, initial_capital: float, def __init__(
slippage_rate: float = 0.0001, self,
commission_rate: float = 0.0002, initial_capital: float,
initial_positions: Optional[Dict[str, int]] = None): slippage_rate: float = 0.0001,
commission_rate: float = 0.0002,
initial_positions: Optional[Dict[str, int]] = None,
initial_average_costs: Optional[Dict[str, float]] = None,
): # 新增参数
self.initial_capital = initial_capital self.initial_capital = initial_capital
self.cash = initial_capital self.cash = initial_capital
self.positions: Dict[str, int] = initial_positions if initial_positions is not None else {} self.positions: Dict[str, int] = (
self.average_costs: Dict[str, float] = {} initial_positions if initial_positions is not None else {}
if initial_positions: )
# 修正初始平均成本应该从参数传入而不是默认0.0
self.average_costs: Dict[str, float] = (
initial_average_costs if initial_average_costs is not None else {}
)
# 如果提供了 initial_positions 但没有提供 initial_average_costs可以警告或默认处理
if initial_positions and not initial_average_costs:
print(
f"[{datetime.now()}] 警告: 提供了初始持仓但未提供初始平均成本这些持仓的成本默认为0.0。"
)
for symbol, qty in initial_positions.items(): for symbol, qty in initial_positions.items():
self.average_costs[symbol] = 0.0 if symbol not in self.average_costs:
self.average_costs[symbol] = 0.0 # 如果没有提供默认给0
self.slippage_rate = slippage_rate self.slippage_rate = slippage_rate
self.commission_rate = commission_rate self.commission_rate = commission_rate
@@ -30,94 +44,94 @@ class ExecutionSimulator:
self._current_time: Optional[datetime] = None self._current_time: Optional[datetime] = None
print( print(
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}") f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}"
)
if self.positions: if self.positions:
print(f"初始持仓:{self.positions}") print(f"初始持仓:{self.positions}")
print(f"初始平均成本:{self.average_costs}") # 打印初始成本以便检查
def update_time(self, current_time: datetime): def update_time(self, current_time: datetime):
self._current_time = current_time self._current_time = current_time
def get_current_time(self) -> datetime: def get_current_time(self) -> datetime:
if self._current_time is None: if self._current_time is None:
# 改进:如果时间未设置,可以抛出错误,防止策略在 on_init 阶段意外调用
# raise RuntimeError("Simulator time has not been set. Ensure update_time is called.")
return None return None
return self._current_time return self._current_time
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float: def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
""" """
内部方法:根据订单类型和滑点计算实际成交价格。 内部方法:根据订单类型和滑点计算实际成交价格。
撮合逻辑:所有订单(市价/限价都以当前K线的 **开盘价 (open)** 为基准进行撮合。 撮合逻辑:
- 市价单以当前K线的 **开盘价 (open)** 为基准进行撮合,并考虑滑点。
- 限价单:判断 K 线的 **最高价 (high)** 和 **最低价 (low)** 是否触及限价。如果触及,则以 **限价 (limit_price)** 为基准计算成交价,并考虑滑点。
""" """
fill_price = -1.0 # 默认未成交 fill_price = -1.0 # 默认未成交
base_price = current_bar.open # 所有成交都以当前K线的开盘价为基准 # 对于市价单,仍然使用开盘价为基准检查点
base_price_for_market_order = current_bar.open
if order.price_type == "MARKET": if order.price_type == "MARKET":
# 市价单:直接以开盘价成交,考虑滑点 # 市价单:直接以开盘价成交,考虑滑点
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 买入/平空:向上偏离(多付) if (
fill_price = base_price * (1 + self.slippage_rate) order.direction == "BUY" or order.direction == "CLOSE_SHORT"
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 卖出/平:向偏离(少收 ): # 买入/平:向偏离(多付
fill_price = base_price * (1 - self.slippage_rate) fill_price = (base_price_for_market_order + 1) * (1 + self.slippage_rate)
elif (
order.direction == "SELL" or order.direction == "CLOSE_LONG"
): # 卖出/平多:向下偏离(少收)
fill_price = (base_price_for_market_order - 1) * (1 - self.slippage_rate)
else: else:
fill_price = base_price # 理论上不发生 fill_price = base_price_for_market_order # 理论上不发生
elif order.price_type == "LIMIT" and order.limit_price is not None: elif order.price_type == "LIMIT" and order.limit_price is not None:
limit_price = order.limit_price limit_price = order.limit_price
# 限价单:判断开盘价是否满足限价条件,如果满足,则以开盘价成交(考虑滑点) # 限价单:判断 K 线的高低价是否触及限价
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 限价买入/平空 if (
# 买单只有当开盘价低于或等于限价时才可能成交 order.direction == "BUY" or order.direction == "CLOSE_SHORT"
# 即:我愿意出 limit_price 买,开盘价 open_price 更低或一样,当然买 ): # 限价买入/平空
if base_price <= limit_price: # 如果当前K线的最低价低于或等于限价则买入限价单有机会成交
fill_price = base_price * (1 + self.slippage_rate) if current_bar.low < limit_price:
# else: 未满足限价条件,不成交 # 成交价以限价为基准,并考虑滑点(买入向上偏离)
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 限价卖出/平多 fill_price = limit_price * (1 + self.slippage_rate)
# 卖单只有当开盘价高于或等于限价时才可能成交 elif (
# 即:我愿意出 limit_price 卖,开盘价 open_price 更高或一样,当然卖 order.direction == "SELL" or order.direction == "CLOSE_LONG"
if base_price >= limit_price: ): # 限价卖出/平多
fill_price = base_price * (1 - self.slippage_rate) # 如果当前K线的最高价高于或等于限价则卖出限价单有机会成交
# else: 未满足限价条件,不成交 if current_bar.high > limit_price:
# 成交价以限价为基准,并考虑滑点(卖出向下偏离)
fill_price = limit_price * (1 - self.slippage_rate)
# 最终检查成交价是否有效且合理大于0
if fill_price <= 0: if fill_price <= 0:
return -1.0 # 未成交或价格无效 return -1.0
return fill_price return fill_price
def send_order_to_pending(self, order: Order) -> Optional[Order]: def send_order_to_pending(self, order: Order) -> Optional[Order]:
""" """
将订单添加到待处理队列。由 BacktestEngine 或 Strategy 调用。 将订单添加到待处理队列。由 BacktestEngine 或 Strategy 调用。
此方法不进行撮合,撮合由 process_pending_orders 统一处理。 此方法不进行撮合,撮合由 process_pending_orders 统一处理。
""" """
if order.id in self.pending_orders: if order.id in self.pending_orders:
# print(f"订单 {order.id} 已经存在于待处理队列。")
return None return None
self.pending_orders[order.id] = order self.pending_orders[order.id] = order
# print(f"订单 {order.id} 加入待处理队列。")
return order return order
def process_pending_orders(self, current_bar: Bar): def process_pending_orders(self, current_bar: Bar):
""" """
处理所有待撮合的订单。在每个K线数据到来时调用。 处理所有待撮合的订单。在每个K线数据到来时调用。
""" """
# 复制一份待处理订单的键,防止在迭代时修改字典
order_ids_to_process = list(self.pending_orders.keys()) order_ids_to_process = list(self.pending_orders.keys())
for order_id in order_ids_to_process: for order_id in order_ids_to_process:
if order_id not in self.pending_orders: # 订单可能已被取消 if order_id not in self.pending_orders:
continue continue
order = self.pending_orders[order_id] order = self.pending_orders[order_id]
# 只有当订单的symbol与当前bar的symbol一致时才尝试撮合
# 这样确保了在换月后,旧合约的挂单不会被尝试撮合 (尽管换月时会强制取消)
if order.symbol != current_bar.symbol: if order.symbol != current_bar.symbol:
# 这种情况理论上应该被换月逻辑清理掉的旧合约挂单,
# 如果因为某种原因漏掉了,这里直接跳过,避免异常。
continue continue
# 尝试成交订单
self._execute_single_order(order, current_bar) self._execute_single_order(order, current_bar)
def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]: def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
@@ -125,193 +139,185 @@ class ExecutionSimulator:
内部方法:尝试执行单个订单,并处理资金和持仓变化。 内部方法:尝试执行单个订单,并处理资金和持仓变化。
由 send_order 或 process_pending_orders 调用。 由 send_order 或 process_pending_orders 调用。
""" """
# --- 处理撤单指令 --- if order.direction == "CANCEL":
if order.direction == "CANCEL": # 策略主动发起撤单
success = self.cancel_order(order.id) success = self.cancel_order(order.id)
if success: if success:
# print(f"[{current_bar.datetime}] 模拟器: 收到并成功处理撤单指令 for Order ID: {order.id}")
pass pass
return None # 撤单操作不返回Trade return None
symbol = order.symbol symbol = order.symbol
volume = order.volume volume = order.volume
# 尝试计算成交价格
fill_price = self._calculate_fill_price(order, current_bar) fill_price = self._calculate_fill_price(order, current_bar)
if fill_price <= 0:
if fill_price <= 0: # 未成交或不满足限价条件
return None return None
# --- 以下是订单成功成交前的预检查逻辑 ---
trade_value = volume * fill_price trade_value = volume * fill_price
commission = trade_value * self.commission_rate commission = trade_value * self.commission_rate
current_position = self.positions.get(symbol, 0) current_position = self.positions.get(symbol, 0)
current_average_cost = self.average_costs.get(symbol, 0.0) current_average_cost = self.average_costs.get(symbol, 0.0)
realized_pnl = 0.0 # 预先计算的实现盈亏 realized_pnl = 0.0
# -----------------------------------------------------------
# 精确判断 is_open_trade 和 is_close_trade
# -----------------------------------------------------------
is_trade_a_close_operation = False is_trade_a_close_operation = False
is_trade_an_open_operation = False is_trade_an_open_operation = False
# 1. 判断是否为平仓操作 if order.direction in ["CLOSE_LONG", "CLOSE_SHORT"]:
# 显式平仓指令
if order.direction in ["CLOSE_LONG", "CLOSE_SELL", "CLOSE_SHORT"]:
is_trade_a_close_operation = True is_trade_a_close_operation = True
# 隐式平仓 (例如,持有空头时买入,或持有多头时卖出) elif order.direction == "BUY" and current_position < 0:
elif order.direction == "BUY" and current_position < 0: # 买入平空
is_trade_a_close_operation = True is_trade_a_close_operation = True
elif order.direction == "SELL" and current_position > 0: # 卖出平多 elif order.direction == "SELL" and current_position > 0:
is_trade_a_close_operation = True is_trade_a_close_operation = True
# 2. 判断是否为开仓操作
if order.direction == "BUY": if order.direction == "BUY":
# 买入开多: 如果当前持有多头或无仓位,或者从空头转为多头 if current_position >= 0 or (
if current_position >= 0 or (current_position < 0 and (current_position + volume) > 0): current_position < 0 and (current_position + volume) > 0
):
is_trade_an_open_operation = True is_trade_an_open_operation = True
elif order.direction == "SELL": elif order.direction == "SELL":
# 卖出开空: 如果当前持有空头或无仓位,或者从多头转为空头 if current_position <= 0 or (
if current_position <= 0 or (current_position > 0 and (current_position - volume) < 0): current_position > 0 and (current_position - volume) < 0
):
is_trade_an_open_operation = True is_trade_an_open_operation = True
# -----------------------------------------------------------
# 区分实际的买卖方向 (用于资金和持仓计算)
actual_execution_direction = "" actual_execution_direction = ""
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": if order.direction == "BUY" or order.direction == "CLOSE_SHORT":
actual_execution_direction = "BUY" actual_execution_direction = "BUY"
elif order.direction == "SELL" or order.direction == "CLOSE_LONG" or order.direction == "CLOSE_SELL": elif order.direction == "SELL" or order.direction == "CLOSE_LONG":
actual_execution_direction = "SELL" actual_execution_direction = "SELL"
else: else:
print( print(
f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。") f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。"
if order.id in self.pending_orders: del self.pending_orders[order.id] )
if order.id in self.pending_orders:
del self.pending_orders[order.id]
return None return None
# --- 临时变量,用于预计算新的资金和持仓状态 ---
temp_cash = self.cash temp_cash = self.cash
temp_positions = self.positions.copy() temp_positions = self.positions.copy()
temp_average_costs = self.average_costs.copy() temp_average_costs = self.average_costs.copy()
# 根据实际执行方向进行预计算和资金检查 if actual_execution_direction == "BUY":
if actual_execution_direction == "BUY": # 处理实际的买入 (开多 / 平空) if current_position >= 0:
if current_position >= 0: # 当前持有多仓或无仓位 (开多)
required_cash = trade_value + commission required_cash = trade_value + commission
if temp_cash < required_cash: if temp_cash < required_cash:
print( # print(
f"[{current_bar.datetime}] 模拟器: 资金不足 (开多), 无法执行买入 {volume} {symbol} @ {fill_price:.2f}. 需要: {required_cash:.2f}, 当前: {temp_cash:.2f}") # f"[{current_bar.datetime}] 模拟器: 资金不足 (开多), 无法执行买入 {volume} {symbol} @ {fill_price:.2f}. 需要: {required_cash:.2f}, 当前: {temp_cash:.2f}"
if order.id in self.pending_orders: del self.pending_orders[order.id] # )
if order.id in self.pending_orders:
del self.pending_orders[order.id]
return None return None
temp_cash -= required_cash temp_cash -= required_cash
new_total_cost = (temp_average_costs.get(symbol, 0.0) * temp_positions.get(symbol, 0)) + ( new_total_cost = (
fill_price * volume) temp_average_costs.get(symbol, 0.0) * temp_positions.get(symbol, 0)
) + (fill_price * volume)
new_total_volume = temp_positions.get(symbol, 0) + volume new_total_volume = temp_positions.get(symbol, 0) + volume
temp_average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume > 0 else 0.0 temp_average_costs[symbol] = (
new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
)
temp_positions[symbol] = new_total_volume temp_positions[symbol] = new_total_volume
else: # 当前持有空仓 (平空) - 平仓交易,佣金从交易价值中扣除,不单独检查现金余额 else: # 当前持有空仓 (平空) - 平仓交易
pnl_per_share = current_average_cost - fill_price # 空头平仓盈亏 pnl_per_share = current_average_cost - fill_price
realized_pnl = pnl_per_share * volume realized_pnl = pnl_per_share * volume
temp_cash -= commission # 扣除佣金 temp_cash -= commission
temp_cash += trade_value # 回收平仓价值 temp_cash -= trade_value
temp_cash += realized_pnl # 计入实现盈亏 temp_cash += realized_pnl
temp_positions[symbol] += volume temp_positions[symbol] += volume
if temp_positions[symbol] == 0: if temp_positions[symbol] == 0:
del temp_positions[symbol] del temp_positions[symbol]
if symbol in temp_average_costs: del temp_average_costs[symbol] if symbol in temp_average_costs:
elif current_position < 0 and temp_positions[symbol] > 0: # 发生空转多 del temp_average_costs[symbol]
temp_average_costs[symbol] = fill_price # 新多头仓位成本以成交价为准 elif current_position < 0 and temp_positions[symbol] > 0:
temp_average_costs[symbol] = fill_price
elif actual_execution_direction == "SELL":
elif actual_execution_direction == "SELL": # 处理实际的卖出 (开空 / 平多)
if current_position <= 0: # 当前持有空仓或无仓位 (开空) if current_position <= 0: # 当前持有空仓或无仓位 (开空)
# 开空主要检查佣金是否足够
if temp_cash < commission: if temp_cash < commission:
print( # print(
f"[{current_bar.datetime}] 模拟器: 资金不足 (开空佣金), 无法执行卖出 {volume} {symbol} @ {fill_price:.2f}. 佣金: {commission:.2f}, 当前: {temp_cash:.2f}") # f"[{current_bar.datetime}] 模拟器: 资金不足 (开空佣金), 无法执行卖出 {volume} {symbol} @ {fill_price:.2f}. 佣金: {commission:.2f}, 当前: {temp_cash:.2f}"
if order.id in self.pending_orders: del self.pending_orders[order.id] # )
if order.id in self.pending_orders:
del self.pending_orders[order.id]
return None return None
temp_cash -= commission temp_cash -= commission
new_total_value = (temp_average_costs.get(symbol, 0.0) * abs(temp_positions.get(symbol, 0))) + ( temp_cash += trade_value # 修正点:开空时将卖出资金计入现金
fill_price * volume)
new_total_volume = abs(temp_positions.get(symbol, 0)) + volume
temp_average_costs[symbol] = new_total_value / new_total_volume if new_total_volume > 0 else 0.0 # 平均成本
temp_positions[symbol] -= volume
else: # 当前持有多仓 (平多) - 平仓交易,佣金从交易价值中扣除,不单独检查现金余额 existing_abs_volume = abs(temp_positions.get(symbol, 0))
pnl_per_share = fill_price - current_average_cost # 多头平仓盈亏 existing_abs_cost = (
temp_average_costs.get(symbol, 0.0) * existing_abs_volume
)
new_total_value = existing_abs_cost + (fill_price * volume)
new_total_volume = existing_abs_volume + volume
temp_average_costs[symbol] = (
new_total_value / new_total_volume if new_total_volume > 0 else 0.0
)
temp_positions[symbol] = -new_total_volume
else: # 当前持有多仓 (平多) - 平仓交易
pnl_per_share = fill_price - current_average_cost
realized_pnl = pnl_per_share * volume realized_pnl = pnl_per_share * volume
temp_cash -= commission # 扣除佣金 temp_cash -= commission
temp_cash += trade_value # 回收平仓价值 temp_cash += trade_value
temp_cash += realized_pnl # 计入实现盈亏 temp_cash += realized_pnl
temp_positions[symbol] -= volume temp_positions[symbol] -= volume
if temp_positions[symbol] == 0: if temp_positions[symbol] == 0:
del temp_positions[symbol] del temp_positions[symbol]
if symbol in temp_average_costs: del temp_average_costs[symbol] if symbol in temp_average_costs:
elif current_position > 0 and temp_positions[symbol] < 0: # 发生多转空 del temp_average_costs[symbol]
temp_average_costs[symbol] = fill_price # 新空头仓位成本以成交价为准 elif current_position > 0 and temp_positions[symbol] < 0:
temp_average_costs[symbol] = fill_price
# --- 所有检查通过后,才正式更新模拟器状态 ---
self.cash = temp_cash self.cash = temp_cash
self.positions = temp_positions self.positions = temp_positions
self.average_costs = temp_average_costs self.average_costs = temp_average_costs
# 创建 Trade 对象时direction 使用原始订单的 direction
executed_trade = Trade( executed_trade = Trade(
order_id=order.id, fill_time=current_bar.datetime, symbol=symbol, order_id=order.id,
direction=order.direction, # 使用原始订单的 direction fill_time=current_bar.datetime,
volume=volume, price=fill_price, commission=commission, symbol=symbol,
cash_after_trade=self.cash, positions_after_trade=self.positions.copy(), direction=order.direction,
volume=volume,
price=fill_price,
commission=commission,
cash_after_trade=self.cash,
positions_after_trade=self.positions.copy(),
realized_pnl=realized_pnl, realized_pnl=realized_pnl,
is_open_trade=is_trade_an_open_operation, # 使用更精确的判断 is_open_trade=is_trade_an_open_operation,
is_close_trade=is_trade_a_close_operation # 使用更精确的判断 is_close_trade=is_trade_a_close_operation,
) )
self.trade_log.append(executed_trade) self.trade_log.append(executed_trade)
# 订单成交,从待处理订单中移除
if order.id in self.pending_orders: if order.id in self.pending_orders:
del self.pending_orders[order.id] del self.pending_orders[order.id]
return executed_trade return executed_trade
def cancel_order(self, order_id: str) -> bool: def cancel_order(self, order_id: str) -> bool:
"""
尝试取消一个待处理订单。
"""
if order_id in self.pending_orders: if order_id in self.pending_orders:
del self.pending_orders[order_id] del self.pending_orders[order_id]
return True return True
return False return False
# --- 新增:强制平仓指定合约的所有持仓 --- def force_close_all_positions_for_symbol(
def force_close_all_positions_for_symbol(self, symbol_to_close: str, closing_bar: Bar) -> List[Trade]: self, symbol_to_close: str, closing_bar: Bar
""" ) -> List[Trade]:
强制平仓指定合约的所有持仓。
Args:
symbol_to_close (str): 需要平仓的合约代码。
closing_bar (Bar): 用于获取平仓价格的当前K线数据通常是旧合约的最后一根K线
Returns:
List[Trade]: 因强制平仓而产生的交易记录。
"""
closed_trades: List[Trade] = [] closed_trades: List[Trade] = []
# 仅处理指定symbol的持仓
if symbol_to_close in self.positions and self.positions[symbol_to_close] != 0: if symbol_to_close in self.positions and self.positions[symbol_to_close] != 0:
volume_to_close = self.positions[symbol_to_close] volume_to_close = self.positions[symbol_to_close]
# 根据持仓方向决定平仓订单的方向 direction = "CLOSE_LONG" if volume_to_close > 0 else "CLOSE_SHORT"
direction = "CLOSE_LONG" if volume_to_close > 0 else "CLOSE_SELL" # 多头平仓是卖出,空头平仓是买入
# 构造一个市价平仓订单
rollover_order = Order( rollover_order = Order(
id=f"FORCE_CLOSE_{symbol_to_close}_{closing_bar.datetime.strftime('%Y%m%d%H%M%S%f')}", id=f"FORCE_CLOSE_{symbol_to_close}_{closing_bar.datetime.strftime('%Y%m%d%H%M%S%f')}",
symbol=symbol_to_close, symbol=symbol_to_close,
@@ -321,28 +327,26 @@ class ExecutionSimulator:
limit_price=None, limit_price=None,
submitted_time=closing_bar.datetime, submitted_time=closing_bar.datetime,
) )
# 这里直接调用 _execute_single_order 确保强制平仓立即成交
# 使用内部的执行逻辑进行撮合
trade = self._execute_single_order(rollover_order, closing_bar) trade = self._execute_single_order(rollover_order, closing_bar)
if trade: if trade:
closed_trades.append(trade) closed_trades.append(trade)
else: else:
print(f"[{closing_bar.datetime}] 警告: 强制平仓 {symbol_to_close} 失败!") print(
f"[{closing_bar.datetime}] 警告: 强制平仓 {symbol_to_close} 失败!"
)
return closed_trades return closed_trades
# --- 新增:取消指定合约的所有挂单 ---
def cancel_all_pending_orders_for_symbol(self, symbol_to_cancel: str) -> int: def cancel_all_pending_orders_for_symbol(self, symbol_to_cancel: str) -> int:
"""
取消指定合约的所有待处理订单。
"""
cancelled_count = 0 cancelled_count = 0
order_ids_to_cancel = [ order_ids_to_cancel = [
order_id for order_id, order in self.pending_orders.items() order_id
for order_id, order in self.pending_orders.items()
if order.symbol == symbol_to_cancel if order.symbol == symbol_to_cancel
] ]
for order_id in order_ids_to_cancel: for order_id in order_ids_to_cancel:
if self.cancel_order(order_id): # 调用现有的 cancel_order 方法 if self.cancel_order(order_id):
cancelled_count += 1 cancelled_count += 1
return cancelled_count return cancelled_count
@@ -350,37 +354,15 @@ class ExecutionSimulator:
return self.pending_orders.copy() return self.pending_orders.copy()
def get_portfolio_value(self, current_bar: Bar) -> float: def get_portfolio_value(self, current_bar: Bar) -> float:
"""
计算当前的投资组合总价值(包括现金和持仓市值)。
此方法需要兼容多合约持仓的场景。
Args:
current_bar (Bar): 当前的Bar数据用于计算**当前活跃合约**的持仓市值。
注意:如果 simulator 中持有多个合约,这里需要更复杂的逻辑。
目前假设主力合约回测时simulator.positions 主要只包含当前主力合约。
Returns:
float: 当前的投资组合总价值。
"""
total_value = self.cash total_value = self.cash
# 遍历所有持仓,计算市值。
# 注意:这里假设 current_bar 提供了当前活跃主力合约的价格。
# 如果 self.positions 中包含其他非 current_bar.symbol 的旧合约,
# 它们的市值将无法用 current_bar.open 来准确计算。
# 在换月模式下,旧合约会被强制平仓,因此 simulator.positions 通常只包含一个合约。
for symbol, quantity in self.positions.items(): for symbol, quantity in self.positions.items():
# 这里简单处理:如果持仓合约与 current_bar.symbol 相同,则使用 current_bar.open 计算。
# 如果是其他合约,则需要外部提供其最新价格,但这超出了本函数当前的能力范围。
# 考虑到换月模式,旧合约会被平仓,所以大部分时候这不会是问题。
if symbol == current_bar.symbol: if symbol == current_bar.symbol:
total_value += quantity * current_bar.open total_value += quantity * current_bar.open
else: else:
# 警告:如果这里出现,说明有未平仓的旧合约持仓,且没有其最新价格来计算市值。 print(
# 在严谨的主力连续回测中,这不应该发生,因为换月会强制平仓。 f"[{current_bar.datetime}] 警告持仓中存在非当前K线合约 {symbol},无法准确计算其市值。"
print(f"[{current_bar.datetime}] 警告持仓中存在非当前K线合约 {symbol},无法准确计算其市值。") )
# 可以选择将这部分持仓价值计为0或者使用上一个已知价格需要额外数据结构
# 这里我们假设它不影响总价值计算,因为换月时会处理掉
pass pass
return total_value return total_value
def get_current_positions(self) -> Dict[str, int]: def get_current_positions(self) -> Dict[str, int]:
@@ -389,22 +371,38 @@ class ExecutionSimulator:
def get_trade_history(self) -> List[Trade]: def get_trade_history(self) -> List[Trade]:
return self.trade_log.copy() return self.trade_log.copy()
def reset(self, new_initial_capital: float = None, new_initial_positions: Dict[str, int] = None) -> None: def reset(
""" self,
重置模拟器状态到新的初始条件。 new_initial_capital: float = None,
此方法不用于换月时的平仓,它用于整个回测开始前的初始化。 new_initial_positions: Dict[str, int] = None,
""" new_initial_average_costs: Dict[str, float] = None,
) -> None: # 新增参数
print("ExecutionSimulator: 重置状态。") print("ExecutionSimulator: 重置状态。")
self.cash = new_initial_capital if new_initial_capital is not None else self.initial_capital self.cash = (
self.positions = new_initial_positions.copy() if new_initial_positions is not None else {} new_initial_capital
self.average_costs = {} if new_initial_capital is not None
for symbol, qty in self.positions.items(): # 重置平均成本 else self.initial_capital
self.average_costs[symbol] = 0.0 )
self.trade_log = [] self.positions = (
self.pending_orders = {} # 清空挂单 new_initial_positions.copy() if new_initial_positions is not None else {}
self._current_time = None )
# 修正:重置时也应该考虑传入初始平均成本
self.average_costs = (
new_initial_average_costs.copy()
if new_initial_average_costs is not None
else {}
)
if self.positions and not new_initial_average_costs:
print(
f"[{datetime.now()}] 警告: 重置时提供了初始持仓但未提供初始平均成本这些持仓的成本默认为0.0。"
)
for symbol, qty in self.positions.items():
if symbol not in self.average_costs:
self.average_costs[symbol] = 0.0
# Removed clear_trade_history as trade_log is cleared in reset self.trade_log = []
self.pending_orders = {}
self._current_time = None
def get_average_position_price(self, symbol: str) -> Optional[float]: def get_average_position_price(self, symbol: str) -> Optional[float]:
if symbol in self.positions and self.positions[symbol] != 0: if symbol in self.positions and self.positions[symbol] != 0:

View File

@@ -5,7 +5,8 @@ from ..core_data import Bar, Order
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from collections import deque from collections import deque
class SimpleLimitBuyStrategy(Strategy):
class SimpleLimitBuyStrategyLong(Strategy):
""" """
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。 一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
具备以下特点: 具备以下特点:
@@ -13,16 +14,23 @@ class SimpleLimitBuyStrategy(Strategy):
- 最多只能有一个开仓挂单和一个持仓。 - 最多只能有一个开仓挂单和一个持仓。
- 包含简单的止损和止盈逻辑。 - 包含简单的止损和止盈逻辑。
""" """
def __init__(self, simulator: Any, symbol: str, enable_log: bool, trade_volume: int,
open_range_factor_1_ago: float, def __init__(
open_range_factor_7_ago: float, self,
max_position: int, context: Any,
stop_loss_points: float = 10, # 新增:止损点数 symbol: str,
take_profit_points: float = 10): # 新增:止盈点数 enable_log: bool,
trade_volume: int,
open_range_factor_1_ago: float,
open_range_factor_7_ago: float,
max_position: int,
stop_loss_points: float = 10, # 新增:止损点数
take_profit_points: float = 10,
): # 新增:止盈点数
""" """
初始化策略。 初始化策略。
Args: Args:
simulator: 模拟器实例。 context: 模拟器实例。
symbol (str): 交易合约代码。 symbol (str): 交易合约代码。
trade_volume (int): 单笔交易量。 trade_volume (int): 单笔交易量。
open_range_factor_1_ago (float): 前1根K线Range的权重因子用于从Open价向下偏移。 open_range_factor_1_ago (float): 前1根K线Range的权重因子用于从Open价向下偏移。
@@ -31,33 +39,34 @@ class SimpleLimitBuyStrategy(Strategy):
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。 stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。 take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
""" """
super().__init__(simulator, symbol, enable_log) super().__init__(context, symbol, enable_log)
self.trade_volume = trade_volume self.trade_volume = trade_volume
self.open_range_factor_1_ago = open_range_factor_1_ago self.open_range_factor_1_ago = open_range_factor_1_ago
self.open_range_factor_7_ago = open_range_factor_7_ago self.open_range_factor_7_ago = open_range_factor_7_ago
self.max_position = max_position # 理论上这里应为1 self.max_position = max_position # 理论上这里应为1
self.stop_loss_points = stop_loss_points self.stop_loss_points = stop_loss_points
self.take_profit_points = take_profit_points self.take_profit_points = take_profit_points
self.order_id_counter = 0 self.order_id_counter = 0
self._bar_history: deque[Bar] = deque(maxlen=10) self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
self.log(f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, " self.log(
f"open_range_factor_1_ago={self.open_range_factor_1_ago}, " f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
f"open_range_factor_7_ago={self.open_range_factor_7_ago}, " f"open_range_factor_1_ago={self.open_range_factor_1_ago}, "
f"max_position={self.max_position}, " f"open_range_factor_7_ago={self.open_range_factor_7_ago}, "
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}") f"max_position={self.max_position}, "
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}"
)
def on_bar(self, bar: Bar, next_bar_open: Optional[float] = None): def on_open_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
""" """
每当新的K线数据到来时调用。 每当新的K线数据到来时调用。
Args: Args:
bar (Bar): 当前的K线数据对象。 bar (Bar): 当前的K线数据对象。
next_bar_open (Optional[float]): 下一根K线的开盘价此处策略未使用。 next_bar_open (Optional[float]): 下一根K线的开盘价此处策略未使用。
""" """
current_datetime = bar.datetime # 获取当前K线时间 current_datetime = bar.datetime # 获取当前K线时间
self.symbol = bar.symbol self.symbol = bar.symbol
# --- 1. 撤销上一根K线未成交的订单 --- # --- 1. 撤销上一根K线未成交的订单 ---
@@ -65,36 +74,46 @@ class SimpleLimitBuyStrategy(Strategy):
if self._last_order_id: if self._last_order_id:
pending_orders = self.get_pending_orders() pending_orders = self.get_pending_orders()
if self._last_order_id in pending_orders: if self._last_order_id in pending_orders:
success = self.cancel_order(self._last_order_id) # 直接调用基类的取消方法 success = self.cancel_order(
self._last_order_id
) # 直接调用基类的取消方法
if success: if success:
self.log(f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}") self.log(
f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}"
)
else: else:
self.log(f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)") self.log(
f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)"
)
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录 # 无论撤销成功与否,既然我们尝试了撤销,就清除记录
self._last_order_id = None self._last_order_id = None
# else: # else:
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。") # self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
# 2. 更新K线历史 # 2. 更新K线历史
self._bar_history.append(bar)
trade_volume = self.trade_volume trade_volume = self.trade_volume
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态) # 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
current_positions = self.get_current_positions() current_positions = self.get_current_positions()
current_pos_volume = current_positions.get(self.symbol, 0) current_pos_volume = current_positions.get(self.symbol, 0)
pending_orders_after_cancel = self.get_pending_orders() # 再次获取,此时应已取消旧订单 pending_orders_after_cancel = (
self.get_pending_orders()
) # 再次获取,此时应已取消旧订单
# --- 3. 平仓逻辑 (止损/止盈) --- # --- 3. 平仓逻辑 (止损/止盈) ---
# 只有当有持仓时才考虑平仓 # 只有当有持仓时才考虑平仓
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0 if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
avg_entry_price = self.get_average_position_price(self.symbol) avg_entry_price = self.get_average_position_price(self.symbol)
if avg_entry_price is not None: if avg_entry_price is not None:
pnl_per_unit = bar.close - avg_entry_price # 当前浮动盈亏(以收盘价计算) pnl_per_unit = (
bar.open - avg_entry_price
) # 当前浮动盈亏(以收盘价计算)
# 止盈条件 # 止盈条件
if pnl_per_unit >= self.take_profit_points: if pnl_per_unit >= self.take_profit_points:
self.log(f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}") self.log(
f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}"
)
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}" order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1 self.order_id_counter += 1
@@ -107,14 +126,17 @@ class SimpleLimitBuyStrategy(Strategy):
volume=trade_volume, volume=trade_volume,
price_type="MARKET", price_type="MARKET",
# limit_price=limit_price, # limit_price=limit_price,
submitted_time=bar.datetime submitted_time=bar.datetime,
offset="CLOSE",
) )
trade = self.send_order(order) trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断 return # 平仓后本K线不再进行开仓判断
# 止损条件 # 止损条件
elif pnl_per_unit <= -self.stop_loss_points: elif pnl_per_unit <= -self.stop_loss_points:
self.log(f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}") self.log(
f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}"
)
# 发送市价卖出订单平仓,确保立即成交 # 发送市价卖出订单平仓,确保立即成交
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}" order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1 self.order_id_counter += 1
@@ -127,20 +149,21 @@ class SimpleLimitBuyStrategy(Strategy):
volume=trade_volume, volume=trade_volume,
price_type="MARKET", price_type="MARKET",
# limit_price=limit_price, # limit_price=limit_price,
submitted_time=bar.datetime submitted_time=bar.datetime,
offset="CLOSE",
) )
trade = self.send_order(order) trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断 return # 平仓后本K线不再进行开仓判断
# --- 4. 开仓逻辑 (只考虑做多 BUY 方向) --- # --- 4. 开仓逻辑 (只考虑做多 BUY 方向) ---
# 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel) # 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel)
# 且K线历史足够长时才考虑开仓 # 且K线历史足够长时才考虑开仓
if current_pos_volume == 0 and \ bar_history = self.get_bar_history()
len(self._bar_history) == self._bar_history.maxlen: if current_pos_volume == 0 and len(bar_history) > 10:
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根) # 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
bar_1_ago = self._bar_history[-2] bar_1_ago = bar_history[-2]
bar_7_ago = self._bar_history[-8] bar_7_ago = bar_history[-8]
# 计算历史 K 线的 Range # 计算历史 K 线的 Range
range_1_ago = bar_1_ago.high - bar_1_ago.low range_1_ago = bar_1_ago.high - bar_1_ago.low
@@ -148,16 +171,24 @@ class SimpleLimitBuyStrategy(Strategy):
# 根据策略逻辑计算目标买入价格 # 根据策略逻辑计算目标买入价格
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2) # 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
self.log(bar.open ,range_1_ago * self.open_range_factor_1_ago, range_7_ago * self.open_range_factor_7_ago) self.log(
target_buy_price = bar.open - (range_1_ago * self.open_range_factor_1_ago + range_7_ago * self.open_range_factor_7_ago) bar.open,
range_1_ago * self.open_range_factor_1_ago,
range_7_ago * self.open_range_factor_7_ago,
)
target_buy_price = bar.open - (
range_1_ago * self.open_range_factor_1_ago
+ range_7_ago * self.open_range_factor_7_ago
)
# 确保目标买入价格有效,例如不能是负数 # 确保目标买入价格有效,例如不能是负数
target_buy_price = max(0.01, target_buy_price) target_buy_price = max(0.01, target_buy_price)
self.log(f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, " self.log(
f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, " f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, "
f"计算目标买入价={target_buy_price:.2f}") f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, "
self.log(f'{self.context._simulator.get_current_positions()}') f"计算目标买入价={target_buy_price:.2f}"
)
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}" order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1 self.order_id_counter += 1
@@ -170,26 +201,554 @@ class SimpleLimitBuyStrategy(Strategy):
volume=trade_volume, volume=trade_volume,
price_type="LIMIT", price_type="LIMIT",
limit_price=target_buy_price, limit_price=target_buy_price,
submitted_time=bar.datetime submitted_time=bar.datetime,
) )
new_order = self.send_order(order) new_order = self.send_order(order)
# 记录下这个订单的ID以便在下一根K线开始时进行撤销 # 记录下这个订单的ID以便在下一根K线开始时进行撤销
if new_order: if new_order:
self._last_order_id = new_order.id self._last_order_id = new_order.id
self.log(f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}") self.log(
f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}"
)
else: else:
self.log(f"[{current_datetime}] 策略: 发送订单失败。") self.log(f"[{current_datetime}] 策略: 发送订单失败。")
# else: # else:
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(self._bar_history)}") # self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(bar_history)}")
def on_close_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
self.cancel_all_pending_orders()
def on_rollover(self, old_symbol: str, new_symbol: str): def on_rollover(self, old_symbol: str, new_symbol: str):
""" """
在合约换月时清空历史K线数据和上次订单ID避免使用旧合约数据进行计算。 在合约换月时清空历史K线数据和上次订单ID避免使用旧合约数据进行计算。
""" """
super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志 super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志
self._bar_history.clear() # 清空历史K线 self._last_order_id = None # 清空上次订单ID因为旧合约订单已取消
self._last_order_id = None # 清空上次订单ID因为旧合约订单已取消
self.log(f"换月完成清空历史K线数据和上次订单ID准备新合约交易。") self.log(f"换月完成清空历史K线数据和上次订单ID准备新合约交易。")
class SimpleLimitBuyStrategyShort(Strategy):
"""
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
具备以下特点:
- 每根K线开始时取消上一根K线未成交的订单。
- 最多只能有一个开仓挂单和一个持仓。
- 包含简单的止损和止盈逻辑。
"""
def __init__(
self,
context: Any,
symbol: str,
enable_log: bool,
trade_volume: int,
open_range_factor_1_ago: float,
open_range_factor_7_ago: float,
max_position: int,
stop_loss_points: float = 10, # 新增:止损点数
take_profit_points: float = 10,
): # 新增:止盈点数
"""
初始化策略。
Args:
context: 模拟器实例。
symbol (str): 交易合约代码。
trade_volume (int): 单笔交易量。
open_range_factor_1_ago (float): 前1根K线Range的权重因子用于从Open价向下偏移。
open_range_factor_7_ago (float): 前7根K线Range的权重因子用于从Open价向下偏移。
max_position (int): 最大持仓量此处为1因为只允许一个持仓
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
"""
super().__init__(context, symbol, enable_log)
self.trade_volume = trade_volume
self.open_range_factor_1_ago = open_range_factor_1_ago
self.open_range_factor_7_ago = open_range_factor_7_ago
self.max_position = max_position # 理论上这里应为1
self.stop_loss_points = stop_loss_points
self.take_profit_points = take_profit_points
self.order_id_counter = 0
self.last_buy_price = 0
self.last_sell_price = 0
bar_history: deque[Bar] = deque(maxlen=10)
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
self.log(
f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
f"open_range_factor_1_ago={self.open_range_factor_1_ago}, "
f"open_range_factor_7_ago={self.open_range_factor_7_ago}, "
f"max_position={self.max_position}, "
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}"
)
def on_open_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
"""
每当新的K线数据到来时调用。
Args:
bar (Bar): 当前的K线数据对象。
next_bar_open (Optional[float]): 下一根K线的开盘价此处策略未使用。
"""
current_datetime = bar.datetime # 获取当前K线时间
self.symbol = bar.symbol
# --- 1. 撤销上一根K线未成交的订单 ---
# 检查是否记录了上一笔订单ID并且该订单仍然在待处理列表中
if self._last_order_id:
pending_orders = self.get_pending_orders()
if self._last_order_id in pending_orders:
success = self.cancel_order(
self._last_order_id
) # 直接调用基类的取消方法
if success:
self.log(
f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}"
)
else:
self.log(
f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)"
)
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录
self._last_order_id = None
# else:
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
# 2. 更新K线历史
trade_volume = self.trade_volume
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
current_positions = self.get_current_positions()
current_pos_volume = current_positions.get(self.symbol, 0)
pending_orders_after_cancel = (
self.get_pending_orders()
) # 再次获取,此时应已取消旧订单
# --- 3. 平仓逻辑 (止损/止盈) ---
# 只有当有持仓时才考虑平仓
if current_pos_volume < 0: # 假设只做多,所以持仓量 > 0
avg_entry_price = self.get_average_position_price(self.symbol)
if avg_entry_price is not None:
pnl_per_unit = (
avg_entry_price - bar.open
) # 当前浮动盈亏(以收盘价计算)
# 止盈条件
if pnl_per_unit >= self.take_profit_points:
self.log(
f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}"
)
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="CLOSE_SHORT",
volume=trade_volume,
price_type="MARKET",
# limit_price=limit_price,
submitted_time=bar.datetime,
offset="CLOSE",
)
trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断
# 止损条件
elif pnl_per_unit <= -self.stop_loss_points:
self.log(
f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}"
)
# 发送市价卖出订单平仓,确保立即成交
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="CLOSE_SHORT",
volume=trade_volume,
price_type="MARKET",
# limit_price=limit_price,
submitted_time=bar.datetime,
offset="CLOSE",
)
trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断
# --- 4. 开仓逻辑 (只考虑做多 BUY 方向) ---
# 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel)
# 且K线历史足够长时才考虑开仓
bar_history = self.get_bar_history()
if current_pos_volume == 0 and len(bar_history) > 10:
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
bar_1_ago = bar_history[-2]
bar_7_ago = bar_history[-8]
# 计算历史 K 线的 Range
range_1_ago = bar_1_ago.high - bar_1_ago.low
range_7_ago = bar_7_ago.high - bar_7_ago.low
# 根据策略逻辑计算目标买入价格
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
self.log(
bar.open,
range_1_ago * self.open_range_factor_1_ago,
range_7_ago * self.open_range_factor_7_ago,
)
target_buy_price = bar.open + (
range_1_ago * self.open_range_factor_1_ago
+ range_7_ago * self.open_range_factor_7_ago
)
# 确保目标买入价格有效,例如不能是负数
target_buy_price = max(0.01, target_buy_price)
self.log(
f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, "
f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, "
f"计算目标买入价={target_buy_price:.2f}"
)
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="SELL",
volume=trade_volume,
price_type="LIMIT",
limit_price=target_buy_price,
submitted_time=bar.datetime,
)
new_order = self.send_order(order)
# 记录下这个订单的ID以便在下一根K线开始时进行撤销
if new_order:
self._last_order_id = new_order.id
self.log(
f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}"
)
else:
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
# else:
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(bar_history)}")
def on_rollover(self, old_symbol: str, new_symbol: str):
"""
在合约换月时清空历史K线数据和上次订单ID避免使用旧合约数据进行计算。
"""
super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志
# bar_history.clear() # 清空历史K线
self._last_order_id = None # 清空上次订单ID因为旧合约订单已取消
self.log(f"换月完成清空历史K线数据和上次订单ID准备新合约交易。")
class SimpleLimitBuyStrategy(Strategy):
"""
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
具备以下特点:
- 每根K线开始时取消上一根K线未成交的订单。
- 最多只能有一个开仓挂单和一个持仓。
- 包含简单的止损和止盈逻辑。
"""
def __init__(
self,
context: Any,
symbol: str,
enable_log: bool,
trade_volume: int,
open_range_factor_1_long: float,
open_range_factor_7_long: float,
open_range_factor_1_short: float,
open_range_factor_7_short: float,
max_position: int,
stop_loss_points: float = 10, # 新增:止损点数
take_profit_points: float = 10,
): # 新增:止盈点数
"""
初始化策略。
Args:
context: 模拟器实例。
symbol (str): 交易合约代码。
trade_volume (int): 单笔交易量。
open_range_factor_1_ago (float): 前1根K线Range的权重因子用于从Open价向下偏移。
open_range_factor_7_ago (float): 前7根K线Range的权重因子用于从Open价向下偏移。
max_position (int): 最大持仓量此处为1因为只允许一个持仓
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
"""
super().__init__(context, symbol, enable_log)
self.last_buy_price = 0
self.last_sell_price = 0
self.trade_volume = trade_volume
self.open_range_factor_1_long = open_range_factor_1_long
self.open_range_factor_7_long = open_range_factor_7_long
self.open_range_factor_1_short = open_range_factor_1_short
self.open_range_factor_7_short = open_range_factor_7_short
self.max_position = max_position # 理论上这里应为1
self.stop_loss_points = stop_loss_points
self.take_profit_points = take_profit_points
self.order_id_counter = 0
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
self.log(
f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
f"max_position={self.max_position}, "
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}"
)
def on_open_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
"""
每当新的K线数据到来时调用。
Args:
bar (Bar): 当前的K线数据对象。
next_bar_open (Optional[float]): 下一根K线的开盘价此处策略未使用。
"""
current_datetime = bar.datetime # 获取当前K线时间
self.symbol = bar.symbol
# --- 1. 撤销上一根K线未成交的订单 ---
# 检查是否记录了上一笔订单ID并且该订单仍然在待处理列表中
if self._last_order_id:
pending_orders = self.get_pending_orders()
# if self._last_order_id in pending_orders:
# success = self.cancel_order(self._last_order_id) # 直接调用基类的取消方法
# if success:
# self.log(f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}")
# else:
# self.log(f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)")
# # 无论撤销成功与否,既然我们尝试了撤销,就清除记录
# self._last_order_id = None
self.cancel_all_pending_orders()
# else:
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
# 2. 更新K线历史
trade_volume = self.trade_volume
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
current_positions = self.get_current_positions()
current_pos_volume = current_positions.get(self.symbol, 0)
pending_orders_after_cancel = (
self.get_pending_orders()
) # 再次获取,此时应已取消旧订单
# --- 3. 平仓逻辑 (止损/止盈) ---
# 只有当有持仓时才考虑平仓
self.log(current_positions, self.symbol)
if current_pos_volume < 0: # 假设只做多,所以持仓量 > 0
avg_entry_price = self.get_average_position_price(self.symbol)
avg_entry_price = self.last_sell_price
if avg_entry_price is not None:
pnl_per_unit = (
avg_entry_price - bar.open
) # 当前浮动盈亏(以收盘价计算)
# 止盈条件
if pnl_per_unit >= self.take_profit_points:
self.log(
f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}"
)
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="CLOSE_SHORT",
volume=trade_volume,
price_type="MARKET",
# limit_price=limit_price,
submitted_time=bar.datetime,
offset="CLOSE",
)
trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断
# 止损条件
elif pnl_per_unit <= -self.stop_loss_points:
self.log(
f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}"
)
# 发送市价卖出订单平仓,确保立即成交
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="CLOSE_SHORT",
volume=trade_volume,
price_type="MARKET",
# limit_price=limit_price,
submitted_time=bar.datetime,
offset="CLOSE",
)
trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
avg_entry_price = self.get_average_position_price(self.symbol)
avg_entry_price = self.last_buy_price
if avg_entry_price is not None:
pnl_per_unit = (
bar.open - avg_entry_price
) # 当前浮动盈亏(以收盘价计算)
# 止盈条件
if pnl_per_unit >= self.take_profit_points:
self.log(
f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}"
)
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="CLOSE_LONG",
volume=trade_volume,
price_type="MARKET",
# limit_price=limit_price,
submitted_time=bar.datetime,
offset="CLOSE",
)
trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断
# 止损条件
elif pnl_per_unit <= -self.stop_loss_points:
self.log(
f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}"
)
# 发送市价卖出订单平仓,确保立即成交
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="CLOSE_LONG",
volume=trade_volume,
price_type="MARKET",
# limit_price=limit_price,
submitted_time=bar.datetime,
offset="CLOSE",
)
trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断
bar_history = self.get_bar_history()
if current_pos_volume == 0 and len(bar_history) > 10:
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
bar_1_ago = bar_history[-2]
bar_7_ago = bar_history[-8]
print(bar_1_ago, bar_7_ago)
# 计算历史 K 线的 Range
range_1_ago = bar_1_ago.high - bar_1_ago.low
range_7_ago = bar_7_ago.high - bar_7_ago.low
# 根据策略逻辑计算目标买入价格
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
target_buy_price = bar.open + (
range_1_ago * self.open_range_factor_1_short
+ range_7_ago * self.open_range_factor_7_short
)
# 确保目标买入价格有效,例如不能是负数
target_buy_price = max(0.01, target_buy_price)
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=bar.symbol,
direction="SELL",
volume=trade_volume,
price_type="LIMIT",
limit_price=target_buy_price,
submitted_time=bar.datetime,
)
self.last_sell_price = target_buy_price
new_order = self.send_order(order)
# 记录下这个订单的ID以便在下一根K线开始时进行撤销
if new_order:
self._last_order_id = new_order.id
self.log(
f"[{current_datetime}] 策略: 发送限价SELL订单 {self._last_order_id} @ {target_buy_price:.2f}"
)
else:
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
target_buy_price = bar.open - (
range_1_ago * self.open_range_factor_1_long
+ range_7_ago * self.open_range_factor_7_long
)
# 确保目标买入价格有效,例如不能是负数
target_buy_price = max(0.01, target_buy_price)
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=bar.symbol,
direction="BUY",
volume=trade_volume,
price_type="LIMIT",
limit_price=target_buy_price,
submitted_time=bar.datetime,
)
new_order = self.send_order(order)
self.last_buy_price = target_buy_price
# 记录下这个订单的ID以便在下一根K线开始时进行撤销
if new_order:
self._last_order_id = new_order.id
self.log(
f"[{current_datetime}] 策略: 发送限价BUY订单 {self._last_order_id} @ {target_buy_price:.2f}"
)
else:
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
# else:
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(bar_history)}")
def on_close_bar(self, bar):
self.log('on close bar!')
self.log(self.get_pending_orders())
self.cancel_all_pending_orders()
def on_rollover(self, old_symbol: str, new_symbol: str):
"""
在合约换月时清空历史K线数据和上次订单ID避免使用旧合约数据进行计算。
"""
super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志
self._last_order_id = None # 清空上次订单ID因为旧合约订单已取消
self.log(f"换月完成清空历史K线数据和上次订单ID准备新合约交易。")

View File

@@ -2,6 +2,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
import math
from typing import Dict, Any, Optional, List, TYPE_CHECKING from typing import Dict, Any, Optional, List, TYPE_CHECKING
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示 # 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
@@ -52,7 +53,7 @@ class Strategy(ABC):
pass # 默认不执行任何操作,具体策略可覆盖 pass # 默认不执行任何操作,具体策略可覆盖
@abstractmethod @abstractmethod
def on_bar(self, bar: "Bar"): def on_open_bar(self, bar: "Bar"):
""" """
每当新的K线数据到来时调用此方法。 每当新的K线数据到来时调用此方法。
Args: Args:
@@ -61,6 +62,15 @@ class Strategy(ABC):
""" """
pass pass
def on_close_bar(self, bar: "Bar"):
"""
每当新的K线数据到来时调用此方法。
Args:
bar (Bar): 当前的K线数据对象。
next_bar_close (Optional[float]): 下一根K线的开盘价如果存在的话。
"""
pass
# --- 新增/修改的辅助方法 --- # --- 新增/修改的辅助方法 ---
def send_order(self, order: "Order") -> Optional[Order]: def send_order(self, order: "Order") -> Optional[Order]:
@@ -71,6 +81,21 @@ class Strategy(ABC):
if self.context.is_rollover_bar: if self.context.is_rollover_bar:
self.log(f"当前是换月K线禁止开仓订单") self.log(f"当前是换月K线禁止开仓订单")
return None return None
if order.price_type == 'LIMIT':
limit_price = order.limit_price
if order.direction in ["BUY", "CLOSE_SHORT"]:
# 买入限价单(或平空),希望以更低或相等的价格成交,
# 所以向下取整,确保挂单价格不高于预期。
# 例如价格100.3tick_size=1 -> math.floor(100.3) = 100
# 价格100.8tick_size=1 -> math.floor(100.8) = 100
order.limit_price = math.floor(limit_price)
elif order.direction in ["SELL", "CLOSE_LONG"]:
# 卖出限价单(或平多),希望以更高或相等的价格成交,
# 所以向上取整,确保挂单价格不低于预期。
# 例如价格100.3tick_size=1 -> math.ceil(100.3) = 101
# 价格100.8tick_size=1 -> math.ceil(100.8) = 101
order.limit_price = math.ceil(limit_price)
return self.context.send_order(order) return self.context.send_order(order)
def cancel_order(self, order_id: str) -> bool: def cancel_order(self, order_id: str) -> bool:
@@ -96,23 +121,23 @@ class Strategy(ABC):
def get_current_positions(self) -> Dict[str, int]: def get_current_positions(self) -> Dict[str, int]:
"""获取所有当前持仓 (可能包含多个合约)。""" """获取所有当前持仓 (可能包含多个合约)。"""
return self.context._simulator.get_current_positions() return self.context.get_current_positions()
def get_pending_orders(self) -> Dict[str, "Order"]: def get_pending_orders(self) -> Dict[str, "Order"]:
"""获取所有当前待处理订单的副本 (可能包含多个合约)。""" """获取所有当前待处理订单的副本 (可能包含多个合约)。"""
return self.context._simulator.get_pending_orders() return self.context.get_pending_orders()
def get_average_position_price(self, symbol: str) -> Optional[float]: def get_average_position_price(self, symbol: str) -> Optional[float]:
"""获取指定合约的平均持仓成本。""" """获取指定合约的平均持仓成本。"""
return self.context._simulator.get_average_position_price(symbol) return self.context.get_average_position_price(symbol)
def get_account_cash(self) -> float: def get_account_cash(self) -> float:
"""获取当前账户现金余额。""" """获取当前账户现金余额。"""
return self.context._simulator.cash return self.context.cash
def get_current_time(self) -> datetime: def get_current_time(self) -> datetime:
"""获取模拟器当前时间。""" """获取模拟器当前时间。"""
return self.context._simulator.get_current_time() return self.context.get_current_time()
def log(self, *args: Any, **kwargs: Any): def log(self, *args: Any, **kwargs: Any):
""" """
@@ -123,7 +148,7 @@ class Strategy(ABC):
if self.enable_log: if self.enable_log:
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀 # 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
try: try:
current_time_str = self.context._simulator.get_current_time().strftime( current_time_str = self.context.get_current_time().strftime(
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
) )
time_prefix = f"[{current_time_str}] " time_prefix = f"[{current_time_str}] "
@@ -151,3 +176,6 @@ class Strategy(ABC):
self.log(f"合约换月事件: 从 {old_symbol} 切换到 {new_symbol}") self.log(f"合约换月事件: 从 {old_symbol} 切换到 {new_symbol}")
# 默认实现可以为空,子类根据需要重写 # 默认实现可以为空,子类根据需要重写
pass pass
def get_bar_history(self):
return self.context.get_bar_history()

View File

@@ -9,7 +9,7 @@ from .base_strategy import Strategy
from ..core_data import Bar, Order, Trade from ..core_data import Bar, Order, Trade
class SimpleLimitBuyStrategy(Strategy): class TestStrategy(Strategy):
""" """
一个简单的限价买入策略: 一个简单的限价买入策略:
在每根Bar线上如果当前没有持仓且没有待处理的买入订单则尝试下一个限价多单。 在每根Bar线上如果当前没有持仓且没有待处理的买入订单则尝试下一个限价多单。

202
src/tqsdk_context.py Normal file
View File

@@ -0,0 +1,202 @@
# filename: tqsdk_context.py
from datetime import datetime
from typing import Optional, Any, Dict, List, Literal, Deque, TYPE_CHECKING
from collections import deque
# 导入你提供的 core_data 中的类型
from src.core_data import Bar, Order, Trade, PortfolioSnapshot # 确保此路径正确如果core_data不在同级目录需要调整
# 导入 Tqsdk 的核心类型
import tqsdk
from tqsdk import TqApi, TqAccount, tafunc
import pandas as pd
# 使用 TYPE_CHECKING 避免循环导入,只在类型检查时导入 TqsdkEngine
if TYPE_CHECKING:
from src.tqsdk_engine import TqsdkEngine # 假设 TqsdkEngine 在 tqsdk_engine.py 中
class TqsdkContext:
"""
Tqsdk 回测上下文,适配原有 BacktestContext 接口。
策略通过此上下文与 Tqsdk 进行交互。
"""
def __init__(self, api: TqApi):
"""
初始化 Tqsdk 回测上下文。
Args:
api (TqApi): Tqsdk 的 TqApi 实例。
"""
self._api = api
self._current_bar: Optional[Bar] = None
self._engine: Optional['TqsdkEngine'] = None # 添加对引擎的引用,用于访问其状态或触发事件
# 用于缓存 Tqsdk 的 K 线序列,避免每次都 get_kline_serial
self._kline_serial: Dict[str, object] = {}
# 订单/取消请求队列TqsdkEngine 会在异步循环中处理它们
self.order_queue: Deque[Order] = deque()
self.cancel_queue: Deque[str] = deque() # 存储 order_id
print("TqsdkContext: 初始化完成。")
def set_current_bar(self, bar: Bar):
"""
设置当前正在处理的 K 线数据。
由 TqsdkEngine 调用。
"""
self._current_bar = bar
def get_current_bar(self) -> Optional[Bar]:
"""
获取当前正在处理的 K 线数据。
策略可以通过此方法获取最新 K 线。
"""
return self._current_bar
def get_kline_data(self, symbol: str, duration_seconds: int, data_length: int = 10):
"""
获取指定合约的 K 线数据。
返回 Tqsdk 的 DataFrame 格式 K 线序列TqKLine 对象),可以直接用于计算指标。
如果需要转换为你自己的 Bar 对象列表,则需要在此方法内部进行转换。
"""
if symbol not in self._kline_serial:
# 这里的 get_kline_serial 并不是实时获取,而是在 TqApi 启动时就已经加载
# 所以在 Context 中直接调用是安全的TqApi 会返回已加载的数据引用
self._kline_serial[symbol] = self._api.get_kline_serial(symbol, duration_seconds, data_length=data_length)
return self._kline_serial[symbol]
def get_current_time(self) -> datetime:
"""
获取当前模拟时间Tqsdk 的数据时间)。
"""
# Tqsdk 的 get_tick_timestamp() 返回微秒时间戳
return self.get_current_bar().datetime
def get_current_positions(self) -> Dict[str, int]:
"""
获取当前所有持仓。返回 {symbol: quantity} 的字典quantity 为净持仓量(多头-空头)。
"""
tq_positions: Dict[str] = self._api.get_position()
converted_positions: Dict[str, int] = {}
for symbol, pos in tq_positions.items():
net_pos = pos.pos_long - pos.pos_short
if net_pos != 0:
converted_positions[symbol] = net_pos
return converted_positions
def get_pending_orders(self) -> Dict[str, Order]:
"""
获取当前所有待处理(未成交)订单。
返回 {order_id: Order} 的字典。
"""
tq_orders: Dict[str] = self._api.get_order()
pending_orders: Dict[str, Order] = {}
for order_id, tq_order in tq_orders.items():
if tq_order.status == "ALIVE": # 正在进行中的订单
# 将 TqOrder 转换为你自己的 core_data.Order 类型
# 注意core_data.Order 的 direction 有 "CLOSE_LONG", "CLOSE_SHORT" 等,需要映射
# Tqsdk 的 direction 只有 "BUY", "SELL"
# Tqsdk 的 offset 决定了是开仓还是平仓
core_direction: Literal["BUY", "SELL", "CLOSE_LONG", "CLOSE_SHORT"]
if tq_order.offset == "OPEN":
core_direction = tq_order.direction # 开仓时方向直接对应买卖
elif tq_order.offset in ["CLOSE", "CLOSETODAY", "CLOSEYESTERDAY"]:
# 平仓时,买入平空,卖出平多
core_direction = "CLOSE_SHORT" if tq_order.direction == "BUY" else "CLOSE_LONG"
else: # 默认为 BUY/SELL
core_direction = tq_order.direction
converted_order = Order(
id=tq_order.order_id, # 将 Tqsdk 的 order_id 赋值给你的 Order 类的 id
symbol=tq_order.exchange_id + "." + tq_order.instrument_id, # 例如 "SHFE.rb2401"
direction=core_direction,
volume=tq_order.volume_orign,
price_type="LIMIT" if tq_order.limit_price is not None else "MARKET", # Tqsdk 市价单类型为 "ANY"
limit_price=tq_order.limit_price,
offset=tq_order.offset, # Tqsdk 原生 offset
# order_id=tq_order.order_id, # 存储 Tqsdk 的 order_id
submitted_time=pd.to_datetime(tq_order.insert_date_time, unit="ns", utc=True),
# status=tq_order.status # 保持 Tqsdk 的状态字符串
)
pending_orders[order_id] = converted_order
return pending_orders
def get_account_cash(self) -> float:
"""
获取当前可用现金。
"""
account: TqAccount = self._api.get_account()
return account.available_cash if account else 0.0
def get_average_position_price(self, symbol: str) -> Optional[float]:
"""
获取指定合约的平均持仓成本。
注意: Tqsdk 的 TqPosition 对象中包含了 open_price_long 和 open_price_short。
这里需要根据多头或空头持仓返回对应的平均成本。
"""
position = self._api.get_position(symbol)
# if position:
# return avg_cost
if position:
if position.pos_long > 0:
return position.open_price_long
elif position.pos_short > 0:
return position.open_price_short
return None
def send_order(self, order: Order) -> Optional[Order]:
"""
策略通过此方法发送订单。
将订单放入队列,等待 TqsdkEngine 在其异步循环中处理。
"""
# 为订单分配一个临时ID便于在队列中追踪实际ID由Tqsdk返回后更新
if not order.id: # 使用 order.id 属性
order.id = f"LOCAL_{id(order)}_{datetime.now().strftime('%f')}"
order.order_id = order.id # 保持 Tqsdk 风格的 order_id 也一致
self.order_queue.append(order)
print(f"Context: 订单已加入队列: {order}")
return order # 返回传入的订单待引擎更新其状态和ID
def cancel_order(self, order_id: str) -> bool:
"""
策略通过此方法取消指定ID的订单。
将取消请求放入队列,等待 TqsdkEngine 在其异步循环中处理。
"""
# 检查订单是否处于待处理状态
if order_id in self.get_pending_orders():
self.cancel_queue.append(order_id)
print(f"Context: 取消订单请求已加入队列: {order_id}")
return True
print(f"Context: 订单 {order_id} 不在待处理队列中,无法取消。")
return False
def set_engine(self, engine: 'TqsdkEngine'): # 使用 TYPE_CHECKING 中的 TqsdkEngine 类型提示
"""
设置对 TqsdkEngine 实例的引用。
由 TqsdkEngine 在初始化时调用,用于允许 Context 访问 Engine 的状态。
"""
self._engine = engine
print("TqsdkContext: 已设置引擎引用。")
@property
def is_rollover_bar(self) -> bool:
"""
属性:判断当前 K 线是否为换月 K 线(即新合约的第一根 K 线)。
用于在换月时禁止策略开仓。
Tqsdk 的回测模式下,通常通过主力连续合约或多合约同时回测来处理换月。
此处为适配原有接口的简化实现。如果你需要 Tqsdk 的换月逻辑,
可能需要在 TqsdkEngine 中实现更复杂的判断,并通过 Context 暴露此状态。
对于 Tqsdk 的主力连续合约,通常不需要策略层面关心具体的换月 K 线。
"""
# 如果引擎设置了 is_rollover_bar 属性,则使用引擎的判断
if self._engine and hasattr(self._engine, 'is_rollover_bar'):
return self._engine.is_rollover_bar
return False # 默认返回 False
def get_bar_history(self):
return self._engine.get_bar_history()

491
src/tqsdk_engine.py Normal file
View File

@@ -0,0 +1,491 @@
# filename: tqsdk_engine.py
import asyncio
from datetime import date, datetime, timedelta
from typing import Literal, Type, Dict, Any, List, Optional
import pandas as pd
import uuid
# 导入你提供的 core_data 中的类型
from src.core_data import Bar, Order, Trade, PortfolioSnapshot
# 导入 Tqsdk 的核心类型
import tqsdk
from tqsdk import (
TqApi,
TqAccount,
tafunc,
TqSim,
TqBacktest,
TqAuth,
TargetPosTask,
BacktestFinished,
)
# 导入 TqsdkContext 和 BaseStrategy
from src.tqsdk_context import TqsdkContext
from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径
BEIJING_TZ = "Asia/Shanghai"
class TqsdkEngine:
"""
Tqsdk 回测引擎:协调 Tqsdk 数据流、策略执行、订单模拟和结果记录。
替代原有的 BacktestEngine。
"""
def __init__(
self,
strategy_class: Type[Strategy],
strategy_params: Dict[str, Any],
api: TqApi,
roll_over_mode: bool = False, # 是否开启换月模式检测
symbol: str = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
duration_seconds: int = 1,
):
"""
初始化 Tqsdk 回测引擎。
Args:
strategy_class (Type[Strategy]): 策略类。
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
data_path (str): 本地 K 线数据文件路径,用于 TqSim 加载。
initial_capital (float): 初始资金。
slippage_rate (float): 交易滑点率(在 Tqsdk 中通常需要手动实现或通过费用设置)。
commission_rate (float): 交易佣金率(在 Tqsdk 中通常需要手动实现或通过费用设置)。
roll_over_mode (bool): 是否启用换月检测。
start_time (Optional[datetime]): 回测开始时间。
end_time (Optional[datetime]): 回测结束时间。
"""
self.strategy_class = strategy_class
self.strategy_params = strategy_params
self.roll_over_mode = roll_over_mode
self.start_time = start_time
self.end_time = end_time
# Tqsdk API 和模拟器
# 这里使用 file_path 参数指定本地数据文件
self._api: TqApi = api
# 从策略参数中获取主symbolTqsdkContext 需要知道它
self.symbol: str = strategy_params.get("symbol")
if not self.symbol:
raise ValueError("strategy_params 必须包含 'symbol' 字段")
# 获取 K 线数据Tqsdk 自动处理)
# 这里假设策略所需 K 线周期在 strategy_params 中否则默认60秒1分钟K线
self.bar_duration_seconds: int = strategy_params.get("bar_duration_seconds", 60)
# self._main_kline_serial = self._api.get_kline_serial(
# self.symbol, self.bar_duration_seconds
# )
# 初始化上下文
self._context: TqsdkContext = TqsdkContext(api=self._api)
# 实例化策略,并将上下文传递给它
self._strategy: Strategy = self.strategy_class(
context=self._context, **self.strategy_params
)
self._context.set_engine(
self
) # 将引擎自身传递给上下文,以便 Context 可以访问引擎属性
self.portfolio_snapshots: List[PortfolioSnapshot] = []
self.trade_history: List[Trade] = []
self.all_bars: List[Bar] = [] # 收集所有处理过的Bar
self.last_processed_bar: Optional[Bar] = None
self._is_rollover_bar: bool = False # 换月信号
self._last_underlying_symbol = self.symbol # 用于检测主力合约换月
self.klines = api.get_kline_serial(symbol, duration_seconds)
self.klines_1min = api.get_kline_serial(symbol, 60)
self.now = None
self.quote = None
if roll_over_mode:
self.quote = api.get_quote(symbol)
print("TqsdkEngine: 初始化完成。")
@property
def is_rollover_bar(self) -> bool:
"""
属性:判断当前 K 线是否为换月 K 线(即检测到主力合约切换)。
"""
return self._is_rollover_bar
def _process_queued_requests(self):
"""
异步处理 Context 中排队的订单和取消请求。
"""
# 处理订单
while self._context.order_queue:
order_to_send: Order = self._context.order_queue.popleft()
print(f"Engine: 处理订单请求: {order_to_send}")
# 映射 core_data.Order 到 Tqsdk 的订单参数
tqsdk_direction = ""
tqsdk_offset = ""
if order_to_send.direction == "BUY":
tqsdk_direction = "BUY"
tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓
elif order_to_send.direction == "SELL":
tqsdk_direction = "SELL"
tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓
elif order_to_send.direction == "CLOSE_LONG":
tqsdk_direction = "SELL"
tqsdk_offset = order_to_send.offset or "CLOSE" # 平多,默认平仓
elif order_to_send.direction == "CLOSE_SHORT":
tqsdk_direction = "BUY"
tqsdk_offset = order_to_send.offset or "CLOSE" # 平空,默认平仓
else:
print(f"Engine: 未知订单方向: {order_to_send.direction}")
continue # 跳过此订单
if "SHFE" in order_to_send.symbol:
tqsdk_offset = "OPEN"
try:
tq_order = self._api.insert_order(
symbol=order_to_send.symbol,
direction=tqsdk_direction,
offset=tqsdk_offset,
volume=order_to_send.volume,
# Tqsdk 市价单 limit_price 设为 None限价单则传递价格
limit_price=(
order_to_send.limit_price
if order_to_send.price_type == "LIMIT"
# else self.quote.bid_price1 + (1 if tqsdk_direction == "BUY" else -1)
else self.quote.bid_price1 if tqsdk_direction == "SELL" else self.quote.ask_price1
),
)
# 更新原始 Order 对象与 Tqsdk 的订单ID和状态
order_to_send.id = tq_order.order_id
# order_to_send.order_id = tq_order.order_id
# order_to_send.status = tq_order.status
order_to_send.submitted_time = pd.to_datetime(
tq_order.insert_date_time, unit="ns", utc=True
)
# 等待订单状态更新(成交/撤销/报错)
# 在 Tqsdk 中,订单和成交是独立的,通常在 wait_update() 循环中通过 api.is_changing() 检查
# 这里为了模拟同步处理,直接等待订单状态最终确定
# 注意:实际回测中,不应在这里长时间阻塞,而应在主循环中持续 wait_update
# 为了简化适配,这里模拟即时处理,但可能与真实异步行为有差异。
# 更健壮的方式是在主循环中通过订单状态回调更新
# 这里我们假设订单会很快更新状态,或者在下一个 wait_update() 周期中被检测到
self._api.wait_update() # 等待一次更新
# # 检查最终订单状态和成交
# if tq_order.status == "FINISHED":
# # 查找对应的成交记录
# for trade_id, tq_trade in self._api.get_trade().items():
# if tq_trade.order_id == tq_order.order_id and tq_trade.volume > 0: # 确保是实际成交
# # 创建 core_data.Trade 对象
# trade = Trade(
# order_id=tq_trade.order_id,
# fill_time=tafunc.get_datetime_from_timestamp(tq_trade.trade_date_time) if tq_trade.trade_date_time else datetime.now(),
# symbol=order_to_send.symbol, # 使用 Context 中的 symbol
# direction=tq_trade.direction, # 实际成交方向
# volume=tq_trade.volume,
# price=tq_trade.price,
# commission=tq_trade.commission,
# cash_after_trade=self._api.get_account().available,
# positions_after_trade=self._context.get_current_positions(),
# realized_pnl=tq_trade.realized_pnl, # Tqsdk TqTrade 对象有 realized_pnl
# is_open_trade=tq_trade.offset == "OPEN",
# is_close_trade=tq_trade.offset in ["CLOSE", "CLOSETODAY", "CLOSEYESTERDAY"]
# )
# self.trade_history.append(trade)
# print(f"Engine: 成交记录: {trade}")
# break # 找到成交就跳出
# order_to_send.status = tq_order.status # 更新最终状态
except Exception as e:
print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}")
# order_to_send.status = "ERROR"
# 处理取消请求
while self._context.cancel_queue:
order_id_to_cancel = self._context.cancel_queue.popleft()
print(f"Engine: 处理取消请求: {order_id_to_cancel}")
tq_order_to_cancel = self._api.get_order(order_id_to_cancel)
if tq_order_to_cancel and tq_order_to_cancel.status == "ALIVE":
try:
self._api.cancel_order(tq_order_to_cancel)
self._api.wait_update() # 等待取消确认
print(
f"Engine: 订单 {order_id_to_cancel} 已尝试取消。当前状态: {tq_order_to_cancel.status}"
)
except Exception as e:
print(f"Engine: 取消订单 {order_id_to_cancel} 失败: {e}")
else:
print(
f"Engine: 订单 {order_id_to_cancel} 不存在或已非活动状态,无法取消。"
)
def _record_portfolio_snapshot(self, current_time: datetime):
"""
记录当前投资组合的快照。
"""
account: TqAccount = self._api.get_account()
current_positions = self._context.get_current_positions()
# 计算当前持仓市值
total_market_value = 0.0
current_prices: Dict[str, float] = {}
for symbol, qty in current_positions.items():
# 获取当前合约的最新价格
quote = self._api.get_quote(symbol)
if quote.last_price: # 确保价格是最近的
price = quote.last_price
current_prices[symbol] = price
total_market_value += (
price * qty * quote.volume_multiple
) # volume_multiple 乘数
else:
# 如果没有最新价格使用最近的K线收盘价作为估算
# 在实盘或连续回测中,通常会有最新的行情
print(f"警告: 未获取到 {symbol} 最新价格,可能影响净值计算。")
# 可以尝试从 K 线获取最近价格
kline = self._api.get_kline_serial(symbol, self.bar_duration_seconds)
if not kline.empty:
last_kline = kline.iloc[-2]
price = last_kline.close
current_prices[symbol] = price
total_market_value += (
price * qty * self._api.get_instrument(symbol).volume_multiple
) # 使用 instrument 的乘数
total_value = (
account.available + account.frozen_margin + total_market_value
) # Tqsdk 的 balance 已包含持仓市值和冻结资金
# Tqsdk 的 total_profit/balance 已经包含了所有盈亏和资金
snapshot = PortfolioSnapshot(
datetime=current_time,
total_value=account.balance, # Tqsdk 的 balance 包含了可用资金、冻结保证金和持仓市值
cash=account.available,
positions=current_positions,
price_at_snapshot=current_prices,
)
self.portfolio_snapshots.append(snapshot)
def _close_all_positions_at_end(self):
"""
回测结束时,平掉所有剩余持仓。
"""
current_positions = self._context.get_current_positions()
if not current_positions:
print("回测结束:没有需要平仓的持仓。")
return
print("回测结束:开始平仓所有剩余持仓...")
for symbol, qty in current_positions.items():
order_direction: Literal["BUY", "SELL"]
if qty > 0: # 多头持仓,卖出平仓
order_direction = "SELL"
else: # 空头持仓,买入平仓
order_direction = "BUY"
TargetPosTask(self._api, symbol).set_target_volume(0)
# # 使用市价单快速平仓
# tq_order = self._api.insert_order(
# symbol=symbol,
# direction=order_direction,
# offset="CLOSE", # 平仓
# volume=abs(qty),
# limit_price=self
# )
# print(f"平仓订单已发送: {symbol} {order_direction} {abs(qty)} 手")
# 等待订单完成
# while tq_order.status == "ALIVE":
# self._api.wait_update()
# if tq_order.status == "FINISHED":
# print(f"订单 {tq_order.order_id} 平仓完成。")
# else:
# print(f"订单 {tq_order.order_id} 平仓失败或未完成,状态: {tq_order.status}")
def _run_backtest_async(self):
"""
异步运行回测的主循环。
"""
print(f"TqsdkEngine: 开始运行回测,从 {self.start_time}{self.end_time}")
# 初始化策略 (如果策略有 on_init 方法)
if hasattr(self._strategy, "on_init"):
self._strategy.on_init()
last_bar_datetime = None
# 迭代 K 线数据
# 使用 self._api.get_kline_serial 获取到的 K 线是 Pandas DataFrame
# 直接迭代其行Bar更符合回测逻辑
try:
while True:
# Tqsdk API 的 wait_update() 确保数据更新
self._api.wait_update()
if self.roll_over_mode and (
self._api.is_changing(self.quote, "underlying_symbol")
or self._last_underlying_symbol != self.quote.underlying_symbol
):
self._last_underlying_symbol = self.quote.underlying_symbol
if self._api.is_changing(self.klines_1min):
now_kline = self.klines_1min.iloc[-1]
now_dt = pd.to_datetime(now_kline.datetime, unit="ns", utc=True)
now_dt = now_dt.tz_convert(BEIJING_TZ)
if (
self.now is not None
and self.now.hour != 13
and now_dt.hour == 13
):
self.main()
self.now = now_dt
if self._api.is_changing(self.klines):
kline_row = self.klines.iloc[-1]
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
if kline_dt.hour == 13 and self.now.hour == 11:
continue
else:
self.main()
except BacktestFinished:
# 回测结束时,确保所有排队请求得到处理
self._process_queued_requests()
# 回测结束后,如果需要,平掉所有剩余持仓
self._close_all_positions_at_end()
print("TqsdkEngine: 回测运行完毕。")
def main(self):
kline_row = self.klines.iloc[-1]
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
if (
self.last_processed_bar is None
or self.last_processed_bar.datetime != kline_dt
):
# 创建 core_data.Bar 对象
current_bar = Bar(
datetime=kline_dt,
symbol=self._last_underlying_symbol,
open=kline_row.open,
high=kline_row.high,
low=kline_row.low,
close=kline_row.close,
volume=kline_row.volume,
open_oi=kline_row.open_oi,
close_oi=kline_row.close_oi,
)
# 设置当前 Bar 到 Context
self._context.set_current_bar(current_bar)
# Tqsdk 的 is_changing 用于判断数据是否有变化,对于回测遍历 K 线,每次迭代都算作新 Bar
# 如果 kline_row.datetime 与上次不同,则认为是新 Bar
if (
self.roll_over_mode
and self.last_processed_bar is not None
and self._last_underlying_symbol != self.last_processed_bar.symbol
):
self._is_rollover_bar = True
print(
f"TqsdkEngine: 检测到换月信号!从 {self._last_underlying_symbol} 切换到 {self.quote.underlying_symbol}"
)
self._close_all_positions_at_end()
self._strategy.cancel_all_pending_orders()
self._strategy.on_rollover(
self.last_processed_bar.symbol, self._last_underlying_symbol
)
else:
self._is_rollover_bar = False
self.all_bars.append(current_bar)
self.last_processed_bar = current_bar
# 调用策略的 on_bar 方法
self._strategy.on_open_bar(current_bar)
# 处理订单和取消请求
self._process_queued_requests()
# 记录投资组合快照
self._record_portfolio_snapshot(current_bar.datetime)
else:
# 创建 core_data.Bar 对象
current_bar = Bar(
datetime=kline_dt,
symbol=self._last_underlying_symbol,
open=kline_row.open,
high=kline_row.high,
low=kline_row.low,
close=kline_row.close,
volume=kline_row.volume,
open_oi=kline_row.open_oi,
close_oi=kline_row.close_oi,
)
self.all_bars[-1] = current_bar
self.last_processed_bar = current_bar
# 设置当前 Bar 到 Context
self._context.set_current_bar(current_bar)
# 调用策略的 on_bar 方法
self._strategy.on_close_bar(current_bar)
# 处理订单和取消请求
self._process_queued_requests()
def run_backtest(self):
"""
同步调用异步回测主循环。
"""
try:
self._run_backtest_async()
except KeyboardInterrupt:
print("\n回测被用户中断。")
finally:
self._api.close()
print("TqsdkEngine: API 已关闭。")
def get_backtest_results(self) -> Dict[str, Any]:
"""
返回回测结果数据,供结果分析模块使用。
"""
final_portfolio_value = 0.0
if self.portfolio_snapshots:
final_portfolio_value = self.portfolio_snapshots[-1].total_value
# else:
# final_portfolio_value = self.initial_capital # 如果没有快照,则净值是初始资金
# total_return_percentage = (
# (final_portfolio_value - self.initial_capital) / self.initial_capital
# ) * 100 if self.initial_capital != 0 else 0.0
return {
"portfolio_snapshots": self.portfolio_snapshots,
"trade_history": self.trade_history,
# "initial_capital": self.initial_capital,
"all_bars": self.all_bars,
"final_portfolio_value": final_portfolio_value,
# "total_return_percentage": total_return_percentage,
}
def get_bar_history(self):
return self.all_bars

0
test Normal file
View File

11330
tqsdk_main.ipynb Normal file

File diff suppressed because one or more lines are too long