"""Tests for cyq_perf API wrapper. Tests for src.data.api_wrappers.api_cyq_perf module. """ import pytest import pandas as pd from unittest.mock import patch, MagicMock from src.data.api_wrappers.api_cyq_perf import ( get_cyq_perf, sync_cyq_perf, preview_cyq_perf_sync, CyqPerfSync, ) class TestCyqPerf: """Test suite for cyq_perf API wrapper.""" @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") def test_get_cyq_perf_by_stock(self, mock_client_class): """Test fetching chip distribution data by stock code.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240101"], "his_low": [8.50], "his_high": [12.00], "cost_5pct": [8.80], "cost_15pct": [9.20], "cost_50pct": [10.00], "cost_85pct": [10.80], "cost_95pct": [11.20], "weight_avg": [10.00], "winner_rate": [5.50], } ) # Test result = get_cyq_perf( ts_code="000001.SZ", start_date="20240101", end_date="20240131", ) # Assert assert not result.empty assert "ts_code" in result.columns assert "trade_date" in result.columns assert "cost_50pct" in result.columns assert "winner_rate" in result.columns assert result["ts_code"].iloc[0] == "000001.SZ" mock_client.query.assert_called_once() # Verify parameters call_args = mock_client.query.call_args assert call_args[0][0] == "cyq_perf" assert call_args[1]["ts_code"] == "000001.SZ" assert call_args[1]["start_date"] == "20240101" assert call_args[1]["end_date"] == "20240131" @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") def test_get_cyq_perf_with_shared_client(self, mock_client_class): """Test fetching data with shared client for rate limiting.""" # Setup mock for shared client shared_client = MagicMock() shared_client.query.return_value = pd.DataFrame( { "ts_code": ["600000.SH"], "trade_date": ["20240115"], "his_low": [5.00], "his_high": [8.00], "cost_5pct": [5.50], "cost_15pct": [6.00], "cost_50pct": [6.50], "cost_85pct": [7.00], "cost_95pct": [7.50], "weight_avg": [6.50], "winner_rate": [3.20], } ) # Test with shared client result = get_cyq_perf( ts_code="600000.SH", start_date="20240101", end_date="20240131", client=shared_client, ) # Assert shared client was used, not new instance mock_client_class.assert_not_called() shared_client.query.assert_called_once() assert not result.empty assert result["ts_code"].iloc[0] == "600000.SH" @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") def test_get_cyq_perf_empty_response(self, mock_client_class): """Test handling empty response.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame() result = get_cyq_perf( ts_code="000001.SZ", start_date="20240101", end_date="20240131", ) assert result.empty @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") def test_get_cyq_perf_date_column_rename(self, mock_client_class): """Test that 'date' column is renamed to 'trade_date'.""" mock_client = MagicMock() mock_client_class.return_value = mock_client # Return data with 'date' column instead of 'trade_date' mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "date": ["20240101"], # Note: 'date' not 'trade_date' "cost_50pct": [10.00], "winner_rate": [5.50], } ) result = get_cyq_perf(ts_code="000001.SZ") # Assert 'date' was renamed to 'trade_date' assert "trade_date" in result.columns assert "date" not in result.columns assert result["trade_date"].iloc[0] == "20240101" class TestCyqPerfSync: """Test suite for CyqPerfSync class.""" @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") @patch("src.data.api_wrappers.base_sync.Storage") @patch("src.data.api_wrappers.base_sync.ThreadSafeStorage") @patch("src.data.api_wrappers.base_sync.sync_trade_cal_cache") @patch("src.data.api_wrappers.base_sync.sync_all_stocks") def test_cyq_perf_sync_class_structure( self, mock_sync_stocks, mock_sync_cal, mock_storage_class, mock_base_storage_class, mock_client_class, ): """Test CyqPerfSync class structure and attributes.""" # Verify class attributes assert CyqPerfSync.table_name == "cyq_perf" assert "ts_code" in CyqPerfSync.TABLE_SCHEMA assert "trade_date" in CyqPerfSync.TABLE_SCHEMA assert "cost_5pct" in CyqPerfSync.TABLE_SCHEMA assert "cost_95pct" in CyqPerfSync.TABLE_SCHEMA assert "winner_rate" in CyqPerfSync.TABLE_SCHEMA assert CyqPerfSync.PRIMARY_KEY == ("ts_code", "trade_date") @patch("src.data.api_wrappers.api_cyq_perf.get_cyq_perf") def test_fetch_single_stock(self, mock_get_cyq_perf): """Test fetch_single_stock method.""" mock_get_cyq_perf.return_value = pd.DataFrame( { "ts_code": ["000001.SZ", "000001.SZ"], "trade_date": ["20240101", "20240102"], "cost_50pct": [10.00, 10.10], "winner_rate": [5.50, 5.60], } ) sync = CyqPerfSync() # Mock the client to avoid real initialization sync.client = MagicMock() result = sync.fetch_single_stock( ts_code="000001.SZ", start_date="20240101", end_date="20240102", ) assert not result.empty assert len(result) == 2 mock_get_cyq_perf.assert_called_once() class TestSyncCyqPerf: """Test suite for sync_cyq_perf convenience function.""" @patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync") def test_sync_cyq_perf_calls_sync_all(self, mock_sync_class): """Test that sync_cyq_perf calls sync_all on CyqPerfSync.""" mock_sync_instance = MagicMock() mock_sync_class.return_value = mock_sync_instance mock_sync_instance.sync_all.return_value = { "000001.SZ": pd.DataFrame({"ts_code": ["000001.SZ"]}) } result = sync_cyq_perf() mock_sync_class.assert_called_once_with(max_workers=None) mock_sync_instance.sync_all.assert_called_once() assert isinstance(result, dict) @patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync") def test_sync_cyq_perf_with_params(self, mock_sync_class): """Test sync_cyq_perf with parameters.""" mock_sync_instance = MagicMock() mock_sync_class.return_value = mock_sync_instance mock_sync_instance.sync_all.return_value = {} result = sync_cyq_perf( force_full=True, start_date="20240101", end_date="20240131", max_workers=20, dry_run=True, ) mock_sync_class.assert_called_once_with(max_workers=20) mock_sync_instance.sync_all.assert_called_once_with( force_full=True, start_date="20240101", end_date="20240131", dry_run=True, ) class TestPreviewCyqPerfSync: """Test suite for preview_cyq_perf_sync convenience function.""" @patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync") def test_preview_cyq_perf_sync(self, mock_sync_class): """Test preview_cyq_perf_sync function.""" mock_sync_instance = MagicMock() mock_sync_class.return_value = mock_sync_instance mock_sync_instance.preview_sync.return_value = { "sync_needed": True, "stock_count": 5000, "start_date": "20240101", "end_date": "20240131", "estimated_records": 100000, "sample_data": pd.DataFrame(), "mode": "incremental", } result = preview_cyq_perf_sync() mock_sync_class.assert_called_once_with() mock_sync_instance.preview_sync.assert_called_once() assert result["sync_needed"] is True assert result["stock_count"] == 5000 @patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync") def test_preview_cyq_perf_sync_with_params(self, mock_sync_class): """Test preview with custom parameters.""" mock_sync_instance = MagicMock() mock_sync_class.return_value = mock_sync_instance mock_sync_instance.preview_sync.return_value = {} preview_cyq_perf_sync( force_full=True, start_date="20240101", end_date="20240131", sample_size=5, ) mock_sync_instance.preview_sync.assert_called_once_with( force_full=True, start_date="20240101", end_date="20240131", sample_size=5, )