Files
NewQuant/src/tqsdk_engine.py

501 lines
20 KiB
Python
Raw Normal View History

2025-06-29 12:03:43 +08:00
# filename: tqsdk_engine.py
import asyncio
2025-11-07 16:26:00 +08:00
import traceback
2025-06-29 12:03:43 +08:00
from datetime import date, datetime, timedelta
from typing import Literal, Type, Dict, Any, List, Optional
import pandas as pd
import uuid
2025-11-07 16:26:00 +08:00
from src.common_utils import generate_strategy_identifier
2025-06-29 12:03:43 +08:00
# 导入你提供的 core_data 中的类型
from src.core_data import Bar, Order, Trade, PortfolioSnapshot
# 导入 Tqsdk 的核心类型
import tqsdk
from tqsdk import (
TqApi,
TqAccount,
tafunc,
TqSim,
TqBacktest,
TqAuth,
TargetPosTask,
BacktestFinished,
)
2025-11-07 16:26:00 +08:00
from src.state_repo import MemoryStateRepository
2025-06-29 12:03:43 +08:00
# 导入 TqsdkContext 和 BaseStrategy
from src.tqsdk_context import TqsdkContext
from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径
BEIJING_TZ = "Asia/Shanghai"
class TqsdkEngine:
"""
Tqsdk 回测引擎协调 Tqsdk 数据流策略执行订单模拟和结果记录
替代原有的 BacktestEngine
"""
def __init__(
2025-11-07 16:26:00 +08:00
self,
strategy_class: Type[Strategy],
strategy_params: Dict[str, Any],
api: TqApi,
roll_over_mode: bool = False, # 是否开启换月模式检测
symbol: str = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
duration_seconds: int = 1,
2025-06-29 12:03:43 +08:00
):
"""
初始化 Tqsdk 回测引擎
Args:
strategy_class (Type[Strategy]): 策略类
strategy_params (Dict[str, Any]): 传递给策略的参数字典
data_path (str): 本地 K 线数据文件路径用于 TqSim 加载
initial_capital (float): 初始资金
slippage_rate (float): 交易滑点率 Tqsdk 中通常需要手动实现或通过费用设置
commission_rate (float): 交易佣金率 Tqsdk 中通常需要手动实现或通过费用设置
roll_over_mode (bool): 是否启用换月检测
start_time (Optional[datetime]): 回测开始时间
end_time (Optional[datetime]): 回测结束时间
"""
self.strategy_class = strategy_class
self.strategy_params = strategy_params
self.roll_over_mode = roll_over_mode
self.start_time = start_time
self.end_time = end_time
# Tqsdk API 和模拟器
# 这里使用 file_path 参数指定本地数据文件
self._api: TqApi = api
# 从策略参数中获取主symbolTqsdkContext 需要知道它
2025-09-16 09:59:38 +08:00
self.symbol: str = symbol.replace('_', '.')
2025-06-29 12:03:43 +08:00
if not self.symbol:
raise ValueError("strategy_params 必须包含 'symbol' 字段")
# 获取 K 线数据Tqsdk 自动处理)
# 这里假设策略所需 K 线周期在 strategy_params 中否则默认60秒1分钟K线
self.bar_duration_seconds: int = strategy_params.get("bar_duration_seconds", 60)
# self._main_kline_serial = self._api.get_kline_serial(
# self.symbol, self.bar_duration_seconds
# )
# 初始化上下文
2025-11-07 16:26:00 +08:00
identifier = generate_strategy_identifier(strategy_class, strategy_params)
self._context: TqsdkContext = TqsdkContext(api=self._api, state_repository=MemoryStateRepository(identifier))
2025-06-29 12:03:43 +08:00
# 实例化策略,并将上下文传递给它
self._strategy: Strategy = self.strategy_class(
context=self._context, **self.strategy_params
)
self._context.set_engine(
self
) # 将引擎自身传递给上下文,以便 Context 可以访问引擎属性
self.portfolio_snapshots: List[PortfolioSnapshot] = []
self.trade_history: List[Trade] = []
self.all_bars: List[Bar] = [] # 收集所有处理过的Bar
2025-07-10 15:07:31 +08:00
self.close_list: List[float] = []
self.open_list: List[float] = []
self.high_list: List[float] = []
self.low_list: List[float] = []
self.volume_list: List[float] = []
2025-06-29 12:03:43 +08:00
self.last_processed_bar: Optional[Bar] = None
self._is_rollover_bar: bool = False # 换月信号
self._last_underlying_symbol = self.symbol # 用于检测主力合约换月
2025-09-16 09:59:38 +08:00
self.klines = api.get_kline_serial(self.symbol, duration_seconds)
self.klines_1min = api.get_kline_serial(self.symbol, 60)
2025-06-29 12:03:43 +08:00
self.now = None
self.quote = None
if roll_over_mode:
2025-09-16 09:59:38 +08:00
self.quote = api.get_quote(self.symbol)
2025-06-29 12:03:43 +08:00
2025-11-07 16:26:00 +08:00
self.target_pos_dict = {}
2025-06-29 12:03:43 +08:00
print("TqsdkEngine: 初始化完成。")
@property
def is_rollover_bar(self) -> bool:
"""
属性判断当前 K 线是否为换月 K 线即检测到主力合约切换
"""
return self._is_rollover_bar
def _process_queued_requests(self):
"""
异步处理 Context 中排队的订单和取消请求
"""
# 处理订单
while self._context.order_queue:
order_to_send: Order = self._context.order_queue.popleft()
print(f"Engine: 处理订单请求: {order_to_send}")
# 映射 core_data.Order 到 Tqsdk 的订单参数
tqsdk_direction = ""
tqsdk_offset = ""
if order_to_send.direction == "BUY":
tqsdk_direction = "BUY"
tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓
elif order_to_send.direction == "SELL":
tqsdk_direction = "SELL"
tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓
elif order_to_send.direction == "CLOSE_LONG":
tqsdk_direction = "SELL"
tqsdk_offset = order_to_send.offset or "CLOSE" # 平多,默认平仓
elif order_to_send.direction == "CLOSE_SHORT":
tqsdk_direction = "BUY"
tqsdk_offset = order_to_send.offset or "CLOSE" # 平空,默认平仓
else:
print(f"Engine: 未知订单方向: {order_to_send.direction}")
continue # 跳过此订单
if "SHFE" in order_to_send.symbol:
tqsdk_offset = "OPEN"
2025-11-07 16:26:00 +08:00
if "CLOSE" in order_to_send.direction:
current_positions = self._context.get_current_positions()
current_pos_volume = current_positions.get(order_to_send.symbol, 0)
target_volume = None
if order_to_send.direction == 'CLOSE_LONG':
target_volume = current_pos_volume - order_to_send.volume
elif order_to_send.direction == 'CLOSE_SHORT':
target_volume = current_pos_volume + order_to_send.volume
if target_volume is not None:
if order_to_send.symbol not in self.target_pos_dict:
self.target_pos_dict[order_to_send.symbol] = TargetPosTask(self._api, order_to_send.symbol)
self.target_pos_dict[order_to_send.symbol].set_target_volume(target_volume)
else:
# try:
2025-06-29 12:03:43 +08:00
tq_order = self._api.insert_order(
symbol=order_to_send.symbol,
direction=tqsdk_direction,
offset=tqsdk_offset,
volume=order_to_send.volume,
# Tqsdk 市价单 limit_price 设为 None限价单则传递价格
limit_price=(
order_to_send.limit_price
if order_to_send.price_type == "LIMIT"
# else self.quote.bid_price1 + (1 if tqsdk_direction == "BUY" else -1)
2025-07-15 22:45:51 +08:00
else (
self.quote.bid_price1
if tqsdk_direction == "SELL"
else self.quote.ask_price1
)
2025-06-29 12:03:43 +08:00
),
)
# 更新原始 Order 对象与 Tqsdk 的订单ID和状态
order_to_send.id = tq_order.order_id
# order_to_send.order_id = tq_order.order_id
# order_to_send.status = tq_order.status
order_to_send.submitted_time = pd.to_datetime(
tq_order.insert_date_time, unit="ns", utc=True
)
self._api.wait_update() # 等待一次更新
2025-11-07 16:26:00 +08:00
#
# except Exception as e:
# print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}")
2025-06-29 12:03:43 +08:00
# 处理取消请求
while self._context.cancel_queue:
order_id_to_cancel = self._context.cancel_queue.popleft()
print(f"Engine: 处理取消请求: {order_id_to_cancel}")
tq_order_to_cancel = self._api.get_order(order_id_to_cancel)
if tq_order_to_cancel and tq_order_to_cancel.status == "ALIVE":
try:
self._api.cancel_order(tq_order_to_cancel)
self._api.wait_update() # 等待取消确认
print(
f"Engine: 订单 {order_id_to_cancel} 已尝试取消。当前状态: {tq_order_to_cancel.status}"
)
except Exception as e:
print(f"Engine: 取消订单 {order_id_to_cancel} 失败: {e}")
else:
print(
f"Engine: 订单 {order_id_to_cancel} 不存在或已非活动状态,无法取消。"
)
def _record_portfolio_snapshot(self, current_time: datetime):
"""
记录当前投资组合的快照
"""
account: TqAccount = self._api.get_account()
current_positions = self._context.get_current_positions()
# 计算当前持仓市值
total_market_value = 0.0
current_prices: Dict[str, float] = {}
for symbol, qty in current_positions.items():
# 获取当前合约的最新价格
quote = self._api.get_quote(symbol)
if quote.last_price: # 确保价格是最近的
price = quote.last_price
current_prices[symbol] = price
total_market_value += (
2025-11-07 16:26:00 +08:00
price * qty * quote.volume_multiple
2025-06-29 12:03:43 +08:00
) # volume_multiple 乘数
else:
# 如果没有最新价格使用最近的K线收盘价作为估算
# 在实盘或连续回测中,通常会有最新的行情
print(f"警告: 未获取到 {symbol} 最新价格,可能影响净值计算。")
# 可以尝试从 K 线获取最近价格
kline = self._api.get_kline_serial(symbol, self.bar_duration_seconds)
if not kline.empty:
last_kline = kline.iloc[-2]
price = last_kline.close
current_prices[symbol] = price
total_market_value += (
2025-11-07 16:26:00 +08:00
price * qty * self._api.get_instrument(symbol).volume_multiple
2025-06-29 12:03:43 +08:00
) # 使用 instrument 的乘数
total_value = (
2025-11-07 16:26:00 +08:00
account.available + account.frozen_margin + total_market_value
2025-06-29 12:03:43 +08:00
) # Tqsdk 的 balance 已包含持仓市值和冻结资金
# Tqsdk 的 total_profit/balance 已经包含了所有盈亏和资金
snapshot = PortfolioSnapshot(
datetime=current_time,
total_value=account.balance, # Tqsdk 的 balance 包含了可用资金、冻结保证金和持仓市值
cash=account.available,
positions=current_positions,
price_at_snapshot=current_prices,
)
self.portfolio_snapshots.append(snapshot)
def _close_all_positions_at_end(self):
"""
回测结束时平掉所有剩余持仓
"""
current_positions = self._context.get_current_positions()
if not current_positions:
print("回测结束:没有需要平仓的持仓。")
return
print("回测结束:开始平仓所有剩余持仓...")
for symbol, qty in current_positions.items():
order_direction: Literal["BUY", "SELL"]
if qty > 0: # 多头持仓,卖出平仓
order_direction = "SELL"
else: # 空头持仓,买入平仓
order_direction = "BUY"
TargetPosTask(self._api, symbol).set_target_volume(0)
# # 使用市价单快速平仓
# tq_order = self._api.insert_order(
# symbol=symbol,
# direction=order_direction,
# offset="CLOSE", # 平仓
# volume=abs(qty),
# limit_price=self
# )
# print(f"平仓订单已发送: {symbol} {order_direction} {abs(qty)} 手")
# 等待订单完成
# while tq_order.status == "ALIVE":
# self._api.wait_update()
# if tq_order.status == "FINISHED":
# print(f"订单 {tq_order.order_id} 平仓完成。")
# else:
# print(f"订单 {tq_order.order_id} 平仓失败或未完成,状态: {tq_order.status}")
def _run_backtest_async(self):
"""
异步运行回测的主循环
"""
print(f"TqsdkEngine: 开始运行回测,从 {self.start_time}{self.end_time}")
2025-07-10 15:07:31 +08:00
self._strategy.trading = True
2025-06-29 12:03:43 +08:00
# 初始化策略 (如果策略有 on_init 方法)
if hasattr(self._strategy, "on_init"):
self._strategy.on_init()
last_bar_datetime = None
# 迭代 K 线数据
# 使用 self._api.get_kline_serial 获取到的 K 线是 Pandas DataFrame
# 直接迭代其行Bar更符合回测逻辑
try:
while True:
# Tqsdk API 的 wait_update() 确保数据更新
self._api.wait_update()
if self.roll_over_mode and (
2025-11-07 16:26:00 +08:00
self._api.is_changing(self.quote, "underlying_symbol")
or self._last_underlying_symbol != self.quote.underlying_symbol
2025-06-29 12:03:43 +08:00
):
self._last_underlying_symbol = self.quote.underlying_symbol
if self._api.is_changing(self.klines_1min):
now_kline = self.klines_1min.iloc[-1]
now_dt = pd.to_datetime(now_kline.datetime, unit="ns", utc=True)
now_dt = now_dt.tz_convert(BEIJING_TZ)
if (
2025-11-07 16:26:00 +08:00
self.now is not None
and self.now.hour != 13
and now_dt.hour == 13
2025-06-29 12:03:43 +08:00
):
self.main()
self.now = now_dt
2025-07-15 22:45:51 +08:00
if self._api.is_changing(self.klines.iloc[-1]):
2025-06-29 12:03:43 +08:00
kline_row = self.klines.iloc[-1]
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
if kline_dt.hour == 13 and self.now.hour == 11:
continue
else:
self.main()
except BacktestFinished:
# 回测结束时,确保所有排队请求得到处理
self._process_queued_requests()
# 回测结束后,如果需要,平掉所有剩余持仓
self._close_all_positions_at_end()
print("TqsdkEngine: 回测运行完毕。")
def main(self):
kline_row = self.klines.iloc[-1]
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
2025-07-15 22:45:51 +08:00
current_bar = Bar(
datetime=kline_dt,
symbol=self._last_underlying_symbol,
open=kline_row.open,
high=kline_row.high,
low=kline_row.low,
close=kline_row.close,
volume=kline_row.volume,
open_oi=kline_row.open_oi,
close_oi=kline_row.close_oi,
)
2025-06-29 12:03:43 +08:00
if (
2025-11-07 16:26:00 +08:00
self.last_processed_bar is None
or self.last_processed_bar.datetime != kline_dt
2025-06-29 12:03:43 +08:00
):
# 设置当前 Bar 到 Context
self._context.set_current_bar(current_bar)
# Tqsdk 的 is_changing 用于判断数据是否有变化,对于回测遍历 K 线,每次迭代都算作新 Bar
# 如果 kline_row.datetime 与上次不同,则认为是新 Bar
if (
2025-11-07 16:26:00 +08:00
self.roll_over_mode
and self.last_processed_bar is not None
and self._last_underlying_symbol != self.last_processed_bar.symbol
2025-06-29 12:03:43 +08:00
):
self._is_rollover_bar = True
print(
f"TqsdkEngine: 检测到换月信号!从 {self._last_underlying_symbol} 切换到 {self.quote.underlying_symbol}"
)
self._close_all_positions_at_end()
self._strategy.cancel_all_pending_orders()
self._strategy.on_rollover(
self.last_processed_bar.symbol, self._last_underlying_symbol
)
else:
self._is_rollover_bar = False
self.last_processed_bar = current_bar
# 调用策略的 on_bar 方法
2025-07-15 22:45:51 +08:00
self._strategy.on_open_bar(current_bar.open, current_bar.symbol)
2025-06-29 12:03:43 +08:00
# 处理订单和取消请求
self._process_queued_requests()
# 记录投资组合快照
self._record_portfolio_snapshot(current_bar.datetime)
else:
2025-11-07 16:26:00 +08:00
if current_bar.volume == 0:
return
2025-07-15 22:45:51 +08:00
self.all_bars.append(current_bar)
self.close_list.append(current_bar.close)
self.open_list.append(current_bar.open)
self.high_list.append(current_bar.high)
self.low_list.append(current_bar.low)
self.volume_list.append(current_bar.volume)
2025-07-10 15:07:31 +08:00
2025-06-29 12:03:43 +08:00
self.last_processed_bar = current_bar
# 设置当前 Bar 到 Context
self._context.set_current_bar(current_bar)
# 调用策略的 on_bar 方法
self._strategy.on_close_bar(current_bar)
# 处理订单和取消请求
self._process_queued_requests()
def run_backtest(self):
"""
同步调用异步回测主循环
"""
try:
self._run_backtest_async()
except KeyboardInterrupt:
print("\n回测被用户中断。")
finally:
self._api.close()
print("TqsdkEngine: API 已关闭。")
def get_backtest_results(self) -> Dict[str, Any]:
"""
返回回测结果数据供结果分析模块使用
"""
final_portfolio_value = 0.0
if self.portfolio_snapshots:
final_portfolio_value = self.portfolio_snapshots[-1].total_value
# else:
# final_portfolio_value = self.initial_capital # 如果没有快照,则净值是初始资金
# total_return_percentage = (
# (final_portfolio_value - self.initial_capital) / self.initial_capital
# ) * 100 if self.initial_capital != 0 else 0.0
return {
"portfolio_snapshots": self.portfolio_snapshots,
"trade_history": self.trade_history,
# "initial_capital": self.initial_capital,
"all_bars": self.all_bars,
"final_portfolio_value": final_portfolio_value,
# "total_return_percentage": total_return_percentage,
}
def get_bar_history(self):
return self.all_bars
2025-07-15 22:45:51 +08:00
2025-07-10 15:07:31 +08:00
def get_price_history(self, key: str):
2025-07-15 22:45:51 +08:00
if key == "close":
2025-07-10 15:07:31 +08:00
return self.close_list
2025-07-15 22:45:51 +08:00
elif key == "open":
2025-07-10 15:07:31 +08:00
return self.open_list
2025-07-15 22:45:51 +08:00
elif key == "high":
2025-07-10 15:07:31 +08:00
return self.high_list
2025-07-15 22:45:51 +08:00
elif key == "low":
2025-07-10 15:07:31 +08:00
return self.low_list
2025-07-15 22:45:51 +08:00
elif key == "volume":
2025-07-10 15:07:31 +08:00
return self.volume_list