Files
NewQuant/src/tqsdk_real_engine.py
2025-11-07 16:26:00 +08:00

584 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# filename: tqsdk_engine.py
import asyncio
from datetime import date, datetime, timedelta
from typing import Literal, Type, Dict, Any, List, Optional
import pandas as pd
import time
# 导入你提供的 core_data 中的类型
from src.common_utils import is_bar_pre_close_period, is_futures_trading_time, generate_strategy_identifier
from src.core_data import Bar, Order, Trade, PortfolioSnapshot
# 导入 Tqsdk 的核心类型
import tqsdk
from tqsdk import (
TqApi,
TqAccount,
tafunc,
TqSim,
TqBacktest,
TqAuth,
TargetPosTask,
BacktestFinished,
)
from src.state_repo import JsonFileStateRepository
# 导入 TqsdkContext 和 BaseStrategy
from src.tqsdk_real_context import TqsdkContext
from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径
BEIJING_TZ = "Asia/Shanghai"
class TqsdkEngine:
"""
Tqsdk 回测引擎:协调 Tqsdk 数据流、策略执行、订单模拟和结果记录。
替代原有的 BacktestEngine。
"""
def __init__(
self,
strategy_class: Type[Strategy],
strategy_params: Dict[str, Any],
api: TqApi,
roll_over_mode: bool = False, # 是否开启换月模式检测
symbol: str = None,
duration_seconds: int = 1,
history_length: int = 50,
close_bar_delta: timedelta = None,
):
"""
初始化 Tqsdk 回测引擎。
Args:
strategy_class (Type[Strategy]): 策略类。
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
data_path (str): 本地 K 线数据文件路径,用于 TqSim 加载。
initial_capital (float): 初始资金。
slippage_rate (float): 交易滑点率(在 Tqsdk 中通常需要手动实现或通过费用设置)。
commission_rate (float): 交易佣金率(在 Tqsdk 中通常需要手动实现或通过费用设置)。
roll_over_mode (bool): 是否启用换月检测。
start_time (Optional[datetime]): 回测开始时间。
end_time (Optional[datetime]): 回测结束时间。
"""
self.strategy_class = strategy_class
self.strategy_params = strategy_params
self.roll_over_mode = roll_over_mode
self.history_length = history_length
self.close_bar_delta = close_bar_delta
self.next_close_time = None
# Tqsdk API 和模拟器
# 这里使用 file_path 参数指定本地数据文件
self._api: TqApi = api
# 从策略参数中获取主symbolTqsdkContext 需要知道它
# self.symbol: str = strategy_params.get("symbol")
# if not self.symbol:
# raise ValueError("strategy_params 必须包含 'symbol' 字段")
self.symbol = symbol
# 获取 K 线数据Tqsdk 自动处理)
# 这里假设策略所需 K 线周期在 strategy_params 中否则默认60秒1分钟K线
self.bar_duration_seconds: int = strategy_params.get("bar_duration_seconds", 60)
# self._main_kline_serial = self._api.get_kline_serial(
# self.symbol, self.bar_duration_seconds
# )
# 初始化上下文
identifier = generate_strategy_identifier(strategy_class, strategy_params)
self._context: TqsdkContext = TqsdkContext(api=self._api, state_repository=JsonFileStateRepository(identifier)) # 实例化策略,并将上下文传递给它
self._strategy: Strategy = self.strategy_class(
context=self._context, **self.strategy_params
)
self._context.set_engine(
self
) # 将引擎自身传递给上下文,以便 Context 可以访问引擎属性
self.portfolio_snapshots: List[PortfolioSnapshot] = []
self.trade_history: List[Trade] = []
self.all_bars: List[Bar] = [] # 收集所有处理过的Bar
self.close_list: List[float] = []
self.open_list: List[float] = []
self.high_list: List[float] = []
self.low_list: List[float] = []
self.volume_list: List[float] = []
self.last_processed_bar: Optional[Bar] = None
self._is_rollover_bar: bool = False # 换月信号
self._last_underlying_symbol = self.symbol # 用于检测主力合约换月
self.now = None
self.quote = None
self.quote = api.get_quote(symbol)
self.klines = api.get_kline_serial(
self.quote.underlying_symbol, duration_seconds, data_length=history_length + 2
)
self.klines_1min = api.get_kline_serial(self.quote.underlying_symbol, 60)
self.partial_bar: Bar = None
self.kline_row = None
self.target_pos_dict = {}
print("TqsdkEngine: 初始化完成。")
@property
def is_rollover_bar(self) -> bool:
"""
属性:判断当前 K 线是否为换月 K 线(即检测到主力合约切换)。
"""
return self._is_rollover_bar
def _process_queued_requests(self):
"""
异步处理 Context 中排队的订单和取消请求。
"""
# 处理订单
while self._context.order_queue:
order_to_send: Order = self._context.order_queue.popleft()
print(f"Engine: 处理订单请求: {order_to_send}")
# 映射 core_data.Order 到 Tqsdk 的订单参数
tqsdk_direction = ""
tqsdk_offset = ""
if order_to_send.direction == "BUY":
tqsdk_direction = "BUY"
tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓
elif order_to_send.direction == "SELL":
tqsdk_direction = "SELL"
tqsdk_offset = order_to_send.offset or "OPEN" # 默认开仓
elif order_to_send.direction == "CLOSE_LONG":
tqsdk_direction = "SELL"
tqsdk_offset = order_to_send.offset or "CLOSE" # 平多,默认平仓
elif order_to_send.direction == "CLOSE_SHORT":
tqsdk_direction = "BUY"
tqsdk_offset = order_to_send.offset or "CLOSE" # 平空,默认平仓
else:
print(f"Engine: 未知订单方向: {order_to_send.direction}")
continue # 跳过此订单
if "SHFE" in order_to_send.symbol:
tqsdk_offset = "OPEN"
if "CLOSE" in order_to_send.direction:
current_positions = self._context.get_current_positions()
current_pos_volume = current_positions.get(order_to_send.symbol, 0)
target_volume = None
if order_to_send.direction == 'CLOSE_LONG':
target_volume = current_pos_volume - order_to_send.volume
elif order_to_send.direction == 'CLOSE_SHORT':
target_volume = current_pos_volume + order_to_send.volume
if target_volume is not None:
if order_to_send.symbol not in self.target_pos_dict:
self.target_pos_dict[order_to_send.symbol] = TargetPosTask(self._api, order_to_send.symbol)
self.target_pos_dict[order_to_send.symbol].set_target_volume(target_volume)
else:
try:
tq_order = self._api.insert_order(
symbol=order_to_send.symbol,
direction=tqsdk_direction,
offset=tqsdk_offset,
volume=order_to_send.volume,
# Tqsdk 市价单 limit_price 设为 None限价单则传递价格
limit_price=(
order_to_send.limit_price
if order_to_send.price_type == "LIMIT"
# else self.quote.bid_price1 + (1 if tqsdk_direction == "BUY" else -1)
else (
self.quote.bid_price1
if tqsdk_direction == "SELL"
else self.quote.ask_price1
)
),
)
# 更新原始 Order 对象与 Tqsdk 的订单ID和状态
order_to_send.id = tq_order.order_id
# order_to_send.order_id = tq_order.order_id
# order_to_send.status = tq_order.status
order_to_send.submitted_time = pd.to_datetime(
tq_order.insert_date_time, unit="ns", utc=True
)
self._api.wait_update() # 等待一次更新
except Exception as e:
print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}")
# order_to_send.status = "ERROR"
# 处理取消请求
while self._context.cancel_queue:
order_id_to_cancel = self._context.cancel_queue.popleft()
print(f"Engine: 处理取消请求: {order_id_to_cancel}")
tq_order_to_cancel = self._api.get_order(order_id_to_cancel)
if tq_order_to_cancel and tq_order_to_cancel.status == "ALIVE":
try:
self._api.cancel_order(tq_order_to_cancel)
self._api.wait_update() # 等待取消确认
print(
f"Engine: 订单 {order_id_to_cancel} 已尝试取消。当前状态: {tq_order_to_cancel.status}"
)
except Exception as e:
print(f"Engine: 取消订单 {order_id_to_cancel} 失败: {e}")
else:
print(
f"Engine: 订单 {order_id_to_cancel} 不存在或已非活动状态,无法取消。"
)
def _record_portfolio_snapshot(self, current_time: datetime):
"""
记录当前投资组合的快照。
"""
account: TqAccount = self._api.get_account()
current_positions = self._context.get_current_positions()
# 计算当前持仓市值
total_market_value = 0.0
current_prices: Dict[str, float] = {}
for symbol, qty in current_positions.items():
# 获取当前合约的最新价格
quote = self._api.get_quote(symbol)
if quote.last_price: # 确保价格是最近的
price = quote.last_price
current_prices[symbol] = price
total_market_value += (
price * qty * quote.volume_multiple
) # volume_multiple 乘数
else:
# 如果没有最新价格使用最近的K线收盘价作为估算
# 在实盘或连续回测中,通常会有最新的行情
print(f"警告: 未获取到 {symbol} 最新价格,可能影响净值计算。")
# 可以尝试从 K 线获取最近价格
kline = self._api.get_kline_serial(symbol, self.bar_duration_seconds)
if not kline.empty:
last_kline = kline.iloc[-2]
price = last_kline.close
current_prices[symbol] = price
total_market_value += (
price * qty * self._api.get_instrument(symbol).volume_multiple
) # 使用 instrument 的乘数
total_value = (
account.available + account.frozen_margin + total_market_value
) # Tqsdk 的 balance 已包含持仓市值和冻结资金
# Tqsdk 的 total_profit/balance 已经包含了所有盈亏和资金
snapshot = PortfolioSnapshot(
datetime=current_time,
total_value=account.balance, # Tqsdk 的 balance 包含了可用资金、冻结保证金和持仓市值
cash=account.available,
positions=current_positions,
price_at_snapshot=current_prices,
)
self.portfolio_snapshots.append(snapshot)
def _close_all_positions_at_end(self):
"""
回测结束时,平掉所有剩余持仓。
"""
current_positions = self._context.get_current_positions()
if not current_positions:
print("回测结束:没有需要平仓的持仓。")
return
print("回测结束:开始平仓所有剩余持仓...")
for symbol, qty in current_positions.items():
order_direction: Literal["BUY", "SELL"]
if qty > 0: # 多头持仓,卖出平仓
order_direction = "SELL"
else: # 空头持仓,买入平仓
order_direction = "BUY"
TargetPosTask(self._api, symbol).set_target_volume(0)
# # 使用市价单快速平仓
# tq_order = self._api.insert_order(
# symbol=symbol,
# direction=order_direction,
# offset="CLOSE", # 平仓
# volume=abs(qty),
# limit_price=self
# )
# print(f"平仓订单已发送: {symbol} {order_direction} {abs(qty)} 手")
# 等待订单完成
# while tq_order.status == "ALIVE":
# self._api.wait_update()
# if tq_order.status == "FINISHED":
# print(f"订单 {tq_order.order_id} 平仓完成。")
# else:
# print(f"订单 {tq_order.order_id} 平仓失败或未完成,状态: {tq_order.status}")
def _run_async(self):
"""
异步运行回测的主循环。
"""
print(f"TqsdkEngine: 开始加载历史数据加载k线数量{self.history_length}")
self._strategy.trading = False
is_trading_time = is_futures_trading_time()
for i in range(self.history_length + 1, 0 if not is_trading_time else 1, -1):
kline_row = self.klines.iloc[-i]
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
self.main(kline_row, self.klines.iloc[-i - 1])
print(
f"TqsdkEngine: 加载历史k线完成, bars数量:{len(self.all_bars)},last bar datetime:{self.all_bars[-1].datetime}"
)
self._strategy.trading = True
self._last_underlying_symbol = self.quote.underlying_symbol
print(
f"TqsdkEngine: self._last_underlying_symbol:{self._last_underlying_symbol}, is_trading_time:{is_trading_time}"
)
# 初始化策略 (如果策略有 on_init 方法)
if hasattr(self._strategy, "on_init"):
self._strategy.on_init()
new_bar = False
if is_trading_time:
kline_row = self.klines.iloc[-1]
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
print(f"TqsdkEngine: 当前是交易时间处理最新一根k线datetime:{kline_dt}")
self.main(self.klines.iloc[-1], self.klines.iloc[-2])
new_bar = True
kline_row = self.klines.iloc[-1]
self.kline_row = kline_row
# 迭代 K 线数据
# 使用 self._api.get_kline_serial 获取到的 K 线是 Pandas DataFrame
# 直接迭代其行Bar更符合回测逻辑
for bar in self.all_bars[-5:]:
print(bar)
print(f"TqsdkEngine: 开始等待最新数据, all bars -1:{self.all_bars[-1].datetime}")
last_min_k = None
while True:
# Tqsdk API 的 wait_update() 确保数据更新
self._api.wait_update()
if self.roll_over_mode and (
self._api.is_changing(self.quote, "underlying_symbol")
or self._last_underlying_symbol != self.quote.underlying_symbol
):
self._last_underlying_symbol = self.quote.underlying_symbol
if new_bar and (last_min_k is None or last_min_k.datetime != self.klines_1min.iloc[-1].datetime):
last_min_k = self.klines_1min.iloc[-1]
if self.kline_row is not None:
kline_dt = pd.to_datetime(self.kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
is_close_bar = is_bar_pre_close_period(
kline_dt, int(self.kline_row.duration), pre_close_minutes=1
)
if is_close_bar:
print(
f"TqsdkEngine: close bar, kline_dt:{kline_dt}, now: {datetime.now()}"
)
self.close_bar(self.kline_row)
new_bar = False
# if self._api.is_changing(self.klines.iloc[-1], "open"):
# print(f"TqsdkEngine: open change!, open:{self.klines.iloc[-1].open}, now: {datetime.now()}")
if self.kline_row is None or self.kline_row.datetime != self.klines.iloc[-1].datetime:
# 到这里一定满足“整点-00/30 且秒>1”
kline_row = self.klines.iloc[-1]
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
# 等待到整点-00 或 30 分且秒>1
now = datetime.now()
# if now.second <= 58:
# while True:
# now = datetime.now()
# if now.second >= 58:
# break
# self._api.wait_update()
# print(f'TqEngine:self.klines.iloc[-2].volume: {self.klines.iloc[-2].volume}')
# print(f'TqEngine:self.klines.iloc[-1].volume: {self.klines.iloc[-1].volume}')
while self.klines.iloc[-1].volume <= 0:
self._api.wait_update()
while True:
now = datetime.now()
minute = now.minute
second = now.second
hour = now.hour
if (minute % 5 == 0) and (second >= 0) and hour != 8 and hour != 20:
break
# 小粒度休眠,防止 CPU 空转
self._api.wait_update()
if kline_dt.hour != self.all_bars[-1].datetime.hour or kline_dt.minute != self.all_bars[
-1].datetime.minute:
print(
f"TqsdkEngine: 新k线产生, k line datetime:{kline_dt}, now: {datetime.now()}, open: {self.klines.iloc[-1].open}")
self.kline_row = self.klines.iloc[-1]
self.main(self.klines.iloc[-1], self.klines.iloc[-2])
new_bar = True
def close_bar(self, kline_row):
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
if len(self.all_bars) > 0:
# 创建 core_data.Bar 对象
current_bar = Bar(
datetime=kline_dt,
symbol=self._last_underlying_symbol,
open=kline_row.open,
high=kline_row.high,
low=kline_row.low,
close=kline_row.close,
volume=kline_row.volume,
open_oi=kline_row.open_oi,
close_oi=kline_row.close_oi,
)
self.last_processed_bar = current_bar
if self._strategy.trading is True:
self._strategy.on_close_bar(current_bar)
# 处理订单和取消请求
self._process_queued_requests()
def main(self, kline_row, prev_kline_row):
kline_dt = pd.to_datetime(kline_row.datetime, unit="ns", utc=True)
kline_dt = kline_dt.tz_convert(BEIJING_TZ)
if self.partial_bar is not None:
last_bar = Bar(
datetime=pd.to_datetime(prev_kline_row.datetime, unit="ns", utc=True).tz_convert(BEIJING_TZ),
symbol=self.partial_bar.symbol,
open=prev_kline_row.open,
high=prev_kline_row.high,
low=prev_kline_row.low,
close=prev_kline_row.close,
volume=prev_kline_row.volume,
open_oi=prev_kline_row.open_oi,
close_oi=prev_kline_row.close_oi,
)
self.all_bars.append(last_bar)
self.close_list.append(last_bar.close)
self.open_list.append(last_bar.open)
self.high_list.append(last_bar.high)
self.low_list.append(last_bar.low)
self.volume_list.append(last_bar.volume)
self.last_processed_bar = last_bar
if (
self.roll_over_mode
and self.last_processed_bar is not None
and self._last_underlying_symbol != self.last_processed_bar.symbol
and self._strategy.trading
):
self._is_rollover_bar = True
print(
f"TqsdkEngine: 检测到换月信号!从 {self._last_underlying_symbol} 切换到 {self.quote.underlying_symbol}"
)
self._close_all_positions_at_end()
self._strategy.cancel_all_pending_orders()
self._strategy.on_rollover(
self.last_processed_bar.symbol, self._last_underlying_symbol
)
else:
self._strategy.on_open_bar(kline_row.open, self._last_underlying_symbol)
# 处理订单和取消请求
if self._strategy.trading is True:
self._process_queued_requests()
self.partial_bar = Bar(
datetime=kline_dt,
symbol=self.quote.underlying_symbol,
open=0,
high=0,
low=0,
close=0,
volume=0,
open_oi=0,
close_oi=0,
)
def run(self):
"""
同步调用异步回测主循环。
"""
try:
self._run_async()
except KeyboardInterrupt:
print("\n回测被用户中断。")
finally:
self._api.close()
print("TqsdkEngine: API 已关闭。")
def get_results(self) -> Dict[str, Any]:
"""
返回回测结果数据,供结果分析模块使用。
"""
final_portfolio_value = 0.0
if self.portfolio_snapshots:
final_portfolio_value = self.portfolio_snapshots[-1].total_value
# else:
# final_portfolio_value = self.initial_capital # 如果没有快照,则净值是初始资金
# total_return_percentage = (
# (final_portfolio_value - self.initial_capital) / self.initial_capital
# ) * 100 if self.initial_capital != 0 else 0.0
return {
"portfolio_snapshots": self.portfolio_snapshots,
"trade_history": self.trade_history,
# "initial_capital": self.initial_capital,
"all_bars": self.all_bars,
"final_portfolio_value": final_portfolio_value,
# "total_return_percentage": total_return_percentage,
}
def get_bar_history(self):
return self.all_bars
def get_price_history(self, key: str):
if key == "close":
return self.close_list
elif key == "open":
return self.open_list
elif key == "high":
return self.high_list
elif key == "low":
return self.low_list
elif key == "volume":
return self.volume_list
return None