255 lines
10 KiB
Python
255 lines
10 KiB
Python
import numpy as np
|
||
import talib
|
||
from scipy.signal import stft
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, Any, List, Dict
|
||
|
||
from src.core_data import Bar, Order
|
||
from src.indicators.base_indicators import Indicator
|
||
from src.indicators.indicators import Empty
|
||
from src.strategies.base_strategy import Strategy
|
||
|
||
|
||
class SpectralTrendStrategy(Strategy):
|
||
"""
|
||
频域能量相变策略 - 极简回归版
|
||
|
||
核心哲学:
|
||
1. 频域 (STFT): 负责"判势" —— 现在的市场是震荡(噪音主导)还是趋势(低频主导)?
|
||
2. 时域 (Regression): 负责"定向" —— 这个低频趋势是向上的还是向下的?
|
||
|
||
这种组合避免了频域相位计算的复杂性和不稳定性,回归了量化的本质。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
context: Any,
|
||
main_symbol: str,
|
||
enable_log: bool,
|
||
trade_volume: int,
|
||
# --- 市场参数 ---
|
||
bars_per_day: int = 23,
|
||
# --- 策略参数 ---
|
||
spectral_window_days: float = 2.0,
|
||
low_freq_days: float = 2.0,
|
||
high_freq_days: float = 1.0,
|
||
trend_strength_threshold: float = 0.2, # 强度阈值
|
||
exit_threshold: float = 0.1, # 退出阈值
|
||
slope_threshold: float = 0.0, # 斜率阈值 (0.05表示每根K线移动0.05个标准差)
|
||
max_hold_days: int = 10,
|
||
# --- 其他 ---
|
||
order_direction: Optional[List[str]] = None,
|
||
indicators: Indicator = None,
|
||
model_indicator: Indicator = None,
|
||
reverse: bool = False,
|
||
):
|
||
super().__init__(context, main_symbol, enable_log)
|
||
if order_direction is None:
|
||
order_direction = ['BUY', 'SELL']
|
||
|
||
self.trade_volume = trade_volume
|
||
self.bars_per_day = bars_per_day
|
||
self.spectral_window_days = spectral_window_days
|
||
self.low_freq_days = low_freq_days
|
||
self.high_freq_days = high_freq_days
|
||
self.trend_strength_threshold = trend_strength_threshold
|
||
self.exit_threshold = exit_threshold
|
||
self.slope_threshold = slope_threshold
|
||
self.max_hold_days = max_hold_days
|
||
self.order_direction = order_direction
|
||
self.model_indicator = model_indicator or Empty()
|
||
self.indicators = indicators or Empty()
|
||
self.reverse = reverse
|
||
|
||
# 计算窗口大小
|
||
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
|
||
# 确保偶数 (STFT偏好)
|
||
if self.spectral_window % 2 != 0:
|
||
self.spectral_window += 1
|
||
|
||
# 频率边界
|
||
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
|
||
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
|
||
|
||
self.order_id_counter = 0
|
||
self.entry_time = None
|
||
self.position_direction = None
|
||
|
||
self.log(f"SpectralTrendStrategy (Regression) Init. Window: {self.spectral_window} bars")
|
||
|
||
def on_open_bar(self, open_price: float, symbol: str):
|
||
self.symbol = symbol
|
||
bar_history = self.get_bar_history()
|
||
current_time = self.get_current_time()
|
||
|
||
self.cancel_all_pending_orders(self.main_symbol)
|
||
|
||
if len(bar_history) < self.spectral_window + 5:
|
||
return
|
||
|
||
# 强制平仓检查
|
||
if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days):
|
||
self.close_all_positions()
|
||
self.entry_time = None
|
||
self.position_direction = None
|
||
return
|
||
|
||
# 获取数据并归一化
|
||
closes = np.array([b.close for b in bar_history[-self.spectral_window:]], dtype=float)
|
||
|
||
# 计算核心指标
|
||
trend_strength, trend_slope = self.calculate_market_state(closes)
|
||
|
||
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||
|
||
if self.trading:
|
||
if position_volume == 0:
|
||
self.evaluate_entry_signal(open_price, trend_strength, trend_slope)
|
||
else:
|
||
self.manage_open_position(position_volume, trend_strength, trend_slope)
|
||
|
||
def calculate_market_state(self, prices: np.array) -> (float, float):
|
||
"""
|
||
【显式傅里叶】计算低频能量占比 (完全参数化)
|
||
|
||
步骤:
|
||
1. 价格归一化 (窗口内)
|
||
2. 短时傅里叶变换 (STFT) - 采样率=bars_per_day
|
||
3. 动态计算频段边界 (基于bars_per_day)
|
||
4. 趋势强度 = 低频能量 / (低频+高频能量)
|
||
"""
|
||
# 1. 验证数据长度
|
||
if len(prices) < self.spectral_window:
|
||
return 0.0, 0.0
|
||
|
||
# 2. 价格归一化 (仅使用窗口内数据)
|
||
window_data = prices[-self.spectral_window:]
|
||
normalized = (window_data - np.mean(window_data)) / (np.std(window_data) + 1e-8)
|
||
normalized = normalized[-self.spectral_window:]
|
||
|
||
# 3. STFT (采样率=bars_per_day)
|
||
try:
|
||
# fs: 每天的样本数 (bars_per_day)
|
||
f, t, Zxx = stft(
|
||
normalized,
|
||
fs=self.bars_per_day, # 关键: 适配市场结构
|
||
nperseg=self.spectral_window,
|
||
noverlap=max(0, self.spectral_window // 2),
|
||
boundary=None,
|
||
padded=False
|
||
)
|
||
except Exception as e:
|
||
self.log(f"STFT calculation error: {str(e)}")
|
||
return 0.0, 0.0
|
||
|
||
# 4. 过滤无效频率 (STFT返回频率范围: 0 到 fs/2)
|
||
valid_mask = (f >= 0) & (f <= self.bars_per_day / 2)
|
||
f = f[valid_mask]
|
||
Zxx = Zxx[valid_mask, :]
|
||
|
||
if Zxx.size == 0 or Zxx.shape[1] == 0:
|
||
return 0.0, 0.0
|
||
|
||
# 5. 计算最新时间点的能量
|
||
current_energy = np.abs(Zxx[:, -1]) ** 2
|
||
|
||
# 6. 动态频段定义 (cycles/day)
|
||
# 低频: 周期 > low_freq_days → 频率 < 1/low_freq_days
|
||
low_freq_mask = f < self.low_freq_bound
|
||
# 高频: 周期 < high_freq_days → 频率 > 1/high_freq_days
|
||
high_freq_mask = f > self.high_freq_bound
|
||
|
||
# 7. 能量计算
|
||
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
|
||
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
|
||
total_energy = low_energy + high_energy + 1e-8 # 防除零
|
||
|
||
# 8. 趋势强度 = 低频能量占比
|
||
trend_strength = low_energy / total_energy
|
||
|
||
# --- 3. 时域分析 (Regression) - 只负责"方向" ---
|
||
# 使用最小二乘法拟合一条直线 y = kx + b
|
||
# x 是时间序列 [0, 1, 2...], y 是归一化价格
|
||
# slope 代表:每经过一根K线,价格变化多少个标准差
|
||
x = np.arange(len(normalized))
|
||
slope, intercept = np.polyfit(x, normalized, 1)
|
||
|
||
return trend_strength, slope
|
||
|
||
def evaluate_entry_signal(self, open_price: float, trend_strength: float, trend_slope: float):
|
||
"""
|
||
入场逻辑:
|
||
当频域告诉我们"有趋势"(Strength高),且时域告诉我们"方向明确"(Slope陡峭)时入场。
|
||
"""
|
||
# 1. 滤除噪音震荡 (STFT关卡)
|
||
if trend_strength > self.trend_strength_threshold:
|
||
|
||
direction = None
|
||
|
||
# 2. 确认方向 (回归关卡)
|
||
# slope > 0.05 意味着趋势向上且有一定力度
|
||
if "BUY" in self.order_direction and trend_slope > self.slope_threshold:
|
||
direction = "BUY"
|
||
# slope < -0.05 意味着趋势向下且有一定力度
|
||
elif "SELL" in self.order_direction and trend_slope < -self.slope_threshold:
|
||
direction = "SELL"
|
||
|
||
if direction:
|
||
# 辅助指标过滤
|
||
if not self.indicators.is_condition_met(*self.get_indicator_tuple()):
|
||
return
|
||
|
||
# 反向逻辑
|
||
direction = direction
|
||
if not self.model_indicator.is_condition_met(*self.get_indicator_tuple()):
|
||
direction = "SELL" if direction == "BUY" else "BUY"
|
||
if self.reverse:
|
||
direction = "SELL" if direction == "BUY" else "BUY"
|
||
|
||
self.log(f"Signal: {direction} | Strength={trend_strength:.2f} | Slope={trend_slope:.4f}")
|
||
|
||
self.send_limit_order(direction, open_price, self.trade_volume, "OPEN")
|
||
self.entry_time = self.get_current_time()
|
||
self.position_direction = "LONG" if direction == "BUY" else "SHORT"
|
||
|
||
def manage_open_position(self, volume: int, trend_strength: float, trend_slope: float):
|
||
"""
|
||
离场逻辑:
|
||
仅依赖频域能量。只要低频能量依然主导,说明趋势(无论方向)未被破坏。
|
||
一旦能量降到 exit_threshold 以下,说明市场进入混乱/震荡,离场观望。
|
||
"""
|
||
if trend_strength < self.exit_threshold:
|
||
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
|
||
self.log(f"Exit: {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
|
||
self.close_position(direction, abs(volume))
|
||
self.entry_time = None
|
||
self.position_direction = None
|
||
|
||
# --- 交易辅助 ---
|
||
def close_all_positions(self):
|
||
positions = self.get_current_positions()
|
||
if self.symbol in positions and positions[self.symbol] != 0:
|
||
dir = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
|
||
self.close_position(dir, abs(positions[self.symbol]))
|
||
|
||
def close_position(self, direction: str, volume: int):
|
||
self.send_market_order(direction, volume, offset="CLOSE")
|
||
|
||
def send_market_order(self, direction: str, volume: int, offset: str):
|
||
order_id = f"{self.symbol}_{direction}_MKT_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(
|
||
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset
|
||
)
|
||
self.send_order(order)
|
||
|
||
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
|
||
order_id = f"{self.symbol}_{direction}_LMT_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(
|
||
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||
price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset,
|
||
limit_price=limit_price
|
||
)
|
||
self.send_order(order) |