Files
ProStock/src/data/sync.py

551 lines
20 KiB
Python
Raw Normal View History

"""Data synchronization module.
This module provides data fetching functions with intelligent sync logic:
- If local file doesn't exist: fetch all data (full load from 20180101)
- If local file exists: incremental update (fetch from latest date + 1 day)
- Multi-threaded concurrent fetching for improved performance
- Stop immediately on any exception
Currently supported data types:
- daily: Daily market data (with turnover rate and volume ratio)
Usage:
# Sync all stocks (full load)
sync_all()
# Sync all stocks (incremental)
sync_all()
# Force full reload
sync_all(force_full=True)
"""
import pandas as pd
from typing import Optional, Dict, Callable
from datetime import datetime, timedelta
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import sys
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.daily import get_daily
from src.data.trade_cal import (
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
)
# Default full sync start date
DEFAULT_START_DATE = "20180101"
# Today's date in YYYYMMDD format
TODAY = datetime.now().strftime("%Y%m%d")
def get_today_date() -> str:
"""Get today's date in YYYYMMDD format."""
return TODAY
def get_next_date(date_str: str) -> str:
"""Get the next day after the given date.
Args:
date_str: Date in YYYYMMDD format
Returns:
Next date in YYYYMMDD format
"""
dt = datetime.strptime(date_str, "%Y%m%d")
next_dt = dt + timedelta(days=1)
return next_dt.strftime("%Y%m%d")
class DataSync:
"""Data synchronization manager with full/incremental sync support."""
# Default number of worker threads
DEFAULT_MAX_WORKERS = 10
def __init__(self, max_workers: Optional[int] = None):
"""Initialize sync manager.
Args:
max_workers: Number of worker threads (default: 10)
"""
self.storage = Storage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()
self._stop_flag.set() # Initially not stopped
self._cached_daily_data: Optional[pd.DataFrame] = None # Cache for daily data
def _load_daily_data(self) -> pd.DataFrame:
"""Load daily data from storage with caching.
This method caches the daily data in memory to avoid repeated disk reads.
Call clear_cache() to force reload.
Returns:
DataFrame with daily data (cached or loaded from storage)
"""
if self._cached_daily_data is None:
self._cached_daily_data = self.storage.load("daily")
return self._cached_daily_data
def clear_cache(self) -> None:
"""Clear the cached daily data to force reload on next access."""
self._cached_daily_data = None
def get_all_stock_codes(self, only_listed: bool = True) -> list:
"""Get all stock codes from local storage.
This function prioritizes stock_basic.csv to ensure all stocks
are included for backtesting to avoid look-ahead bias.
Args:
only_listed: If True, only return currently listed stocks (L status).
Set to False to include delisted stocks (for full backtest).
Returns:
List of stock codes
"""
# Import sync_all_stocks here to avoid circular imports
from src.data.stock_basic import sync_all_stocks, _get_csv_path
# First, ensure stock_basic.csv is up-to-date with all stocks
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
sync_all_stocks()
# Get from stock_basic.csv file
stock_csv_path = _get_csv_path()
if stock_csv_path.exists():
print(f"[DataSync] Reading stock_basic from CSV: {stock_csv_path}")
try:
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
if not stock_df.empty and "ts_code" in stock_df.columns:
# Filter by list_status if only_listed is True
if only_listed and "list_status" in stock_df.columns:
listed_stocks = stock_df[stock_df["list_status"] == "L"]
codes = listed_stocks["ts_code"].unique().tolist()
total = len(stock_df["ts_code"].unique())
print(
f"[DataSync] Found {len(codes)} listed stocks (filtered from {total} total)"
)
else:
codes = stock_df["ts_code"].unique().tolist()
print(
f"[DataSync] Found {len(codes)} stock codes from stock_basic.csv"
)
return codes
else:
print(
f"[DataSync] stock_basic.csv exists but no ts_code column or empty"
)
except Exception as e:
print(f"[DataSync] Error reading stock_basic.csv: {e}")
# Fallback: try daily storage if stock_basic not available (using cached data)
print("[DataSync] stock_basic.csv not available, falling back to daily data...")
daily_data = self._load_daily_data()
if not daily_data.empty and "ts_code" in daily_data.columns:
codes = daily_data["ts_code"].unique().tolist()
print(f"[DataSync] Found {len(codes)} stock codes from daily data")
return codes
print("[DataSync] No stock codes found in local storage")
return []
def get_global_last_date(self) -> Optional[str]:
"""Get the global last trade date across all stocks.
Returns:
Last trade date string or None
"""
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
return None
return str(daily_data["trade_date"].max())
def get_global_first_date(self) -> Optional[str]:
"""Get the global first trade date across all stocks.
Returns:
First trade date string or None
"""
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
return None
return str(daily_data["trade_date"].min())
def get_trade_calendar_bounds(
self, start_date: str, end_date: str
) -> tuple[Optional[str], Optional[str]]:
"""Get the first and last trading day from trade calendar.
Args:
start_date: Start date in YYYYMMDD format
end_date: End date in YYYYMMDD format
Returns:
Tuple of (first_trading_day, last_trading_day) or (None, None) if error
"""
try:
first_day = get_first_trading_day(start_date, end_date)
last_day = get_last_trading_day(start_date, end_date)
return (first_day, last_day)
except Exception as e:
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
return (None, None)
def check_sync_needed(
self, force_full: bool = False
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""Check if sync is needed based on trade calendar.
This method compares local data date range with trade calendar
to determine if new data needs to be fetched.
Logic:
- If force_full: sync needed, return (True, 20180101, today)
- If no local data: sync needed, return (True, 20180101, today)
- If local data exists:
- Get the last trading day from trade calendar
- If local last date >= calendar last date: NO sync needed
- Otherwise: sync needed from local_last_date + 1 to latest trade day
Args:
force_full: If True, always return sync needed
Returns:
Tuple of (sync_needed, start_date, end_date, local_last_date)
- sync_needed: True if sync should proceed, False to skip
- start_date: Sync start date (None if sync not needed)
- end_date: Sync end date (None if sync not needed)
- local_last_date: Local data last date (for incremental sync)
"""
# If force_full, always sync
if force_full:
print("[DataSync] Force full sync requested")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# Check if local data exists (using cached data)
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
print("[DataSync] No local data found, full sync needed")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# Get local data last date (we only care about the latest date, not the first)
local_last_date = str(daily_data["trade_date"].max())
print(f"[DataSync] Local data last date: {local_last_date}")
# Get the latest trading day from trade calendar
today = get_today_date()
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
if cal_last is None:
print("[DataSync] Failed to get trade calendar, proceeding with sync")
return (True, DEFAULT_START_DATE, today, local_last_date)
print(f"[DataSync] Calendar last trading day: {cal_last}")
# Compare local last date with calendar last date
# If local data is already up-to-date or newer, no sync needed
print(
f"[DataSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), cal={cal_last} (type={type(cal_last).__name__})"
)
try:
local_last_int = int(local_last_date)
cal_last_int = int(cal_last)
print(
f"[DataSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = {local_last_int >= cal_last_int}"
)
if local_last_int >= cal_last_int:
print(
"[DataSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
)
return (False, None, None, None)
except (ValueError, TypeError) as e:
print(f"[ERROR] Date comparison failed: {e}")
# Need to sync from local_last_date + 1 to latest trade day
sync_start = get_next_date(local_last_date)
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""Sync daily data for a single stock.
Args:
ts_code: Stock code
start_date: Start date (YYYYMMDD)
end_date: End date (YYYYMMDD)
Returns:
DataFrame with daily market data
"""
# Check if sync should stop (for exception handling)
if not self._stop_flag.is_set():
return pd.DataFrame()
try:
# Use shared client for rate limiting across threads
data = self.client.query(
"pro_bar",
ts_code=ts_code,
start_date=start_date,
end_date=end_date,
factors="tor,vr",
)
return data
except Exception as e:
# Set stop flag to signal other threads to stop
self._stop_flag.clear()
print(f"[ERROR] Exception syncing {ts_code}: {e}")
raise
def sync_all(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
) -> Dict[str, pd.DataFrame]:
"""Sync daily data for all stocks in local storage.
This function:
1. Reads stock codes from local storage (daily or stock_basic)
2. Checks trade calendar to determine if sync is needed:
- If local data matches trade calendar bounds, SKIP sync (save tokens)
- Otherwise, sync from local_last_date + 1 to latest trade day (bandwidth optimized)
3. Uses multi-threaded concurrent fetching with rate limiting
4. Skips updating stocks that return empty data (delisted/unavailable)
5. Stops immediately on any exception
Args:
force_full: If True, force full reload from 20180101
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
max_workers: Number of worker threads (default: 10)
Returns:
Dict mapping ts_code to DataFrame (empty if sync skipped)
"""
print("\n" + "=" * 60)
print("[DataSync] Starting daily data sync...")
print("=" * 60)
# First, ensure trade calendar cache is up-to-date (uses incremental sync)
print("[DataSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# Determine date range
if end_date is None:
end_date = get_today_date()
# Check if sync is needed based on trade calendar
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
# Sync skipped - no tokens consumed
print("\n" + "=" * 60)
print("[DataSync] Sync Summary")
print("=" * 60)
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
print(" Tokens saved: 0 consumed")
print("=" * 60)
return {}
# Use dates from check_sync_needed (which calculates incremental start if needed)
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
# Fallback to default logic
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# Determine sync mode
if force_full:
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
else:
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
# Get all stock codes
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DataSync] No stocks found to sync")
return {}
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
# Reset stop flag for new sync
self._stop_flag.set()
# Multi-threaded concurrent fetching
results: Dict[str, pd.DataFrame] = {}
error_occurred = False
exception_to_raise = None
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
"""Task function for each stock."""
try:
data = self.sync_single_stock(
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
)
return (ts_code, data)
except Exception as e:
# Re-raise to be caught by Future
raise
# Use ThreadPoolExecutor for concurrent fetching
workers = max_workers or self.max_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
# Submit all tasks and track futures with their stock codes
future_to_code = {
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
}
# Process results using as_completed
error_count = 0
empty_count = 0
success_count = 0
# Create progress bar
pbar = tqdm(total=len(stock_codes), desc="Syncing stocks")
try:
# Process futures as they complete
for future in as_completed(future_to_code):
ts_code = future_to_code[future]
try:
_, data = future.result()
if data is not None and not data.empty:
results[ts_code] = data
success_count += 1
else:
# Empty data - stock may be delisted or unavailable
empty_count += 1
print(
f"[DataSync] Stock {ts_code}: empty data (skipped, may be delisted)"
)
except Exception as e:
# Exception occurred - stop all and abort
error_occurred = True
exception_to_raise = e
print(f"\n[ERROR] Sync aborted due to exception: {e}")
# Shutdown executor to stop all pending tasks
executor.shutdown(wait=False, cancel_futures=True)
raise exception_to_raise
# Update progress bar
pbar.update(1)
except Exception:
error_count = 1
print("[DataSync] Sync stopped due to exception")
finally:
pbar.close()
# Write all data at once (only if no error)
if results and not error_occurred:
combined_data = pd.concat(results.values(), ignore_index=True)
self.storage.save("daily", combined_data, mode="append")
print(f"\n[DataSync] Saved {len(combined_data)} rows to storage")
# Summary
print("\n" + "=" * 60)
print("[DataSync] Sync Summary")
print("=" * 60)
print(f" Total stocks: {len(stock_codes)}")
print(f" Updated: {success_count}")
print(f" Skipped (empty/delisted): {empty_count}")
print(
f" Errors: {error_count} (aborted on first error)"
if error_count
else " Errors: 0"
)
print(f" Date range: {sync_start_date} to {end_date}")
print("=" * 60)
return results
# Convenience functions
def sync_all(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
) -> Dict[str, pd.DataFrame]:
"""Sync daily data for all stocks.
This is the main entry point for data synchronization.
Args:
force_full: If True, force full reload from 20180101
start_date: Manual start date (YYYYMMDD)
end_date: Manual end date (defaults to today)
max_workers: Number of worker threads (default: 10)
Returns:
Dict mapping ts_code to DataFrame
Example:
>>> # First time sync (full load from 20180101)
>>> result = sync_all()
>>>
>>> # Subsequent sync (incremental - only new data)
>>> result = sync_all()
>>>
>>> # Force full reload
>>> result = sync_all(force_full=True)
>>>
>>> # Manual date range
>>> result = sync_all(start_date='20240101', end_date='20240131')
>>>
>>> # Custom thread count
>>> result = sync_all(max_workers=20)
"""
sync_manager = DataSync(max_workers=max_workers)
return sync_manager.sync_all(
force_full=force_full,
start_date=start_date,
end_date=end_date,
)
if __name__ == "__main__":
print("=" * 60)
print("Data Sync Module")
print("=" * 60)
print("\nUsage:")
print(" from src.data.sync import sync_all")
print(" result = sync_all() # Incremental sync")
print(" result = sync_all(force_full=True) # Full reload")
print("\n" + "=" * 60)
# Run sync
result = sync_all()
print(f"\nSynced {len(result)} stocks")