feat: 新增股票基础数据获取模块 stock_basic
- 新增 get_stock_basic 和 sync_all_stocks 函数 - 完善 Tushare 数据获取模块体系 - 测试用例重构:从 Mock 改为真实 API 调用 - 更新 API 文档,添加接口使用示例 - 更新开发规范:添加 Mock 使用规范
This commit is contained in:
@@ -291,6 +291,36 @@ tests/
|
|||||||
└── fixtures/ # 测试数据
|
└── fixtures/ # 测试数据
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 5.3 Mock 使用规范
|
||||||
|
- **默认不使用 Mock**:单元测试应直接调用真实 API 或服务
|
||||||
|
- **仅在必要时使用 Mock**:
|
||||||
|
- 测试错误处理场景(如网络超时、服务不可用)
|
||||||
|
- 测试空结果或边界情况
|
||||||
|
- 隔离外部依赖的不确定性
|
||||||
|
- **禁止全面 Mock**:不应为避免配置或环境问题而使用 Mock
|
||||||
|
- **真实环境验证**:确保至少有一个测试套件直接调用真实 API
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 正确示例:直接调用真实 API
|
||||||
|
class TestUserService:
|
||||||
|
def test_get_user(self):
|
||||||
|
result = get_user('user_id')
|
||||||
|
assert result.name == "John"
|
||||||
|
|
||||||
|
def test_empty_result_with_mock(self): # 特殊场景使用 mock
|
||||||
|
with patch('module.api_call', return_value=pd.DataFrame()):
|
||||||
|
result = get_user('invalid_id')
|
||||||
|
assert result.empty
|
||||||
|
|
||||||
|
# 错误示例:过度使用 mock
|
||||||
|
class TestUserService:
|
||||||
|
def test_get_user(self):
|
||||||
|
mock_data = {"name": "John"}
|
||||||
|
with patch('module.api_call', return_value=mock_data):
|
||||||
|
result = get_user('user_id')
|
||||||
|
assert result == mock_data
|
||||||
|
```
|
||||||
|
|
||||||
## 6 Git提交规范
|
## 6 Git提交规范
|
||||||
|
|
||||||
### 6.1 提交信息格式
|
### 6.1 提交信息格式
|
||||||
@@ -323,3 +353,4 @@ tests/
|
|||||||
- [ ] 无循环依赖
|
- [ ] 无循环依赖
|
||||||
- [ ] 命名符合规范
|
- [ ] 命名符合规范
|
||||||
- [ ] 日志不包含敏感信息
|
- [ ] 日志不包含敏感信息
|
||||||
|
- [ ] 测试未过度使用 Mock
|
||||||
|
|||||||
BIN
data/stock_basic.h5
Normal file
BIN
data/stock_basic.h5
Normal file
Binary file not shown.
@@ -6,10 +6,13 @@ Provides simplified interfaces for fetching and storing Tushare data.
|
|||||||
from src.data.config import Config, get_config
|
from src.data.config import Config, get_config
|
||||||
from src.data.client import TushareClient
|
from src.data.client import TushareClient
|
||||||
from src.data.storage import Storage
|
from src.data.storage import Storage
|
||||||
|
from src.data.stock_basic import get_stock_basic, sync_all_stocks
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Config",
|
"Config",
|
||||||
"get_config",
|
"get_config",
|
||||||
"TushareClient",
|
"TushareClient",
|
||||||
"Storage",
|
"Storage",
|
||||||
|
"get_stock_basic",
|
||||||
|
"sync_all_stocks",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -45,3 +45,82 @@ change float 涨跌额
|
|||||||
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
|
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
|
||||||
vol float 成交量 (手)
|
vol float 成交量 (手)
|
||||||
amount float 成交额 (千元)
|
amount float 成交额 (千元)
|
||||||
|
|
||||||
|
接口用例
|
||||||
|
|
||||||
|
#取000001的前复权行情
|
||||||
|
df = ts.pro_bar(ts_code='000001.SZ', adj='qfq', start_date='20180101', end_date='20181011')
|
||||||
|
|
||||||
|
ts_code trade_date open high low close \
|
||||||
|
trade_date
|
||||||
|
20181011 000001.SZ 20181011 1085.71 1097.59 1047.90 1065.19
|
||||||
|
20181010 000001.SZ 20181010 1138.65 1151.61 1121.36 1128.92
|
||||||
|
20181009 000001.SZ 20181009 1130.00 1155.93 1122.44 1140.81
|
||||||
|
20181008 000001.SZ 20181008 1155.93 1165.65 1128.92 1128.92
|
||||||
|
20180928 000001.SZ 20180928 1164.57 1217.51 1164.57 1193.74
|
||||||
|
|
||||||
|
#取上证指数行情数据
|
||||||
|
|
||||||
|
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
|
||||||
|
|
||||||
|
In [10]: df.head()
|
||||||
|
Out[10]:
|
||||||
|
ts_code trade_date close open high low \
|
||||||
|
0 000001.SH 20181011 2583.4575 2643.0740 2661.2859 2560.3164
|
||||||
|
1 000001.SH 20181010 2725.8367 2723.7242 2743.5480 2703.0626
|
||||||
|
2 000001.SH 20181009 2721.0130 2713.7319 2734.3142 2711.1971
|
||||||
|
3 000001.SH 20181008 2716.5104 2768.2075 2771.9384 2710.1781
|
||||||
|
4 000001.SH 20180928 2821.3501 2794.2644 2821.7553 2791.8363
|
||||||
|
|
||||||
|
pre_close change pct_chg vol amount
|
||||||
|
0 2725.8367 -142.3792 -5.2233 197150702.0 170057762.5
|
||||||
|
1 2721.0130 4.8237 0.1773 113485736.0 111312455.3
|
||||||
|
2 2716.5104 4.5026 0.1657 116771899.0 110292457.8
|
||||||
|
3 2821.3501 -104.8397 -3.7159 149501388.0 141531551.8
|
||||||
|
4 2791.7748 29.5753 1.0594 134290456.0 125369989.4
|
||||||
|
|
||||||
|
#均线
|
||||||
|
|
||||||
|
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', ma=[5, 20, 50])
|
||||||
|
注:Tushare pro_bar接口的均价和均量数据是动态计算,想要获取某个时间段的均线,必须要设置start_date日期大于最大均线的日期数,然后自行截取想要日期段。例如,想要获取20190801开始的3日均线,必须设置start_date='20190729',然后剔除20190801之前的日期记录。
|
||||||
|
|
||||||
|
#换手率tor,量比vr
|
||||||
|
|
||||||
|
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', factors=['tor', 'vr'])
|
||||||
|
|
||||||
|
|
||||||
|
- 基础信息:https://tushare.pro/document/2?doc_id=25
|
||||||
|
接口:stock_basic,可以通过数据工具调试和查看数据
|
||||||
|
描述:获取基础信息数据,包括股票代码、名称、上市日期、退市日期等
|
||||||
|
权限:2000积分起。此接口是基础信息,调取一次就可以拉取完,建议保存倒本地存储后使用
|
||||||
|
|
||||||
|
输入参数
|
||||||
|
|
||||||
|
名称 类型 必选 描述
|
||||||
|
ts_code str N TS股票代码
|
||||||
|
name str N 名称
|
||||||
|
market str N 市场类别 (主板/创业板/科创板/CDR/北交所)
|
||||||
|
list_status str N 上市状态 L上市 D退市 P暂停上市 G过会未交易,默认是L
|
||||||
|
exchange str N 交易所 SSE上交所 SZSE深交所 BSE北交所
|
||||||
|
is_hs str N 是否沪深港通标的,N否 H沪股通 S深股通
|
||||||
|
输出参数
|
||||||
|
|
||||||
|
名称 类型 默认显示 描述
|
||||||
|
ts_code str Y TS代码
|
||||||
|
symbol str Y 股票代码
|
||||||
|
name str Y 股票名称
|
||||||
|
area str Y 地域
|
||||||
|
industry str Y 所属行业
|
||||||
|
fullname str N 股票全称
|
||||||
|
enname str N 英文全称
|
||||||
|
cnspell str Y 拼音缩写
|
||||||
|
market str Y 市场类型(主板/创业板/科创板/CDR)
|
||||||
|
exchange str N 交易所代码
|
||||||
|
curr_type str N 交易货币
|
||||||
|
list_status str N 上市状态 L上市 D退市 G过会未交易 P暂停上市
|
||||||
|
list_date str Y 上市日期
|
||||||
|
delist_date str N 退市日期
|
||||||
|
is_hs str N 是否沪深港通标的,N否 H沪股通 S深股通
|
||||||
|
act_name str Y 实控人名称
|
||||||
|
act_ent_type str Y 实控人企业性质
|
||||||
|
说明:旧版上的PE/PB/股本等字段,请在行情接口“每日指标”中获取。
|
||||||
@@ -44,7 +44,7 @@ class TushareClient:
|
|||||||
"""Execute API query with rate limiting and retry.
|
"""Execute API query with rate limiting and retry.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_name: API name (e.g., 'daily')
|
api_name: API name ('daily', 'pro_bar', etc.)
|
||||||
timeout: Timeout for rate limiting
|
timeout: Timeout for rate limiting
|
||||||
**params: API parameters
|
**params: API parameters
|
||||||
|
|
||||||
@@ -66,6 +66,20 @@ class TushareClient:
|
|||||||
|
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
|
import tushare as ts
|
||||||
|
|
||||||
|
# pro_bar uses ts.pro_bar() instead of api.query()
|
||||||
|
if api_name == "pro_bar":
|
||||||
|
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor
|
||||||
|
data = ts.pro_bar(ts_code=params.get("ts_code"),
|
||||||
|
start_date=params.get("start_date"),
|
||||||
|
end_date=params.get("end_date"),
|
||||||
|
adj=params.get("adj"),
|
||||||
|
freq=params.get("freq", "D"),
|
||||||
|
factors=params.get("factors"), # factors should be a list like ['tor', 'vr']
|
||||||
|
ma=params.get("ma"),
|
||||||
|
adjfactor=params.get("adjfactor"))
|
||||||
|
else:
|
||||||
api = self._get_api()
|
api = self._get_api()
|
||||||
data = api.query(api_name, **params)
|
data = api.query(api_name, **params)
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,37 @@
|
|||||||
"""Configuration management for data collection module."""
|
"""Configuration management for data collection module."""
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
# Config directory path - used for loading .env.local
|
||||||
|
# Static detection for pydantic-settings to find .env.local
|
||||||
|
CONFIG_DIR = Path(__file__).parent.parent.parent / "config"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_project_root() -> Path:
|
||||||
|
"""Get project root path from ROOT_PATH env var or auto-detect."""
|
||||||
|
# Try to read from environment variable first
|
||||||
|
root_path = os.environ.get("ROOT_PATH") or os.environ.get("DATA_ROOT")
|
||||||
|
if root_path:
|
||||||
|
return Path(root_path)
|
||||||
|
# Fallback to auto-detection
|
||||||
|
return Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseSettings):
|
class Config(BaseSettings):
|
||||||
"""Application configuration loaded from environment variables."""
|
"""Application configuration loaded from environment variables."""
|
||||||
|
|
||||||
# Tushare API token
|
# Tushare API token
|
||||||
tushare_token: str = ""
|
tushare_token: str = ""
|
||||||
|
|
||||||
# Data storage path
|
# Root path - loaded from environment variable ROOT_PATH
|
||||||
data_path: Path = Path("./data")
|
# If not set, uses auto-detected path
|
||||||
|
root_path: str = ""
|
||||||
|
|
||||||
|
# Data storage path - can be set via DATA_PATH environment variable
|
||||||
|
# If relative path, it will be resolved relative to root_path
|
||||||
|
data_path: str = "data"
|
||||||
|
|
||||||
# Rate limit: requests per minute
|
# Rate limit: requests per minute
|
||||||
rate_limit: int = 100
|
rate_limit: int = 100
|
||||||
@@ -18,10 +39,31 @@ class Config(BaseSettings):
|
|||||||
# Thread pool size
|
# Thread pool size
|
||||||
threads: int = 2
|
threads: int = 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def project_root(self) -> Path:
|
||||||
|
"""Get project root path."""
|
||||||
|
if self.root_path:
|
||||||
|
return Path(self.root_path)
|
||||||
|
return _get_project_root()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_path_resolved(self) -> Path:
|
||||||
|
"""Get resolved data path (absolute)."""
|
||||||
|
path = Path(self.data_path)
|
||||||
|
if path.is_absolute():
|
||||||
|
return path
|
||||||
|
return self.project_root / path
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env.local"
|
# 从 config/ 目录读取 .env.local 文件
|
||||||
|
env_file = str(CONFIG_DIR / ".env.local")
|
||||||
env_file_encoding = "utf-8"
|
env_file_encoding = "utf-8"
|
||||||
case_sensitive = False
|
case_sensitive = False
|
||||||
|
extra = "ignore" # 忽略 .env.local 中的额外变量
|
||||||
|
# pydantic-settings 默认会将字段名转换为大写作为环境变量名
|
||||||
|
# 所以 tushare_token 会映射到 TUSHARE_TOKEN
|
||||||
|
# root_path 会映射到 ROOT_PATH
|
||||||
|
# data_path 会映射到 DATA_PATH
|
||||||
|
|
||||||
|
|
||||||
# Global config instance
|
# Global config instance
|
||||||
@@ -31,3 +73,8 @@ config = Config()
|
|||||||
def get_config() -> Config:
|
def get_config() -> Config:
|
||||||
"""Get configuration instance."""
|
"""Get configuration instance."""
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_project_root() -> Path:
|
||||||
|
"""Get project root path (convenience function)."""
|
||||||
|
return get_config().project_root
|
||||||
|
|||||||
@@ -57,14 +57,24 @@ def get_daily(
|
|||||||
if adj:
|
if adj:
|
||||||
params["adj"] = adj
|
params["adj"] = adj
|
||||||
if factors:
|
if factors:
|
||||||
params["factors"] = factors
|
# Tushare expects factors as comma-separated string, not list
|
||||||
|
if isinstance(factors, list):
|
||||||
|
factors_str = ",".join(factors)
|
||||||
|
else:
|
||||||
|
factors_str = factors
|
||||||
|
params["factors"] = factors_str
|
||||||
|
print(f"[get_daily] factors param: '{factors_str}'")
|
||||||
if adjfactor:
|
if adjfactor:
|
||||||
params["adjfactor"] = "True"
|
params["adjfactor"] = "True"
|
||||||
|
|
||||||
# Fetch data
|
# Fetch data using pro_bar (supports factors like tor, vr)
|
||||||
data = client.query("daily", **params)
|
print(f"[get_daily] Query params: {params}")
|
||||||
|
data = client.query("pro_bar", **params)
|
||||||
|
|
||||||
if data.empty:
|
if not data.empty:
|
||||||
|
print(f"[get_daily] Returned columns: {data.columns.tolist()}")
|
||||||
|
print(f"[get_daily] Sample row: {data.iloc[0].to_dict()}")
|
||||||
|
else:
|
||||||
print(f"[get_daily] No data for ts_code={ts_code}")
|
print(f"[get_daily] No data for ts_code={ts_code}")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|||||||
120
src/data/stock_basic.py
Normal file
120
src/data/stock_basic.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Simplified stock basic information interface.
|
||||||
|
|
||||||
|
Fetch basic stock information including code, name, listing date, etc.
|
||||||
|
This is a special interface - call once to get all stocks (listed and delisted).
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Optional, Literal, List
|
||||||
|
from src.data.client import TushareClient
|
||||||
|
from src.data.storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
def get_stock_basic(
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
market: Optional[Literal["主板", "创业板", "科创板", "CDR", "北交所"]] = None,
|
||||||
|
list_status: Optional[Literal["L", "D", "P", "G"]] = None,
|
||||||
|
exchange: Optional[Literal["SSE", "SZSE", "BSE"]] = None,
|
||||||
|
is_hs: Optional[Literal["N", "H", "S"]] = None,
|
||||||
|
fields: Optional[List[str]] = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Fetch basic stock information.
|
||||||
|
|
||||||
|
This interface retrieves basic information data including stock code,
|
||||||
|
name, listing date, delisting date, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_code: TS stock code
|
||||||
|
name: Stock name
|
||||||
|
market: Market type (主板/创业板/科创板/CDR/北交所)
|
||||||
|
list_status: Listing status - L(listed), D(delisted), P(paused), G(approved-not-traded)
|
||||||
|
exchange: Exchange - SSE, SZSE, BSE
|
||||||
|
is_hs: HS indicator - N(no), H(Shanghai), S(Shenzhen)
|
||||||
|
fields: Specific fields to return, None returns default fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with stock basic information containing:
|
||||||
|
- ts_code, symbol, name, area, industry, fullname, enname, cnspell,
|
||||||
|
market, exchange, curr_type, list_status, list_date, delist_date,
|
||||||
|
is_hs, act_name, act_ent_type
|
||||||
|
"""
|
||||||
|
client = TushareClient()
|
||||||
|
|
||||||
|
# Build parameters
|
||||||
|
params = {}
|
||||||
|
if ts_code:
|
||||||
|
params["ts_code"] = ts_code
|
||||||
|
if name:
|
||||||
|
params["name"] = name
|
||||||
|
if market:
|
||||||
|
params["market"] = market
|
||||||
|
if list_status:
|
||||||
|
params["list_status"] = list_status
|
||||||
|
if exchange:
|
||||||
|
params["exchange"] = exchange
|
||||||
|
if is_hs:
|
||||||
|
params["is_hs"] = is_hs
|
||||||
|
if fields:
|
||||||
|
params["fields"] = ",".join(fields)
|
||||||
|
|
||||||
|
# Fetch data
|
||||||
|
data = client.query("stock_basic", **params)
|
||||||
|
|
||||||
|
if data.empty:
|
||||||
|
print("[get_stock_basic] No data returned")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def sync_all_stocks() -> pd.DataFrame:
|
||||||
|
"""Fetch and save all stocks (listed and delisted) to local storage.
|
||||||
|
|
||||||
|
This is a special interface that should only be called once to initialize
|
||||||
|
the local database with all stock information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with all stock information
|
||||||
|
"""
|
||||||
|
# Initialize storage
|
||||||
|
storage = Storage()
|
||||||
|
|
||||||
|
# Check if already exists
|
||||||
|
if storage.exists("stock_basic"):
|
||||||
|
print("[sync_all_stocks] stock_basic data already exists, skipping...")
|
||||||
|
return storage.load("stock_basic")
|
||||||
|
|
||||||
|
print("[sync_all_stocks] Fetching all stocks (listed and delisted)...")
|
||||||
|
|
||||||
|
# Fetch all stocks - explicitly get all list_status values
|
||||||
|
# API default is L (listed), so we need to fetch all statuses
|
||||||
|
client = TushareClient()
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
for status in ["L", "D", "P", "G"]:
|
||||||
|
print(f"[sync_all_stocks] Fetching stocks with status: {status}")
|
||||||
|
data = client.query("stock_basic", list_status=status)
|
||||||
|
print(f"[sync_all_stocks] Fetched {len(data)} stocks with status {status}")
|
||||||
|
if not data.empty:
|
||||||
|
all_data.append(data)
|
||||||
|
|
||||||
|
if not all_data:
|
||||||
|
print("[sync_all_stocks] No stock data fetched")
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
# Combine all data
|
||||||
|
data = pd.concat(all_data, ignore_index=True)
|
||||||
|
# Remove duplicates if any
|
||||||
|
data = data.drop_duplicates(subset=["ts_code"], keep="first")
|
||||||
|
print(f"[sync_all_stocks] Total unique stocks: {len(data)}")
|
||||||
|
|
||||||
|
# Save to storage
|
||||||
|
storage.save("stock_basic", data, mode="replace")
|
||||||
|
|
||||||
|
print(f"[sync_all_stocks] Saved {len(data)} stocks to local storage")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Sync all stocks (listed and delisted) to data folder
|
||||||
|
result = sync_all_stocks()
|
||||||
|
print(f"Total stocks synced: {len(result)}")
|
||||||
@@ -16,7 +16,7 @@ class Storage:
|
|||||||
path: Base path for data storage (auto-loaded from config if not provided)
|
path: Base path for data storage (auto-loaded from config if not provided)
|
||||||
"""
|
"""
|
||||||
cfg = get_config()
|
cfg = get_config()
|
||||||
self.base_path = path or cfg.data_path
|
self.base_path = path or cfg.data_path_resolved
|
||||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def _get_file_path(self, name: str) -> Path:
|
def _get_file_path(self, name: str) -> Path:
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ Tests the daily interface implementation against api.md requirements:
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
from src.data.daily import get_daily
|
from src.data.daily import get_daily
|
||||||
from src.data.client import TushareClient
|
|
||||||
|
|
||||||
|
|
||||||
# Expected output fields according to api.md
|
# Expected output fields according to api.md
|
||||||
@@ -28,270 +26,24 @@ EXPECTED_BASE_FIELDS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
EXPECTED_FACTOR_FIELDS = [
|
EXPECTED_FACTOR_FIELDS = [
|
||||||
'tor', # 换手率
|
'turnover_rate', # 换手率 (tor)
|
||||||
'vr', # 量比
|
'volume_ratio', # 量比 (vr)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def run_tests_with_print():
|
|
||||||
"""Run all tests and print results."""
|
|
||||||
print("=" * 60)
|
|
||||||
print("Daily API 测试开始")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
test_results = []
|
|
||||||
|
|
||||||
# Test 1: Basic daily data fetch
|
|
||||||
print("\n【测试1】基本日线数据获取")
|
|
||||||
print("-" * 40)
|
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['000001.SZ'],
|
|
||||||
'trade_date': ['20240102'],
|
|
||||||
'open': [10.5],
|
|
||||||
'high': [11.0],
|
|
||||||
'low': [10.2],
|
|
||||||
'close': [10.8],
|
|
||||||
'pre_close': [10.3],
|
|
||||||
'change': [0.5],
|
|
||||||
'pct_chg': [4.85],
|
|
||||||
'vol': [1000000],
|
|
||||||
'amount': [10800000],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
|
||||||
|
|
||||||
print(f"获取数据形状: {result.shape}")
|
|
||||||
print(f"获取数据列: {result.columns.tolist()}")
|
|
||||||
print(f"数据内容:\n{result}")
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
tests_passed = isinstance(result, pd.DataFrame)
|
|
||||||
tests_passed &= len(result) == 1
|
|
||||||
tests_passed &= result['ts_code'].iloc[0] == '000001.SZ'
|
|
||||||
|
|
||||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
||||||
test_results.append(("基本日线数据获取", tests_passed))
|
|
||||||
|
|
||||||
# Test 2: Fetch with factors
|
|
||||||
print("\n【测试2】获取含换手率和量比的数据")
|
|
||||||
print("-" * 40)
|
|
||||||
mock_data_factors = pd.DataFrame({
|
|
||||||
'ts_code': ['000001.SZ'],
|
|
||||||
'trade_date': ['20240102'],
|
|
||||||
'open': [10.5],
|
|
||||||
'high': [11.0],
|
|
||||||
'low': [10.2],
|
|
||||||
'close': [10.8],
|
|
||||||
'pre_close': [10.3],
|
|
||||||
'change': [0.5],
|
|
||||||
'pct_chg': [4.85],
|
|
||||||
'vol': [1000000],
|
|
||||||
'amount': [10800000],
|
|
||||||
'tor': [2.5],
|
|
||||||
'vr': [1.2],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data_factors):
|
|
||||||
result = get_daily(
|
|
||||||
'000001.SZ',
|
|
||||||
start_date='20240101',
|
|
||||||
end_date='20240131',
|
|
||||||
factors=['tor', 'vr'],
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"获取数据形状: {result.shape}")
|
|
||||||
print(f"获取数据列: {result.columns.tolist()}")
|
|
||||||
print(f"数据内容:\n{result}")
|
|
||||||
|
|
||||||
# Verify columns
|
|
||||||
tests_passed = isinstance(result, pd.DataFrame)
|
|
||||||
missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns]
|
|
||||||
missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns]
|
|
||||||
|
|
||||||
print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base} ✗'}")
|
|
||||||
print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors} ✗'}")
|
|
||||||
|
|
||||||
tests_passed &= len(missing_base) == 0
|
|
||||||
tests_passed &= len(missing_factors) == 0
|
|
||||||
|
|
||||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
||||||
test_results.append(("含因子数据获取", tests_passed))
|
|
||||||
|
|
||||||
# Test 3: Output fields completeness
|
|
||||||
print("\n【测试3】输出字段完整性检查")
|
|
||||||
print("-" * 40)
|
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['600000.SH'],
|
|
||||||
'trade_date': ['20240102'],
|
|
||||||
'open': [10.5],
|
|
||||||
'high': [11.0],
|
|
||||||
'low': [10.2],
|
|
||||||
'close': [10.8],
|
|
||||||
'pre_close': [10.3],
|
|
||||||
'change': [0.5],
|
|
||||||
'pct_chg': [4.85],
|
|
||||||
'vol': [1000000],
|
|
||||||
'amount': [10800000],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily('600000.SH')
|
|
||||||
|
|
||||||
print(f"获取数据形状: {result.shape}")
|
|
||||||
print(f"获取数据列: {result.columns.tolist()}")
|
|
||||||
print(f"期望基础列: {EXPECTED_BASE_FIELDS}")
|
|
||||||
|
|
||||||
missing = set(EXPECTED_BASE_FIELDS) - set(result.columns)
|
|
||||||
print(f"缺失列: {missing if missing else '无'}")
|
|
||||||
|
|
||||||
tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist())
|
|
||||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
||||||
test_results.append(("输出字段完整性", tests_passed))
|
|
||||||
|
|
||||||
# Test 4: Empty result
|
|
||||||
print("\n【测试4】空结果处理")
|
|
||||||
print("-" * 40)
|
|
||||||
mock_data = pd.DataFrame()
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily('INVALID.SZ')
|
|
||||||
|
|
||||||
print(f"获取数据是否为空: {result.empty}")
|
|
||||||
tests_passed = result.empty
|
|
||||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
||||||
test_results.append(("空结果处理", tests_passed))
|
|
||||||
|
|
||||||
# Test 5: Date range query
|
|
||||||
print("\n【测试5】日期范围查询")
|
|
||||||
print("-" * 40)
|
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['000001.SZ', '000001.SZ'],
|
|
||||||
'trade_date': ['20240102', '20240103'],
|
|
||||||
'open': [10.5, 10.6],
|
|
||||||
'high': [11.0, 11.1],
|
|
||||||
'low': [10.2, 10.3],
|
|
||||||
'close': [10.8, 10.9],
|
|
||||||
'pre_close': [10.3, 10.8],
|
|
||||||
'change': [0.5, 0.1],
|
|
||||||
'pct_chg': [4.85, 0.93],
|
|
||||||
'vol': [1000000, 1100000],
|
|
||||||
'amount': [10800000, 11900000],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily(
|
|
||||||
'000001.SZ',
|
|
||||||
start_date='20240101',
|
|
||||||
end_date='20240131',
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"获取数据数量: {len(result)}")
|
|
||||||
print(f"期望数量: 2")
|
|
||||||
print(f"数据内容:\n{result}")
|
|
||||||
|
|
||||||
tests_passed = len(result) == 2
|
|
||||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
||||||
test_results.append(("日期范围查询", tests_passed))
|
|
||||||
|
|
||||||
# Test 6: With adjustment
|
|
||||||
print("\n【测试6】带复权参数查询")
|
|
||||||
print("-" * 40)
|
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['000001.SZ'],
|
|
||||||
'trade_date': ['20240102'],
|
|
||||||
'open': [10.5],
|
|
||||||
'high': [11.0],
|
|
||||||
'low': [10.2],
|
|
||||||
'close': [10.8],
|
|
||||||
'pre_close': [10.3],
|
|
||||||
'change': [0.5],
|
|
||||||
'pct_chg': [4.85],
|
|
||||||
'vol': [1000000],
|
|
||||||
'amount': [10800000],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily('000001.SZ', adj='qfq')
|
|
||||||
|
|
||||||
print(f"获取数据形状: {result.shape}")
|
|
||||||
print(f"数据内容:\n{result}")
|
|
||||||
|
|
||||||
tests_passed = isinstance(result, pd.DataFrame)
|
|
||||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
|
||||||
test_results.append(("复权参数查询", tests_passed))
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("测试汇总")
|
|
||||||
print("=" * 60)
|
|
||||||
passed = sum(1 for _, r in test_results if r)
|
|
||||||
total = len(test_results)
|
|
||||||
print(f"总测试数: {total}")
|
|
||||||
print(f"通过: {passed}")
|
|
||||||
print(f"失败: {total - passed}")
|
|
||||||
print(f"通过率: {passed/total*100:.1f}%")
|
|
||||||
print("\n详细结果:")
|
|
||||||
for name, passed in test_results:
|
|
||||||
status = "通过 ✓" if passed else "失败 ✗"
|
|
||||||
print(f" - {name}: {status}")
|
|
||||||
|
|
||||||
return all(r for _, r in test_results)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetDaily:
|
class TestGetDaily:
|
||||||
"""Test cases for simplified get_daily function."""
|
"""Test cases for get_daily function with real API calls."""
|
||||||
|
|
||||||
def test_fetch_basic(self):
|
def test_fetch_basic(self):
|
||||||
"""Test basic daily data fetch."""
|
"""Test basic daily data fetch with real API."""
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['000001.SZ'],
|
|
||||||
'trade_date': ['20240102'],
|
|
||||||
'open': [10.5],
|
|
||||||
'high': [11.0],
|
|
||||||
'low': [10.2],
|
|
||||||
'close': [10.8],
|
|
||||||
'pre_close': [10.3],
|
|
||||||
'change': [0.5],
|
|
||||||
'pct_chg': [4.85],
|
|
||||||
'vol': [1000000],
|
|
||||||
'amount': [10800000],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
assert len(result) == 1
|
assert len(result) >= 1
|
||||||
assert result['ts_code'].iloc[0] == '000001.SZ'
|
assert result['ts_code'].iloc[0] == '000001.SZ'
|
||||||
|
|
||||||
def test_fetch_with_factors(self):
|
def test_fetch_with_factors(self):
|
||||||
"""Test fetch with tor and vr factors."""
|
"""Test fetch with tor and vr factors."""
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['000001.SZ'],
|
|
||||||
'trade_date': ['20240102'],
|
|
||||||
'open': [10.5],
|
|
||||||
'high': [11.0],
|
|
||||||
'low': [10.2],
|
|
||||||
'close': [10.8],
|
|
||||||
'pre_close': [10.3],
|
|
||||||
'change': [0.5],
|
|
||||||
'pct_chg': [4.85],
|
|
||||||
'vol': [1000000],
|
|
||||||
'amount': [10800000],
|
|
||||||
'tor': [2.5], # 换手率
|
|
||||||
'vr': [1.2], # 量比
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily(
|
result = get_daily(
|
||||||
'000001.SZ',
|
'000001.SZ',
|
||||||
start_date='20240101',
|
start_date='20240101',
|
||||||
@@ -309,22 +61,6 @@ class TestGetDaily:
|
|||||||
|
|
||||||
def test_output_fields_completeness(self):
|
def test_output_fields_completeness(self):
|
||||||
"""Verify all required output fields are returned."""
|
"""Verify all required output fields are returned."""
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['600000.SH'],
|
|
||||||
'trade_date': ['20240102'],
|
|
||||||
'open': [10.5],
|
|
||||||
'high': [11.0],
|
|
||||||
'low': [10.2],
|
|
||||||
'close': [10.8],
|
|
||||||
'pre_close': [10.3],
|
|
||||||
'change': [0.5],
|
|
||||||
'pct_chg': [4.85],
|
|
||||||
'vol': [1000000],
|
|
||||||
'amount': [10800000],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily('600000.SH')
|
result = get_daily('600000.SH')
|
||||||
|
|
||||||
# Verify all base fields are present
|
# Verify all base fields are present
|
||||||
@@ -333,58 +69,24 @@ class TestGetDaily:
|
|||||||
|
|
||||||
def test_empty_result(self):
|
def test_empty_result(self):
|
||||||
"""Test handling of empty results."""
|
"""Test handling of empty results."""
|
||||||
mock_data = pd.DataFrame()
|
# 使用真实 API 测试无效股票代码的空结果
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily('INVALID.SZ')
|
result = get_daily('INVALID.SZ')
|
||||||
|
assert isinstance(result, pd.DataFrame)
|
||||||
assert result.empty
|
assert result.empty
|
||||||
|
|
||||||
def test_date_range_query(self):
|
def test_date_range_query(self):
|
||||||
"""Test query with date range."""
|
"""Test query with date range."""
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['000001.SZ', '000001.SZ'],
|
|
||||||
'trade_date': ['20240102', '20240103'],
|
|
||||||
'open': [10.5, 10.6],
|
|
||||||
'high': [11.0, 11.1],
|
|
||||||
'low': [10.2, 10.3],
|
|
||||||
'close': [10.8, 10.9],
|
|
||||||
'pre_close': [10.3, 10.8],
|
|
||||||
'change': [0.5, 0.1],
|
|
||||||
'pct_chg': [4.85, 0.93],
|
|
||||||
'vol': [1000000, 1100000],
|
|
||||||
'amount': [10800000, 11900000],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily(
|
result = get_daily(
|
||||||
'000001.SZ',
|
'000001.SZ',
|
||||||
start_date='20240101',
|
start_date='20240101',
|
||||||
end_date='20240131',
|
end_date='20240131',
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(result) == 2
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
assert len(result) >= 1
|
||||||
|
|
||||||
def test_with_adj(self):
|
def test_with_adj(self):
|
||||||
"""Test fetch with adjustment type."""
|
"""Test fetch with adjustment type."""
|
||||||
mock_data = pd.DataFrame({
|
|
||||||
'ts_code': ['000001.SZ'],
|
|
||||||
'trade_date': ['20240102'],
|
|
||||||
'open': [10.5],
|
|
||||||
'high': [11.0],
|
|
||||||
'low': [10.2],
|
|
||||||
'close': [10.8],
|
|
||||||
'pre_close': [10.3],
|
|
||||||
'change': [0.5],
|
|
||||||
'pct_chg': [4.85],
|
|
||||||
'vol': [1000000],
|
|
||||||
'amount': [10800000],
|
|
||||||
})
|
|
||||||
|
|
||||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
|
||||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
|
||||||
result = get_daily('000001.SZ', adj='qfq')
|
result = get_daily('000001.SZ', adj='qfq')
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
@@ -411,9 +113,5 @@ def test_integration():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 运行详细的打印测试
|
# 运行 pytest 单元测试(真实API调用)
|
||||||
run_tests_with_print()
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("运行 pytest 单元测试")
|
|
||||||
print("=" * 60 + "\n")
|
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, '-v'])
|
||||||
|
|||||||
Reference in New Issue
Block a user