refactor: 重构 API 接口模块,整合为 api_wrappers 目录结构
- 将独立 API 模块 (daily, stock_basic, trade_cal) 整合至 api_wrappers/ - 重写 sync.py 使用新的 wrapper 结构,支持更多同步功能 - 更新测试文件适配新的模块结构 - 添加 pytest.ini 配置文件
This commit is contained in:
3
pytest.ini
Normal file
3
pytest.ini
Normal file
@@ -0,0 +1,3 @@
|
||||
[pytest]
|
||||
pythonpath = .
|
||||
testpaths = tests
|
||||
@@ -1,847 +0,0 @@
|
||||
# ProStock 数据接口封装规范
|
||||
|
||||
## 1. 概述
|
||||
|
||||
本文档定义了在 `src/data/` 目录下新增 Tushare API 接口封装的标准规范。所有非特殊接口(因子和基础数据)必须遵循此规范,以确保:
|
||||
- 代码风格统一
|
||||
- 自动 sync 支持
|
||||
- 增量更新逻辑一致
|
||||
- 减少存储写入压力
|
||||
|
||||
## 2. 接口分类
|
||||
|
||||
### 2.1 特殊接口(不参与统一 sync)
|
||||
|
||||
以下接口有独立的同步逻辑,不参与本文档定义的自动 sync 机制:
|
||||
|
||||
| 接口类型 | 示例 | 说明 |
|
||||
|---------|------|------|
|
||||
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
|
||||
| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 |
|
||||
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
|
||||
|
||||
### 2.2 标准接口(必须遵循本规范)
|
||||
|
||||
所有**按股票**或**按日期**获取的因子数据、行情数据、财务数据等,必须遵循本规范。
|
||||
|
||||
## 3. 文件结构
|
||||
|
||||
### 3.1 文件命名
|
||||
|
||||
```
|
||||
{data_type}.py
|
||||
```
|
||||
|
||||
示例:
|
||||
- `daily.py` - 日线行情
|
||||
- `moneyflow.py` - 资金流向
|
||||
- `limit_list.py` - 涨跌停数据
|
||||
- `stk_holdernumber.py` - 股东人数
|
||||
|
||||
### 3.2 文件位置
|
||||
|
||||
```
|
||||
src/data/
|
||||
├── __init__.py # 导出公共接口
|
||||
├── client.py # TushareClient(已有)
|
||||
├── config.py # 配置管理(已有)
|
||||
├── storage.py # 存储管理(已有)
|
||||
├── rate_limiter.py # 速率限制(已有)
|
||||
├── trade_cal.py # 交易日历(特殊接口)
|
||||
├── stock_basic.py # 股票基础(特殊接口)
|
||||
├── daily.py # 日线行情(参考示例)
|
||||
└── {new_data_type}.py # 新增接口文件
|
||||
```
|
||||
|
||||
## 4. 接口设计规范
|
||||
|
||||
### 4.1 数据获取函数
|
||||
|
||||
#### 4.1.1 按股票获取的接口
|
||||
|
||||
适用于:日线行情、分钟线、资金流向等
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
"""获取 {数据描述}。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码(如 '000001.SZ')
|
||||
start_date: 开始日期(YYYYMMDD格式)
|
||||
end_date: 结束日期(YYYYMMDD格式)
|
||||
# 其他参数说明...
|
||||
|
||||
Returns:
|
||||
pd.DataFrame 包含以下字段:
|
||||
- ts_code: 股票代码
|
||||
- trade_date: 交易日期
|
||||
# 其他字段...
|
||||
|
||||
Example:
|
||||
>>> data = get_{data_type}('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
params = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
# 其他参数...
|
||||
|
||||
data = client.query("{api_name}", **params)
|
||||
return data
|
||||
```
|
||||
|
||||
#### 4.1.2 按日期获取的接口
|
||||
|
||||
适用于:每日涨跌停、每日龙虎榜、每日筹码分布等
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
trade_date: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
"""获取 {数据描述}。
|
||||
|
||||
**优先按日期获取**(推荐):
|
||||
- 使用 trade_date 获取单日全市场数据
|
||||
- 或使用 start_date + end_date 获取区间数据
|
||||
|
||||
Args:
|
||||
trade_date: 交易日期(YYYYMMDD格式),获取单日全市场数据
|
||||
start_date: 开始日期(YYYYMMDD格式)
|
||||
end_date: 结束日期(YYYYMMDD格式)
|
||||
ts_code: 股票代码(可选,用于过滤特定股票)
|
||||
# 其他参数说明...
|
||||
|
||||
Returns:
|
||||
pd.DataFrame 包含以下字段:
|
||||
- ts_code: 股票代码
|
||||
- trade_date: 交易日期
|
||||
# 其他字段...
|
||||
|
||||
Example:
|
||||
>>> # 获取单日全市场数据(推荐)
|
||||
>>> data = get_{data_type}(trade_date='20240115')
|
||||
>>> # 获取区间数据
|
||||
>>> data = get_{data_type}(start_date='20240101', end_date='20240131')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
# 其他参数...
|
||||
|
||||
data = client.query("{api_name}", **params)
|
||||
return data
|
||||
```
|
||||
|
||||
### 4.2 关键设计原则
|
||||
|
||||
#### 4.2.1 优先按日期获取
|
||||
|
||||
**强烈建议**优先实现按日期获取的接口:
|
||||
|
||||
1. **效率更高**:一次请求获取全市场数据
|
||||
2. **API 调用更少**:N 天 = N 次调用,而非 N 天 × M 只股票
|
||||
3. **更适合增量更新**:按天检查本地数据,只获取缺失日期
|
||||
|
||||
#### 4.2.2 日期字段统一
|
||||
|
||||
- 统一使用 `trade_date` 作为日期字段名
|
||||
- 日期格式:`YYYYMMDD` 字符串
|
||||
- 如果 API 返回其他字段名(如 `date`、`end_date`),在返回前重命名为 `trade_date`
|
||||
|
||||
#### 4.2.3 股票代码字段
|
||||
|
||||
- 统一使用 `ts_code` 作为股票代码字段名
|
||||
- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH`
|
||||
|
||||
## 5. Sync 集成规范
|
||||
|
||||
### 5.1 在 sync.py 中注册新数据类型
|
||||
|
||||
在 `DataSync` 类中添加新数据类型的同步支持:
|
||||
|
||||
```python
|
||||
class DataSync:
|
||||
"""Data synchronization manager with full/incremental sync support."""
|
||||
|
||||
DEFAULT_MAX_WORKERS = 10
|
||||
|
||||
# 数据类型配置
|
||||
DATASET_CONFIG = {
|
||||
"daily": {
|
||||
"api_name": "pro_bar",
|
||||
"fetch_by": "stock", # 按股票获取
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"],
|
||||
},
|
||||
"moneyflow": {
|
||||
"api_name": "moneyflow",
|
||||
"fetch_by": "stock", # 按股票获取
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"],
|
||||
},
|
||||
"limit_list": {
|
||||
"api_name": "limit_list",
|
||||
"fetch_by": "date", # 按日期获取(优先)
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"],
|
||||
},
|
||||
# 新增数据类型...
|
||||
"{new_data_type}": {
|
||||
"api_name": "{tushare_api_name}",
|
||||
"fetch_by": "date", # "date" 或 "stock"
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 实现同步方法
|
||||
|
||||
#### 5.2.1 按日期获取的同步方法(推荐)
|
||||
|
||||
```python
|
||||
def sync_by_date(
|
||||
self,
|
||||
dataset_name: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync data by date (fetch all stocks for each date).
|
||||
|
||||
This is the RECOMMENDED approach for date-based data like:
|
||||
- limit_list (涨跌停)
|
||||
- top_list (龙虎榜)
|
||||
- cyq_perf (筹码分布)
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset in DATASET_CONFIG
|
||||
start_date: Start date (YYYYMMDD)
|
||||
end_date: End date (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
Combined DataFrame with all data
|
||||
"""
|
||||
from src.data.trade_cal import get_trading_days
|
||||
|
||||
config = self.DATASET_CONFIG[dataset_name]
|
||||
api_name = config["api_name"]
|
||||
date_field = config["date_field"]
|
||||
|
||||
# Get trading days in the range
|
||||
trading_days = get_trading_days(start_date, end_date)
|
||||
if not trading_days:
|
||||
print(f"[DataSync] No trading days in range {start_date} to {end_date}")
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"[DataSync] Fetching {dataset_name} for {len(trading_days)} trading days")
|
||||
|
||||
all_data = []
|
||||
error_occurred = False
|
||||
|
||||
for trade_date in tqdm(trading_days, desc=f"Syncing {dataset_name}"):
|
||||
if not self._stop_flag.is_set():
|
||||
break
|
||||
|
||||
try:
|
||||
data = self.client.query(
|
||||
api_name,
|
||||
trade_date=trade_date,
|
||||
)
|
||||
if not data.empty:
|
||||
all_data.append(data)
|
||||
except Exception as e:
|
||||
self._stop_flag.clear()
|
||||
error_occurred = True
|
||||
print(f"[ERROR] Failed to fetch {dataset_name} for {trade_date}: {e}")
|
||||
raise
|
||||
|
||||
if error_occurred or not all_data:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine all data
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
|
||||
# Ensure date field is consistent
|
||||
if date_field not in combined.columns and "trade_date" in combined.columns:
|
||||
combined = combined.rename(columns={"trade_date": date_field})
|
||||
|
||||
return combined
|
||||
```
|
||||
|
||||
#### 5.2.2 按股票获取的同步方法
|
||||
|
||||
```python
|
||||
def sync_by_stock(
|
||||
self,
|
||||
dataset_name: str,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync data by stock (fetch all dates for each stock).
|
||||
|
||||
Use this for stock-based data like:
|
||||
- daily (日线行情)
|
||||
- moneyflow (资金流向)
|
||||
- stk_holdernumber (股东人数)
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset in DATASET_CONFIG
|
||||
ts_code: Stock code
|
||||
start_date: Start date (YYYYMMDD)
|
||||
end_date: End date (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
DataFrame with data for the stock
|
||||
"""
|
||||
config = self.DATASET_CONFIG[dataset_name]
|
||||
api_name = config["api_name"]
|
||||
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
data = self.client.query(
|
||||
api_name,
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {dataset_name} for {ts_code}: {e}")
|
||||
raise
|
||||
```
|
||||
|
||||
### 5.3 增量更新逻辑
|
||||
|
||||
#### 5.3.1 通用增量更新检查
|
||||
|
||||
```python
|
||||
def check_incremental_sync(
|
||||
self,
|
||||
dataset_name: str,
|
||||
force_full: bool = False,
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""Check if incremental sync is needed for a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset
|
||||
force_full: If True, force full sync
|
||||
|
||||
Returns:
|
||||
Tuple of (sync_needed, start_date, end_date, local_last_date)
|
||||
"""
|
||||
config = self.DATASET_CONFIG[dataset_name]
|
||||
date_field = config["date_field"]
|
||||
|
||||
# If force_full, always sync from default start
|
||||
if force_full:
|
||||
print(f"[DataSync] Force full sync for {dataset_name}")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# Check local data
|
||||
local_data = self.storage.load(dataset_name)
|
||||
if local_data.empty or date_field not in local_data.columns:
|
||||
print(f"[DataSync] No local {dataset_name} data, full sync needed")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# Get local last date
|
||||
local_last_date = str(local_data[date_field].max())
|
||||
print(f"[DataSync] Local {dataset_name} last date: {local_last_date}")
|
||||
|
||||
# Get calendar last trading day
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print(f"[DataSync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[DataSync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# Compare dates
|
||||
if int(local_last_date) >= int(cal_last):
|
||||
print(f"[DataSync] {dataset_name} is up-to-date, skipping sync")
|
||||
return (False, None, None, None)
|
||||
|
||||
# Need incremental sync
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[DataSync] Incremental sync for {dataset_name} from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
```
|
||||
|
||||
#### 5.3.2 完整的同步入口
|
||||
|
||||
```python
|
||||
def sync_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
force_full: bool = False,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync a dataset with automatic incremental update.
|
||||
|
||||
This is the main entry point for syncing any dataset.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset in DATASET_CONFIG
|
||||
force_full: If True, force full reload
|
||||
max_workers: Number of worker threads (for stock-based sync)
|
||||
|
||||
Returns:
|
||||
DataFrame with synced data
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print(f"[DataSync] Starting {dataset_name} sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# Ensure trade calendar is up-to-date
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Check if sync is needed
|
||||
sync_needed, start_date, end_date, local_last = self.check_incremental_sync(
|
||||
dataset_name, force_full
|
||||
)
|
||||
|
||||
if not sync_needed:
|
||||
print(f"[DataSync] {dataset_name} is up-to-date, skipping")
|
||||
return pd.DataFrame()
|
||||
|
||||
config = self.DATASET_CONFIG[dataset_name]
|
||||
fetch_by = config["fetch_by"]
|
||||
|
||||
# Fetch data based on strategy
|
||||
if fetch_by == "date":
|
||||
# Fetch by date (all stocks per day)
|
||||
data = self.sync_by_date(dataset_name, start_date, end_date)
|
||||
else:
|
||||
# Fetch by stock (all dates per stock)
|
||||
data = self._sync_all_stocks(dataset_name, start_date, end_date, max_workers)
|
||||
|
||||
if data.empty:
|
||||
print(f"[DataSync] No new data for {dataset_name}")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Save to storage (single write)
|
||||
self.storage.save(dataset_name, data, mode="append")
|
||||
|
||||
print(f"[DataSync] Synced {len(data)} rows for {dataset_name}")
|
||||
return data
|
||||
|
||||
def _sync_all_stocks(
|
||||
self,
|
||||
dataset_name: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync data for all stocks (stock-based fetch)."""
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"[DataSync] Syncing {dataset_name} for {len(stock_codes)} stocks")
|
||||
|
||||
self._stop_flag.set()
|
||||
results = []
|
||||
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_code = {
|
||||
executor.submit(
|
||||
self.sync_by_stock, dataset_name, ts_code, start_date, end_date
|
||||
): ts_code
|
||||
for ts_code in stock_codes
|
||||
}
|
||||
|
||||
with tqdm(total=len(stock_codes), desc=f"Syncing {dataset_name}") as pbar:
|
||||
for future in as_completed(future_to_code):
|
||||
try:
|
||||
data = future.result()
|
||||
if not data.empty:
|
||||
results.append(data)
|
||||
except Exception as e:
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise
|
||||
pbar.update(1)
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
return pd.concat(results, ignore_index=True)
|
||||
```
|
||||
|
||||
## 6. 存储规范
|
||||
|
||||
### 6.1 Storage 类使用
|
||||
|
||||
所有数据通过 `Storage` 类进行 HDF5 存储:
|
||||
|
||||
```python
|
||||
from src.data.storage import Storage
|
||||
|
||||
storage = Storage()
|
||||
|
||||
# 保存数据(自动增量合并)
|
||||
storage.save("dataset_name", data, mode="append")
|
||||
|
||||
# 加载数据
|
||||
all_data = storage.load("dataset_name")
|
||||
filtered_data = storage.load("dataset_name", start_date="20240101", end_date="20240131")
|
||||
|
||||
# 获取最新日期
|
||||
last_date = storage.get_last_date("dataset_name")
|
||||
|
||||
# 检查是否存在
|
||||
exists = storage.exists("dataset_name")
|
||||
```
|
||||
|
||||
### 6.2 增量写入策略
|
||||
|
||||
**关键原则**:所有数据在请求完成后**一次性写入**,而非逐条写入:
|
||||
|
||||
```python
|
||||
# ❌ 错误:逐条写入(性能差)
|
||||
for date in dates:
|
||||
data = fetch(date)
|
||||
storage.save("dataset", data, mode="append") # 多次写入
|
||||
|
||||
# ✅ 正确:批量写入(性能好)
|
||||
all_data = []
|
||||
for date in dates:
|
||||
data = fetch(date)
|
||||
all_data.append(data)
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
storage.save("dataset", combined, mode="append") # 单次写入
|
||||
```
|
||||
|
||||
### 6.3 去重策略
|
||||
|
||||
`Storage.save()` 方法会自动去重,基于配置中的 `key_fields`:
|
||||
|
||||
```python
|
||||
# storage.py 中的实现
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], # 使用 key_fields
|
||||
keep="last" # 保留最新数据
|
||||
)
|
||||
```
|
||||
|
||||
## 7. 完整示例:新增涨跌停数据接口
|
||||
|
||||
### 7.1 创建 limit_list.py
|
||||
|
||||
```python
|
||||
"""Limit up/down list interface.
|
||||
|
||||
Fetch stocks that hit limit up or limit down for a specific trade date.
|
||||
This is a date-based interface (recommended approach).
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
def get_limit_list(
|
||||
trade_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""获取涨跌停数据。
|
||||
|
||||
**优先按日期获取**(推荐):
|
||||
- 使用 trade_date 获取单日全市场涨跌停数据
|
||||
- 或使用 start_date + end_date 获取区间数据
|
||||
|
||||
Args:
|
||||
trade_date: 交易日期(YYYYMMDD格式),获取单日全市场数据
|
||||
ts_code: 股票代码(可选,用于过滤)
|
||||
start_date: 开始日期(YYYYMMDD格式)
|
||||
end_date: 结束日期(YYYYMMDD格式)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame 包含以下字段:
|
||||
- ts_code: 股票代码
|
||||
- trade_date: 交易日期
|
||||
- name: 股票名称
|
||||
- close: 收盘价
|
||||
- pct_chg: 涨跌幅
|
||||
- amp: 振幅
|
||||
- fc_ratio: 封单金额/日成交额
|
||||
- fl_ratio: 封单手数/流通股本
|
||||
- fd_amount: 封单金额
|
||||
- first_time: 首次涨停时间
|
||||
- last_time: 最后封板时间
|
||||
- open_times: 打开次数
|
||||
- strth: 涨停强度
|
||||
- limit: 涨停类型(U涨停D跌停)
|
||||
|
||||
Example:
|
||||
>>> # 获取单日全市场涨跌停数据(推荐)
|
||||
>>> data = get_limit_list(trade_date='20240115')
|
||||
>>> # 获取区间数据
|
||||
>>> data = get_limit_list(start_date='20240101', end_date='20240131')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
data = client.query("limit_list", **params)
|
||||
return data
|
||||
```
|
||||
|
||||
### 7.2 在 sync.py 中注册
|
||||
|
||||
```python
|
||||
class DataSync:
|
||||
"""Data synchronization manager with full/incremental sync support."""
|
||||
|
||||
DATASET_CONFIG = {
|
||||
# ... 其他配置 ...
|
||||
"limit_list": {
|
||||
"api_name": "limit_list",
|
||||
"fetch_by": "date", # 按日期获取
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"],
|
||||
},
|
||||
}
|
||||
|
||||
# ... 其他方法 ...
|
||||
|
||||
def sync_limit_list(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync limit list data."""
|
||||
return self.sync_dataset("limit_list", force_full)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def sync_limit_list(force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync limit up/down data."""
|
||||
sync_manager = DataSync()
|
||||
return sync_manager.sync_limit_list(force_full)
|
||||
```
|
||||
|
||||
### 7.3 更新 __init__.py
|
||||
|
||||
```python
|
||||
from src.data.limit_list import get_limit_list
|
||||
|
||||
__all__ = [
|
||||
# ... 其他导出 ...
|
||||
"get_limit_list",
|
||||
]
|
||||
```
|
||||
|
||||
## 8. 测试规范
|
||||
|
||||
### 8.1 测试文件结构
|
||||
|
||||
```
|
||||
tests/
|
||||
├── test_sync.py # sync 模块测试
|
||||
├── test_daily.py # daily 模块测试
|
||||
└── test_{new_module}.py # 新增模块测试
|
||||
```
|
||||
|
||||
### 8.2 测试模板
|
||||
|
||||
```python
|
||||
"""Tests for {module_name} module."""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pandas as pd
|
||||
from src.data.{module_name} import get_{data_type}
|
||||
|
||||
|
||||
class Test{DataType}:
|
||||
"""Test cases for {data_type} data fetching."""
|
||||
|
||||
@patch("src.data.{module_name}.TushareClient")
|
||||
def test_get_{data_type}_by_date(self, mock_client_class):
|
||||
"""Test fetching data by date."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame({
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
# ... 其他字段 ...
|
||||
})
|
||||
|
||||
# Call function
|
||||
result = get_{data_type}(trade_date="20240115")
|
||||
|
||||
# Verify
|
||||
assert not result.empty
|
||||
mock_client.query.assert_called_once_with(
|
||||
"{api_name}",
|
||||
trade_date="20240115",
|
||||
)
|
||||
|
||||
@patch("src.data.{module_name}.TushareClient")
|
||||
def test_get_{data_type}_by_stock(self, mock_client_class):
|
||||
"""Test fetching data by stock code."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame({
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
# ... 其他字段 ...
|
||||
})
|
||||
|
||||
# Call function
|
||||
result = get_{data_type}(
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert not result.empty
|
||||
mock_client.query.assert_called_once()
|
||||
```
|
||||
|
||||
## 9. 检查清单
|
||||
|
||||
在提交新接口前,请确认以下事项:
|
||||
|
||||
### 9.1 文件结构
|
||||
- [ ] 文件位于 `src/data/{data_type}.py`
|
||||
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
|
||||
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
|
||||
|
||||
### 9.2 接口实现
|
||||
- [ ] 数据获取函数使用 `TushareClient`
|
||||
- [ ] 函数包含完整的 Google 风格文档字符串
|
||||
- [ ] 日期参数使用 `YYYYMMDD` 格式
|
||||
- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段
|
||||
- [ ] 优先实现按日期获取的接口(如果 API 支持)
|
||||
|
||||
### 9.3 Sync 集成
|
||||
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
|
||||
- [ ] 正确设置 `fetch_by`("date" 或 "stock")
|
||||
- [ ] 正确设置 `date_field` 和 `key_fields`
|
||||
- [ ] 已实现对应的 sync 方法或复用通用方法
|
||||
- [ ] 增量更新逻辑正确(检查本地最新日期)
|
||||
|
||||
### 9.4 存储优化
|
||||
- [ ] 所有数据一次性写入(非逐条)
|
||||
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
|
||||
- [ ] 去重字段配置正确
|
||||
|
||||
### 9.5 测试
|
||||
- [ ] 已编写单元测试
|
||||
- [ ] 已 mock TushareClient
|
||||
- [ ] 测试覆盖正常和异常情况
|
||||
|
||||
## 10. 常见问题
|
||||
|
||||
### Q1: API 返回的日期字段名不是 trade_date 怎么办?
|
||||
|
||||
在返回前重命名:
|
||||
|
||||
```python
|
||||
data = client.query("api_name", **params)
|
||||
if "end_date" in data.columns:
|
||||
data = data.rename(columns={"end_date": "trade_date"})
|
||||
return data
|
||||
```
|
||||
|
||||
### Q2: 如何处理分页(limit/offset)?
|
||||
|
||||
Tushare Pro API 通常不需要手动分页,但如果需要:
|
||||
|
||||
```python
|
||||
all_data = []
|
||||
offset = 0
|
||||
limit = 5000
|
||||
|
||||
while True:
|
||||
data = client.query(
|
||||
"api_name",
|
||||
trade_date=trade_date,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if data.empty or len(data) < limit:
|
||||
all_data.append(data)
|
||||
break
|
||||
all_data.append(data)
|
||||
offset += limit
|
||||
|
||||
return pd.concat(all_data, ignore_index=True)
|
||||
```
|
||||
|
||||
### Q3: 如何处理需要额外参数的接口?
|
||||
|
||||
在函数签名中添加参数,并传递给 client.query:
|
||||
|
||||
```python
|
||||
def get_data(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
fields: Optional[list] = None, # 额外参数
|
||||
) -> pd.DataFrame:
|
||||
params = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if fields:
|
||||
params["fields"] = ",".join(fields)
|
||||
|
||||
return client.query("api_name", **params)
|
||||
```
|
||||
|
||||
### Q4: 如何处理没有 trade_date 字段的数据?
|
||||
|
||||
如果数据确实不包含日期字段(如静态数据),可以:
|
||||
1. 将其归类为"特殊接口",独立管理
|
||||
2. 或者添加一个 `sync_date` 字段记录同步时间
|
||||
|
||||
### Q5: 如何处理按日期获取但 API 不支持的情况?
|
||||
|
||||
如果 API 只支持按股票获取:
|
||||
1. 在 `DATASET_CONFIG` 中设置 `fetch_by: "stock"`
|
||||
2. 使用 `_sync_all_stocks` 方法进行同步
|
||||
3. 在文档中说明这是按股票获取的接口
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-01
|
||||
@@ -6,7 +6,7 @@ Provides simplified interfaces for fetching and storing Tushare data.
|
||||
from src.data.config import Config, get_config
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.stock_basic import get_stock_basic, sync_all_stocks
|
||||
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
|
||||
244
src/data/api_wrappers/API_INTERFACE_SPEC.md
Normal file
244
src/data/api_wrappers/API_INTERFACE_SPEC.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# ProStock 数据接口封装规范
|
||||
|
||||
## 1. 概述
|
||||
|
||||
本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保:
|
||||
- 代码风格统一
|
||||
- 自动 sync 支持
|
||||
- 增量更新逻辑一致
|
||||
- 减少存储写入压力
|
||||
|
||||
## 2. 接口分类
|
||||
|
||||
### 2.1 特殊接口(不参与统一 sync)
|
||||
|
||||
以下接口有独立的同步逻辑,不参与自动 sync 机制:
|
||||
|
||||
| 接口类型 | 示例 | 说明 |
|
||||
|---------|------|------|
|
||||
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
|
||||
| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 |
|
||||
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
|
||||
|
||||
### 2.2 标准接口(必须遵循本规范)
|
||||
|
||||
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。
|
||||
|
||||
## 3. 文件结构要求
|
||||
|
||||
### 3.1 文件命名
|
||||
|
||||
```
|
||||
{data_type}.py
|
||||
```
|
||||
|
||||
示例:`daily.py`、`moneyflow.py`、`limit_list.py`
|
||||
|
||||
### 3.2 文件位置
|
||||
|
||||
所有接口文件必须位于 `src/data/` 目录下。
|
||||
|
||||
### 3.3 导出要求
|
||||
|
||||
新接口必须在 `src/data/__init__.py` 中导出:
|
||||
|
||||
```python
|
||||
from src.data.{module_name} import get_{data_type}
|
||||
|
||||
__all__ = [
|
||||
# ... 其他导出 ...
|
||||
"get_{data_type}",
|
||||
]
|
||||
```
|
||||
|
||||
## 4. 接口设计规范
|
||||
|
||||
### 4.1 数据获取函数签名要求
|
||||
|
||||
函数必须返回 `pd.DataFrame`,参数必须包含以下之一:
|
||||
|
||||
#### 4.1.1 按日期获取的接口(优先)
|
||||
|
||||
适用于:涨跌停、龙虎榜、筹码分布等。
|
||||
|
||||
**函数签名要求**:
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
trade_date: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
```
|
||||
|
||||
**要求**:
|
||||
- 优先使用 `trade_date` 获取单日全市场数据
|
||||
- 支持 `start_date + end_date` 获取区间数据
|
||||
- `ts_code` 作为可选过滤参数
|
||||
|
||||
#### 4.1.2 按股票获取的接口
|
||||
|
||||
适用于:日线行情、资金流向等。
|
||||
|
||||
**函数签名要求**:
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
```
|
||||
|
||||
### 4.2 文档字符串要求
|
||||
|
||||
函数必须包含 Google 风格的完整文档字符串,包含:
|
||||
- 函数功能描述
|
||||
- `Args` 部分:所有参数说明
|
||||
- `Returns` 部分:返回的 DataFrame 包含的字段说明
|
||||
- `Example` 部分:使用示例
|
||||
|
||||
### 4.3 日期格式要求
|
||||
|
||||
- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式
|
||||
- 统一使用 `trade_date` 作为日期字段名
|
||||
- 如果 API 返回其他日期字段名(如 `date`、`end_date`),必须在返回前重命名为 `trade_date`
|
||||
|
||||
### 4.4 股票代码要求
|
||||
|
||||
- 统一使用 `ts_code` 作为股票代码字段名
|
||||
- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH`
|
||||
|
||||
### 4.5 令牌桶限速要求
|
||||
|
||||
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
|
||||
|
||||
## 5. Sync 集成规范
|
||||
|
||||
### 5.1 DATASET_CONFIG 注册要求
|
||||
|
||||
新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项:
|
||||
|
||||
```python
|
||||
"{new_data_type}": {
|
||||
"api_name": "{tushare_api_name}", # Tushare API 名称
|
||||
"fetch_by": "date", # "date" 或 "stock"
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 fetch_by 取值规则
|
||||
|
||||
- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据
|
||||
- 仅当 API 不支持按日期获取时才使用 `"stock"`
|
||||
|
||||
### 5.3 sync 方法要求
|
||||
|
||||
必须实现对应的 sync 方法或复用通用方法:
|
||||
|
||||
```python
|
||||
def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述}。"""
|
||||
return self.sync_dataset("{data_type}", force_full)
|
||||
```
|
||||
|
||||
同时提供便捷函数:
|
||||
|
||||
```python
|
||||
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述}。"""
|
||||
sync_manager = DataSync()
|
||||
return sync_manager.sync_{data_type}(force_full)
|
||||
```
|
||||
|
||||
### 5.4 增量更新要求
|
||||
|
||||
- 必须实现增量更新逻辑(自动检查本地最新日期)
|
||||
- 使用 `force_full` 参数支持强制全量同步
|
||||
|
||||
## 6. 存储规范
|
||||
|
||||
### 6.1 存储方式
|
||||
|
||||
所有数据通过 `Storage` 类进行 HDF5 存储。
|
||||
|
||||
### 6.2 写入策略
|
||||
|
||||
**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。
|
||||
|
||||
### 6.3 去重要求
|
||||
|
||||
使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`。
|
||||
|
||||
## 7. 测试规范
|
||||
|
||||
### 7.1 测试文件要求
|
||||
|
||||
必须创建对应的测试文件:`tests/test_{data_type}.py`
|
||||
|
||||
### 7.2 测试覆盖要求
|
||||
|
||||
- 测试按日期获取
|
||||
- 测试按股票获取(如果支持)
|
||||
- 必须 mock `TushareClient`
|
||||
- 测试覆盖正常和异常情况
|
||||
|
||||
## 8. 新增接口完整流程
|
||||
|
||||
### 8.1 创建接口文件
|
||||
|
||||
1. 在 `src/data/` 下创建 `{data_type}.py`
|
||||
2. 实现数据获取函数,遵循第 4 节规范
|
||||
|
||||
### 8.2 注册 sync 支持
|
||||
|
||||
1. 在 `sync.py` 的 `DataSync.DATASET_CONFIG` 中注册
|
||||
2. 实现对应的 sync 方法
|
||||
3. 提供便捷函数
|
||||
|
||||
### 8.3 更新导出
|
||||
|
||||
在 `src/data/__init__.py` 中导出接口函数。
|
||||
|
||||
### 8.4 创建测试
|
||||
|
||||
创建 `tests/test_{data_type}.py`,覆盖关键场景。
|
||||
|
||||
## 9. 检查清单
|
||||
|
||||
### 9.1 文件结构
|
||||
- [ ] 文件位于 `src/data/{data_type}.py`
|
||||
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
|
||||
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
|
||||
|
||||
### 9.2 接口实现
|
||||
- [ ] 数据获取函数使用 `TushareClient`
|
||||
- [ ] 函数包含完整的 Google 风格文档字符串
|
||||
- [ ] 日期参数使用 `YYYYMMDD` 格式
|
||||
- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段
|
||||
- [ ] 优先实现按日期获取的接口(如果 API 支持)
|
||||
|
||||
### 9.3 Sync 集成
|
||||
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
|
||||
- [ ] 正确设置 `fetch_by`("date" 或 "stock")
|
||||
- [ ] 正确设置 `date_field` 和 `key_fields`
|
||||
- [ ] 已实现对应的 sync 方法或复用通用方法
|
||||
- [ ] 增量更新逻辑正确(检查本地最新日期)
|
||||
|
||||
### 9.4 存储优化
|
||||
- [ ] 所有数据一次性写入(非逐条)
|
||||
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
|
||||
- [ ] 去重字段配置正确
|
||||
|
||||
### 9.5 测试
|
||||
- [ ] 已编写单元测试
|
||||
- [ ] 已 mock TushareClient
|
||||
- [ ] 测试覆盖正常和异常情况
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-01
|
||||
40
src/data/api_wrappers/__init__.py
Normal file
40
src/data/api_wrappers/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Tushare API wrapper modules.
|
||||
|
||||
This package contains simplified interfaces for fetching data from Tushare API.
|
||||
All wrapper files follow the naming convention: api_{data_type}.py
|
||||
|
||||
Available APIs:
|
||||
- api_daily: Daily market data (日线行情)
|
||||
- api_stock_basic: Stock basic information (股票基本信息)
|
||||
- api_trade_cal: Trading calendar (交易日历)
|
||||
|
||||
Example:
|
||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal
|
||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> stocks = get_stock_basic()
|
||||
>>> calendar = get_trade_cal('20240101', '20240131')
|
||||
"""
|
||||
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_trade_cal,
|
||||
get_trading_days,
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Daily market data
|
||||
"get_daily",
|
||||
# Stock basic information
|
||||
"get_stock_basic",
|
||||
"sync_all_stocks",
|
||||
# Trade calendar
|
||||
"get_trade_cal",
|
||||
"get_trading_days",
|
||||
"get_first_trading_day",
|
||||
"get_last_trading_day",
|
||||
"sync_trade_cal_cache",
|
||||
]
|
||||
@@ -179,4 +179,75 @@ df = pro.query('trade_cal', start_date='20180101', end_date='20181231')
|
||||
17 SSE 20180118 1
|
||||
18 SSE 20180119 1
|
||||
19 SSE 20180120 0
|
||||
20 SSE 20180121 0
|
||||
20 SSE 20180121 0
|
||||
|
||||
|
||||
每日指标
|
||||
接口:daily_basic,可以通过数据工具调试和查看数据。
|
||||
更新时间:交易日每日15点~17点之间
|
||||
描述:获取全部股票每日重要的基本面指标,可用于选股分析、报表展示等。单次请求最大返回6000条数据,可按日线循环提取全部历史。
|
||||
积分:至少2000积分才可以调取,5000积分无总量限制,具体请参阅积分获取办法
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str Y 股票代码(二选一)
|
||||
trade_date str N 交易日期 (二选一)
|
||||
start_date str N 开始日期(YYYYMMDD)
|
||||
end_date str N 结束日期(YYYYMMDD)
|
||||
注:日期都填YYYYMMDD格式,比如20181010
|
||||
|
||||
输出参数
|
||||
|
||||
名称 类型 描述
|
||||
ts_code str TS股票代码
|
||||
trade_date str 交易日期
|
||||
close float 当日收盘价
|
||||
turnover_rate float 换手率(%)
|
||||
turnover_rate_f float 换手率(自由流通股)
|
||||
volume_ratio float 量比
|
||||
pe float 市盈率(总市值/净利润, 亏损的PE为空)
|
||||
pe_ttm float 市盈率(TTM,亏损的PE为空)
|
||||
pb float 市净率(总市值/净资产)
|
||||
ps float 市销率
|
||||
ps_ttm float 市销率(TTM)
|
||||
dv_ratio float 股息率 (%)
|
||||
dv_ttm float 股息率(TTM)(%)
|
||||
total_share float 总股本 (万股)
|
||||
float_share float 流通股本 (万股)
|
||||
free_share float 自由流通股本 (万)
|
||||
total_mv float 总市值 (万元)
|
||||
circ_mv float 流通市值(万元)
|
||||
接口用法
|
||||
|
||||
|
||||
pro = ts.pro_api()
|
||||
|
||||
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||
或者
|
||||
|
||||
|
||||
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||
数据样例
|
||||
|
||||
ts_code trade_date turnover_rate volume_ratio pe pb
|
||||
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
|
||||
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
|
||||
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
|
||||
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
|
||||
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
|
||||
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
|
||||
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
|
||||
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
|
||||
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
|
||||
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
|
||||
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
|
||||
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
|
||||
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
|
||||
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
|
||||
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
|
||||
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
|
||||
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
|
||||
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
|
||||
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
|
||||
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
|
||||
@@ -3,6 +3,7 @@
|
||||
A single function to fetch A股日线行情 data from Tushare.
|
||||
Supports all output fields including tor (换手率) and vr (量比).
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal
|
||||
from src.data.client import TushareClient
|
||||
@@ -33,7 +34,7 @@ def get_daily(
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with daily market data containing:
|
||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||
change, pct_chg, vol, amount
|
||||
- Factor fields (if requested): tor, vr
|
||||
- Adjustment factor (if adjfactor=True): adjfactor
|
||||
@@ -3,6 +3,7 @@
|
||||
Fetch basic stock information including code, name, listing date, etc.
|
||||
This is a special interface - call once to get all stocks (listed and delisted).
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Simplified HDF5 storage for data persistence."""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -47,7 +48,9 @@ class Storage:
|
||||
# Merge with existing data
|
||||
existing = store[name]
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], keep="last"
|
||||
)
|
||||
store.put(name, combined, format="table")
|
||||
|
||||
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
||||
@@ -57,10 +60,13 @@ class Storage:
|
||||
print(f"[Storage] Error saving {name}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def load(self, name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None) -> pd.DataFrame:
|
||||
def load(
|
||||
self,
|
||||
name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Load data from HDF5 file.
|
||||
|
||||
Args:
|
||||
@@ -80,14 +86,25 @@ class Storage:
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="r") as store:
|
||||
if name not in store.keys():
|
||||
keys = store.keys()
|
||||
# Handle both '/daily' and 'daily' keys
|
||||
actual_key = None
|
||||
if name in keys:
|
||||
actual_key = name
|
||||
elif f"/{name}" in keys:
|
||||
actual_key = f"/{name}"
|
||||
|
||||
if actual_key is None:
|
||||
return pd.DataFrame()
|
||||
|
||||
data = store[name]
|
||||
data = store[actual_key]
|
||||
|
||||
# Apply filters
|
||||
if start_date and end_date and "trade_date" in data.columns:
|
||||
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
|
||||
data = data[
|
||||
(data["trade_date"] >= start_date)
|
||||
& (data["trade_date"] <= end_date)
|
||||
]
|
||||
|
||||
if ts_code and "ts_code" in data.columns:
|
||||
data = data[data["ts_code"] == ts_code]
|
||||
|
||||
295
src/data/sync.py
295
src/data/sync.py
@@ -5,11 +5,15 @@ This module provides data fetching functions with intelligent sync logic:
|
||||
- If local file exists: incremental update (fetch from latest date + 1 day)
|
||||
- Multi-threaded concurrent fetching for improved performance
|
||||
- Stop immediately on any exception
|
||||
- Preview mode: check data volume and samples before actual sync
|
||||
|
||||
Currently supported data types:
|
||||
- daily: Daily market data (with turnover rate and volume ratio)
|
||||
|
||||
Usage:
|
||||
# Preview sync (check data volume and samples without writing)
|
||||
preview_sync()
|
||||
|
||||
# Sync all stocks (full load)
|
||||
sync_all()
|
||||
|
||||
@@ -18,6 +22,9 @@ Usage:
|
||||
|
||||
# Force full reload
|
||||
sync_all(force_full=True)
|
||||
|
||||
# Dry run (preview only, no write)
|
||||
sync_all(dry_run=True)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
@@ -30,8 +37,8 @@ import sys
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.daily import get_daily
|
||||
from src.data.trade_cal import (
|
||||
from src.data.api_wrappers import get_daily
|
||||
from src.data.api_wrappers import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
@@ -114,7 +121,8 @@ class DataSync:
|
||||
List of stock codes
|
||||
"""
|
||||
# Import sync_all_stocks here to avoid circular imports
|
||||
from src.data.stock_basic import sync_all_stocks, _get_csv_path
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||
|
||||
# First, ensure stock_basic.csv is up-to-date with all stocks
|
||||
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
|
||||
@@ -278,6 +286,184 @@ class DataSync:
|
||||
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""Preview sync data volume and samples without actually syncing.
|
||||
|
||||
This method provides a preview of what would be synced, including:
|
||||
- Number of stocks to be synced
|
||||
- Date range for sync
|
||||
- Estimated total records
|
||||
- Sample data from first few stocks
|
||||
|
||||
Args:
|
||||
force_full: If True, preview full sync from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||
|
||||
Returns:
|
||||
Dictionary with preview information:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' or 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# First, ensure trade calendar cache is up-to-date
|
||||
print("[DataSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Determine date range
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Check if sync is needed
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# Use dates from check_sync_needed
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DataSync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[DataSync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# Fetch sample data from first few stocks
|
||||
print(f"[DataSync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# Combine sample data
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# Estimate total records based on sample
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# Display preview results
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# Display sample data in a compact format
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
@@ -320,6 +506,7 @@ class DataSync:
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks in local storage.
|
||||
|
||||
@@ -337,9 +524,10 @@ class DataSync:
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
dry_run: If True, only preview what would be synced without writing data
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame (empty if sync skipped)
|
||||
Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Starting daily data sync...")
|
||||
@@ -378,11 +566,14 @@ class DataSync:
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
@@ -394,6 +585,17 @@ class DataSync:
|
||||
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# Handle dry run mode
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# Reset stop flag for new sync
|
||||
self._stop_flag.set()
|
||||
|
||||
@@ -492,11 +694,62 @@ class DataSync:
|
||||
# Convenience functions
|
||||
|
||||
|
||||
def preview_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Preview sync data volume and samples without actually syncing.
|
||||
|
||||
This is the recommended way to check what would be synced before
|
||||
running the actual synchronization.
|
||||
|
||||
Args:
|
||||
force_full: If True, preview full sync from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||
max_workers: Number of worker threads (not used in preview, for API compatibility)
|
||||
|
||||
Returns:
|
||||
Dictionary with preview information:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', or 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # Preview what would be synced
|
||||
>>> preview = preview_sync()
|
||||
>>>
|
||||
>>> # Preview full sync
|
||||
>>> preview = preview_sync(force_full=True)
|
||||
>>>
|
||||
>>> # Preview with more samples
|
||||
>>> preview = preview_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.preview_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
|
||||
|
||||
def sync_all(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks.
|
||||
|
||||
@@ -507,6 +760,7 @@ def sync_all(
|
||||
start_date: Manual start date (YYYYMMDD)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
dry_run: If True, only preview what would be synced without writing data
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame
|
||||
@@ -526,12 +780,16 @@ def sync_all(
|
||||
>>>
|
||||
>>> # Custom thread count
|
||||
>>> result = sync_all(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run (preview only)
|
||||
>>> result = sync_all(dry_run=True)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
@@ -540,11 +798,32 @@ if __name__ == "__main__":
|
||||
print("Data Sync Module")
|
||||
print("=" * 60)
|
||||
print("\nUsage:")
|
||||
print(" from src.data.sync import sync_all")
|
||||
print(" from src.data.sync import sync_all, preview_sync")
|
||||
print("")
|
||||
print(" # Preview before sync (recommended)")
|
||||
print(" preview = preview_sync()")
|
||||
print("")
|
||||
print(" # Dry run (preview only)")
|
||||
print(" result = sync_all(dry_run=True)")
|
||||
print("")
|
||||
print(" # Actual sync")
|
||||
print(" result = sync_all() # Incremental sync")
|
||||
print(" result = sync_all(force_full=True) # Full reload")
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
# Run sync
|
||||
result = sync_all()
|
||||
print(f"\nSynced {len(result)} stocks")
|
||||
# Run preview first
|
||||
print("\n[Main] Running preview first...")
|
||||
preview = preview_sync()
|
||||
|
||||
if preview["sync_needed"]:
|
||||
# Ask for confirmation
|
||||
print("\n" + "=" * 60)
|
||||
response = input("Proceed with sync? (y/n): ").strip().lower()
|
||||
if response in ("y", "yes"):
|
||||
print("\n[Main] Starting actual sync...")
|
||||
result = sync_all()
|
||||
print(f"\nSynced {len(result)} stocks")
|
||||
else:
|
||||
print("\n[Main] Sync cancelled by user")
|
||||
else:
|
||||
print("\n[Main] No sync needed - data is up to date")
|
||||
|
||||
@@ -5,29 +5,30 @@ Tests the daily interface implementation against api.md requirements:
|
||||
- tor 换手率
|
||||
- vr 量比
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from src.data.daily import get_daily
|
||||
from src.data.api_wrappers import get_daily
|
||||
|
||||
|
||||
# Expected output fields according to api.md
|
||||
EXPECTED_BASE_FIELDS = [
|
||||
'ts_code', # 股票代码
|
||||
'trade_date', # 交易日期
|
||||
'open', # 开盘价
|
||||
'high', # 最高价
|
||||
'low', # 最低价
|
||||
'close', # 收盘价
|
||||
'pre_close', # 昨收价
|
||||
'change', # 涨跌额
|
||||
'pct_chg', # 涨跌幅
|
||||
'vol', # 成交量
|
||||
'amount', # 成交额
|
||||
"ts_code", # 股票代码
|
||||
"trade_date", # 交易日期
|
||||
"open", # 开盘价
|
||||
"high", # 最高价
|
||||
"low", # 最低价
|
||||
"close", # 收盘价
|
||||
"pre_close", # 昨收价
|
||||
"change", # 涨跌额
|
||||
"pct_chg", # 涨跌幅
|
||||
"vol", # 成交量
|
||||
"amount", # 成交额
|
||||
]
|
||||
|
||||
EXPECTED_FACTOR_FIELDS = [
|
||||
'turnover_rate', # 换手率 (tor)
|
||||
'volume_ratio', # 量比 (vr)
|
||||
"turnover_rate", # 换手率 (tor)
|
||||
"volume_ratio", # 量比 (vr)
|
||||
]
|
||||
|
||||
|
||||
@@ -36,19 +37,19 @@ class TestGetDaily:
|
||||
|
||||
def test_fetch_basic(self):
|
||||
"""Test basic daily data fetch with real API."""
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
result = get_daily("000001.SZ", start_date="20240101", end_date="20240131")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) >= 1
|
||||
assert result['ts_code'].iloc[0] == '000001.SZ'
|
||||
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||
|
||||
def test_fetch_with_factors(self):
|
||||
"""Test fetch with tor and vr factors."""
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
factors=['tor', 'vr'],
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=["tor", "vr"],
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
@@ -61,25 +62,26 @@ class TestGetDaily:
|
||||
|
||||
def test_output_fields_completeness(self):
|
||||
"""Verify all required output fields are returned."""
|
||||
result = get_daily('600000.SH')
|
||||
result = get_daily("600000.SH")
|
||||
|
||||
# Verify all base fields are present
|
||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
|
||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), (
|
||||
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
|
||||
)
|
||||
|
||||
def test_empty_result(self):
|
||||
"""Test handling of empty results."""
|
||||
# 使用真实 API 测试无效股票代码的空结果
|
||||
result = get_daily('INVALID.SZ')
|
||||
result = get_daily("INVALID.SZ")
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert result.empty
|
||||
|
||||
def test_date_range_query(self):
|
||||
"""Test query with date range."""
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
@@ -87,7 +89,7 @@ class TestGetDaily:
|
||||
|
||||
def test_with_adj(self):
|
||||
"""Test fetch with adjustment type."""
|
||||
result = get_daily('000001.SZ', adj='qfq')
|
||||
result = get_daily("000001.SZ", adj="qfq")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
|
||||
@@ -95,11 +97,14 @@ class TestGetDaily:
|
||||
def test_integration():
|
||||
"""Integration test with real Tushare API (requires valid token)."""
|
||||
import os
|
||||
token = os.environ.get('TUSHARE_TOKEN')
|
||||
|
||||
token = os.environ.get("TUSHARE_TOKEN")
|
||||
if not token:
|
||||
pytest.skip("TUSHARE_TOKEN not configured")
|
||||
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
|
||||
result = get_daily(
|
||||
"000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"]
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
@@ -112,6 +117,6 @@ def test_integration():
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# 运行 pytest 单元测试(真实API调用)
|
||||
pytest.main([__file__, '-v'])
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from src.data.storage import Storage
|
||||
from src.data.stock_basic import _get_csv_path
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||
|
||||
|
||||
class TestDailyStorageValidation:
|
||||
|
||||
@@ -5,6 +5,7 @@ Tests the sync module's full/incremental sync logic for daily data:
|
||||
- Incremental sync when local data exists (from last_date + 1)
|
||||
- Data integrity validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
@@ -17,6 +18,8 @@ from src.data.sync import (
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import Storage
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
class TestDateUtilities:
|
||||
@@ -63,30 +66,32 @@ class TestDataSync:
|
||||
|
||||
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
||||
"""Test getting stock codes from daily data."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'],
|
||||
})
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
|
||||
}
|
||||
)
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
assert '000001.SZ' in codes
|
||||
assert '600000.SH' in codes
|
||||
assert "000001.SZ" in codes
|
||||
assert "600000.SH" in codes
|
||||
|
||||
def test_get_all_stock_codes_fallback(self, mock_storage):
|
||||
"""Test fallback to stock_basic when daily is empty."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
# First call (daily) returns empty, second call (stock_basic) returns data
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(), # daily empty
|
||||
pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic
|
||||
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
|
||||
]
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
@@ -95,21 +100,23 @@ class TestDataSync:
|
||||
|
||||
def test_get_global_last_date(self, mock_storage):
|
||||
"""Test getting global last date."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '600000.SH'],
|
||||
'trade_date': ['20240102', '20240103'],
|
||||
})
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "600000.SH"],
|
||||
"trade_date": ["20240102", "20240103"],
|
||||
}
|
||||
)
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date == '20240103'
|
||||
assert last_date == "20240103"
|
||||
|
||||
def test_get_global_last_date_empty(self, mock_storage):
|
||||
"""Test getting last date from empty storage."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
@@ -120,18 +127,23 @@ class TestDataSync:
|
||||
|
||||
def test_sync_single_stock(self, mock_storage):
|
||||
"""Test syncing a single stock."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
})):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
with patch(
|
||||
"src.data.sync.get_daily",
|
||||
return_value=pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
),
|
||||
):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code='000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240102',
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
@@ -139,15 +151,15 @@ class TestDataSync:
|
||||
|
||||
def test_sync_single_stock_empty(self, mock_storage):
|
||||
"""Test syncing a stock with no data."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code='INVALID.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240102',
|
||||
ts_code="INVALID.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
)
|
||||
|
||||
assert result.empty
|
||||
@@ -158,40 +170,46 @@ class TestSyncAll:
|
||||
|
||||
def test_full_sync_mode(self, mock_storage):
|
||||
"""Test full sync mode when force_full=True."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
})
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
|
||||
result = sync.sync_all(force_full=True)
|
||||
|
||||
# Verify sync_single_stock was called with default start date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == DEFAULT_START_DATE
|
||||
assert call_args[1]["start_date"] == DEFAULT_START_DATE
|
||||
|
||||
def test_incremental_sync_mode(self, mock_storage):
|
||||
"""Test incremental sync mode when data exists."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
# Mock existing data with last date
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
}), # get_all_stock_codes
|
||||
pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
}), # get_global_last_date
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_all_stock_codes
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_global_last_date
|
||||
]
|
||||
|
||||
result = sync.sync_all(force_full=False)
|
||||
@@ -199,28 +217,30 @@ class TestSyncAll:
|
||||
# Verify sync_single_stock was called with next date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == '20240103'
|
||||
assert call_args[1]["start_date"] == "20240103"
|
||||
|
||||
def test_manual_start_date(self, mock_storage):
|
||||
"""Test sync with manual start date."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
})
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
|
||||
result = sync.sync_all(force_full=False, start_date='20230601')
|
||||
result = sync.sync_all(force_full=False, start_date="20230601")
|
||||
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == '20230601'
|
||||
assert call_args[1]["start_date"] == "20230601"
|
||||
|
||||
def test_no_stocks_found(self, mock_storage):
|
||||
"""Test sync when no stocks are found."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch("src.data.sync.Storage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
@@ -236,7 +256,7 @@ class TestSyncAllConvenienceFunction:
|
||||
|
||||
def test_sync_all_function(self):
|
||||
"""Test sync_all convenience function."""
|
||||
with patch('src.data.sync.DataSync') as MockSync:
|
||||
with patch("src.data.sync.DataSync") as MockSync:
|
||||
mock_instance = Mock()
|
||||
mock_instance.sync_all.return_value = {}
|
||||
MockSync.return_value = mock_instance
|
||||
@@ -251,5 +271,5 @@ class TestSyncAllConvenienceFunction:
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
256
tests/test_sync_real.py
Normal file
256
tests/test_sync_real.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Tests for data sync with REAL data (read-only).
|
||||
|
||||
Tests verify:
|
||||
1. get_global_last_date() correctly reads local data's max date
|
||||
2. Incremental sync date calculation (local_last_date + 1)
|
||||
3. Full sync date calculation (20180101)
|
||||
4. Multi-stock scenario with real data
|
||||
|
||||
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
|
||||
- NO sync_all() calls (writes daily.h5)
|
||||
- NO check_sync_needed() calls (writes trade_cal.h5)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from src.data.sync import (
|
||||
DataSync,
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import Storage
|
||||
|
||||
|
||||
class TestDataSyncReadOnly:
|
||||
"""Read-only tests for data sync - verify date calculation logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self):
|
||||
"""Create storage instance."""
|
||||
return Storage()
|
||||
|
||||
@pytest.fixture
|
||||
def data_sync(self):
|
||||
"""Create DataSync instance."""
|
||||
return DataSync()
|
||||
|
||||
@pytest.fixture
|
||||
def daily_exists(self, storage):
|
||||
"""Check if daily.h5 exists."""
|
||||
return storage.exists("daily")
|
||||
|
||||
def test_daily_h5_exists(self, storage):
|
||||
"""Verify daily.h5 data file exists before running tests."""
|
||||
assert storage.exists("daily"), (
|
||||
"daily.h5 not found. Please run full sync first: "
|
||||
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
|
||||
)
|
||||
|
||||
def test_get_global_last_date(self, data_sync, daily_exists):
|
||||
"""Test get_global_last_date returns correct max date from local data."""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
last_date = data_sync.get_global_last_date()
|
||||
|
||||
# Verify it's a valid date string
|
||||
assert last_date is not None, "get_global_last_date returned None"
|
||||
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
|
||||
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
|
||||
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
|
||||
|
||||
# Verify by reading storage directly
|
||||
daily_data = data_sync.storage.load("daily")
|
||||
expected_max = str(daily_data["trade_date"].max())
|
||||
|
||||
assert last_date == expected_max, (
|
||||
f"get_global_last_date returned {last_date}, "
|
||||
f"but actual max date is {expected_max}"
|
||||
)
|
||||
|
||||
print(f"[TEST] Local data last date: {last_date}")
|
||||
|
||||
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
|
||||
"""Test incremental sync: start_date = local_last_date + 1.
|
||||
|
||||
This verifies that when local data exists, incremental sync should
|
||||
fetch data from (local_last_date + 1), not from 20180101.
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
# Get local last date
|
||||
local_last_date = data_sync.get_global_last_date()
|
||||
assert local_last_date is not None, "No local data found"
|
||||
|
||||
# Calculate expected incremental start date
|
||||
expected_start_date = get_next_date(local_last_date)
|
||||
|
||||
# Verify the calculation is correct
|
||||
local_last_int = int(local_last_date)
|
||||
expected_int = local_last_int + 1
|
||||
actual_int = int(expected_start_date)
|
||||
|
||||
assert actual_int == expected_int, (
|
||||
f"Incremental start date calculation error: "
|
||||
f"expected {expected_int}, got {actual_int}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[TEST] Incremental sync: local_last={local_last_date}, "
|
||||
f"start_date should be {expected_start_date}"
|
||||
)
|
||||
|
||||
# Verify this is NOT 20180101 (would be full sync)
|
||||
assert expected_start_date != DEFAULT_START_DATE, (
|
||||
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
|
||||
)
|
||||
|
||||
def test_full_sync_date_calculation(self):
|
||||
"""Test full sync: start_date = 20180101 when force_full=True.
|
||||
|
||||
This verifies that force_full=True always starts from 20180101.
|
||||
"""
|
||||
# Full sync should always use DEFAULT_START_DATE
|
||||
full_sync_start = DEFAULT_START_DATE
|
||||
|
||||
assert full_sync_start == "20180101", (
|
||||
f"Full sync should start from 20180101, got {full_sync_start}"
|
||||
)
|
||||
|
||||
print(f"[TEST] Full sync start date: {full_sync_start}")
|
||||
|
||||
def test_date_comparison_logic(self, data_sync, daily_exists):
|
||||
"""Test date comparison: incremental vs full sync selection logic.
|
||||
|
||||
Verify that:
|
||||
- If local_last_date < today: incremental sync needed
|
||||
- If local_last_date >= today: no sync needed
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
local_last_date = data_sync.get_global_last_date()
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
local_last_int = int(local_last_date)
|
||||
today_int = int(today)
|
||||
|
||||
# Log the comparison
|
||||
print(
|
||||
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
|
||||
f"today={today} ({today_int})"
|
||||
)
|
||||
|
||||
# This test just verifies the comparison logic works
|
||||
if local_last_int < today_int:
|
||||
print("[TEST] Local data is older than today - sync needed")
|
||||
# Incremental sync should fetch from local_last_date + 1
|
||||
sync_start = get_next_date(local_last_date)
|
||||
assert int(sync_start) > local_last_int, (
|
||||
"Sync start should be after local last"
|
||||
)
|
||||
else:
|
||||
print("[TEST] Local data is up-to-date - no sync needed")
|
||||
|
||||
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
|
||||
"""Test get_all_stock_codes returns multiple real stock codes."""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
codes = data_sync.get_all_stock_codes()
|
||||
|
||||
# Verify it's a list
|
||||
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
|
||||
assert len(codes) > 0, "No stock codes found"
|
||||
|
||||
# Verify multiple stocks
|
||||
assert len(codes) >= 10, (
|
||||
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
|
||||
)
|
||||
|
||||
# Verify format (should be like 000001.SZ, 600000.SH)
|
||||
sample_codes = codes[:5]
|
||||
for code in sample_codes:
|
||||
assert "." in code, f"Invalid stock code format: {code}"
|
||||
suffix = code.split(".")[-1]
|
||||
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
|
||||
|
||||
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
|
||||
|
||||
def test_multi_stock_date_range(self, data_sync, daily_exists):
|
||||
"""Test that multiple stocks share the same date range in local data.
|
||||
|
||||
This verifies that local data has consistent date coverage across stocks.
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
daily_data = data_sync.storage.load("daily")
|
||||
|
||||
# Get date range for each stock
|
||||
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
|
||||
|
||||
# Get global min and max
|
||||
global_min = str(daily_data["trade_date"].min())
|
||||
global_max = str(daily_data["trade_date"].max())
|
||||
|
||||
print(f"[TEST] Global date range: {global_min} to {global_max}")
|
||||
print(f"[TEST] Total stocks: {len(stock_dates)}")
|
||||
|
||||
# Verify we have data for multiple stocks
|
||||
assert len(stock_dates) >= 10, (
|
||||
f"Expected at least 10 stocks, got {len(stock_dates)}"
|
||||
)
|
||||
|
||||
# Verify date range is reasonable (at least 1 year of data)
|
||||
global_min_int = int(global_min)
|
||||
global_max_int = int(global_max)
|
||||
days_span = global_max_int - global_min_int
|
||||
|
||||
assert days_span > 100, (
|
||||
f"Date range too small: {days_span} days. "
|
||||
f"Expected at least 100 days of data."
|
||||
)
|
||||
|
||||
print(f"[TEST] Date span: {days_span} days")
|
||||
|
||||
|
||||
class TestDateUtilities:
|
||||
"""Test date utility functions."""
|
||||
|
||||
def test_get_next_date(self):
|
||||
"""Test get_next_date correctly calculates next day."""
|
||||
# Test normal cases
|
||||
assert get_next_date("20240101") == "20240102"
|
||||
assert get_next_date("20240131") == "20240201" # Month boundary
|
||||
assert get_next_date("20241231") == "20250101" # Year boundary
|
||||
|
||||
def test_incremental_vs_full_sync_logic(self):
|
||||
"""Test the logic difference between incremental and full sync.
|
||||
|
||||
Incremental: start_date = local_last_date + 1
|
||||
Full: start_date = 20180101
|
||||
"""
|
||||
# Scenario 1: Local data exists
|
||||
local_last_date = "20240115"
|
||||
incremental_start = get_next_date(local_last_date)
|
||||
|
||||
assert incremental_start == "20240116"
|
||||
assert incremental_start != DEFAULT_START_DATE
|
||||
|
||||
# Scenario 2: Force full sync
|
||||
full_sync_start = DEFAULT_START_DATE # "20180101"
|
||||
|
||||
assert full_sync_start == "20180101"
|
||||
assert incremental_start != full_sync_start
|
||||
|
||||
print("[TEST] Incremental vs Full sync logic verified")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user