import numpy as np import talib from scipy.signal import stft from datetime import timedelta from typing import Optional, Any, List from src.core_data import Bar, Order from src.indicators.base_indicators import Indicator from src.indicators.indicators import Empty from src.strategies.base_strategy import Strategy class SpectralTrendStrategy(Strategy): """ 频域能量相变策略 - 双因子自适应版 (Dual-Factor Adaptive) 优化逻辑: 不再通过静态参数 reverse 控制方向,而是由两个 Indicator 动态决策: 1. 计算 STFT 基础方向 (Base Direction)。 2. 检查 indicator_primary:若满足,则采用 Base Direction (顺势/正向)。 3. 若不满足,检查 indicator_secondary:若满足,则采用 Reverse Direction (逆势/反转)。 4. 若都不满足,保持空仓。 状态追踪: self.entry_signal_source 会记录当前持仓是 'PRIMARY' 还是 'SECONDARY'。 """ def __init__( self, context: Any, main_symbol: str, enable_log: bool, trade_volume: int, # --- 市场参数 --- bars_per_day: int = 23, # --- STFT 策略参数 --- spectral_window_days: float = 2.0, low_freq_days: float = 2.0, high_freq_days: float = 1.0, trend_strength_threshold: float = 0.2, exit_threshold: float = 0.1, slope_threshold: float = 0.0, # --- 关键风控参数 (Chandelier Exit) --- stop_loss_atr_multiplier: float = 5.0, stop_loss_atr_period: int = 14, # --- 信号控制指标 (核心优化) --- indicator_primary: Indicator = None, # 满足此条件 -> 正向开仓 (reverse=False) indicator_secondary: Indicator = None, # 满足此条件 -> 反向开仓 (reverse=True) model_indicator: Indicator = None, # 可选:额外的AI模型过滤器 # --- 其他 --- order_direction: Optional[List[str]] = None, ): super().__init__(context, main_symbol, enable_log) if order_direction is None: order_direction = ['BUY', 'SELL'] self.trade_volume = trade_volume self.bars_per_day = bars_per_day # 信号参数 self.spectral_window_days = spectral_window_days self.low_freq_days = low_freq_days self.high_freq_days = high_freq_days self.trend_strength_threshold = trend_strength_threshold self.exit_threshold = exit_threshold self.slope_threshold = slope_threshold # 风控参数 self.sl_atr_multiplier = stop_loss_atr_multiplier self.sl_atr_period = stop_loss_atr_period self.order_direction = order_direction # 初始指标容器 (默认为Empty,即永远返回True或False,视Empty具体实现而定,建议传入具体指标) self.indicator_primary = indicator_primary or Empty() self.indicator_secondary = indicator_secondary or Empty() self.model_indicator = model_indicator or Empty() # 计算 STFT 窗口大小 self.spectral_window = int(self.spectral_window_days * self.bars_per_day) if self.spectral_window % 2 != 0: self.spectral_window += 1 self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf') self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0 self.order_id_counter = 0 # --- 持仓状态追踪变量 --- self.entry_price = 0.0 self.pos_highest = 0.0 self.pos_lowest = 0.0 # 新增:记录开仓信号来源 ('PRIMARY' or 'SECONDARY') self.entry_signal_source = None self.log( f"SpectralTrend Dual-Adaptive Strategy Initialized.\n" f"Window: {self.spectral_window}, ATR Stop: {self.sl_atr_multiplier}x\n" f"Primary Ind: {type(self.indicator_primary).__name__}, " f"Secondary Ind: {type(self.indicator_secondary).__name__}" ) def on_open_bar(self, open_price: float, symbol: str): self.symbol = symbol bar_history = self.get_bar_history() self.cancel_all_pending_orders(self.main_symbol) # 1. 数据长度检查 required_len = max(self.spectral_window, self.sl_atr_period + 5) if len(bar_history) < required_len: return # 2. 计算 ATR atr_window = self.sl_atr_period + 10 highs = np.array([b.high for b in bar_history[-atr_window:]], dtype=float) lows = np.array([b.low for b in bar_history[-atr_window:]], dtype=float) closes = np.array([b.close for b in bar_history[-atr_window:]], dtype=float) try: atr_values = talib.ATR(highs, lows, closes, timeperiod=self.sl_atr_period) current_atr = atr_values[-1] if np.isnan(current_atr): current_atr = 0.0 except Exception as e: self.log(f"ATR Calc Error: {e}") current_atr = 0.0 # 3. 计算 STFT 核心指标 stft_closes = np.array([b.close for b in bar_history[-self.spectral_window:]], dtype=float) trend_strength, trend_slope = self.calculate_market_state(stft_closes) # 4. 交易逻辑 position_volume = self.get_current_positions().get(self.symbol, 0) current_high = bar_history[-1].high current_low = bar_history[-1].low if self.trading: if position_volume == 0: # 重置所有状态 self.pos_highest = 0.0 self.pos_lowest = 0.0 self.entry_price = 0.0 self.entry_signal_source = None # 重置信号来源 self.evaluate_entry_signal(open_price, trend_strength, trend_slope) else: self.manage_open_position( position_volume, trend_strength, open_price, current_atr, current_high, current_low ) def calculate_market_state(self, prices: np.array) -> (float, float): """ 计算频域能量占比和线性回归斜率 """ if len(prices) < self.spectral_window: return 0.0, 0.0 window_data = prices[-self.spectral_window:] mean_val = np.mean(window_data) std_val = np.std(window_data) if std_val == 0: std_val = 1.0 normalized = (window_data - mean_val) / (std_val + 1e-8) try: f, t, Zxx = stft( normalized, fs=self.bars_per_day, nperseg=self.spectral_window, noverlap=max(0, self.spectral_window // 2), boundary=None, padded=False ) except Exception: return 0.0, 0.0 valid_mask = (f >= 0) & (f <= self.bars_per_day / 2) f = f[valid_mask] Zxx = Zxx[valid_mask, :] if Zxx.size == 0: return 0.0, 0.0 current_energy = np.abs(Zxx[:, -1]) ** 2 low_mask = f < self.low_freq_bound high_mask = f > self.high_freq_bound low_energy = np.sum(current_energy[low_mask]) if np.any(low_mask) else 0.0 high_energy = np.sum(current_energy[high_mask]) if np.any(high_mask) else 0.0 total_energy = low_energy + high_energy + 1e-8 trend_strength = low_energy / total_energy x = np.arange(len(normalized)) slope, _ = np.polyfit(x, normalized, 1) return trend_strength, slope def evaluate_entry_signal(self, open_price: float, trend_strength: float, trend_slope: float): """ 入场逻辑优化:双因子控制 """ # 1. 基础能量阈值检查 if trend_strength <= self.trend_strength_threshold: return # 2. 确定 STFT 原始方向 (Raw Direction) raw_direction = None if trend_slope > self.slope_threshold: raw_direction = "BUY" elif trend_slope < -self.slope_threshold: raw_direction = "SELL" if not raw_direction: return # 3. 双指标分支逻辑 (Dual-Indicator Branching) # 获取指标计算所需的参数 (通常是 bar_history 等,依赖基类实现) indicator_args = self.get_indicator_tuple() final_direction = None source_tag = None # --- 分支 1: Primary Indicator (优先) --- # 如果满足主条件 -> 使用原始方向 (Equivalent to reverse=False) if self.indicator_primary.is_condition_met(*indicator_args): final_direction = raw_direction source_tag = "PRIMARY" # --- 分支 2: Secondary Indicator (备选/Else) --- # 如果不满足主条件,但满足备选条件 -> 使用反转方向 (Equivalent to reverse=True) elif self.indicator_secondary.is_condition_met(*indicator_args): final_direction = "SELL" if raw_direction == "BUY" else "BUY" source_tag = "SECONDARY" # --- 分支 3: 都不满足 --- else: return # 放弃交易 # 4. 全局模型过滤 (可选) if not self.model_indicator.is_condition_met(*indicator_args): return # 5. 最终方向检查 if final_direction not in self.order_direction: return # 6. 执行开仓 self.log( f"Entry Triggered [{source_tag}]: " f"Raw={raw_direction} -> Final={final_direction} | " f"Strength={trend_strength:.2f} | Slope={trend_slope:.4f}" ) self.send_limit_order(final_direction, open_price, self.trade_volume, "OPEN") # 更新状态 self.entry_price = open_price self.pos_highest = open_price self.pos_lowest = open_price self.entry_signal_source = source_tag # 保存是由哪一个条件控制的 def manage_open_position(self, volume: int, trend_strength: float, current_price: float, current_atr: float, high_price: float, low_price: float): """ 离场逻辑 (保持不变,但日志中可以体现来源) """ exit_dir = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT" # 更新极值 if volume > 0: if high_price > self.pos_highest or self.pos_highest == 0: self.pos_highest = high_price else: if (low_price < self.pos_lowest or self.pos_lowest == 0) and low_price > 0: self.pos_lowest = low_price # 1. 结构性止损 (Chandelier Exit) is_stop_triggered = False stop_line = 0.0 if current_atr > 0: stop_distance = current_atr * self.sl_atr_multiplier if volume > 0: stop_line = self.pos_highest - stop_distance if current_price <= stop_line: is_stop_triggered = True else: stop_line = self.pos_lowest + stop_distance if current_price >= stop_line: is_stop_triggered = True if is_stop_triggered: self.log( f"STOP Loss ({self.entry_signal_source}): Price {current_price} hit Chandelier line {stop_line:.2f}") self.close_position(exit_dir, abs(volume)) return # 2. 信号衰竭离场 if trend_strength < self.exit_threshold: self.log( f"Exit Signal ({self.entry_signal_source}): Energy Faded {trend_strength:.2f} < {self.exit_threshold}") self.close_position(exit_dir, abs(volume)) return # --- 交易辅助 --- def close_position(self, direction: str, volume: int): self.send_market_order(direction, volume, offset="CLOSE") def send_market_order(self, direction: str, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_MKT_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET", submitted_time=self.get_current_time(), offset=offset ) self.send_order(order) def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_LMT_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset, limit_price=limit_price ) self.send_order(order)