Files
NewQuant/futures_trading_strategies/ru/Spectral/SpectralTrendStrategy5.py

335 lines
13 KiB
Python
Raw Normal View History

2025-12-16 00:36:36 +08:00
import numpy as np
import talib
from scipy.signal import stft
from datetime import timedelta
from typing import Optional, Any, List
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class SpectralTrendStrategy(Strategy):
"""
频域能量相变策略 - 双因子自适应版 (Dual-Factor Adaptive)
优化逻辑
不再通过静态参数 reverse 控制方向而是由两个 Indicator 动态决策
1. 计算 STFT 基础方向 (Base Direction)
2. 检查 indicator_primary若满足则采用 Base Direction (顺势/正向)
3. 若不满足检查 indicator_secondary若满足则采用 Reverse Direction (逆势/反转)
4. 若都不满足保持空仓
状态追踪
self.entry_signal_source 会记录当前持仓是 'PRIMARY' 还是 'SECONDARY'
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
# --- 市场参数 ---
bars_per_day: int = 23,
# --- STFT 策略参数 ---
spectral_window_days: float = 2.0,
low_freq_days: float = 2.0,
high_freq_days: float = 1.0,
trend_strength_threshold: float = 0.2,
exit_threshold: float = 0.1,
slope_threshold: float = 0.0,
# --- 关键风控参数 (Chandelier Exit) ---
stop_loss_atr_multiplier: float = 5.0,
stop_loss_atr_period: int = 14,
# --- 信号控制指标 (核心优化) ---
indicator_primary: Indicator = None, # 满足此条件 -> 正向开仓 (reverse=False)
indicator_secondary: Indicator = None, # 满足此条件 -> 反向开仓 (reverse=True)
model_indicator: Indicator = None, # 可选额外的AI模型过滤器
# --- 其他 ---
order_direction: Optional[List[str]] = None,
):
super().__init__(context, main_symbol, enable_log)
if order_direction is None:
order_direction = ['BUY', 'SELL']
self.trade_volume = trade_volume
self.bars_per_day = bars_per_day
# 信号参数
self.spectral_window_days = spectral_window_days
self.low_freq_days = low_freq_days
self.high_freq_days = high_freq_days
self.trend_strength_threshold = trend_strength_threshold
self.exit_threshold = exit_threshold
self.slope_threshold = slope_threshold
# 风控参数
self.sl_atr_multiplier = stop_loss_atr_multiplier
self.sl_atr_period = stop_loss_atr_period
self.order_direction = order_direction
# 初始指标容器 (默认为Empty即永远返回True或False视Empty具体实现而定建议传入具体指标)
self.indicator_primary = indicator_primary or Empty()
self.indicator_secondary = indicator_secondary or Empty()
self.model_indicator = model_indicator or Empty()
# 计算 STFT 窗口大小
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
if self.spectral_window % 2 != 0:
self.spectral_window += 1
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
self.order_id_counter = 0
# --- 持仓状态追踪变量 ---
self.entry_price = 0.0
self.pos_highest = 0.0
self.pos_lowest = 0.0
# 新增:记录开仓信号来源 ('PRIMARY' or 'SECONDARY')
self.entry_signal_source = None
self.log(
f"SpectralTrend Dual-Adaptive Strategy Initialized.\n"
f"Window: {self.spectral_window}, ATR Stop: {self.sl_atr_multiplier}x\n"
f"Primary Ind: {type(self.indicator_primary).__name__}, "
f"Secondary Ind: {type(self.indicator_secondary).__name__}"
)
def on_open_bar(self, open_price: float, symbol: str):
self.symbol = symbol
bar_history = self.get_bar_history()
self.cancel_all_pending_orders(self.main_symbol)
# 1. 数据长度检查
required_len = max(self.spectral_window, self.sl_atr_period + 5)
if len(bar_history) < required_len:
return
# 2. 计算 ATR
atr_window = self.sl_atr_period + 10
highs = np.array([b.high for b in bar_history[-atr_window:]], dtype=float)
lows = np.array([b.low for b in bar_history[-atr_window:]], dtype=float)
closes = np.array([b.close for b in bar_history[-atr_window:]], dtype=float)
try:
atr_values = talib.ATR(highs, lows, closes, timeperiod=self.sl_atr_period)
current_atr = atr_values[-1]
if np.isnan(current_atr): current_atr = 0.0
except Exception as e:
self.log(f"ATR Calc Error: {e}")
current_atr = 0.0
# 3. 计算 STFT 核心指标
stft_closes = np.array([b.close for b in bar_history[-self.spectral_window:]], dtype=float)
trend_strength, trend_slope = self.calculate_market_state(stft_closes)
# 4. 交易逻辑
position_volume = self.get_current_positions().get(self.symbol, 0)
current_high = bar_history[-1].high
current_low = bar_history[-1].low
if self.trading:
if position_volume == 0:
# 重置所有状态
self.pos_highest = 0.0
self.pos_lowest = 0.0
self.entry_price = 0.0
self.entry_signal_source = None # 重置信号来源
self.evaluate_entry_signal(open_price, trend_strength, trend_slope)
else:
self.manage_open_position(
position_volume,
trend_strength,
open_price,
current_atr,
current_high,
current_low
)
def calculate_market_state(self, prices: np.array) -> (float, float):
"""
计算频域能量占比和线性回归斜率
"""
if len(prices) < self.spectral_window:
return 0.0, 0.0
window_data = prices[-self.spectral_window:]
mean_val = np.mean(window_data)
std_val = np.std(window_data)
if std_val == 0: std_val = 1.0
normalized = (window_data - mean_val) / (std_val + 1e-8)
try:
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day,
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
except Exception:
return 0.0, 0.0
valid_mask = (f >= 0) & (f <= self.bars_per_day / 2)
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
if Zxx.size == 0: return 0.0, 0.0
current_energy = np.abs(Zxx[:, -1]) ** 2
low_mask = f < self.low_freq_bound
high_mask = f > self.high_freq_bound
low_energy = np.sum(current_energy[low_mask]) if np.any(low_mask) else 0.0
high_energy = np.sum(current_energy[high_mask]) if np.any(high_mask) else 0.0
total_energy = low_energy + high_energy + 1e-8
trend_strength = low_energy / total_energy
x = np.arange(len(normalized))
slope, _ = np.polyfit(x, normalized, 1)
return trend_strength, slope
def evaluate_entry_signal(self, open_price: float, trend_strength: float, trend_slope: float):
"""
入场逻辑优化双因子控制
"""
# 1. 基础能量阈值检查
if trend_strength <= self.trend_strength_threshold:
return
# 2. 确定 STFT 原始方向 (Raw Direction)
raw_direction = None
if trend_slope > self.slope_threshold:
raw_direction = "BUY"
elif trend_slope < -self.slope_threshold:
raw_direction = "SELL"
if not raw_direction:
return
# 3. 双指标分支逻辑 (Dual-Indicator Branching)
# 获取指标计算所需的参数 (通常是 bar_history 等,依赖基类实现)
indicator_args = self.get_indicator_tuple()
final_direction = None
source_tag = None
# --- 分支 1: Primary Indicator (优先) ---
# 如果满足主条件 -> 使用原始方向 (Equivalent to reverse=False)
if self.indicator_primary.is_condition_met(*indicator_args):
final_direction = raw_direction
source_tag = "PRIMARY"
# --- 分支 2: Secondary Indicator (备选/Else) ---
# 如果不满足主条件,但满足备选条件 -> 使用反转方向 (Equivalent to reverse=True)
elif self.indicator_secondary.is_condition_met(*indicator_args):
final_direction = "SELL" if raw_direction == "BUY" else "BUY"
source_tag = "SECONDARY"
# --- 分支 3: 都不满足 ---
else:
return # 放弃交易
# 4. 全局模型过滤 (可选)
if not self.model_indicator.is_condition_met(*indicator_args):
return
# 5. 最终方向检查
if final_direction not in self.order_direction:
return
# 6. 执行开仓
self.log(
f"Entry Triggered [{source_tag}]: "
f"Raw={raw_direction} -> Final={final_direction} | "
f"Strength={trend_strength:.2f} | Slope={trend_slope:.4f}"
)
self.send_limit_order(final_direction, open_price, self.trade_volume, "OPEN")
# 更新状态
self.entry_price = open_price
self.pos_highest = open_price
self.pos_lowest = open_price
self.entry_signal_source = source_tag # 保存是由哪一个条件控制的
def manage_open_position(self, volume: int, trend_strength: float, current_price: float,
current_atr: float, high_price: float, low_price: float):
"""
离场逻辑 (保持不变但日志中可以体现来源)
"""
exit_dir = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
# 更新极值
if volume > 0:
if high_price > self.pos_highest or self.pos_highest == 0:
self.pos_highest = high_price
else:
if (low_price < self.pos_lowest or self.pos_lowest == 0) and low_price > 0:
self.pos_lowest = low_price
# 1. 结构性止损 (Chandelier Exit)
is_stop_triggered = False
stop_line = 0.0
if current_atr > 0:
stop_distance = current_atr * self.sl_atr_multiplier
if volume > 0:
stop_line = self.pos_highest - stop_distance
if current_price <= stop_line:
is_stop_triggered = True
else:
stop_line = self.pos_lowest + stop_distance
if current_price >= stop_line:
is_stop_triggered = True
if is_stop_triggered:
self.log(
f"STOP Loss ({self.entry_signal_source}): Price {current_price} hit Chandelier line {stop_line:.2f}")
self.close_position(exit_dir, abs(volume))
return
# 2. 信号衰竭离场
if trend_strength < self.exit_threshold:
self.log(
f"Exit Signal ({self.entry_signal_source}): Energy Faded {trend_strength:.2f} < {self.exit_threshold}")
self.close_position(exit_dir, abs(volume))
return
# --- 交易辅助 ---
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MKT_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset
)
self.send_order(order)
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_LMT_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset,
limit_price=limit_price
)
self.send_order(order)