- 新增 api_stock_st.py,实现ST股票数据获取和日期遍历同步 - 更新 sync.py,将ST股票同步加入第7步流程 - 移除 base_sync.py 中未使用的 get_last_n_trading_days 导入
144 lines
4.7 KiB
Python
144 lines
4.7 KiB
Python
"""Test suite for stock_st API wrapper."""
|
|
|
|
import pytest
|
|
import pandas as pd
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from src.data.api_wrappers.api_stock_st import get_stock_st, sync_stock_st, StockSTSync
|
|
|
|
|
|
class TestStockST:
|
|
"""Test suite for stock_st API wrapper."""
|
|
|
|
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
|
def test_get_by_date(self, mock_client_class):
|
|
"""Test fetching ST stock list by date."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["300313.SZ", "605081.SH", "300391.SZ"],
|
|
"name": ["*ST天山", "*ST太和", "*ST长药"],
|
|
"trade_date": ["20240101", "20240101", "20240101"],
|
|
"type": ["ST", "ST", "ST"],
|
|
"type_name": ["风险警示板", "风险警示板", "风险警示板"],
|
|
}
|
|
)
|
|
|
|
# Test
|
|
result = get_stock_st(trade_date="20240101")
|
|
|
|
# Assert
|
|
assert not result.empty
|
|
assert len(result) == 3
|
|
assert "ts_code" in result.columns
|
|
assert "name" in result.columns
|
|
assert "trade_date" in result.columns
|
|
assert "type" in result.columns
|
|
assert "type_name" in result.columns
|
|
mock_client.query.assert_called_once()
|
|
|
|
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
|
def test_get_by_stock(self, mock_client_class):
|
|
"""Test fetching ST history by stock code."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["300313.SZ", "300313.SZ"],
|
|
"name": ["*ST天山", "*ST天山"],
|
|
"trade_date": ["20240101", "20240102"],
|
|
"type": ["ST", "ST"],
|
|
"type_name": ["风险警示板", "风险警示板"],
|
|
}
|
|
)
|
|
|
|
# Test
|
|
result = get_stock_st(
|
|
ts_code="300313.SZ", start_date="20240101", end_date="20240102"
|
|
)
|
|
|
|
# Assert
|
|
assert not result.empty
|
|
assert len(result) == 2
|
|
mock_client.query.assert_called_once()
|
|
|
|
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
|
def test_empty_response(self, mock_client_class):
|
|
"""Test handling empty response."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame()
|
|
|
|
result = get_stock_st(trade_date="20240101")
|
|
assert result.empty
|
|
|
|
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
|
def test_get_by_date_range(self, mock_client_class):
|
|
"""Test fetching ST stock list by date range."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["300313.SZ"],
|
|
"name": ["*ST天山"],
|
|
"trade_date": ["20240101"],
|
|
"type": ["ST"],
|
|
"type_name": ["风险警示板"],
|
|
}
|
|
)
|
|
|
|
# Test
|
|
result = get_stock_st(start_date="20240101", end_date="20240131")
|
|
|
|
# Assert
|
|
assert not result.empty
|
|
mock_client.query.assert_called_once()
|
|
|
|
|
|
class TestStockSTSync:
|
|
"""Test suite for StockSTSync class."""
|
|
|
|
def test_sync_class_attributes(self):
|
|
"""Test that sync class has correct attributes."""
|
|
sync = StockSTSync()
|
|
assert sync.table_name == "stock_st"
|
|
assert sync.default_start_date == "20160101"
|
|
assert "ts_code" in sync.TABLE_SCHEMA
|
|
assert "trade_date" in sync.TABLE_SCHEMA
|
|
assert "name" in sync.TABLE_SCHEMA
|
|
assert "type" in sync.TABLE_SCHEMA
|
|
assert "type_name" in sync.TABLE_SCHEMA
|
|
assert sync.PRIMARY_KEY == ("trade_date", "ts_code")
|
|
|
|
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
|
|
def test_fetch_single_date(self, mock_client_class):
|
|
"""Test fetching single date data."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["300313.SZ"],
|
|
"name": ["*ST天山"],
|
|
"trade_date": ["20240101"],
|
|
"type": ["ST"],
|
|
"type_name": ["风险警示板"],
|
|
}
|
|
)
|
|
|
|
# Test
|
|
sync = StockSTSync()
|
|
result = sync.fetch_single_date("20240101")
|
|
|
|
# Assert
|
|
assert not result.empty
|
|
assert len(result) == 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|