feat: 新增股票基础数据获取模块 stock_basic

- 新增 get_stock_basic 和 sync_all_stocks 函数
- 完善 Tushare 数据获取模块体系
- 测试用例重构:从 Mock 改为真实 API 调用
- 更新 API 文档,添加接口使用示例
- 更新开发规范:添加 Mock 使用规范
This commit is contained in:
2026-01-31 04:30:29 +08:00
parent e625a53162
commit 38e78a5326
10 changed files with 341 additions and 339 deletions

View File

@@ -7,9 +7,7 @@ Tests the daily interface implementation against api.md requirements:
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch
from src.data.daily import get_daily
from src.data.client import TushareClient
# Expected output fields according to api.md
@@ -28,276 +26,30 @@ EXPECTED_BASE_FIELDS = [
]
EXPECTED_FACTOR_FIELDS = [
'tor', # 换手率
'vr', # 量比
'turnover_rate', # 换手率 (tor)
'volume_ratio', # 量比 (vr)
]
def run_tests_with_print():
"""Run all tests and print results."""
print("=" * 60)
print("Daily API 测试开始")
print("=" * 60)
test_results = []
# Test 1: Basic daily data fetch
print("\n【测试1】基本日线数据获取")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"数据内容:\n{result}")
# Verify
tests_passed = isinstance(result, pd.DataFrame)
tests_passed &= len(result) == 1
tests_passed &= result['ts_code'].iloc[0] == '000001.SZ'
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("基本日线数据获取", tests_passed))
# Test 2: Fetch with factors
print("\n【测试2】获取含换手率和量比的数据")
print("-" * 40)
mock_data_factors = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
'tor': [2.5],
'vr': [1.2],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data_factors):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"数据内容:\n{result}")
# Verify columns
tests_passed = isinstance(result, pd.DataFrame)
missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns]
missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns]
print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base}'}")
print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors}'}")
tests_passed &= len(missing_base) == 0
tests_passed &= len(missing_factors) == 0
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("含因子数据获取", tests_passed))
# Test 3: Output fields completeness
print("\n【测试3】输出字段完整性检查")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['600000.SH'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('600000.SH')
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"期望基础列: {EXPECTED_BASE_FIELDS}")
missing = set(EXPECTED_BASE_FIELDS) - set(result.columns)
print(f"缺失列: {missing if missing else ''}")
tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist())
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("输出字段完整性", tests_passed))
# Test 4: Empty result
print("\n【测试4】空结果处理")
print("-" * 40)
mock_data = pd.DataFrame()
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('INVALID.SZ')
print(f"获取数据是否为空: {result.empty}")
tests_passed = result.empty
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("空结果处理", tests_passed))
# Test 5: Date range query
print("\n【测试5】日期范围查询")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ'],
'trade_date': ['20240102', '20240103'],
'open': [10.5, 10.6],
'high': [11.0, 11.1],
'low': [10.2, 10.3],
'close': [10.8, 10.9],
'pre_close': [10.3, 10.8],
'change': [0.5, 0.1],
'pct_chg': [4.85, 0.93],
'vol': [1000000, 1100000],
'amount': [10800000, 11900000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
print(f"获取数据数量: {len(result)}")
print(f"期望数量: 2")
print(f"数据内容:\n{result}")
tests_passed = len(result) == 2
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("日期范围查询", tests_passed))
# Test 6: With adjustment
print("\n【测试6】带复权参数查询")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', adj='qfq')
print(f"获取数据形状: {result.shape}")
print(f"数据内容:\n{result}")
tests_passed = isinstance(result, pd.DataFrame)
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("复权参数查询", tests_passed))
# Summary
print("\n" + "=" * 60)
print("测试汇总")
print("=" * 60)
passed = sum(1 for _, r in test_results if r)
total = len(test_results)
print(f"总测试数: {total}")
print(f"通过: {passed}")
print(f"失败: {total - passed}")
print(f"通过率: {passed/total*100:.1f}%")
print("\n详细结果:")
for name, passed in test_results:
status = "通过 ✓" if passed else "失败 ✗"
print(f" - {name}: {status}")
return all(r for _, r in test_results)
class TestGetDaily:
"""Test cases for simplified get_daily function."""
"""Test cases for get_daily function with real API calls."""
def test_fetch_basic(self):
"""Test basic daily data fetch."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
"""Test basic daily data fetch with real API."""
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert len(result) >= 1
assert result['ts_code'].iloc[0] == '000001.SZ'
def test_fetch_with_factors(self):
"""Test fetch with tor and vr factors."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
'tor': [2.5], # 换手率
'vr': [1.2], # 量比
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
assert isinstance(result, pd.DataFrame)
# Check all base fields are present
@@ -309,23 +61,7 @@ class TestGetDaily:
def test_output_fields_completeness(self):
"""Verify all required output fields are returned."""
mock_data = pd.DataFrame({
'ts_code': ['600000.SH'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('600000.SH')
result = get_daily('600000.SH')
# Verify all base fields are present
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
@@ -333,59 +69,25 @@ class TestGetDaily:
def test_empty_result(self):
"""Test handling of empty results."""
mock_data = pd.DataFrame()
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('INVALID.SZ')
# 使用真实 API 测试无效股票代码的空结果
result = get_daily('INVALID.SZ')
assert isinstance(result, pd.DataFrame)
assert result.empty
def test_date_range_query(self):
"""Test query with date range."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ'],
'trade_date': ['20240102', '20240103'],
'open': [10.5, 10.6],
'high': [11.0, 11.1],
'low': [10.2, 10.3],
'close': [10.8, 10.9],
'pre_close': [10.3, 10.8],
'change': [0.5, 0.1],
'pct_chg': [4.85, 0.93],
'vol': [1000000, 1100000],
'amount': [10800000, 11900000],
})
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
assert len(result) == 2
assert isinstance(result, pd.DataFrame)
assert len(result) >= 1
def test_with_adj(self):
"""Test fetch with adjustment type."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', adj='qfq')
result = get_daily('000001.SZ', adj='qfq')
assert isinstance(result, pd.DataFrame)
@@ -411,9 +113,5 @@ def test_integration():
if __name__ == '__main__':
# 运行详细的打印测试
run_tests_with_print()
print("\n" + "=" * 60)
print("运行 pytest 单元测试")
print("=" * 60 + "\n")
# 运行 pytest 单元测试真实API调用
pytest.main([__file__, '-v'])