- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合 - 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等) - 新增 factors/compiler.py: 因子编译器 - 新增 factors/translator.py: DSL表达式翻译器 - 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据 - 新增 data/data_router.py: 数据路由功能 - 新增相关测试用例
422 lines
14 KiB
Python
422 lines
14 KiB
Python
"""Test for pro_bar (universal market) API.
|
|
|
|
Tests the pro_bar interface implementation:
|
|
- Backward-adjusted (后复权) data fetching
|
|
- All output fields including tor, vr, and adj_factor (default behavior)
|
|
- Multiple asset types support
|
|
- ProBarSync batch synchronization
|
|
"""
|
|
|
|
import pytest
|
|
import pandas as pd
|
|
from unittest.mock import patch, MagicMock
|
|
from src.data.api_wrappers.api_pro_bar import (
|
|
get_pro_bar,
|
|
ProBarSync,
|
|
sync_pro_bar,
|
|
preview_pro_bar_sync,
|
|
)
|
|
|
|
|
|
# Expected output fields according to api.md
|
|
EXPECTED_BASE_FIELDS = [
|
|
"ts_code", # 股票代码
|
|
"trade_date", # 交易日期
|
|
"open", # 开盘价
|
|
"high", # 最高价
|
|
"low", # 最低价
|
|
"close", # 收盘价
|
|
"pre_close", # 昨收价
|
|
"change", # 涨跌额
|
|
"pct_chg", # 涨跌幅
|
|
"vol", # 成交量
|
|
"amount", # 成交额
|
|
]
|
|
|
|
EXPECTED_FACTOR_FIELDS = [
|
|
"turnover_rate", # 换手率 (tor)
|
|
"volume_ratio", # 量比 (vr)
|
|
]
|
|
|
|
|
|
class TestGetProBar:
|
|
"""Test cases for get_pro_bar function."""
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_fetch_basic(self, mock_client_class):
|
|
"""Test basic pro_bar data fetch."""
|
|
# Setup mock
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"trade_date": ["20240115"],
|
|
"open": [10.5],
|
|
"high": [11.0],
|
|
"low": [10.2],
|
|
"close": [10.8],
|
|
"pre_close": [10.5],
|
|
"change": [0.3],
|
|
"pct_chg": [2.86],
|
|
"vol": [100000.0],
|
|
"amount": [1080000.0],
|
|
}
|
|
)
|
|
|
|
# Test
|
|
result = get_pro_bar("000001.SZ", start_date="20240101", end_date="20240131")
|
|
|
|
# Assert
|
|
assert isinstance(result, pd.DataFrame)
|
|
assert not result.empty
|
|
assert result["ts_code"].iloc[0] == "000001.SZ"
|
|
mock_client.query.assert_called_once()
|
|
# Verify pro_bar API is called
|
|
call_args = mock_client.query.call_args
|
|
assert call_args[0][0] == "pro_bar"
|
|
assert call_args[1]["ts_code"] == "000001.SZ"
|
|
# Default should use hfq (backward-adjusted)
|
|
assert call_args[1]["adj"] == "hfq"
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_default_backward_adjusted(self, mock_client_class):
|
|
"""Test that default adjustment is backward (hfq)."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"trade_date": ["20240115"],
|
|
"close": [100.5],
|
|
}
|
|
)
|
|
|
|
result = get_pro_bar("000001.SZ")
|
|
|
|
call_args = mock_client.query.call_args
|
|
assert call_args[1]["adj"] == "hfq"
|
|
assert call_args[1]["adjfactor"] == "True"
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_default_factors_all_fields(self, mock_client_class):
|
|
"""Test that default factors includes tor and vr."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"trade_date": ["20240115"],
|
|
"close": [10.8],
|
|
"turnover_rate": [2.5],
|
|
"volume_ratio": [1.2],
|
|
"adj_factor": [1.05],
|
|
}
|
|
)
|
|
|
|
result = get_pro_bar("000001.SZ")
|
|
|
|
call_args = mock_client.query.call_args
|
|
# Default should include both tor and vr
|
|
assert call_args[1]["factors"] == "tor,vr"
|
|
assert "turnover_rate" in result.columns
|
|
assert "volume_ratio" in result.columns
|
|
assert "adj_factor" in result.columns
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_fetch_with_custom_factors(self, mock_client_class):
|
|
"""Test fetch with custom factors."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"trade_date": ["20240115"],
|
|
"close": [10.8],
|
|
"turnover_rate": [2.5],
|
|
}
|
|
)
|
|
|
|
# Only request tor
|
|
result = get_pro_bar(
|
|
"000001.SZ",
|
|
start_date="20240101",
|
|
end_date="20240131",
|
|
factors=["tor"],
|
|
)
|
|
|
|
call_args = mock_client.query.call_args
|
|
assert call_args[1]["factors"] == "tor"
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_fetch_with_no_factors(self, mock_client_class):
|
|
"""Test fetch with no factors (empty list)."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"trade_date": ["20240115"],
|
|
"close": [10.8],
|
|
}
|
|
)
|
|
|
|
# Explicitly set factors to empty list
|
|
result = get_pro_bar(
|
|
"000001.SZ",
|
|
start_date="20240101",
|
|
end_date="20240131",
|
|
factors=[],
|
|
)
|
|
|
|
call_args = mock_client.query.call_args
|
|
# Should not include factors parameter
|
|
assert "factors" not in call_args[1]
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_fetch_with_ma(self, mock_client_class):
|
|
"""Test fetch with moving averages."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"trade_date": ["20240115"],
|
|
"close": [10.8],
|
|
"ma_5": [10.5],
|
|
"ma_10": [10.3],
|
|
"ma_v_5": [95000.0],
|
|
}
|
|
)
|
|
|
|
result = get_pro_bar(
|
|
"000001.SZ",
|
|
start_date="20240101",
|
|
end_date="20240131",
|
|
ma=[5, 10],
|
|
)
|
|
|
|
call_args = mock_client.query.call_args
|
|
assert call_args[1]["ma"] == "5,10"
|
|
assert "ma_5" in result.columns
|
|
assert "ma_10" in result.columns
|
|
assert "ma_v_5" in result.columns
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_fetch_index_data(self, mock_client_class):
|
|
"""Test fetching index data."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SH"],
|
|
"trade_date": ["20240115"],
|
|
"close": [2900.5],
|
|
}
|
|
)
|
|
|
|
result = get_pro_bar(
|
|
"000001.SH",
|
|
asset="I",
|
|
start_date="20240101",
|
|
end_date="20240131",
|
|
)
|
|
|
|
call_args = mock_client.query.call_args
|
|
assert call_args[1]["asset"] == "I"
|
|
assert call_args[1]["ts_code"] == "000001.SH"
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_forward_adjustment(self, mock_client_class):
|
|
"""Test forward adjustment (qfq)."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"trade_date": ["20240115"],
|
|
"close": [10.8],
|
|
}
|
|
)
|
|
|
|
result = get_pro_bar("000001.SZ", adj="qfq")
|
|
|
|
call_args = mock_client.query.call_args
|
|
assert call_args[1]["adj"] == "qfq"
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_no_adjustment(self, mock_client_class):
|
|
"""Test no adjustment."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"trade_date": ["20240115"],
|
|
"close": [10.8],
|
|
}
|
|
)
|
|
|
|
result = get_pro_bar("000001.SZ", adj=None)
|
|
|
|
call_args = mock_client.query.call_args
|
|
assert "adj" not in call_args[1]
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_empty_response(self, mock_client_class):
|
|
"""Test handling empty response."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame()
|
|
|
|
result = get_pro_bar("INVALID.SZ")
|
|
|
|
assert isinstance(result, pd.DataFrame)
|
|
assert result.empty
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
|
def test_date_column_rename(self, mock_client_class):
|
|
"""Test that 'date' column is renamed to 'trade_date'."""
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.query.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ"],
|
|
"date": ["20240115"], # API returns 'date' instead of 'trade_date'
|
|
"close": [10.8],
|
|
}
|
|
)
|
|
|
|
result = get_pro_bar("000001.SZ")
|
|
|
|
assert "trade_date" in result.columns
|
|
assert "date" not in result.columns
|
|
assert result["trade_date"].iloc[0] == "20240115"
|
|
|
|
|
|
class TestProBarSync:
|
|
"""Test cases for ProBarSync class."""
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.sync_all_stocks")
|
|
@patch("src.data.api_wrappers.api_pro_bar.pd.read_csv")
|
|
@patch("src.data.api_wrappers.api_pro_bar._get_csv_path")
|
|
def test_get_all_stock_codes(self, mock_get_path, mock_read_csv, mock_sync_stocks):
|
|
"""Test getting all stock codes."""
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock
|
|
|
|
# Create a mock path that exists
|
|
mock_path = MagicMock(spec=Path)
|
|
mock_path.exists.return_value = True
|
|
mock_get_path.return_value = mock_path
|
|
|
|
mock_read_csv.return_value = pd.DataFrame(
|
|
{
|
|
"ts_code": ["000001.SZ", "600000.SH"],
|
|
"list_status": ["L", "L"],
|
|
}
|
|
)
|
|
|
|
sync = ProBarSync()
|
|
codes = sync.get_all_stock_codes()
|
|
|
|
assert len(codes) == 2
|
|
assert "000001.SZ" in codes
|
|
assert "600000.SH" in codes
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.Storage")
|
|
def test_check_sync_needed_force_full(self, mock_storage_class):
|
|
"""Test check_sync_needed with force_full=True."""
|
|
mock_storage = MagicMock()
|
|
mock_storage_class.return_value = mock_storage
|
|
mock_storage.exists.return_value = False
|
|
|
|
sync = ProBarSync()
|
|
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
|
|
|
|
assert needed is True
|
|
assert start == "20180101" # DEFAULT_START_DATE
|
|
assert local_last is None
|
|
@patch("src.data.api_wrappers.api_pro_bar.Storage")
|
|
def test_check_sync_needed_force_full(self, mock_storage_class):
|
|
"""Test check_sync_needed with force_full=True."""
|
|
mock_storage = MagicMock()
|
|
mock_storage_class.return_value = mock_storage
|
|
mock_storage.exists.return_value = False
|
|
|
|
sync = ProBarSync()
|
|
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
|
|
|
|
assert needed is True
|
|
assert start == "20180101" # DEFAULT_START_DATE
|
|
assert local_last is None
|
|
|
|
|
|
class TestSyncProBar:
|
|
"""Test cases for sync_pro_bar function."""
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
|
|
def test_sync_pro_bar(self, mock_sync_class):
|
|
"""Test sync_pro_bar function."""
|
|
mock_sync = MagicMock()
|
|
mock_sync_class.return_value = mock_sync
|
|
mock_sync.sync_all.return_value = {"000001.SZ": pd.DataFrame({"close": [10.5]})}
|
|
|
|
result = sync_pro_bar(force_full=True, max_workers=5)
|
|
|
|
mock_sync_class.assert_called_once_with(max_workers=5)
|
|
mock_sync.sync_all.assert_called_once()
|
|
assert "000001.SZ" in result
|
|
|
|
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
|
|
def test_preview_pro_bar_sync(self, mock_sync_class):
|
|
"""Test preview_pro_bar_sync function."""
|
|
mock_sync = MagicMock()
|
|
mock_sync_class.return_value = mock_sync
|
|
mock_sync.preview_sync.return_value = {
|
|
"sync_needed": True,
|
|
"stock_count": 5000,
|
|
"mode": "full",
|
|
}
|
|
|
|
result = preview_pro_bar_sync(force_full=True)
|
|
|
|
mock_sync_class.assert_called_once_with()
|
|
mock_sync.preview_sync.assert_called_once()
|
|
assert result["sync_needed"] is True
|
|
assert result["stock_count"] == 5000
|
|
|
|
|
|
class TestProBarIntegration:
|
|
"""Integration tests with real Tushare API."""
|
|
|
|
def test_real_api_call(self):
|
|
"""Test with real API (requires valid token)."""
|
|
import os
|
|
|
|
token = os.environ.get("TUSHARE_TOKEN")
|
|
if not token:
|
|
pytest.skip("TUSHARE_TOKEN not configured")
|
|
|
|
result = get_pro_bar(
|
|
"000001.SZ",
|
|
start_date="20240101",
|
|
end_date="20240131",
|
|
)
|
|
|
|
# Verify structure
|
|
assert isinstance(result, pd.DataFrame)
|
|
if not result.empty:
|
|
# Check base fields
|
|
for field in EXPECTED_BASE_FIELDS:
|
|
assert field in result.columns, f"Missing base field: {field}"
|
|
# Check factor fields (should be present by default)
|
|
for field in EXPECTED_FACTOR_FIELDS:
|
|
assert field in result.columns, f"Missing factor field: {field}"
|
|
# Check adj_factor is present (default behavior)
|
|
assert "adj_factor" in result.columns
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|