Files
ProStock/tests/test_db_manager.py
liaozhaorun e58b39970c feat: HDF5迁移至DuckDB存储
- 新增DuckDB Storage与ThreadSafeStorage实现
- 新增db_manager模块支持增量同步策略
- DataLoader与Sync模块适配DuckDB
- 补充迁移相关文档与测试
- 修复README文档链接
2026-02-23 00:07:21 +08:00

378 lines
12 KiB
Python

"""Tests for DuckDB database manager and incremental sync."""
import pytest
import pandas as pd
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
from src.data.db_manager import (
TableManager,
IncrementalSync,
SyncManager,
ensure_table,
get_table_info,
sync_table,
)
class TestTableManager:
"""Test table creation and management."""
@pytest.fixture
def mock_storage(self):
"""Create a mock storage instance."""
storage = Mock()
storage._connection = Mock()
storage.exists = Mock(return_value=False)
return storage
@pytest.fixture
def sample_data(self):
"""Create sample DataFrame with ts_code and trade_date."""
return pd.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
"trade_date": ["20240101", "20240102", "20240101"],
"open": [10.0, 10.5, 20.0],
"close": [10.5, 11.0, 20.5],
"volume": [1000, 2000, 3000],
}
)
def test_create_table_from_dataframe(self, mock_storage, sample_data):
"""Test table creation from DataFrame."""
manager = TableManager(mock_storage)
result = manager.create_table_from_dataframe("daily", sample_data)
assert result is True
# Should execute CREATE TABLE
assert mock_storage._connection.execute.call_count >= 1
# Get the CREATE TABLE SQL
calls = mock_storage._connection.execute.call_args_list
create_table_call = None
for call in calls:
sql = call[0][0] if call[0] else call[1].get("sql", "")
if "CREATE TABLE" in str(sql):
create_table_call = sql
break
assert create_table_call is not None
assert "ts_code" in str(create_table_call)
assert "trade_date" in str(create_table_call)
def test_create_table_with_index(self, mock_storage, sample_data):
"""Test that composite index is created for trade_date and ts_code."""
manager = TableManager(mock_storage)
manager.create_table_from_dataframe("daily", sample_data, create_index=True)
# Check that index creation was called
calls = mock_storage._connection.execute.call_args_list
index_calls = [call for call in calls if "CREATE INDEX" in str(call)]
assert len(index_calls) > 0
def test_create_table_empty_dataframe(self, mock_storage):
"""Test that empty DataFrame is rejected."""
manager = TableManager(mock_storage)
empty_df = pd.DataFrame()
result = manager.create_table_from_dataframe("daily", empty_df)
assert result is False
mock_storage._connection.execute.assert_not_called()
def test_ensure_table_exists_creates_table(self, mock_storage, sample_data):
"""Test ensure_table_exists creates table if not exists."""
mock_storage.exists.return_value = False
manager = TableManager(mock_storage)
result = manager.ensure_table_exists("daily", sample_data)
assert result is True
mock_storage._connection.execute.assert_called()
def test_ensure_table_exists_already_exists(self, mock_storage):
"""Test ensure_table_exists returns True if table already exists."""
mock_storage.exists.return_value = True
manager = TableManager(mock_storage)
result = manager.ensure_table_exists("daily", None)
assert result is True
mock_storage._connection.execute.assert_not_called()
class TestIncrementalSync:
"""Test incremental synchronization strategies."""
@pytest.fixture
def mock_storage(self):
"""Create a mock storage instance."""
storage = Mock()
storage._connection = Mock()
storage.exists = Mock(return_value=False)
storage.get_distinct_stocks = Mock(return_value=[])
return storage
def test_sync_strategy_new_table(self, mock_storage):
"""Test strategy for non-existent table."""
mock_storage.exists.return_value = False
sync = IncrementalSync(mock_storage)
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131"
)
assert strategy == "by_date"
assert start == "20240101"
assert end == "20240131"
assert stocks is None
def test_sync_strategy_empty_table(self, mock_storage):
"""Test strategy for empty table."""
mock_storage.exists.return_value = True
sync = IncrementalSync(mock_storage)
# Mock get_table_stats to return empty
sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 0,
"max_date": None,
}
)
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131"
)
assert strategy == "by_date"
assert start == "20240101"
assert end == "20240131"
def test_sync_strategy_up_to_date(self, mock_storage):
"""Test strategy when table is already up-to-date."""
mock_storage.exists.return_value = True
sync = IncrementalSync(mock_storage)
# Mock get_table_stats to show table is up-to-date
sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 100,
"max_date": "20240131",
}
)
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131"
)
assert strategy == "none"
assert start is None
assert end is None
def test_sync_strategy_incremental_by_date(self, mock_storage):
"""Test incremental sync by date when new data available."""
mock_storage.exists.return_value = True
sync = IncrementalSync(mock_storage)
# Table has data until Jan 15
sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 100,
"max_date": "20240115",
}
)
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131"
)
assert strategy == "by_date"
assert start == "20240116" # Next day after last date
assert end == "20240131"
def test_sync_strategy_by_stock(self, mock_storage):
"""Test sync by stock for specific stocks."""
mock_storage.exists.return_value = True
mock_storage.get_distinct_stocks.return_value = ["000001.SZ"]
sync = IncrementalSync(mock_storage)
sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 100,
"max_date": "20240131",
}
)
# Request 2 stocks, but only 1 exists
strategy, start, end, stocks = sync.get_sync_strategy(
"daily", "20240101", "20240131", stock_codes=["000001.SZ", "600000.SH"]
)
assert strategy == "by_stock"
assert "600000.SH" in stocks
assert "000001.SZ" not in stocks
def test_sync_data_by_date(self, mock_storage):
"""Test syncing data by date strategy."""
mock_storage.exists.return_value = True
mock_storage.save = Mock(return_value={"status": "success", "rows": 1})
sync = IncrementalSync(mock_storage)
data = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240101"],
"close": [10.0],
}
)
result = sync.sync_data("daily", data, strategy="by_date")
assert result["status"] == "success"
def test_sync_data_empty_dataframe(self, mock_storage):
"""Test syncing empty DataFrame."""
sync = IncrementalSync(mock_storage)
empty_df = pd.DataFrame()
result = sync.sync_data("daily", empty_df)
assert result["status"] == "skipped"
class TestSyncManager:
"""Test high-level sync manager."""
@pytest.fixture
def mock_storage(self):
"""Create a mock storage instance."""
storage = Mock()
storage._connection = Mock()
storage.exists = Mock(return_value=False)
storage.save = Mock(return_value={"status": "success", "rows": 10})
storage.get_distinct_stocks = Mock(return_value=[])
return storage
def test_sync_no_sync_needed(self, mock_storage):
"""Test sync when no update is needed."""
mock_storage.exists.return_value = True
manager = SyncManager(mock_storage)
# Mock incremental_sync to return 'none' strategy
manager.incremental_sync.get_sync_strategy = Mock(
return_value=("none", None, None, None)
)
# Mock fetch function
fetch_func = Mock()
result = manager.sync("daily", fetch_func, "20240101", "20240131")
assert result["status"] == "skipped"
fetch_func.assert_not_called()
def test_sync_fetches_data(self, mock_storage):
"""Test that sync fetches data when needed."""
mock_storage.exists.return_value = False
manager = SyncManager(mock_storage)
# Mock table_manager
manager.table_manager.ensure_table_exists = Mock(return_value=True)
# Mock incremental_sync
manager.incremental_sync.get_sync_strategy = Mock(
return_value=("by_date", "20240101", "20240131", None)
)
manager.incremental_sync.sync_data = Mock(
return_value={"status": "success", "rows_inserted": 10}
)
# Mock fetch function returning data
fetch_func = Mock(
return_value=pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240101"],
}
)
)
result = manager.sync("daily", fetch_func, "20240101", "20240131")
fetch_func.assert_called_once()
assert result["status"] == "success"
def test_sync_handles_fetch_error(self, mock_storage):
"""Test error handling during data fetch."""
manager = SyncManager(mock_storage)
# Mock incremental_sync
manager.incremental_sync.get_sync_strategy = Mock(
return_value=("by_date", "20240101", "20240131", None)
)
# Mock fetch function that raises exception
fetch_func = Mock(side_effect=Exception("API Error"))
result = manager.sync("daily", fetch_func, "20240101", "20240131")
assert result["status"] == "error"
assert "API Error" in result["error"]
class TestConvenienceFunctions:
"""Test convenience functions."""
@patch("src.data.db_manager.TableManager")
def test_ensure_table(self, mock_manager_class):
"""Test ensure_table convenience function."""
mock_manager = Mock()
mock_manager.ensure_table_exists = Mock(return_value=True)
mock_manager_class.return_value = mock_manager
data = pd.DataFrame({"ts_code": ["000001.SZ"], "trade_date": ["20240101"]})
result = ensure_table("daily", data)
assert result is True
mock_manager.ensure_table_exists.assert_called_once_with("daily", data)
@patch("src.data.db_manager.IncrementalSync")
def test_get_table_info(self, mock_sync_class):
"""Test get_table_info convenience function."""
mock_sync = Mock()
mock_sync.get_table_stats = Mock(
return_value={
"exists": True,
"row_count": 100,
}
)
mock_sync_class.return_value = mock_sync
result = get_table_info("daily")
assert result["exists"] is True
assert result["row_count"] == 100
@patch("src.data.db_manager.SyncManager")
def test_sync_table(self, mock_manager_class):
"""Test sync_table convenience function."""
mock_manager = Mock()
mock_manager.sync = Mock(return_value={"status": "success", "rows": 10})
mock_manager_class.return_value = mock_manager
fetch_func = Mock()
result = sync_table("daily", fetch_func, "20240101", "20240131")
assert result["status"] == "success"
mock_manager.sync.assert_called_once()
if __name__ == "__main__":
pytest.main([__file__, "-v"])