241 lines
9.6 KiB
Python
241 lines
9.6 KiB
Python
# src/strategies/base_strategy.py
|
||
|
||
from abc import ABC, abstractmethod
|
||
from datetime import datetime
|
||
import math
|
||
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||
|
||
import numpy as np
|
||
|
||
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
|
||
from ..backtest_context import BacktestContext # 转发引用 BacktestEngine
|
||
from ..core_data import Bar, Order, Trade # 导入必要的类型
|
||
|
||
|
||
class Strategy(ABC):
|
||
"""
|
||
所有交易策略的抽象基类。
|
||
策略通过 context 对象与回测引擎和模拟器进行交互,并提供辅助方法。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
context: "BacktestContext",
|
||
symbol: str,
|
||
enable_log: bool = True,
|
||
**params: Any,
|
||
):
|
||
"""
|
||
Args:
|
||
context (BacktestEngine): 回测引擎实例,作为策略的上下文,提供与模拟器等的交互接口。
|
||
symbol (str): 策略操作的合约Symbol。
|
||
**params (Any): 其他策略特定参数。
|
||
"""
|
||
self.context = context # 存储 context 对象
|
||
self.symbol = symbol # 策略操作的合约Symbol
|
||
self.main_symbol = symbol
|
||
self.params = params
|
||
self.enable_log = enable_log
|
||
self.trading = False
|
||
# 缓存指标用
|
||
self._indicator_cache = None # type: Optional[Tuple[np.ndarray, ...]]
|
||
self._cache_length = 0 # 上次缓存时数据长度
|
||
|
||
def on_init(self):
|
||
"""
|
||
策略初始化时调用(在回测开始前)。
|
||
可用于设置初始状态或打印信息。
|
||
"""
|
||
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
|
||
|
||
def on_trade(self, trade: "Trade"):
|
||
"""
|
||
当模拟器成功执行一笔交易时调用。
|
||
可用于更新策略内部持仓状态或记录交易。
|
||
|
||
Args:
|
||
trade (Trade): 已完成的交易记录。
|
||
"""
|
||
# print(f"策略接收到交易: {trade.direction} {trade.volume} {trade.symbol} @ {trade.price:.2f}")
|
||
pass # 默认不执行任何操作,具体策略可覆盖
|
||
|
||
@abstractmethod
|
||
def on_open_bar(self, open: float, symbol: str):
|
||
"""
|
||
每当新的K线数据到来时调用此方法。
|
||
Args:
|
||
bar (Bar): 当前的K线数据对象。
|
||
next_bar_open (Optional[float]): 下一根K线的开盘价,如果存在的话。
|
||
"""
|
||
pass
|
||
|
||
def on_close_bar(self, bar: "Bar"):
|
||
"""
|
||
每当新的K线数据到来时调用此方法。
|
||
Args:
|
||
bar (Bar): 当前的K线数据对象。
|
||
next_bar_close (Optional[float]): 下一根K线的开盘价,如果存在的话。
|
||
"""
|
||
pass
|
||
|
||
def on_start_trading(self):
|
||
pass
|
||
|
||
# --- 新增/修改的辅助方法 ---
|
||
|
||
def send_order(self, order: "Order") -> Optional[Order]:
|
||
"""
|
||
发送订单的辅助方法。
|
||
会在 BaseStrategy 内部构建 Order 对象,并通过 context 转发给模拟器。
|
||
"""
|
||
if not self.trading:
|
||
return None
|
||
if self.context.is_rollover_bar:
|
||
self.log(f"当前是换月K线,禁止开仓订单")
|
||
return None
|
||
|
||
if order.price_type == 'LIMIT':
|
||
limit_price = order.limit_price
|
||
if order.direction in ["BUY", "CLOSE_SHORT"]:
|
||
# 买入限价单(或平空),希望以更低或相等的价格成交,
|
||
# 所以向下取整,确保挂单价格不高于预期。
|
||
# 例如:价格100.3,tick_size=1 -> math.floor(100.3) = 100
|
||
# 价格100.8,tick_size=1 -> math.floor(100.8) = 100
|
||
order.limit_price = math.floor(limit_price)
|
||
elif order.direction in ["SELL", "CLOSE_LONG"]:
|
||
# 卖出限价单(或平多),希望以更高或相等的价格成交,
|
||
# 所以向上取整,确保挂单价格不低于预期。
|
||
# 例如:价格100.3,tick_size=1 -> math.ceil(100.3) = 101
|
||
# 价格100.8,tick_size=1 -> math.ceil(100.8) = 101
|
||
order.limit_price = math.ceil(limit_price)
|
||
return self.context.send_order(order)
|
||
|
||
def cancel_order(self, order_id: str) -> bool:
|
||
"""
|
||
取消指定ID的订单。
|
||
通过 context 调用模拟器的 cancel_order 方法。
|
||
"""
|
||
if not self.trading:
|
||
return False
|
||
|
||
return self.context.cancel_order(order_id)
|
||
|
||
|
||
def cancel_all_pending_orders(self, main_symbol = None) -> int:
|
||
"""取消当前策略的未决订单,仅限于当前策略关注的Symbol。"""
|
||
# 注意:在换月模式下,引擎会自动取消旧合约的挂单,这里是策略主动取消
|
||
if not self.trading:
|
||
return 0
|
||
|
||
pending_orders = self.get_pending_orders()
|
||
cancelled_count = 0
|
||
# orders_to_cancel = [
|
||
# order.id for order in pending_orders.values() if order.symbol == self.symbol
|
||
# ]
|
||
if main_symbol is not None:
|
||
orders_to_cancel = [
|
||
order.id for order in pending_orders.values() if main_symbol in order.symbol
|
||
]
|
||
else:
|
||
orders_to_cancel = [
|
||
order.id for order in pending_orders.values()
|
||
]
|
||
for order_id in orders_to_cancel:
|
||
if self.cancel_order(order_id):
|
||
cancelled_count += 1
|
||
return cancelled_count
|
||
|
||
def get_current_positions(self) -> Dict[str, int]:
|
||
"""获取所有当前持仓 (可能包含多个合约)。"""
|
||
return self.context.get_current_positions()
|
||
|
||
def get_pending_orders(self) -> Dict[str, "Order"]:
|
||
"""获取所有当前待处理订单的副本 (可能包含多个合约)。"""
|
||
return self.context.get_pending_orders()
|
||
|
||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||
"""获取指定合约的平均持仓成本。"""
|
||
return self.context.get_average_position_price(symbol)
|
||
|
||
def get_account_cash(self) -> float:
|
||
"""获取当前账户现金余额。"""
|
||
return self.context.cash
|
||
|
||
def get_current_time(self) -> datetime:
|
||
"""获取模拟器当前时间。"""
|
||
return self.context.get_current_time()
|
||
|
||
def log(self, *args: Any, **kwargs: Any):
|
||
"""
|
||
统一的日志打印方法。
|
||
如果 enable_log 为 True,则打印消息到控制台,并包含当前模拟时间。
|
||
支持传入多个参数,像 print() 函数一样使用。
|
||
"""
|
||
if self.enable_log:
|
||
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
|
||
try:
|
||
current_time_str = self.context.get_current_time().strftime(
|
||
"%Y-%m-%d %H:%M:%S"
|
||
)
|
||
time_prefix = f"[{current_time_str}] "
|
||
except AttributeError:
|
||
# 如果获取不到时间(例如在策略初始化时,模拟器时间还未设置),则不加时间前缀
|
||
time_prefix = ""
|
||
|
||
# 使用 f-string 结合 *args 来构建消息
|
||
# print() 函数会将 *args 自动用空格分隔,这里我们模仿这个行为
|
||
message = " ".join(map(str, args))
|
||
|
||
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print,
|
||
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
|
||
print(f"{time_prefix}策略 ({self.symbol}): {message}")
|
||
|
||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||
"""
|
||
当回测的合约发生换月时调用此方法。
|
||
子类可以重写此方法来执行换月相关的逻辑(例如,调整目标仓位,清空历史数据)。
|
||
注意:在调用此方法前,引擎已强制平仓旧合约的所有仓位并取消所有挂单。
|
||
Args:
|
||
old_symbol (str): 旧的合约代码。
|
||
new_symbol (str): 新的合约代码。
|
||
"""
|
||
self.log(f"合约换月事件: 从 {old_symbol} 切换到 {new_symbol}")
|
||
# 默认实现可以为空,子类根据需要重写
|
||
pass
|
||
|
||
def get_bar_history(self):
|
||
return self.context.get_bar_history()
|
||
|
||
|
||
def get_price_history(self, key: str):
|
||
return self.context.get_price_history(key)
|
||
|
||
def get_indicator_tuple(self):
|
||
"""获取价格数据的 numpy 数组元组,带缓存功能。"""
|
||
close_data = self.get_price_history("close")
|
||
current_length = len(close_data)
|
||
|
||
# 如果长度没有变化,直接返回缓存
|
||
if self._indicator_cache is not None and current_length == self._cache_length:
|
||
return self._indicator_cache
|
||
|
||
# 数据有变化,重新创建数组并更新缓存
|
||
close = np.array(close_data[-1000:])
|
||
open_price = np.array(self.get_price_history("open")[-1000:])
|
||
high = np.array(self.get_price_history("high")[-1000:])
|
||
low = np.array(self.get_price_history("low")[-1000:])
|
||
volume = np.array(self.get_price_history("volume")[-1000:])
|
||
|
||
self._indicator_cache = (close, open_price, high, low, volume)
|
||
self._cache_length = current_length
|
||
|
||
return self._indicator_cache
|
||
|
||
def save_state(self, state: Any) -> None:
|
||
if self.trading:
|
||
self.context.save_state(state)
|
||
|
||
def load_state(self) -> None:
|
||
if self.trading:
|
||
self.context.load_state()
|