"""FactorStorage 单元测试。""" import numpy as np import pandas as pd import polars as pl import pytest from src.data.factor_storage import FactorStorage @pytest.fixture def storage(tmp_path): return FactorStorage(base_dir=tmp_path / "factor") @pytest.fixture def sample_df(): return pl.DataFrame( { "trade_date": ["20240101", "20240102", "20240103"], "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "test_factor": [1.0, 2.0, 3.0], } ) def test_exists_and_save_load(storage, sample_df): assert not storage.exists("test_factor") storage.save("test_factor", sample_df) assert storage.exists("test_factor") loaded = storage.load("test_factor") assert loaded.shape == (3, 3) assert set(loaded.columns) == {"trade_date", "ts_code", "test_factor"} assert loaded["test_factor"].to_list() == [1.0, 2.0, 3.0] def test_get_date_range(storage, sample_df): storage.save("test_factor", sample_df) dr = storage.get_date_range("test_factor") assert dr == ("20240101", "20240103") assert storage.get_date_range("missing") is None def test_load_with_date_filter(storage, sample_df): storage.save("test_factor", sample_df) loaded = storage.load("test_factor", start_date="20240102") assert loaded.shape == (2, 3) assert loaded["trade_date"].to_list() == ["20240102", "20240103"] def test_incremental_update(storage, sample_df): # 第一次写入 storage.save("test_factor", sample_df) # 第二次写入:新增日期 + 覆盖已有日期 new_df = pl.DataFrame( { "trade_date": ["20240103", "20240104"], "ts_code": ["000003.SZ", "000004.SZ"], "test_factor": [30.0, 4.0], } ) storage.save("test_factor", new_df) loaded = storage.load("test_factor").sort(["trade_date", "ts_code"]) assert loaded.shape == (4, 3) assert loaded["test_factor"].to_list() == [1.0, 2.0, 30.0, 4.0] def test_validate_pass_when_no_local_file(storage, sample_df): passed, stats = storage.validate("test_factor", sample_df) assert passed is True assert stats == {} def test_validate_pass_with_identical_data(storage, sample_df): storage.save("test_factor", sample_df) passed, stats = storage.validate("test_factor", sample_df) assert passed is True assert stats["matched_rows"] == 3 assert stats["max_abs_diff"] == pytest.approx(0.0) assert stats["mean_abs_diff"] == pytest.approx(0.0) def test_validate_fail_on_data_mismatch(storage, sample_df): storage.save("test_factor", sample_df) modified = sample_df.with_columns( pl.when(pl.col("trade_date") == "20240101") .then(pl.col("test_factor") + 1.0) .otherwise(pl.col("test_factor")) .alias("test_factor") ) passed, stats = storage.validate("test_factor", modified, tolerance=1e-6) assert passed is False assert stats["matched_rows"] == 3 assert stats["max_abs_diff"] == pytest.approx(1.0) def test_validate_pass_with_non_overlapping_data(storage, sample_df): storage.save("test_factor", sample_df) non_overlap = pl.DataFrame( { "trade_date": ["20240105"], "ts_code": ["000001.SZ"], "test_factor": [99.0], } ) passed, stats = storage.validate("test_factor", non_overlap) assert passed is True assert stats == {} def test_save_preserves_column_order(storage): df = pl.DataFrame( { "trade_date": ["20240101"], "ts_code": ["000001.SZ"], "my_factor": [1.5], } ) storage.save("my_factor", df) pdf = pd.read_hdf(storage._file_path("my_factor"), key=storage._HDF_KEY) assert list(pdf.columns) == ["trade_date", "ts_code", "my_factor"] # type: ignore[attr-defined]