同步本地回测与tqsdk回测
This commit is contained in:
@@ -1,18 +1,19 @@
|
|||||||
from tqsdk import TqApi, TqAuth
|
from datetime import date
|
||||||
|
from tqsdk import TqApi, TqAuth, TqBacktest, TargetPosTask
|
||||||
|
|
||||||
api = TqApi(auth=TqAuth("emanresu", "dfgvfgdfgg"))
|
'''
|
||||||
|
如果当前价格大于5分钟K线的MA15则开多仓
|
||||||
|
如果小于则平仓
|
||||||
|
回测从 2018-05-01 到 2018-10-01
|
||||||
|
'''
|
||||||
|
# 在创建 api 实例时传入 TqBacktest 就会进入回测模式
|
||||||
|
api = TqApi(backtest=TqBacktest(start_dt=date(2018, 5, 1), end_dt=date(2018, 10, 1)), auth=TqAuth("emanresu", "dfgvfgdfgg"))
|
||||||
|
# 获得 m1901 5分钟K线的引用
|
||||||
|
klines = api.get_kline_serial("DCE.m1901", 60 * 60, data_length=15)
|
||||||
|
# 创建 m1901 的目标持仓 task,该 task 负责调整 m1901 的仓位到指定的目标仓位
|
||||||
|
target_pos = TargetPosTask(api, "DCE.m1901")
|
||||||
|
|
||||||
# au 品种指数合约
|
while True:
|
||||||
ls = api.query_quotes(ins_class="INDEX", product_id="au")
|
api.wait_update()
|
||||||
print(ls)
|
if api.is_changing(klines) and len(klines) > 2:
|
||||||
|
target_pos.set_target_volume(5)
|
||||||
# au 品种主连合约
|
|
||||||
ls = api.query_quotes(ins_class="CONT", product_id="au")
|
|
||||||
print(ls)
|
|
||||||
|
|
||||||
quote = api.get_quote("KQ.m@SHFE.rb")
|
|
||||||
# 打印现在螺纹钢主连的标的合约
|
|
||||||
print(quote.underlying_symbol)
|
|
||||||
|
|
||||||
# 关闭api,释放相应资源
|
|
||||||
api.close()
|
|
||||||
@@ -109,6 +109,7 @@ def collect_and_save_tqsdk_data_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
last_kline_datetime = None # 用于跟踪上一根已完成K线的时间
|
last_kline_datetime = None # 用于跟踪上一根已完成K线的时间
|
||||||
|
swap_month_dt = None
|
||||||
|
|
||||||
while api.wait_update():
|
while api.wait_update():
|
||||||
if underlying_symbol is None:
|
if underlying_symbol is None:
|
||||||
@@ -117,7 +118,11 @@ def collect_and_save_tqsdk_data_stream(
|
|||||||
# 检查是否有新的完整K线生成,或者当前K线是最后一次更新 (在回测结束时)
|
# 检查是否有新的完整K线生成,或者当前K线是最后一次更新 (在回测结束时)
|
||||||
# TqSdk会在K线完成时发送最后一次更新,或者在回测结束时强制更新
|
# TqSdk会在K线完成时发送最后一次更新,或者在回测结束时强制更新
|
||||||
if api.is_changing(quote, "underlying_symbol"):
|
if api.is_changing(quote, "underlying_symbol"):
|
||||||
underlying_symbol = quote.underlying_symbol
|
swap_month_dt = pd.to_datetime(
|
||||||
|
quote.datetime, unit="ns", utc=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if api.is_changing(klines):
|
if api.is_changing(klines):
|
||||||
# 只有当K线序列发生变化时才处理
|
# 只有当K线序列发生变化时才处理
|
||||||
# 关注最新一根 K 线(即 klines.iloc[-1])
|
# 关注最新一根 K 线(即 klines.iloc[-1])
|
||||||
@@ -141,10 +146,15 @@ def collect_and_save_tqsdk_data_stream(
|
|||||||
|
|
||||||
kline_dt = pd.to_datetime(
|
kline_dt = pd.to_datetime(
|
||||||
current_kline["datetime"], unit="ns", utc=True
|
current_kline["datetime"], unit="ns", utc=True
|
||||||
)
|
).tz_convert(BEIJING_TZ)
|
||||||
kline_dt = kline_dt.tz_convert(BEIJING_TZ).strftime(
|
|
||||||
|
if swap_month_dt is not None and kline_dt.hour == swap_month_dt.hour:
|
||||||
|
underlying_symbol = quote.underlying_symbol
|
||||||
|
|
||||||
|
kline_dt = kline_dt.strftime(
|
||||||
"%Y-%m-%d %H:%M:%S"
|
"%Y-%m-%d %H:%M:%S"
|
||||||
)
|
)
|
||||||
|
|
||||||
kline_data_to_save = {
|
kline_data_to_save = {
|
||||||
"datetime": kline_dt,
|
"datetime": kline_dt,
|
||||||
"open": current_kline["open"],
|
"open": current_kline["open"],
|
||||||
|
|||||||
25621
grid_search.ipynb
25621
grid_search.ipynb
File diff suppressed because one or more lines are too long
13757
main.ipynb
13757
main.ipynb
File diff suppressed because one or more lines are too long
@@ -124,27 +124,30 @@ def calculate_metrics(
|
|||||||
"亏损交易次数": losing_count,
|
"亏损交易次数": losing_count,
|
||||||
"平均每次盈利": avg_profit_per_trade,
|
"平均每次盈利": avg_profit_per_trade,
|
||||||
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
|
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
|
||||||
"InitialCapital": initial_capital,
|
"initial_capital": initial_capital,
|
||||||
"FinalCapital": final_value,
|
"final_capital": final_value,
|
||||||
"TotalReturn": total_return,
|
"total_return": total_return,
|
||||||
"AnnualizedReturn": annualized_return,
|
"annualized_return": annualized_return,
|
||||||
"MaxDrawdown": max_drawdown,
|
"max_drawdown": max_drawdown,
|
||||||
"SharpeRatio": sharpe_ratio,
|
"sharpe_ratio": sharpe_ratio,
|
||||||
"CalmarRatio": calmar_ratio,
|
"calmar_ratio": calmar_ratio,
|
||||||
"TotalTrades": len(trades), # All buy and sell trades
|
"total_trades": len(trades), # All buy and sell trades
|
||||||
"TransactionCosts": total_commissions,
|
"transaction_costs": total_commissions,
|
||||||
"TotalRealizedPNL": total_realized_pnl, # New
|
"total_realized_pnl": total_realized_pnl, # New
|
||||||
"WinRate": win_rate,
|
"win_rate": win_rate,
|
||||||
"ProfitLossRatio": profit_loss_ratio,
|
"profit_loss_ratio": profit_loss_ratio,
|
||||||
"WinningTradesCount": winning_count,
|
"winning_trades_count": winning_count,
|
||||||
"LosingTradesCount": losing_count,
|
"losing_trades_count": losing_count,
|
||||||
"AvgProfitPerTrade": avg_profit_per_trade,
|
"avg_profit_per_trade": avg_profit_per_trade,
|
||||||
"AvgLossPerTrade": avg_loss_per_trade, # This value is negative
|
"avg_loss_per_trade": avg_loss_per_trade, # This value is negative
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_capital: float,
|
def plot_equity_and_drawdown_chart(
|
||||||
title: str = "Portfolio Equity and Drawdown Curve") -> None:
|
snapshots: List[PortfolioSnapshot],
|
||||||
|
initial_capital: float,
|
||||||
|
title: str = "Portfolio Equity and Drawdown Curve",
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
|
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
|
||||||
|
|
||||||
@@ -157,35 +160,45 @@ def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_c
|
|||||||
print("No portfolio snapshots available to plot equity and drawdown.")
|
print("No portfolio snapshots available to plot equity and drawdown.")
|
||||||
return
|
return
|
||||||
|
|
||||||
df_equity = pd.DataFrame([
|
df_equity = pd.DataFrame(
|
||||||
{'datetime': s.datetime, 'total_value': s.total_value}
|
[{"datetime": s.datetime, "total_value": s.total_value} for s in snapshots]
|
||||||
for s in snapshots
|
)
|
||||||
])
|
|
||||||
|
|
||||||
equity_curve = df_equity['total_value'] / initial_capital
|
equity_curve = df_equity["total_value"] / initial_capital
|
||||||
|
|
||||||
rolling_max = equity_curve.cummax()
|
rolling_max = equity_curve.cummax()
|
||||||
drawdown = (rolling_max - equity_curve) / rolling_max
|
drawdown = (rolling_max - equity_curve) / rolling_max
|
||||||
|
|
||||||
plt.style.use('seaborn-v0_8-darkgrid')
|
plt.style.use("seaborn-v0_8-darkgrid")
|
||||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]})
|
fig, (ax1, ax2) = plt.subplots(
|
||||||
|
2, 1, figsize=(14, 10), sharex=True, gridspec_kw={"height_ratios": [3, 1]}
|
||||||
|
)
|
||||||
|
|
||||||
x_axis_indices = np.arange(len(df_equity))
|
x_axis_indices = np.arange(len(df_equity))
|
||||||
|
|
||||||
# Equity Curve Plot
|
# Equity Curve Plot
|
||||||
ax1.plot(x_axis_indices, equity_curve, label='Equity Curve', color='blue', linewidth=1.5)
|
ax1.plot(
|
||||||
ax1.set_ylabel('Equity', fontsize=12)
|
x_axis_indices, equity_curve, label="Equity Curve", color="blue", linewidth=1.5
|
||||||
ax1.legend(loc='upper left')
|
)
|
||||||
|
ax1.set_ylabel("Equity", fontsize=12)
|
||||||
|
ax1.legend(loc="upper left")
|
||||||
ax1.grid(True)
|
ax1.grid(True)
|
||||||
ax1.set_title(title, fontsize=16)
|
ax1.set_title(title, fontsize=16)
|
||||||
|
|
||||||
# Drawdown Curve Plot
|
# Drawdown Curve Plot
|
||||||
ax2.fill_between(x_axis_indices, 0, drawdown, color='red', alpha=0.3)
|
ax2.fill_between(x_axis_indices, 0, drawdown, color="red", alpha=0.3)
|
||||||
ax2.plot(x_axis_indices, drawdown, color='red', linewidth=1.0, linestyle='--', label='Drawdown')
|
ax2.plot(
|
||||||
ax2.set_ylabel('Drawdown Rate', fontsize=12)
|
x_axis_indices,
|
||||||
ax2.set_xlabel('Data Point Index (Date Labels Below)', fontsize=12)
|
drawdown,
|
||||||
ax2.set_title('Portfolio Drawdown Curve', fontsize=14)
|
color="red",
|
||||||
ax2.legend(loc='upper left')
|
linewidth=1.0,
|
||||||
|
linestyle="--",
|
||||||
|
label="Drawdown",
|
||||||
|
)
|
||||||
|
ax2.set_ylabel("Drawdown Rate", fontsize=12)
|
||||||
|
ax2.set_xlabel("Data Point Index (Date Labels Below)", fontsize=12)
|
||||||
|
ax2.set_title("Portfolio Drawdown Curve", fontsize=14)
|
||||||
|
ax2.legend(loc="upper left")
|
||||||
ax2.grid(True)
|
ax2.grid(True)
|
||||||
ax2.set_ylim(0, max(drawdown.max() * 1.1, 0.05))
|
ax2.set_ylim(0, max(drawdown.max() * 1.1, 0.05))
|
||||||
|
|
||||||
@@ -193,9 +206,12 @@ def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_c
|
|||||||
num_ticks = 10
|
num_ticks = 10
|
||||||
if len(df_equity) > 0:
|
if len(df_equity) > 0:
|
||||||
tick_positions = np.linspace(0, len(df_equity) - 1, num_ticks, dtype=int)
|
tick_positions = np.linspace(0, len(df_equity) - 1, num_ticks, dtype=int)
|
||||||
tick_labels = [df_equity['datetime'].iloc[i].strftime('%Y-%m-%d %H:%M') for i in tick_positions]
|
tick_labels = [
|
||||||
|
df_equity["datetime"].iloc[i].strftime("%Y-%m-%d %H:%M")
|
||||||
|
for i in tick_positions
|
||||||
|
]
|
||||||
ax1.set_xticks(tick_positions)
|
ax1.set_xticks(tick_positions)
|
||||||
ax1.set_xticklabels(tick_labels, rotation=45, ha='right')
|
ax1.set_xticklabels(tick_labels, rotation=45, ha="right")
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
@@ -213,30 +229,38 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
|
|||||||
print("No bar data available to plot close price.")
|
print("No bar data available to plot close price.")
|
||||||
return
|
return
|
||||||
|
|
||||||
df_prices = pd.DataFrame([
|
df_prices = pd.DataFrame(
|
||||||
{'datetime': b.datetime, 'close_price': b.close}
|
[{"datetime": b.datetime, "close_price": b.close} for b in bars]
|
||||||
for b in bars
|
)
|
||||||
])
|
|
||||||
|
|
||||||
plt.style.use('seaborn-v0_8-darkgrid')
|
plt.style.use("seaborn-v0_8-darkgrid")
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
|
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
|
||||||
|
|
||||||
x_axis_indices = np.arange(len(df_prices))
|
x_axis_indices = np.arange(len(df_prices))
|
||||||
|
|
||||||
ax.plot(x_axis_indices, df_prices['close_price'], label='Close Price', color='orange', linewidth=1.5)
|
ax.plot(
|
||||||
ax.set_ylabel('Price', fontsize=12)
|
x_axis_indices,
|
||||||
ax.set_xlabel('Data Point Index (Date Labels Below)', fontsize=12)
|
df_prices["close_price"],
|
||||||
|
label="Close Price",
|
||||||
|
color="orange",
|
||||||
|
linewidth=1.5,
|
||||||
|
)
|
||||||
|
ax.set_ylabel("Price", fontsize=12)
|
||||||
|
ax.set_xlabel("Data Point Index (Date Labels Below)", fontsize=12)
|
||||||
ax.set_title(title, fontsize=16)
|
ax.set_title(title, fontsize=16)
|
||||||
ax.legend(loc='upper left')
|
ax.legend(loc="upper left")
|
||||||
ax.grid(True)
|
ax.grid(True)
|
||||||
|
|
||||||
# Set X-axis ticks to show actual dates at intervals
|
# Set X-axis ticks to show actual dates at intervals
|
||||||
num_ticks = 10
|
num_ticks = 10
|
||||||
if len(df_prices) > 0:
|
if len(df_prices) > 0:
|
||||||
tick_positions = np.linspace(0, len(df_prices) - 1, num_ticks, dtype=int)
|
tick_positions = np.linspace(0, len(df_prices) - 1, num_ticks, dtype=int)
|
||||||
tick_labels = [df_prices['datetime'].iloc[i].strftime('%Y-%m-%d %H:%M') for i in tick_positions]
|
tick_labels = [
|
||||||
|
df_prices["datetime"].iloc[i].strftime("%Y-%m-%d %H:%M")
|
||||||
|
for i in tick_positions
|
||||||
|
]
|
||||||
ax.set_xticks(tick_positions)
|
ax.set_xticks(tick_positions)
|
||||||
ax.set_xticklabels(tick_labels, rotation=45, ha='right')
|
ax.set_xticklabels(tick_labels, rotation=45, ha="right")
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class GridSearchAnalyzer:
|
|||||||
y_idx = y_values.index(item[self.param2_name])
|
y_idx = y_values.index(item[self.param2_name])
|
||||||
heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric]
|
heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric]
|
||||||
|
|
||||||
|
print([x_values[0], x_values[-1], y_values[0], y_values[-1]])
|
||||||
fig, ax = plt.subplots(figsize=(10, 8))
|
fig, ax = plt.subplots(figsize=(10, 8))
|
||||||
im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower',
|
im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower',
|
||||||
extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]],
|
extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]],
|
||||||
|
|||||||
@@ -59,6 +59,18 @@ class ResultAnalyzer:
|
|||||||
"""
|
"""
|
||||||
生成并打印详细的回测报告。
|
生成并打印详细的回测报告。
|
||||||
"""
|
"""
|
||||||
|
if self.trade_history:
|
||||||
|
print("\n--- 交易明细 ---")
|
||||||
|
for trade in self.trade_history:
|
||||||
|
# 调整输出格式,显示实现盈亏
|
||||||
|
pnl_display = f" | PnL: {trade.realized_pnl:.2f}" if trade.is_close_trade else ""
|
||||||
|
print(
|
||||||
|
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Comm: {trade.commission:.2f}{pnl_display}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("\n没有交易记录。")
|
||||||
|
|
||||||
|
|
||||||
metrics = self.calculate_all_metrics()
|
metrics = self.calculate_all_metrics()
|
||||||
|
|
||||||
print("\n--- 回测绩效报告 ---")
|
print("\n--- 回测绩效报告 ---")
|
||||||
@@ -83,17 +95,6 @@ class ResultAnalyzer:
|
|||||||
print(f"{'平均每次亏损':<15}: {metrics['平均每次亏损']:.2f}")
|
print(f"{'平均每次亏损':<15}: {metrics['平均每次亏损']:.2f}")
|
||||||
|
|
||||||
|
|
||||||
if self.trade_history:
|
|
||||||
print("\n--- 部分交易明细 (最近5笔) ---")
|
|
||||||
for trade in self.trade_history[-5:]:
|
|
||||||
# 调整输出格式,显示实现盈亏
|
|
||||||
pnl_display = f" | PnL: {trade.realized_pnl:.2f}" if trade.is_close_trade else ""
|
|
||||||
print(
|
|
||||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Comm: {trade.commission:.2f}{pnl_display}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print("\n没有交易记录。")
|
|
||||||
|
|
||||||
def plot_performance(self) -> None:
|
def plot_performance(self) -> None:
|
||||||
"""
|
"""
|
||||||
绘制投资组合净值和回撤曲线,以及所有合约的收盘价曲线。
|
绘制投资组合净值和回撤曲线,以及所有合约的收盘价曲线。
|
||||||
|
|||||||
@@ -91,6 +91,9 @@ class BacktestContext:
|
|||||||
"""
|
"""
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
|
|
||||||
|
def get_bar_history(self):
|
||||||
|
return self._engine.get_bar_history()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_rollover_bar(self) -> bool:
|
def is_rollover_bar(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -89,9 +89,12 @@ class BacktestEngine:
|
|||||||
# 主回测循环
|
# 主回测循环
|
||||||
while True:
|
while True:
|
||||||
current_bar = self.data_manager.get_next_bar()
|
current_bar = self.data_manager.get_next_bar()
|
||||||
|
|
||||||
if current_bar is None:
|
if current_bar is None:
|
||||||
break # 没有更多数据,回测结束
|
break # 没有更多数据,回测结束
|
||||||
|
|
||||||
|
self.all_bars.append(current_bar)
|
||||||
|
|
||||||
if self.start_time and current_bar.datetime < self.start_time:
|
if self.start_time and current_bar.datetime < self.start_time:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -104,11 +107,11 @@ class BacktestEngine:
|
|||||||
# 1. 重置 is_rollover_bar 标记
|
# 1. 重置 is_rollover_bar 标记
|
||||||
self.is_rollover_bar = False
|
self.is_rollover_bar = False
|
||||||
|
|
||||||
|
# 4. 更新 Context 和 Simulator 的当前 Bar 和时间
|
||||||
|
self.context.set_current_bar(current_bar)
|
||||||
|
self.simulator.update_time(current_time=current_bar.datetime)
|
||||||
|
|
||||||
# 2. 如果启用换月模式,并且检测到合约 symbol 变化
|
# 2. 如果启用换月模式,并且检测到合约 symbol 变化
|
||||||
if current_bar.symbol != self._last_processed_bar_symbol:
|
|
||||||
print(self.roll_over_mode,
|
|
||||||
self._last_processed_bar_symbol,
|
|
||||||
current_bar.symbol, self._last_processed_bar_symbol)
|
|
||||||
if self.roll_over_mode and \
|
if self.roll_over_mode and \
|
||||||
self._last_processed_bar_symbol is not None and \
|
self._last_processed_bar_symbol is not None and \
|
||||||
current_bar.symbol != self._last_processed_bar_symbol:
|
current_bar.symbol != self._last_processed_bar_symbol:
|
||||||
@@ -141,20 +144,22 @@ class BacktestEngine:
|
|||||||
# 3. 更新策略关注的当前合约 symbol
|
# 3. 更新策略关注的当前合约 symbol
|
||||||
self.strategy.symbol = current_bar.symbol
|
self.strategy.symbol = current_bar.symbol
|
||||||
|
|
||||||
# 4. 更新 Context 和 Simulator 的当前 Bar 和时间
|
|
||||||
self.context.set_current_bar(current_bar)
|
|
||||||
self.simulator.update_time(current_time=current_bar.datetime)
|
|
||||||
|
|
||||||
# 5. 更新引擎内部的历史 Bar 缓存
|
# 5. 更新引擎内部的历史 Bar 缓存
|
||||||
self._history_bars.append(current_bar)
|
self._history_bars.append(current_bar)
|
||||||
if len(self._history_bars) > self._max_history_bars:
|
if len(self._history_bars) > self._max_history_bars:
|
||||||
self._history_bars.pop(0)
|
self._history_bars.pop(0)
|
||||||
|
|
||||||
# 6. 处理待撮合订单 (在调用策略 on_bar 之前,确保订单在当前 K 线开盘价撮合)
|
# 6. 处理待撮合订单 (在调用策略 on_bar 之前,确保订单在当前 K 线开盘价撮合)
|
||||||
self.simulator.process_pending_orders(current_bar)
|
# self.simulator.process_pending_orders(current_bar)
|
||||||
|
self.strategy.on_open_bar(current_bar)
|
||||||
|
|
||||||
# 7. 调用策略的 on_bar 方法
|
# 7. 调用策略的 on_bar 方法
|
||||||
self.strategy.on_bar(current_bar)
|
# self.strategy.on_bar(current_bar)
|
||||||
|
self.simulator.process_pending_orders(current_bar)
|
||||||
|
|
||||||
|
self.strategy.on_close_bar(current_bar)
|
||||||
|
self.simulator.process_pending_orders(current_bar)
|
||||||
|
|
||||||
|
|
||||||
# 8. 记录投资组合快照
|
# 8. 记录投资组合快照
|
||||||
current_portfolio_value = self.simulator.get_portfolio_value(current_bar)
|
current_portfolio_value = self.simulator.get_portfolio_value(current_bar)
|
||||||
@@ -170,7 +175,6 @@ class BacktestEngine:
|
|||||||
price_at_snapshot=price_at_snapshot
|
price_at_snapshot=price_at_snapshot
|
||||||
)
|
)
|
||||||
self.portfolio_snapshots.append(snapshot)
|
self.portfolio_snapshots.append(snapshot)
|
||||||
self.all_bars.append(current_bar)
|
|
||||||
|
|
||||||
# 9. 更新 `_last_processed_bar_symbol` 和 `last_processed_bar` 为当前 Bar,为下一轮循环做准备
|
# 9. 更新 `_last_processed_bar_symbol` 和 `last_processed_bar` 为当前 Bar,为下一轮循环做准备
|
||||||
self._last_processed_bar_symbol = current_bar.symbol
|
self._last_processed_bar_symbol = current_bar.symbol
|
||||||
@@ -222,3 +226,7 @@ class BacktestEngine:
|
|||||||
|
|
||||||
def get_simulator(self) -> ExecutionSimulator:
|
def get_simulator(self) -> ExecutionSimulator:
|
||||||
return self.simulator
|
return self.simulator
|
||||||
|
|
||||||
|
def get_bar_history(self):
|
||||||
|
return self.all_bars
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional
|
|||||||
import uuid # 用于生成唯一订单ID
|
import uuid # 用于生成唯一订单ID
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True) # frozen=True 使实例变为不可变
|
@dataclass() # 使实例变为不可变
|
||||||
class Bar:
|
class Bar:
|
||||||
"""
|
"""
|
||||||
K线数据对象,包含期货或股票的 OHLCV 和持仓量信息。
|
K线数据对象,包含期货或股票的 OHLCV 和持仓量信息。
|
||||||
@@ -25,11 +25,11 @@ class Bar:
|
|||||||
"""
|
"""
|
||||||
数据验证(可选):确保持仓量为非负整数。
|
数据验证(可选):确保持仓量为非负整数。
|
||||||
"""
|
"""
|
||||||
if not isinstance(self.volume, int) or self.volume < 0:
|
if self.volume < 0:
|
||||||
raise ValueError(f"Volume must be a non-negative integer, got {self.volume}")
|
raise ValueError(f"Volume must be a non-negative integer, got {self.volume}")
|
||||||
if not isinstance(self.open_oi, int) or self.open_oi < 0:
|
if self.open_oi < 0:
|
||||||
raise ValueError(f"Open interest must be a non-negative integer, got {self.open_oi}")
|
raise ValueError(f"Open interest must be a non-negative integer, got {self.open_oi}")
|
||||||
if not isinstance(self.close_oi, int) or self.close_oi < 0:
|
if self.close_oi < 0:
|
||||||
raise ValueError(f"Close interest must be a non-negative integer, got {self.close_oi}")
|
raise ValueError(f"Close interest must be a non-negative integer, got {self.close_oi}")
|
||||||
|
|
||||||
# 验证价格是否合理
|
# 验证价格是否合理
|
||||||
@@ -41,7 +41,7 @@ class Bar:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass()
|
||||||
class Order:
|
class Order:
|
||||||
"""
|
"""
|
||||||
代表一个待执行的交易指令。
|
代表一个待执行的交易指令。
|
||||||
@@ -53,6 +53,7 @@ class Order:
|
|||||||
price_type: str = "MARKET" # "MARKET", "LIMIT" (简易版默认市价)
|
price_type: str = "MARKET" # "MARKET", "LIMIT" (简易版默认市价)
|
||||||
limit_price: Optional[float] = None # 限价单价格
|
limit_price: Optional[float] = None # 限价单价格
|
||||||
submitted_time: pd.Timestamp = field(default_factory=pd.Timestamp.now) # 订单提交时间
|
submitted_time: pd.Timestamp = field(default_factory=pd.Timestamp.now) # 订单提交时间
|
||||||
|
offset: str = "OPEN"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.direction not in ["BUY", "SELL", "CLOSE_LONG", "CLOSE_SHORT", "CANCEL"]:
|
if self.direction not in ["BUY", "SELL", "CLOSE_LONG", "CLOSE_SHORT", "CANCEL"]:
|
||||||
@@ -63,7 +64,7 @@ class Order:
|
|||||||
raise ValueError(f"Order volume must be a positive integer, got {self.volume}")
|
raise ValueError(f"Order volume must be a positive integer, got {self.volume}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass()
|
||||||
class Trade:
|
class Trade:
|
||||||
"""
|
"""
|
||||||
代表一个已完成的成交记录。
|
代表一个已完成的成交记录。
|
||||||
@@ -83,7 +84,7 @@ class Trade:
|
|||||||
is_close_trade: bool = False # <--- 新增字段:是否是平仓交易 (用于计算盈亏)
|
is_close_trade: bool = False # <--- 新增字段:是否是平仓交易 (用于计算盈亏)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass()
|
||||||
class PortfolioSnapshot:
|
class PortfolioSnapshot:
|
||||||
"""
|
"""
|
||||||
在特定时间点记录投资组合的快照。
|
在特定时间点记录投资组合的快照。
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# src/data_manager.py (修改并添加 get_history_bars 方法)
|
# src/data_manager.py (修改并添加 get_history_bars 方法)
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from typing import Iterator, List, Dict, Any, Optional
|
from typing import Iterator, List, Dict, Any, Optional
|
||||||
import os
|
import os
|
||||||
@@ -15,10 +16,16 @@ class DataManager: # DataManager 现在是一个类,以便维护内部索引
|
|||||||
并提供获取历史Bar的能力。
|
并提供获取历史Bar的能力。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file_path: str, symbol: str, tz='Asia/Shanghai'):
|
def __init__(self, file_path: str, symbol: str, tz='Asia/Shanghai', start_time: Optional[datetime] = None, end_time: Optional[datetime] = None):
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.tz = tz
|
self.tz = tz
|
||||||
self.raw_df = load_raw_data(self.file_path) # 调用函数式加载数据
|
self.raw_df = load_raw_data(self.file_path) # 调用函数式加载数据
|
||||||
|
|
||||||
|
self.start_time = start_time
|
||||||
|
self.end_time = end_time
|
||||||
|
if start_time is not None and end_time is not None:
|
||||||
|
self.raw_df = self.raw_df[(self.raw_df.index >= start_time) & (self.raw_df.index <= end_time)]
|
||||||
|
|
||||||
# self.bars = list(df_to_bar_stream(self.raw_df)) # 一次性转换所有Bar,方便历史数据查找
|
# self.bars = list(df_to_bar_stream(self.raw_df)) # 一次性转换所有Bar,方便历史数据查找
|
||||||
# 优化:使用内部迭代器和缓存,避免一次性生成所有Bar,但可以按需提供历史
|
# 优化:使用内部迭代器和缓存,避免一次性生成所有Bar,但可以按需提供历史
|
||||||
self._bar_generator = df_to_bar_stream(self.raw_df, symbol)
|
self._bar_generator = df_to_bar_stream(self.raw_df, symbol)
|
||||||
@@ -26,6 +33,8 @@ class DataManager: # DataManager 现在是一个类,以便维护内部索引
|
|||||||
self._last_bar_index = -1 # 用于跟踪当前的Bar在原始df中的索引
|
self._last_bar_index = -1 # 用于跟踪当前的Bar在原始df中的索引
|
||||||
self.symbol = symbol
|
self.symbol = symbol
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_next_bar(self) -> Optional[Bar]:
|
def get_next_bar(self) -> Optional[Bar]:
|
||||||
"""
|
"""
|
||||||
按顺序返回下一根 Bar 对象。
|
按顺序返回下一根 Bar 对象。
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ def load_raw_data(file_path: str) -> pd.DataFrame:
|
|||||||
missing_cols = [col for col in expected_cols[1:] if col not in df.columns]
|
missing_cols = [col for col in expected_cols[1:] if col not in df.columns]
|
||||||
if missing_cols:
|
if missing_cols:
|
||||||
print(f"CSV文件中缺少以下列: {', '.join(missing_cols)}")
|
print(f"CSV文件中缺少以下列: {', '.join(missing_cols)}")
|
||||||
|
expected_cols = [col for col in expected_cols if col in df.columns]
|
||||||
|
|
||||||
# 确保数据按时间排序 (这是回测的基础)
|
# 确保数据按时间排序 (这是回测的基础)
|
||||||
df = df.sort_index()
|
df = df.sort_index()
|
||||||
@@ -51,7 +52,7 @@ def load_raw_data(file_path: str) -> pd.DataFrame:
|
|||||||
print(f"数据范围从 {df.index.min()} 到 {df.index.max()}")
|
print(f"数据范围从 {df.index.min()} 到 {df.index.max()}")
|
||||||
print(f"总计 {len(df)} 条记录。")
|
print(f"总计 {len(df)} 条记录。")
|
||||||
|
|
||||||
return df[expected_cols[1:]] # 返回包含核心数据的DataFrame
|
return df[expected_cols] # 返回包含核心数据的DataFrame
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"加载数据时发生错误: {e}")
|
print(f"加载数据时发生错误: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -68,7 +69,6 @@ def df_to_bar_stream(df: pd.DataFrame, symbol: str) -> Iterator[Bar]:
|
|||||||
Bar: 逐个生成的 Bar 对象。
|
Bar: 逐个生成的 Bar 对象。
|
||||||
"""
|
"""
|
||||||
print("开始将 DataFrame 转换为 Bar 对象流...")
|
print("开始将 DataFrame 转换为 Bar 对象流...")
|
||||||
print(df)
|
|
||||||
for index, row in df.iterrows():
|
for index, row in df.iterrows():
|
||||||
try:
|
try:
|
||||||
if 'underlying_symbol' in df.columns and row['underlying_symbol'] != '':
|
if 'underlying_symbol' in df.columns and row['underlying_symbol'] != '':
|
||||||
|
|||||||
@@ -11,17 +11,31 @@ class ExecutionSimulator:
|
|||||||
模拟交易执行和管理账户资金、持仓。
|
模拟交易执行和管理账户资金、持仓。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, initial_capital: float,
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_capital: float,
|
||||||
slippage_rate: float = 0.0001,
|
slippage_rate: float = 0.0001,
|
||||||
commission_rate: float = 0.0002,
|
commission_rate: float = 0.0002,
|
||||||
initial_positions: Optional[Dict[str, int]] = None):
|
initial_positions: Optional[Dict[str, int]] = None,
|
||||||
|
initial_average_costs: Optional[Dict[str, float]] = None,
|
||||||
|
): # 新增参数
|
||||||
self.initial_capital = initial_capital
|
self.initial_capital = initial_capital
|
||||||
self.cash = initial_capital
|
self.cash = initial_capital
|
||||||
self.positions: Dict[str, int] = initial_positions if initial_positions is not None else {}
|
self.positions: Dict[str, int] = (
|
||||||
self.average_costs: Dict[str, float] = {}
|
initial_positions if initial_positions is not None else {}
|
||||||
if initial_positions:
|
)
|
||||||
|
# 修正:初始平均成本应该从参数传入,而不是默认0.0
|
||||||
|
self.average_costs: Dict[str, float] = (
|
||||||
|
initial_average_costs if initial_average_costs is not None else {}
|
||||||
|
)
|
||||||
|
# 如果提供了 initial_positions 但没有提供 initial_average_costs,可以警告或默认处理
|
||||||
|
if initial_positions and not initial_average_costs:
|
||||||
|
print(
|
||||||
|
f"[{datetime.now()}] 警告: 提供了初始持仓但未提供初始平均成本,这些持仓的成本默认为0.0。"
|
||||||
|
)
|
||||||
for symbol, qty in initial_positions.items():
|
for symbol, qty in initial_positions.items():
|
||||||
self.average_costs[symbol] = 0.0
|
if symbol not in self.average_costs:
|
||||||
|
self.average_costs[symbol] = 0.0 # 如果没有提供,默认给0
|
||||||
|
|
||||||
self.slippage_rate = slippage_rate
|
self.slippage_rate = slippage_rate
|
||||||
self.commission_rate = commission_rate
|
self.commission_rate = commission_rate
|
||||||
@@ -30,94 +44,94 @@ class ExecutionSimulator:
|
|||||||
self._current_time: Optional[datetime] = None
|
self._current_time: Optional[datetime] = None
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}")
|
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}"
|
||||||
|
)
|
||||||
if self.positions:
|
if self.positions:
|
||||||
print(f"初始持仓:{self.positions}")
|
print(f"初始持仓:{self.positions}")
|
||||||
|
print(f"初始平均成本:{self.average_costs}") # 打印初始成本以便检查
|
||||||
|
|
||||||
def update_time(self, current_time: datetime):
|
def update_time(self, current_time: datetime):
|
||||||
self._current_time = current_time
|
self._current_time = current_time
|
||||||
|
|
||||||
def get_current_time(self) -> datetime:
|
def get_current_time(self) -> datetime:
|
||||||
if self._current_time is None:
|
if self._current_time is None:
|
||||||
# 改进:如果时间未设置,可以抛出错误,防止策略在 on_init 阶段意外调用
|
|
||||||
# raise RuntimeError("Simulator time has not been set. Ensure update_time is called.")
|
|
||||||
return None
|
return None
|
||||||
return self._current_time
|
return self._current_time
|
||||||
|
|
||||||
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
|
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
|
||||||
"""
|
"""
|
||||||
内部方法:根据订单类型和滑点计算实际成交价格。
|
内部方法:根据订单类型和滑点计算实际成交价格。
|
||||||
撮合逻辑:所有订单(市价/限价)都以当前K线的 **开盘价 (open)** 为基准进行撮合。
|
撮合逻辑:
|
||||||
|
- 市价单:以当前K线的 **开盘价 (open)** 为基准进行撮合,并考虑滑点。
|
||||||
|
- 限价单:判断 K 线的 **最高价 (high)** 和 **最低价 (low)** 是否触及限价。如果触及,则以 **限价 (limit_price)** 为基准计算成交价,并考虑滑点。
|
||||||
"""
|
"""
|
||||||
fill_price = -1.0 # 默认未成交
|
fill_price = -1.0 # 默认未成交
|
||||||
|
|
||||||
base_price = current_bar.open # 所有成交都以当前K线的开盘价为基准
|
# 对于市价单,仍然使用开盘价作为基准检查点
|
||||||
|
base_price_for_market_order = current_bar.open
|
||||||
|
|
||||||
if order.price_type == "MARKET":
|
if order.price_type == "MARKET":
|
||||||
# 市价单:直接以开盘价成交,考虑滑点
|
# 市价单:直接以开盘价成交,考虑滑点
|
||||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 买入/平空:向上偏离(多付)
|
if (
|
||||||
fill_price = base_price * (1 + self.slippage_rate)
|
order.direction == "BUY" or order.direction == "CLOSE_SHORT"
|
||||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 卖出/平多:向下偏离(少收)
|
): # 买入/平空:向上偏离(多付)
|
||||||
fill_price = base_price * (1 - self.slippage_rate)
|
fill_price = (base_price_for_market_order + 1) * (1 + self.slippage_rate)
|
||||||
|
elif (
|
||||||
|
order.direction == "SELL" or order.direction == "CLOSE_LONG"
|
||||||
|
): # 卖出/平多:向下偏离(少收)
|
||||||
|
fill_price = (base_price_for_market_order - 1) * (1 - self.slippage_rate)
|
||||||
else:
|
else:
|
||||||
fill_price = base_price # 理论上不发生
|
fill_price = base_price_for_market_order # 理论上不发生
|
||||||
|
|
||||||
elif order.price_type == "LIMIT" and order.limit_price is not None:
|
elif order.price_type == "LIMIT" and order.limit_price is not None:
|
||||||
limit_price = order.limit_price
|
limit_price = order.limit_price
|
||||||
|
|
||||||
# 限价单:判断开盘价是否满足限价条件,如果满足,则以开盘价成交(考虑滑点)
|
# 限价单:判断 K 线的高低价是否触及限价
|
||||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 限价买入/平空
|
if (
|
||||||
# 买单只有当开盘价低于或等于限价时才可能成交
|
order.direction == "BUY" or order.direction == "CLOSE_SHORT"
|
||||||
# 即:我愿意出 limit_price 买,开盘价 open_price 更低或一样,当然买
|
): # 限价买入/平空
|
||||||
if base_price <= limit_price:
|
# 如果当前K线的最低价低于或等于限价,则买入限价单有机会成交
|
||||||
fill_price = base_price * (1 + self.slippage_rate)
|
if current_bar.low < limit_price:
|
||||||
# else: 未满足限价条件,不成交
|
# 成交价以限价为基准,并考虑滑点(买入向上偏离)
|
||||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 限价卖出/平多
|
fill_price = limit_price * (1 + self.slippage_rate)
|
||||||
# 卖单只有当开盘价高于或等于限价时才可能成交
|
elif (
|
||||||
# 即:我愿意出 limit_price 卖,开盘价 open_price 更高或一样,当然卖
|
order.direction == "SELL" or order.direction == "CLOSE_LONG"
|
||||||
if base_price >= limit_price:
|
): # 限价卖出/平多
|
||||||
fill_price = base_price * (1 - self.slippage_rate)
|
# 如果当前K线的最高价高于或等于限价,则卖出限价单有机会成交
|
||||||
# else: 未满足限价条件,不成交
|
if current_bar.high > limit_price:
|
||||||
|
# 成交价以限价为基准,并考虑滑点(卖出向下偏离)
|
||||||
|
fill_price = limit_price * (1 - self.slippage_rate)
|
||||||
|
|
||||||
# 最终检查成交价是否有效且合理(大于0)
|
|
||||||
if fill_price <= 0:
|
if fill_price <= 0:
|
||||||
return -1.0 # 未成交或价格无效
|
return -1.0
|
||||||
|
|
||||||
return fill_price
|
return fill_price
|
||||||
|
|
||||||
|
|
||||||
def send_order_to_pending(self, order: Order) -> Optional[Order]:
|
def send_order_to_pending(self, order: Order) -> Optional[Order]:
|
||||||
"""
|
"""
|
||||||
将订单添加到待处理队列。由 BacktestEngine 或 Strategy 调用。
|
将订单添加到待处理队列。由 BacktestEngine 或 Strategy 调用。
|
||||||
此方法不进行撮合,撮合由 process_pending_orders 统一处理。
|
此方法不进行撮合,撮合由 process_pending_orders 统一处理。
|
||||||
"""
|
"""
|
||||||
if order.id in self.pending_orders:
|
if order.id in self.pending_orders:
|
||||||
# print(f"订单 {order.id} 已经存在于待处理队列。")
|
|
||||||
return None
|
return None
|
||||||
self.pending_orders[order.id] = order
|
self.pending_orders[order.id] = order
|
||||||
# print(f"订单 {order.id} 加入待处理队列。")
|
|
||||||
return order
|
return order
|
||||||
|
|
||||||
def process_pending_orders(self, current_bar: Bar):
|
def process_pending_orders(self, current_bar: Bar):
|
||||||
"""
|
"""
|
||||||
处理所有待撮合的订单。在每个K线数据到来时调用。
|
处理所有待撮合的订单。在每个K线数据到来时调用。
|
||||||
"""
|
"""
|
||||||
# 复制一份待处理订单的键,防止在迭代时修改字典
|
|
||||||
order_ids_to_process = list(self.pending_orders.keys())
|
order_ids_to_process = list(self.pending_orders.keys())
|
||||||
|
|
||||||
for order_id in order_ids_to_process:
|
for order_id in order_ids_to_process:
|
||||||
if order_id not in self.pending_orders: # 订单可能已被取消
|
if order_id not in self.pending_orders:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
order = self.pending_orders[order_id]
|
order = self.pending_orders[order_id]
|
||||||
|
|
||||||
# 只有当订单的symbol与当前bar的symbol一致时才尝试撮合
|
|
||||||
# 这样确保了在换月后,旧合约的挂单不会被尝试撮合 (尽管换月时会强制取消)
|
|
||||||
if order.symbol != current_bar.symbol:
|
if order.symbol != current_bar.symbol:
|
||||||
# 这种情况理论上应该被换月逻辑清理掉的旧合约挂单,
|
|
||||||
# 如果因为某种原因漏掉了,这里直接跳过,避免异常。
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 尝试成交订单
|
|
||||||
self._execute_single_order(order, current_bar)
|
self._execute_single_order(order, current_bar)
|
||||||
|
|
||||||
def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
|
def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
|
||||||
@@ -125,193 +139,185 @@ class ExecutionSimulator:
|
|||||||
内部方法:尝试执行单个订单,并处理资金和持仓变化。
|
内部方法:尝试执行单个订单,并处理资金和持仓变化。
|
||||||
由 send_order 或 process_pending_orders 调用。
|
由 send_order 或 process_pending_orders 调用。
|
||||||
"""
|
"""
|
||||||
# --- 处理撤单指令 ---
|
if order.direction == "CANCEL":
|
||||||
if order.direction == "CANCEL": # 策略主动发起撤单
|
|
||||||
success = self.cancel_order(order.id)
|
success = self.cancel_order(order.id)
|
||||||
if success:
|
if success:
|
||||||
# print(f"[{current_bar.datetime}] 模拟器: 收到并成功处理撤单指令 for Order ID: {order.id}")
|
|
||||||
pass
|
pass
|
||||||
return None # 撤单操作不返回Trade
|
return None
|
||||||
|
|
||||||
symbol = order.symbol
|
symbol = order.symbol
|
||||||
volume = order.volume
|
volume = order.volume
|
||||||
|
|
||||||
# 尝试计算成交价格
|
|
||||||
fill_price = self._calculate_fill_price(order, current_bar)
|
fill_price = self._calculate_fill_price(order, current_bar)
|
||||||
|
if fill_price <= 0:
|
||||||
if fill_price <= 0: # 未成交或不满足限价条件
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# --- 以下是订单成功成交前的预检查逻辑 ---
|
|
||||||
trade_value = volume * fill_price
|
trade_value = volume * fill_price
|
||||||
commission = trade_value * self.commission_rate
|
commission = trade_value * self.commission_rate
|
||||||
|
|
||||||
current_position = self.positions.get(symbol, 0)
|
current_position = self.positions.get(symbol, 0)
|
||||||
current_average_cost = self.average_costs.get(symbol, 0.0)
|
current_average_cost = self.average_costs.get(symbol, 0.0)
|
||||||
|
|
||||||
realized_pnl = 0.0 # 预先计算的实现盈亏
|
realized_pnl = 0.0
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
|
||||||
# 精确判断 is_open_trade 和 is_close_trade
|
|
||||||
# -----------------------------------------------------------
|
|
||||||
is_trade_a_close_operation = False
|
is_trade_a_close_operation = False
|
||||||
is_trade_an_open_operation = False
|
is_trade_an_open_operation = False
|
||||||
|
|
||||||
# 1. 判断是否为平仓操作
|
if order.direction in ["CLOSE_LONG", "CLOSE_SHORT"]:
|
||||||
# 显式平仓指令
|
|
||||||
if order.direction in ["CLOSE_LONG", "CLOSE_SELL", "CLOSE_SHORT"]:
|
|
||||||
is_trade_a_close_operation = True
|
is_trade_a_close_operation = True
|
||||||
# 隐式平仓 (例如,持有空头时买入,或持有多头时卖出)
|
elif order.direction == "BUY" and current_position < 0:
|
||||||
elif order.direction == "BUY" and current_position < 0: # 买入平空
|
|
||||||
is_trade_a_close_operation = True
|
is_trade_a_close_operation = True
|
||||||
elif order.direction == "SELL" and current_position > 0: # 卖出平多
|
elif order.direction == "SELL" and current_position > 0:
|
||||||
is_trade_a_close_operation = True
|
is_trade_a_close_operation = True
|
||||||
|
|
||||||
# 2. 判断是否为开仓操作
|
|
||||||
if order.direction == "BUY":
|
if order.direction == "BUY":
|
||||||
# 买入开多: 如果当前持有多头或无仓位,或者从空头转为多头
|
if current_position >= 0 or (
|
||||||
if current_position >= 0 or (current_position < 0 and (current_position + volume) > 0):
|
current_position < 0 and (current_position + volume) > 0
|
||||||
|
):
|
||||||
is_trade_an_open_operation = True
|
is_trade_an_open_operation = True
|
||||||
elif order.direction == "SELL":
|
elif order.direction == "SELL":
|
||||||
# 卖出开空: 如果当前持有空头或无仓位,或者从多头转为空头
|
if current_position <= 0 or (
|
||||||
if current_position <= 0 or (current_position > 0 and (current_position - volume) < 0):
|
current_position > 0 and (current_position - volume) < 0
|
||||||
|
):
|
||||||
is_trade_an_open_operation = True
|
is_trade_an_open_operation = True
|
||||||
|
|
||||||
# -----------------------------------------------------------
|
|
||||||
|
|
||||||
# 区分实际的买卖方向 (用于资金和持仓计算)
|
|
||||||
actual_execution_direction = ""
|
actual_execution_direction = ""
|
||||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT":
|
if order.direction == "BUY" or order.direction == "CLOSE_SHORT":
|
||||||
actual_execution_direction = "BUY"
|
actual_execution_direction = "BUY"
|
||||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG" or order.direction == "CLOSE_SELL":
|
elif order.direction == "SELL" or order.direction == "CLOSE_LONG":
|
||||||
actual_execution_direction = "SELL"
|
actual_execution_direction = "SELL"
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。")
|
f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。"
|
||||||
if order.id in self.pending_orders: del self.pending_orders[order.id]
|
)
|
||||||
|
if order.id in self.pending_orders:
|
||||||
|
del self.pending_orders[order.id]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# --- 临时变量,用于预计算新的资金和持仓状态 ---
|
|
||||||
temp_cash = self.cash
|
temp_cash = self.cash
|
||||||
temp_positions = self.positions.copy()
|
temp_positions = self.positions.copy()
|
||||||
temp_average_costs = self.average_costs.copy()
|
temp_average_costs = self.average_costs.copy()
|
||||||
|
|
||||||
# 根据实际执行方向进行预计算和资金检查
|
if actual_execution_direction == "BUY":
|
||||||
if actual_execution_direction == "BUY": # 处理实际的买入 (开多 / 平空)
|
if current_position >= 0:
|
||||||
if current_position >= 0: # 当前持有多仓或无仓位 (开多)
|
|
||||||
required_cash = trade_value + commission
|
required_cash = trade_value + commission
|
||||||
if temp_cash < required_cash:
|
if temp_cash < required_cash:
|
||||||
print(
|
# print(
|
||||||
f"[{current_bar.datetime}] 模拟器: 资金不足 (开多), 无法执行买入 {volume} {symbol} @ {fill_price:.2f}. 需要: {required_cash:.2f}, 当前: {temp_cash:.2f}")
|
# f"[{current_bar.datetime}] 模拟器: 资金不足 (开多), 无法执行买入 {volume} {symbol} @ {fill_price:.2f}. 需要: {required_cash:.2f}, 当前: {temp_cash:.2f}"
|
||||||
if order.id in self.pending_orders: del self.pending_orders[order.id]
|
# )
|
||||||
|
if order.id in self.pending_orders:
|
||||||
|
del self.pending_orders[order.id]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
temp_cash -= required_cash
|
temp_cash -= required_cash
|
||||||
new_total_cost = (temp_average_costs.get(symbol, 0.0) * temp_positions.get(symbol, 0)) + (
|
new_total_cost = (
|
||||||
fill_price * volume)
|
temp_average_costs.get(symbol, 0.0) * temp_positions.get(symbol, 0)
|
||||||
|
) + (fill_price * volume)
|
||||||
new_total_volume = temp_positions.get(symbol, 0) + volume
|
new_total_volume = temp_positions.get(symbol, 0) + volume
|
||||||
temp_average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
|
temp_average_costs[symbol] = (
|
||||||
|
new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
|
||||||
|
)
|
||||||
temp_positions[symbol] = new_total_volume
|
temp_positions[symbol] = new_total_volume
|
||||||
|
|
||||||
else: # 当前持有空仓 (平空) - 平仓交易,佣金从交易价值中扣除,不单独检查现金余额
|
else: # 当前持有空仓 (平空) - 平仓交易
|
||||||
pnl_per_share = current_average_cost - fill_price # 空头平仓盈亏
|
pnl_per_share = current_average_cost - fill_price
|
||||||
realized_pnl = pnl_per_share * volume
|
realized_pnl = pnl_per_share * volume
|
||||||
|
|
||||||
temp_cash -= commission # 扣除佣金
|
temp_cash -= commission
|
||||||
temp_cash += trade_value # 回收平仓价值
|
temp_cash -= trade_value
|
||||||
temp_cash += realized_pnl # 计入实现盈亏
|
temp_cash += realized_pnl
|
||||||
|
|
||||||
temp_positions[symbol] += volume
|
temp_positions[symbol] += volume
|
||||||
if temp_positions[symbol] == 0:
|
if temp_positions[symbol] == 0:
|
||||||
del temp_positions[symbol]
|
del temp_positions[symbol]
|
||||||
if symbol in temp_average_costs: del temp_average_costs[symbol]
|
if symbol in temp_average_costs:
|
||||||
elif current_position < 0 and temp_positions[symbol] > 0: # 发生空转多
|
del temp_average_costs[symbol]
|
||||||
temp_average_costs[symbol] = fill_price # 新多头仓位成本以成交价为准
|
elif current_position < 0 and temp_positions[symbol] > 0:
|
||||||
|
temp_average_costs[symbol] = fill_price
|
||||||
|
|
||||||
|
elif actual_execution_direction == "SELL":
|
||||||
elif actual_execution_direction == "SELL": # 处理实际的卖出 (开空 / 平多)
|
|
||||||
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
|
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
|
||||||
# 开空主要检查佣金是否足够
|
|
||||||
if temp_cash < commission:
|
if temp_cash < commission:
|
||||||
print(
|
# print(
|
||||||
f"[{current_bar.datetime}] 模拟器: 资金不足 (开空佣金), 无法执行卖出 {volume} {symbol} @ {fill_price:.2f}. 佣金: {commission:.2f}, 当前: {temp_cash:.2f}")
|
# f"[{current_bar.datetime}] 模拟器: 资金不足 (开空佣金), 无法执行卖出 {volume} {symbol} @ {fill_price:.2f}. 佣金: {commission:.2f}, 当前: {temp_cash:.2f}"
|
||||||
if order.id in self.pending_orders: del self.pending_orders[order.id]
|
# )
|
||||||
|
if order.id in self.pending_orders:
|
||||||
|
del self.pending_orders[order.id]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
temp_cash -= commission
|
temp_cash -= commission
|
||||||
new_total_value = (temp_average_costs.get(symbol, 0.0) * abs(temp_positions.get(symbol, 0))) + (
|
temp_cash += trade_value # 修正点:开空时将卖出资金计入现金
|
||||||
fill_price * volume)
|
|
||||||
new_total_volume = abs(temp_positions.get(symbol, 0)) + volume
|
|
||||||
temp_average_costs[symbol] = new_total_value / new_total_volume if new_total_volume > 0 else 0.0 # 平均成本
|
|
||||||
temp_positions[symbol] -= volume
|
|
||||||
|
|
||||||
else: # 当前持有多仓 (平多) - 平仓交易,佣金从交易价值中扣除,不单独检查现金余额
|
existing_abs_volume = abs(temp_positions.get(symbol, 0))
|
||||||
pnl_per_share = fill_price - current_average_cost # 多头平仓盈亏
|
existing_abs_cost = (
|
||||||
|
temp_average_costs.get(symbol, 0.0) * existing_abs_volume
|
||||||
|
)
|
||||||
|
|
||||||
|
new_total_value = existing_abs_cost + (fill_price * volume)
|
||||||
|
new_total_volume = existing_abs_volume + volume
|
||||||
|
|
||||||
|
temp_average_costs[symbol] = (
|
||||||
|
new_total_value / new_total_volume if new_total_volume > 0 else 0.0
|
||||||
|
)
|
||||||
|
temp_positions[symbol] = -new_total_volume
|
||||||
|
|
||||||
|
else: # 当前持有多仓 (平多) - 平仓交易
|
||||||
|
pnl_per_share = fill_price - current_average_cost
|
||||||
realized_pnl = pnl_per_share * volume
|
realized_pnl = pnl_per_share * volume
|
||||||
|
|
||||||
temp_cash -= commission # 扣除佣金
|
temp_cash -= commission
|
||||||
temp_cash += trade_value # 回收平仓价值
|
temp_cash += trade_value
|
||||||
temp_cash += realized_pnl # 计入实现盈亏
|
temp_cash += realized_pnl
|
||||||
|
|
||||||
temp_positions[symbol] -= volume
|
temp_positions[symbol] -= volume
|
||||||
if temp_positions[symbol] == 0:
|
if temp_positions[symbol] == 0:
|
||||||
del temp_positions[symbol]
|
del temp_positions[symbol]
|
||||||
if symbol in temp_average_costs: del temp_average_costs[symbol]
|
if symbol in temp_average_costs:
|
||||||
elif current_position > 0 and temp_positions[symbol] < 0: # 发生多转空
|
del temp_average_costs[symbol]
|
||||||
temp_average_costs[symbol] = fill_price # 新空头仓位成本以成交价为准
|
elif current_position > 0 and temp_positions[symbol] < 0:
|
||||||
|
temp_average_costs[symbol] = fill_price
|
||||||
|
|
||||||
# --- 所有检查通过后,才正式更新模拟器状态 ---
|
|
||||||
self.cash = temp_cash
|
self.cash = temp_cash
|
||||||
self.positions = temp_positions
|
self.positions = temp_positions
|
||||||
self.average_costs = temp_average_costs
|
self.average_costs = temp_average_costs
|
||||||
|
|
||||||
# 创建 Trade 对象时,direction 使用原始订单的 direction
|
|
||||||
executed_trade = Trade(
|
executed_trade = Trade(
|
||||||
order_id=order.id, fill_time=current_bar.datetime, symbol=symbol,
|
order_id=order.id,
|
||||||
direction=order.direction, # 使用原始订单的 direction
|
fill_time=current_bar.datetime,
|
||||||
volume=volume, price=fill_price, commission=commission,
|
symbol=symbol,
|
||||||
cash_after_trade=self.cash, positions_after_trade=self.positions.copy(),
|
direction=order.direction,
|
||||||
|
volume=volume,
|
||||||
|
price=fill_price,
|
||||||
|
commission=commission,
|
||||||
|
cash_after_trade=self.cash,
|
||||||
|
positions_after_trade=self.positions.copy(),
|
||||||
realized_pnl=realized_pnl,
|
realized_pnl=realized_pnl,
|
||||||
is_open_trade=is_trade_an_open_operation, # 使用更精确的判断
|
is_open_trade=is_trade_an_open_operation,
|
||||||
is_close_trade=is_trade_a_close_operation # 使用更精确的判断
|
is_close_trade=is_trade_a_close_operation,
|
||||||
)
|
)
|
||||||
self.trade_log.append(executed_trade)
|
self.trade_log.append(executed_trade)
|
||||||
|
|
||||||
# 订单成交,从待处理订单中移除
|
|
||||||
if order.id in self.pending_orders:
|
if order.id in self.pending_orders:
|
||||||
del self.pending_orders[order.id]
|
del self.pending_orders[order.id]
|
||||||
|
|
||||||
return executed_trade
|
return executed_trade
|
||||||
|
|
||||||
def cancel_order(self, order_id: str) -> bool:
|
def cancel_order(self, order_id: str) -> bool:
|
||||||
"""
|
|
||||||
尝试取消一个待处理订单。
|
|
||||||
"""
|
|
||||||
if order_id in self.pending_orders:
|
if order_id in self.pending_orders:
|
||||||
del self.pending_orders[order_id]
|
del self.pending_orders[order_id]
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# --- 新增:强制平仓指定合约的所有持仓 ---
|
def force_close_all_positions_for_symbol(
|
||||||
def force_close_all_positions_for_symbol(self, symbol_to_close: str, closing_bar: Bar) -> List[Trade]:
|
self, symbol_to_close: str, closing_bar: Bar
|
||||||
"""
|
) -> List[Trade]:
|
||||||
强制平仓指定合约的所有持仓。
|
|
||||||
Args:
|
|
||||||
symbol_to_close (str): 需要平仓的合约代码。
|
|
||||||
closing_bar (Bar): 用于获取平仓价格的当前K线数据(通常是旧合约的最后一根K线)。
|
|
||||||
Returns:
|
|
||||||
List[Trade]: 因强制平仓而产生的交易记录。
|
|
||||||
"""
|
|
||||||
closed_trades: List[Trade] = []
|
closed_trades: List[Trade] = []
|
||||||
|
|
||||||
# 仅处理指定symbol的持仓
|
|
||||||
if symbol_to_close in self.positions and self.positions[symbol_to_close] != 0:
|
if symbol_to_close in self.positions and self.positions[symbol_to_close] != 0:
|
||||||
volume_to_close = self.positions[symbol_to_close]
|
volume_to_close = self.positions[symbol_to_close]
|
||||||
|
|
||||||
# 根据持仓方向决定平仓订单的方向
|
direction = "CLOSE_LONG" if volume_to_close > 0 else "CLOSE_SHORT"
|
||||||
direction = "CLOSE_LONG" if volume_to_close > 0 else "CLOSE_SELL" # 多头平仓是卖出,空头平仓是买入
|
|
||||||
|
|
||||||
# 构造一个市价平仓订单
|
|
||||||
rollover_order = Order(
|
rollover_order = Order(
|
||||||
id=f"FORCE_CLOSE_{symbol_to_close}_{closing_bar.datetime.strftime('%Y%m%d%H%M%S%f')}",
|
id=f"FORCE_CLOSE_{symbol_to_close}_{closing_bar.datetime.strftime('%Y%m%d%H%M%S%f')}",
|
||||||
symbol=symbol_to_close,
|
symbol=symbol_to_close,
|
||||||
@@ -321,28 +327,26 @@ class ExecutionSimulator:
|
|||||||
limit_price=None,
|
limit_price=None,
|
||||||
submitted_time=closing_bar.datetime,
|
submitted_time=closing_bar.datetime,
|
||||||
)
|
)
|
||||||
|
# 这里直接调用 _execute_single_order 确保强制平仓立即成交
|
||||||
# 使用内部的执行逻辑进行撮合
|
|
||||||
trade = self._execute_single_order(rollover_order, closing_bar)
|
trade = self._execute_single_order(rollover_order, closing_bar)
|
||||||
if trade:
|
if trade:
|
||||||
closed_trades.append(trade)
|
closed_trades.append(trade)
|
||||||
else:
|
else:
|
||||||
print(f"[{closing_bar.datetime}] 警告: 强制平仓 {symbol_to_close} 失败!")
|
print(
|
||||||
|
f"[{closing_bar.datetime}] 警告: 强制平仓 {symbol_to_close} 失败!"
|
||||||
|
)
|
||||||
|
|
||||||
return closed_trades
|
return closed_trades
|
||||||
|
|
||||||
# --- 新增:取消指定合约的所有挂单 ---
|
|
||||||
def cancel_all_pending_orders_for_symbol(self, symbol_to_cancel: str) -> int:
|
def cancel_all_pending_orders_for_symbol(self, symbol_to_cancel: str) -> int:
|
||||||
"""
|
|
||||||
取消指定合约的所有待处理订单。
|
|
||||||
"""
|
|
||||||
cancelled_count = 0
|
cancelled_count = 0
|
||||||
order_ids_to_cancel = [
|
order_ids_to_cancel = [
|
||||||
order_id for order_id, order in self.pending_orders.items()
|
order_id
|
||||||
|
for order_id, order in self.pending_orders.items()
|
||||||
if order.symbol == symbol_to_cancel
|
if order.symbol == symbol_to_cancel
|
||||||
]
|
]
|
||||||
for order_id in order_ids_to_cancel:
|
for order_id in order_ids_to_cancel:
|
||||||
if self.cancel_order(order_id): # 调用现有的 cancel_order 方法
|
if self.cancel_order(order_id):
|
||||||
cancelled_count += 1
|
cancelled_count += 1
|
||||||
return cancelled_count
|
return cancelled_count
|
||||||
|
|
||||||
@@ -350,37 +354,15 @@ class ExecutionSimulator:
|
|||||||
return self.pending_orders.copy()
|
return self.pending_orders.copy()
|
||||||
|
|
||||||
def get_portfolio_value(self, current_bar: Bar) -> float:
|
def get_portfolio_value(self, current_bar: Bar) -> float:
|
||||||
"""
|
|
||||||
计算当前的投资组合总价值(包括现金和持仓市值)。
|
|
||||||
此方法需要兼容多合约持仓的场景。
|
|
||||||
Args:
|
|
||||||
current_bar (Bar): 当前的Bar数据,用于计算**当前活跃合约**的持仓市值。
|
|
||||||
注意:如果 simulator 中持有多个合约,这里需要更复杂的逻辑。
|
|
||||||
目前假设主力合约回测时,simulator.positions 主要只包含当前主力合约。
|
|
||||||
Returns:
|
|
||||||
float: 当前的投资组合总价值。
|
|
||||||
"""
|
|
||||||
total_value = self.cash
|
total_value = self.cash
|
||||||
|
|
||||||
# 遍历所有持仓,计算市值。
|
|
||||||
# 注意:这里假设 current_bar 提供了当前活跃主力合约的价格。
|
|
||||||
# 如果 self.positions 中包含其他非 current_bar.symbol 的旧合约,
|
|
||||||
# 它们的市值将无法用 current_bar.open 来准确计算。
|
|
||||||
# 在换月模式下,旧合约会被强制平仓,因此 simulator.positions 通常只包含一个合约。
|
|
||||||
for symbol, quantity in self.positions.items():
|
for symbol, quantity in self.positions.items():
|
||||||
# 这里简单处理:如果持仓合约与 current_bar.symbol 相同,则使用 current_bar.open 计算。
|
|
||||||
# 如果是其他合约,则需要外部提供其最新价格,但这超出了本函数当前的能力范围。
|
|
||||||
# 考虑到换月模式,旧合约会被平仓,所以大部分时候这不会是问题。
|
|
||||||
if symbol == current_bar.symbol:
|
if symbol == current_bar.symbol:
|
||||||
total_value += quantity * current_bar.open
|
total_value += quantity * current_bar.open
|
||||||
else:
|
else:
|
||||||
# 警告:如果这里出现,说明有未平仓的旧合约持仓,且没有其最新价格来计算市值。
|
print(
|
||||||
# 在严谨的主力连续回测中,这不应该发生,因为换月会强制平仓。
|
f"[{current_bar.datetime}] 警告:持仓中存在非当前K线合约 {symbol},无法准确计算其市值。"
|
||||||
print(f"[{current_bar.datetime}] 警告:持仓中存在非当前K线合约 {symbol},无法准确计算其市值。")
|
)
|
||||||
# 可以选择将这部分持仓价值计为0,或者使用上一个已知价格(需要额外数据结构)
|
|
||||||
# 这里我们假设它不影响总价值计算,因为换月时会处理掉
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return total_value
|
return total_value
|
||||||
|
|
||||||
def get_current_positions(self) -> Dict[str, int]:
|
def get_current_positions(self) -> Dict[str, int]:
|
||||||
@@ -389,22 +371,38 @@ class ExecutionSimulator:
|
|||||||
def get_trade_history(self) -> List[Trade]:
|
def get_trade_history(self) -> List[Trade]:
|
||||||
return self.trade_log.copy()
|
return self.trade_log.copy()
|
||||||
|
|
||||||
def reset(self, new_initial_capital: float = None, new_initial_positions: Dict[str, int] = None) -> None:
|
def reset(
|
||||||
"""
|
self,
|
||||||
重置模拟器状态到新的初始条件。
|
new_initial_capital: float = None,
|
||||||
此方法不用于换月时的平仓,它用于整个回测开始前的初始化。
|
new_initial_positions: Dict[str, int] = None,
|
||||||
"""
|
new_initial_average_costs: Dict[str, float] = None,
|
||||||
|
) -> None: # 新增参数
|
||||||
print("ExecutionSimulator: 重置状态。")
|
print("ExecutionSimulator: 重置状态。")
|
||||||
self.cash = new_initial_capital if new_initial_capital is not None else self.initial_capital
|
self.cash = (
|
||||||
self.positions = new_initial_positions.copy() if new_initial_positions is not None else {}
|
new_initial_capital
|
||||||
self.average_costs = {}
|
if new_initial_capital is not None
|
||||||
for symbol, qty in self.positions.items(): # 重置平均成本
|
else self.initial_capital
|
||||||
|
)
|
||||||
|
self.positions = (
|
||||||
|
new_initial_positions.copy() if new_initial_positions is not None else {}
|
||||||
|
)
|
||||||
|
# 修正:重置时也应该考虑传入初始平均成本
|
||||||
|
self.average_costs = (
|
||||||
|
new_initial_average_costs.copy()
|
||||||
|
if new_initial_average_costs is not None
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
if self.positions and not new_initial_average_costs:
|
||||||
|
print(
|
||||||
|
f"[{datetime.now()}] 警告: 重置时提供了初始持仓但未提供初始平均成本,这些持仓的成本默认为0.0。"
|
||||||
|
)
|
||||||
|
for symbol, qty in self.positions.items():
|
||||||
|
if symbol not in self.average_costs:
|
||||||
self.average_costs[symbol] = 0.0
|
self.average_costs[symbol] = 0.0
|
||||||
self.trade_log = []
|
|
||||||
self.pending_orders = {} # 清空挂单
|
|
||||||
self._current_time = None
|
|
||||||
|
|
||||||
# Removed clear_trade_history as trade_log is cleared in reset
|
self.trade_log = []
|
||||||
|
self.pending_orders = {}
|
||||||
|
self._current_time = None
|
||||||
|
|
||||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||||
if symbol in self.positions and self.positions[symbol] != 0:
|
if symbol in self.positions and self.positions[symbol] != 0:
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from ..core_data import Bar, Order
|
|||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
class SimpleLimitBuyStrategy(Strategy):
|
|
||||||
|
class SimpleLimitBuyStrategyLong(Strategy):
|
||||||
"""
|
"""
|
||||||
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
|
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
|
||||||
具备以下特点:
|
具备以下特点:
|
||||||
@@ -13,16 +14,23 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
- 最多只能有一个开仓挂单和一个持仓。
|
- 最多只能有一个开仓挂单和一个持仓。
|
||||||
- 包含简单的止损和止盈逻辑。
|
- 包含简单的止损和止盈逻辑。
|
||||||
"""
|
"""
|
||||||
def __init__(self, simulator: Any, symbol: str, enable_log: bool, trade_volume: int,
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: Any,
|
||||||
|
symbol: str,
|
||||||
|
enable_log: bool,
|
||||||
|
trade_volume: int,
|
||||||
open_range_factor_1_ago: float,
|
open_range_factor_1_ago: float,
|
||||||
open_range_factor_7_ago: float,
|
open_range_factor_7_ago: float,
|
||||||
max_position: int,
|
max_position: int,
|
||||||
stop_loss_points: float = 10, # 新增:止损点数
|
stop_loss_points: float = 10, # 新增:止损点数
|
||||||
take_profit_points: float = 10): # 新增:止盈点数
|
take_profit_points: float = 10,
|
||||||
|
): # 新增:止盈点数
|
||||||
"""
|
"""
|
||||||
初始化策略。
|
初始化策略。
|
||||||
Args:
|
Args:
|
||||||
simulator: 模拟器实例。
|
context: 模拟器实例。
|
||||||
symbol (str): 交易合约代码。
|
symbol (str): 交易合约代码。
|
||||||
trade_volume (int): 单笔交易量。
|
trade_volume (int): 单笔交易量。
|
||||||
open_range_factor_1_ago (float): 前1根K线Range的权重因子,用于从Open价向下偏移。
|
open_range_factor_1_ago (float): 前1根K线Range的权重因子,用于从Open价向下偏移。
|
||||||
@@ -31,7 +39,7 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
|
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
|
||||||
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
|
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
|
||||||
"""
|
"""
|
||||||
super().__init__(simulator, symbol, enable_log)
|
super().__init__(context, symbol, enable_log)
|
||||||
self.trade_volume = trade_volume
|
self.trade_volume = trade_volume
|
||||||
self.open_range_factor_1_ago = open_range_factor_1_ago
|
self.open_range_factor_1_ago = open_range_factor_1_ago
|
||||||
self.open_range_factor_7_ago = open_range_factor_7_ago
|
self.open_range_factor_7_ago = open_range_factor_7_ago
|
||||||
@@ -41,16 +49,17 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
|
|
||||||
self.order_id_counter = 0
|
self.order_id_counter = 0
|
||||||
|
|
||||||
self._bar_history: deque[Bar] = deque(maxlen=10)
|
|
||||||
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
|
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
|
||||||
|
|
||||||
self.log(f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
|
self.log(
|
||||||
|
f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
|
||||||
f"open_range_factor_1_ago={self.open_range_factor_1_ago}, "
|
f"open_range_factor_1_ago={self.open_range_factor_1_ago}, "
|
||||||
f"open_range_factor_7_ago={self.open_range_factor_7_ago}, "
|
f"open_range_factor_7_ago={self.open_range_factor_7_ago}, "
|
||||||
f"max_position={self.max_position}, "
|
f"max_position={self.max_position}, "
|
||||||
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}")
|
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}"
|
||||||
|
)
|
||||||
|
|
||||||
def on_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
|
def on_open_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
|
||||||
"""
|
"""
|
||||||
每当新的K线数据到来时调用。
|
每当新的K线数据到来时调用。
|
||||||
Args:
|
Args:
|
||||||
@@ -65,36 +74,46 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
if self._last_order_id:
|
if self._last_order_id:
|
||||||
pending_orders = self.get_pending_orders()
|
pending_orders = self.get_pending_orders()
|
||||||
if self._last_order_id in pending_orders:
|
if self._last_order_id in pending_orders:
|
||||||
success = self.cancel_order(self._last_order_id) # 直接调用基类的取消方法
|
success = self.cancel_order(
|
||||||
|
self._last_order_id
|
||||||
|
) # 直接调用基类的取消方法
|
||||||
if success:
|
if success:
|
||||||
self.log(f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}")
|
self.log(
|
||||||
|
f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.log(f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)")
|
self.log(
|
||||||
|
f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)"
|
||||||
|
)
|
||||||
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录
|
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录
|
||||||
self._last_order_id = None
|
self._last_order_id = None
|
||||||
# else:
|
# else:
|
||||||
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
|
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
|
||||||
|
|
||||||
|
|
||||||
# 2. 更新K线历史
|
# 2. 更新K线历史
|
||||||
self._bar_history.append(bar)
|
|
||||||
trade_volume = self.trade_volume
|
trade_volume = self.trade_volume
|
||||||
|
|
||||||
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
|
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
|
||||||
current_positions = self.get_current_positions()
|
current_positions = self.get_current_positions()
|
||||||
current_pos_volume = current_positions.get(self.symbol, 0)
|
current_pos_volume = current_positions.get(self.symbol, 0)
|
||||||
pending_orders_after_cancel = self.get_pending_orders() # 再次获取,此时应已取消旧订单
|
pending_orders_after_cancel = (
|
||||||
|
self.get_pending_orders()
|
||||||
|
) # 再次获取,此时应已取消旧订单
|
||||||
|
|
||||||
# --- 3. 平仓逻辑 (止损/止盈) ---
|
# --- 3. 平仓逻辑 (止损/止盈) ---
|
||||||
# 只有当有持仓时才考虑平仓
|
# 只有当有持仓时才考虑平仓
|
||||||
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
|
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
|
||||||
avg_entry_price = self.get_average_position_price(self.symbol)
|
avg_entry_price = self.get_average_position_price(self.symbol)
|
||||||
if avg_entry_price is not None:
|
if avg_entry_price is not None:
|
||||||
pnl_per_unit = bar.close - avg_entry_price # 当前浮动盈亏(以收盘价计算)
|
pnl_per_unit = (
|
||||||
|
bar.open - avg_entry_price
|
||||||
|
) # 当前浮动盈亏(以收盘价计算)
|
||||||
|
|
||||||
# 止盈条件
|
# 止盈条件
|
||||||
if pnl_per_unit >= self.take_profit_points:
|
if pnl_per_unit >= self.take_profit_points:
|
||||||
self.log(f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}")
|
self.log(
|
||||||
|
f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
self.order_id_counter += 1
|
self.order_id_counter += 1
|
||||||
@@ -107,14 +126,17 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
volume=trade_volume,
|
volume=trade_volume,
|
||||||
price_type="MARKET",
|
price_type="MARKET",
|
||||||
# limit_price=limit_price,
|
# limit_price=limit_price,
|
||||||
submitted_time=bar.datetime
|
submitted_time=bar.datetime,
|
||||||
|
offset="CLOSE",
|
||||||
)
|
)
|
||||||
trade = self.send_order(order)
|
trade = self.send_order(order)
|
||||||
return # 平仓后本K线不再进行开仓判断
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
# 止损条件
|
# 止损条件
|
||||||
elif pnl_per_unit <= -self.stop_loss_points:
|
elif pnl_per_unit <= -self.stop_loss_points:
|
||||||
self.log(f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}")
|
self.log(
|
||||||
|
f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}"
|
||||||
|
)
|
||||||
# 发送市价卖出订单平仓,确保立即成交
|
# 发送市价卖出订单平仓,确保立即成交
|
||||||
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
self.order_id_counter += 1
|
self.order_id_counter += 1
|
||||||
@@ -127,7 +149,8 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
volume=trade_volume,
|
volume=trade_volume,
|
||||||
price_type="MARKET",
|
price_type="MARKET",
|
||||||
# limit_price=limit_price,
|
# limit_price=limit_price,
|
||||||
submitted_time=bar.datetime
|
submitted_time=bar.datetime,
|
||||||
|
offset="CLOSE",
|
||||||
)
|
)
|
||||||
trade = self.send_order(order)
|
trade = self.send_order(order)
|
||||||
return # 平仓后本K线不再进行开仓判断
|
return # 平仓后本K线不再进行开仓判断
|
||||||
@@ -135,12 +158,12 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
# --- 4. 开仓逻辑 (只考虑做多 BUY 方向) ---
|
# --- 4. 开仓逻辑 (只考虑做多 BUY 方向) ---
|
||||||
# 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel)
|
# 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel)
|
||||||
# 且K线历史足够长时才考虑开仓
|
# 且K线历史足够长时才考虑开仓
|
||||||
if current_pos_volume == 0 and \
|
bar_history = self.get_bar_history()
|
||||||
len(self._bar_history) == self._bar_history.maxlen:
|
if current_pos_volume == 0 and len(bar_history) > 10:
|
||||||
|
|
||||||
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
|
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
|
||||||
bar_1_ago = self._bar_history[-2]
|
bar_1_ago = bar_history[-2]
|
||||||
bar_7_ago = self._bar_history[-8]
|
bar_7_ago = bar_history[-8]
|
||||||
|
|
||||||
# 计算历史 K 线的 Range
|
# 计算历史 K 线的 Range
|
||||||
range_1_ago = bar_1_ago.high - bar_1_ago.low
|
range_1_ago = bar_1_ago.high - bar_1_ago.low
|
||||||
@@ -148,16 +171,24 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
|
|
||||||
# 根据策略逻辑计算目标买入价格
|
# 根据策略逻辑计算目标买入价格
|
||||||
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
|
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
|
||||||
self.log(bar.open ,range_1_ago * self.open_range_factor_1_ago, range_7_ago * self.open_range_factor_7_ago)
|
self.log(
|
||||||
target_buy_price = bar.open - (range_1_ago * self.open_range_factor_1_ago + range_7_ago * self.open_range_factor_7_ago)
|
bar.open,
|
||||||
|
range_1_ago * self.open_range_factor_1_ago,
|
||||||
|
range_7_ago * self.open_range_factor_7_ago,
|
||||||
|
)
|
||||||
|
target_buy_price = bar.open - (
|
||||||
|
range_1_ago * self.open_range_factor_1_ago
|
||||||
|
+ range_7_ago * self.open_range_factor_7_ago
|
||||||
|
)
|
||||||
|
|
||||||
# 确保目标买入价格有效,例如不能是负数
|
# 确保目标买入价格有效,例如不能是负数
|
||||||
target_buy_price = max(0.01, target_buy_price)
|
target_buy_price = max(0.01, target_buy_price)
|
||||||
|
|
||||||
self.log(f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, "
|
self.log(
|
||||||
|
f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, "
|
||||||
f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, "
|
f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, "
|
||||||
f"计算目标买入价={target_buy_price:.2f}")
|
f"计算目标买入价={target_buy_price:.2f}"
|
||||||
self.log(f'{self.context._simulator.get_current_positions()}')
|
)
|
||||||
|
|
||||||
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
self.order_id_counter += 1
|
self.order_id_counter += 1
|
||||||
@@ -170,26 +201,554 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
volume=trade_volume,
|
volume=trade_volume,
|
||||||
price_type="LIMIT",
|
price_type="LIMIT",
|
||||||
limit_price=target_buy_price,
|
limit_price=target_buy_price,
|
||||||
submitted_time=bar.datetime
|
submitted_time=bar.datetime,
|
||||||
)
|
)
|
||||||
new_order = self.send_order(order)
|
new_order = self.send_order(order)
|
||||||
# 记录下这个订单的ID,以便在下一根K线开始时进行撤销
|
# 记录下这个订单的ID,以便在下一根K线开始时进行撤销
|
||||||
if new_order:
|
if new_order:
|
||||||
self._last_order_id = new_order.id
|
self._last_order_id = new_order.id
|
||||||
self.log(f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}")
|
self.log(
|
||||||
|
f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
|
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
|
||||||
|
|
||||||
# else:
|
# else:
|
||||||
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(self._bar_history)}")
|
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(bar_history)}")
|
||||||
|
|
||||||
|
def on_close_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
|
||||||
|
self.cancel_all_pending_orders()
|
||||||
|
|
||||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||||
"""
|
"""
|
||||||
在合约换月时清空历史K线数据和上次订单ID,避免使用旧合约数据进行计算。
|
在合约换月时清空历史K线数据和上次订单ID,避免使用旧合约数据进行计算。
|
||||||
"""
|
"""
|
||||||
super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志
|
super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志
|
||||||
self._bar_history.clear() # 清空历史K线
|
|
||||||
self._last_order_id = None # 清空上次订单ID,因为旧合约订单已取消
|
self._last_order_id = None # 清空上次订单ID,因为旧合约订单已取消
|
||||||
|
|
||||||
self.log(f"换月完成,清空历史K线数据和上次订单ID,准备新合约交易。")
|
self.log(f"换月完成,清空历史K线数据和上次订单ID,准备新合约交易。")
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLimitBuyStrategyShort(Strategy):
|
||||||
|
"""
|
||||||
|
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
|
||||||
|
具备以下特点:
|
||||||
|
- 每根K线开始时取消上一根K线未成交的订单。
|
||||||
|
- 最多只能有一个开仓挂单和一个持仓。
|
||||||
|
- 包含简单的止损和止盈逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: Any,
|
||||||
|
symbol: str,
|
||||||
|
enable_log: bool,
|
||||||
|
trade_volume: int,
|
||||||
|
open_range_factor_1_ago: float,
|
||||||
|
open_range_factor_7_ago: float,
|
||||||
|
max_position: int,
|
||||||
|
stop_loss_points: float = 10, # 新增:止损点数
|
||||||
|
take_profit_points: float = 10,
|
||||||
|
): # 新增:止盈点数
|
||||||
|
"""
|
||||||
|
初始化策略。
|
||||||
|
Args:
|
||||||
|
context: 模拟器实例。
|
||||||
|
symbol (str): 交易合约代码。
|
||||||
|
trade_volume (int): 单笔交易量。
|
||||||
|
open_range_factor_1_ago (float): 前1根K线Range的权重因子,用于从Open价向下偏移。
|
||||||
|
open_range_factor_7_ago (float): 前7根K线Range的权重因子,用于从Open价向下偏移。
|
||||||
|
max_position (int): 最大持仓量(此处为1,因为只允许一个持仓)。
|
||||||
|
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
|
||||||
|
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
|
||||||
|
"""
|
||||||
|
super().__init__(context, symbol, enable_log)
|
||||||
|
self.trade_volume = trade_volume
|
||||||
|
self.open_range_factor_1_ago = open_range_factor_1_ago
|
||||||
|
self.open_range_factor_7_ago = open_range_factor_7_ago
|
||||||
|
self.max_position = max_position # 理论上这里应为1
|
||||||
|
self.stop_loss_points = stop_loss_points
|
||||||
|
self.take_profit_points = take_profit_points
|
||||||
|
|
||||||
|
self.order_id_counter = 0
|
||||||
|
self.last_buy_price = 0
|
||||||
|
self.last_sell_price = 0
|
||||||
|
|
||||||
|
bar_history: deque[Bar] = deque(maxlen=10)
|
||||||
|
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
|
||||||
|
|
||||||
|
self.log(
|
||||||
|
f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
|
||||||
|
f"open_range_factor_1_ago={self.open_range_factor_1_ago}, "
|
||||||
|
f"open_range_factor_7_ago={self.open_range_factor_7_ago}, "
|
||||||
|
f"max_position={self.max_position}, "
|
||||||
|
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_open_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
|
||||||
|
"""
|
||||||
|
每当新的K线数据到来时调用。
|
||||||
|
Args:
|
||||||
|
bar (Bar): 当前的K线数据对象。
|
||||||
|
next_bar_open (Optional[float]): 下一根K线的开盘价,此处策略未使用。
|
||||||
|
"""
|
||||||
|
current_datetime = bar.datetime # 获取当前K线时间
|
||||||
|
self.symbol = bar.symbol
|
||||||
|
|
||||||
|
# --- 1. 撤销上一根K线未成交的订单 ---
|
||||||
|
# 检查是否记录了上一笔订单ID,并且该订单仍然在待处理列表中
|
||||||
|
if self._last_order_id:
|
||||||
|
pending_orders = self.get_pending_orders()
|
||||||
|
if self._last_order_id in pending_orders:
|
||||||
|
success = self.cancel_order(
|
||||||
|
self._last_order_id
|
||||||
|
) # 直接调用基类的取消方法
|
||||||
|
if success:
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)"
|
||||||
|
)
|
||||||
|
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录
|
||||||
|
self._last_order_id = None
|
||||||
|
# else:
|
||||||
|
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
|
||||||
|
|
||||||
|
# 2. 更新K线历史
|
||||||
|
trade_volume = self.trade_volume
|
||||||
|
|
||||||
|
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
|
||||||
|
current_positions = self.get_current_positions()
|
||||||
|
current_pos_volume = current_positions.get(self.symbol, 0)
|
||||||
|
pending_orders_after_cancel = (
|
||||||
|
self.get_pending_orders()
|
||||||
|
) # 再次获取,此时应已取消旧订单
|
||||||
|
|
||||||
|
# --- 3. 平仓逻辑 (止损/止盈) ---
|
||||||
|
# 只有当有持仓时才考虑平仓
|
||||||
|
if current_pos_volume < 0: # 假设只做多,所以持仓量 > 0
|
||||||
|
avg_entry_price = self.get_average_position_price(self.symbol)
|
||||||
|
if avg_entry_price is not None:
|
||||||
|
pnl_per_unit = (
|
||||||
|
avg_entry_price - bar.open
|
||||||
|
) # 当前浮动盈亏(以收盘价计算)
|
||||||
|
|
||||||
|
# 止盈条件
|
||||||
|
if pnl_per_unit >= self.take_profit_points:
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="CLOSE_SHORT",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="MARKET",
|
||||||
|
# limit_price=limit_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
offset="CLOSE",
|
||||||
|
)
|
||||||
|
trade = self.send_order(order)
|
||||||
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
|
# 止损条件
|
||||||
|
elif pnl_per_unit <= -self.stop_loss_points:
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}"
|
||||||
|
)
|
||||||
|
# 发送市价卖出订单平仓,确保立即成交
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="CLOSE_SHORT",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="MARKET",
|
||||||
|
# limit_price=limit_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
offset="CLOSE",
|
||||||
|
)
|
||||||
|
trade = self.send_order(order)
|
||||||
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
|
# --- 4. 开仓逻辑 (只考虑做多 BUY 方向) ---
|
||||||
|
# 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel)
|
||||||
|
# 且K线历史足够长时才考虑开仓
|
||||||
|
bar_history = self.get_bar_history()
|
||||||
|
if current_pos_volume == 0 and len(bar_history) > 10:
|
||||||
|
|
||||||
|
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
|
||||||
|
bar_1_ago = bar_history[-2]
|
||||||
|
bar_7_ago = bar_history[-8]
|
||||||
|
|
||||||
|
# 计算历史 K 线的 Range
|
||||||
|
range_1_ago = bar_1_ago.high - bar_1_ago.low
|
||||||
|
range_7_ago = bar_7_ago.high - bar_7_ago.low
|
||||||
|
|
||||||
|
# 根据策略逻辑计算目标买入价格
|
||||||
|
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
|
||||||
|
self.log(
|
||||||
|
bar.open,
|
||||||
|
range_1_ago * self.open_range_factor_1_ago,
|
||||||
|
range_7_ago * self.open_range_factor_7_ago,
|
||||||
|
)
|
||||||
|
target_buy_price = bar.open + (
|
||||||
|
range_1_ago * self.open_range_factor_1_ago
|
||||||
|
+ range_7_ago * self.open_range_factor_7_ago
|
||||||
|
)
|
||||||
|
|
||||||
|
# 确保目标买入价格有效,例如不能是负数
|
||||||
|
target_buy_price = max(0.01, target_buy_price)
|
||||||
|
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, "
|
||||||
|
f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, "
|
||||||
|
f"计算目标买入价={target_buy_price:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="SELL",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="LIMIT",
|
||||||
|
limit_price=target_buy_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
)
|
||||||
|
new_order = self.send_order(order)
|
||||||
|
# 记录下这个订单的ID,以便在下一根K线开始时进行撤销
|
||||||
|
if new_order:
|
||||||
|
self._last_order_id = new_order.id
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(bar_history)}")
|
||||||
|
|
||||||
|
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||||
|
"""
|
||||||
|
在合约换月时清空历史K线数据和上次订单ID,避免使用旧合约数据进行计算。
|
||||||
|
"""
|
||||||
|
super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志
|
||||||
|
# bar_history.clear() # 清空历史K线
|
||||||
|
self._last_order_id = None # 清空上次订单ID,因为旧合约订单已取消
|
||||||
|
|
||||||
|
self.log(f"换月完成,清空历史K线数据和上次订单ID,准备新合约交易。")
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLimitBuyStrategy(Strategy):
|
||||||
|
"""
|
||||||
|
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
|
||||||
|
具备以下特点:
|
||||||
|
- 每根K线开始时取消上一根K线未成交的订单。
|
||||||
|
- 最多只能有一个开仓挂单和一个持仓。
|
||||||
|
- 包含简单的止损和止盈逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: Any,
|
||||||
|
symbol: str,
|
||||||
|
enable_log: bool,
|
||||||
|
trade_volume: int,
|
||||||
|
open_range_factor_1_long: float,
|
||||||
|
open_range_factor_7_long: float,
|
||||||
|
open_range_factor_1_short: float,
|
||||||
|
open_range_factor_7_short: float,
|
||||||
|
max_position: int,
|
||||||
|
stop_loss_points: float = 10, # 新增:止损点数
|
||||||
|
take_profit_points: float = 10,
|
||||||
|
): # 新增:止盈点数
|
||||||
|
"""
|
||||||
|
初始化策略。
|
||||||
|
Args:
|
||||||
|
context: 模拟器实例。
|
||||||
|
symbol (str): 交易合约代码。
|
||||||
|
trade_volume (int): 单笔交易量。
|
||||||
|
open_range_factor_1_ago (float): 前1根K线Range的权重因子,用于从Open价向下偏移。
|
||||||
|
open_range_factor_7_ago (float): 前7根K线Range的权重因子,用于从Open价向下偏移。
|
||||||
|
max_position (int): 最大持仓量(此处为1,因为只允许一个持仓)。
|
||||||
|
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
|
||||||
|
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
|
||||||
|
"""
|
||||||
|
super().__init__(context, symbol, enable_log)
|
||||||
|
self.last_buy_price = 0
|
||||||
|
self.last_sell_price = 0
|
||||||
|
self.trade_volume = trade_volume
|
||||||
|
self.open_range_factor_1_long = open_range_factor_1_long
|
||||||
|
self.open_range_factor_7_long = open_range_factor_7_long
|
||||||
|
self.open_range_factor_1_short = open_range_factor_1_short
|
||||||
|
self.open_range_factor_7_short = open_range_factor_7_short
|
||||||
|
self.max_position = max_position # 理论上这里应为1
|
||||||
|
self.stop_loss_points = stop_loss_points
|
||||||
|
self.take_profit_points = take_profit_points
|
||||||
|
|
||||||
|
self.order_id_counter = 0
|
||||||
|
|
||||||
|
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
|
||||||
|
|
||||||
|
self.log(
|
||||||
|
f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
|
||||||
|
f"max_position={self.max_position}, "
|
||||||
|
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_open_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
|
||||||
|
"""
|
||||||
|
每当新的K线数据到来时调用。
|
||||||
|
Args:
|
||||||
|
bar (Bar): 当前的K线数据对象。
|
||||||
|
next_bar_open (Optional[float]): 下一根K线的开盘价,此处策略未使用。
|
||||||
|
"""
|
||||||
|
current_datetime = bar.datetime # 获取当前K线时间
|
||||||
|
self.symbol = bar.symbol
|
||||||
|
|
||||||
|
# --- 1. 撤销上一根K线未成交的订单 ---
|
||||||
|
# 检查是否记录了上一笔订单ID,并且该订单仍然在待处理列表中
|
||||||
|
if self._last_order_id:
|
||||||
|
pending_orders = self.get_pending_orders()
|
||||||
|
# if self._last_order_id in pending_orders:
|
||||||
|
# success = self.cancel_order(self._last_order_id) # 直接调用基类的取消方法
|
||||||
|
# if success:
|
||||||
|
# self.log(f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}")
|
||||||
|
# else:
|
||||||
|
# self.log(f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)")
|
||||||
|
# # 无论撤销成功与否,既然我们尝试了撤销,就清除记录
|
||||||
|
# self._last_order_id = None
|
||||||
|
self.cancel_all_pending_orders()
|
||||||
|
# else:
|
||||||
|
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
|
||||||
|
|
||||||
|
# 2. 更新K线历史
|
||||||
|
trade_volume = self.trade_volume
|
||||||
|
|
||||||
|
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
|
||||||
|
current_positions = self.get_current_positions()
|
||||||
|
current_pos_volume = current_positions.get(self.symbol, 0)
|
||||||
|
pending_orders_after_cancel = (
|
||||||
|
self.get_pending_orders()
|
||||||
|
) # 再次获取,此时应已取消旧订单
|
||||||
|
|
||||||
|
# --- 3. 平仓逻辑 (止损/止盈) ---
|
||||||
|
# 只有当有持仓时才考虑平仓
|
||||||
|
self.log(current_positions, self.symbol)
|
||||||
|
if current_pos_volume < 0: # 假设只做多,所以持仓量 > 0
|
||||||
|
avg_entry_price = self.get_average_position_price(self.symbol)
|
||||||
|
avg_entry_price = self.last_sell_price
|
||||||
|
if avg_entry_price is not None:
|
||||||
|
pnl_per_unit = (
|
||||||
|
avg_entry_price - bar.open
|
||||||
|
) # 当前浮动盈亏(以收盘价计算)
|
||||||
|
|
||||||
|
# 止盈条件
|
||||||
|
if pnl_per_unit >= self.take_profit_points:
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="CLOSE_SHORT",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="MARKET",
|
||||||
|
# limit_price=limit_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
offset="CLOSE",
|
||||||
|
)
|
||||||
|
trade = self.send_order(order)
|
||||||
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
|
# 止损条件
|
||||||
|
elif pnl_per_unit <= -self.stop_loss_points:
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}"
|
||||||
|
)
|
||||||
|
# 发送市价卖出订单平仓,确保立即成交
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="CLOSE_SHORT",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="MARKET",
|
||||||
|
# limit_price=limit_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
offset="CLOSE",
|
||||||
|
)
|
||||||
|
trade = self.send_order(order)
|
||||||
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
|
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
|
||||||
|
avg_entry_price = self.get_average_position_price(self.symbol)
|
||||||
|
avg_entry_price = self.last_buy_price
|
||||||
|
if avg_entry_price is not None:
|
||||||
|
pnl_per_unit = (
|
||||||
|
bar.open - avg_entry_price
|
||||||
|
) # 当前浮动盈亏(以收盘价计算)
|
||||||
|
|
||||||
|
# 止盈条件
|
||||||
|
if pnl_per_unit >= self.take_profit_points:
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="CLOSE_LONG",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="MARKET",
|
||||||
|
# limit_price=limit_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
offset="CLOSE",
|
||||||
|
)
|
||||||
|
trade = self.send_order(order)
|
||||||
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
|
# 止损条件
|
||||||
|
elif pnl_per_unit <= -self.stop_loss_points:
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}"
|
||||||
|
)
|
||||||
|
# 发送市价卖出订单平仓,确保立即成交
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="CLOSE_LONG",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="MARKET",
|
||||||
|
# limit_price=limit_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
offset="CLOSE",
|
||||||
|
)
|
||||||
|
trade = self.send_order(order)
|
||||||
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
|
bar_history = self.get_bar_history()
|
||||||
|
if current_pos_volume == 0 and len(bar_history) > 10:
|
||||||
|
|
||||||
|
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
|
||||||
|
bar_1_ago = bar_history[-2]
|
||||||
|
bar_7_ago = bar_history[-8]
|
||||||
|
|
||||||
|
print(bar_1_ago, bar_7_ago)
|
||||||
|
|
||||||
|
# 计算历史 K 线的 Range
|
||||||
|
range_1_ago = bar_1_ago.high - bar_1_ago.low
|
||||||
|
range_7_ago = bar_7_ago.high - bar_7_ago.low
|
||||||
|
|
||||||
|
# 根据策略逻辑计算目标买入价格
|
||||||
|
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
|
||||||
|
target_buy_price = bar.open + (
|
||||||
|
range_1_ago * self.open_range_factor_1_short
|
||||||
|
+ range_7_ago * self.open_range_factor_7_short
|
||||||
|
)
|
||||||
|
|
||||||
|
# 确保目标买入价格有效,例如不能是负数
|
||||||
|
target_buy_price = max(0.01, target_buy_price)
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=bar.symbol,
|
||||||
|
direction="SELL",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="LIMIT",
|
||||||
|
limit_price=target_buy_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
)
|
||||||
|
self.last_sell_price = target_buy_price
|
||||||
|
new_order = self.send_order(order)
|
||||||
|
# 记录下这个订单的ID,以便在下一根K线开始时进行撤销
|
||||||
|
if new_order:
|
||||||
|
self._last_order_id = new_order.id
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 策略: 发送限价SELL订单 {self._last_order_id} @ {target_buy_price:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
|
||||||
|
|
||||||
|
target_buy_price = bar.open - (
|
||||||
|
range_1_ago * self.open_range_factor_1_long
|
||||||
|
+ range_7_ago * self.open_range_factor_7_long
|
||||||
|
)
|
||||||
|
# 确保目标买入价格有效,例如不能是负数
|
||||||
|
target_buy_price = max(0.01, target_buy_price)
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=bar.symbol,
|
||||||
|
direction="BUY",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="LIMIT",
|
||||||
|
limit_price=target_buy_price,
|
||||||
|
submitted_time=bar.datetime,
|
||||||
|
)
|
||||||
|
new_order = self.send_order(order)
|
||||||
|
self.last_buy_price = target_buy_price
|
||||||
|
# 记录下这个订单的ID,以便在下一根K线开始时进行撤销
|
||||||
|
if new_order:
|
||||||
|
self._last_order_id = new_order.id
|
||||||
|
self.log(
|
||||||
|
f"[{current_datetime}] 策略: 发送限价BUY订单 {self._last_order_id} @ {target_buy_price:.2f}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(bar_history)}")
|
||||||
|
|
||||||
|
|
||||||
|
def on_close_bar(self, bar):
|
||||||
|
self.log('on close bar!')
|
||||||
|
self.log(self.get_pending_orders())
|
||||||
|
self.cancel_all_pending_orders()
|
||||||
|
|
||||||
|
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||||
|
"""
|
||||||
|
在合约换月时清空历史K线数据和上次订单ID,避免使用旧合约数据进行计算。
|
||||||
|
"""
|
||||||
|
super().on_rollover(old_symbol, new_symbol) # 调用基类方法打印日志
|
||||||
|
self._last_order_id = None # 清空上次订单ID,因为旧合约订单已取消
|
||||||
|
|
||||||
|
self.log(f"换月完成,清空历史K线数据和上次订单ID,准备新合约交易。")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import math
|
||||||
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||||||
|
|
||||||
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
|
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
|
||||||
@@ -52,7 +53,7 @@ class Strategy(ABC):
|
|||||||
pass # 默认不执行任何操作,具体策略可覆盖
|
pass # 默认不执行任何操作,具体策略可覆盖
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_bar(self, bar: "Bar"):
|
def on_open_bar(self, bar: "Bar"):
|
||||||
"""
|
"""
|
||||||
每当新的K线数据到来时调用此方法。
|
每当新的K线数据到来时调用此方法。
|
||||||
Args:
|
Args:
|
||||||
@@ -61,6 +62,15 @@ class Strategy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_close_bar(self, bar: "Bar"):
|
||||||
|
"""
|
||||||
|
每当新的K线数据到来时调用此方法。
|
||||||
|
Args:
|
||||||
|
bar (Bar): 当前的K线数据对象。
|
||||||
|
next_bar_close (Optional[float]): 下一根K线的开盘价,如果存在的话。
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
# --- 新增/修改的辅助方法 ---
|
# --- 新增/修改的辅助方法 ---
|
||||||
|
|
||||||
def send_order(self, order: "Order") -> Optional[Order]:
|
def send_order(self, order: "Order") -> Optional[Order]:
|
||||||
@@ -71,6 +81,21 @@ class Strategy(ABC):
|
|||||||
if self.context.is_rollover_bar:
|
if self.context.is_rollover_bar:
|
||||||
self.log(f"当前是换月K线,禁止开仓订单")
|
self.log(f"当前是换月K线,禁止开仓订单")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if order.price_type == 'LIMIT':
|
||||||
|
limit_price = order.limit_price
|
||||||
|
if order.direction in ["BUY", "CLOSE_SHORT"]:
|
||||||
|
# 买入限价单(或平空),希望以更低或相等的价格成交,
|
||||||
|
# 所以向下取整,确保挂单价格不高于预期。
|
||||||
|
# 例如:价格100.3,tick_size=1 -> math.floor(100.3) = 100
|
||||||
|
# 价格100.8,tick_size=1 -> math.floor(100.8) = 100
|
||||||
|
order.limit_price = math.floor(limit_price)
|
||||||
|
elif order.direction in ["SELL", "CLOSE_LONG"]:
|
||||||
|
# 卖出限价单(或平多),希望以更高或相等的价格成交,
|
||||||
|
# 所以向上取整,确保挂单价格不低于预期。
|
||||||
|
# 例如:价格100.3,tick_size=1 -> math.ceil(100.3) = 101
|
||||||
|
# 价格100.8,tick_size=1 -> math.ceil(100.8) = 101
|
||||||
|
order.limit_price = math.ceil(limit_price)
|
||||||
return self.context.send_order(order)
|
return self.context.send_order(order)
|
||||||
|
|
||||||
def cancel_order(self, order_id: str) -> bool:
|
def cancel_order(self, order_id: str) -> bool:
|
||||||
@@ -96,23 +121,23 @@ class Strategy(ABC):
|
|||||||
|
|
||||||
def get_current_positions(self) -> Dict[str, int]:
|
def get_current_positions(self) -> Dict[str, int]:
|
||||||
"""获取所有当前持仓 (可能包含多个合约)。"""
|
"""获取所有当前持仓 (可能包含多个合约)。"""
|
||||||
return self.context._simulator.get_current_positions()
|
return self.context.get_current_positions()
|
||||||
|
|
||||||
def get_pending_orders(self) -> Dict[str, "Order"]:
|
def get_pending_orders(self) -> Dict[str, "Order"]:
|
||||||
"""获取所有当前待处理订单的副本 (可能包含多个合约)。"""
|
"""获取所有当前待处理订单的副本 (可能包含多个合约)。"""
|
||||||
return self.context._simulator.get_pending_orders()
|
return self.context.get_pending_orders()
|
||||||
|
|
||||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||||
"""获取指定合约的平均持仓成本。"""
|
"""获取指定合约的平均持仓成本。"""
|
||||||
return self.context._simulator.get_average_position_price(symbol)
|
return self.context.get_average_position_price(symbol)
|
||||||
|
|
||||||
def get_account_cash(self) -> float:
|
def get_account_cash(self) -> float:
|
||||||
"""获取当前账户现金余额。"""
|
"""获取当前账户现金余额。"""
|
||||||
return self.context._simulator.cash
|
return self.context.cash
|
||||||
|
|
||||||
def get_current_time(self) -> datetime:
|
def get_current_time(self) -> datetime:
|
||||||
"""获取模拟器当前时间。"""
|
"""获取模拟器当前时间。"""
|
||||||
return self.context._simulator.get_current_time()
|
return self.context.get_current_time()
|
||||||
|
|
||||||
def log(self, *args: Any, **kwargs: Any):
|
def log(self, *args: Any, **kwargs: Any):
|
||||||
"""
|
"""
|
||||||
@@ -123,7 +148,7 @@ class Strategy(ABC):
|
|||||||
if self.enable_log:
|
if self.enable_log:
|
||||||
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
|
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
|
||||||
try:
|
try:
|
||||||
current_time_str = self.context._simulator.get_current_time().strftime(
|
current_time_str = self.context.get_current_time().strftime(
|
||||||
"%Y-%m-%d %H:%M:%S"
|
"%Y-%m-%d %H:%M:%S"
|
||||||
)
|
)
|
||||||
time_prefix = f"[{current_time_str}] "
|
time_prefix = f"[{current_time_str}] "
|
||||||
@@ -151,3 +176,6 @@ class Strategy(ABC):
|
|||||||
self.log(f"合约换月事件: 从 {old_symbol} 切换到 {new_symbol}")
|
self.log(f"合约换月事件: 从 {old_symbol} 切换到 {new_symbol}")
|
||||||
# 默认实现可以为空,子类根据需要重写
|
# 默认实现可以为空,子类根据需要重写
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_bar_history(self):
|
||||||
|
return self.context.get_bar_history()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from .base_strategy import Strategy
|
|||||||
from ..core_data import Bar, Order, Trade
|
from ..core_data import Bar, Order, Trade
|
||||||
|
|
||||||
|
|
||||||
class SimpleLimitBuyStrategy(Strategy):
|
class TestStrategy(Strategy):
|
||||||
"""
|
"""
|
||||||
一个简单的限价买入策略:
|
一个简单的限价买入策略:
|
||||||
在每根Bar线上,如果当前没有持仓,且没有待处理的买入订单,则尝试下一个限价多单。
|
在每根Bar线上,如果当前没有持仓,且没有待处理的买入订单,则尝试下一个限价多单。
|
||||||
|
|||||||
202
src/tqsdk_context.py
Normal file
202
src/tqsdk_context.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
# filename: tqsdk_context.py
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Any, Dict, List, Literal, Deque, TYPE_CHECKING
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
# 导入你提供的 core_data 中的类型
|
||||||
|
from src.core_data import Bar, Order, Trade, PortfolioSnapshot # 确保此路径正确,如果core_data不在同级目录,需要调整
|
||||||
|
|
||||||
|
# 导入 Tqsdk 的核心类型
|
||||||
|
import tqsdk
|
||||||
|
from tqsdk import TqApi, TqAccount, tafunc
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# 使用 TYPE_CHECKING 避免循环导入,只在类型检查时导入 TqsdkEngine
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.tqsdk_engine import TqsdkEngine # 假设 TqsdkEngine 在 tqsdk_engine.py 中
|
||||||
|
|
||||||
|
|
||||||
|
class TqsdkContext:
|
||||||
|
"""
|
||||||
|
Tqsdk 回测上下文,适配原有 BacktestContext 接口。
|
||||||
|
策略通过此上下文与 Tqsdk 进行交互。
|
||||||
|
"""
|
||||||
|
def __init__(self, api: TqApi):
|
||||||
|
"""
|
||||||
|
初始化 Tqsdk 回测上下文。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api (TqApi): Tqsdk 的 TqApi 实例。
|
||||||
|
"""
|
||||||
|
self._api = api
|
||||||
|
self._current_bar: Optional[Bar] = None
|
||||||
|
self._engine: Optional['TqsdkEngine'] = None # 添加对引擎的引用,用于访问其状态或触发事件
|
||||||
|
|
||||||
|
# 用于缓存 Tqsdk 的 K 线序列,避免每次都 get_kline_serial
|
||||||
|
self._kline_serial: Dict[str, object] = {}
|
||||||
|
|
||||||
|
# 订单/取消请求队列,TqsdkEngine 会在异步循环中处理它们
|
||||||
|
self.order_queue: Deque[Order] = deque()
|
||||||
|
self.cancel_queue: Deque[str] = deque() # 存储 order_id
|
||||||
|
|
||||||
|
print("TqsdkContext: 初始化完成。")
|
||||||
|
|
||||||
|
def set_current_bar(self, bar: Bar):
|
||||||
|
"""
|
||||||
|
设置当前正在处理的 K 线数据。
|
||||||
|
由 TqsdkEngine 调用。
|
||||||
|
"""
|
||||||
|
self._current_bar = bar
|
||||||
|
|
||||||
|
def get_current_bar(self) -> Optional[Bar]:
|
||||||
|
"""
|
||||||
|
获取当前正在处理的 K 线数据。
|
||||||
|
策略可以通过此方法获取最新 K 线。
|
||||||
|
"""
|
||||||
|
return self._current_bar
|
||||||
|
|
||||||
|
def get_kline_data(self, symbol: str, duration_seconds: int, data_length: int = 10):
|
||||||
|
"""
|
||||||
|
获取指定合约的 K 线数据。
|
||||||
|
返回 Tqsdk 的 DataFrame 格式 K 线序列(TqKLine 对象),可以直接用于计算指标。
|
||||||
|
如果需要转换为你自己的 Bar 对象列表,则需要在此方法内部进行转换。
|
||||||
|
"""
|
||||||
|
if symbol not in self._kline_serial:
|
||||||
|
# 这里的 get_kline_serial 并不是实时获取,而是在 TqApi 启动时就已经加载
|
||||||
|
# 所以在 Context 中直接调用是安全的,TqApi 会返回已加载的数据引用
|
||||||
|
self._kline_serial[symbol] = self._api.get_kline_serial(symbol, duration_seconds, data_length=data_length)
|
||||||
|
return self._kline_serial[symbol]
|
||||||
|
|
||||||
|
def get_current_time(self) -> datetime:
|
||||||
|
"""
|
||||||
|
获取当前模拟时间(Tqsdk 的数据时间)。
|
||||||
|
"""
|
||||||
|
# Tqsdk 的 get_tick_timestamp() 返回微秒时间戳
|
||||||
|
return self.get_current_bar().datetime
|
||||||
|
|
||||||
|
def get_current_positions(self) -> Dict[str, int]:
|
||||||
|
"""
|
||||||
|
获取当前所有持仓。返回 {symbol: quantity} 的字典,quantity 为净持仓量(多头-空头)。
|
||||||
|
"""
|
||||||
|
tq_positions: Dict[str] = self._api.get_position()
|
||||||
|
converted_positions: Dict[str, int] = {}
|
||||||
|
for symbol, pos in tq_positions.items():
|
||||||
|
net_pos = pos.pos_long - pos.pos_short
|
||||||
|
if net_pos != 0:
|
||||||
|
converted_positions[symbol] = net_pos
|
||||||
|
return converted_positions
|
||||||
|
|
||||||
|
def get_pending_orders(self) -> Dict[str, Order]:
|
||||||
|
"""
|
||||||
|
获取当前所有待处理(未成交)订单。
|
||||||
|
返回 {order_id: Order} 的字典。
|
||||||
|
"""
|
||||||
|
tq_orders: Dict[str] = self._api.get_order()
|
||||||
|
pending_orders: Dict[str, Order] = {}
|
||||||
|
for order_id, tq_order in tq_orders.items():
|
||||||
|
if tq_order.status == "ALIVE": # 正在进行中的订单
|
||||||
|
# 将 TqOrder 转换为你自己的 core_data.Order 类型
|
||||||
|
# 注意:core_data.Order 的 direction 有 "CLOSE_LONG", "CLOSE_SHORT" 等,需要映射
|
||||||
|
# Tqsdk 的 direction 只有 "BUY", "SELL"
|
||||||
|
# Tqsdk 的 offset 决定了是开仓还是平仓
|
||||||
|
core_direction: Literal["BUY", "SELL", "CLOSE_LONG", "CLOSE_SHORT"]
|
||||||
|
if tq_order.offset == "OPEN":
|
||||||
|
core_direction = tq_order.direction # 开仓时方向直接对应买卖
|
||||||
|
elif tq_order.offset in ["CLOSE", "CLOSETODAY", "CLOSEYESTERDAY"]:
|
||||||
|
# 平仓时,买入平空,卖出平多
|
||||||
|
core_direction = "CLOSE_SHORT" if tq_order.direction == "BUY" else "CLOSE_LONG"
|
||||||
|
else: # 默认为 BUY/SELL
|
||||||
|
core_direction = tq_order.direction
|
||||||
|
|
||||||
|
|
||||||
|
converted_order = Order(
|
||||||
|
id=tq_order.order_id, # 将 Tqsdk 的 order_id 赋值给你的 Order 类的 id
|
||||||
|
symbol=tq_order.exchange_id + "." + tq_order.instrument_id, # 例如 "SHFE.rb2401"
|
||||||
|
direction=core_direction,
|
||||||
|
volume=tq_order.volume_orign,
|
||||||
|
price_type="LIMIT" if tq_order.limit_price is not None else "MARKET", # Tqsdk 市价单类型为 "ANY"
|
||||||
|
limit_price=tq_order.limit_price,
|
||||||
|
offset=tq_order.offset, # Tqsdk 原生 offset
|
||||||
|
# order_id=tq_order.order_id, # 存储 Tqsdk 的 order_id
|
||||||
|
submitted_time=pd.to_datetime(tq_order.insert_date_time, unit="ns", utc=True),
|
||||||
|
# status=tq_order.status # 保持 Tqsdk 的状态字符串
|
||||||
|
)
|
||||||
|
pending_orders[order_id] = converted_order
|
||||||
|
return pending_orders
|
||||||
|
|
||||||
|
def get_account_cash(self) -> float:
|
||||||
|
"""
|
||||||
|
获取当前可用现金。
|
||||||
|
"""
|
||||||
|
account: TqAccount = self._api.get_account()
|
||||||
|
return account.available_cash if account else 0.0
|
||||||
|
|
||||||
|
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
获取指定合约的平均持仓成本。
|
||||||
|
注意: Tqsdk 的 TqPosition 对象中包含了 open_price_long 和 open_price_short。
|
||||||
|
这里需要根据多头或空头持仓返回对应的平均成本。
|
||||||
|
"""
|
||||||
|
position = self._api.get_position(symbol)
|
||||||
|
# if position:
|
||||||
|
# return avg_cost
|
||||||
|
if position:
|
||||||
|
if position.pos_long > 0:
|
||||||
|
return position.open_price_long
|
||||||
|
elif position.pos_short > 0:
|
||||||
|
return position.open_price_short
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send_order(self, order: Order) -> Optional[Order]:
|
||||||
|
"""
|
||||||
|
策略通过此方法发送订单。
|
||||||
|
将订单放入队列,等待 TqsdkEngine 在其异步循环中处理。
|
||||||
|
"""
|
||||||
|
# 为订单分配一个临时ID,便于在队列中追踪,实际ID由Tqsdk返回后更新
|
||||||
|
if not order.id: # 使用 order.id 属性
|
||||||
|
order.id = f"LOCAL_{id(order)}_{datetime.now().strftime('%f')}"
|
||||||
|
order.order_id = order.id # 保持 Tqsdk 风格的 order_id 也一致
|
||||||
|
self.order_queue.append(order)
|
||||||
|
print(f"Context: 订单已加入队列: {order}")
|
||||||
|
return order # 返回传入的订单,待引擎更新其状态和ID
|
||||||
|
|
||||||
|
def cancel_order(self, order_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
策略通过此方法取消指定ID的订单。
|
||||||
|
将取消请求放入队列,等待 TqsdkEngine 在其异步循环中处理。
|
||||||
|
"""
|
||||||
|
# 检查订单是否处于待处理状态
|
||||||
|
if order_id in self.get_pending_orders():
|
||||||
|
self.cancel_queue.append(order_id)
|
||||||
|
print(f"Context: 取消订单请求已加入队列: {order_id}")
|
||||||
|
return True
|
||||||
|
print(f"Context: 订单 {order_id} 不在待处理队列中,无法取消。")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def set_engine(self, engine: 'TqsdkEngine'): # 使用 TYPE_CHECKING 中的 TqsdkEngine 类型提示
|
||||||
|
"""
|
||||||
|
设置对 TqsdkEngine 实例的引用。
|
||||||
|
由 TqsdkEngine 在初始化时调用,用于允许 Context 访问 Engine 的状态。
|
||||||
|
"""
|
||||||
|
self._engine = engine
|
||||||
|
print("TqsdkContext: 已设置引擎引用。")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_rollover_bar(self) -> bool:
|
||||||
|
"""
|
||||||
|
属性:判断当前 K 线是否为换月 K 线(即新合约的第一根 K 线)。
|
||||||
|
用于在换月时禁止策略开仓。
|
||||||
|
Tqsdk 的回测模式下,通常通过主力连续合约或多合约同时回测来处理换月。
|
||||||
|
此处为适配原有接口的简化实现。如果你需要 Tqsdk 的换月逻辑,
|
||||||
|
可能需要在 TqsdkEngine 中实现更复杂的判断,并通过 Context 暴露此状态。
|
||||||
|
对于 Tqsdk 的主力连续合约,通常不需要策略层面关心具体的换月 K 线。
|
||||||
|
"""
|
||||||
|
# 如果引擎设置了 is_rollover_bar 属性,则使用引擎的判断
|
||||||
|
if self._engine and hasattr(self._engine, 'is_rollover_bar'):
|
||||||
|
return self._engine.is_rollover_bar
|
||||||
|
return False # 默认返回 False
|
||||||
|
|
||||||
|
def get_bar_history(self):
|
||||||
|
return self._engine.get_bar_history()
|
||||||
491
src/tqsdk_engine.py
Normal file
491
src/tqsdk_engine.py
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
# filename: tqsdk_engine.py
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import date, datetime, timedelta
|
||||||
|
from typing import Literal, Type, Dict, Any, List, Optional
|
||||||
|
import pandas as pd
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# 导入你提供的 core_data 中的类型
|
||||||
|
from src.core_data import Bar, Order, Trade, PortfolioSnapshot
|
||||||
|
|
||||||
|
# 导入 Tqsdk 的核心类型
|
||||||
|
import tqsdk
|
||||||
|
from tqsdk import (
|
||||||
|
TqApi,
|
||||||
|
TqAccount,
|
||||||
|
tafunc,
|
||||||
|
TqSim,
|
||||||
|
TqBacktest,
|
||||||
|
TqAuth,
|
||||||
|
TargetPosTask,
|
||||||
|
BacktestFinished,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入 TqsdkContext 和 BaseStrategy
|
||||||
|
from src.tqsdk_context import TqsdkContext
|
||||||
|
from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径
|
||||||
|
|
||||||
|
BEIJING_TZ = "Asia/Shanghai"
|
||||||
|
|
||||||
|
|
||||||
|
class TqsdkEngine:
|
||||||
|
"""
|
||||||
|
Tqsdk 回测引擎:协调 Tqsdk 数据流、策略执行、订单模拟和结果记录。
|
||||||
|
替代原有的 BacktestEngine。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
strategy_class: Type[Strategy],
|
||||||
|
strategy_params: Dict[str, Any],
|
||||||
|
api: TqApi,
|
||||||
|
roll_over_mode: bool = False, # 是否开启换月模式检测
|
||||||
|
symbol: str = None,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
end_time: Optional[datetime] = None,
|
||||||
|
duration_seconds: int = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化 Tqsdk 回测引擎。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy_class (Type[Strategy]): 策略类。
|
||||||
|
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
|
||||||
|
data_path (str): 本地 K 线数据文件路径,用于 TqSim 加载。
|
||||||
|
initial_capital (float): 初始资金。
|
||||||
|
slippage_rate (float): 交易滑点率(在 Tqsdk 中通常需要手动实现或通过费用设置)。
|
||||||
|
commission_rate (float): 交易佣金率(在 Tqsdk 中通常需要手动实现或通过费用设置)。
|
||||||
|
roll_over_mode (bool): 是否启用换月检测。
|
||||||
|
start_time (Optional[datetime]): 回测开始时间。
|
||||||
|
end_time (Optional[datetime]): 回测结束时间。
|
||||||
|
"""
|
||||||
|
self.strategy_class = strategy_class
|
||||||
|
self.strategy_params = strategy_params
|
||||||
|
self.roll_over_mode = roll_over_mode
|
||||||
|
self.start_time = start_time
|
||||||
|
self.end_time = end_time
|
||||||
|
|
||||||
|
# Tqsdk API 和模拟器
|
||||||
|
# 这里使用 file_path 参数指定本地数据文件
|
||||||
|
self._api: TqApi = api
|
||||||
|
|
||||||
|
# 从策略参数中获取主symbol,TqsdkContext 需要知道它
|
||||||
|
self.symbol: str = strategy_params.get("symbol")
|
||||||
|
if not self.symbol:
|
||||||
|
raise ValueError("strategy_params 必须包含 'symbol' 字段")
|
||||||
|
|
||||||
|
# 获取 K 线数据(Tqsdk 自动处理)
|
||||||
|
# 这里假设策略所需 K 线周期在 strategy_params 中,否则默认60秒(1分钟K线)
|
||||||
|
self.bar_duration_seconds: int = strategy_params.get("bar_duration_seconds", 60)
|
||||||
|
# self._main_kline_serial = self._api.get_kline_serial(
|
||||||
|
# self.symbol, self.bar_duration_seconds
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 初始化上下文
|
||||||
|
self._context: TqsdkContext = TqsdkContext(api=self._api)
|
||||||
|
# 实例化策略,并将上下文传递给它
|
||||||
|
self._strategy: Strategy = self.strategy_class(
|
||||||
|
context=self._context, **self.strategy_params
|
||||||
|
)
|
||||||
|
self._context.set_engine(
|
||||||
|
self
|
||||||
|
) # 将引擎自身传递给上下文,以便 Context 可以访问引擎属性
|
||||||
|
|
||||||
|
self.portfolio_snapshots: List[PortfolioSnapshot] = []
|
||||||
|
self.trade_history: List[Trade] = []
|
||||||
|
self.all_bars: List[Bar] = [] # 收集所有处理过的Bar
|
||||||
|
|
||||||
|
self.last_processed_bar: Optional[Bar] = None
|
||||||
|
self._is_rollover_bar: bool = False # 换月信号
|
||||||
|
self._last_underlying_symbol = self.symbol # 用于检测主力合约换月
|
||||||
|
|
||||||
|
self.klines = api.get_kline_serial(symbol, duration_seconds)
|
||||||
|
self.klines_1min = api.get_kline_serial(symbol, 60)
|
||||||
|
self.now = None
|
||||||
|
self.quote = None
|
||||||
|
if roll_over_mode:
|
||||||
|
self.quote = api.get_quote(symbol)
|
||||||
|
|
||||||
|
print("TqsdkEngine: 初始化完成。")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_rollover_bar(self) -> bool:
|
||||||
|
"""
|
||||||
|
属性:判断当前 K 线是否为换月 K 线(即检测到主力合约切换)。
|
||||||
|
"""
|
||||||
|
return self._is_rollover_bar
|
||||||
|
|
||||||
|
def _process_queued_requests(self):
|
||||||
|
"""
|
||||||
|
异步处理 Context 中排队的订单和取消请求。
|
||||||
|
"""
|
||||||
|
# 处理订单
|
||||||
|
while self._context.order_queue:
|
||||||
|
order_to_send: Order = self._context.order_queue.popleft()
|
||||||
|
print(f"Engine: 处理订单请求: {order_to_send}")
|
||||||
|
|
||||||
|
# 映射 core_data.Order 到 Tqsdk 的订单参数
|
||||||
|
tqsdk_direction = ""
|
||||||
|
tqsdk_offset = ""
|
||||||
|
|
||||||
|
if order_to_send.direction == "BUY":
|
||||||
|
tqsdk_direction = "BUY"
|
||||||
|
tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓
|
||||||
|
elif order_to_send.direction == "SELL":
|
||||||
|
tqsdk_direction = "SELL"
|
||||||
|
tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓
|
||||||
|
elif order_to_send.direction == "CLOSE_LONG":
|
||||||
|
tqsdk_direction = "SELL"
|
||||||
|
tqsdk_offset = order_to_send.offset or "CLOSE" # 平多,默认平仓
|
||||||
|
elif order_to_send.direction == "CLOSE_SHORT":
|
||||||
|
tqsdk_direction = "BUY"
|
||||||
|
tqsdk_offset = order_to_send.offset or "CLOSE" # 平空,默认平仓
|
||||||
|
else:
|
||||||
|
print(f"Engine: 未知订单方向: {order_to_send.direction}")
|
||||||
|
continue # 跳过此订单
|
||||||
|
|
||||||
|
if "SHFE" in order_to_send.symbol:
|
||||||
|
tqsdk_offset = "OPEN"
|
||||||
|
|
||||||
|
try:
|
||||||
|
tq_order = self._api.insert_order(
|
||||||
|
symbol=order_to_send.symbol,
|
||||||
|
direction=tqsdk_direction,
|
||||||
|
offset=tqsdk_offset,
|
||||||
|
volume=order_to_send.volume,
|
||||||
|
# Tqsdk 市价单 limit_price 设为 None,限价单则传递价格
|
||||||
|
limit_price=(
|
||||||
|
order_to_send.limit_price
|
||||||
|
if order_to_send.price_type == "LIMIT"
|
||||||
|
# else self.quote.bid_price1 + (1 if tqsdk_direction == "BUY" else -1)
|
||||||
|
else self.quote.bid_price1 if tqsdk_direction == "SELL" else self.quote.ask_price1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# 更新原始 Order 对象与 Tqsdk 的订单ID和状态
|
||||||
|
order_to_send.id = tq_order.order_id
|
||||||
|
# order_to_send.order_id = tq_order.order_id
|
||||||
|
# order_to_send.status = tq_order.status
|
||||||
|
order_to_send.submitted_time = pd.to_datetime(
|
||||||
|
tq_order.insert_date_time, unit="ns", utc=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 等待订单状态更新(成交/撤销/报错)
|
||||||
|
# 在 Tqsdk 中,订单和成交是独立的,通常在 wait_update() 循环中通过 api.is_changing() 检查
|
||||||
|
# 这里为了模拟同步处理,直接等待订单状态最终确定
|
||||||
|
# 注意:实际回测中,不应在这里长时间阻塞,而应在主循环中持续 wait_update
|
||||||
|
# 为了简化适配,这里模拟即时处理,但可能与真实异步行为有差异。
|
||||||
|
# 更健壮的方式是在主循环中通过订单状态回调更新
|
||||||
|
# 这里我们假设订单会很快更新状态,或者在下一个 wait_update() 周期中被检测到
|
||||||
|
self._api.wait_update() # 等待一次更新
|
||||||
|
|
||||||
|
# # 检查最终订单状态和成交
|
||||||
|
# if tq_order.status == "FINISHED":
|
||||||
|
# # 查找对应的成交记录
|
||||||
|
# for trade_id, tq_trade in self._api.get_trade().items():
|
||||||
|
# if tq_trade.order_id == tq_order.order_id and tq_trade.volume > 0: # 确保是实际成交
|
||||||
|
# # 创建 core_data.Trade 对象
|
||||||
|
# trade = Trade(
|
||||||
|
# order_id=tq_trade.order_id,
|
||||||
|
# fill_time=tafunc.get_datetime_from_timestamp(tq_trade.trade_date_time) if tq_trade.trade_date_time else datetime.now(),
|
||||||
|
# symbol=order_to_send.symbol, # 使用 Context 中的 symbol
|
||||||
|
# direction=tq_trade.direction, # 实际成交方向
|
||||||
|
# volume=tq_trade.volume,
|
||||||
|
# price=tq_trade.price,
|
||||||
|
# commission=tq_trade.commission,
|
||||||
|
# cash_after_trade=self._api.get_account().available,
|
||||||
|
# positions_after_trade=self._context.get_current_positions(),
|
||||||
|
# realized_pnl=tq_trade.realized_pnl, # Tqsdk TqTrade 对象有 realized_pnl
|
||||||
|
# is_open_trade=tq_trade.offset == "OPEN",
|
||||||
|
# is_close_trade=tq_trade.offset in ["CLOSE", "CLOSETODAY", "CLOSEYESTERDAY"]
|
||||||
|
# )
|
||||||
|
# self.trade_history.append(trade)
|
||||||
|
# print(f"Engine: 成交记录: {trade}")
|
||||||
|
# break # 找到成交就跳出
|
||||||
|
# order_to_send.status = tq_order.status # 更新最终状态
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}")
|
||||||
|
# order_to_send.status = "ERROR"
|
||||||
|
|
||||||
|
# 处理取消请求
|
||||||
|
while self._context.cancel_queue:
|
||||||
|
order_id_to_cancel = self._context.cancel_queue.popleft()
|
||||||
|
print(f"Engine: 处理取消请求: {order_id_to_cancel}")
|
||||||
|
tq_order_to_cancel = self._api.get_order(order_id_to_cancel)
|
||||||
|
if tq_order_to_cancel and tq_order_to_cancel.status == "ALIVE":
|
||||||
|
try:
|
||||||
|
self._api.cancel_order(tq_order_to_cancel)
|
||||||
|
self._api.wait_update() # 等待取消确认
|
||||||
|
print(
|
||||||
|
f"Engine: 订单 {order_id_to_cancel} 已尝试取消。当前状态: {tq_order_to_cancel.status}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Engine: 取消订单 {order_id_to_cancel} 失败: {e}")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Engine: 订单 {order_id_to_cancel} 不存在或已非活动状态,无法取消。"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _record_portfolio_snapshot(self, current_time: datetime):
|
||||||
|
"""
|
||||||
|
记录当前投资组合的快照。
|
||||||
|
"""
|
||||||
|
account: TqAccount = self._api.get_account()
|
||||||
|
current_positions = self._context.get_current_positions()
|
||||||
|
|
||||||
|
# 计算当前持仓市值
|
||||||
|
total_market_value = 0.0
|
||||||
|
current_prices: Dict[str, float] = {}
|
||||||
|
for symbol, qty in current_positions.items():
|
||||||
|
# 获取当前合约的最新价格
|
||||||
|
quote = self._api.get_quote(symbol)
|
||||||
|
if quote.last_price: # 确保价格是最近的
|
||||||
|
price = quote.last_price
|
||||||
|
current_prices[symbol] = price
|
||||||
|
total_market_value += (
|
||||||
|
price * qty * quote.volume_multiple
|
||||||
|
) # volume_multiple 乘数
|
||||||
|
else:
|
||||||
|
# 如果没有最新价格,使用最近的K线收盘价作为估算
|
||||||
|
# 在实盘或连续回测中,通常会有最新的行情
|
||||||
|
print(f"警告: 未获取到 {symbol} 最新价格,可能影响净值计算。")
|
||||||
|
# 可以尝试从 K 线获取最近价格
|
||||||
|
kline = self._api.get_kline_serial(symbol, self.bar_duration_seconds)
|
||||||
|
if not kline.empty:
|
||||||
|
last_kline = kline.iloc[-2]
|
||||||
|
price = last_kline.close
|
||||||
|
current_prices[symbol] = price
|
||||||
|
total_market_value += (
|
||||||
|
price * qty * self._api.get_instrument(symbol).volume_multiple
|
||||||
|
) # 使用 instrument 的乘数
|
||||||
|
|
||||||
|
total_value = (
|
||||||
|
account.available + account.frozen_margin + total_market_value
|
||||||
|
) # Tqsdk 的 balance 已包含持仓市值和冻结资金
|
||||||
|
# Tqsdk 的 total_profit/balance 已经包含了所有盈亏和资金
|
||||||
|
|
||||||
|
snapshot = PortfolioSnapshot(
|
||||||
|
datetime=current_time,
|
||||||
|
total_value=account.balance, # Tqsdk 的 balance 包含了可用资金、冻结保证金和持仓市值
|
||||||
|
cash=account.available,
|
||||||
|
positions=current_positions,
|
||||||
|
price_at_snapshot=current_prices,
|
||||||
|
)
|
||||||
|
self.portfolio_snapshots.append(snapshot)
|
||||||
|
|
||||||
|
def _close_all_positions_at_end(self):
|
||||||
|
"""
|
||||||
|
回测结束时,平掉所有剩余持仓。
|
||||||
|
"""
|
||||||
|
current_positions = self._context.get_current_positions()
|
||||||
|
if not current_positions:
|
||||||
|
print("回测结束:没有需要平仓的持仓。")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("回测结束:开始平仓所有剩余持仓...")
|
||||||
|
for symbol, qty in current_positions.items():
|
||||||
|
order_direction: Literal["BUY", "SELL"]
|
||||||
|
if qty > 0: # 多头持仓,卖出平仓
|
||||||
|
order_direction = "SELL"
|
||||||
|
else: # 空头持仓,买入平仓
|
||||||
|
order_direction = "BUY"
|
||||||
|
|
||||||
|
TargetPosTask(self._api, symbol).set_target_volume(0)
|
||||||
|
|
||||||
|
# # 使用市价单快速平仓
|
||||||
|
# tq_order = self._api.insert_order(
|
||||||
|
# symbol=symbol,
|
||||||
|
# direction=order_direction,
|
||||||
|
# offset="CLOSE", # 平仓
|
||||||
|
# volume=abs(qty),
|
||||||
|
# limit_price=self
|
||||||
|
# )
|
||||||
|
# print(f"平仓订单已发送: {symbol} {order_direction} {abs(qty)} 手")
|
||||||
|
# 等待订单完成
|
||||||
|
# while tq_order.status == "ALIVE":
|
||||||
|
# self._api.wait_update()
|
||||||
|
|
||||||
|
# if tq_order.status == "FINISHED":
|
||||||
|
# print(f"订单 {tq_order.order_id} 平仓完成。")
|
||||||
|
# else:
|
||||||
|
# print(f"订单 {tq_order.order_id} 平仓失败或未完成,状态: {tq_order.status}")
|
||||||
|
|
||||||
|
def _run_backtest_async(self):
|
||||||
|
"""
|
||||||
|
异步运行回测的主循环。
|
||||||
|
"""
|
||||||
|
print(f"TqsdkEngine: 开始运行回测,从 {self.start_time} 到 {self.end_time}")
|
||||||
|
|
||||||
|
# 初始化策略 (如果策略有 on_init 方法)
|
||||||
|
if hasattr(self._strategy, "on_init"):
|
||||||
|
self._strategy.on_init()
|
||||||
|
|
||||||
|
last_bar_datetime = None
|
||||||
|
|
||||||
|
# 迭代 K 线数据
|
||||||
|
# 使用 self._api.get_kline_serial 获取到的 K 线是 Pandas DataFrame,
|
||||||
|
# 直接迭代其行(Bar)更符合回测逻辑
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Tqsdk API 的 wait_update() 确保数据更新
|
||||||
|
self._api.wait_update()
|
||||||
|
|
||||||
|
if self.roll_over_mode and (
|
||||||
|
self._api.is_changing(self.quote, "underlying_symbol")
|
||||||
|
or self._last_underlying_symbol != self.quote.underlying_symbol
|
||||||
|
):
|
||||||
|
self._last_underlying_symbol = self.quote.underlying_symbol
|
||||||
|
|
||||||
|
if self._api.is_changing(self.klines_1min):
|
||||||
|
now_kline = self.klines_1min.iloc[-1]
|
||||||
|
now_dt = pd.to_datetime(now_kline.datetime, unit="ns", utc=True)
|
||||||
|
now_dt = now_dt.tz_convert(BEIJING_TZ)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.now is not None
|
||||||
|
and self.now.hour != 13
|
||||||
|
and now_dt.hour == 13
|
||||||
|
):
|
||||||
|
self.main()
|
||||||
|
|
||||||
|
self.now = now_dt
|
||||||
|
|
||||||
|
if self._api.is_changing(self.klines):
|
||||||
|
kline_row = self.klines.iloc[-1]
|
||||||
|
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
|
||||||
|
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
|
||||||
|
if kline_dt.hour == 13 and self.now.hour == 11:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self.main()
|
||||||
|
|
||||||
|
except BacktestFinished:
|
||||||
|
# 回测结束时,确保所有排队请求得到处理
|
||||||
|
self._process_queued_requests()
|
||||||
|
|
||||||
|
# 回测结束后,如果需要,平掉所有剩余持仓
|
||||||
|
self._close_all_positions_at_end()
|
||||||
|
|
||||||
|
print("TqsdkEngine: 回测运行完毕。")
|
||||||
|
|
||||||
|
def main(self):
|
||||||
|
kline_row = self.klines.iloc[-1]
|
||||||
|
|
||||||
|
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
|
||||||
|
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.last_processed_bar is None
|
||||||
|
or self.last_processed_bar.datetime != kline_dt
|
||||||
|
):
|
||||||
|
# 创建 core_data.Bar 对象
|
||||||
|
current_bar = Bar(
|
||||||
|
datetime=kline_dt,
|
||||||
|
symbol=self._last_underlying_symbol,
|
||||||
|
open=kline_row.open,
|
||||||
|
high=kline_row.high,
|
||||||
|
low=kline_row.low,
|
||||||
|
close=kline_row.close,
|
||||||
|
volume=kline_row.volume,
|
||||||
|
open_oi=kline_row.open_oi,
|
||||||
|
close_oi=kline_row.close_oi,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 设置当前 Bar 到 Context
|
||||||
|
self._context.set_current_bar(current_bar)
|
||||||
|
|
||||||
|
# Tqsdk 的 is_changing 用于判断数据是否有变化,对于回测遍历 K 线,每次迭代都算作新 Bar
|
||||||
|
# 如果 kline_row.datetime 与上次不同,则认为是新 Bar
|
||||||
|
if (
|
||||||
|
self.roll_over_mode
|
||||||
|
and self.last_processed_bar is not None
|
||||||
|
and self._last_underlying_symbol != self.last_processed_bar.symbol
|
||||||
|
):
|
||||||
|
self._is_rollover_bar = True
|
||||||
|
print(
|
||||||
|
f"TqsdkEngine: 检测到换月信号!从 {self._last_underlying_symbol} 切换到 {self.quote.underlying_symbol}"
|
||||||
|
)
|
||||||
|
self._close_all_positions_at_end()
|
||||||
|
|
||||||
|
self._strategy.cancel_all_pending_orders()
|
||||||
|
|
||||||
|
self._strategy.on_rollover(
|
||||||
|
self.last_processed_bar.symbol, self._last_underlying_symbol
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._is_rollover_bar = False
|
||||||
|
|
||||||
|
self.all_bars.append(current_bar)
|
||||||
|
self.last_processed_bar = current_bar
|
||||||
|
|
||||||
|
|
||||||
|
# 调用策略的 on_bar 方法
|
||||||
|
self._strategy.on_open_bar(current_bar)
|
||||||
|
|
||||||
|
# 处理订单和取消请求
|
||||||
|
self._process_queued_requests()
|
||||||
|
|
||||||
|
# 记录投资组合快照
|
||||||
|
self._record_portfolio_snapshot(current_bar.datetime)
|
||||||
|
else:
|
||||||
|
# 创建 core_data.Bar 对象
|
||||||
|
current_bar = Bar(
|
||||||
|
datetime=kline_dt,
|
||||||
|
symbol=self._last_underlying_symbol,
|
||||||
|
open=kline_row.open,
|
||||||
|
high=kline_row.high,
|
||||||
|
low=kline_row.low,
|
||||||
|
close=kline_row.close,
|
||||||
|
volume=kline_row.volume,
|
||||||
|
open_oi=kline_row.open_oi,
|
||||||
|
close_oi=kline_row.close_oi,
|
||||||
|
)
|
||||||
|
self.all_bars[-1] = current_bar
|
||||||
|
self.last_processed_bar = current_bar
|
||||||
|
|
||||||
|
# 设置当前 Bar 到 Context
|
||||||
|
self._context.set_current_bar(current_bar)
|
||||||
|
|
||||||
|
# 调用策略的 on_bar 方法
|
||||||
|
self._strategy.on_close_bar(current_bar)
|
||||||
|
|
||||||
|
# 处理订单和取消请求
|
||||||
|
self._process_queued_requests()
|
||||||
|
|
||||||
|
def run_backtest(self):
|
||||||
|
"""
|
||||||
|
同步调用异步回测主循环。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._run_backtest_async()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n回测被用户中断。")
|
||||||
|
finally:
|
||||||
|
self._api.close()
|
||||||
|
print("TqsdkEngine: API 已关闭。")
|
||||||
|
|
||||||
|
def get_backtest_results(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
返回回测结果数据,供结果分析模块使用。
|
||||||
|
"""
|
||||||
|
final_portfolio_value = 0.0
|
||||||
|
if self.portfolio_snapshots:
|
||||||
|
final_portfolio_value = self.portfolio_snapshots[-1].total_value
|
||||||
|
# else:
|
||||||
|
# final_portfolio_value = self.initial_capital # 如果没有快照,则净值是初始资金
|
||||||
|
|
||||||
|
# total_return_percentage = (
|
||||||
|
# (final_portfolio_value - self.initial_capital) / self.initial_capital
|
||||||
|
# ) * 100 if self.initial_capital != 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"portfolio_snapshots": self.portfolio_snapshots,
|
||||||
|
"trade_history": self.trade_history,
|
||||||
|
# "initial_capital": self.initial_capital,
|
||||||
|
"all_bars": self.all_bars,
|
||||||
|
"final_portfolio_value": final_portfolio_value,
|
||||||
|
# "total_return_percentage": total_return_percentage,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_bar_history(self):
|
||||||
|
return self.all_bars
|
||||||
11330
tqsdk_main.ipynb
Normal file
11330
tqsdk_main.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user