Files
NewQuant/futures_trading_strategies/FG/AreaReversal/AreaReversalStrategy2.py

255 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import talib
from typing import Optional, Any, List
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class AreaReversalStrategy(Strategy):
"""
面积反转策略(开仓逻辑不变,出场替换为 ATR 动态跟踪止损)
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
ma_period: int = 14,
area_window: int = 14,
strength_window: int = 50,
breakout_window: int = 20,
quantile_threshold: float = 0.4,
top_k: int = 3,
# --- 原有跟踪止损(保留为后备)---
trailing_points: float = 100.0,
trailing_percent: float = None,
# --- 新增 ATR 动态止损参数 ---
atr_period: int = 14,
initial_atr_mult: float = 3.0, # 初始止损 = 1.0 * ATR
max_atr_mult: float = 9.0, # 最大止损 = 3.0 * ATR
scale_threshold_mult: float = 1.0, # 盈利达 initial_atr_mult * ATR 时开始扩大
use_atr_trailing: bool = True, # 是否启用 ATR 止损
# --- 其他 ---
order_direction: Optional[List[str]] = None,
indicators: Optional[List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
if order_direction is None:
order_direction = ["BUY", "SELL"]
if indicators is None:
indicators = [Empty(), Empty()]
self.trade_volume = trade_volume
self.ma_period = ma_period
self.area_window = area_window
self.strength_window = strength_window
self.breakout_window = breakout_window
self.quantile_threshold = quantile_threshold
self.top_k = top_k
self.trailing_points = trailing_points
self.trailing_percent = trailing_percent
self.atr_period = atr_period
self.initial_atr_mult = initial_atr_mult
self.max_atr_mult = max_atr_mult
self.scale_threshold_mult = scale_threshold_mult
self.use_atr_trailing = use_atr_trailing
self.order_direction = order_direction
self.indicators = indicators
# 状态(新增 entry_atr
self.entry_price = None
self.highest_high = None
self.lowest_low = None
self.entry_atr = None # 入场时的 ATR 值
self.order_id_counter = 0
self.min_bars_needed = max(
ma_period,
area_window * 3,
strength_window,
breakout_window,
atr_period
) + 10
self.log("AreaReversalStrategy with ATR Trailing Stop Initialized")
def _calculate_areas(self, closes: np.array, ma: np.array) -> np.array:
diffs = np.abs(closes - ma)
areas = talib.SUM(diffs, self.area_window)
return areas
def on_open_bar(self, open_price: float, symbol: str):
self.symbol = symbol
bar_history = self.get_bar_history()
if len(bar_history) < self.min_bars_needed or not self.trading:
return
position = self.get_current_positions().get(self.symbol, 0)
current_bar = bar_history[-1]
# === 提取价格序列(新增 highs, lows 用于 ATR===
closes = np.array([b.close for b in bar_history], dtype=float)
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
# === 计算指标 ===
ma = talib.SMA(closes, self.ma_period)
areas = self._calculate_areas(closes, ma)
# 新增:计算 ATR
if self.use_atr_trailing:
atr = talib.ATR(highs, lows, closes, self.atr_period)
current_atr = atr[-1]
else:
current_atr = None
A1 = areas[-1]
A2 = areas[-2] if len(areas) >= 2 else 0
historical_areas = areas[-(self.strength_window + 1):-1]
if len(historical_areas) < self.strength_window:
return
# === 面积信号条件(完全不变)===
area_contracting = (A1 < A2) and (A2 > 0)
threshold = np.nanpercentile(historical_areas, self.quantile_threshold * 100)
strength_satisfied = (A2 >= threshold)
top_k_values = np.partition(historical_areas, -self.top_k)[-self.top_k:]
local_peak = (A2 >= np.min(top_k_values))
area_signal = area_contracting and strength_satisfied and local_peak
# === 突破判断(完全不变)===
recent_bars = bar_history[-self.breakout_window:]
highest = max(b.high for b in recent_bars)
lowest = min(b.low for b in recent_bars)
# =============== 开仓逻辑(完全不变)==============
if position == 0 and area_signal:
if "BUY" in self.order_direction and current_bar.high >= highest:
self.send_market_order("BUY", self.trade_volume, "OPEN")
self.entry_price = current_bar.close
self.highest_high = current_bar.high
self.lowest_low = None
if self.use_atr_trailing and current_atr is not None:
self.entry_atr = current_atr # 记录入场 ATR
self.log(f"🚀 Long Entry | A2={A2:.4f}")
elif "SELL" in self.order_direction and current_bar.low <= lowest:
self.send_market_order("SELL", self.trade_volume, "OPEN")
self.entry_price = current_bar.close
self.lowest_low = current_bar.low
self.highest_high = None
if self.use_atr_trailing and current_atr is not None:
self.entry_atr = current_atr
self.log(f"⬇️ Short Entry | A2={A2:.4f}")
# =============== 出场逻辑ATR 动态跟踪止损 ===============
elif position != 0 and self.entry_price is not None:
if self.use_atr_trailing and self.entry_atr is not None:
# --- ATR 动态止损 ---
if position > 0:
if self.highest_high is None or current_bar.high > self.highest_high:
self.highest_high = current_bar.high
unrealized_pnl = current_bar.close - self.entry_price
scale_threshold_pnl = self.scale_threshold_mult * self.initial_atr_mult * self.entry_atr
if unrealized_pnl <= 0:
trail_mult = self.initial_atr_mult
elif unrealized_pnl >= scale_threshold_pnl:
trail_mult = self.max_atr_mult
else:
ratio = unrealized_pnl / scale_threshold_pnl
trail_mult = self.initial_atr_mult + ratio * (self.max_atr_mult - self.initial_atr_mult)
stop_loss_price = self.highest_high - trail_mult * self.entry_atr
if current_bar.low <= stop_loss_price:
self.close_position("CLOSE_LONG", position)
self._reset_state()
self.log(f"CloseOperation (ATR Trailing) | Mult={trail_mult:.2f}")
else: # short
if self.lowest_low is None or current_bar.low < self.lowest_low:
self.lowest_low = current_bar.low
unrealized_pnl = self.entry_price - current_bar.close
scale_threshold_pnl = self.scale_threshold_mult * self.initial_atr_mult * self.entry_atr
if unrealized_pnl <= 0:
trail_mult = self.initial_atr_mult
elif unrealized_pnl >= scale_threshold_pnl:
trail_mult = self.max_atr_mult
else:
ratio = unrealized_pnl / scale_threshold_pnl
trail_mult = self.initial_atr_mult + ratio * (self.max_atr_mult - self.initial_atr_mult)
stop_loss_price = self.lowest_low + trail_mult * self.entry_atr
if current_bar.high >= stop_loss_price:
self.close_position("CLOSE_SHORT", -position)
self._reset_state()
self.log(f"CloseOperation (ATR Trailing) | Mult={trail_mult:.2f}")
else:
# --- 保留原有跟踪止损(后备)---
if position > 0:
if self.highest_high is None or current_bar.high > self.highest_high:
self.highest_high = current_bar.high
if self.trailing_percent is not None:
offset = self.highest_high * self.trailing_percent
else:
offset = self.trailing_points
stop_loss_price = self.highest_high - offset
if current_bar.low <= stop_loss_price:
self.close_position("CLOSE_LONG", position)
self._reset_state()
else:
if self.lowest_low is None or current_bar.low < self.lowest_low:
self.lowest_low = current_bar.low
if self.trailing_percent is not None:
offset = self.lowest_low * self.trailing_percent
else:
offset = self.trailing_points
stop_loss_price = self.lowest_low + offset
if current_bar.high >= stop_loss_price:
self.close_position("CLOSE_SHORT", -position)
self._reset_state()
def _reset_state(self):
self.entry_price = None
self.highest_high = None
self.lowest_low = None
self.entry_atr = None
# --- 模板方法(不变)---
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
self._reset_state()
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
symbol=self.symbol,
direction=direction,
volume=volume,
price_type="MARKET",
submitted_time=self.get_current_time(),
offset=offset,
)
self.send_order(order)
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self._reset_state()
self.log("Rollover: Reset trailing stop state.")