refactor: 存储层迁移DuckDB + 模块重构
- 存储层重构: HDF5 → DuckDB(UPSERT模式、线程安全存储) - Sync类迁移: DataSync从sync.py迁移到api_daily.py(职责分离) - 模型模块重构: src/models → src/pipeline(更清晰的命名) - 新增因子模块: factors/momentum (MA、收益率排名)、factors/financial - 新增API接口: api_namechange、api_bak_basic - 新增训练入口: training模块(main.py、pipeline配置) - 工具函数统一: get_today_date等移至utils.py - 文档更新: AGENTS.md添加架构变更历史
This commit is contained in:
@@ -1,277 +1,163 @@
|
||||
"""Tests for data synchronization module.
|
||||
"""Sync 接口测试规范与实现。
|
||||
|
||||
Tests the sync module's full/incremental sync logic for daily data:
|
||||
- Full sync when local data doesn't exist (from 20180101)
|
||||
- Incremental sync when local data exists (from last_date + 1)
|
||||
- Data integrity validation
|
||||
【测试规范】
|
||||
1. 所有 sync 测试只使用 2018-01-01 到 2018-04-01 的数据
|
||||
2. 只测试接口是否能正常返回数据,不测试落库逻辑
|
||||
3. 对于按股票查询的接口,只测试 000001.SZ、000002.SZ 两支股票
|
||||
4. 使用真实 API 调用,确保接口可用性
|
||||
|
||||
【测试范围】
|
||||
- get_daily: 日线数据接口(按股票)
|
||||
- sync_all_stocks: 股票基础信息接口
|
||||
- sync_trade_cal_cache: 交易日历接口
|
||||
- sync_namechange: 名称变更接口
|
||||
- sync_bak_basic: 备用股票基础信息接口
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
|
||||
from src.data.sync import (
|
||||
DataSync,
|
||||
sync_all,
|
||||
get_today_date,
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import ThreadSafeStorage
|
||||
from src.data.client import TushareClient
|
||||
# 测试用常量
|
||||
TEST_START_DATE = "20180101"
|
||||
TEST_END_DATE = "20180401"
|
||||
TEST_STOCK_CODES = ["000001.SZ", "000002.SZ"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
"""Create a mock storage instance."""
|
||||
storage = Mock(spec=ThreadSafeStorage)
|
||||
storage.exists = Mock(return_value=False)
|
||||
storage.load = Mock(return_value=pd.DataFrame())
|
||||
storage.save = Mock(return_value={"status": "success", "rows": 0})
|
||||
return storage
|
||||
class TestGetDaily:
|
||||
"""测试日线数据 get 接口(按股票查询)."""
|
||||
|
||||
def test_get_daily_single_stock(self):
|
||||
"""测试 get_daily 获取单只股票数据."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client():
|
||||
"""Create a mock client instance."""
|
||||
return Mock(spec=TushareClient)
|
||||
result = get_daily(
|
||||
ts_code=TEST_STOCK_CODES[0],
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "get_daily 应返回 DataFrame"
|
||||
assert not result.empty, "get_daily 应返回非空数据"
|
||||
|
||||
class TestDateUtilities:
|
||||
"""Test date utility functions."""
|
||||
def test_get_daily_has_required_columns(self):
|
||||
"""测试 get_daily 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
def test_get_today_date_format(self):
|
||||
"""Test today date is in YYYYMMDD format."""
|
||||
result = get_today_date()
|
||||
assert len(result) == 8
|
||||
assert result.isdigit()
|
||||
result = get_daily(
|
||||
ts_code=TEST_STOCK_CODES[0],
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
def test_get_next_date(self):
|
||||
"""Test getting next date."""
|
||||
result = get_next_date("20240101")
|
||||
assert result == "20240102"
|
||||
# 验证必要的列存在
|
||||
required_columns = ["ts_code", "trade_date", "open", "high", "low", "close"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"get_daily 返回应包含 {col} 列"
|
||||
|
||||
def test_get_next_date_year_end(self):
|
||||
"""Test getting next date across year boundary."""
|
||||
result = get_next_date("20241231")
|
||||
assert result == "20250101"
|
||||
def test_get_daily_multiple_stocks(self):
|
||||
"""测试 get_daily 获取多只股票数据."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
def test_get_next_date_month_end(self):
|
||||
"""Test getting next date across month boundary."""
|
||||
result = get_next_date("20240131")
|
||||
assert result == "20240201"
|
||||
|
||||
|
||||
class TestDataSync:
|
||||
"""Test DataSync class functionality."""
|
||||
|
||||
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
||||
"""Test getting stock codes from daily data."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
|
||||
}
|
||||
results = {}
|
||||
for code in TEST_STOCK_CODES:
|
||||
result = get_daily(
|
||||
ts_code=code,
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
assert "000001.SZ" in codes
|
||||
assert "600000.SH" in codes
|
||||
|
||||
def test_get_all_stock_codes_fallback(self, mock_storage):
|
||||
"""Test fallback to stock_basic when daily is empty."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
# First call (daily) returns empty, second call (stock_basic) returns data
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(), # daily empty
|
||||
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
|
||||
]
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
|
||||
def test_get_global_last_date(self, mock_storage):
|
||||
"""Test getting global last date."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "600000.SH"],
|
||||
"trade_date": ["20240102", "20240103"],
|
||||
}
|
||||
results[code] = result
|
||||
assert isinstance(result, pd.DataFrame), (
|
||||
f"get_daily({code}) 应返回 DataFrame"
|
||||
)
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date == "20240103"
|
||||
|
||||
def test_get_global_last_date_empty(self, mock_storage):
|
||||
"""Test getting last date from empty storage."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date is None
|
||||
|
||||
def test_sync_single_stock(self, mock_storage):
|
||||
"""Test syncing a single stock."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch(
|
||||
"src.data.sync.get_daily",
|
||||
return_value=pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
),
|
||||
):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_sync_single_stock_empty(self, mock_storage):
|
||||
"""Test syncing a stock with no data."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code="INVALID.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert result.empty
|
||||
assert not result.empty, f"get_daily({code}) 应返回非空数据"
|
||||
|
||||
|
||||
class TestSyncAll:
|
||||
"""Test sync_all function."""
|
||||
class TestSyncStockBasic:
|
||||
"""测试股票基础信息 sync 接口."""
|
||||
|
||||
def test_full_sync_mode(self, mock_storage):
|
||||
"""Test full sync mode when force_full=True."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
def test_sync_all_stocks_returns_data(self):
|
||||
"""测试 sync_all_stocks 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
result = sync_all_stocks()
|
||||
|
||||
result = sync.sync_all(force_full=True)
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_all_stocks 应返回 DataFrame"
|
||||
assert not result.empty, "sync_all_stocks 应返回非空数据"
|
||||
|
||||
# Verify sync_single_stock was called with default start date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == DEFAULT_START_DATE
|
||||
def test_sync_all_stocks_has_required_columns(self):
|
||||
"""测试 sync_all_stocks 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
|
||||
|
||||
def test_incremental_sync_mode(self, mock_storage):
|
||||
"""Test incremental sync mode when data exists."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
result = sync_all_stocks()
|
||||
|
||||
# Mock existing data with last date
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_all_stock_codes
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_global_last_date
|
||||
]
|
||||
|
||||
result = sync.sync_all(force_full=False)
|
||||
|
||||
# Verify sync_single_stock was called with next date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == "20240103"
|
||||
|
||||
def test_manual_start_date(self, mock_storage):
|
||||
"""Test sync with manual start date."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
|
||||
result = sync.sync_all(force_full=False, start_date="20230601")
|
||||
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == "20230601"
|
||||
|
||||
def test_no_stocks_found(self, mock_storage):
|
||||
"""Test sync when no stocks are found."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
|
||||
result = sync.sync_all()
|
||||
|
||||
assert result == {}
|
||||
# 验证必要的列存在
|
||||
required_columns = ["ts_code"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"sync_all_stocks 返回应包含 {col} 列"
|
||||
|
||||
|
||||
class TestSyncAllConvenienceFunction:
|
||||
"""Test sync_all convenience function."""
|
||||
class TestSyncTradeCal:
|
||||
"""测试交易日历 sync 接口."""
|
||||
|
||||
def test_sync_all_function(self):
|
||||
"""Test sync_all convenience function."""
|
||||
with patch("src.data.sync.DataSync") as MockSync:
|
||||
mock_instance = Mock()
|
||||
mock_instance.sync_all.return_value = {}
|
||||
MockSync.return_value = mock_instance
|
||||
def test_sync_trade_cal_cache_returns_data(self):
|
||||
"""测试 sync_trade_cal_cache 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
|
||||
|
||||
result = sync_all(force_full=True)
|
||||
result = sync_trade_cal_cache(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
MockSync.assert_called_once()
|
||||
mock_instance.sync_all.assert_called_once_with(
|
||||
force_full=True,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
dry_run=False,
|
||||
)
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_trade_cal_cache 应返回 DataFrame"
|
||||
assert not result.empty, "sync_trade_cal_cache 应返回非空数据"
|
||||
|
||||
def test_sync_trade_cal_cache_has_required_columns(self):
|
||||
"""测试 sync_trade_cal_cache 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
|
||||
|
||||
result = sync_trade_cal_cache(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证必要的列存在
|
||||
required_columns = ["cal_date", "is_open"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"sync_trade_cal_cache 返回应包含 {col} 列"
|
||||
|
||||
|
||||
class TestSyncNamechange:
|
||||
"""测试名称变更 sync 接口."""
|
||||
|
||||
def test_sync_namechange_returns_data(self):
|
||||
"""测试 sync_namechange 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_namechange import sync_namechange
|
||||
|
||||
result = sync_namechange()
|
||||
|
||||
# 验证返回了数据(可能是空 DataFrame,因为是历史变更)
|
||||
assert isinstance(result, pd.DataFrame), "sync_namechange 应返回 DataFrame"
|
||||
|
||||
|
||||
class TestSyncBakBasic:
|
||||
"""测试备用股票基础信息 sync 接口."""
|
||||
|
||||
def test_sync_bak_basic_returns_data(self):
|
||||
"""测试 sync_bak_basic 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
||||
|
||||
result = sync_bak_basic(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_bak_basic 应返回 DataFrame"
|
||||
# 注意:bak_basic 可能返回空数据,这是正常的
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user