Compare commits
3 Commits
05228ce9de
...
9965ce5706
| Author | SHA1 | Date | |
|---|---|---|---|
| 9965ce5706 | |||
| e81d39ae0d | |||
| 8fc88b60e3 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -46,6 +46,7 @@ logs/
|
|||||||
# IDE和编辑器
|
# IDE和编辑器
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
|
.opencode/
|
||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
*~
|
*~
|
||||||
|
|||||||
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
pythonpath = .
|
||||||
|
testpaths = tests
|
||||||
@@ -6,7 +6,7 @@ Provides simplified interfaces for fetching and storing Tushare data.
|
|||||||
from src.data.config import Config, get_config
|
from src.data.config import Config, get_config
|
||||||
from src.data.client import TushareClient
|
from src.data.client import TushareClient
|
||||||
from src.data.storage import Storage
|
from src.data.storage import Storage
|
||||||
from src.data.stock_basic import get_stock_basic, sync_all_stocks
|
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Config",
|
"Config",
|
||||||
|
|||||||
244
src/data/api_wrappers/API_INTERFACE_SPEC.md
Normal file
244
src/data/api_wrappers/API_INTERFACE_SPEC.md
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
# ProStock 数据接口封装规范
|
||||||
|
|
||||||
|
## 1. 概述
|
||||||
|
|
||||||
|
本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保:
|
||||||
|
- 代码风格统一
|
||||||
|
- 自动 sync 支持
|
||||||
|
- 增量更新逻辑一致
|
||||||
|
- 减少存储写入压力
|
||||||
|
|
||||||
|
## 2. 接口分类
|
||||||
|
|
||||||
|
### 2.1 特殊接口(不参与统一 sync)
|
||||||
|
|
||||||
|
以下接口有独立的同步逻辑,不参与自动 sync 机制:
|
||||||
|
|
||||||
|
| 接口类型 | 示例 | 说明 |
|
||||||
|
|---------|------|------|
|
||||||
|
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
|
||||||
|
| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 |
|
||||||
|
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
|
||||||
|
|
||||||
|
### 2.2 标准接口(必须遵循本规范)
|
||||||
|
|
||||||
|
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。
|
||||||
|
|
||||||
|
## 3. 文件结构要求
|
||||||
|
|
||||||
|
### 3.1 文件命名
|
||||||
|
|
||||||
|
```
|
||||||
|
{data_type}.py
|
||||||
|
```
|
||||||
|
|
||||||
|
示例:`daily.py`、`moneyflow.py`、`limit_list.py`
|
||||||
|
|
||||||
|
### 3.2 文件位置
|
||||||
|
|
||||||
|
所有接口文件必须位于 `src/data/` 目录下。
|
||||||
|
|
||||||
|
### 3.3 导出要求
|
||||||
|
|
||||||
|
新接口必须在 `src/data/__init__.py` 中导出:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.data.{module_name} import get_{data_type}
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# ... 其他导出 ...
|
||||||
|
"get_{data_type}",
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. 接口设计规范
|
||||||
|
|
||||||
|
### 4.1 数据获取函数签名要求
|
||||||
|
|
||||||
|
函数必须返回 `pd.DataFrame`,参数必须包含以下之一:
|
||||||
|
|
||||||
|
#### 4.1.1 按日期获取的接口(优先)
|
||||||
|
|
||||||
|
适用于:涨跌停、龙虎榜、筹码分布等。
|
||||||
|
|
||||||
|
**函数签名要求**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_{data_type}(
|
||||||
|
trade_date: Optional[str] = None,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
# 其他可选参数...
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
```
|
||||||
|
|
||||||
|
**要求**:
|
||||||
|
- 优先使用 `trade_date` 获取单日全市场数据
|
||||||
|
- 支持 `start_date + end_date` 获取区间数据
|
||||||
|
- `ts_code` 作为可选过滤参数
|
||||||
|
|
||||||
|
#### 4.1.2 按股票获取的接口
|
||||||
|
|
||||||
|
适用于:日线行情、资金流向等。
|
||||||
|
|
||||||
|
**函数签名要求**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_{data_type}(
|
||||||
|
ts_code: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
# 其他可选参数...
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 文档字符串要求
|
||||||
|
|
||||||
|
函数必须包含 Google 风格的完整文档字符串,包含:
|
||||||
|
- 函数功能描述
|
||||||
|
- `Args` 部分:所有参数说明
|
||||||
|
- `Returns` 部分:返回的 DataFrame 包含的字段说明
|
||||||
|
- `Example` 部分:使用示例
|
||||||
|
|
||||||
|
### 4.3 日期格式要求
|
||||||
|
|
||||||
|
- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式
|
||||||
|
- 统一使用 `trade_date` 作为日期字段名
|
||||||
|
- 如果 API 返回其他日期字段名(如 `date`、`end_date`),必须在返回前重命名为 `trade_date`
|
||||||
|
|
||||||
|
### 4.4 股票代码要求
|
||||||
|
|
||||||
|
- 统一使用 `ts_code` 作为股票代码字段名
|
||||||
|
- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH`
|
||||||
|
|
||||||
|
### 4.5 令牌桶限速要求
|
||||||
|
|
||||||
|
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
|
||||||
|
|
||||||
|
## 5. Sync 集成规范
|
||||||
|
|
||||||
|
### 5.1 DATASET_CONFIG 注册要求
|
||||||
|
|
||||||
|
新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项:
|
||||||
|
|
||||||
|
```python
|
||||||
|
"{new_data_type}": {
|
||||||
|
"api_name": "{tushare_api_name}", # Tushare API 名称
|
||||||
|
"fetch_by": "date", # "date" 或 "stock"
|
||||||
|
"date_field": "trade_date",
|
||||||
|
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 fetch_by 取值规则
|
||||||
|
|
||||||
|
- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据
|
||||||
|
- 仅当 API 不支持按日期获取时才使用 `"stock"`
|
||||||
|
|
||||||
|
### 5.3 sync 方法要求
|
||||||
|
|
||||||
|
必须实现对应的 sync 方法或复用通用方法:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame:
|
||||||
|
"""Sync {数据描述}。"""
|
||||||
|
return self.sync_dataset("{data_type}", force_full)
|
||||||
|
```
|
||||||
|
|
||||||
|
同时提供便捷函数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
|
||||||
|
"""Sync {数据描述}。"""
|
||||||
|
sync_manager = DataSync()
|
||||||
|
return sync_manager.sync_{data_type}(force_full)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.4 增量更新要求
|
||||||
|
|
||||||
|
- 必须实现增量更新逻辑(自动检查本地最新日期)
|
||||||
|
- 使用 `force_full` 参数支持强制全量同步
|
||||||
|
|
||||||
|
## 6. 存储规范
|
||||||
|
|
||||||
|
### 6.1 存储方式
|
||||||
|
|
||||||
|
所有数据通过 `Storage` 类进行 HDF5 存储。
|
||||||
|
|
||||||
|
### 6.2 写入策略
|
||||||
|
|
||||||
|
**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。
|
||||||
|
|
||||||
|
### 6.3 去重要求
|
||||||
|
|
||||||
|
使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`。
|
||||||
|
|
||||||
|
## 7. 测试规范
|
||||||
|
|
||||||
|
### 7.1 测试文件要求
|
||||||
|
|
||||||
|
必须创建对应的测试文件:`tests/test_{data_type}.py`
|
||||||
|
|
||||||
|
### 7.2 测试覆盖要求
|
||||||
|
|
||||||
|
- 测试按日期获取
|
||||||
|
- 测试按股票获取(如果支持)
|
||||||
|
- 必须 mock `TushareClient`
|
||||||
|
- 测试覆盖正常和异常情况
|
||||||
|
|
||||||
|
## 8. 新增接口完整流程
|
||||||
|
|
||||||
|
### 8.1 创建接口文件
|
||||||
|
|
||||||
|
1. 在 `src/data/` 下创建 `{data_type}.py`
|
||||||
|
2. 实现数据获取函数,遵循第 4 节规范
|
||||||
|
|
||||||
|
### 8.2 注册 sync 支持
|
||||||
|
|
||||||
|
1. 在 `sync.py` 的 `DataSync.DATASET_CONFIG` 中注册
|
||||||
|
2. 实现对应的 sync 方法
|
||||||
|
3. 提供便捷函数
|
||||||
|
|
||||||
|
### 8.3 更新导出
|
||||||
|
|
||||||
|
在 `src/data/__init__.py` 中导出接口函数。
|
||||||
|
|
||||||
|
### 8.4 创建测试
|
||||||
|
|
||||||
|
创建 `tests/test_{data_type}.py`,覆盖关键场景。
|
||||||
|
|
||||||
|
## 9. 检查清单
|
||||||
|
|
||||||
|
### 9.1 文件结构
|
||||||
|
- [ ] 文件位于 `src/data/{data_type}.py`
|
||||||
|
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
|
||||||
|
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
|
||||||
|
|
||||||
|
### 9.2 接口实现
|
||||||
|
- [ ] 数据获取函数使用 `TushareClient`
|
||||||
|
- [ ] 函数包含完整的 Google 风格文档字符串
|
||||||
|
- [ ] 日期参数使用 `YYYYMMDD` 格式
|
||||||
|
- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段
|
||||||
|
- [ ] 优先实现按日期获取的接口(如果 API 支持)
|
||||||
|
|
||||||
|
### 9.3 Sync 集成
|
||||||
|
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
|
||||||
|
- [ ] 正确设置 `fetch_by`("date" 或 "stock")
|
||||||
|
- [ ] 正确设置 `date_field` 和 `key_fields`
|
||||||
|
- [ ] 已实现对应的 sync 方法或复用通用方法
|
||||||
|
- [ ] 增量更新逻辑正确(检查本地最新日期)
|
||||||
|
|
||||||
|
### 9.4 存储优化
|
||||||
|
- [ ] 所有数据一次性写入(非逐条)
|
||||||
|
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
|
||||||
|
- [ ] 去重字段配置正确
|
||||||
|
|
||||||
|
### 9.5 测试
|
||||||
|
- [ ] 已编写单元测试
|
||||||
|
- [ ] 已 mock TushareClient
|
||||||
|
- [ ] 测试覆盖正常和异常情况
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**最后更新**: 2026-02-01
|
||||||
40
src/data/api_wrappers/__init__.py
Normal file
40
src/data/api_wrappers/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Tushare API wrapper modules.
|
||||||
|
|
||||||
|
This package contains simplified interfaces for fetching data from Tushare API.
|
||||||
|
All wrapper files follow the naming convention: api_{data_type}.py
|
||||||
|
|
||||||
|
Available APIs:
|
||||||
|
- api_daily: Daily market data (日线行情)
|
||||||
|
- api_stock_basic: Stock basic information (股票基本信息)
|
||||||
|
- api_trade_cal: Trading calendar (交易日历)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal
|
||||||
|
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
|
>>> stocks = get_stock_basic()
|
||||||
|
>>> calendar = get_trade_cal('20240101', '20240131')
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.data.api_wrappers.api_daily import get_daily
|
||||||
|
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
|
||||||
|
from src.data.api_wrappers.api_trade_cal import (
|
||||||
|
get_trade_cal,
|
||||||
|
get_trading_days,
|
||||||
|
get_first_trading_day,
|
||||||
|
get_last_trading_day,
|
||||||
|
sync_trade_cal_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Daily market data
|
||||||
|
"get_daily",
|
||||||
|
# Stock basic information
|
||||||
|
"get_stock_basic",
|
||||||
|
"sync_all_stocks",
|
||||||
|
# Trade calendar
|
||||||
|
"get_trade_cal",
|
||||||
|
"get_trading_days",
|
||||||
|
"get_first_trading_day",
|
||||||
|
"get_last_trading_day",
|
||||||
|
"sync_trade_cal_cache",
|
||||||
|
]
|
||||||
@@ -179,4 +179,75 @@ df = pro.query('trade_cal', start_date='20180101', end_date='20181231')
|
|||||||
17 SSE 20180118 1
|
17 SSE 20180118 1
|
||||||
18 SSE 20180119 1
|
18 SSE 20180119 1
|
||||||
19 SSE 20180120 0
|
19 SSE 20180120 0
|
||||||
20 SSE 20180121 0
|
20 SSE 20180121 0
|
||||||
|
|
||||||
|
|
||||||
|
每日指标
|
||||||
|
接口:daily_basic,可以通过数据工具调试和查看数据。
|
||||||
|
更新时间:交易日每日15点~17点之间
|
||||||
|
描述:获取全部股票每日重要的基本面指标,可用于选股分析、报表展示等。单次请求最大返回6000条数据,可按日线循环提取全部历史。
|
||||||
|
积分:至少2000积分才可以调取,5000积分无总量限制,具体请参阅积分获取办法
|
||||||
|
|
||||||
|
输入参数
|
||||||
|
|
||||||
|
名称 类型 必选 描述
|
||||||
|
ts_code str Y 股票代码(二选一)
|
||||||
|
trade_date str N 交易日期 (二选一)
|
||||||
|
start_date str N 开始日期(YYYYMMDD)
|
||||||
|
end_date str N 结束日期(YYYYMMDD)
|
||||||
|
注:日期都填YYYYMMDD格式,比如20181010
|
||||||
|
|
||||||
|
输出参数
|
||||||
|
|
||||||
|
名称 类型 描述
|
||||||
|
ts_code str TS股票代码
|
||||||
|
trade_date str 交易日期
|
||||||
|
close float 当日收盘价
|
||||||
|
turnover_rate float 换手率(%)
|
||||||
|
turnover_rate_f float 换手率(自由流通股)
|
||||||
|
volume_ratio float 量比
|
||||||
|
pe float 市盈率(总市值/净利润, 亏损的PE为空)
|
||||||
|
pe_ttm float 市盈率(TTM,亏损的PE为空)
|
||||||
|
pb float 市净率(总市值/净资产)
|
||||||
|
ps float 市销率
|
||||||
|
ps_ttm float 市销率(TTM)
|
||||||
|
dv_ratio float 股息率 (%)
|
||||||
|
dv_ttm float 股息率(TTM)(%)
|
||||||
|
total_share float 总股本 (万股)
|
||||||
|
float_share float 流通股本 (万股)
|
||||||
|
free_share float 自由流通股本 (万)
|
||||||
|
total_mv float 总市值 (万元)
|
||||||
|
circ_mv float 流通市值(万元)
|
||||||
|
接口用法
|
||||||
|
|
||||||
|
|
||||||
|
pro = ts.pro_api()
|
||||||
|
|
||||||
|
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||||
|
或者
|
||||||
|
|
||||||
|
|
||||||
|
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||||
|
数据样例
|
||||||
|
|
||||||
|
ts_code trade_date turnover_rate volume_ratio pe pb
|
||||||
|
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
|
||||||
|
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
|
||||||
|
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
|
||||||
|
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
|
||||||
|
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
|
||||||
|
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
|
||||||
|
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
|
||||||
|
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
|
||||||
|
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
|
||||||
|
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
|
||||||
|
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
|
||||||
|
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
|
||||||
|
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
|
||||||
|
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
|
||||||
|
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
|
||||||
|
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
|
||||||
|
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
|
||||||
|
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
|
||||||
|
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
|
||||||
|
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
A single function to fetch A股日线行情 data from Tushare.
|
A single function to fetch A股日线行情 data from Tushare.
|
||||||
Supports all output fields including tor (换手率) and vr (量比).
|
Supports all output fields including tor (换手率) and vr (量比).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from typing import Optional, List, Literal
|
from typing import Optional, List, Literal
|
||||||
from src.data.client import TushareClient
|
from src.data.client import TushareClient
|
||||||
@@ -33,7 +34,7 @@ def get_daily(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
pd.DataFrame with daily market data containing:
|
pd.DataFrame with daily market data containing:
|
||||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||||
change, pct_chg, vol, amount
|
change, pct_chg, vol, amount
|
||||||
- Factor fields (if requested): tor, vr
|
- Factor fields (if requested): tor, vr
|
||||||
- Adjustment factor (if adjfactor=True): adjfactor
|
- Adjustment factor (if adjfactor=True): adjfactor
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
Fetch basic stock information including code, name, listing date, etc.
|
Fetch basic stock information including code, name, listing date, etc.
|
||||||
This is a special interface - call once to get all stocks (listed and delisted).
|
This is a special interface - call once to get all stocks (listed and delisted).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Simplified HDF5 storage for data persistence."""
|
"""Simplified HDF5 storage for data persistence."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -47,7 +48,9 @@ class Storage:
|
|||||||
# Merge with existing data
|
# Merge with existing data
|
||||||
existing = store[name]
|
existing = store[name]
|
||||||
combined = pd.concat([existing, data], ignore_index=True)
|
combined = pd.concat([existing, data], ignore_index=True)
|
||||||
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
|
combined = combined.drop_duplicates(
|
||||||
|
subset=["ts_code", "trade_date"], keep="last"
|
||||||
|
)
|
||||||
store.put(name, combined, format="table")
|
store.put(name, combined, format="table")
|
||||||
|
|
||||||
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
||||||
@@ -57,10 +60,13 @@ class Storage:
|
|||||||
print(f"[Storage] Error saving {name}: {e}")
|
print(f"[Storage] Error saving {name}: {e}")
|
||||||
return {"status": "error", "error": str(e)}
|
return {"status": "error", "error": str(e)}
|
||||||
|
|
||||||
def load(self, name: str,
|
def load(
|
||||||
start_date: Optional[str] = None,
|
self,
|
||||||
end_date: Optional[str] = None,
|
name: str,
|
||||||
ts_code: Optional[str] = None) -> pd.DataFrame:
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
"""Load data from HDF5 file.
|
"""Load data from HDF5 file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -80,14 +86,25 @@ class Storage:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with pd.HDFStore(file_path, mode="r") as store:
|
with pd.HDFStore(file_path, mode="r") as store:
|
||||||
if name not in store.keys():
|
keys = store.keys()
|
||||||
|
# Handle both '/daily' and 'daily' keys
|
||||||
|
actual_key = None
|
||||||
|
if name in keys:
|
||||||
|
actual_key = name
|
||||||
|
elif f"/{name}" in keys:
|
||||||
|
actual_key = f"/{name}"
|
||||||
|
|
||||||
|
if actual_key is None:
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
|
|
||||||
data = store[name]
|
data = store[actual_key]
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
if start_date and end_date and "trade_date" in data.columns:
|
if start_date and end_date and "trade_date" in data.columns:
|
||||||
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
|
data = data[
|
||||||
|
(data["trade_date"] >= start_date)
|
||||||
|
& (data["trade_date"] <= end_date)
|
||||||
|
]
|
||||||
|
|
||||||
if ts_code and "ts_code" in data.columns:
|
if ts_code and "ts_code" in data.columns:
|
||||||
data = data[data["ts_code"] == ts_code]
|
data = data[data["ts_code"] == ts_code]
|
||||||
|
|||||||
295
src/data/sync.py
295
src/data/sync.py
@@ -5,11 +5,15 @@ This module provides data fetching functions with intelligent sync logic:
|
|||||||
- If local file exists: incremental update (fetch from latest date + 1 day)
|
- If local file exists: incremental update (fetch from latest date + 1 day)
|
||||||
- Multi-threaded concurrent fetching for improved performance
|
- Multi-threaded concurrent fetching for improved performance
|
||||||
- Stop immediately on any exception
|
- Stop immediately on any exception
|
||||||
|
- Preview mode: check data volume and samples before actual sync
|
||||||
|
|
||||||
Currently supported data types:
|
Currently supported data types:
|
||||||
- daily: Daily market data (with turnover rate and volume ratio)
|
- daily: Daily market data (with turnover rate and volume ratio)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
# Preview sync (check data volume and samples without writing)
|
||||||
|
preview_sync()
|
||||||
|
|
||||||
# Sync all stocks (full load)
|
# Sync all stocks (full load)
|
||||||
sync_all()
|
sync_all()
|
||||||
|
|
||||||
@@ -18,6 +22,9 @@ Usage:
|
|||||||
|
|
||||||
# Force full reload
|
# Force full reload
|
||||||
sync_all(force_full=True)
|
sync_all(force_full=True)
|
||||||
|
|
||||||
|
# Dry run (preview only, no write)
|
||||||
|
sync_all(dry_run=True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -30,8 +37,8 @@ import sys
|
|||||||
|
|
||||||
from src.data.client import TushareClient
|
from src.data.client import TushareClient
|
||||||
from src.data.storage import Storage
|
from src.data.storage import Storage
|
||||||
from src.data.daily import get_daily
|
from src.data.api_wrappers import get_daily
|
||||||
from src.data.trade_cal import (
|
from src.data.api_wrappers import (
|
||||||
get_first_trading_day,
|
get_first_trading_day,
|
||||||
get_last_trading_day,
|
get_last_trading_day,
|
||||||
sync_trade_cal_cache,
|
sync_trade_cal_cache,
|
||||||
@@ -114,7 +121,8 @@ class DataSync:
|
|||||||
List of stock codes
|
List of stock codes
|
||||||
"""
|
"""
|
||||||
# Import sync_all_stocks here to avoid circular imports
|
# Import sync_all_stocks here to avoid circular imports
|
||||||
from src.data.stock_basic import sync_all_stocks, _get_csv_path
|
from src.data.api_wrappers import sync_all_stocks
|
||||||
|
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||||
|
|
||||||
# First, ensure stock_basic.csv is up-to-date with all stocks
|
# First, ensure stock_basic.csv is up-to-date with all stocks
|
||||||
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
|
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
|
||||||
@@ -278,6 +286,184 @@ class DataSync:
|
|||||||
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
|
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||||
return (True, sync_start, cal_last, local_last_date)
|
return (True, sync_start, cal_last, local_last_date)
|
||||||
|
|
||||||
|
def preview_sync(
|
||||||
|
self,
|
||||||
|
force_full: bool = False,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
sample_size: int = 3,
|
||||||
|
) -> dict:
|
||||||
|
"""Preview sync data volume and samples without actually syncing.
|
||||||
|
|
||||||
|
This method provides a preview of what would be synced, including:
|
||||||
|
- Number of stocks to be synced
|
||||||
|
- Date range for sync
|
||||||
|
- Estimated total records
|
||||||
|
- Sample data from first few stocks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_full: If True, preview full sync from 20180101
|
||||||
|
start_date: Manual start date (overrides auto-detection)
|
||||||
|
end_date: Manual end date (defaults to today)
|
||||||
|
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with preview information:
|
||||||
|
{
|
||||||
|
'sync_needed': bool,
|
||||||
|
'stock_count': int,
|
||||||
|
'start_date': str,
|
||||||
|
'end_date': str,
|
||||||
|
'estimated_records': int,
|
||||||
|
'sample_data': pd.DataFrame,
|
||||||
|
'mode': str, # 'full' or 'incremental'
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("[DataSync] Preview Mode - Analyzing sync requirements...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# First, ensure trade calendar cache is up-to-date
|
||||||
|
print("[DataSync] Syncing trade calendar cache...")
|
||||||
|
sync_trade_cal_cache()
|
||||||
|
|
||||||
|
# Determine date range
|
||||||
|
if end_date is None:
|
||||||
|
end_date = get_today_date()
|
||||||
|
|
||||||
|
# Check if sync is needed
|
||||||
|
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||||
|
|
||||||
|
if not sync_needed:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("[DataSync] Preview Result")
|
||||||
|
print("=" * 60)
|
||||||
|
print(" Sync Status: NOT NEEDED")
|
||||||
|
print(" Reason: Local data is up-to-date with trade calendar")
|
||||||
|
print("=" * 60)
|
||||||
|
return {
|
||||||
|
"sync_needed": False,
|
||||||
|
"stock_count": 0,
|
||||||
|
"start_date": None,
|
||||||
|
"end_date": None,
|
||||||
|
"estimated_records": 0,
|
||||||
|
"sample_data": pd.DataFrame(),
|
||||||
|
"mode": "none",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use dates from check_sync_needed
|
||||||
|
if cal_start and cal_end:
|
||||||
|
sync_start_date = cal_start
|
||||||
|
end_date = cal_end
|
||||||
|
else:
|
||||||
|
sync_start_date = start_date or DEFAULT_START_DATE
|
||||||
|
if end_date is None:
|
||||||
|
end_date = get_today_date()
|
||||||
|
|
||||||
|
# Determine sync mode
|
||||||
|
if force_full:
|
||||||
|
mode = "full"
|
||||||
|
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||||
|
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||||
|
mode = "incremental"
|
||||||
|
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||||
|
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||||
|
else:
|
||||||
|
mode = "partial"
|
||||||
|
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||||
|
|
||||||
|
# Get all stock codes
|
||||||
|
stock_codes = self.get_all_stock_codes()
|
||||||
|
if not stock_codes:
|
||||||
|
print("[DataSync] No stocks found to sync")
|
||||||
|
return {
|
||||||
|
"sync_needed": False,
|
||||||
|
"stock_count": 0,
|
||||||
|
"start_date": None,
|
||||||
|
"end_date": None,
|
||||||
|
"estimated_records": 0,
|
||||||
|
"sample_data": pd.DataFrame(),
|
||||||
|
"mode": "none",
|
||||||
|
}
|
||||||
|
|
||||||
|
stock_count = len(stock_codes)
|
||||||
|
print(f"[DataSync] Total stocks to sync: {stock_count}")
|
||||||
|
|
||||||
|
# Fetch sample data from first few stocks
|
||||||
|
print(f"[DataSync] Fetching sample data from {sample_size} stocks...")
|
||||||
|
sample_data_list = []
|
||||||
|
sample_codes = stock_codes[:sample_size]
|
||||||
|
|
||||||
|
for ts_code in sample_codes:
|
||||||
|
try:
|
||||||
|
data = self.client.query(
|
||||||
|
"pro_bar",
|
||||||
|
ts_code=ts_code,
|
||||||
|
start_date=sync_start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
factors="tor,vr",
|
||||||
|
)
|
||||||
|
if not data.empty:
|
||||||
|
sample_data_list.append(data)
|
||||||
|
print(f" - {ts_code}: {len(data)} records")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" - {ts_code}: Error fetching - {e}")
|
||||||
|
|
||||||
|
# Combine sample data
|
||||||
|
sample_df = (
|
||||||
|
pd.concat(sample_data_list, ignore_index=True)
|
||||||
|
if sample_data_list
|
||||||
|
else pd.DataFrame()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Estimate total records based on sample
|
||||||
|
if not sample_df.empty:
|
||||||
|
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||||
|
estimated_records = int(avg_records_per_stock * stock_count)
|
||||||
|
else:
|
||||||
|
estimated_records = 0
|
||||||
|
|
||||||
|
# Display preview results
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("[DataSync] Preview Result")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f" Sync Mode: {mode.upper()}")
|
||||||
|
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||||
|
print(f" Stocks to Sync: {stock_count}")
|
||||||
|
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||||
|
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||||
|
|
||||||
|
if not sample_df.empty:
|
||||||
|
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||||
|
print(" " + "-" * 56)
|
||||||
|
# Display sample data in a compact format
|
||||||
|
preview_cols = [
|
||||||
|
"ts_code",
|
||||||
|
"trade_date",
|
||||||
|
"open",
|
||||||
|
"high",
|
||||||
|
"low",
|
||||||
|
"close",
|
||||||
|
"vol",
|
||||||
|
]
|
||||||
|
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||||
|
sample_display = sample_df[available_cols].head(10)
|
||||||
|
for idx, row in sample_display.iterrows():
|
||||||
|
print(f" {row.to_dict()}")
|
||||||
|
print(" " + "-" * 56)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sync_needed": True,
|
||||||
|
"stock_count": stock_count,
|
||||||
|
"start_date": sync_start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"estimated_records": estimated_records,
|
||||||
|
"sample_data": sample_df,
|
||||||
|
"mode": mode,
|
||||||
|
}
|
||||||
|
|
||||||
def sync_single_stock(
|
def sync_single_stock(
|
||||||
self,
|
self,
|
||||||
ts_code: str,
|
ts_code: str,
|
||||||
@@ -320,6 +506,7 @@ class DataSync:
|
|||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
|
dry_run: bool = False,
|
||||||
) -> Dict[str, pd.DataFrame]:
|
) -> Dict[str, pd.DataFrame]:
|
||||||
"""Sync daily data for all stocks in local storage.
|
"""Sync daily data for all stocks in local storage.
|
||||||
|
|
||||||
@@ -337,9 +524,10 @@ class DataSync:
|
|||||||
start_date: Manual start date (overrides auto-detection)
|
start_date: Manual start date (overrides auto-detection)
|
||||||
end_date: Manual end date (defaults to today)
|
end_date: Manual end date (defaults to today)
|
||||||
max_workers: Number of worker threads (default: 10)
|
max_workers: Number of worker threads (default: 10)
|
||||||
|
dry_run: If True, only preview what would be synced without writing data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping ts_code to DataFrame (empty if sync skipped)
|
Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run)
|
||||||
"""
|
"""
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("[DataSync] Starting daily data sync...")
|
print("[DataSync] Starting daily data sync...")
|
||||||
@@ -378,11 +566,14 @@ class DataSync:
|
|||||||
|
|
||||||
# Determine sync mode
|
# Determine sync mode
|
||||||
if force_full:
|
if force_full:
|
||||||
|
mode = "full"
|
||||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||||
|
mode = "incremental"
|
||||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||||
else:
|
else:
|
||||||
|
mode = "partial"
|
||||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||||
|
|
||||||
# Get all stock codes
|
# Get all stock codes
|
||||||
@@ -394,6 +585,17 @@ class DataSync:
|
|||||||
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
|
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
|
||||||
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
|
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
|
||||||
|
|
||||||
|
# Handle dry run mode
|
||||||
|
if dry_run:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("[DataSync] DRY RUN MODE - No data will be written")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f" Would sync {len(stock_codes)} stocks")
|
||||||
|
print(f" Date range: {sync_start_date} to {end_date}")
|
||||||
|
print(f" Mode: {mode}")
|
||||||
|
print("=" * 60)
|
||||||
|
return {}
|
||||||
|
|
||||||
# Reset stop flag for new sync
|
# Reset stop flag for new sync
|
||||||
self._stop_flag.set()
|
self._stop_flag.set()
|
||||||
|
|
||||||
@@ -492,11 +694,62 @@ class DataSync:
|
|||||||
# Convenience functions
|
# Convenience functions
|
||||||
|
|
||||||
|
|
||||||
|
def preview_sync(
|
||||||
|
force_full: bool = False,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
sample_size: int = 3,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Preview sync data volume and samples without actually syncing.
|
||||||
|
|
||||||
|
This is the recommended way to check what would be synced before
|
||||||
|
running the actual synchronization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_full: If True, preview full sync from 20180101
|
||||||
|
start_date: Manual start date (overrides auto-detection)
|
||||||
|
end_date: Manual end date (defaults to today)
|
||||||
|
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||||
|
max_workers: Number of worker threads (not used in preview, for API compatibility)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with preview information:
|
||||||
|
{
|
||||||
|
'sync_needed': bool,
|
||||||
|
'stock_count': int,
|
||||||
|
'start_date': str,
|
||||||
|
'end_date': str,
|
||||||
|
'estimated_records': int,
|
||||||
|
'sample_data': pd.DataFrame,
|
||||||
|
'mode': str, # 'full', 'incremental', 'partial', or 'none'
|
||||||
|
}
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Preview what would be synced
|
||||||
|
>>> preview = preview_sync()
|
||||||
|
>>>
|
||||||
|
>>> # Preview full sync
|
||||||
|
>>> preview = preview_sync(force_full=True)
|
||||||
|
>>>
|
||||||
|
>>> # Preview with more samples
|
||||||
|
>>> preview = preview_sync(sample_size=5)
|
||||||
|
"""
|
||||||
|
sync_manager = DataSync(max_workers=max_workers)
|
||||||
|
return sync_manager.preview_sync(
|
||||||
|
force_full=force_full,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
sample_size=sample_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sync_all(
|
def sync_all(
|
||||||
force_full: bool = False,
|
force_full: bool = False,
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
|
dry_run: bool = False,
|
||||||
) -> Dict[str, pd.DataFrame]:
|
) -> Dict[str, pd.DataFrame]:
|
||||||
"""Sync daily data for all stocks.
|
"""Sync daily data for all stocks.
|
||||||
|
|
||||||
@@ -507,6 +760,7 @@ def sync_all(
|
|||||||
start_date: Manual start date (YYYYMMDD)
|
start_date: Manual start date (YYYYMMDD)
|
||||||
end_date: Manual end date (defaults to today)
|
end_date: Manual end date (defaults to today)
|
||||||
max_workers: Number of worker threads (default: 10)
|
max_workers: Number of worker threads (default: 10)
|
||||||
|
dry_run: If True, only preview what would be synced without writing data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping ts_code to DataFrame
|
Dict mapping ts_code to DataFrame
|
||||||
@@ -526,12 +780,16 @@ def sync_all(
|
|||||||
>>>
|
>>>
|
||||||
>>> # Custom thread count
|
>>> # Custom thread count
|
||||||
>>> result = sync_all(max_workers=20)
|
>>> result = sync_all(max_workers=20)
|
||||||
|
>>>
|
||||||
|
>>> # Dry run (preview only)
|
||||||
|
>>> result = sync_all(dry_run=True)
|
||||||
"""
|
"""
|
||||||
sync_manager = DataSync(max_workers=max_workers)
|
sync_manager = DataSync(max_workers=max_workers)
|
||||||
return sync_manager.sync_all(
|
return sync_manager.sync_all(
|
||||||
force_full=force_full,
|
force_full=force_full,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
|
dry_run=dry_run,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -540,11 +798,32 @@ if __name__ == "__main__":
|
|||||||
print("Data Sync Module")
|
print("Data Sync Module")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("\nUsage:")
|
print("\nUsage:")
|
||||||
print(" from src.data.sync import sync_all")
|
print(" from src.data.sync import sync_all, preview_sync")
|
||||||
|
print("")
|
||||||
|
print(" # Preview before sync (recommended)")
|
||||||
|
print(" preview = preview_sync()")
|
||||||
|
print("")
|
||||||
|
print(" # Dry run (preview only)")
|
||||||
|
print(" result = sync_all(dry_run=True)")
|
||||||
|
print("")
|
||||||
|
print(" # Actual sync")
|
||||||
print(" result = sync_all() # Incremental sync")
|
print(" result = sync_all() # Incremental sync")
|
||||||
print(" result = sync_all(force_full=True) # Full reload")
|
print(" result = sync_all(force_full=True) # Full reload")
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
# Run sync
|
# Run preview first
|
||||||
result = sync_all()
|
print("\n[Main] Running preview first...")
|
||||||
print(f"\nSynced {len(result)} stocks")
|
preview = preview_sync()
|
||||||
|
|
||||||
|
if preview["sync_needed"]:
|
||||||
|
# Ask for confirmation
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
response = input("Proceed with sync? (y/n): ").strip().lower()
|
||||||
|
if response in ("y", "yes"):
|
||||||
|
print("\n[Main] Starting actual sync...")
|
||||||
|
result = sync_all()
|
||||||
|
print(f"\nSynced {len(result)} stocks")
|
||||||
|
else:
|
||||||
|
print("\n[Main] Sync cancelled by user")
|
||||||
|
else:
|
||||||
|
print("\n[Main] No sync needed - data is up to date")
|
||||||
|
|||||||
@@ -5,29 +5,30 @@ Tests the daily interface implementation against api.md requirements:
|
|||||||
- tor 换手率
|
- tor 换手率
|
||||||
- vr 量比
|
- vr 量比
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from src.data.daily import get_daily
|
from src.data.api_wrappers import get_daily
|
||||||
|
|
||||||
|
|
||||||
# Expected output fields according to api.md
|
# Expected output fields according to api.md
|
||||||
EXPECTED_BASE_FIELDS = [
|
EXPECTED_BASE_FIELDS = [
|
||||||
'ts_code', # 股票代码
|
"ts_code", # 股票代码
|
||||||
'trade_date', # 交易日期
|
"trade_date", # 交易日期
|
||||||
'open', # 开盘价
|
"open", # 开盘价
|
||||||
'high', # 最高价
|
"high", # 最高价
|
||||||
'low', # 最低价
|
"low", # 最低价
|
||||||
'close', # 收盘价
|
"close", # 收盘价
|
||||||
'pre_close', # 昨收价
|
"pre_close", # 昨收价
|
||||||
'change', # 涨跌额
|
"change", # 涨跌额
|
||||||
'pct_chg', # 涨跌幅
|
"pct_chg", # 涨跌幅
|
||||||
'vol', # 成交量
|
"vol", # 成交量
|
||||||
'amount', # 成交额
|
"amount", # 成交额
|
||||||
]
|
]
|
||||||
|
|
||||||
EXPECTED_FACTOR_FIELDS = [
|
EXPECTED_FACTOR_FIELDS = [
|
||||||
'turnover_rate', # 换手率 (tor)
|
"turnover_rate", # 换手率 (tor)
|
||||||
'volume_ratio', # 量比 (vr)
|
"volume_ratio", # 量比 (vr)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -36,19 +37,19 @@ class TestGetDaily:
|
|||||||
|
|
||||||
def test_fetch_basic(self):
|
def test_fetch_basic(self):
|
||||||
"""Test basic daily data fetch with real API."""
|
"""Test basic daily data fetch with real API."""
|
||||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
result = get_daily("000001.SZ", start_date="20240101", end_date="20240131")
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
assert len(result) >= 1
|
assert len(result) >= 1
|
||||||
assert result['ts_code'].iloc[0] == '000001.SZ'
|
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||||
|
|
||||||
def test_fetch_with_factors(self):
|
def test_fetch_with_factors(self):
|
||||||
"""Test fetch with tor and vr factors."""
|
"""Test fetch with tor and vr factors."""
|
||||||
result = get_daily(
|
result = get_daily(
|
||||||
'000001.SZ',
|
"000001.SZ",
|
||||||
start_date='20240101',
|
start_date="20240101",
|
||||||
end_date='20240131',
|
end_date="20240131",
|
||||||
factors=['tor', 'vr'],
|
factors=["tor", "vr"],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
@@ -61,25 +62,26 @@ class TestGetDaily:
|
|||||||
|
|
||||||
def test_output_fields_completeness(self):
|
def test_output_fields_completeness(self):
|
||||||
"""Verify all required output fields are returned."""
|
"""Verify all required output fields are returned."""
|
||||||
result = get_daily('600000.SH')
|
result = get_daily("600000.SH")
|
||||||
|
|
||||||
# Verify all base fields are present
|
# Verify all base fields are present
|
||||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
|
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), (
|
||||||
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
|
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
|
||||||
|
)
|
||||||
|
|
||||||
def test_empty_result(self):
|
def test_empty_result(self):
|
||||||
"""Test handling of empty results."""
|
"""Test handling of empty results."""
|
||||||
# 使用真实 API 测试无效股票代码的空结果
|
# 使用真实 API 测试无效股票代码的空结果
|
||||||
result = get_daily('INVALID.SZ')
|
result = get_daily("INVALID.SZ")
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
assert result.empty
|
assert result.empty
|
||||||
|
|
||||||
def test_date_range_query(self):
|
def test_date_range_query(self):
|
||||||
"""Test query with date range."""
|
"""Test query with date range."""
|
||||||
result = get_daily(
|
result = get_daily(
|
||||||
'000001.SZ',
|
"000001.SZ",
|
||||||
start_date='20240101',
|
start_date="20240101",
|
||||||
end_date='20240131',
|
end_date="20240131",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
@@ -87,7 +89,7 @@ class TestGetDaily:
|
|||||||
|
|
||||||
def test_with_adj(self):
|
def test_with_adj(self):
|
||||||
"""Test fetch with adjustment type."""
|
"""Test fetch with adjustment type."""
|
||||||
result = get_daily('000001.SZ', adj='qfq')
|
result = get_daily("000001.SZ", adj="qfq")
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
|
|
||||||
@@ -95,11 +97,14 @@ class TestGetDaily:
|
|||||||
def test_integration():
|
def test_integration():
|
||||||
"""Integration test with real Tushare API (requires valid token)."""
|
"""Integration test with real Tushare API (requires valid token)."""
|
||||||
import os
|
import os
|
||||||
token = os.environ.get('TUSHARE_TOKEN')
|
|
||||||
|
token = os.environ.get("TUSHARE_TOKEN")
|
||||||
if not token:
|
if not token:
|
||||||
pytest.skip("TUSHARE_TOKEN not configured")
|
pytest.skip("TUSHARE_TOKEN not configured")
|
||||||
|
|
||||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
|
result = get_daily(
|
||||||
|
"000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"]
|
||||||
|
)
|
||||||
|
|
||||||
# Verify structure
|
# Verify structure
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
@@ -112,6 +117,6 @@ def test_integration():
|
|||||||
assert field in result.columns, f"Missing factor field: {field}"
|
assert field in result.columns, f"Missing factor field: {field}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# 运行 pytest 单元测试(真实API调用)
|
# 运行 pytest 单元测试(真实API调用)
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import pytest
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from src.data.storage import Storage
|
from src.data.storage import Storage
|
||||||
from src.data.stock_basic import _get_csv_path
|
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||||
|
|
||||||
|
|
||||||
class TestDailyStorageValidation:
|
class TestDailyStorageValidation:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ Tests the sync module's full/incremental sync logic for daily data:
|
|||||||
- Incremental sync when local data exists (from last_date + 1)
|
- Incremental sync when local data exists (from last_date + 1)
|
||||||
- Data integrity validation
|
- Data integrity validation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
@@ -17,6 +18,8 @@ from src.data.sync import (
|
|||||||
get_next_date,
|
get_next_date,
|
||||||
DEFAULT_START_DATE,
|
DEFAULT_START_DATE,
|
||||||
)
|
)
|
||||||
|
from src.data.storage import Storage
|
||||||
|
from src.data.client import TushareClient
|
||||||
|
|
||||||
|
|
||||||
class TestDateUtilities:
|
class TestDateUtilities:
|
||||||
@@ -63,30 +66,32 @@ class TestDataSync:
|
|||||||
|
|
||||||
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
||||||
"""Test getting stock codes from daily data."""
|
"""Test getting stock codes from daily data."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
|
|
||||||
mock_storage.load.return_value = pd.DataFrame({
|
mock_storage.load.return_value = pd.DataFrame(
|
||||||
'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'],
|
{
|
||||||
})
|
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
codes = sync.get_all_stock_codes()
|
codes = sync.get_all_stock_codes()
|
||||||
|
|
||||||
assert len(codes) == 2
|
assert len(codes) == 2
|
||||||
assert '000001.SZ' in codes
|
assert "000001.SZ" in codes
|
||||||
assert '600000.SH' in codes
|
assert "600000.SH" in codes
|
||||||
|
|
||||||
def test_get_all_stock_codes_fallback(self, mock_storage):
|
def test_get_all_stock_codes_fallback(self, mock_storage):
|
||||||
"""Test fallback to stock_basic when daily is empty."""
|
"""Test fallback to stock_basic when daily is empty."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
|
|
||||||
# First call (daily) returns empty, second call (stock_basic) returns data
|
# First call (daily) returns empty, second call (stock_basic) returns data
|
||||||
mock_storage.load.side_effect = [
|
mock_storage.load.side_effect = [
|
||||||
pd.DataFrame(), # daily empty
|
pd.DataFrame(), # daily empty
|
||||||
pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic
|
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
|
||||||
]
|
]
|
||||||
|
|
||||||
codes = sync.get_all_stock_codes()
|
codes = sync.get_all_stock_codes()
|
||||||
@@ -95,21 +100,23 @@ class TestDataSync:
|
|||||||
|
|
||||||
def test_get_global_last_date(self, mock_storage):
|
def test_get_global_last_date(self, mock_storage):
|
||||||
"""Test getting global last date."""
|
"""Test getting global last date."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
|
|
||||||
mock_storage.load.return_value = pd.DataFrame({
|
mock_storage.load.return_value = pd.DataFrame(
|
||||||
'ts_code': ['000001.SZ', '600000.SH'],
|
{
|
||||||
'trade_date': ['20240102', '20240103'],
|
"ts_code": ["000001.SZ", "600000.SH"],
|
||||||
})
|
"trade_date": ["20240102", "20240103"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
last_date = sync.get_global_last_date()
|
last_date = sync.get_global_last_date()
|
||||||
assert last_date == '20240103'
|
assert last_date == "20240103"
|
||||||
|
|
||||||
def test_get_global_last_date_empty(self, mock_storage):
|
def test_get_global_last_date_empty(self, mock_storage):
|
||||||
"""Test getting last date from empty storage."""
|
"""Test getting last date from empty storage."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
|
|
||||||
@@ -120,18 +127,23 @@ class TestDataSync:
|
|||||||
|
|
||||||
def test_sync_single_stock(self, mock_storage):
|
def test_sync_single_stock(self, mock_storage):
|
||||||
"""Test syncing a single stock."""
|
"""Test syncing a single stock."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame({
|
with patch(
|
||||||
'ts_code': ['000001.SZ'],
|
"src.data.sync.get_daily",
|
||||||
'trade_date': ['20240102'],
|
return_value=pd.DataFrame(
|
||||||
})):
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"trade_date": ["20240102"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
|
|
||||||
result = sync.sync_single_stock(
|
result = sync.sync_single_stock(
|
||||||
ts_code='000001.SZ',
|
ts_code="000001.SZ",
|
||||||
start_date='20240101',
|
start_date="20240101",
|
||||||
end_date='20240102',
|
end_date="20240102",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
assert isinstance(result, pd.DataFrame)
|
||||||
@@ -139,15 +151,15 @@ class TestDataSync:
|
|||||||
|
|
||||||
def test_sync_single_stock_empty(self, mock_storage):
|
def test_sync_single_stock_empty(self, mock_storage):
|
||||||
"""Test syncing a stock with no data."""
|
"""Test syncing a stock with no data."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
|
|
||||||
result = sync.sync_single_stock(
|
result = sync.sync_single_stock(
|
||||||
ts_code='INVALID.SZ',
|
ts_code="INVALID.SZ",
|
||||||
start_date='20240101',
|
start_date="20240101",
|
||||||
end_date='20240102',
|
end_date="20240102",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.empty
|
assert result.empty
|
||||||
@@ -158,40 +170,46 @@ class TestSyncAll:
|
|||||||
|
|
||||||
def test_full_sync_mode(self, mock_storage):
|
def test_full_sync_mode(self, mock_storage):
|
||||||
"""Test full sync mode when force_full=True."""
|
"""Test full sync mode when force_full=True."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||||
|
|
||||||
mock_storage.load.return_value = pd.DataFrame({
|
mock_storage.load.return_value = pd.DataFrame(
|
||||||
'ts_code': ['000001.SZ'],
|
{
|
||||||
})
|
"ts_code": ["000001.SZ"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
result = sync.sync_all(force_full=True)
|
result = sync.sync_all(force_full=True)
|
||||||
|
|
||||||
# Verify sync_single_stock was called with default start date
|
# Verify sync_single_stock was called with default start date
|
||||||
sync.sync_single_stock.assert_called_once()
|
sync.sync_single_stock.assert_called_once()
|
||||||
call_args = sync.sync_single_stock.call_args
|
call_args = sync.sync_single_stock.call_args
|
||||||
assert call_args[1]['start_date'] == DEFAULT_START_DATE
|
assert call_args[1]["start_date"] == DEFAULT_START_DATE
|
||||||
|
|
||||||
def test_incremental_sync_mode(self, mock_storage):
|
def test_incremental_sync_mode(self, mock_storage):
|
||||||
"""Test incremental sync mode when data exists."""
|
"""Test incremental sync mode when data exists."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||||
|
|
||||||
# Mock existing data with last date
|
# Mock existing data with last date
|
||||||
mock_storage.load.side_effect = [
|
mock_storage.load.side_effect = [
|
||||||
pd.DataFrame({
|
pd.DataFrame(
|
||||||
'ts_code': ['000001.SZ'],
|
{
|
||||||
'trade_date': ['20240102'],
|
"ts_code": ["000001.SZ"],
|
||||||
}), # get_all_stock_codes
|
"trade_date": ["20240102"],
|
||||||
pd.DataFrame({
|
}
|
||||||
'ts_code': ['000001.SZ'],
|
), # get_all_stock_codes
|
||||||
'trade_date': ['20240102'],
|
pd.DataFrame(
|
||||||
}), # get_global_last_date
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"trade_date": ["20240102"],
|
||||||
|
}
|
||||||
|
), # get_global_last_date
|
||||||
]
|
]
|
||||||
|
|
||||||
result = sync.sync_all(force_full=False)
|
result = sync.sync_all(force_full=False)
|
||||||
@@ -199,28 +217,30 @@ class TestSyncAll:
|
|||||||
# Verify sync_single_stock was called with next date
|
# Verify sync_single_stock was called with next date
|
||||||
sync.sync_single_stock.assert_called_once()
|
sync.sync_single_stock.assert_called_once()
|
||||||
call_args = sync.sync_single_stock.call_args
|
call_args = sync.sync_single_stock.call_args
|
||||||
assert call_args[1]['start_date'] == '20240103'
|
assert call_args[1]["start_date"] == "20240103"
|
||||||
|
|
||||||
def test_manual_start_date(self, mock_storage):
|
def test_manual_start_date(self, mock_storage):
|
||||||
"""Test sync with manual start date."""
|
"""Test sync with manual start date."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||||
|
|
||||||
mock_storage.load.return_value = pd.DataFrame({
|
mock_storage.load.return_value = pd.DataFrame(
|
||||||
'ts_code': ['000001.SZ'],
|
{
|
||||||
})
|
"ts_code": ["000001.SZ"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
result = sync.sync_all(force_full=False, start_date='20230601')
|
result = sync.sync_all(force_full=False, start_date="20230601")
|
||||||
|
|
||||||
sync.sync_single_stock.assert_called_once()
|
sync.sync_single_stock.assert_called_once()
|
||||||
call_args = sync.sync_single_stock.call_args
|
call_args = sync.sync_single_stock.call_args
|
||||||
assert call_args[1]['start_date'] == '20230601'
|
assert call_args[1]["start_date"] == "20230601"
|
||||||
|
|
||||||
def test_no_stocks_found(self, mock_storage):
|
def test_no_stocks_found(self, mock_storage):
|
||||||
"""Test sync when no stocks are found."""
|
"""Test sync when no stocks are found."""
|
||||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||||
sync = DataSync()
|
sync = DataSync()
|
||||||
sync.storage = mock_storage
|
sync.storage = mock_storage
|
||||||
|
|
||||||
@@ -236,7 +256,7 @@ class TestSyncAllConvenienceFunction:
|
|||||||
|
|
||||||
def test_sync_all_function(self):
|
def test_sync_all_function(self):
|
||||||
"""Test sync_all convenience function."""
|
"""Test sync_all convenience function."""
|
||||||
with patch('src.data.sync.DataSync') as MockSync:
|
with patch("src.data.sync.DataSync") as MockSync:
|
||||||
mock_instance = Mock()
|
mock_instance = Mock()
|
||||||
mock_instance.sync_all.return_value = {}
|
mock_instance.sync_all.return_value = {}
|
||||||
MockSync.return_value = mock_instance
|
MockSync.return_value = mock_instance
|
||||||
@@ -251,5 +271,5 @@ class TestSyncAllConvenienceFunction:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__, '-v'])
|
pytest.main([__file__, "-v"])
|
||||||
|
|||||||
256
tests/test_sync_real.py
Normal file
256
tests/test_sync_real.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
"""Tests for data sync with REAL data (read-only).
|
||||||
|
|
||||||
|
Tests verify:
|
||||||
|
1. get_global_last_date() correctly reads local data's max date
|
||||||
|
2. Incremental sync date calculation (local_last_date + 1)
|
||||||
|
3. Full sync date calculation (20180101)
|
||||||
|
4. Multi-stock scenario with real data
|
||||||
|
|
||||||
|
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
|
||||||
|
- NO sync_all() calls (writes daily.h5)
|
||||||
|
- NO check_sync_needed() calls (writes trade_cal.h5)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.data.sync import (
|
||||||
|
DataSync,
|
||||||
|
get_next_date,
|
||||||
|
DEFAULT_START_DATE,
|
||||||
|
)
|
||||||
|
from src.data.storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataSyncReadOnly:
|
||||||
|
"""Read-only tests for data sync - verify date calculation logic."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage(self):
|
||||||
|
"""Create storage instance."""
|
||||||
|
return Storage()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data_sync(self):
|
||||||
|
"""Create DataSync instance."""
|
||||||
|
return DataSync()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def daily_exists(self, storage):
|
||||||
|
"""Check if daily.h5 exists."""
|
||||||
|
return storage.exists("daily")
|
||||||
|
|
||||||
|
def test_daily_h5_exists(self, storage):
|
||||||
|
"""Verify daily.h5 data file exists before running tests."""
|
||||||
|
assert storage.exists("daily"), (
|
||||||
|
"daily.h5 not found. Please run full sync first: "
|
||||||
|
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_global_last_date(self, data_sync, daily_exists):
|
||||||
|
"""Test get_global_last_date returns correct max date from local data."""
|
||||||
|
if not daily_exists:
|
||||||
|
pytest.skip("daily.h5 not found")
|
||||||
|
|
||||||
|
last_date = data_sync.get_global_last_date()
|
||||||
|
|
||||||
|
# Verify it's a valid date string
|
||||||
|
assert last_date is not None, "get_global_last_date returned None"
|
||||||
|
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
|
||||||
|
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
|
||||||
|
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
|
||||||
|
|
||||||
|
# Verify by reading storage directly
|
||||||
|
daily_data = data_sync.storage.load("daily")
|
||||||
|
expected_max = str(daily_data["trade_date"].max())
|
||||||
|
|
||||||
|
assert last_date == expected_max, (
|
||||||
|
f"get_global_last_date returned {last_date}, "
|
||||||
|
f"but actual max date is {expected_max}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[TEST] Local data last date: {last_date}")
|
||||||
|
|
||||||
|
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
|
||||||
|
"""Test incremental sync: start_date = local_last_date + 1.
|
||||||
|
|
||||||
|
This verifies that when local data exists, incremental sync should
|
||||||
|
fetch data from (local_last_date + 1), not from 20180101.
|
||||||
|
"""
|
||||||
|
if not daily_exists:
|
||||||
|
pytest.skip("daily.h5 not found")
|
||||||
|
|
||||||
|
# Get local last date
|
||||||
|
local_last_date = data_sync.get_global_last_date()
|
||||||
|
assert local_last_date is not None, "No local data found"
|
||||||
|
|
||||||
|
# Calculate expected incremental start date
|
||||||
|
expected_start_date = get_next_date(local_last_date)
|
||||||
|
|
||||||
|
# Verify the calculation is correct
|
||||||
|
local_last_int = int(local_last_date)
|
||||||
|
expected_int = local_last_int + 1
|
||||||
|
actual_int = int(expected_start_date)
|
||||||
|
|
||||||
|
assert actual_int == expected_int, (
|
||||||
|
f"Incremental start date calculation error: "
|
||||||
|
f"expected {expected_int}, got {actual_int}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[TEST] Incremental sync: local_last={local_last_date}, "
|
||||||
|
f"start_date should be {expected_start_date}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify this is NOT 20180101 (would be full sync)
|
||||||
|
assert expected_start_date != DEFAULT_START_DATE, (
|
||||||
|
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_full_sync_date_calculation(self):
|
||||||
|
"""Test full sync: start_date = 20180101 when force_full=True.
|
||||||
|
|
||||||
|
This verifies that force_full=True always starts from 20180101.
|
||||||
|
"""
|
||||||
|
# Full sync should always use DEFAULT_START_DATE
|
||||||
|
full_sync_start = DEFAULT_START_DATE
|
||||||
|
|
||||||
|
assert full_sync_start == "20180101", (
|
||||||
|
f"Full sync should start from 20180101, got {full_sync_start}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[TEST] Full sync start date: {full_sync_start}")
|
||||||
|
|
||||||
|
def test_date_comparison_logic(self, data_sync, daily_exists):
|
||||||
|
"""Test date comparison: incremental vs full sync selection logic.
|
||||||
|
|
||||||
|
Verify that:
|
||||||
|
- If local_last_date < today: incremental sync needed
|
||||||
|
- If local_last_date >= today: no sync needed
|
||||||
|
"""
|
||||||
|
if not daily_exists:
|
||||||
|
pytest.skip("daily.h5 not found")
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
local_last_date = data_sync.get_global_last_date()
|
||||||
|
today = datetime.now().strftime("%Y%m%d")
|
||||||
|
|
||||||
|
local_last_int = int(local_last_date)
|
||||||
|
today_int = int(today)
|
||||||
|
|
||||||
|
# Log the comparison
|
||||||
|
print(
|
||||||
|
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
|
||||||
|
f"today={today} ({today_int})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# This test just verifies the comparison logic works
|
||||||
|
if local_last_int < today_int:
|
||||||
|
print("[TEST] Local data is older than today - sync needed")
|
||||||
|
# Incremental sync should fetch from local_last_date + 1
|
||||||
|
sync_start = get_next_date(local_last_date)
|
||||||
|
assert int(sync_start) > local_last_int, (
|
||||||
|
"Sync start should be after local last"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[TEST] Local data is up-to-date - no sync needed")
|
||||||
|
|
||||||
|
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
|
||||||
|
"""Test get_all_stock_codes returns multiple real stock codes."""
|
||||||
|
if not daily_exists:
|
||||||
|
pytest.skip("daily.h5 not found")
|
||||||
|
|
||||||
|
codes = data_sync.get_all_stock_codes()
|
||||||
|
|
||||||
|
# Verify it's a list
|
||||||
|
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
|
||||||
|
assert len(codes) > 0, "No stock codes found"
|
||||||
|
|
||||||
|
# Verify multiple stocks
|
||||||
|
assert len(codes) >= 10, (
|
||||||
|
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify format (should be like 000001.SZ, 600000.SH)
|
||||||
|
sample_codes = codes[:5]
|
||||||
|
for code in sample_codes:
|
||||||
|
assert "." in code, f"Invalid stock code format: {code}"
|
||||||
|
suffix = code.split(".")[-1]
|
||||||
|
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
|
||||||
|
|
||||||
|
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
|
||||||
|
|
||||||
|
def test_multi_stock_date_range(self, data_sync, daily_exists):
|
||||||
|
"""Test that multiple stocks share the same date range in local data.
|
||||||
|
|
||||||
|
This verifies that local data has consistent date coverage across stocks.
|
||||||
|
"""
|
||||||
|
if not daily_exists:
|
||||||
|
pytest.skip("daily.h5 not found")
|
||||||
|
|
||||||
|
daily_data = data_sync.storage.load("daily")
|
||||||
|
|
||||||
|
# Get date range for each stock
|
||||||
|
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
|
||||||
|
|
||||||
|
# Get global min and max
|
||||||
|
global_min = str(daily_data["trade_date"].min())
|
||||||
|
global_max = str(daily_data["trade_date"].max())
|
||||||
|
|
||||||
|
print(f"[TEST] Global date range: {global_min} to {global_max}")
|
||||||
|
print(f"[TEST] Total stocks: {len(stock_dates)}")
|
||||||
|
|
||||||
|
# Verify we have data for multiple stocks
|
||||||
|
assert len(stock_dates) >= 10, (
|
||||||
|
f"Expected at least 10 stocks, got {len(stock_dates)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify date range is reasonable (at least 1 year of data)
|
||||||
|
global_min_int = int(global_min)
|
||||||
|
global_max_int = int(global_max)
|
||||||
|
days_span = global_max_int - global_min_int
|
||||||
|
|
||||||
|
assert days_span > 100, (
|
||||||
|
f"Date range too small: {days_span} days. "
|
||||||
|
f"Expected at least 100 days of data."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[TEST] Date span: {days_span} days")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDateUtilities:
|
||||||
|
"""Test date utility functions."""
|
||||||
|
|
||||||
|
def test_get_next_date(self):
|
||||||
|
"""Test get_next_date correctly calculates next day."""
|
||||||
|
# Test normal cases
|
||||||
|
assert get_next_date("20240101") == "20240102"
|
||||||
|
assert get_next_date("20240131") == "20240201" # Month boundary
|
||||||
|
assert get_next_date("20241231") == "20250101" # Year boundary
|
||||||
|
|
||||||
|
def test_incremental_vs_full_sync_logic(self):
|
||||||
|
"""Test the logic difference between incremental and full sync.
|
||||||
|
|
||||||
|
Incremental: start_date = local_last_date + 1
|
||||||
|
Full: start_date = 20180101
|
||||||
|
"""
|
||||||
|
# Scenario 1: Local data exists
|
||||||
|
local_last_date = "20240115"
|
||||||
|
incremental_start = get_next_date(local_last_date)
|
||||||
|
|
||||||
|
assert incremental_start == "20240116"
|
||||||
|
assert incremental_start != DEFAULT_START_DATE
|
||||||
|
|
||||||
|
# Scenario 2: Force full sync
|
||||||
|
full_sync_start = DEFAULT_START_DATE # "20180101"
|
||||||
|
|
||||||
|
assert full_sync_start == "20180101"
|
||||||
|
assert incremental_start != full_sync_start
|
||||||
|
|
||||||
|
print("[TEST] Incremental vs Full sync logic verified")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "-s"])
|
||||||
Reference in New Issue
Block a user