from typing import Any import numpy as np import pywt from src.core_data import Order from src.strategies.base_strategy import Strategy # ============================================================================= # 策略实现 (WaveletDynamicsStrategy - 全新动态分析策略) # ============================================================================= class WaveletSignalNoiseStrategy(Strategy): """ 小波信噪比策略 (最终版) 核心哲学: 1. 信任小波: 策略完全基于小波变换最独特的“信号/噪音”分离能力。 2. 简洁因子: 使用一个核心因子——趋势信噪比(TNR),衡量趋势的质量。 3. 可靠逻辑: - 当信噪比高(趋势清晰)时入场。 - 当信噪比低(噪音过大)时出场。 """ def __init__( self, context: Any, main_symbol: str, enable_log: bool, trade_volume: int, # --- 【核心参数】 --- bars_per_day: int = 23, analysis_window_days: float = 2.0, # 窗口长度适中即可 wavelet_family: str = 'db4', # --- 【信噪比交易阈值】 --- tnr_entry_threshold: float = 5, # 入场阈值:信号强度至少是噪音的2倍 tnr_exit_threshold: float = 5, # 离场阈值:信号强度不再显著高于噪音 # --- 【持仓管理】 --- max_hold_days: int = 10, ): super().__init__(context, main_symbol, enable_log) # ... (参数赋值) ... self.bars_per_day = bars_per_day self.analysis_window_days = analysis_window_days self.wavelet = wavelet_family self.tnr_entry_threshold = tnr_entry_threshold self.tnr_exit_threshold = tnr_exit_threshold self.trade_volume = trade_volume self.max_hold_days = max_hold_days self.analysis_window = int(self.analysis_window_days * self.bars_per_day) self.decomposition_level = pywt.dwt_max_level(self.analysis_window, self.wavelet) self.entry_time = None self.order_id_counter = 0 self.log("WaveletSignalNoiseStrategy Initialized.") def calculate_trend_noise_ratio(self, prices: np.array) -> (float, np.array): """ 【最终核心】计算趋势信噪比(TNR)和内在趋势线 返回: (tnr_factor, trend_signal) """ if len(prices) < self.analysis_window: return 0.0, None window_data = prices[-self.analysis_window:] try: coeffs = pywt.wavedec(window_data, self.wavelet, level=self.decomposition_level) # 1. 重构内在趋势信号 (Signal) trend_coeffs = [coeffs[0]] + [np.zeros_like(d) for d in coeffs[1:]] trend_signal = pywt.waverec(trend_coeffs, self.wavelet) trend_signal = trend_signal[:len(window_data)] # 2. 重构噪音信号 (Noise) noise_coeffs = [np.zeros_like(coeffs[0])] + coeffs[1:] noise_signal = pywt.waverec(noise_coeffs, self.wavelet) noise_signal = noise_signal[:len(window_data)] # 3. 计算各自的强度 (标准差) strength_trend = np.std(trend_signal) strength_noise = np.std(noise_signal) # 4. 计算信噪比因子 if strength_noise < 1e-9: # 避免除以零 tnr_factor = np.inf else: tnr_factor = strength_trend / strength_noise return tnr_factor, trend_signal except Exception as e: self.log(f"TNR calculation error: {e}", "ERROR") return 0.0, None def on_open_bar(self, open_price: float, symbol: str): self.symbol = symbol bar_history = self.get_bar_history() position_volume = self.get_current_positions().get(self.symbol, 0) self.cancel_all_pending_orders(self.main_symbol) if len(bar_history) < self.analysis_window: return closes = np.array([b.close for b in bar_history], dtype=float) tnr_factor, trend_signal = self.calculate_trend_noise_ratio(closes) if trend_signal is None: return if position_volume == 0: self.evaluate_entry_signal(open_price, tnr_factor, trend_signal) else: self.manage_open_position(position_volume, tnr_factor) def evaluate_entry_signal(self, open_price: float, tnr_factor: float, trend_signal: np.array): """入场逻辑:信噪比达标 + 方向确认""" if tnr_factor < self.tnr_entry_threshold: return direction = None # 方向判断:内在趋势线的斜率 # if len(trend_signal) < 5: return if trend_signal[-1] > trend_signal[-5]: direction = "SELL" elif trend_signal[-1] < trend_signal[-5]: direction = "BUY" if direction: self.log(f"Entry Signal: {direction} | Trend-Noise Ratio={tnr_factor:.2f}") self.entry_time = self.get_current_time() self.send_limit_order(direction, open_price, self.trade_volume, "OPEN") def manage_open_position(self, volume: int, tnr_factor: float): """出场逻辑:信噪比低于退出阈值""" if tnr_factor < self.tnr_exit_threshold: direction_str = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT" self.log(f"Exit Signal: TNR ({tnr_factor:.2f}) < Threshold ({self.tnr_exit_threshold})") self.close_position(direction_str, abs(volume)) self.entry_time = None # --- 辅助函数区 (与之前版本相同) --- # (此处省略,以保持简洁) # --- 辅助函数区 (与之前版本相同) --- # --- 辅助函数区 --- def close_all_positions(self): """强制平仓所有头寸""" positions = self.get_current_positions() if self.symbol in positions and positions[self.symbol] != 0: direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT" self.close_position(direction, abs(positions[self.symbol])) self.log(f"Forced exit of {abs(positions[self.symbol])} contracts") def close_position(self, direction: str, volume: int): self.send_market_order(direction, volume, offset="CLOSE") def send_market_order(self, direction: str, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET", submitted_time=self.get_current_time(), offset=offset ) self.send_order(order) def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str): order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset, limit_price=limit_price ) self.send_order(order) def on_init(self): super().on_init() self.cancel_all_pending_orders(self.main_symbol) self.log("Strategy initialized. Waiting for phase transition signals...") def on_rollover(self, old_symbol: str, new_symbol: str): super().on_rollover(old_symbol, new_symbol) self.log(f"Rollover from {old_symbol} to {new_symbol}. Resetting position state.") self.entry_time = None self.position_direction = None self.last_trend_strength = 0.0