"""Simplified daily market data interface. A single function to fetch A股日线行情 data from Tushare. Supports all output fields including tor (换手率) and vr (量比). This module provides both single-stock fetching (get_daily) and batch synchronization (DailySync class) for daily market data. """ import pandas as pd from typing import Optional, List, Literal, Dict from datetime import datetime, timedelta from tqdm import tqdm from concurrent.futures import ThreadPoolExecutor, as_completed import threading from src.data.client import TushareClient from src.data.storage import ThreadSafeStorage, Storage from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE from src.data.api_wrappers.api_trade_cal import ( get_first_trading_day, get_last_trading_day, sync_trade_cal_cache, ) from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks def get_daily( ts_code: str, start_date: Optional[str] = None, end_date: Optional[str] = None, trade_date: Optional[str] = None, adj: Literal[None, "qfq", "hfq"] = None, factors: Optional[List[Literal["tor", "vr"]]] = None, adjfactor: bool = False, ) -> pd.DataFrame: """Fetch daily market data for A-share stocks. This is a simplified interface that combines rate limiting, API calls, and error handling into a single function. Args: ts_code: Stock code (e.g., '000001.SZ', '600000.SH') start_date: Start date in YYYYMMDD format end_date: End date in YYYYMMDD format trade_date: Specific trade date in YYYYMMDD format adj: Adjustment type - None, 'qfq' (forward), 'hfq' (backward) factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio) adjfactor: Whether to include adjustment factor Returns: pd.DataFrame with daily market data containing: - Base fields: ts_code, trade_date, open, high, low, close, pre_close, change, pct_chg, vol, amount - Factor fields (if requested): tor, vr - Adjustment factor (if adjfactor=True): adjfactor Example: >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') >>> data = get_daily('600000.SH', factors=['tor', 'vr']) """ # Initialize client client = TushareClient() # Build parameters params = {"ts_code": ts_code} if start_date: params["start_date"] = start_date if end_date: params["end_date"] = end_date if trade_date: params["trade_date"] = trade_date if adj: params["adj"] = adj if factors: # Tushare expects factors as comma-separated string, not list if isinstance(factors, list): factors_str = ",".join(factors) else: factors_str = factors params["factors"] = factors_str if adjfactor: params["adjfactor"] = "True" # Fetch data using pro_bar (supports factors like tor, vr) data = client.query("pro_bar", **params) return data # ============================================================================= # DailySync - 日线数据批量同步类 # ============================================================================= class DailySync: """日线数据批量同步管理器,支持全量/增量同步。 功能特性: - 多线程并发获取(ThreadPoolExecutor) - 增量同步(自动检测上次同步位置) - 内存缓存(避免重复磁盘读取) - 异常立即停止(确保数据一致性) - 预览模式(预览同步数据量,不实际写入) """ # 默认工作线程数 DEFAULT_MAX_WORKERS = 10 def __init__(self, max_workers: Optional[int] = None): """初始化同步管理器。 Args: max_workers: 工作线程数(默认: 10) """ self.storage = ThreadSafeStorage() self.client = TushareClient() self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS self._stop_flag = threading.Event() self._stop_flag.set() # 初始为未停止状态 self._cached_daily_data: Optional[pd.DataFrame] = None # 日线数据缓存 def _load_daily_data(self) -> pd.DataFrame: """从存储加载日线数据(带缓存)。 该方法会将数据缓存在内存中以避免重复磁盘读取。 调用 clear_cache() 可强制重新加载。 Returns: 缓存或从存储加载的日线数据 DataFrame """ if self._cached_daily_data is None: self._cached_daily_data = self.storage.load("daily") return self._cached_daily_data def clear_cache(self) -> None: """清除缓存的日线数据,强制下次访问时重新加载。""" self._cached_daily_data = None def get_all_stock_codes(self, only_listed: bool = True) -> list: """从本地存储获取所有股票代码。 优先使用 stock_basic.csv 以确保包含所有股票, 避免回测中的前视偏差。 Args: only_listed: 若为 True,仅返回当前上市股票(L 状态)。 设为 False 可包含退市股票(用于完整回测)。 Returns: 股票代码列表 """ # 确保 stock_basic.csv 是最新的 print("[DailySync] Ensuring stock_basic.csv is up-to-date...") sync_all_stocks() # 从 stock_basic.csv 文件获取 stock_csv_path = _get_csv_path() if stock_csv_path.exists(): print(f"[DailySync] Reading stock_basic from CSV: {stock_csv_path}") try: stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig") if not stock_df.empty and "ts_code" in stock_df.columns: # 根据 list_status 过滤 if only_listed and "list_status" in stock_df.columns: listed_stocks = stock_df[stock_df["list_status"] == "L"] codes = listed_stocks["ts_code"].unique().tolist() total = len(stock_df["ts_code"].unique()) print( f"[DailySync] Found {len(codes)} listed stocks (filtered from {total} total)" ) else: codes = stock_df["ts_code"].unique().tolist() print( f"[DailySync] Found {len(codes)} stock codes from stock_basic.csv" ) return codes else: print( f"[DailySync] stock_basic.csv exists but no ts_code column or empty" ) except Exception as e: print(f"[DailySync] Error reading stock_basic.csv: {e}") # 回退:从日线存储获取 print( "[DailySync] stock_basic.csv not available, falling back to daily data..." ) daily_data = self._load_daily_data() if not daily_data.empty and "ts_code" in daily_data.columns: codes = daily_data["ts_code"].unique().tolist() print(f"[DailySync] Found {len(codes)} stock codes from daily data") return codes print("[DailySync] No stock codes found in local storage") return [] def get_global_last_date(self) -> Optional[str]: """获取全局最后交易日期。 Returns: 最后交易日期字符串,若无数据则返回 None """ daily_data = self._load_daily_data() if daily_data.empty or "trade_date" not in daily_data.columns: return None return str(daily_data["trade_date"].max()) def get_global_first_date(self) -> Optional[str]: """获取全局最早交易日期。 Returns: 最早交易日期字符串,若无数据则返回 None """ daily_data = self._load_daily_data() if daily_data.empty or "trade_date" not in daily_data.columns: return None return str(daily_data["trade_date"].min()) def get_trade_calendar_bounds( self, start_date: str, end_date: str ) -> tuple[Optional[str], Optional[str]]: """从交易日历获取首尾交易日。 Args: start_date: 开始日期(YYYYMMDD 格式) end_date: 结束日期(YYYYMMDD 格式) Returns: (首交易日, 尾交易日) 元组,若出错则返回 (None, None) """ try: first_day = get_first_trading_day(start_date, end_date) last_day = get_last_trading_day(start_date, end_date) return (first_day, last_day) except Exception as e: print(f"[ERROR] Failed to get trade calendar bounds: {e}") return (None, None) def check_sync_needed( self, force_full: bool = False, table_name: str = "daily", ) -> tuple[bool, Optional[str], Optional[str], Optional[str]]: """基于交易日历检查是否需要同步。 该方法比较本地数据日期范围与交易日历, 以确定是否需要获取新数据。 逻辑: - 若 force_full:需要同步,返回 (True, 20180101, today) - 若无本地数据:需要同步,返回 (True, 20180101, today) - 若存在本地数据: - 从交易日历获取最后交易日 - 若本地最后日期 >= 日历最后日期:无需同步 - 否则:从本地最后日期+1 到最新交易日同步 Args: force_full: 若为 True,始终返回需要同步 table_name: 要检查的表名(默认: "daily") Returns: (需要同步, 起始日期, 结束日期, 本地最后日期) - 需要同步: True 表示应继续同步 - 起始日期: 同步起始日期(无需同步时为 None) - 结束日期: 同步结束日期(无需同步时为 None) - 本地最后日期: 本地数据最后日期(用于增量同步) """ # 若 force_full,始终同步 if force_full: print("[DailySync] Force full sync requested") return (True, DEFAULT_START_DATE, get_today_date(), None) # 检查特定表的本地数据是否存在 storage = Storage() table_data = ( storage.load(table_name) if storage.exists(table_name) else pd.DataFrame() ) if table_data.empty or "trade_date" not in table_data.columns: print( f"[DailySync] No local data found for table '{table_name}', full sync needed" ) return (True, DEFAULT_START_DATE, get_today_date(), None) # 获取本地数据最后日期 local_last_date = str(table_data["trade_date"].max()) print(f"[DailySync] Local data last date: {local_last_date}") # 从交易日历获取最新交易日 today = get_today_date() _, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today) if cal_last is None: print("[DailySync] Failed to get trade calendar, proceeding with sync") return (True, DEFAULT_START_DATE, today, local_last_date) print(f"[DailySync] Calendar last trading day: {cal_last}") # 比较本地最后日期与日历最后日期 print( f"[DailySync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), " f"cal={cal_last} (type={type(cal_last).__name__})" ) try: local_last_int = int(local_last_date) cal_last_int = int(cal_last) print( f"[DailySync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = " f"{local_last_int >= cal_last_int}" ) if local_last_int >= cal_last_int: print( "[DailySync] Local data is up-to-date, SKIPPING sync (no tokens consumed)" ) return (False, None, None, None) except (ValueError, TypeError) as e: print(f"[ERROR] Date comparison failed: {e}") # 需要从本地最后日期+1 同步到最新交易日 sync_start = get_next_date(local_last_date) print(f"[DailySync] Incremental sync needed from {sync_start} to {cal_last}") return (True, sync_start, cal_last, local_last_date) def preview_sync( self, force_full: bool = False, start_date: Optional[str] = None, end_date: Optional[str] = None, sample_size: int = 3, ) -> dict: """预览同步数据量和样本(不实际同步)。 该方法提供即将同步的数据的预览,包括: - 将同步的股票数量 - 同步日期范围 - 预估总记录数 - 前几只股票的样本数据 Args: force_full: 若为 True,预览全量同步(从 20180101) start_date: 手动指定起始日期(覆盖自动检测) end_date: 手动指定结束日期(默认为今天) sample_size: 预览用样本股票数量(默认: 3) Returns: 包含预览信息的字典: { 'sync_needed': bool, 'stock_count': int, 'start_date': str, 'end_date': str, 'estimated_records': int, 'sample_data': pd.DataFrame, 'mode': str, # 'full' 或 'incremental' } """ print("\n" + "=" * 60) print("[DailySync] Preview Mode - Analyzing sync requirements...") print("=" * 60) # 首先确保交易日历缓存是最新的 print("[DailySync] Syncing trade calendar cache...") sync_trade_cal_cache() # 确定日期范围 if end_date is None: end_date = get_today_date() # 检查是否需要同步 sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) if not sync_needed: print("\n" + "=" * 60) print("[DailySync] Preview Result") print("=" * 60) print(" Sync Status: NOT NEEDED") print(" Reason: Local data is up-to-date with trade calendar") print("=" * 60) return { "sync_needed": False, "stock_count": 0, "start_date": None, "end_date": None, "estimated_records": 0, "sample_data": pd.DataFrame(), "mode": "none", } # 使用 check_sync_needed 返回的日期 if cal_start and cal_end: sync_start_date = cal_start end_date = cal_end else: sync_start_date = start_date or DEFAULT_START_DATE if end_date is None: end_date = get_today_date() # 确定同步模式 if force_full: mode = "full" print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}") elif local_last and cal_start and sync_start_date == get_next_date(local_last): mode = "incremental" print(f"[DailySync] Mode: INCREmental SYNC (bandwidth optimized)") print(f"[DailySync] Sync from: {sync_start_date} to {end_date}") else: mode = "partial" print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}") # 获取所有股票代码 stock_codes = self.get_all_stock_codes() if not stock_codes: print("[DailySync] No stocks found to sync") return { "sync_needed": False, "stock_count": 0, "start_date": None, "end_date": None, "estimated_records": 0, "sample_data": pd.DataFrame(), "mode": "none", } stock_count = len(stock_codes) print(f"[DailySync] Total stocks to sync: {stock_count}") # 从前几只股票获取样本数据 print(f"[DailySync] Fetching sample data from {sample_size} stocks...") sample_data_list = [] sample_codes = stock_codes[:sample_size] for ts_code in sample_codes: try: data = self.client.query( "pro_bar", ts_code=ts_code, start_date=sync_start_date, end_date=end_date, factors="tor,vr", ) if not data.empty: sample_data_list.append(data) print(f" - {ts_code}: {len(data)} records") except Exception as e: print(f" - {ts_code}: Error fetching - {e}") # 合并样本数据 sample_df = ( pd.concat(sample_data_list, ignore_index=True) if sample_data_list else pd.DataFrame() ) # 基于样本估算总记录数 if not sample_df.empty: avg_records_per_stock = len(sample_df) / len(sample_data_list) estimated_records = int(avg_records_per_stock * stock_count) else: estimated_records = 0 # 显示预览结果 print("\n" + "=" * 60) print("[DailySync] Preview Result") print("=" * 60) print(f" Sync Mode: {mode.upper()}") print(f" Date Range: {sync_start_date} to {end_date}") print(f" Stocks to Sync: {stock_count}") print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}") print(f" Estimated Total Records: ~{estimated_records:,}") if not sample_df.empty: print(f"\n Sample Data Preview (first {len(sample_df)} rows):") print(" " + "-" * 56) # 以紧凑格式显示样本数据 preview_cols = [ "ts_code", "trade_date", "open", "high", "low", "close", "vol", ] available_cols = [c for c in preview_cols if c in sample_df.columns] sample_display = sample_df[available_cols].head(10) for idx, row in sample_display.iterrows(): print(f" {row.to_dict()}") print(" " + "-" * 56) print("=" * 60) return { "sync_needed": True, "stock_count": stock_count, "start_date": sync_start_date, "end_date": end_date, "estimated_records": estimated_records, "sample_data": sample_df, "mode": mode, } def sync_single_stock( self, ts_code: str, start_date: str, end_date: str, ) -> pd.DataFrame: """同步单只股票的日线数据。 Args: ts_code: 股票代码 start_date: 起始日期(YYYYMMDD) end_date: 结束日期(YYYYMMDD) Returns: 包含日线市场数据的 DataFrame """ # 检查是否应该停止同步(用于异常处理) if not self._stop_flag.is_set(): return pd.DataFrame() try: # 使用共享客户端进行跨线程速率限制 data = self.client.query( "pro_bar", ts_code=ts_code, start_date=start_date, end_date=end_date, factors="tor,vr", ) return data except Exception as e: # 设置停止标志以通知其他线程停止 self._stop_flag.clear() print(f"[ERROR] Exception syncing {ts_code}: {e}") raise def sync_all( self, force_full: bool = False, start_date: Optional[str] = None, end_date: Optional[str] = None, max_workers: Optional[int] = None, dry_run: bool = False, ) -> Dict[str, pd.DataFrame]: """同步本地存储中所有股票的日线数据。 该函数: 1. 从本地存储读取股票代码(daily 或 stock_basic) 2. 检查交易日历确定是否需要同步: - 若本地数据匹配交易日历边界,则跳过同步(节省 token) - 否则,从本地最后日期+1 同步到最新交易日(带宽优化) 3. 使用多线程并发获取(带速率限制) 4. 跳过返回空数据的股票(退市/不可用) 5. 遇异常立即停止 Args: force_full: 若为 True,强制从 20180101 完整重载 start_date: 手动指定起始日期(覆盖自动检测) end_date: 手动指定结束日期(默认为今天) max_workers: 工作线程数(默认: 10) dry_run: 若为 True,仅预览将要同步的内容,不写入数据 Returns: 映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典) """ print("\n" + "=" * 60) print("[DailySync] Starting daily data sync...") print("=" * 60) # 首先确保交易日历缓存是最新的(使用增量同步) print("[DailySync] Syncing trade calendar cache...") sync_trade_cal_cache() # 确定日期范围 if end_date is None: end_date = get_today_date() # 基于交易日历检查是否需要同步 sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) if not sync_needed: # 跳过同步 - 不消耗 token print("\n" + "=" * 60) print("[DailySync] Sync Summary") print("=" * 60) print(" Sync: SKIPPED (local data up-to-date with trade calendar)") print(" Tokens saved: 0 consumed") print("=" * 60) return {} # 使用 check_sync_needed 返回的日期(会计算增量起始日期) if cal_start and cal_end: sync_start_date = cal_start end_date = cal_end else: # 回退到默认逻辑 sync_start_date = start_date or DEFAULT_START_DATE if end_date is None: end_date = get_today_date() # 确定同步模式 if force_full: mode = "full" print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}") elif local_last and cal_start and sync_start_date == get_next_date(local_last): mode = "incremental" print(f"[DailySync] Mode: INCREMENTAL SYNC (bandwidth optimized)") print(f"[DailySync] Sync from: {sync_start_date} to {end_date}") else: mode = "partial" print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}") # 获取所有股票代码 stock_codes = self.get_all_stock_codes() if not stock_codes: print("[DailySync] No stocks found to sync") return {} print(f"[DailySync] Total stocks to sync: {len(stock_codes)}") print(f"[DailySync] Using {max_workers or self.max_workers} worker threads") # 处理 dry run 模式 if dry_run: print("\n" + "=" * 60) print("[DailySync] DRY RUN MODE - No data will be written") print("=" * 60) print(f" Would sync {len(stock_codes)} stocks") print(f" Date range: {sync_start_date} to {end_date}") print(f" Mode: {mode}") print("=" * 60) return {} # 为新同步重置停止标志 self._stop_flag.set() # 多线程并发获取 results: Dict[str, pd.DataFrame] = {} error_occurred = False exception_to_raise = None def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]: """每只股票的任务函数。""" try: data = self.sync_single_stock( ts_code=ts_code, start_date=sync_start_date, end_date=end_date, ) return (ts_code, data) except Exception as e: # 重新抛出以被 Future 捕获 raise # 使用 ThreadPoolExecutor 进行并发获取 workers = max_workers or self.max_workers with ThreadPoolExecutor(max_workers=workers) as executor: # 提交所有任务并跟踪 futures 与股票代码的映射 future_to_code = { executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes } # 使用 as_completed 处理结果 error_count = 0 empty_count = 0 success_count = 0 # 创建进度条 pbar = tqdm(total=len(stock_codes), desc="Syncing stocks") try: # 处理完成的 futures for future in as_completed(future_to_code): ts_code = future_to_code[future] try: _, data = future.result() if data is not None and not data.empty: results[ts_code] = data success_count += 1 else: # 空数据 - 股票可能已退市或不可用 empty_count += 1 print( f"[DailySync] Stock {ts_code}: empty data (skipped, may be delisted)" ) except Exception as e: # 发生异常 - 停止全部并中止 error_occurred = True exception_to_raise = e print(f"\n[ERROR] Sync aborted due to exception: {e}") # 关闭 executor 以停止所有待处理任务 executor.shutdown(wait=False, cancel_futures=True) raise exception_to_raise # 更新进度条 pbar.update(1) except Exception: error_count = 1 print("[DailySync] Sync stopped due to exception") finally: pbar.close() # 批量写入所有数据(仅在无错误时) if results and not error_occurred: for ts_code, data in results.items(): if not data.empty: self.storage.queue_save("daily", data) # 一次性刷新所有排队写入 self.storage.flush() total_rows = sum(len(df) for df in results.values()) print(f"\n[DailySync] Saved {total_rows} rows to storage") # 摘要 print("\n" + "=" * 60) print("[DailySync] Sync Summary") print("=" * 60) print(f" Total stocks: {len(stock_codes)}") print(f" Updated: {success_count}") print(f" Skipped (empty/delisted): {empty_count}") print( f" Errors: {error_count} (aborted on first error)" if error_count else " Errors: 0" ) print(f" Date range: {sync_start_date} to {end_date}") print("=" * 60) return results def sync_daily( force_full: bool = False, start_date: Optional[str] = None, end_date: Optional[str] = None, max_workers: Optional[int] = None, dry_run: bool = False, ) -> Dict[str, pd.DataFrame]: """同步所有股票的日线数据。 这是日线数据同步的主要入口点。 Args: force_full: 若为 True,强制从 20180101 完整重载 start_date: 手动指定起始日期(YYYYMMDD) end_date: 手动指定结束日期(默认为今天) max_workers: 工作线程数(默认: 10) dry_run: 若为 True,仅预览将要同步的内容,不写入数据 Returns: 映射 ts_code 到 DataFrame 的字典 Example: >>> # 首次同步(从 20180101 全量加载) >>> result = sync_daily() >>> >>> # 后续同步(增量 - 仅新数据) >>> result = sync_daily() >>> >>> # 强制完整重载 >>> result = sync_daily(force_full=True) >>> >>> # 手动指定日期范围 >>> result = sync_daily(start_date='20240101', end_date='20240131') >>> >>> # 自定义线程数 >>> result = sync_daily(max_workers=20) >>> >>> # Dry run(仅预览) >>> result = sync_daily(dry_run=True) """ sync_manager = DailySync(max_workers=max_workers) return sync_manager.sync_all( force_full=force_full, start_date=start_date, end_date=end_date, dry_run=dry_run, ) def preview_daily_sync( force_full: bool = False, start_date: Optional[str] = None, end_date: Optional[str] = None, sample_size: int = 3, ) -> dict: """预览日线同步数据量和样本(不实际同步)。 这是推荐的方式,可在实际同步前检查将要同步的内容。 Args: force_full: 若为 True,预览全量同步(从 20180101) start_date: 手动指定起始日期(覆盖自动检测) end_date: 手动指定结束日期(默认为今天) sample_size: 预览用样本股票数量(默认: 3) Returns: 包含预览信息的字典: { 'sync_needed': bool, 'stock_count': int, 'start_date': str, 'end_date': str, 'estimated_records': int, 'sample_data': pd.DataFrame, 'mode': str, # 'full', 'incremental', 'partial', 或 'none' } Example: >>> # 预览将要同步的内容 >>> preview = preview_daily_sync() >>> >>> # 预览全量同步 >>> preview = preview_daily_sync(force_full=True) >>> >>> # 预览更多样本 >>> preview = preview_daily_sync(sample_size=5) """ sync_manager = DailySync() return sync_manager.preview_sync( force_full=force_full, start_date=start_date, end_date=end_date, sample_size=sample_size, )