# filename: tqsdk_engine.py import asyncio from datetime import date, datetime, timedelta from typing import Literal, Type, Dict, Any, List, Optional import pandas as pd import time # 导入你提供的 core_data 中的类型 from src.common_utils import is_bar_pre_close_period, is_futures_trading_time, generate_strategy_identifier from src.core_data import Bar, Order, Trade, PortfolioSnapshot # 导入 Tqsdk 的核心类型 import tqsdk from tqsdk import ( TqApi, TqAccount, tafunc, TqSim, TqBacktest, TqAuth, TargetPosTask, BacktestFinished, ) from src.state_repo import JsonFileStateRepository # 导入 TqsdkContext 和 BaseStrategy from src.tqsdk_real_context import TqsdkContext from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径 BEIJING_TZ = "Asia/Shanghai" class TqsdkEngine: """ Tqsdk 回测引擎:协调 Tqsdk 数据流、策略执行、订单模拟和结果记录。 替代原有的 BacktestEngine。 """ def __init__( self, strategy_class: Type[Strategy], strategy_params: Dict[str, Any], api: TqApi, roll_over_mode: bool = False, # 是否开启换月模式检测 symbol: str = None, duration_seconds: int = 1, history_length: int = 50, close_bar_delta: timedelta = None, ): """ 初始化 Tqsdk 回测引擎。 Args: strategy_class (Type[Strategy]): 策略类。 strategy_params (Dict[str, Any]): 传递给策略的参数字典。 data_path (str): 本地 K 线数据文件路径,用于 TqSim 加载。 initial_capital (float): 初始资金。 slippage_rate (float): 交易滑点率(在 Tqsdk 中通常需要手动实现或通过费用设置)。 commission_rate (float): 交易佣金率(在 Tqsdk 中通常需要手动实现或通过费用设置)。 roll_over_mode (bool): 是否启用换月检测。 start_time (Optional[datetime]): 回测开始时间。 end_time (Optional[datetime]): 回测结束时间。 """ self.strategy_class = strategy_class self.strategy_params = strategy_params self.roll_over_mode = roll_over_mode self.history_length = history_length self.close_bar_delta = close_bar_delta self.next_close_time = None # Tqsdk API 和模拟器 # 这里使用 file_path 参数指定本地数据文件 self._api: TqApi = api # 从策略参数中获取主symbol,TqsdkContext 需要知道它 # self.symbol: str = strategy_params.get("symbol") # if not self.symbol: # raise ValueError("strategy_params 必须包含 'symbol' 字段") self.symbol = symbol # 获取 K 线数据(Tqsdk 自动处理) # 这里假设策略所需 K 线周期在 strategy_params 中,否则默认60秒(1分钟K线) self.bar_duration_seconds: int = strategy_params.get("bar_duration_seconds", 60) # self._main_kline_serial = self._api.get_kline_serial( # self.symbol, self.bar_duration_seconds # ) # 初始化上下文 identifier = generate_strategy_identifier(strategy_class, strategy_params) self._context: TqsdkContext = TqsdkContext(api=self._api, state_repository=JsonFileStateRepository(identifier)) # 实例化策略,并将上下文传递给它 self._strategy: Strategy = self.strategy_class( context=self._context, **self.strategy_params ) self._context.set_engine( self ) # 将引擎自身传递给上下文,以便 Context 可以访问引擎属性 self.portfolio_snapshots: List[PortfolioSnapshot] = [] self.trade_history: List[Trade] = [] self.all_bars: List[Bar] = [] # 收集所有处理过的Bar self.close_list: List[float] = [] self.open_list: List[float] = [] self.high_list: List[float] = [] self.low_list: List[float] = [] self.volume_list: List[float] = [] self.last_processed_bar: Optional[Bar] = None self._is_rollover_bar: bool = False # 换月信号 self._last_underlying_symbol = self.symbol # 用于检测主力合约换月 self.now = None self.quote = None self.quote = api.get_quote(symbol) self.klines = api.get_kline_serial( self.quote.underlying_symbol, duration_seconds, data_length=history_length + 2 ) self.klines_1min = api.get_kline_serial(self.quote.underlying_symbol, 60) self.partial_bar: Bar = None self.kline_row = None self.target_pos_dict = {} print("TqsdkEngine: 初始化完成。") @property def is_rollover_bar(self) -> bool: """ 属性:判断当前 K 线是否为换月 K 线(即检测到主力合约切换)。 """ return self._is_rollover_bar def _process_queued_requests(self): """ 异步处理 Context 中排队的订单和取消请求。 """ # 处理订单 while self._context.order_queue: order_to_send: Order = self._context.order_queue.popleft() print(f"Engine: 处理订单请求: {order_to_send}") # 映射 core_data.Order 到 Tqsdk 的订单参数 tqsdk_direction = "" tqsdk_offset = "" if order_to_send.direction == "BUY": tqsdk_direction = "BUY" tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓 elif order_to_send.direction == "SELL": tqsdk_direction = "SELL" tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓 elif order_to_send.direction == "CLOSE_LONG": tqsdk_direction = "SELL" tqsdk_offset = order_to_send.offset or "CLOSE" # 平多,默认平仓 elif order_to_send.direction == "CLOSE_SHORT": tqsdk_direction = "BUY" tqsdk_offset = order_to_send.offset or "CLOSE" # 平空,默认平仓 else: print(f"Engine: 未知订单方向: {order_to_send.direction}") continue # 跳过此订单 if "SHFE" in order_to_send.symbol: tqsdk_offset = "OPEN" if "CLOSE" in order_to_send.direction: current_positions = self._context.get_current_positions() current_pos_volume = current_positions.get(order_to_send.symbol, 0) target_volume = None if order_to_send.direction == 'CLOSE_LONG': target_volume = current_pos_volume - order_to_send.volume elif order_to_send.direction == 'CLOSE_SHORT': target_volume = current_pos_volume + order_to_send.volume if target_volume is not None: if order_to_send.symbol not in self.target_pos_dict: self.target_pos_dict[order_to_send.symbol] = TargetPosTask(self._api, order_to_send.symbol) self.target_pos_dict[order_to_send.symbol].set_target_volume(target_volume) else: try: tq_order = self._api.insert_order( symbol=order_to_send.symbol, direction=tqsdk_direction, offset=tqsdk_offset, volume=order_to_send.volume, # Tqsdk 市价单 limit_price 设为 None,限价单则传递价格 limit_price=( order_to_send.limit_price if order_to_send.price_type == "LIMIT" # else self.quote.bid_price1 + (1 if tqsdk_direction == "BUY" else -1) else ( self.quote.bid_price1 if tqsdk_direction == "SELL" else self.quote.ask_price1 ) ), ) # 更新原始 Order 对象与 Tqsdk 的订单ID和状态 order_to_send.id = tq_order.order_id # order_to_send.order_id = tq_order.order_id # order_to_send.status = tq_order.status order_to_send.submitted_time = pd.to_datetime( tq_order.insert_date_time, unit="ns", utc=True ) self._api.wait_update() # 等待一次更新 except Exception as e: print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}") # order_to_send.status = "ERROR" # 处理取消请求 while self._context.cancel_queue: order_id_to_cancel = self._context.cancel_queue.popleft() print(f"Engine: 处理取消请求: {order_id_to_cancel}") tq_order_to_cancel = self._api.get_order(order_id_to_cancel) if tq_order_to_cancel and tq_order_to_cancel.status == "ALIVE": try: self._api.cancel_order(tq_order_to_cancel) self._api.wait_update() # 等待取消确认 print( f"Engine: 订单 {order_id_to_cancel} 已尝试取消。当前状态: {tq_order_to_cancel.status}" ) except Exception as e: print(f"Engine: 取消订单 {order_id_to_cancel} 失败: {e}") else: print( f"Engine: 订单 {order_id_to_cancel} 不存在或已非活动状态,无法取消。" ) def _record_portfolio_snapshot(self, current_time: datetime): """ 记录当前投资组合的快照。 """ account: TqAccount = self._api.get_account() current_positions = self._context.get_current_positions() # 计算当前持仓市值 total_market_value = 0.0 current_prices: Dict[str, float] = {} for symbol, qty in current_positions.items(): # 获取当前合约的最新价格 quote = self._api.get_quote(symbol) if quote.last_price: # 确保价格是最近的 price = quote.last_price current_prices[symbol] = price total_market_value += ( price * qty * quote.volume_multiple ) # volume_multiple 乘数 else: # 如果没有最新价格,使用最近的K线收盘价作为估算 # 在实盘或连续回测中,通常会有最新的行情 print(f"警告: 未获取到 {symbol} 最新价格,可能影响净值计算。") # 可以尝试从 K 线获取最近价格 kline = self._api.get_kline_serial(symbol, self.bar_duration_seconds) if not kline.empty: last_kline = kline.iloc[-2] price = last_kline.close current_prices[symbol] = price total_market_value += ( price * qty * self._api.get_instrument(symbol).volume_multiple ) # 使用 instrument 的乘数 total_value = ( account.available + account.frozen_margin + total_market_value ) # Tqsdk 的 balance 已包含持仓市值和冻结资金 # Tqsdk 的 total_profit/balance 已经包含了所有盈亏和资金 snapshot = PortfolioSnapshot( datetime=current_time, total_value=account.balance, # Tqsdk 的 balance 包含了可用资金、冻结保证金和持仓市值 cash=account.available, positions=current_positions, price_at_snapshot=current_prices, ) self.portfolio_snapshots.append(snapshot) def _close_all_positions_at_end(self): """ 回测结束时,平掉所有剩余持仓。 """ current_positions = self._context.get_current_positions() if not current_positions: print("回测结束:没有需要平仓的持仓。") return print("回测结束:开始平仓所有剩余持仓...") for symbol, qty in current_positions.items(): order_direction: Literal["BUY", "SELL"] if qty > 0: # 多头持仓,卖出平仓 order_direction = "SELL" else: # 空头持仓,买入平仓 order_direction = "BUY" TargetPosTask(self._api, symbol).set_target_volume(0) # # 使用市价单快速平仓 # tq_order = self._api.insert_order( # symbol=symbol, # direction=order_direction, # offset="CLOSE", # 平仓 # volume=abs(qty), # limit_price=self # ) # print(f"平仓订单已发送: {symbol} {order_direction} {abs(qty)} 手") # 等待订单完成 # while tq_order.status == "ALIVE": # self._api.wait_update() # if tq_order.status == "FINISHED": # print(f"订单 {tq_order.order_id} 平仓完成。") # else: # print(f"订单 {tq_order.order_id} 平仓失败或未完成,状态: {tq_order.status}") def _run_async(self): """ 异步运行回测的主循环。 """ print(f"TqsdkEngine: 开始加载历史数据,加载k线数量{self.history_length}") self._strategy.trading = False is_trading_time = is_futures_trading_time() for i in range(self.history_length + 1, 0 if not is_trading_time else 1, -1): kline_row = self.klines.iloc[-i] kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True) kline_dt = kline_dt.tz_convert(BEIJING_TZ) self.main(kline_row, self.klines.iloc[-i - 1]) print( f"TqsdkEngine: 加载历史k线完成, bars数量:{len(self.all_bars)},last bar datetime:{self.all_bars[-1].datetime}" ) self._strategy.trading = True self._last_underlying_symbol = self.quote.underlying_symbol print( f"TqsdkEngine: self._last_underlying_symbol:{self._last_underlying_symbol}, is_trading_time:{is_trading_time}" ) # 初始化策略 (如果策略有 on_init 方法) if hasattr(self._strategy, "on_init"): self._strategy.on_init() new_bar = False if is_trading_time: kline_row = self.klines.iloc[-1] kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True) kline_dt = kline_dt.tz_convert(BEIJING_TZ) print(f"TqsdkEngine: 当前是交易时间,处理最新一根k线,datetime:{kline_dt}") self.main(self.klines.iloc[-1], self.klines.iloc[-2]) new_bar = True kline_row = self.klines.iloc[-1] self.kline_row = kline_row # 迭代 K 线数据 # 使用 self._api.get_kline_serial 获取到的 K 线是 Pandas DataFrame, # 直接迭代其行(Bar)更符合回测逻辑 for bar in self.all_bars[-5:]: print(bar) print(f"TqsdkEngine: 开始等待最新数据, all bars -1:{self.all_bars[-1].datetime}") last_min_k = None while True: # Tqsdk API 的 wait_update() 确保数据更新 self._api.wait_update() if self.roll_over_mode and ( self._api.is_changing(self.quote, "underlying_symbol") or self._last_underlying_symbol != self.quote.underlying_symbol ): self._last_underlying_symbol = self.quote.underlying_symbol if new_bar and (last_min_k is None or last_min_k.datetime != self.klines_1min.iloc[-1].datetime): last_min_k = self.klines_1min.iloc[-1] if self.kline_row is not None: kline_dt = pd.to_datetime(self.kline_row.datetime, unit="ns", utc=True) kline_dt = kline_dt.tz_convert(BEIJING_TZ) is_close_bar = is_bar_pre_close_period( kline_dt, int(self.kline_row.duration), pre_close_minutes=1 ) if is_close_bar: print( f"TqsdkEngine: close bar, kline_dt:{kline_dt}, now: {datetime.now()}" ) self.close_bar(self.kline_row) new_bar = False # if self._api.is_changing(self.klines.iloc[-1], "open"): # print(f"TqsdkEngine: open change!, open:{self.klines.iloc[-1].open}, now: {datetime.now()}") if self.kline_row is None or self.kline_row.datetime != self.klines.iloc[-1].datetime: # 到这里一定满足“整点-00/30 且秒>1” kline_row = self.klines.iloc[-1] kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True) kline_dt = kline_dt.tz_convert(BEIJING_TZ) # 等待到整点-00 或 30 分且秒>1 now = datetime.now() # if now.second <= 58: # while True: # now = datetime.now() # if now.second >= 58: # break # self._api.wait_update() # print(f'TqEngine:self.klines.iloc[-2].volume: {self.klines.iloc[-2].volume}') # print(f'TqEngine:self.klines.iloc[-1].volume: {self.klines.iloc[-1].volume}') while self.klines.iloc[-1].volume <= 0: self._api.wait_update() while True: now = datetime.now() minute = now.minute second = now.second hour = now.hour if (minute % 5 == 0) and (second >= 0) and hour != 8 and hour != 20: break # 小粒度休眠,防止 CPU 空转 self._api.wait_update() if kline_dt.hour != self.all_bars[-1].datetime.hour or kline_dt.minute != self.all_bars[ -1].datetime.minute: print( f"TqsdkEngine: 新k线产生, k line datetime:{kline_dt}, now: {datetime.now()}, open: {self.klines.iloc[-1].open}") self.kline_row = self.klines.iloc[-1] self.main(self.klines.iloc[-1], self.klines.iloc[-2]) new_bar = True def close_bar(self, kline_row): kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True) kline_dt = kline_dt.tz_convert(BEIJING_TZ) if len(self.all_bars) > 0: # 创建 core_data.Bar 对象 current_bar = Bar( datetime=kline_dt, symbol=self._last_underlying_symbol, open=kline_row.open, high=kline_row.high, low=kline_row.low, close=kline_row.close, volume=kline_row.volume, open_oi=kline_row.open_oi, close_oi=kline_row.close_oi, ) self.last_processed_bar = current_bar if self._strategy.trading is True: self._strategy.on_close_bar(current_bar) # 处理订单和取消请求 self._process_queued_requests() def main(self, kline_row, prev_kline_row): kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True) kline_dt = kline_dt.tz_convert(BEIJING_TZ) if self.partial_bar is not None: last_bar = Bar( datetime=pd.to_datetime(prev_kline_row.datetime, unit="ns", utc=True).tz_convert(BEIJING_TZ), symbol=self.partial_bar.symbol, open=prev_kline_row.open, high=prev_kline_row.high, low=prev_kline_row.low, close=prev_kline_row.close, volume=prev_kline_row.volume, open_oi=prev_kline_row.open_oi, close_oi=prev_kline_row.close_oi, ) self.all_bars.append(last_bar) self.close_list.append(last_bar.close) self.open_list.append(last_bar.open) self.high_list.append(last_bar.high) self.low_list.append(last_bar.low) self.volume_list.append(last_bar.volume) self.last_processed_bar = last_bar if ( self.roll_over_mode and self.last_processed_bar is not None and self._last_underlying_symbol != self.last_processed_bar.symbol and self._strategy.trading ): self._is_rollover_bar = True print( f"TqsdkEngine: 检测到换月信号!从 {self._last_underlying_symbol} 切换到 {self.quote.underlying_symbol}" ) self._close_all_positions_at_end() self._strategy.cancel_all_pending_orders() self._strategy.on_rollover( self.last_processed_bar.symbol, self._last_underlying_symbol ) else: self._strategy.on_open_bar(kline_row.open, self._last_underlying_symbol) # 处理订单和取消请求 if self._strategy.trading is True: self._process_queued_requests() self.partial_bar = Bar( datetime=kline_dt, symbol=self.quote.underlying_symbol, open=0, high=0, low=0, close=0, volume=0, open_oi=0, close_oi=0, ) def run(self): """ 同步调用异步回测主循环。 """ try: self._run_async() except KeyboardInterrupt: print("\n回测被用户中断。") finally: self._api.close() print("TqsdkEngine: API 已关闭。") def get_results(self) -> Dict[str, Any]: """ 返回回测结果数据,供结果分析模块使用。 """ final_portfolio_value = 0.0 if self.portfolio_snapshots: final_portfolio_value = self.portfolio_snapshots[-1].total_value # else: # final_portfolio_value = self.initial_capital # 如果没有快照,则净值是初始资金 # total_return_percentage = ( # (final_portfolio_value - self.initial_capital) / self.initial_capital # ) * 100 if self.initial_capital != 0 else 0.0 return { "portfolio_snapshots": self.portfolio_snapshots, "trade_history": self.trade_history, # "initial_capital": self.initial_capital, "all_bars": self.all_bars, "final_portfolio_value": final_portfolio_value, # "total_return_percentage": total_return_percentage, } def get_bar_history(self): return self.all_bars def get_price_history(self, key: str): if key == "close": return self.close_list elif key == "open": return self.open_list elif key == "high": return self.high_list elif key == "low": return self.low_list elif key == "volume": return self.volume_list return None