255 lines
11 KiB
Python
255 lines
11 KiB
Python
import numpy as np
|
||
import talib
|
||
from typing import Optional, Any, List
|
||
|
||
from src.core_data import Bar, Order
|
||
from src.indicators.base_indicators import Indicator
|
||
from src.indicators.indicators import Empty
|
||
from src.strategies.base_strategy import Strategy
|
||
|
||
|
||
class AreaReversalStrategy(Strategy):
|
||
"""
|
||
面积反转策略(开仓逻辑不变,出场替换为 ATR 动态跟踪止损)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
context: Any,
|
||
main_symbol: str,
|
||
enable_log: bool,
|
||
trade_volume: int,
|
||
ma_period: int = 14,
|
||
area_window: int = 14,
|
||
strength_window: int = 50,
|
||
breakout_window: int = 20,
|
||
quantile_threshold: float = 0.4,
|
||
top_k: int = 3,
|
||
# --- 原有跟踪止损(保留为后备)---
|
||
trailing_points: float = 100.0,
|
||
trailing_percent: float = None,
|
||
# --- 新增 ATR 动态止损参数 ---
|
||
atr_period: int = 14,
|
||
initial_atr_mult: float = 3.0, # 初始止损 = 1.0 * ATR
|
||
max_atr_mult: float = 9.0, # 最大止损 = 3.0 * ATR
|
||
scale_threshold_mult: float = 1.0, # 盈利达 initial_atr_mult * ATR 时开始扩大
|
||
use_atr_trailing: bool = True, # 是否启用 ATR 止损
|
||
# --- 其他 ---
|
||
order_direction: Optional[List[str]] = None,
|
||
indicators: Optional[List[Indicator]] = None,
|
||
):
|
||
super().__init__(context, main_symbol, enable_log)
|
||
if order_direction is None:
|
||
order_direction = ["BUY", "SELL"]
|
||
if indicators is None:
|
||
indicators = [Empty(), Empty()]
|
||
|
||
self.trade_volume = trade_volume
|
||
self.ma_period = ma_period
|
||
self.area_window = area_window
|
||
self.strength_window = strength_window
|
||
self.breakout_window = breakout_window
|
||
self.quantile_threshold = quantile_threshold
|
||
self.top_k = top_k
|
||
self.trailing_points = trailing_points
|
||
self.trailing_percent = trailing_percent
|
||
self.atr_period = atr_period
|
||
self.initial_atr_mult = initial_atr_mult
|
||
self.max_atr_mult = max_atr_mult
|
||
self.scale_threshold_mult = scale_threshold_mult
|
||
self.use_atr_trailing = use_atr_trailing
|
||
self.order_direction = order_direction
|
||
self.indicators = indicators
|
||
|
||
# 状态(新增 entry_atr)
|
||
self.entry_price = None
|
||
self.highest_high = None
|
||
self.lowest_low = None
|
||
self.entry_atr = None # 入场时的 ATR 值
|
||
|
||
self.order_id_counter = 0
|
||
self.min_bars_needed = max(
|
||
ma_period,
|
||
area_window * 3,
|
||
strength_window,
|
||
breakout_window,
|
||
atr_period
|
||
) + 10
|
||
self.log("AreaReversalStrategy with ATR Trailing Stop Initialized")
|
||
|
||
def _calculate_areas(self, closes: np.array, ma: np.array) -> np.array:
|
||
diffs = np.abs(closes - ma)
|
||
areas = talib.SUM(diffs, self.area_window)
|
||
return areas
|
||
|
||
def on_open_bar(self, open_price: float, symbol: str):
|
||
self.symbol = symbol
|
||
bar_history = self.get_bar_history()
|
||
|
||
if len(bar_history) < self.min_bars_needed or not self.trading:
|
||
return
|
||
|
||
position = self.get_current_positions().get(self.symbol, 0)
|
||
current_bar = bar_history[-1]
|
||
|
||
# === 提取价格序列(新增 highs, lows 用于 ATR)===
|
||
closes = np.array([b.close for b in bar_history], dtype=float)
|
||
highs = np.array([b.high for b in bar_history], dtype=float)
|
||
lows = np.array([b.low for b in bar_history], dtype=float)
|
||
|
||
# === 计算指标 ===
|
||
ma = talib.SMA(closes, self.ma_period)
|
||
areas = self._calculate_areas(closes, ma)
|
||
|
||
# 新增:计算 ATR
|
||
if self.use_atr_trailing:
|
||
atr = talib.ATR(highs, lows, closes, self.atr_period)
|
||
current_atr = atr[-1]
|
||
else:
|
||
current_atr = None
|
||
|
||
A1 = areas[-1]
|
||
A2 = areas[-2] if len(areas) >= 2 else 0
|
||
|
||
historical_areas = areas[-(self.strength_window + 1):-1]
|
||
if len(historical_areas) < self.strength_window:
|
||
return
|
||
|
||
# === 面积信号条件(完全不变)===
|
||
area_contracting = (A1 < A2) and (A2 > 0)
|
||
threshold = np.nanpercentile(historical_areas, self.quantile_threshold * 100)
|
||
strength_satisfied = (A2 >= threshold)
|
||
top_k_values = np.partition(historical_areas, -self.top_k)[-self.top_k:]
|
||
local_peak = (A2 >= np.min(top_k_values))
|
||
area_signal = area_contracting and strength_satisfied and local_peak
|
||
|
||
# === 突破判断(完全不变)===
|
||
recent_bars = bar_history[-self.breakout_window:]
|
||
highest = max(b.high for b in recent_bars)
|
||
lowest = min(b.low for b in recent_bars)
|
||
|
||
# =============== 开仓逻辑(完全不变)==============
|
||
if position == 0 and area_signal:
|
||
if "BUY" in self.order_direction and current_bar.high >= highest:
|
||
self.send_market_order("BUY", self.trade_volume, "OPEN")
|
||
self.entry_price = current_bar.close
|
||
self.highest_high = current_bar.high
|
||
self.lowest_low = None
|
||
if self.use_atr_trailing and current_atr is not None:
|
||
self.entry_atr = current_atr # 记录入场 ATR
|
||
self.log(f"🚀 Long Entry | A2={A2:.4f}")
|
||
|
||
elif "SELL" in self.order_direction and current_bar.low <= lowest:
|
||
self.send_market_order("SELL", self.trade_volume, "OPEN")
|
||
self.entry_price = current_bar.close
|
||
self.lowest_low = current_bar.low
|
||
self.highest_high = None
|
||
if self.use_atr_trailing and current_atr is not None:
|
||
self.entry_atr = current_atr
|
||
self.log(f"⬇️ Short Entry | A2={A2:.4f}")
|
||
|
||
# =============== 出场逻辑:ATR 动态跟踪止损 ===============
|
||
elif position != 0 and self.entry_price is not None:
|
||
if self.use_atr_trailing and self.entry_atr is not None:
|
||
# --- ATR 动态止损 ---
|
||
if position > 0:
|
||
if self.highest_high is None or current_bar.high > self.highest_high:
|
||
self.highest_high = current_bar.high
|
||
|
||
unrealized_pnl = current_bar.close - self.entry_price
|
||
scale_threshold_pnl = self.scale_threshold_mult * self.initial_atr_mult * self.entry_atr
|
||
|
||
if unrealized_pnl <= 0:
|
||
trail_mult = self.initial_atr_mult
|
||
elif unrealized_pnl >= scale_threshold_pnl:
|
||
trail_mult = self.max_atr_mult
|
||
else:
|
||
ratio = unrealized_pnl / scale_threshold_pnl
|
||
trail_mult = self.initial_atr_mult + ratio * (self.max_atr_mult - self.initial_atr_mult)
|
||
|
||
stop_loss_price = self.highest_high - trail_mult * self.entry_atr
|
||
if current_bar.low <= stop_loss_price:
|
||
self.close_position("CLOSE_LONG", position)
|
||
self._reset_state()
|
||
self.log(f"CloseOperation (ATR Trailing) | Mult={trail_mult:.2f}")
|
||
|
||
else: # short
|
||
if self.lowest_low is None or current_bar.low < self.lowest_low:
|
||
self.lowest_low = current_bar.low
|
||
|
||
unrealized_pnl = self.entry_price - current_bar.close
|
||
scale_threshold_pnl = self.scale_threshold_mult * self.initial_atr_mult * self.entry_atr
|
||
|
||
if unrealized_pnl <= 0:
|
||
trail_mult = self.initial_atr_mult
|
||
elif unrealized_pnl >= scale_threshold_pnl:
|
||
trail_mult = self.max_atr_mult
|
||
else:
|
||
ratio = unrealized_pnl / scale_threshold_pnl
|
||
trail_mult = self.initial_atr_mult + ratio * (self.max_atr_mult - self.initial_atr_mult)
|
||
|
||
stop_loss_price = self.lowest_low + trail_mult * self.entry_atr
|
||
if current_bar.high >= stop_loss_price:
|
||
self.close_position("CLOSE_SHORT", -position)
|
||
self._reset_state()
|
||
self.log(f"CloseOperation (ATR Trailing) | Mult={trail_mult:.2f}")
|
||
|
||
else:
|
||
# --- 保留原有跟踪止损(后备)---
|
||
if position > 0:
|
||
if self.highest_high is None or current_bar.high > self.highest_high:
|
||
self.highest_high = current_bar.high
|
||
if self.trailing_percent is not None:
|
||
offset = self.highest_high * self.trailing_percent
|
||
else:
|
||
offset = self.trailing_points
|
||
stop_loss_price = self.highest_high - offset
|
||
if current_bar.low <= stop_loss_price:
|
||
self.close_position("CLOSE_LONG", position)
|
||
self._reset_state()
|
||
|
||
else:
|
||
if self.lowest_low is None or current_bar.low < self.lowest_low:
|
||
self.lowest_low = current_bar.low
|
||
if self.trailing_percent is not None:
|
||
offset = self.lowest_low * self.trailing_percent
|
||
else:
|
||
offset = self.trailing_points
|
||
stop_loss_price = self.lowest_low + offset
|
||
if current_bar.high >= stop_loss_price:
|
||
self.close_position("CLOSE_SHORT", -position)
|
||
self._reset_state()
|
||
|
||
def _reset_state(self):
|
||
self.entry_price = None
|
||
self.highest_high = None
|
||
self.lowest_low = None
|
||
self.entry_atr = None
|
||
|
||
# --- 模板方法(不变)---
|
||
def on_init(self):
|
||
super().on_init()
|
||
self.cancel_all_pending_orders(self.main_symbol)
|
||
self._reset_state()
|
||
|
||
def close_position(self, direction: str, volume: int):
|
||
self.send_market_order(direction, volume, offset="CLOSE")
|
||
|
||
def send_market_order(self, direction: str, volume: int, offset: str):
|
||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(
|
||
id=order_id,
|
||
symbol=self.symbol,
|
||
direction=direction,
|
||
volume=volume,
|
||
price_type="MARKET",
|
||
submitted_time=self.get_current_time(),
|
||
offset=offset,
|
||
)
|
||
self.send_order(order)
|
||
|
||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||
super().on_rollover(old_symbol, new_symbol)
|
||
self._reset_state()
|
||
self.log("Rollover: Reset trailing stop state.") |