"""Trade calendar interface. Fetch trading calendar data from Tushare to determine market open/close dates. With local caching for performance optimization. """ import pandas as pd from typing import Optional, Literal from pathlib import Path from src.data.client import TushareClient from src.config.settings import get_settings # Module-level flag to track if cache has been synced in this session _cache_synced = False # Trading calendar cache file path def _get_cache_path() -> Path: """Get the cache file path for trade calendar.""" cfg = get_settings() return cfg.data_path_resolved / "trade_cal.h5" def _save_to_cache(data: pd.DataFrame) -> None: """Save trade calendar data to local cache. Args: data: Trade calendar DataFrame """ if data.empty: return cache_path = _get_cache_path() cache_path.parent.mkdir(parents=True, exist_ok=True) try: with pd.HDFStore(cache_path, mode="a") as store: store.put("trade_cal", data, format="table") print(f"[trade_cal] Saved {len(data)} records to cache: {cache_path}") except Exception as e: print(f"[trade_cal] Error saving to cache: {e}") def _load_from_cache() -> pd.DataFrame: """Load trade calendar data from local cache. Returns: Trade calendar DataFrame or empty DataFrame if cache doesn't exist """ cache_path = _get_cache_path() if not cache_path.exists(): return pd.DataFrame() try: with pd.HDFStore(cache_path, mode="r") as store: # HDF5 keys include leading slash (e.g., '/trade_cal') if "/trade_cal" in store.keys(): data = store["/trade_cal"] print(f"[trade_cal] Loaded {len(data)} records from cache") return data except Exception as e: print(f"[trade_cal] Error loading from cache: {e}") return pd.DataFrame() def _get_cached_date_range() -> tuple[Optional[str], Optional[str]]: """Get the date range of cached trade calendar. Returns: Tuple of (min_date, max_date) or (None, None) if cache empty """ data = _load_from_cache() if data.empty or "cal_date" not in data.columns: return (None, None) return (str(data["cal_date"].min()), str(data["cal_date"].max())) def sync_trade_cal_cache( start_date: str = "20180101", end_date: Optional[str] = None, force: bool = False, ) -> pd.DataFrame: """Sync trade calendar data to local cache with incremental updates. This function checks if we have cached data and only fetches new data from the last cached date onwards. Args: start_date: Initial start date for full sync (default: 20180101) end_date: End date (defaults to today) force: If True, force sync even if already synced in this session Returns: Full trade calendar DataFrame (cached + new) """ global _cache_synced # Skip if already synced in this session (unless forced) if _cache_synced and not force: return _load_from_cache() if end_date is None: from datetime import datetime end_date = datetime.now().strftime("%Y%m%d") client = TushareClient() # Check cached data range cached_min, cached_max = _get_cached_date_range() if cached_min and cached_max: print(f"[trade_cal] Cache found: {cached_min} to {cached_max}") # Only fetch new data after the cached max date fetch_start = str(int(cached_max) + 1) print(f"[trade_cal] Fetching incremental data from {fetch_start} to {end_date}") if int(fetch_start) > int(end_date): print("[trade_cal] Cache is up-to-date, no new data needed") return _load_from_cache() # Fetch new data new_data = client.query( "trade_cal", start_date=fetch_start, end_date=end_date, exchange="SSE", ) if new_data.empty: print("[trade_cal] No new data returned") return _load_from_cache() print(f"[trade_cal] Fetched {len(new_data)} new records") # Load cached data and merge cached_data = _load_from_cache() if not cached_data.empty: combined = pd.concat([cached_data, new_data], ignore_index=True) # Remove duplicates by cal_date combined = combined.drop_duplicates( subset=["cal_date", "exchange"], keep="first" ) combined = combined.sort_values("cal_date").reset_index(drop=True) else: combined = new_data # Save combined data to cache # Mark as synced to avoid redundant syncs in this session _cache_synced = True _save_to_cache(combined) return combined else: # No cache, fetch all data print(f"[trade_cal] No cache found, fetching from {start_date} to {end_date}") data = client.query( "trade_cal", start_date=start_date, end_date=end_date, exchange="SSE", ) if data.empty: print("[trade_cal] No data returned") return data # Mark as synced to avoid redundant syncs in this session _cache_synced = True _save_to_cache(data) return data def get_trade_cal( start_date: str, end_date: str, exchange: Literal["SSE", "SZSE", "BSE"] = "SSE", is_open: Optional[Literal["0", "1"]] = None, use_cache: bool = True, ) -> pd.DataFrame: """Fetch trading calendar data with optional local caching. This interface retrieves trading calendar information including whether each date is a trading day. Uses cached data when available to reduce API calls and improve performance. Args: start_date: Start date in YYYYMMDD format end_date: End date in YYYYMMDD format exchange: Exchange - SSE (Shanghai), SZSE (Shenzhen), BSE (Beijing) is_open: Open status - "1" for trading day, "0" for non-trading day use_cache: Whether to use and update local cache (default: True) Returns: pd.DataFrame with trade calendar containing: - cal_date: Calendar date (YYYYMMDD) - exchange: Exchange code - is_open: Whether it's a trading day (1/0) - pretrade_date: Previous trading day Example: >>> # Get all trading days in January 2024 >>> cal = get_trade_cal('20240101', '20240131') >>> trading_days = cal[cal['is_open'] == '1'] >>> >>> # Get first and last trading day of a period >>> cal = get_trade_cal('20180101', '20240101') >>> first_trade_day = cal[cal['is_open'] == '1'].iloc[0]['cal_date'] >>> last_trade_day = cal[cal['is_open'] == '1'].iloc[-1]['cal_date'] """ # Use cache if enabled if use_cache and exchange == "SSE": # Sync cache first (incremental) sync_trade_cal_cache() # Load from cache and filter by date range cached_data = _load_from_cache() if not cached_data.empty and "cal_date" in cached_data.columns: # Filter by date range and exchange filtered = cached_data[ (cached_data["cal_date"] >= start_date) & (cached_data["cal_date"] <= end_date) & (cached_data["exchange"] == exchange) ] # Apply is_open filter if specified if is_open is not None: # Handle type mismatch: HDF5 stores is_open as int, but API returns str filtered = filtered[filtered["is_open"].astype(str) == str(is_open)] if not filtered.empty: print(f"[get_trade_cal] Retrieved {len(filtered)} records from cache") return filtered # Fallback to API if cache not available or disabled client = TushareClient() # Build parameters params = { "start_date": start_date, "end_date": end_date, "exchange": exchange, } if is_open is not None: params["is_open"] = is_open # Fetch data data = client.query("trade_cal", **params) if data.empty: print("[get_trade_cal] No data returned") return data def get_trading_days( start_date: str, end_date: str, exchange: Literal["SSE", "SZSE", "BSE"] = "SSE", ) -> list: """Get list of trading days in a date range. Args: start_date: Start date in YYYYMMDD format end_date: End date in YYYYMMDD format exchange: Exchange code Returns: List of trading dates (YYYYMMDD strings) """ cal = get_trade_cal(start_date, end_date, exchange=exchange, is_open="1") if cal.empty: return [] return cal["cal_date"].tolist() def get_first_trading_day( start_date: str, end_date: str, exchange: Literal["SSE", "SZSE", "BSE"] = "SSE", ) -> Optional[str]: """Get the first trading day in a date range. Args: start_date: Start date in YYYYMMDD format end_date: End date in YYYYMMDD format exchange: Exchange code Returns: First trading date (YYYYMMDD) or None if no trading days """ trading_days = get_trading_days(start_date, end_date, exchange) if not trading_days: return None # Return the earliest trading day return min(trading_days) def get_last_trading_day( start_date: str, end_date: str, exchange: Literal["SSE", "SZSE", "BSE"] = "SSE", ) -> Optional[str]: """Get the last trading day in a date range. Args: start_date: Start date in YYYYMMDD format end_date: End date in YYYYMMDD format exchange: Exchange code Returns: Last trading date (YYYYMMDD) or None if no trading days """ trading_days = get_trading_days(start_date, end_date, exchange) if not trading_days: return None # Return the latest trading day return max(trading_days) if __name__ == "__main__": # Example usage start = "20180101" end = "20240101" print(f"Trade calendar from {start} to {end}") cal = get_trade_cal(start, end) print(f"Total records: {len(cal)}") first_day = get_first_trading_day(start, end) last_day = get_last_trading_day(start, end) print(f"First trading day: {first_day}") print(f"Last trading day: {last_day}")