feat(factors/translator): 新增 ts_rank 和 if 函数处理器,修复类型注解
- 添加 ts_rank 滚动排名函数(计算分位排名 0-1) - 添加 if 条件选择函数 - 为 map_batches 调用添加 return_dtype=pl.Float64 类型注解 - 修复 ts_wma 边界条件判断
This commit is contained in:
@@ -71,6 +71,7 @@ class PolarsTranslator:
|
||||
self.register_handler("ts_atr", self._handle_ts_atr)
|
||||
self.register_handler("ts_rsi", self._handle_ts_rsi)
|
||||
self.register_handler("ts_obv", self._handle_ts_obv)
|
||||
self.register_handler("ts_rank", self._handle_ts_rank)
|
||||
# 补充的时间序列因子处理器
|
||||
self.register_handler("ts_sma", self._handle_ts_sma)
|
||||
self.register_handler("ts_wma", self._handle_ts_wma)
|
||||
@@ -100,6 +101,8 @@ class PolarsTranslator:
|
||||
self.register_handler("sin", self._handle_sin)
|
||||
self.register_handler("atan", self._handle_atan)
|
||||
self.register_handler("log1p", self._handle_log1p)
|
||||
# 条件函数
|
||||
self.register_handler("if", self._handle_if)
|
||||
|
||||
def register_handler(
|
||||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||||
@@ -413,7 +416,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.struct(
|
||||
[high.alias("h"), low.alias("l"), close.alias("c")]
|
||||
).map_batches(calc_atr)
|
||||
).map_batches(calc_atr, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_rsi(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -437,7 +440,7 @@ class PolarsTranslator:
|
||||
result = talib.RSI(values, timeperiod=window)
|
||||
return pl.Series(result)
|
||||
|
||||
return close.map_batches(calc_rsi)
|
||||
return close.map_batches(calc_rsi, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_obv(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -466,7 +469,39 @@ class PolarsTranslator:
|
||||
result = talib.OBV(c, v)
|
||||
return pl.Series(result)
|
||||
|
||||
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv)
|
||||
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(
|
||||
calc_obv, return_dtype=pl.Float64
|
||||
)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_rank(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_rank(x, window) -> 滚动排名(分位数)。
|
||||
|
||||
计算当前值在过去窗口内的分位排名(0-1之间)。
|
||||
"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_rank 需要 2 个参数: (x, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
|
||||
def rank_calc(s: pl.Series) -> pl.Series:
|
||||
"""计算滚动排名。"""
|
||||
values = s.to_numpy()
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
return pl.Series([float("nan")] * n)
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
for i in range(window - 1, n):
|
||||
window_slice = values[i - window + 1 : i + 1]
|
||||
# 计算分位排名 (0-1)
|
||||
current_value = values[i]
|
||||
rank = np.sum(window_slice <= current_value) / len(window_slice)
|
||||
result[i] = rank
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(rank_calc, return_dtype=pl.Float64)
|
||||
|
||||
# ==================== 补充的时间序列因子处理器 ====================
|
||||
|
||||
@@ -499,7 +534,8 @@ class PolarsTranslator:
|
||||
"""计算线性加权移动平均,使用 numpy 卷积优化。"""
|
||||
values = s.to_numpy()
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
if n == 0 or n < window:
|
||||
# 数据为空或不足以计算完整窗口时,返回全 NaN
|
||||
return pl.Series([float("nan")] * n)
|
||||
|
||||
# 线性递增权重: 1, 2, 3, ..., window
|
||||
@@ -514,7 +550,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(wma_calc)
|
||||
return expr.map_batches(wma_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_decay_linear(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -556,7 +592,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(argmax_calc)
|
||||
return expr.map_batches(argmax_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -586,7 +622,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(argmin_calc)
|
||||
return expr.map_batches(argmin_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_count(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -623,7 +659,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(prod_calc)
|
||||
return expr.map_batches(prod_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_sumac(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -651,6 +687,19 @@ class PolarsTranslator:
|
||||
y = self.translate(node.args[1])
|
||||
return pl.when(x <= y).then(x).otherwise(y)
|
||||
|
||||
@element_wise
|
||||
def _handle_if(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 if(condition, true_val, false_val) -> 条件选择。
|
||||
|
||||
使用 Polars 的 when/then/otherwise 实现条件选择。
|
||||
"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("if 需要 3 个参数: (condition, true_val, false_val)")
|
||||
condition = self.translate(node.args[0])
|
||||
true_val = self.translate(node.args[1])
|
||||
false_val = self.translate(node.args[2])
|
||||
return pl.when(condition).then(true_val).otherwise(false_val)
|
||||
|
||||
# ==================== 截面因子处理器 (cs_*) ====================
|
||||
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
|
||||
|
||||
|
||||
Reference in New Issue
Block a user