Files
ProStock/tests/test_phase1_2_factors.py
liaozhaorun 1520c2a51e feat(factors): 新增 Phase 1-2 数学和统计因子函数
- 新增 atan, log1p 数学函数
- 新增 ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema 统计函数
- 新增 ts_atr, ts_rsi, ts_obv TA-Lib 技术指标函数
- 新增完整集成测试覆盖所有新函数
2026-03-07 01:03:49 +08:00

542 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Phase 1-2 因子函数集成测试。
测试所有新实现的函数,使用字符串因子表达式形式计算因子,
并与原始 Polars 计算结果进行对比。
测试范围:
1. 数学函数atan, log1p
2. 统计函数ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema
3. TA-Lib 函数ts_atr, ts_rsi, ts_obv
"""
import numpy as np
import polars as pl
import pytest
from src.factors import FormulaParser, FunctionRegistry
from src.factors.translator import PolarsTranslator, HAS_TALIB
from src.factors.engine import FactorEngine
from src.data.catalog import DatabaseCatalog
# ============== 测试数据准备 ==============
def create_test_data() -> pl.DataFrame:
"""创建测试用的模拟数据。
创建一个包含多只股票、多个交易日的 DataFrame
用于测试因子函数的计算。
"""
np.random.seed(42)
dates = pl.date_range(
start=pl.date(2024, 1, 1),
end=pl.date(2024, 1, 31),
interval="1d",
eager=True,
)
stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600001.SH"]
data = []
for stock in stocks:
base_price = 100 + np.random.randn() * 10
for i, date in enumerate(dates):
price = base_price + np.random.randn() * 5 + i * 0.1
data.append(
{
"ts_code": stock,
"trade_date": date,
"close": price,
"open": price * (1 + np.random.randn() * 0.01),
"high": price * (1 + abs(np.random.randn()) * 0.02),
"low": price * (1 - abs(np.random.randn()) * 0.02),
"vol": int(1000000 + np.random.randn() * 500000),
}
)
return pl.DataFrame(data)
# ============== 数学函数测试 ==============
def test_atan_function():
"""测试 atan 函数:计算反正切值。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
),
"value": [0.0, 1.0, -1.0, 0.5, -0.5],
}
)
# DSL 计算
expr = parser.parse("atan(value)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(pl_result=pl.col("value").arctan()).to_pandas()[
"pl_result"
]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_log1p_function():
"""测试 log1p 函数:计算 log(1+x)。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
),
"value": [0.0, 0.1, -0.1, 1.0, -0.5],
}
)
# DSL 计算
expr = parser.parse("log1p(value)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(pl_result=pl.col("value").log1p()).to_pandas()[
"pl_result"
]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
# ============== 统计函数测试 ==============
def test_ts_var_function():
"""测试 ts_var 函数:滚动方差。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 10 + ["B"] * 10,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True
).append(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
),
"close": list(range(1, 11)) + list(range(10, 20)),
}
)
# DSL 计算
expr = parser.parse("ts_var(close, 5)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = (
df.with_columns(dsl_result=polars_expr)
.to_pandas()
.groupby("ts_code")["dsl_result"]
.apply(list)
)
# 原始 Polars 计算
result_pl = (
df.with_columns(
pl_result=pl.col("close").rolling_var(window_size=5).over("ts_code")
)
.to_pandas()
.groupby("ts_code")["pl_result"]
.apply(list)
)
# 对比结果
for stock in ["A", "B"]:
np.testing.assert_array_almost_equal(
result_dsl[stock], result_pl[stock], decimal=10
)
def test_ts_skew_function():
"""测试 ts_skew 函数:滚动偏度。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": np.random.randn(40),
}
)
# DSL 计算
expr = parser.parse("ts_skew(close, 10)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close").rolling_skew(window_size=10).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_kurt_function():
"""测试 ts_kurt 函数:滚动峰度。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": np.random.randn(40),
}
)
# DSL 计算
expr = parser.parse("ts_kurt(close, 10)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close")
.rolling_map(
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
window_size=10,
)
.over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_pct_change_function():
"""测试 ts_pct_change 函数:百分比变化。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5 + ["B"] * 5,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True)
)
* 2,
"close": [100, 105, 102, 108, 110, 50, 52, 48, 55, 60],
}
)
# DSL 计算
expr = parser.parse("ts_pct_change(close, 1)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=(pl.col("close") - pl.col("close").shift(1))
/ pl.col("close").shift(1).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_ema_function():
"""测试 ts_ema 函数:指数移动平均。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 10 + ["B"] * 10,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
)
* 2,
"close": list(range(1, 11)) + list(range(10, 20)),
}
)
# DSL 计算
expr = parser.parse("ts_ema(close, 5)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close").ewm_mean(span=5).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
# ============== TA-Lib 函数测试 ==============
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_atr_function():
"""测试 ts_atr 函数:平均真实波幅。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"high": 100 + np.random.randn(40) * 2,
"low": 98 + np.random.randn(40) * 2,
"close": 99 + np.random.randn(40) * 2,
}
)
# DSL 计算
expr = parser.parse("ts_atr(high, low, close, 14)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
atr = talib.ATR(
stock_df["high"].values,
stock_df["low"].values,
stock_df["close"].values,
timeperiod=14,
)
result_expected.extend(atr)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_rsi_function():
"""测试 ts_rsi 函数:相对强弱指数。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 30 + ["B"] * 30,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(60)),
}
)
# DSL 计算
expr = parser.parse("ts_rsi(close, 14)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
rsi = talib.RSI(stock_df["close"].values, timeperiod=14)
result_expected.extend(rsi)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_obv_function():
"""测试 ts_obv 函数:能量潮指标。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(40)),
"vol": np.random.randint(100000, 1000000, 40).astype(float),
}
)
# DSL 计算
expr = parser.parse("ts_obv(close, vol)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
obv = talib.OBV(
stock_df["close"].values,
stock_df["vol"].values,
)
result_expected.extend(obv)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
# ============== 综合测试 ==============
def test_complex_factor_expressions():
"""测试复杂因子表达式的计算。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 30 + ["B"] * 30,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(60)),
}
)
# 测试 act_factor1: atan((ts_ema(close,5)/ts_delay(ts_ema(close,5),1)-1)*100) * 57.3 / 50
expr = parser.parse(
"atan((ts_ema(close, 5) / ts_delay(ts_ema(close, 5), 1) - 1) * 100) * 57.3 / 50"
)
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result = df.with_columns(factor=polars_expr)
# 验证结果不为空
assert len(result) == 60
assert "factor" in result.columns
print("复杂因子表达式测试通过")
# ============== 主函数 ==============
if __name__ == "__main__":
print("运行 Phase 1-2 因子函数测试...")
print("=" * 80)
# 运行数学函数测试
print("\n[数学函数测试]")
test_atan_function()
print(" ✅ atan 测试通过")
test_log1p_function()
print(" ✅ log1p 测试通过")
# 运行统计函数测试
print("\n[统计函数测试]")
test_ts_var_function()
print(" ✅ ts_var 测试通过")
test_ts_skew_function()
print(" ✅ ts_skew 测试通过")
test_ts_kurt_function()
print(" ✅ ts_kurt 测试通过")
test_ts_pct_change_function()
print(" ✅ ts_pct_change 测试通过")
test_ts_ema_function()
print(" ✅ ts_ema 测试通过")
# 运行 TA-Lib 函数测试
print("\n[TA-Lib 函数测试]")
try:
import talib
HAS_TALIB = True
except ImportError:
HAS_TALIB = False
print(" ⚠️ TA-Lib 未安装,跳过 TA-Lib 测试")
if HAS_TALIB:
test_ts_atr_function()
print(" ✅ ts_atr 测试通过")
test_ts_rsi_function()
print(" ✅ ts_rsi 测试通过")
test_ts_obv_function()
print(" ✅ ts_obv 测试通过")
# 运行综合测试
print("\n[综合测试]")
test_complex_factor_expressions()
print(" ✅ 复杂因子表达式测试通过")
print("\n" + "=" * 80)
print("所有测试通过!")