1、新增傅里叶策略

2、新增策略管理、策略重启功能
This commit is contained in:
2025-11-20 16:10:16 +08:00
parent 2ae9f2db9e
commit 2c917a467a
19 changed files with 3368 additions and 6643 deletions

View File

@@ -4,19 +4,21 @@ from typing import List, Union, Tuple, Optional
import numpy as np
import talib
from numpy.lib._stride_tricks_impl import sliding_window_view
from scipy import stats
from src.indicators.base_indicators import Indicator
class Empty(Indicator, ABC):
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array):
return []
def is_condition_met(self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array):
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array):
return True
def get_name(self):
@@ -402,7 +404,7 @@ class PriceRangeToVolatilityRatio(Indicator):
return ratio
def _rolling_max(self,arr: np.array, window: int) -> np.array:
def _rolling_max(self, arr: np.array, window: int) -> np.array:
if len(arr) < window:
return np.full_like(arr, np.nan)
@@ -591,15 +593,17 @@ class ROC_MA(Indicator):
"""
return f"roc_ma_{self.roc_window}_{self.ma_window}"
from numpy.lib.stride_tricks import sliding_window_view
class ZScoreATR(Indicator):
def __init__(
self,
atr_window: int = 14,
z_window: int = 100,
down_bound: float = None,
up_bound: float = None,
self,
atr_window: int = 14,
z_window: int = 100,
down_bound: float = None,
up_bound: float = None,
):
super().__init__(down_bound, up_bound)
self.atr_window = atr_window
@@ -616,7 +620,7 @@ class ZScoreATR(Indicator):
# Step 2: 只对有效区域计算 z-score
start_idx = self.atr_window - 1 # ATR 从这里开始非 NaN
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
valid_n = len(valid_atr)
if valid_n < self.z_window:
@@ -627,8 +631,8 @@ class ZScoreATR(Indicator):
windows = sliding_window_view(valid_atr, window_shape=self.z_window)
# Step 4: 向量化计算均值和标准差(沿窗口轴)
means = np.mean(windows, axis=1) # shape: (M,)
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
means = np.mean(windows, axis=1) # shape: (M,)
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
# Step 5: 计算 z-score当前值是窗口最后一个元素
current_vals = valid_atr[self.z_window - 1:] # 对齐窗口末尾
@@ -647,4 +651,599 @@ class ZScoreATR(Indicator):
return result
def get_name(self):
return f"z_atr_{self.atr_window}_{self.z_window}"
return f"z_atr_{self.atr_window}_{self.z_window}"
from scipy.signal import stft
class FFTTrendStrength(Indicator):
"""
傅里叶趋势强度指标 (FFT_TrendStrength)
该指标通过短时傅里叶变换(STFT)计算低频能量占比,量化趋势强度。
低频能量占比越高,趋势越强;当该值在不同波动率环境下变化时,
往往预示策略转折点。
"""
def __init__(
self,
spectral_window: int = 46, # 2天×23根/天
low_freq_days: float = 2.0, # 低频定义下限(天)
bars_per_day: int = 23, # 每日K线数量
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
"""
初始化 FFT_TrendStrength 指标。
Args:
spectral_window (int): STFT窗口大小(根K线)
low_freq_days (float): 低频定义下限(天)
bars_per_day (int): 每日K线数量
down_bound (float): (可选) 用于条件判断的下轨
up_bound (float): (可选) 用于条件判断的上轨
shift_window (int): (可选) 指标值的时间偏移
"""
super().__init__(down_bound, up_bound)
self.spectral_window = spectral_window
self.low_freq_days = low_freq_days
self.bars_per_day = bars_per_day
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
"""
计算傅里叶趋势强度值。
Args:
close (np.array): 收盘价列表
其他参数保留接口兼容性本指标仅使用close
Returns:
np.array: 趋势强度值列表(0~1)数据不足时为NaN
"""
n = len(close)
trend_strengths = np.full(n, np.nan)
# 验证最小数据要求
min_required = self.spectral_window + 5
if n < min_required:
return trend_strengths
# 频率边界计算
low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
# 为每个时间点计算趋势强度
for i in range(min_required - 1, n):
# 获取窗口内数据
window_data = close[max(0, i - self.spectral_window + 1): i + 1]
# 跳过数据不足的窗口
if len(window_data) < self.spectral_window:
continue
# 价格归一化
window_mean = np.mean(window_data)
window_std = np.std(window_data)
if window_std < 1e-8:
continue
normalized = (window_data - window_mean) / window_std
try:
# STFT计算
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day,
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
# 频率过滤
max_freq = self.bars_per_day / 2
valid_mask = (f >= 0) & (f <= max_freq)
if not np.any(valid_mask):
continue
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
if Zxx.shape[1] == 0:
continue
# 能量计算
current_energy = np.abs(Zxx[:, -1]) ** 2
low_freq_mask = f < low_freq_bound
high_freq_mask = f > 1.0 # 高频: <1天周期
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
total_energy = low_energy + high_energy + 1e-8
trend_strength = low_energy / total_energy
trend_strengths[i] = np.clip(trend_strength, 0.0, 1.0)
except Exception:
continue
# 应用时间偏移
if self.shift_window > 0 and len(trend_strengths) > self.shift_window:
trend_strengths = np.roll(trend_strengths, -self.shift_window)
trend_strengths[-self.shift_window:] = np.nan
return trend_strengths
def get_name(self) -> str:
return f"fft_trend_{self.spectral_window}_{self.low_freq_days}"
class AtrVolatility(Indicator):
"""
波动率环境识别指标 (VolatilityRegime)
该指标识别当前市场处于高波动还是低波动环境,对策略转折点
有强预测能力。在低波动环境下,趋势信号往往失效转为反转。
"""
def __init__(
self,
vol_window: int = 23, # 波动率计算窗口
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
"""
初始化 VolatilityRegime 指标。
Args:
vol_window (int): ATR波动率计算窗口
high_vol_threshold (float): 高波动阈值(%),高于此值为高波动环境
low_vol_threshold (float): 低波动阈值(%),低于此值为低波动环境
down_bound (float): (可选) 用于条件判断的下轨
up_bound (float): (可选) 用于条件判断的上轨
shift_window (int): (可选) 指标值的时间偏移
"""
super().__init__(down_bound, up_bound)
self.vol_window = vol_window
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
"""
计算波动率环境指标。
返回值含义:
- 1.0: 高波动环境 (趋势策略有效)
- 0.0: 中波动环境 (谨慎)
- -1.0: 低波动环境 (反转策略有效)
Args:
close (np.array): 收盘价列表
high (np.array): 最高价列表
low (np.array): 最低价列表
其他参数保留接口兼容性
Returns:
np.array: 波动率环境标识数据不足时为NaN
"""
n = len(close)
regimes = np.full(n, np.nan)
# 验证最小数据要求
if n < self.vol_window + 1:
return regimes
# 计算ATR
try:
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
except Exception:
return regimes
# 计算标准化波动率 (%)
volatility = (atr / close) * 100
return volatility
def get_name(self) -> str:
return f"atr_volume_{self.vol_window}"
class FFTPhaseShift(Indicator):
"""
傅里叶相位偏移指标 (FFT_PhaseShift)
该指标检测频域中主导频率的相位偏移,相位突变往往预示市场
趋势的转折点。特别适用于捕捉低波动环境下的价格极端位置。
"""
def __init__(
self,
spectral_window: int = 46, # 2天×23根/天
dominant_freq_bound: float = 0.5, # 主导频率上限(cycles/day)
phase_shift_threshold: float = 1.0, # 相位偏移阈值(弧度)
bars_per_day: int = 23, # 每日K线数量
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
"""
初始化 FFT_PhaseShift 指标。
Args:
spectral_window (int): STFT窗口大小(根K线)
dominant_freq_bound (float): 主导频率上限(cycles/day)
phase_shift_threshold (float): 相位偏移阈值(弧度)
bars_per_day (int): 每日K线数量
down_bound (float): (可选) 用于条件判断的下轨
up_bound (float): (可选) 用于条件判断的上轨
shift_window (int): (可选) 指标值的时间偏移
"""
super().__init__(down_bound, up_bound)
self.spectral_window = spectral_window
self.dominant_freq_bound = dominant_freq_bound
self.phase_shift_threshold = phase_shift_threshold
self.bars_per_day = bars_per_day
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
"""
计算傅里叶相位偏移值。
返回值含义:
- 1.0: 相位正向偏移(可能预示上涨转折)
- -1.0: 相位负向偏移(可能预示下跌转折)
- 0.0: 无显著相位偏移
Args:
close (np.array): 收盘价列表
其他参数保留接口兼容性本指标仅使用close
Returns:
np.array: 相位偏移标识数据不足时为NaN
"""
n = len(close)
phase_shifts = np.full(n, np.nan)
# 验证最小数据要求
min_required = self.spectral_window + 5
if n < min_required:
return phase_shifts
# 为每个时间点计算相位偏移
prev_phase = None
for i in range(min_required - 1, n):
# 获取窗口内数据
window_data = close[max(0, i - self.spectral_window + 1): i + 1]
if len(window_data) < self.spectral_window:
continue
# 价格归一化
window_mean = np.mean(window_data)
window_std = np.std(window_data)
if window_std < 1e-8:
continue
normalized = (window_data - window_mean) / window_std
try:
# STFT计算
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day,
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
# 频率过滤
max_freq = self.bars_per_day / 2
valid_mask = (f >= 0) & (f <= max_freq)
if not np.any(valid_mask):
continue
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
if Zxx.shape[1] < 2: # 需要至少两个时间点计算相位变化
continue
# 计算相位
phases = np.angle(Zxx[:, -1])
prev_phases = np.angle(Zxx[:, -2])
# 找出主导频率(低频)
low_freq_mask = f < self.dominant_freq_bound
if not np.any(low_freq_mask):
continue
# 计算主导频率的相位差
dominant_idx = np.argmax(np.abs(Zxx[low_freq_mask, -1]))
current_phase = phases[low_freq_mask][dominant_idx]
prev_dominant_phase = prev_phases[low_freq_mask][dominant_idx]
# 计算相位差(考虑2π周期性)
phase_diff = current_phase - prev_dominant_phase
phase_diff = (phase_diff + np.pi) % (2 * np.pi) - np.pi
# 确定相位偏移方向
if np.abs(phase_diff) > self.phase_shift_threshold:
phase_shifts[i] = 1.0 if phase_diff > 0 else -1.0
else:
phase_shifts[i] = 0.0
prev_phase = current_phase
except Exception:
continue
# 应用时间偏移
if self.shift_window > 0 and len(phase_shifts) > self.shift_window:
phase_shifts = np.roll(phase_shifts, -self.shift_window)
phase_shifts[-self.shift_window:] = np.nan
return phase_shifts
def get_name(self) -> str:
return f"fft_phase_{self.spectral_window}_{self.dominant_freq_bound}"
class VolatilitySkew(Indicator):
"""
波动率偏斜指标 (VolatilitySkew)
该指标测量近期波动率分布的偏斜程度,正偏斜表示波动率上升趋势,
负偏斜表示波动率下降趋势。波动率偏斜的变化往往预示策略逻辑
的转折点,特别是在低波动环境向高波动环境转换时。
"""
def __init__(
self,
vol_window: int = 20, # 单期波动率计算窗口
skew_window: int = 60, # 偏斜计算窗口
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
"""
初始化 VolatilitySkew 指标。
Args:
vol_window (int): ATR波动率计算窗口
skew_window (int): 偏斜计算窗口
positive_threshold (float): 正偏斜阈值
negative_threshold (float): 负偏斜阈值
down_bound (float): (可选) 用于条件判断的下轨
up_bound (float): (可选) 用于条件判断的上轨
shift_window (int): (可选) 指标值的时间偏移
"""
super().__init__(down_bound, up_bound)
self.vol_window = vol_window
self.skew_window = skew_window
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
"""
计算波动率偏斜指标。
返回值含义:
- 1.0: 正偏斜(波动率上升趋势,可能预示高波动环境到来)
- -1.0: 负偏斜(波动率下降趋势,可能预示低波动环境到来)
- 0.0: 无显著偏斜
Args:
close (np.array): 收盘价列表
high (np.array): 最高价列表
low (np.array): 最低价列表
其他参数保留接口兼容性
Returns:
np.array: 波动率偏斜标识数据不足时为NaN
"""
n = len(close)
skews = np.full(n, np.nan)
# 验证最小数据要求
if n < self.vol_window + self.skew_window:
return skews
# 计算ATR
try:
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
except Exception:
return skews
# 计算标准化波动率 (%)
volatility = (atr / close) * 100
# 计算滚动偏斜
for i in range(self.vol_window + self.skew_window - 1, n):
window_vol = volatility[i - self.skew_window + 1: i + 1]
valid_vol = window_vol[~np.isnan(window_vol)]
if len(valid_vol) < self.skew_window * 0.7: # 要求70%有效数据
continue
# 计算偏斜
skew_value = stats.skew(valid_vol)
skews[i] = skew_value
return skews
def get_name(self) -> str:
return f"vol_skew_{self.vol_window}_{self.skew_window}"
import numpy as np
import talib
from src.indicators.base_indicators import Indicator
class VolatilityTrendRelationship(Indicator):
"""
精准修复版:波动率-趋势关系指标
仅修复NaN问题
1. 保留talib的ATR计算性能和稳定性更优
2. 修复std_val计算中的NaN传播
3. 添加严格的NaN处理确保100%数据有效性
4. 保持原始物理逻辑不变
核心修复点:
- 在计算标准差前过滤NaN值
- 为平滑后的序列提供安全回退值
- 确保所有中间步骤处理NaN
"""
def __init__(
self,
vol_window: int = 20, # 波动率计算窗口
price_lag: int = 3, # 价格自相关滞后
ma_window: int = 5, # 平滑窗口
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
super().__init__(down_bound, up_bound)
self.vol_window = vol_window
self.price_lag = price_lag
self.ma_window = ma_window
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
n = len(close)
relationship = np.full(n, np.nan)
# 验证最小数据要求
min_required = max(self.vol_window, self.price_lag, self.ma_window) + 5
if n < min_required:
return relationship
# 1. 计算标准化波动率 (使用talib保持性能)
try:
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
volatility = (atr / close) * 100
except Exception:
return relationship
# 2. 计算波动率变化率 (安全处理除零)
vol_change = np.zeros(n)
for i in range(1, n):
if volatility[i - 1] > 1e-8:
vol_change[i] = (volatility[i] - volatility[i - 1]) / volatility[i - 1]
else:
vol_change[i] = 0.0
# 3. 计算价格自相关 (安全实现)
returns = np.diff(close, prepend=close[0]) / (close + 1e-8)
autocorr = np.zeros(n)
for i in range(self.price_lag, n):
if i < self.price_lag * 2:
continue
window_returns = returns[i - self.price_lag * 2:i + 1]
valid_returns = window_returns[~np.isnan(window_returns)]
if len(valid_returns) < self.price_lag * 1.5:
continue
# 计算自相关
lagged = valid_returns[:-self.price_lag]
current = valid_returns[self.price_lag:]
if len(lagged) == 0 or len(current) == 0:
continue
mean_lagged = np.mean(lagged)
mean_current = np.mean(current)
numerator = np.sum((lagged - mean_lagged) * (current - mean_current))
denom_lagged = np.sum((lagged - mean_lagged) ** 2)
denom_current = np.sum((current - mean_current) ** 2)
if denom_lagged > 1e-8 and denom_current > 1e-8:
autocorr[i] = numerator / np.sqrt(denom_lagged * denom_current)
# 4. 计算核心关系指标
raw_relationship = vol_change * autocorr
# 5. 平滑处理 (处理NaN)
smoothed_relationship = np.full(n, np.nan)
for i in range(self.ma_window - 1, n):
window = raw_relationship[max(0, i - self.ma_window + 1):i + 1]
valid_window = window[~np.isnan(window)]
if len(valid_window) > 0:
smoothed_relationship[i] = np.mean(valid_window)
# 6. 修复关键问题std_val计算
# 获取有效数据范围
valid_mask = ~np.isnan(smoothed_relationship[min_required - 1:])
if np.any(valid_mask):
valid_values = smoothed_relationship[min_required - 1:][valid_mask]
std_val = np.std(valid_values) if len(valid_values) > 1 else 1.0
else:
std_val = 1.0 # 安全回退值
# 确保std_val不为零
std_val = max(std_val, 1e-8)
# 7. 标准化到稳定范围 (-1, 1)
for i in range(n):
if not np.isnan(smoothed_relationship[i]):
relationship[i] = smoothed_relationship[i] / (std_val * 3.0)
else:
relationship[i] = 0.0 # 安全默认值
# 8. 截断到合理范围
relationship = np.clip(relationship, -1.0, 1.0)
# 应用时间偏移
if self.shift_window > 0 and len(relationship) > self.shift_window:
relationship = np.roll(relationship, -self.shift_window)
relationship[-self.shift_window:] = np.nan
return relationship
def get_name(self) -> str:
return f"vol_trend_rel_{self.vol_window}_{self.price_lag}"