312 lines
12 KiB
Python
312 lines
12 KiB
Python
import numpy as np
|
||
import talib
|
||
from scipy.signal import stft
|
||
from datetime import timedelta
|
||
from typing import Optional, Any, List
|
||
|
||
from src.core_data import Bar, Order
|
||
from src.indicators.base_indicators import Indicator
|
||
from src.indicators.indicators import Empty
|
||
from src.strategies.base_strategy import Strategy
|
||
|
||
|
||
class SpectralTrendStrategy(Strategy):
|
||
"""
|
||
频域能量相变策略 - 塔勒布宽幅结构版 (Chandelier Exit)
|
||
|
||
修改重点:
|
||
1. 移除 Slope 离场,避免噪音干扰。
|
||
2. 引入状态变量记录持仓期间的极值 (pos_highest/pos_lowest)。
|
||
3. 实施“吊灯止损”:以持仓期间极值为锚点,回撤 N * ATR 离场。
|
||
4. 止损系数建议:4.0 - 6.0 (给予极大的呼吸空间,仅防范趋势崩溃)。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
context: Any,
|
||
main_symbol: str,
|
||
enable_log: bool,
|
||
trade_volume: int,
|
||
# --- 市场参数 ---
|
||
bars_per_day: int = 23,
|
||
# --- STFT 策略参数 ---
|
||
spectral_window_days: float = 2.0,
|
||
low_freq_days: float = 2.0,
|
||
high_freq_days: float = 1.0,
|
||
trend_strength_threshold: float = 0.2, # 入场能量阈值
|
||
exit_threshold: float = 0.1, # 自然能量衰竭离场阈值
|
||
slope_threshold: float = 0.0, # 仅用于判断方向,不用于离场
|
||
|
||
# --- 关键风控参数 (Chandelier Exit) ---
|
||
stop_loss_atr_multiplier: float = 5.0, # 塔勒布式宽止损,建议 4.0 ~ 6.0
|
||
stop_loss_atr_period: int = 14,
|
||
|
||
# --- 其他 ---
|
||
order_direction: Optional[List[str]] = None,
|
||
indicators: Indicator = None,
|
||
model_indicator: Indicator = None,
|
||
reverse: bool = False,
|
||
):
|
||
super().__init__(context, main_symbol, enable_log)
|
||
if order_direction is None:
|
||
order_direction = ['BUY', 'SELL']
|
||
|
||
self.trade_volume = trade_volume
|
||
self.bars_per_day = bars_per_day
|
||
|
||
# 信号参数
|
||
self.spectral_window_days = spectral_window_days
|
||
self.low_freq_days = low_freq_days
|
||
self.high_freq_days = high_freq_days
|
||
self.trend_strength_threshold = trend_strength_threshold
|
||
self.exit_threshold = exit_threshold
|
||
self.slope_threshold = slope_threshold
|
||
|
||
# 风控参数
|
||
self.sl_atr_multiplier = stop_loss_atr_multiplier
|
||
self.sl_atr_period = stop_loss_atr_period
|
||
|
||
self.order_direction = order_direction
|
||
self.model_indicator = model_indicator or Empty()
|
||
self.indicators = indicators or Empty()
|
||
self.reverse = reverse
|
||
|
||
# 计算 STFT 窗口大小
|
||
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
|
||
if self.spectral_window % 2 != 0:
|
||
self.spectral_window += 1
|
||
|
||
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
|
||
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
|
||
|
||
self.order_id_counter = 0
|
||
|
||
# --- 持仓状态追踪变量 ---
|
||
self.entry_price = 0.0
|
||
self.pos_highest = 0.0 # 持有多单期间的最高价
|
||
self.pos_lowest = 0.0 # 持有空单期间的最低价
|
||
|
||
self.log(
|
||
f"SpectralTrend Strategy Initialized. Window: {self.spectral_window}, "
|
||
f"Chandelier Stop: {self.sl_atr_multiplier}x ATR"
|
||
)
|
||
|
||
def on_open_bar(self, open_price: float, symbol: str):
|
||
self.symbol = symbol
|
||
bar_history = self.get_bar_history()
|
||
|
||
self.cancel_all_pending_orders(self.main_symbol)
|
||
|
||
# 1. 数据长度检查
|
||
required_len = max(self.spectral_window, self.sl_atr_period + 5)
|
||
if len(bar_history) < required_len:
|
||
return
|
||
|
||
# 2. 计算 ATR (用于止损)
|
||
atr_window = self.sl_atr_period + 10
|
||
highs = np.array([b.high for b in bar_history[-atr_window:]], dtype=float)
|
||
lows = np.array([b.low for b in bar_history[-atr_window:]], dtype=float)
|
||
closes = np.array([b.close for b in bar_history[-atr_window:]], dtype=float)
|
||
|
||
try:
|
||
atr_values = talib.ATR(highs, lows, closes, timeperiod=self.sl_atr_period)
|
||
current_atr = atr_values[-1]
|
||
if np.isnan(current_atr): current_atr = 0.0
|
||
except Exception as e:
|
||
self.log(f"ATR Calc Error: {e}")
|
||
current_atr = 0.0
|
||
|
||
# 3. 计算 STFT 核心指标
|
||
stft_closes = np.array([b.close for b in bar_history[-self.spectral_window:]], dtype=float)
|
||
trend_strength, trend_slope = self.calculate_market_state(stft_closes)
|
||
|
||
# 4. 交易逻辑
|
||
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||
|
||
# 获取当前Bar的最高/最低价用于更新极值(如果使用Bar内更新,更加灵敏)
|
||
# 这里为了稳健,使用上一根Bar的High/Low来更新,或者使用开盘价近似
|
||
current_high = bar_history[-1].high
|
||
current_low = bar_history[-1].low
|
||
|
||
if self.trading:
|
||
if position_volume == 0:
|
||
# 重置状态
|
||
self.pos_highest = 0.0
|
||
self.pos_lowest = 0.0
|
||
self.entry_price = 0.0
|
||
|
||
self.evaluate_entry_signal(open_price, trend_strength, trend_slope)
|
||
else:
|
||
# 传入 current_high/low 用于更新追踪止损的锚点
|
||
self.manage_open_position(
|
||
position_volume,
|
||
trend_strength,
|
||
open_price,
|
||
current_atr,
|
||
current_high,
|
||
current_low
|
||
)
|
||
|
||
def calculate_market_state(self, prices: np.array) -> (float, float):
|
||
"""
|
||
计算频域能量占比和线性回归斜率(仅用于方向)
|
||
"""
|
||
if len(prices) < self.spectral_window:
|
||
return 0.0, 0.0
|
||
|
||
# 标准化处理,消除绝对价格影响
|
||
window_data = prices[-self.spectral_window:]
|
||
mean_val = np.mean(window_data)
|
||
std_val = np.std(window_data)
|
||
if std_val == 0: std_val = 1.0
|
||
|
||
normalized = (window_data - mean_val) / (std_val + 1e-8)
|
||
|
||
# STFT 计算
|
||
try:
|
||
f, t, Zxx = stft(
|
||
normalized,
|
||
fs=self.bars_per_day,
|
||
nperseg=self.spectral_window,
|
||
noverlap=max(0, self.spectral_window // 2),
|
||
boundary=None,
|
||
padded=False
|
||
)
|
||
except Exception:
|
||
return 0.0, 0.0
|
||
|
||
# 提取有效频率
|
||
valid_mask = (f >= 0) & (f <= self.bars_per_day / 2)
|
||
f = f[valid_mask]
|
||
Zxx = Zxx[valid_mask, :]
|
||
|
||
if Zxx.size == 0: return 0.0, 0.0
|
||
|
||
# 计算能量
|
||
current_energy = np.abs(Zxx[:, -1]) ** 2
|
||
|
||
low_mask = f < self.low_freq_bound
|
||
high_mask = f > self.high_freq_bound
|
||
|
||
low_energy = np.sum(current_energy[low_mask]) if np.any(low_mask) else 0.0
|
||
high_energy = np.sum(current_energy[high_mask]) if np.any(high_mask) else 0.0
|
||
total_energy = low_energy + high_energy + 1e-8
|
||
|
||
trend_strength = low_energy / total_energy
|
||
|
||
# 计算简单线性斜率用于判断方向
|
||
x = np.arange(len(normalized))
|
||
slope, _ = np.polyfit(x, normalized, 1)
|
||
|
||
return trend_strength, slope
|
||
|
||
def evaluate_entry_signal(self, open_price: float, trend_strength: float, trend_slope: float):
|
||
"""
|
||
入场逻辑:高低频能量比 > 阈值 + 斜率确认方向
|
||
"""
|
||
# 必须有足够的趋势强度
|
||
if trend_strength > self.trend_strength_threshold:
|
||
direction = None
|
||
|
||
# 仅使用 slope 的正负号和基本阈值来决定方向,不作为离场依据
|
||
if "BUY" in self.order_direction and trend_slope > self.slope_threshold:
|
||
direction = "BUY"
|
||
elif "SELL" in self.order_direction and trend_slope < -self.slope_threshold:
|
||
direction = "SELL"
|
||
|
||
if direction:
|
||
# 外部过滤条件
|
||
if not self.indicators.is_condition_met(*self.get_indicator_tuple()):
|
||
return
|
||
if not self.model_indicator.is_condition_met(*self.get_indicator_tuple()):
|
||
direction = "SELL" if direction == "BUY" else "BUY"
|
||
if self.reverse:
|
||
direction = "SELL" if direction == "BUY" else "BUY"
|
||
|
||
self.log(f"Entry: {direction} | Strength={trend_strength:.2f} | DirSlope={trend_slope:.4f}")
|
||
|
||
self.send_limit_order(direction, open_price, self.trade_volume, "OPEN")
|
||
|
||
# 初始化持仓状态
|
||
self.entry_price = open_price
|
||
# 初始极值设为当前价格
|
||
self.pos_highest = open_price
|
||
self.pos_lowest = open_price
|
||
|
||
def manage_open_position(self, volume: int, trend_strength: float, current_price: float,
|
||
current_atr: float, high_price: float, low_price: float):
|
||
"""
|
||
离场逻辑核心:
|
||
1. 信号离场:能量自然衰竭 (Trend Strength < Threshold)
|
||
2. 结构离场:Chandelier Exit (吊灯止损) - 宽幅 ATR 追踪
|
||
"""
|
||
|
||
exit_dir = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
|
||
|
||
# --- 更新持仓期间的极值 ---
|
||
if volume > 0: # 多头
|
||
if high_price > self.pos_highest or self.pos_highest == 0:
|
||
self.pos_highest = high_price
|
||
else: # 空头
|
||
if (low_price < self.pos_lowest or self.pos_lowest == 0) and low_price > 0:
|
||
self.pos_lowest = low_price
|
||
|
||
# --- 1. 计算宽幅吊灯止损 (Structural Stop) ---
|
||
is_stop_triggered = False
|
||
stop_line = 0.0
|
||
|
||
# 如果 ATR 无效,暂时不触发止损,或者使用百分比兜底(此处略)
|
||
if current_atr > 0:
|
||
stop_distance = current_atr * self.sl_atr_multiplier
|
||
|
||
if volume > 0:
|
||
# 多头止损线 = 最高价 - N * ATR
|
||
# 随着价格创新高,止损线不断上移;价格下跌,止损线不变
|
||
stop_line = self.pos_highest - stop_distance
|
||
if current_price <= stop_line:
|
||
is_stop_triggered = True
|
||
self.log(
|
||
f"STOP (Long): Price {current_price} <= Highest {self.pos_highest} - {self.sl_atr_multiplier}xATR")
|
||
|
||
else:
|
||
# 空头止损线 = 最低价 + N * ATR
|
||
stop_line = self.pos_lowest + stop_distance
|
||
if current_price >= stop_line:
|
||
is_stop_triggered = True
|
||
self.log(
|
||
f"STOP (Short): Price {current_price} >= Lowest {self.pos_lowest} + {self.sl_atr_multiplier}xATR")
|
||
|
||
if is_stop_triggered:
|
||
self.close_position(exit_dir, abs(volume))
|
||
return # 止损优先
|
||
|
||
# --- 2. 信号自然离场 (Signal Exit) ---
|
||
# 当 STFT 检测到低频能量消散,说明市场进入混沌或震荡,此时平仓
|
||
# 这是一个 "慢" 离场,通常在趋势走完后
|
||
if trend_strength < self.exit_threshold:
|
||
self.log(f"Exit (Signal): Strength {trend_strength:.2f} < {self.exit_threshold}")
|
||
self.close_position(exit_dir, abs(volume))
|
||
return
|
||
|
||
# --- 交易执行辅助函数 ---
|
||
def close_position(self, direction: str, volume: int):
|
||
self.send_market_order(direction, volume, offset="CLOSE")
|
||
|
||
def send_market_order(self, direction: str, volume: int, offset: str):
|
||
order_id = f"{self.symbol}_{direction}_MKT_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(
|
||
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset
|
||
)
|
||
self.send_order(order)
|
||
|
||
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
|
||
order_id = f"{self.symbol}_{direction}_LMT_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(
|
||
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||
price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset,
|
||
limit_price=limit_price
|
||
)
|
||
self.send_order(order) |