同步本地回测与tqsdk回测
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
import math
|
||||
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||||
|
||||
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
|
||||
@@ -52,7 +53,7 @@ class Strategy(ABC):
|
||||
pass # 默认不执行任何操作,具体策略可覆盖
|
||||
|
||||
@abstractmethod
|
||||
def on_bar(self, bar: "Bar"):
|
||||
def on_open_bar(self, bar: "Bar"):
|
||||
"""
|
||||
每当新的K线数据到来时调用此方法。
|
||||
Args:
|
||||
@@ -61,6 +62,15 @@ class Strategy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_close_bar(self, bar: "Bar"):
|
||||
"""
|
||||
每当新的K线数据到来时调用此方法。
|
||||
Args:
|
||||
bar (Bar): 当前的K线数据对象。
|
||||
next_bar_close (Optional[float]): 下一根K线的开盘价,如果存在的话。
|
||||
"""
|
||||
pass
|
||||
|
||||
# --- 新增/修改的辅助方法 ---
|
||||
|
||||
def send_order(self, order: "Order") -> Optional[Order]:
|
||||
@@ -71,6 +81,21 @@ class Strategy(ABC):
|
||||
if self.context.is_rollover_bar:
|
||||
self.log(f"当前是换月K线,禁止开仓订单")
|
||||
return None
|
||||
|
||||
if order.price_type == 'LIMIT':
|
||||
limit_price = order.limit_price
|
||||
if order.direction in ["BUY", "CLOSE_SHORT"]:
|
||||
# 买入限价单(或平空),希望以更低或相等的价格成交,
|
||||
# 所以向下取整,确保挂单价格不高于预期。
|
||||
# 例如:价格100.3,tick_size=1 -> math.floor(100.3) = 100
|
||||
# 价格100.8,tick_size=1 -> math.floor(100.8) = 100
|
||||
order.limit_price = math.floor(limit_price)
|
||||
elif order.direction in ["SELL", "CLOSE_LONG"]:
|
||||
# 卖出限价单(或平多),希望以更高或相等的价格成交,
|
||||
# 所以向上取整,确保挂单价格不低于预期。
|
||||
# 例如:价格100.3,tick_size=1 -> math.ceil(100.3) = 101
|
||||
# 价格100.8,tick_size=1 -> math.ceil(100.8) = 101
|
||||
order.limit_price = math.ceil(limit_price)
|
||||
return self.context.send_order(order)
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
@@ -96,23 +121,23 @@ class Strategy(ABC):
|
||||
|
||||
def get_current_positions(self) -> Dict[str, int]:
|
||||
"""获取所有当前持仓 (可能包含多个合约)。"""
|
||||
return self.context._simulator.get_current_positions()
|
||||
return self.context.get_current_positions()
|
||||
|
||||
def get_pending_orders(self) -> Dict[str, "Order"]:
|
||||
"""获取所有当前待处理订单的副本 (可能包含多个合约)。"""
|
||||
return self.context._simulator.get_pending_orders()
|
||||
return self.context.get_pending_orders()
|
||||
|
||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||
"""获取指定合约的平均持仓成本。"""
|
||||
return self.context._simulator.get_average_position_price(symbol)
|
||||
return self.context.get_average_position_price(symbol)
|
||||
|
||||
def get_account_cash(self) -> float:
|
||||
"""获取当前账户现金余额。"""
|
||||
return self.context._simulator.cash
|
||||
return self.context.cash
|
||||
|
||||
def get_current_time(self) -> datetime:
|
||||
"""获取模拟器当前时间。"""
|
||||
return self.context._simulator.get_current_time()
|
||||
return self.context.get_current_time()
|
||||
|
||||
def log(self, *args: Any, **kwargs: Any):
|
||||
"""
|
||||
@@ -123,7 +148,7 @@ class Strategy(ABC):
|
||||
if self.enable_log:
|
||||
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
|
||||
try:
|
||||
current_time_str = self.context._simulator.get_current_time().strftime(
|
||||
current_time_str = self.context.get_current_time().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
time_prefix = f"[{current_time_str}] "
|
||||
@@ -151,3 +176,6 @@ class Strategy(ABC):
|
||||
self.log(f"合约换月事件: 从 {old_symbol} 切换到 {new_symbol}")
|
||||
# 默认实现可以为空,子类根据需要重写
|
||||
pass
|
||||
|
||||
def get_bar_history(self):
|
||||
return self.context.get_bar_history()
|
||||
|
||||
Reference in New Issue
Block a user