Files
ProStock/src/data/api_wrappers/api_trade_cal.py
liaozhaorun 317ecd87e7 feat(data): 封装ST股票列表接口(stock_st)
- 新增 api_stock_st.py,实现ST股票数据获取和日期遍历同步
- 更新 sync.py,将ST股票同步加入第7步流程
- 移除 base_sync.py 中未使用的 get_last_n_trading_days 导入
2026-03-03 22:04:22 +08:00

338 lines
10 KiB
Python

"""Trade calendar interface.
Fetch trading calendar data from Tushare to determine market open/close dates.
With local caching for performance optimization.
"""
import pandas as pd
from typing import Optional, Literal
from pathlib import Path
from src.data.client import TushareClient
from src.config.settings import get_settings
# Module-level flag to track if cache has been synced in this session
_cache_synced = False
# Trading calendar cache file path
def _get_cache_path() -> Path:
"""Get the cache file path for trade calendar."""
cfg = get_settings()
return cfg.data_path_resolved / "trade_cal.h5"
def _save_to_cache(data: pd.DataFrame) -> None:
"""Save trade calendar data to local cache.
Args:
data: Trade calendar DataFrame
"""
if data.empty:
return
cache_path = _get_cache_path()
cache_path.parent.mkdir(parents=True, exist_ok=True)
try:
with pd.HDFStore(cache_path, mode="a") as store:
store.put("trade_cal", data, format="table")
print(f"[trade_cal] Saved {len(data)} records to cache: {cache_path}")
except Exception as e:
print(f"[trade_cal] Error saving to cache: {e}")
def _load_from_cache() -> pd.DataFrame:
"""Load trade calendar data from local cache.
Returns:
Trade calendar DataFrame or empty DataFrame if cache doesn't exist
"""
cache_path = _get_cache_path()
if not cache_path.exists():
return pd.DataFrame()
try:
with pd.HDFStore(cache_path, mode="r") as store:
# HDF5 keys include leading slash (e.g., '/trade_cal')
if "/trade_cal" in store.keys():
data = store["/trade_cal"]
print(f"[trade_cal] Loaded {len(data)} records from cache")
return data
except Exception as e:
print(f"[trade_cal] Error loading from cache: {e}")
return pd.DataFrame()
def _get_cached_date_range() -> tuple[Optional[str], Optional[str]]:
"""Get the date range of cached trade calendar.
Returns:
Tuple of (min_date, max_date) or (None, None) if cache empty
"""
data = _load_from_cache()
if data.empty or "cal_date" not in data.columns:
return (None, None)
return (str(data["cal_date"].min()), str(data["cal_date"].max()))
def sync_trade_cal_cache(
start_date: str = "20180101",
end_date: Optional[str] = None,
force: bool = False,
) -> pd.DataFrame:
"""Sync trade calendar data to local cache with incremental updates.
This function checks if we have cached data and only fetches new data
from the last cached date onwards.
Args:
start_date: Initial start date for full sync (default: 20180101)
end_date: End date (defaults to today)
force: If True, force sync even if already synced in this session
Returns:
Full trade calendar DataFrame (cached + new)
"""
global _cache_synced
# Skip if already synced in this session (unless forced)
if _cache_synced and not force:
return _load_from_cache()
if end_date is None:
from datetime import datetime
end_date = datetime.now().strftime("%Y%m%d")
client = TushareClient()
# Check cached data range
cached_min, cached_max = _get_cached_date_range()
if cached_min and cached_max:
print(f"[trade_cal] Cache found: {cached_min} to {cached_max}")
# Only fetch new data after the cached max date
fetch_start = str(int(cached_max) + 1)
print(f"[trade_cal] Fetching incremental data from {fetch_start} to {end_date}")
if int(fetch_start) > int(end_date):
print("[trade_cal] Cache is up-to-date, no new data needed")
return _load_from_cache()
# Fetch new data
new_data = client.query(
"trade_cal",
start_date=fetch_start,
end_date=end_date,
exchange="SSE",
)
if new_data.empty:
print("[trade_cal] No new data returned")
return _load_from_cache()
print(f"[trade_cal] Fetched {len(new_data)} new records")
# Load cached data and merge
cached_data = _load_from_cache()
if not cached_data.empty:
combined = pd.concat([cached_data, new_data], ignore_index=True)
# Remove duplicates by cal_date
combined = combined.drop_duplicates(
subset=["cal_date", "exchange"], keep="first"
)
combined = combined.sort_values("cal_date").reset_index(drop=True)
else:
combined = new_data
# Save combined data to cache
# Mark as synced to avoid redundant syncs in this session
_cache_synced = True
_save_to_cache(combined)
return combined
else:
# No cache, fetch all data
print(f"[trade_cal] No cache found, fetching from {start_date} to {end_date}")
data = client.query(
"trade_cal",
start_date=start_date,
end_date=end_date,
exchange="SSE",
)
if data.empty:
print("[trade_cal] No data returned")
return data
# Mark as synced to avoid redundant syncs in this session
_cache_synced = True
_save_to_cache(data)
return data
def get_trade_cal(
start_date: str,
end_date: str,
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
is_open: Optional[Literal["0", "1"]] = None,
use_cache: bool = True,
) -> pd.DataFrame:
"""Fetch trading calendar data with optional local caching.
This interface retrieves trading calendar information including
whether each date is a trading day. Uses cached data when available
to reduce API calls and improve performance.
Args:
start_date: Start date in YYYYMMDD format
end_date: End date in YYYYMMDD format
exchange: Exchange - SSE (Shanghai), SZSE (Shenzhen), BSE (Beijing)
is_open: Open status - "1" for trading day, "0" for non-trading day
use_cache: Whether to use and update local cache (default: True)
Returns:
pd.DataFrame with trade calendar containing:
- cal_date: Calendar date (YYYYMMDD)
- exchange: Exchange code
- is_open: Whether it's a trading day (1/0)
- pretrade_date: Previous trading day
Example:
>>> # Get all trading days in January 2024
>>> cal = get_trade_cal('20240101', '20240131')
>>> trading_days = cal[cal['is_open'] == '1']
>>>
>>> # Get first and last trading day of a period
>>> cal = get_trade_cal('20180101', '20240101')
>>> first_trade_day = cal[cal['is_open'] == '1'].iloc[0]['cal_date']
>>> last_trade_day = cal[cal['is_open'] == '1'].iloc[-1]['cal_date']
"""
# Use cache if enabled
if use_cache and exchange == "SSE":
# Sync cache first (incremental)
sync_trade_cal_cache()
# Load from cache and filter by date range
cached_data = _load_from_cache()
if not cached_data.empty and "cal_date" in cached_data.columns:
# Filter by date range and exchange
filtered = cached_data[
(cached_data["cal_date"] >= start_date)
& (cached_data["cal_date"] <= end_date)
& (cached_data["exchange"] == exchange)
]
# Apply is_open filter if specified
if is_open is not None:
# Handle type mismatch: HDF5 stores is_open as int, but API returns str
filtered = filtered[filtered["is_open"].astype(str) == str(is_open)]
if not filtered.empty:
print(f"[get_trade_cal] Retrieved {len(filtered)} records from cache")
return filtered
# Fallback to API if cache not available or disabled
client = TushareClient()
# Build parameters
params = {
"start_date": start_date,
"end_date": end_date,
"exchange": exchange,
}
if is_open is not None:
params["is_open"] = is_open
# Fetch data
data = client.query("trade_cal", **params)
if data.empty:
print("[get_trade_cal] No data returned")
return data
def get_trading_days(
start_date: str,
end_date: str,
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
) -> list:
"""Get list of trading days in a date range.
Args:
start_date: Start date in YYYYMMDD format
end_date: End date in YYYYMMDD format
exchange: Exchange code
Returns:
List of trading dates (YYYYMMDD strings)
"""
cal = get_trade_cal(start_date, end_date, exchange=exchange, is_open="1")
if cal.empty:
return []
return cal["cal_date"].tolist()
def get_first_trading_day(
start_date: str,
end_date: str,
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
) -> Optional[str]:
"""Get the first trading day in a date range.
Args:
start_date: Start date in YYYYMMDD format
end_date: End date in YYYYMMDD format
exchange: Exchange code
Returns:
First trading date (YYYYMMDD) or None if no trading days
"""
trading_days = get_trading_days(start_date, end_date, exchange)
if not trading_days:
return None
# Return the earliest trading day
return min(trading_days)
def get_last_trading_day(
start_date: str,
end_date: str,
exchange: Literal["SSE", "SZSE", "BSE"] = "SSE",
) -> Optional[str]:
"""Get the last trading day in a date range.
Args:
start_date: Start date in YYYYMMDD format
end_date: End date in YYYYMMDD format
exchange: Exchange code
Returns:
Last trading date (YYYYMMDD) or None if no trading days
"""
trading_days = get_trading_days(start_date, end_date, exchange)
if not trading_days:
return None
# Return the latest trading day
return max(trading_days)
if __name__ == "__main__":
# Example usage
start = "20180101"
end = "20240101"
print(f"Trade calendar from {start} to {end}")
cal = get_trade_cal(start, end)
print(f"Total records: {len(cal)}")
first_day = get_first_trading_day(start, end)
last_day = get_last_trading_day(start, end)
print(f"First trading day: {first_day}")
print(f"Last trading day: {last_day}")