import numpy as np from scipy.signal import stft from datetime import datetime, timedelta from typing import Optional, Any, List, Dict from src.core_data import Bar, Order from src.indicators.base_indicators import Indicator from src.indicators.indicators import Empty, NormalizedATR, AtrVolatility, ZScoreATR from src.strategies.base_strategy import Strategy # ============================================================================= # 策略实现 (SpectralTrendStrategy) # ============================================================================= class SpectralTrendStrategy(Strategy): """ 频域能量相变策略 - 捕获肥尾趋势 核心哲学: 1. 显式傅里叶变换: 直接分离低频(趋势)、高频(噪音)能量 2. 相变临界点: 仅当低频能量占比 > 阈值时入场 3. 低频交易: 每月仅2-5次信号,持仓数日捕获肥尾 4. 完全参数化: 无硬编码,适配任何市场时间结构 参数说明: - bars_per_day: 市场每日K线数量 (e.g., 23 for 15min US markets) - low_freq_days: 低频定义下限 (天), 默认2.0 - high_freq_days: 高频定义上限 (天), 默认1.0 """ def __init__( self, context: Any, main_symbol: str, enable_log: bool, trade_volume: int, # --- 【市场结构参数】 --- bars_per_day: int = 23, # 关键: 适配23根/天的市场 # --- 【频域核心参数】 --- spectral_window_days: float = 2.0, # STFT窗口大小(天) low_freq_days: float = 2.0, # 低频下限(天) high_freq_days: float = 1.0, # 高频上限(天) trend_strength_threshold: float = 0.1, # 相变临界值 exit_threshold: float = 0.4, # 退出阈值 # --- 【持仓管理】 --- max_hold_days: int = 10, # 最大持仓天数 # --- 其他 --- order_direction: Optional[List[str]] = None, indicators: Optional[List[Indicator]] = None, model_indicator: Indicator = None, ): super().__init__(context, main_symbol, enable_log) if order_direction is None: order_direction = ['BUY', 'SELL'] if indicators is None: indicators = [Empty(), Empty()] # 保持兼容性 # --- 参数赋值 (完全参数化) --- self.trade_volume = trade_volume self.bars_per_day = bars_per_day self.spectral_window_days = spectral_window_days self.low_freq_days = low_freq_days self.high_freq_days = high_freq_days self.trend_strength_threshold = trend_strength_threshold self.exit_threshold = exit_threshold self.max_hold_days = max_hold_days self.order_direction = order_direction if model_indicator is None: model_indicator = Empty() self.model_indicator = model_indicator # --- 动态计算参数 --- self.spectral_window = int(self.spectral_window_days * self.bars_per_day) # 确保窗口大小为偶数 (STFT要求) self.spectral_window = self.spectral_window if self.spectral_window % 2 == 0 else self.spectral_window + 1 # 频率边界 (cycles/day) self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf') self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0 # --- 内部状态变量 --- self.main_symbol = main_symbol self.order_id_counter = 0 self.indicators = indicators self.entry_time = None # 入场时间 self.position_direction = None # 'LONG' or 'SHORT' self.last_trend_strength = 0.0 self.last_dominant_freq = 0.0 # 主导周期(天) self.log(f"SpectralTrendStrategy Initialized (bars/day={bars_per_day}, window={self.spectral_window} bars)") def on_open_bar(self, open_price: float, symbol: str): """每根K线开盘时被调用""" self.symbol = symbol bar_history = self.get_bar_history() current_time = self.get_current_time() self.cancel_all_pending_orders(self.main_symbol) # 需要足够的数据 (STFT窗口 + 缓冲) if len(bar_history) < self.spectral_window + 10: if self.enable_log and len(bar_history) % 50 == 0: self.log(f"Waiting for {len(bar_history)}/{self.spectral_window + 10} bars") return position_volume = self.get_current_positions().get(self.symbol, 0) # 获取历史价格 (使用完整历史) closes = np.array([b.close for b in bar_history], dtype=float) # 【核心】计算频域趋势强度 (显式傅里叶) trend_strength, dominant_freq = self.calculate_trend_strength(closes) self.last_trend_strength = trend_strength self.last_dominant_freq = dominant_freq # 检查最大持仓时间 (防止极端事件) if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days): self.log(f"Max hold time reached ({self.max_hold_days} days). Forcing exit.") self.close_all_positions() self.entry_time = None self.position_direction = None return # 核心逻辑:相变入场/退出 if self.trading: if position_volume == 0: self.evaluate_entry_signal(open_price, trend_strength, dominant_freq) else: self.manage_open_position(position_volume, trend_strength, dominant_freq) def calculate_trend_strength(self, prices: np.array) -> (float, float): """ 【显式傅里叶】计算低频能量占比 (完全参数化) 步骤: 1. 价格归一化 (窗口内) 2. 短时傅里叶变换 (STFT) - 采样率=bars_per_day 3. 动态计算频段边界 (基于bars_per_day) 4. 趋势强度 = 低频能量 / (低频+高频能量) """ # 1. 验证数据长度 if len(prices) < self.spectral_window: return 0.0, 0.0 # 2. 价格归一化 (仅使用窗口内数据) window_data = prices[-self.spectral_window:] normalized = (window_data - np.mean(window_data)) / (np.std(window_data) + 1e-8) # 3. STFT (采样率=bars_per_day) try: # fs: 每天的样本数 (bars_per_day) f, t, Zxx = stft( normalized, fs=self.bars_per_day, # 关键: 适配市场结构 nperseg=self.spectral_window, noverlap=max(0, self.spectral_window // 2), boundary=None, padded=False ) except Exception as e: self.log(f"STFT calculation error: {str(e)}") return 0.0, 0.0 # 4. 过滤无效频率 (STFT返回频率范围: 0 到 fs/2) valid_mask = (f >= 0) & (f <= self.bars_per_day / 2) f = f[valid_mask] Zxx = Zxx[valid_mask, :] if Zxx.size == 0 or Zxx.shape[1] == 0: return 0.0, 0.0 # 5. 计算最新时间点的能量 current_energy = np.abs(Zxx[:, -1]) ** 2 # 6. 动态频段定义 (cycles/day) # 低频: 周期 > low_freq_days → 频率 < 1/low_freq_days low_freq_mask = f < self.low_freq_bound # 高频: 周期 < high_freq_days → 频率 > 1/high_freq_days high_freq_mask = f > self.high_freq_bound # 7. 能量计算 low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0 high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0 total_energy = low_energy + high_energy + 1e-8 # 防除零 # 8. 趋势强度 = 低频能量占比 trend_strength = low_energy / total_energy # 9. 计算主导趋势周期 (天) dominant_freq = 0.0 if np.any(low_freq_mask) and low_energy > 0: # 找到低频段最大能量对应的频率 low_energies = current_energy[low_freq_mask] max_idx = np.argmax(low_energies) dominant_freq = 1.0 / (f[low_freq_mask][max_idx] + 1e-8) # 转换为周期(天) return trend_strength, dominant_freq def evaluate_entry_signal(self, open_price: float, trend_strength: float, dominant_freq: float): """评估相变入场信号""" # 仅当趋势强度跨越临界点且有明确周期时入场 self.log( f"Strength={trend_strength:.2f}") if (trend_strength > self.trend_strength_threshold and self.model_indicator.is_condition_met(*self.get_indicator_tuple())): direction = None indicator = self.model_indicator # 做多信号: 价格在窗口均值上方 closes = np.array([b.close for b in self.get_bar_history()[-self.spectral_window:]], dtype=float) if "BUY" in self.order_direction and np.mean(closes[-5:]) > np.mean(closes): direction = "BUY" if indicator.is_condition_met(*self.get_indicator_tuple()) else "SELL" # 做空信号: 价格在窗口均值下方 elif "SELL" in self.order_direction and np.mean(closes[-5:]) < np.mean(closes): direction = "SELL" if indicator.is_condition_met(*self.get_indicator_tuple()) else "BUY" if direction: self.send_limit_order(direction, open_price, self.trade_volume, "OPEN") self.entry_time = self.get_current_time() self.position_direction = "LONG" if direction == "BUY" else "SHORT" def manage_open_position(self, volume: int, trend_strength: float, dominant_freq: float): """管理持仓:仅当相变逆转时退出""" # 相变逆转条件: 趋势强度 < 退出阈值 if trend_strength < self.exit_threshold: direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT" self.log(f"Phase Transition Exit: {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}") self.close_position(direction, abs(volume)) self.entry_time = None self.position_direction = None # --- 辅助函数区 --- def close_all_positions(self): """强制平仓所有头寸""" positions = self.get_current_positions() if self.symbol in positions and positions[self.symbol] != 0: direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT" self.close_position(direction, abs(positions[self.symbol])) self.log(f"Forced exit of {abs(positions[self.symbol])} contracts") def close_position(self, direction: str, volume: int): self.send_market_order(direction, volume, offset="CLOSE") def send_market_order(self, direction: str, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET", submitted_time=self.get_current_time(), offset=offset ) self.send_order(order) def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset, limit_price=limit_price ) self.send_order(order) def on_init(self): super().on_init() self.cancel_all_pending_orders(self.main_symbol) self.log("Strategy initialized. Waiting for phase transition signals...") def on_rollover(self, old_symbol: str, new_symbol: str): super().on_rollover(old_symbol, new_symbol) self.log(f"Rollover from {old_symbol} to {new_symbol}. Resetting position state.") self.entry_time = None self.position_direction = None self.last_trend_strength = 0.0