refactor: 调整项目结构,新增数据同步和交易日历模块
- 移除 pyproject.toml,改用 uv 管理项目 - 新增 data/* 忽略规则 - 新增数据同步模块 sync.py - 新增交易日历模块 trade_cal.py - 新增相关测试用例 - 更新 API 文档
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -72,5 +72,5 @@ cover/
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
# 数据目录(允许跟踪)
|
||||
data/
|
||||
# 数据目录(允许跟踪,但忽略内容)
|
||||
data/*
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
[project]
|
||||
name = "ProStock"
|
||||
version = "0.1.0"
|
||||
description = "A股量化投资框架"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"pandas>=2.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"tushare>=2.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"tqdm>=4.65.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
package = false
|
||||
@@ -123,4 +123,60 @@ delist_date str N 退市日期
|
||||
is_hs str N 是否沪深港通标的,N否 H沪股通 S深股通
|
||||
act_name str Y 实控人名称
|
||||
act_ent_type str Y 实控人企业性质
|
||||
说明:旧版上的PE/PB/股本等字段,请在行情接口“每日指标”中获取。
|
||||
说明:旧版上的PE/PB/股本等字段,请在行情接口“每日指标”中获取。
|
||||
|
||||
|
||||
交易日历
|
||||
接口:trade_cal,可以通过数据工具调试和查看数据。
|
||||
描述:获取各大交易所交易日历数据,默认提取的是上交所
|
||||
积分:需2000积分
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
exchange str N 交易所 SSE上交所,SZSE深交所,CFFEX 中金所,SHFE 上期所,CZCE 郑商所,DCE 大商所,INE 上能源
|
||||
start_date str N 开始日期 (格式:YYYYMMDD 下同)
|
||||
end_date str N 结束日期
|
||||
is_open str N 是否交易 '0'休市 '1'交易
|
||||
输出参数
|
||||
|
||||
名称 类型 默认显示 描述
|
||||
exchange str Y 交易所 SSE上交所 SZSE深交所
|
||||
cal_date str Y 日历日期
|
||||
is_open str Y 是否交易 0休市 1交易
|
||||
pretrade_date str Y 上一个交易日
|
||||
接口示例
|
||||
|
||||
|
||||
pro = ts.pro_api()
|
||||
|
||||
|
||||
df = pro.trade_cal(exchange='', start_date='20180101', end_date='20181231')
|
||||
或者
|
||||
|
||||
|
||||
df = pro.query('trade_cal', start_date='20180101', end_date='20181231')
|
||||
数据样例
|
||||
|
||||
exchange cal_date is_open
|
||||
0 SSE 20180101 0
|
||||
1 SSE 20180102 1
|
||||
2 SSE 20180103 1
|
||||
3 SSE 20180104 1
|
||||
4 SSE 20180105 1
|
||||
5 SSE 20180106 0
|
||||
6 SSE 20180107 0
|
||||
7 SSE 20180108 1
|
||||
8 SSE 20180109 1
|
||||
9 SSE 20180110 1
|
||||
10 SSE 20180111 1
|
||||
11 SSE 20180112 1
|
||||
12 SSE 20180113 0
|
||||
13 SSE 20180114 0
|
||||
14 SSE 20180115 1
|
||||
15 SSE 20180116 1
|
||||
16 SSE 20180117 1
|
||||
17 SSE 20180118 1
|
||||
18 SSE 20180119 1
|
||||
19 SSE 20180120 0
|
||||
20 SSE 20180121 0
|
||||
550
src/data/sync.py
Normal file
550
src/data/sync.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""Data synchronization module.
|
||||
|
||||
This module provides data fetching functions with intelligent sync logic:
|
||||
- If local file doesn't exist: fetch all data (full load from 20180101)
|
||||
- If local file exists: incremental update (fetch from latest date + 1 day)
|
||||
- Multi-threaded concurrent fetching for improved performance
|
||||
- Stop immediately on any exception
|
||||
|
||||
Currently supported data types:
|
||||
- daily: Daily market data (with turnover rate and volume ratio)
|
||||
|
||||
Usage:
|
||||
# Sync all stocks (full load)
|
||||
sync_all()
|
||||
|
||||
# Sync all stocks (incremental)
|
||||
sync_all()
|
||||
|
||||
# Force full reload
|
||||
sync_all(force_full=True)
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, Dict, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
import sys
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.daily import get_daily
|
||||
from src.data.trade_cal import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
|
||||
|
||||
# Default full sync start date
|
||||
DEFAULT_START_DATE = "20180101"
|
||||
|
||||
# Today's date in YYYYMMDD format
|
||||
TODAY = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def get_today_date() -> str:
|
||||
"""Get today's date in YYYYMMDD format."""
|
||||
return TODAY
|
||||
|
||||
|
||||
def get_next_date(date_str: str) -> str:
|
||||
"""Get the next day after the given date.
|
||||
|
||||
Args:
|
||||
date_str: Date in YYYYMMDD format
|
||||
|
||||
Returns:
|
||||
Next date in YYYYMMDD format
|
||||
"""
|
||||
dt = datetime.strptime(date_str, "%Y%m%d")
|
||||
next_dt = dt + timedelta(days=1)
|
||||
return next_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
class DataSync:
|
||||
"""Data synchronization manager with full/incremental sync support."""
|
||||
|
||||
# Default number of worker threads
|
||||
DEFAULT_MAX_WORKERS = 10
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""Initialize sync manager.
|
||||
|
||||
Args:
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
"""
|
||||
self.storage = Storage()
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
self._stop_flag.set() # Initially not stopped
|
||||
self._cached_daily_data: Optional[pd.DataFrame] = None # Cache for daily data
|
||||
|
||||
def _load_daily_data(self) -> pd.DataFrame:
|
||||
"""Load daily data from storage with caching.
|
||||
|
||||
This method caches the daily data in memory to avoid repeated disk reads.
|
||||
Call clear_cache() to force reload.
|
||||
|
||||
Returns:
|
||||
DataFrame with daily data (cached or loaded from storage)
|
||||
"""
|
||||
if self._cached_daily_data is None:
|
||||
self._cached_daily_data = self.storage.load("daily")
|
||||
return self._cached_daily_data
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the cached daily data to force reload on next access."""
|
||||
self._cached_daily_data = None
|
||||
|
||||
def get_all_stock_codes(self, only_listed: bool = True) -> list:
|
||||
"""Get all stock codes from local storage.
|
||||
|
||||
This function prioritizes stock_basic.csv to ensure all stocks
|
||||
are included for backtesting to avoid look-ahead bias.
|
||||
|
||||
Args:
|
||||
only_listed: If True, only return currently listed stocks (L status).
|
||||
Set to False to include delisted stocks (for full backtest).
|
||||
|
||||
Returns:
|
||||
List of stock codes
|
||||
"""
|
||||
# Import sync_all_stocks here to avoid circular imports
|
||||
from src.data.stock_basic import sync_all_stocks, _get_csv_path
|
||||
|
||||
# First, ensure stock_basic.csv is up-to-date with all stocks
|
||||
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
|
||||
sync_all_stocks()
|
||||
|
||||
# Get from stock_basic.csv file
|
||||
stock_csv_path = _get_csv_path()
|
||||
|
||||
if stock_csv_path.exists():
|
||||
print(f"[DataSync] Reading stock_basic from CSV: {stock_csv_path}")
|
||||
try:
|
||||
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
|
||||
if not stock_df.empty and "ts_code" in stock_df.columns:
|
||||
# Filter by list_status if only_listed is True
|
||||
if only_listed and "list_status" in stock_df.columns:
|
||||
listed_stocks = stock_df[stock_df["list_status"] == "L"]
|
||||
codes = listed_stocks["ts_code"].unique().tolist()
|
||||
total = len(stock_df["ts_code"].unique())
|
||||
print(
|
||||
f"[DataSync] Found {len(codes)} listed stocks (filtered from {total} total)"
|
||||
)
|
||||
else:
|
||||
codes = stock_df["ts_code"].unique().tolist()
|
||||
print(
|
||||
f"[DataSync] Found {len(codes)} stock codes from stock_basic.csv"
|
||||
)
|
||||
return codes
|
||||
else:
|
||||
print(
|
||||
f"[DataSync] stock_basic.csv exists but no ts_code column or empty"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[DataSync] Error reading stock_basic.csv: {e}")
|
||||
|
||||
# Fallback: try daily storage if stock_basic not available (using cached data)
|
||||
print("[DataSync] stock_basic.csv not available, falling back to daily data...")
|
||||
daily_data = self._load_daily_data()
|
||||
if not daily_data.empty and "ts_code" in daily_data.columns:
|
||||
codes = daily_data["ts_code"].unique().tolist()
|
||||
print(f"[DataSync] Found {len(codes)} stock codes from daily data")
|
||||
return codes
|
||||
|
||||
print("[DataSync] No stock codes found in local storage")
|
||||
return []
|
||||
|
||||
def get_global_last_date(self) -> Optional[str]:
|
||||
"""Get the global last trade date across all stocks.
|
||||
|
||||
Returns:
|
||||
Last trade date string or None
|
||||
"""
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
return None
|
||||
return str(daily_data["trade_date"].max())
|
||||
|
||||
def get_global_first_date(self) -> Optional[str]:
|
||||
"""Get the global first trade date across all stocks.
|
||||
|
||||
Returns:
|
||||
First trade date string or None
|
||||
"""
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
return None
|
||||
return str(daily_data["trade_date"].min())
|
||||
|
||||
def get_trade_calendar_bounds(
|
||||
self, start_date: str, end_date: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Get the first and last trading day from trade calendar.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
|
||||
Returns:
|
||||
Tuple of (first_trading_day, last_trading_day) or (None, None) if error
|
||||
"""
|
||||
try:
|
||||
first_day = get_first_trading_day(start_date, end_date)
|
||||
last_day = get_last_trading_day(start_date, end_date)
|
||||
return (first_day, last_day)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
|
||||
return (None, None)
|
||||
|
||||
def check_sync_needed(
|
||||
self, force_full: bool = False
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""Check if sync is needed based on trade calendar.
|
||||
|
||||
This method compares local data date range with trade calendar
|
||||
to determine if new data needs to be fetched.
|
||||
|
||||
Logic:
|
||||
- If force_full: sync needed, return (True, 20180101, today)
|
||||
- If no local data: sync needed, return (True, 20180101, today)
|
||||
- If local data exists:
|
||||
- Get the last trading day from trade calendar
|
||||
- If local last date >= calendar last date: NO sync needed
|
||||
- Otherwise: sync needed from local_last_date + 1 to latest trade day
|
||||
|
||||
Args:
|
||||
force_full: If True, always return sync needed
|
||||
|
||||
Returns:
|
||||
Tuple of (sync_needed, start_date, end_date, local_last_date)
|
||||
- sync_needed: True if sync should proceed, False to skip
|
||||
- start_date: Sync start date (None if sync not needed)
|
||||
- end_date: Sync end date (None if sync not needed)
|
||||
- local_last_date: Local data last date (for incremental sync)
|
||||
"""
|
||||
# If force_full, always sync
|
||||
if force_full:
|
||||
print("[DataSync] Force full sync requested")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# Check if local data exists (using cached data)
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
print("[DataSync] No local data found, full sync needed")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# Get local data last date (we only care about the latest date, not the first)
|
||||
local_last_date = str(daily_data["trade_date"].max())
|
||||
|
||||
print(f"[DataSync] Local data last date: {local_last_date}")
|
||||
|
||||
# Get the latest trading day from trade calendar
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print("[DataSync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[DataSync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# Compare local last date with calendar last date
|
||||
# If local data is already up-to-date or newer, no sync needed
|
||||
print(
|
||||
f"[DataSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), cal={cal_last} (type={type(cal_last).__name__})"
|
||||
)
|
||||
try:
|
||||
local_last_int = int(local_last_date)
|
||||
cal_last_int = int(cal_last)
|
||||
print(
|
||||
f"[DataSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = {local_last_int >= cal_last_int}"
|
||||
)
|
||||
if local_last_int >= cal_last_int:
|
||||
print(
|
||||
"[DataSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
|
||||
)
|
||||
return (False, None, None, None)
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"[ERROR] Date comparison failed: {e}")
|
||||
|
||||
# Need to sync from local_last_date + 1 to latest trade day
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync daily data for a single stock.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code
|
||||
start_date: Start date (YYYYMMDD)
|
||||
end_date: End date (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
DataFrame with daily market data
|
||||
"""
|
||||
# Check if sync should stop (for exception handling)
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# Use shared client for rate limiting across threads
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
# Set stop flag to signal other threads to stop
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {ts_code}: {e}")
|
||||
raise
|
||||
|
||||
def sync_all(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks in local storage.
|
||||
|
||||
This function:
|
||||
1. Reads stock codes from local storage (daily or stock_basic)
|
||||
2. Checks trade calendar to determine if sync is needed:
|
||||
- If local data matches trade calendar bounds, SKIP sync (save tokens)
|
||||
- Otherwise, sync from local_last_date + 1 to latest trade day (bandwidth optimized)
|
||||
3. Uses multi-threaded concurrent fetching with rate limiting
|
||||
4. Skips updating stocks that return empty data (delisted/unavailable)
|
||||
5. Stops immediately on any exception
|
||||
|
||||
Args:
|
||||
force_full: If True, force full reload from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame (empty if sync skipped)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Starting daily data sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# First, ensure trade calendar cache is up-to-date (uses incremental sync)
|
||||
print("[DataSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Determine date range
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Check if sync is needed based on trade calendar
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
# Sync skipped - no tokens consumed
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
|
||||
print(" Tokens saved: 0 consumed")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# Use dates from check_sync_needed (which calculates incremental start if needed)
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
# Fallback to default logic
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DataSync] No stocks found to sync")
|
||||
return {}
|
||||
|
||||
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# Reset stop flag for new sync
|
||||
self._stop_flag.set()
|
||||
|
||||
# Multi-threaded concurrent fetching
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
error_occurred = False
|
||||
exception_to_raise = None
|
||||
|
||||
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
|
||||
"""Task function for each stock."""
|
||||
try:
|
||||
data = self.sync_single_stock(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return (ts_code, data)
|
||||
except Exception as e:
|
||||
# Re-raise to be caught by Future
|
||||
raise
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent fetching
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# Submit all tasks and track futures with their stock codes
|
||||
future_to_code = {
|
||||
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
|
||||
}
|
||||
|
||||
# Process results using as_completed
|
||||
error_count = 0
|
||||
empty_count = 0
|
||||
success_count = 0
|
||||
|
||||
# Create progress bar
|
||||
pbar = tqdm(total=len(stock_codes), desc="Syncing stocks")
|
||||
|
||||
try:
|
||||
# Process futures as they complete
|
||||
for future in as_completed(future_to_code):
|
||||
ts_code = future_to_code[future]
|
||||
|
||||
try:
|
||||
_, data = future.result()
|
||||
if data is not None and not data.empty:
|
||||
results[ts_code] = data
|
||||
success_count += 1
|
||||
else:
|
||||
# Empty data - stock may be delisted or unavailable
|
||||
empty_count += 1
|
||||
print(
|
||||
f"[DataSync] Stock {ts_code}: empty data (skipped, may be delisted)"
|
||||
)
|
||||
except Exception as e:
|
||||
# Exception occurred - stop all and abort
|
||||
error_occurred = True
|
||||
exception_to_raise = e
|
||||
print(f"\n[ERROR] Sync aborted due to exception: {e}")
|
||||
# Shutdown executor to stop all pending tasks
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise exception_to_raise
|
||||
|
||||
# Update progress bar
|
||||
pbar.update(1)
|
||||
|
||||
except Exception:
|
||||
error_count = 1
|
||||
print("[DataSync] Sync stopped due to exception")
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
# Write all data at once (only if no error)
|
||||
if results and not error_occurred:
|
||||
combined_data = pd.concat(results.values(), ignore_index=True)
|
||||
self.storage.save("daily", combined_data, mode="append")
|
||||
print(f"\n[DataSync] Saved {len(combined_data)} rows to storage")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(f" Total stocks: {len(stock_codes)}")
|
||||
print(f" Updated: {success_count}")
|
||||
print(f" Skipped (empty/delisted): {empty_count}")
|
||||
print(
|
||||
f" Errors: {error_count} (aborted on first error)"
|
||||
if error_count
|
||||
else " Errors: 0"
|
||||
)
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
def sync_all(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks.
|
||||
|
||||
This is the main entry point for data synchronization.
|
||||
|
||||
Args:
|
||||
force_full: If True, force full reload from 20180101
|
||||
start_date: Manual start date (YYYYMMDD)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame
|
||||
|
||||
Example:
|
||||
>>> # First time sync (full load from 20180101)
|
||||
>>> result = sync_all()
|
||||
>>>
|
||||
>>> # Subsequent sync (incremental - only new data)
|
||||
>>> result = sync_all()
|
||||
>>>
|
||||
>>> # Force full reload
|
||||
>>> result = sync_all(force_full=True)
|
||||
>>>
|
||||
>>> # Manual date range
|
||||
>>> result = sync_all(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # Custom thread count
|
||||
>>> result = sync_all(max_workers=20)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Data Sync Module")
|
||||
print("=" * 60)
|
||||
print("\nUsage:")
|
||||
print(" from src.data.sync import sync_all")
|
||||
print(" result = sync_all() # Incremental sync")
|
||||
print(" result = sync_all(force_full=True) # Full reload")
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
# Run sync
|
||||
result = sync_all()
|
||||
print(f"\nSynced {len(result)} stocks")
|
||||
321
src/data/trade_cal.py
Normal file
321
src/data/trade_cal.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Trade calendar interface.
|
||||
|
||||
Fetch trading calendar data from Tushare to determine market open/close dates.
|
||||
With local caching for performance optimization.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, Literal
|
||||
from pathlib import Path
|
||||
from src.data.client import TushareClient
|
||||
from src.data.config import get_config
|
||||
|
||||
|
||||
# Trading calendar cache file path
|
||||
def _get_cache_path() -> Path:
|
||||
"""Get the cache file path for trade calendar."""
|
||||
cfg = get_config()
|
||||
return cfg.data_path_resolved / "trade_cal.h5"
|
||||
|
||||
|
||||
def _save_to_cache(data: pd.DataFrame) -> None:
|
||||
"""Save trade calendar data to local cache.
|
||||
|
||||
Args:
|
||||
data: Trade calendar DataFrame
|
||||
"""
|
||||
if data.empty:
|
||||
return
|
||||
|
||||
cache_path = _get_cache_path()
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with pd.HDFStore(cache_path, mode="a") as store:
|
||||
store.put("trade_cal", data, format="table")
|
||||
print(f"[trade_cal] Saved {len(data)} records to cache: {cache_path}")
|
||||
except Exception as e:
|
||||
print(f"[trade_cal] Error saving to cache: {e}")
|
||||
|
||||
|
||||
def _load_from_cache() -> pd.DataFrame:
|
||||
"""Load trade calendar data from local cache.
|
||||
|
||||
Returns:
|
||||
Trade calendar DataFrame or empty DataFrame if cache doesn't exist
|
||||
"""
|
||||
cache_path = _get_cache_path()
|
||||
|
||||
if not cache_path.exists():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
with pd.HDFStore(cache_path, mode="r") as store:
|
||||
if "trade_cal" in store.keys():
|
||||
data = store["trade_cal"]
|
||||
print(f"[trade_cal] Loaded {len(data)} records from cache")
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f"[trade_cal] Error loading from cache: {e}")
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
def _get_cached_date_range() -> tuple[Optional[str], Optional[str]]:
|
||||
"""Get the date range of cached trade calendar.
|
||||
|
||||
Returns:
|
||||
Tuple of (min_date, max_date) or (None, None) if cache empty
|
||||
"""
|
||||
data = _load_from_cache()
|
||||
if data.empty or "cal_date" not in data.columns:
|
||||
return (None, None)
|
||||
|
||||
return (str(data["cal_date"].min()), str(data["cal_date"].max()))
|
||||
|
||||
|
||||
def sync_trade_cal_cache(
|
||||
start_date: str = "20180101",
|
||||
end_date: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync trade calendar data to local cache with incremental updates.
|
||||
|
||||
This function checks if we have cached data and only fetches new data
|
||||
from the last cached date onwards.
|
||||
|
||||
Args:
|
||||
start_date: Initial start date for full sync (default: 20180101)
|
||||
end_date: End date (defaults to today)
|
||||
|
||||
Returns:
|
||||
Full trade calendar DataFrame (cached + new)
|
||||
"""
|
||||
if end_date is None:
|
||||
from datetime import datetime
|
||||
|
||||
end_date = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
client = TushareClient()
|
||||
|
||||
# Check cached data range
|
||||
cached_min, cached_max = _get_cached_date_range()
|
||||
|
||||
if cached_min and cached_max:
|
||||
print(f"[trade_cal] Cache found: {cached_min} to {cached_max}")
|
||||
# Only fetch new data after the cached max date
|
||||
fetch_start = str(int(cached_max) + 1)
|
||||
print(f"[trade_cal] Fetching incremental data from {fetch_start} to {end_date}")
|
||||
|
||||
if int(fetch_start) > int(end_date):
|
||||
print("[trade_cal] Cache is up-to-date, no new data needed")
|
||||
return _load_from_cache()
|
||||
|
||||
# Fetch new data
|
||||
new_data = client.query(
|
||||
"trade_cal",
|
||||
start_date=fetch_start,
|
||||
end_date=end_date,
|
||||
exchange="SSE",
|
||||
)
|
||||
|
||||
if new_data.empty:
|
||||
print("[trade_cal] No new data returned")
|
||||
return _load_from_cache()
|
||||
|
||||
print(f"[trade_cal] Fetched {len(new_data)} new records")
|
||||
|
||||
# Load cached data and merge
|
||||
cached_data = _load_from_cache()
|
||||
if not cached_data.empty:
|
||||
combined = pd.concat([cached_data, new_data], ignore_index=True)
|
||||
# Remove duplicates by cal_date
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["cal_date", "exchange"], keep="first"
|
||||
)
|
||||
combined = combined.sort_values("cal_date").reset_index(drop=True)
|
||||
else:
|
||||
combined = new_data
|
||||
|
||||
# Save combined data to cache
|
||||
_save_to_cache(combined)
|
||||
return combined
|
||||
else:
|
||||
# No cache, fetch all data
|
||||
print(f"[trade_cal] No cache found, fetching from {start_date} to {end_date}")
|
||||
data = client.query(
|
||||
"trade_cal",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
exchange="SSE",
|
||||
)
|
||||
|
||||
if data.empty:
|
||||
print("[trade_cal] No data returned")
|
||||
return data
|
||||
|
||||
_save_to_cache(data)
|
||||
return data
|
||||
|
||||
|
||||
def get_trade_cal(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
|
||||
is_open: Optional[Literal["0", "1"]] = None,
|
||||
use_cache: bool = True,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch trading calendar data with optional local caching.
|
||||
|
||||
This interface retrieves trading calendar information including
|
||||
whether each date is a trading day. Uses cached data when available
|
||||
to reduce API calls and improve performance.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
exchange: Exchange - SSE (Shanghai), SZSE (Shenzhen), BSE (Beijing)
|
||||
is_open: Open status - "1" for trading day, "0" for non-trading day
|
||||
use_cache: Whether to use and update local cache (default: True)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with trade calendar containing:
|
||||
- cal_date: Calendar date (YYYYMMDD)
|
||||
- exchange: Exchange code
|
||||
- is_open: Whether it's a trading day (1/0)
|
||||
- pretrade_date: Previous trading day
|
||||
|
||||
Example:
|
||||
>>> # Get all trading days in January 2024
|
||||
>>> cal = get_trade_cal('20240101', '20240131')
|
||||
>>> trading_days = cal[cal['is_open'] == '1']
|
||||
>>>
|
||||
>>> # Get first and last trading day of a period
|
||||
>>> cal = get_trade_cal('20180101', '20240101')
|
||||
>>> first_trade_day = cal[cal['is_open'] == '1'].iloc[0]['cal_date']
|
||||
>>> last_trade_day = cal[cal['is_open'] == '1'].iloc[-1]['cal_date']
|
||||
"""
|
||||
# Use cache if enabled
|
||||
if use_cache and exchange == "SSE":
|
||||
# Sync cache first (incremental)
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Load from cache and filter by date range
|
||||
cached_data = _load_from_cache()
|
||||
if not cached_data.empty and "cal_date" in cached_data.columns:
|
||||
# Filter by date range and exchange
|
||||
filtered = cached_data[
|
||||
(cached_data["cal_date"] >= start_date)
|
||||
& (cached_data["cal_date"] <= end_date)
|
||||
& (cached_data["exchange"] == exchange)
|
||||
]
|
||||
|
||||
# Apply is_open filter if specified
|
||||
if is_open is not None:
|
||||
# Handle type mismatch: HDF5 stores is_open as int, but API returns str
|
||||
filtered = filtered[filtered["is_open"].astype(str) == str(is_open)]
|
||||
|
||||
if not filtered.empty:
|
||||
print(f"[get_trade_cal] Retrieved {len(filtered)} records from cache")
|
||||
return filtered
|
||||
|
||||
# Fallback to API if cache not available or disabled
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"exchange": exchange,
|
||||
}
|
||||
|
||||
if is_open is not None:
|
||||
params["is_open"] = is_open
|
||||
|
||||
# Fetch data
|
||||
data = client.query("trade_cal", **params)
|
||||
|
||||
if data.empty:
|
||||
print("[get_trade_cal] No data returned")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_trading_days(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
|
||||
) -> list:
|
||||
"""Get list of trading days in a date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
exchange: Exchange code
|
||||
|
||||
Returns:
|
||||
List of trading dates (YYYYMMDD strings)
|
||||
"""
|
||||
cal = get_trade_cal(start_date, end_date, exchange=exchange, is_open="1")
|
||||
if cal.empty:
|
||||
return []
|
||||
return cal["cal_date"].tolist()
|
||||
|
||||
|
||||
def get_first_trading_day(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
|
||||
) -> Optional[str]:
|
||||
"""Get the first trading day in a date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
exchange: Exchange code
|
||||
|
||||
Returns:
|
||||
First trading date (YYYYMMDD) or None if no trading days
|
||||
"""
|
||||
trading_days = get_trading_days(start_date, end_date, exchange)
|
||||
if not trading_days:
|
||||
return None
|
||||
# Trading days are sorted in descending order (newest first) from cache
|
||||
return trading_days[-1]
|
||||
|
||||
|
||||
def get_last_trading_day(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
|
||||
) -> Optional[str]:
|
||||
"""Get the last trading day in a date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
exchange: Exchange code
|
||||
|
||||
Returns:
|
||||
Last trading date (YYYYMMDD) or None if no trading days
|
||||
"""
|
||||
trading_days = get_trading_days(start_date, end_date, exchange)
|
||||
if not trading_days:
|
||||
return None
|
||||
# Trading days are sorted in descending order (newest first) from cache
|
||||
return trading_days[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
start = "20180101"
|
||||
end = "20240101"
|
||||
|
||||
print(f"Trade calendar from {start} to {end}")
|
||||
|
||||
cal = get_trade_cal(start, end)
|
||||
print(f"Total records: {len(cal)}")
|
||||
|
||||
first_day = get_first_trading_day(start, end)
|
||||
last_day = get_last_trading_day(start, end)
|
||||
print(f"First trading day: {first_day}")
|
||||
print(f"Last trading day: {last_day}")
|
||||
190
tests/test_daily_storage.py
Normal file
190
tests/test_daily_storage.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Tests for data/daily.h5 storage validation.
|
||||
|
||||
Validates two key points:
|
||||
1. All stocks from stock_basic.csv are saved in daily.h5
|
||||
2. No abnormal data with very few data points (< 10 rows per stock)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from src.data.storage import Storage
|
||||
from src.data.stock_basic import _get_csv_path
|
||||
|
||||
|
||||
class TestDailyStorageValidation:
|
||||
"""Test daily.h5 storage integrity and completeness."""
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self):
|
||||
"""Create storage instance."""
|
||||
return Storage()
|
||||
|
||||
@pytest.fixture
|
||||
def stock_basic_df(self):
|
||||
"""Load stock basic data from CSV."""
|
||||
csv_path = _get_csv_path()
|
||||
if not csv_path.exists():
|
||||
pytest.skip(f"stock_basic.csv not found at {csv_path}")
|
||||
return pd.read_csv(csv_path)
|
||||
|
||||
@pytest.fixture
|
||||
def daily_df(self, storage):
|
||||
"""Load daily data from HDF5."""
|
||||
if not storage.exists("daily"):
|
||||
pytest.skip("daily.h5 not found")
|
||||
# HDF5 stores keys with leading slash, so we need to handle both '/daily' and 'daily'
|
||||
file_path = storage._get_file_path("daily")
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="r") as store:
|
||||
if "/daily" in store.keys():
|
||||
return store["/daily"]
|
||||
elif "daily" in store.keys():
|
||||
return store["daily"]
|
||||
return pd.DataFrame()
|
||||
except Exception as e:
|
||||
pytest.skip(f"Error loading daily.h5: {e}")
|
||||
|
||||
def test_all_stocks_saved(self, storage, stock_basic_df, daily_df):
|
||||
"""Verify all stocks from stock_basic are saved in daily.h5.
|
||||
|
||||
This test ensures data completeness - every stock in stock_basic
|
||||
should have corresponding data in daily.h5.
|
||||
"""
|
||||
if daily_df.empty:
|
||||
pytest.fail("daily.h5 is empty")
|
||||
|
||||
# Get unique stock codes from both sources
|
||||
expected_codes = set(stock_basic_df["ts_code"].dropna().unique())
|
||||
actual_codes = set(daily_df["ts_code"].dropna().unique())
|
||||
|
||||
# Check for missing stocks
|
||||
missing_codes = expected_codes - actual_codes
|
||||
|
||||
if missing_codes:
|
||||
missing_list = sorted(missing_codes)
|
||||
# Show first 20 missing stocks as sample
|
||||
sample = missing_list[:20]
|
||||
msg = f"Found {len(missing_codes)} stocks missing from daily.h5:\n"
|
||||
msg += f"Sample missing: {sample}\n"
|
||||
if len(missing_list) > 20:
|
||||
msg += f"... and {len(missing_list) - 20} more"
|
||||
pytest.fail(msg)
|
||||
|
||||
# All stocks present
|
||||
assert len(actual_codes) > 0, "No stocks found in daily.h5"
|
||||
print(
|
||||
f"[TEST] All {len(expected_codes)} stocks from stock_basic are present in daily.h5"
|
||||
)
|
||||
|
||||
def test_no_stock_with_insufficient_data(self, storage, daily_df):
|
||||
"""Verify no stock has abnormally few data points (< 10 rows).
|
||||
|
||||
Stocks with very few data points may indicate sync failures,
|
||||
delisted stocks not properly handled, or data corruption.
|
||||
"""
|
||||
if daily_df.empty:
|
||||
pytest.fail("daily.h5 is empty")
|
||||
|
||||
# Count rows per stock
|
||||
stock_counts = daily_df.groupby("ts_code").size()
|
||||
|
||||
# Find stocks with less than 10 data points
|
||||
insufficient_stocks = stock_counts[stock_counts < 10]
|
||||
|
||||
if not insufficient_stocks.empty:
|
||||
# Separate into categories for better reporting
|
||||
empty_stocks = stock_counts[stock_counts == 0]
|
||||
very_few_stocks = stock_counts[(stock_counts > 0) & (stock_counts < 10)]
|
||||
|
||||
msg = f"Found {len(insufficient_stocks)} stocks with insufficient data (< 10 rows):\n"
|
||||
|
||||
if not empty_stocks.empty:
|
||||
msg += f"\nEmpty stocks (0 rows): {len(empty_stocks)}\n"
|
||||
sample = sorted(empty_stocks.index[:10].tolist())
|
||||
msg += f"Sample: {sample}"
|
||||
|
||||
if not very_few_stocks.empty:
|
||||
msg += f"\nVery few data points (1-9 rows): {len(very_few_stocks)}\n"
|
||||
# Show counts for these stocks
|
||||
sample = very_few_stocks.sort_values().head(20)
|
||||
msg += "Sample (ts_code: count):\n"
|
||||
for code, count in sample.items():
|
||||
msg += f" {code}: {count} rows\n"
|
||||
|
||||
pytest.fail(msg)
|
||||
|
||||
print(f"[TEST] All stocks have sufficient data (>= 10 rows)")
|
||||
|
||||
def test_data_integrity_basic(self, storage, daily_df):
|
||||
"""Basic data integrity checks for daily.h5."""
|
||||
if daily_df.empty:
|
||||
pytest.fail("daily.h5 is empty")
|
||||
|
||||
# Check required columns exist
|
||||
required_columns = ["ts_code", "trade_date"]
|
||||
missing_columns = [
|
||||
col for col in required_columns if col not in daily_df.columns
|
||||
]
|
||||
|
||||
if missing_columns:
|
||||
pytest.fail(f"Missing required columns: {missing_columns}")
|
||||
|
||||
# Check for null values in key columns
|
||||
null_ts_code = daily_df["ts_code"].isna().sum()
|
||||
null_trade_date = daily_df["trade_date"].isna().sum()
|
||||
|
||||
if null_ts_code > 0:
|
||||
pytest.fail(f"Found {null_ts_code} rows with null ts_code")
|
||||
if null_trade_date > 0:
|
||||
pytest.fail(f"Found {null_trade_date} rows with null trade_date")
|
||||
|
||||
print(f"[TEST] Data integrity check passed")
|
||||
|
||||
def test_stock_data_coverage_report(self, storage, daily_df):
|
||||
"""Generate a summary report of stock data coverage.
|
||||
|
||||
This test provides visibility into data distribution without failing.
|
||||
"""
|
||||
if daily_df.empty:
|
||||
pytest.skip("daily.h5 is empty - cannot generate report")
|
||||
|
||||
stock_counts = daily_df.groupby("ts_code").size()
|
||||
|
||||
# Calculate statistics
|
||||
total_stocks = len(stock_counts)
|
||||
min_count = stock_counts.min()
|
||||
max_count = stock_counts.max()
|
||||
median_count = stock_counts.median()
|
||||
mean_count = stock_counts.mean()
|
||||
|
||||
# Distribution buckets
|
||||
very_low = (stock_counts < 10).sum()
|
||||
low = ((stock_counts >= 10) & (stock_counts < 100)).sum()
|
||||
medium = ((stock_counts >= 100) & (stock_counts < 500)).sum()
|
||||
high = (stock_counts >= 500).sum()
|
||||
|
||||
report = f"""
|
||||
=== Stock Data Coverage Report ===
|
||||
Total stocks: {total_stocks}
|
||||
Data points per stock:
|
||||
Min: {min_count}
|
||||
Max: {max_count}
|
||||
Median: {median_count:.0f}
|
||||
Mean: {mean_count:.1f}
|
||||
|
||||
Distribution:
|
||||
< 10 rows: {very_low} stocks ({very_low / total_stocks * 100:.1f}%)
|
||||
10-99: {low} stocks ({low / total_stocks * 100:.1f}%)
|
||||
100-499: {medium} stocks ({medium / total_stocks * 100:.1f}%)
|
||||
>= 500: {high} stocks ({high / total_stocks * 100:.1f}%)
|
||||
"""
|
||||
print(report)
|
||||
|
||||
# This is an informational test - it should not fail
|
||||
# But we assert to mark it as passed
|
||||
assert total_stocks > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
20
tests/test_tushare_api.py
Normal file
20
tests/test_tushare_api.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Tushare API 验证脚本 - 快速生成 pro 对象用于调试。"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ.setdefault("DATA_PATH", "data")
|
||||
|
||||
from src.data.config import get_config
|
||||
import tushare as ts
|
||||
|
||||
config = get_config()
|
||||
token = config.tushare_token
|
||||
|
||||
if not token:
|
||||
raise ValueError("请在 config/.env.local 中配置 TUSHARE_TOKEN")
|
||||
|
||||
pro = ts.pro_api(token)
|
||||
print(f"pro_api 对象已创建,token: {token[:10]}...")
|
||||
|
||||
df = pro.query('daily', ts_code='000001.SZ', start_date='20180702', end_date='20180718')
|
||||
print(df)
|
||||
Reference in New Issue
Block a user