from typing import List, Union import numpy as np import talib from src.indicators.base_indicators import Indicator class RSI(Indicator): """ 相对强弱指数 (RSI) 指标实现,使用 TA-Lib 简化计算。 """ def __init__(self, window: int = 14): """ 初始化RSI指标。 Args: window (int): RSI的计算周期,默认为14。 """ super().__init__() self.window = window def get_values(self, close: np.array, open: np.array, # 不使用 high: np.array, # 不使用 low: np.array, # 不使用 volume: np.array) -> np.array: # 不使用 """ 根据收盘价列表计算RSI值,使用 TA-Lib。 Args: close (np.array): 收盘价列表。 其他 OHLCV 参数在此指标中不使用。 Returns: np.array: RSI值列表。如果数据不足,则列表开头为NaN。 """ # 使用 talib.RSI 直接计算 # 注意:TA-Lib 会在数据不足时自动填充 NaN rsi_values = talib.RSI(close, timeperiod=self.window) # 将 numpy 数组转换为 list 并返回 return rsi_values def get_name(self): return f'rsi_{self.window}' class HistoricalRange(Indicator): """ 历史波动幅度指标:计算过去 N 日的 (最高价 - 最低价) 的简单移动平均。 """ def __init__(self, window: int = 20): """ 初始化历史波动幅度指标。 Args: window (int): 计算范围平均值的周期,默认为20。 """ super().__init__() self.window = window def get_values(self, close: np.array, # 不使用 open: np.array, # 不使用 high: np.array, low: np.array, volume: np.array) -> np.array: # 不使用 """ 根据最高价和最低价列表计算过去 N 日的 (high - low) 值的简单移动平均。 Args: high (np.array): 最高价列表。 low (np.array): 最低价列表。 其他 OHLCV 参数在此指标中不使用。 Returns: np.array: 历史波动幅度指标值列表。如果数据不足,则列表开头为NaN。 """ # if not high or not low or len(high) != len(low): # print(high, low, len(high), len(low)) # return [] # 计算每日的 (high - low) 范围 daily_ranges = high - low # 将 numpy 数组转换为 list 并返回 return np.roll(daily_ranges, self.window) def get_name(self): return f'range_{self.window}'