- 新增 api_stock_st.py,实现ST股票数据获取和日期遍历同步 - 更新 sync.py,将ST股票同步加入第7步流程 - 移除 base_sync.py 中未使用的 get_last_n_trading_days 导入
338 lines
10 KiB
Python
338 lines
10 KiB
Python
"""Trade calendar interface.
|
|
|
|
Fetch trading calendar data from Tushare to determine market open/close dates.
|
|
With local caching for performance optimization.
|
|
"""
|
|
|
|
import pandas as pd
|
|
from typing import Optional, Literal
|
|
from pathlib import Path
|
|
from src.data.client import TushareClient
|
|
from src.config.settings import get_settings
|
|
|
|
# Module-level flag to track if cache has been synced in this session
|
|
_cache_synced = False
|
|
|
|
|
|
# Trading calendar cache file path
|
|
def _get_cache_path() -> Path:
|
|
"""Get the cache file path for trade calendar."""
|
|
cfg = get_settings()
|
|
return cfg.data_path_resolved / "trade_cal.h5"
|
|
|
|
|
|
def _save_to_cache(data: pd.DataFrame) -> None:
|
|
"""Save trade calendar data to local cache.
|
|
|
|
Args:
|
|
data: Trade calendar DataFrame
|
|
"""
|
|
if data.empty:
|
|
return
|
|
|
|
cache_path = _get_cache_path()
|
|
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
with pd.HDFStore(cache_path, mode="a") as store:
|
|
store.put("trade_cal", data, format="table")
|
|
print(f"[trade_cal] Saved {len(data)} records to cache: {cache_path}")
|
|
except Exception as e:
|
|
print(f"[trade_cal] Error saving to cache: {e}")
|
|
|
|
|
|
def _load_from_cache() -> pd.DataFrame:
|
|
"""Load trade calendar data from local cache.
|
|
|
|
Returns:
|
|
Trade calendar DataFrame or empty DataFrame if cache doesn't exist
|
|
"""
|
|
cache_path = _get_cache_path()
|
|
|
|
if not cache_path.exists():
|
|
return pd.DataFrame()
|
|
|
|
try:
|
|
with pd.HDFStore(cache_path, mode="r") as store:
|
|
# HDF5 keys include leading slash (e.g., '/trade_cal')
|
|
if "/trade_cal" in store.keys():
|
|
data = store["/trade_cal"]
|
|
print(f"[trade_cal] Loaded {len(data)} records from cache")
|
|
return data
|
|
except Exception as e:
|
|
print(f"[trade_cal] Error loading from cache: {e}")
|
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
def _get_cached_date_range() -> tuple[Optional[str], Optional[str]]:
|
|
"""Get the date range of cached trade calendar.
|
|
|
|
Returns:
|
|
Tuple of (min_date, max_date) or (None, None) if cache empty
|
|
"""
|
|
data = _load_from_cache()
|
|
if data.empty or "cal_date" not in data.columns:
|
|
return (None, None)
|
|
|
|
return (str(data["cal_date"].min()), str(data["cal_date"].max()))
|
|
|
|
|
|
def sync_trade_cal_cache(
|
|
start_date: str = "20180101",
|
|
end_date: Optional[str] = None,
|
|
force: bool = False,
|
|
) -> pd.DataFrame:
|
|
"""Sync trade calendar data to local cache with incremental updates.
|
|
|
|
This function checks if we have cached data and only fetches new data
|
|
from the last cached date onwards.
|
|
|
|
Args:
|
|
start_date: Initial start date for full sync (default: 20180101)
|
|
end_date: End date (defaults to today)
|
|
force: If True, force sync even if already synced in this session
|
|
|
|
Returns:
|
|
Full trade calendar DataFrame (cached + new)
|
|
"""
|
|
global _cache_synced
|
|
|
|
# Skip if already synced in this session (unless forced)
|
|
if _cache_synced and not force:
|
|
return _load_from_cache()
|
|
|
|
if end_date is None:
|
|
from datetime import datetime
|
|
|
|
end_date = datetime.now().strftime("%Y%m%d")
|
|
|
|
client = TushareClient()
|
|
|
|
# Check cached data range
|
|
cached_min, cached_max = _get_cached_date_range()
|
|
|
|
if cached_min and cached_max:
|
|
print(f"[trade_cal] Cache found: {cached_min} to {cached_max}")
|
|
# Only fetch new data after the cached max date
|
|
fetch_start = str(int(cached_max) + 1)
|
|
print(f"[trade_cal] Fetching incremental data from {fetch_start} to {end_date}")
|
|
|
|
if int(fetch_start) > int(end_date):
|
|
print("[trade_cal] Cache is up-to-date, no new data needed")
|
|
return _load_from_cache()
|
|
|
|
# Fetch new data
|
|
new_data = client.query(
|
|
"trade_cal",
|
|
start_date=fetch_start,
|
|
end_date=end_date,
|
|
exchange="SSE",
|
|
)
|
|
|
|
if new_data.empty:
|
|
print("[trade_cal] No new data returned")
|
|
return _load_from_cache()
|
|
|
|
print(f"[trade_cal] Fetched {len(new_data)} new records")
|
|
|
|
# Load cached data and merge
|
|
cached_data = _load_from_cache()
|
|
if not cached_data.empty:
|
|
combined = pd.concat([cached_data, new_data], ignore_index=True)
|
|
# Remove duplicates by cal_date
|
|
combined = combined.drop_duplicates(
|
|
subset=["cal_date", "exchange"], keep="first"
|
|
)
|
|
combined = combined.sort_values("cal_date").reset_index(drop=True)
|
|
else:
|
|
combined = new_data
|
|
|
|
# Save combined data to cache
|
|
# Mark as synced to avoid redundant syncs in this session
|
|
_cache_synced = True
|
|
_save_to_cache(combined)
|
|
return combined
|
|
else:
|
|
# No cache, fetch all data
|
|
print(f"[trade_cal] No cache found, fetching from {start_date} to {end_date}")
|
|
data = client.query(
|
|
"trade_cal",
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
exchange="SSE",
|
|
)
|
|
|
|
if data.empty:
|
|
print("[trade_cal] No data returned")
|
|
return data
|
|
|
|
# Mark as synced to avoid redundant syncs in this session
|
|
_cache_synced = True
|
|
_save_to_cache(data)
|
|
return data
|
|
|
|
|
|
def get_trade_cal(
|
|
start_date: str,
|
|
end_date: str,
|
|
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
|
|
is_open: Optional[Literal["0", "1"]] = None,
|
|
use_cache: bool = True,
|
|
) -> pd.DataFrame:
|
|
"""Fetch trading calendar data with optional local caching.
|
|
|
|
This interface retrieves trading calendar information including
|
|
whether each date is a trading day. Uses cached data when available
|
|
to reduce API calls and improve performance.
|
|
|
|
Args:
|
|
start_date: Start date in YYYYMMDD format
|
|
end_date: End date in YYYYMMDD format
|
|
exchange: Exchange - SSE (Shanghai), SZSE (Shenzhen), BSE (Beijing)
|
|
is_open: Open status - "1" for trading day, "0" for non-trading day
|
|
use_cache: Whether to use and update local cache (default: True)
|
|
|
|
Returns:
|
|
pd.DataFrame with trade calendar containing:
|
|
- cal_date: Calendar date (YYYYMMDD)
|
|
- exchange: Exchange code
|
|
- is_open: Whether it's a trading day (1/0)
|
|
- pretrade_date: Previous trading day
|
|
|
|
Example:
|
|
>>> # Get all trading days in January 2024
|
|
>>> cal = get_trade_cal('20240101', '20240131')
|
|
>>> trading_days = cal[cal['is_open'] == '1']
|
|
>>>
|
|
>>> # Get first and last trading day of a period
|
|
>>> cal = get_trade_cal('20180101', '20240101')
|
|
>>> first_trade_day = cal[cal['is_open'] == '1'].iloc[0]['cal_date']
|
|
>>> last_trade_day = cal[cal['is_open'] == '1'].iloc[-1]['cal_date']
|
|
"""
|
|
# Use cache if enabled
|
|
if use_cache and exchange == "SSE":
|
|
# Sync cache first (incremental)
|
|
sync_trade_cal_cache()
|
|
|
|
# Load from cache and filter by date range
|
|
cached_data = _load_from_cache()
|
|
if not cached_data.empty and "cal_date" in cached_data.columns:
|
|
# Filter by date range and exchange
|
|
filtered = cached_data[
|
|
(cached_data["cal_date"] >= start_date)
|
|
& (cached_data["cal_date"] <= end_date)
|
|
& (cached_data["exchange"] == exchange)
|
|
]
|
|
|
|
# Apply is_open filter if specified
|
|
if is_open is not None:
|
|
# Handle type mismatch: HDF5 stores is_open as int, but API returns str
|
|
filtered = filtered[filtered["is_open"].astype(str) == str(is_open)]
|
|
|
|
if not filtered.empty:
|
|
print(f"[get_trade_cal] Retrieved {len(filtered)} records from cache")
|
|
return filtered
|
|
|
|
# Fallback to API if cache not available or disabled
|
|
client = TushareClient()
|
|
|
|
# Build parameters
|
|
params = {
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"exchange": exchange,
|
|
}
|
|
|
|
if is_open is not None:
|
|
params["is_open"] = is_open
|
|
|
|
# Fetch data
|
|
data = client.query("trade_cal", **params)
|
|
|
|
if data.empty:
|
|
print("[get_trade_cal] No data returned")
|
|
|
|
return data
|
|
|
|
|
|
def get_trading_days(
|
|
start_date: str,
|
|
end_date: str,
|
|
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
|
|
) -> list:
|
|
"""Get list of trading days in a date range.
|
|
|
|
Args:
|
|
start_date: Start date in YYYYMMDD format
|
|
end_date: End date in YYYYMMDD format
|
|
exchange: Exchange code
|
|
|
|
Returns:
|
|
List of trading dates (YYYYMMDD strings)
|
|
"""
|
|
cal = get_trade_cal(start_date, end_date, exchange=exchange, is_open="1")
|
|
if cal.empty:
|
|
return []
|
|
return cal["cal_date"].tolist()
|
|
|
|
|
|
def get_first_trading_day(
|
|
start_date: str,
|
|
end_date: str,
|
|
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
|
|
) -> Optional[str]:
|
|
"""Get the first trading day in a date range.
|
|
|
|
Args:
|
|
start_date: Start date in YYYYMMDD format
|
|
end_date: End date in YYYYMMDD format
|
|
exchange: Exchange code
|
|
|
|
Returns:
|
|
First trading date (YYYYMMDD) or None if no trading days
|
|
"""
|
|
trading_days = get_trading_days(start_date, end_date, exchange)
|
|
if not trading_days:
|
|
return None
|
|
# Return the earliest trading day
|
|
return min(trading_days)
|
|
|
|
|
|
def get_last_trading_day(
|
|
start_date: str,
|
|
end_date: str,
|
|
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
|
|
) -> Optional[str]:
|
|
"""Get the last trading day in a date range.
|
|
|
|
Args:
|
|
start_date: Start date in YYYYMMDD format
|
|
end_date: End date in YYYYMMDD format
|
|
exchange: Exchange code
|
|
|
|
Returns:
|
|
Last trading date (YYYYMMDD) or None if no trading days
|
|
"""
|
|
trading_days = get_trading_days(start_date, end_date, exchange)
|
|
if not trading_days:
|
|
return None
|
|
# Return the latest trading day
|
|
return max(trading_days)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
start = "20180101"
|
|
end = "20240101"
|
|
|
|
print(f"Trade calendar from {start} to {end}")
|
|
|
|
cal = get_trade_cal(start, end)
|
|
print(f"Total records: {len(cal)}")
|
|
|
|
first_day = get_first_trading_day(start, end)
|
|
last_day = get_last_trading_day(start, end)
|
|
print(f"First trading day: {first_day}")
|
|
print(f"Last trading day: {last_day}")
|