import numpy as np import talib from scipy.signal import stft from datetime import timedelta from typing import Optional, Any, List from src.core_data import Bar, Order from src.indicators.base_indicators import Indicator from src.indicators.indicators import Empty from src.strategies.base_strategy import Strategy class SpectralTrendStrategy(Strategy): """ 频域能量相变策略 - 塔勒布宽幅结构版 (Chandelier Exit) 修改重点: 1. 移除 Slope 离场,避免噪音干扰。 2. 引入状态变量记录持仓期间的极值 (pos_highest/pos_lowest)。 3. 实施“吊灯止损”:以持仓期间极值为锚点,回撤 N * ATR 离场。 4. 止损系数建议:4.0 - 6.0 (给予极大的呼吸空间,仅防范趋势崩溃)。 """ def __init__( self, context: Any, main_symbol: str, enable_log: bool, trade_volume: int, # --- 市场参数 --- bars_per_day: int = 23, # --- STFT 策略参数 --- spectral_window_days: float = 2.0, low_freq_days: float = 2.0, high_freq_days: float = 1.0, trend_strength_threshold: float = 0.2, # 入场能量阈值 exit_threshold: float = 0.1, # 自然能量衰竭离场阈值 slope_threshold: float = 0.0, # 仅用于判断方向,不用于离场 # --- 关键风控参数 (Chandelier Exit) --- stop_loss_atr_multiplier: float = 5.0, # 塔勒布式宽止损,建议 4.0 ~ 6.0 stop_loss_atr_period: int = 14, # --- 其他 --- order_direction: Optional[List[str]] = None, indicators: Indicator = None, model_indicator: Indicator = None, reverse: bool = False, ): super().__init__(context, main_symbol, enable_log) if order_direction is None: order_direction = ['BUY', 'SELL'] self.trade_volume = trade_volume self.bars_per_day = bars_per_day # 信号参数 self.spectral_window_days = spectral_window_days self.low_freq_days = low_freq_days self.high_freq_days = high_freq_days self.trend_strength_threshold = trend_strength_threshold self.exit_threshold = exit_threshold self.slope_threshold = slope_threshold # 风控参数 self.sl_atr_multiplier = stop_loss_atr_multiplier self.sl_atr_period = stop_loss_atr_period self.order_direction = order_direction self.model_indicator = model_indicator or Empty() self.indicators = indicators or Empty() self.reverse = reverse # 计算 STFT 窗口大小 self.spectral_window = int(self.spectral_window_days * self.bars_per_day) if self.spectral_window % 2 != 0: self.spectral_window += 1 self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf') self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0 self.order_id_counter = 0 # --- 持仓状态追踪变量 --- self.entry_price = 0.0 self.pos_highest = 0.0 # 持有多单期间的最高价 self.pos_lowest = 0.0 # 持有空单期间的最低价 self.log( f"SpectralTrend Strategy Initialized. Window: {self.spectral_window}, " f"Chandelier Stop: {self.sl_atr_multiplier}x ATR" ) def on_open_bar(self, open_price: float, symbol: str): self.symbol = symbol bar_history = self.get_bar_history() self.cancel_all_pending_orders(self.main_symbol) # 1. 数据长度检查 required_len = max(self.spectral_window, self.sl_atr_period + 5) if len(bar_history) < required_len: return # 2. 计算 ATR (用于止损) atr_window = self.sl_atr_period + 10 highs = np.array([b.high for b in bar_history[-atr_window:]], dtype=float) lows = np.array([b.low for b in bar_history[-atr_window:]], dtype=float) closes = np.array([b.close for b in bar_history[-atr_window:]], dtype=float) try: atr_values = talib.ATR(highs, lows, closes, timeperiod=self.sl_atr_period) current_atr = atr_values[-1] if np.isnan(current_atr): current_atr = 0.0 except Exception as e: self.log(f"ATR Calc Error: {e}") current_atr = 0.0 # 3. 计算 STFT 核心指标 stft_closes = np.array([b.close for b in bar_history[-self.spectral_window:]], dtype=float) trend_strength, trend_slope = self.calculate_market_state(stft_closes) # 4. 交易逻辑 position_volume = self.get_current_positions().get(self.symbol, 0) # 获取当前Bar的最高/最低价用于更新极值(如果使用Bar内更新,更加灵敏) # 这里为了稳健,使用上一根Bar的High/Low来更新,或者使用开盘价近似 current_high = bar_history[-1].high current_low = bar_history[-1].low if self.trading: if position_volume == 0: # 重置状态 self.pos_highest = 0.0 self.pos_lowest = 0.0 self.entry_price = 0.0 self.evaluate_entry_signal(open_price, trend_strength, trend_slope) else: # 传入 current_high/low 用于更新追踪止损的锚点 self.manage_open_position( position_volume, trend_strength, open_price, current_atr, current_high, current_low ) def calculate_market_state(self, prices: np.array) -> (float, float): """ 计算频域能量占比和线性回归斜率(仅用于方向) """ if len(prices) < self.spectral_window: return 0.0, 0.0 # 标准化处理,消除绝对价格影响 window_data = prices[-self.spectral_window:] mean_val = np.mean(window_data) std_val = np.std(window_data) if std_val == 0: std_val = 1.0 normalized = (window_data - mean_val) / (std_val + 1e-8) # STFT 计算 try: f, t, Zxx = stft( normalized, fs=self.bars_per_day, nperseg=self.spectral_window, noverlap=max(0, self.spectral_window // 2), boundary=None, padded=False ) except Exception: return 0.0, 0.0 # 提取有效频率 valid_mask = (f >= 0) & (f <= self.bars_per_day / 2) f = f[valid_mask] Zxx = Zxx[valid_mask, :] if Zxx.size == 0: return 0.0, 0.0 # 计算能量 current_energy = np.abs(Zxx[:, -1]) ** 2 low_mask = f < self.low_freq_bound high_mask = f > self.high_freq_bound low_energy = np.sum(current_energy[low_mask]) if np.any(low_mask) else 0.0 high_energy = np.sum(current_energy[high_mask]) if np.any(high_mask) else 0.0 total_energy = low_energy + high_energy + 1e-8 trend_strength = low_energy / total_energy # 计算简单线性斜率用于判断方向 x = np.arange(len(normalized)) slope, _ = np.polyfit(x, normalized, 1) return trend_strength, slope def evaluate_entry_signal(self, open_price: float, trend_strength: float, trend_slope: float): """ 入场逻辑:高低频能量比 > 阈值 + 斜率确认方向 """ # 必须有足够的趋势强度 if trend_strength > self.trend_strength_threshold: direction = None # 仅使用 slope 的正负号和基本阈值来决定方向,不作为离场依据 if "BUY" in self.order_direction and trend_slope > self.slope_threshold: direction = "BUY" elif "SELL" in self.order_direction and trend_slope < -self.slope_threshold: direction = "SELL" if direction: # 外部过滤条件 if not self.indicators.is_condition_met(*self.get_indicator_tuple()): return if not self.model_indicator.is_condition_met(*self.get_indicator_tuple()): direction = "SELL" if direction == "BUY" else "BUY" if self.reverse: direction = "SELL" if direction == "BUY" else "BUY" self.log(f"Entry: {direction} | Strength={trend_strength:.2f} | DirSlope={trend_slope:.4f}") self.send_limit_order(direction, open_price, self.trade_volume, "OPEN") # 初始化持仓状态 self.entry_price = open_price # 初始极值设为当前价格 self.pos_highest = open_price self.pos_lowest = open_price def manage_open_position(self, volume: int, trend_strength: float, current_price: float, current_atr: float, high_price: float, low_price: float): """ 离场逻辑核心: 1. 信号离场:能量自然衰竭 (Trend Strength < Threshold) 2. 结构离场:Chandelier Exit (吊灯止损) - 宽幅 ATR 追踪 """ exit_dir = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT" # --- 更新持仓期间的极值 --- if volume > 0: # 多头 if high_price > self.pos_highest or self.pos_highest == 0: self.pos_highest = high_price else: # 空头 if (low_price < self.pos_lowest or self.pos_lowest == 0) and low_price > 0: self.pos_lowest = low_price # --- 1. 计算宽幅吊灯止损 (Structural Stop) --- is_stop_triggered = False stop_line = 0.0 # 如果 ATR 无效,暂时不触发止损,或者使用百分比兜底(此处略) if current_atr > 0: stop_distance = current_atr * self.sl_atr_multiplier if volume > 0: # 多头止损线 = 最高价 - N * ATR # 随着价格创新高,止损线不断上移;价格下跌,止损线不变 stop_line = self.pos_highest - stop_distance if current_price <= stop_line: is_stop_triggered = True self.log( f"STOP (Long): Price {current_price} <= Highest {self.pos_highest} - {self.sl_atr_multiplier}xATR") else: # 空头止损线 = 最低价 + N * ATR stop_line = self.pos_lowest + stop_distance if current_price >= stop_line: is_stop_triggered = True self.log( f"STOP (Short): Price {current_price} >= Lowest {self.pos_lowest} + {self.sl_atr_multiplier}xATR") if is_stop_triggered: self.close_position(exit_dir, abs(volume)) return # 止损优先 # --- 2. 信号自然离场 (Signal Exit) --- # 当 STFT 检测到低频能量消散,说明市场进入混沌或震荡,此时平仓 # 这是一个 "慢" 离场,通常在趋势走完后 if trend_strength < self.exit_threshold: self.log(f"Exit (Signal): Strength {trend_strength:.2f} < {self.exit_threshold}") self.close_position(exit_dir, abs(volume)) return # --- 交易执行辅助函数 --- def close_position(self, direction: str, volume: int): self.send_market_order(direction, volume, offset="CLOSE") def send_market_order(self, direction: str, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_MKT_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET", submitted_time=self.get_current_time(), offset=offset ) self.send_order(order) def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_LMT_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset, limit_price=limit_price ) self.send_order(order)