144 lines
4.8 KiB
Python
144 lines
4.8 KiB
Python
"""
|
|
技术指标因子模块
|
|
包含基于股票截面的技术指标因子实现
|
|
"""
|
|
|
|
import numpy as np
|
|
import polars as pl
|
|
import talib
|
|
from main.factor.operator_framework import DateWiseFactor, StockWiseFactor
|
|
|
|
|
|
class SMAFactor(StockWiseFactor):
|
|
"""简单移动平均线因子"""
|
|
|
|
def __init__(self, window: int):
|
|
super().__init__(
|
|
name="SMA",
|
|
parameters={"window": window}, # ← 只放数值参数
|
|
required_factor_ids=["close"] # ← 依赖原始列
|
|
)
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
|
window = self.parameters["window"] # ← 直接从 self.parameters 取
|
|
return group_df["close"].rolling_mean(window_size=window).alias(self.factor_id)
|
|
|
|
|
|
class EMAFactor(StockWiseFactor):
|
|
"""指数移动平均线因子"""
|
|
|
|
def __init__(self, window: int):
|
|
super().__init__(
|
|
name="EMA",
|
|
parameters={"window": window},
|
|
required_factor_ids=["close"]
|
|
)
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
|
window = self.parameters["window"]
|
|
return group_df["close"].ewm_mean(span=window).alias(self.factor_id)
|
|
|
|
|
|
class ATRFactor(StockWiseFactor):
|
|
"""平均真实波幅因子"""
|
|
|
|
def __init__(self, window: int):
|
|
super().__init__(
|
|
name="ATR",
|
|
parameters={"window": window},
|
|
required_factor_ids=["high", "low", "close"]
|
|
)
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
|
window = self.parameters["window"]
|
|
# 使用talib计算ATR
|
|
close_array = group_df["close"].to_numpy()
|
|
high_array = group_df["high"].to_numpy()
|
|
low_array = group_df["low"].to_numpy()
|
|
|
|
atr_values = talib.ATR(high_array, low_array, close_array, timeperiod=window)
|
|
return pl.Series(atr_values).alias(self.factor_id)
|
|
|
|
|
|
class OBVFactor(StockWiseFactor):
|
|
"""能量潮指标因子"""
|
|
|
|
def __init__(self):
|
|
super().__init__(
|
|
name="OBV",
|
|
parameters={},
|
|
required_factor_ids=["close", "vol"]
|
|
)
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
|
# 使用talib计算OBV
|
|
close_array = group_df["close"].to_numpy()
|
|
vol_array = group_df["vol"].to_numpy()
|
|
|
|
obv_values = talib.OBV(close_array, vol_array)
|
|
return pl.Series(obv_values).alias(self.factor_id)
|
|
|
|
|
|
class MACDFactor(StockWiseFactor):
|
|
"""MACD指标因子"""
|
|
|
|
def __init__(self, fast_period: int = 12, slow_period: int = 26, signal_period: int = 9):
|
|
super().__init__(
|
|
name="MACD",
|
|
parameters={"fast_period": fast_period, "slow_period": slow_period, "signal_period": signal_period},
|
|
required_factor_ids=["close"]
|
|
)
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
|
fast_period = self.parameters["fast_period"]
|
|
slow_period = self.parameters["slow_period"]
|
|
signal_period = self.parameters["signal_period"]
|
|
|
|
# 使用talib计算MACD
|
|
close_array = group_df["close"].to_numpy()
|
|
|
|
macd, macd_signal, macd_hist = talib.MACD(close_array,
|
|
fastperiod=fast_period,
|
|
slowperiod=slow_period,
|
|
signalperiod=signal_period)
|
|
|
|
# 返回MACD线值
|
|
return pl.Series(macd).alias(self.factor_id)
|
|
|
|
|
|
class RSI_Factor(StockWiseFactor):
|
|
"""RSI相对强弱指数因子"""
|
|
|
|
def __init__(self, window: int = 14):
|
|
super().__init__(
|
|
name="RSI",
|
|
parameters={"window": window},
|
|
required_factor_ids=["close"]
|
|
)
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
|
window = self.parameters["window"]
|
|
# 使用talib计算RSI
|
|
close_array = group_df["close"].to_numpy()
|
|
|
|
rsi_values = talib.RSI(close_array, timeperiod=window)
|
|
return pl.Series(rsi_values).alias(self.factor_id)
|
|
|
|
|
|
|
|
class CrossSectionalRankFactor(DateWiseFactor):
|
|
def __init__(self, column: str, name: str = None, ascending: bool = True):
|
|
self.target_column = column
|
|
self.ascending = ascending
|
|
factor_name = name or f"{column}_rank"
|
|
super().__init__(
|
|
name=factor_name,
|
|
parameters={"column": column, "ascending": ascending},
|
|
required_factor_ids=[column]
|
|
)
|
|
|
|
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
|
values = group_df[self.target_column]
|
|
rank_pct = values.rank(method="average", descending=not self.ascending) / len(values)
|
|
normalized = (rank_pct - 0.5) * 3.46
|
|
return normalized.alias(self.factor_id) |