Files
ProStock/tests/test_factor_storage.py

128 lines
3.8 KiB
Python
Raw Normal View History

"""FactorStorage 单元测试。"""
import numpy as np
import pandas as pd
import polars as pl
import pytest
from src.data.factor_storage import FactorStorage
@pytest.fixture
def storage(tmp_path):
return FactorStorage(base_dir=tmp_path / "factor")
@pytest.fixture
def sample_df():
return pl.DataFrame(
{
"trade_date": ["20240101", "20240102", "20240103"],
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"test_factor": [1.0, 2.0, 3.0],
}
)
def test_exists_and_save_load(storage, sample_df):
assert not storage.exists("test_factor")
storage.save("test_factor", sample_df)
assert storage.exists("test_factor")
loaded = storage.load("test_factor")
assert loaded.shape == (3, 3)
assert set(loaded.columns) == {"trade_date", "ts_code", "test_factor"}
assert loaded["test_factor"].to_list() == [1.0, 2.0, 3.0]
def test_get_date_range(storage, sample_df):
storage.save("test_factor", sample_df)
dr = storage.get_date_range("test_factor")
assert dr == ("20240101", "20240103")
assert storage.get_date_range("missing") is None
def test_load_with_date_filter(storage, sample_df):
storage.save("test_factor", sample_df)
loaded = storage.load("test_factor", start_date="20240102")
assert loaded.shape == (2, 3)
assert loaded["trade_date"].to_list() == ["20240102", "20240103"]
def test_incremental_update(storage, sample_df):
# 第一次写入
storage.save("test_factor", sample_df)
# 第二次写入:新增日期 + 覆盖已有日期
new_df = pl.DataFrame(
{
"trade_date": ["20240103", "20240104"],
"ts_code": ["000003.SZ", "000004.SZ"],
"test_factor": [30.0, 4.0],
}
)
storage.save("test_factor", new_df)
loaded = storage.load("test_factor").sort(["trade_date", "ts_code"])
assert loaded.shape == (4, 3)
assert loaded["test_factor"].to_list() == [1.0, 2.0, 30.0, 4.0]
def test_validate_pass_when_no_local_file(storage, sample_df):
passed, stats = storage.validate("test_factor", sample_df)
assert passed is True
assert stats == {}
def test_validate_pass_with_identical_data(storage, sample_df):
storage.save("test_factor", sample_df)
passed, stats = storage.validate("test_factor", sample_df)
assert passed is True
assert stats["matched_rows"] == 3
assert stats["max_abs_diff"] == pytest.approx(0.0)
assert stats["mean_abs_diff"] == pytest.approx(0.0)
def test_validate_fail_on_data_mismatch(storage, sample_df):
storage.save("test_factor", sample_df)
modified = sample_df.with_columns(
pl.when(pl.col("trade_date") == "20240101")
.then(pl.col("test_factor") + 1.0)
.otherwise(pl.col("test_factor"))
.alias("test_factor")
)
passed, stats = storage.validate("test_factor", modified, tolerance=1e-6)
assert passed is False
assert stats["matched_rows"] == 3
assert stats["max_abs_diff"] == pytest.approx(1.0)
def test_validate_pass_with_non_overlapping_data(storage, sample_df):
storage.save("test_factor", sample_df)
non_overlap = pl.DataFrame(
{
"trade_date": ["20240105"],
"ts_code": ["000001.SZ"],
"test_factor": [99.0],
}
)
passed, stats = storage.validate("test_factor", non_overlap)
assert passed is True
assert stats == {}
def test_save_preserves_column_order(storage):
df = pl.DataFrame(
{
"trade_date": ["20240101"],
"ts_code": ["000001.SZ"],
"my_factor": [1.5],
}
)
storage.save("my_factor", df)
pdf = pd.read_hdf(storage._file_path("my_factor"), key=storage._HDF_KEY)
assert list(pdf.columns) == ["trade_date", "ts_code", "my_factor"] # type: ignore[attr-defined]