200 lines
7.7 KiB
Python
200 lines
7.7 KiB
Python
from typing import Any
|
||
|
||
import numpy as np
|
||
import pywt
|
||
|
||
from src.core_data import Order
|
||
from src.strategies.base_strategy import Strategy
|
||
|
||
|
||
# =============================================================================
|
||
# 策略实现 (WaveletDynamicsStrategy - 全新动态分析策略)
|
||
# =============================================================================
|
||
|
||
class WaveletSignalNoiseStrategy(Strategy):
|
||
"""
|
||
小波信噪比策略 (最终版)
|
||
|
||
核心哲学:
|
||
1. 信任小波: 策略完全基于小波变换最独特的“信号/噪音”分离能力。
|
||
2. 简洁因子: 使用一个核心因子——趋势信噪比(TNR),衡量趋势的质量。
|
||
3. 可靠逻辑:
|
||
- 当信噪比高(趋势清晰)时入场。
|
||
- 当信噪比低(噪音过大)时出场。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
context: Any,
|
||
main_symbol: str,
|
||
enable_log: bool,
|
||
trade_volume: int,
|
||
# --- 【核心参数】 ---
|
||
bars_per_day: int = 23,
|
||
analysis_window_days: float = 2.0, # 窗口长度适中即可
|
||
wavelet_family: str = 'db4',
|
||
# --- 【信噪比交易阈值】 ---
|
||
tnr_entry_threshold: float = 5, # 入场阈值:信号强度至少是噪音的2倍
|
||
tnr_exit_threshold: float = 5, # 离场阈值:信号强度不再显著高于噪音
|
||
# --- 【持仓管理】 ---
|
||
max_hold_days: int = 10,
|
||
):
|
||
super().__init__(context, main_symbol, enable_log)
|
||
# ... (参数赋值) ...
|
||
self.bars_per_day = bars_per_day
|
||
self.analysis_window_days = analysis_window_days
|
||
self.wavelet = wavelet_family
|
||
self.tnr_entry_threshold = tnr_entry_threshold
|
||
self.tnr_exit_threshold = tnr_exit_threshold
|
||
self.trade_volume = trade_volume
|
||
self.max_hold_days = max_hold_days
|
||
|
||
self.analysis_window = int(self.analysis_window_days * self.bars_per_day)
|
||
self.decomposition_level = pywt.dwt_max_level(self.analysis_window, self.wavelet)
|
||
|
||
self.entry_time = None
|
||
self.order_id_counter = 0
|
||
self.log("WaveletSignalNoiseStrategy Initialized.")
|
||
|
||
def calculate_trend_noise_ratio(self, prices: np.array) -> (float, np.array):
|
||
"""
|
||
【最终核心】计算趋势信噪比(TNR)和内在趋势线
|
||
返回: (tnr_factor, trend_signal)
|
||
"""
|
||
if len(prices) < self.analysis_window:
|
||
return 0.0, None
|
||
|
||
window_data = prices[-self.analysis_window:]
|
||
|
||
try:
|
||
coeffs = pywt.wavedec(window_data, self.wavelet, level=self.decomposition_level)
|
||
|
||
# 1. 重构内在趋势信号 (Signal)
|
||
trend_coeffs = [coeffs[0]] + [np.zeros_like(d) for d in coeffs[1:]]
|
||
trend_signal = pywt.waverec(trend_coeffs, self.wavelet)
|
||
trend_signal = trend_signal[:len(window_data)]
|
||
|
||
# 2. 重构噪音信号 (Noise)
|
||
noise_coeffs = [np.zeros_like(coeffs[0])] + coeffs[1:]
|
||
noise_signal = pywt.waverec(noise_coeffs, self.wavelet)
|
||
noise_signal = noise_signal[:len(window_data)]
|
||
|
||
# 3. 计算各自的强度 (标准差)
|
||
strength_trend = np.std(trend_signal)
|
||
strength_noise = np.std(noise_signal)
|
||
|
||
# 4. 计算信噪比因子
|
||
if strength_noise < 1e-9: # 避免除以零
|
||
tnr_factor = np.inf
|
||
else:
|
||
tnr_factor = strength_trend / strength_noise
|
||
|
||
return tnr_factor, trend_signal
|
||
|
||
except Exception as e:
|
||
self.log(f"TNR calculation error: {e}", "ERROR")
|
||
return 0.0, None
|
||
|
||
def on_open_bar(self, open_price: float, symbol: str):
|
||
self.symbol = symbol
|
||
bar_history = self.get_bar_history()
|
||
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||
|
||
self.cancel_all_pending_orders(self.main_symbol)
|
||
|
||
if len(bar_history) < self.analysis_window:
|
||
return
|
||
|
||
closes = np.array([b.close for b in bar_history], dtype=float)
|
||
tnr_factor, trend_signal = self.calculate_trend_noise_ratio(closes)
|
||
|
||
if trend_signal is None: return
|
||
|
||
if position_volume == 0:
|
||
self.evaluate_entry_signal(open_price, tnr_factor, trend_signal)
|
||
else:
|
||
self.manage_open_position(position_volume, tnr_factor)
|
||
|
||
def evaluate_entry_signal(self, open_price: float, tnr_factor: float, trend_signal: np.array):
|
||
"""入场逻辑:信噪比达标 + 方向确认"""
|
||
if tnr_factor < self.tnr_entry_threshold:
|
||
return
|
||
|
||
direction = None
|
||
# 方向判断:内在趋势线的斜率
|
||
# if len(trend_signal) < 5: return
|
||
|
||
if trend_signal[-1] > trend_signal[-5]:
|
||
direction = "SELL"
|
||
elif trend_signal[-1] < trend_signal[-5]:
|
||
direction = "BUY"
|
||
|
||
if direction:
|
||
self.log(f"Entry Signal: {direction} | Trend-Noise Ratio={tnr_factor:.2f}")
|
||
self.entry_time = self.get_current_time()
|
||
self.send_limit_order(direction, open_price, self.trade_volume, "OPEN")
|
||
|
||
def manage_open_position(self, volume: int, tnr_factor: float):
|
||
"""出场逻辑:信噪比低于退出阈值"""
|
||
if tnr_factor < self.tnr_exit_threshold:
|
||
direction_str = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
|
||
self.log(f"Exit Signal: TNR ({tnr_factor:.2f}) < Threshold ({self.tnr_exit_threshold})")
|
||
self.close_position(direction_str, abs(volume))
|
||
self.entry_time = None
|
||
|
||
# --- 辅助函数区 (与之前版本相同) ---
|
||
# (此处省略,以保持简洁)
|
||
|
||
# --- 辅助函数区 (与之前版本相同) ---
|
||
# --- 辅助函数区 ---
|
||
def close_all_positions(self):
|
||
"""强制平仓所有头寸"""
|
||
positions = self.get_current_positions()
|
||
if self.symbol in positions and positions[self.symbol] != 0:
|
||
direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
|
||
self.close_position(direction, abs(positions[self.symbol]))
|
||
self.log(f"Forced exit of {abs(positions[self.symbol])} contracts")
|
||
|
||
def close_position(self, direction: str, volume: int):
|
||
self.send_market_order(direction, volume, offset="CLOSE")
|
||
|
||
def send_market_order(self, direction: str, volume: int, offset: str):
|
||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(
|
||
id=order_id,
|
||
symbol=self.symbol,
|
||
direction=direction,
|
||
volume=volume,
|
||
price_type="MARKET",
|
||
submitted_time=self.get_current_time(),
|
||
offset=offset
|
||
)
|
||
self.send_order(order)
|
||
|
||
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
|
||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(
|
||
id=order_id,
|
||
symbol=self.symbol,
|
||
direction=direction,
|
||
volume=volume,
|
||
price_type="LIMIT",
|
||
submitted_time=self.get_current_time(),
|
||
offset=offset,
|
||
limit_price=limit_price
|
||
)
|
||
self.send_order(order)
|
||
|
||
def on_init(self):
|
||
super().on_init()
|
||
self.cancel_all_pending_orders(self.main_symbol)
|
||
self.log("Strategy initialized. Waiting for phase transition signals...")
|
||
|
||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||
super().on_rollover(old_symbol, new_symbol)
|
||
self.log(f"Rollover from {old_symbol} to {new_symbol}. Resetting position state.")
|
||
self.entry_time = None
|
||
self.position_direction = None
|
||
self.last_trend_strength = 0.0 |