1、卡尔曼策略
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
# filename: tqsdk_engine.py
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Literal, Type, Dict, Any, List, Optional
|
||||
import pandas as pd
|
||||
import uuid
|
||||
|
||||
from src.common_utils import generate_strategy_identifier
|
||||
# 导入你提供的 core_data 中的类型
|
||||
from src.core_data import Bar, Order, Trade, PortfolioSnapshot
|
||||
|
||||
@@ -22,6 +24,7 @@ from tqsdk import (
|
||||
BacktestFinished,
|
||||
)
|
||||
|
||||
from src.state_repo import MemoryStateRepository
|
||||
# 导入 TqsdkContext 和 BaseStrategy
|
||||
from src.tqsdk_context import TqsdkContext
|
||||
from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径
|
||||
@@ -36,15 +39,15 @@ class TqsdkEngine:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy_class: Type[Strategy],
|
||||
strategy_params: Dict[str, Any],
|
||||
api: TqApi,
|
||||
roll_over_mode: bool = False, # 是否开启换月模式检测
|
||||
symbol: str = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
duration_seconds: int = 1,
|
||||
self,
|
||||
strategy_class: Type[Strategy],
|
||||
strategy_params: Dict[str, Any],
|
||||
api: TqApi,
|
||||
roll_over_mode: bool = False, # 是否开启换月模式检测
|
||||
symbol: str = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
duration_seconds: int = 1,
|
||||
):
|
||||
"""
|
||||
初始化 Tqsdk 回测引擎。
|
||||
@@ -83,7 +86,8 @@ class TqsdkEngine:
|
||||
# )
|
||||
|
||||
# 初始化上下文
|
||||
self._context: TqsdkContext = TqsdkContext(api=self._api)
|
||||
identifier = generate_strategy_identifier(strategy_class, strategy_params)
|
||||
self._context: TqsdkContext = TqsdkContext(api=self._api, state_repository=MemoryStateRepository(identifier))
|
||||
# 实例化策略,并将上下文传递给它
|
||||
self._strategy: Strategy = self.strategy_class(
|
||||
context=self._context, **self.strategy_params
|
||||
@@ -113,6 +117,8 @@ class TqsdkEngine:
|
||||
if roll_over_mode:
|
||||
self.quote = api.get_quote(self.symbol)
|
||||
|
||||
self.target_pos_dict = {}
|
||||
|
||||
print("TqsdkEngine: 初始化完成。")
|
||||
|
||||
@property
|
||||
@@ -127,7 +133,6 @@ class TqsdkEngine:
|
||||
异步处理 Context 中排队的订单和取消请求。
|
||||
"""
|
||||
# 处理订单
|
||||
print(self._context.order_queue)
|
||||
while self._context.order_queue:
|
||||
order_to_send: Order = self._context.order_queue.popleft()
|
||||
print(f"Engine: 处理订单请求: {order_to_send}")
|
||||
@@ -155,7 +160,23 @@ class TqsdkEngine:
|
||||
if "SHFE" in order_to_send.symbol:
|
||||
tqsdk_offset = "OPEN"
|
||||
|
||||
try:
|
||||
if "CLOSE" in order_to_send.direction:
|
||||
current_positions = self._context.get_current_positions()
|
||||
current_pos_volume = current_positions.get(order_to_send.symbol, 0)
|
||||
|
||||
target_volume = None
|
||||
if order_to_send.direction == 'CLOSE_LONG':
|
||||
target_volume = current_pos_volume - order_to_send.volume
|
||||
elif order_to_send.direction == 'CLOSE_SHORT':
|
||||
target_volume = current_pos_volume + order_to_send.volume
|
||||
|
||||
if target_volume is not None:
|
||||
if order_to_send.symbol not in self.target_pos_dict:
|
||||
self.target_pos_dict[order_to_send.symbol] = TargetPosTask(self._api, order_to_send.symbol)
|
||||
|
||||
self.target_pos_dict[order_to_send.symbol].set_target_volume(target_volume)
|
||||
else:
|
||||
# try:
|
||||
tq_order = self._api.insert_order(
|
||||
symbol=order_to_send.symbol,
|
||||
direction=tqsdk_direction,
|
||||
@@ -181,42 +202,10 @@ class TqsdkEngine:
|
||||
tq_order.insert_date_time, unit="ns", utc=True
|
||||
)
|
||||
|
||||
# 等待订单状态更新(成交/撤销/报错)
|
||||
# 在 Tqsdk 中,订单和成交是独立的,通常在 wait_update() 循环中通过 api.is_changing() 检查
|
||||
# 这里为了模拟同步处理,直接等待订单状态最终确定
|
||||
# 注意:实际回测中,不应在这里长时间阻塞,而应在主循环中持续 wait_update
|
||||
# 为了简化适配,这里模拟即时处理,但可能与真实异步行为有差异。
|
||||
# 更健壮的方式是在主循环中通过订单状态回调更新
|
||||
# 这里我们假设订单会很快更新状态,或者在下一个 wait_update() 周期中被检测到
|
||||
self._api.wait_update() # 等待一次更新
|
||||
|
||||
# # 检查最终订单状态和成交
|
||||
# if tq_order.status == "FINISHED":
|
||||
# # 查找对应的成交记录
|
||||
# for trade_id, tq_trade in self._api.get_trade().items():
|
||||
# if tq_trade.order_id == tq_order.order_id and tq_trade.volume > 0: # 确保是实际成交
|
||||
# # 创建 core_data.Trade 对象
|
||||
# trade = Trade(
|
||||
# order_id=tq_trade.order_id,
|
||||
# fill_time=tafunc.get_datetime_from_timestamp(tq_trade.trade_date_time) if tq_trade.trade_date_time else datetime.now(),
|
||||
# symbol=order_to_send.symbol, # 使用 Context 中的 symbol
|
||||
# direction=tq_trade.direction, # 实际成交方向
|
||||
# volume=tq_trade.volume,
|
||||
# price=tq_trade.price,
|
||||
# commission=tq_trade.commission,
|
||||
# cash_after_trade=self._api.get_account().available,
|
||||
# positions_after_trade=self._context.get_current_positions(),
|
||||
# realized_pnl=tq_trade.realized_pnl, # Tqsdk TqTrade 对象有 realized_pnl
|
||||
# is_open_trade=tq_trade.offset == "OPEN",
|
||||
# is_close_trade=tq_trade.offset in ["CLOSE", "CLOSETODAY", "CLOSEYESTERDAY"]
|
||||
# )
|
||||
# self.trade_history.append(trade)
|
||||
# print(f"Engine: 成交记录: {trade}")
|
||||
# break # 找到成交就跳出
|
||||
# order_to_send.status = tq_order.status # 更新最终状态
|
||||
except Exception as e:
|
||||
print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}")
|
||||
# order_to_send.status = "ERROR"
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}")
|
||||
|
||||
# 处理取消请求
|
||||
while self._context.cancel_queue:
|
||||
@@ -254,7 +243,7 @@ class TqsdkEngine:
|
||||
price = quote.last_price
|
||||
current_prices[symbol] = price
|
||||
total_market_value += (
|
||||
price * qty * quote.volume_multiple
|
||||
price * qty * quote.volume_multiple
|
||||
) # volume_multiple 乘数
|
||||
else:
|
||||
# 如果没有最新价格,使用最近的K线收盘价作为估算
|
||||
@@ -267,11 +256,11 @@ class TqsdkEngine:
|
||||
price = last_kline.close
|
||||
current_prices[symbol] = price
|
||||
total_market_value += (
|
||||
price * qty * self._api.get_instrument(symbol).volume_multiple
|
||||
price * qty * self._api.get_instrument(symbol).volume_multiple
|
||||
) # 使用 instrument 的乘数
|
||||
|
||||
total_value = (
|
||||
account.available + account.frozen_margin + total_market_value
|
||||
account.available + account.frozen_margin + total_market_value
|
||||
) # Tqsdk 的 balance 已包含持仓市值和冻结资金
|
||||
# Tqsdk 的 total_profit/balance 已经包含了所有盈亏和资金
|
||||
|
||||
@@ -344,8 +333,8 @@ class TqsdkEngine:
|
||||
self._api.wait_update()
|
||||
|
||||
if self.roll_over_mode and (
|
||||
self._api.is_changing(self.quote, "underlying_symbol")
|
||||
or self._last_underlying_symbol != self.quote.underlying_symbol
|
||||
self._api.is_changing(self.quote, "underlying_symbol")
|
||||
or self._last_underlying_symbol != self.quote.underlying_symbol
|
||||
):
|
||||
self._last_underlying_symbol = self.quote.underlying_symbol
|
||||
|
||||
@@ -355,9 +344,9 @@ class TqsdkEngine:
|
||||
now_dt = now_dt.tz_convert(BEIJING_TZ)
|
||||
|
||||
if (
|
||||
self.now is not None
|
||||
and self.now.hour != 13
|
||||
and now_dt.hour == 13
|
||||
self.now is not None
|
||||
and self.now.hour != 13
|
||||
and now_dt.hour == 13
|
||||
):
|
||||
self.main()
|
||||
|
||||
@@ -400,8 +389,8 @@ class TqsdkEngine:
|
||||
)
|
||||
|
||||
if (
|
||||
self.last_processed_bar is None
|
||||
or self.last_processed_bar.datetime != kline_dt
|
||||
self.last_processed_bar is None
|
||||
or self.last_processed_bar.datetime != kline_dt
|
||||
):
|
||||
|
||||
# 设置当前 Bar 到 Context
|
||||
@@ -410,9 +399,9 @@ class TqsdkEngine:
|
||||
# Tqsdk 的 is_changing 用于判断数据是否有变化,对于回测遍历 K 线,每次迭代都算作新 Bar
|
||||
# 如果 kline_row.datetime 与上次不同,则认为是新 Bar
|
||||
if (
|
||||
self.roll_over_mode
|
||||
and self.last_processed_bar is not None
|
||||
and self._last_underlying_symbol != self.last_processed_bar.symbol
|
||||
self.roll_over_mode
|
||||
and self.last_processed_bar is not None
|
||||
and self._last_underlying_symbol != self.last_processed_bar.symbol
|
||||
):
|
||||
self._is_rollover_bar = True
|
||||
print(
|
||||
@@ -439,6 +428,8 @@ class TqsdkEngine:
|
||||
# 记录投资组合快照
|
||||
self._record_portfolio_snapshot(current_bar.datetime)
|
||||
else:
|
||||
if current_bar.volume == 0:
|
||||
return
|
||||
self.all_bars.append(current_bar)
|
||||
|
||||
self.close_list.append(current_bar.close)
|
||||
|
||||
Reference in New Issue
Block a user