diff --git a/.gitignore b/.gitignore index e610afc..958ce7b 100644 --- a/.gitignore +++ b/.gitignore @@ -72,5 +72,5 @@ cover/ tmp/ temp/ -# 数据目录(允许跟踪) -data/ +# 数据目录(允许跟踪,但忽略内容) +data/* diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 7d51d5e..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,21 +0,0 @@ -[project] -name = "ProStock" -version = "0.1.0" -description = "A股量化投资框架" -readme = "README.md" -requires-python = ">=3.10,<3.14" -dependencies = [ - "pandas>=2.0.0", - "numpy>=1.24.0", - "tushare>=2.0.0", - "pydantic>=2.0.0", - "pydantic-settings>=2.0.0", - "tqdm>=4.65.0", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.uv] -package = false diff --git a/src/data/api.md b/src/data/api.md index 30f6332..ba96d04 100644 --- a/src/data/api.md +++ b/src/data/api.md @@ -123,4 +123,60 @@ delist_date str N 退市日期 is_hs str N 是否沪深港通标的,N否 H沪股通 S深股通 act_name str Y 实控人名称 act_ent_type str Y 实控人企业性质 -说明:旧版上的PE/PB/股本等字段,请在行情接口“每日指标”中获取。 \ No newline at end of file +说明:旧版上的PE/PB/股本等字段,请在行情接口“每日指标”中获取。 + + +交易日历 +接口:trade_cal,可以通过数据工具调试和查看数据。 +描述:获取各大交易所交易日历数据,默认提取的是上交所 +积分:需2000积分 + +输入参数 + +名称 类型 必选 描述 +exchange str N 交易所 SSE上交所,SZSE深交所,CFFEX 中金所,SHFE 上期所,CZCE 郑商所,DCE 大商所,INE 上能源 +start_date str N 开始日期 (格式:YYYYMMDD 下同) +end_date str N 结束日期 +is_open str N 是否交易 '0'休市 '1'交易 +输出参数 + +名称 类型 默认显示 描述 +exchange str Y 交易所 SSE上交所 SZSE深交所 +cal_date str Y 日历日期 +is_open str Y 是否交易 0休市 1交易 +pretrade_date str Y 上一个交易日 +接口示例 + + +pro = ts.pro_api() + + +df = pro.trade_cal(exchange='', start_date='20180101', end_date='20181231') +或者 + + +df = pro.query('trade_cal', start_date='20180101', end_date='20181231') +数据样例 + + exchange cal_date is_open +0 SSE 20180101 0 +1 SSE 20180102 1 +2 SSE 20180103 1 +3 SSE 20180104 1 +4 SSE 20180105 1 +5 SSE 20180106 0 +6 SSE 20180107 0 +7 SSE 20180108 1 +8 SSE 20180109 1 +9 SSE 20180110 1 +10 SSE 20180111 1 +11 SSE 20180112 1 +12 SSE 20180113 0 +13 SSE 20180114 0 +14 SSE 20180115 1 +15 SSE 20180116 1 +16 SSE 20180117 1 +17 SSE 20180118 1 +18 SSE 20180119 1 +19 SSE 20180120 0 +20 SSE 20180121 0 \ No newline at end of file diff --git a/src/data/sync.py b/src/data/sync.py new file mode 100644 index 0000000..ee0ef3a --- /dev/null +++ b/src/data/sync.py @@ -0,0 +1,550 @@ +"""Data synchronization module. + +This module provides data fetching functions with intelligent sync logic: +- If local file doesn't exist: fetch all data (full load from 20180101) +- If local file exists: incremental update (fetch from latest date + 1 day) +- Multi-threaded concurrent fetching for improved performance +- Stop immediately on any exception + +Currently supported data types: +- daily: Daily market data (with turnover rate and volume ratio) + +Usage: + # Sync all stocks (full load) + sync_all() + + # Sync all stocks (incremental) + sync_all() + + # Force full reload + sync_all(force_full=True) +""" + +import pandas as pd +from typing import Optional, Dict, Callable +from datetime import datetime, timedelta +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading +import sys + +from src.data.client import TushareClient +from src.data.storage import Storage +from src.data.daily import get_daily +from src.data.trade_cal import ( + get_first_trading_day, + get_last_trading_day, + sync_trade_cal_cache, +) + + +# Default full sync start date +DEFAULT_START_DATE = "20180101" + +# Today's date in YYYYMMDD format +TODAY = datetime.now().strftime("%Y%m%d") + + +def get_today_date() -> str: + """Get today's date in YYYYMMDD format.""" + return TODAY + + +def get_next_date(date_str: str) -> str: + """Get the next day after the given date. + + Args: + date_str: Date in YYYYMMDD format + + Returns: + Next date in YYYYMMDD format + """ + dt = datetime.strptime(date_str, "%Y%m%d") + next_dt = dt + timedelta(days=1) + return next_dt.strftime("%Y%m%d") + + +class DataSync: + """Data synchronization manager with full/incremental sync support.""" + + # Default number of worker threads + DEFAULT_MAX_WORKERS = 10 + + def __init__(self, max_workers: Optional[int] = None): + """Initialize sync manager. + + Args: + max_workers: Number of worker threads (default: 10) + """ + self.storage = Storage() + self.client = TushareClient() + self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS + self._stop_flag = threading.Event() + self._stop_flag.set() # Initially not stopped + self._cached_daily_data: Optional[pd.DataFrame] = None # Cache for daily data + + def _load_daily_data(self) -> pd.DataFrame: + """Load daily data from storage with caching. + + This method caches the daily data in memory to avoid repeated disk reads. + Call clear_cache() to force reload. + + Returns: + DataFrame with daily data (cached or loaded from storage) + """ + if self._cached_daily_data is None: + self._cached_daily_data = self.storage.load("daily") + return self._cached_daily_data + + def clear_cache(self) -> None: + """Clear the cached daily data to force reload on next access.""" + self._cached_daily_data = None + + def get_all_stock_codes(self, only_listed: bool = True) -> list: + """Get all stock codes from local storage. + + This function prioritizes stock_basic.csv to ensure all stocks + are included for backtesting to avoid look-ahead bias. + + Args: + only_listed: If True, only return currently listed stocks (L status). + Set to False to include delisted stocks (for full backtest). + + Returns: + List of stock codes + """ + # Import sync_all_stocks here to avoid circular imports + from src.data.stock_basic import sync_all_stocks, _get_csv_path + + # First, ensure stock_basic.csv is up-to-date with all stocks + print("[DataSync] Ensuring stock_basic.csv is up-to-date...") + sync_all_stocks() + + # Get from stock_basic.csv file + stock_csv_path = _get_csv_path() + + if stock_csv_path.exists(): + print(f"[DataSync] Reading stock_basic from CSV: {stock_csv_path}") + try: + stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig") + if not stock_df.empty and "ts_code" in stock_df.columns: + # Filter by list_status if only_listed is True + if only_listed and "list_status" in stock_df.columns: + listed_stocks = stock_df[stock_df["list_status"] == "L"] + codes = listed_stocks["ts_code"].unique().tolist() + total = len(stock_df["ts_code"].unique()) + print( + f"[DataSync] Found {len(codes)} listed stocks (filtered from {total} total)" + ) + else: + codes = stock_df["ts_code"].unique().tolist() + print( + f"[DataSync] Found {len(codes)} stock codes from stock_basic.csv" + ) + return codes + else: + print( + f"[DataSync] stock_basic.csv exists but no ts_code column or empty" + ) + except Exception as e: + print(f"[DataSync] Error reading stock_basic.csv: {e}") + + # Fallback: try daily storage if stock_basic not available (using cached data) + print("[DataSync] stock_basic.csv not available, falling back to daily data...") + daily_data = self._load_daily_data() + if not daily_data.empty and "ts_code" in daily_data.columns: + codes = daily_data["ts_code"].unique().tolist() + print(f"[DataSync] Found {len(codes)} stock codes from daily data") + return codes + + print("[DataSync] No stock codes found in local storage") + return [] + + def get_global_last_date(self) -> Optional[str]: + """Get the global last trade date across all stocks. + + Returns: + Last trade date string or None + """ + daily_data = self._load_daily_data() + if daily_data.empty or "trade_date" not in daily_data.columns: + return None + return str(daily_data["trade_date"].max()) + + def get_global_first_date(self) -> Optional[str]: + """Get the global first trade date across all stocks. + + Returns: + First trade date string or None + """ + daily_data = self._load_daily_data() + if daily_data.empty or "trade_date" not in daily_data.columns: + return None + return str(daily_data["trade_date"].min()) + + def get_trade_calendar_bounds( + self, start_date: str, end_date: str + ) -> tuple[Optional[str], Optional[str]]: + """Get the first and last trading day from trade calendar. + + Args: + start_date: Start date in YYYYMMDD format + end_date: End date in YYYYMMDD format + + Returns: + Tuple of (first_trading_day, last_trading_day) or (None, None) if error + """ + try: + first_day = get_first_trading_day(start_date, end_date) + last_day = get_last_trading_day(start_date, end_date) + return (first_day, last_day) + except Exception as e: + print(f"[ERROR] Failed to get trade calendar bounds: {e}") + return (None, None) + + def check_sync_needed( + self, force_full: bool = False + ) -> tuple[bool, Optional[str], Optional[str], Optional[str]]: + """Check if sync is needed based on trade calendar. + + This method compares local data date range with trade calendar + to determine if new data needs to be fetched. + + Logic: + - If force_full: sync needed, return (True, 20180101, today) + - If no local data: sync needed, return (True, 20180101, today) + - If local data exists: + - Get the last trading day from trade calendar + - If local last date >= calendar last date: NO sync needed + - Otherwise: sync needed from local_last_date + 1 to latest trade day + + Args: + force_full: If True, always return sync needed + + Returns: + Tuple of (sync_needed, start_date, end_date, local_last_date) + - sync_needed: True if sync should proceed, False to skip + - start_date: Sync start date (None if sync not needed) + - end_date: Sync end date (None if sync not needed) + - local_last_date: Local data last date (for incremental sync) + """ + # If force_full, always sync + if force_full: + print("[DataSync] Force full sync requested") + return (True, DEFAULT_START_DATE, get_today_date(), None) + + # Check if local data exists (using cached data) + daily_data = self._load_daily_data() + if daily_data.empty or "trade_date" not in daily_data.columns: + print("[DataSync] No local data found, full sync needed") + return (True, DEFAULT_START_DATE, get_today_date(), None) + + # Get local data last date (we only care about the latest date, not the first) + local_last_date = str(daily_data["trade_date"].max()) + + print(f"[DataSync] Local data last date: {local_last_date}") + + # Get the latest trading day from trade calendar + today = get_today_date() + _, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today) + + if cal_last is None: + print("[DataSync] Failed to get trade calendar, proceeding with sync") + return (True, DEFAULT_START_DATE, today, local_last_date) + + print(f"[DataSync] Calendar last trading day: {cal_last}") + + # Compare local last date with calendar last date + # If local data is already up-to-date or newer, no sync needed + print( + f"[DataSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), cal={cal_last} (type={type(cal_last).__name__})" + ) + try: + local_last_int = int(local_last_date) + cal_last_int = int(cal_last) + print( + f"[DataSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = {local_last_int >= cal_last_int}" + ) + if local_last_int >= cal_last_int: + print( + "[DataSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)" + ) + return (False, None, None, None) + except (ValueError, TypeError) as e: + print(f"[ERROR] Date comparison failed: {e}") + + # Need to sync from local_last_date + 1 to latest trade day + sync_start = get_next_date(local_last_date) + print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}") + return (True, sync_start, cal_last, local_last_date) + + def sync_single_stock( + self, + ts_code: str, + start_date: str, + end_date: str, + ) -> pd.DataFrame: + """Sync daily data for a single stock. + + Args: + ts_code: Stock code + start_date: Start date (YYYYMMDD) + end_date: End date (YYYYMMDD) + + Returns: + DataFrame with daily market data + """ + # Check if sync should stop (for exception handling) + if not self._stop_flag.is_set(): + return pd.DataFrame() + + try: + # Use shared client for rate limiting across threads + data = self.client.query( + "pro_bar", + ts_code=ts_code, + start_date=start_date, + end_date=end_date, + factors="tor,vr", + ) + return data + except Exception as e: + # Set stop flag to signal other threads to stop + self._stop_flag.clear() + print(f"[ERROR] Exception syncing {ts_code}: {e}") + raise + + def sync_all( + self, + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + max_workers: Optional[int] = None, + ) -> Dict[str, pd.DataFrame]: + """Sync daily data for all stocks in local storage. + + This function: + 1. Reads stock codes from local storage (daily or stock_basic) + 2. Checks trade calendar to determine if sync is needed: + - If local data matches trade calendar bounds, SKIP sync (save tokens) + - Otherwise, sync from local_last_date + 1 to latest trade day (bandwidth optimized) + 3. Uses multi-threaded concurrent fetching with rate limiting + 4. Skips updating stocks that return empty data (delisted/unavailable) + 5. Stops immediately on any exception + + Args: + force_full: If True, force full reload from 20180101 + start_date: Manual start date (overrides auto-detection) + end_date: Manual end date (defaults to today) + max_workers: Number of worker threads (default: 10) + + Returns: + Dict mapping ts_code to DataFrame (empty if sync skipped) + """ + print("\n" + "=" * 60) + print("[DataSync] Starting daily data sync...") + print("=" * 60) + + # First, ensure trade calendar cache is up-to-date (uses incremental sync) + print("[DataSync] Syncing trade calendar cache...") + sync_trade_cal_cache() + + # Determine date range + if end_date is None: + end_date = get_today_date() + + # Check if sync is needed based on trade calendar + sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) + + if not sync_needed: + # Sync skipped - no tokens consumed + print("\n" + "=" * 60) + print("[DataSync] Sync Summary") + print("=" * 60) + print(" Sync: SKIPPED (local data up-to-date with trade calendar)") + print(" Tokens saved: 0 consumed") + print("=" * 60) + return {} + + # Use dates from check_sync_needed (which calculates incremental start if needed) + if cal_start and cal_end: + sync_start_date = cal_start + end_date = cal_end + else: + # Fallback to default logic + sync_start_date = start_date or DEFAULT_START_DATE + if end_date is None: + end_date = get_today_date() + + # Determine sync mode + if force_full: + print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}") + elif local_last and cal_start and sync_start_date == get_next_date(local_last): + print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)") + print(f"[DataSync] Sync from: {sync_start_date} to {end_date}") + else: + print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}") + + # Get all stock codes + stock_codes = self.get_all_stock_codes() + if not stock_codes: + print("[DataSync] No stocks found to sync") + return {} + + print(f"[DataSync] Total stocks to sync: {len(stock_codes)}") + print(f"[DataSync] Using {max_workers or self.max_workers} worker threads") + + # Reset stop flag for new sync + self._stop_flag.set() + + # Multi-threaded concurrent fetching + results: Dict[str, pd.DataFrame] = {} + error_occurred = False + exception_to_raise = None + + def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]: + """Task function for each stock.""" + try: + data = self.sync_single_stock( + ts_code=ts_code, + start_date=sync_start_date, + end_date=end_date, + ) + return (ts_code, data) + except Exception as e: + # Re-raise to be caught by Future + raise + + # Use ThreadPoolExecutor for concurrent fetching + workers = max_workers or self.max_workers + with ThreadPoolExecutor(max_workers=workers) as executor: + # Submit all tasks and track futures with their stock codes + future_to_code = { + executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes + } + + # Process results using as_completed + error_count = 0 + empty_count = 0 + success_count = 0 + + # Create progress bar + pbar = tqdm(total=len(stock_codes), desc="Syncing stocks") + + try: + # Process futures as they complete + for future in as_completed(future_to_code): + ts_code = future_to_code[future] + + try: + _, data = future.result() + if data is not None and not data.empty: + results[ts_code] = data + success_count += 1 + else: + # Empty data - stock may be delisted or unavailable + empty_count += 1 + print( + f"[DataSync] Stock {ts_code}: empty data (skipped, may be delisted)" + ) + except Exception as e: + # Exception occurred - stop all and abort + error_occurred = True + exception_to_raise = e + print(f"\n[ERROR] Sync aborted due to exception: {e}") + # Shutdown executor to stop all pending tasks + executor.shutdown(wait=False, cancel_futures=True) + raise exception_to_raise + + # Update progress bar + pbar.update(1) + + except Exception: + error_count = 1 + print("[DataSync] Sync stopped due to exception") + finally: + pbar.close() + + # Write all data at once (only if no error) + if results and not error_occurred: + combined_data = pd.concat(results.values(), ignore_index=True) + self.storage.save("daily", combined_data, mode="append") + print(f"\n[DataSync] Saved {len(combined_data)} rows to storage") + + # Summary + print("\n" + "=" * 60) + print("[DataSync] Sync Summary") + print("=" * 60) + print(f" Total stocks: {len(stock_codes)}") + print(f" Updated: {success_count}") + print(f" Skipped (empty/delisted): {empty_count}") + print( + f" Errors: {error_count} (aborted on first error)" + if error_count + else " Errors: 0" + ) + print(f" Date range: {sync_start_date} to {end_date}") + print("=" * 60) + + return results + + +# Convenience functions + + +def sync_all( + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + max_workers: Optional[int] = None, +) -> Dict[str, pd.DataFrame]: + """Sync daily data for all stocks. + + This is the main entry point for data synchronization. + + Args: + force_full: If True, force full reload from 20180101 + start_date: Manual start date (YYYYMMDD) + end_date: Manual end date (defaults to today) + max_workers: Number of worker threads (default: 10) + + Returns: + Dict mapping ts_code to DataFrame + + Example: + >>> # First time sync (full load from 20180101) + >>> result = sync_all() + >>> + >>> # Subsequent sync (incremental - only new data) + >>> result = sync_all() + >>> + >>> # Force full reload + >>> result = sync_all(force_full=True) + >>> + >>> # Manual date range + >>> result = sync_all(start_date='20240101', end_date='20240131') + >>> + >>> # Custom thread count + >>> result = sync_all(max_workers=20) + """ + sync_manager = DataSync(max_workers=max_workers) + return sync_manager.sync_all( + force_full=force_full, + start_date=start_date, + end_date=end_date, + ) + + +if __name__ == "__main__": + print("=" * 60) + print("Data Sync Module") + print("=" * 60) + print("\nUsage:") + print(" from src.data.sync import sync_all") + print(" result = sync_all() # Incremental sync") + print(" result = sync_all(force_full=True) # Full reload") + print("\n" + "=" * 60) + + # Run sync + result = sync_all() + print(f"\nSynced {len(result)} stocks") diff --git a/src/data/trade_cal.py b/src/data/trade_cal.py new file mode 100644 index 0000000..761f1f5 --- /dev/null +++ b/src/data/trade_cal.py @@ -0,0 +1,321 @@ +"""Trade calendar interface. + +Fetch trading calendar data from Tushare to determine market open/close dates. +With local caching for performance optimization. +""" + +import pandas as pd +from typing import Optional, Literal +from pathlib import Path +from src.data.client import TushareClient +from src.data.config import get_config + + +# Trading calendar cache file path +def _get_cache_path() -> Path: + """Get the cache file path for trade calendar.""" + cfg = get_config() + return cfg.data_path_resolved / "trade_cal.h5" + + +def _save_to_cache(data: pd.DataFrame) -> None: + """Save trade calendar data to local cache. + + Args: + data: Trade calendar DataFrame + """ + if data.empty: + return + + cache_path = _get_cache_path() + cache_path.parent.mkdir(parents=True, exist_ok=True) + + try: + with pd.HDFStore(cache_path, mode="a") as store: + store.put("trade_cal", data, format="table") + print(f"[trade_cal] Saved {len(data)} records to cache: {cache_path}") + except Exception as e: + print(f"[trade_cal] Error saving to cache: {e}") + + +def _load_from_cache() -> pd.DataFrame: + """Load trade calendar data from local cache. + + Returns: + Trade calendar DataFrame or empty DataFrame if cache doesn't exist + """ + cache_path = _get_cache_path() + + if not cache_path.exists(): + return pd.DataFrame() + + try: + with pd.HDFStore(cache_path, mode="r") as store: + if "trade_cal" in store.keys(): + data = store["trade_cal"] + print(f"[trade_cal] Loaded {len(data)} records from cache") + return data + except Exception as e: + print(f"[trade_cal] Error loading from cache: {e}") + + return pd.DataFrame() + + +def _get_cached_date_range() -> tuple[Optional[str], Optional[str]]: + """Get the date range of cached trade calendar. + + Returns: + Tuple of (min_date, max_date) or (None, None) if cache empty + """ + data = _load_from_cache() + if data.empty or "cal_date" not in data.columns: + return (None, None) + + return (str(data["cal_date"].min()), str(data["cal_date"].max())) + + +def sync_trade_cal_cache( + start_date: str = "20180101", + end_date: Optional[str] = None, +) -> pd.DataFrame: + """Sync trade calendar data to local cache with incremental updates. + + This function checks if we have cached data and only fetches new data + from the last cached date onwards. + + Args: + start_date: Initial start date for full sync (default: 20180101) + end_date: End date (defaults to today) + + Returns: + Full trade calendar DataFrame (cached + new) + """ + if end_date is None: + from datetime import datetime + + end_date = datetime.now().strftime("%Y%m%d") + + client = TushareClient() + + # Check cached data range + cached_min, cached_max = _get_cached_date_range() + + if cached_min and cached_max: + print(f"[trade_cal] Cache found: {cached_min} to {cached_max}") + # Only fetch new data after the cached max date + fetch_start = str(int(cached_max) + 1) + print(f"[trade_cal] Fetching incremental data from {fetch_start} to {end_date}") + + if int(fetch_start) > int(end_date): + print("[trade_cal] Cache is up-to-date, no new data needed") + return _load_from_cache() + + # Fetch new data + new_data = client.query( + "trade_cal", + start_date=fetch_start, + end_date=end_date, + exchange="SSE", + ) + + if new_data.empty: + print("[trade_cal] No new data returned") + return _load_from_cache() + + print(f"[trade_cal] Fetched {len(new_data)} new records") + + # Load cached data and merge + cached_data = _load_from_cache() + if not cached_data.empty: + combined = pd.concat([cached_data, new_data], ignore_index=True) + # Remove duplicates by cal_date + combined = combined.drop_duplicates( + subset=["cal_date", "exchange"], keep="first" + ) + combined = combined.sort_values("cal_date").reset_index(drop=True) + else: + combined = new_data + + # Save combined data to cache + _save_to_cache(combined) + return combined + else: + # No cache, fetch all data + print(f"[trade_cal] No cache found, fetching from {start_date} to {end_date}") + data = client.query( + "trade_cal", + start_date=start_date, + end_date=end_date, + exchange="SSE", + ) + + if data.empty: + print("[trade_cal] No data returned") + return data + + _save_to_cache(data) + return data + + +def get_trade_cal( + start_date: str, + end_date: str, + exchange: Literal["SSE", "SZSE", "BSE"] = "SSE", + is_open: Optional[Literal["0", "1"]] = None, + use_cache: bool = True, +) -> pd.DataFrame: + """Fetch trading calendar data with optional local caching. + + This interface retrieves trading calendar information including + whether each date is a trading day. Uses cached data when available + to reduce API calls and improve performance. + + Args: + start_date: Start date in YYYYMMDD format + end_date: End date in YYYYMMDD format + exchange: Exchange - SSE (Shanghai), SZSE (Shenzhen), BSE (Beijing) + is_open: Open status - "1" for trading day, "0" for non-trading day + use_cache: Whether to use and update local cache (default: True) + + Returns: + pd.DataFrame with trade calendar containing: + - cal_date: Calendar date (YYYYMMDD) + - exchange: Exchange code + - is_open: Whether it's a trading day (1/0) + - pretrade_date: Previous trading day + + Example: + >>> # Get all trading days in January 2024 + >>> cal = get_trade_cal('20240101', '20240131') + >>> trading_days = cal[cal['is_open'] == '1'] + >>> + >>> # Get first and last trading day of a period + >>> cal = get_trade_cal('20180101', '20240101') + >>> first_trade_day = cal[cal['is_open'] == '1'].iloc[0]['cal_date'] + >>> last_trade_day = cal[cal['is_open'] == '1'].iloc[-1]['cal_date'] + """ + # Use cache if enabled + if use_cache and exchange == "SSE": + # Sync cache first (incremental) + sync_trade_cal_cache() + + # Load from cache and filter by date range + cached_data = _load_from_cache() + if not cached_data.empty and "cal_date" in cached_data.columns: + # Filter by date range and exchange + filtered = cached_data[ + (cached_data["cal_date"] >= start_date) + & (cached_data["cal_date"] <= end_date) + & (cached_data["exchange"] == exchange) + ] + + # Apply is_open filter if specified + if is_open is not None: + # Handle type mismatch: HDF5 stores is_open as int, but API returns str + filtered = filtered[filtered["is_open"].astype(str) == str(is_open)] + + if not filtered.empty: + print(f"[get_trade_cal] Retrieved {len(filtered)} records from cache") + return filtered + + # Fallback to API if cache not available or disabled + client = TushareClient() + + # Build parameters + params = { + "start_date": start_date, + "end_date": end_date, + "exchange": exchange, + } + + if is_open is not None: + params["is_open"] = is_open + + # Fetch data + data = client.query("trade_cal", **params) + + if data.empty: + print("[get_trade_cal] No data returned") + + return data + + +def get_trading_days( + start_date: str, + end_date: str, + exchange: Literal["SSE", "SZSE", "BSE"] = "SSE", +) -> list: + """Get list of trading days in a date range. + + Args: + start_date: Start date in YYYYMMDD format + end_date: End date in YYYYMMDD format + exchange: Exchange code + + Returns: + List of trading dates (YYYYMMDD strings) + """ + cal = get_trade_cal(start_date, end_date, exchange=exchange, is_open="1") + if cal.empty: + return [] + return cal["cal_date"].tolist() + + +def get_first_trading_day( + start_date: str, + end_date: str, + exchange: Literal["SSE", "SZSE", "BSE"] = "SSE", +) -> Optional[str]: + """Get the first trading day in a date range. + + Args: + start_date: Start date in YYYYMMDD format + end_date: End date in YYYYMMDD format + exchange: Exchange code + + Returns: + First trading date (YYYYMMDD) or None if no trading days + """ + trading_days = get_trading_days(start_date, end_date, exchange) + if not trading_days: + return None + # Trading days are sorted in descending order (newest first) from cache + return trading_days[-1] + + +def get_last_trading_day( + start_date: str, + end_date: str, + exchange: Literal["SSE", "SZSE", "BSE"] = "SSE", +) -> Optional[str]: + """Get the last trading day in a date range. + + Args: + start_date: Start date in YYYYMMDD format + end_date: End date in YYYYMMDD format + exchange: Exchange code + + Returns: + Last trading date (YYYYMMDD) or None if no trading days + """ + trading_days = get_trading_days(start_date, end_date, exchange) + if not trading_days: + return None + # Trading days are sorted in descending order (newest first) from cache + return trading_days[0] + + +if __name__ == "__main__": + # Example usage + start = "20180101" + end = "20240101" + + print(f"Trade calendar from {start} to {end}") + + cal = get_trade_cal(start, end) + print(f"Total records: {len(cal)}") + + first_day = get_first_trading_day(start, end) + last_day = get_last_trading_day(start, end) + print(f"First trading day: {first_day}") + print(f"Last trading day: {last_day}") diff --git a/tests/test_daily_storage.py b/tests/test_daily_storage.py new file mode 100644 index 0000000..fcd5048 --- /dev/null +++ b/tests/test_daily_storage.py @@ -0,0 +1,190 @@ +"""Tests for data/daily.h5 storage validation. + +Validates two key points: +1. All stocks from stock_basic.csv are saved in daily.h5 +2. No abnormal data with very few data points (< 10 rows per stock) +""" + +import pytest +import pandas as pd +from pathlib import Path +from src.data.storage import Storage +from src.data.stock_basic import _get_csv_path + + +class TestDailyStorageValidation: + """Test daily.h5 storage integrity and completeness.""" + + @pytest.fixture + def storage(self): + """Create storage instance.""" + return Storage() + + @pytest.fixture + def stock_basic_df(self): + """Load stock basic data from CSV.""" + csv_path = _get_csv_path() + if not csv_path.exists(): + pytest.skip(f"stock_basic.csv not found at {csv_path}") + return pd.read_csv(csv_path) + + @pytest.fixture + def daily_df(self, storage): + """Load daily data from HDF5.""" + if not storage.exists("daily"): + pytest.skip("daily.h5 not found") + # HDF5 stores keys with leading slash, so we need to handle both '/daily' and 'daily' + file_path = storage._get_file_path("daily") + try: + with pd.HDFStore(file_path, mode="r") as store: + if "/daily" in store.keys(): + return store["/daily"] + elif "daily" in store.keys(): + return store["daily"] + return pd.DataFrame() + except Exception as e: + pytest.skip(f"Error loading daily.h5: {e}") + + def test_all_stocks_saved(self, storage, stock_basic_df, daily_df): + """Verify all stocks from stock_basic are saved in daily.h5. + + This test ensures data completeness - every stock in stock_basic + should have corresponding data in daily.h5. + """ + if daily_df.empty: + pytest.fail("daily.h5 is empty") + + # Get unique stock codes from both sources + expected_codes = set(stock_basic_df["ts_code"].dropna().unique()) + actual_codes = set(daily_df["ts_code"].dropna().unique()) + + # Check for missing stocks + missing_codes = expected_codes - actual_codes + + if missing_codes: + missing_list = sorted(missing_codes) + # Show first 20 missing stocks as sample + sample = missing_list[:20] + msg = f"Found {len(missing_codes)} stocks missing from daily.h5:\n" + msg += f"Sample missing: {sample}\n" + if len(missing_list) > 20: + msg += f"... and {len(missing_list) - 20} more" + pytest.fail(msg) + + # All stocks present + assert len(actual_codes) > 0, "No stocks found in daily.h5" + print( + f"[TEST] All {len(expected_codes)} stocks from stock_basic are present in daily.h5" + ) + + def test_no_stock_with_insufficient_data(self, storage, daily_df): + """Verify no stock has abnormally few data points (< 10 rows). + + Stocks with very few data points may indicate sync failures, + delisted stocks not properly handled, or data corruption. + """ + if daily_df.empty: + pytest.fail("daily.h5 is empty") + + # Count rows per stock + stock_counts = daily_df.groupby("ts_code").size() + + # Find stocks with less than 10 data points + insufficient_stocks = stock_counts[stock_counts < 10] + + if not insufficient_stocks.empty: + # Separate into categories for better reporting + empty_stocks = stock_counts[stock_counts == 0] + very_few_stocks = stock_counts[(stock_counts > 0) & (stock_counts < 10)] + + msg = f"Found {len(insufficient_stocks)} stocks with insufficient data (< 10 rows):\n" + + if not empty_stocks.empty: + msg += f"\nEmpty stocks (0 rows): {len(empty_stocks)}\n" + sample = sorted(empty_stocks.index[:10].tolist()) + msg += f"Sample: {sample}" + + if not very_few_stocks.empty: + msg += f"\nVery few data points (1-9 rows): {len(very_few_stocks)}\n" + # Show counts for these stocks + sample = very_few_stocks.sort_values().head(20) + msg += "Sample (ts_code: count):\n" + for code, count in sample.items(): + msg += f" {code}: {count} rows\n" + + pytest.fail(msg) + + print(f"[TEST] All stocks have sufficient data (>= 10 rows)") + + def test_data_integrity_basic(self, storage, daily_df): + """Basic data integrity checks for daily.h5.""" + if daily_df.empty: + pytest.fail("daily.h5 is empty") + + # Check required columns exist + required_columns = ["ts_code", "trade_date"] + missing_columns = [ + col for col in required_columns if col not in daily_df.columns + ] + + if missing_columns: + pytest.fail(f"Missing required columns: {missing_columns}") + + # Check for null values in key columns + null_ts_code = daily_df["ts_code"].isna().sum() + null_trade_date = daily_df["trade_date"].isna().sum() + + if null_ts_code > 0: + pytest.fail(f"Found {null_ts_code} rows with null ts_code") + if null_trade_date > 0: + pytest.fail(f"Found {null_trade_date} rows with null trade_date") + + print(f"[TEST] Data integrity check passed") + + def test_stock_data_coverage_report(self, storage, daily_df): + """Generate a summary report of stock data coverage. + + This test provides visibility into data distribution without failing. + """ + if daily_df.empty: + pytest.skip("daily.h5 is empty - cannot generate report") + + stock_counts = daily_df.groupby("ts_code").size() + + # Calculate statistics + total_stocks = len(stock_counts) + min_count = stock_counts.min() + max_count = stock_counts.max() + median_count = stock_counts.median() + mean_count = stock_counts.mean() + + # Distribution buckets + very_low = (stock_counts < 10).sum() + low = ((stock_counts >= 10) & (stock_counts < 100)).sum() + medium = ((stock_counts >= 100) & (stock_counts < 500)).sum() + high = (stock_counts >= 500).sum() + + report = f""" +=== Stock Data Coverage Report === +Total stocks: {total_stocks} +Data points per stock: + Min: {min_count} + Max: {max_count} + Median: {median_count:.0f} + Mean: {mean_count:.1f} + +Distribution: + < 10 rows: {very_low} stocks ({very_low / total_stocks * 100:.1f}%) + 10-99: {low} stocks ({low / total_stocks * 100:.1f}%) + 100-499: {medium} stocks ({medium / total_stocks * 100:.1f}%) + >= 500: {high} stocks ({high / total_stocks * 100:.1f}%) +""" + print(report) + + # This is an informational test - it should not fail + # But we assert to mark it as passed + assert total_stocks > 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_tushare_api.py b/tests/test_tushare_api.py new file mode 100644 index 0000000..0fc57f4 --- /dev/null +++ b/tests/test_tushare_api.py @@ -0,0 +1,20 @@ +"""Tushare API 验证脚本 - 快速生成 pro 对象用于调试。""" + +import os + +os.environ.setdefault("DATA_PATH", "data") + +from src.data.config import get_config +import tushare as ts + +config = get_config() +token = config.tushare_token + +if not token: + raise ValueError("请在 config/.env.local 中配置 TUSHARE_TOKEN") + +pro = ts.pro_api(token) +print(f"pro_api 对象已创建,token: {token[:10]}...") + +df = pro.query('daily', ts_code='000001.SZ', start_date='20180702', end_date='20180718') +print(df)