Files
NewQuant/src/strategies/base_strategy.py
2025-11-07 16:26:00 +08:00

241 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# src/strategies/base_strategy.py
from abc import ABC, abstractmethod
from datetime import datetime
import math
from typing import Dict, Any, Optional, List, TYPE_CHECKING
import numpy as np
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
from ..backtest_context import BacktestContext # 转发引用 BacktestEngine
from ..core_data import Bar, Order, Trade # 导入必要的类型
class Strategy(ABC):
"""
所有交易策略的抽象基类。
策略通过 context 对象与回测引擎和模拟器进行交互,并提供辅助方法。
"""
def __init__(
self,
context: "BacktestContext",
symbol: str,
enable_log: bool = True,
**params: Any,
):
"""
Args:
context (BacktestEngine): 回测引擎实例,作为策略的上下文,提供与模拟器等的交互接口。
symbol (str): 策略操作的合约Symbol。
**params (Any): 其他策略特定参数。
"""
self.context = context # 存储 context 对象
self.symbol = symbol # 策略操作的合约Symbol
self.main_symbol = symbol
self.params = params
self.enable_log = enable_log
self.trading = False
# 缓存指标用
self._indicator_cache = None # type: Optional[Tuple[np.ndarray, ...]]
self._cache_length = 0 # 上次缓存时数据长度
def on_init(self):
"""
策略初始化时调用(在回测开始前)。
可用于设置初始状态或打印信息。
"""
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
def on_trade(self, trade: "Trade"):
"""
当模拟器成功执行一笔交易时调用。
可用于更新策略内部持仓状态或记录交易。
Args:
trade (Trade): 已完成的交易记录。
"""
# print(f"策略接收到交易: {trade.direction} {trade.volume} {trade.symbol} @ {trade.price:.2f}")
pass # 默认不执行任何操作,具体策略可覆盖
@abstractmethod
def on_open_bar(self, open: float, symbol: str):
"""
每当新的K线数据到来时调用此方法。
Args:
bar (Bar): 当前的K线数据对象。
next_bar_open (Optional[float]): 下一根K线的开盘价如果存在的话。
"""
pass
def on_close_bar(self, bar: "Bar"):
"""
每当新的K线数据到来时调用此方法。
Args:
bar (Bar): 当前的K线数据对象。
next_bar_close (Optional[float]): 下一根K线的开盘价如果存在的话。
"""
pass
def on_start_trading(self):
pass
# --- 新增/修改的辅助方法 ---
def send_order(self, order: "Order") -> Optional[Order]:
"""
发送订单的辅助方法。
会在 BaseStrategy 内部构建 Order 对象,并通过 context 转发给模拟器。
"""
if not self.trading:
return None
if self.context.is_rollover_bar:
self.log(f"当前是换月K线禁止开仓订单")
return None
if order.price_type == 'LIMIT':
limit_price = order.limit_price
if order.direction in ["BUY", "CLOSE_SHORT"]:
# 买入限价单(或平空),希望以更低或相等的价格成交,
# 所以向下取整,确保挂单价格不高于预期。
# 例如价格100.3tick_size=1 -> math.floor(100.3) = 100
# 价格100.8tick_size=1 -> math.floor(100.8) = 100
order.limit_price = math.floor(limit_price)
elif order.direction in ["SELL", "CLOSE_LONG"]:
# 卖出限价单(或平多),希望以更高或相等的价格成交,
# 所以向上取整,确保挂单价格不低于预期。
# 例如价格100.3tick_size=1 -> math.ceil(100.3) = 101
# 价格100.8tick_size=1 -> math.ceil(100.8) = 101
order.limit_price = math.ceil(limit_price)
return self.context.send_order(order)
def cancel_order(self, order_id: str) -> bool:
"""
取消指定ID的订单。
通过 context 调用模拟器的 cancel_order 方法。
"""
if not self.trading:
return False
return self.context.cancel_order(order_id)
def cancel_all_pending_orders(self, main_symbol = None) -> int:
"""取消当前策略的未决订单仅限于当前策略关注的Symbol。"""
# 注意:在换月模式下,引擎会自动取消旧合约的挂单,这里是策略主动取消
if not self.trading:
return 0
pending_orders = self.get_pending_orders()
cancelled_count = 0
# orders_to_cancel = [
# order.id for order in pending_orders.values() if order.symbol == self.symbol
# ]
if main_symbol is not None:
orders_to_cancel = [
order.id for order in pending_orders.values() if main_symbol in order.symbol
]
else:
orders_to_cancel = [
order.id for order in pending_orders.values()
]
for order_id in orders_to_cancel:
if self.cancel_order(order_id):
cancelled_count += 1
return cancelled_count
def get_current_positions(self) -> Dict[str, int]:
"""获取所有当前持仓 (可能包含多个合约)。"""
return self.context.get_current_positions()
def get_pending_orders(self) -> Dict[str, "Order"]:
"""获取所有当前待处理订单的副本 (可能包含多个合约)。"""
return self.context.get_pending_orders()
def get_average_position_price(self, symbol: str) -> Optional[float]:
"""获取指定合约的平均持仓成本。"""
return self.context.get_average_position_price(symbol)
def get_account_cash(self) -> float:
"""获取当前账户现金余额。"""
return self.context.cash
def get_current_time(self) -> datetime:
"""获取模拟器当前时间。"""
return self.context.get_current_time()
def log(self, *args: Any, **kwargs: Any):
"""
统一的日志打印方法。
如果 enable_log 为 True则打印消息到控制台并包含当前模拟时间。
支持传入多个参数,像 print() 函数一样使用。
"""
if self.enable_log:
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
try:
current_time_str = self.context.get_current_time().strftime(
"%Y-%m-%d %H:%M:%S"
)
time_prefix = f"[{current_time_str}] "
except AttributeError:
# 如果获取不到时间(例如在策略初始化时,模拟器时间还未设置),则不加时间前缀
time_prefix = ""
# 使用 f-string 结合 *args 来构建消息
# print() 函数会将 *args 自动用空格分隔,这里我们模仿这个行为
message = " ".join(map(str, args))
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
print(f"{time_prefix}策略 ({self.symbol}): {message}")
def on_rollover(self, old_symbol: str, new_symbol: str):
"""
当回测的合约发生换月时调用此方法。
子类可以重写此方法来执行换月相关的逻辑(例如,调整目标仓位,清空历史数据)。
注意:在调用此方法前,引擎已强制平仓旧合约的所有仓位并取消所有挂单。
Args:
old_symbol (str): 旧的合约代码。
new_symbol (str): 新的合约代码。
"""
self.log(f"合约换月事件: 从 {old_symbol} 切换到 {new_symbol}")
# 默认实现可以为空,子类根据需要重写
pass
def get_bar_history(self):
return self.context.get_bar_history()
def get_price_history(self, key: str):
return self.context.get_price_history(key)
def get_indicator_tuple(self):
"""获取价格数据的 numpy 数组元组,带缓存功能。"""
close_data = self.get_price_history("close")
current_length = len(close_data)
# 如果长度没有变化,直接返回缓存
if self._indicator_cache is not None and current_length == self._cache_length:
return self._indicator_cache
# 数据有变化,重新创建数组并更新缓存
close = np.array(close_data[-1000:])
open_price = np.array(self.get_price_history("open")[-1000:])
high = np.array(self.get_price_history("high")[-1000:])
low = np.array(self.get_price_history("low")[-1000:])
volume = np.array(self.get_price_history("volume")[-1000:])
self._indicator_cache = (close, open_price, high, low, volume)
self._cache_length = current_length
return self._indicator_cache
def save_state(self, state: Any) -> None:
if self.trading:
self.context.save_state(state)
def load_state(self) -> None:
if self.trading:
self.context.load_state()