"""Test for daily market data API. Tests the daily interface implementation against api.md requirements: - A股日线行情所有输出字段 - tor 换手率 - vr 量比 """ import pytest import pandas as pd from unittest.mock import Mock, patch from src.data.daily import get_daily from src.data.client import TushareClient # Expected output fields according to api.md EXPECTED_BASE_FIELDS = [ 'ts_code', # 股票代码 'trade_date', # 交易日期 'open', # 开盘价 'high', # 最高价 'low', # 最低价 'close', # 收盘价 'pre_close', # 昨收价 'change', # 涨跌额 'pct_chg', # 涨跌幅 'vol', # 成交量 'amount', # 成交额 ] EXPECTED_FACTOR_FIELDS = [ 'tor', # 换手率 'vr', # 量比 ] def run_tests_with_print(): """Run all tests and print results.""" print("=" * 60) print("Daily API 测试开始") print("=" * 60) test_results = [] # Test 1: Basic daily data fetch print("\n【测试1】基本日线数据获取") print("-" * 40) mock_data = pd.DataFrame({ 'ts_code': ['000001.SZ'], 'trade_date': ['20240102'], 'open': [10.5], 'high': [11.0], 'low': [10.2], 'close': [10.8], 'pre_close': [10.3], 'change': [0.5], 'pct_chg': [4.85], 'vol': [1000000], 'amount': [10800000], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily('000001.SZ', start_date='20240101', end_date='20240131') print(f"获取数据形状: {result.shape}") print(f"获取数据列: {result.columns.tolist()}") print(f"数据内容:\n{result}") # Verify tests_passed = isinstance(result, pd.DataFrame) tests_passed &= len(result) == 1 tests_passed &= result['ts_code'].iloc[0] == '000001.SZ' print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") test_results.append(("基本日线数据获取", tests_passed)) # Test 2: Fetch with factors print("\n【测试2】获取含换手率和量比的数据") print("-" * 40) mock_data_factors = pd.DataFrame({ 'ts_code': ['000001.SZ'], 'trade_date': ['20240102'], 'open': [10.5], 'high': [11.0], 'low': [10.2], 'close': [10.8], 'pre_close': [10.3], 'change': [0.5], 'pct_chg': [4.85], 'vol': [1000000], 'amount': [10800000], 'tor': [2.5], 'vr': [1.2], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data_factors): result = get_daily( '000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'], ) print(f"获取数据形状: {result.shape}") print(f"获取数据列: {result.columns.tolist()}") print(f"数据内容:\n{result}") # Verify columns tests_passed = isinstance(result, pd.DataFrame) missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns] missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns] print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base} ✗'}") print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors} ✗'}") tests_passed &= len(missing_base) == 0 tests_passed &= len(missing_factors) == 0 print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") test_results.append(("含因子数据获取", tests_passed)) # Test 3: Output fields completeness print("\n【测试3】输出字段完整性检查") print("-" * 40) mock_data = pd.DataFrame({ 'ts_code': ['600000.SH'], 'trade_date': ['20240102'], 'open': [10.5], 'high': [11.0], 'low': [10.2], 'close': [10.8], 'pre_close': [10.3], 'change': [0.5], 'pct_chg': [4.85], 'vol': [1000000], 'amount': [10800000], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily('600000.SH') print(f"获取数据形状: {result.shape}") print(f"获取数据列: {result.columns.tolist()}") print(f"期望基础列: {EXPECTED_BASE_FIELDS}") missing = set(EXPECTED_BASE_FIELDS) - set(result.columns) print(f"缺失列: {missing if missing else '无'}") tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()) print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") test_results.append(("输出字段完整性", tests_passed)) # Test 4: Empty result print("\n【测试4】空结果处理") print("-" * 40) mock_data = pd.DataFrame() with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily('INVALID.SZ') print(f"获取数据是否为空: {result.empty}") tests_passed = result.empty print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") test_results.append(("空结果处理", tests_passed)) # Test 5: Date range query print("\n【测试5】日期范围查询") print("-" * 40) mock_data = pd.DataFrame({ 'ts_code': ['000001.SZ', '000001.SZ'], 'trade_date': ['20240102', '20240103'], 'open': [10.5, 10.6], 'high': [11.0, 11.1], 'low': [10.2, 10.3], 'close': [10.8, 10.9], 'pre_close': [10.3, 10.8], 'change': [0.5, 0.1], 'pct_chg': [4.85, 0.93], 'vol': [1000000, 1100000], 'amount': [10800000, 11900000], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily( '000001.SZ', start_date='20240101', end_date='20240131', ) print(f"获取数据数量: {len(result)}") print(f"期望数量: 2") print(f"数据内容:\n{result}") tests_passed = len(result) == 2 print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") test_results.append(("日期范围查询", tests_passed)) # Test 6: With adjustment print("\n【测试6】带复权参数查询") print("-" * 40) mock_data = pd.DataFrame({ 'ts_code': ['000001.SZ'], 'trade_date': ['20240102'], 'open': [10.5], 'high': [11.0], 'low': [10.2], 'close': [10.8], 'pre_close': [10.3], 'change': [0.5], 'pct_chg': [4.85], 'vol': [1000000], 'amount': [10800000], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily('000001.SZ', adj='qfq') print(f"获取数据形状: {result.shape}") print(f"数据内容:\n{result}") tests_passed = isinstance(result, pd.DataFrame) print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") test_results.append(("复权参数查询", tests_passed)) # Summary print("\n" + "=" * 60) print("测试汇总") print("=" * 60) passed = sum(1 for _, r in test_results if r) total = len(test_results) print(f"总测试数: {total}") print(f"通过: {passed}") print(f"失败: {total - passed}") print(f"通过率: {passed/total*100:.1f}%") print("\n详细结果:") for name, passed in test_results: status = "通过 ✓" if passed else "失败 ✗" print(f" - {name}: {status}") return all(r for _, r in test_results) class TestGetDaily: """Test cases for simplified get_daily function.""" def test_fetch_basic(self): """Test basic daily data fetch.""" mock_data = pd.DataFrame({ 'ts_code': ['000001.SZ'], 'trade_date': ['20240102'], 'open': [10.5], 'high': [11.0], 'low': [10.2], 'close': [10.8], 'pre_close': [10.3], 'change': [0.5], 'pct_chg': [4.85], 'vol': [1000000], 'amount': [10800000], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily('000001.SZ', start_date='20240101', end_date='20240131') assert isinstance(result, pd.DataFrame) assert len(result) == 1 assert result['ts_code'].iloc[0] == '000001.SZ' def test_fetch_with_factors(self): """Test fetch with tor and vr factors.""" mock_data = pd.DataFrame({ 'ts_code': ['000001.SZ'], 'trade_date': ['20240102'], 'open': [10.5], 'high': [11.0], 'low': [10.2], 'close': [10.8], 'pre_close': [10.3], 'change': [0.5], 'pct_chg': [4.85], 'vol': [1000000], 'amount': [10800000], 'tor': [2.5], # 换手率 'vr': [1.2], # 量比 }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily( '000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'], ) assert isinstance(result, pd.DataFrame) # Check all base fields are present for field in EXPECTED_BASE_FIELDS: assert field in result.columns, f"Missing base field: {field}" # Check factor fields are present for field in EXPECTED_FACTOR_FIELDS: assert field in result.columns, f"Missing factor field: {field}" def test_output_fields_completeness(self): """Verify all required output fields are returned.""" mock_data = pd.DataFrame({ 'ts_code': ['600000.SH'], 'trade_date': ['20240102'], 'open': [10.5], 'high': [11.0], 'low': [10.2], 'close': [10.8], 'pre_close': [10.3], 'change': [0.5], 'pct_chg': [4.85], 'vol': [1000000], 'amount': [10800000], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily('600000.SH') # Verify all base fields are present assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \ f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}" def test_empty_result(self): """Test handling of empty results.""" mock_data = pd.DataFrame() with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily('INVALID.SZ') assert result.empty def test_date_range_query(self): """Test query with date range.""" mock_data = pd.DataFrame({ 'ts_code': ['000001.SZ', '000001.SZ'], 'trade_date': ['20240102', '20240103'], 'open': [10.5, 10.6], 'high': [11.0, 11.1], 'low': [10.2, 10.3], 'close': [10.8, 10.9], 'pre_close': [10.3, 10.8], 'change': [0.5, 0.1], 'pct_chg': [4.85, 0.93], 'vol': [1000000, 1100000], 'amount': [10800000, 11900000], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily( '000001.SZ', start_date='20240101', end_date='20240131', ) assert len(result) == 2 def test_with_adj(self): """Test fetch with adjustment type.""" mock_data = pd.DataFrame({ 'ts_code': ['000001.SZ'], 'trade_date': ['20240102'], 'open': [10.5], 'high': [11.0], 'low': [10.2], 'close': [10.8], 'pre_close': [10.3], 'change': [0.5], 'pct_chg': [4.85], 'vol': [1000000], 'amount': [10800000], }) with patch.object(TushareClient, '__init__', lambda self, token=None: None): with patch.object(TushareClient, 'query', return_value=mock_data): result = get_daily('000001.SZ', adj='qfq') assert isinstance(result, pd.DataFrame) def test_integration(): """Integration test with real Tushare API (requires valid token).""" import os token = os.environ.get('TUSHARE_TOKEN') if not token: pytest.skip("TUSHARE_TOKEN not configured") result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr']) # Verify structure assert isinstance(result, pd.DataFrame) if not result.empty: # Check base fields for field in EXPECTED_BASE_FIELDS: assert field in result.columns, f"Missing base field: {field}" # Check factor fields for field in EXPECTED_FACTOR_FIELDS: assert field in result.columns, f"Missing factor field: {field}" if __name__ == '__main__': # 运行详细的打印测试 run_tests_with_print() print("\n" + "=" * 60) print("运行 pytest 单元测试") print("=" * 60 + "\n") pytest.main([__file__, '-v'])