Files
ProStock/tests/test_daily.py
liaozhaorun e625a53162 feat: 初始化 ProStock 项目基础结构和配置
- 添加项目规则文档(开发规范、安全规则、配置管理)
- 实现数据模块核心功能(API 客户端、限流器、存储管理、配置加载)
- 添加 .gitignore 和 .kilocodeignore 配置
- 配置环境变量模板
- 编写 daily 模块单元测试
2026-01-31 03:04:51 +08:00

420 lines
14 KiB
Python

"""Test for daily market data API.
Tests the daily interface implementation against api.md requirements:
- A股日线行情所有输出字段
- tor 换手率
- vr 量比
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch
from src.data.daily import get_daily
from src.data.client import TushareClient
# Expected output fields according to api.md
EXPECTED_BASE_FIELDS = [
'ts_code', # 股票代码
'trade_date', # 交易日期
'open', # 开盘价
'high', # 最高价
'low', # 最低价
'close', # 收盘价
'pre_close', # 昨收价
'change', # 涨跌额
'pct_chg', # 涨跌幅
'vol', # 成交量
'amount', # 成交额
]
EXPECTED_FACTOR_FIELDS = [
'tor', # 换手率
'vr', # 量比
]
def run_tests_with_print():
"""Run all tests and print results."""
print("=" * 60)
print("Daily API 测试开始")
print("=" * 60)
test_results = []
# Test 1: Basic daily data fetch
print("\n【测试1】基本日线数据获取")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"数据内容:\n{result}")
# Verify
tests_passed = isinstance(result, pd.DataFrame)
tests_passed &= len(result) == 1
tests_passed &= result['ts_code'].iloc[0] == '000001.SZ'
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("基本日线数据获取", tests_passed))
# Test 2: Fetch with factors
print("\n【测试2】获取含换手率和量比的数据")
print("-" * 40)
mock_data_factors = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
'tor': [2.5],
'vr': [1.2],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data_factors):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"数据内容:\n{result}")
# Verify columns
tests_passed = isinstance(result, pd.DataFrame)
missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns]
missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns]
print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base}'}")
print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors}'}")
tests_passed &= len(missing_base) == 0
tests_passed &= len(missing_factors) == 0
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("含因子数据获取", tests_passed))
# Test 3: Output fields completeness
print("\n【测试3】输出字段完整性检查")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['600000.SH'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('600000.SH')
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"期望基础列: {EXPECTED_BASE_FIELDS}")
missing = set(EXPECTED_BASE_FIELDS) - set(result.columns)
print(f"缺失列: {missing if missing else ''}")
tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist())
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("输出字段完整性", tests_passed))
# Test 4: Empty result
print("\n【测试4】空结果处理")
print("-" * 40)
mock_data = pd.DataFrame()
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('INVALID.SZ')
print(f"获取数据是否为空: {result.empty}")
tests_passed = result.empty
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("空结果处理", tests_passed))
# Test 5: Date range query
print("\n【测试5】日期范围查询")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ'],
'trade_date': ['20240102', '20240103'],
'open': [10.5, 10.6],
'high': [11.0, 11.1],
'low': [10.2, 10.3],
'close': [10.8, 10.9],
'pre_close': [10.3, 10.8],
'change': [0.5, 0.1],
'pct_chg': [4.85, 0.93],
'vol': [1000000, 1100000],
'amount': [10800000, 11900000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
print(f"获取数据数量: {len(result)}")
print(f"期望数量: 2")
print(f"数据内容:\n{result}")
tests_passed = len(result) == 2
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("日期范围查询", tests_passed))
# Test 6: With adjustment
print("\n【测试6】带复权参数查询")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', adj='qfq')
print(f"获取数据形状: {result.shape}")
print(f"数据内容:\n{result}")
tests_passed = isinstance(result, pd.DataFrame)
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("复权参数查询", tests_passed))
# Summary
print("\n" + "=" * 60)
print("测试汇总")
print("=" * 60)
passed = sum(1 for _, r in test_results if r)
total = len(test_results)
print(f"总测试数: {total}")
print(f"通过: {passed}")
print(f"失败: {total - passed}")
print(f"通过率: {passed/total*100:.1f}%")
print("\n详细结果:")
for name, passed in test_results:
status = "通过 ✓" if passed else "失败 ✗"
print(f" - {name}: {status}")
return all(r for _, r in test_results)
class TestGetDaily:
"""Test cases for simplified get_daily function."""
def test_fetch_basic(self):
"""Test basic daily data fetch."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert result['ts_code'].iloc[0] == '000001.SZ'
def test_fetch_with_factors(self):
"""Test fetch with tor and vr factors."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
'tor': [2.5], # 换手率
'vr': [1.2], # 量比
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
assert isinstance(result, pd.DataFrame)
# Check all base fields are present
for field in EXPECTED_BASE_FIELDS:
assert field in result.columns, f"Missing base field: {field}"
# Check factor fields are present
for field in EXPECTED_FACTOR_FIELDS:
assert field in result.columns, f"Missing factor field: {field}"
def test_output_fields_completeness(self):
"""Verify all required output fields are returned."""
mock_data = pd.DataFrame({
'ts_code': ['600000.SH'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('600000.SH')
# Verify all base fields are present
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
def test_empty_result(self):
"""Test handling of empty results."""
mock_data = pd.DataFrame()
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('INVALID.SZ')
assert result.empty
def test_date_range_query(self):
"""Test query with date range."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ'],
'trade_date': ['20240102', '20240103'],
'open': [10.5, 10.6],
'high': [11.0, 11.1],
'low': [10.2, 10.3],
'close': [10.8, 10.9],
'pre_close': [10.3, 10.8],
'change': [0.5, 0.1],
'pct_chg': [4.85, 0.93],
'vol': [1000000, 1100000],
'amount': [10800000, 11900000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
assert len(result) == 2
def test_with_adj(self):
"""Test fetch with adjustment type."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', adj='qfq')
assert isinstance(result, pd.DataFrame)
def test_integration():
"""Integration test with real Tushare API (requires valid token)."""
import os
token = os.environ.get('TUSHARE_TOKEN')
if not token:
pytest.skip("TUSHARE_TOKEN not configured")
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
# Verify structure
assert isinstance(result, pd.DataFrame)
if not result.empty:
# Check base fields
for field in EXPECTED_BASE_FIELDS:
assert field in result.columns, f"Missing base field: {field}"
# Check factor fields
for field in EXPECTED_FACTOR_FIELDS:
assert field in result.columns, f"Missing factor field: {field}"
if __name__ == '__main__':
# 运行详细的打印测试
run_tests_with_print()
print("\n" + "=" * 60)
print("运行 pytest 单元测试")
print("=" * 60 + "\n")
pytest.main([__file__, '-v'])