Compare commits
3 Commits
05228ce9de
...
9965ce5706
| Author | SHA1 | Date | |
|---|---|---|---|
| 9965ce5706 | |||
| e81d39ae0d | |||
| 8fc88b60e3 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -46,6 +46,7 @@ logs/
|
||||
# IDE和编辑器
|
||||
.vscode/
|
||||
.idea/
|
||||
.opencode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
pythonpath = .
|
||||
testpaths = tests
|
||||
@@ -6,7 +6,7 @@ Provides simplified interfaces for fetching and storing Tushare data.
|
||||
from src.data.config import Config, get_config
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.stock_basic import get_stock_basic, sync_all_stocks
|
||||
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
|
||||
244
src/data/api_wrappers/API_INTERFACE_SPEC.md
Normal file
244
src/data/api_wrappers/API_INTERFACE_SPEC.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# ProStock 数据接口封装规范
|
||||
|
||||
## 1. 概述
|
||||
|
||||
本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保:
|
||||
- 代码风格统一
|
||||
- 自动 sync 支持
|
||||
- 增量更新逻辑一致
|
||||
- 减少存储写入压力
|
||||
|
||||
## 2. 接口分类
|
||||
|
||||
### 2.1 特殊接口(不参与统一 sync)
|
||||
|
||||
以下接口有独立的同步逻辑,不参与自动 sync 机制:
|
||||
|
||||
| 接口类型 | 示例 | 说明 |
|
||||
|---------|------|------|
|
||||
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
|
||||
| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 |
|
||||
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
|
||||
|
||||
### 2.2 标准接口(必须遵循本规范)
|
||||
|
||||
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。
|
||||
|
||||
## 3. 文件结构要求
|
||||
|
||||
### 3.1 文件命名
|
||||
|
||||
```
|
||||
{data_type}.py
|
||||
```
|
||||
|
||||
示例:`daily.py`、`moneyflow.py`、`limit_list.py`
|
||||
|
||||
### 3.2 文件位置
|
||||
|
||||
所有接口文件必须位于 `src/data/` 目录下。
|
||||
|
||||
### 3.3 导出要求
|
||||
|
||||
新接口必须在 `src/data/__init__.py` 中导出:
|
||||
|
||||
```python
|
||||
from src.data.{module_name} import get_{data_type}
|
||||
|
||||
__all__ = [
|
||||
# ... 其他导出 ...
|
||||
"get_{data_type}",
|
||||
]
|
||||
```
|
||||
|
||||
## 4. 接口设计规范
|
||||
|
||||
### 4.1 数据获取函数签名要求
|
||||
|
||||
函数必须返回 `pd.DataFrame`,参数必须包含以下之一:
|
||||
|
||||
#### 4.1.1 按日期获取的接口(优先)
|
||||
|
||||
适用于:涨跌停、龙虎榜、筹码分布等。
|
||||
|
||||
**函数签名要求**:
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
trade_date: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
```
|
||||
|
||||
**要求**:
|
||||
- 优先使用 `trade_date` 获取单日全市场数据
|
||||
- 支持 `start_date + end_date` 获取区间数据
|
||||
- `ts_code` 作为可选过滤参数
|
||||
|
||||
#### 4.1.2 按股票获取的接口
|
||||
|
||||
适用于:日线行情、资金流向等。
|
||||
|
||||
**函数签名要求**:
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
```
|
||||
|
||||
### 4.2 文档字符串要求
|
||||
|
||||
函数必须包含 Google 风格的完整文档字符串,包含:
|
||||
- 函数功能描述
|
||||
- `Args` 部分:所有参数说明
|
||||
- `Returns` 部分:返回的 DataFrame 包含的字段说明
|
||||
- `Example` 部分:使用示例
|
||||
|
||||
### 4.3 日期格式要求
|
||||
|
||||
- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式
|
||||
- 统一使用 `trade_date` 作为日期字段名
|
||||
- 如果 API 返回其他日期字段名(如 `date`、`end_date`),必须在返回前重命名为 `trade_date`
|
||||
|
||||
### 4.4 股票代码要求
|
||||
|
||||
- 统一使用 `ts_code` 作为股票代码字段名
|
||||
- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH`
|
||||
|
||||
### 4.5 令牌桶限速要求
|
||||
|
||||
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
|
||||
|
||||
## 5. Sync 集成规范
|
||||
|
||||
### 5.1 DATASET_CONFIG 注册要求
|
||||
|
||||
新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项:
|
||||
|
||||
```python
|
||||
"{new_data_type}": {
|
||||
"api_name": "{tushare_api_name}", # Tushare API 名称
|
||||
"fetch_by": "date", # "date" 或 "stock"
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 fetch_by 取值规则
|
||||
|
||||
- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据
|
||||
- 仅当 API 不支持按日期获取时才使用 `"stock"`
|
||||
|
||||
### 5.3 sync 方法要求
|
||||
|
||||
必须实现对应的 sync 方法或复用通用方法:
|
||||
|
||||
```python
|
||||
def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述}。"""
|
||||
return self.sync_dataset("{data_type}", force_full)
|
||||
```
|
||||
|
||||
同时提供便捷函数:
|
||||
|
||||
```python
|
||||
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述}。"""
|
||||
sync_manager = DataSync()
|
||||
return sync_manager.sync_{data_type}(force_full)
|
||||
```
|
||||
|
||||
### 5.4 增量更新要求
|
||||
|
||||
- 必须实现增量更新逻辑(自动检查本地最新日期)
|
||||
- 使用 `force_full` 参数支持强制全量同步
|
||||
|
||||
## 6. 存储规范
|
||||
|
||||
### 6.1 存储方式
|
||||
|
||||
所有数据通过 `Storage` 类进行 HDF5 存储。
|
||||
|
||||
### 6.2 写入策略
|
||||
|
||||
**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。
|
||||
|
||||
### 6.3 去重要求
|
||||
|
||||
使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`。
|
||||
|
||||
## 7. 测试规范
|
||||
|
||||
### 7.1 测试文件要求
|
||||
|
||||
必须创建对应的测试文件:`tests/test_{data_type}.py`
|
||||
|
||||
### 7.2 测试覆盖要求
|
||||
|
||||
- 测试按日期获取
|
||||
- 测试按股票获取(如果支持)
|
||||
- 必须 mock `TushareClient`
|
||||
- 测试覆盖正常和异常情况
|
||||
|
||||
## 8. 新增接口完整流程
|
||||
|
||||
### 8.1 创建接口文件
|
||||
|
||||
1. 在 `src/data/` 下创建 `{data_type}.py`
|
||||
2. 实现数据获取函数,遵循第 4 节规范
|
||||
|
||||
### 8.2 注册 sync 支持
|
||||
|
||||
1. 在 `sync.py` 的 `DataSync.DATASET_CONFIG` 中注册
|
||||
2. 实现对应的 sync 方法
|
||||
3. 提供便捷函数
|
||||
|
||||
### 8.3 更新导出
|
||||
|
||||
在 `src/data/__init__.py` 中导出接口函数。
|
||||
|
||||
### 8.4 创建测试
|
||||
|
||||
创建 `tests/test_{data_type}.py`,覆盖关键场景。
|
||||
|
||||
## 9. 检查清单
|
||||
|
||||
### 9.1 文件结构
|
||||
- [ ] 文件位于 `src/data/{data_type}.py`
|
||||
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
|
||||
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
|
||||
|
||||
### 9.2 接口实现
|
||||
- [ ] 数据获取函数使用 `TushareClient`
|
||||
- [ ] 函数包含完整的 Google 风格文档字符串
|
||||
- [ ] 日期参数使用 `YYYYMMDD` 格式
|
||||
- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段
|
||||
- [ ] 优先实现按日期获取的接口(如果 API 支持)
|
||||
|
||||
### 9.3 Sync 集成
|
||||
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
|
||||
- [ ] 正确设置 `fetch_by`("date" 或 "stock")
|
||||
- [ ] 正确设置 `date_field` 和 `key_fields`
|
||||
- [ ] 已实现对应的 sync 方法或复用通用方法
|
||||
- [ ] 增量更新逻辑正确(检查本地最新日期)
|
||||
|
||||
### 9.4 存储优化
|
||||
- [ ] 所有数据一次性写入(非逐条)
|
||||
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
|
||||
- [ ] 去重字段配置正确
|
||||
|
||||
### 9.5 测试
|
||||
- [ ] 已编写单元测试
|
||||
- [ ] 已 mock TushareClient
|
||||
- [ ] 测试覆盖正常和异常情况
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-01
|
||||
40
src/data/api_wrappers/__init__.py
Normal file
40
src/data/api_wrappers/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Tushare API wrapper modules.
|
||||
|
||||
This package contains simplified interfaces for fetching data from Tushare API.
|
||||
All wrapper files follow the naming convention: api_{data_type}.py
|
||||
|
||||
Available APIs:
|
||||
- api_daily: Daily market data (日线行情)
|
||||
- api_stock_basic: Stock basic information (股票基本信息)
|
||||
- api_trade_cal: Trading calendar (交易日历)
|
||||
|
||||
Example:
|
||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal
|
||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> stocks = get_stock_basic()
|
||||
>>> calendar = get_trade_cal('20240101', '20240131')
|
||||
"""
|
||||
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_trade_cal,
|
||||
get_trading_days,
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Daily market data
|
||||
"get_daily",
|
||||
# Stock basic information
|
||||
"get_stock_basic",
|
||||
"sync_all_stocks",
|
||||
# Trade calendar
|
||||
"get_trade_cal",
|
||||
"get_trading_days",
|
||||
"get_first_trading_day",
|
||||
"get_last_trading_day",
|
||||
"sync_trade_cal_cache",
|
||||
]
|
||||
@@ -179,4 +179,75 @@ df = pro.query('trade_cal', start_date='20180101', end_date='20181231')
|
||||
17 SSE 20180118 1
|
||||
18 SSE 20180119 1
|
||||
19 SSE 20180120 0
|
||||
20 SSE 20180121 0
|
||||
20 SSE 20180121 0
|
||||
|
||||
|
||||
每日指标
|
||||
接口:daily_basic,可以通过数据工具调试和查看数据。
|
||||
更新时间:交易日每日15点~17点之间
|
||||
描述:获取全部股票每日重要的基本面指标,可用于选股分析、报表展示等。单次请求最大返回6000条数据,可按日线循环提取全部历史。
|
||||
积分:至少2000积分才可以调取,5000积分无总量限制,具体请参阅积分获取办法
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str Y 股票代码(二选一)
|
||||
trade_date str N 交易日期 (二选一)
|
||||
start_date str N 开始日期(YYYYMMDD)
|
||||
end_date str N 结束日期(YYYYMMDD)
|
||||
注:日期都填YYYYMMDD格式,比如20181010
|
||||
|
||||
输出参数
|
||||
|
||||
名称 类型 描述
|
||||
ts_code str TS股票代码
|
||||
trade_date str 交易日期
|
||||
close float 当日收盘价
|
||||
turnover_rate float 换手率(%)
|
||||
turnover_rate_f float 换手率(自由流通股)
|
||||
volume_ratio float 量比
|
||||
pe float 市盈率(总市值/净利润, 亏损的PE为空)
|
||||
pe_ttm float 市盈率(TTM,亏损的PE为空)
|
||||
pb float 市净率(总市值/净资产)
|
||||
ps float 市销率
|
||||
ps_ttm float 市销率(TTM)
|
||||
dv_ratio float 股息率 (%)
|
||||
dv_ttm float 股息率(TTM)(%)
|
||||
total_share float 总股本 (万股)
|
||||
float_share float 流通股本 (万股)
|
||||
free_share float 自由流通股本 (万)
|
||||
total_mv float 总市值 (万元)
|
||||
circ_mv float 流通市值(万元)
|
||||
接口用法
|
||||
|
||||
|
||||
pro = ts.pro_api()
|
||||
|
||||
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||
或者
|
||||
|
||||
|
||||
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||
数据样例
|
||||
|
||||
ts_code trade_date turnover_rate volume_ratio pe pb
|
||||
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
|
||||
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
|
||||
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
|
||||
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
|
||||
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
|
||||
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
|
||||
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
|
||||
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
|
||||
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
|
||||
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
|
||||
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
|
||||
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
|
||||
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
|
||||
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
|
||||
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
|
||||
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
|
||||
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
|
||||
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
|
||||
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
|
||||
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
|
||||
@@ -3,6 +3,7 @@
|
||||
A single function to fetch A股日线行情 data from Tushare.
|
||||
Supports all output fields including tor (换手率) and vr (量比).
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal
|
||||
from src.data.client import TushareClient
|
||||
@@ -33,7 +34,7 @@ def get_daily(
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with daily market data containing:
|
||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||
change, pct_chg, vol, amount
|
||||
- Factor fields (if requested): tor, vr
|
||||
- Adjustment factor (if adjfactor=True): adjfactor
|
||||
@@ -3,6 +3,7 @@
|
||||
Fetch basic stock information including code, name, listing date, etc.
|
||||
This is a special interface - call once to get all stocks (listed and delisted).
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Simplified HDF5 storage for data persistence."""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -47,7 +48,9 @@ class Storage:
|
||||
# Merge with existing data
|
||||
existing = store[name]
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], keep="last"
|
||||
)
|
||||
store.put(name, combined, format="table")
|
||||
|
||||
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
||||
@@ -57,10 +60,13 @@ class Storage:
|
||||
print(f"[Storage] Error saving {name}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def load(self, name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None) -> pd.DataFrame:
|
||||
def load(
|
||||
self,
|
||||
name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Load data from HDF5 file.
|
||||
|
||||
Args:
|
||||
@@ -80,14 +86,25 @@ class Storage:
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="r") as store:
|
||||
if name not in store.keys():
|
||||
keys = store.keys()
|
||||
# Handle both '/daily' and 'daily' keys
|
||||
actual_key = None
|
||||
if name in keys:
|
||||
actual_key = name
|
||||
elif f"/{name}" in keys:
|
||||
actual_key = f"/{name}"
|
||||
|
||||
if actual_key is None:
|
||||
return pd.DataFrame()
|
||||
|
||||
data = store[name]
|
||||
data = store[actual_key]
|
||||
|
||||
# Apply filters
|
||||
if start_date and end_date and "trade_date" in data.columns:
|
||||
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
|
||||
data = data[
|
||||
(data["trade_date"] >= start_date)
|
||||
& (data["trade_date"] <= end_date)
|
||||
]
|
||||
|
||||
if ts_code and "ts_code" in data.columns:
|
||||
data = data[data["ts_code"] == ts_code]
|
||||
|
||||
295
src/data/sync.py
295
src/data/sync.py
@@ -5,11 +5,15 @@ This module provides data fetching functions with intelligent sync logic:
|
||||
- If local file exists: incremental update (fetch from latest date + 1 day)
|
||||
- Multi-threaded concurrent fetching for improved performance
|
||||
- Stop immediately on any exception
|
||||
- Preview mode: check data volume and samples before actual sync
|
||||
|
||||
Currently supported data types:
|
||||
- daily: Daily market data (with turnover rate and volume ratio)
|
||||
|
||||
Usage:
|
||||
# Preview sync (check data volume and samples without writing)
|
||||
preview_sync()
|
||||
|
||||
# Sync all stocks (full load)
|
||||
sync_all()
|
||||
|
||||
@@ -18,6 +22,9 @@ Usage:
|
||||
|
||||
# Force full reload
|
||||
sync_all(force_full=True)
|
||||
|
||||
# Dry run (preview only, no write)
|
||||
sync_all(dry_run=True)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
@@ -30,8 +37,8 @@ import sys
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.daily import get_daily
|
||||
from src.data.trade_cal import (
|
||||
from src.data.api_wrappers import get_daily
|
||||
from src.data.api_wrappers import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
@@ -114,7 +121,8 @@ class DataSync:
|
||||
List of stock codes
|
||||
"""
|
||||
# Import sync_all_stocks here to avoid circular imports
|
||||
from src.data.stock_basic import sync_all_stocks, _get_csv_path
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||
|
||||
# First, ensure stock_basic.csv is up-to-date with all stocks
|
||||
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
|
||||
@@ -278,6 +286,184 @@ class DataSync:
|
||||
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""Preview sync data volume and samples without actually syncing.
|
||||
|
||||
This method provides a preview of what would be synced, including:
|
||||
- Number of stocks to be synced
|
||||
- Date range for sync
|
||||
- Estimated total records
|
||||
- Sample data from first few stocks
|
||||
|
||||
Args:
|
||||
force_full: If True, preview full sync from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||
|
||||
Returns:
|
||||
Dictionary with preview information:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' or 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# First, ensure trade calendar cache is up-to-date
|
||||
print("[DataSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Determine date range
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Check if sync is needed
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# Use dates from check_sync_needed
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DataSync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[DataSync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# Fetch sample data from first few stocks
|
||||
print(f"[DataSync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# Combine sample data
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# Estimate total records based on sample
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# Display preview results
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# Display sample data in a compact format
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
@@ -320,6 +506,7 @@ class DataSync:
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks in local storage.
|
||||
|
||||
@@ -337,9 +524,10 @@ class DataSync:
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
dry_run: If True, only preview what would be synced without writing data
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame (empty if sync skipped)
|
||||
Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Starting daily data sync...")
|
||||
@@ -378,11 +566,14 @@ class DataSync:
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
@@ -394,6 +585,17 @@ class DataSync:
|
||||
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# Handle dry run mode
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# Reset stop flag for new sync
|
||||
self._stop_flag.set()
|
||||
|
||||
@@ -492,11 +694,62 @@ class DataSync:
|
||||
# Convenience functions
|
||||
|
||||
|
||||
def preview_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Preview sync data volume and samples without actually syncing.
|
||||
|
||||
This is the recommended way to check what would be synced before
|
||||
running the actual synchronization.
|
||||
|
||||
Args:
|
||||
force_full: If True, preview full sync from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||
max_workers: Number of worker threads (not used in preview, for API compatibility)
|
||||
|
||||
Returns:
|
||||
Dictionary with preview information:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', or 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # Preview what would be synced
|
||||
>>> preview = preview_sync()
|
||||
>>>
|
||||
>>> # Preview full sync
|
||||
>>> preview = preview_sync(force_full=True)
|
||||
>>>
|
||||
>>> # Preview with more samples
|
||||
>>> preview = preview_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.preview_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
|
||||
|
||||
def sync_all(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks.
|
||||
|
||||
@@ -507,6 +760,7 @@ def sync_all(
|
||||
start_date: Manual start date (YYYYMMDD)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
dry_run: If True, only preview what would be synced without writing data
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame
|
||||
@@ -526,12 +780,16 @@ def sync_all(
|
||||
>>>
|
||||
>>> # Custom thread count
|
||||
>>> result = sync_all(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run (preview only)
|
||||
>>> result = sync_all(dry_run=True)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
@@ -540,11 +798,32 @@ if __name__ == "__main__":
|
||||
print("Data Sync Module")
|
||||
print("=" * 60)
|
||||
print("\nUsage:")
|
||||
print(" from src.data.sync import sync_all")
|
||||
print(" from src.data.sync import sync_all, preview_sync")
|
||||
print("")
|
||||
print(" # Preview before sync (recommended)")
|
||||
print(" preview = preview_sync()")
|
||||
print("")
|
||||
print(" # Dry run (preview only)")
|
||||
print(" result = sync_all(dry_run=True)")
|
||||
print("")
|
||||
print(" # Actual sync")
|
||||
print(" result = sync_all() # Incremental sync")
|
||||
print(" result = sync_all(force_full=True) # Full reload")
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
# Run sync
|
||||
result = sync_all()
|
||||
print(f"\nSynced {len(result)} stocks")
|
||||
# Run preview first
|
||||
print("\n[Main] Running preview first...")
|
||||
preview = preview_sync()
|
||||
|
||||
if preview["sync_needed"]:
|
||||
# Ask for confirmation
|
||||
print("\n" + "=" * 60)
|
||||
response = input("Proceed with sync? (y/n): ").strip().lower()
|
||||
if response in ("y", "yes"):
|
||||
print("\n[Main] Starting actual sync...")
|
||||
result = sync_all()
|
||||
print(f"\nSynced {len(result)} stocks")
|
||||
else:
|
||||
print("\n[Main] Sync cancelled by user")
|
||||
else:
|
||||
print("\n[Main] No sync needed - data is up to date")
|
||||
|
||||
@@ -5,29 +5,30 @@ Tests the daily interface implementation against api.md requirements:
|
||||
- tor 换手率
|
||||
- vr 量比
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from src.data.daily import get_daily
|
||||
from src.data.api_wrappers import get_daily
|
||||
|
||||
|
||||
# Expected output fields according to api.md
|
||||
EXPECTED_BASE_FIELDS = [
|
||||
'ts_code', # 股票代码
|
||||
'trade_date', # 交易日期
|
||||
'open', # 开盘价
|
||||
'high', # 最高价
|
||||
'low', # 最低价
|
||||
'close', # 收盘价
|
||||
'pre_close', # 昨收价
|
||||
'change', # 涨跌额
|
||||
'pct_chg', # 涨跌幅
|
||||
'vol', # 成交量
|
||||
'amount', # 成交额
|
||||
"ts_code", # 股票代码
|
||||
"trade_date", # 交易日期
|
||||
"open", # 开盘价
|
||||
"high", # 最高价
|
||||
"low", # 最低价
|
||||
"close", # 收盘价
|
||||
"pre_close", # 昨收价
|
||||
"change", # 涨跌额
|
||||
"pct_chg", # 涨跌幅
|
||||
"vol", # 成交量
|
||||
"amount", # 成交额
|
||||
]
|
||||
|
||||
EXPECTED_FACTOR_FIELDS = [
|
||||
'turnover_rate', # 换手率 (tor)
|
||||
'volume_ratio', # 量比 (vr)
|
||||
"turnover_rate", # 换手率 (tor)
|
||||
"volume_ratio", # 量比 (vr)
|
||||
]
|
||||
|
||||
|
||||
@@ -36,19 +37,19 @@ class TestGetDaily:
|
||||
|
||||
def test_fetch_basic(self):
|
||||
"""Test basic daily data fetch with real API."""
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
result = get_daily("000001.SZ", start_date="20240101", end_date="20240131")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) >= 1
|
||||
assert result['ts_code'].iloc[0] == '000001.SZ'
|
||||
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||
|
||||
def test_fetch_with_factors(self):
|
||||
"""Test fetch with tor and vr factors."""
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
factors=['tor', 'vr'],
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=["tor", "vr"],
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
@@ -61,25 +62,26 @@ class TestGetDaily:
|
||||
|
||||
def test_output_fields_completeness(self):
|
||||
"""Verify all required output fields are returned."""
|
||||
result = get_daily('600000.SH')
|
||||
result = get_daily("600000.SH")
|
||||
|
||||
# Verify all base fields are present
|
||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
|
||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), (
|
||||
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
|
||||
)
|
||||
|
||||
def test_empty_result(self):
|
||||
"""Test handling of empty results."""
|
||||
# 使用真实 API 测试无效股票代码的空结果
|
||||
result = get_daily('INVALID.SZ')
|
||||
result = get_daily("INVALID.SZ")
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert result.empty
|
||||
|
||||
def test_date_range_query(self):
|
||||
"""Test query with date range."""
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
@@ -87,7 +89,7 @@ class TestGetDaily:
|
||||
|
||||
def test_with_adj(self):
|
||||
"""Test fetch with adjustment type."""
|
||||
result = get_daily('000001.SZ', adj='qfq')
|
||||
result = get_daily("000001.SZ", adj="qfq")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
|
||||
@@ -95,11 +97,14 @@ class TestGetDaily:
|
||||
def test_integration():
|
||||
"""Integration test with real Tushare API (requires valid token)."""
|
||||
import os
|
||||
token = os.environ.get('TUSHARE_TOKEN')
|
||||
|
||||
token = os.environ.get("TUSHARE_TOKEN")
|
||||
if not token:
|
||||
pytest.skip("TUSHARE_TOKEN not configured")
|
||||
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
|
||||
result = get_daily(
|
||||
"000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"]
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
@@ -112,6 +117,6 @@ def test_integration():
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# 运行 pytest 单元测试(真实API调用)
|
||||
pytest.main([__file__, '-v'])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from src.data.storage import Storage
|
||||
from src.data.stock_basic import _get_csv_path
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||
|
||||
|
||||
class TestDailyStorageValidation:
|
||||
|
||||
@@ -5,6 +5,7 @@ Tests the sync module's full/incremental sync logic for daily data:
|
||||
- Incremental sync when local data exists (from last_date + 1)
|
||||
- Data integrity validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
@@ -17,6 +18,8 @@ from src.data.sync import (
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import Storage
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
class TestDateUtilities:
|
||||
@@ -63,30 +66,32 @@ class TestDataSync:
|
||||
|
||||
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
||||
"""Test getting stock codes from daily data."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'],
|
||||
})
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
|
||||
}
|
||||
)
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
assert '000001.SZ' in codes
|
||||
assert '600000.SH' in codes
|
||||
assert "000001.SZ" in codes
|
||||
assert "600000.SH" in codes
|
||||
|
||||
def test_get_all_stock_codes_fallback(self, mock_storage):
|
||||
"""Test fallback to stock_basic when daily is empty."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
# First call (daily) returns empty, second call (stock_basic) returns data
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(), # daily empty
|
||||
pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic
|
||||
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
|
||||
]
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
@@ -95,21 +100,23 @@ class TestDataSync:
|
||||
|
||||
def test_get_global_last_date(self, mock_storage):
|
||||
"""Test getting global last date."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '600000.SH'],
|
||||
'trade_date': ['20240102', '20240103'],
|
||||
})
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "600000.SH"],
|
||||
"trade_date": ["20240102", "20240103"],
|
||||
}
|
||||
)
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date == '20240103'
|
||||
assert last_date == "20240103"
|
||||
|
||||
def test_get_global_last_date_empty(self, mock_storage):
|
||||
"""Test getting last date from empty storage."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
@@ -120,18 +127,23 @@ class TestDataSync:
|
||||
|
||||
def test_sync_single_stock(self, mock_storage):
|
||||
"""Test syncing a single stock."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
})):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
with patch(
|
||||
"src.data.sync.get_daily",
|
||||
return_value=pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
),
|
||||
):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code='000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240102',
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
@@ -139,15 +151,15 @@ class TestDataSync:
|
||||
|
||||
def test_sync_single_stock_empty(self, mock_storage):
|
||||
"""Test syncing a stock with no data."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code='INVALID.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240102',
|
||||
ts_code="INVALID.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert result.empty
|
||||
@@ -158,40 +170,46 @@ class TestSyncAll:
|
||||
|
||||
def test_full_sync_mode(self, mock_storage):
|
||||
"""Test full sync mode when force_full=True."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
})
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
|
||||
result = sync.sync_all(force_full=True)
|
||||
|
||||
# Verify sync_single_stock was called with default start date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == DEFAULT_START_DATE
|
||||
assert call_args[1]["start_date"] == DEFAULT_START_DATE
|
||||
|
||||
def test_incremental_sync_mode(self, mock_storage):
|
||||
"""Test incremental sync mode when data exists."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
# Mock existing data with last date
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
}), # get_all_stock_codes
|
||||
pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
}), # get_global_last_date
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_all_stock_codes
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_global_last_date
|
||||
]
|
||||
|
||||
result = sync.sync_all(force_full=False)
|
||||
@@ -199,28 +217,30 @@ class TestSyncAll:
|
||||
# Verify sync_single_stock was called with next date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == '20240103'
|
||||
assert call_args[1]["start_date"] == "20240103"
|
||||
|
||||
def test_manual_start_date(self, mock_storage):
|
||||
"""Test sync with manual start date."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
})
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
|
||||
result = sync.sync_all(force_full=False, start_date='20230601')
|
||||
result = sync.sync_all(force_full=False, start_date="20230601")
|
||||
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == '20230601'
|
||||
assert call_args[1]["start_date"] == "20230601"
|
||||
|
||||
def test_no_stocks_found(self, mock_storage):
|
||||
"""Test sync when no stocks are found."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
@@ -236,7 +256,7 @@ class TestSyncAllConvenienceFunction:
|
||||
|
||||
def test_sync_all_function(self):
|
||||
"""Test sync_all convenience function."""
|
||||
with patch('src.data.sync.DataSync') as MockSync:
|
||||
with patch("src.data.sync.DataSync") as MockSync:
|
||||
mock_instance = Mock()
|
||||
mock_instance.sync_all.return_value = {}
|
||||
MockSync.return_value = mock_instance
|
||||
@@ -251,5 +271,5 @@ class TestSyncAllConvenienceFunction:
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
256
tests/test_sync_real.py
Normal file
256
tests/test_sync_real.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Tests for data sync with REAL data (read-only).
|
||||
|
||||
Tests verify:
|
||||
1. get_global_last_date() correctly reads local data's max date
|
||||
2. Incremental sync date calculation (local_last_date + 1)
|
||||
3. Full sync date calculation (20180101)
|
||||
4. Multi-stock scenario with real data
|
||||
|
||||
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
|
||||
- NO sync_all() calls (writes daily.h5)
|
||||
- NO check_sync_needed() calls (writes trade_cal.h5)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from src.data.sync import (
|
||||
DataSync,
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import Storage
|
||||
|
||||
|
||||
class TestDataSyncReadOnly:
|
||||
"""Read-only tests for data sync - verify date calculation logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self):
|
||||
"""Create storage instance."""
|
||||
return Storage()
|
||||
|
||||
@pytest.fixture
|
||||
def data_sync(self):
|
||||
"""Create DataSync instance."""
|
||||
return DataSync()
|
||||
|
||||
@pytest.fixture
|
||||
def daily_exists(self, storage):
|
||||
"""Check if daily.h5 exists."""
|
||||
return storage.exists("daily")
|
||||
|
||||
def test_daily_h5_exists(self, storage):
|
||||
"""Verify daily.h5 data file exists before running tests."""
|
||||
assert storage.exists("daily"), (
|
||||
"daily.h5 not found. Please run full sync first: "
|
||||
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
|
||||
)
|
||||
|
||||
def test_get_global_last_date(self, data_sync, daily_exists):
|
||||
"""Test get_global_last_date returns correct max date from local data."""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
last_date = data_sync.get_global_last_date()
|
||||
|
||||
# Verify it's a valid date string
|
||||
assert last_date is not None, "get_global_last_date returned None"
|
||||
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
|
||||
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
|
||||
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
|
||||
|
||||
# Verify by reading storage directly
|
||||
daily_data = data_sync.storage.load("daily")
|
||||
expected_max = str(daily_data["trade_date"].max())
|
||||
|
||||
assert last_date == expected_max, (
|
||||
f"get_global_last_date returned {last_date}, "
|
||||
f"but actual max date is {expected_max}"
|
||||
)
|
||||
|
||||
print(f"[TEST] Local data last date: {last_date}")
|
||||
|
||||
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
|
||||
"""Test incremental sync: start_date = local_last_date + 1.
|
||||
|
||||
This verifies that when local data exists, incremental sync should
|
||||
fetch data from (local_last_date + 1), not from 20180101.
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
# Get local last date
|
||||
local_last_date = data_sync.get_global_last_date()
|
||||
assert local_last_date is not None, "No local data found"
|
||||
|
||||
# Calculate expected incremental start date
|
||||
expected_start_date = get_next_date(local_last_date)
|
||||
|
||||
# Verify the calculation is correct
|
||||
local_last_int = int(local_last_date)
|
||||
expected_int = local_last_int + 1
|
||||
actual_int = int(expected_start_date)
|
||||
|
||||
assert actual_int == expected_int, (
|
||||
f"Incremental start date calculation error: "
|
||||
f"expected {expected_int}, got {actual_int}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[TEST] Incremental sync: local_last={local_last_date}, "
|
||||
f"start_date should be {expected_start_date}"
|
||||
)
|
||||
|
||||
# Verify this is NOT 20180101 (would be full sync)
|
||||
assert expected_start_date != DEFAULT_START_DATE, (
|
||||
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
|
||||
)
|
||||
|
||||
def test_full_sync_date_calculation(self):
|
||||
"""Test full sync: start_date = 20180101 when force_full=True.
|
||||
|
||||
This verifies that force_full=True always starts from 20180101.
|
||||
"""
|
||||
# Full sync should always use DEFAULT_START_DATE
|
||||
full_sync_start = DEFAULT_START_DATE
|
||||
|
||||
assert full_sync_start == "20180101", (
|
||||
f"Full sync should start from 20180101, got {full_sync_start}"
|
||||
)
|
||||
|
||||
print(f"[TEST] Full sync start date: {full_sync_start}")
|
||||
|
||||
def test_date_comparison_logic(self, data_sync, daily_exists):
|
||||
"""Test date comparison: incremental vs full sync selection logic.
|
||||
|
||||
Verify that:
|
||||
- If local_last_date < today: incremental sync needed
|
||||
- If local_last_date >= today: no sync needed
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
local_last_date = data_sync.get_global_last_date()
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
local_last_int = int(local_last_date)
|
||||
today_int = int(today)
|
||||
|
||||
# Log the comparison
|
||||
print(
|
||||
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
|
||||
f"today={today} ({today_int})"
|
||||
)
|
||||
|
||||
# This test just verifies the comparison logic works
|
||||
if local_last_int < today_int:
|
||||
print("[TEST] Local data is older than today - sync needed")
|
||||
# Incremental sync should fetch from local_last_date + 1
|
||||
sync_start = get_next_date(local_last_date)
|
||||
assert int(sync_start) > local_last_int, (
|
||||
"Sync start should be after local last"
|
||||
)
|
||||
else:
|
||||
print("[TEST] Local data is up-to-date - no sync needed")
|
||||
|
||||
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
|
||||
"""Test get_all_stock_codes returns multiple real stock codes."""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
codes = data_sync.get_all_stock_codes()
|
||||
|
||||
# Verify it's a list
|
||||
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
|
||||
assert len(codes) > 0, "No stock codes found"
|
||||
|
||||
# Verify multiple stocks
|
||||
assert len(codes) >= 10, (
|
||||
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
|
||||
)
|
||||
|
||||
# Verify format (should be like 000001.SZ, 600000.SH)
|
||||
sample_codes = codes[:5]
|
||||
for code in sample_codes:
|
||||
assert "." in code, f"Invalid stock code format: {code}"
|
||||
suffix = code.split(".")[-1]
|
||||
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
|
||||
|
||||
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
|
||||
|
||||
def test_multi_stock_date_range(self, data_sync, daily_exists):
|
||||
"""Test that multiple stocks share the same date range in local data.
|
||||
|
||||
This verifies that local data has consistent date coverage across stocks.
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
daily_data = data_sync.storage.load("daily")
|
||||
|
||||
# Get date range for each stock
|
||||
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
|
||||
|
||||
# Get global min and max
|
||||
global_min = str(daily_data["trade_date"].min())
|
||||
global_max = str(daily_data["trade_date"].max())
|
||||
|
||||
print(f"[TEST] Global date range: {global_min} to {global_max}")
|
||||
print(f"[TEST] Total stocks: {len(stock_dates)}")
|
||||
|
||||
# Verify we have data for multiple stocks
|
||||
assert len(stock_dates) >= 10, (
|
||||
f"Expected at least 10 stocks, got {len(stock_dates)}"
|
||||
)
|
||||
|
||||
# Verify date range is reasonable (at least 1 year of data)
|
||||
global_min_int = int(global_min)
|
||||
global_max_int = int(global_max)
|
||||
days_span = global_max_int - global_min_int
|
||||
|
||||
assert days_span > 100, (
|
||||
f"Date range too small: {days_span} days. "
|
||||
f"Expected at least 100 days of data."
|
||||
)
|
||||
|
||||
print(f"[TEST] Date span: {days_span} days")
|
||||
|
||||
|
||||
class TestDateUtilities:
|
||||
"""Test date utility functions."""
|
||||
|
||||
def test_get_next_date(self):
|
||||
"""Test get_next_date correctly calculates next day."""
|
||||
# Test normal cases
|
||||
assert get_next_date("20240101") == "20240102"
|
||||
assert get_next_date("20240131") == "20240201" # Month boundary
|
||||
assert get_next_date("20241231") == "20250101" # Year boundary
|
||||
|
||||
def test_incremental_vs_full_sync_logic(self):
|
||||
"""Test the logic difference between incremental and full sync.
|
||||
|
||||
Incremental: start_date = local_last_date + 1
|
||||
Full: start_date = 20180101
|
||||
"""
|
||||
# Scenario 1: Local data exists
|
||||
local_last_date = "20240115"
|
||||
incremental_start = get_next_date(local_last_date)
|
||||
|
||||
assert incremental_start == "20240116"
|
||||
assert incremental_start != DEFAULT_START_DATE
|
||||
|
||||
# Scenario 2: Force full sync
|
||||
full_sync_start = DEFAULT_START_DATE # "20180101"
|
||||
|
||||
assert full_sync_start == "20180101"
|
||||
assert incremental_start != full_sync_start
|
||||
|
||||
print("[TEST] Incremental vs Full sync logic verified")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user