1、卡尔曼策略
This commit is contained in:
@@ -589,4 +589,62 @@ class ROC_MA(Indicator):
|
||||
"""
|
||||
返回指标的唯一名称,用于标识和调试。
|
||||
"""
|
||||
return f"roc_ma_{self.roc_window}_{self.ma_window}"
|
||||
return f"roc_ma_{self.roc_window}_{self.ma_window}"
|
||||
|
||||
from numpy.lib.stride_tricks import sliding_window_view
|
||||
|
||||
class ZScoreATR(Indicator):
|
||||
def __init__(
|
||||
self,
|
||||
atr_window: int = 14,
|
||||
z_window: int = 100,
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.atr_window = atr_window
|
||||
self.z_window = z_window
|
||||
|
||||
def get_values(self, close, open, high, low, volume) -> np.ndarray:
|
||||
n = len(close)
|
||||
min_len = self.atr_window + self.z_window
|
||||
if n < min_len:
|
||||
return np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
# Step 1: 计算 ATR (NumPy array)
|
||||
atr = talib.ATR(high, low, close, timeperiod=self.atr_window) # shape: (n,)
|
||||
|
||||
# Step 2: 只对有效区域计算 z-score
|
||||
start_idx = self.atr_window - 1 # ATR 从这里开始非 NaN
|
||||
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
|
||||
valid_n = len(valid_atr)
|
||||
|
||||
if valid_n < self.z_window:
|
||||
return np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
# Step 3: 使用 sliding_window_view 构造滚动窗口(无数据复制)
|
||||
# windows: shape = (valid_n - z_window + 1, z_window)
|
||||
windows = sliding_window_view(valid_atr, window_shape=self.z_window)
|
||||
|
||||
# Step 4: 向量化计算均值和标准差(沿窗口轴)
|
||||
means = np.mean(windows, axis=1) # shape: (M,)
|
||||
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
|
||||
|
||||
# Step 5: 计算 z-score(当前值是窗口最后一个元素)
|
||||
current_vals = valid_atr[self.z_window - 1:] # 对齐窗口末尾
|
||||
zscores_valid = np.empty_like(valid_atr)
|
||||
zscores_valid[:self.z_window - 1] = np.nan
|
||||
|
||||
# 安全除法:避免除零
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
z = (current_vals - means) / stds
|
||||
zscores_valid[self.z_window - 1:] = np.where(stds > 1e-12, z, 0.0)
|
||||
|
||||
# Step 6: 拼回完整长度(前面 ATR 无效部分为 NaN)
|
||||
result = np.full(n, np.nan, dtype=np.float64)
|
||||
result[start_idx:] = zscores_valid
|
||||
|
||||
return result
|
||||
|
||||
def get_name(self):
|
||||
return f"z_atr_{self.atr_window}_{self.z_window}"
|
||||
Reference in New Issue
Block a user