Files
NewQuant/futures_trading_strategies/MA/Spectral/SpectralTrendStrategy.py
2025-11-29 16:35:02 +08:00

291 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
from scipy.signal import stft
from datetime import datetime, timedelta
from typing import Optional, Any, List, Dict
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty, NormalizedATR, AtrVolatility, ZScoreATR
from src.strategies.base_strategy import Strategy
# =============================================================================
# 策略实现 (SpectralTrendStrategy)
# =============================================================================
class SpectralTrendStrategy(Strategy):
"""
频域能量相变策略 - 捕获肥尾趋势
核心哲学:
1. 显式傅里叶变换: 直接分离低频(趋势)、高频(噪音)能量
2. 相变临界点: 仅当低频能量占比 > 阈值时入场
3. 低频交易: 每月仅2-5次信号持仓数日捕获肥尾
4. 完全参数化: 无硬编码,适配任何市场时间结构
参数说明:
- bars_per_day: 市场每日K线数量 (e.g., 23 for 15min US markets)
- low_freq_days: 低频定义下限 (天), 默认2.0
- high_freq_days: 高频定义上限 (天), 默认1.0
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
# --- 【市场结构参数】 ---
bars_per_day: int = 23, # 关键: 适配23根/天的市场
# --- 【频域核心参数】 ---
spectral_window_days: float = 2.0, # STFT窗口大小(天)
low_freq_days: float = 2.0, # 低频下限(天)
high_freq_days: float = 1.0, # 高频上限(天)
trend_strength_threshold: float = 0.1, # 相变临界值
exit_threshold: float = 0.4, # 退出阈值
# --- 【持仓管理】 ---
max_hold_days: int = 10, # 最大持仓天数
# --- 其他 ---
order_direction: Optional[List[str]] = None,
indicators: Indicator = None,
model_indicator: Indicator = None,
reverse: bool = False,
):
super().__init__(context, main_symbol, enable_log)
if order_direction is None:
order_direction = ['BUY', 'SELL']
if indicators is None:
indicators = Empty() # 保持兼容性
# --- 参数赋值 (完全参数化) ---
self.trade_volume = trade_volume
self.bars_per_day = bars_per_day
self.spectral_window_days = spectral_window_days
self.low_freq_days = low_freq_days
self.high_freq_days = high_freq_days
self.trend_strength_threshold = trend_strength_threshold
self.exit_threshold = exit_threshold
self.max_hold_days = max_hold_days
self.order_direction = order_direction
if model_indicator is None:
model_indicator = Empty()
self.model_indicator = model_indicator
# --- 动态计算参数 ---
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
# 确保窗口大小为偶数 (STFT要求)
self.spectral_window = self.spectral_window if self.spectral_window % 2 == 0 else self.spectral_window + 1
# 频率边界 (cycles/day)
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
# --- 内部状态变量 ---
self.main_symbol = main_symbol
self.order_id_counter = 0
self.indicators = indicators
self.entry_time = None # 入场时间
self.position_direction = None # 'LONG' or 'SHORT'
self.last_trend_strength = 0.0
self.last_dominant_freq = 0.0 # 主导周期(天)
self.reverse = reverse
self.log(f"SpectralTrendStrategy Initialized (bars/day={bars_per_day}, window={self.spectral_window} bars)")
def on_open_bar(self, open_price: float, symbol: str):
"""每根K线开盘时被调用"""
self.symbol = symbol
bar_history = self.get_bar_history()
current_time = self.get_current_time()
self.cancel_all_pending_orders(self.main_symbol)
# 需要足够的数据 (STFT窗口 + 缓冲)
if len(bar_history) < self.spectral_window + 10:
if self.enable_log and len(bar_history) % 50 == 0:
self.log(f"Waiting for {len(bar_history)}/{self.spectral_window + 10} bars")
return
position_volume = self.get_current_positions().get(self.symbol, 0)
# 获取历史价格 (使用完整历史)
closes = np.array([b.close for b in bar_history[-self.spectral_window:]], dtype=float)
# 【核心】计算频域趋势强度 (显式傅里叶)
trend_strength, dominant_freq = self.calculate_trend_strength(closes)
self.last_trend_strength = trend_strength
self.last_dominant_freq = dominant_freq
# 检查最大持仓时间 (防止极端事件)
if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days):
self.log(f"Max hold time reached ({self.max_hold_days} days). Forcing exit.")
self.close_all_positions()
self.entry_time = None
self.position_direction = None
return
# 核心逻辑:相变入场/退出
if self.trading:
if position_volume == 0:
self.evaluate_entry_signal(open_price, trend_strength, dominant_freq)
else:
self.manage_open_position(position_volume, trend_strength, dominant_freq)
def calculate_trend_strength(self, prices: np.array) -> (float, float):
"""
【显式傅里叶】计算低频能量占比 (完全参数化)
步骤:
1. 价格归一化 (窗口内)
2. 短时傅里叶变换 (STFT) - 采样率=bars_per_day
3. 动态计算频段边界 (基于bars_per_day)
4. 趋势强度 = 低频能量 / (低频+高频能量)
"""
# 1. 验证数据长度
if len(prices) < self.spectral_window:
return 0.0, 0.0
# 2. 价格归一化 (仅使用窗口内数据)
window_data = prices[-self.spectral_window * 10:]
normalized = (window_data - np.mean(window_data)) / (np.std(window_data) + 1e-8)
normalized = normalized[-self.spectral_window:]
# 3. STFT (采样率=bars_per_day)
try:
# fs: 每天的样本数 (bars_per_day)
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day, # 关键: 适配市场结构
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
except Exception as e:
self.log(f"STFT calculation error: {str(e)}")
return 0.0, 0.0
# 4. 过滤无效频率 (STFT返回频率范围: 0 到 fs/2)
valid_mask = (f >= 0) & (f <= self.bars_per_day / 2)
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
if Zxx.size == 0 or Zxx.shape[1] == 0:
return 0.0, 0.0
# 5. 计算最新时间点的能量
current_energy = np.abs(Zxx[:, -1]) ** 2
# 6. 动态频段定义 (cycles/day)
# 低频: 周期 > low_freq_days → 频率 < 1/low_freq_days
low_freq_mask = f < self.low_freq_bound
# 高频: 周期 < high_freq_days → 频率 > 1/high_freq_days
high_freq_mask = f > self.high_freq_bound
# 7. 能量计算
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
total_energy = low_energy + high_energy + 1e-8 # 防除零
# 8. 趋势强度 = 低频能量占比
trend_strength = low_energy / total_energy
# 9. 计算主导趋势周期 (天)
dominant_freq = 0.0
if np.any(low_freq_mask) and low_energy > 0:
# 找到低频段最大能量对应的频率
low_energies = current_energy[low_freq_mask]
max_idx = np.argmax(low_energies)
dominant_freq = 1.0 / (f[low_freq_mask][max_idx] + 1e-8) # 转换为周期(天)
return trend_strength, dominant_freq
def evaluate_entry_signal(self, open_price: float, trend_strength: float, dominant_freq: float):
"""评估相变入场信号"""
# 仅当趋势强度跨越临界点且有明确周期时入场
self.log(
f"Strength={trend_strength:.2f}")
if trend_strength > self.trend_strength_threshold:
direction = None
indicator = self.model_indicator
# 做多信号: 价格在窗口均值上方
closes = np.array([b.close for b in self.get_bar_history()[-self.spectral_window:]], dtype=float)
if "BUY" in self.order_direction and np.mean(closes[-5:]) > np.mean(closes):
direction = "BUY" if indicator.is_condition_met(*self.get_indicator_tuple()) else "SELL"
# 做空信号: 价格在窗口均值下方
elif "SELL" in self.order_direction and np.mean(closes[-5:]) < np.mean(closes):
direction = "SELL" if indicator.is_condition_met(*self.get_indicator_tuple()) else "BUY"
if direction and self.indicators.is_condition_met(*self.get_indicator_tuple()):
if self.reverse:
direction = "SELL" if direction == "BUY" else "BUY"
self.log(f"Direction={direction}, Open Position")
self.send_limit_order(direction, open_price, self.trade_volume, "OPEN")
self.entry_time = self.get_current_time()
self.position_direction = "LONG" if direction == "BUY" else "SHORT"
def manage_open_position(self, volume: int, trend_strength: float, dominant_freq: float):
"""管理持仓:仅当相变逆转时退出"""
# 相变逆转条件: 趋势强度 < 退出阈值
if trend_strength < self.exit_threshold:
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
self.log(f"Phase Transition Exit: {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
self.close_position(direction, abs(volume))
self.entry_time = None
self.position_direction = None
# --- 辅助函数区 ---
def close_all_positions(self):
"""强制平仓所有头寸"""
positions = self.get_current_positions()
if self.symbol in positions and positions[self.symbol] != 0:
direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
self.close_position(direction, abs(positions[self.symbol]))
self.log(f"Forced exit of {abs(positions[self.symbol])} contracts")
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
symbol=self.symbol,
direction=direction,
volume=volume,
price_type="MARKET",
submitted_time=self.get_current_time(),
offset=offset
)
self.send_order(order)
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
symbol=self.symbol,
direction=direction,
volume=volume,
price_type="LIMIT",
submitted_time=self.get_current_time(),
offset=offset,
limit_price=limit_price
)
self.send_order(order)
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
self.log("Strategy initialized. Waiting for phase transition signals...")
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self.log(f"Rollover from {old_symbol} to {new_symbol}. Resetting position state.")
self.entry_time = None
self.position_direction = None
self.last_trend_strength = 0.0