From 1520c2a51ec6f326844af01385229c02653f5b05 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Sat, 7 Mar 2026 01:03:49 +0800 Subject: [PATCH] =?UTF-8?q?feat(factors):=20=E6=96=B0=E5=A2=9E=20Phase=201?= =?UTF-8?q?-2=20=E6=95=B0=E5=AD=A6=E5=92=8C=E7=BB=9F=E8=AE=A1=E5=9B=A0?= =?UTF-8?q?=E5=AD=90=E5=87=BD=E6=95=B0=20-=20=E6=96=B0=E5=A2=9E=20atan,=20?= =?UTF-8?q?log1p=20=E6=95=B0=E5=AD=A6=E5=87=BD=E6=95=B0=20-=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=20ts=5Fvar,=20ts=5Fskew,=20ts=5Fkurt,=20ts=5Fpct=5Fch?= =?UTF-8?q?ange,=20ts=5Fema=20=E7=BB=9F=E8=AE=A1=E5=87=BD=E6=95=B0=20-=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=20ts=5Fatr,=20ts=5Frsi,=20ts=5Fobv=20TA-Lib?= =?UTF-8?q?=20=E6=8A=80=E6=9C=AF=E6=8C=87=E6=A0=87=E5=87=BD=E6=95=B0=20-?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=E5=AE=8C=E6=95=B4=E9=9B=86=E6=88=90?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E8=A6=86=E7=9B=96=E6=89=80=E6=9C=89=E6=96=B0?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/factors/api.py | 152 +++++++++ src/factors/translator.py | 173 +++++++++++ tests/test_phase1_2_factors.py | 541 +++++++++++++++++++++++++++++++++ 3 files changed, 866 insertions(+) create mode 100644 tests/test_phase1_2_factors.py diff --git a/src/factors/api.py b/src/factors/api.py index 83435d7..229754e 100644 --- a/src/factors/api.py +++ b/src/factors/api.py @@ -190,6 +190,130 @@ def ts_cov(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNod return FunctionNode("ts_cov", x, y, window) +def ts_var(x: Union[Node, str], window: int) -> FunctionNode: + """时间序列方差。 + + 计算给定因子在滚动窗口内的方差。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_var", x, window) + + +def ts_skew(x: Union[Node, str], window: int) -> FunctionNode: + """时间序列偏度。 + + 计算给定因子在滚动窗口内的偏度(三阶矩)。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_skew", x, window) + + +def ts_kurt(x: Union[Node, str], window: int) -> FunctionNode: + """时间序列峰度。 + + 计算给定因子在滚动窗口内的峰度(四阶矩)。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_kurt", x, window) + + +def ts_pct_change(x: Union[Node, str], periods: int) -> FunctionNode: + """时间序列百分比变化。 + + 计算给定因子与 N 个周期前的百分比变化:(x - x.shift(n)) / x.shift(n)。 + + Args: + x: 输入因子表达式或字段名字符串 + periods: 滞后期数 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_pct_change", x, periods) + + +def ts_ema(x: Union[Node, str], window: int) -> FunctionNode: + """指数移动平均。 + + 计算给定因子的指数移动平均值。 + + Args: + x: 输入因子表达式或字段名字符串 + window: 指数移动平均的 span 参数 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_ema", x, window) + + +def ts_atr( + high: Union[Node, str], low: Union[Node, str], close: Union[Node, str], window: int +) -> FunctionNode: + """平均真实波幅 (Average True Range)。 + + 计算给定窗口内的平均真实波幅,使用 TA-Lib 实现。 + + Args: + high: 最高价表达式或字段名字符串 + low: 最低价表达式或字段名字符串 + close: 收盘价表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_atr", high, low, close, window) + + +def ts_rsi(close: Union[Node, str], window: int) -> FunctionNode: + """相对强弱指数 (Relative Strength Index)。 + + 计算给定窗口内的 RSI 值,使用 TA-Lib 实现。 + + Args: + close: 收盘价表达式或字段名字符串 + window: 滚动窗口大小 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_rsi", close, window) + + +def ts_obv(close: Union[Node, str], volume: Union[Node, str]) -> FunctionNode: + """能量潮指标 (On Balance Volume)。 + + 计算 OBV 值,使用 TA-Lib 实现。 + + Args: + close: 收盘价表达式或字段名字符串 + volume: 成交量表达式或字段名字符串 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("ts_obv", close, volume) + + def ts_rank(x: Union[Node, str], window: int) -> FunctionNode: """时间序列排名。 @@ -429,6 +553,34 @@ def clip( return FunctionNode("clip", x, _ensure_node(lower), _ensure_node(upper)) +def atan(x: Union[Node, str]) -> FunctionNode: + """反正切函数。 + + 计算输入值的反正切值(弧度)。 + + Args: + x: 输入因子表达式或字段名字符串 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("atan", x) + + +def log1p(x: Union[Node, str]) -> FunctionNode: + """log(1+x) 函数。 + + 计算 log(1+x),对 x 接近 0 的情况更精确。 + + Args: + x: 输入因子表达式或字段名字符串 + + Returns: + FunctionNode: 函数调用节点 + """ + return FunctionNode("log1p", x) + + # ==================== 条件函数 ==================== diff --git a/src/factors/translator.py b/src/factors/translator.py index 3bc20a5..2c79443 100644 --- a/src/factors/translator.py +++ b/src/factors/translator.py @@ -6,8 +6,18 @@ from typing import Any, Callable, Dict +import numpy as np import polars as pl +# TA-Lib 可选依赖 +try: + import talib + + HAS_TALIB = True +except ImportError: + HAS_TALIB = False + talib = None + from src.factors.decorators import cross_section, element_wise, time_series from src.factors.dsl import ( BinaryOpNode, @@ -53,6 +63,14 @@ class PolarsTranslator: self.register_handler("ts_delta", self._handle_ts_delta) self.register_handler("ts_corr", self._handle_ts_corr) self.register_handler("ts_cov", self._handle_ts_cov) + self.register_handler("ts_var", self._handle_ts_var) + self.register_handler("ts_skew", self._handle_ts_skew) + self.register_handler("ts_kurt", self._handle_ts_kurt) + self.register_handler("ts_pct_change", self._handle_ts_pct_change) + self.register_handler("ts_ema", self._handle_ts_ema) + self.register_handler("ts_atr", self._handle_ts_atr) + self.register_handler("ts_rsi", self._handle_ts_rsi) + self.register_handler("ts_obv", self._handle_ts_obv) # 截面因子处理器 (cs_*) self.register_handler("cs_rank", self._handle_cs_rank) @@ -66,6 +84,8 @@ class PolarsTranslator: self.register_handler("sign", self._handle_sign) self.register_handler("cos", self._handle_cos) self.register_handler("sin", self._handle_sin) + self.register_handler("atan", self._handle_atan) + self.register_handler("log1p", self._handle_log1p) def register_handler( self, func_name: str, handler: Callable[[FunctionNode], pl.Expr] @@ -295,6 +315,143 @@ class PolarsTranslator: window = self._extract_window(node.args[2]) return x.rolling_cov(y, window_size=window) + @time_series + def _handle_ts_var(self, node: FunctionNode) -> pl.Expr: + """处理 ts_var(close, window) -> rolling_var(window)。""" + if len(node.args) != 2: + raise ValueError("ts_var 需要 2 个参数: (expr, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + return expr.rolling_var(window_size=window) + + @time_series + def _handle_ts_skew(self, node: FunctionNode) -> pl.Expr: + """处理 ts_skew(close, window) -> rolling_skew(window)。""" + if len(node.args) != 2: + raise ValueError("ts_skew 需要 2 个参数: (expr, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + return expr.rolling_skew(window_size=window) + + @time_series + def _handle_ts_kurt(self, node: FunctionNode) -> pl.Expr: + """处理 ts_kurt(close, window) -> rolling_kurt(window)。""" + if len(node.args) != 2: + raise ValueError("ts_kurt 需要 2 个参数: (expr, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + # 使用 rolling_map 计算峰度 + return expr.rolling_map( + lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"), + window_size=window, + ) + + @time_series + def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr: + """处理 ts_pct_change(x, n) -> (x - shift(n)) / shift(n)。""" + if len(node.args) != 2: + raise ValueError("ts_pct_change 需要 2 个参数: (expr, periods)") + expr = self.translate(node.args[0]) + n = self._extract_window(node.args[1]) + shifted = expr.shift(n) + return (expr - shifted) / shifted + + @time_series + def _handle_ts_ema(self, node: FunctionNode) -> pl.Expr: + """处理 ts_ema(x, window) -> ewm_mean(span=window)。""" + if len(node.args) != 2: + raise ValueError("ts_ema 需要 2 个参数: (expr, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + return expr.ewm_mean(span=window) + + @time_series + def _handle_ts_atr(self, node: FunctionNode) -> pl.Expr: + """处理 ts_atr(high, low, close, window) -> 使用 TA-Lib 计算 ATR。 + + 使用 map_batches 在每个分组上应用 TA-Lib ATR 函数。 + @time_series 装饰器会自动添加 .over("ts_code") + """ + if not HAS_TALIB: + raise ImportError("ts_atr 需要安装 TA-Lib。请运行: pip install TA-Lib") + if len(node.args) != 4: + raise ValueError("ts_atr 需要 4 个参数: (high, low, close, window)") + + high = self.translate(node.args[0]) + low = self.translate(node.args[1]) + close = self.translate(node.args[2]) + window = self._extract_window(node.args[3]) + + # 使用 map_batches 应用 TA-Lib ATR 到整个分组 + def calc_atr(struct_series: pl.Series) -> pl.Series: + """计算 ATR 的辅助函数。""" + if len(struct_series) == 0: + return pl.Series([float("nan")] * len(struct_series)) + + # struct_series 包含 h, l, c 三个字段 + h = np.array(struct_series.struct.field("h").to_list(), dtype=float) + l = np.array(struct_series.struct.field("l").to_list(), dtype=float) + c = np.array(struct_series.struct.field("c").to_list(), dtype=float) + result = talib.ATR(h, l, c, timeperiod=window) + return pl.Series(result) + + return pl.struct( + [high.alias("h"), low.alias("l"), close.alias("c")] + ).map_batches(calc_atr) + + @time_series + def _handle_ts_rsi(self, node: FunctionNode) -> pl.Expr: + """处理 ts_rsi(close, window) -> 使用 TA-Lib 计算 RSI。 + + 使用 map_batches 在每个分组上应用 TA-Lib RSI 函数。 + @time_series 装饰器会自动添加 .over("ts_code") + """ + if not HAS_TALIB: + raise ImportError("ts_rsi 需要安装 TA-Lib。请运行: pip install TA-Lib") + if len(node.args) != 2: + raise ValueError("ts_rsi 需要 2 个参数: (close, window)") + + close = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + + # 使用 map_batches 应用 TA-Lib RSI 到整个分组 + def calc_rsi(series: pl.Series) -> pl.Series: + """计算 RSI 的辅助函数。""" + values = np.array(series.to_list(), dtype=float) + result = talib.RSI(values, timeperiod=window) + return pl.Series(result) + + return close.map_batches(calc_rsi) + + @time_series + def _handle_ts_obv(self, node: FunctionNode) -> pl.Expr: + """处理 ts_obv(close, volume) -> 使用 TA-Lib 计算 OBV。 + + 使用 map_batches 在每个分组上应用 TA-Lib OBV 函数。 + @time_series 装饰器会自动添加 .over("ts_code") + """ + if not HAS_TALIB: + raise ImportError("ts_obv 需要安装 TA-Lib。请运行: pip install TA-Lib") + if len(node.args) != 2: + raise ValueError("ts_obv 需要 2 个参数: (close, volume)") + + close = self.translate(node.args[0]) + volume = self.translate(node.args[1]) + + # 使用 map_batches 应用 TA-Lib OBV 到整个分组 + def calc_obv(struct_series: pl.Series) -> pl.Series: + """计算 OBV 的辅助函数。""" + if len(struct_series) == 0: + return pl.Series([float("nan")] * len(struct_series)) + + # struct_series 包含 c 和 v 两个字段 + c = np.array(struct_series.struct.field("c").to_list(), dtype=float) + v = np.array(struct_series.struct.field("v").to_list(), dtype=float) + result = talib.OBV(c, v) + return pl.Series(result) + + return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv) + # ==================== 截面因子处理器 (cs_*) ==================== # 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表 @@ -377,6 +534,22 @@ class PolarsTranslator: expr = self.translate(node.args[0]) return expr.sin() + @element_wise + def _handle_atan(self, node: FunctionNode) -> pl.Expr: + """处理 atan(expr) -> 反正切函数。""" + if len(node.args) != 1: + raise ValueError("atan 需要 1 个参数: (expr)") + expr = self.translate(node.args[0]) + return expr.arctan() + + @element_wise + def _handle_log1p(self, node: FunctionNode) -> pl.Expr: + """处理 log1p(expr) -> log(1+x) 函数。""" + if len(node.args) != 1: + raise ValueError("log1p 需要 1 个参数: (expr)") + expr = self.translate(node.args[0]) + return expr.log1p() + # ==================== 辅助方法 ==================== def _extract_window(self, node: Node) -> int: diff --git a/tests/test_phase1_2_factors.py b/tests/test_phase1_2_factors.py new file mode 100644 index 0000000..acc464c --- /dev/null +++ b/tests/test_phase1_2_factors.py @@ -0,0 +1,541 @@ +"""Phase 1-2 因子函数集成测试。 + +测试所有新实现的函数,使用字符串因子表达式形式计算因子, +并与原始 Polars 计算结果进行对比。 + +测试范围: +1. 数学函数:atan, log1p +2. 统计函数:ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema +3. TA-Lib 函数:ts_atr, ts_rsi, ts_obv +""" + +import numpy as np +import polars as pl +import pytest + +from src.factors import FormulaParser, FunctionRegistry +from src.factors.translator import PolarsTranslator, HAS_TALIB +from src.factors.engine import FactorEngine +from src.data.catalog import DatabaseCatalog + + +# ============== 测试数据准备 ============== + + +def create_test_data() -> pl.DataFrame: + """创建测试用的模拟数据。 + + 创建一个包含多只股票、多个交易日的 DataFrame, + 用于测试因子函数的计算。 + """ + np.random.seed(42) + + dates = pl.date_range( + start=pl.date(2024, 1, 1), + end=pl.date(2024, 1, 31), + interval="1d", + eager=True, + ) + + stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600001.SH"] + + data = [] + for stock in stocks: + base_price = 100 + np.random.randn() * 10 + for i, date in enumerate(dates): + price = base_price + np.random.randn() * 5 + i * 0.1 + data.append( + { + "ts_code": stock, + "trade_date": date, + "close": price, + "open": price * (1 + np.random.randn() * 0.01), + "high": price * (1 + abs(np.random.randn()) * 0.02), + "low": price * (1 - abs(np.random.randn()) * 0.02), + "vol": int(1000000 + np.random.randn() * 500000), + } + ) + + return pl.DataFrame(data) + + +# ============== 数学函数测试 ============== + + +def test_atan_function(): + """测试 atan 函数:计算反正切值。""" + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + df = pl.DataFrame( + { + "ts_code": ["A"] * 5, + "trade_date": pl.date_range( + pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True + ), + "value": [0.0, 1.0, -1.0, 0.5, -0.5], + } + ) + + # DSL 计算 + expr = parser.parse("atan(value)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 原始 Polars 计算 + result_pl = df.with_columns(pl_result=pl.col("value").arctan()).to_pandas()[ + "pl_result" + ] + + # 对比结果 + np.testing.assert_array_almost_equal( + result_dsl.values, result_pl.values, decimal=10 + ) + + +def test_log1p_function(): + """测试 log1p 函数:计算 log(1+x)。""" + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + df = pl.DataFrame( + { + "ts_code": ["A"] * 5, + "trade_date": pl.date_range( + pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True + ), + "value": [0.0, 0.1, -0.1, 1.0, -0.5], + } + ) + + # DSL 计算 + expr = parser.parse("log1p(value)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 原始 Polars 计算 + result_pl = df.with_columns(pl_result=pl.col("value").log1p()).to_pandas()[ + "pl_result" + ] + + # 对比结果 + np.testing.assert_array_almost_equal( + result_dsl.values, result_pl.values, decimal=10 + ) + + +# ============== 统计函数测试 ============== + + +def test_ts_var_function(): + """测试 ts_var 函数:滚动方差。""" + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + df = pl.DataFrame( + { + "ts_code": ["A"] * 10 + ["B"] * 10, + "trade_date": pl.date_range( + pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True + ).append( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True) + ), + "close": list(range(1, 11)) + list(range(10, 20)), + } + ) + + # DSL 计算 + expr = parser.parse("ts_var(close, 5)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = ( + df.with_columns(dsl_result=polars_expr) + .to_pandas() + .groupby("ts_code")["dsl_result"] + .apply(list) + ) + + # 原始 Polars 计算 + result_pl = ( + df.with_columns( + pl_result=pl.col("close").rolling_var(window_size=5).over("ts_code") + ) + .to_pandas() + .groupby("ts_code")["pl_result"] + .apply(list) + ) + + # 对比结果 + for stock in ["A", "B"]: + np.testing.assert_array_almost_equal( + result_dsl[stock], result_pl[stock], decimal=10 + ) + + +def test_ts_skew_function(): + """测试 ts_skew 函数:滚动偏度。""" + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + np.random.seed(42) + df = pl.DataFrame( + { + "ts_code": ["A"] * 20 + ["B"] * 20, + "trade_date": list( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) + ) + * 2, + "close": np.random.randn(40), + } + ) + + # DSL 计算 + expr = parser.parse("ts_skew(close, 10)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 原始 Polars 计算 + result_pl = df.with_columns( + pl_result=pl.col("close").rolling_skew(window_size=10).over("ts_code") + ).to_pandas()["pl_result"] + + # 对比结果 + np.testing.assert_array_almost_equal( + result_dsl.values, result_pl.values, decimal=10 + ) + + +def test_ts_kurt_function(): + """测试 ts_kurt 函数:滚动峰度。""" + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + np.random.seed(42) + df = pl.DataFrame( + { + "ts_code": ["A"] * 20 + ["B"] * 20, + "trade_date": list( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) + ) + * 2, + "close": np.random.randn(40), + } + ) + + # DSL 计算 + expr = parser.parse("ts_kurt(close, 10)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 原始 Polars 计算 + result_pl = df.with_columns( + pl_result=pl.col("close") + .rolling_map( + lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"), + window_size=10, + ) + .over("ts_code") + ).to_pandas()["pl_result"] + + # 对比结果 + np.testing.assert_array_almost_equal( + result_dsl.values, result_pl.values, decimal=10 + ) + + +def test_ts_pct_change_function(): + """测试 ts_pct_change 函数:百分比变化。""" + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + df = pl.DataFrame( + { + "ts_code": ["A"] * 5 + ["B"] * 5, + "trade_date": list( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True) + ) + * 2, + "close": [100, 105, 102, 108, 110, 50, 52, 48, 55, 60], + } + ) + + # DSL 计算 + expr = parser.parse("ts_pct_change(close, 1)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 原始 Polars 计算 + result_pl = df.with_columns( + pl_result=(pl.col("close") - pl.col("close").shift(1)) + / pl.col("close").shift(1).over("ts_code") + ).to_pandas()["pl_result"] + + # 对比结果 + np.testing.assert_array_almost_equal( + result_dsl.values, result_pl.values, decimal=10 + ) + + +def test_ts_ema_function(): + """测试 ts_ema 函数:指数移动平均。""" + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + df = pl.DataFrame( + { + "ts_code": ["A"] * 10 + ["B"] * 10, + "trade_date": list( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True) + ) + * 2, + "close": list(range(1, 11)) + list(range(10, 20)), + } + ) + + # DSL 计算 + expr = parser.parse("ts_ema(close, 5)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 原始 Polars 计算 + result_pl = df.with_columns( + pl_result=pl.col("close").ewm_mean(span=5).over("ts_code") + ).to_pandas()["pl_result"] + + # 对比结果 + np.testing.assert_array_almost_equal( + result_dsl.values, result_pl.values, decimal=10 + ) + + +# ============== TA-Lib 函数测试 ============== + + +@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") +def test_ts_atr_function(): + """测试 ts_atr 函数:平均真实波幅。""" + import talib + + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + np.random.seed(42) + df = pl.DataFrame( + { + "ts_code": ["A"] * 20 + ["B"] * 20, + "trade_date": list( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) + ) + * 2, + "high": 100 + np.random.randn(40) * 2, + "low": 98 + np.random.randn(40) * 2, + "close": 99 + np.random.randn(40) * 2, + } + ) + + # DSL 计算 + expr = parser.parse("ts_atr(high, low, close, 14)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 使用 talib 手动计算(分组计算) + result_expected = [] + for stock in ["A", "B"]: + stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() + atr = talib.ATR( + stock_df["high"].values, + stock_df["low"].values, + stock_df["close"].values, + timeperiod=14, + ) + result_expected.extend(atr) + + # 对比结果(允许小误差) + np.testing.assert_array_almost_equal( + result_dsl.values, np.array(result_expected), decimal=5 + ) + + +@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") +def test_ts_rsi_function(): + """测试 ts_rsi 函数:相对强弱指数。""" + import talib + + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + np.random.seed(42) + df = pl.DataFrame( + { + "ts_code": ["A"] * 30 + ["B"] * 30, + "trade_date": list( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True) + ) + * 2, + "close": 100 + np.cumsum(np.random.randn(60)), + } + ) + + # DSL 计算 + expr = parser.parse("ts_rsi(close, 14)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 使用 talib 手动计算(分组计算) + result_expected = [] + for stock in ["A", "B"]: + stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() + rsi = talib.RSI(stock_df["close"].values, timeperiod=14) + result_expected.extend(rsi) + + # 对比结果(允许小误差) + np.testing.assert_array_almost_equal( + result_dsl.values, np.array(result_expected), decimal=5 + ) + + +@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed") +def test_ts_obv_function(): + """测试 ts_obv 函数:能量潮指标。""" + import talib + + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + np.random.seed(42) + df = pl.DataFrame( + { + "ts_code": ["A"] * 20 + ["B"] * 20, + "trade_date": list( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True) + ) + * 2, + "close": 100 + np.cumsum(np.random.randn(40)), + "vol": np.random.randint(100000, 1000000, 40).astype(float), + } + ) + + # DSL 计算 + expr = parser.parse("ts_obv(close, vol)") + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"] + + # 使用 talib 手动计算(分组计算) + result_expected = [] + for stock in ["A", "B"]: + stock_df = df.filter(pl.col("ts_code") == stock).to_pandas() + obv = talib.OBV( + stock_df["close"].values, + stock_df["vol"].values, + ) + result_expected.extend(obv) + + # 对比结果(允许小误差) + np.testing.assert_array_almost_equal( + result_dsl.values, np.array(result_expected), decimal=5 + ) + + +# ============== 综合测试 ============== + + +def test_complex_factor_expressions(): + """测试复杂因子表达式的计算。""" + parser = FormulaParser(FunctionRegistry()) + + # 创建测试数据 + np.random.seed(42) + df = pl.DataFrame( + { + "ts_code": ["A"] * 30 + ["B"] * 30, + "trade_date": list( + pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True) + ) + * 2, + "close": 100 + np.cumsum(np.random.randn(60)), + } + ) + + # 测试 act_factor1: atan((ts_ema(close,5)/ts_delay(ts_ema(close,5),1)-1)*100) * 57.3 / 50 + expr = parser.parse( + "atan((ts_ema(close, 5) / ts_delay(ts_ema(close, 5), 1) - 1) * 100) * 57.3 / 50" + ) + translator = PolarsTranslator() + polars_expr = translator.translate(expr) + result = df.with_columns(factor=polars_expr) + + # 验证结果不为空 + assert len(result) == 60 + assert "factor" in result.columns + + print("复杂因子表达式测试通过") + + +# ============== 主函数 ============== + + +if __name__ == "__main__": + print("运行 Phase 1-2 因子函数测试...") + print("=" * 80) + + # 运行数学函数测试 + print("\n[数学函数测试]") + test_atan_function() + print(" ✅ atan 测试通过") + + test_log1p_function() + print(" ✅ log1p 测试通过") + + # 运行统计函数测试 + print("\n[统计函数测试]") + test_ts_var_function() + print(" ✅ ts_var 测试通过") + + test_ts_skew_function() + print(" ✅ ts_skew 测试通过") + + test_ts_kurt_function() + print(" ✅ ts_kurt 测试通过") + + test_ts_pct_change_function() + print(" ✅ ts_pct_change 测试通过") + + test_ts_ema_function() + print(" ✅ ts_ema 测试通过") + + # 运行 TA-Lib 函数测试 + print("\n[TA-Lib 函数测试]") + try: + import talib + + HAS_TALIB = True + except ImportError: + HAS_TALIB = False + print(" ⚠️ TA-Lib 未安装,跳过 TA-Lib 测试") + + if HAS_TALIB: + test_ts_atr_function() + print(" ✅ ts_atr 测试通过") + + test_ts_rsi_function() + print(" ✅ ts_rsi 测试通过") + + test_ts_obv_function() + print(" ✅ ts_obv 测试通过") + + # 运行综合测试 + print("\n[综合测试]") + test_complex_factor_expressions() + print(" ✅ 复杂因子表达式测试通过") + + print("\n" + "=" * 80) + print("所有测试通过!")