1250 lines
42 KiB
Python
1250 lines
42 KiB
Python
from abc import ABC
|
||
from typing import List, Union, Tuple, Optional
|
||
|
||
import numpy as np
|
||
import talib
|
||
from numpy.lib._stride_tricks_impl import sliding_window_view
|
||
from scipy import stats
|
||
|
||
from src.indicators.base_indicators import Indicator
|
||
|
||
|
||
class Empty(Indicator, ABC):
|
||
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array):
|
||
return []
|
||
|
||
def is_condition_met(self,
|
||
close: np.array,
|
||
open: np.array,
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array):
|
||
return True
|
||
|
||
def get_name(self):
|
||
return "Empty"
|
||
|
||
|
||
class RSI(Indicator):
|
||
"""
|
||
相对强弱指数 (RSI) 指标实现,使用 TA-Lib 简化计算。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
window: int = 14,
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.window = window
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array, # 不使用
|
||
high: np.array, # 不使用
|
||
low: np.array, # 不使用
|
||
volume: np.array,
|
||
) -> np.array: # 不使用
|
||
"""
|
||
根据收盘价列表计算RSI值,使用 TA-Lib。
|
||
Args:
|
||
close (np.array): 收盘价列表。
|
||
其他 OHLCV 参数在此指标中不使用。
|
||
Returns:
|
||
np.array: RSI值列表。如果数据不足,则列表开头为NaN。
|
||
"""
|
||
# 使用 talib.RSI 直接计算
|
||
# 注意:TA-Lib 会在数据不足时自动填充 NaN
|
||
rsi_values = talib.RSI(close, timeperiod=self.window)
|
||
|
||
# 将 numpy 数组转换为 list 并返回
|
||
return rsi_values
|
||
|
||
def get_name(self):
|
||
return f"rsi_{self.window}"
|
||
|
||
|
||
class HistoricalRange(Indicator):
|
||
"""
|
||
历史波动幅度指标:计算过去 N 日的 (最高价 - 最低价) 的简单移动平均。
|
||
"""
|
||
|
||
def __init__(
|
||
self, down_bound: float = None, up_bound: float = None, shift_window: int = 0
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array, # 不使用
|
||
open: np.array, # 不使用
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array,
|
||
) -> np.array: # 不使用
|
||
"""
|
||
根据最高价和最低价列表计算过去 N 日的 (high - low) 值的简单移动平均。
|
||
Args:
|
||
high (np.array): 最高价列表。
|
||
low (np.array): 最低价列表。
|
||
其他 OHLCV 参数在此指标中不使用。
|
||
Returns:
|
||
np.array: 历史波动幅度指标值列表。如果数据不足,则列表开头为NaN。
|
||
"""
|
||
# if not high or not low or len(high) != len(low):
|
||
# print(high, low, len(high), len(low))
|
||
# return []
|
||
|
||
# 计算每日的 (high - low) 范围
|
||
daily_ranges = high - low
|
||
|
||
# 将 numpy 数组转换为 list 并返回
|
||
return daily_ranges
|
||
|
||
def get_name(self):
|
||
return f"range_{self.shift_window}"
|
||
|
||
|
||
class DifferencedVolumeIndicator(Indicator):
|
||
"""
|
||
计算当前交易量与前一交易量的差值。
|
||
volume[t] - volume[t-1]。
|
||
用于识别交易量变化的趋势,常用于平稳化交易量序列。
|
||
"""
|
||
|
||
def __init__(
|
||
self, down_bound: float = None, up_bound: float = None, shift_window: int = 0
|
||
):
|
||
# 差值没有固定上下界,取决于实际交易量
|
||
super().__init__(down_bound, up_bound)
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array, # 不使用
|
||
open: np.array, # 不使用
|
||
high: np.array, # 不使用
|
||
low: np.array, # 不使用
|
||
volume: np.array,
|
||
) -> np.array:
|
||
"""
|
||
根据交易量计算其差分值。
|
||
Args:
|
||
volume (np.array): 交易量列表。
|
||
其他 OHLCV 参数在此指标中不使用。
|
||
Returns:
|
||
np.array: 交易量差分值列表。第一个值为NaN。
|
||
"""
|
||
if not isinstance(volume, np.ndarray) or len(volume) < 2:
|
||
return np.full_like(
|
||
volume if isinstance(volume, np.ndarray) else [], np.nan, dtype=float
|
||
)
|
||
|
||
# 计算相邻交易量的差值
|
||
# np.diff(volume) 会比原数组少一个元素,前面补 NaN
|
||
diff_volume = np.concatenate(([np.nan], np.diff(volume)))
|
||
return diff_volume
|
||
|
||
def get_name(self) -> str:
|
||
return f"differenced_volume_{self.shift_window}"
|
||
|
||
|
||
class StochasticOscillator(Indicator):
|
||
"""
|
||
随机摆动指标 (%K),衡量收盘价在近期价格高低区间内的位置。
|
||
这是一个平稳的动量摆动指标,值域在 [0, 100] 之间。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
fastk_period: int = 14,
|
||
slowk_period: int = 3,
|
||
slowd_period: int = 3, # 在此实现中未使用 slowd,但保留以符合标准
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.fastk_period = fastk_period
|
||
self.slowk_period = slowk_period
|
||
self.slowd_period = slowd_period
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array, # 不使用
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array, # 不使用
|
||
) -> np.array:
|
||
"""
|
||
根据最高价、最低价和收盘价计算随机摆动指标 %K 的值。
|
||
Args:
|
||
high (np.array): 最高价列表。
|
||
low (np.array): 最低价列表。
|
||
close (np.array): 收盘价列表。
|
||
Returns:
|
||
np.array: 慢速 %K 线的值列表。
|
||
"""
|
||
# TA-Lib 的 STOCH 函数返回 slowk 和 slowd 两条线
|
||
# 我们通常使用 slowk 作为主要的摆动指标
|
||
slowk, _ = talib.STOCH(
|
||
high,
|
||
low,
|
||
close,
|
||
fastk_period=self.fastk_period,
|
||
slowk_period=self.slowk_period,
|
||
slowk_matype=0, # 使用 SMA
|
||
slowd_period=self.slowd_period,
|
||
slowd_matype=0, # 使用 SMA
|
||
)
|
||
return slowk
|
||
|
||
def get_name(self):
|
||
return f"stoch_k_{self.fastk_period}_{self.slowk_period}"
|
||
|
||
|
||
class RateOfChange(Indicator):
|
||
"""
|
||
价格变化率 (ROC),衡量当前价格与 N 期前价格的百分比变化。
|
||
这是一个平稳的动量指标,围绕 0 波动。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
window: int = 10,
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.window = window
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array, # 不使用
|
||
high: np.array, # 不使用
|
||
low: np.array, # 不使用
|
||
volume: np.array, # 不使用
|
||
) -> np.array:
|
||
"""
|
||
根据收盘价计算 ROC 值。
|
||
Args:
|
||
close (np.array): 收盘价列表。
|
||
Returns:
|
||
np.array: ROC 值列表。
|
||
"""
|
||
roc_values = talib.ROC(close, timeperiod=self.window)
|
||
return roc_values
|
||
|
||
def get_name(self):
|
||
return f"roc_{self.window}"
|
||
|
||
|
||
class NormalizedATR(Indicator):
|
||
"""
|
||
归一化平均真实波幅 (NATR),即 ATR / Close * 100。
|
||
将绝对波动幅度转换为相对波动百分比,使其成为一个更平稳的波动率指标。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
window: int = 14,
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.window = window
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array, # 不使用
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array, # 不使用
|
||
) -> np.array:
|
||
"""
|
||
根据最高价、最低价和收盘价计算 NATR 值。
|
||
Args:
|
||
high (np.array): 最高价列表。
|
||
low (np.array): 最低价列表。
|
||
close (np.array): 收盘价列表。
|
||
Returns:
|
||
np.array: NATR 值列表。
|
||
"""
|
||
# 使用 TA-Lib 直接计算 NATR
|
||
natr_values = talib.NATR(high, low, close, timeperiod=self.window)
|
||
return natr_values
|
||
|
||
def get_name(self):
|
||
return f"natr_{self.window}"
|
||
|
||
|
||
class ADX(Indicator):
|
||
"""
|
||
平均趋向指标 (ADX),用于衡量趋势的强度而非方向。
|
||
是区分趋势行情和震荡行情的核心过滤指标。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
window: int = 14,
|
||
down_bound: float = None, # 例如,设置 down_bound=25 可过滤出强趋势行情
|
||
up_bound: float = None, # 例如,设置 up_bound=20 可过滤出震荡行情
|
||
shift_window: int = 0,
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.window = window
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array, # 不使用
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array, # 不使用
|
||
) -> np.array:
|
||
"""
|
||
根据最高价、最低价和收盘价计算ADX值。
|
||
"""
|
||
adx_values = talib.ADX(high, low, close, timeperiod=self.window)
|
||
return adx_values
|
||
|
||
def get_name(self):
|
||
return f"adx_{self.window}"
|
||
|
||
|
||
class BollingerBandwidth(Indicator):
|
||
"""
|
||
布林带宽度,计算公式为 (上轨 - 下轨) / 中轨。
|
||
这是一个归一化的波动率指标,用于识别波动性的收缩(Squeeze)和扩张。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
window: int = 20,
|
||
nbdev: float = 2.0, # 标准差倍数
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.window = window
|
||
self.nbdev = nbdev
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array, # 不使用
|
||
high: np.array, # 不使用
|
||
low: np.array, # 不使用
|
||
volume: np.array, # 不使用
|
||
) -> np.array:
|
||
"""
|
||
根据收盘价计算布林带宽度。
|
||
"""
|
||
upper, middle, lower = talib.BBANDS(
|
||
close,
|
||
timeperiod=self.window,
|
||
nbdevup=self.nbdev,
|
||
nbdevdn=self.nbdev,
|
||
matype=0 # 使用SMA
|
||
)
|
||
# 为避免除以0,在 middle 为0或NaN的地方,带宽也设为NaN
|
||
bandwidth = np.full_like(middle, np.nan)
|
||
mask = (middle > 0)
|
||
bandwidth[mask] = (upper[mask] - lower[mask]) / middle[mask] * 100
|
||
return bandwidth
|
||
|
||
def get_name(self):
|
||
return f"bbw_{self.window}_{int(self.nbdev * 10)}"
|
||
|
||
|
||
# ====================================================================
|
||
# 1. 通用版:价格范围与波动率比率 (Price Range to Volatility Ratio)
|
||
# ====================================================================
|
||
class PriceRangeToVolatilityRatio(Indicator):
|
||
"""
|
||
衡量一个n根K线窗口内的价格范围与ATR的比率。
|
||
n_period: 窗口大小。
|
||
atr_period: 计算ATR的周期。
|
||
"""
|
||
|
||
def __init__(self, n_period: int = 3, atr_period: int = 14, down_bound: Optional[float] = None,
|
||
up_bound: Optional[float] = None):
|
||
super().__init__(down_bound, up_bound)
|
||
self.n_period = n_period
|
||
self.atr_period = atr_period
|
||
|
||
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array,
|
||
**kwargs) -> np.array:
|
||
# 计算整个窗口内的价格范围(最高价 - 最低价)
|
||
high_in_window = self._rolling_max(high, self.n_period)
|
||
low_in_window = self._rolling_min(low, self.n_period)
|
||
price_range = high_in_window - low_in_window
|
||
|
||
# 计算ATR
|
||
atr_values = talib.ATR(high, low, close, timeperiod=self.atr_period)
|
||
|
||
# 计算比率
|
||
ratio = price_range / atr_values
|
||
|
||
return ratio
|
||
|
||
def _rolling_max(self, arr: np.array, window: int) -> np.array:
|
||
if len(arr) < window:
|
||
return np.full_like(arr, np.nan)
|
||
|
||
# 创建滑动窗口视图
|
||
view = sliding_window_view(arr, window_shape=window)
|
||
# 对每个窗口求最大值
|
||
rolling_max = np.max(view, axis=1)
|
||
|
||
# 填充结果数组,前面用NaN填充
|
||
result = np.full_like(arr, np.nan)
|
||
result[window - 1:] = rolling_max
|
||
return result
|
||
|
||
def _rolling_min(self, arr: np.array, window: int) -> np.array:
|
||
if len(arr) < window:
|
||
return np.full_like(arr, np.nan)
|
||
|
||
view = sliding_window_view(arr, window_shape=window)
|
||
rolling_min = np.min(view, axis=1)
|
||
|
||
result = np.full_like(arr, np.nan)
|
||
result[window - 1:] = rolling_min
|
||
return result
|
||
|
||
def get_name(self) -> str:
|
||
return f"price_range_to_vol_ratio_n{self.n_period}_atr{self.atr_period}"
|
||
|
||
|
||
# ====================================================================
|
||
# 2. 通用版:动力K线信念度 (Impulse Candle Conviction)
|
||
# ====================================================================
|
||
class ImpulseCandleConviction(Indicator):
|
||
"""
|
||
量化指定K线收盘价在实体中的位置。
|
||
n_period: 窗口大小。
|
||
impulse_index_from_end: 动力K线在窗口中的位置(从末尾数,0为最后一根)。
|
||
"""
|
||
|
||
def __init__(self, n_period: int = 3, impulse_index_from_end: int = 1, down_bound: Optional[float] = None,
|
||
up_bound: Optional[float] = None):
|
||
super().__init__(down_bound, up_bound)
|
||
self.n_period = n_period
|
||
self.impulse_index_from_end = impulse_index_from_end
|
||
if self.impulse_index_from_end >= self.n_period:
|
||
raise ValueError("impulse_index_from_end must be less than n_period")
|
||
|
||
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array,
|
||
**kwargs) -> np.array:
|
||
conviction_values = np.full_like(close, np.nan)
|
||
|
||
# 使用切片获取动力K线的数据
|
||
impulse_high = np.roll(high, -self.impulse_index_from_end)
|
||
impulse_low = np.roll(low, -self.impulse_index_from_end)
|
||
impulse_close = np.roll(close, -self.impulse_index_from_end)
|
||
impulse_open = np.roll(open, -self.impulse_index_from_end)
|
||
|
||
# 检查K线是看涨还是看跌
|
||
is_bullish = impulse_close > impulse_open
|
||
|
||
# 计算K线实体范围
|
||
candle_range = impulse_high - impulse_low
|
||
|
||
# 看涨信念度
|
||
bullish_conviction = (impulse_close - impulse_low) / candle_range
|
||
# 看跌信念度
|
||
bearish_conviction = (impulse_high - impulse_close) / candle_range
|
||
|
||
# 根据看涨看跌应用不同的公式
|
||
conviction_values[is_bullish] = bullish_conviction[is_bullish]
|
||
conviction_values[~is_bullish] = bearish_conviction[~is_bullish]
|
||
|
||
# 确保分母不为0,且只在有效的窗口位置返回结果
|
||
mask = (candle_range > 0)
|
||
conviction_values[~mask] = np.nan
|
||
|
||
# 由于使用了np.roll,需要截取到原始数组的长度
|
||
return conviction_values
|
||
|
||
def get_name(self) -> str:
|
||
return f"conviction_n{self.n_period}_idx{self.impulse_index_from_end}"
|
||
|
||
|
||
# ====================================================================
|
||
# 3. 通用版:相对成交量 (Relative Volume)
|
||
# ====================================================================
|
||
class RelativeVolumeInWindow(Indicator):
|
||
"""
|
||
衡量指定K线的成交量与其前n根K线内的简单移动平均成交量之比。
|
||
n_period: SMA的计算周期。
|
||
impulse_index_from_end: 动力K线在窗口中的位置(从末尾数,0为最后一根)。
|
||
"""
|
||
|
||
def __init__(self, n_period: int = 20, impulse_index_from_end: int = 1, down_bound: Optional[float] = None,
|
||
up_bound: Optional[float] = None):
|
||
super().__init__(down_bound, up_bound)
|
||
self.n_period = n_period
|
||
self.impulse_index_from_end = impulse_index_from_end
|
||
if self.impulse_index_from_end >= self.n_period:
|
||
raise ValueError("impulse_index_from_end must be less than n_period")
|
||
|
||
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array,
|
||
**kwargs) -> np.array:
|
||
# 计算成交量的SMA
|
||
volume_sma = talib.SMA(volume, timeperiod=self.n_period)
|
||
|
||
# 提取指定位置的K线成交量
|
||
impulse_volume = np.roll(volume, -self.impulse_index_from_end)
|
||
|
||
# 提取SMA值
|
||
sma_at_position = np.roll(volume_sma, -self.impulse_index_from_end)
|
||
|
||
relative_volume = np.full_like(volume, np.nan)
|
||
mask = sma_at_position > 0
|
||
relative_volume[mask] = impulse_volume[mask] / sma_at_position[mask]
|
||
|
||
return relative_volume
|
||
|
||
def get_name(self) -> str:
|
||
return f"relative_volume_sma{self.n_period}_idx{self.impulse_index_from_end}"
|
||
|
||
|
||
class ROC_MA(Indicator):
|
||
"""
|
||
变动率的移动平均 (ROC_MA) 指标实现。
|
||
该指标首先计算ROC,然后对其结果应用移动平均,以获得更平滑的动量曲线。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
roc_window: int = 60,
|
||
ma_window: int = 20,
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
"""
|
||
初始化 ROC_MA 指标。
|
||
|
||
Args:
|
||
roc_window (int): 计算ROC所需的回看周期。
|
||
ma_window (int): 对ROC值进行平滑的移动平均周期。
|
||
down_bound (float): (可选) 用于条件判断的下轨。
|
||
up_bound (float): (可选) 用于条件判断的上轨。
|
||
shift_window (int): (可选) 指标值的时间偏移。
|
||
"""
|
||
# 【关键】调用父类的初始化方法
|
||
super().__init__(down_bound, up_bound)
|
||
|
||
self.roc_window = roc_window
|
||
self.ma_window = ma_window
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array,
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array,
|
||
) -> np.array:
|
||
"""
|
||
根据收盘价列表计算 ROC_MA 值。
|
||
|
||
Args:
|
||
close (np.array): 收盘价列表。
|
||
其他 OHLCV 参数在此指标中不使用。
|
||
|
||
Returns:
|
||
np.array: ROC_MA 值列表。如果数据不足,则列表开头为NaN。
|
||
"""
|
||
# 步骤 1: 使用 talib.ROC 计算原始的ROC值
|
||
# TA-Lib 会在数据不足时自动填充 NaN
|
||
roc_values = talib.ROC(close, timeperiod=self.roc_window)
|
||
|
||
# 步骤 2: 对 roc_values 计算移动平均 (SMA)
|
||
# 注意:在计算MA之前,ROC已经产生了一些NaN,TA-Lib的MA函数会处理这些NaN
|
||
# 并产生更多的NaN,这是正常的。
|
||
roc_ma_values = talib.SMA(roc_values, timeperiod=self.ma_window)
|
||
|
||
# 返回最终的 numpy 数组
|
||
return roc_ma_values
|
||
|
||
def get_name(self) -> str:
|
||
"""
|
||
返回指标的唯一名称,用于标识和调试。
|
||
"""
|
||
return f"roc_ma_{self.roc_window}_{self.ma_window}"
|
||
|
||
|
||
from numpy.lib.stride_tricks import sliding_window_view
|
||
|
||
|
||
class ZScoreATR(Indicator):
|
||
def __init__(
|
||
self,
|
||
atr_window: int = 14,
|
||
z_window: int = 100,
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.atr_window = atr_window
|
||
self.z_window = z_window
|
||
|
||
def get_values(self, close, open, high, low, volume) -> np.ndarray:
|
||
n = len(close)
|
||
min_len = self.atr_window + self.z_window
|
||
if n < min_len:
|
||
return np.full(n, np.nan, dtype=np.float64)
|
||
|
||
# Step 1: 计算 ATR (NumPy array)
|
||
atr = talib.ATR(high, low, close, timeperiod=self.atr_window) # shape: (n,)
|
||
|
||
# Step 2: 只对有效区域计算 z-score
|
||
start_idx = self.atr_window - 1 # ATR 从这里开始非 NaN
|
||
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
|
||
valid_n = len(valid_atr)
|
||
|
||
if valid_n < self.z_window:
|
||
return np.full(n, np.nan, dtype=np.float64)
|
||
|
||
# Step 3: 使用 sliding_window_view 构造滚动窗口(无数据复制)
|
||
# windows: shape = (valid_n - z_window + 1, z_window)
|
||
windows = sliding_window_view(valid_atr, window_shape=self.z_window)
|
||
|
||
# Step 4: 向量化计算均值和标准差(沿窗口轴)
|
||
means = np.mean(windows, axis=1) # shape: (M,)
|
||
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
|
||
|
||
# Step 5: 计算 z-score(当前值是窗口最后一个元素)
|
||
current_vals = valid_atr[self.z_window - 1:] # 对齐窗口末尾
|
||
zscores_valid = np.empty_like(valid_atr)
|
||
zscores_valid[:self.z_window - 1] = np.nan
|
||
|
||
# 安全除法:避免除零
|
||
with np.errstate(divide='ignore', invalid='ignore'):
|
||
z = (current_vals - means) / stds
|
||
zscores_valid[self.z_window - 1:] = np.where(stds > 1e-12, z, 0.0)
|
||
|
||
# Step 6: 拼回完整长度(前面 ATR 无效部分为 NaN)
|
||
result = np.full(n, np.nan, dtype=np.float64)
|
||
result[start_idx:] = zscores_valid
|
||
|
||
return result
|
||
|
||
def get_name(self):
|
||
return f"z_atr_{self.atr_window}_{self.z_window}"
|
||
|
||
|
||
from scipy.signal import stft
|
||
|
||
|
||
class FFTTrendStrength(Indicator):
|
||
"""
|
||
傅里叶趋势强度指标 (FFT_TrendStrength)
|
||
|
||
该指标通过短时傅里叶变换(STFT)计算低频能量占比,量化趋势强度。
|
||
低频能量占比越高,趋势越强;当该值在不同波动率环境下变化时,
|
||
往往预示策略转折点。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
spectral_window: int = 46, # 2天×23根/天
|
||
low_freq_days: float = 2.0, # 低频定义下限(天)
|
||
bars_per_day: int = 23, # 每日K线数量
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
"""
|
||
初始化 FFT_TrendStrength 指标。
|
||
|
||
Args:
|
||
spectral_window (int): STFT窗口大小(根K线)
|
||
low_freq_days (float): 低频定义下限(天)
|
||
bars_per_day (int): 每日K线数量
|
||
down_bound (float): (可选) 用于条件判断的下轨
|
||
up_bound (float): (可选) 用于条件判断的上轨
|
||
shift_window (int): (可选) 指标值的时间偏移
|
||
"""
|
||
super().__init__(down_bound, up_bound)
|
||
self.spectral_window = spectral_window
|
||
self.low_freq_days = low_freq_days
|
||
self.bars_per_day = bars_per_day
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array,
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array,
|
||
) -> np.array:
|
||
"""
|
||
计算傅里叶趋势强度值。
|
||
|
||
Args:
|
||
close (np.array): 收盘价列表
|
||
其他参数保留接口兼容性,本指标仅使用close
|
||
|
||
Returns:
|
||
np.array: 趋势强度值列表(0~1),数据不足时为NaN
|
||
"""
|
||
n = len(close)
|
||
trend_strengths = np.full(n, np.nan)
|
||
|
||
# 验证最小数据要求
|
||
min_required = self.spectral_window + 5
|
||
if n < min_required:
|
||
return trend_strengths
|
||
|
||
# 频率边界计算
|
||
low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
|
||
|
||
# 为每个时间点计算趋势强度
|
||
for i in range(min_required - 1, n):
|
||
# 获取窗口内数据
|
||
window_data = close[max(0, i - self.spectral_window + 1): i + 1]
|
||
|
||
# 跳过数据不足的窗口
|
||
if len(window_data) < self.spectral_window:
|
||
continue
|
||
|
||
# 价格归一化
|
||
window_mean = np.mean(window_data)
|
||
window_std = np.std(window_data)
|
||
if window_std < 1e-8:
|
||
continue
|
||
|
||
normalized = (window_data - window_mean) / window_std
|
||
|
||
try:
|
||
# STFT计算
|
||
f, t, Zxx = stft(
|
||
normalized,
|
||
fs=self.bars_per_day,
|
||
nperseg=self.spectral_window,
|
||
noverlap=max(0, self.spectral_window // 2),
|
||
boundary=None,
|
||
padded=False
|
||
)
|
||
|
||
# 频率过滤
|
||
max_freq = self.bars_per_day / 2
|
||
valid_mask = (f >= 0) & (f <= max_freq)
|
||
if not np.any(valid_mask):
|
||
continue
|
||
|
||
f = f[valid_mask]
|
||
Zxx = Zxx[valid_mask, :]
|
||
|
||
if Zxx.shape[1] == 0:
|
||
continue
|
||
|
||
# 能量计算
|
||
current_energy = np.abs(Zxx[:, -1]) ** 2
|
||
low_freq_mask = f < low_freq_bound
|
||
high_freq_mask = f > 1.0 # 高频: <1天周期
|
||
|
||
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
|
||
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
|
||
total_energy = low_energy + high_energy + 1e-8
|
||
|
||
trend_strength = low_energy / total_energy
|
||
trend_strengths[i] = np.clip(trend_strength, 0.0, 1.0)
|
||
|
||
except Exception:
|
||
continue
|
||
|
||
# 应用时间偏移
|
||
if self.shift_window > 0 and len(trend_strengths) > self.shift_window:
|
||
trend_strengths = np.roll(trend_strengths, -self.shift_window)
|
||
trend_strengths[-self.shift_window:] = np.nan
|
||
|
||
return trend_strengths
|
||
|
||
def get_name(self) -> str:
|
||
return f"fft_trend_{self.spectral_window}_{self.low_freq_days}"
|
||
|
||
|
||
class AtrVolatility(Indicator):
|
||
"""
|
||
波动率环境识别指标 (VolatilityRegime)
|
||
|
||
该指标识别当前市场处于高波动还是低波动环境,对策略转折点
|
||
有强预测能力。在低波动环境下,趋势信号往往失效转为反转。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
vol_window: int = 23, # 波动率计算窗口
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
"""
|
||
初始化 VolatilityRegime 指标。
|
||
|
||
Args:
|
||
vol_window (int): ATR波动率计算窗口
|
||
high_vol_threshold (float): 高波动阈值(%),高于此值为高波动环境
|
||
low_vol_threshold (float): 低波动阈值(%),低于此值为低波动环境
|
||
down_bound (float): (可选) 用于条件判断的下轨
|
||
up_bound (float): (可选) 用于条件判断的上轨
|
||
shift_window (int): (可选) 指标值的时间偏移
|
||
"""
|
||
super().__init__(down_bound, up_bound)
|
||
self.vol_window = vol_window
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array,
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array,
|
||
) -> np.array:
|
||
"""
|
||
计算波动率环境指标。
|
||
|
||
返回值含义:
|
||
- 1.0: 高波动环境 (趋势策略有效)
|
||
- 0.0: 中波动环境 (谨慎)
|
||
- -1.0: 低波动环境 (反转策略有效)
|
||
|
||
Args:
|
||
close (np.array): 收盘价列表
|
||
high (np.array): 最高价列表
|
||
low (np.array): 最低价列表
|
||
其他参数保留接口兼容性
|
||
|
||
Returns:
|
||
np.array: 波动率环境标识,数据不足时为NaN
|
||
"""
|
||
n = len(close)
|
||
regimes = np.full(n, np.nan)
|
||
|
||
# 验证最小数据要求
|
||
if n < self.vol_window + 1:
|
||
return regimes
|
||
|
||
# 计算ATR
|
||
try:
|
||
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
|
||
except Exception:
|
||
return regimes
|
||
|
||
# 计算标准化波动率 (%)
|
||
volatility = (atr / close) * 100
|
||
|
||
return volatility
|
||
|
||
def get_name(self) -> str:
|
||
return f"atr_volume_{self.vol_window}"
|
||
|
||
|
||
class FFTPhaseShift(Indicator):
|
||
"""
|
||
傅里叶相位偏移指标 (FFT_PhaseShift)
|
||
|
||
该指标检测频域中主导频率的相位偏移,相位突变往往预示市场
|
||
趋势的转折点。特别适用于捕捉低波动环境下的价格极端位置。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
spectral_window: int = 46, # 2天×23根/天
|
||
dominant_freq_bound: float = 0.5, # 主导频率上限(cycles/day)
|
||
phase_shift_threshold: float = 1.0, # 相位偏移阈值(弧度)
|
||
bars_per_day: int = 23, # 每日K线数量
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
"""
|
||
初始化 FFT_PhaseShift 指标。
|
||
|
||
Args:
|
||
spectral_window (int): STFT窗口大小(根K线)
|
||
dominant_freq_bound (float): 主导频率上限(cycles/day)
|
||
phase_shift_threshold (float): 相位偏移阈值(弧度)
|
||
bars_per_day (int): 每日K线数量
|
||
down_bound (float): (可选) 用于条件判断的下轨
|
||
up_bound (float): (可选) 用于条件判断的上轨
|
||
shift_window (int): (可选) 指标值的时间偏移
|
||
"""
|
||
super().__init__(down_bound, up_bound)
|
||
self.spectral_window = spectral_window
|
||
self.dominant_freq_bound = dominant_freq_bound
|
||
self.phase_shift_threshold = phase_shift_threshold
|
||
self.bars_per_day = bars_per_day
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array,
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array,
|
||
) -> np.array:
|
||
"""
|
||
计算傅里叶相位偏移值。
|
||
|
||
返回值含义:
|
||
- 1.0: 相位正向偏移(可能预示上涨转折)
|
||
- -1.0: 相位负向偏移(可能预示下跌转折)
|
||
- 0.0: 无显著相位偏移
|
||
|
||
Args:
|
||
close (np.array): 收盘价列表
|
||
其他参数保留接口兼容性,本指标仅使用close
|
||
|
||
Returns:
|
||
np.array: 相位偏移标识,数据不足时为NaN
|
||
"""
|
||
n = len(close)
|
||
phase_shifts = np.full(n, np.nan)
|
||
|
||
# 验证最小数据要求
|
||
min_required = self.spectral_window + 5
|
||
if n < min_required:
|
||
return phase_shifts
|
||
|
||
# 为每个时间点计算相位偏移
|
||
prev_phase = None
|
||
|
||
for i in range(min_required - 1, n):
|
||
# 获取窗口内数据
|
||
window_data = close[max(0, i - self.spectral_window + 1): i + 1]
|
||
|
||
if len(window_data) < self.spectral_window:
|
||
continue
|
||
|
||
# 价格归一化
|
||
window_mean = np.mean(window_data)
|
||
window_std = np.std(window_data)
|
||
if window_std < 1e-8:
|
||
continue
|
||
|
||
normalized = (window_data - window_mean) / window_std
|
||
|
||
try:
|
||
# STFT计算
|
||
f, t, Zxx = stft(
|
||
normalized,
|
||
fs=self.bars_per_day,
|
||
nperseg=self.spectral_window,
|
||
noverlap=max(0, self.spectral_window // 2),
|
||
boundary=None,
|
||
padded=False
|
||
)
|
||
|
||
# 频率过滤
|
||
max_freq = self.bars_per_day / 2
|
||
valid_mask = (f >= 0) & (f <= max_freq)
|
||
if not np.any(valid_mask):
|
||
continue
|
||
|
||
f = f[valid_mask]
|
||
Zxx = Zxx[valid_mask, :]
|
||
|
||
if Zxx.shape[1] < 2: # 需要至少两个时间点计算相位变化
|
||
continue
|
||
|
||
# 计算相位
|
||
phases = np.angle(Zxx[:, -1])
|
||
prev_phases = np.angle(Zxx[:, -2])
|
||
|
||
# 找出主导频率(低频)
|
||
low_freq_mask = f < self.dominant_freq_bound
|
||
if not np.any(low_freq_mask):
|
||
continue
|
||
|
||
# 计算主导频率的相位差
|
||
dominant_idx = np.argmax(np.abs(Zxx[low_freq_mask, -1]))
|
||
current_phase = phases[low_freq_mask][dominant_idx]
|
||
prev_dominant_phase = prev_phases[low_freq_mask][dominant_idx]
|
||
|
||
# 计算相位差(考虑2π周期性)
|
||
phase_diff = current_phase - prev_dominant_phase
|
||
phase_diff = (phase_diff + np.pi) % (2 * np.pi) - np.pi
|
||
|
||
# 确定相位偏移方向
|
||
if np.abs(phase_diff) > self.phase_shift_threshold:
|
||
phase_shifts[i] = 1.0 if phase_diff > 0 else -1.0
|
||
else:
|
||
phase_shifts[i] = 0.0
|
||
|
||
prev_phase = current_phase
|
||
|
||
except Exception:
|
||
continue
|
||
|
||
# 应用时间偏移
|
||
if self.shift_window > 0 and len(phase_shifts) > self.shift_window:
|
||
phase_shifts = np.roll(phase_shifts, -self.shift_window)
|
||
phase_shifts[-self.shift_window:] = np.nan
|
||
|
||
return phase_shifts
|
||
|
||
def get_name(self) -> str:
|
||
return f"fft_phase_{self.spectral_window}_{self.dominant_freq_bound}"
|
||
|
||
|
||
class VolatilitySkew(Indicator):
|
||
"""
|
||
波动率偏斜指标 (VolatilitySkew)
|
||
|
||
该指标测量近期波动率分布的偏斜程度,正偏斜表示波动率上升趋势,
|
||
负偏斜表示波动率下降趋势。波动率偏斜的变化往往预示策略逻辑
|
||
的转折点,特别是在低波动环境向高波动环境转换时。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
vol_window: int = 20, # 单期波动率计算窗口
|
||
skew_window: int = 60, # 偏斜计算窗口
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
"""
|
||
初始化 VolatilitySkew 指标。
|
||
|
||
Args:
|
||
vol_window (int): ATR波动率计算窗口
|
||
skew_window (int): 偏斜计算窗口
|
||
positive_threshold (float): 正偏斜阈值
|
||
negative_threshold (float): 负偏斜阈值
|
||
down_bound (float): (可选) 用于条件判断的下轨
|
||
up_bound (float): (可选) 用于条件判断的上轨
|
||
shift_window (int): (可选) 指标值的时间偏移
|
||
"""
|
||
super().__init__(down_bound, up_bound)
|
||
self.vol_window = vol_window
|
||
self.skew_window = skew_window
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array,
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array,
|
||
) -> np.array:
|
||
"""
|
||
计算波动率偏斜指标。
|
||
|
||
返回值含义:
|
||
- 1.0: 正偏斜(波动率上升趋势,可能预示高波动环境到来)
|
||
- -1.0: 负偏斜(波动率下降趋势,可能预示低波动环境到来)
|
||
- 0.0: 无显著偏斜
|
||
|
||
Args:
|
||
close (np.array): 收盘价列表
|
||
high (np.array): 最高价列表
|
||
low (np.array): 最低价列表
|
||
其他参数保留接口兼容性
|
||
|
||
Returns:
|
||
np.array: 波动率偏斜标识,数据不足时为NaN
|
||
"""
|
||
n = len(close)
|
||
skews = np.full(n, np.nan)
|
||
|
||
# 验证最小数据要求
|
||
if n < self.vol_window + self.skew_window:
|
||
return skews
|
||
|
||
# 计算ATR
|
||
try:
|
||
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
|
||
except Exception:
|
||
return skews
|
||
|
||
# 计算标准化波动率 (%)
|
||
volatility = (atr / close) * 100
|
||
|
||
# 计算滚动偏斜
|
||
for i in range(self.vol_window + self.skew_window - 1, n):
|
||
window_vol = volatility[i - self.skew_window + 1: i + 1]
|
||
valid_vol = window_vol[~np.isnan(window_vol)]
|
||
|
||
if len(valid_vol) < self.skew_window * 0.7: # 要求70%有效数据
|
||
continue
|
||
|
||
# 计算偏斜
|
||
skew_value = stats.skew(valid_vol)
|
||
|
||
skews[i] = skew_value
|
||
|
||
return skews
|
||
|
||
def get_name(self) -> str:
|
||
return f"vol_skew_{self.vol_window}_{self.skew_window}"
|
||
|
||
|
||
import numpy as np
|
||
import talib
|
||
from src.indicators.base_indicators import Indicator
|
||
|
||
|
||
class VolatilityTrendRelationship(Indicator):
|
||
"""
|
||
精准修复版:波动率-趋势关系指标
|
||
|
||
仅修复NaN问题:
|
||
1. 保留talib的ATR计算(性能和稳定性更优)
|
||
2. 修复std_val计算中的NaN传播
|
||
3. 添加严格的NaN处理,确保100%数据有效性
|
||
4. 保持原始物理逻辑不变
|
||
|
||
核心修复点:
|
||
- 在计算标准差前过滤NaN值
|
||
- 为平滑后的序列提供安全回退值
|
||
- 确保所有中间步骤处理NaN
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
vol_window: int = 20, # 波动率计算窗口
|
||
price_lag: int = 3, # 价格自相关滞后
|
||
ma_window: int = 5, # 平滑窗口
|
||
down_bound: float = None,
|
||
up_bound: float = None,
|
||
shift_window: int = 0,
|
||
):
|
||
super().__init__(down_bound, up_bound)
|
||
self.vol_window = vol_window
|
||
self.price_lag = price_lag
|
||
self.ma_window = ma_window
|
||
self.shift_window = shift_window
|
||
|
||
def get_values(
|
||
self,
|
||
close: np.array,
|
||
open: np.array,
|
||
high: np.array,
|
||
low: np.array,
|
||
volume: np.array,
|
||
) -> np.array:
|
||
n = len(close)
|
||
relationship = np.full(n, np.nan)
|
||
|
||
# 验证最小数据要求
|
||
min_required = max(self.vol_window, self.price_lag, self.ma_window) + 5
|
||
if n < min_required:
|
||
return relationship
|
||
|
||
# 1. 计算标准化波动率 (使用talib,保持性能)
|
||
try:
|
||
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
|
||
volatility = (atr / close) * 100
|
||
except Exception:
|
||
return relationship
|
||
|
||
# 2. 计算波动率变化率 (安全处理除零)
|
||
vol_change = np.zeros(n)
|
||
for i in range(1, n):
|
||
if volatility[i - 1] > 1e-8:
|
||
vol_change[i] = (volatility[i] - volatility[i - 1]) / volatility[i - 1]
|
||
else:
|
||
vol_change[i] = 0.0
|
||
|
||
# 3. 计算价格自相关 (安全实现)
|
||
returns = np.diff(close, prepend=close[0]) / (close + 1e-8)
|
||
autocorr = np.zeros(n)
|
||
|
||
for i in range(self.price_lag, n):
|
||
if i < self.price_lag * 2:
|
||
continue
|
||
|
||
window_returns = returns[i - self.price_lag * 2:i + 1]
|
||
valid_returns = window_returns[~np.isnan(window_returns)]
|
||
|
||
if len(valid_returns) < self.price_lag * 1.5:
|
||
continue
|
||
|
||
# 计算自相关
|
||
lagged = valid_returns[:-self.price_lag]
|
||
current = valid_returns[self.price_lag:]
|
||
|
||
if len(lagged) == 0 or len(current) == 0:
|
||
continue
|
||
|
||
mean_lagged = np.mean(lagged)
|
||
mean_current = np.mean(current)
|
||
|
||
numerator = np.sum((lagged - mean_lagged) * (current - mean_current))
|
||
denom_lagged = np.sum((lagged - mean_lagged) ** 2)
|
||
denom_current = np.sum((current - mean_current) ** 2)
|
||
|
||
if denom_lagged > 1e-8 and denom_current > 1e-8:
|
||
autocorr[i] = numerator / np.sqrt(denom_lagged * denom_current)
|
||
|
||
# 4. 计算核心关系指标
|
||
raw_relationship = vol_change * autocorr
|
||
|
||
# 5. 平滑处理 (处理NaN)
|
||
smoothed_relationship = np.full(n, np.nan)
|
||
for i in range(self.ma_window - 1, n):
|
||
window = raw_relationship[max(0, i - self.ma_window + 1):i + 1]
|
||
valid_window = window[~np.isnan(window)]
|
||
if len(valid_window) > 0:
|
||
smoothed_relationship[i] = np.mean(valid_window)
|
||
|
||
# 6. 修复关键问题:std_val计算
|
||
# 获取有效数据范围
|
||
valid_mask = ~np.isnan(smoothed_relationship[min_required - 1:])
|
||
if np.any(valid_mask):
|
||
valid_values = smoothed_relationship[min_required - 1:][valid_mask]
|
||
std_val = np.std(valid_values) if len(valid_values) > 1 else 1.0
|
||
else:
|
||
std_val = 1.0 # 安全回退值
|
||
|
||
# 确保std_val不为零
|
||
std_val = max(std_val, 1e-8)
|
||
|
||
# 7. 标准化到稳定范围 (-1, 1)
|
||
for i in range(n):
|
||
if not np.isnan(smoothed_relationship[i]):
|
||
relationship[i] = smoothed_relationship[i] / (std_val * 3.0)
|
||
else:
|
||
relationship[i] = 0.0 # 安全默认值
|
||
|
||
# 8. 截断到合理范围
|
||
relationship = np.clip(relationship, -1.0, 1.0)
|
||
|
||
# 应用时间偏移
|
||
if self.shift_window > 0 and len(relationship) > self.shift_window:
|
||
relationship = np.roll(relationship, -self.shift_window)
|
||
relationship[-self.shift_window:] = np.nan
|
||
|
||
return relationship
|
||
|
||
def get_name(self) -> str:
|
||
return f"vol_trend_rel_{self.vol_window}_{self.price_lag}"
|