377 lines
16 KiB
Python
377 lines
16 KiB
Python
import numpy as np
|
|
from scipy.signal import stft
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Any, List, Dict
|
|
|
|
from src.core_data import Bar, Order
|
|
from src.indicators.base_indicators import Indicator
|
|
from src.indicators.indicators import Empty
|
|
from src.strategies.base_strategy import Strategy
|
|
|
|
|
|
# =============================================================================
|
|
# 策略实现 (VolatilityAdaptiveSpectralStrategy)
|
|
# =============================================================================
|
|
|
|
class SpectralTrendStrategy(Strategy):
|
|
"""
|
|
波动率自适应频域趋势策略
|
|
|
|
核心哲学:
|
|
1. 显式傅里叶变换: 分离低频(趋势)、高频(噪音)能量
|
|
2. 波动率条件信号: 根据波动率环境动态调整交易方向
|
|
- 低波动环境: 趋势策略 (高趋势强度 → 延续)
|
|
- 高波动环境: 反转策略 (高趋势强度 → 反转)
|
|
3. 无硬编码参数: 所有阈值通过配置参数设定
|
|
4. 严格无未来函数: 所有计算使用历史数据
|
|
|
|
参数说明:
|
|
- bars_per_day: 市场每日K线数量
|
|
- volatility_lookback: 波动率计算窗口(天)
|
|
- low_vol_threshold: 低波动环境阈值(0-1)
|
|
- high_vol_threshold: 高波动环境阈值(0-1)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: Any,
|
|
main_symbol: str,
|
|
enable_log: bool,
|
|
trade_volume: int,
|
|
# --- 【市场结构参数】 ---
|
|
bars_per_day: int = 23, # 适配23根/天的市场
|
|
# --- 【频域核心参数】 ---
|
|
spectral_window_days: float = 2.0, # STFT窗口大小(天)
|
|
low_freq_days: float = 2.0, # 低频下限(天)
|
|
high_freq_days: float = 1.0, # 高频上限(天)
|
|
trend_strength_threshold: float = 0.8, # 趋势强度阈值
|
|
exit_threshold: float = 0.5, # 退出阈值
|
|
# --- 【波动率参数】 ---
|
|
volatility_lookback_days: float = 5.0, # 波动率计算窗口(天)
|
|
low_vol_threshold: float = 0.3, # 低波动环境阈值(0-1)
|
|
high_vol_threshold: float = 0.7, # 高波动环境阈值(0-1)
|
|
# --- 【持仓管理】 ---
|
|
max_hold_days: int = 10, # 最大持仓天数
|
|
# --- 其他 ---
|
|
order_direction: Optional[List[str]] = None,
|
|
indicators: Optional[List[Indicator]] = None,
|
|
):
|
|
super().__init__(context, main_symbol, enable_log)
|
|
if order_direction is None:
|
|
order_direction = ['BUY', 'SELL']
|
|
if indicators is None:
|
|
indicators = [Empty(), Empty()] # 保持兼容性
|
|
|
|
# --- 参数赋值 (完全参数化) ---
|
|
self.trade_volume = trade_volume
|
|
self.bars_per_day = bars_per_day
|
|
self.spectral_window_days = spectral_window_days
|
|
self.low_freq_days = low_freq_days
|
|
self.high_freq_days = high_freq_days
|
|
self.trend_strength_threshold = trend_strength_threshold
|
|
self.exit_threshold = exit_threshold
|
|
self.volatility_lookback_days = volatility_lookback_days
|
|
self.low_vol_threshold = low_vol_threshold
|
|
self.high_vol_threshold = high_vol_threshold
|
|
self.max_hold_days = max_hold_days
|
|
self.order_direction = order_direction
|
|
|
|
# --- 动态计算参数 ---
|
|
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
|
|
self.spectral_window = self.spectral_window if self.spectral_window % 2 == 0 else self.spectral_window + 1
|
|
self.volatility_window = int(self.volatility_lookback_days * self.bars_per_day)
|
|
|
|
# 频率边界 (cycles/day)
|
|
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
|
|
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
|
|
|
|
# --- 内部状态变量 ---
|
|
self.main_symbol = main_symbol
|
|
self.order_id_counter = 0
|
|
self.indicators = indicators
|
|
self.entry_time = None # 入场时间
|
|
self.position_direction = None # 'LONG' or 'SHORT'
|
|
self.last_trend_strength = 0.0
|
|
self.last_dominant_freq = 0.0 # 主导周期(天)
|
|
self.last_volatility = 0.0 # 标准化波动率(0-1)
|
|
self.volatility_history = [] # 存储历史波动率
|
|
|
|
self.log(f"VolatilityAdaptiveSpectralStrategy Initialized (bars/day={bars_per_day}, "
|
|
f"window={self.spectral_window} bars, vol_window={self.volatility_window} bars)")
|
|
|
|
def on_open_bar(self, open_price: float, symbol: str):
|
|
"""每根K线开盘时被调用"""
|
|
self.symbol = symbol
|
|
bar_history = self.get_bar_history()
|
|
current_time = self.get_current_time()
|
|
|
|
# 需要足够的数据 (最大窗口 + 缓冲)
|
|
min_required = max(self.spectral_window, self.volatility_window) + 10
|
|
if len(bar_history) < min_required:
|
|
if self.enable_log and len(bar_history) % 50 == 0:
|
|
self.log(f"Waiting for {len(bar_history)}/{min_required} bars")
|
|
return
|
|
|
|
position_volume = self.get_current_positions().get(self.symbol, 0)
|
|
|
|
# 获取必要历史价格 (仅取所需部分)
|
|
recent_bars = bar_history[-(max(self.spectral_window, self.volatility_window) + 5):]
|
|
closes = np.array([b.close for b in recent_bars], dtype=np.float32)
|
|
highs = np.array([b.high for b in recent_bars], dtype=np.float32)
|
|
lows = np.array([b.low for b in recent_bars], dtype=np.float32)
|
|
|
|
# 【核心】计算频域趋势强度 (显式傅里叶)
|
|
trend_strength, dominant_freq = self.calculate_trend_strength(closes)
|
|
self.last_trend_strength = trend_strength
|
|
self.last_dominant_freq = dominant_freq
|
|
|
|
# 【核心】计算标准化波动率 (0-1范围)
|
|
volatility = self.calculate_normalized_volatility(highs, lows, closes)
|
|
self.last_volatility = volatility
|
|
|
|
# 检查最大持仓时间 (防止极端事件)
|
|
if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days):
|
|
self.log(f"Max hold time reached ({self.max_hold_days} days). Forcing exit.")
|
|
self.close_all_positions()
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
return
|
|
|
|
# 核心逻辑:相变入场/退出
|
|
if position_volume == 0:
|
|
self.evaluate_entry_signal(open_price, trend_strength, dominant_freq, volatility, recent_bars)
|
|
else:
|
|
self.manage_open_position(position_volume, trend_strength, volatility)
|
|
|
|
def calculate_trend_strength(self, closes: np.array) -> (float, float):
|
|
"""
|
|
【显式傅里叶】计算低频能量占比 (完全参数化)
|
|
"""
|
|
if len(closes) < self.spectral_window:
|
|
return 0.0, 0.0
|
|
|
|
# 仅使用窗口内数据
|
|
window_data = closes[-self.spectral_window:]
|
|
window_mean = np.mean(window_data)
|
|
window_std = np.std(window_data)
|
|
if window_std < 1e-8:
|
|
return 0.0, 0.0
|
|
|
|
normalized = (window_data - window_mean) / window_std
|
|
|
|
try:
|
|
f, t, Zxx = stft(
|
|
normalized,
|
|
fs=self.bars_per_day,
|
|
nperseg=self.spectral_window,
|
|
noverlap=max(0, self.spectral_window // 2),
|
|
boundary=None,
|
|
padded=False
|
|
)
|
|
except Exception as e:
|
|
self.log(f"STFT calculation error: {str(e)}")
|
|
return 0.0, 0.0
|
|
|
|
# 过滤无效频率
|
|
max_freq = self.bars_per_day / 2
|
|
valid_mask = (f >= 0) & (f <= max_freq)
|
|
if not np.any(valid_mask):
|
|
return 0.0, 0.0
|
|
|
|
f = f[valid_mask]
|
|
Zxx = Zxx[valid_mask, :]
|
|
|
|
if Zxx.size == 0 or Zxx.shape[1] == 0:
|
|
return 0.0, 0.0
|
|
|
|
# 计算最新时间点的能量
|
|
current_energy = np.abs(Zxx[:, -1]) ** 2
|
|
|
|
# 动态频段定义
|
|
low_freq_mask = f < self.low_freq_bound
|
|
high_freq_mask = f > self.high_freq_bound
|
|
|
|
# 能量计算
|
|
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
|
|
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
|
|
total_energy = low_energy + high_energy + 1e-8
|
|
|
|
# 趋势强度 = 低频能量占比
|
|
trend_strength = low_energy / total_energy
|
|
|
|
# 计算主导趋势周期 (天)
|
|
dominant_freq = 0.0
|
|
if np.any(low_freq_mask) and low_energy > 0:
|
|
low_energies = current_energy[low_freq_mask]
|
|
max_idx = np.argmax(low_energies)
|
|
dominant_freq = 1.0 / (f[low_freq_mask][max_idx] + 1e-8)
|
|
|
|
return float(trend_strength), float(dominant_freq)
|
|
|
|
def calculate_normalized_volatility(self, highs: np.array, lows: np.array, closes: np.array) -> float:
|
|
"""
|
|
计算标准化波动率 (0-1范围)
|
|
|
|
步骤:
|
|
1. 计算ATR (真实波幅)
|
|
2. 标准化ATR (除以价格)
|
|
3. 归一化到0-1范围 (基于历史波动率)
|
|
"""
|
|
if len(closes) < self.volatility_window + 1:
|
|
return 0.5 # 默认中性值
|
|
|
|
# 1. 计算真实波幅 (TR)
|
|
tr1 = highs[-self.volatility_window - 1:] - lows[-self.volatility_window - 1:]
|
|
tr2 = np.abs(highs[-self.volatility_window - 1:] - np.roll(closes, 1)[-self.volatility_window - 1:])
|
|
tr3 = np.abs(lows[-self.volatility_window - 1:] - np.roll(closes, 1)[-self.volatility_window - 1:])
|
|
tr = np.maximum(tr1, np.maximum(tr2, tr3))
|
|
|
|
# 2. 计算ATR
|
|
atr = np.mean(tr[-self.volatility_window:])
|
|
|
|
# 3. 标准化ATR (除以当前价格)
|
|
current_price = closes[-1]
|
|
normalized_atr = atr / current_price if current_price > 0 else 0.0
|
|
|
|
# 4. 归一化到0-1范围 (基于历史波动率)
|
|
self.volatility_history.append(normalized_atr)
|
|
if len(self.volatility_history) > 1000: # 保留1000个历史值
|
|
self.volatility_history.pop(0)
|
|
|
|
if len(self.volatility_history) < 50: # 需要足够历史数据
|
|
return 0.5
|
|
|
|
# 使用历史50-95百分位进行归一化
|
|
low_percentile = np.percentile(self.volatility_history, 50)
|
|
high_percentile = np.percentile(self.volatility_history, 95)
|
|
|
|
if high_percentile - low_percentile < 1e-8:
|
|
return 0.5
|
|
|
|
# 归一化到0-1范围
|
|
normalized_vol = (normalized_atr - low_percentile) / (high_percentile - low_percentile + 1e-8)
|
|
normalized_vol = max(0.0, min(1.0, normalized_vol)) # 限制在0-1范围内
|
|
|
|
return normalized_vol
|
|
|
|
def evaluate_entry_signal(self, open_price: float, trend_strength: float, dominant_freq: float,
|
|
volatility: float, recent_bars: List[Bar]):
|
|
"""评估波动率条件入场信号"""
|
|
# 仅当趋势强度跨越临界点且有明确周期时入场
|
|
if trend_strength > self.trend_strength_threshold:
|
|
direction = None
|
|
trade_type = ""
|
|
|
|
# 计算价格位置 (短期vs长期均值)
|
|
window_closes = np.array([b.close for b in recent_bars[-self.spectral_window:]], dtype=np.float32)
|
|
short_avg = np.mean(window_closes[-5:])
|
|
long_avg = np.mean(window_closes)
|
|
|
|
# 添加统计显著性过滤
|
|
if abs(short_avg - long_avg) < 0.0005 * long_avg:
|
|
return
|
|
|
|
# 【核心】根据波动率环境决定交易逻辑
|
|
if volatility < self.low_vol_threshold:
|
|
# 低波动环境: 趋势策略
|
|
trade_type = "TREND"
|
|
if "BUY" in self.order_direction and short_avg > long_avg:
|
|
direction = "BUY"
|
|
elif "SELL" in self.order_direction and short_avg < long_avg:
|
|
direction = "SELL"
|
|
|
|
elif volatility > self.high_vol_threshold:
|
|
# 高波动环境: 反转策略
|
|
trade_type = "REVERSAL"
|
|
if "BUY" in self.order_direction and short_avg < long_avg:
|
|
direction = "BUY" # 价格低于均值,预期回归
|
|
elif "SELL" in self.order_direction and short_avg > long_avg:
|
|
direction = "SELL" # 价格高于均值,预期反转
|
|
|
|
else:
|
|
# 中波动环境: 谨慎策略 (需要更强信号)
|
|
trade_type = "CAUTIOUS"
|
|
if trend_strength > 0.9 and "BUY" in self.order_direction and short_avg > long_avg:
|
|
direction = "BUY"
|
|
elif trend_strength > 0.9 and "SELL" in self.order_direction and short_avg < long_avg:
|
|
direction = "SELL"
|
|
|
|
if direction:
|
|
self.log(
|
|
f"Entry: {direction} | Type={trade_type} | Strength={trend_strength:.2f} | "
|
|
f"Volatility={volatility:.2f} | ShortAvg={short_avg:.4f} vs LongAvg={long_avg:.4f}"
|
|
)
|
|
self.send_market_order(direction, self.trade_volume, "OPEN")
|
|
self.entry_time = self.get_current_time()
|
|
self.position_direction = "LONG" if direction == "BUY" else "SHORT"
|
|
|
|
def manage_open_position(self, volume: int, trend_strength: float, volatility: float):
|
|
"""管理持仓:波动率条件退出"""
|
|
# 退出条件1: 趋势强度 < 退出阈值
|
|
if trend_strength < self.exit_threshold:
|
|
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
|
|
self.log(f"Exit (Strength): {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
|
|
self.close_position(direction, abs(volume))
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
return
|
|
|
|
# 退出条件2: 波动率环境突变 (从低波动变为高波动,或反之)
|
|
if self.position_direction == "LONG" and volatility > self.high_vol_threshold * 1.2:
|
|
# 多头仓位在波动率突增时退出
|
|
self.log(
|
|
f"Exit (Volatility Spike): CLOSE_LONG | Volatility={volatility:.2f} > {self.high_vol_threshold * 1.2:.2f}")
|
|
self.close_position("CLOSE_LONG", abs(volume))
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
elif self.position_direction == "SHORT" and volatility > self.high_vol_threshold * 1.2:
|
|
# 空头仓位在波动率突增时退出
|
|
self.log(
|
|
f"Exit (Volatility Spike): CLOSE_SHORT | Volatility={volatility:.2f} > {self.high_vol_threshold * 1.2:.2f}")
|
|
self.close_position("CLOSE_SHORT", abs(volume))
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
|
|
# --- 辅助函数区 ---
|
|
def close_all_positions(self):
|
|
"""强制平仓所有头寸"""
|
|
positions = self.get_current_positions()
|
|
if not positions or self.symbol not in positions or positions[self.symbol] == 0:
|
|
return
|
|
|
|
direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
|
|
self.close_position(direction, abs(positions[self.symbol]))
|
|
if self.enable_log:
|
|
self.log(f"Closed {abs(positions[self.symbol])} contracts")
|
|
|
|
def close_position(self, direction: str, volume: int):
|
|
self.send_market_order(direction, volume, offset="CLOSE")
|
|
|
|
def send_market_order(self, direction: str, volume: int, offset: str):
|
|
order_id = f"{self.symbol}_{direction[-6:]}_{self.order_id_counter}"
|
|
self.order_id_counter += 1
|
|
order = Order(
|
|
id=order_id,
|
|
symbol=self.symbol,
|
|
direction=direction,
|
|
volume=volume,
|
|
price_type="MARKET",
|
|
submitted_time=self.get_current_time(),
|
|
offset=offset
|
|
)
|
|
self.send_order(order)
|
|
|
|
def on_init(self):
|
|
super().on_init()
|
|
self.cancel_all_pending_orders(self.main_symbol)
|
|
if self.enable_log:
|
|
self.log("Strategy initialized. Waiting for volatility-adaptive signals...")
|
|
|
|
def on_rollover(self, old_symbol: str, new_symbol: str):
|
|
super().on_rollover(old_symbol, new_symbol)
|
|
if self.enable_log:
|
|
self.log(f"Rollover: {old_symbol} -> {new_symbol}. Resetting state.")
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
self.last_trend_strength = 0.0
|
|
self.volatility_history = [] # 重置波动率历史 |