Files
ProStock/src/factors/api.py
liaozhaorun f943cc98d0 feat(factors): 添加 cs_mean 函数并增强 max_/min_ 单参数支持
- 新增 cs_mean 截面均值函数,支持 GTJA Alpha127 等因子转换
- max_/min_ 支持单参数调用,默认使用 252 天(约 1 年)滚动窗口
2026-03-15 18:00:48 +08:00

797 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""DSL API 层 - 提供常用的符号和函数。
该模块提供量化因子表达式中常用的符号(如 close, open 等)
和函数(如 ts_mean, cs_rank 等),用户可以直接导入使用。
示例:
>>> from src.factors.api import close, open, ts_mean, cs_rank
>>> expr = ts_mean(close - open, 20) / close
>>> print(expr)
ts_mean(((close - open), 20)) / close
"""
from src.factors.dsl import Symbol, FunctionNode, Node, _ensure_node
from typing import Union
# ==================== 常用价格符号 ====================
#: 收盘价
close = Symbol("close")
#: 开盘价
open = Symbol("open")
#: 最高价
high = Symbol("high")
#: 最低价
low = Symbol("low")
#: 成交量(数据库字段名为 vol
vol = Symbol("vol")
#: 成交额
amount = Symbol("amount")
#: 前收盘价
pre_close = Symbol("pre_close")
#: 涨跌额
change = Symbol("change")
#: 涨跌幅(数据库字段名为 pct_chg
pct_chg = Symbol("pct_chg")
# ==================== 时间序列函数 (ts_*) ====================
def ts_mean(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列均值。
计算给定因子在滚动窗口内的平均值。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
Example:
>>> from src.factors.api import close, ts_mean
>>> expr = ts_mean(close, 20) # 20日收盘价均值
>>> expr = ts_mean("close", 20) # 使用字符串
>>> print(expr)
ts_mean(close, 20)
"""
return FunctionNode("ts_mean", x, window)
def ts_std(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列标准差。
计算给定因子在滚动窗口内的标准差。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_std", x, window)
def ts_max(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列最大值。
计算给定因子在滚动窗口内的最大值。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_max", x, window)
def ts_min(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列最小值。
计算给定因子在滚动窗口内的最小值。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_min", x, window)
def ts_sum(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列求和。
计算给定因子在滚动窗口内的求和。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_sum", x, window)
def ts_delay(x: Union[Node, str], periods: int) -> FunctionNode:
"""时间序列滞后。
获取给定因子在 N 个周期前的值。
Args:
x: 输入因子表达式或字段名字符串
periods: 滞后期数
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_delay", x, periods)
def ts_delta(x: Union[Node, str], periods: int) -> FunctionNode:
"""时间序列差分。
计算给定因子与 N 个周期前的差值。
Args:
x: 输入因子表达式或字段名字符串
periods: 差分期数
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_delta", x, periods)
def ts_corr(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNode:
"""时间序列相关系数。
计算两个因子在滚动窗口内的相关系数。
Args:
x: 第一个因子表达式或字段名字符串
y: 第二个因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_corr", x, y, window)
def ts_cov(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNode:
"""时间序列协方差。
计算两个因子在滚动窗口内的协方差。
Args:
x: 第一个因子表达式或字段名字符串
y: 第二个因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_cov", x, y, window)
def ts_var(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列方差。
计算给定因子在滚动窗口内的方差。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_var", x, window)
def ts_skew(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列偏度。
计算给定因子在滚动窗口内的偏度(三阶矩)。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_skew", x, window)
def ts_kurt(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列峰度。
计算给定因子在滚动窗口内的峰度(四阶矩)。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_kurt", x, window)
def ts_pct_change(x: Union[Node, str], periods: int) -> FunctionNode:
"""时间序列百分比变化。
计算给定因子与 N 个周期前的百分比变化:(x - x.shift(n)) / x.shift(n)。
Args:
x: 输入因子表达式或字段名字符串
periods: 滞后期数
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_pct_change", x, periods)
def ts_ema(x: Union[Node, str], window: int) -> FunctionNode:
"""指数移动平均。
计算给定因子的指数移动平均值。
Args:
x: 输入因子表达式或字段名字符串
window: 指数移动平均的 span 参数
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_ema", x, window)
def ts_atr(
high: Union[Node, str], low: Union[Node, str], close: Union[Node, str], window: int
) -> FunctionNode:
"""平均真实波幅 (Average True Range)。
计算给定窗口内的平均真实波幅,使用 TA-Lib 实现。
Args:
high: 最高价表达式或字段名字符串
low: 最低价表达式或字段名字符串
close: 收盘价表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_atr", high, low, close, window)
def ts_rsi(close: Union[Node, str], window: int) -> FunctionNode:
"""相对强弱指数 (Relative Strength Index)。
计算给定窗口内的 RSI 值,使用 TA-Lib 实现。
Args:
close: 收盘价表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_rsi", close, window)
def ts_obv(close: Union[Node, str], volume: Union[Node, str]) -> FunctionNode:
"""能量潮指标 (On Balance Volume)。
计算 OBV 值,使用 TA-Lib 实现。
Args:
close: 收盘价表达式或字段名字符串
volume: 成交量表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_obv", close, volume)
def ts_rank(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列排名。
计算当前值在过去窗口内的分位排名。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_rank", x, window)
# ==================== 截面函数 (cs_*) ====================
def cs_rank(x: Union[Node, str]) -> FunctionNode:
"""截面排名。
计算因子在横截面上的排名(分位数)。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
Example:
>>> from src.factors.api import close, cs_rank
>>> expr = cs_rank(close) # 收盘价截面排名
>>> expr = cs_rank("close") # 使用字符串
>>> print(expr)
cs_rank(close)
"""
return FunctionNode("cs_rank", x)
def cs_zscore(x: Union[Node, str]) -> FunctionNode:
"""截面标准化 (Z-Score)。
计算因子在横截面上的 Z-Score 标准化值。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("cs_zscore", x)
def cs_neutralize(
x: Union[Node, str], group: Union[Symbol, str, None] = None
) -> FunctionNode:
"""截面中性化。
对因子进行行业/市值中性化处理。
Args:
x: 输入因子表达式或字段名字符串
group: 分组变量(如行业分类),可以为字符串或 Symbol默认为 None
Returns:
FunctionNode: 函数调用节点
"""
if group is not None:
return FunctionNode("cs_neutralize", x, group)
return FunctionNode("cs_neutralize", x)
def cs_winsorize(
x: Union[Node, str], lower: float = 0.01, upper: float = 0.99
) -> FunctionNode:
"""截面缩尾处理。
对因子进行截面缩尾处理,去除极端值。
Args:
x: 输入因子表达式或字段名字符串
lower: 下尾分位数,默认 0.01
upper: 上尾分位数,默认 0.99
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("cs_winsorize", x, lower, upper)
def cs_demean(x: Union[Node, str]) -> FunctionNode:
"""截面去均值。
计算因子在横截面上减去均值。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("cs_demean", x)
def cs_mean(x: Union[Node, str]) -> FunctionNode:
"""截面均值。
计算因子在横截面上的平均值。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
Example:
>>> from src.factors.api import close, cs_mean
>>> expr = cs_mean((close - 100) ** 2)
>>> print(expr)
cs_mean(((close - 100) ** 2))
"""
return FunctionNode("cs_mean", x)
# ==================== 数学函数 ====================
def log(x: Union[Node, str]) -> FunctionNode:
"""自然对数。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("log", x)
def exp(x: Union[Node, str]) -> FunctionNode:
"""指数函数。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("exp", x)
def sqrt(x: Union[Node, str]) -> FunctionNode:
"""平方根。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("sqrt", x)
def sign(x: Union[Node, str]) -> FunctionNode:
"""符号函数。
返回 -1, 0, 1 表示输入值的符号。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("sign", x)
def cos(x: Union[Node, str]) -> FunctionNode:
"""余弦函数。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("cos", x)
def sin(x: Union[Node, str]) -> FunctionNode:
"""正弦函数。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("sin", x)
def abs(x: Union[Node, str]) -> FunctionNode:
"""绝对值。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("abs", x)
def max_(
x: Union[Node, str], y: Union[Node, str, int, float, None] = None
) -> FunctionNode:
"""最大值。
智能分发逻辑:
- 单参数:调用 ts_max(x, 252) 计算滚动窗口最大值(默认 252 天≈1年
- 如果 y 是正整数 (y > 0),调用 ts_max(x, y) 滚动窗口最大值
- 否则,调用逐元素 max(x, y)
注意:避免 MAX(CLOSE - DELAY(CLOSE, 1), 0) 这类场景被错误路由到 ts_max
Args:
x: 第一个因子表达式或字段名字符串,或单参数时的输入序列
y: 可选,第二个因子表达式、字段名字符串或正整数(窗口大小)
Returns:
FunctionNode: 函数调用节点
"""
if y is None:
# 单参数:默认使用 252 天(约 1 年交易日)窗口
return ts_max(x, 252)
if isinstance(y, int) and y > 0:
return ts_max(x, y)
return FunctionNode("max", x, _ensure_node(y))
def min_(
x: Union[Node, str], y: Union[Node, str, int, float, None] = None
) -> FunctionNode:
"""最小值。
智能分发逻辑:
- 单参数:调用 ts_min(x, 252) 计算滚动窗口最小值(默认 252 天≈1年
- 如果 y 是正整数 (y > 0),调用 ts_min(x, y) 滚动窗口最小值
- 否则,调用逐元素 min(x, y)
Args:
x: 第一个因子表达式或字段名字符串,或单参数时的输入序列
y: 可选,第二个因子表达式、字段名字符串或正整数(窗口大小)
Returns:
FunctionNode: 函数调用节点
"""
if y is None:
# 单参数:默认使用 252 天(约 1 年交易日)窗口
return ts_min(x, 252)
if isinstance(y, int) and y > 0:
return ts_min(x, y)
return FunctionNode("min", x, _ensure_node(y))
def clip(
x: Union[Node, str],
lower: Union[Node, str, int, float],
upper: Union[Node, str, int, float],
) -> FunctionNode:
"""数值裁剪。
将因子值限制在 [lower, upper] 范围内。
Args:
x: 输入因子表达式或字段名字符串
lower: 下限(因子表达式、字段名字符串或数值)
upper: 上限(因子表达式、字段名字符串或数值)
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("clip", x, _ensure_node(lower), _ensure_node(upper))
def atan(x: Union[Node, str]) -> FunctionNode:
"""反正切函数。
计算输入值的反正切值(弧度)。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("atan", x)
def log1p(x: Union[Node, str]) -> FunctionNode:
"""log(1+x) 函数。
计算 log(1+x),对 x 接近 0 的情况更精确。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("log1p", x)
# ==================== 条件函数 ====================
def if_(
condition: Union[Node, str],
true_val: Union[Node, str, int, float],
false_val: Union[Node, str, int, float],
) -> FunctionNode:
"""条件选择。
根据条件选择值。
Args:
condition: 条件表达式或字段名字符串
true_val: 条件为真时的值(因子表达式、字段名字符串或数值)
false_val: 条件为假时的值(因子表达式、字段名字符串或数值)
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode(
"if", condition, _ensure_node(true_val), _ensure_node(false_val)
)
def where(
condition: Union[Node, str],
true_val: Union[Node, str, int, float],
false_val: Union[Node, str, int, float],
) -> FunctionNode:
"""条件选择if_ 的别名)。
Args:
condition: 条件表达式或字段名字符串
true_val: 条件为真时的值(因子表达式、字段名字符串或数值)
false_val: 条件为假时的值(因子表达式、字段名字符串或数值)
Returns:
FunctionNode: 函数调用节点
"""
return if_(condition, true_val, false_val)
# ==================== 补充的时间序列函数 ====================
def ts_sma(x: Union[Node, str], n: int, m: int) -> FunctionNode:
"""国内平滑移动平均 (Simple Moving Average)。
使用 N*M 平滑,公式: SMA = (M*CLOSE + (N-M)*REF(CLOSE,1)) / N
对应国泰君安 191 因子的 SMA 函数。
Args:
x: 输入因子表达式或字段名字符串
n: 第一个参数,对应公式中的 N
m: 第二个参数,对应公式中的 M
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_sma", x, n, m)
def ts_wma(x: Union[Node, str], window: int) -> FunctionNode:
"""线性加权移动平均 (Weighted Moving Average)。
权重按线性递增,最近一天权重最大。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_wma", x, window)
def ts_decay_linear(x: Union[Node, str], window: int) -> FunctionNode:
"""线性衰减移动平均 (Decay Linear)。
与 WMA 等价,权重按线性递增(近期权重最大)。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_decay_linear", x, window)
def ts_argmax(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列最大值位置 (HIGHDAY)。
返回距离过去 window 期内最大值的交易日数。
例如:今天是最高点返回 0昨天是最高点返回 1。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_argmax", x, window)
def ts_argmin(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列最小值位置 (LOWDAY)。
返回距离过去 window 期内最小值的交易日数。
例如:今天是最低点返回 0昨天是最低点返回 1。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_argmin", x, window)
def ts_count(condition: Union[Node, str], window: int) -> FunctionNode:
"""窗口内条件为真的天数。
统计过去 window 内满足条件的交易日数量。
Args:
condition: 条件表达式或字段名字符串(布尔型)
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_count", condition, window)
def ts_prod(x: Union[Node, str], window: int) -> FunctionNode:
"""窗口期内的连乘积。
计算过去 window 期因子值的乘积。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_prod", x, window)
def ts_sumac(x: Union[Node, str]) -> FunctionNode:
"""累计求和。
从序列开始到当前位置的累计求和。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_sumac", x)