refactor: 重构 API 接口模块,整合为 api_wrappers 目录结构

- 将独立 API 模块 (daily, stock_basic, trade_cal) 整合至 api_wrappers/
- 重写 sync.py 使用新的 wrapper 结构,支持更多同步功能
- 更新测试文件适配新的模块结构
- 添加 pytest.ini 配置文件
This commit is contained in:
2026-02-21 03:43:30 +08:00
parent e81d39ae0d
commit 9965ce5706
15 changed files with 1042 additions and 952 deletions

3
pytest.ini Normal file
View File

@@ -0,0 +1,3 @@
[pytest]
pythonpath = .
testpaths = tests

View File

@@ -1,847 +0,0 @@
# ProStock 数据接口封装规范
## 1. 概述
本文档定义了在 `src/data/` 目录下新增 Tushare API 接口封装的标准规范。所有非特殊接口(因子和基础数据)必须遵循此规范,以确保:
- 代码风格统一
- 自动 sync 支持
- 增量更新逻辑一致
- 减少存储写入压力
## 2. 接口分类
### 2.1 特殊接口(不参与统一 sync
以下接口有独立的同步逻辑,不参与本文档定义的自动 sync 机制:
| 接口类型 | 示例 | 说明 |
|---------|------|------|
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
| 股票基础信息 | `stock_basic` | 一次性全量获取CSV 存储 |
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
### 2.2 标准接口(必须遵循本规范)
所有**按股票**或**按日期**获取的因子数据、行情数据、财务数据等,必须遵循本规范。
## 3. 文件结构
### 3.1 文件命名
```
{data_type}.py
```
示例:
- `daily.py` - 日线行情
- `moneyflow.py` - 资金流向
- `limit_list.py` - 涨跌停数据
- `stk_holdernumber.py` - 股东人数
### 3.2 文件位置
```
src/data/
├── __init__.py # 导出公共接口
├── client.py # TushareClient已有
├── config.py # 配置管理(已有)
├── storage.py # 存储管理(已有)
├── rate_limiter.py # 速率限制(已有)
├── trade_cal.py # 交易日历(特殊接口)
├── stock_basic.py # 股票基础(特殊接口)
├── daily.py # 日线行情(参考示例)
└── {new_data_type}.py # 新增接口文件
```
## 4. 接口设计规范
### 4.1 数据获取函数
#### 4.1.1 按股票获取的接口
适用于:日线行情、分钟线、资金流向等
```python
def get_{data_type}(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
# 其他可选参数...
) -> pd.DataFrame:
"""获取 {数据描述}
Args:
ts_code: 股票代码(如 '000001.SZ'
start_date: 开始日期YYYYMMDD格式
end_date: 结束日期YYYYMMDD格式
# 其他参数说明...
Returns:
pd.DataFrame 包含以下字段:
- ts_code: 股票代码
- trade_date: 交易日期
# 其他字段...
Example:
>>> data = get_{data_type}('000001.SZ', start_date='20240101', end_date='20240131')
"""
client = TushareClient()
params = {"ts_code": ts_code}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
# 其他参数...
data = client.query("{api_name}", **params)
return data
```
#### 4.1.2 按日期获取的接口
适用于:每日涨跌停、每日龙虎榜、每日筹码分布等
```python
def get_{data_type}(
trade_date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
# 其他可选参数...
) -> pd.DataFrame:
"""获取 {数据描述}
**优先按日期获取**(推荐):
- 使用 trade_date 获取单日全市场数据
- 或使用 start_date + end_date 获取区间数据
Args:
trade_date: 交易日期YYYYMMDD格式获取单日全市场数据
start_date: 开始日期YYYYMMDD格式
end_date: 结束日期YYYYMMDD格式
ts_code: 股票代码(可选,用于过滤特定股票)
# 其他参数说明...
Returns:
pd.DataFrame 包含以下字段:
- ts_code: 股票代码
- trade_date: 交易日期
# 其他字段...
Example:
>>> # 获取单日全市场数据(推荐)
>>> data = get_{data_type}(trade_date='20240115')
>>> # 获取区间数据
>>> data = get_{data_type}(start_date='20240101', end_date='20240131')
"""
client = TushareClient()
params = {}
if trade_date:
params["trade_date"] = trade_date
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if ts_code:
params["ts_code"] = ts_code
# 其他参数...
data = client.query("{api_name}", **params)
return data
```
### 4.2 关键设计原则
#### 4.2.1 优先按日期获取
**强烈建议**优先实现按日期获取的接口:
1. **效率更高**:一次请求获取全市场数据
2. **API 调用更少**N 天 = N 次调用,而非 N 天 × M 只股票
3. **更适合增量更新**:按天检查本地数据,只获取缺失日期
#### 4.2.2 日期字段统一
- 统一使用 `trade_date` 作为日期字段名
- 日期格式:`YYYYMMDD` 字符串
- 如果 API 返回其他字段名(如 `date``end_date`),在返回前重命名为 `trade_date`
#### 4.2.3 股票代码字段
- 统一使用 `ts_code` 作为股票代码字段名
- 格式:`{code}.{exchange}`,如 `000001.SZ``600000.SH`
## 5. Sync 集成规范
### 5.1 在 sync.py 中注册新数据类型
`DataSync` 类中添加新数据类型的同步支持:
```python
class DataSync:
"""Data synchronization manager with full/incremental sync support."""
DEFAULT_MAX_WORKERS = 10
# 数据类型配置
DATASET_CONFIG = {
"daily": {
"api_name": "pro_bar",
"fetch_by": "stock", # 按股票获取
"date_field": "trade_date",
"key_fields": ["ts_code", "trade_date"],
},
"moneyflow": {
"api_name": "moneyflow",
"fetch_by": "stock", # 按股票获取
"date_field": "trade_date",
"key_fields": ["ts_code", "trade_date"],
},
"limit_list": {
"api_name": "limit_list",
"fetch_by": "date", # 按日期获取(优先)
"date_field": "trade_date",
"key_fields": ["ts_code", "trade_date"],
},
# 新增数据类型...
"{new_data_type}": {
"api_name": "{tushare_api_name}",
"fetch_by": "date", # "date" 或 "stock"
"date_field": "trade_date",
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
},
}
```
### 5.2 实现同步方法
#### 5.2.1 按日期获取的同步方法(推荐)
```python
def sync_by_date(
self,
dataset_name: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""Sync data by date (fetch all stocks for each date).
This is the RECOMMENDED approach for date-based data like:
- limit_list (涨跌停)
- top_list (龙虎榜)
- cyq_perf (筹码分布)
Args:
dataset_name: Name of the dataset in DATASET_CONFIG
start_date: Start date (YYYYMMDD)
end_date: End date (YYYYMMDD)
Returns:
Combined DataFrame with all data
"""
from src.data.trade_cal import get_trading_days
config = self.DATASET_CONFIG[dataset_name]
api_name = config["api_name"]
date_field = config["date_field"]
# Get trading days in the range
trading_days = get_trading_days(start_date, end_date)
if not trading_days:
print(f"[DataSync] No trading days in range {start_date} to {end_date}")
return pd.DataFrame()
print(f"[DataSync] Fetching {dataset_name} for {len(trading_days)} trading days")
all_data = []
error_occurred = False
for trade_date in tqdm(trading_days, desc=f"Syncing {dataset_name}"):
if not self._stop_flag.is_set():
break
try:
data = self.client.query(
api_name,
trade_date=trade_date,
)
if not data.empty:
all_data.append(data)
except Exception as e:
self._stop_flag.clear()
error_occurred = True
print(f"[ERROR] Failed to fetch {dataset_name} for {trade_date}: {e}")
raise
if error_occurred or not all_data:
return pd.DataFrame()
# Combine all data
combined = pd.concat(all_data, ignore_index=True)
# Ensure date field is consistent
if date_field not in combined.columns and "trade_date" in combined.columns:
combined = combined.rename(columns={"trade_date": date_field})
return combined
```
#### 5.2.2 按股票获取的同步方法
```python
def sync_by_stock(
self,
dataset_name: str,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""Sync data by stock (fetch all dates for each stock).
Use this for stock-based data like:
- daily (日线行情)
- moneyflow (资金流向)
- stk_holdernumber (股东人数)
Args:
dataset_name: Name of the dataset in DATASET_CONFIG
ts_code: Stock code
start_date: Start date (YYYYMMDD)
end_date: End date (YYYYMMDD)
Returns:
DataFrame with data for the stock
"""
config = self.DATASET_CONFIG[dataset_name]
api_name = config["api_name"]
if not self._stop_flag.is_set():
return pd.DataFrame()
try:
data = self.client.query(
api_name,
ts_code=ts_code,
start_date=start_date,
end_date=end_date,
)
return data
except Exception as e:
self._stop_flag.clear()
print(f"[ERROR] Exception syncing {dataset_name} for {ts_code}: {e}")
raise
```
### 5.3 增量更新逻辑
#### 5.3.1 通用增量更新检查
```python
def check_incremental_sync(
self,
dataset_name: str,
force_full: bool = False,
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""Check if incremental sync is needed for a dataset.
Args:
dataset_name: Name of the dataset
force_full: If True, force full sync
Returns:
Tuple of (sync_needed, start_date, end_date, local_last_date)
"""
config = self.DATASET_CONFIG[dataset_name]
date_field = config["date_field"]
# If force_full, always sync from default start
if force_full:
print(f"[DataSync] Force full sync for {dataset_name}")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# Check local data
local_data = self.storage.load(dataset_name)
if local_data.empty or date_field not in local_data.columns:
print(f"[DataSync] No local {dataset_name} data, full sync needed")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# Get local last date
local_last_date = str(local_data[date_field].max())
print(f"[DataSync] Local {dataset_name} last date: {local_last_date}")
# Get calendar last trading day
today = get_today_date()
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
if cal_last is None:
print(f"[DataSync] Failed to get trade calendar, proceeding with sync")
return (True, DEFAULT_START_DATE, today, local_last_date)
print(f"[DataSync] Calendar last trading day: {cal_last}")
# Compare dates
if int(local_last_date) >= int(cal_last):
print(f"[DataSync] {dataset_name} is up-to-date, skipping sync")
return (False, None, None, None)
# Need incremental sync
sync_start = get_next_date(local_last_date)
print(f"[DataSync] Incremental sync for {dataset_name} from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
```
#### 5.3.2 完整的同步入口
```python
def sync_dataset(
self,
dataset_name: str,
force_full: bool = False,
max_workers: Optional[int] = None,
) -> pd.DataFrame:
"""Sync a dataset with automatic incremental update.
This is the main entry point for syncing any dataset.
Args:
dataset_name: Name of the dataset in DATASET_CONFIG
force_full: If True, force full reload
max_workers: Number of worker threads (for stock-based sync)
Returns:
DataFrame with synced data
"""
print("\n" + "=" * 60)
print(f"[DataSync] Starting {dataset_name} sync...")
print("=" * 60)
# Ensure trade calendar is up-to-date
sync_trade_cal_cache()
# Check if sync is needed
sync_needed, start_date, end_date, local_last = self.check_incremental_sync(
dataset_name, force_full
)
if not sync_needed:
print(f"[DataSync] {dataset_name} is up-to-date, skipping")
return pd.DataFrame()
config = self.DATASET_CONFIG[dataset_name]
fetch_by = config["fetch_by"]
# Fetch data based on strategy
if fetch_by == "date":
# Fetch by date (all stocks per day)
data = self.sync_by_date(dataset_name, start_date, end_date)
else:
# Fetch by stock (all dates per stock)
data = self._sync_all_stocks(dataset_name, start_date, end_date, max_workers)
if data.empty:
print(f"[DataSync] No new data for {dataset_name}")
return pd.DataFrame()
# Save to storage (single write)
self.storage.save(dataset_name, data, mode="append")
print(f"[DataSync] Synced {len(data)} rows for {dataset_name}")
return data
def _sync_all_stocks(
self,
dataset_name: str,
start_date: str,
end_date: str,
max_workers: Optional[int] = None,
) -> pd.DataFrame:
"""Sync data for all stocks (stock-based fetch)."""
stock_codes = self.get_all_stock_codes()
if not stock_codes:
return pd.DataFrame()
print(f"[DataSync] Syncing {dataset_name} for {len(stock_codes)} stocks")
self._stop_flag.set()
results = []
workers = max_workers or self.max_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_code = {
executor.submit(
self.sync_by_stock, dataset_name, ts_code, start_date, end_date
): ts_code
for ts_code in stock_codes
}
with tqdm(total=len(stock_codes), desc=f"Syncing {dataset_name}") as pbar:
for future in as_completed(future_to_code):
try:
data = future.result()
if not data.empty:
results.append(data)
except Exception as e:
executor.shutdown(wait=False, cancel_futures=True)
raise
pbar.update(1)
if not results:
return pd.DataFrame()
return pd.concat(results, ignore_index=True)
```
## 6. 存储规范
### 6.1 Storage 类使用
所有数据通过 `Storage` 类进行 HDF5 存储:
```python
from src.data.storage import Storage
storage = Storage()
# 保存数据(自动增量合并)
storage.save("dataset_name", data, mode="append")
# 加载数据
all_data = storage.load("dataset_name")
filtered_data = storage.load("dataset_name", start_date="20240101", end_date="20240131")
# 获取最新日期
last_date = storage.get_last_date("dataset_name")
# 检查是否存在
exists = storage.exists("dataset_name")
```
### 6.2 增量写入策略
**关键原则**:所有数据在请求完成后**一次性写入**,而非逐条写入:
```python
# ❌ 错误:逐条写入(性能差)
for date in dates:
data = fetch(date)
storage.save("dataset", data, mode="append") # 多次写入
# ✅ 正确:批量写入(性能好)
all_data = []
for date in dates:
data = fetch(date)
all_data.append(data)
combined = pd.concat(all_data, ignore_index=True)
storage.save("dataset", combined, mode="append") # 单次写入
```
### 6.3 去重策略
`Storage.save()` 方法会自动去重,基于配置中的 `key_fields`
```python
# storage.py 中的实现
combined = pd.concat([existing, data], ignore_index=True)
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], # 使用 key_fields
keep="last" # 保留最新数据
)
```
## 7. 完整示例:新增涨跌停数据接口
### 7.1 创建 limit_list.py
```python
"""Limit up/down list interface.
Fetch stocks that hit limit up or limit down for a specific trade date.
This is a date-based interface (recommended approach).
"""
import pandas as pd
from typing import Optional
from src.data.client import TushareClient
def get_limit_list(
trade_date: Optional[str] = None,
ts_code: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""获取涨跌停数据。
**优先按日期获取**(推荐):
- 使用 trade_date 获取单日全市场涨跌停数据
- 或使用 start_date + end_date 获取区间数据
Args:
trade_date: 交易日期YYYYMMDD格式获取单日全市场数据
ts_code: 股票代码(可选,用于过滤)
start_date: 开始日期YYYYMMDD格式
end_date: 结束日期YYYYMMDD格式
Returns:
pd.DataFrame 包含以下字段:
- ts_code: 股票代码
- trade_date: 交易日期
- name: 股票名称
- close: 收盘价
- pct_chg: 涨跌幅
- amp: 振幅
- fc_ratio: 封单金额/日成交额
- fl_ratio: 封单手数/流通股本
- fd_amount: 封单金额
- first_time: 首次涨停时间
- last_time: 最后封板时间
- open_times: 打开次数
- strth: 涨停强度
- limit: 涨停类型U涨停D跌停
Example:
>>> # 获取单日全市场涨跌停数据(推荐)
>>> data = get_limit_list(trade_date='20240115')
>>> # 获取区间数据
>>> data = get_limit_list(start_date='20240101', end_date='20240131')
"""
client = TushareClient()
params = {}
if trade_date:
params["trade_date"] = trade_date
if ts_code:
params["ts_code"] = ts_code
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
data = client.query("limit_list", **params)
return data
```
### 7.2 在 sync.py 中注册
```python
class DataSync:
"""Data synchronization manager with full/incremental sync support."""
DATASET_CONFIG = {
# ... 其他配置 ...
"limit_list": {
"api_name": "limit_list",
"fetch_by": "date", # 按日期获取
"date_field": "trade_date",
"key_fields": ["ts_code", "trade_date"],
},
}
# ... 其他方法 ...
def sync_limit_list(
self,
force_full: bool = False,
) -> pd.DataFrame:
"""Sync limit list data."""
return self.sync_dataset("limit_list", force_full)
# 便捷函数
def sync_limit_list(force_full: bool = False) -> pd.DataFrame:
"""Sync limit up/down data."""
sync_manager = DataSync()
return sync_manager.sync_limit_list(force_full)
```
### 7.3 更新 __init__.py
```python
from src.data.limit_list import get_limit_list
__all__ = [
# ... 其他导出 ...
"get_limit_list",
]
```
## 8. 测试规范
### 8.1 测试文件结构
```
tests/
├── test_sync.py # sync 模块测试
├── test_daily.py # daily 模块测试
└── test_{new_module}.py # 新增模块测试
```
### 8.2 测试模板
```python
"""Tests for {module_name} module."""
import pytest
from unittest.mock import patch, MagicMock
import pandas as pd
from src.data.{module_name} import get_{data_type}
class Test{DataType}:
"""Test cases for {data_type} data fetching."""
@patch("src.data.{module_name}.TushareClient")
def test_get_{data_type}_by_date(self, mock_client_class):
"""Test fetching data by date."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame({
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
# ... 其他字段 ...
})
# Call function
result = get_{data_type}(trade_date="20240115")
# Verify
assert not result.empty
mock_client.query.assert_called_once_with(
"{api_name}",
trade_date="20240115",
)
@patch("src.data.{module_name}.TushareClient")
def test_get_{data_type}_by_stock(self, mock_client_class):
"""Test fetching data by stock code."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame({
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
# ... 其他字段 ...
})
# Call function
result = get_{data_type}(
ts_code="000001.SZ",
start_date="20240101",
end_date="20240131",
)
# Verify
assert not result.empty
mock_client.query.assert_called_once()
```
## 9. 检查清单
在提交新接口前,请确认以下事项:
### 9.1 文件结构
- [ ] 文件位于 `src/data/{data_type}.py`
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
### 9.2 接口实现
- [ ] 数据获取函数使用 `TushareClient`
- [ ] 函数包含完整的 Google 风格文档字符串
- [ ] 日期参数使用 `YYYYMMDD` 格式
- [ ] 返回的 DataFrame 包含 `ts_code``trade_date` 字段
- [ ] 优先实现按日期获取的接口(如果 API 支持)
### 9.3 Sync 集成
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
- [ ] 正确设置 `fetch_by`"date" 或 "stock"
- [ ] 正确设置 `date_field``key_fields`
- [ ] 已实现对应的 sync 方法或复用通用方法
- [ ] 增量更新逻辑正确(检查本地最新日期)
### 9.4 存储优化
- [ ] 所有数据一次性写入(非逐条)
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
- [ ] 去重字段配置正确
### 9.5 测试
- [ ] 已编写单元测试
- [ ] 已 mock TushareClient
- [ ] 测试覆盖正常和异常情况
## 10. 常见问题
### Q1: API 返回的日期字段名不是 trade_date 怎么办?
在返回前重命名:
```python
data = client.query("api_name", **params)
if "end_date" in data.columns:
data = data.rename(columns={"end_date": "trade_date"})
return data
```
### Q2: 如何处理分页limit/offset
Tushare Pro API 通常不需要手动分页,但如果需要:
```python
all_data = []
offset = 0
limit = 5000
while True:
data = client.query(
"api_name",
trade_date=trade_date,
limit=limit,
offset=offset,
)
if data.empty or len(data) < limit:
all_data.append(data)
break
all_data.append(data)
offset += limit
return pd.concat(all_data, ignore_index=True)
```
### Q3: 如何处理需要额外参数的接口?
在函数签名中添加参数,并传递给 client.query
```python
def get_data(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
fields: Optional[list] = None, # 额外参数
) -> pd.DataFrame:
params = {"ts_code": ts_code}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if fields:
params["fields"] = ",".join(fields)
return client.query("api_name", **params)
```
### Q4: 如何处理没有 trade_date 字段的数据?
如果数据确实不包含日期字段(如静态数据),可以:
1. 将其归类为"特殊接口",独立管理
2. 或者添加一个 `sync_date` 字段记录同步时间
### Q5: 如何处理按日期获取但 API 不支持的情况?
如果 API 只支持按股票获取:
1.`DATASET_CONFIG` 中设置 `fetch_by: "stock"`
2. 使用 `_sync_all_stocks` 方法进行同步
3. 在文档中说明这是按股票获取的接口
---
**最后更新**: 2026-02-01

View File

@@ -6,7 +6,7 @@ Provides simplified interfaces for fetching and storing Tushare data.
from src.data.config import Config, get_config
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.stock_basic import get_stock_basic, sync_all_stocks
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
__all__ = [
"Config",

View File

@@ -0,0 +1,244 @@
# ProStock 数据接口封装规范
## 1. 概述
本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保:
- 代码风格统一
- 自动 sync 支持
- 增量更新逻辑一致
- 减少存储写入压力
## 2. 接口分类
### 2.1 特殊接口(不参与统一 sync
以下接口有独立的同步逻辑,不参与自动 sync 机制:
| 接口类型 | 示例 | 说明 |
|---------|------|------|
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
| 股票基础信息 | `stock_basic` | 一次性全量获取CSV 存储 |
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
### 2.2 标准接口(必须遵循本规范)
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。
## 3. 文件结构要求
### 3.1 文件命名
```
{data_type}.py
```
示例:`daily.py``moneyflow.py``limit_list.py`
### 3.2 文件位置
所有接口文件必须位于 `src/data/` 目录下。
### 3.3 导出要求
新接口必须在 `src/data/__init__.py` 中导出:
```python
from src.data.{module_name} import get_{data_type}
__all__ = [
# ... 其他导出 ...
"get_{data_type}",
]
```
## 4. 接口设计规范
### 4.1 数据获取函数签名要求
函数必须返回 `pd.DataFrame`,参数必须包含以下之一:
#### 4.1.1 按日期获取的接口(优先)
适用于:涨跌停、龙虎榜、筹码分布等。
**函数签名要求**
```python
def get_{data_type}(
trade_date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
# 其他可选参数...
) -> pd.DataFrame:
```
**要求**
- 优先使用 `trade_date` 获取单日全市场数据
- 支持 `start_date + end_date` 获取区间数据
- `ts_code` 作为可选过滤参数
#### 4.1.2 按股票获取的接口
适用于:日线行情、资金流向等。
**函数签名要求**
```python
def get_{data_type}(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
# 其他可选参数...
) -> pd.DataFrame:
```
### 4.2 文档字符串要求
函数必须包含 Google 风格的完整文档字符串,包含:
- 函数功能描述
- `Args` 部分:所有参数说明
- `Returns` 部分:返回的 DataFrame 包含的字段说明
- `Example` 部分:使用示例
### 4.3 日期格式要求
- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式
- 统一使用 `trade_date` 作为日期字段名
- 如果 API 返回其他日期字段名(如 `date``end_date`),必须在返回前重命名为 `trade_date`
### 4.4 股票代码要求
- 统一使用 `ts_code` 作为股票代码字段名
- 格式:`{code}.{exchange}`,如 `000001.SZ``600000.SH`
### 4.5 令牌桶限速要求
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
## 5. Sync 集成规范
### 5.1 DATASET_CONFIG 注册要求
新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项:
```python
"{new_data_type}": {
"api_name": "{tushare_api_name}", # Tushare API 名称
"fetch_by": "date", # "date" 或 "stock"
"date_field": "trade_date",
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
}
```
### 5.2 fetch_by 取值规则
- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据
- 仅当 API 不支持按日期获取时才使用 `"stock"`
### 5.3 sync 方法要求
必须实现对应的 sync 方法或复用通用方法:
```python
def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame:
"""Sync {数据描述}"""
return self.sync_dataset("{data_type}", force_full)
```
同时提供便捷函数:
```python
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
"""Sync {数据描述}"""
sync_manager = DataSync()
return sync_manager.sync_{data_type}(force_full)
```
### 5.4 增量更新要求
- 必须实现增量更新逻辑(自动检查本地最新日期)
- 使用 `force_full` 参数支持强制全量同步
## 6. 存储规范
### 6.1 存储方式
所有数据通过 `Storage` 类进行 HDF5 存储。
### 6.2 写入策略
**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。
### 6.3 去重要求
使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`
## 7. 测试规范
### 7.1 测试文件要求
必须创建对应的测试文件:`tests/test_{data_type}.py`
### 7.2 测试覆盖要求
- 测试按日期获取
- 测试按股票获取(如果支持)
- 必须 mock `TushareClient`
- 测试覆盖正常和异常情况
## 8. 新增接口完整流程
### 8.1 创建接口文件
1.`src/data/` 下创建 `{data_type}.py`
2. 实现数据获取函数,遵循第 4 节规范
### 8.2 注册 sync 支持
1.`sync.py``DataSync.DATASET_CONFIG` 中注册
2. 实现对应的 sync 方法
3. 提供便捷函数
### 8.3 更新导出
`src/data/__init__.py` 中导出接口函数。
### 8.4 创建测试
创建 `tests/test_{data_type}.py`,覆盖关键场景。
## 9. 检查清单
### 9.1 文件结构
- [ ] 文件位于 `src/data/{data_type}.py`
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
### 9.2 接口实现
- [ ] 数据获取函数使用 `TushareClient`
- [ ] 函数包含完整的 Google 风格文档字符串
- [ ] 日期参数使用 `YYYYMMDD` 格式
- [ ] 返回的 DataFrame 包含 `ts_code``trade_date` 字段
- [ ] 优先实现按日期获取的接口(如果 API 支持)
### 9.3 Sync 集成
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
- [ ] 正确设置 `fetch_by`"date" 或 "stock"
- [ ] 正确设置 `date_field``key_fields`
- [ ] 已实现对应的 sync 方法或复用通用方法
- [ ] 增量更新逻辑正确(检查本地最新日期)
### 9.4 存储优化
- [ ] 所有数据一次性写入(非逐条)
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
- [ ] 去重字段配置正确
### 9.5 测试
- [ ] 已编写单元测试
- [ ] 已 mock TushareClient
- [ ] 测试覆盖正常和异常情况
---
**最后更新**: 2026-02-01

View File

@@ -0,0 +1,40 @@
"""Tushare API wrapper modules.
This package contains simplified interfaces for fetching data from Tushare API.
All wrapper files follow the naming convention: api_{data_type}.py
Available APIs:
- api_daily: Daily market data (日线行情)
- api_stock_basic: Stock basic information (股票基本信息)
- api_trade_cal: Trading calendar (交易日历)
Example:
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> stocks = get_stock_basic()
>>> calendar = get_trade_cal('20240101', '20240131')
"""
from src.data.api_wrappers.api_daily import get_daily
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
from src.data.api_wrappers.api_trade_cal import (
get_trade_cal,
get_trading_days,
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
)
__all__ = [
# Daily market data
"get_daily",
# Stock basic information
"get_stock_basic",
"sync_all_stocks",
# Trade calendar
"get_trade_cal",
"get_trading_days",
"get_first_trading_day",
"get_last_trading_day",
"sync_trade_cal_cache",
]

View File

@@ -179,4 +179,75 @@ df = pro.query('trade_cal', start_date='20180101', end_date='20181231')
17 SSE 20180118 1
18 SSE 20180119 1
19 SSE 20180120 0
20 SSE 20180121 0
20 SSE 20180121 0
每日指标
接口daily_basic可以通过数据工具调试和查看数据。
更新时间交易日每日15点17点之间
描述获取全部股票每日重要的基本面指标可用于选股分析、报表展示等。单次请求最大返回6000条数据可按日线循环提取全部历史。
积分至少2000积分才可以调取5000积分无总量限制具体请参阅积分获取办法
输入参数
名称 类型 必选 描述
ts_code str Y 股票代码(二选一)
trade_date str N 交易日期 (二选一)
start_date str N 开始日期(YYYYMMDD)
end_date str N 结束日期(YYYYMMDD)
日期都填YYYYMMDD格式比如20181010
输出参数
名称 类型 描述
ts_code str TS股票代码
trade_date str 交易日期
close float 当日收盘价
turnover_rate float 换手率(%
turnover_rate_f float 换手率(自由流通股)
volume_ratio float 量比
pe float 市盈率(总市值/净利润, 亏损的PE为空
pe_ttm float 市盈率TTM亏损的PE为空
pb float 市净率(总市值/净资产)
ps float 市销率
ps_ttm float 市销率TTM
dv_ratio float 股息率 %
dv_ttm float 股息率TTM%
total_share float 总股本 (万股)
float_share float 流通股本 (万股)
free_share float 自由流通股本 (万)
total_mv float 总市值 (万元)
circ_mv float 流通市值(万元)
接口用法
pro = ts.pro_api()
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
或者
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
数据样例
ts_code trade_date turnover_rate volume_ratio pe pb
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214

View File

@@ -3,6 +3,7 @@
A single function to fetch A股日线行情 data from Tushare.
Supports all output fields including tor (换手率) and vr (量比).
"""
import pandas as pd
from typing import Optional, List, Literal
from src.data.client import TushareClient
@@ -33,7 +34,7 @@ def get_daily(
Returns:
pd.DataFrame with daily market data containing:
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
change, pct_chg, vol, amount
- Factor fields (if requested): tor, vr
- Adjustment factor (if adjfactor=True): adjfactor

View File

@@ -3,6 +3,7 @@
Fetch basic stock information including code, name, listing date, etc.
This is a special interface - call once to get all stocks (listed and delisted).
"""
import os
import pandas as pd
from pathlib import Path

View File

@@ -1,4 +1,5 @@
"""Simplified HDF5 storage for data persistence."""
import os
import pandas as pd
from pathlib import Path
@@ -47,7 +48,9 @@ class Storage:
# Merge with existing data
existing = store[name]
combined = pd.concat([existing, data], ignore_index=True)
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
)
store.put(name, combined, format="table")
print(f"[Storage] Saved {len(data)} rows to {file_path}")
@@ -57,10 +60,13 @@ class Storage:
print(f"[Storage] Error saving {name}: {e}")
return {"status": "error", "error": str(e)}
def load(self, name: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None) -> pd.DataFrame:
def load(
self,
name: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pd.DataFrame:
"""Load data from HDF5 file.
Args:
@@ -80,14 +86,25 @@ class Storage:
try:
with pd.HDFStore(file_path, mode="r") as store:
if name not in store.keys():
keys = store.keys()
# Handle both '/daily' and 'daily' keys
actual_key = None
if name in keys:
actual_key = name
elif f"/{name}" in keys:
actual_key = f"/{name}"
if actual_key is None:
return pd.DataFrame()
data = store[name]
data = store[actual_key]
# Apply filters
if start_date and end_date and "trade_date" in data.columns:
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
data = data[
(data["trade_date"] >= start_date)
& (data["trade_date"] <= end_date)
]
if ts_code and "ts_code" in data.columns:
data = data[data["ts_code"] == ts_code]

View File

@@ -5,11 +5,15 @@ This module provides data fetching functions with intelligent sync logic:
- If local file exists: incremental update (fetch from latest date + 1 day)
- Multi-threaded concurrent fetching for improved performance
- Stop immediately on any exception
- Preview mode: check data volume and samples before actual sync
Currently supported data types:
- daily: Daily market data (with turnover rate and volume ratio)
Usage:
# Preview sync (check data volume and samples without writing)
preview_sync()
# Sync all stocks (full load)
sync_all()
@@ -18,6 +22,9 @@ Usage:
# Force full reload
sync_all(force_full=True)
# Dry run (preview only, no write)
sync_all(dry_run=True)
"""
import pandas as pd
@@ -30,8 +37,8 @@ import sys
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.daily import get_daily
from src.data.trade_cal import (
from src.data.api_wrappers import get_daily
from src.data.api_wrappers import (
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
@@ -114,7 +121,8 @@ class DataSync:
List of stock codes
"""
# Import sync_all_stocks here to avoid circular imports
from src.data.stock_basic import sync_all_stocks, _get_csv_path
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_stock_basic import _get_csv_path
# First, ensure stock_basic.csv is up-to-date with all stocks
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
@@ -278,6 +286,184 @@ class DataSync:
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
def preview_sync(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""Preview sync data volume and samples without actually syncing.
This method provides a preview of what would be synced, including:
- Number of stocks to be synced
- Date range for sync
- Estimated total records
- Sample data from first few stocks
Args:
force_full: If True, preview full sync from 20180101
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
sample_size: Number of sample stocks to fetch for preview (default: 3)
Returns:
Dictionary with preview information:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full' or 'incremental'
}
"""
print("\n" + "=" * 60)
print("[DataSync] Preview Mode - Analyzing sync requirements...")
print("=" * 60)
# First, ensure trade calendar cache is up-to-date
print("[DataSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# Determine date range
if end_date is None:
end_date = get_today_date()
# Check if sync is needed
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
print("\n" + "=" * 60)
print("[DataSync] Preview Result")
print("=" * 60)
print(" Sync Status: NOT NEEDED")
print(" Reason: Local data is up-to-date with trade calendar")
print("=" * 60)
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
# Use dates from check_sync_needed
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# Determine sync mode
if force_full:
mode = "full"
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
# Get all stock codes
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DataSync] No stocks found to sync")
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
stock_count = len(stock_codes)
print(f"[DataSync] Total stocks to sync: {stock_count}")
# Fetch sample data from first few stocks
print(f"[DataSync] Fetching sample data from {sample_size} stocks...")
sample_data_list = []
sample_codes = stock_codes[:sample_size]
for ts_code in sample_codes:
try:
data = self.client.query(
"pro_bar",
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
factors="tor,vr",
)
if not data.empty:
sample_data_list.append(data)
print(f" - {ts_code}: {len(data)} records")
except Exception as e:
print(f" - {ts_code}: Error fetching - {e}")
# Combine sample data
sample_df = (
pd.concat(sample_data_list, ignore_index=True)
if sample_data_list
else pd.DataFrame()
)
# Estimate total records based on sample
if not sample_df.empty:
avg_records_per_stock = len(sample_df) / len(sample_data_list)
estimated_records = int(avg_records_per_stock * stock_count)
else:
estimated_records = 0
# Display preview results
print("\n" + "=" * 60)
print("[DataSync] Preview Result")
print("=" * 60)
print(f" Sync Mode: {mode.upper()}")
print(f" Date Range: {sync_start_date} to {end_date}")
print(f" Stocks to Sync: {stock_count}")
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
print(f" Estimated Total Records: ~{estimated_records:,}")
if not sample_df.empty:
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
print(" " + "-" * 56)
# Display sample data in a compact format
preview_cols = [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"vol",
]
available_cols = [c for c in preview_cols if c in sample_df.columns]
sample_display = sample_df[available_cols].head(10)
for idx, row in sample_display.iterrows():
print(f" {row.to_dict()}")
print(" " + "-" * 56)
print("=" * 60)
return {
"sync_needed": True,
"stock_count": stock_count,
"start_date": sync_start_date,
"end_date": end_date,
"estimated_records": estimated_records,
"sample_data": sample_df,
"mode": mode,
}
def sync_single_stock(
self,
ts_code: str,
@@ -320,6 +506,7 @@ class DataSync:
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""Sync daily data for all stocks in local storage.
@@ -337,9 +524,10 @@ class DataSync:
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
max_workers: Number of worker threads (default: 10)
dry_run: If True, only preview what would be synced without writing data
Returns:
Dict mapping ts_code to DataFrame (empty if sync skipped)
Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run)
"""
print("\n" + "=" * 60)
print("[DataSync] Starting daily data sync...")
@@ -378,11 +566,14 @@ class DataSync:
# Determine sync mode
if force_full:
mode = "full"
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
# Get all stock codes
@@ -394,6 +585,17 @@ class DataSync:
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
# Handle dry run mode
if dry_run:
print("\n" + "=" * 60)
print("[DataSync] DRY RUN MODE - No data will be written")
print("=" * 60)
print(f" Would sync {len(stock_codes)} stocks")
print(f" Date range: {sync_start_date} to {end_date}")
print(f" Mode: {mode}")
print("=" * 60)
return {}
# Reset stop flag for new sync
self._stop_flag.set()
@@ -492,11 +694,62 @@ class DataSync:
# Convenience functions
def preview_sync(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
max_workers: Optional[int] = None,
) -> dict:
"""Preview sync data volume and samples without actually syncing.
This is the recommended way to check what would be synced before
running the actual synchronization.
Args:
force_full: If True, preview full sync from 20180101
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
sample_size: Number of sample stocks to fetch for preview (default: 3)
max_workers: Number of worker threads (not used in preview, for API compatibility)
Returns:
Dictionary with preview information:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 'partial', or 'none'
}
Example:
>>> # Preview what would be synced
>>> preview = preview_sync()
>>>
>>> # Preview full sync
>>> preview = preview_sync(force_full=True)
>>>
>>> # Preview with more samples
>>> preview = preview_sync(sample_size=5)
"""
sync_manager = DataSync(max_workers=max_workers)
return sync_manager.preview_sync(
force_full=force_full,
start_date=start_date,
end_date=end_date,
sample_size=sample_size,
)
def sync_all(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""Sync daily data for all stocks.
@@ -507,6 +760,7 @@ def sync_all(
start_date: Manual start date (YYYYMMDD)
end_date: Manual end date (defaults to today)
max_workers: Number of worker threads (default: 10)
dry_run: If True, only preview what would be synced without writing data
Returns:
Dict mapping ts_code to DataFrame
@@ -526,12 +780,16 @@ def sync_all(
>>>
>>> # Custom thread count
>>> result = sync_all(max_workers=20)
>>>
>>> # Dry run (preview only)
>>> result = sync_all(dry_run=True)
"""
sync_manager = DataSync(max_workers=max_workers)
return sync_manager.sync_all(
force_full=force_full,
start_date=start_date,
end_date=end_date,
dry_run=dry_run,
)
@@ -540,11 +798,32 @@ if __name__ == "__main__":
print("Data Sync Module")
print("=" * 60)
print("\nUsage:")
print(" from src.data.sync import sync_all")
print(" from src.data.sync import sync_all, preview_sync")
print("")
print(" # Preview before sync (recommended)")
print(" preview = preview_sync()")
print("")
print(" # Dry run (preview only)")
print(" result = sync_all(dry_run=True)")
print("")
print(" # Actual sync")
print(" result = sync_all() # Incremental sync")
print(" result = sync_all(force_full=True) # Full reload")
print("\n" + "=" * 60)
# Run sync
result = sync_all()
print(f"\nSynced {len(result)} stocks")
# Run preview first
print("\n[Main] Running preview first...")
preview = preview_sync()
if preview["sync_needed"]:
# Ask for confirmation
print("\n" + "=" * 60)
response = input("Proceed with sync? (y/n): ").strip().lower()
if response in ("y", "yes"):
print("\n[Main] Starting actual sync...")
result = sync_all()
print(f"\nSynced {len(result)} stocks")
else:
print("\n[Main] Sync cancelled by user")
else:
print("\n[Main] No sync needed - data is up to date")

View File

@@ -5,29 +5,30 @@ Tests the daily interface implementation against api.md requirements:
- tor 换手率
- vr 量比
"""
import pytest
import pandas as pd
from src.data.daily import get_daily
from src.data.api_wrappers import get_daily
# Expected output fields according to api.md
EXPECTED_BASE_FIELDS = [
'ts_code', # 股票代码
'trade_date', # 交易日期
'open', # 开盘价
'high', # 最高价
'low', # 最低价
'close', # 收盘价
'pre_close', # 昨收价
'change', # 涨跌额
'pct_chg', # 涨跌幅
'vol', # 成交量
'amount', # 成交额
"ts_code", # 股票代码
"trade_date", # 交易日期
"open", # 开盘价
"high", # 最高价
"low", # 最低价
"close", # 收盘价
"pre_close", # 昨收价
"change", # 涨跌额
"pct_chg", # 涨跌幅
"vol", # 成交量
"amount", # 成交额
]
EXPECTED_FACTOR_FIELDS = [
'turnover_rate', # 换手率 (tor)
'volume_ratio', # 量比 (vr)
"turnover_rate", # 换手率 (tor)
"volume_ratio", # 量比 (vr)
]
@@ -36,19 +37,19 @@ class TestGetDaily:
def test_fetch_basic(self):
"""Test basic daily data fetch with real API."""
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
result = get_daily("000001.SZ", start_date="20240101", end_date="20240131")
assert isinstance(result, pd.DataFrame)
assert len(result) >= 1
assert result['ts_code'].iloc[0] == '000001.SZ'
assert result["ts_code"].iloc[0] == "000001.SZ"
def test_fetch_with_factors(self):
"""Test fetch with tor and vr factors."""
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
"000001.SZ",
start_date="20240101",
end_date="20240131",
factors=["tor", "vr"],
)
assert isinstance(result, pd.DataFrame)
@@ -61,25 +62,26 @@ class TestGetDaily:
def test_output_fields_completeness(self):
"""Verify all required output fields are returned."""
result = get_daily('600000.SH')
result = get_daily("600000.SH")
# Verify all base fields are present
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), (
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
)
def test_empty_result(self):
"""Test handling of empty results."""
# 使用真实 API 测试无效股票代码的空结果
result = get_daily('INVALID.SZ')
result = get_daily("INVALID.SZ")
assert isinstance(result, pd.DataFrame)
assert result.empty
def test_date_range_query(self):
"""Test query with date range."""
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
"000001.SZ",
start_date="20240101",
end_date="20240131",
)
assert isinstance(result, pd.DataFrame)
@@ -87,7 +89,7 @@ class TestGetDaily:
def test_with_adj(self):
"""Test fetch with adjustment type."""
result = get_daily('000001.SZ', adj='qfq')
result = get_daily("000001.SZ", adj="qfq")
assert isinstance(result, pd.DataFrame)
@@ -95,11 +97,14 @@ class TestGetDaily:
def test_integration():
"""Integration test with real Tushare API (requires valid token)."""
import os
token = os.environ.get('TUSHARE_TOKEN')
token = os.environ.get("TUSHARE_TOKEN")
if not token:
pytest.skip("TUSHARE_TOKEN not configured")
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
result = get_daily(
"000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"]
)
# Verify structure
assert isinstance(result, pd.DataFrame)
@@ -112,6 +117,6 @@ def test_integration():
assert field in result.columns, f"Missing factor field: {field}"
if __name__ == '__main__':
if __name__ == "__main__":
# 运行 pytest 单元测试真实API调用
pytest.main([__file__, '-v'])
pytest.main([__file__, "-v"])

View File

@@ -9,7 +9,7 @@ import pytest
import pandas as pd
from pathlib import Path
from src.data.storage import Storage
from src.data.stock_basic import _get_csv_path
from src.data.api_wrappers.api_stock_basic import _get_csv_path
class TestDailyStorageValidation:

View File

@@ -5,6 +5,7 @@ Tests the sync module's full/incremental sync logic for daily data:
- Incremental sync when local data exists (from last_date + 1)
- Data integrity validation
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch, MagicMock
@@ -17,6 +18,8 @@ from src.data.sync import (
get_next_date,
DEFAULT_START_DATE,
)
from src.data.storage import Storage
from src.data.client import TushareClient
class TestDateUtilities:
@@ -63,30 +66,32 @@ class TestDataSync:
def test_get_all_stock_codes_from_daily(self, mock_storage):
"""Test getting stock codes from daily data."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'],
})
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
}
)
codes = sync.get_all_stock_codes()
assert len(codes) == 2
assert '000001.SZ' in codes
assert '600000.SH' in codes
assert "000001.SZ" in codes
assert "600000.SH" in codes
def test_get_all_stock_codes_fallback(self, mock_storage):
"""Test fallback to stock_basic when daily is empty."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
# First call (daily) returns empty, second call (stock_basic) returns data
mock_storage.load.side_effect = [
pd.DataFrame(), # daily empty
pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
]
codes = sync.get_all_stock_codes()
@@ -95,21 +100,23 @@ class TestDataSync:
def test_get_global_last_date(self, mock_storage):
"""Test getting global last date."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ', '600000.SH'],
'trade_date': ['20240102', '20240103'],
})
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "600000.SH"],
"trade_date": ["20240102", "20240103"],
}
)
last_date = sync.get_global_last_date()
assert last_date == '20240103'
assert last_date == "20240103"
def test_get_global_last_date_empty(self, mock_storage):
"""Test getting last date from empty storage."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
@@ -120,18 +127,23 @@ class TestDataSync:
def test_sync_single_stock(self, mock_storage):
"""Test syncing a single stock."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
})):
with patch("src.data.sync.Storage", return_value=mock_storage):
with patch(
"src.data.sync.get_daily",
return_value=pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
),
):
sync = DataSync()
sync.storage = mock_storage
result = sync.sync_single_stock(
ts_code='000001.SZ',
start_date='20240101',
end_date='20240102',
ts_code="000001.SZ",
start_date="20240101",
end_date="20240102",
)
assert isinstance(result, pd.DataFrame)
@@ -139,15 +151,15 @@ class TestDataSync:
def test_sync_single_stock_empty(self, mock_storage):
"""Test syncing a stock with no data."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
with patch("src.data.sync.Storage", return_value=mock_storage):
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
result = sync.sync_single_stock(
ts_code='INVALID.SZ',
start_date='20240101',
end_date='20240102',
ts_code="INVALID.SZ",
start_date="20240101",
end_date="20240102",
)
assert result.empty
@@ -158,40 +170,46 @@ class TestSyncAll:
def test_full_sync_mode(self, mock_storage):
"""Test full sync mode when force_full=True."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
with patch("src.data.sync.Storage", return_value=mock_storage):
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ'],
})
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
}
)
result = sync.sync_all(force_full=True)
# Verify sync_single_stock was called with default start date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == DEFAULT_START_DATE
assert call_args[1]["start_date"] == DEFAULT_START_DATE
def test_incremental_sync_mode(self, mock_storage):
"""Test incremental sync mode when data exists."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
# Mock existing data with last date
mock_storage.load.side_effect = [
pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
}), # get_all_stock_codes
pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
}), # get_global_last_date
pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
), # get_all_stock_codes
pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
), # get_global_last_date
]
result = sync.sync_all(force_full=False)
@@ -199,28 +217,30 @@ class TestSyncAll:
# Verify sync_single_stock was called with next date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == '20240103'
assert call_args[1]["start_date"] == "20240103"
def test_manual_start_date(self, mock_storage):
"""Test sync with manual start date."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ'],
})
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
}
)
result = sync.sync_all(force_full=False, start_date='20230601')
result = sync.sync_all(force_full=False, start_date="20230601")
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == '20230601'
assert call_args[1]["start_date"] == "20230601"
def test_no_stocks_found(self, mock_storage):
"""Test sync when no stocks are found."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
@@ -236,7 +256,7 @@ class TestSyncAllConvenienceFunction:
def test_sync_all_function(self):
"""Test sync_all convenience function."""
with patch('src.data.sync.DataSync') as MockSync:
with patch("src.data.sync.DataSync") as MockSync:
mock_instance = Mock()
mock_instance.sync_all.return_value = {}
MockSync.return_value = mock_instance
@@ -251,5 +271,5 @@ class TestSyncAllConvenienceFunction:
)
if __name__ == '__main__':
pytest.main([__file__, '-v'])
if __name__ == "__main__":
pytest.main([__file__, "-v"])

256
tests/test_sync_real.py Normal file
View File

@@ -0,0 +1,256 @@
"""Tests for data sync with REAL data (read-only).
Tests verify:
1. get_global_last_date() correctly reads local data's max date
2. Incremental sync date calculation (local_last_date + 1)
3. Full sync date calculation (20180101)
4. Multi-stock scenario with real data
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
- NO sync_all() calls (writes daily.h5)
- NO check_sync_needed() calls (writes trade_cal.h5)
"""
import pytest
import pandas as pd
from pathlib import Path
from src.data.sync import (
DataSync,
get_next_date,
DEFAULT_START_DATE,
)
from src.data.storage import Storage
class TestDataSyncReadOnly:
"""Read-only tests for data sync - verify date calculation logic."""
@pytest.fixture
def storage(self):
"""Create storage instance."""
return Storage()
@pytest.fixture
def data_sync(self):
"""Create DataSync instance."""
return DataSync()
@pytest.fixture
def daily_exists(self, storage):
"""Check if daily.h5 exists."""
return storage.exists("daily")
def test_daily_h5_exists(self, storage):
"""Verify daily.h5 data file exists before running tests."""
assert storage.exists("daily"), (
"daily.h5 not found. Please run full sync first: "
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
)
def test_get_global_last_date(self, data_sync, daily_exists):
"""Test get_global_last_date returns correct max date from local data."""
if not daily_exists:
pytest.skip("daily.h5 not found")
last_date = data_sync.get_global_last_date()
# Verify it's a valid date string
assert last_date is not None, "get_global_last_date returned None"
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
# Verify by reading storage directly
daily_data = data_sync.storage.load("daily")
expected_max = str(daily_data["trade_date"].max())
assert last_date == expected_max, (
f"get_global_last_date returned {last_date}, "
f"but actual max date is {expected_max}"
)
print(f"[TEST] Local data last date: {last_date}")
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
"""Test incremental sync: start_date = local_last_date + 1.
This verifies that when local data exists, incremental sync should
fetch data from (local_last_date + 1), not from 20180101.
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
# Get local last date
local_last_date = data_sync.get_global_last_date()
assert local_last_date is not None, "No local data found"
# Calculate expected incremental start date
expected_start_date = get_next_date(local_last_date)
# Verify the calculation is correct
local_last_int = int(local_last_date)
expected_int = local_last_int + 1
actual_int = int(expected_start_date)
assert actual_int == expected_int, (
f"Incremental start date calculation error: "
f"expected {expected_int}, got {actual_int}"
)
print(
f"[TEST] Incremental sync: local_last={local_last_date}, "
f"start_date should be {expected_start_date}"
)
# Verify this is NOT 20180101 (would be full sync)
assert expected_start_date != DEFAULT_START_DATE, (
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
)
def test_full_sync_date_calculation(self):
"""Test full sync: start_date = 20180101 when force_full=True.
This verifies that force_full=True always starts from 20180101.
"""
# Full sync should always use DEFAULT_START_DATE
full_sync_start = DEFAULT_START_DATE
assert full_sync_start == "20180101", (
f"Full sync should start from 20180101, got {full_sync_start}"
)
print(f"[TEST] Full sync start date: {full_sync_start}")
def test_date_comparison_logic(self, data_sync, daily_exists):
"""Test date comparison: incremental vs full sync selection logic.
Verify that:
- If local_last_date < today: incremental sync needed
- If local_last_date >= today: no sync needed
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
from datetime import datetime
local_last_date = data_sync.get_global_last_date()
today = datetime.now().strftime("%Y%m%d")
local_last_int = int(local_last_date)
today_int = int(today)
# Log the comparison
print(
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
f"today={today} ({today_int})"
)
# This test just verifies the comparison logic works
if local_last_int < today_int:
print("[TEST] Local data is older than today - sync needed")
# Incremental sync should fetch from local_last_date + 1
sync_start = get_next_date(local_last_date)
assert int(sync_start) > local_last_int, (
"Sync start should be after local last"
)
else:
print("[TEST] Local data is up-to-date - no sync needed")
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
"""Test get_all_stock_codes returns multiple real stock codes."""
if not daily_exists:
pytest.skip("daily.h5 not found")
codes = data_sync.get_all_stock_codes()
# Verify it's a list
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
assert len(codes) > 0, "No stock codes found"
# Verify multiple stocks
assert len(codes) >= 10, (
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
)
# Verify format (should be like 000001.SZ, 600000.SH)
sample_codes = codes[:5]
for code in sample_codes:
assert "." in code, f"Invalid stock code format: {code}"
suffix = code.split(".")[-1]
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
def test_multi_stock_date_range(self, data_sync, daily_exists):
"""Test that multiple stocks share the same date range in local data.
This verifies that local data has consistent date coverage across stocks.
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
daily_data = data_sync.storage.load("daily")
# Get date range for each stock
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
# Get global min and max
global_min = str(daily_data["trade_date"].min())
global_max = str(daily_data["trade_date"].max())
print(f"[TEST] Global date range: {global_min} to {global_max}")
print(f"[TEST] Total stocks: {len(stock_dates)}")
# Verify we have data for multiple stocks
assert len(stock_dates) >= 10, (
f"Expected at least 10 stocks, got {len(stock_dates)}"
)
# Verify date range is reasonable (at least 1 year of data)
global_min_int = int(global_min)
global_max_int = int(global_max)
days_span = global_max_int - global_min_int
assert days_span > 100, (
f"Date range too small: {days_span} days. "
f"Expected at least 100 days of data."
)
print(f"[TEST] Date span: {days_span} days")
class TestDateUtilities:
"""Test date utility functions."""
def test_get_next_date(self):
"""Test get_next_date correctly calculates next day."""
# Test normal cases
assert get_next_date("20240101") == "20240102"
assert get_next_date("20240131") == "20240201" # Month boundary
assert get_next_date("20241231") == "20250101" # Year boundary
def test_incremental_vs_full_sync_logic(self):
"""Test the logic difference between incremental and full sync.
Incremental: start_date = local_last_date + 1
Full: start_date = 20180101
"""
# Scenario 1: Local data exists
local_last_date = "20240115"
incremental_start = get_next_date(local_last_date)
assert incremental_start == "20240116"
assert incremental_start != DEFAULT_START_DATE
# Scenario 2: Force full sync
full_sync_start = DEFAULT_START_DATE # "20180101"
assert full_sync_start == "20180101"
assert incremental_start != full_sync_start
print("[TEST] Incremental vs Full sync logic verified")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])