Compare commits

...

3 Commits

Author SHA1 Message Date
9965ce5706 refactor: 重构 API 接口模块,整合为 api_wrappers 目录结构
- 将独立 API 模块 (daily, stock_basic, trade_cal) 整合至 api_wrappers/
- 重写 sync.py 使用新的 wrapper 结构,支持更多同步功能
- 更新测试文件适配新的模块结构
- 添加 pytest.ini 配置文件
2026-02-21 03:43:30 +08:00
e81d39ae0d chore: 添加 .opencode 到 gitignore 2026-02-01 23:50:17 +08:00
8fc88b60e3 docs: 添加 Tushare API 接口规范文档 2026-02-01 23:50:03 +08:00
15 changed files with 1043 additions and 105 deletions

1
.gitignore vendored
View File

@@ -46,6 +46,7 @@ logs/
# IDE和编辑器
.vscode/
.idea/
.opencode/
*.swp
*.swo
*~

3
pytest.ini Normal file
View File

@@ -0,0 +1,3 @@
[pytest]
pythonpath = .
testpaths = tests

View File

@@ -6,7 +6,7 @@ Provides simplified interfaces for fetching and storing Tushare data.
from src.data.config import Config, get_config
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.stock_basic import get_stock_basic, sync_all_stocks
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
__all__ = [
"Config",

View File

@@ -0,0 +1,244 @@
# ProStock 数据接口封装规范
## 1. 概述
本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保:
- 代码风格统一
- 自动 sync 支持
- 增量更新逻辑一致
- 减少存储写入压力
## 2. 接口分类
### 2.1 特殊接口(不参与统一 sync
以下接口有独立的同步逻辑,不参与自动 sync 机制:
| 接口类型 | 示例 | 说明 |
|---------|------|------|
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
| 股票基础信息 | `stock_basic` | 一次性全量获取CSV 存储 |
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
### 2.2 标准接口(必须遵循本规范)
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。
## 3. 文件结构要求
### 3.1 文件命名
```
{data_type}.py
```
示例:`daily.py``moneyflow.py``limit_list.py`
### 3.2 文件位置
所有接口文件必须位于 `src/data/` 目录下。
### 3.3 导出要求
新接口必须在 `src/data/__init__.py` 中导出:
```python
from src.data.{module_name} import get_{data_type}
__all__ = [
# ... 其他导出 ...
"get_{data_type}",
]
```
## 4. 接口设计规范
### 4.1 数据获取函数签名要求
函数必须返回 `pd.DataFrame`,参数必须包含以下之一:
#### 4.1.1 按日期获取的接口(优先)
适用于:涨跌停、龙虎榜、筹码分布等。
**函数签名要求**
```python
def get_{data_type}(
trade_date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
# 其他可选参数...
) -> pd.DataFrame:
```
**要求**
- 优先使用 `trade_date` 获取单日全市场数据
- 支持 `start_date + end_date` 获取区间数据
- `ts_code` 作为可选过滤参数
#### 4.1.2 按股票获取的接口
适用于:日线行情、资金流向等。
**函数签名要求**
```python
def get_{data_type}(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
# 其他可选参数...
) -> pd.DataFrame:
```
### 4.2 文档字符串要求
函数必须包含 Google 风格的完整文档字符串,包含:
- 函数功能描述
- `Args` 部分:所有参数说明
- `Returns` 部分:返回的 DataFrame 包含的字段说明
- `Example` 部分:使用示例
### 4.3 日期格式要求
- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式
- 统一使用 `trade_date` 作为日期字段名
- 如果 API 返回其他日期字段名(如 `date``end_date`),必须在返回前重命名为 `trade_date`
### 4.4 股票代码要求
- 统一使用 `ts_code` 作为股票代码字段名
- 格式:`{code}.{exchange}`,如 `000001.SZ``600000.SH`
### 4.5 令牌桶限速要求
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
## 5. Sync 集成规范
### 5.1 DATASET_CONFIG 注册要求
新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项:
```python
"{new_data_type}": {
"api_name": "{tushare_api_name}", # Tushare API 名称
"fetch_by": "date", # "date" 或 "stock"
"date_field": "trade_date",
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
}
```
### 5.2 fetch_by 取值规则
- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据
- 仅当 API 不支持按日期获取时才使用 `"stock"`
### 5.3 sync 方法要求
必须实现对应的 sync 方法或复用通用方法:
```python
def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame:
"""Sync {数据描述}"""
return self.sync_dataset("{data_type}", force_full)
```
同时提供便捷函数:
```python
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
"""Sync {数据描述}"""
sync_manager = DataSync()
return sync_manager.sync_{data_type}(force_full)
```
### 5.4 增量更新要求
- 必须实现增量更新逻辑(自动检查本地最新日期)
- 使用 `force_full` 参数支持强制全量同步
## 6. 存储规范
### 6.1 存储方式
所有数据通过 `Storage` 类进行 HDF5 存储。
### 6.2 写入策略
**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。
### 6.3 去重要求
使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`
## 7. 测试规范
### 7.1 测试文件要求
必须创建对应的测试文件:`tests/test_{data_type}.py`
### 7.2 测试覆盖要求
- 测试按日期获取
- 测试按股票获取(如果支持)
- 必须 mock `TushareClient`
- 测试覆盖正常和异常情况
## 8. 新增接口完整流程
### 8.1 创建接口文件
1.`src/data/` 下创建 `{data_type}.py`
2. 实现数据获取函数,遵循第 4 节规范
### 8.2 注册 sync 支持
1.`sync.py``DataSync.DATASET_CONFIG` 中注册
2. 实现对应的 sync 方法
3. 提供便捷函数
### 8.3 更新导出
`src/data/__init__.py` 中导出接口函数。
### 8.4 创建测试
创建 `tests/test_{data_type}.py`,覆盖关键场景。
## 9. 检查清单
### 9.1 文件结构
- [ ] 文件位于 `src/data/{data_type}.py`
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
### 9.2 接口实现
- [ ] 数据获取函数使用 `TushareClient`
- [ ] 函数包含完整的 Google 风格文档字符串
- [ ] 日期参数使用 `YYYYMMDD` 格式
- [ ] 返回的 DataFrame 包含 `ts_code``trade_date` 字段
- [ ] 优先实现按日期获取的接口(如果 API 支持)
### 9.3 Sync 集成
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
- [ ] 正确设置 `fetch_by`"date" 或 "stock"
- [ ] 正确设置 `date_field``key_fields`
- [ ] 已实现对应的 sync 方法或复用通用方法
- [ ] 增量更新逻辑正确(检查本地最新日期)
### 9.4 存储优化
- [ ] 所有数据一次性写入(非逐条)
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
- [ ] 去重字段配置正确
### 9.5 测试
- [ ] 已编写单元测试
- [ ] 已 mock TushareClient
- [ ] 测试覆盖正常和异常情况
---
**最后更新**: 2026-02-01

View File

@@ -0,0 +1,40 @@
"""Tushare API wrapper modules.
This package contains simplified interfaces for fetching data from Tushare API.
All wrapper files follow the naming convention: api_{data_type}.py
Available APIs:
- api_daily: Daily market data (日线行情)
- api_stock_basic: Stock basic information (股票基本信息)
- api_trade_cal: Trading calendar (交易日历)
Example:
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> stocks = get_stock_basic()
>>> calendar = get_trade_cal('20240101', '20240131')
"""
from src.data.api_wrappers.api_daily import get_daily
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
from src.data.api_wrappers.api_trade_cal import (
get_trade_cal,
get_trading_days,
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
)
__all__ = [
# Daily market data
"get_daily",
# Stock basic information
"get_stock_basic",
"sync_all_stocks",
# Trade calendar
"get_trade_cal",
"get_trading_days",
"get_first_trading_day",
"get_last_trading_day",
"sync_trade_cal_cache",
]

View File

@@ -179,4 +179,75 @@ df = pro.query('trade_cal', start_date='20180101', end_date='20181231')
17 SSE 20180118 1
18 SSE 20180119 1
19 SSE 20180120 0
20 SSE 20180121 0
20 SSE 20180121 0
每日指标
接口daily_basic可以通过数据工具调试和查看数据。
更新时间交易日每日15点17点之间
描述获取全部股票每日重要的基本面指标可用于选股分析、报表展示等。单次请求最大返回6000条数据可按日线循环提取全部历史。
积分至少2000积分才可以调取5000积分无总量限制具体请参阅积分获取办法
输入参数
名称 类型 必选 描述
ts_code str Y 股票代码(二选一)
trade_date str N 交易日期 (二选一)
start_date str N 开始日期(YYYYMMDD)
end_date str N 结束日期(YYYYMMDD)
日期都填YYYYMMDD格式比如20181010
输出参数
名称 类型 描述
ts_code str TS股票代码
trade_date str 交易日期
close float 当日收盘价
turnover_rate float 换手率(%
turnover_rate_f float 换手率(自由流通股)
volume_ratio float 量比
pe float 市盈率(总市值/净利润, 亏损的PE为空
pe_ttm float 市盈率TTM亏损的PE为空
pb float 市净率(总市值/净资产)
ps float 市销率
ps_ttm float 市销率TTM
dv_ratio float 股息率 %
dv_ttm float 股息率TTM%
total_share float 总股本 (万股)
float_share float 流通股本 (万股)
free_share float 自由流通股本 (万)
total_mv float 总市值 (万元)
circ_mv float 流通市值(万元)
接口用法
pro = ts.pro_api()
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
或者
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
数据样例
ts_code trade_date turnover_rate volume_ratio pe pb
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214

View File

@@ -3,6 +3,7 @@
A single function to fetch A股日线行情 data from Tushare.
Supports all output fields including tor (换手率) and vr (量比).
"""
import pandas as pd
from typing import Optional, List, Literal
from src.data.client import TushareClient
@@ -33,7 +34,7 @@ def get_daily(
Returns:
pd.DataFrame with daily market data containing:
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
change, pct_chg, vol, amount
- Factor fields (if requested): tor, vr
- Adjustment factor (if adjfactor=True): adjfactor

View File

@@ -3,6 +3,7 @@
Fetch basic stock information including code, name, listing date, etc.
This is a special interface - call once to get all stocks (listed and delisted).
"""
import os
import pandas as pd
from pathlib import Path

View File

@@ -1,4 +1,5 @@
"""Simplified HDF5 storage for data persistence."""
import os
import pandas as pd
from pathlib import Path
@@ -47,7 +48,9 @@ class Storage:
# Merge with existing data
existing = store[name]
combined = pd.concat([existing, data], ignore_index=True)
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
)
store.put(name, combined, format="table")
print(f"[Storage] Saved {len(data)} rows to {file_path}")
@@ -57,10 +60,13 @@ class Storage:
print(f"[Storage] Error saving {name}: {e}")
return {"status": "error", "error": str(e)}
def load(self, name: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None) -> pd.DataFrame:
def load(
self,
name: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pd.DataFrame:
"""Load data from HDF5 file.
Args:
@@ -80,14 +86,25 @@ class Storage:
try:
with pd.HDFStore(file_path, mode="r") as store:
if name not in store.keys():
keys = store.keys()
# Handle both '/daily' and 'daily' keys
actual_key = None
if name in keys:
actual_key = name
elif f"/{name}" in keys:
actual_key = f"/{name}"
if actual_key is None:
return pd.DataFrame()
data = store[name]
data = store[actual_key]
# Apply filters
if start_date and end_date and "trade_date" in data.columns:
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
data = data[
(data["trade_date"] >= start_date)
& (data["trade_date"] <= end_date)
]
if ts_code and "ts_code" in data.columns:
data = data[data["ts_code"] == ts_code]

View File

@@ -5,11 +5,15 @@ This module provides data fetching functions with intelligent sync logic:
- If local file exists: incremental update (fetch from latest date + 1 day)
- Multi-threaded concurrent fetching for improved performance
- Stop immediately on any exception
- Preview mode: check data volume and samples before actual sync
Currently supported data types:
- daily: Daily market data (with turnover rate and volume ratio)
Usage:
# Preview sync (check data volume and samples without writing)
preview_sync()
# Sync all stocks (full load)
sync_all()
@@ -18,6 +22,9 @@ Usage:
# Force full reload
sync_all(force_full=True)
# Dry run (preview only, no write)
sync_all(dry_run=True)
"""
import pandas as pd
@@ -30,8 +37,8 @@ import sys
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.daily import get_daily
from src.data.trade_cal import (
from src.data.api_wrappers import get_daily
from src.data.api_wrappers import (
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
@@ -114,7 +121,8 @@ class DataSync:
List of stock codes
"""
# Import sync_all_stocks here to avoid circular imports
from src.data.stock_basic import sync_all_stocks, _get_csv_path
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_stock_basic import _get_csv_path
# First, ensure stock_basic.csv is up-to-date with all stocks
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
@@ -278,6 +286,184 @@ class DataSync:
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
def preview_sync(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""Preview sync data volume and samples without actually syncing.
This method provides a preview of what would be synced, including:
- Number of stocks to be synced
- Date range for sync
- Estimated total records
- Sample data from first few stocks
Args:
force_full: If True, preview full sync from 20180101
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
sample_size: Number of sample stocks to fetch for preview (default: 3)
Returns:
Dictionary with preview information:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full' or 'incremental'
}
"""
print("\n" + "=" * 60)
print("[DataSync] Preview Mode - Analyzing sync requirements...")
print("=" * 60)
# First, ensure trade calendar cache is up-to-date
print("[DataSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# Determine date range
if end_date is None:
end_date = get_today_date()
# Check if sync is needed
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
print("\n" + "=" * 60)
print("[DataSync] Preview Result")
print("=" * 60)
print(" Sync Status: NOT NEEDED")
print(" Reason: Local data is up-to-date with trade calendar")
print("=" * 60)
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
# Use dates from check_sync_needed
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# Determine sync mode
if force_full:
mode = "full"
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
# Get all stock codes
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DataSync] No stocks found to sync")
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
stock_count = len(stock_codes)
print(f"[DataSync] Total stocks to sync: {stock_count}")
# Fetch sample data from first few stocks
print(f"[DataSync] Fetching sample data from {sample_size} stocks...")
sample_data_list = []
sample_codes = stock_codes[:sample_size]
for ts_code in sample_codes:
try:
data = self.client.query(
"pro_bar",
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
factors="tor,vr",
)
if not data.empty:
sample_data_list.append(data)
print(f" - {ts_code}: {len(data)} records")
except Exception as e:
print(f" - {ts_code}: Error fetching - {e}")
# Combine sample data
sample_df = (
pd.concat(sample_data_list, ignore_index=True)
if sample_data_list
else pd.DataFrame()
)
# Estimate total records based on sample
if not sample_df.empty:
avg_records_per_stock = len(sample_df) / len(sample_data_list)
estimated_records = int(avg_records_per_stock * stock_count)
else:
estimated_records = 0
# Display preview results
print("\n" + "=" * 60)
print("[DataSync] Preview Result")
print("=" * 60)
print(f" Sync Mode: {mode.upper()}")
print(f" Date Range: {sync_start_date} to {end_date}")
print(f" Stocks to Sync: {stock_count}")
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
print(f" Estimated Total Records: ~{estimated_records:,}")
if not sample_df.empty:
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
print(" " + "-" * 56)
# Display sample data in a compact format
preview_cols = [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"vol",
]
available_cols = [c for c in preview_cols if c in sample_df.columns]
sample_display = sample_df[available_cols].head(10)
for idx, row in sample_display.iterrows():
print(f" {row.to_dict()}")
print(" " + "-" * 56)
print("=" * 60)
return {
"sync_needed": True,
"stock_count": stock_count,
"start_date": sync_start_date,
"end_date": end_date,
"estimated_records": estimated_records,
"sample_data": sample_df,
"mode": mode,
}
def sync_single_stock(
self,
ts_code: str,
@@ -320,6 +506,7 @@ class DataSync:
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""Sync daily data for all stocks in local storage.
@@ -337,9 +524,10 @@ class DataSync:
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
max_workers: Number of worker threads (default: 10)
dry_run: If True, only preview what would be synced without writing data
Returns:
Dict mapping ts_code to DataFrame (empty if sync skipped)
Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run)
"""
print("\n" + "=" * 60)
print("[DataSync] Starting daily data sync...")
@@ -378,11 +566,14 @@ class DataSync:
# Determine sync mode
if force_full:
mode = "full"
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
# Get all stock codes
@@ -394,6 +585,17 @@ class DataSync:
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
# Handle dry run mode
if dry_run:
print("\n" + "=" * 60)
print("[DataSync] DRY RUN MODE - No data will be written")
print("=" * 60)
print(f" Would sync {len(stock_codes)} stocks")
print(f" Date range: {sync_start_date} to {end_date}")
print(f" Mode: {mode}")
print("=" * 60)
return {}
# Reset stop flag for new sync
self._stop_flag.set()
@@ -492,11 +694,62 @@ class DataSync:
# Convenience functions
def preview_sync(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
max_workers: Optional[int] = None,
) -> dict:
"""Preview sync data volume and samples without actually syncing.
This is the recommended way to check what would be synced before
running the actual synchronization.
Args:
force_full: If True, preview full sync from 20180101
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
sample_size: Number of sample stocks to fetch for preview (default: 3)
max_workers: Number of worker threads (not used in preview, for API compatibility)
Returns:
Dictionary with preview information:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 'partial', or 'none'
}
Example:
>>> # Preview what would be synced
>>> preview = preview_sync()
>>>
>>> # Preview full sync
>>> preview = preview_sync(force_full=True)
>>>
>>> # Preview with more samples
>>> preview = preview_sync(sample_size=5)
"""
sync_manager = DataSync(max_workers=max_workers)
return sync_manager.preview_sync(
force_full=force_full,
start_date=start_date,
end_date=end_date,
sample_size=sample_size,
)
def sync_all(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""Sync daily data for all stocks.
@@ -507,6 +760,7 @@ def sync_all(
start_date: Manual start date (YYYYMMDD)
end_date: Manual end date (defaults to today)
max_workers: Number of worker threads (default: 10)
dry_run: If True, only preview what would be synced without writing data
Returns:
Dict mapping ts_code to DataFrame
@@ -526,12 +780,16 @@ def sync_all(
>>>
>>> # Custom thread count
>>> result = sync_all(max_workers=20)
>>>
>>> # Dry run (preview only)
>>> result = sync_all(dry_run=True)
"""
sync_manager = DataSync(max_workers=max_workers)
return sync_manager.sync_all(
force_full=force_full,
start_date=start_date,
end_date=end_date,
dry_run=dry_run,
)
@@ -540,11 +798,32 @@ if __name__ == "__main__":
print("Data Sync Module")
print("=" * 60)
print("\nUsage:")
print(" from src.data.sync import sync_all")
print(" from src.data.sync import sync_all, preview_sync")
print("")
print(" # Preview before sync (recommended)")
print(" preview = preview_sync()")
print("")
print(" # Dry run (preview only)")
print(" result = sync_all(dry_run=True)")
print("")
print(" # Actual sync")
print(" result = sync_all() # Incremental sync")
print(" result = sync_all(force_full=True) # Full reload")
print("\n" + "=" * 60)
# Run sync
result = sync_all()
print(f"\nSynced {len(result)} stocks")
# Run preview first
print("\n[Main] Running preview first...")
preview = preview_sync()
if preview["sync_needed"]:
# Ask for confirmation
print("\n" + "=" * 60)
response = input("Proceed with sync? (y/n): ").strip().lower()
if response in ("y", "yes"):
print("\n[Main] Starting actual sync...")
result = sync_all()
print(f"\nSynced {len(result)} stocks")
else:
print("\n[Main] Sync cancelled by user")
else:
print("\n[Main] No sync needed - data is up to date")

View File

@@ -5,29 +5,30 @@ Tests the daily interface implementation against api.md requirements:
- tor 换手率
- vr 量比
"""
import pytest
import pandas as pd
from src.data.daily import get_daily
from src.data.api_wrappers import get_daily
# Expected output fields according to api.md
EXPECTED_BASE_FIELDS = [
'ts_code', # 股票代码
'trade_date', # 交易日期
'open', # 开盘价
'high', # 最高价
'low', # 最低价
'close', # 收盘价
'pre_close', # 昨收价
'change', # 涨跌额
'pct_chg', # 涨跌幅
'vol', # 成交量
'amount', # 成交额
"ts_code", # 股票代码
"trade_date", # 交易日期
"open", # 开盘价
"high", # 最高价
"low", # 最低价
"close", # 收盘价
"pre_close", # 昨收价
"change", # 涨跌额
"pct_chg", # 涨跌幅
"vol", # 成交量
"amount", # 成交额
]
EXPECTED_FACTOR_FIELDS = [
'turnover_rate', # 换手率 (tor)
'volume_ratio', # 量比 (vr)
"turnover_rate", # 换手率 (tor)
"volume_ratio", # 量比 (vr)
]
@@ -36,19 +37,19 @@ class TestGetDaily:
def test_fetch_basic(self):
"""Test basic daily data fetch with real API."""
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
result = get_daily("000001.SZ", start_date="20240101", end_date="20240131")
assert isinstance(result, pd.DataFrame)
assert len(result) >= 1
assert result['ts_code'].iloc[0] == '000001.SZ'
assert result["ts_code"].iloc[0] == "000001.SZ"
def test_fetch_with_factors(self):
"""Test fetch with tor and vr factors."""
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
"000001.SZ",
start_date="20240101",
end_date="20240131",
factors=["tor", "vr"],
)
assert isinstance(result, pd.DataFrame)
@@ -61,25 +62,26 @@ class TestGetDaily:
def test_output_fields_completeness(self):
"""Verify all required output fields are returned."""
result = get_daily('600000.SH')
result = get_daily("600000.SH")
# Verify all base fields are present
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), (
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
)
def test_empty_result(self):
"""Test handling of empty results."""
# 使用真实 API 测试无效股票代码的空结果
result = get_daily('INVALID.SZ')
result = get_daily("INVALID.SZ")
assert isinstance(result, pd.DataFrame)
assert result.empty
def test_date_range_query(self):
"""Test query with date range."""
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
"000001.SZ",
start_date="20240101",
end_date="20240131",
)
assert isinstance(result, pd.DataFrame)
@@ -87,7 +89,7 @@ class TestGetDaily:
def test_with_adj(self):
"""Test fetch with adjustment type."""
result = get_daily('000001.SZ', adj='qfq')
result = get_daily("000001.SZ", adj="qfq")
assert isinstance(result, pd.DataFrame)
@@ -95,11 +97,14 @@ class TestGetDaily:
def test_integration():
"""Integration test with real Tushare API (requires valid token)."""
import os
token = os.environ.get('TUSHARE_TOKEN')
token = os.environ.get("TUSHARE_TOKEN")
if not token:
pytest.skip("TUSHARE_TOKEN not configured")
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
result = get_daily(
"000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"]
)
# Verify structure
assert isinstance(result, pd.DataFrame)
@@ -112,6 +117,6 @@ def test_integration():
assert field in result.columns, f"Missing factor field: {field}"
if __name__ == '__main__':
if __name__ == "__main__":
# 运行 pytest 单元测试真实API调用
pytest.main([__file__, '-v'])
pytest.main([__file__, "-v"])

View File

@@ -9,7 +9,7 @@ import pytest
import pandas as pd
from pathlib import Path
from src.data.storage import Storage
from src.data.stock_basic import _get_csv_path
from src.data.api_wrappers.api_stock_basic import _get_csv_path
class TestDailyStorageValidation:

View File

@@ -5,6 +5,7 @@ Tests the sync module's full/incremental sync logic for daily data:
- Incremental sync when local data exists (from last_date + 1)
- Data integrity validation
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch, MagicMock
@@ -17,6 +18,8 @@ from src.data.sync import (
get_next_date,
DEFAULT_START_DATE,
)
from src.data.storage import Storage
from src.data.client import TushareClient
class TestDateUtilities:
@@ -63,30 +66,32 @@ class TestDataSync:
def test_get_all_stock_codes_from_daily(self, mock_storage):
"""Test getting stock codes from daily data."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'],
})
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
}
)
codes = sync.get_all_stock_codes()
assert len(codes) == 2
assert '000001.SZ' in codes
assert '600000.SH' in codes
assert "000001.SZ" in codes
assert "600000.SH" in codes
def test_get_all_stock_codes_fallback(self, mock_storage):
"""Test fallback to stock_basic when daily is empty."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
# First call (daily) returns empty, second call (stock_basic) returns data
mock_storage.load.side_effect = [
pd.DataFrame(), # daily empty
pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
]
codes = sync.get_all_stock_codes()
@@ -95,21 +100,23 @@ class TestDataSync:
def test_get_global_last_date(self, mock_storage):
"""Test getting global last date."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ', '600000.SH'],
'trade_date': ['20240102', '20240103'],
})
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "600000.SH"],
"trade_date": ["20240102", "20240103"],
}
)
last_date = sync.get_global_last_date()
assert last_date == '20240103'
assert last_date == "20240103"
def test_get_global_last_date_empty(self, mock_storage):
"""Test getting last date from empty storage."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
@@ -120,18 +127,23 @@ class TestDataSync:
def test_sync_single_stock(self, mock_storage):
"""Test syncing a single stock."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
})):
with patch("src.data.sync.Storage", return_value=mock_storage):
with patch(
"src.data.sync.get_daily",
return_value=pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
),
):
sync = DataSync()
sync.storage = mock_storage
result = sync.sync_single_stock(
ts_code='000001.SZ',
start_date='20240101',
end_date='20240102',
ts_code="000001.SZ",
start_date="20240101",
end_date="20240102",
)
assert isinstance(result, pd.DataFrame)
@@ -139,15 +151,15 @@ class TestDataSync:
def test_sync_single_stock_empty(self, mock_storage):
"""Test syncing a stock with no data."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
with patch("src.data.sync.Storage", return_value=mock_storage):
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
result = sync.sync_single_stock(
ts_code='INVALID.SZ',
start_date='20240101',
end_date='20240102',
ts_code="INVALID.SZ",
start_date="20240101",
end_date="20240102",
)
assert result.empty
@@ -158,40 +170,46 @@ class TestSyncAll:
def test_full_sync_mode(self, mock_storage):
"""Test full sync mode when force_full=True."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
with patch("src.data.sync.Storage", return_value=mock_storage):
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ'],
})
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
}
)
result = sync.sync_all(force_full=True)
# Verify sync_single_stock was called with default start date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == DEFAULT_START_DATE
assert call_args[1]["start_date"] == DEFAULT_START_DATE
def test_incremental_sync_mode(self, mock_storage):
"""Test incremental sync mode when data exists."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
# Mock existing data with last date
mock_storage.load.side_effect = [
pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
}), # get_all_stock_codes
pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
}), # get_global_last_date
pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
), # get_all_stock_codes
pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
), # get_global_last_date
]
result = sync.sync_all(force_full=False)
@@ -199,28 +217,30 @@ class TestSyncAll:
# Verify sync_single_stock was called with next date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == '20240103'
assert call_args[1]["start_date"] == "20240103"
def test_manual_start_date(self, mock_storage):
"""Test sync with manual start date."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ'],
})
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
}
)
result = sync.sync_all(force_full=False, start_date='20230601')
result = sync.sync_all(force_full=False, start_date="20230601")
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == '20230601'
assert call_args[1]["start_date"] == "20230601"
def test_no_stocks_found(self, mock_storage):
"""Test sync when no stocks are found."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch("src.data.sync.Storage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
@@ -236,7 +256,7 @@ class TestSyncAllConvenienceFunction:
def test_sync_all_function(self):
"""Test sync_all convenience function."""
with patch('src.data.sync.DataSync') as MockSync:
with patch("src.data.sync.DataSync") as MockSync:
mock_instance = Mock()
mock_instance.sync_all.return_value = {}
MockSync.return_value = mock_instance
@@ -251,5 +271,5 @@ class TestSyncAllConvenienceFunction:
)
if __name__ == '__main__':
pytest.main([__file__, '-v'])
if __name__ == "__main__":
pytest.main([__file__, "-v"])

256
tests/test_sync_real.py Normal file
View File

@@ -0,0 +1,256 @@
"""Tests for data sync with REAL data (read-only).
Tests verify:
1. get_global_last_date() correctly reads local data's max date
2. Incremental sync date calculation (local_last_date + 1)
3. Full sync date calculation (20180101)
4. Multi-stock scenario with real data
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
- NO sync_all() calls (writes daily.h5)
- NO check_sync_needed() calls (writes trade_cal.h5)
"""
import pytest
import pandas as pd
from pathlib import Path
from src.data.sync import (
DataSync,
get_next_date,
DEFAULT_START_DATE,
)
from src.data.storage import Storage
class TestDataSyncReadOnly:
"""Read-only tests for data sync - verify date calculation logic."""
@pytest.fixture
def storage(self):
"""Create storage instance."""
return Storage()
@pytest.fixture
def data_sync(self):
"""Create DataSync instance."""
return DataSync()
@pytest.fixture
def daily_exists(self, storage):
"""Check if daily.h5 exists."""
return storage.exists("daily")
def test_daily_h5_exists(self, storage):
"""Verify daily.h5 data file exists before running tests."""
assert storage.exists("daily"), (
"daily.h5 not found. Please run full sync first: "
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
)
def test_get_global_last_date(self, data_sync, daily_exists):
"""Test get_global_last_date returns correct max date from local data."""
if not daily_exists:
pytest.skip("daily.h5 not found")
last_date = data_sync.get_global_last_date()
# Verify it's a valid date string
assert last_date is not None, "get_global_last_date returned None"
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
# Verify by reading storage directly
daily_data = data_sync.storage.load("daily")
expected_max = str(daily_data["trade_date"].max())
assert last_date == expected_max, (
f"get_global_last_date returned {last_date}, "
f"but actual max date is {expected_max}"
)
print(f"[TEST] Local data last date: {last_date}")
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
"""Test incremental sync: start_date = local_last_date + 1.
This verifies that when local data exists, incremental sync should
fetch data from (local_last_date + 1), not from 20180101.
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
# Get local last date
local_last_date = data_sync.get_global_last_date()
assert local_last_date is not None, "No local data found"
# Calculate expected incremental start date
expected_start_date = get_next_date(local_last_date)
# Verify the calculation is correct
local_last_int = int(local_last_date)
expected_int = local_last_int + 1
actual_int = int(expected_start_date)
assert actual_int == expected_int, (
f"Incremental start date calculation error: "
f"expected {expected_int}, got {actual_int}"
)
print(
f"[TEST] Incremental sync: local_last={local_last_date}, "
f"start_date should be {expected_start_date}"
)
# Verify this is NOT 20180101 (would be full sync)
assert expected_start_date != DEFAULT_START_DATE, (
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
)
def test_full_sync_date_calculation(self):
"""Test full sync: start_date = 20180101 when force_full=True.
This verifies that force_full=True always starts from 20180101.
"""
# Full sync should always use DEFAULT_START_DATE
full_sync_start = DEFAULT_START_DATE
assert full_sync_start == "20180101", (
f"Full sync should start from 20180101, got {full_sync_start}"
)
print(f"[TEST] Full sync start date: {full_sync_start}")
def test_date_comparison_logic(self, data_sync, daily_exists):
"""Test date comparison: incremental vs full sync selection logic.
Verify that:
- If local_last_date < today: incremental sync needed
- If local_last_date >= today: no sync needed
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
from datetime import datetime
local_last_date = data_sync.get_global_last_date()
today = datetime.now().strftime("%Y%m%d")
local_last_int = int(local_last_date)
today_int = int(today)
# Log the comparison
print(
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
f"today={today} ({today_int})"
)
# This test just verifies the comparison logic works
if local_last_int < today_int:
print("[TEST] Local data is older than today - sync needed")
# Incremental sync should fetch from local_last_date + 1
sync_start = get_next_date(local_last_date)
assert int(sync_start) > local_last_int, (
"Sync start should be after local last"
)
else:
print("[TEST] Local data is up-to-date - no sync needed")
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
"""Test get_all_stock_codes returns multiple real stock codes."""
if not daily_exists:
pytest.skip("daily.h5 not found")
codes = data_sync.get_all_stock_codes()
# Verify it's a list
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
assert len(codes) > 0, "No stock codes found"
# Verify multiple stocks
assert len(codes) >= 10, (
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
)
# Verify format (should be like 000001.SZ, 600000.SH)
sample_codes = codes[:5]
for code in sample_codes:
assert "." in code, f"Invalid stock code format: {code}"
suffix = code.split(".")[-1]
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
def test_multi_stock_date_range(self, data_sync, daily_exists):
"""Test that multiple stocks share the same date range in local data.
This verifies that local data has consistent date coverage across stocks.
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
daily_data = data_sync.storage.load("daily")
# Get date range for each stock
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
# Get global min and max
global_min = str(daily_data["trade_date"].min())
global_max = str(daily_data["trade_date"].max())
print(f"[TEST] Global date range: {global_min} to {global_max}")
print(f"[TEST] Total stocks: {len(stock_dates)}")
# Verify we have data for multiple stocks
assert len(stock_dates) >= 10, (
f"Expected at least 10 stocks, got {len(stock_dates)}"
)
# Verify date range is reasonable (at least 1 year of data)
global_min_int = int(global_min)
global_max_int = int(global_max)
days_span = global_max_int - global_min_int
assert days_span > 100, (
f"Date range too small: {days_span} days. "
f"Expected at least 100 days of data."
)
print(f"[TEST] Date span: {days_span} days")
class TestDateUtilities:
"""Test date utility functions."""
def test_get_next_date(self):
"""Test get_next_date correctly calculates next day."""
# Test normal cases
assert get_next_date("20240101") == "20240102"
assert get_next_date("20240131") == "20240201" # Month boundary
assert get_next_date("20241231") == "20250101" # Year boundary
def test_incremental_vs_full_sync_logic(self):
"""Test the logic difference between incremental and full sync.
Incremental: start_date = local_last_date + 1
Full: start_date = 20180101
"""
# Scenario 1: Local data exists
local_last_date = "20240115"
incremental_start = get_next_date(local_last_date)
assert incremental_start == "20240116"
assert incremental_start != DEFAULT_START_DATE
# Scenario 2: Force full sync
full_sync_start = DEFAULT_START_DATE # "20180101"
assert full_sync_start == "20180101"
assert incremental_start != full_sync_start
print("[TEST] Incremental vs Full sync logic verified")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])