Files
ProStock/tests/test_pro_bar.py
liaozhaorun 0698b9d919 feat: 添加DSL因子表达式系统和Pro Bar API封装
- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合
- 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等)
- 新增 factors/compiler.py: 因子编译器
- 新增 factors/translator.py: DSL表达式翻译器
- 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据
- 新增 data/data_router.py: 数据路由功能
- 新增相关测试用例
2026-02-27 22:43:45 +08:00

422 lines
14 KiB
Python

"""Test for pro_bar (universal market) API.
Tests the pro_bar interface implementation:
- Backward-adjusted (后复权) data fetching
- All output fields including tor, vr, and adj_factor (default behavior)
- Multiple asset types support
- ProBarSync batch synchronization
"""
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
from src.data.api_wrappers.api_pro_bar import (
get_pro_bar,
ProBarSync,
sync_pro_bar,
preview_pro_bar_sync,
)
# Expected output fields according to api.md
EXPECTED_BASE_FIELDS = [
"ts_code", # 股票代码
"trade_date", # 交易日期
"open", # 开盘价
"high", # 最高价
"low", # 最低价
"close", # 收盘价
"pre_close", # 昨收价
"change", # 涨跌额
"pct_chg", # 涨跌幅
"vol", # 成交量
"amount", # 成交额
]
EXPECTED_FACTOR_FIELDS = [
"turnover_rate", # 换手率 (tor)
"volume_ratio", # 量比 (vr)
]
class TestGetProBar:
"""Test cases for get_pro_bar function."""
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_basic(self, mock_client_class):
"""Test basic pro_bar data fetch."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"open": [10.5],
"high": [11.0],
"low": [10.2],
"close": [10.8],
"pre_close": [10.5],
"change": [0.3],
"pct_chg": [2.86],
"vol": [100000.0],
"amount": [1080000.0],
}
)
# Test
result = get_pro_bar("000001.SZ", start_date="20240101", end_date="20240131")
# Assert
assert isinstance(result, pd.DataFrame)
assert not result.empty
assert result["ts_code"].iloc[0] == "000001.SZ"
mock_client.query.assert_called_once()
# Verify pro_bar API is called
call_args = mock_client.query.call_args
assert call_args[0][0] == "pro_bar"
assert call_args[1]["ts_code"] == "000001.SZ"
# Default should use hfq (backward-adjusted)
assert call_args[1]["adj"] == "hfq"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_default_backward_adjusted(self, mock_client_class):
"""Test that default adjustment is backward (hfq)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [100.5],
}
)
result = get_pro_bar("000001.SZ")
call_args = mock_client.query.call_args
assert call_args[1]["adj"] == "hfq"
assert call_args[1]["adjfactor"] == "True"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_default_factors_all_fields(self, mock_client_class):
"""Test that default factors includes tor and vr."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
"turnover_rate": [2.5],
"volume_ratio": [1.2],
"adj_factor": [1.05],
}
)
result = get_pro_bar("000001.SZ")
call_args = mock_client.query.call_args
# Default should include both tor and vr
assert call_args[1]["factors"] == "tor,vr"
assert "turnover_rate" in result.columns
assert "volume_ratio" in result.columns
assert "adj_factor" in result.columns
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_with_custom_factors(self, mock_client_class):
"""Test fetch with custom factors."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
"turnover_rate": [2.5],
}
)
# Only request tor
result = get_pro_bar(
"000001.SZ",
start_date="20240101",
end_date="20240131",
factors=["tor"],
)
call_args = mock_client.query.call_args
assert call_args[1]["factors"] == "tor"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_with_no_factors(self, mock_client_class):
"""Test fetch with no factors (empty list)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
}
)
# Explicitly set factors to empty list
result = get_pro_bar(
"000001.SZ",
start_date="20240101",
end_date="20240131",
factors=[],
)
call_args = mock_client.query.call_args
# Should not include factors parameter
assert "factors" not in call_args[1]
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_with_ma(self, mock_client_class):
"""Test fetch with moving averages."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
"ma_5": [10.5],
"ma_10": [10.3],
"ma_v_5": [95000.0],
}
)
result = get_pro_bar(
"000001.SZ",
start_date="20240101",
end_date="20240131",
ma=[5, 10],
)
call_args = mock_client.query.call_args
assert call_args[1]["ma"] == "5,10"
assert "ma_5" in result.columns
assert "ma_10" in result.columns
assert "ma_v_5" in result.columns
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_fetch_index_data(self, mock_client_class):
"""Test fetching index data."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SH"],
"trade_date": ["20240115"],
"close": [2900.5],
}
)
result = get_pro_bar(
"000001.SH",
asset="I",
start_date="20240101",
end_date="20240131",
)
call_args = mock_client.query.call_args
assert call_args[1]["asset"] == "I"
assert call_args[1]["ts_code"] == "000001.SH"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_forward_adjustment(self, mock_client_class):
"""Test forward adjustment (qfq)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
}
)
result = get_pro_bar("000001.SZ", adj="qfq")
call_args = mock_client.query.call_args
assert call_args[1]["adj"] == "qfq"
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_no_adjustment(self, mock_client_class):
"""Test no adjustment."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"close": [10.8],
}
)
result = get_pro_bar("000001.SZ", adj=None)
call_args = mock_client.query.call_args
assert "adj" not in call_args[1]
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_empty_response(self, mock_client_class):
"""Test handling empty response."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame()
result = get_pro_bar("INVALID.SZ")
assert isinstance(result, pd.DataFrame)
assert result.empty
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
def test_date_column_rename(self, mock_client_class):
"""Test that 'date' column is renamed to 'trade_date'."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"date": ["20240115"], # API returns 'date' instead of 'trade_date'
"close": [10.8],
}
)
result = get_pro_bar("000001.SZ")
assert "trade_date" in result.columns
assert "date" not in result.columns
assert result["trade_date"].iloc[0] == "20240115"
class TestProBarSync:
"""Test cases for ProBarSync class."""
@patch("src.data.api_wrappers.api_pro_bar.sync_all_stocks")
@patch("src.data.api_wrappers.api_pro_bar.pd.read_csv")
@patch("src.data.api_wrappers.api_pro_bar._get_csv_path")
def test_get_all_stock_codes(self, mock_get_path, mock_read_csv, mock_sync_stocks):
"""Test getting all stock codes."""
from pathlib import Path
from unittest.mock import MagicMock
# Create a mock path that exists
mock_path = MagicMock(spec=Path)
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
mock_read_csv.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "600000.SH"],
"list_status": ["L", "L"],
}
)
sync = ProBarSync()
codes = sync.get_all_stock_codes()
assert len(codes) == 2
assert "000001.SZ" in codes
assert "600000.SH" in codes
@patch("src.data.api_wrappers.api_pro_bar.Storage")
def test_check_sync_needed_force_full(self, mock_storage_class):
"""Test check_sync_needed with force_full=True."""
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.exists.return_value = False
sync = ProBarSync()
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
assert needed is True
assert start == "20180101" # DEFAULT_START_DATE
assert local_last is None
@patch("src.data.api_wrappers.api_pro_bar.Storage")
def test_check_sync_needed_force_full(self, mock_storage_class):
"""Test check_sync_needed with force_full=True."""
mock_storage = MagicMock()
mock_storage_class.return_value = mock_storage
mock_storage.exists.return_value = False
sync = ProBarSync()
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
assert needed is True
assert start == "20180101" # DEFAULT_START_DATE
assert local_last is None
class TestSyncProBar:
"""Test cases for sync_pro_bar function."""
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
def test_sync_pro_bar(self, mock_sync_class):
"""Test sync_pro_bar function."""
mock_sync = MagicMock()
mock_sync_class.return_value = mock_sync
mock_sync.sync_all.return_value = {"000001.SZ": pd.DataFrame({"close": [10.5]})}
result = sync_pro_bar(force_full=True, max_workers=5)
mock_sync_class.assert_called_once_with(max_workers=5)
mock_sync.sync_all.assert_called_once()
assert "000001.SZ" in result
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
def test_preview_pro_bar_sync(self, mock_sync_class):
"""Test preview_pro_bar_sync function."""
mock_sync = MagicMock()
mock_sync_class.return_value = mock_sync
mock_sync.preview_sync.return_value = {
"sync_needed": True,
"stock_count": 5000,
"mode": "full",
}
result = preview_pro_bar_sync(force_full=True)
mock_sync_class.assert_called_once_with()
mock_sync.preview_sync.assert_called_once()
assert result["sync_needed"] is True
assert result["stock_count"] == 5000
class TestProBarIntegration:
"""Integration tests with real Tushare API."""
def test_real_api_call(self):
"""Test with real API (requires valid token)."""
import os
token = os.environ.get("TUSHARE_TOKEN")
if not token:
pytest.skip("TUSHARE_TOKEN not configured")
result = get_pro_bar(
"000001.SZ",
start_date="20240101",
end_date="20240131",
)
# Verify structure
assert isinstance(result, pd.DataFrame)
if not result.empty:
# Check base fields
for field in EXPECTED_BASE_FIELDS:
assert field in result.columns, f"Missing base field: {field}"
# Check factor fields (should be present by default)
for field in EXPECTED_FACTOR_FIELDS:
assert field in result.columns, f"Missing factor field: {field}"
# Check adj_factor is present (default behavior)
assert "adj_factor" in result.columns
if __name__ == "__main__":
pytest.main([__file__, "-v"])