- 添加项目规则文档(开发规范、安全规则、配置管理) - 实现数据模块核心功能(API 客户端、限流器、存储管理、配置加载) - 添加 .gitignore 和 .kilocodeignore 配置 - 配置环境变量模板 - 编写 daily 模块单元测试
420 lines
14 KiB
Python
420 lines
14 KiB
Python
"""Test for daily market data API.
|
|
|
|
Tests the daily interface implementation against api.md requirements:
|
|
- A股日线行情所有输出字段
|
|
- tor 换手率
|
|
- vr 量比
|
|
"""
|
|
import pytest
|
|
import pandas as pd
|
|
from unittest.mock import Mock, patch
|
|
from src.data.daily import get_daily
|
|
from src.data.client import TushareClient
|
|
|
|
|
|
# Expected output fields according to api.md
|
|
EXPECTED_BASE_FIELDS = [
|
|
'ts_code', # 股票代码
|
|
'trade_date', # 交易日期
|
|
'open', # 开盘价
|
|
'high', # 最高价
|
|
'low', # 最低价
|
|
'close', # 收盘价
|
|
'pre_close', # 昨收价
|
|
'change', # 涨跌额
|
|
'pct_chg', # 涨跌幅
|
|
'vol', # 成交量
|
|
'amount', # 成交额
|
|
]
|
|
|
|
EXPECTED_FACTOR_FIELDS = [
|
|
'tor', # 换手率
|
|
'vr', # 量比
|
|
]
|
|
|
|
|
|
def run_tests_with_print():
|
|
"""Run all tests and print results."""
|
|
print("=" * 60)
|
|
print("Daily API 测试开始")
|
|
print("=" * 60)
|
|
|
|
test_results = []
|
|
|
|
# Test 1: Basic daily data fetch
|
|
print("\n【测试1】基本日线数据获取")
|
|
print("-" * 40)
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['000001.SZ'],
|
|
'trade_date': ['20240102'],
|
|
'open': [10.5],
|
|
'high': [11.0],
|
|
'low': [10.2],
|
|
'close': [10.8],
|
|
'pre_close': [10.3],
|
|
'change': [0.5],
|
|
'pct_chg': [4.85],
|
|
'vol': [1000000],
|
|
'amount': [10800000],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
|
|
|
print(f"获取数据形状: {result.shape}")
|
|
print(f"获取数据列: {result.columns.tolist()}")
|
|
print(f"数据内容:\n{result}")
|
|
|
|
# Verify
|
|
tests_passed = isinstance(result, pd.DataFrame)
|
|
tests_passed &= len(result) == 1
|
|
tests_passed &= result['ts_code'].iloc[0] == '000001.SZ'
|
|
|
|
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
test_results.append(("基本日线数据获取", tests_passed))
|
|
|
|
# Test 2: Fetch with factors
|
|
print("\n【测试2】获取含换手率和量比的数据")
|
|
print("-" * 40)
|
|
mock_data_factors = pd.DataFrame({
|
|
'ts_code': ['000001.SZ'],
|
|
'trade_date': ['20240102'],
|
|
'open': [10.5],
|
|
'high': [11.0],
|
|
'low': [10.2],
|
|
'close': [10.8],
|
|
'pre_close': [10.3],
|
|
'change': [0.5],
|
|
'pct_chg': [4.85],
|
|
'vol': [1000000],
|
|
'amount': [10800000],
|
|
'tor': [2.5],
|
|
'vr': [1.2],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data_factors):
|
|
result = get_daily(
|
|
'000001.SZ',
|
|
start_date='20240101',
|
|
end_date='20240131',
|
|
factors=['tor', 'vr'],
|
|
)
|
|
|
|
print(f"获取数据形状: {result.shape}")
|
|
print(f"获取数据列: {result.columns.tolist()}")
|
|
print(f"数据内容:\n{result}")
|
|
|
|
# Verify columns
|
|
tests_passed = isinstance(result, pd.DataFrame)
|
|
missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns]
|
|
missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns]
|
|
|
|
print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base} ✗'}")
|
|
print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors} ✗'}")
|
|
|
|
tests_passed &= len(missing_base) == 0
|
|
tests_passed &= len(missing_factors) == 0
|
|
|
|
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
test_results.append(("含因子数据获取", tests_passed))
|
|
|
|
# Test 3: Output fields completeness
|
|
print("\n【测试3】输出字段完整性检查")
|
|
print("-" * 40)
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['600000.SH'],
|
|
'trade_date': ['20240102'],
|
|
'open': [10.5],
|
|
'high': [11.0],
|
|
'low': [10.2],
|
|
'close': [10.8],
|
|
'pre_close': [10.3],
|
|
'change': [0.5],
|
|
'pct_chg': [4.85],
|
|
'vol': [1000000],
|
|
'amount': [10800000],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily('600000.SH')
|
|
|
|
print(f"获取数据形状: {result.shape}")
|
|
print(f"获取数据列: {result.columns.tolist()}")
|
|
print(f"期望基础列: {EXPECTED_BASE_FIELDS}")
|
|
|
|
missing = set(EXPECTED_BASE_FIELDS) - set(result.columns)
|
|
print(f"缺失列: {missing if missing else '无'}")
|
|
|
|
tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist())
|
|
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
test_results.append(("输出字段完整性", tests_passed))
|
|
|
|
# Test 4: Empty result
|
|
print("\n【测试4】空结果处理")
|
|
print("-" * 40)
|
|
mock_data = pd.DataFrame()
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily('INVALID.SZ')
|
|
|
|
print(f"获取数据是否为空: {result.empty}")
|
|
tests_passed = result.empty
|
|
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
test_results.append(("空结果处理", tests_passed))
|
|
|
|
# Test 5: Date range query
|
|
print("\n【测试5】日期范围查询")
|
|
print("-" * 40)
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['000001.SZ', '000001.SZ'],
|
|
'trade_date': ['20240102', '20240103'],
|
|
'open': [10.5, 10.6],
|
|
'high': [11.0, 11.1],
|
|
'low': [10.2, 10.3],
|
|
'close': [10.8, 10.9],
|
|
'pre_close': [10.3, 10.8],
|
|
'change': [0.5, 0.1],
|
|
'pct_chg': [4.85, 0.93],
|
|
'vol': [1000000, 1100000],
|
|
'amount': [10800000, 11900000],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily(
|
|
'000001.SZ',
|
|
start_date='20240101',
|
|
end_date='20240131',
|
|
)
|
|
|
|
print(f"获取数据数量: {len(result)}")
|
|
print(f"期望数量: 2")
|
|
print(f"数据内容:\n{result}")
|
|
|
|
tests_passed = len(result) == 2
|
|
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
test_results.append(("日期范围查询", tests_passed))
|
|
|
|
# Test 6: With adjustment
|
|
print("\n【测试6】带复权参数查询")
|
|
print("-" * 40)
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['000001.SZ'],
|
|
'trade_date': ['20240102'],
|
|
'open': [10.5],
|
|
'high': [11.0],
|
|
'low': [10.2],
|
|
'close': [10.8],
|
|
'pre_close': [10.3],
|
|
'change': [0.5],
|
|
'pct_chg': [4.85],
|
|
'vol': [1000000],
|
|
'amount': [10800000],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily('000001.SZ', adj='qfq')
|
|
|
|
print(f"获取数据形状: {result.shape}")
|
|
print(f"数据内容:\n{result}")
|
|
|
|
tests_passed = isinstance(result, pd.DataFrame)
|
|
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
test_results.append(("复权参数查询", tests_passed))
|
|
|
|
# Summary
|
|
print("\n" + "=" * 60)
|
|
print("测试汇总")
|
|
print("=" * 60)
|
|
passed = sum(1 for _, r in test_results if r)
|
|
total = len(test_results)
|
|
print(f"总测试数: {total}")
|
|
print(f"通过: {passed}")
|
|
print(f"失败: {total - passed}")
|
|
print(f"通过率: {passed/total*100:.1f}%")
|
|
print("\n详细结果:")
|
|
for name, passed in test_results:
|
|
status = "通过 ✓" if passed else "失败 ✗"
|
|
print(f" - {name}: {status}")
|
|
|
|
return all(r for _, r in test_results)
|
|
|
|
|
|
class TestGetDaily:
|
|
"""Test cases for simplified get_daily function."""
|
|
|
|
def test_fetch_basic(self):
|
|
"""Test basic daily data fetch."""
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['000001.SZ'],
|
|
'trade_date': ['20240102'],
|
|
'open': [10.5],
|
|
'high': [11.0],
|
|
'low': [10.2],
|
|
'close': [10.8],
|
|
'pre_close': [10.3],
|
|
'change': [0.5],
|
|
'pct_chg': [4.85],
|
|
'vol': [1000000],
|
|
'amount': [10800000],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
|
|
|
assert isinstance(result, pd.DataFrame)
|
|
assert len(result) == 1
|
|
assert result['ts_code'].iloc[0] == '000001.SZ'
|
|
|
|
def test_fetch_with_factors(self):
|
|
"""Test fetch with tor and vr factors."""
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['000001.SZ'],
|
|
'trade_date': ['20240102'],
|
|
'open': [10.5],
|
|
'high': [11.0],
|
|
'low': [10.2],
|
|
'close': [10.8],
|
|
'pre_close': [10.3],
|
|
'change': [0.5],
|
|
'pct_chg': [4.85],
|
|
'vol': [1000000],
|
|
'amount': [10800000],
|
|
'tor': [2.5], # 换手率
|
|
'vr': [1.2], # 量比
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily(
|
|
'000001.SZ',
|
|
start_date='20240101',
|
|
end_date='20240131',
|
|
factors=['tor', 'vr'],
|
|
)
|
|
|
|
assert isinstance(result, pd.DataFrame)
|
|
# Check all base fields are present
|
|
for field in EXPECTED_BASE_FIELDS:
|
|
assert field in result.columns, f"Missing base field: {field}"
|
|
# Check factor fields are present
|
|
for field in EXPECTED_FACTOR_FIELDS:
|
|
assert field in result.columns, f"Missing factor field: {field}"
|
|
|
|
def test_output_fields_completeness(self):
|
|
"""Verify all required output fields are returned."""
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['600000.SH'],
|
|
'trade_date': ['20240102'],
|
|
'open': [10.5],
|
|
'high': [11.0],
|
|
'low': [10.2],
|
|
'close': [10.8],
|
|
'pre_close': [10.3],
|
|
'change': [0.5],
|
|
'pct_chg': [4.85],
|
|
'vol': [1000000],
|
|
'amount': [10800000],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily('600000.SH')
|
|
|
|
# Verify all base fields are present
|
|
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
|
|
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
|
|
|
|
def test_empty_result(self):
|
|
"""Test handling of empty results."""
|
|
mock_data = pd.DataFrame()
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily('INVALID.SZ')
|
|
|
|
assert result.empty
|
|
|
|
def test_date_range_query(self):
|
|
"""Test query with date range."""
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['000001.SZ', '000001.SZ'],
|
|
'trade_date': ['20240102', '20240103'],
|
|
'open': [10.5, 10.6],
|
|
'high': [11.0, 11.1],
|
|
'low': [10.2, 10.3],
|
|
'close': [10.8, 10.9],
|
|
'pre_close': [10.3, 10.8],
|
|
'change': [0.5, 0.1],
|
|
'pct_chg': [4.85, 0.93],
|
|
'vol': [1000000, 1100000],
|
|
'amount': [10800000, 11900000],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily(
|
|
'000001.SZ',
|
|
start_date='20240101',
|
|
end_date='20240131',
|
|
)
|
|
|
|
assert len(result) == 2
|
|
|
|
def test_with_adj(self):
|
|
"""Test fetch with adjustment type."""
|
|
mock_data = pd.DataFrame({
|
|
'ts_code': ['000001.SZ'],
|
|
'trade_date': ['20240102'],
|
|
'open': [10.5],
|
|
'high': [11.0],
|
|
'low': [10.2],
|
|
'close': [10.8],
|
|
'pre_close': [10.3],
|
|
'change': [0.5],
|
|
'pct_chg': [4.85],
|
|
'vol': [1000000],
|
|
'amount': [10800000],
|
|
})
|
|
|
|
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
result = get_daily('000001.SZ', adj='qfq')
|
|
|
|
assert isinstance(result, pd.DataFrame)
|
|
|
|
|
|
def test_integration():
|
|
"""Integration test with real Tushare API (requires valid token)."""
|
|
import os
|
|
token = os.environ.get('TUSHARE_TOKEN')
|
|
if not token:
|
|
pytest.skip("TUSHARE_TOKEN not configured")
|
|
|
|
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
|
|
|
|
# Verify structure
|
|
assert isinstance(result, pd.DataFrame)
|
|
if not result.empty:
|
|
# Check base fields
|
|
for field in EXPECTED_BASE_FIELDS:
|
|
assert field in result.columns, f"Missing base field: {field}"
|
|
# Check factor fields
|
|
for field in EXPECTED_FACTOR_FIELDS:
|
|
assert field in result.columns, f"Missing factor field: {field}"
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# 运行详细的打印测试
|
|
run_tests_with_print()
|
|
print("\n" + "=" * 60)
|
|
print("运行 pytest 单元测试")
|
|
print("=" * 60 + "\n")
|
|
pytest.main([__file__, '-v'])
|