refactor: 重构 API 接口模块,整合为 api_wrappers 目录结构
- 将独立 API 模块 (daily, stock_basic, trade_cal) 整合至 api_wrappers/ - 重写 sync.py 使用新的 wrapper 结构,支持更多同步功能 - 更新测试文件适配新的模块结构 - 添加 pytest.ini 配置文件
This commit is contained in:
@@ -1,847 +0,0 @@
|
||||
# ProStock 数据接口封装规范
|
||||
|
||||
## 1. 概述
|
||||
|
||||
本文档定义了在 `src/data/` 目录下新增 Tushare API 接口封装的标准规范。所有非特殊接口(因子和基础数据)必须遵循此规范,以确保:
|
||||
- 代码风格统一
|
||||
- 自动 sync 支持
|
||||
- 增量更新逻辑一致
|
||||
- 减少存储写入压力
|
||||
|
||||
## 2. 接口分类
|
||||
|
||||
### 2.1 特殊接口(不参与统一 sync)
|
||||
|
||||
以下接口有独立的同步逻辑,不参与本文档定义的自动 sync 机制:
|
||||
|
||||
| 接口类型 | 示例 | 说明 |
|
||||
|---------|------|------|
|
||||
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
|
||||
| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 |
|
||||
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
|
||||
|
||||
### 2.2 标准接口(必须遵循本规范)
|
||||
|
||||
所有**按股票**或**按日期**获取的因子数据、行情数据、财务数据等,必须遵循本规范。
|
||||
|
||||
## 3. 文件结构
|
||||
|
||||
### 3.1 文件命名
|
||||
|
||||
```
|
||||
{data_type}.py
|
||||
```
|
||||
|
||||
示例:
|
||||
- `daily.py` - 日线行情
|
||||
- `moneyflow.py` - 资金流向
|
||||
- `limit_list.py` - 涨跌停数据
|
||||
- `stk_holdernumber.py` - 股东人数
|
||||
|
||||
### 3.2 文件位置
|
||||
|
||||
```
|
||||
src/data/
|
||||
├── __init__.py # 导出公共接口
|
||||
├── client.py # TushareClient(已有)
|
||||
├── config.py # 配置管理(已有)
|
||||
├── storage.py # 存储管理(已有)
|
||||
├── rate_limiter.py # 速率限制(已有)
|
||||
├── trade_cal.py # 交易日历(特殊接口)
|
||||
├── stock_basic.py # 股票基础(特殊接口)
|
||||
├── daily.py # 日线行情(参考示例)
|
||||
└── {new_data_type}.py # 新增接口文件
|
||||
```
|
||||
|
||||
## 4. 接口设计规范
|
||||
|
||||
### 4.1 数据获取函数
|
||||
|
||||
#### 4.1.1 按股票获取的接口
|
||||
|
||||
适用于:日线行情、分钟线、资金流向等
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
"""获取 {数据描述}。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码(如 '000001.SZ')
|
||||
start_date: 开始日期(YYYYMMDD格式)
|
||||
end_date: 结束日期(YYYYMMDD格式)
|
||||
# 其他参数说明...
|
||||
|
||||
Returns:
|
||||
pd.DataFrame 包含以下字段:
|
||||
- ts_code: 股票代码
|
||||
- trade_date: 交易日期
|
||||
# 其他字段...
|
||||
|
||||
Example:
|
||||
>>> data = get_{data_type}('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
params = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
# 其他参数...
|
||||
|
||||
data = client.query("{api_name}", **params)
|
||||
return data
|
||||
```
|
||||
|
||||
#### 4.1.2 按日期获取的接口
|
||||
|
||||
适用于:每日涨跌停、每日龙虎榜、每日筹码分布等
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
trade_date: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
"""获取 {数据描述}。
|
||||
|
||||
**优先按日期获取**(推荐):
|
||||
- 使用 trade_date 获取单日全市场数据
|
||||
- 或使用 start_date + end_date 获取区间数据
|
||||
|
||||
Args:
|
||||
trade_date: 交易日期(YYYYMMDD格式),获取单日全市场数据
|
||||
start_date: 开始日期(YYYYMMDD格式)
|
||||
end_date: 结束日期(YYYYMMDD格式)
|
||||
ts_code: 股票代码(可选,用于过滤特定股票)
|
||||
# 其他参数说明...
|
||||
|
||||
Returns:
|
||||
pd.DataFrame 包含以下字段:
|
||||
- ts_code: 股票代码
|
||||
- trade_date: 交易日期
|
||||
# 其他字段...
|
||||
|
||||
Example:
|
||||
>>> # 获取单日全市场数据(推荐)
|
||||
>>> data = get_{data_type}(trade_date='20240115')
|
||||
>>> # 获取区间数据
|
||||
>>> data = get_{data_type}(start_date='20240101', end_date='20240131')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
# 其他参数...
|
||||
|
||||
data = client.query("{api_name}", **params)
|
||||
return data
|
||||
```
|
||||
|
||||
### 4.2 关键设计原则
|
||||
|
||||
#### 4.2.1 优先按日期获取
|
||||
|
||||
**强烈建议**优先实现按日期获取的接口:
|
||||
|
||||
1. **效率更高**:一次请求获取全市场数据
|
||||
2. **API 调用更少**:N 天 = N 次调用,而非 N 天 × M 只股票
|
||||
3. **更适合增量更新**:按天检查本地数据,只获取缺失日期
|
||||
|
||||
#### 4.2.2 日期字段统一
|
||||
|
||||
- 统一使用 `trade_date` 作为日期字段名
|
||||
- 日期格式:`YYYYMMDD` 字符串
|
||||
- 如果 API 返回其他字段名(如 `date`、`end_date`),在返回前重命名为 `trade_date`
|
||||
|
||||
#### 4.2.3 股票代码字段
|
||||
|
||||
- 统一使用 `ts_code` 作为股票代码字段名
|
||||
- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH`
|
||||
|
||||
## 5. Sync 集成规范
|
||||
|
||||
### 5.1 在 sync.py 中注册新数据类型
|
||||
|
||||
在 `DataSync` 类中添加新数据类型的同步支持:
|
||||
|
||||
```python
|
||||
class DataSync:
|
||||
"""Data synchronization manager with full/incremental sync support."""
|
||||
|
||||
DEFAULT_MAX_WORKERS = 10
|
||||
|
||||
# 数据类型配置
|
||||
DATASET_CONFIG = {
|
||||
"daily": {
|
||||
"api_name": "pro_bar",
|
||||
"fetch_by": "stock", # 按股票获取
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"],
|
||||
},
|
||||
"moneyflow": {
|
||||
"api_name": "moneyflow",
|
||||
"fetch_by": "stock", # 按股票获取
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"],
|
||||
},
|
||||
"limit_list": {
|
||||
"api_name": "limit_list",
|
||||
"fetch_by": "date", # 按日期获取(优先)
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"],
|
||||
},
|
||||
# 新增数据类型...
|
||||
"{new_data_type}": {
|
||||
"api_name": "{tushare_api_name}",
|
||||
"fetch_by": "date", # "date" 或 "stock"
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 实现同步方法
|
||||
|
||||
#### 5.2.1 按日期获取的同步方法(推荐)
|
||||
|
||||
```python
|
||||
def sync_by_date(
|
||||
self,
|
||||
dataset_name: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync data by date (fetch all stocks for each date).
|
||||
|
||||
This is the RECOMMENDED approach for date-based data like:
|
||||
- limit_list (涨跌停)
|
||||
- top_list (龙虎榜)
|
||||
- cyq_perf (筹码分布)
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset in DATASET_CONFIG
|
||||
start_date: Start date (YYYYMMDD)
|
||||
end_date: End date (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
Combined DataFrame with all data
|
||||
"""
|
||||
from src.data.trade_cal import get_trading_days
|
||||
|
||||
config = self.DATASET_CONFIG[dataset_name]
|
||||
api_name = config["api_name"]
|
||||
date_field = config["date_field"]
|
||||
|
||||
# Get trading days in the range
|
||||
trading_days = get_trading_days(start_date, end_date)
|
||||
if not trading_days:
|
||||
print(f"[DataSync] No trading days in range {start_date} to {end_date}")
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"[DataSync] Fetching {dataset_name} for {len(trading_days)} trading days")
|
||||
|
||||
all_data = []
|
||||
error_occurred = False
|
||||
|
||||
for trade_date in tqdm(trading_days, desc=f"Syncing {dataset_name}"):
|
||||
if not self._stop_flag.is_set():
|
||||
break
|
||||
|
||||
try:
|
||||
data = self.client.query(
|
||||
api_name,
|
||||
trade_date=trade_date,
|
||||
)
|
||||
if not data.empty:
|
||||
all_data.append(data)
|
||||
except Exception as e:
|
||||
self._stop_flag.clear()
|
||||
error_occurred = True
|
||||
print(f"[ERROR] Failed to fetch {dataset_name} for {trade_date}: {e}")
|
||||
raise
|
||||
|
||||
if error_occurred or not all_data:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine all data
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
|
||||
# Ensure date field is consistent
|
||||
if date_field not in combined.columns and "trade_date" in combined.columns:
|
||||
combined = combined.rename(columns={"trade_date": date_field})
|
||||
|
||||
return combined
|
||||
```
|
||||
|
||||
#### 5.2.2 按股票获取的同步方法
|
||||
|
||||
```python
|
||||
def sync_by_stock(
|
||||
self,
|
||||
dataset_name: str,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync data by stock (fetch all dates for each stock).
|
||||
|
||||
Use this for stock-based data like:
|
||||
- daily (日线行情)
|
||||
- moneyflow (资金流向)
|
||||
- stk_holdernumber (股东人数)
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset in DATASET_CONFIG
|
||||
ts_code: Stock code
|
||||
start_date: Start date (YYYYMMDD)
|
||||
end_date: End date (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
DataFrame with data for the stock
|
||||
"""
|
||||
config = self.DATASET_CONFIG[dataset_name]
|
||||
api_name = config["api_name"]
|
||||
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
data = self.client.query(
|
||||
api_name,
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {dataset_name} for {ts_code}: {e}")
|
||||
raise
|
||||
```
|
||||
|
||||
### 5.3 增量更新逻辑
|
||||
|
||||
#### 5.3.1 通用增量更新检查
|
||||
|
||||
```python
|
||||
def check_incremental_sync(
|
||||
self,
|
||||
dataset_name: str,
|
||||
force_full: bool = False,
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""Check if incremental sync is needed for a dataset.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset
|
||||
force_full: If True, force full sync
|
||||
|
||||
Returns:
|
||||
Tuple of (sync_needed, start_date, end_date, local_last_date)
|
||||
"""
|
||||
config = self.DATASET_CONFIG[dataset_name]
|
||||
date_field = config["date_field"]
|
||||
|
||||
# If force_full, always sync from default start
|
||||
if force_full:
|
||||
print(f"[DataSync] Force full sync for {dataset_name}")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# Check local data
|
||||
local_data = self.storage.load(dataset_name)
|
||||
if local_data.empty or date_field not in local_data.columns:
|
||||
print(f"[DataSync] No local {dataset_name} data, full sync needed")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# Get local last date
|
||||
local_last_date = str(local_data[date_field].max())
|
||||
print(f"[DataSync] Local {dataset_name} last date: {local_last_date}")
|
||||
|
||||
# Get calendar last trading day
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print(f"[DataSync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[DataSync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# Compare dates
|
||||
if int(local_last_date) >= int(cal_last):
|
||||
print(f"[DataSync] {dataset_name} is up-to-date, skipping sync")
|
||||
return (False, None, None, None)
|
||||
|
||||
# Need incremental sync
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[DataSync] Incremental sync for {dataset_name} from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
```
|
||||
|
||||
#### 5.3.2 完整的同步入口
|
||||
|
||||
```python
|
||||
def sync_dataset(
|
||||
self,
|
||||
dataset_name: str,
|
||||
force_full: bool = False,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync a dataset with automatic incremental update.
|
||||
|
||||
This is the main entry point for syncing any dataset.
|
||||
|
||||
Args:
|
||||
dataset_name: Name of the dataset in DATASET_CONFIG
|
||||
force_full: If True, force full reload
|
||||
max_workers: Number of worker threads (for stock-based sync)
|
||||
|
||||
Returns:
|
||||
DataFrame with synced data
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print(f"[DataSync] Starting {dataset_name} sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# Ensure trade calendar is up-to-date
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Check if sync is needed
|
||||
sync_needed, start_date, end_date, local_last = self.check_incremental_sync(
|
||||
dataset_name, force_full
|
||||
)
|
||||
|
||||
if not sync_needed:
|
||||
print(f"[DataSync] {dataset_name} is up-to-date, skipping")
|
||||
return pd.DataFrame()
|
||||
|
||||
config = self.DATASET_CONFIG[dataset_name]
|
||||
fetch_by = config["fetch_by"]
|
||||
|
||||
# Fetch data based on strategy
|
||||
if fetch_by == "date":
|
||||
# Fetch by date (all stocks per day)
|
||||
data = self.sync_by_date(dataset_name, start_date, end_date)
|
||||
else:
|
||||
# Fetch by stock (all dates per stock)
|
||||
data = self._sync_all_stocks(dataset_name, start_date, end_date, max_workers)
|
||||
|
||||
if data.empty:
|
||||
print(f"[DataSync] No new data for {dataset_name}")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Save to storage (single write)
|
||||
self.storage.save(dataset_name, data, mode="append")
|
||||
|
||||
print(f"[DataSync] Synced {len(data)} rows for {dataset_name}")
|
||||
return data
|
||||
|
||||
def _sync_all_stocks(
|
||||
self,
|
||||
dataset_name: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync data for all stocks (stock-based fetch)."""
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"[DataSync] Syncing {dataset_name} for {len(stock_codes)} stocks")
|
||||
|
||||
self._stop_flag.set()
|
||||
results = []
|
||||
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_code = {
|
||||
executor.submit(
|
||||
self.sync_by_stock, dataset_name, ts_code, start_date, end_date
|
||||
): ts_code
|
||||
for ts_code in stock_codes
|
||||
}
|
||||
|
||||
with tqdm(total=len(stock_codes), desc=f"Syncing {dataset_name}") as pbar:
|
||||
for future in as_completed(future_to_code):
|
||||
try:
|
||||
data = future.result()
|
||||
if not data.empty:
|
||||
results.append(data)
|
||||
except Exception as e:
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise
|
||||
pbar.update(1)
|
||||
|
||||
if not results:
|
||||
return pd.DataFrame()
|
||||
|
||||
return pd.concat(results, ignore_index=True)
|
||||
```
|
||||
|
||||
## 6. 存储规范
|
||||
|
||||
### 6.1 Storage 类使用
|
||||
|
||||
所有数据通过 `Storage` 类进行 HDF5 存储:
|
||||
|
||||
```python
|
||||
from src.data.storage import Storage
|
||||
|
||||
storage = Storage()
|
||||
|
||||
# 保存数据(自动增量合并)
|
||||
storage.save("dataset_name", data, mode="append")
|
||||
|
||||
# 加载数据
|
||||
all_data = storage.load("dataset_name")
|
||||
filtered_data = storage.load("dataset_name", start_date="20240101", end_date="20240131")
|
||||
|
||||
# 获取最新日期
|
||||
last_date = storage.get_last_date("dataset_name")
|
||||
|
||||
# 检查是否存在
|
||||
exists = storage.exists("dataset_name")
|
||||
```
|
||||
|
||||
### 6.2 增量写入策略
|
||||
|
||||
**关键原则**:所有数据在请求完成后**一次性写入**,而非逐条写入:
|
||||
|
||||
```python
|
||||
# ❌ 错误:逐条写入(性能差)
|
||||
for date in dates:
|
||||
data = fetch(date)
|
||||
storage.save("dataset", data, mode="append") # 多次写入
|
||||
|
||||
# ✅ 正确:批量写入(性能好)
|
||||
all_data = []
|
||||
for date in dates:
|
||||
data = fetch(date)
|
||||
all_data.append(data)
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
storage.save("dataset", combined, mode="append") # 单次写入
|
||||
```
|
||||
|
||||
### 6.3 去重策略
|
||||
|
||||
`Storage.save()` 方法会自动去重,基于配置中的 `key_fields`:
|
||||
|
||||
```python
|
||||
# storage.py 中的实现
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], # 使用 key_fields
|
||||
keep="last" # 保留最新数据
|
||||
)
|
||||
```
|
||||
|
||||
## 7. 完整示例:新增涨跌停数据接口
|
||||
|
||||
### 7.1 创建 limit_list.py
|
||||
|
||||
```python
|
||||
"""Limit up/down list interface.
|
||||
|
||||
Fetch stocks that hit limit up or limit down for a specific trade date.
|
||||
This is a date-based interface (recommended approach).
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
def get_limit_list(
|
||||
trade_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""获取涨跌停数据。
|
||||
|
||||
**优先按日期获取**(推荐):
|
||||
- 使用 trade_date 获取单日全市场涨跌停数据
|
||||
- 或使用 start_date + end_date 获取区间数据
|
||||
|
||||
Args:
|
||||
trade_date: 交易日期(YYYYMMDD格式),获取单日全市场数据
|
||||
ts_code: 股票代码(可选,用于过滤)
|
||||
start_date: 开始日期(YYYYMMDD格式)
|
||||
end_date: 结束日期(YYYYMMDD格式)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame 包含以下字段:
|
||||
- ts_code: 股票代码
|
||||
- trade_date: 交易日期
|
||||
- name: 股票名称
|
||||
- close: 收盘价
|
||||
- pct_chg: 涨跌幅
|
||||
- amp: 振幅
|
||||
- fc_ratio: 封单金额/日成交额
|
||||
- fl_ratio: 封单手数/流通股本
|
||||
- fd_amount: 封单金额
|
||||
- first_time: 首次涨停时间
|
||||
- last_time: 最后封板时间
|
||||
- open_times: 打开次数
|
||||
- strth: 涨停强度
|
||||
- limit: 涨停类型(U涨停D跌停)
|
||||
|
||||
Example:
|
||||
>>> # 获取单日全市场涨跌停数据(推荐)
|
||||
>>> data = get_limit_list(trade_date='20240115')
|
||||
>>> # 获取区间数据
|
||||
>>> data = get_limit_list(start_date='20240101', end_date='20240131')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
data = client.query("limit_list", **params)
|
||||
return data
|
||||
```
|
||||
|
||||
### 7.2 在 sync.py 中注册
|
||||
|
||||
```python
|
||||
class DataSync:
|
||||
"""Data synchronization manager with full/incremental sync support."""
|
||||
|
||||
DATASET_CONFIG = {
|
||||
# ... 其他配置 ...
|
||||
"limit_list": {
|
||||
"api_name": "limit_list",
|
||||
"fetch_by": "date", # 按日期获取
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"],
|
||||
},
|
||||
}
|
||||
|
||||
# ... 其他方法 ...
|
||||
|
||||
def sync_limit_list(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync limit list data."""
|
||||
return self.sync_dataset("limit_list", force_full)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def sync_limit_list(force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync limit up/down data."""
|
||||
sync_manager = DataSync()
|
||||
return sync_manager.sync_limit_list(force_full)
|
||||
```
|
||||
|
||||
### 7.3 更新 __init__.py
|
||||
|
||||
```python
|
||||
from src.data.limit_list import get_limit_list
|
||||
|
||||
__all__ = [
|
||||
# ... 其他导出 ...
|
||||
"get_limit_list",
|
||||
]
|
||||
```
|
||||
|
||||
## 8. 测试规范
|
||||
|
||||
### 8.1 测试文件结构
|
||||
|
||||
```
|
||||
tests/
|
||||
├── test_sync.py # sync 模块测试
|
||||
├── test_daily.py # daily 模块测试
|
||||
└── test_{new_module}.py # 新增模块测试
|
||||
```
|
||||
|
||||
### 8.2 测试模板
|
||||
|
||||
```python
|
||||
"""Tests for {module_name} module."""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pandas as pd
|
||||
from src.data.{module_name} import get_{data_type}
|
||||
|
||||
|
||||
class Test{DataType}:
|
||||
"""Test cases for {data_type} data fetching."""
|
||||
|
||||
@patch("src.data.{module_name}.TushareClient")
|
||||
def test_get_{data_type}_by_date(self, mock_client_class):
|
||||
"""Test fetching data by date."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame({
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
# ... 其他字段 ...
|
||||
})
|
||||
|
||||
# Call function
|
||||
result = get_{data_type}(trade_date="20240115")
|
||||
|
||||
# Verify
|
||||
assert not result.empty
|
||||
mock_client.query.assert_called_once_with(
|
||||
"{api_name}",
|
||||
trade_date="20240115",
|
||||
)
|
||||
|
||||
@patch("src.data.{module_name}.TushareClient")
|
||||
def test_get_{data_type}_by_stock(self, mock_client_class):
|
||||
"""Test fetching data by stock code."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame({
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
# ... 其他字段 ...
|
||||
})
|
||||
|
||||
# Call function
|
||||
result = get_{data_type}(
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert not result.empty
|
||||
mock_client.query.assert_called_once()
|
||||
```
|
||||
|
||||
## 9. 检查清单
|
||||
|
||||
在提交新接口前,请确认以下事项:
|
||||
|
||||
### 9.1 文件结构
|
||||
- [ ] 文件位于 `src/data/{data_type}.py`
|
||||
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
|
||||
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
|
||||
|
||||
### 9.2 接口实现
|
||||
- [ ] 数据获取函数使用 `TushareClient`
|
||||
- [ ] 函数包含完整的 Google 风格文档字符串
|
||||
- [ ] 日期参数使用 `YYYYMMDD` 格式
|
||||
- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段
|
||||
- [ ] 优先实现按日期获取的接口(如果 API 支持)
|
||||
|
||||
### 9.3 Sync 集成
|
||||
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
|
||||
- [ ] 正确设置 `fetch_by`("date" 或 "stock")
|
||||
- [ ] 正确设置 `date_field` 和 `key_fields`
|
||||
- [ ] 已实现对应的 sync 方法或复用通用方法
|
||||
- [ ] 增量更新逻辑正确(检查本地最新日期)
|
||||
|
||||
### 9.4 存储优化
|
||||
- [ ] 所有数据一次性写入(非逐条)
|
||||
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
|
||||
- [ ] 去重字段配置正确
|
||||
|
||||
### 9.5 测试
|
||||
- [ ] 已编写单元测试
|
||||
- [ ] 已 mock TushareClient
|
||||
- [ ] 测试覆盖正常和异常情况
|
||||
|
||||
## 10. 常见问题
|
||||
|
||||
### Q1: API 返回的日期字段名不是 trade_date 怎么办?
|
||||
|
||||
在返回前重命名:
|
||||
|
||||
```python
|
||||
data = client.query("api_name", **params)
|
||||
if "end_date" in data.columns:
|
||||
data = data.rename(columns={"end_date": "trade_date"})
|
||||
return data
|
||||
```
|
||||
|
||||
### Q2: 如何处理分页(limit/offset)?
|
||||
|
||||
Tushare Pro API 通常不需要手动分页,但如果需要:
|
||||
|
||||
```python
|
||||
all_data = []
|
||||
offset = 0
|
||||
limit = 5000
|
||||
|
||||
while True:
|
||||
data = client.query(
|
||||
"api_name",
|
||||
trade_date=trade_date,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if data.empty or len(data) < limit:
|
||||
all_data.append(data)
|
||||
break
|
||||
all_data.append(data)
|
||||
offset += limit
|
||||
|
||||
return pd.concat(all_data, ignore_index=True)
|
||||
```
|
||||
|
||||
### Q3: 如何处理需要额外参数的接口?
|
||||
|
||||
在函数签名中添加参数,并传递给 client.query:
|
||||
|
||||
```python
|
||||
def get_data(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
fields: Optional[list] = None, # 额外参数
|
||||
) -> pd.DataFrame:
|
||||
params = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if fields:
|
||||
params["fields"] = ",".join(fields)
|
||||
|
||||
return client.query("api_name", **params)
|
||||
```
|
||||
|
||||
### Q4: 如何处理没有 trade_date 字段的数据?
|
||||
|
||||
如果数据确实不包含日期字段(如静态数据),可以:
|
||||
1. 将其归类为"特殊接口",独立管理
|
||||
2. 或者添加一个 `sync_date` 字段记录同步时间
|
||||
|
||||
### Q5: 如何处理按日期获取但 API 不支持的情况?
|
||||
|
||||
如果 API 只支持按股票获取:
|
||||
1. 在 `DATASET_CONFIG` 中设置 `fetch_by: "stock"`
|
||||
2. 使用 `_sync_all_stocks` 方法进行同步
|
||||
3. 在文档中说明这是按股票获取的接口
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-01
|
||||
@@ -6,7 +6,7 @@ Provides simplified interfaces for fetching and storing Tushare data.
|
||||
from src.data.config import Config, get_config
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.stock_basic import get_stock_basic, sync_all_stocks
|
||||
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
|
||||
244
src/data/api_wrappers/API_INTERFACE_SPEC.md
Normal file
244
src/data/api_wrappers/API_INTERFACE_SPEC.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# ProStock 数据接口封装规范
|
||||
|
||||
## 1. 概述
|
||||
|
||||
本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保:
|
||||
- 代码风格统一
|
||||
- 自动 sync 支持
|
||||
- 增量更新逻辑一致
|
||||
- 减少存储写入压力
|
||||
|
||||
## 2. 接口分类
|
||||
|
||||
### 2.1 特殊接口(不参与统一 sync)
|
||||
|
||||
以下接口有独立的同步逻辑,不参与自动 sync 机制:
|
||||
|
||||
| 接口类型 | 示例 | 说明 |
|
||||
|---------|------|------|
|
||||
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
|
||||
| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 |
|
||||
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
|
||||
|
||||
### 2.2 标准接口(必须遵循本规范)
|
||||
|
||||
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。
|
||||
|
||||
## 3. 文件结构要求
|
||||
|
||||
### 3.1 文件命名
|
||||
|
||||
```
|
||||
{data_type}.py
|
||||
```
|
||||
|
||||
示例:`daily.py`、`moneyflow.py`、`limit_list.py`
|
||||
|
||||
### 3.2 文件位置
|
||||
|
||||
所有接口文件必须位于 `src/data/` 目录下。
|
||||
|
||||
### 3.3 导出要求
|
||||
|
||||
新接口必须在 `src/data/__init__.py` 中导出:
|
||||
|
||||
```python
|
||||
from src.data.{module_name} import get_{data_type}
|
||||
|
||||
__all__ = [
|
||||
# ... 其他导出 ...
|
||||
"get_{data_type}",
|
||||
]
|
||||
```
|
||||
|
||||
## 4. 接口设计规范
|
||||
|
||||
### 4.1 数据获取函数签名要求
|
||||
|
||||
函数必须返回 `pd.DataFrame`,参数必须包含以下之一:
|
||||
|
||||
#### 4.1.1 按日期获取的接口(优先)
|
||||
|
||||
适用于:涨跌停、龙虎榜、筹码分布等。
|
||||
|
||||
**函数签名要求**:
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
trade_date: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
```
|
||||
|
||||
**要求**:
|
||||
- 优先使用 `trade_date` 获取单日全市场数据
|
||||
- 支持 `start_date + end_date` 获取区间数据
|
||||
- `ts_code` 作为可选过滤参数
|
||||
|
||||
#### 4.1.2 按股票获取的接口
|
||||
|
||||
适用于:日线行情、资金流向等。
|
||||
|
||||
**函数签名要求**:
|
||||
|
||||
```python
|
||||
def get_{data_type}(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
# 其他可选参数...
|
||||
) -> pd.DataFrame:
|
||||
```
|
||||
|
||||
### 4.2 文档字符串要求
|
||||
|
||||
函数必须包含 Google 风格的完整文档字符串,包含:
|
||||
- 函数功能描述
|
||||
- `Args` 部分:所有参数说明
|
||||
- `Returns` 部分:返回的 DataFrame 包含的字段说明
|
||||
- `Example` 部分:使用示例
|
||||
|
||||
### 4.3 日期格式要求
|
||||
|
||||
- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式
|
||||
- 统一使用 `trade_date` 作为日期字段名
|
||||
- 如果 API 返回其他日期字段名(如 `date`、`end_date`),必须在返回前重命名为 `trade_date`
|
||||
|
||||
### 4.4 股票代码要求
|
||||
|
||||
- 统一使用 `ts_code` 作为股票代码字段名
|
||||
- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH`
|
||||
|
||||
### 4.5 令牌桶限速要求
|
||||
|
||||
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
|
||||
|
||||
## 5. Sync 集成规范
|
||||
|
||||
### 5.1 DATASET_CONFIG 注册要求
|
||||
|
||||
新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项:
|
||||
|
||||
```python
|
||||
"{new_data_type}": {
|
||||
"api_name": "{tushare_api_name}", # Tushare API 名称
|
||||
"fetch_by": "date", # "date" 或 "stock"
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 fetch_by 取值规则
|
||||
|
||||
- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据
|
||||
- 仅当 API 不支持按日期获取时才使用 `"stock"`
|
||||
|
||||
### 5.3 sync 方法要求
|
||||
|
||||
必须实现对应的 sync 方法或复用通用方法:
|
||||
|
||||
```python
|
||||
def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述}。"""
|
||||
return self.sync_dataset("{data_type}", force_full)
|
||||
```
|
||||
|
||||
同时提供便捷函数:
|
||||
|
||||
```python
|
||||
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述}。"""
|
||||
sync_manager = DataSync()
|
||||
return sync_manager.sync_{data_type}(force_full)
|
||||
```
|
||||
|
||||
### 5.4 增量更新要求
|
||||
|
||||
- 必须实现增量更新逻辑(自动检查本地最新日期)
|
||||
- 使用 `force_full` 参数支持强制全量同步
|
||||
|
||||
## 6. 存储规范
|
||||
|
||||
### 6.1 存储方式
|
||||
|
||||
所有数据通过 `Storage` 类进行 HDF5 存储。
|
||||
|
||||
### 6.2 写入策略
|
||||
|
||||
**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。
|
||||
|
||||
### 6.3 去重要求
|
||||
|
||||
使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`。
|
||||
|
||||
## 7. 测试规范
|
||||
|
||||
### 7.1 测试文件要求
|
||||
|
||||
必须创建对应的测试文件:`tests/test_{data_type}.py`
|
||||
|
||||
### 7.2 测试覆盖要求
|
||||
|
||||
- 测试按日期获取
|
||||
- 测试按股票获取(如果支持)
|
||||
- 必须 mock `TushareClient`
|
||||
- 测试覆盖正常和异常情况
|
||||
|
||||
## 8. 新增接口完整流程
|
||||
|
||||
### 8.1 创建接口文件
|
||||
|
||||
1. 在 `src/data/` 下创建 `{data_type}.py`
|
||||
2. 实现数据获取函数,遵循第 4 节规范
|
||||
|
||||
### 8.2 注册 sync 支持
|
||||
|
||||
1. 在 `sync.py` 的 `DataSync.DATASET_CONFIG` 中注册
|
||||
2. 实现对应的 sync 方法
|
||||
3. 提供便捷函数
|
||||
|
||||
### 8.3 更新导出
|
||||
|
||||
在 `src/data/__init__.py` 中导出接口函数。
|
||||
|
||||
### 8.4 创建测试
|
||||
|
||||
创建 `tests/test_{data_type}.py`,覆盖关键场景。
|
||||
|
||||
## 9. 检查清单
|
||||
|
||||
### 9.1 文件结构
|
||||
- [ ] 文件位于 `src/data/{data_type}.py`
|
||||
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
|
||||
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
|
||||
|
||||
### 9.2 接口实现
|
||||
- [ ] 数据获取函数使用 `TushareClient`
|
||||
- [ ] 函数包含完整的 Google 风格文档字符串
|
||||
- [ ] 日期参数使用 `YYYYMMDD` 格式
|
||||
- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段
|
||||
- [ ] 优先实现按日期获取的接口(如果 API 支持)
|
||||
|
||||
### 9.3 Sync 集成
|
||||
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
|
||||
- [ ] 正确设置 `fetch_by`("date" 或 "stock")
|
||||
- [ ] 正确设置 `date_field` 和 `key_fields`
|
||||
- [ ] 已实现对应的 sync 方法或复用通用方法
|
||||
- [ ] 增量更新逻辑正确(检查本地最新日期)
|
||||
|
||||
### 9.4 存储优化
|
||||
- [ ] 所有数据一次性写入(非逐条)
|
||||
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
|
||||
- [ ] 去重字段配置正确
|
||||
|
||||
### 9.5 测试
|
||||
- [ ] 已编写单元测试
|
||||
- [ ] 已 mock TushareClient
|
||||
- [ ] 测试覆盖正常和异常情况
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-01
|
||||
40
src/data/api_wrappers/__init__.py
Normal file
40
src/data/api_wrappers/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Tushare API wrapper modules.
|
||||
|
||||
This package contains simplified interfaces for fetching data from Tushare API.
|
||||
All wrapper files follow the naming convention: api_{data_type}.py
|
||||
|
||||
Available APIs:
|
||||
- api_daily: Daily market data (日线行情)
|
||||
- api_stock_basic: Stock basic information (股票基本信息)
|
||||
- api_trade_cal: Trading calendar (交易日历)
|
||||
|
||||
Example:
|
||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal
|
||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> stocks = get_stock_basic()
|
||||
>>> calendar = get_trade_cal('20240101', '20240131')
|
||||
"""
|
||||
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_trade_cal,
|
||||
get_trading_days,
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Daily market data
|
||||
"get_daily",
|
||||
# Stock basic information
|
||||
"get_stock_basic",
|
||||
"sync_all_stocks",
|
||||
# Trade calendar
|
||||
"get_trade_cal",
|
||||
"get_trading_days",
|
||||
"get_first_trading_day",
|
||||
"get_last_trading_day",
|
||||
"sync_trade_cal_cache",
|
||||
]
|
||||
@@ -179,4 +179,75 @@ df = pro.query('trade_cal', start_date='20180101', end_date='20181231')
|
||||
17 SSE 20180118 1
|
||||
18 SSE 20180119 1
|
||||
19 SSE 20180120 0
|
||||
20 SSE 20180121 0
|
||||
20 SSE 20180121 0
|
||||
|
||||
|
||||
每日指标
|
||||
接口:daily_basic,可以通过数据工具调试和查看数据。
|
||||
更新时间:交易日每日15点~17点之间
|
||||
描述:获取全部股票每日重要的基本面指标,可用于选股分析、报表展示等。单次请求最大返回6000条数据,可按日线循环提取全部历史。
|
||||
积分:至少2000积分才可以调取,5000积分无总量限制,具体请参阅积分获取办法
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str Y 股票代码(二选一)
|
||||
trade_date str N 交易日期 (二选一)
|
||||
start_date str N 开始日期(YYYYMMDD)
|
||||
end_date str N 结束日期(YYYYMMDD)
|
||||
注:日期都填YYYYMMDD格式,比如20181010
|
||||
|
||||
输出参数
|
||||
|
||||
名称 类型 描述
|
||||
ts_code str TS股票代码
|
||||
trade_date str 交易日期
|
||||
close float 当日收盘价
|
||||
turnover_rate float 换手率(%)
|
||||
turnover_rate_f float 换手率(自由流通股)
|
||||
volume_ratio float 量比
|
||||
pe float 市盈率(总市值/净利润, 亏损的PE为空)
|
||||
pe_ttm float 市盈率(TTM,亏损的PE为空)
|
||||
pb float 市净率(总市值/净资产)
|
||||
ps float 市销率
|
||||
ps_ttm float 市销率(TTM)
|
||||
dv_ratio float 股息率 (%)
|
||||
dv_ttm float 股息率(TTM)(%)
|
||||
total_share float 总股本 (万股)
|
||||
float_share float 流通股本 (万股)
|
||||
free_share float 自由流通股本 (万)
|
||||
total_mv float 总市值 (万元)
|
||||
circ_mv float 流通市值(万元)
|
||||
接口用法
|
||||
|
||||
|
||||
pro = ts.pro_api()
|
||||
|
||||
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||
或者
|
||||
|
||||
|
||||
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||
数据样例
|
||||
|
||||
ts_code trade_date turnover_rate volume_ratio pe pb
|
||||
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
|
||||
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
|
||||
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
|
||||
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
|
||||
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
|
||||
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
|
||||
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
|
||||
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
|
||||
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
|
||||
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
|
||||
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
|
||||
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
|
||||
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
|
||||
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
|
||||
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
|
||||
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
|
||||
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
|
||||
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
|
||||
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
|
||||
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
|
||||
@@ -3,6 +3,7 @@
|
||||
A single function to fetch A股日线行情 data from Tushare.
|
||||
Supports all output fields including tor (换手率) and vr (量比).
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal
|
||||
from src.data.client import TushareClient
|
||||
@@ -33,7 +34,7 @@ def get_daily(
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with daily market data containing:
|
||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||
change, pct_chg, vol, amount
|
||||
- Factor fields (if requested): tor, vr
|
||||
- Adjustment factor (if adjfactor=True): adjfactor
|
||||
@@ -3,6 +3,7 @@
|
||||
Fetch basic stock information including code, name, listing date, etc.
|
||||
This is a special interface - call once to get all stocks (listed and delisted).
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Simplified HDF5 storage for data persistence."""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -47,7 +48,9 @@ class Storage:
|
||||
# Merge with existing data
|
||||
existing = store[name]
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], keep="last"
|
||||
)
|
||||
store.put(name, combined, format="table")
|
||||
|
||||
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
||||
@@ -57,10 +60,13 @@ class Storage:
|
||||
print(f"[Storage] Error saving {name}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def load(self, name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None) -> pd.DataFrame:
|
||||
def load(
|
||||
self,
|
||||
name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Load data from HDF5 file.
|
||||
|
||||
Args:
|
||||
@@ -80,14 +86,25 @@ class Storage:
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="r") as store:
|
||||
if name not in store.keys():
|
||||
keys = store.keys()
|
||||
# Handle both '/daily' and 'daily' keys
|
||||
actual_key = None
|
||||
if name in keys:
|
||||
actual_key = name
|
||||
elif f"/{name}" in keys:
|
||||
actual_key = f"/{name}"
|
||||
|
||||
if actual_key is None:
|
||||
return pd.DataFrame()
|
||||
|
||||
data = store[name]
|
||||
data = store[actual_key]
|
||||
|
||||
# Apply filters
|
||||
if start_date and end_date and "trade_date" in data.columns:
|
||||
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
|
||||
data = data[
|
||||
(data["trade_date"] >= start_date)
|
||||
& (data["trade_date"] <= end_date)
|
||||
]
|
||||
|
||||
if ts_code and "ts_code" in data.columns:
|
||||
data = data[data["ts_code"] == ts_code]
|
||||
|
||||
295
src/data/sync.py
295
src/data/sync.py
@@ -5,11 +5,15 @@ This module provides data fetching functions with intelligent sync logic:
|
||||
- If local file exists: incremental update (fetch from latest date + 1 day)
|
||||
- Multi-threaded concurrent fetching for improved performance
|
||||
- Stop immediately on any exception
|
||||
- Preview mode: check data volume and samples before actual sync
|
||||
|
||||
Currently supported data types:
|
||||
- daily: Daily market data (with turnover rate and volume ratio)
|
||||
|
||||
Usage:
|
||||
# Preview sync (check data volume and samples without writing)
|
||||
preview_sync()
|
||||
|
||||
# Sync all stocks (full load)
|
||||
sync_all()
|
||||
|
||||
@@ -18,6 +22,9 @@ Usage:
|
||||
|
||||
# Force full reload
|
||||
sync_all(force_full=True)
|
||||
|
||||
# Dry run (preview only, no write)
|
||||
sync_all(dry_run=True)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
@@ -30,8 +37,8 @@ import sys
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.daily import get_daily
|
||||
from src.data.trade_cal import (
|
||||
from src.data.api_wrappers import get_daily
|
||||
from src.data.api_wrappers import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
@@ -114,7 +121,8 @@ class DataSync:
|
||||
List of stock codes
|
||||
"""
|
||||
# Import sync_all_stocks here to avoid circular imports
|
||||
from src.data.stock_basic import sync_all_stocks, _get_csv_path
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||
|
||||
# First, ensure stock_basic.csv is up-to-date with all stocks
|
||||
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
|
||||
@@ -278,6 +286,184 @@ class DataSync:
|
||||
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""Preview sync data volume and samples without actually syncing.
|
||||
|
||||
This method provides a preview of what would be synced, including:
|
||||
- Number of stocks to be synced
|
||||
- Date range for sync
|
||||
- Estimated total records
|
||||
- Sample data from first few stocks
|
||||
|
||||
Args:
|
||||
force_full: If True, preview full sync from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||
|
||||
Returns:
|
||||
Dictionary with preview information:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' or 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# First, ensure trade calendar cache is up-to-date
|
||||
print("[DataSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Determine date range
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Check if sync is needed
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# Use dates from check_sync_needed
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DataSync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[DataSync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# Fetch sample data from first few stocks
|
||||
print(f"[DataSync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# Combine sample data
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# Estimate total records based on sample
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# Display preview results
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# Display sample data in a compact format
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
@@ -320,6 +506,7 @@ class DataSync:
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks in local storage.
|
||||
|
||||
@@ -337,9 +524,10 @@ class DataSync:
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
dry_run: If True, only preview what would be synced without writing data
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame (empty if sync skipped)
|
||||
Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Starting daily data sync...")
|
||||
@@ -378,11 +566,14 @@ class DataSync:
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
@@ -394,6 +585,17 @@ class DataSync:
|
||||
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# Handle dry run mode
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# Reset stop flag for new sync
|
||||
self._stop_flag.set()
|
||||
|
||||
@@ -492,11 +694,62 @@ class DataSync:
|
||||
# Convenience functions
|
||||
|
||||
|
||||
def preview_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Preview sync data volume and samples without actually syncing.
|
||||
|
||||
This is the recommended way to check what would be synced before
|
||||
running the actual synchronization.
|
||||
|
||||
Args:
|
||||
force_full: If True, preview full sync from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||
max_workers: Number of worker threads (not used in preview, for API compatibility)
|
||||
|
||||
Returns:
|
||||
Dictionary with preview information:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', or 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # Preview what would be synced
|
||||
>>> preview = preview_sync()
|
||||
>>>
|
||||
>>> # Preview full sync
|
||||
>>> preview = preview_sync(force_full=True)
|
||||
>>>
|
||||
>>> # Preview with more samples
|
||||
>>> preview = preview_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.preview_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
|
||||
|
||||
def sync_all(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks.
|
||||
|
||||
@@ -507,6 +760,7 @@ def sync_all(
|
||||
start_date: Manual start date (YYYYMMDD)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
dry_run: If True, only preview what would be synced without writing data
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame
|
||||
@@ -526,12 +780,16 @@ def sync_all(
|
||||
>>>
|
||||
>>> # Custom thread count
|
||||
>>> result = sync_all(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run (preview only)
|
||||
>>> result = sync_all(dry_run=True)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
@@ -540,11 +798,32 @@ if __name__ == "__main__":
|
||||
print("Data Sync Module")
|
||||
print("=" * 60)
|
||||
print("\nUsage:")
|
||||
print(" from src.data.sync import sync_all")
|
||||
print(" from src.data.sync import sync_all, preview_sync")
|
||||
print("")
|
||||
print(" # Preview before sync (recommended)")
|
||||
print(" preview = preview_sync()")
|
||||
print("")
|
||||
print(" # Dry run (preview only)")
|
||||
print(" result = sync_all(dry_run=True)")
|
||||
print("")
|
||||
print(" # Actual sync")
|
||||
print(" result = sync_all() # Incremental sync")
|
||||
print(" result = sync_all(force_full=True) # Full reload")
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
# Run sync
|
||||
result = sync_all()
|
||||
print(f"\nSynced {len(result)} stocks")
|
||||
# Run preview first
|
||||
print("\n[Main] Running preview first...")
|
||||
preview = preview_sync()
|
||||
|
||||
if preview["sync_needed"]:
|
||||
# Ask for confirmation
|
||||
print("\n" + "=" * 60)
|
||||
response = input("Proceed with sync? (y/n): ").strip().lower()
|
||||
if response in ("y", "yes"):
|
||||
print("\n[Main] Starting actual sync...")
|
||||
result = sync_all()
|
||||
print(f"\nSynced {len(result)} stocks")
|
||||
else:
|
||||
print("\n[Main] Sync cancelled by user")
|
||||
else:
|
||||
print("\n[Main] No sync needed - data is up to date")
|
||||
|
||||
Reference in New Issue
Block a user