feat: 新增因子装饰器系统和完整因子文档
- 添加因子表达式文档,收录180+个因子及数学表达式 - 添加因子实现分析报告,明确ts_*与cs_*算子分类 - 实现装饰器系统:@time_series/@cross_section/@element_wise - 优化API和翻译器以支持新架构
This commit is contained in:
@@ -347,6 +347,30 @@ def sign(x: Union[Node, str]) -> FunctionNode:
|
||||
return FunctionNode("sign", x)
|
||||
|
||||
|
||||
def cos(x: Union[Node, str]) -> FunctionNode:
|
||||
"""余弦函数。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("cos", x)
|
||||
|
||||
|
||||
def sin(x: Union[Node, str]) -> FunctionNode:
|
||||
"""正弦函数。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("sin", x)
|
||||
|
||||
|
||||
def abs(x: Union[Node, str]) -> FunctionNode:
|
||||
"""绝对值。
|
||||
|
||||
|
||||
57
src/factors/decorators.py
Normal file
57
src/factors/decorators.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""函数装饰器 - 用于标记因子函数类型并自动注入 over 处理。
|
||||
|
||||
提供三种装饰器:
|
||||
- @time_series: 时序因子,自动添加 .over("ts_code")
|
||||
- @cross_section: 截面因子,自动添加 .over("trade_date")
|
||||
- @element_wise: 元素级运算,不添加 over
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def time_series(func: Callable) -> Callable:
|
||||
"""标记为时序因子,自动添加 .over('ts_code')。
|
||||
|
||||
用于 ts_* 函数,如 ts_mean, ts_std, ts_corr 等。
|
||||
每个时序计算都按股票代码分组,防止跨股票串数据。
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, node):
|
||||
expr = func(self, node)
|
||||
return expr.over("ts_code")
|
||||
|
||||
wrapper._factor_type = "ts"
|
||||
return wrapper
|
||||
|
||||
|
||||
def cross_section(func: Callable) -> Callable:
|
||||
"""标记为截面因子,自动添加 .over('trade_date')。
|
||||
|
||||
用于 cs_* 函数,如 cs_rank, cs_zscore 等。
|
||||
每个截面计算都按交易日分组,在同一天的所有股票间计算。
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, node):
|
||||
expr = func(self, node)
|
||||
return expr.over("trade_date")
|
||||
|
||||
wrapper._factor_type = "cs"
|
||||
return wrapper
|
||||
|
||||
|
||||
def element_wise(func: Callable) -> Callable:
|
||||
"""标记为元素级运算,不添加 over。
|
||||
|
||||
用于数学函数,如 log, exp, sqrt, cos, sin 等。
|
||||
这些函数对每个元素独立计算,不需要分组。
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, node):
|
||||
return func(self, node)
|
||||
|
||||
wrapper._factor_type = "element"
|
||||
return wrapper
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.decorators import cross_section, element_wise, time_series
|
||||
from src.factors.dsl import (
|
||||
BinaryOpNode,
|
||||
Constant,
|
||||
@@ -58,6 +59,14 @@ class PolarsTranslator:
|
||||
self.register_handler("cs_zscore", self._handle_cs_zscore)
|
||||
self.register_handler("cs_neutral", self._handle_cs_neutral)
|
||||
|
||||
# 元素级数学函数 (element_wise)
|
||||
self.register_handler("log", self._handle_log)
|
||||
self.register_handler("exp", self._handle_exp)
|
||||
self.register_handler("sqrt", self._handle_sqrt)
|
||||
self.register_handler("sign", self._handle_sign)
|
||||
self.register_handler("cos", self._handle_cos)
|
||||
self.register_handler("sin", self._handle_sin)
|
||||
|
||||
def register_handler(
|
||||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||||
) -> None:
|
||||
@@ -201,109 +210,172 @@ class PolarsTranslator:
|
||||
)
|
||||
|
||||
# ==================== 时序因子处理器 (ts_*) ====================
|
||||
# 所有时序因子强制注入 over("ts_code") 防串表
|
||||
# 所有时序因子使用 @time_series 装饰器自动注入 over("ts_code") 防串表
|
||||
|
||||
@time_series
|
||||
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_mean(close, window) -> rolling_mean(window).over(ts_code)。"""
|
||||
"""处理 ts_mean(close, window) -> rolling_mean(window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_mean(window_size=window).over("ts_code")
|
||||
return expr.rolling_mean(window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_sum(close, window) -> rolling_sum(window).over(ts_code)。"""
|
||||
"""处理 ts_sum(close, window) -> rolling_sum(window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_sum(window_size=window).over("ts_code")
|
||||
return expr.rolling_sum(window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_std(close, window) -> rolling_std(window).over(ts_code)。"""
|
||||
"""处理 ts_std(close, window) -> rolling_std(window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_std(window_size=window).over("ts_code")
|
||||
return expr.rolling_std(window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_max(close, window) -> rolling_max(window).over(ts_code)。"""
|
||||
"""处理 ts_max(close, window) -> rolling_max(window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_max(window_size=window).over("ts_code")
|
||||
return expr.rolling_max(window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_min(close, window) -> rolling_min(window).over(ts_code)。"""
|
||||
"""处理 ts_min(close, window) -> rolling_min(window)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_min(window_size=window).over("ts_code")
|
||||
return expr.rolling_min(window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_delay(close, n) -> shift(n).over(ts_code)。"""
|
||||
"""处理 ts_delay(close, n) -> shift(n)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
return expr.shift(n).over("ts_code")
|
||||
return expr.shift(n)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_delta(close, n) -> (expr - shift(n)).over(ts_code)。"""
|
||||
"""处理 ts_delta(close, n) -> (expr - shift(n))。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
return (expr - expr.shift(n)).over("ts_code")
|
||||
return expr - expr.shift(n)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window).over(ts_code)。"""
|
||||
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window)。"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
|
||||
x = self.translate(node.args[0])
|
||||
y = self.translate(node.args[1])
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_corr(y, window_size=window).over("ts_code")
|
||||
return x.rolling_corr(y, window_size=window)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window).over(ts_code)。"""
|
||||
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window)。"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
|
||||
x = self.translate(node.args[0])
|
||||
y = self.translate(node.args[1])
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_cov(y, window_size=window).over("ts_code")
|
||||
return x.rolling_cov(y, window_size=window)
|
||||
|
||||
# ==================== 截面因子处理器 (cs_*) ====================
|
||||
# 所有截面因子强制注入 over("trade_date") 防串表
|
||||
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
|
||||
|
||||
@cross_section
|
||||
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_rank(expr) -> rank()/count().over(trade_date)。
|
||||
"""处理 cs_rank(expr) -> rank()/count()。
|
||||
|
||||
将排名归一化到 [0, 1] 区间。
|
||||
"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cs_rank 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return (expr.rank() / expr.count()).over("trade_date")
|
||||
return expr.rank() / expr.count()
|
||||
|
||||
@cross_section
|
||||
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_zscore(expr) -> (expr - mean())/std().over(trade_date)。"""
|
||||
"""处理 cs_zscore(expr) -> (expr - mean())/std()。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return ((expr - expr.mean()) / expr.std()).over("trade_date")
|
||||
return (expr - expr.mean()) / expr.std()
|
||||
|
||||
@cross_section
|
||||
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
|
||||
if len(node.args) not in [1, 2]:
|
||||
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
|
||||
expr = self.translate(node.args[0])
|
||||
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
|
||||
return (expr - expr.mean()).over("trade_date")
|
||||
return expr - expr.mean()
|
||||
|
||||
# ==================== 元素级数学函数 (element_wise) ====================
|
||||
# 这些函数对每个元素独立计算,不添加 over
|
||||
|
||||
@element_wise
|
||||
def _handle_log(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 log(expr) -> 自然对数。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("log 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.log()
|
||||
|
||||
@element_wise
|
||||
def _handle_exp(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 exp(expr) -> 指数函数。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("exp 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.exp()
|
||||
|
||||
@element_wise
|
||||
def _handle_sqrt(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 sqrt(expr) -> 平方根。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("sqrt 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.sqrt()
|
||||
|
||||
@element_wise
|
||||
def _handle_sign(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 sign(expr) -> 符号函数。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("sign 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.sign()
|
||||
|
||||
@element_wise
|
||||
def _handle_cos(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cos(expr) -> 余弦函数。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cos 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.cos()
|
||||
|
||||
@element_wise
|
||||
def _handle_sin(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 sin(expr) -> 正弦函数。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("sin 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return expr.sin()
|
||||
|
||||
# ==================== 辅助方法 ====================
|
||||
|
||||
|
||||
Reference in New Issue
Block a user