"""Tests for stock limit price API wrapper.""" import pytest import pandas as pd from unittest.mock import patch, MagicMock from src.data.api_wrappers.api_stk_limit import ( get_stk_limit, sync_stk_limit, preview_stk_limit_sync, StkLimitSync, ) class TestStkLimit: """Test suite for stk_limit API wrapper.""" @patch("src.data.api_wrappers.api_stk_limit.TushareClient") def test_get_by_date(self, mock_client_class): """Test fetching data by trade_date.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20240625", "20240625"], "pre_close": [10.0, 20.0], "up_limit": [11.0, 22.0], "down_limit": [9.0, 18.0], } ) # Test result = get_stk_limit(trade_date="20240625") # Assert assert not result.empty assert len(result) == 2 assert "ts_code" in result.columns assert "trade_date" in result.columns assert "up_limit" in result.columns assert "down_limit" in result.columns mock_client.query.assert_called_once_with("stk_limit", trade_date="20240625") @patch("src.data.api_wrappers.api_stk_limit.TushareClient") def test_get_by_date_range(self, mock_client_class): """Test fetching data by date range.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ", "000001.SZ"], "trade_date": ["20240624", "20240625"], "pre_close": [10.0, 10.5], "up_limit": [11.0, 11.55], "down_limit": [9.0, 9.45], } ) # Test result = get_stk_limit(start_date="20240624", end_date="20240625") # Assert assert not result.empty assert len(result) == 2 mock_client.query.assert_called_once_with( "stk_limit", start_date="20240624", end_date="20240625" ) @patch("src.data.api_wrappers.api_stk_limit.TushareClient") def test_get_by_stock_code(self, mock_client_class): """Test fetching data by stock code.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240625"], "pre_close": [10.0], "up_limit": [11.0], "down_limit": [9.0], } ) # Test result = get_stk_limit(ts_code="000001.SZ", trade_date="20240625") # Assert assert not result.empty assert len(result) == 1 assert result.iloc[0]["ts_code"] == "000001.SZ" mock_client.query.assert_called_once_with( "stk_limit", trade_date="20240625", ts_code="000001.SZ" ) @patch("src.data.api_wrappers.api_stk_limit.TushareClient") def test_empty_response(self, mock_client_class): """Test handling empty response.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame() # Test result = get_stk_limit(trade_date="20240625") # Assert assert result.empty @patch("src.data.api_wrappers.api_stk_limit.TushareClient") def test_shared_client(self, mock_client_class): """Test passing shared client for rate limiting.""" # Setup mock shared_client = MagicMock() shared_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240625"], "pre_close": [10.0], "up_limit": [11.0], "down_limit": [9.0], } ) # Test result = get_stk_limit(trade_date="20240625", client=shared_client) # Assert assert not result.empty shared_client.query.assert_called_once() # Verify new client was not created mock_client_class.assert_not_called() class TestStkLimitSync: """Test suite for StkLimitSync class.""" @patch("src.data.api_wrappers.api_stk_limit.TushareClient") @patch("src.data.api_wrappers.base_sync.Storage") @patch("src.data.api_wrappers.base_sync.sync_trade_cal_cache") def test_fetch_single_date( self, mock_sync_cal, mock_storage_class, mock_client_class ): """Test fetch_single_date method.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20240625", "20240625"], "pre_close": [10.0, 20.0], "up_limit": [11.0, 22.0], "down_limit": [9.0, 18.0], } ) mock_storage = MagicMock() mock_storage_class.return_value = mock_storage mock_storage.exists.return_value = True mock_storage.load.return_value = pd.DataFrame() # Test sync = StkLimitSync() result = sync.fetch_single_date("20240625") # Assert assert not result.empty assert len(result) == 2 mock_client.query.assert_called_once_with("stk_limit", trade_date="20240625") def test_table_schema(self): """Test table schema definition.""" sync = StkLimitSync() # Assert table configuration assert sync.table_name == "stk_limit" assert "ts_code" in sync.TABLE_SCHEMA assert "trade_date" in sync.TABLE_SCHEMA assert "pre_close" in sync.TABLE_SCHEMA assert "up_limit" in sync.TABLE_SCHEMA assert "down_limit" in sync.TABLE_SCHEMA assert sync.PRIMARY_KEY == ("ts_code", "trade_date") class TestSyncFunctions: """Test suite for sync convenience functions.""" @patch.object(StkLimitSync, "sync_all") def test_sync_stk_limit(self, mock_sync_all): """Test sync_stk_limit convenience function.""" # Setup mock mock_sync_all.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240625"], "up_limit": [11.0], "down_limit": [9.0], } ) # Test result = sync_stk_limit(force_full=True) # Assert assert not result.empty mock_sync_all.assert_called_once_with( force_full=True, start_date=None, end_date=None, dry_run=False, ) @patch.object(StkLimitSync, "preview_sync") def test_preview_stk_limit_sync(self, mock_preview): """Test preview_stk_limit_sync convenience function.""" # Setup mock mock_preview.return_value = { "sync_needed": True, "date_count": 10, "start_date": "20240601", "end_date": "20240610", "estimated_records": 5000, "sample_data": pd.DataFrame(), "mode": "incremental", } # Test result = preview_stk_limit_sync() # Assert assert result["sync_needed"] is True assert result["mode"] == "incremental" mock_preview.assert_called_once_with( force_full=False, start_date=None, end_date=None, sample_size=3, ) if __name__ == "__main__": pytest.main([__file__, "-v"])