refactor: 存储层迁移DuckDB + 模块重构
- 存储层重构: HDF5 → DuckDB(UPSERT模式、线程安全存储) - Sync类迁移: DataSync从sync.py迁移到api_daily.py(职责分离) - 模型模块重构: src/models → src/pipeline(更清晰的命名) - 新增因子模块: factors/momentum (MA、收益率排名)、factors/financial - 新增API接口: api_namechange、api_bak_basic - 新增训练入口: training模块(main.py、pipeline配置) - 工具函数统一: get_today_date等移至utils.py - 文档更新: AGENTS.md添加架构变更历史
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""模型框架核心测试
|
||||
"""Pipeline 组件库核心测试
|
||||
|
||||
测试核心抽象类、插件注册中心、处理器、模型和划分策略。
|
||||
"""
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
from typing import List, Optional
|
||||
|
||||
# 确保导入时注册所有组件
|
||||
from src.models import (
|
||||
from src.pipeline import (
|
||||
PluginRegistry,
|
||||
PipelineStage,
|
||||
BaseProcessor,
|
||||
@@ -17,7 +17,7 @@ from src.models import (
|
||||
BaseSplitter,
|
||||
ProcessingPipeline,
|
||||
)
|
||||
from src.models.core import TaskType
|
||||
from src.pipeline.core import TaskType
|
||||
|
||||
|
||||
# ========== 测试核心抽象类 ==========
|
||||
@@ -232,7 +232,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_dropna_processor(self):
|
||||
"""测试缺失值删除处理器"""
|
||||
from src.models.processors import DropNAProcessor
|
||||
from src.pipeline.processors import DropNAProcessor
|
||||
|
||||
processor = DropNAProcessor(columns=["a", "b"])
|
||||
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
|
||||
@@ -246,7 +246,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_fillna_processor(self):
|
||||
"""测试缺失值填充处理器"""
|
||||
from src.models.processors import FillNAProcessor
|
||||
from src.pipeline.processors import FillNAProcessor
|
||||
|
||||
processor = FillNAProcessor(columns=["a"], method="mean")
|
||||
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
|
||||
@@ -258,7 +258,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_standard_scaler(self):
|
||||
"""测试标准化处理器"""
|
||||
from src.models.processors import StandardScaler
|
||||
from src.pipeline.processors import StandardScaler
|
||||
|
||||
processor = StandardScaler(columns=["value"])
|
||||
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
|
||||
@@ -271,7 +271,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_winsorizer(self):
|
||||
"""测试缩尾处理器"""
|
||||
from src.models.processors import Winsorizer
|
||||
from src.pipeline.processors import Winsorizer
|
||||
|
||||
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
|
||||
df = pl.DataFrame(
|
||||
@@ -288,7 +288,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_rank_transformer(self):
|
||||
"""测试排名转换处理器"""
|
||||
from src.models.processors import RankTransformer
|
||||
from src.pipeline.processors import RankTransformer
|
||||
|
||||
processor = RankTransformer(columns=["value"])
|
||||
df = pl.DataFrame(
|
||||
@@ -302,7 +302,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_neutralizer(self):
|
||||
"""测试中性化处理器"""
|
||||
from src.models.processors import Neutralizer
|
||||
from src.pipeline.processors import Neutralizer
|
||||
|
||||
processor = Neutralizer(columns=["value"], group_col="industry")
|
||||
df = pl.DataFrame(
|
||||
@@ -331,7 +331,7 @@ class TestProcessingPipeline:
|
||||
|
||||
def test_pipeline_fit_transform(self):
|
||||
"""测试流水线的 fit_transform"""
|
||||
from src.models.processors import StandardScaler
|
||||
from src.pipeline.processors import StandardScaler
|
||||
|
||||
scaler1 = StandardScaler(columns=["a"])
|
||||
scaler2 = StandardScaler(columns=["b"])
|
||||
@@ -348,7 +348,7 @@ class TestProcessingPipeline:
|
||||
|
||||
def test_pipeline_transform_uses_fitted_params(self):
|
||||
"""测试 transform 使用已 fit 的参数"""
|
||||
from src.models.processors import StandardScaler
|
||||
from src.pipeline.processors import StandardScaler
|
||||
|
||||
scaler = StandardScaler(columns=["value"])
|
||||
pipeline = ProcessingPipeline([scaler])
|
||||
@@ -383,7 +383,7 @@ class TestSplitters:
|
||||
|
||||
def test_time_series_split(self):
|
||||
"""测试时间序列划分"""
|
||||
from src.models.core import TimeSeriesSplit
|
||||
from src.pipeline.core import TimeSeriesSplit
|
||||
|
||||
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
|
||||
|
||||
@@ -406,7 +406,7 @@ class TestSplitters:
|
||||
|
||||
def test_walk_forward_split(self):
|
||||
"""测试滚动前向划分"""
|
||||
from src.models.core import WalkForwardSplit
|
||||
from src.pipeline.core import WalkForwardSplit
|
||||
|
||||
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
|
||||
|
||||
@@ -426,7 +426,7 @@ class TestSplitters:
|
||||
|
||||
def test_expanding_window_split(self):
|
||||
"""测试扩展窗口划分"""
|
||||
from src.models.core import ExpandingWindowSplit
|
||||
from src.pipeline.core import ExpandingWindowSplit
|
||||
|
||||
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
|
||||
|
||||
@@ -455,7 +455,7 @@ class TestModels:
|
||||
@pytest.mark.skip(reason="需要安装 lightgbm")
|
||||
def test_lightgbm_model(self):
|
||||
"""测试 LightGBM 模型"""
|
||||
from src.models.models import LightGBMModel
|
||||
from src.pipeline.models import LightGBMModel
|
||||
|
||||
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})
|
||||
|
||||
@@ -1,277 +1,163 @@
|
||||
"""Tests for data synchronization module.
|
||||
"""Sync 接口测试规范与实现。
|
||||
|
||||
Tests the sync module's full/incremental sync logic for daily data:
|
||||
- Full sync when local data doesn't exist (from 20180101)
|
||||
- Incremental sync when local data exists (from last_date + 1)
|
||||
- Data integrity validation
|
||||
【测试规范】
|
||||
1. 所有 sync 测试只使用 2018-01-01 到 2018-04-01 的数据
|
||||
2. 只测试接口是否能正常返回数据,不测试落库逻辑
|
||||
3. 对于按股票查询的接口,只测试 000001.SZ、000002.SZ 两支股票
|
||||
4. 使用真实 API 调用,确保接口可用性
|
||||
|
||||
【测试范围】
|
||||
- get_daily: 日线数据接口(按股票)
|
||||
- sync_all_stocks: 股票基础信息接口
|
||||
- sync_trade_cal_cache: 交易日历接口
|
||||
- sync_namechange: 名称变更接口
|
||||
- sync_bak_basic: 备用股票基础信息接口
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
|
||||
from src.data.sync import (
|
||||
DataSync,
|
||||
sync_all,
|
||||
get_today_date,
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import ThreadSafeStorage
|
||||
from src.data.client import TushareClient
|
||||
# 测试用常量
|
||||
TEST_START_DATE = "20180101"
|
||||
TEST_END_DATE = "20180401"
|
||||
TEST_STOCK_CODES = ["000001.SZ", "000002.SZ"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
"""Create a mock storage instance."""
|
||||
storage = Mock(spec=ThreadSafeStorage)
|
||||
storage.exists = Mock(return_value=False)
|
||||
storage.load = Mock(return_value=pd.DataFrame())
|
||||
storage.save = Mock(return_value={"status": "success", "rows": 0})
|
||||
return storage
|
||||
class TestGetDaily:
|
||||
"""测试日线数据 get 接口(按股票查询)."""
|
||||
|
||||
def test_get_daily_single_stock(self):
|
||||
"""测试 get_daily 获取单只股票数据."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client():
|
||||
"""Create a mock client instance."""
|
||||
return Mock(spec=TushareClient)
|
||||
result = get_daily(
|
||||
ts_code=TEST_STOCK_CODES[0],
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "get_daily 应返回 DataFrame"
|
||||
assert not result.empty, "get_daily 应返回非空数据"
|
||||
|
||||
class TestDateUtilities:
|
||||
"""Test date utility functions."""
|
||||
def test_get_daily_has_required_columns(self):
|
||||
"""测试 get_daily 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
def test_get_today_date_format(self):
|
||||
"""Test today date is in YYYYMMDD format."""
|
||||
result = get_today_date()
|
||||
assert len(result) == 8
|
||||
assert result.isdigit()
|
||||
result = get_daily(
|
||||
ts_code=TEST_STOCK_CODES[0],
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
def test_get_next_date(self):
|
||||
"""Test getting next date."""
|
||||
result = get_next_date("20240101")
|
||||
assert result == "20240102"
|
||||
# 验证必要的列存在
|
||||
required_columns = ["ts_code", "trade_date", "open", "high", "low", "close"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"get_daily 返回应包含 {col} 列"
|
||||
|
||||
def test_get_next_date_year_end(self):
|
||||
"""Test getting next date across year boundary."""
|
||||
result = get_next_date("20241231")
|
||||
assert result == "20250101"
|
||||
def test_get_daily_multiple_stocks(self):
|
||||
"""测试 get_daily 获取多只股票数据."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
def test_get_next_date_month_end(self):
|
||||
"""Test getting next date across month boundary."""
|
||||
result = get_next_date("20240131")
|
||||
assert result == "20240201"
|
||||
|
||||
|
||||
class TestDataSync:
|
||||
"""Test DataSync class functionality."""
|
||||
|
||||
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
||||
"""Test getting stock codes from daily data."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
|
||||
}
|
||||
results = {}
|
||||
for code in TEST_STOCK_CODES:
|
||||
result = get_daily(
|
||||
ts_code=code,
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
assert "000001.SZ" in codes
|
||||
assert "600000.SH" in codes
|
||||
|
||||
def test_get_all_stock_codes_fallback(self, mock_storage):
|
||||
"""Test fallback to stock_basic when daily is empty."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
# First call (daily) returns empty, second call (stock_basic) returns data
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(), # daily empty
|
||||
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
|
||||
]
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
|
||||
def test_get_global_last_date(self, mock_storage):
|
||||
"""Test getting global last date."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "600000.SH"],
|
||||
"trade_date": ["20240102", "20240103"],
|
||||
}
|
||||
results[code] = result
|
||||
assert isinstance(result, pd.DataFrame), (
|
||||
f"get_daily({code}) 应返回 DataFrame"
|
||||
)
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date == "20240103"
|
||||
|
||||
def test_get_global_last_date_empty(self, mock_storage):
|
||||
"""Test getting last date from empty storage."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date is None
|
||||
|
||||
def test_sync_single_stock(self, mock_storage):
|
||||
"""Test syncing a single stock."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch(
|
||||
"src.data.sync.get_daily",
|
||||
return_value=pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
),
|
||||
):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_sync_single_stock_empty(self, mock_storage):
|
||||
"""Test syncing a stock with no data."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code="INVALID.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert result.empty
|
||||
assert not result.empty, f"get_daily({code}) 应返回非空数据"
|
||||
|
||||
|
||||
class TestSyncAll:
|
||||
"""Test sync_all function."""
|
||||
class TestSyncStockBasic:
|
||||
"""测试股票基础信息 sync 接口."""
|
||||
|
||||
def test_full_sync_mode(self, mock_storage):
|
||||
"""Test full sync mode when force_full=True."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
def test_sync_all_stocks_returns_data(self):
|
||||
"""测试 sync_all_stocks 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
result = sync_all_stocks()
|
||||
|
||||
result = sync.sync_all(force_full=True)
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_all_stocks 应返回 DataFrame"
|
||||
assert not result.empty, "sync_all_stocks 应返回非空数据"
|
||||
|
||||
# Verify sync_single_stock was called with default start date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == DEFAULT_START_DATE
|
||||
def test_sync_all_stocks_has_required_columns(self):
|
||||
"""测试 sync_all_stocks 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
|
||||
|
||||
def test_incremental_sync_mode(self, mock_storage):
|
||||
"""Test incremental sync mode when data exists."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
result = sync_all_stocks()
|
||||
|
||||
# Mock existing data with last date
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_all_stock_codes
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_global_last_date
|
||||
]
|
||||
|
||||
result = sync.sync_all(force_full=False)
|
||||
|
||||
# Verify sync_single_stock was called with next date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == "20240103"
|
||||
|
||||
def test_manual_start_date(self, mock_storage):
|
||||
"""Test sync with manual start date."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
|
||||
result = sync.sync_all(force_full=False, start_date="20230601")
|
||||
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == "20230601"
|
||||
|
||||
def test_no_stocks_found(self, mock_storage):
|
||||
"""Test sync when no stocks are found."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
|
||||
result = sync.sync_all()
|
||||
|
||||
assert result == {}
|
||||
# 验证必要的列存在
|
||||
required_columns = ["ts_code"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"sync_all_stocks 返回应包含 {col} 列"
|
||||
|
||||
|
||||
class TestSyncAllConvenienceFunction:
|
||||
"""Test sync_all convenience function."""
|
||||
class TestSyncTradeCal:
|
||||
"""测试交易日历 sync 接口."""
|
||||
|
||||
def test_sync_all_function(self):
|
||||
"""Test sync_all convenience function."""
|
||||
with patch("src.data.sync.DataSync") as MockSync:
|
||||
mock_instance = Mock()
|
||||
mock_instance.sync_all.return_value = {}
|
||||
MockSync.return_value = mock_instance
|
||||
def test_sync_trade_cal_cache_returns_data(self):
|
||||
"""测试 sync_trade_cal_cache 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
|
||||
|
||||
result = sync_all(force_full=True)
|
||||
result = sync_trade_cal_cache(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
MockSync.assert_called_once()
|
||||
mock_instance.sync_all.assert_called_once_with(
|
||||
force_full=True,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
dry_run=False,
|
||||
)
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_trade_cal_cache 应返回 DataFrame"
|
||||
assert not result.empty, "sync_trade_cal_cache 应返回非空数据"
|
||||
|
||||
def test_sync_trade_cal_cache_has_required_columns(self):
|
||||
"""测试 sync_trade_cal_cache 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
|
||||
|
||||
result = sync_trade_cal_cache(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证必要的列存在
|
||||
required_columns = ["cal_date", "is_open"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"sync_trade_cal_cache 返回应包含 {col} 列"
|
||||
|
||||
|
||||
class TestSyncNamechange:
|
||||
"""测试名称变更 sync 接口."""
|
||||
|
||||
def test_sync_namechange_returns_data(self):
|
||||
"""测试 sync_namechange 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_namechange import sync_namechange
|
||||
|
||||
result = sync_namechange()
|
||||
|
||||
# 验证返回了数据(可能是空 DataFrame,因为是历史变更)
|
||||
assert isinstance(result, pd.DataFrame), "sync_namechange 应返回 DataFrame"
|
||||
|
||||
|
||||
class TestSyncBakBasic:
|
||||
"""测试备用股票基础信息 sync 接口."""
|
||||
|
||||
def test_sync_bak_basic_returns_data(self):
|
||||
"""测试 sync_bak_basic 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
||||
|
||||
result = sync_bak_basic(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_bak_basic 应返回 DataFrame"
|
||||
# 注意:bak_basic 可能返回空数据,这是正常的
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
"""Tests for data sync with REAL data (read-only).
|
||||
|
||||
Tests verify:
|
||||
1. get_global_last_date() correctly reads local data's max date
|
||||
2. Incremental sync date calculation (local_last_date + 1)
|
||||
3. Full sync date calculation (20180101)
|
||||
4. Multi-stock scenario with real data
|
||||
|
||||
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
|
||||
- NO sync_all() calls (writes daily.h5)
|
||||
- NO check_sync_needed() calls (writes trade_cal.h5)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from src.data.sync import (
|
||||
DataSync,
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import Storage
|
||||
|
||||
|
||||
class TestDataSyncReadOnly:
|
||||
"""Read-only tests for data sync - verify date calculation logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self):
|
||||
"""Create storage instance."""
|
||||
return Storage()
|
||||
|
||||
@pytest.fixture
|
||||
def data_sync(self):
|
||||
"""Create DataSync instance."""
|
||||
return DataSync()
|
||||
|
||||
@pytest.fixture
|
||||
def daily_exists(self, storage):
|
||||
"""Check if daily.h5 exists."""
|
||||
return storage.exists("daily")
|
||||
|
||||
def test_daily_h5_exists(self, storage):
|
||||
"""Verify daily.h5 data file exists before running tests."""
|
||||
assert storage.exists("daily"), (
|
||||
"daily.h5 not found. Please run full sync first: "
|
||||
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
|
||||
)
|
||||
|
||||
def test_get_global_last_date(self, data_sync, daily_exists):
|
||||
"""Test get_global_last_date returns correct max date from local data."""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
last_date = data_sync.get_global_last_date()
|
||||
|
||||
# Verify it's a valid date string
|
||||
assert last_date is not None, "get_global_last_date returned None"
|
||||
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
|
||||
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
|
||||
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
|
||||
|
||||
# Verify by reading storage directly
|
||||
daily_data = data_sync.storage.load("daily")
|
||||
expected_max = str(daily_data["trade_date"].max())
|
||||
|
||||
assert last_date == expected_max, (
|
||||
f"get_global_last_date returned {last_date}, "
|
||||
f"but actual max date is {expected_max}"
|
||||
)
|
||||
|
||||
print(f"[TEST] Local data last date: {last_date}")
|
||||
|
||||
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
|
||||
"""Test incremental sync: start_date = local_last_date + 1.
|
||||
|
||||
This verifies that when local data exists, incremental sync should
|
||||
fetch data from (local_last_date + 1), not from 20180101.
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
# Get local last date
|
||||
local_last_date = data_sync.get_global_last_date()
|
||||
assert local_last_date is not None, "No local data found"
|
||||
|
||||
# Calculate expected incremental start date
|
||||
expected_start_date = get_next_date(local_last_date)
|
||||
|
||||
# Verify the calculation is correct
|
||||
local_last_int = int(local_last_date)
|
||||
expected_int = local_last_int + 1
|
||||
actual_int = int(expected_start_date)
|
||||
|
||||
assert actual_int == expected_int, (
|
||||
f"Incremental start date calculation error: "
|
||||
f"expected {expected_int}, got {actual_int}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[TEST] Incremental sync: local_last={local_last_date}, "
|
||||
f"start_date should be {expected_start_date}"
|
||||
)
|
||||
|
||||
# Verify this is NOT 20180101 (would be full sync)
|
||||
assert expected_start_date != DEFAULT_START_DATE, (
|
||||
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
|
||||
)
|
||||
|
||||
def test_full_sync_date_calculation(self):
|
||||
"""Test full sync: start_date = 20180101 when force_full=True.
|
||||
|
||||
This verifies that force_full=True always starts from 20180101.
|
||||
"""
|
||||
# Full sync should always use DEFAULT_START_DATE
|
||||
full_sync_start = DEFAULT_START_DATE
|
||||
|
||||
assert full_sync_start == "20180101", (
|
||||
f"Full sync should start from 20180101, got {full_sync_start}"
|
||||
)
|
||||
|
||||
print(f"[TEST] Full sync start date: {full_sync_start}")
|
||||
|
||||
def test_date_comparison_logic(self, data_sync, daily_exists):
|
||||
"""Test date comparison: incremental vs full sync selection logic.
|
||||
|
||||
Verify that:
|
||||
- If local_last_date < today: incremental sync needed
|
||||
- If local_last_date >= today: no sync needed
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
local_last_date = data_sync.get_global_last_date()
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
local_last_int = int(local_last_date)
|
||||
today_int = int(today)
|
||||
|
||||
# Log the comparison
|
||||
print(
|
||||
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
|
||||
f"today={today} ({today_int})"
|
||||
)
|
||||
|
||||
# This test just verifies the comparison logic works
|
||||
if local_last_int < today_int:
|
||||
print("[TEST] Local data is older than today - sync needed")
|
||||
# Incremental sync should fetch from local_last_date + 1
|
||||
sync_start = get_next_date(local_last_date)
|
||||
assert int(sync_start) > local_last_int, (
|
||||
"Sync start should be after local last"
|
||||
)
|
||||
else:
|
||||
print("[TEST] Local data is up-to-date - no sync needed")
|
||||
|
||||
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
|
||||
"""Test get_all_stock_codes returns multiple real stock codes."""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
codes = data_sync.get_all_stock_codes()
|
||||
|
||||
# Verify it's a list
|
||||
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
|
||||
assert len(codes) > 0, "No stock codes found"
|
||||
|
||||
# Verify multiple stocks
|
||||
assert len(codes) >= 10, (
|
||||
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
|
||||
)
|
||||
|
||||
# Verify format (should be like 000001.SZ, 600000.SH)
|
||||
sample_codes = codes[:5]
|
||||
for code in sample_codes:
|
||||
assert "." in code, f"Invalid stock code format: {code}"
|
||||
suffix = code.split(".")[-1]
|
||||
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
|
||||
|
||||
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
|
||||
|
||||
def test_multi_stock_date_range(self, data_sync, daily_exists):
|
||||
"""Test that multiple stocks share the same date range in local data.
|
||||
|
||||
This verifies that local data has consistent date coverage across stocks.
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
daily_data = data_sync.storage.load("daily")
|
||||
|
||||
# Get date range for each stock
|
||||
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
|
||||
|
||||
# Get global min and max
|
||||
global_min = str(daily_data["trade_date"].min())
|
||||
global_max = str(daily_data["trade_date"].max())
|
||||
|
||||
print(f"[TEST] Global date range: {global_min} to {global_max}")
|
||||
print(f"[TEST] Total stocks: {len(stock_dates)}")
|
||||
|
||||
# Verify we have data for multiple stocks
|
||||
assert len(stock_dates) >= 10, (
|
||||
f"Expected at least 10 stocks, got {len(stock_dates)}"
|
||||
)
|
||||
|
||||
# Verify date range is reasonable (at least 1 year of data)
|
||||
global_min_int = int(global_min)
|
||||
global_max_int = int(global_max)
|
||||
days_span = global_max_int - global_min_int
|
||||
|
||||
assert days_span > 100, (
|
||||
f"Date range too small: {days_span} days. "
|
||||
f"Expected at least 100 days of data."
|
||||
)
|
||||
|
||||
print(f"[TEST] Date span: {days_span} days")
|
||||
|
||||
|
||||
class TestDateUtilities:
|
||||
"""Test date utility functions."""
|
||||
|
||||
def test_get_next_date(self):
|
||||
"""Test get_next_date correctly calculates next day."""
|
||||
# Test normal cases
|
||||
assert get_next_date("20240101") == "20240102"
|
||||
assert get_next_date("20240131") == "20240201" # Month boundary
|
||||
assert get_next_date("20241231") == "20250101" # Year boundary
|
||||
|
||||
def test_incremental_vs_full_sync_logic(self):
|
||||
"""Test the logic difference between incremental and full sync.
|
||||
|
||||
Incremental: start_date = local_last_date + 1
|
||||
Full: start_date = 20180101
|
||||
"""
|
||||
# Scenario 1: Local data exists
|
||||
local_last_date = "20240115"
|
||||
incremental_start = get_next_date(local_last_date)
|
||||
|
||||
assert incremental_start == "20240116"
|
||||
assert incremental_start != DEFAULT_START_DATE
|
||||
|
||||
# Scenario 2: Force full sync
|
||||
full_sync_start = DEFAULT_START_DATE # "20180101"
|
||||
|
||||
assert full_sync_start == "20180101"
|
||||
assert incremental_start != full_sync_start
|
||||
|
||||
print("[TEST] Incremental vs Full sync logic verified")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user