Files
NewQuant/futures_trading_strategies/ru/Spectral/SpectralTrendStrategy5.py
2025-12-16 00:36:36 +08:00

335 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import talib
from scipy.signal import stft
from datetime import timedelta
from typing import Optional, Any, List
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class SpectralTrendStrategy(Strategy):
"""
频域能量相变策略 - 双因子自适应版 (Dual-Factor Adaptive)
优化逻辑:
不再通过静态参数 reverse 控制方向,而是由两个 Indicator 动态决策:
1. 计算 STFT 基础方向 (Base Direction)。
2. 检查 indicator_primary若满足则采用 Base Direction (顺势/正向)。
3. 若不满足,检查 indicator_secondary若满足则采用 Reverse Direction (逆势/反转)。
4. 若都不满足,保持空仓。
状态追踪:
self.entry_signal_source 会记录当前持仓是 'PRIMARY' 还是 'SECONDARY'
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
# --- 市场参数 ---
bars_per_day: int = 23,
# --- STFT 策略参数 ---
spectral_window_days: float = 2.0,
low_freq_days: float = 2.0,
high_freq_days: float = 1.0,
trend_strength_threshold: float = 0.2,
exit_threshold: float = 0.1,
slope_threshold: float = 0.0,
# --- 关键风控参数 (Chandelier Exit) ---
stop_loss_atr_multiplier: float = 5.0,
stop_loss_atr_period: int = 14,
# --- 信号控制指标 (核心优化) ---
indicator_primary: Indicator = None, # 满足此条件 -> 正向开仓 (reverse=False)
indicator_secondary: Indicator = None, # 满足此条件 -> 反向开仓 (reverse=True)
model_indicator: Indicator = None, # 可选额外的AI模型过滤器
# --- 其他 ---
order_direction: Optional[List[str]] = None,
):
super().__init__(context, main_symbol, enable_log)
if order_direction is None:
order_direction = ['BUY', 'SELL']
self.trade_volume = trade_volume
self.bars_per_day = bars_per_day
# 信号参数
self.spectral_window_days = spectral_window_days
self.low_freq_days = low_freq_days
self.high_freq_days = high_freq_days
self.trend_strength_threshold = trend_strength_threshold
self.exit_threshold = exit_threshold
self.slope_threshold = slope_threshold
# 风控参数
self.sl_atr_multiplier = stop_loss_atr_multiplier
self.sl_atr_period = stop_loss_atr_period
self.order_direction = order_direction
# 初始指标容器 (默认为Empty即永远返回True或False视Empty具体实现而定建议传入具体指标)
self.indicator_primary = indicator_primary or Empty()
self.indicator_secondary = indicator_secondary or Empty()
self.model_indicator = model_indicator or Empty()
# 计算 STFT 窗口大小
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
if self.spectral_window % 2 != 0:
self.spectral_window += 1
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
self.order_id_counter = 0
# --- 持仓状态追踪变量 ---
self.entry_price = 0.0
self.pos_highest = 0.0
self.pos_lowest = 0.0
# 新增:记录开仓信号来源 ('PRIMARY' or 'SECONDARY')
self.entry_signal_source = None
self.log(
f"SpectralTrend Dual-Adaptive Strategy Initialized.\n"
f"Window: {self.spectral_window}, ATR Stop: {self.sl_atr_multiplier}x\n"
f"Primary Ind: {type(self.indicator_primary).__name__}, "
f"Secondary Ind: {type(self.indicator_secondary).__name__}"
)
def on_open_bar(self, open_price: float, symbol: str):
self.symbol = symbol
bar_history = self.get_bar_history()
self.cancel_all_pending_orders(self.main_symbol)
# 1. 数据长度检查
required_len = max(self.spectral_window, self.sl_atr_period + 5)
if len(bar_history) < required_len:
return
# 2. 计算 ATR
atr_window = self.sl_atr_period + 10
highs = np.array([b.high for b in bar_history[-atr_window:]], dtype=float)
lows = np.array([b.low for b in bar_history[-atr_window:]], dtype=float)
closes = np.array([b.close for b in bar_history[-atr_window:]], dtype=float)
try:
atr_values = talib.ATR(highs, lows, closes, timeperiod=self.sl_atr_period)
current_atr = atr_values[-1]
if np.isnan(current_atr): current_atr = 0.0
except Exception as e:
self.log(f"ATR Calc Error: {e}")
current_atr = 0.0
# 3. 计算 STFT 核心指标
stft_closes = np.array([b.close for b in bar_history[-self.spectral_window:]], dtype=float)
trend_strength, trend_slope = self.calculate_market_state(stft_closes)
# 4. 交易逻辑
position_volume = self.get_current_positions().get(self.symbol, 0)
current_high = bar_history[-1].high
current_low = bar_history[-1].low
if self.trading:
if position_volume == 0:
# 重置所有状态
self.pos_highest = 0.0
self.pos_lowest = 0.0
self.entry_price = 0.0
self.entry_signal_source = None # 重置信号来源
self.evaluate_entry_signal(open_price, trend_strength, trend_slope)
else:
self.manage_open_position(
position_volume,
trend_strength,
open_price,
current_atr,
current_high,
current_low
)
def calculate_market_state(self, prices: np.array) -> (float, float):
"""
计算频域能量占比和线性回归斜率
"""
if len(prices) < self.spectral_window:
return 0.0, 0.0
window_data = prices[-self.spectral_window:]
mean_val = np.mean(window_data)
std_val = np.std(window_data)
if std_val == 0: std_val = 1.0
normalized = (window_data - mean_val) / (std_val + 1e-8)
try:
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day,
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
except Exception:
return 0.0, 0.0
valid_mask = (f >= 0) & (f <= self.bars_per_day / 2)
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
if Zxx.size == 0: return 0.0, 0.0
current_energy = np.abs(Zxx[:, -1]) ** 2
low_mask = f < self.low_freq_bound
high_mask = f > self.high_freq_bound
low_energy = np.sum(current_energy[low_mask]) if np.any(low_mask) else 0.0
high_energy = np.sum(current_energy[high_mask]) if np.any(high_mask) else 0.0
total_energy = low_energy + high_energy + 1e-8
trend_strength = low_energy / total_energy
x = np.arange(len(normalized))
slope, _ = np.polyfit(x, normalized, 1)
return trend_strength, slope
def evaluate_entry_signal(self, open_price: float, trend_strength: float, trend_slope: float):
"""
入场逻辑优化:双因子控制
"""
# 1. 基础能量阈值检查
if trend_strength <= self.trend_strength_threshold:
return
# 2. 确定 STFT 原始方向 (Raw Direction)
raw_direction = None
if trend_slope > self.slope_threshold:
raw_direction = "BUY"
elif trend_slope < -self.slope_threshold:
raw_direction = "SELL"
if not raw_direction:
return
# 3. 双指标分支逻辑 (Dual-Indicator Branching)
# 获取指标计算所需的参数 (通常是 bar_history 等,依赖基类实现)
indicator_args = self.get_indicator_tuple()
final_direction = None
source_tag = None
# --- 分支 1: Primary Indicator (优先) ---
# 如果满足主条件 -> 使用原始方向 (Equivalent to reverse=False)
if self.indicator_primary.is_condition_met(*indicator_args):
final_direction = raw_direction
source_tag = "PRIMARY"
# --- 分支 2: Secondary Indicator (备选/Else) ---
# 如果不满足主条件,但满足备选条件 -> 使用反转方向 (Equivalent to reverse=True)
elif self.indicator_secondary.is_condition_met(*indicator_args):
final_direction = "SELL" if raw_direction == "BUY" else "BUY"
source_tag = "SECONDARY"
# --- 分支 3: 都不满足 ---
else:
return # 放弃交易
# 4. 全局模型过滤 (可选)
if not self.model_indicator.is_condition_met(*indicator_args):
return
# 5. 最终方向检查
if final_direction not in self.order_direction:
return
# 6. 执行开仓
self.log(
f"Entry Triggered [{source_tag}]: "
f"Raw={raw_direction} -> Final={final_direction} | "
f"Strength={trend_strength:.2f} | Slope={trend_slope:.4f}"
)
self.send_limit_order(final_direction, open_price, self.trade_volume, "OPEN")
# 更新状态
self.entry_price = open_price
self.pos_highest = open_price
self.pos_lowest = open_price
self.entry_signal_source = source_tag # 保存是由哪一个条件控制的
def manage_open_position(self, volume: int, trend_strength: float, current_price: float,
current_atr: float, high_price: float, low_price: float):
"""
离场逻辑 (保持不变,但日志中可以体现来源)
"""
exit_dir = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
# 更新极值
if volume > 0:
if high_price > self.pos_highest or self.pos_highest == 0:
self.pos_highest = high_price
else:
if (low_price < self.pos_lowest or self.pos_lowest == 0) and low_price > 0:
self.pos_lowest = low_price
# 1. 结构性止损 (Chandelier Exit)
is_stop_triggered = False
stop_line = 0.0
if current_atr > 0:
stop_distance = current_atr * self.sl_atr_multiplier
if volume > 0:
stop_line = self.pos_highest - stop_distance
if current_price <= stop_line:
is_stop_triggered = True
else:
stop_line = self.pos_lowest + stop_distance
if current_price >= stop_line:
is_stop_triggered = True
if is_stop_triggered:
self.log(
f"STOP Loss ({self.entry_signal_source}): Price {current_price} hit Chandelier line {stop_line:.2f}")
self.close_position(exit_dir, abs(volume))
return
# 2. 信号衰竭离场
if trend_strength < self.exit_threshold:
self.log(
f"Exit Signal ({self.entry_signal_source}): Energy Faded {trend_strength:.2f} < {self.exit_threshold}")
self.close_position(exit_dir, abs(volume))
return
# --- 交易辅助 ---
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MKT_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset
)
self.send_order(order)
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_LMT_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset,
limit_price=limit_price
)
self.send_order(order)