256 lines
8.8 KiB
Python
256 lines
8.8 KiB
Python
|
|
"""Tests for data synchronization module.
|
||
|
|
|
||
|
|
Tests the sync module's full/incremental sync logic for daily data:
|
||
|
|
- Full sync when local data doesn't exist (from 20180101)
|
||
|
|
- Incremental sync when local data exists (from last_date + 1)
|
||
|
|
- Data integrity validation
|
||
|
|
"""
|
||
|
|
import pytest
|
||
|
|
import pandas as pd
|
||
|
|
from unittest.mock import Mock, patch, MagicMock
|
||
|
|
from datetime import datetime, timedelta
|
||
|
|
|
||
|
|
from src.data.sync import (
|
||
|
|
DataSync,
|
||
|
|
sync_all,
|
||
|
|
get_today_date,
|
||
|
|
get_next_date,
|
||
|
|
DEFAULT_START_DATE,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestDateUtilities:
|
||
|
|
"""Test date utility functions."""
|
||
|
|
|
||
|
|
def test_get_today_date_format(self):
|
||
|
|
"""Test today date is in YYYYMMDD format."""
|
||
|
|
result = get_today_date()
|
||
|
|
assert len(result) == 8
|
||
|
|
assert result.isdigit()
|
||
|
|
|
||
|
|
def test_get_next_date(self):
|
||
|
|
"""Test getting next date."""
|
||
|
|
result = get_next_date("20240101")
|
||
|
|
assert result == "20240102"
|
||
|
|
|
||
|
|
def test_get_next_date_year_end(self):
|
||
|
|
"""Test getting next date across year boundary."""
|
||
|
|
result = get_next_date("20241231")
|
||
|
|
assert result == "20250101"
|
||
|
|
|
||
|
|
def test_get_next_date_month_end(self):
|
||
|
|
"""Test getting next date across month boundary."""
|
||
|
|
result = get_next_date("20240131")
|
||
|
|
assert result == "20240201"
|
||
|
|
|
||
|
|
|
||
|
|
class TestDataSync:
|
||
|
|
"""Test DataSync class functionality."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_storage(self):
|
||
|
|
"""Create a mock storage instance."""
|
||
|
|
storage = Mock(spec=Storage)
|
||
|
|
storage.exists = Mock(return_value=False)
|
||
|
|
storage.load = Mock(return_value=pd.DataFrame())
|
||
|
|
storage.save = Mock(return_value={"status": "success", "rows": 0})
|
||
|
|
return storage
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_client(self):
|
||
|
|
"""Create a mock client instance."""
|
||
|
|
return Mock(spec=TushareClient)
|
||
|
|
|
||
|
|
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
||
|
|
"""Test getting stock codes from daily data."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
|
||
|
|
mock_storage.load.return_value = pd.DataFrame({
|
||
|
|
'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'],
|
||
|
|
})
|
||
|
|
|
||
|
|
codes = sync.get_all_stock_codes()
|
||
|
|
|
||
|
|
assert len(codes) == 2
|
||
|
|
assert '000001.SZ' in codes
|
||
|
|
assert '600000.SH' in codes
|
||
|
|
|
||
|
|
def test_get_all_stock_codes_fallback(self, mock_storage):
|
||
|
|
"""Test fallback to stock_basic when daily is empty."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
|
||
|
|
# First call (daily) returns empty, second call (stock_basic) returns data
|
||
|
|
mock_storage.load.side_effect = [
|
||
|
|
pd.DataFrame(), # daily empty
|
||
|
|
pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic
|
||
|
|
]
|
||
|
|
|
||
|
|
codes = sync.get_all_stock_codes()
|
||
|
|
|
||
|
|
assert len(codes) == 2
|
||
|
|
|
||
|
|
def test_get_global_last_date(self, mock_storage):
|
||
|
|
"""Test getting global last date."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
|
||
|
|
mock_storage.load.return_value = pd.DataFrame({
|
||
|
|
'ts_code': ['000001.SZ', '600000.SH'],
|
||
|
|
'trade_date': ['20240102', '20240103'],
|
||
|
|
})
|
||
|
|
|
||
|
|
last_date = sync.get_global_last_date()
|
||
|
|
assert last_date == '20240103'
|
||
|
|
|
||
|
|
def test_get_global_last_date_empty(self, mock_storage):
|
||
|
|
"""Test getting last date from empty storage."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
|
||
|
|
mock_storage.load.return_value = pd.DataFrame()
|
||
|
|
|
||
|
|
last_date = sync.get_global_last_date()
|
||
|
|
assert last_date is None
|
||
|
|
|
||
|
|
def test_sync_single_stock(self, mock_storage):
|
||
|
|
"""Test syncing a single stock."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
with patch('src.data.sync.get_daily', return_value=pd.DataFrame({
|
||
|
|
'ts_code': ['000001.SZ'],
|
||
|
|
'trade_date': ['20240102'],
|
||
|
|
})):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
|
||
|
|
result = sync.sync_single_stock(
|
||
|
|
ts_code='000001.SZ',
|
||
|
|
start_date='20240101',
|
||
|
|
end_date='20240102',
|
||
|
|
)
|
||
|
|
|
||
|
|
assert isinstance(result, pd.DataFrame)
|
||
|
|
assert len(result) == 1
|
||
|
|
|
||
|
|
def test_sync_single_stock_empty(self, mock_storage):
|
||
|
|
"""Test syncing a stock with no data."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
|
||
|
|
result = sync.sync_single_stock(
|
||
|
|
ts_code='INVALID.SZ',
|
||
|
|
start_date='20240101',
|
||
|
|
end_date='20240102',
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.empty
|
||
|
|
|
||
|
|
|
||
|
|
class TestSyncAll:
|
||
|
|
"""Test sync_all function."""
|
||
|
|
|
||
|
|
def test_full_sync_mode(self, mock_storage):
|
||
|
|
"""Test full sync mode when force_full=True."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||
|
|
|
||
|
|
mock_storage.load.return_value = pd.DataFrame({
|
||
|
|
'ts_code': ['000001.SZ'],
|
||
|
|
})
|
||
|
|
|
||
|
|
result = sync.sync_all(force_full=True)
|
||
|
|
|
||
|
|
# Verify sync_single_stock was called with default start date
|
||
|
|
sync.sync_single_stock.assert_called_once()
|
||
|
|
call_args = sync.sync_single_stock.call_args
|
||
|
|
assert call_args[1]['start_date'] == DEFAULT_START_DATE
|
||
|
|
|
||
|
|
def test_incremental_sync_mode(self, mock_storage):
|
||
|
|
"""Test incremental sync mode when data exists."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||
|
|
|
||
|
|
# Mock existing data with last date
|
||
|
|
mock_storage.load.side_effect = [
|
||
|
|
pd.DataFrame({
|
||
|
|
'ts_code': ['000001.SZ'],
|
||
|
|
'trade_date': ['20240102'],
|
||
|
|
}), # get_all_stock_codes
|
||
|
|
pd.DataFrame({
|
||
|
|
'ts_code': ['000001.SZ'],
|
||
|
|
'trade_date': ['20240102'],
|
||
|
|
}), # get_global_last_date
|
||
|
|
]
|
||
|
|
|
||
|
|
result = sync.sync_all(force_full=False)
|
||
|
|
|
||
|
|
# Verify sync_single_stock was called with next date
|
||
|
|
sync.sync_single_stock.assert_called_once()
|
||
|
|
call_args = sync.sync_single_stock.call_args
|
||
|
|
assert call_args[1]['start_date'] == '20240103'
|
||
|
|
|
||
|
|
def test_manual_start_date(self, mock_storage):
|
||
|
|
"""Test sync with manual start date."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||
|
|
|
||
|
|
mock_storage.load.return_value = pd.DataFrame({
|
||
|
|
'ts_code': ['000001.SZ'],
|
||
|
|
})
|
||
|
|
|
||
|
|
result = sync.sync_all(force_full=False, start_date='20230601')
|
||
|
|
|
||
|
|
sync.sync_single_stock.assert_called_once()
|
||
|
|
call_args = sync.sync_single_stock.call_args
|
||
|
|
assert call_args[1]['start_date'] == '20230601'
|
||
|
|
|
||
|
|
def test_no_stocks_found(self, mock_storage):
|
||
|
|
"""Test sync when no stocks are found."""
|
||
|
|
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||
|
|
sync = DataSync()
|
||
|
|
sync.storage = mock_storage
|
||
|
|
|
||
|
|
mock_storage.load.return_value = pd.DataFrame()
|
||
|
|
|
||
|
|
result = sync.sync_all()
|
||
|
|
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
|
||
|
|
class TestSyncAllConvenienceFunction:
|
||
|
|
"""Test sync_all convenience function."""
|
||
|
|
|
||
|
|
def test_sync_all_function(self):
|
||
|
|
"""Test sync_all convenience function."""
|
||
|
|
with patch('src.data.sync.DataSync') as MockSync:
|
||
|
|
mock_instance = Mock()
|
||
|
|
mock_instance.sync_all.return_value = {}
|
||
|
|
MockSync.return_value = mock_instance
|
||
|
|
|
||
|
|
result = sync_all(force_full=True)
|
||
|
|
|
||
|
|
MockSync.assert_called_once()
|
||
|
|
mock_instance.sync_all.assert_called_once_with(
|
||
|
|
force_full=True,
|
||
|
|
start_date=None,
|
||
|
|
end_date=None,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
pytest.main([__file__, '-v'])
|