feat: 新增股票基础数据获取模块 stock_basic
- 新增 get_stock_basic 和 sync_all_stocks 函数 - 完善 Tushare 数据获取模块体系 - 测试用例重构:从 Mock 改为真实 API 调用 - 更新 API 文档,添加接口使用示例 - 更新开发规范:添加 Mock 使用规范
This commit is contained in:
@@ -291,6 +291,36 @@ tests/
|
||||
└── fixtures/ # 测试数据
|
||||
```
|
||||
|
||||
### 5.3 Mock 使用规范
|
||||
- **默认不使用 Mock**:单元测试应直接调用真实 API 或服务
|
||||
- **仅在必要时使用 Mock**:
|
||||
- 测试错误处理场景(如网络超时、服务不可用)
|
||||
- 测试空结果或边界情况
|
||||
- 隔离外部依赖的不确定性
|
||||
- **禁止全面 Mock**:不应为避免配置或环境问题而使用 Mock
|
||||
- **真实环境验证**:确保至少有一个测试套件直接调用真实 API
|
||||
|
||||
```python
|
||||
# 正确示例:直接调用真实 API
|
||||
class TestUserService:
|
||||
def test_get_user(self):
|
||||
result = get_user('user_id')
|
||||
assert result.name == "John"
|
||||
|
||||
def test_empty_result_with_mock(self): # 特殊场景使用 mock
|
||||
with patch('module.api_call', return_value=pd.DataFrame()):
|
||||
result = get_user('invalid_id')
|
||||
assert result.empty
|
||||
|
||||
# 错误示例:过度使用 mock
|
||||
class TestUserService:
|
||||
def test_get_user(self):
|
||||
mock_data = {"name": "John"}
|
||||
with patch('module.api_call', return_value=mock_data):
|
||||
result = get_user('user_id')
|
||||
assert result == mock_data
|
||||
```
|
||||
|
||||
## 6 Git提交规范
|
||||
|
||||
### 6.1 提交信息格式
|
||||
@@ -323,3 +353,4 @@ tests/
|
||||
- [ ] 无循环依赖
|
||||
- [ ] 命名符合规范
|
||||
- [ ] 日志不包含敏感信息
|
||||
- [ ] 测试未过度使用 Mock
|
||||
|
||||
BIN
data/stock_basic.h5
Normal file
BIN
data/stock_basic.h5
Normal file
Binary file not shown.
@@ -6,10 +6,13 @@ Provides simplified interfaces for fetching and storing Tushare data.
|
||||
from src.data.config import Config, get_config
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.stock_basic import get_stock_basic, sync_all_stocks
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"get_config",
|
||||
"TushareClient",
|
||||
"Storage",
|
||||
"get_stock_basic",
|
||||
"sync_all_stocks",
|
||||
]
|
||||
|
||||
@@ -44,4 +44,83 @@ pre_close float 昨收价【除权价】
|
||||
change float 涨跌额
|
||||
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
|
||||
vol float 成交量 (手)
|
||||
amount float 成交额 (千元)
|
||||
amount float 成交额 (千元)
|
||||
|
||||
接口用例
|
||||
|
||||
#取000001的前复权行情
|
||||
df = ts.pro_bar(ts_code='000001.SZ', adj='qfq', start_date='20180101', end_date='20181011')
|
||||
|
||||
ts_code trade_date open high low close \
|
||||
trade_date
|
||||
20181011 000001.SZ 20181011 1085.71 1097.59 1047.90 1065.19
|
||||
20181010 000001.SZ 20181010 1138.65 1151.61 1121.36 1128.92
|
||||
20181009 000001.SZ 20181009 1130.00 1155.93 1122.44 1140.81
|
||||
20181008 000001.SZ 20181008 1155.93 1165.65 1128.92 1128.92
|
||||
20180928 000001.SZ 20180928 1164.57 1217.51 1164.57 1193.74
|
||||
|
||||
#取上证指数行情数据
|
||||
|
||||
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
|
||||
|
||||
In [10]: df.head()
|
||||
Out[10]:
|
||||
ts_code trade_date close open high low \
|
||||
0 000001.SH 20181011 2583.4575 2643.0740 2661.2859 2560.3164
|
||||
1 000001.SH 20181010 2725.8367 2723.7242 2743.5480 2703.0626
|
||||
2 000001.SH 20181009 2721.0130 2713.7319 2734.3142 2711.1971
|
||||
3 000001.SH 20181008 2716.5104 2768.2075 2771.9384 2710.1781
|
||||
4 000001.SH 20180928 2821.3501 2794.2644 2821.7553 2791.8363
|
||||
|
||||
pre_close change pct_chg vol amount
|
||||
0 2725.8367 -142.3792 -5.2233 197150702.0 170057762.5
|
||||
1 2721.0130 4.8237 0.1773 113485736.0 111312455.3
|
||||
2 2716.5104 4.5026 0.1657 116771899.0 110292457.8
|
||||
3 2821.3501 -104.8397 -3.7159 149501388.0 141531551.8
|
||||
4 2791.7748 29.5753 1.0594 134290456.0 125369989.4
|
||||
|
||||
#均线
|
||||
|
||||
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', ma=[5, 20, 50])
|
||||
注:Tushare pro_bar接口的均价和均量数据是动态计算,想要获取某个时间段的均线,必须要设置start_date日期大于最大均线的日期数,然后自行截取想要日期段。例如,想要获取20190801开始的3日均线,必须设置start_date='20190729',然后剔除20190801之前的日期记录。
|
||||
|
||||
#换手率tor,量比vr
|
||||
|
||||
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', factors=['tor', 'vr'])
|
||||
|
||||
|
||||
- 基础信息:https://tushare.pro/document/2?doc_id=25
|
||||
接口:stock_basic,可以通过数据工具调试和查看数据
|
||||
描述:获取基础信息数据,包括股票代码、名称、上市日期、退市日期等
|
||||
权限:2000积分起。此接口是基础信息,调取一次就可以拉取完,建议保存倒本地存储后使用
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str N TS股票代码
|
||||
name str N 名称
|
||||
market str N 市场类别 (主板/创业板/科创板/CDR/北交所)
|
||||
list_status str N 上市状态 L上市 D退市 P暂停上市 G过会未交易,默认是L
|
||||
exchange str N 交易所 SSE上交所 SZSE深交所 BSE北交所
|
||||
is_hs str N 是否沪深港通标的,N否 H沪股通 S深股通
|
||||
输出参数
|
||||
|
||||
名称 类型 默认显示 描述
|
||||
ts_code str Y TS代码
|
||||
symbol str Y 股票代码
|
||||
name str Y 股票名称
|
||||
area str Y 地域
|
||||
industry str Y 所属行业
|
||||
fullname str N 股票全称
|
||||
enname str N 英文全称
|
||||
cnspell str Y 拼音缩写
|
||||
market str Y 市场类型(主板/创业板/科创板/CDR)
|
||||
exchange str N 交易所代码
|
||||
curr_type str N 交易货币
|
||||
list_status str N 上市状态 L上市 D退市 G过会未交易 P暂停上市
|
||||
list_date str Y 上市日期
|
||||
delist_date str N 退市日期
|
||||
is_hs str N 是否沪深港通标的,N否 H沪股通 S深股通
|
||||
act_name str Y 实控人名称
|
||||
act_ent_type str Y 实控人企业性质
|
||||
说明:旧版上的PE/PB/股本等字段,请在行情接口“每日指标”中获取。
|
||||
@@ -44,7 +44,7 @@ class TushareClient:
|
||||
"""Execute API query with rate limiting and retry.
|
||||
|
||||
Args:
|
||||
api_name: API name (e.g., 'daily')
|
||||
api_name: API name ('daily', 'pro_bar', etc.)
|
||||
timeout: Timeout for rate limiting
|
||||
**params: API parameters
|
||||
|
||||
@@ -66,8 +66,22 @@ class TushareClient:
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
api = self._get_api()
|
||||
data = api.query(api_name, **params)
|
||||
import tushare as ts
|
||||
|
||||
# pro_bar uses ts.pro_bar() instead of api.query()
|
||||
if api_name == "pro_bar":
|
||||
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor
|
||||
data = ts.pro_bar(ts_code=params.get("ts_code"),
|
||||
start_date=params.get("start_date"),
|
||||
end_date=params.get("end_date"),
|
||||
adj=params.get("adj"),
|
||||
freq=params.get("freq", "D"),
|
||||
factors=params.get("factors"), # factors should be a list like ['tor', 'vr']
|
||||
ma=params.get("ma"),
|
||||
adjfactor=params.get("adjfactor"))
|
||||
else:
|
||||
api = self._get_api()
|
||||
data = api.query(api_name, **params)
|
||||
|
||||
available = self.rate_limiter.get_available_tokens()
|
||||
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")
|
||||
|
||||
@@ -1,16 +1,37 @@
|
||||
"""Configuration management for data collection module."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
# Config directory path - used for loading .env.local
|
||||
# Static detection for pydantic-settings to find .env.local
|
||||
CONFIG_DIR = Path(__file__).parent.parent.parent / "config"
|
||||
|
||||
|
||||
def _get_project_root() -> Path:
|
||||
"""Get project root path from ROOT_PATH env var or auto-detect."""
|
||||
# Try to read from environment variable first
|
||||
root_path = os.environ.get("ROOT_PATH") or os.environ.get("DATA_ROOT")
|
||||
if root_path:
|
||||
return Path(root_path)
|
||||
# Fallback to auto-detection
|
||||
return Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""Application configuration loaded from environment variables."""
|
||||
|
||||
# Tushare API token
|
||||
tushare_token: str = ""
|
||||
|
||||
# Data storage path
|
||||
data_path: Path = Path("./data")
|
||||
# Root path - loaded from environment variable ROOT_PATH
|
||||
# If not set, uses auto-detected path
|
||||
root_path: str = ""
|
||||
|
||||
# Data storage path - can be set via DATA_PATH environment variable
|
||||
# If relative path, it will be resolved relative to root_path
|
||||
data_path: str = "data"
|
||||
|
||||
# Rate limit: requests per minute
|
||||
rate_limit: int = 100
|
||||
@@ -18,10 +39,31 @@ class Config(BaseSettings):
|
||||
# Thread pool size
|
||||
threads: int = 2
|
||||
|
||||
@property
|
||||
def project_root(self) -> Path:
|
||||
"""Get project root path."""
|
||||
if self.root_path:
|
||||
return Path(self.root_path)
|
||||
return _get_project_root()
|
||||
|
||||
@property
|
||||
def data_path_resolved(self) -> Path:
|
||||
"""Get resolved data path (absolute)."""
|
||||
path = Path(self.data_path)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
return self.project_root / path
|
||||
|
||||
class Config:
|
||||
env_file = ".env.local"
|
||||
# 从 config/ 目录读取 .env.local 文件
|
||||
env_file = str(CONFIG_DIR / ".env.local")
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
extra = "ignore" # 忽略 .env.local 中的额外变量
|
||||
# pydantic-settings 默认会将字段名转换为大写作为环境变量名
|
||||
# 所以 tushare_token 会映射到 TUSHARE_TOKEN
|
||||
# root_path 会映射到 ROOT_PATH
|
||||
# data_path 会映射到 DATA_PATH
|
||||
|
||||
|
||||
# Global config instance
|
||||
@@ -31,3 +73,8 @@ config = Config()
|
||||
def get_config() -> Config:
|
||||
"""Get configuration instance."""
|
||||
return config
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""Get project root path (convenience function)."""
|
||||
return get_config().project_root
|
||||
|
||||
@@ -57,14 +57,24 @@ def get_daily(
|
||||
if adj:
|
||||
params["adj"] = adj
|
||||
if factors:
|
||||
params["factors"] = factors
|
||||
# Tushare expects factors as comma-separated string, not list
|
||||
if isinstance(factors, list):
|
||||
factors_str = ",".join(factors)
|
||||
else:
|
||||
factors_str = factors
|
||||
params["factors"] = factors_str
|
||||
print(f"[get_daily] factors param: '{factors_str}'")
|
||||
if adjfactor:
|
||||
params["adjfactor"] = "True"
|
||||
|
||||
# Fetch data
|
||||
data = client.query("daily", **params)
|
||||
# Fetch data using pro_bar (supports factors like tor, vr)
|
||||
print(f"[get_daily] Query params: {params}")
|
||||
data = client.query("pro_bar", **params)
|
||||
|
||||
if data.empty:
|
||||
if not data.empty:
|
||||
print(f"[get_daily] Returned columns: {data.columns.tolist()}")
|
||||
print(f"[get_daily] Sample row: {data.iloc[0].to_dict()}")
|
||||
else:
|
||||
print(f"[get_daily] No data for ts_code={ts_code}")
|
||||
|
||||
return data
|
||||
|
||||
120
src/data/stock_basic.py
Normal file
120
src/data/stock_basic.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Simplified stock basic information interface.
|
||||
|
||||
Fetch basic stock information including code, name, listing date, etc.
|
||||
This is a special interface - call once to get all stocks (listed and delisted).
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Optional, Literal, List
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
|
||||
|
||||
def get_stock_basic(
|
||||
ts_code: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
market: Optional[Literal["主板", "创业板", "科创板", "CDR", "北交所"]] = None,
|
||||
list_status: Optional[Literal["L", "D", "P", "G"]] = None,
|
||||
exchange: Optional[Literal["SSE", "SZSE", "BSE"]] = None,
|
||||
is_hs: Optional[Literal["N", "H", "S"]] = None,
|
||||
fields: Optional[List[str]] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch basic stock information.
|
||||
|
||||
This interface retrieves basic information data including stock code,
|
||||
name, listing date, delisting date, etc.
|
||||
|
||||
Args:
|
||||
ts_code: TS stock code
|
||||
name: Stock name
|
||||
market: Market type (主板/创业板/科创板/CDR/北交所)
|
||||
list_status: Listing status - L(listed), D(delisted), P(paused), G(approved-not-traded)
|
||||
exchange: Exchange - SSE, SZSE, BSE
|
||||
is_hs: HS indicator - N(no), H(Shanghai), S(Shenzhen)
|
||||
fields: Specific fields to return, None returns default fields
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with stock basic information containing:
|
||||
- ts_code, symbol, name, area, industry, fullname, enname, cnspell,
|
||||
market, exchange, curr_type, list_status, list_date, delist_date,
|
||||
is_hs, act_name, act_ent_type
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
if name:
|
||||
params["name"] = name
|
||||
if market:
|
||||
params["market"] = market
|
||||
if list_status:
|
||||
params["list_status"] = list_status
|
||||
if exchange:
|
||||
params["exchange"] = exchange
|
||||
if is_hs:
|
||||
params["is_hs"] = is_hs
|
||||
if fields:
|
||||
params["fields"] = ",".join(fields)
|
||||
|
||||
# Fetch data
|
||||
data = client.query("stock_basic", **params)
|
||||
|
||||
if data.empty:
|
||||
print("[get_stock_basic] No data returned")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def sync_all_stocks() -> pd.DataFrame:
|
||||
"""Fetch and save all stocks (listed and delisted) to local storage.
|
||||
|
||||
This is a special interface that should only be called once to initialize
|
||||
the local database with all stock information.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with all stock information
|
||||
"""
|
||||
# Initialize storage
|
||||
storage = Storage()
|
||||
|
||||
# Check if already exists
|
||||
if storage.exists("stock_basic"):
|
||||
print("[sync_all_stocks] stock_basic data already exists, skipping...")
|
||||
return storage.load("stock_basic")
|
||||
|
||||
print("[sync_all_stocks] Fetching all stocks (listed and delisted)...")
|
||||
|
||||
# Fetch all stocks - explicitly get all list_status values
|
||||
# API default is L (listed), so we need to fetch all statuses
|
||||
client = TushareClient()
|
||||
|
||||
all_data = []
|
||||
for status in ["L", "D", "P", "G"]:
|
||||
print(f"[sync_all_stocks] Fetching stocks with status: {status}")
|
||||
data = client.query("stock_basic", list_status=status)
|
||||
print(f"[sync_all_stocks] Fetched {len(data)} stocks with status {status}")
|
||||
if not data.empty:
|
||||
all_data.append(data)
|
||||
|
||||
if not all_data:
|
||||
print("[sync_all_stocks] No stock data fetched")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine all data
|
||||
data = pd.concat(all_data, ignore_index=True)
|
||||
# Remove duplicates if any
|
||||
data = data.drop_duplicates(subset=["ts_code"], keep="first")
|
||||
print(f"[sync_all_stocks] Total unique stocks: {len(data)}")
|
||||
|
||||
# Save to storage
|
||||
storage.save("stock_basic", data, mode="replace")
|
||||
|
||||
print(f"[sync_all_stocks] Saved {len(data)} stocks to local storage")
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Sync all stocks (listed and delisted) to data folder
|
||||
result = sync_all_stocks()
|
||||
print(f"Total stocks synced: {len(result)}")
|
||||
@@ -16,7 +16,7 @@ class Storage:
|
||||
path: Base path for data storage (auto-loaded from config if not provided)
|
||||
"""
|
||||
cfg = get_config()
|
||||
self.base_path = path or cfg.data_path
|
||||
self.base_path = path or cfg.data_path_resolved
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_file_path(self, name: str) -> Path:
|
||||
|
||||
@@ -7,9 +7,7 @@ Tests the daily interface implementation against api.md requirements:
|
||||
"""
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch
|
||||
from src.data.daily import get_daily
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
# Expected output fields according to api.md
|
||||
@@ -28,276 +26,30 @@ EXPECTED_BASE_FIELDS = [
|
||||
]
|
||||
|
||||
EXPECTED_FACTOR_FIELDS = [
|
||||
'tor', # 换手率
|
||||
'vr', # 量比
|
||||
'turnover_rate', # 换手率 (tor)
|
||||
'volume_ratio', # 量比 (vr)
|
||||
]
|
||||
|
||||
|
||||
def run_tests_with_print():
|
||||
"""Run all tests and print results."""
|
||||
print("=" * 60)
|
||||
print("Daily API 测试开始")
|
||||
print("=" * 60)
|
||||
|
||||
test_results = []
|
||||
|
||||
# Test 1: Basic daily data fetch
|
||||
print("\n【测试1】基本日线数据获取")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
|
||||
print(f"获取数据形状: {result.shape}")
|
||||
print(f"获取数据列: {result.columns.tolist()}")
|
||||
print(f"数据内容:\n{result}")
|
||||
|
||||
# Verify
|
||||
tests_passed = isinstance(result, pd.DataFrame)
|
||||
tests_passed &= len(result) == 1
|
||||
tests_passed &= result['ts_code'].iloc[0] == '000001.SZ'
|
||||
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("基本日线数据获取", tests_passed))
|
||||
|
||||
# Test 2: Fetch with factors
|
||||
print("\n【测试2】获取含换手率和量比的数据")
|
||||
print("-" * 40)
|
||||
mock_data_factors = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
'tor': [2.5],
|
||||
'vr': [1.2],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data_factors):
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
factors=['tor', 'vr'],
|
||||
)
|
||||
|
||||
print(f"获取数据形状: {result.shape}")
|
||||
print(f"获取数据列: {result.columns.tolist()}")
|
||||
print(f"数据内容:\n{result}")
|
||||
|
||||
# Verify columns
|
||||
tests_passed = isinstance(result, pd.DataFrame)
|
||||
missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns]
|
||||
missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns]
|
||||
|
||||
print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base} ✗'}")
|
||||
print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors} ✗'}")
|
||||
|
||||
tests_passed &= len(missing_base) == 0
|
||||
tests_passed &= len(missing_factors) == 0
|
||||
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("含因子数据获取", tests_passed))
|
||||
|
||||
# Test 3: Output fields completeness
|
||||
print("\n【测试3】输出字段完整性检查")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['600000.SH'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('600000.SH')
|
||||
|
||||
print(f"获取数据形状: {result.shape}")
|
||||
print(f"获取数据列: {result.columns.tolist()}")
|
||||
print(f"期望基础列: {EXPECTED_BASE_FIELDS}")
|
||||
|
||||
missing = set(EXPECTED_BASE_FIELDS) - set(result.columns)
|
||||
print(f"缺失列: {missing if missing else '无'}")
|
||||
|
||||
tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist())
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("输出字段完整性", tests_passed))
|
||||
|
||||
# Test 4: Empty result
|
||||
print("\n【测试4】空结果处理")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame()
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('INVALID.SZ')
|
||||
|
||||
print(f"获取数据是否为空: {result.empty}")
|
||||
tests_passed = result.empty
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("空结果处理", tests_passed))
|
||||
|
||||
# Test 5: Date range query
|
||||
print("\n【测试5】日期范围查询")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '000001.SZ'],
|
||||
'trade_date': ['20240102', '20240103'],
|
||||
'open': [10.5, 10.6],
|
||||
'high': [11.0, 11.1],
|
||||
'low': [10.2, 10.3],
|
||||
'close': [10.8, 10.9],
|
||||
'pre_close': [10.3, 10.8],
|
||||
'change': [0.5, 0.1],
|
||||
'pct_chg': [4.85, 0.93],
|
||||
'vol': [1000000, 1100000],
|
||||
'amount': [10800000, 11900000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
)
|
||||
|
||||
print(f"获取数据数量: {len(result)}")
|
||||
print(f"期望数量: 2")
|
||||
print(f"数据内容:\n{result}")
|
||||
|
||||
tests_passed = len(result) == 2
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("日期范围查询", tests_passed))
|
||||
|
||||
# Test 6: With adjustment
|
||||
print("\n【测试6】带复权参数查询")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('000001.SZ', adj='qfq')
|
||||
|
||||
print(f"获取数据形状: {result.shape}")
|
||||
print(f"数据内容:\n{result}")
|
||||
|
||||
tests_passed = isinstance(result, pd.DataFrame)
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("复权参数查询", tests_passed))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("测试汇总")
|
||||
print("=" * 60)
|
||||
passed = sum(1 for _, r in test_results if r)
|
||||
total = len(test_results)
|
||||
print(f"总测试数: {total}")
|
||||
print(f"通过: {passed}")
|
||||
print(f"失败: {total - passed}")
|
||||
print(f"通过率: {passed/total*100:.1f}%")
|
||||
print("\n详细结果:")
|
||||
for name, passed in test_results:
|
||||
status = "通过 ✓" if passed else "失败 ✗"
|
||||
print(f" - {name}: {status}")
|
||||
|
||||
return all(r for _, r in test_results)
|
||||
|
||||
|
||||
class TestGetDaily:
|
||||
"""Test cases for simplified get_daily function."""
|
||||
"""Test cases for get_daily function with real API calls."""
|
||||
|
||||
def test_fetch_basic(self):
|
||||
"""Test basic daily data fetch."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
"""Test basic daily data fetch with real API."""
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) == 1
|
||||
assert len(result) >= 1
|
||||
assert result['ts_code'].iloc[0] == '000001.SZ'
|
||||
|
||||
def test_fetch_with_factors(self):
|
||||
"""Test fetch with tor and vr factors."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
'tor': [2.5], # 换手率
|
||||
'vr': [1.2], # 量比
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
factors=['tor', 'vr'],
|
||||
)
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
factors=['tor', 'vr'],
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
# Check all base fields are present
|
||||
@@ -309,23 +61,7 @@ class TestGetDaily:
|
||||
|
||||
def test_output_fields_completeness(self):
|
||||
"""Verify all required output fields are returned."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['600000.SH'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('600000.SH')
|
||||
result = get_daily('600000.SH')
|
||||
|
||||
# Verify all base fields are present
|
||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
|
||||
@@ -333,59 +69,25 @@ class TestGetDaily:
|
||||
|
||||
def test_empty_result(self):
|
||||
"""Test handling of empty results."""
|
||||
mock_data = pd.DataFrame()
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('INVALID.SZ')
|
||||
|
||||
# 使用真实 API 测试无效股票代码的空结果
|
||||
result = get_daily('INVALID.SZ')
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert result.empty
|
||||
|
||||
def test_date_range_query(self):
|
||||
"""Test query with date range."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '000001.SZ'],
|
||||
'trade_date': ['20240102', '20240103'],
|
||||
'open': [10.5, 10.6],
|
||||
'high': [11.0, 11.1],
|
||||
'low': [10.2, 10.3],
|
||||
'close': [10.8, 10.9],
|
||||
'pre_close': [10.3, 10.8],
|
||||
'change': [0.5, 0.1],
|
||||
'pct_chg': [4.85, 0.93],
|
||||
'vol': [1000000, 1100000],
|
||||
'amount': [10800000, 11900000],
|
||||
})
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
)
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_with_adj(self):
|
||||
"""Test fetch with adjustment type."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('000001.SZ', adj='qfq')
|
||||
result = get_daily('000001.SZ', adj='qfq')
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
|
||||
@@ -411,9 +113,5 @@ def test_integration():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 运行详细的打印测试
|
||||
run_tests_with_print()
|
||||
print("\n" + "=" * 60)
|
||||
print("运行 pytest 单元测试")
|
||||
print("=" * 60 + "\n")
|
||||
# 运行 pytest 单元测试(真实API调用)
|
||||
pytest.main([__file__, '-v'])
|
||||
|
||||
Reference in New Issue
Block a user