Files
NewQuant/futures_trading_strategies/rb/Spectral/SpectralTrendStrategy2.py

200 lines
7.7 KiB
Python
Raw Normal View History

2025-11-29 16:35:02 +08:00
from typing import Any
import numpy as np
import pywt
from src.core_data import Order
from src.strategies.base_strategy import Strategy
# =============================================================================
# 策略实现 (WaveletDynamicsStrategy - 全新动态分析策略)
# =============================================================================
class WaveletSignalNoiseStrategy(Strategy):
"""
小波信噪比策略 (最终版)
核心哲学:
1. 信任小波: 策略完全基于小波变换最独特的信号/噪音分离能力
2. 简洁因子: 使用一个核心因子趋势信噪比(TNR)衡量趋势的质量
3. 可靠逻辑:
- 当信噪比高趋势清晰时入场
- 当信噪比低噪音过大时出场
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
# --- 【核心参数】 ---
bars_per_day: int = 23,
analysis_window_days: float = 2.0, # 窗口长度适中即可
wavelet_family: str = 'db4',
# --- 【信噪比交易阈值】 ---
tnr_entry_threshold: float = 5, # 入场阈值信号强度至少是噪音的2倍
tnr_exit_threshold: float = 5, # 离场阈值:信号强度不再显著高于噪音
# --- 【持仓管理】 ---
max_hold_days: int = 10,
):
super().__init__(context, main_symbol, enable_log)
# ... (参数赋值) ...
self.bars_per_day = bars_per_day
self.analysis_window_days = analysis_window_days
self.wavelet = wavelet_family
self.tnr_entry_threshold = tnr_entry_threshold
self.tnr_exit_threshold = tnr_exit_threshold
self.trade_volume = trade_volume
self.max_hold_days = max_hold_days
self.analysis_window = int(self.analysis_window_days * self.bars_per_day)
self.decomposition_level = pywt.dwt_max_level(self.analysis_window, self.wavelet)
self.entry_time = None
self.order_id_counter = 0
self.log("WaveletSignalNoiseStrategy Initialized.")
def calculate_trend_noise_ratio(self, prices: np.array) -> (float, np.array):
"""
最终核心计算趋势信噪比(TNR)和内在趋势线
返回: (tnr_factor, trend_signal)
"""
if len(prices) < self.analysis_window:
return 0.0, None
window_data = prices[-self.analysis_window:]
try:
coeffs = pywt.wavedec(window_data, self.wavelet, level=self.decomposition_level)
# 1. 重构内在趋势信号 (Signal)
trend_coeffs = [coeffs[0]] + [np.zeros_like(d) for d in coeffs[1:]]
trend_signal = pywt.waverec(trend_coeffs, self.wavelet)
trend_signal = trend_signal[:len(window_data)]
# 2. 重构噪音信号 (Noise)
noise_coeffs = [np.zeros_like(coeffs[0])] + coeffs[1:]
noise_signal = pywt.waverec(noise_coeffs, self.wavelet)
noise_signal = noise_signal[:len(window_data)]
# 3. 计算各自的强度 (标准差)
strength_trend = np.std(trend_signal)
strength_noise = np.std(noise_signal)
# 4. 计算信噪比因子
if strength_noise < 1e-9: # 避免除以零
tnr_factor = np.inf
else:
tnr_factor = strength_trend / strength_noise
return tnr_factor, trend_signal
except Exception as e:
self.log(f"TNR calculation error: {e}", "ERROR")
return 0.0, None
def on_open_bar(self, open_price: float, symbol: str):
self.symbol = symbol
bar_history = self.get_bar_history()
position_volume = self.get_current_positions().get(self.symbol, 0)
self.cancel_all_pending_orders(self.main_symbol)
if len(bar_history) < self.analysis_window:
return
closes = np.array([b.close for b in bar_history], dtype=float)
tnr_factor, trend_signal = self.calculate_trend_noise_ratio(closes)
if trend_signal is None: return
if position_volume == 0:
self.evaluate_entry_signal(open_price, tnr_factor, trend_signal)
else:
self.manage_open_position(position_volume, tnr_factor)
def evaluate_entry_signal(self, open_price: float, tnr_factor: float, trend_signal: np.array):
"""入场逻辑:信噪比达标 + 方向确认"""
if tnr_factor < self.tnr_entry_threshold:
return
direction = None
# 方向判断:内在趋势线的斜率
# if len(trend_signal) < 5: return
if trend_signal[-1] > trend_signal[-5]:
direction = "SELL"
elif trend_signal[-1] < trend_signal[-5]:
direction = "BUY"
if direction:
self.log(f"Entry Signal: {direction} | Trend-Noise Ratio={tnr_factor:.2f}")
self.entry_time = self.get_current_time()
self.send_limit_order(direction, open_price, self.trade_volume, "OPEN")
def manage_open_position(self, volume: int, tnr_factor: float):
"""出场逻辑:信噪比低于退出阈值"""
if tnr_factor < self.tnr_exit_threshold:
direction_str = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
self.log(f"Exit Signal: TNR ({tnr_factor:.2f}) < Threshold ({self.tnr_exit_threshold})")
self.close_position(direction_str, abs(volume))
self.entry_time = None
# --- 辅助函数区 (与之前版本相同) ---
# (此处省略,以保持简洁)
# --- 辅助函数区 (与之前版本相同) ---
# --- 辅助函数区 ---
def close_all_positions(self):
"""强制平仓所有头寸"""
positions = self.get_current_positions()
if self.symbol in positions and positions[self.symbol] != 0:
direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
self.close_position(direction, abs(positions[self.symbol]))
self.log(f"Forced exit of {abs(positions[self.symbol])} contracts")
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
symbol=self.symbol,
direction=direction,
volume=volume,
price_type="MARKET",
submitted_time=self.get_current_time(),
offset=offset
)
self.send_order(order)
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
symbol=self.symbol,
direction=direction,
volume=volume,
price_type="LIMIT",
submitted_time=self.get_current_time(),
offset=offset,
limit_price=limit_price
)
self.send_order(order)
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
self.log("Strategy initialized. Waiting for phase transition signals...")
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self.log(f"Rollover from {old_symbol} to {new_symbol}. Resetting position state.")
self.entry_time = None
self.position_direction = None
self.last_trend_strength = 0.0