feat: 新增股票基础数据获取模块 stock_basic

- 新增 get_stock_basic 和 sync_all_stocks 函数
- 完善 Tushare 数据获取模块体系
- 测试用例重构:从 Mock 改为真实 API 调用
- 更新 API 文档,添加接口使用示例
- 更新开发规范:添加 Mock 使用规范
This commit is contained in:
2026-01-31 04:30:29 +08:00
parent e625a53162
commit 38e78a5326
10 changed files with 341 additions and 339 deletions

View File

@@ -291,6 +291,36 @@ tests/
└── fixtures/ # 测试数据
```
### 5.3 Mock 使用规范
- **默认不使用 Mock**:单元测试应直接调用真实 API 或服务
- **仅在必要时使用 Mock**
- 测试错误处理场景(如网络超时、服务不可用)
- 测试空结果或边界情况
- 隔离外部依赖的不确定性
- **禁止全面 Mock**:不应为避免配置或环境问题而使用 Mock
- **真实环境验证**:确保至少有一个测试套件直接调用真实 API
```python
# 正确示例:直接调用真实 API
class TestUserService:
def test_get_user(self):
result = get_user('user_id')
assert result.name == "John"
def test_empty_result_with_mock(self): # 特殊场景使用 mock
with patch('module.api_call', return_value=pd.DataFrame()):
result = get_user('invalid_id')
assert result.empty
# 错误示例:过度使用 mock
class TestUserService:
def test_get_user(self):
mock_data = {"name": "John"}
with patch('module.api_call', return_value=mock_data):
result = get_user('user_id')
assert result == mock_data
```
## 6 Git提交规范
### 6.1 提交信息格式
@@ -323,3 +353,4 @@ tests/
- [ ] 无循环依赖
- [ ] 命名符合规范
- [ ] 日志不包含敏感信息
- [ ] 测试未过度使用 Mock

BIN
data/stock_basic.h5 Normal file

Binary file not shown.

View File

@@ -6,10 +6,13 @@ Provides simplified interfaces for fetching and storing Tushare data.
from src.data.config import Config, get_config
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.stock_basic import get_stock_basic, sync_all_stocks
__all__ = [
"Config",
"get_config",
"TushareClient",
"Storage",
"get_stock_basic",
"sync_all_stocks",
]

View File

@@ -45,3 +45,82 @@ change float 涨跌额
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
vol float 成交量 (手)
amount float 成交额 (千元)
接口用例
#取000001的前复权行情
df = ts.pro_bar(ts_code='000001.SZ', adj='qfq', start_date='20180101', end_date='20181011')
ts_code trade_date open high low close \
trade_date
20181011 000001.SZ 20181011 1085.71 1097.59 1047.90 1065.19
20181010 000001.SZ 20181010 1138.65 1151.61 1121.36 1128.92
20181009 000001.SZ 20181009 1130.00 1155.93 1122.44 1140.81
20181008 000001.SZ 20181008 1155.93 1165.65 1128.92 1128.92
20180928 000001.SZ 20180928 1164.57 1217.51 1164.57 1193.74
#取上证指数行情数据
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
In [10]: df.head()
Out[10]:
ts_code trade_date close open high low \
0 000001.SH 20181011 2583.4575 2643.0740 2661.2859 2560.3164
1 000001.SH 20181010 2725.8367 2723.7242 2743.5480 2703.0626
2 000001.SH 20181009 2721.0130 2713.7319 2734.3142 2711.1971
3 000001.SH 20181008 2716.5104 2768.2075 2771.9384 2710.1781
4 000001.SH 20180928 2821.3501 2794.2644 2821.7553 2791.8363
pre_close change pct_chg vol amount
0 2725.8367 -142.3792 -5.2233 197150702.0 170057762.5
1 2721.0130 4.8237 0.1773 113485736.0 111312455.3
2 2716.5104 4.5026 0.1657 116771899.0 110292457.8
3 2821.3501 -104.8397 -3.7159 149501388.0 141531551.8
4 2791.7748 29.5753 1.0594 134290456.0 125369989.4
#均线
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', ma=[5, 20, 50])
Tushare pro_bar接口的均价和均量数据是动态计算想要获取某个时间段的均线必须要设置start_date日期大于最大均线的日期数然后自行截取想要日期段。例如想要获取20190801开始的3日均线必须设置start_date='20190729'然后剔除20190801之前的日期记录。
#换手率tor量比vr
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', factors=['tor', 'vr'])
- 基础信息https://tushare.pro/document/2?doc_id=25
接口stock_basic可以通过数据工具调试和查看数据
描述:获取基础信息数据,包括股票代码、名称、上市日期、退市日期等
权限2000积分起。此接口是基础信息调取一次就可以拉取完建议保存倒本地存储后使用
输入参数
名称 类型 必选 描述
ts_code str N TS股票代码
name str N 名称
market str N 市场类别 (主板/创业板/科创板/CDR/北交所)
list_status str N 上市状态 L上市 D退市 P暂停上市 G过会未交易默认是L
exchange str N 交易所 SSE上交所 SZSE深交所 BSE北交所
is_hs str N 是否沪深港通标的N否 H沪股通 S深股通
输出参数
名称 类型 默认显示 描述
ts_code str Y TS代码
symbol str Y 股票代码
name str Y 股票名称
area str Y 地域
industry str Y 所属行业
fullname str N 股票全称
enname str N 英文全称
cnspell str Y 拼音缩写
market str Y 市场类型(主板/创业板/科创板/CDR
exchange str N 交易所代码
curr_type str N 交易货币
list_status str N 上市状态 L上市 D退市 G过会未交易 P暂停上市
list_date str Y 上市日期
delist_date str N 退市日期
is_hs str N 是否沪深港通标的N否 H沪股通 S深股通
act_name str Y 实控人名称
act_ent_type str Y 实控人企业性质
说明旧版上的PE/PB/股本等字段,请在行情接口“每日指标”中获取。

View File

@@ -44,7 +44,7 @@ class TushareClient:
"""Execute API query with rate limiting and retry.
Args:
api_name: API name (e.g., 'daily')
api_name: API name ('daily', 'pro_bar', etc.)
timeout: Timeout for rate limiting
**params: API parameters
@@ -66,8 +66,22 @@ class TushareClient:
for attempt in range(max_retries):
try:
api = self._get_api()
data = api.query(api_name, **params)
import tushare as ts
# pro_bar uses ts.pro_bar() instead of api.query()
if api_name == "pro_bar":
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor
data = ts.pro_bar(ts_code=params.get("ts_code"),
start_date=params.get("start_date"),
end_date=params.get("end_date"),
adj=params.get("adj"),
freq=params.get("freq", "D"),
factors=params.get("factors"), # factors should be a list like ['tor', 'vr']
ma=params.get("ma"),
adjfactor=params.get("adjfactor"))
else:
api = self._get_api()
data = api.query(api_name, **params)
available = self.rate_limiter.get_available_tokens()
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")

View File

@@ -1,16 +1,37 @@
"""Configuration management for data collection module."""
import os
from pathlib import Path
from pydantic_settings import BaseSettings
# Config directory path - used for loading .env.local
# Static detection for pydantic-settings to find .env.local
CONFIG_DIR = Path(__file__).parent.parent.parent / "config"
def _get_project_root() -> Path:
"""Get project root path from ROOT_PATH env var or auto-detect."""
# Try to read from environment variable first
root_path = os.environ.get("ROOT_PATH") or os.environ.get("DATA_ROOT")
if root_path:
return Path(root_path)
# Fallback to auto-detection
return Path(__file__).parent.parent.parent
class Config(BaseSettings):
"""Application configuration loaded from environment variables."""
# Tushare API token
tushare_token: str = ""
# Data storage path
data_path: Path = Path("./data")
# Root path - loaded from environment variable ROOT_PATH
# If not set, uses auto-detected path
root_path: str = ""
# Data storage path - can be set via DATA_PATH environment variable
# If relative path, it will be resolved relative to root_path
data_path: str = "data"
# Rate limit: requests per minute
rate_limit: int = 100
@@ -18,10 +39,31 @@ class Config(BaseSettings):
# Thread pool size
threads: int = 2
@property
def project_root(self) -> Path:
"""Get project root path."""
if self.root_path:
return Path(self.root_path)
return _get_project_root()
@property
def data_path_resolved(self) -> Path:
"""Get resolved data path (absolute)."""
path = Path(self.data_path)
if path.is_absolute():
return path
return self.project_root / path
class Config:
env_file = ".env.local"
# 从 config/ 目录读取 .env.local 文件
env_file = str(CONFIG_DIR / ".env.local")
env_file_encoding = "utf-8"
case_sensitive = False
extra = "ignore" # 忽略 .env.local 中的额外变量
# pydantic-settings 默认会将字段名转换为大写作为环境变量名
# 所以 tushare_token 会映射到 TUSHARE_TOKEN
# root_path 会映射到 ROOT_PATH
# data_path 会映射到 DATA_PATH
# Global config instance
@@ -31,3 +73,8 @@ config = Config()
def get_config() -> Config:
"""Get configuration instance."""
return config
def get_project_root() -> Path:
"""Get project root path (convenience function)."""
return get_config().project_root

View File

@@ -57,14 +57,24 @@ def get_daily(
if adj:
params["adj"] = adj
if factors:
params["factors"] = factors
# Tushare expects factors as comma-separated string, not list
if isinstance(factors, list):
factors_str = ",".join(factors)
else:
factors_str = factors
params["factors"] = factors_str
print(f"[get_daily] factors param: '{factors_str}'")
if adjfactor:
params["adjfactor"] = "True"
# Fetch data
data = client.query("daily", **params)
# Fetch data using pro_bar (supports factors like tor, vr)
print(f"[get_daily] Query params: {params}")
data = client.query("pro_bar", **params)
if data.empty:
if not data.empty:
print(f"[get_daily] Returned columns: {data.columns.tolist()}")
print(f"[get_daily] Sample row: {data.iloc[0].to_dict()}")
else:
print(f"[get_daily] No data for ts_code={ts_code}")
return data

120
src/data/stock_basic.py Normal file
View File

@@ -0,0 +1,120 @@
"""Simplified stock basic information interface.
Fetch basic stock information including code, name, listing date, etc.
This is a special interface - call once to get all stocks (listed and delisted).
"""
import pandas as pd
from typing import Optional, Literal, List
from src.data.client import TushareClient
from src.data.storage import Storage
def get_stock_basic(
ts_code: Optional[str] = None,
name: Optional[str] = None,
market: Optional[Literal["主板", "创业板", "科创板", "CDR", "北交所"]] = None,
list_status: Optional[Literal["L", "D", "P", "G"]] = None,
exchange: Optional[Literal["SSE", "SZSE", "BSE"]] = None,
is_hs: Optional[Literal["N", "H", "S"]] = None,
fields: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Fetch basic stock information.
This interface retrieves basic information data including stock code,
name, listing date, delisting date, etc.
Args:
ts_code: TS stock code
name: Stock name
market: Market type (主板/创业板/科创板/CDR/北交所)
list_status: Listing status - L(listed), D(delisted), P(paused), G(approved-not-traded)
exchange: Exchange - SSE, SZSE, BSE
is_hs: HS indicator - N(no), H(Shanghai), S(Shenzhen)
fields: Specific fields to return, None returns default fields
Returns:
pd.DataFrame with stock basic information containing:
- ts_code, symbol, name, area, industry, fullname, enname, cnspell,
market, exchange, curr_type, list_status, list_date, delist_date,
is_hs, act_name, act_ent_type
"""
client = TushareClient()
# Build parameters
params = {}
if ts_code:
params["ts_code"] = ts_code
if name:
params["name"] = name
if market:
params["market"] = market
if list_status:
params["list_status"] = list_status
if exchange:
params["exchange"] = exchange
if is_hs:
params["is_hs"] = is_hs
if fields:
params["fields"] = ",".join(fields)
# Fetch data
data = client.query("stock_basic", **params)
if data.empty:
print("[get_stock_basic] No data returned")
return data
def sync_all_stocks() -> pd.DataFrame:
"""Fetch and save all stocks (listed and delisted) to local storage.
This is a special interface that should only be called once to initialize
the local database with all stock information.
Returns:
pd.DataFrame with all stock information
"""
# Initialize storage
storage = Storage()
# Check if already exists
if storage.exists("stock_basic"):
print("[sync_all_stocks] stock_basic data already exists, skipping...")
return storage.load("stock_basic")
print("[sync_all_stocks] Fetching all stocks (listed and delisted)...")
# Fetch all stocks - explicitly get all list_status values
# API default is L (listed), so we need to fetch all statuses
client = TushareClient()
all_data = []
for status in ["L", "D", "P", "G"]:
print(f"[sync_all_stocks] Fetching stocks with status: {status}")
data = client.query("stock_basic", list_status=status)
print(f"[sync_all_stocks] Fetched {len(data)} stocks with status {status}")
if not data.empty:
all_data.append(data)
if not all_data:
print("[sync_all_stocks] No stock data fetched")
return pd.DataFrame()
# Combine all data
data = pd.concat(all_data, ignore_index=True)
# Remove duplicates if any
data = data.drop_duplicates(subset=["ts_code"], keep="first")
print(f"[sync_all_stocks] Total unique stocks: {len(data)}")
# Save to storage
storage.save("stock_basic", data, mode="replace")
print(f"[sync_all_stocks] Saved {len(data)} stocks to local storage")
return data
if __name__ == "__main__":
# Sync all stocks (listed and delisted) to data folder
result = sync_all_stocks()
print(f"Total stocks synced: {len(result)}")

View File

@@ -16,7 +16,7 @@ class Storage:
path: Base path for data storage (auto-loaded from config if not provided)
"""
cfg = get_config()
self.base_path = path or cfg.data_path
self.base_path = path or cfg.data_path_resolved
self.base_path.mkdir(parents=True, exist_ok=True)
def _get_file_path(self, name: str) -> Path:

View File

@@ -7,9 +7,7 @@ Tests the daily interface implementation against api.md requirements:
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch
from src.data.daily import get_daily
from src.data.client import TushareClient
# Expected output fields according to api.md
@@ -28,276 +26,30 @@ EXPECTED_BASE_FIELDS = [
]
EXPECTED_FACTOR_FIELDS = [
'tor', # 换手率
'vr', # 量比
'turnover_rate', # 换手率 (tor)
'volume_ratio', # 量比 (vr)
]
def run_tests_with_print():
"""Run all tests and print results."""
print("=" * 60)
print("Daily API 测试开始")
print("=" * 60)
test_results = []
# Test 1: Basic daily data fetch
print("\n【测试1】基本日线数据获取")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"数据内容:\n{result}")
# Verify
tests_passed = isinstance(result, pd.DataFrame)
tests_passed &= len(result) == 1
tests_passed &= result['ts_code'].iloc[0] == '000001.SZ'
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("基本日线数据获取", tests_passed))
# Test 2: Fetch with factors
print("\n【测试2】获取含换手率和量比的数据")
print("-" * 40)
mock_data_factors = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
'tor': [2.5],
'vr': [1.2],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data_factors):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"数据内容:\n{result}")
# Verify columns
tests_passed = isinstance(result, pd.DataFrame)
missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns]
missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns]
print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base}'}")
print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors}'}")
tests_passed &= len(missing_base) == 0
tests_passed &= len(missing_factors) == 0
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("含因子数据获取", tests_passed))
# Test 3: Output fields completeness
print("\n【测试3】输出字段完整性检查")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['600000.SH'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('600000.SH')
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"期望基础列: {EXPECTED_BASE_FIELDS}")
missing = set(EXPECTED_BASE_FIELDS) - set(result.columns)
print(f"缺失列: {missing if missing else ''}")
tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist())
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("输出字段完整性", tests_passed))
# Test 4: Empty result
print("\n【测试4】空结果处理")
print("-" * 40)
mock_data = pd.DataFrame()
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('INVALID.SZ')
print(f"获取数据是否为空: {result.empty}")
tests_passed = result.empty
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("空结果处理", tests_passed))
# Test 5: Date range query
print("\n【测试5】日期范围查询")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ'],
'trade_date': ['20240102', '20240103'],
'open': [10.5, 10.6],
'high': [11.0, 11.1],
'low': [10.2, 10.3],
'close': [10.8, 10.9],
'pre_close': [10.3, 10.8],
'change': [0.5, 0.1],
'pct_chg': [4.85, 0.93],
'vol': [1000000, 1100000],
'amount': [10800000, 11900000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
print(f"获取数据数量: {len(result)}")
print(f"期望数量: 2")
print(f"数据内容:\n{result}")
tests_passed = len(result) == 2
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("日期范围查询", tests_passed))
# Test 6: With adjustment
print("\n【测试6】带复权参数查询")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', adj='qfq')
print(f"获取数据形状: {result.shape}")
print(f"数据内容:\n{result}")
tests_passed = isinstance(result, pd.DataFrame)
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("复权参数查询", tests_passed))
# Summary
print("\n" + "=" * 60)
print("测试汇总")
print("=" * 60)
passed = sum(1 for _, r in test_results if r)
total = len(test_results)
print(f"总测试数: {total}")
print(f"通过: {passed}")
print(f"失败: {total - passed}")
print(f"通过率: {passed/total*100:.1f}%")
print("\n详细结果:")
for name, passed in test_results:
status = "通过 ✓" if passed else "失败 ✗"
print(f" - {name}: {status}")
return all(r for _, r in test_results)
class TestGetDaily:
"""Test cases for simplified get_daily function."""
"""Test cases for get_daily function with real API calls."""
def test_fetch_basic(self):
"""Test basic daily data fetch."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
"""Test basic daily data fetch with real API."""
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert len(result) >= 1
assert result['ts_code'].iloc[0] == '000001.SZ'
def test_fetch_with_factors(self):
"""Test fetch with tor and vr factors."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
'tor': [2.5], # 换手率
'vr': [1.2], # 量比
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
assert isinstance(result, pd.DataFrame)
# Check all base fields are present
@@ -309,23 +61,7 @@ class TestGetDaily:
def test_output_fields_completeness(self):
"""Verify all required output fields are returned."""
mock_data = pd.DataFrame({
'ts_code': ['600000.SH'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('600000.SH')
result = get_daily('600000.SH')
# Verify all base fields are present
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
@@ -333,59 +69,25 @@ class TestGetDaily:
def test_empty_result(self):
"""Test handling of empty results."""
mock_data = pd.DataFrame()
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('INVALID.SZ')
# 使用真实 API 测试无效股票代码的空结果
result = get_daily('INVALID.SZ')
assert isinstance(result, pd.DataFrame)
assert result.empty
def test_date_range_query(self):
"""Test query with date range."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ'],
'trade_date': ['20240102', '20240103'],
'open': [10.5, 10.6],
'high': [11.0, 11.1],
'low': [10.2, 10.3],
'close': [10.8, 10.9],
'pre_close': [10.3, 10.8],
'change': [0.5, 0.1],
'pct_chg': [4.85, 0.93],
'vol': [1000000, 1100000],
'amount': [10800000, 11900000],
})
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
assert len(result) == 2
assert isinstance(result, pd.DataFrame)
assert len(result) >= 1
def test_with_adj(self):
"""Test fetch with adjustment type."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', adj='qfq')
result = get_daily('000001.SZ', adj='qfq')
assert isinstance(result, pd.DataFrame)
@@ -411,9 +113,5 @@ def test_integration():
if __name__ == '__main__':
# 运行详细的打印测试
run_tests_with_print()
print("\n" + "=" * 60)
print("运行 pytest 单元测试")
print("=" * 60 + "\n")
# 运行 pytest 单元测试真实API调用
pytest.main([__file__, '-v'])