"""Test for pro_bar (universal market) API. Tests the pro_bar interface implementation: - Backward-adjusted (后复权) data fetching - All output fields including tor, vr, and adj_factor (default behavior) - Multiple asset types support - ProBarSync batch synchronization """ import pytest import pandas as pd from unittest.mock import patch, MagicMock from src.data.api_wrappers.api_pro_bar import ( get_pro_bar, ProBarSync, sync_pro_bar, preview_pro_bar_sync, ) # Expected output fields according to api.md EXPECTED_BASE_FIELDS = [ "ts_code", # 股票代码 "trade_date", # 交易日期 "open", # 开盘价 "high", # 最高价 "low", # 最低价 "close", # 收盘价 "pre_close", # 昨收价 "change", # 涨跌额 "pct_chg", # 涨跌幅 "vol", # 成交量 "amount", # 成交额 ] EXPECTED_FACTOR_FIELDS = [ "turnover_rate", # 换手率 (tor) "volume_ratio", # 量比 (vr) ] class TestGetProBar: """Test cases for get_pro_bar function.""" @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_fetch_basic(self, mock_client_class): """Test basic pro_bar data fetch.""" # Setup mock mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240115"], "open": [10.5], "high": [11.0], "low": [10.2], "close": [10.8], "pre_close": [10.5], "change": [0.3], "pct_chg": [2.86], "vol": [100000.0], "amount": [1080000.0], } ) # Test result = get_pro_bar("000001.SZ", start_date="20240101", end_date="20240131") # Assert assert isinstance(result, pd.DataFrame) assert not result.empty assert result["ts_code"].iloc[0] == "000001.SZ" mock_client.query.assert_called_once() # Verify pro_bar API is called call_args = mock_client.query.call_args assert call_args[0][0] == "pro_bar" assert call_args[1]["ts_code"] == "000001.SZ" # Default should use hfq (backward-adjusted) assert call_args[1]["adj"] == "hfq" @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_default_backward_adjusted(self, mock_client_class): """Test that default adjustment is backward (hfq).""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240115"], "close": [100.5], } ) result = get_pro_bar("000001.SZ") call_args = mock_client.query.call_args assert call_args[1]["adj"] == "hfq" assert call_args[1]["adjfactor"] == "True" @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_default_factors_all_fields(self, mock_client_class): """Test that default factors includes tor and vr.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240115"], "close": [10.8], "turnover_rate": [2.5], "volume_ratio": [1.2], "adj_factor": [1.05], } ) result = get_pro_bar("000001.SZ") call_args = mock_client.query.call_args # Default should include both tor and vr assert call_args[1]["factors"] == "tor,vr" assert "turnover_rate" in result.columns assert "volume_ratio" in result.columns assert "adj_factor" in result.columns @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_fetch_with_custom_factors(self, mock_client_class): """Test fetch with custom factors.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240115"], "close": [10.8], "turnover_rate": [2.5], } ) # Only request tor result = get_pro_bar( "000001.SZ", start_date="20240101", end_date="20240131", factors=["tor"], ) call_args = mock_client.query.call_args assert call_args[1]["factors"] == "tor" @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_fetch_with_no_factors(self, mock_client_class): """Test fetch with no factors (empty list).""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240115"], "close": [10.8], } ) # Explicitly set factors to empty list result = get_pro_bar( "000001.SZ", start_date="20240101", end_date="20240131", factors=[], ) call_args = mock_client.query.call_args # Should not include factors parameter assert "factors" not in call_args[1] @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_fetch_with_ma(self, mock_client_class): """Test fetch with moving averages.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240115"], "close": [10.8], "ma_5": [10.5], "ma_10": [10.3], "ma_v_5": [95000.0], } ) result = get_pro_bar( "000001.SZ", start_date="20240101", end_date="20240131", ma=[5, 10], ) call_args = mock_client.query.call_args assert call_args[1]["ma"] == "5,10" assert "ma_5" in result.columns assert "ma_10" in result.columns assert "ma_v_5" in result.columns @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_fetch_index_data(self, mock_client_class): """Test fetching index data.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SH"], "trade_date": ["20240115"], "close": [2900.5], } ) result = get_pro_bar( "000001.SH", asset="I", start_date="20240101", end_date="20240131", ) call_args = mock_client.query.call_args assert call_args[1]["asset"] == "I" assert call_args[1]["ts_code"] == "000001.SH" @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_forward_adjustment(self, mock_client_class): """Test forward adjustment (qfq).""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240115"], "close": [10.8], } ) result = get_pro_bar("000001.SZ", adj="qfq") call_args = mock_client.query.call_args assert call_args[1]["adj"] == "qfq" @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_no_adjustment(self, mock_client_class): """Test no adjustment.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "trade_date": ["20240115"], "close": [10.8], } ) result = get_pro_bar("000001.SZ", adj=None) call_args = mock_client.query.call_args assert "adj" not in call_args[1] @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_empty_response(self, mock_client_class): """Test handling empty response.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame() result = get_pro_bar("INVALID.SZ") assert isinstance(result, pd.DataFrame) assert result.empty @patch("src.data.api_wrappers.api_pro_bar.TushareClient") def test_date_column_rename(self, mock_client_class): """Test that 'date' column is renamed to 'trade_date'.""" mock_client = MagicMock() mock_client_class.return_value = mock_client mock_client.query.return_value = pd.DataFrame( { "ts_code": ["000001.SZ"], "date": ["20240115"], # API returns 'date' instead of 'trade_date' "close": [10.8], } ) result = get_pro_bar("000001.SZ") assert "trade_date" in result.columns assert "date" not in result.columns assert result["trade_date"].iloc[0] == "20240115" class TestProBarSync: """Test cases for ProBarSync class.""" @patch("src.data.api_wrappers.api_pro_bar.sync_all_stocks") @patch("src.data.api_wrappers.api_pro_bar.pd.read_csv") @patch("src.data.api_wrappers.api_pro_bar._get_csv_path") def test_get_all_stock_codes(self, mock_get_path, mock_read_csv, mock_sync_stocks): """Test getting all stock codes.""" from pathlib import Path from unittest.mock import MagicMock # Create a mock path that exists mock_path = MagicMock(spec=Path) mock_path.exists.return_value = True mock_get_path.return_value = mock_path mock_read_csv.return_value = pd.DataFrame( { "ts_code": ["000001.SZ", "600000.SH"], "list_status": ["L", "L"], } ) sync = ProBarSync() codes = sync.get_all_stock_codes() assert len(codes) == 2 assert "000001.SZ" in codes assert "600000.SH" in codes @patch("src.data.api_wrappers.api_pro_bar.Storage") def test_check_sync_needed_force_full(self, mock_storage_class): """Test check_sync_needed with force_full=True.""" mock_storage = MagicMock() mock_storage_class.return_value = mock_storage mock_storage.exists.return_value = False sync = ProBarSync() needed, start, end, local_last = sync.check_sync_needed(force_full=True) assert needed is True assert start == "20180101" # DEFAULT_START_DATE assert local_last is None @patch("src.data.api_wrappers.api_pro_bar.Storage") def test_check_sync_needed_force_full(self, mock_storage_class): """Test check_sync_needed with force_full=True.""" mock_storage = MagicMock() mock_storage_class.return_value = mock_storage mock_storage.exists.return_value = False sync = ProBarSync() needed, start, end, local_last = sync.check_sync_needed(force_full=True) assert needed is True assert start == "20180101" # DEFAULT_START_DATE assert local_last is None class TestSyncProBar: """Test cases for sync_pro_bar function.""" @patch("src.data.api_wrappers.api_pro_bar.ProBarSync") def test_sync_pro_bar(self, mock_sync_class): """Test sync_pro_bar function.""" mock_sync = MagicMock() mock_sync_class.return_value = mock_sync mock_sync.sync_all.return_value = {"000001.SZ": pd.DataFrame({"close": [10.5]})} result = sync_pro_bar(force_full=True, max_workers=5) mock_sync_class.assert_called_once_with(max_workers=5) mock_sync.sync_all.assert_called_once() assert "000001.SZ" in result @patch("src.data.api_wrappers.api_pro_bar.ProBarSync") def test_preview_pro_bar_sync(self, mock_sync_class): """Test preview_pro_bar_sync function.""" mock_sync = MagicMock() mock_sync_class.return_value = mock_sync mock_sync.preview_sync.return_value = { "sync_needed": True, "stock_count": 5000, "mode": "full", } result = preview_pro_bar_sync(force_full=True) mock_sync_class.assert_called_once_with() mock_sync.preview_sync.assert_called_once() assert result["sync_needed"] is True assert result["stock_count"] == 5000 class TestProBarIntegration: """Integration tests with real Tushare API.""" def test_real_api_call(self): """Test with real API (requires valid token).""" import os token = os.environ.get("TUSHARE_TOKEN") if not token: pytest.skip("TUSHARE_TOKEN not configured") result = get_pro_bar( "000001.SZ", start_date="20240101", end_date="20240131", ) # Verify structure assert isinstance(result, pd.DataFrame) if not result.empty: # Check base fields for field in EXPECTED_BASE_FIELDS: assert field in result.columns, f"Missing base field: {field}" # Check factor fields (should be present by default) for field in EXPECTED_FACTOR_FIELDS: assert field in result.columns, f"Missing factor field: {field}" # Check adj_factor is present (default behavior) assert "adj_factor" in result.columns if __name__ == "__main__": pytest.main([__file__, "-v"])