2025-06-18 10:25:05 +08:00
|
|
|
|
# src/strategies/base_strategy.py
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
2025-06-22 23:03:50 +08:00
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
|
|
|
|
|
|
from ..backtest_context import BacktestContext # 转发引用 BacktestEngine
|
|
|
|
|
|
from ..core_data import Bar, Order, Trade # 导入必要的类型
|
2025-06-18 10:25:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Strategy(ABC):
|
|
|
|
|
|
"""
|
2025-06-22 23:03:50 +08:00
|
|
|
|
所有交易策略的抽象基类。
|
|
|
|
|
|
策略通过 context 对象与回测引擎和模拟器进行交互,并提供辅助方法。
|
2025-06-18 10:25:05 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
2025-06-22 23:03:50 +08:00
|
|
|
|
def __init__(self, context: 'BacktestContext', symbol: str, enable_log: bool = True, **params: Any):
|
2025-06-18 10:25:05 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Args:
|
2025-06-22 23:03:50 +08:00
|
|
|
|
context (BacktestEngine): 回测引擎实例,作为策略的上下文,提供与模拟器等的交互接口。
|
|
|
|
|
|
symbol (str): 策略操作的合约Symbol。
|
|
|
|
|
|
**params (Any): 其他策略特定参数。
|
2025-06-18 10:25:05 +08:00
|
|
|
|
"""
|
2025-06-22 23:03:50 +08:00
|
|
|
|
self.context = context # 存储 context 对象
|
|
|
|
|
|
self.symbol = symbol # 策略操作的合约Symbol
|
|
|
|
|
|
self.params = params
|
|
|
|
|
|
self.enable_log = enable_log
|
2025-06-18 10:25:05 +08:00
|
|
|
|
|
|
|
|
|
|
def on_init(self):
|
|
|
|
|
|
"""
|
|
|
|
|
|
策略初始化时调用(在回测开始前)。
|
|
|
|
|
|
可用于设置初始状态或打印信息。
|
|
|
|
|
|
"""
|
|
|
|
|
|
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
|
|
|
|
|
|
|
2025-06-22 23:03:50 +08:00
|
|
|
|
def on_trade(self, trade: 'Trade'):
|
2025-06-18 10:25:05 +08:00
|
|
|
|
"""
|
|
|
|
|
|
当模拟器成功执行一笔交易时调用。
|
|
|
|
|
|
可用于更新策略内部持仓状态或记录交易。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
trade (Trade): 已完成的交易记录。
|
|
|
|
|
|
"""
|
|
|
|
|
|
# print(f"策略接收到交易: {trade.direction} {trade.volume} {trade.symbol} @ {trade.price:.2f}")
|
2025-06-22 23:03:50 +08:00
|
|
|
|
pass # 默认不执行任何操作,具体策略可覆盖
|
2025-06-18 10:25:05 +08:00
|
|
|
|
|
2025-06-22 23:03:50 +08:00
|
|
|
|
@abstractmethod
|
|
|
|
|
|
def on_bar(self, bar: 'Bar'):
|
2025-06-18 10:25:05 +08:00
|
|
|
|
"""
|
2025-06-22 23:03:50 +08:00
|
|
|
|
每当新的K线数据到来时调用此方法。
|
2025-06-18 10:25:05 +08:00
|
|
|
|
Args:
|
2025-06-22 23:03:50 +08:00
|
|
|
|
bar (Bar): 当前的K线数据对象。
|
|
|
|
|
|
next_bar_open (Optional[float]): 下一根K线的开盘价,如果存在的话。
|
|
|
|
|
|
"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# --- 新增/修改的辅助方法 ---
|
|
|
|
|
|
|
|
|
|
|
|
def send_order(self, order: 'Order') -> Optional[Trade]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
发送订单的辅助方法。
|
|
|
|
|
|
会在 BaseStrategy 内部构建 Order 对象,并通过 context 转发给模拟器。
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self.context.send_order(order)
|
|
|
|
|
|
|
|
|
|
|
|
def cancel_order(self, order_id: str) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
取消指定ID的订单。
|
|
|
|
|
|
通过 context 调用模拟器的 cancel_order 方法。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
return self.context.cancel_order(order_id)
|
|
|
|
|
|
|
|
|
|
|
|
def cancel_all_pending_orders(self) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
取消所有当前策略的未决订单。
|
|
|
|
|
|
返回成功取消的订单数量。
|
|
|
|
|
|
"""
|
|
|
|
|
|
pending_orders = self.get_pending_orders() # 调用 BaseStrategy 自己的 get_pending_orders
|
|
|
|
|
|
cancelled_count = 0
|
|
|
|
|
|
orders_to_cancel = [order.id for order in pending_orders.values() if order.symbol == self.symbol]
|
|
|
|
|
|
for order_id in orders_to_cancel:
|
|
|
|
|
|
if self.cancel_order(order_id): # 调用 BaseStrategy 自己的 cancel_order
|
|
|
|
|
|
cancelled_count += 1
|
|
|
|
|
|
return cancelled_count
|
|
|
|
|
|
|
|
|
|
|
|
def get_current_positions(self) -> Dict[str, int]:
|
2025-06-18 10:25:05 +08:00
|
|
|
|
"""
|
2025-06-22 23:03:50 +08:00
|
|
|
|
获取当前持仓。
|
|
|
|
|
|
通过 context 调用模拟器的 get_positions 方法。
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self.context._simulator.get_current_positions()
|
|
|
|
|
|
|
|
|
|
|
|
def get_pending_orders(self) -> Dict[str, 'Order']:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取当前所有待处理订单的副本。
|
|
|
|
|
|
通过 context 调用模拟器的 get_pending_orders 方法。
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self.context._simulator.get_pending_orders()
|
|
|
|
|
|
|
|
|
|
|
|
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取指定合约的平均持仓成本。
|
|
|
|
|
|
通过 context 调用模拟器的 get_average_position_price 方法。
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self.context._simulator.get_average_position_price(symbol)
|
|
|
|
|
|
|
|
|
|
|
|
# 你可以根据需要在这里添加更多辅助方法,如获取账户净值等
|
|
|
|
|
|
def get_account_cash(self) -> float:
|
|
|
|
|
|
"""获取当前账户现金余额。"""
|
|
|
|
|
|
return self.context._simulator.cash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log(self, *args: Any, **kwargs: Any):
|
|
|
|
|
|
"""
|
|
|
|
|
|
统一的日志打印方法。
|
|
|
|
|
|
如果 enable_log 为 True,则打印消息到控制台,并包含当前模拟时间。
|
|
|
|
|
|
支持传入多个参数,像 print() 函数一样使用。
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self.enable_log:
|
|
|
|
|
|
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
|
|
|
|
|
|
try:
|
|
|
|
|
|
current_time_str = self.context._simulator.get_current_time().strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
|
|
time_prefix = f"[{current_time_str}] "
|
|
|
|
|
|
except AttributeError:
|
|
|
|
|
|
# 如果获取不到时间(例如在策略初始化时,模拟器时间还未设置),则不加时间前缀
|
|
|
|
|
|
time_prefix = ""
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 f-string 结合 *args 来构建消息
|
|
|
|
|
|
# print() 函数会将 *args 自动用空格分隔,这里我们模仿这个行为
|
|
|
|
|
|
message = ' '.join(map(str, args))
|
|
|
|
|
|
|
|
|
|
|
|
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print,
|
|
|
|
|
|
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
|
|
|
|
|
|
print(f"{time_prefix}策略 ({self.symbol}): {message}", **kwargs)
|