# filename: tqsdk_engine.py import asyncio import traceback from datetime import date, datetime, timedelta from typing import Literal, Type, Dict, Any, List, Optional import pandas as pd import uuid from src.common_utils import generate_strategy_identifier # 导入你提供的 core_data 中的类型 from src.core_data import Bar, Order, Trade, PortfolioSnapshot # 导入 Tqsdk 的核心类型 import tqsdk from tqsdk import ( TqApi, TqAccount, tafunc, TqSim, TqBacktest, TqAuth, TargetPosTask, BacktestFinished, ) from src.state_repo import MemoryStateRepository # 导入 TqsdkContext 和 BaseStrategy from src.tqsdk_context import TqsdkContext from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径 BEIJING_TZ = "Asia/Shanghai" class TqsdkEngine: """ Tqsdk 回测引擎:协调 Tqsdk 数据流、策略执行、订单模拟和结果记录。 替代原有的 BacktestEngine。 """ def __init__( self, strategy_class: Type[Strategy], strategy_params: Dict[str, Any], api: TqApi, roll_over_mode: bool = False, # 是否开启换月模式检测 symbol: str = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, duration_seconds: int = 1, ): """ 初始化 Tqsdk 回测引擎。 Args: strategy_class (Type[Strategy]): 策略类。 strategy_params (Dict[str, Any]): 传递给策略的参数字典。 data_path (str): 本地 K 线数据文件路径,用于 TqSim 加载。 initial_capital (float): 初始资金。 slippage_rate (float): 交易滑点率(在 Tqsdk 中通常需要手动实现或通过费用设置)。 commission_rate (float): 交易佣金率(在 Tqsdk 中通常需要手动实现或通过费用设置)。 roll_over_mode (bool): 是否启用换月检测。 start_time (Optional[datetime]): 回测开始时间。 end_time (Optional[datetime]): 回测结束时间。 """ self.strategy_class = strategy_class self.strategy_params = strategy_params self.roll_over_mode = roll_over_mode self.start_time = start_time self.end_time = end_time # Tqsdk API 和模拟器 # 这里使用 file_path 参数指定本地数据文件 self._api: TqApi = api # 从策略参数中获取主symbol,TqsdkContext 需要知道它 self.symbol: str = symbol.replace('_', '.') if not self.symbol: raise ValueError("strategy_params 必须包含 'symbol' 字段") # 获取 K 线数据(Tqsdk 自动处理) # 这里假设策略所需 K 线周期在 strategy_params 中,否则默认60秒(1分钟K线) self.bar_duration_seconds: int = strategy_params.get("bar_duration_seconds", 60) # self._main_kline_serial = self._api.get_kline_serial( # self.symbol, self.bar_duration_seconds # ) # 初始化上下文 identifier = generate_strategy_identifier(strategy_class, strategy_params) self._context: TqsdkContext = TqsdkContext(api=self._api, state_repository=MemoryStateRepository(identifier)) # 实例化策略,并将上下文传递给它 self._strategy: Strategy = self.strategy_class( context=self._context, **self.strategy_params ) self._context.set_engine( self ) # 将引擎自身传递给上下文,以便 Context 可以访问引擎属性 self.portfolio_snapshots: List[PortfolioSnapshot] = [] self.trade_history: List[Trade] = [] self.all_bars: List[Bar] = [] # 收集所有处理过的Bar self.close_list: List[float] = [] self.open_list: List[float] = [] self.high_list: List[float] = [] self.low_list: List[float] = [] self.volume_list: List[float] = [] self.last_processed_bar: Optional[Bar] = None self._is_rollover_bar: bool = False # 换月信号 self._last_underlying_symbol = self.symbol # 用于检测主力合约换月 self.klines = api.get_kline_serial(self.symbol, duration_seconds) self.klines_1min = api.get_kline_serial(self.symbol, 60) self.now = None self.quote = None if roll_over_mode: self.quote = api.get_quote(self.symbol) self.target_pos_dict = {} print("TqsdkEngine: 初始化完成。") @property def is_rollover_bar(self) -> bool: """ 属性:判断当前 K 线是否为换月 K 线(即检测到主力合约切换)。 """ return self._is_rollover_bar def _process_queued_requests(self): """ 异步处理 Context 中排队的订单和取消请求。 """ # 处理订单 while self._context.order_queue: order_to_send: Order = self._context.order_queue.popleft() print(f"Engine: 处理订单请求: {order_to_send}") # 映射 core_data.Order 到 Tqsdk 的订单参数 tqsdk_direction = "" tqsdk_offset = "" if order_to_send.direction == "BUY": tqsdk_direction = "BUY" tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓 elif order_to_send.direction == "SELL": tqsdk_direction = "SELL" tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓 elif order_to_send.direction == "CLOSE_LONG": tqsdk_direction = "SELL" tqsdk_offset = order_to_send.offset or "CLOSE" # 平多,默认平仓 elif order_to_send.direction == "CLOSE_SHORT": tqsdk_direction = "BUY" tqsdk_offset = order_to_send.offset or "CLOSE" # 平空,默认平仓 else: print(f"Engine: 未知订单方向: {order_to_send.direction}") continue # 跳过此订单 if "SHFE" in order_to_send.symbol: tqsdk_offset = "OPEN" if "CLOSE" in order_to_send.direction: current_positions = self._context.get_current_positions() current_pos_volume = current_positions.get(order_to_send.symbol, 0) target_volume = None if order_to_send.direction == 'CLOSE_LONG': target_volume = current_pos_volume - order_to_send.volume elif order_to_send.direction == 'CLOSE_SHORT': target_volume = current_pos_volume + order_to_send.volume if target_volume is not None: if order_to_send.symbol not in self.target_pos_dict: self.target_pos_dict[order_to_send.symbol] = TargetPosTask(self._api, order_to_send.symbol) self.target_pos_dict[order_to_send.symbol].set_target_volume(target_volume) else: # try: tq_order = self._api.insert_order( symbol=order_to_send.symbol, direction=tqsdk_direction, offset=tqsdk_offset, volume=order_to_send.volume, # Tqsdk 市价单 limit_price 设为 None,限价单则传递价格 limit_price=( order_to_send.limit_price if order_to_send.price_type == "LIMIT" # else self.quote.bid_price1 + (1 if tqsdk_direction == "BUY" else -1) else ( self.quote.bid_price1 if tqsdk_direction == "SELL" else self.quote.ask_price1 ) ), ) # 更新原始 Order 对象与 Tqsdk 的订单ID和状态 order_to_send.id = tq_order.order_id # order_to_send.order_id = tq_order.order_id # order_to_send.status = tq_order.status order_to_send.submitted_time = pd.to_datetime( tq_order.insert_date_time, unit="ns", utc=True ) self._api.wait_update() # 等待一次更新 # # except Exception as e: # print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}") # 处理取消请求 while self._context.cancel_queue: order_id_to_cancel = self._context.cancel_queue.popleft() print(f"Engine: 处理取消请求: {order_id_to_cancel}") tq_order_to_cancel = self._api.get_order(order_id_to_cancel) if tq_order_to_cancel and tq_order_to_cancel.status == "ALIVE": try: self._api.cancel_order(tq_order_to_cancel) self._api.wait_update() # 等待取消确认 print( f"Engine: 订单 {order_id_to_cancel} 已尝试取消。当前状态: {tq_order_to_cancel.status}" ) except Exception as e: print(f"Engine: 取消订单 {order_id_to_cancel} 失败: {e}") else: print( f"Engine: 订单 {order_id_to_cancel} 不存在或已非活动状态,无法取消。" ) def _record_portfolio_snapshot(self, current_time: datetime): """ 记录当前投资组合的快照。 """ account: TqAccount = self._api.get_account() current_positions = self._context.get_current_positions() # 计算当前持仓市值 total_market_value = 0.0 current_prices: Dict[str, float] = {} for symbol, qty in current_positions.items(): # 获取当前合约的最新价格 quote = self._api.get_quote(symbol) if quote.last_price: # 确保价格是最近的 price = quote.last_price current_prices[symbol] = price total_market_value += ( price * qty * quote.volume_multiple ) # volume_multiple 乘数 else: # 如果没有最新价格,使用最近的K线收盘价作为估算 # 在实盘或连续回测中,通常会有最新的行情 print(f"警告: 未获取到 {symbol} 最新价格,可能影响净值计算。") # 可以尝试从 K 线获取最近价格 kline = self._api.get_kline_serial(symbol, self.bar_duration_seconds) if not kline.empty: last_kline = kline.iloc[-2] price = last_kline.close current_prices[symbol] = price total_market_value += ( price * qty * self._api.get_instrument(symbol).volume_multiple ) # 使用 instrument 的乘数 total_value = ( account.available + account.frozen_margin + total_market_value ) # Tqsdk 的 balance 已包含持仓市值和冻结资金 # Tqsdk 的 total_profit/balance 已经包含了所有盈亏和资金 snapshot = PortfolioSnapshot( datetime=current_time, total_value=account.balance, # Tqsdk 的 balance 包含了可用资金、冻结保证金和持仓市值 cash=account.available, positions=current_positions, price_at_snapshot=current_prices, ) self.portfolio_snapshots.append(snapshot) def _close_all_positions_at_end(self): """ 回测结束时,平掉所有剩余持仓。 """ current_positions = self._context.get_current_positions() if not current_positions: print("回测结束:没有需要平仓的持仓。") return print("回测结束:开始平仓所有剩余持仓...") for symbol, qty in current_positions.items(): order_direction: Literal["BUY", "SELL"] if qty > 0: # 多头持仓,卖出平仓 order_direction = "SELL" else: # 空头持仓,买入平仓 order_direction = "BUY" TargetPosTask(self._api, symbol).set_target_volume(0) # # 使用市价单快速平仓 # tq_order = self._api.insert_order( # symbol=symbol, # direction=order_direction, # offset="CLOSE", # 平仓 # volume=abs(qty), # limit_price=self # ) # print(f"平仓订单已发送: {symbol} {order_direction} {abs(qty)} 手") # 等待订单完成 # while tq_order.status == "ALIVE": # self._api.wait_update() # if tq_order.status == "FINISHED": # print(f"订单 {tq_order.order_id} 平仓完成。") # else: # print(f"订单 {tq_order.order_id} 平仓失败或未完成,状态: {tq_order.status}") def _run_backtest_async(self): """ 异步运行回测的主循环。 """ print(f"TqsdkEngine: 开始运行回测,从 {self.start_time} 到 {self.end_time}") self._strategy.trading = True # 初始化策略 (如果策略有 on_init 方法) if hasattr(self._strategy, "on_init"): self._strategy.on_init() last_bar_datetime = None # 迭代 K 线数据 # 使用 self._api.get_kline_serial 获取到的 K 线是 Pandas DataFrame, # 直接迭代其行(Bar)更符合回测逻辑 try: while True: # Tqsdk API 的 wait_update() 确保数据更新 self._api.wait_update() if self.roll_over_mode and ( self._api.is_changing(self.quote, "underlying_symbol") or self._last_underlying_symbol != self.quote.underlying_symbol ): self._last_underlying_symbol = self.quote.underlying_symbol if self._api.is_changing(self.klines_1min): now_kline = self.klines_1min.iloc[-1] now_dt = pd.to_datetime(now_kline.datetime, unit="ns", utc=True) now_dt = now_dt.tz_convert(BEIJING_TZ) if ( self.now is not None and self.now.hour != 13 and now_dt.hour == 13 ): self.main() self.now = now_dt if self._api.is_changing(self.klines.iloc[-1]): kline_row = self.klines.iloc[-1] kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True) kline_dt = kline_dt.tz_convert(BEIJING_TZ) if kline_dt.hour == 13 and self.now.hour == 11: continue else: self.main() except BacktestFinished: # 回测结束时,确保所有排队请求得到处理 self._process_queued_requests() # 回测结束后,如果需要,平掉所有剩余持仓 self._close_all_positions_at_end() print("TqsdkEngine: 回测运行完毕。") def main(self): kline_row = self.klines.iloc[-1] kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True) kline_dt = kline_dt.tz_convert(BEIJING_TZ) current_bar = Bar( datetime=kline_dt, symbol=self._last_underlying_symbol, open=kline_row.open, high=kline_row.high, low=kline_row.low, close=kline_row.close, volume=kline_row.volume, open_oi=kline_row.open_oi, close_oi=kline_row.close_oi, ) if ( self.last_processed_bar is None or self.last_processed_bar.datetime != kline_dt ): # 设置当前 Bar 到 Context self._context.set_current_bar(current_bar) # Tqsdk 的 is_changing 用于判断数据是否有变化,对于回测遍历 K 线,每次迭代都算作新 Bar # 如果 kline_row.datetime 与上次不同,则认为是新 Bar if ( self.roll_over_mode and self.last_processed_bar is not None and self._last_underlying_symbol != self.last_processed_bar.symbol ): self._is_rollover_bar = True print( f"TqsdkEngine: 检测到换月信号!从 {self._last_underlying_symbol} 切换到 {self.quote.underlying_symbol}" ) self._close_all_positions_at_end() self._strategy.cancel_all_pending_orders() self._strategy.on_rollover( self.last_processed_bar.symbol, self._last_underlying_symbol ) else: self._is_rollover_bar = False self.last_processed_bar = current_bar # 调用策略的 on_bar 方法 self._strategy.on_open_bar(current_bar.open, current_bar.symbol) # 处理订单和取消请求 self._process_queued_requests() # 记录投资组合快照 self._record_portfolio_snapshot(current_bar.datetime) else: if current_bar.volume == 0: return self.all_bars.append(current_bar) self.close_list.append(current_bar.close) self.open_list.append(current_bar.open) self.high_list.append(current_bar.high) self.low_list.append(current_bar.low) self.volume_list.append(current_bar.volume) self.last_processed_bar = current_bar # 设置当前 Bar 到 Context self._context.set_current_bar(current_bar) # 调用策略的 on_bar 方法 self._strategy.on_close_bar(current_bar) # 处理订单和取消请求 self._process_queued_requests() def run_backtest(self): """ 同步调用异步回测主循环。 """ try: self._run_backtest_async() except KeyboardInterrupt: print("\n回测被用户中断。") finally: self._api.close() print("TqsdkEngine: API 已关闭。") def get_backtest_results(self) -> Dict[str, Any]: """ 返回回测结果数据,供结果分析模块使用。 """ final_portfolio_value = 0.0 if self.portfolio_snapshots: final_portfolio_value = self.portfolio_snapshots[-1].total_value # else: # final_portfolio_value = self.initial_capital # 如果没有快照,则净值是初始资金 # total_return_percentage = ( # (final_portfolio_value - self.initial_capital) / self.initial_capital # ) * 100 if self.initial_capital != 0 else 0.0 return { "portfolio_snapshots": self.portfolio_snapshots, "trade_history": self.trade_history, # "initial_capital": self.initial_capital, "all_bars": self.all_bars, "final_portfolio_value": final_portfolio_value, # "total_return_percentage": total_return_percentage, } def get_bar_history(self): return self.all_bars def get_price_history(self, key: str): if key == "close": return self.close_list elif key == "open": return self.open_list elif key == "high": return self.high_list elif key == "low": return self.low_list elif key == "volume": return self.volume_list