Files
NewQuant/futures_trading_strategies/SA/TrendlineBreakoutStrategy/AtrVol/TrendlineHawkesStrategy.py
2025-10-05 00:09:59 +08:00

284 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import pandas as pd
from typing import Optional, Dict, Any, List, Union
import talib # <-- 【新增】导入talib库
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
from src.algo.TrendLine import calculate_latest_trendline_values
import numpy as np
import pandas as pd
from typing import Optional, Dict, Any, List, Union
import talib
class TrendlineHawkesStrategy(Strategy):
"""
趋势线与霍克斯过程双重确认策略 (V8 - O(1) 滚动统计终极版):
- 对交易量Z-score的计算进行了极致优化采用增量方式维护滚动窗口的统计量。
- 每次更新均值和标准差的计算复杂度从 O(N) 降为 O(1)。
- 这是目前性能最高的实现方式,适用于非常高频的场景。
"""
def __init__(
self,
context: Any,
main_symbol: str,
# --- 所有参数与V7完全相同 ---
trade_volume: int = 1,
order_direction: Optional[List[str]] = None,
reverse_logic: bool = False,
trendline_n: int = 50,
hawkes_kappa: float = 0.1,
hawkes_lookback: int = 50,
hawkes_entry_percent: float = 0.95,
hawkes_exit_percent: float = 0.25,
volume_norm_n: int = 50,
enable_atr_stop_loss: bool = True,
atr_period: int = 14,
atr_multiplier: float = 1.0,
enable_log: bool = True,
indicators: Union[Indicator, List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
# --- 参数赋值 (与V7相同) ---
# ... (省略) ...
self.main_symbol = main_symbol
self.trade_volume = trade_volume
self.order_direction = order_direction or ["BUY", "SELL"]
self.reverse_logic = reverse_logic
self.trendline_n = trendline_n
self.hawkes_kappa = hawkes_kappa
self.hawkes_lookback = hawkes_lookback
self.hawkes_entry_percent = hawkes_entry_percent
self.hawkes_exit_percent = hawkes_exit_percent
self.volume_norm_n = volume_norm_n
self.enable_atr_stop_loss = enable_atr_stop_loss
self.atr_period = atr_period
self.atr_multiplier = atr_multiplier
self.pos_meta: Dict[str, Dict[str, Any]] = {}
if indicators is None:
indicators = [Empty(), Empty()]
self.indicators = indicators
# --- 霍克斯过程状态 (与V7相同) ---
self._last_hawkes_unscaled: float = 0.0
self._hawkes_window: np.ndarray = np.array([], dtype=np.float64)
self._hawkes_alpha = np.exp(-self.hawkes_kappa)
# --- 【核心修改】O(1) 滚动统计状态 ---
# 预分配一个固定长度的数组作为循环缓冲区
self._volume_window: np.ndarray = np.zeros(self.volume_norm_n, dtype=np.float64)
self._volume_sum: float = 0.0 # 窗口内元素的和
self._volume_sum_sq: float = 0.0 # 窗口内元素平方的和
self._volume_pointer: int = 0 # 指向窗口中最旧元素的指针
self._is_volume_window_full: bool = False # 窗口是否已填满的标志
def on_init(self):
super().on_init()
self.pos_meta.clear()
# 重置霍克斯状态
self._last_hawkes_unscaled = 0.0
self._hawkes_window = np.array([], dtype=np.float64)
# 【核心修改】重置所有滚动统计状态
self._volume_window.fill(0)
self._volume_sum = 0.0
self._volume_sum_sq = 0.0
self._volume_pointer = 0
self._is_volume_window_full = False
# 【核心修改】_initialize_state 和 _update_state_incrementally 被重构
def _initialize_state(self, initial_volumes: np.ndarray):
"""
在策略开始时调用一次,用历史数据填充所有状态。
这个函数现在也会以增量方式填充滚动统计量。
"""
print("首次运行,正在以增量方式初始化所有状态...")
# 1. 增量填充交易量窗口并计算历史Z-score
normalized_volumes = []
for vol in initial_volumes:
# 调用增量更新函数,该函数会更新窗口、和、平方和
self._update_volume_stats_incrementally(vol)
# 计算Z-score
mean, std = self._get_current_volume_stats()
z_score = 0.0
if std > 1e-9:
z_score = (vol - mean) / std
normalized_volumes.append(z_score)
# 2. 使用标准化的交易量历史来初始化霍克斯过程 (逻辑与V7相同)
print("正在基于标准化的交易量初始化霍克斯过程...")
alpha = self._hawkes_alpha
temp_hawkes_history = np.zeros_like(normalized_volumes, dtype=np.float64)
if len(normalized_volumes) > 0:
temp_hawkes_history[0] = normalized_volumes[0]
for i in range(1, len(normalized_volumes)):
temp_hawkes_history[i] = temp_hawkes_history[i - 1] * alpha + normalized_volumes[i]
# 3. 记录最后的状态
self._last_hawkes_unscaled = temp_hawkes_history[-1] if len(temp_hawkes_history) > 0 else 0.0
self._hawkes_window = (temp_hawkes_history * self.hawkes_kappa)[-self.hawkes_lookback:]
print("状态初始化完成。")
def _update_volume_stats_incrementally(self, latest_volume: float):
"""O(1) 增量更新交易量窗口的统计数据"""
# 获取即将被替换的最旧的元素
oldest_volume = self._volume_window[self._volume_pointer]
# 更新和与平方和
self._volume_sum += latest_volume - oldest_volume
self._volume_sum_sq += latest_volume ** 2 - oldest_volume ** 2
# 在循环缓冲区中替换旧值
self._volume_window[self._volume_pointer] = latest_volume
# 移动指针
self._volume_pointer += 1
if self._volume_pointer >= self.volume_norm_n:
self._volume_pointer = 0
self._is_volume_window_full = True # 窗口在指针第一次循环时被填满
def _get_current_volume_stats(self) -> (float, float):
"""O(1) 获取当前的均值和标准差"""
# 在窗口未满时,我们按实际元素数量计算
n = self.volume_norm_n if self._is_volume_window_full else self._volume_pointer
if n == 0:
return 0.0, 0.0
mean = self._volume_sum / n
# 为防止浮点误差导致极小的负数,使用 max(0, ...)
variance = max(0, (self._volume_sum_sq / n) - mean ** 2)
std = np.sqrt(variance)
return mean, std
def _update_state_incrementally(self, latest_volume: float):
"""【重构】每个Bar上调用的主增量更新函数"""
# 1. O(1) 更新交易量统计
self._update_volume_stats_incrementally(latest_volume)
# 2. O(1) 计算最新Z-score
mean, std = self._get_current_volume_stats()
normalized_volume = 0.0
if std > 1e-9:
normalized_volume = (latest_volume - mean) / std
# 3. 更新霍克斯过程 (逻辑与V7相同)
new_hawkes_unscaled = self._last_hawkes_unscaled * self._hawkes_alpha + normalized_volume
self._last_hawkes_unscaled = new_hawkes_unscaled
new_hawkes_scaled = new_hawkes_unscaled * self.hawkes_kappa
if self._hawkes_window.size < self.hawkes_lookback:
self._hawkes_window = np.append(self._hawkes_window, new_hawkes_scaled)
else:
self._hawkes_window = np.roll(self._hawkes_window, -1)
self._hawkes_window[-1] = new_hawkes_scaled
# on_open_bar 逻辑不变,它只负责调用 _update_state_incrementally
def on_open_bar(self, open_price: float, symbol: str):
bar_history = self.get_bar_history()
min_bars_required = max(self.trendline_n + 2, self.hawkes_lookback + 2, self.volume_norm_n + 2,
self.atr_period + 2)
if len(bar_history) < min_bars_required:
return
# 状态更新 (调用重构后的函数)
if self._hawkes_window.size == 0:
initial_volumes = np.array([b.volume for b in bar_history], dtype=float)
self._initialize_state(initial_volumes[:-1])
self._update_state_incrementally(float(bar_history[-1].volume))
# --- 后续交易逻辑 (与V7完全相同) ---
# ... (此处省略代码与V7的 on_open_bar 后半部分完全一样) ...
self.cancel_all_pending_orders(symbol)
pos = self.get_current_positions().get(symbol, 0)
latest_hawkes_value = self._hawkes_window[-1]
latest_hawkes_lower = np.quantile(self._hawkes_window, self.hawkes_exit_percent)
meta = self.pos_meta.get(symbol)
if meta and pos != 0:
close_reason = None
if latest_hawkes_value < latest_hawkes_lower:
close_reason = f"霍克斯出场信号(强度: {latest_hawkes_value:.4f} < 阈值: {latest_hawkes_lower:.4f})"
if self.enable_atr_stop_loss and 'stop_loss_price' in meta and meta['stop_loss_price'] is not None:
last_close = bar_history[-1].close
stop_loss_price = meta['stop_loss_price']
if (meta['direction'] == "BUY" and last_close < stop_loss_price) or \
(meta['direction'] == "SELL" and last_close > stop_loss_price):
close_reason = f"ATR止损触发(收盘价: {last_close:.2f}, 止损价: {stop_loss_price:.2f})"
if close_reason:
self.log(close_reason)
self.send_market_order("CLOSE_LONG" if meta['direction'] == "BUY" else "CLOSE_SHORT", abs(pos))
if symbol in self.pos_meta: del self.pos_meta[symbol]
return
if pos == 0:
latest_hawkes_upper = np.quantile(self._hawkes_window, self.hawkes_entry_percent)
close_prices = np.array([b.close for b in bar_history])
prices_for_trendline = close_prices[-self.trendline_n - 1:-1]
trend_upper, trend_lower = calculate_latest_trendline_values(prices_for_trendline)
if trend_upper is not None and trend_lower is not None:
prev_close, last_close = bar_history[-2].close, bar_history[-1].close
upper_break = last_close > trend_upper and prev_close < trend_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple())
lower_break = last_close < trend_lower and prev_close > trend_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple())
hawkes_confirm = latest_hawkes_value > latest_hawkes_upper
if hawkes_confirm and (upper_break or lower_break):
direction = "BUY"
if upper_break:
direction = "SELL" if self.reverse_logic else "BUY"
elif lower_break:
direction = "BUY" if self.reverse_logic else "SELL"
if direction in self.order_direction:
sl_price = None
if self.enable_atr_stop_loss:
atr_val = self._calculate_atr(bar_history[:-1], self.atr_period)
if atr_val is not None:
sl_price = open_price - atr_val * self.atr_multiplier if direction == "BUY" else open_price + atr_val * self.atr_multiplier
self.log(f"ATR({self.atr_period})={atr_val:.4f}, 止损价设置为: {sl_price:.2f}")
self.log(
f"开仓信号确认(霍克斯强度: {latest_hawkes_value:.4f} > 阈值: {latest_hawkes_upper:.4f})")
self.send_open_order(direction, open_price, self.trade_volume, sl_price)
# ATR计算函数及其他下单函数与V7完全相同
def _calculate_atr(self, bar_history: List[Bar], period: int) -> Optional[float]:
if len(bar_history) < period + 1: return None
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
closes = np.array([b.close for b in bar_history], dtype=float)
atr_values = talib.ATR(highs, lows, closes, timeperiod=period)
latest_atr = atr_values[-1]
return latest_atr if not np.isnan(latest_atr) else None
def send_open_order(self, direction: str, entry_price: float, volume: int, stop_loss_price: Optional[float] = None):
current_time = self.get_current_time()
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
order_direction = "BUY" if direction == "BUY" else "SELL"
order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET",
submitted_time=current_time, offset="OPEN")
self.send_order(order)
self.pos_meta[self.symbol] = {
"direction": direction,
"volume": volume,
"entry_price": entry_price,
"stop_loss_price": stop_loss_price
}
self.log(f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f})")
def send_market_order(self, direction: str, volume: int):
current_time = self.get_current_time()
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
submitted_time=current_time, offset="CLOSE")
self.send_order(order)
self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price")
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self.cancel_all_pending_orders(new_symbol)
self.pos_meta.clear()