Files
NewQuant/futures_trading_strategies/FG/AreaReversal/AreaReversalStrategy2.py

255 lines
11 KiB
Python
Raw Normal View History

2025-11-07 16:37:16 +08:00
import numpy as np
import talib
from typing import Optional, Any, List
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class AreaReversalStrategy(Strategy):
"""
面积反转策略开仓逻辑不变出场替换为 ATR 动态跟踪止损
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
ma_period: int = 14,
area_window: int = 14,
strength_window: int = 50,
breakout_window: int = 20,
quantile_threshold: float = 0.4,
top_k: int = 3,
# --- 原有跟踪止损(保留为后备)---
trailing_points: float = 100.0,
trailing_percent: float = None,
# --- 新增 ATR 动态止损参数 ---
atr_period: int = 14,
initial_atr_mult: float = 3.0, # 初始止损 = 1.0 * ATR
max_atr_mult: float = 9.0, # 最大止损 = 3.0 * ATR
scale_threshold_mult: float = 1.0, # 盈利达 initial_atr_mult * ATR 时开始扩大
use_atr_trailing: bool = True, # 是否启用 ATR 止损
# --- 其他 ---
order_direction: Optional[List[str]] = None,
indicators: Optional[List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
if order_direction is None:
order_direction = ["BUY", "SELL"]
if indicators is None:
indicators = [Empty(), Empty()]
self.trade_volume = trade_volume
self.ma_period = ma_period
self.area_window = area_window
self.strength_window = strength_window
self.breakout_window = breakout_window
self.quantile_threshold = quantile_threshold
self.top_k = top_k
self.trailing_points = trailing_points
self.trailing_percent = trailing_percent
self.atr_period = atr_period
self.initial_atr_mult = initial_atr_mult
self.max_atr_mult = max_atr_mult
self.scale_threshold_mult = scale_threshold_mult
self.use_atr_trailing = use_atr_trailing
self.order_direction = order_direction
self.indicators = indicators
# 状态(新增 entry_atr
self.entry_price = None
self.highest_high = None
self.lowest_low = None
self.entry_atr = None # 入场时的 ATR 值
self.order_id_counter = 0
self.min_bars_needed = max(
ma_period,
area_window * 3,
strength_window,
breakout_window,
atr_period
) + 10
self.log("AreaReversalStrategy with ATR Trailing Stop Initialized")
def _calculate_areas(self, closes: np.array, ma: np.array) -> np.array:
diffs = np.abs(closes - ma)
areas = talib.SUM(diffs, self.area_window)
return areas
def on_open_bar(self, open_price: float, symbol: str):
self.symbol = symbol
bar_history = self.get_bar_history()
if len(bar_history) < self.min_bars_needed or not self.trading:
return
position = self.get_current_positions().get(self.symbol, 0)
current_bar = bar_history[-1]
# === 提取价格序列(新增 highs, lows 用于 ATR===
closes = np.array([b.close for b in bar_history], dtype=float)
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
# === 计算指标 ===
ma = talib.SMA(closes, self.ma_period)
areas = self._calculate_areas(closes, ma)
# 新增:计算 ATR
if self.use_atr_trailing:
atr = talib.ATR(highs, lows, closes, self.atr_period)
current_atr = atr[-1]
else:
current_atr = None
A1 = areas[-1]
A2 = areas[-2] if len(areas) >= 2 else 0
historical_areas = areas[-(self.strength_window + 1):-1]
if len(historical_areas) < self.strength_window:
return
# === 面积信号条件(完全不变)===
area_contracting = (A1 < A2) and (A2 > 0)
threshold = np.nanpercentile(historical_areas, self.quantile_threshold * 100)
strength_satisfied = (A2 >= threshold)
top_k_values = np.partition(historical_areas, -self.top_k)[-self.top_k:]
local_peak = (A2 >= np.min(top_k_values))
area_signal = area_contracting and strength_satisfied and local_peak
# === 突破判断(完全不变)===
recent_bars = bar_history[-self.breakout_window:]
highest = max(b.high for b in recent_bars)
lowest = min(b.low for b in recent_bars)
# =============== 开仓逻辑(完全不变)==============
if position == 0 and area_signal:
if "BUY" in self.order_direction and current_bar.high >= highest:
self.send_market_order("BUY", self.trade_volume, "OPEN")
self.entry_price = current_bar.close
self.highest_high = current_bar.high
self.lowest_low = None
if self.use_atr_trailing and current_atr is not None:
self.entry_atr = current_atr # 记录入场 ATR
self.log(f"🚀 Long Entry | A2={A2:.4f}")
elif "SELL" in self.order_direction and current_bar.low <= lowest:
self.send_market_order("SELL", self.trade_volume, "OPEN")
self.entry_price = current_bar.close
self.lowest_low = current_bar.low
self.highest_high = None
if self.use_atr_trailing and current_atr is not None:
self.entry_atr = current_atr
self.log(f"⬇️ Short Entry | A2={A2:.4f}")
# =============== 出场逻辑ATR 动态跟踪止损 ===============
elif position != 0 and self.entry_price is not None:
if self.use_atr_trailing and self.entry_atr is not None:
# --- ATR 动态止损 ---
if position > 0:
if self.highest_high is None or current_bar.high > self.highest_high:
self.highest_high = current_bar.high
unrealized_pnl = current_bar.close - self.entry_price
scale_threshold_pnl = self.scale_threshold_mult * self.initial_atr_mult * self.entry_atr
if unrealized_pnl <= 0:
trail_mult = self.initial_atr_mult
elif unrealized_pnl >= scale_threshold_pnl:
trail_mult = self.max_atr_mult
else:
ratio = unrealized_pnl / scale_threshold_pnl
trail_mult = self.initial_atr_mult + ratio * (self.max_atr_mult - self.initial_atr_mult)
stop_loss_price = self.highest_high - trail_mult * self.entry_atr
if current_bar.low <= stop_loss_price:
self.close_position("CLOSE_LONG", position)
self._reset_state()
self.log(f"CloseOperation (ATR Trailing) | Mult={trail_mult:.2f}")
else: # short
if self.lowest_low is None or current_bar.low < self.lowest_low:
self.lowest_low = current_bar.low
unrealized_pnl = self.entry_price - current_bar.close
scale_threshold_pnl = self.scale_threshold_mult * self.initial_atr_mult * self.entry_atr
if unrealized_pnl <= 0:
trail_mult = self.initial_atr_mult
elif unrealized_pnl >= scale_threshold_pnl:
trail_mult = self.max_atr_mult
else:
ratio = unrealized_pnl / scale_threshold_pnl
trail_mult = self.initial_atr_mult + ratio * (self.max_atr_mult - self.initial_atr_mult)
stop_loss_price = self.lowest_low + trail_mult * self.entry_atr
if current_bar.high >= stop_loss_price:
self.close_position("CLOSE_SHORT", -position)
self._reset_state()
self.log(f"CloseOperation (ATR Trailing) | Mult={trail_mult:.2f}")
else:
# --- 保留原有跟踪止损(后备)---
if position > 0:
if self.highest_high is None or current_bar.high > self.highest_high:
self.highest_high = current_bar.high
if self.trailing_percent is not None:
offset = self.highest_high * self.trailing_percent
else:
offset = self.trailing_points
stop_loss_price = self.highest_high - offset
if current_bar.low <= stop_loss_price:
self.close_position("CLOSE_LONG", position)
self._reset_state()
else:
if self.lowest_low is None or current_bar.low < self.lowest_low:
self.lowest_low = current_bar.low
if self.trailing_percent is not None:
offset = self.lowest_low * self.trailing_percent
else:
offset = self.trailing_points
stop_loss_price = self.lowest_low + offset
if current_bar.high >= stop_loss_price:
self.close_position("CLOSE_SHORT", -position)
self._reset_state()
def _reset_state(self):
self.entry_price = None
self.highest_high = None
self.lowest_low = None
self.entry_atr = None
# --- 模板方法(不变)---
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
self._reset_state()
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
symbol=self.symbol,
direction=direction,
volume=volume,
price_type="MARKET",
submitted_time=self.get_current_time(),
offset=offset,
)
self.send_order(order)
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self._reset_state()
self.log("Rollover: Reset trailing stop state.")