Files
ProStock/tests/test_phase1_2_factors.py

542 lines
15 KiB
Python
Raw Normal View History

"""Phase 1-2 因子函数集成测试。
测试所有新实现的函数使用字符串因子表达式形式计算因子
并与原始 Polars 计算结果进行对比
测试范围
1. 数学函数atan, log1p
2. 统计函数ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema
3. TA-Lib 函数ts_atr, ts_rsi, ts_obv
"""
import numpy as np
import polars as pl
import pytest
from src.factors import FormulaParser, FunctionRegistry
from src.factors.translator import PolarsTranslator, HAS_TALIB
from src.factors.engine import FactorEngine
from src.data.catalog import DatabaseCatalog
# ============== 测试数据准备 ==============
def create_test_data() -> pl.DataFrame:
"""创建测试用的模拟数据。
创建一个包含多只股票多个交易日的 DataFrame
用于测试因子函数的计算
"""
np.random.seed(42)
dates = pl.date_range(
start=pl.date(2024, 1, 1),
end=pl.date(2024, 1, 31),
interval="1d",
eager=True,
)
stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600001.SH"]
data = []
for stock in stocks:
base_price = 100 + np.random.randn() * 10
for i, date in enumerate(dates):
price = base_price + np.random.randn() * 5 + i * 0.1
data.append(
{
"ts_code": stock,
"trade_date": date,
"close": price,
"open": price * (1 + np.random.randn() * 0.01),
"high": price * (1 + abs(np.random.randn()) * 0.02),
"low": price * (1 - abs(np.random.randn()) * 0.02),
"vol": int(1000000 + np.random.randn() * 500000),
}
)
return pl.DataFrame(data)
# ============== 数学函数测试 ==============
def test_atan_function():
"""测试 atan 函数:计算反正切值。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
),
"value": [0.0, 1.0, -1.0, 0.5, -0.5],
}
)
# DSL 计算
expr = parser.parse("atan(value)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(pl_result=pl.col("value").arctan()).to_pandas()[
"pl_result"
]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_log1p_function():
"""测试 log1p 函数:计算 log(1+x)。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True
),
"value": [0.0, 0.1, -0.1, 1.0, -0.5],
}
)
# DSL 计算
expr = parser.parse("log1p(value)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(pl_result=pl.col("value").log1p()).to_pandas()[
"pl_result"
]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
# ============== 统计函数测试 ==============
def test_ts_var_function():
"""测试 ts_var 函数:滚动方差。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 10 + ["B"] * 10,
"trade_date": pl.date_range(
pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True
).append(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
),
"close": list(range(1, 11)) + list(range(10, 20)),
}
)
# DSL 计算
expr = parser.parse("ts_var(close, 5)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = (
df.with_columns(dsl_result=polars_expr)
.to_pandas()
.groupby("ts_code")["dsl_result"]
.apply(list)
)
# 原始 Polars 计算
result_pl = (
df.with_columns(
pl_result=pl.col("close").rolling_var(window_size=5).over("ts_code")
)
.to_pandas()
.groupby("ts_code")["pl_result"]
.apply(list)
)
# 对比结果
for stock in ["A", "B"]:
np.testing.assert_array_almost_equal(
result_dsl[stock], result_pl[stock], decimal=10
)
def test_ts_skew_function():
"""测试 ts_skew 函数:滚动偏度。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": np.random.randn(40),
}
)
# DSL 计算
expr = parser.parse("ts_skew(close, 10)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close").rolling_skew(window_size=10).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_kurt_function():
"""测试 ts_kurt 函数:滚动峰度。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": np.random.randn(40),
}
)
# DSL 计算
expr = parser.parse("ts_kurt(close, 10)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close")
.rolling_map(
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
window_size=10,
)
.over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_pct_change_function():
"""测试 ts_pct_change 函数:百分比变化。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 5 + ["B"] * 5,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 5), eager=True)
)
* 2,
"close": [100, 105, 102, 108, 110, 50, 52, 48, 55, 60],
}
)
# DSL 计算
expr = parser.parse("ts_pct_change(close, 1)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=(pl.col("close") - pl.col("close").shift(1))
/ pl.col("close").shift(1).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
def test_ts_ema_function():
"""测试 ts_ema 函数:指数移动平均。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
df = pl.DataFrame(
{
"ts_code": ["A"] * 10 + ["B"] * 10,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 10), eager=True)
)
* 2,
"close": list(range(1, 11)) + list(range(10, 20)),
}
)
# DSL 计算
expr = parser.parse("ts_ema(close, 5)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 原始 Polars 计算
result_pl = df.with_columns(
pl_result=pl.col("close").ewm_mean(span=5).over("ts_code")
).to_pandas()["pl_result"]
# 对比结果
np.testing.assert_array_almost_equal(
result_dsl.values, result_pl.values, decimal=10
)
# ============== TA-Lib 函数测试 ==============
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_atr_function():
"""测试 ts_atr 函数:平均真实波幅。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"high": 100 + np.random.randn(40) * 2,
"low": 98 + np.random.randn(40) * 2,
"close": 99 + np.random.randn(40) * 2,
}
)
# DSL 计算
expr = parser.parse("ts_atr(high, low, close, 14)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
atr = talib.ATR(
stock_df["high"].values,
stock_df["low"].values,
stock_df["close"].values,
timeperiod=14,
)
result_expected.extend(atr)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_rsi_function():
"""测试 ts_rsi 函数:相对强弱指数。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 30 + ["B"] * 30,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(60)),
}
)
# DSL 计算
expr = parser.parse("ts_rsi(close, 14)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
rsi = talib.RSI(stock_df["close"].values, timeperiod=14)
result_expected.extend(rsi)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
@pytest.mark.skipif(not HAS_TALIB, reason="TA-Lib not installed")
def test_ts_obv_function():
"""测试 ts_obv 函数:能量潮指标。"""
import talib
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 20 + ["B"] * 20,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 20), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(40)),
"vol": np.random.randint(100000, 1000000, 40).astype(float),
}
)
# DSL 计算
expr = parser.parse("ts_obv(close, vol)")
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result_dsl = df.with_columns(dsl_result=polars_expr).to_pandas()["dsl_result"]
# 使用 talib 手动计算(分组计算)
result_expected = []
for stock in ["A", "B"]:
stock_df = df.filter(pl.col("ts_code") == stock).to_pandas()
obv = talib.OBV(
stock_df["close"].values,
stock_df["vol"].values,
)
result_expected.extend(obv)
# 对比结果(允许小误差)
np.testing.assert_array_almost_equal(
result_dsl.values, np.array(result_expected), decimal=5
)
# ============== 综合测试 ==============
def test_complex_factor_expressions():
"""测试复杂因子表达式的计算。"""
parser = FormulaParser(FunctionRegistry())
# 创建测试数据
np.random.seed(42)
df = pl.DataFrame(
{
"ts_code": ["A"] * 30 + ["B"] * 30,
"trade_date": list(
pl.date_range(pl.date(2024, 1, 1), pl.date(2024, 1, 30), eager=True)
)
* 2,
"close": 100 + np.cumsum(np.random.randn(60)),
}
)
# 测试 act_factor1: atan((ts_ema(close,5)/ts_delay(ts_ema(close,5),1)-1)*100) * 57.3 / 50
expr = parser.parse(
"atan((ts_ema(close, 5) / ts_delay(ts_ema(close, 5), 1) - 1) * 100) * 57.3 / 50"
)
translator = PolarsTranslator()
polars_expr = translator.translate(expr)
result = df.with_columns(factor=polars_expr)
# 验证结果不为空
assert len(result) == 60
assert "factor" in result.columns
print("复杂因子表达式测试通过")
# ============== 主函数 ==============
if __name__ == "__main__":
print("运行 Phase 1-2 因子函数测试...")
print("=" * 80)
# 运行数学函数测试
print("\n[数学函数测试]")
test_atan_function()
print(" ✅ atan 测试通过")
test_log1p_function()
print(" ✅ log1p 测试通过")
# 运行统计函数测试
print("\n[统计函数测试]")
test_ts_var_function()
print(" ✅ ts_var 测试通过")
test_ts_skew_function()
print(" ✅ ts_skew 测试通过")
test_ts_kurt_function()
print(" ✅ ts_kurt 测试通过")
test_ts_pct_change_function()
print(" ✅ ts_pct_change 测试通过")
test_ts_ema_function()
print(" ✅ ts_ema 测试通过")
# 运行 TA-Lib 函数测试
print("\n[TA-Lib 函数测试]")
try:
import talib
HAS_TALIB = True
except ImportError:
HAS_TALIB = False
print(" ⚠️ TA-Lib 未安装,跳过 TA-Lib 测试")
if HAS_TALIB:
test_ts_atr_function()
print(" ✅ ts_atr 测试通过")
test_ts_rsi_function()
print(" ✅ ts_rsi 测试通过")
test_ts_obv_function()
print(" ✅ ts_obv 测试通过")
# 运行综合测试
print("\n[综合测试]")
test_complex_factor_expressions()
print(" ✅ 复杂因子表达式测试通过")
print("\n" + "=" * 80)
print("所有测试通过!")