import numpy as np import talib from typing import Optional, Any, List from src.core_data import Bar, Order from src.indicators.base_indicators import Indicator from src.indicators.indicators import Empty from src.strategies.base_strategy import Strategy class AreaReversalStrategy(Strategy): """ 面积反转策略(开仓逻辑不变,出场替换为 ATR 动态跟踪止损) """ def __init__( self, context: Any, main_symbol: str, enable_log: bool, trade_volume: int, ma_period: int = 14, area_window: int = 14, strength_window: int = 50, breakout_window: int = 20, quantile_threshold: float = 0.4, top_k: int = 3, # --- 原有跟踪止损(保留为后备)--- trailing_points: float = 100.0, trailing_percent: float = None, # --- 新增 ATR 动态止损参数 --- atr_period: int = 14, initial_atr_mult: float = 3.0, # 初始止损 = 1.0 * ATR max_atr_mult: float = 9.0, # 最大止损 = 3.0 * ATR scale_threshold_mult: float = 1.0, # 盈利达 initial_atr_mult * ATR 时开始扩大 use_atr_trailing: bool = True, # 是否启用 ATR 止损 # --- 其他 --- order_direction: Optional[List[str]] = None, indicators: Optional[List[Indicator]] = None, ): super().__init__(context, main_symbol, enable_log) if order_direction is None: order_direction = ["BUY", "SELL"] if indicators is None: indicators = [Empty(), Empty()] self.trade_volume = trade_volume self.ma_period = ma_period self.area_window = area_window self.strength_window = strength_window self.breakout_window = breakout_window self.quantile_threshold = quantile_threshold self.top_k = top_k self.trailing_points = trailing_points self.trailing_percent = trailing_percent self.atr_period = atr_period self.initial_atr_mult = initial_atr_mult self.max_atr_mult = max_atr_mult self.scale_threshold_mult = scale_threshold_mult self.use_atr_trailing = use_atr_trailing self.order_direction = order_direction self.indicators = indicators # 状态(新增 entry_atr) self.entry_price = None self.highest_high = None self.lowest_low = None self.entry_atr = None # 入场时的 ATR 值 self.order_id_counter = 0 self.min_bars_needed = max( ma_period, area_window * 3, strength_window, breakout_window, atr_period ) + 10 self.log("AreaReversalStrategy with ATR Trailing Stop Initialized") def _calculate_areas(self, closes: np.array, ma: np.array) -> np.array: diffs = np.abs(closes - ma) areas = talib.SUM(diffs, self.area_window) return areas def on_open_bar(self, open_price: float, symbol: str): self.symbol = symbol bar_history = self.get_bar_history() if len(bar_history) < self.min_bars_needed or not self.trading: return position = self.get_current_positions().get(self.symbol, 0) current_bar = bar_history[-1] # === 提取价格序列(新增 highs, lows 用于 ATR)=== closes = np.array([b.close for b in bar_history], dtype=float) highs = np.array([b.high for b in bar_history], dtype=float) lows = np.array([b.low for b in bar_history], dtype=float) # === 计算指标 === ma = talib.SMA(closes, self.ma_period) areas = self._calculate_areas(closes, ma) # 新增:计算 ATR if self.use_atr_trailing: atr = talib.ATR(highs, lows, closes, self.atr_period) current_atr = atr[-1] else: current_atr = None A1 = areas[-1] A2 = areas[-2] if len(areas) >= 2 else 0 historical_areas = areas[-(self.strength_window + 1):-1] if len(historical_areas) < self.strength_window: return # === 面积信号条件(完全不变)=== area_contracting = (A1 < A2) and (A2 > 0) threshold = np.nanpercentile(historical_areas, self.quantile_threshold * 100) strength_satisfied = (A2 >= threshold) top_k_values = np.partition(historical_areas, -self.top_k)[-self.top_k:] local_peak = (A2 >= np.min(top_k_values)) area_signal = area_contracting and strength_satisfied and local_peak # === 突破判断(完全不变)=== recent_bars = bar_history[-self.breakout_window:] highest = max(b.high for b in recent_bars) lowest = min(b.low for b in recent_bars) # =============== 开仓逻辑(完全不变)============== if position == 0 and area_signal: if "BUY" in self.order_direction and current_bar.high >= highest: self.send_market_order("BUY", self.trade_volume, "OPEN") self.entry_price = current_bar.close self.highest_high = current_bar.high self.lowest_low = None if self.use_atr_trailing and current_atr is not None: self.entry_atr = current_atr # 记录入场 ATR self.log(f"🚀 Long Entry | A2={A2:.4f}") elif "SELL" in self.order_direction and current_bar.low <= lowest: self.send_market_order("SELL", self.trade_volume, "OPEN") self.entry_price = current_bar.close self.lowest_low = current_bar.low self.highest_high = None if self.use_atr_trailing and current_atr is not None: self.entry_atr = current_atr self.log(f"⬇️ Short Entry | A2={A2:.4f}") # =============== 出场逻辑:ATR 动态跟踪止损 =============== elif position != 0 and self.entry_price is not None: if self.use_atr_trailing and self.entry_atr is not None: # --- ATR 动态止损 --- if position > 0: if self.highest_high is None or current_bar.high > self.highest_high: self.highest_high = current_bar.high unrealized_pnl = current_bar.close - self.entry_price scale_threshold_pnl = self.scale_threshold_mult * self.initial_atr_mult * self.entry_atr if unrealized_pnl <= 0: trail_mult = self.initial_atr_mult elif unrealized_pnl >= scale_threshold_pnl: trail_mult = self.max_atr_mult else: ratio = unrealized_pnl / scale_threshold_pnl trail_mult = self.initial_atr_mult + ratio * (self.max_atr_mult - self.initial_atr_mult) stop_loss_price = self.highest_high - trail_mult * self.entry_atr if current_bar.low <= stop_loss_price: self.close_position("CLOSE_LONG", position) self._reset_state() self.log(f"CloseOperation (ATR Trailing) | Mult={trail_mult:.2f}") else: # short if self.lowest_low is None or current_bar.low < self.lowest_low: self.lowest_low = current_bar.low unrealized_pnl = self.entry_price - current_bar.close scale_threshold_pnl = self.scale_threshold_mult * self.initial_atr_mult * self.entry_atr if unrealized_pnl <= 0: trail_mult = self.initial_atr_mult elif unrealized_pnl >= scale_threshold_pnl: trail_mult = self.max_atr_mult else: ratio = unrealized_pnl / scale_threshold_pnl trail_mult = self.initial_atr_mult + ratio * (self.max_atr_mult - self.initial_atr_mult) stop_loss_price = self.lowest_low + trail_mult * self.entry_atr if current_bar.high >= stop_loss_price: self.close_position("CLOSE_SHORT", -position) self._reset_state() self.log(f"CloseOperation (ATR Trailing) | Mult={trail_mult:.2f}") else: # --- 保留原有跟踪止损(后备)--- if position > 0: if self.highest_high is None or current_bar.high > self.highest_high: self.highest_high = current_bar.high if self.trailing_percent is not None: offset = self.highest_high * self.trailing_percent else: offset = self.trailing_points stop_loss_price = self.highest_high - offset if current_bar.low <= stop_loss_price: self.close_position("CLOSE_LONG", position) self._reset_state() else: if self.lowest_low is None or current_bar.low < self.lowest_low: self.lowest_low = current_bar.low if self.trailing_percent is not None: offset = self.lowest_low * self.trailing_percent else: offset = self.trailing_points stop_loss_price = self.lowest_low + offset if current_bar.high >= stop_loss_price: self.close_position("CLOSE_SHORT", -position) self._reset_state() def _reset_state(self): self.entry_price = None self.highest_high = None self.lowest_low = None self.entry_atr = None # --- 模板方法(不变)--- def on_init(self): super().on_init() self.cancel_all_pending_orders(self.main_symbol) self._reset_state() def close_position(self, direction: str, volume: int): self.send_market_order(direction, volume, offset="CLOSE") def send_market_order(self, direction: str, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET", submitted_time=self.get_current_time(), offset=offset, ) self.send_order(order) def on_rollover(self, old_symbol: str, new_symbol: str): super().on_rollover(old_symbol, new_symbol) self._reset_state() self.log("Rollover: Reset trailing stop state.")