feat(data): 添加每日筹码及胜率数据接口 (cyq_perf)
- 新增 api_cyq_perf 模块,支持筹码分布数据获取和同步 - 在 sync_registry 中注册 cyq_perf 同步器
This commit is contained in:
278
tests/test_cyq_perf.py
Normal file
278
tests/test_cyq_perf.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Tests for cyq_perf API wrapper.
|
||||
|
||||
Tests for src.data.api_wrappers.api_cyq_perf module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from src.data.api_wrappers.api_cyq_perf import (
|
||||
get_cyq_perf,
|
||||
sync_cyq_perf,
|
||||
preview_cyq_perf_sync,
|
||||
CyqPerfSync,
|
||||
)
|
||||
|
||||
|
||||
class TestCyqPerf:
|
||||
"""Test suite for cyq_perf API wrapper."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||
def test_get_cyq_perf_by_stock(self, mock_client_class):
|
||||
"""Test fetching chip distribution data by stock code."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240101"],
|
||||
"his_low": [8.50],
|
||||
"his_high": [12.00],
|
||||
"cost_5pct": [8.80],
|
||||
"cost_15pct": [9.20],
|
||||
"cost_50pct": [10.00],
|
||||
"cost_85pct": [10.80],
|
||||
"cost_95pct": [11.20],
|
||||
"weight_avg": [10.00],
|
||||
"winner_rate": [5.50],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_cyq_perf(
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert "ts_code" in result.columns
|
||||
assert "trade_date" in result.columns
|
||||
assert "cost_50pct" in result.columns
|
||||
assert "winner_rate" in result.columns
|
||||
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||
mock_client.query.assert_called_once()
|
||||
|
||||
# Verify parameters
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[0][0] == "cyq_perf"
|
||||
assert call_args[1]["ts_code"] == "000001.SZ"
|
||||
assert call_args[1]["start_date"] == "20240101"
|
||||
assert call_args[1]["end_date"] == "20240131"
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||
def test_get_cyq_perf_with_shared_client(self, mock_client_class):
|
||||
"""Test fetching data with shared client for rate limiting."""
|
||||
# Setup mock for shared client
|
||||
shared_client = MagicMock()
|
||||
shared_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["600000.SH"],
|
||||
"trade_date": ["20240115"],
|
||||
"his_low": [5.00],
|
||||
"his_high": [8.00],
|
||||
"cost_5pct": [5.50],
|
||||
"cost_15pct": [6.00],
|
||||
"cost_50pct": [6.50],
|
||||
"cost_85pct": [7.00],
|
||||
"cost_95pct": [7.50],
|
||||
"weight_avg": [6.50],
|
||||
"winner_rate": [3.20],
|
||||
}
|
||||
)
|
||||
|
||||
# Test with shared client
|
||||
result = get_cyq_perf(
|
||||
ts_code="600000.SH",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
client=shared_client,
|
||||
)
|
||||
|
||||
# Assert shared client was used, not new instance
|
||||
mock_client_class.assert_not_called()
|
||||
shared_client.query.assert_called_once()
|
||||
assert not result.empty
|
||||
assert result["ts_code"].iloc[0] == "600000.SH"
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||
def test_get_cyq_perf_empty_response(self, mock_client_class):
|
||||
"""Test handling empty response."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame()
|
||||
|
||||
result = get_cyq_perf(
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
assert result.empty
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||
def test_get_cyq_perf_date_column_rename(self, mock_client_class):
|
||||
"""Test that 'date' column is renamed to 'trade_date'."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
# Return data with 'date' column instead of 'trade_date'
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"date": ["20240101"], # Note: 'date' not 'trade_date'
|
||||
"cost_50pct": [10.00],
|
||||
"winner_rate": [5.50],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_cyq_perf(ts_code="000001.SZ")
|
||||
|
||||
# Assert 'date' was renamed to 'trade_date'
|
||||
assert "trade_date" in result.columns
|
||||
assert "date" not in result.columns
|
||||
assert result["trade_date"].iloc[0] == "20240101"
|
||||
|
||||
|
||||
class TestCyqPerfSync:
|
||||
"""Test suite for CyqPerfSync class."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||
@patch("src.data.api_wrappers.base_sync.Storage")
|
||||
@patch("src.data.api_wrappers.base_sync.ThreadSafeStorage")
|
||||
@patch("src.data.api_wrappers.base_sync.sync_trade_cal_cache")
|
||||
@patch("src.data.api_wrappers.base_sync.sync_all_stocks")
|
||||
def test_cyq_perf_sync_class_structure(
|
||||
self,
|
||||
mock_sync_stocks,
|
||||
mock_sync_cal,
|
||||
mock_storage_class,
|
||||
mock_base_storage_class,
|
||||
mock_client_class,
|
||||
):
|
||||
"""Test CyqPerfSync class structure and attributes."""
|
||||
# Verify class attributes
|
||||
assert CyqPerfSync.table_name == "cyq_perf"
|
||||
assert "ts_code" in CyqPerfSync.TABLE_SCHEMA
|
||||
assert "trade_date" in CyqPerfSync.TABLE_SCHEMA
|
||||
assert "cost_5pct" in CyqPerfSync.TABLE_SCHEMA
|
||||
assert "cost_95pct" in CyqPerfSync.TABLE_SCHEMA
|
||||
assert "winner_rate" in CyqPerfSync.TABLE_SCHEMA
|
||||
assert CyqPerfSync.PRIMARY_KEY == ("ts_code", "trade_date")
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.get_cyq_perf")
|
||||
def test_fetch_single_stock(self, mock_get_cyq_perf):
|
||||
"""Test fetch_single_stock method."""
|
||||
mock_get_cyq_perf.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ"],
|
||||
"trade_date": ["20240101", "20240102"],
|
||||
"cost_50pct": [10.00, 10.10],
|
||||
"winner_rate": [5.50, 5.60],
|
||||
}
|
||||
)
|
||||
|
||||
sync = CyqPerfSync()
|
||||
# Mock the client to avoid real initialization
|
||||
sync.client = MagicMock()
|
||||
|
||||
result = sync.fetch_single_stock(
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert not result.empty
|
||||
assert len(result) == 2
|
||||
mock_get_cyq_perf.assert_called_once()
|
||||
|
||||
|
||||
class TestSyncCyqPerf:
|
||||
"""Test suite for sync_cyq_perf convenience function."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync")
|
||||
def test_sync_cyq_perf_calls_sync_all(self, mock_sync_class):
|
||||
"""Test that sync_cyq_perf calls sync_all on CyqPerfSync."""
|
||||
mock_sync_instance = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync_instance
|
||||
mock_sync_instance.sync_all.return_value = {
|
||||
"000001.SZ": pd.DataFrame({"ts_code": ["000001.SZ"]})
|
||||
}
|
||||
|
||||
result = sync_cyq_perf()
|
||||
|
||||
mock_sync_class.assert_called_once_with(max_workers=None)
|
||||
mock_sync_instance.sync_all.assert_called_once()
|
||||
assert isinstance(result, dict)
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync")
|
||||
def test_sync_cyq_perf_with_params(self, mock_sync_class):
|
||||
"""Test sync_cyq_perf with parameters."""
|
||||
mock_sync_instance = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync_instance
|
||||
mock_sync_instance.sync_all.return_value = {}
|
||||
|
||||
result = sync_cyq_perf(
|
||||
force_full=True,
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
max_workers=20,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
mock_sync_class.assert_called_once_with(max_workers=20)
|
||||
mock_sync_instance.sync_all.assert_called_once_with(
|
||||
force_full=True,
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
|
||||
class TestPreviewCyqPerfSync:
|
||||
"""Test suite for preview_cyq_perf_sync convenience function."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync")
|
||||
def test_preview_cyq_perf_sync(self, mock_sync_class):
|
||||
"""Test preview_cyq_perf_sync function."""
|
||||
mock_sync_instance = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync_instance
|
||||
mock_sync_instance.preview_sync.return_value = {
|
||||
"sync_needed": True,
|
||||
"stock_count": 5000,
|
||||
"start_date": "20240101",
|
||||
"end_date": "20240131",
|
||||
"estimated_records": 100000,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "incremental",
|
||||
}
|
||||
|
||||
result = preview_cyq_perf_sync()
|
||||
|
||||
mock_sync_class.assert_called_once_with()
|
||||
mock_sync_instance.preview_sync.assert_called_once()
|
||||
assert result["sync_needed"] is True
|
||||
assert result["stock_count"] == 5000
|
||||
|
||||
@patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync")
|
||||
def test_preview_cyq_perf_sync_with_params(self, mock_sync_class):
|
||||
"""Test preview with custom parameters."""
|
||||
mock_sync_instance = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync_instance
|
||||
mock_sync_instance.preview_sync.return_value = {}
|
||||
|
||||
preview_cyq_perf_sync(
|
||||
force_full=True,
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
sample_size=5,
|
||||
)
|
||||
|
||||
mock_sync_instance.preview_sync.assert_called_once_with(
|
||||
force_full=True,
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
sample_size=5,
|
||||
)
|
||||
Reference in New Issue
Block a user