158 lines
5.1 KiB
Python
158 lines
5.1 KiB
Python
|
|
"""Tests for library I/O and paper factor imports."""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
from src.factorminer.core.factor_library import Factor, FactorLibrary
|
|||
|
|
from src.factorminer.core.library_io import (
|
|||
|
|
import_from_paper,
|
|||
|
|
load_library,
|
|||
|
|
save_library,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestSaveLoadLibrary:
|
|||
|
|
"""测试 FactorLibrary 的序列化与反序列化."""
|
|||
|
|
|
|||
|
|
def test_save_library_ignores_save_signals(self, tmp_path: Path) -> None:
|
|||
|
|
"""save_signals=True 也不应生成 .npz 文件."""
|
|||
|
|
library = FactorLibrary()
|
|||
|
|
factor = Factor(
|
|||
|
|
id=0,
|
|||
|
|
name="test_factor",
|
|||
|
|
formula="close / ts_delay(close, 1) - 1",
|
|||
|
|
category="Momentum",
|
|||
|
|
ic_mean=0.05,
|
|||
|
|
icir=0.5,
|
|||
|
|
ic_win_rate=0.55,
|
|||
|
|
max_correlation=0.1,
|
|||
|
|
batch_number=1,
|
|||
|
|
)
|
|||
|
|
# 即使给一个信号矩阵,也不应保存
|
|||
|
|
factor.signals = np.ones((10, 20))
|
|||
|
|
library.admit_factor(factor)
|
|||
|
|
|
|||
|
|
base_path = tmp_path / "test_lib"
|
|||
|
|
save_library(library, str(base_path), save_signals=True)
|
|||
|
|
|
|||
|
|
assert (base_path.with_suffix(".json")).exists()
|
|||
|
|
assert not (Path(str(base_path) + "_signals.npz")).exists()
|
|||
|
|
|
|||
|
|
def test_load_library_restores_metadata_and_unsupported(
|
|||
|
|
self, tmp_path: Path
|
|||
|
|
) -> None:
|
|||
|
|
"""加载 JSON 后应恢复 metadata,并对 # TODO 公式标记 unsupported."""
|
|||
|
|
library = FactorLibrary()
|
|||
|
|
f1 = Factor(
|
|||
|
|
id=0,
|
|||
|
|
name="ok_factor",
|
|||
|
|
formula="cs_rank(close)",
|
|||
|
|
category="Test",
|
|||
|
|
ic_mean=0.0,
|
|||
|
|
icir=0.0,
|
|||
|
|
ic_win_rate=0.0,
|
|||
|
|
max_correlation=0.0,
|
|||
|
|
batch_number=0,
|
|||
|
|
metadata={"author": "ai"},
|
|||
|
|
)
|
|||
|
|
f2 = Factor(
|
|||
|
|
id=0,
|
|||
|
|
name="todo_factor",
|
|||
|
|
formula="# TODO: Neg(CsRank(Decay(close, 10)))",
|
|||
|
|
category="Test",
|
|||
|
|
ic_mean=0.0,
|
|||
|
|
icir=0.0,
|
|||
|
|
ic_win_rate=0.0,
|
|||
|
|
max_correlation=0.0,
|
|||
|
|
batch_number=0,
|
|||
|
|
)
|
|||
|
|
library.admit_factor(f1)
|
|||
|
|
library.admit_factor(f2)
|
|||
|
|
|
|||
|
|
base_path = tmp_path / "meta_lib"
|
|||
|
|
save_library(library, str(base_path))
|
|||
|
|
|
|||
|
|
loaded = load_library(str(base_path))
|
|||
|
|
assert loaded.size == 2
|
|||
|
|
|
|||
|
|
f1_loaded = loaded.get_factor(1)
|
|||
|
|
assert f1_loaded.metadata.get("author") == "ai"
|
|||
|
|
assert not f1_loaded.metadata.get("unsupported", False)
|
|||
|
|
|
|||
|
|
f2_loaded = loaded.get_factor(2)
|
|||
|
|
assert f2_loaded.metadata.get("unsupported") is True
|
|||
|
|
|
|||
|
|
def test_factor_round_trip_with_metadata(self) -> None:
|
|||
|
|
"""Factor.to_dict / from_dict 应正确传递 metadata."""
|
|||
|
|
factor = Factor(
|
|||
|
|
id=1,
|
|||
|
|
name="round_trip",
|
|||
|
|
formula="ts_mean(close, 20)",
|
|||
|
|
category="Momentum",
|
|||
|
|
ic_mean=0.1,
|
|||
|
|
icir=1.0,
|
|||
|
|
ic_win_rate=0.6,
|
|||
|
|
max_correlation=0.2,
|
|||
|
|
batch_number=2,
|
|||
|
|
metadata={"unsupported": True, "tags": ["test"]},
|
|||
|
|
)
|
|||
|
|
d = factor.to_dict()
|
|||
|
|
restored = Factor.from_dict(d)
|
|||
|
|
assert restored.metadata == factor.metadata
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestImportFromPaper:
|
|||
|
|
"""测试从内置 paper catalog 导入因子."""
|
|||
|
|
|
|||
|
|
def test_import_from_paper_includes_all_translated_factors(self) -> None:
|
|||
|
|
"""内置 PAPER_FACTORS 应全部成功导入."""
|
|||
|
|
library = import_from_paper()
|
|||
|
|
assert library.size > 0
|
|||
|
|
# 当前 catalog 中已有因子应全部被 admit
|
|||
|
|
for factor in library.list_factors():
|
|||
|
|
assert factor.id > 0
|
|||
|
|
assert factor.name
|
|||
|
|
assert factor.formula
|
|||
|
|
assert factor.category
|
|||
|
|
|
|||
|
|
def test_import_from_paper_marks_todo_as_unsupported(self, tmp_path: Path) -> None:
|
|||
|
|
"""对 # TODO 公式应在 metadata 中标记 unsupported."""
|
|||
|
|
custom_path = tmp_path / "custom_factors.json"
|
|||
|
|
custom_data = [
|
|||
|
|
{
|
|||
|
|
"name": "Normal Factor",
|
|||
|
|
"formula": "cs_rank(close)",
|
|||
|
|
"category": "Test",
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "Unsupported Factor",
|
|||
|
|
"formula": "# TODO: Neg(CsRank(Decay(close, 10)))",
|
|||
|
|
"category": "Test",
|
|||
|
|
},
|
|||
|
|
]
|
|||
|
|
custom_path.write_text(json.dumps(custom_data), encoding="utf-8")
|
|||
|
|
|
|||
|
|
library = import_from_paper(str(custom_path))
|
|||
|
|
assert library.size == 2
|
|||
|
|
|
|||
|
|
normal = library.list_factors()[0]
|
|||
|
|
todo = library.list_factors()[1]
|
|||
|
|
|
|||
|
|
assert normal.metadata.get("unsupported") is None
|
|||
|
|
assert todo.metadata.get("unsupported") is True
|
|||
|
|
|
|||
|
|
def test_import_from_paper_path_override(self, tmp_path: Path) -> None:
|
|||
|
|
"""通过 path 参数加载外部 JSON 列表."""
|
|||
|
|
custom_path = tmp_path / "override.json"
|
|||
|
|
custom_data = [
|
|||
|
|
{"name": "custom_1", "formula": "open + close", "category": "Custom"},
|
|||
|
|
]
|
|||
|
|
custom_path.write_text(json.dumps(custom_data), encoding="utf-8")
|
|||
|
|
|
|||
|
|
library = import_from_paper(str(custom_path))
|
|||
|
|
assert library.size == 1
|
|||
|
|
assert library.list_factors()[0].name == "custom_1"
|