"""Test suite for stock_st API wrapper.""" import pytest import pandas as pd from unittest.mock import patch, MagicMock from src.data.api_wrappers.api_stock_st import get_stock_st, sync_stock_st, StockSTSync class TestStockST: """Test suite for stock_st API wrapper.""" @patch("src.data.api_wrappers.api_stock_st.TushareClient") def test_get_by_date(self, mock_client_class): """Test fetching ST stock list by date.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["300313.SZ", "605081.SH", "300391.SZ"], "name": ["*ST天山", "*ST太和", "*ST长药"], "trade_date": ["20240101", "20240101", "20240101"], "type": ["ST", "ST", "ST"], "type_name": ["风险警示板", "风险警示板", "风险警示板"], } ) # Test result = get_stock_st(trade_date="20240101") # Assert assert not result.empty assert len(result) == 3 assert "ts_code" in result.columns assert "name" in result.columns assert "trade_date" in result.columns assert "type" in result.columns assert "type_name" in result.columns mock_client.query.assert_called_once() @patch("src.data.api_wrappers.api_stock_st.TushareClient") def test_get_by_stock(self, mock_client_class): """Test fetching ST history by stock code.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["300313.SZ", "300313.SZ"], "name": ["*ST天山", "*ST天山"], "trade_date": ["20240101", "20240102"], "type": ["ST", "ST"], "type_name": ["风险警示板", "风险警示板"], } ) # Test result = get_stock_st( ts_code="300313.SZ", start_date="20240101", end_date="20240102" ) # Assert assert not result.empty assert len(result) == 2 mock_client.query.assert_called_once() @patch("src.data.api_wrappers.api_stock_st.TushareClient") def test_empty_response(self, mock_client_class): """Test handling empty response.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame() result = get_stock_st(trade_date="20240101") assert result.empty @patch("src.data.api_wrappers.api_stock_st.TushareClient") def test_get_by_date_range(self, mock_client_class): """Test fetching ST stock list by date range.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["300313.SZ"], "name": ["*ST天山"], "trade_date": ["20240101"], "type": ["ST"], "type_name": ["风险警示板"], } ) # Test result = get_stock_st(start_date="20240101", end_date="20240131") # Assert assert not result.empty mock_client.query.assert_called_once() class TestStockSTSync: """Test suite for StockSTSync class.""" def test_sync_class_attributes(self): """Test that sync class has correct attributes.""" sync = StockSTSync() assert sync.table_name == "stock_st" assert sync.default_start_date == "20160101" assert "ts_code" in sync.TABLE_SCHEMA assert "trade_date" in sync.TABLE_SCHEMA assert "name" in sync.TABLE_SCHEMA assert "type" in sync.TABLE_SCHEMA assert "type_name" in sync.TABLE_SCHEMA assert sync.PRIMARY_KEY == ("trade_date", "ts_code") @patch("src.data.api_wrappers.api_stock_st.TushareClient") def test_fetch_single_date(self, mock_client_class): """Test fetching single date data.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["300313.SZ"], "name": ["*ST天山"], "trade_date": ["20240101"], "type": ["ST"], "type_name": ["风险警示板"], } ) # Test sync = StockSTSync() result = sync.fetch_single_date("20240101") # Assert assert not result.empty assert len(result) == 1 if __name__ == "__main__": pytest.main([__file__, "-v"])