修复未来函数bug
This commit is contained in:
@@ -8,8 +8,10 @@ from src.core_data import Bar
|
||||
|
||||
class Indicator(ABC):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, down_bound:float = None, up_bound:float = None, shift_window: int = 0):
|
||||
self.down_bound =down_bound
|
||||
self.up_bound =up_bound
|
||||
self.shift_window = shift_window
|
||||
|
||||
@abstractmethod
|
||||
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array):
|
||||
@@ -17,7 +19,27 @@ class Indicator(ABC):
|
||||
|
||||
|
||||
def get_latest_value(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array):
|
||||
return self.get_values(close, open, high, low, volume)[-1].item()
|
||||
values = self.get_values(close, open, high, low, volume)
|
||||
return values[-(self.shift_window + 1)].item() if len(values) > self.shift_window + 1 else None
|
||||
|
||||
|
||||
def is_condition_met(self,
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array):
|
||||
value = self.get_latest_value(close, open, high, low, volume)
|
||||
condition_met = True
|
||||
if value is None:
|
||||
return False
|
||||
if self.up_bound is None and self.down_bound is None:
|
||||
condition_met = False
|
||||
if self.up_bound is not None:
|
||||
condition_met = condition_met and (value < self.up_bound)
|
||||
if self.down_bound is not None:
|
||||
condition_met = condition_met and (value > self.down_bound)
|
||||
return condition_met
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self):
|
||||
|
||||
@@ -1,17 +1,33 @@
|
||||
from src.indicators.indicators import RSI, HistoricalRange
|
||||
|
||||
from src.indicators.indicators import RSI, HistoricalRange, DifferencedVolumeIndicator, StochasticOscillator, \
|
||||
RateOfChange, NormalizedATR
|
||||
|
||||
INDICATOR_LIST = [
|
||||
RSI(5),
|
||||
RSI(7),
|
||||
RSI(10),
|
||||
RSI(14),
|
||||
RSI(15),
|
||||
RSI(20),
|
||||
RSI(25),
|
||||
RSI(30),
|
||||
RSI(35),
|
||||
RSI(40),
|
||||
HistoricalRange(1),
|
||||
HistoricalRange(8),
|
||||
HistoricalRange(15),
|
||||
HistoricalRange(21),
|
||||
]
|
||||
HistoricalRange(shift_window=0),
|
||||
HistoricalRange(shift_window=6),
|
||||
HistoricalRange(shift_window=13),
|
||||
HistoricalRange(shift_window=20),
|
||||
# DifferencedVolumeIndicator(shift_window=0),
|
||||
# DifferencedVolumeIndicator(shift_window=6),
|
||||
# DifferencedVolumeIndicator(shift_window=13),
|
||||
# DifferencedVolumeIndicator(shift_window=20),
|
||||
StochasticOscillator(fastk_period=14, slowd_period=3, slowk_period=3),
|
||||
StochasticOscillator(fastk_period=5, slowd_period=3, slowk_period=3),
|
||||
StochasticOscillator(fastk_period=21, slowd_period=5, slowk_period=5),
|
||||
RateOfChange(window=5),
|
||||
RateOfChange(window=10),
|
||||
RateOfChange(window=15),
|
||||
RateOfChange(window=20),
|
||||
NormalizedATR(window=5),
|
||||
NormalizedATR(window=14),
|
||||
NormalizedATR(window=21)
|
||||
]
|
||||
|
||||
@@ -10,21 +10,25 @@ class RSI(Indicator):
|
||||
相对强弱指数 (RSI) 指标实现,使用 TA-Lib 简化计算。
|
||||
"""
|
||||
|
||||
def __init__(self, window: int = 14):
|
||||
"""
|
||||
初始化RSI指标。
|
||||
Args:
|
||||
window (int): RSI的计算周期,默认为14。
|
||||
"""
|
||||
super().__init__()
|
||||
def __init__(
|
||||
self,
|
||||
window: int = 14,
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.window = window
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(self,
|
||||
close: np.array,
|
||||
open: np.array, # 不使用
|
||||
high: np.array, # 不使用
|
||||
low: np.array, # 不使用
|
||||
volume: np.array) -> np.array: # 不使用
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array, # 不使用
|
||||
high: np.array, # 不使用
|
||||
low: np.array, # 不使用
|
||||
volume: np.array,
|
||||
) -> np.array: # 不使用
|
||||
"""
|
||||
根据收盘价列表计算RSI值,使用 TA-Lib。
|
||||
Args:
|
||||
@@ -39,10 +43,9 @@ class RSI(Indicator):
|
||||
|
||||
# 将 numpy 数组转换为 list 并返回
|
||||
return rsi_values
|
||||
|
||||
|
||||
def get_name(self):
|
||||
return f'rsi_{self.window}'
|
||||
|
||||
return f"rsi_{self.window}"
|
||||
|
||||
|
||||
class HistoricalRange(Indicator):
|
||||
@@ -50,21 +53,20 @@ class HistoricalRange(Indicator):
|
||||
历史波动幅度指标:计算过去 N 日的 (最高价 - 最低价) 的简单移动平均。
|
||||
"""
|
||||
|
||||
def __init__(self, window: int = 20):
|
||||
"""
|
||||
初始化历史波动幅度指标。
|
||||
Args:
|
||||
window (int): 计算范围平均值的周期,默认为20。
|
||||
"""
|
||||
super().__init__()
|
||||
self.window = window
|
||||
def __init__(
|
||||
self, down_bound: float = None, up_bound: float = None, shift_window: int = 0
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(self,
|
||||
close: np.array, # 不使用
|
||||
open: np.array, # 不使用
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array) -> np.array: # 不使用
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array, # 不使用
|
||||
open: np.array, # 不使用
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array,
|
||||
) -> np.array: # 不使用
|
||||
"""
|
||||
根据最高价和最低价列表计算过去 N 日的 (high - low) 值的简单移动平均。
|
||||
Args:
|
||||
@@ -82,7 +84,188 @@ class HistoricalRange(Indicator):
|
||||
daily_ranges = high - low
|
||||
|
||||
# 将 numpy 数组转换为 list 并返回
|
||||
return np.roll(daily_ranges, self.window)
|
||||
return daily_ranges
|
||||
|
||||
def get_name(self):
|
||||
return f'range_{self.window}'
|
||||
return f"range_{self.shift_window}"
|
||||
|
||||
|
||||
class DifferencedVolumeIndicator(Indicator):
|
||||
"""
|
||||
计算当前交易量与前一交易量的差值。
|
||||
volume[t] - volume[t-1]。
|
||||
用于识别交易量变化的趋势,常用于平稳化交易量序列。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, down_bound: float = None, up_bound: float = None, shift_window: int = 0
|
||||
):
|
||||
# 差值没有固定上下界,取决于实际交易量
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array, # 不使用
|
||||
open: np.array, # 不使用
|
||||
high: np.array, # 不使用
|
||||
low: np.array, # 不使用
|
||||
volume: np.array,
|
||||
) -> np.array:
|
||||
"""
|
||||
根据交易量计算其差分值。
|
||||
Args:
|
||||
volume (np.array): 交易量列表。
|
||||
其他 OHLCV 参数在此指标中不使用。
|
||||
Returns:
|
||||
np.array: 交易量差分值列表。第一个值为NaN。
|
||||
"""
|
||||
if not isinstance(volume, np.ndarray) or len(volume) < 2:
|
||||
return np.full_like(
|
||||
volume if isinstance(volume, np.ndarray) else [], np.nan, dtype=float
|
||||
)
|
||||
|
||||
# 计算相邻交易量的差值
|
||||
# np.diff(volume) 会比原数组少一个元素,前面补 NaN
|
||||
diff_volume = np.concatenate(([np.nan], np.diff(volume)))
|
||||
return diff_volume
|
||||
|
||||
def get_name(self) -> str:
|
||||
return f"differenced_volume_{self.shift_window}"
|
||||
|
||||
|
||||
class StochasticOscillator(Indicator):
|
||||
"""
|
||||
随机摆动指标 (%K),衡量收盘价在近期价格高低区间内的位置。
|
||||
这是一个平稳的动量摆动指标,值域在 [0, 100] 之间。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fastk_period: int = 14,
|
||||
slowk_period: int = 3,
|
||||
slowd_period: int = 3, # 在此实现中未使用 slowd,但保留以符合标准
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.fastk_period = fastk_period
|
||||
self.slowk_period = slowk_period
|
||||
self.slowd_period = slowd_period
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array, # 不使用
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array, # 不使用
|
||||
) -> np.array:
|
||||
"""
|
||||
根据最高价、最低价和收盘价计算随机摆动指标 %K 的值。
|
||||
Args:
|
||||
high (np.array): 最高价列表。
|
||||
low (np.array): 最低价列表。
|
||||
close (np.array): 收盘价列表。
|
||||
Returns:
|
||||
np.array: 慢速 %K 线的值列表。
|
||||
"""
|
||||
# TA-Lib 的 STOCH 函数返回 slowk 和 slowd 两条线
|
||||
# 我们通常使用 slowk 作为主要的摆动指标
|
||||
slowk, _ = talib.STOCH(
|
||||
high,
|
||||
low,
|
||||
close,
|
||||
fastk_period=self.fastk_period,
|
||||
slowk_period=self.slowk_period,
|
||||
slowk_matype=0, # 使用 SMA
|
||||
slowd_period=self.slowd_period,
|
||||
slowd_matype=0, # 使用 SMA
|
||||
)
|
||||
return slowk
|
||||
|
||||
def get_name(self):
|
||||
return f"stoch_k_{self.fastk_period}_{self.slowk_period}"
|
||||
|
||||
|
||||
class RateOfChange(Indicator):
|
||||
"""
|
||||
价格变化率 (ROC),衡量当前价格与 N 期前价格的百分比变化。
|
||||
这是一个平稳的动量指标,围绕 0 波动。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window: int = 10,
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.window = window
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array, # 不使用
|
||||
high: np.array, # 不使用
|
||||
low: np.array, # 不使用
|
||||
volume: np.array, # 不使用
|
||||
) -> np.array:
|
||||
"""
|
||||
根据收盘价计算 ROC 值。
|
||||
Args:
|
||||
close (np.array): 收盘价列表。
|
||||
Returns:
|
||||
np.array: ROC 值列表。
|
||||
"""
|
||||
roc_values = talib.ROC(close, timeperiod=self.window)
|
||||
return roc_values
|
||||
|
||||
def get_name(self):
|
||||
return f"roc_{self.window}"
|
||||
|
||||
|
||||
class NormalizedATR(Indicator):
|
||||
"""
|
||||
归一化平均真实波幅 (NATR),即 ATR / Close * 100。
|
||||
将绝对波动幅度转换为相对波动百分比,使其成为一个更平稳的波动率指标。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window: int = 14,
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.window = window
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array, # 不使用
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array, # 不使用
|
||||
) -> np.array:
|
||||
"""
|
||||
根据最高价、最低价和收盘价计算 NATR 值。
|
||||
Args:
|
||||
high (np.array): 最高价列表。
|
||||
low (np.array): 最低价列表。
|
||||
close (np.array): 收盘价列表。
|
||||
Returns:
|
||||
np.array: NATR 值列表。
|
||||
"""
|
||||
# 使用 TA-Lib 直接计算 NATR
|
||||
natr_values = talib.NATR(high, low, close, timeperiod=self.window)
|
||||
return natr_values
|
||||
|
||||
def get_name(self):
|
||||
return f"natr_{self.window}"
|
||||
|
||||
Reference in New Issue
Block a user