From 484bcd0ab70c1fb96d3d24996181dae90b1f0793 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Fri, 27 Feb 2026 23:34:12 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=8F=90=E5=8F=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=90=8C=E6=AD=A5=E9=80=BB=E8=BE=91=E4=B8=BA=E6=8A=BD?= =?UTF-8?q?=E8=B1=A1=E5=9F=BA=E7=B1=BB=20=E6=96=B0=E5=A2=9E=20base=5Fsync.?= =?UTF-8?q?py=20=E6=A8=A1=E5=9D=97=EF=BC=8C=E6=8F=90=E4=BE=9B=E4=B8=89?= =?UTF-8?q?=E5=B1=82=E6=8A=BD=E8=B1=A1=E7=BB=93=E6=9E=84=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=90=8C=E6=AD=A5=E6=B5=81=E7=A8=8B=EF=BC=9A?= =?UTF-8?q?=20-=20BaseDataSync:=20=E6=89=80=E6=9C=89=E5=90=8C=E6=AD=A5?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E7=9A=84=E5=9F=BA=E7=A1=80=E6=8A=BD=E8=B1=A1?= =?UTF-8?q?=EF=BC=88=E5=AE=A2=E6=88=B7=E7=AB=AF=E3=80=81=E8=82=A1=E7=A5=A8?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E8=8E=B7=E5=8F=96=E3=80=81=E4=BA=A4=E6=98=93?= =?UTF-8?q?=E6=97=A5=E5=8E=86=EF=BC=89=20-=20StockBasedSync:=20=E6=8C=89?= =?UTF-8?q?=E8=82=A1=E7=A5=A8=E5=90=8C=E6=AD=A5=E6=8A=BD=E8=B1=A1=E7=B1=BB?= =?UTF-8?q?=EF=BC=88=E9=80=82=E7=94=A8=E4=BA=8E=20daily,=20pro=5Fbar?= =?UTF-8?q?=EF=BC=89=20-=20DateBasedSync:=20=E6=8C=89=E6=97=A5=E6=9C=9F?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E6=8A=BD=E8=B1=A1=E7=B1=BB=EF=BC=88=E9=80=82?= =?UTF-8?q?=E7=94=A8=E4=BA=8E=20bak=5Fbasic=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/data/api_wrappers/api_bak_basic.py | 188 +--- src/data/api_wrappers/api_daily.py | 657 +------------- src/data/api_wrappers/api_pro_bar.py | 657 +------------- src/data/api_wrappers/base_sync.py | 1137 ++++++++++++++++++++++++ src/data/storage.py | 3 + src/data/sync.py | 161 +--- 6 files changed, 1255 insertions(+), 1548 deletions(-) create mode 100644 src/data/api_wrappers/base_sync.py diff --git a/src/data/api_wrappers/api_bak_basic.py b/src/data/api_wrappers/api_bak_basic.py index a06b166..f56e819 100644 --- a/src/data/api_wrappers/api_bak_basic.py +++ b/src/data/api_wrappers/api_bak_basic.py @@ -5,12 +5,10 @@ Data available from 2016 onwards. """ import pandas as pd -from typing import Optional, List -from datetime import datetime, timedelta -from tqdm import tqdm +from typing import Optional + from src.data.client import TushareClient -from src.data.storage import ThreadSafeStorage, Storage -from src.data.db_manager import ensure_table +from src.data.api_wrappers.base_sync import DateBasedSync def get_bak_basic( @@ -75,6 +73,34 @@ def get_bak_basic( return data +class BakBasicSync(DateBasedSync): + """历史股票列表批量同步管理器,支持全量/增量同步。 + + 继承自 DateBasedSync,按日期顺序获取数据。 + 数据从 2016 年开始可用。 + + Example: + >>> sync = BakBasicSync() + >>> results = sync.sync_all() # 增量同步 + >>> results = sync.sync_all(force_full=True) # 全量同步 + >>> preview = sync.preview_sync() # 预览 + """ + + table_name = "bak_basic" + default_start_date = "20160101" + + def fetch_single_date(self, trade_date: str) -> pd.DataFrame: + """获取单日的历史股票列表数据。 + + Args: + trade_date: 交易日期(YYYYMMDD) + + Returns: + 包含当日所有股票数据的 DataFrame + """ + return get_bak_basic(trade_date=trade_date) + + def sync_bak_basic( start_date: Optional[str] = None, end_date: Optional[str] = None, @@ -94,152 +120,12 @@ def sync_bak_basic( Returns: pd.DataFrame with synced data """ - from src.data.db_manager import ensure_table - - TABLE_NAME = "bak_basic" - storage = Storage() - thread_storage = ThreadSafeStorage() - - # Default end date - if end_date is None: - end_date = datetime.now().strftime("%Y%m%d") - - # Check if table exists - table_exists = storage.exists(TABLE_NAME) - - if not table_exists or force_full: - # ===== FULL SYNC ===== - # 1. Create table with schema - # 2. Create composite index (trade_date, ts_code) - # 3. Full sync from start_date - - if not table_exists: - print(f"[sync_bak_basic] Table '{TABLE_NAME}' doesn't exist, creating...") - - # Fetch sample to get schema - sample = get_bak_basic(trade_date=end_date) - if sample.empty: - sample = get_bak_basic(trade_date="20240102") - - if sample.empty: - print("[sync_bak_basic] Cannot create table: no sample data available") - return pd.DataFrame() - - # Create table with schema - columns = [] - for col in sample.columns: - dtype = str(sample[col].dtype) - if col == "trade_date": - col_type = "DATE" - elif "int" in dtype: - col_type = "INTEGER" - elif "float" in dtype: - col_type = "DOUBLE" - else: - col_type = "VARCHAR" - columns.append(f'"{col}" {col_type}') - - columns_sql = ", ".join(columns) - create_sql = f'CREATE TABLE IF NOT EXISTS "{TABLE_NAME}" ({columns_sql}, PRIMARY KEY ("trade_date", "ts_code"))' - - try: - storage._connection.execute(create_sql) - print(f"[sync_bak_basic] Created table '{TABLE_NAME}'") - except Exception as e: - print(f"[sync_bak_basic] Error creating table: {e}") - - # Create composite index - try: - storage._connection.execute(f""" - CREATE INDEX IF NOT EXISTS "idx_bak_basic_date_code" - ON "{TABLE_NAME}"("trade_date", "ts_code") - """) - print(f"[sync_bak_basic] Created composite index on (trade_date, ts_code)") - except Exception as e: - print(f"[sync_bak_basic] Error creating index: {e}") - - # Determine sync dates - sync_start = start_date or "20160101" - mode = "FULL" - print(f"[sync_bak_basic] Mode: {mode} SYNC from {sync_start} to {end_date}") - - else: - # ===== INCREMENTAL SYNC ===== - # Check last date in table, sync from last_date + 1 - - try: - result = storage._connection.execute( - f'SELECT MAX("trade_date") FROM "{TABLE_NAME}"' - ).fetchone() - last_date = result[0] if result and result[0] else None - except Exception as e: - print(f"[sync_bak_basic] Error getting last date: {e}") - last_date = None - - if last_date is None: - # Table exists but empty, do full sync - sync_start = start_date or "20160101" - mode = "FULL (empty table)" - else: - # Incremental from last_date + 1 - # Handle both YYYYMMDD and YYYY-MM-DD formats - last_date_str = str(last_date).replace("-", "") - last_dt = datetime.strptime(last_date_str, "%Y%m%d") - next_dt = last_dt + timedelta(days=1) - sync_start = next_dt.strftime("%Y%m%d") - mode = "INCREMENTAL" - - # Skip if already up to date - if sync_start > end_date: - print(f"[sync_bak_basic] Data is up-to-date (last: {last_date}), skipping sync") - return pd.DataFrame() - - print(f"[sync_bak_basic] Mode: {mode} from {sync_start} to {end_date} (last: {last_date})") - - # ===== FETCH AND SAVE DATA ===== - all_data: List[pd.DataFrame] = [] - current = datetime.strptime(sync_start, "%Y%m%d") - end_dt = datetime.strptime(end_date, "%Y%m%d") - - # Calculate total days for progress bar - total_days = (end_dt - current).days + 1 - print(f"[sync_bak_basic] Fetching data for {total_days} days...") - - with tqdm(total=total_days, desc="Syncing dates") as pbar: - while current <= end_dt: - date_str = current.strftime("%Y%m%d") - try: - data = get_bak_basic(trade_date=date_str) - if not data.empty: - all_data.append(data) - pbar.set_postfix({"date": date_str, "records": len(data)}) - except Exception as e: - print(f" {date_str}: ERROR - {e}") - - current += timedelta(days=1) - pbar.update(1) - - if not all_data: - print("[sync_bak_basic] No data fetched") - return pd.DataFrame() - - # Combine and save - combined = pd.concat(all_data, ignore_index=True) - - # Convert trade_date to datetime for proper DATE type storage - combined["trade_date"] = pd.to_datetime(combined["trade_date"], format="%Y%m%d") - - print(f"[sync_bak_basic] Total records: {len(combined)}") - - # Delete existing data for the date range and append new data - # Convert sync_start to date format for comparison with DATE column - sync_start_date = pd.to_datetime(sync_start, format="%Y%m%d").date() - storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start_date]) - thread_storage.queue_save(TABLE_NAME, combined) - thread_storage.flush() - - print(f"[sync_bak_basic] Saved {len(combined)} records to DuckDB") - return combined + sync_manager = BakBasicSync() + return sync_manager.sync_all( + start_date=start_date, + end_date=end_date, + force_full=force_full, + ) if __name__ == "__main__": diff --git a/src/data/api_wrappers/api_daily.py b/src/data/api_wrappers/api_daily.py index 96b4929..2cbfebf 100644 --- a/src/data/api_wrappers/api_daily.py +++ b/src/data/api_wrappers/api_daily.py @@ -9,21 +9,9 @@ batch synchronization (DailySync class) for daily market data. import pandas as pd from typing import Optional, List, Literal, Dict -from datetime import datetime, timedelta -from tqdm import tqdm -from concurrent.futures import ThreadPoolExecutor, as_completed -import threading from src.data.client import TushareClient -from src.data.storage import ThreadSafeStorage, Storage -from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE -from src.config.settings import get_settings -from src.data.api_wrappers.api_trade_cal import ( - get_first_trading_day, - get_last_trading_day, - sync_trade_cal_cache, -) -from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks +from src.data.api_wrappers.base_sync import StockBasedSync def get_daily( @@ -90,426 +78,27 @@ def get_daily( return data -# ============================================================================= -# DailySync - 日线数据批量同步类 -# ============================================================================= - - -class DailySync: +class DailySync(StockBasedSync): """日线数据批量同步管理器,支持全量/增量同步。 - 功能特性: - - 多线程并发获取(ThreadPoolExecutor) - - 增量同步(自动检测上次同步位置) - - 内存缓存(避免重复磁盘读取) - - 异常立即停止(确保数据一致性) - - 预览模式(预览同步数据量,不实际写入) + 继承自 StockBasedSync,使用多线程按股票并发获取数据。 + + Example: + >>> sync = DailySync() + >>> results = sync.sync_all() # 增量同步 + >>> results = sync.sync_all(force_full=True) # 全量同步 + >>> preview = sync.preview_sync() # 预览 """ - # 默认工作线程数(从配置读取,默认10) - DEFAULT_MAX_WORKERS = get_settings().threads + table_name = "daily" - def __init__(self, max_workers: Optional[int] = None): - """初始化同步管理器。 - - Args: - max_workers: 工作线程数(默认从配置读取) - """ - self.client = TushareClient() - self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS - self._stop_flag = threading.Event() - self._stop_flag.set() # 初始为未停止状态 - self._cached_daily_data: Optional[pd.DataFrame] = None # 日线数据缓存 - - def _load_daily_data(self) -> pd.DataFrame: - """从存储加载日线数据(带缓存)。 - - 该方法会将数据缓存在内存中以避免重复磁盘读取。 - 调用 clear_cache() 可强制重新加载。 - - Returns: - 缓存或从存储加载的日线数据 DataFrame - """ - if self._cached_daily_data is None: - self._cached_daily_data = self.storage.load("daily") - return self._cached_daily_data - - def clear_cache(self) -> None: - """清除缓存的日线数据,强制下次访问时重新加载。""" - self._cached_daily_data = None - - def get_all_stock_codes(self, only_listed: bool = True) -> list: - """从本地存储获取所有股票代码。 - - 优先使用 stock_basic.csv 以确保包含所有股票, - 避免回测中的前视偏差。 - - Args: - only_listed: 若为 True,仅返回当前上市股票(L 状态)。 - 设为 False 可包含退市股票(用于完整回测)。 - - Returns: - 股票代码列表 - """ - # 确保 stock_basic.csv 是最新的 - print("[DailySync] Ensuring stock_basic.csv is up-to-date...") - sync_all_stocks() - - # 从 stock_basic.csv 文件获取 - stock_csv_path = _get_csv_path() - - if stock_csv_path.exists(): - print(f"[DailySync] Reading stock_basic from CSV: {stock_csv_path}") - try: - stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig") - if not stock_df.empty and "ts_code" in stock_df.columns: - # 根据 list_status 过滤 - if only_listed and "list_status" in stock_df.columns: - listed_stocks = stock_df[stock_df["list_status"] == "L"] - codes = listed_stocks["ts_code"].unique().tolist() - total = len(stock_df["ts_code"].unique()) - print( - f"[DailySync] Found {len(codes)} listed stocks (filtered from {total} total)" - ) - else: - codes = stock_df["ts_code"].unique().tolist() - print( - f"[DailySync] Found {len(codes)} stock codes from stock_basic.csv" - ) - return codes - else: - print( - f"[DailySync] stock_basic.csv exists but no ts_code column or empty" - ) - except Exception as e: - print(f"[DailySync] Error reading stock_basic.csv: {e}") - - # 回退:从日线存储获取 - print( - "[DailySync] stock_basic.csv not available, falling back to daily data..." - ) - daily_data = self._load_daily_data() - if not daily_data.empty and "ts_code" in daily_data.columns: - codes = daily_data["ts_code"].unique().tolist() - print(f"[DailySync] Found {len(codes)} stock codes from daily data") - return codes - - print("[DailySync] No stock codes found in local storage") - return [] - - def get_global_last_date(self) -> Optional[str]: - """获取全局最后交易日期。 - - Returns: - 最后交易日期字符串,若无数据则返回 None - """ - daily_data = self._load_daily_data() - if daily_data.empty or "trade_date" not in daily_data.columns: - return None - return str(daily_data["trade_date"].max()) - - def get_global_first_date(self) -> Optional[str]: - """获取全局最早交易日期。 - - Returns: - 最早交易日期字符串,若无数据则返回 None - """ - daily_data = self._load_daily_data() - if daily_data.empty or "trade_date" not in daily_data.columns: - return None - return str(daily_data["trade_date"].min()) - - def get_trade_calendar_bounds( - self, start_date: str, end_date: str - ) -> tuple[Optional[str], Optional[str]]: - """从交易日历获取首尾交易日。 - - Args: - start_date: 开始日期(YYYYMMDD 格式) - end_date: 结束日期(YYYYMMDD 格式) - - Returns: - (首交易日, 尾交易日) 元组,若出错则返回 (None, None) - """ - try: - first_day = get_first_trading_day(start_date, end_date) - last_day = get_last_trading_day(start_date, end_date) - return (first_day, last_day) - except Exception as e: - print(f"[ERROR] Failed to get trade calendar bounds: {e}") - return (None, None) - - def check_sync_needed( - self, - force_full: bool = False, - table_name: str = "daily", - ) -> tuple[bool, Optional[str], Optional[str], Optional[str]]: - """基于交易日历检查是否需要同步。 - - 该方法比较本地数据日期范围与交易日历, - 以确定是否需要获取新数据。 - - 逻辑: - - 若 force_full:需要同步,返回 (True, 20180101, today) - - 若无本地数据:需要同步,返回 (True, 20180101, today) - - 若存在本地数据: - - 从交易日历获取最后交易日 - - 若本地最后日期 >= 日历最后日期:无需同步 - - 否则:从本地最后日期+1 到最新交易日同步 - - Args: - force_full: 若为 True,始终返回需要同步 - table_name: 要检查的表名(默认: "daily") - - Returns: - (需要同步, 起始日期, 结束日期, 本地最后日期) - - 需要同步: True 表示应继续同步 - - 起始日期: 同步起始日期(无需同步时为 None) - - 结束日期: 同步结束日期(无需同步时为 None) - - 本地最后日期: 本地数据最后日期(用于增量同步) - """ - # 若 force_full,始终同步 - if force_full: - print("[DailySync] Force full sync requested") - return (True, DEFAULT_START_DATE, get_today_date(), None) - - # 检查特定表的本地数据是否存在 - storage = Storage() - table_data = ( - storage.load(table_name) if storage.exists(table_name) else pd.DataFrame() - ) - - if table_data.empty or "trade_date" not in table_data.columns: - print( - f"[DailySync] No local data found for table '{table_name}', full sync needed" - ) - return (True, DEFAULT_START_DATE, get_today_date(), None) - - # 获取本地数据最后日期 - local_last_date = str(table_data["trade_date"].max()) - - print(f"[DailySync] Local data last date: {local_last_date}") - - # 从交易日历获取最新交易日 - today = get_today_date() - _, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today) - - if cal_last is None: - print("[DailySync] Failed to get trade calendar, proceeding with sync") - return (True, DEFAULT_START_DATE, today, local_last_date) - - print(f"[DailySync] Calendar last trading day: {cal_last}") - - # 比较本地最后日期与日历最后日期 - print( - f"[DailySync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), " - f"cal={cal_last} (type={type(cal_last).__name__})" - ) - try: - local_last_int = int(local_last_date) - cal_last_int = int(cal_last) - print( - f"[DailySync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = " - f"{local_last_int >= cal_last_int}" - ) - if local_last_int >= cal_last_int: - print( - "[DailySync] Local data is up-to-date, SKIPPING sync (no tokens consumed)" - ) - return (False, None, None, None) - except (ValueError, TypeError) as e: - print(f"[ERROR] Date comparison failed: {e}") - - # 需要从本地最后日期+1 同步到最新交易日 - sync_start = get_next_date(local_last_date) - print(f"[DailySync] Incremental sync needed from {sync_start} to {cal_last}") - return (True, sync_start, cal_last, local_last_date) - - def preview_sync( - self, - force_full: bool = False, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - sample_size: int = 3, - ) -> dict: - """预览同步数据量和样本(不实际同步)。 - - 该方法提供即将同步的数据的预览,包括: - - 将同步的股票数量 - - 同步日期范围 - - 预估总记录数 - - 前几只股票的样本数据 - - Args: - force_full: 若为 True,预览全量同步(从 20180101) - start_date: 手动指定起始日期(覆盖自动检测) - end_date: 手动指定结束日期(默认为今天) - sample_size: 预览用样本股票数量(默认: 3) - - Returns: - 包含预览信息的字典: - { - 'sync_needed': bool, - 'stock_count': int, - 'start_date': str, - 'end_date': str, - 'estimated_records': int, - 'sample_data': pd.DataFrame, - 'mode': str, # 'full' 或 'incremental' - } - """ - print("\n" + "=" * 60) - print("[DailySync] Preview Mode - Analyzing sync requirements...") - print("=" * 60) - - # 首先确保交易日历缓存是最新的 - print("[DailySync] Syncing trade calendar cache...") - sync_trade_cal_cache() - - # 确定日期范围 - if end_date is None: - end_date = get_today_date() - - # 检查是否需要同步 - sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) - - if not sync_needed: - print("\n" + "=" * 60) - print("[DailySync] Preview Result") - print("=" * 60) - print(" Sync Status: NOT NEEDED") - print(" Reason: Local data is up-to-date with trade calendar") - print("=" * 60) - return { - "sync_needed": False, - "stock_count": 0, - "start_date": None, - "end_date": None, - "estimated_records": 0, - "sample_data": pd.DataFrame(), - "mode": "none", - } - - # 使用 check_sync_needed 返回的日期 - if cal_start and cal_end: - sync_start_date = cal_start - end_date = cal_end - else: - sync_start_date = start_date or DEFAULT_START_DATE - if end_date is None: - end_date = get_today_date() - - # 确定同步模式 - if force_full: - mode = "full" - print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}") - elif local_last and cal_start and sync_start_date == get_next_date(local_last): - mode = "incremental" - print(f"[DailySync] Mode: INCREmental SYNC (bandwidth optimized)") - print(f"[DailySync] Sync from: {sync_start_date} to {end_date}") - else: - mode = "partial" - print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}") - - # 获取所有股票代码 - stock_codes = self.get_all_stock_codes() - if not stock_codes: - print("[DailySync] No stocks found to sync") - return { - "sync_needed": False, - "stock_count": 0, - "start_date": None, - "end_date": None, - "estimated_records": 0, - "sample_data": pd.DataFrame(), - "mode": "none", - } - - stock_count = len(stock_codes) - print(f"[DailySync] Total stocks to sync: {stock_count}") - - # 从前几只股票获取样本数据 - print(f"[DailySync] Fetching sample data from {sample_size} stocks...") - sample_data_list = [] - sample_codes = stock_codes[:sample_size] - - for ts_code in sample_codes: - try: - data = self.client.query( - "pro_bar", - ts_code=ts_code, - start_date=sync_start_date, - end_date=end_date, - factors="tor,vr", - ) - if not data.empty: - sample_data_list.append(data) - print(f" - {ts_code}: {len(data)} records") - except Exception as e: - print(f" - {ts_code}: Error fetching - {e}") - - # 合并样本数据 - sample_df = ( - pd.concat(sample_data_list, ignore_index=True) - if sample_data_list - else pd.DataFrame() - ) - - # 基于样本估算总记录数 - if not sample_df.empty: - avg_records_per_stock = len(sample_df) / len(sample_data_list) - estimated_records = int(avg_records_per_stock * stock_count) - else: - estimated_records = 0 - - # 显示预览结果 - print("\n" + "=" * 60) - print("[DailySync] Preview Result") - print("=" * 60) - print(f" Sync Mode: {mode.upper()}") - print(f" Date Range: {sync_start_date} to {end_date}") - print(f" Stocks to Sync: {stock_count}") - print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}") - print(f" Estimated Total Records: ~{estimated_records:,}") - - if not sample_df.empty: - print(f"\n Sample Data Preview (first {len(sample_df)} rows):") - print(" " + "-" * 56) - # 以紧凑格式显示样本数据 - preview_cols = [ - "ts_code", - "trade_date", - "open", - "high", - "low", - "close", - "vol", - ] - available_cols = [c for c in preview_cols if c in sample_df.columns] - sample_display = sample_df[available_cols].head(10) - for idx, row in sample_display.iterrows(): - print(f" {row.to_dict()}") - print(" " + "-" * 56) - - print("=" * 60) - - return { - "sync_needed": True, - "stock_count": stock_count, - "start_date": sync_start_date, - "end_date": end_date, - "estimated_records": estimated_records, - "sample_data": sample_df, - "mode": mode, - } - - def sync_single_stock( + def fetch_single_stock( self, ts_code: str, start_date: str, end_date: str, ) -> pd.DataFrame: - """同步单只股票的日线数据。 + """获取单只股票的日线数据。 Args: ts_code: 股票代码 @@ -517,221 +106,17 @@ class DailySync: end_date: 结束日期(YYYYMMDD) Returns: - 包含日线市场数据的 DataFrame + 包含日线数据的 DataFrame """ - # 检查是否应该停止同步(用于异常处理) - if not self._stop_flag.is_set(): - return pd.DataFrame() - - try: - # 使用共享客户端进行跨线程速率限制 - data = self.client.query( - "pro_bar", - ts_code=ts_code, - start_date=start_date, - end_date=end_date, - factors="tor,vr", - ) - return data - except Exception as e: - # 设置停止标志以通知其他线程停止 - self._stop_flag.clear() - print(f"[ERROR] Exception syncing {ts_code}: {e}") - raise - - def sync_all( - self, - force_full: bool = False, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - max_workers: Optional[int] = None, - dry_run: bool = False, - ) -> Dict[str, pd.DataFrame]: - """同步本地存储中所有股票的日线数据。 - - 该函数: - 1. 从本地存储读取股票代码(daily 或 stock_basic) - 2. 检查交易日历确定是否需要同步: - - 若本地数据匹配交易日历边界,则跳过同步(节省 token) - - 否则,从本地最后日期+1 同步到最新交易日(带宽优化) - 3. 使用多线程并发获取(带速率限制) - 4. 跳过返回空数据的股票(退市/不可用) - 5. 遇异常立即停止 - - Args: - force_full: 若为 True,强制从 20180101 完整重载 - start_date: 手动指定起始日期(覆盖自动检测) - end_date: 手动指定结束日期(默认为今天) - max_workers: 工作线程数(默认: 10) - dry_run: 若为 True,仅预览将要同步的内容,不写入数据 - - Returns: - 映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典) - """ - print("\n" + "=" * 60) - print("[DailySync] Starting daily data sync...") - print("=" * 60) - - # 首先确保交易日历缓存是最新的(使用增量同步) - print("[DailySync] Syncing trade calendar cache...") - sync_trade_cal_cache() - - # 确定日期范围 - if end_date is None: - end_date = get_today_date() - - # 基于交易日历检查是否需要同步 - sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) - - if not sync_needed: - # 跳过同步 - 不消耗 token - print("\n" + "=" * 60) - print("[DailySync] Sync Summary") - print("=" * 60) - print(" Sync: SKIPPED (local data up-to-date with trade calendar)") - print(" Tokens saved: 0 consumed") - print("=" * 60) - return {} - - # 使用 check_sync_needed 返回的日期(会计算增量起始日期) - if cal_start and cal_end: - sync_start_date = cal_start - end_date = cal_end - else: - # 回退到默认逻辑 - sync_start_date = start_date or DEFAULT_START_DATE - if end_date is None: - end_date = get_today_date() - - # 确定同步模式 - if force_full: - mode = "full" - print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}") - elif local_last and cal_start and sync_start_date == get_next_date(local_last): - mode = "incremental" - print(f"[DailySync] Mode: INCREMENTAL SYNC (bandwidth optimized)") - print(f"[DailySync] Sync from: {sync_start_date} to {end_date}") - else: - mode = "partial" - print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}") - - # 获取所有股票代码 - stock_codes = self.get_all_stock_codes() - if not stock_codes: - print("[DailySync] No stocks found to sync") - return {} - - print(f"[DailySync] Total stocks to sync: {len(stock_codes)}") - print(f"[DailySync] Using {max_workers or self.max_workers} worker threads") - - # 处理 dry run 模式 - if dry_run: - print("\n" + "=" * 60) - print("[DailySync] DRY RUN MODE - No data will be written") - print("=" * 60) - print(f" Would sync {len(stock_codes)} stocks") - print(f" Date range: {sync_start_date} to {end_date}") - print(f" Mode: {mode}") - print("=" * 60) - return {} - - # 为新同步重置停止标志 - self._stop_flag.set() - - # 多线程并发获取 - results: Dict[str, pd.DataFrame] = {} - error_occurred = False - exception_to_raise = None - - def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]: - """每只股票的任务函数。""" - try: - data = self.sync_single_stock( - ts_code=ts_code, - start_date=sync_start_date, - end_date=end_date, - ) - return (ts_code, data) - except Exception as e: - # 重新抛出以被 Future 捕获 - raise - - # 使用 ThreadPoolExecutor 进行并发获取 - workers = max_workers or self.max_workers - with ThreadPoolExecutor(max_workers=workers) as executor: - # 提交所有任务并跟踪 futures 与股票代码的映射 - future_to_code = { - executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes - } - - # 使用 as_completed 处理结果 - error_count = 0 - empty_count = 0 - success_count = 0 - - # 创建进度条 - pbar = tqdm(total=len(stock_codes), desc="Syncing stocks") - - try: - # 处理完成的 futures - for future in as_completed(future_to_code): - ts_code = future_to_code[future] - - try: - _, data = future.result() - if data is not None and not data.empty: - results[ts_code] = data - success_count += 1 - else: - # 空数据 - 股票可能已退市或不可用 - empty_count += 1 - print( - f"[DailySync] Stock {ts_code}: empty data (skipped, may be delisted)" - ) - except Exception as e: - # 发生异常 - 停止全部并中止 - error_occurred = True - exception_to_raise = e - print(f"\n[ERROR] Sync aborted due to exception: {e}") - # 关闭 executor 以停止所有待处理任务 - executor.shutdown(wait=False, cancel_futures=True) - raise exception_to_raise - - # 更新进度条 - pbar.update(1) - - except Exception: - error_count = 1 - print("[DailySync] Sync stopped due to exception") - finally: - pbar.close() - - # 批量写入所有数据(仅在无错误时) - if results and not error_occurred: - for ts_code, data in results.items(): - if not data.empty: - self.storage.queue_save("daily", data) - # 一次性刷新所有排队写入 - self.storage.flush() - total_rows = sum(len(df) for df in results.values()) - print(f"\n[DailySync] Saved {total_rows} rows to storage") - - # 摘要 - print("\n" + "=" * 60) - print("[DailySync] Sync Summary") - print("=" * 60) - print(f" Total stocks: {len(stock_codes)}") - print(f" Updated: {success_count}") - print(f" Skipped (empty/delisted): {empty_count}") - print( - f" Errors: {error_count} (aborted on first error)" - if error_count - else " Errors: 0" + # 使用共享客户端进行跨线程速率限制 + data = self.client.query( + "pro_bar", + ts_code=ts_code, + start_date=start_date, + end_date=end_date, + factors="tor,vr", ) - print(f" Date range: {sync_start_date} to {end_date}") - print("=" * 60) - - return results + return data def sync_daily( diff --git a/src/data/api_wrappers/api_pro_bar.py b/src/data/api_wrappers/api_pro_bar.py index ce08fa0..87fd142 100644 --- a/src/data/api_wrappers/api_pro_bar.py +++ b/src/data/api_wrappers/api_pro_bar.py @@ -8,21 +8,9 @@ volume ratio (vr), and adjustment factors. import pandas as pd from typing import Optional, List, Literal, Dict -from datetime import datetime, timedelta -from tqdm import tqdm -from concurrent.futures import ThreadPoolExecutor, as_completed -import threading from src.data.client import TushareClient -from src.data.storage import ThreadSafeStorage, Storage -from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE -from src.config.settings import get_settings -from src.data.api_wrappers.api_trade_cal import ( - get_first_trading_day, - get_last_trading_day, - sync_trade_cal_cache, -) -from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks +from src.data.api_wrappers.base_sync import StockBasedSync def get_pro_bar( @@ -138,429 +126,28 @@ def get_pro_bar( return data -# ============================================================================= -# ProBarSync - Pro Bar 数据批量同步类 -# ============================================================================= - - -class ProBarSync: +class ProBarSync(StockBasedSync): """Pro Bar 数据批量同步管理器,支持全量/增量同步。 - 功能特性: - - 多线程并发获取(ThreadPoolExecutor) - - 增量同步(自动检测上次同步位置) - - 内存缓存(避免重复磁盘读取) - - 异常立即停止(确保数据一致性) - - 预览模式(预览同步数据量,不实际写入) - - 默认获取全部数据列(tor, vr, adj_factor) + 继承自 StockBasedSync,使用多线程按股票并发获取数据。 + 默认获取全部数据列(tor, vr, adj_factor)。 + + Example: + >>> sync = ProBarSync() + >>> results = sync.sync_all() # 增量同步 + >>> results = sync.sync_all(force_full=True) # 全量同步 + >>> preview = sync.preview_sync() # 预览 """ - # 默认工作线程数(从配置读取,默认10) - DEFAULT_MAX_WORKERS = get_settings().threads + table_name = "pro_bar" - def __init__(self, max_workers: Optional[int] = None): - """初始化同步管理器。 - - max_workers: 工作线程数(默认从配置读取,若未指定则使用配置值) - max_workers: 工作线程数(默认: 10) - """ - self.storage = ThreadSafeStorage() - self.client = TushareClient() - self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS - self._stop_flag = threading.Event() - self._stop_flag.set() # 初始为未停止状态 - self._cached_pro_bar_data: Optional[pd.DataFrame] = None # 数据缓存 - - def _load_pro_bar_data(self) -> pd.DataFrame: - """从存储加载 Pro Bar 数据(带缓存)。 - - 该方法会将数据缓存在内存中以避免重复磁盘读取。 - 调用 clear_cache() 可强制重新加载。 - - Returns: - 缓存或从存储加载的 Pro Bar 数据 DataFrame - """ - if self._cached_pro_bar_data is None: - self._cached_pro_bar_data = self.storage.load("pro_bar") - return self._cached_pro_bar_data - - def clear_cache(self) -> None: - """清除缓存的 Pro Bar 数据,强制下次访问时重新加载。""" - self._cached_pro_bar_data = None - - def get_all_stock_codes(self, only_listed: bool = True) -> list: - """从本地存储获取所有股票代码。 - - 优先使用 stock_basic.csv 以确保包含所有股票, - 避免回测中的前视偏差。 - - Args: - only_listed: 若为 True,仅返回当前上市股票(L 状态)。 - 设为 False 可包含退市股票(用于完整回测)。 - - Returns: - 股票代码列表 - """ - # 确保 stock_basic.csv 是最新的 - print("[ProBarSync] Ensuring stock_basic.csv is up-to-date...") - sync_all_stocks() - - # 从 stock_basic.csv 文件获取 - stock_csv_path = _get_csv_path() - - if stock_csv_path.exists(): - print(f"[ProBarSync] Reading stock_basic from CSV: {stock_csv_path}") - try: - stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig") - if not stock_df.empty and "ts_code" in stock_df.columns: - # 根据 list_status 过滤 - if only_listed and "list_status" in stock_df.columns: - listed_stocks = stock_df[stock_df["list_status"] == "L"] - codes = listed_stocks["ts_code"].unique().tolist() - total = len(stock_df["ts_code"].unique()) - print( - f"[ProBarSync] Found {len(codes)} listed stocks (filtered from {total} total)" - ) - else: - codes = stock_df["ts_code"].unique().tolist() - print( - f"[ProBarSync] Found {len(codes)} stock codes from stock_basic.csv" - ) - return codes - else: - print( - f"[ProBarSync] stock_basic.csv exists but no ts_code column or empty" - ) - except Exception as e: - print(f"[ProBarSync] Error reading stock_basic.csv: {e}") - - # 回退:从 Pro Bar 存储获取 - print( - "[ProBarSync] stock_basic.csv not available, falling back to pro_bar data..." - ) - pro_bar_data = self._load_pro_bar_data() - if not pro_bar_data.empty and "ts_code" in pro_bar_data.columns: - codes = pro_bar_data["ts_code"].unique().tolist() - print(f"[ProBarSync] Found {len(codes)} stock codes from pro_bar data") - return codes - - print("[ProBarSync] No stock codes found in local storage") - return [] - - def get_global_last_date(self) -> Optional[str]: - """获取全局最后交易日期。 - - Returns: - 最后交易日期字符串,若无数据则返回 None - """ - pro_bar_data = self._load_pro_bar_data() - if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns: - return None - return str(pro_bar_data["trade_date"].max()) - - def get_global_first_date(self) -> Optional[str]: - """获取全局最早交易日期。 - - Returns: - 最早交易日期字符串,若无数据则返回 None - """ - pro_bar_data = self._load_pro_bar_data() - if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns: - return None - return str(pro_bar_data["trade_date"].min()) - - def get_trade_calendar_bounds( - self, start_date: str, end_date: str - ) -> tuple[Optional[str], Optional[str]]: - """从交易日历获取首尾交易日。 - - Args: - start_date: 开始日期(YYYYMMDD 格式) - end_date: 结束日期(YYYYMMDD 格式) - - Returns: - (首交易日, 尾交易日) 元组,若出错则返回 (None, None) - """ - try: - first_day = get_first_trading_day(start_date, end_date) - last_day = get_last_trading_day(start_date, end_date) - return (first_day, last_day) - except Exception as e: - print(f"[ERROR] Failed to get trade calendar bounds: {e}") - return (None, None) - - def check_sync_needed( - self, - force_full: bool = False, - table_name: str = "pro_bar", - ) -> tuple[bool, Optional[str], Optional[str], Optional[str]]: - """基于交易日历检查是否需要同步。 - - 该方法比较本地数据日期范围与交易日历, - 以确定是否需要获取新数据。 - - 逻辑: - - 若 force_full:需要同步,返回 (True, 20180101, today) - - 若无本地数据:需要同步,返回 (True, 20180101, today) - - 若存在本地数据: - - 从交易日历获取最后交易日 - - 若本地最后日期 >= 日历最后日期:无需同步 - - 否则:从本地最后日期+1 到最新交易日同步 - - Args: - force_full: 若为 True,始终返回需要同步 - table_name: 要检查的表名(默认: "pro_bar") - - Returns: - (需要同步, 起始日期, 结束日期, 本地最后日期) - - 需要同步: True 表示应继续同步 - - 起始日期: 同步起始日期(无需同步时为 None) - - 结束日期: 同步结束日期(无需同步时为 None) - - 本地最后日期: 本地数据最后日期(用于增量同步) - """ - # 若 force_full,始终同步 - if force_full: - print("[ProBarSync] Force full sync requested") - return (True, DEFAULT_START_DATE, get_today_date(), None) - - # 检查特定表的本地数据是否存在 - storage = Storage() - table_data = ( - storage.load(table_name) if storage.exists(table_name) else pd.DataFrame() - ) - - if table_data.empty or "trade_date" not in table_data.columns: - print( - f"[ProBarSync] No local data found for table '{table_name}', full sync needed" - ) - return (True, DEFAULT_START_DATE, get_today_date(), None) - - # 获取本地数据最后日期 - local_last_date = str(table_data["trade_date"].max()) - - print(f"[ProBarSync] Local data last date: {local_last_date}") - - # 从交易日历获取最新交易日 - today = get_today_date() - _, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today) - - if cal_last is None: - print("[ProBarSync] Failed to get trade calendar, proceeding with sync") - return (True, DEFAULT_START_DATE, today, local_last_date) - - print(f"[ProBarSync] Calendar last trading day: {cal_last}") - - # 比较本地最后日期与日历最后日期 - print( - f"[ProBarSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), " - f"cal={cal_last} (type={type(cal_last).__name__})" - ) - try: - local_last_int = int(local_last_date) - cal_last_int = int(cal_last) - print( - f"[ProBarSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = " - f"{local_last_int >= cal_last_int}" - ) - if local_last_int >= cal_last_int: - print( - "[ProBarSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)" - ) - return (False, None, None, None) - except (ValueError, TypeError) as e: - print(f"[ERROR] Date comparison failed: {e}") - - # 需要从本地最后日期+1 同步到最新交易日 - sync_start = get_next_date(local_last_date) - print(f"[ProBarSync] Incremental sync needed from {sync_start} to {cal_last}") - return (True, sync_start, cal_last, local_last_date) - - def preview_sync( - self, - force_full: bool = False, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - sample_size: int = 3, - ) -> dict: - """预览同步数据量和样本(不实际同步)。 - - 该方法提供即将同步的数据的预览,包括: - - 将同步的股票数量 - - 同步日期范围 - - 预估总记录数 - - 前几只股票的样本数据 - - Args: - force_full: 若为 True,预览全量同步(从 20180101) - start_date: 手动指定起始日期(覆盖自动检测) - end_date: 手动指定结束日期(默认为今天) - sample_size: 预览用样本股票数量(默认: 3) - - Returns: - 包含预览信息的字典: - { - 'sync_needed': bool, - 'stock_count': int, - 'start_date': str, - 'end_date': str, - 'estimated_records': int, - 'sample_data': pd.DataFrame, - 'mode': str, # 'full' 或 'incremental' - } - """ - print("\n" + "=" * 60) - print("[ProBarSync] Preview Mode - Analyzing sync requirements...") - print("=" * 60) - - # 首先确保交易日历缓存是最新的 - print("[ProBarSync] Syncing trade calendar cache...") - sync_trade_cal_cache() - - # 确定日期范围 - if end_date is None: - end_date = get_today_date() - - # 检查是否需要同步 - sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) - - if not sync_needed: - print("\n" + "=" * 60) - print("[ProBarSync] Preview Result") - print("=" * 60) - print(" Sync Status: NOT NEEDED") - print(" Reason: Local data is up-to-date with trade calendar") - print("=" * 60) - return { - "sync_needed": False, - "stock_count": 0, - "start_date": None, - "end_date": None, - "estimated_records": 0, - "sample_data": pd.DataFrame(), - "mode": "none", - } - - # 使用 check_sync_needed 返回的日期 - if cal_start and cal_end: - sync_start_date = cal_start - end_date = cal_end - else: - sync_start_date = start_date or DEFAULT_START_DATE - if end_date is None: - end_date = get_today_date() - - # 确定同步模式 - if force_full: - mode = "full" - print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}") - elif local_last and cal_start and sync_start_date == get_next_date(local_last): - mode = "incremental" - print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)") - print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}") - else: - mode = "partial" - print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}") - - # 获取所有股票代码 - stock_codes = self.get_all_stock_codes() - if not stock_codes: - print("[ProBarSync] No stocks found to sync") - return { - "sync_needed": False, - "stock_count": 0, - "start_date": None, - "end_date": None, - "estimated_records": 0, - "sample_data": pd.DataFrame(), - "mode": "none", - } - - stock_count = len(stock_codes) - print(f"[ProBarSync] Total stocks to sync: {stock_count}") - - # 从前几只股票获取样本数据 - print(f"[ProBarSync] Fetching sample data from {sample_size} stocks...") - sample_data_list = [] - sample_codes = stock_codes[:sample_size] - - for ts_code in sample_codes: - try: - # 使用 get_pro_bar 获取样本数据(包含所有字段) - data = get_pro_bar( - ts_code=ts_code, - start_date=sync_start_date, - end_date=end_date, - ) - if not data.empty: - sample_data_list.append(data) - print(f" - {ts_code}: {len(data)} records") - except Exception as e: - print(f" - {ts_code}: Error fetching - {e}") - - # 合并样本数据 - sample_df = ( - pd.concat(sample_data_list, ignore_index=True) - if sample_data_list - else pd.DataFrame() - ) - - # 基于样本估算总记录数 - if not sample_df.empty: - avg_records_per_stock = len(sample_df) / len(sample_data_list) - estimated_records = int(avg_records_per_stock * stock_count) - else: - estimated_records = 0 - - # 显示预览结果 - print("\n" + "=" * 60) - print("[ProBarSync] Preview Result") - print("=" * 60) - print(f" Sync Mode: {mode.upper()}") - print(f" Date Range: {sync_start_date} to {end_date}") - print(f" Stocks to Sync: {stock_count}") - print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}") - print(f" Estimated Total Records: ~{estimated_records:,}") - - if not sample_df.empty: - print(f"\n Sample Data Preview (first {len(sample_df)} rows):") - print(" " + "-" * 56) - # 以紧凑格式显示样本数据 - preview_cols = [ - "ts_code", - "trade_date", - "open", - "high", - "low", - "close", - "vol", - "tor", - "vr", - ] - available_cols = [c for c in preview_cols if c in sample_df.columns] - sample_display = sample_df[available_cols].head(10) - for idx, row in sample_display.iterrows(): - print(f" {row.to_dict()}") - print(" " + "-" * 56) - - print("=" * 60) - - return { - "sync_needed": True, - "stock_count": stock_count, - "start_date": sync_start_date, - "end_date": end_date, - "estimated_records": estimated_records, - "sample_data": sample_df, - "mode": mode, - } - - def sync_single_stock( + def fetch_single_stock( self, ts_code: str, start_date: str, end_date: str, ) -> pd.DataFrame: - """同步单只股票的 Pro Bar 数据。 + """获取单只股票的 Pro Bar 数据。 Args: ts_code: 股票代码 @@ -570,218 +157,14 @@ class ProBarSync: Returns: 包含 Pro Bar 数据的 DataFrame """ - # 检查是否应该停止同步(用于异常处理) - if not self._stop_flag.is_set(): - return pd.DataFrame() - - try: - # 使用 get_pro_bar 获取数据(默认包含所有字段,传递共享 client) - data = get_pro_bar( - ts_code=ts_code, - start_date=start_date, - end_date=end_date, - client=self.client, # 传递共享客户端以确保限流 - ) - return data - except Exception as e: - # 设置停止标志以通知其他线程停止 - self._stop_flag.clear() - print(f"[ERROR] Exception syncing {ts_code}: {e}") - raise - - def sync_all( - self, - force_full: bool = False, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - max_workers: Optional[int] = None, - dry_run: bool = False, - ) -> Dict[str, pd.DataFrame]: - """同步本地存储中所有股票的 Pro Bar 数据。 - - 该函数: - 1. 从本地存储读取股票代码(pro_bar 或 stock_basic) - 2. 检查交易日历确定是否需要同步: - - 若本地数据匹配交易日历边界,则跳过同步(节省 token) - - 否则,从本地最后日期+1 同步到最新交易日(带宽优化) - 3. 使用多线程并发获取(带速率限制) - 4. 跳过返回空数据的股票(退市/不可用) - 5. 遇异常立即停止 - - Args: - force_full: 若为 True,强制从 20180101 完整重载 - start_date: 手动指定起始日期(覆盖自动检测) - end_date: 手动指定结束日期(默认为今天) - max_workers: 工作线程数(默认: 10) - dry_run: 若为 True,仅预览将要同步的内容,不写入数据 - - Returns: - 映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典) - """ - print("\n" + "=" * 60) - print("[ProBarSync] Starting pro_bar data sync...") - print("=" * 60) - - # 首先确保交易日历缓存是最新的(使用增量同步) - print("[ProBarSync] Syncing trade calendar cache...") - sync_trade_cal_cache() - - # 确定日期范围 - if end_date is None: - end_date = get_today_date() - - # 基于交易日历检查是否需要同步 - sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) - - if not sync_needed: - # 跳过同步 - 不消耗 token - print("\n" + "=" * 60) - print("[ProBarSync] Sync Summary") - print("=" * 60) - print(" Sync: SKIPPED (local data up-to-date with trade calendar)") - print(" Tokens saved: 0 consumed") - print("=" * 60) - return {} - - # 使用 check_sync_needed 返回的日期(会计算增量起始日期) - if cal_start and cal_end: - sync_start_date = cal_start - end_date = cal_end - else: - # 回退到默认逻辑 - sync_start_date = start_date or DEFAULT_START_DATE - if end_date is None: - end_date = get_today_date() - - # 确定同步模式 - if force_full: - mode = "full" - print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}") - elif local_last and cal_start and sync_start_date == get_next_date(local_last): - mode = "incremental" - print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)") - print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}") - else: - mode = "partial" - print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}") - - # 获取所有股票代码 - stock_codes = self.get_all_stock_codes() - if not stock_codes: - print("[ProBarSync] No stocks found to sync") - return {} - - print(f"[ProBarSync] Total stocks to sync: {len(stock_codes)}") - print(f"[ProBarSync] Using {max_workers or self.max_workers} worker threads") - - # 处理 dry run 模式 - if dry_run: - print("\n" + "=" * 60) - print("[ProBarSync] DRY RUN MODE - No data will be written") - print("=" * 60) - print(f" Would sync {len(stock_codes)} stocks") - print(f" Date range: {sync_start_date} to {end_date}") - print(f" Mode: {mode}") - print("=" * 60) - return {} - - # 为新同步重置停止标志 - self._stop_flag.set() - - # 多线程并发获取 - results: Dict[str, pd.DataFrame] = {} - error_occurred = False - exception_to_raise = None - - def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]: - """每只股票的任务函数。""" - try: - data = self.sync_single_stock( - ts_code=ts_code, - start_date=sync_start_date, - end_date=end_date, - ) - return (ts_code, data) - except Exception as e: - # 重新抛出以被 Future 捕获 - raise - - # 使用 ThreadPoolExecutor 进行并发获取 - workers = max_workers or self.max_workers - with ThreadPoolExecutor(max_workers=workers) as executor: - # 提交所有任务并跟踪 futures 与股票代码的映射 - future_to_code = { - executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes - } - - # 使用 as_completed 处理结果 - error_count = 0 - empty_count = 0 - success_count = 0 - - # 创建进度条 - pbar = tqdm(total=len(stock_codes), desc="Syncing pro_bar stocks") - - try: - # 处理完成的 futures - for future in as_completed(future_to_code): - ts_code = future_to_code[future] - - try: - _, data = future.result() - if data is not None and not data.empty: - results[ts_code] = data - success_count += 1 - else: - # 空数据 - 股票可能已退市或不可用 - empty_count += 1 - print( - f"[ProBarSync] Stock {ts_code}: empty data (skipped, may be delisted)" - ) - except Exception as e: - # 发生异常 - 停止全部并中止 - error_occurred = True - exception_to_raise = e - print(f"\n[ERROR] Sync aborted due to exception: {e}") - # 关闭 executor 以停止所有待处理任务 - executor.shutdown(wait=False, cancel_futures=True) - raise exception_to_raise - - # 更新进度条 - pbar.update(1) - - except Exception: - error_count = 1 - print("[ProBarSync] Sync stopped due to exception") - finally: - pbar.close() - - # 批量写入所有数据(仅在无错误时) - if results and not error_occurred: - for ts_code, data in results.items(): - if not data.empty: - self.storage.queue_save("pro_bar", data) - # 一次性刷新所有排队写入 - self.storage.flush() - total_rows = sum(len(df) for df in results.values()) - print(f"\n[ProBarSync] Saved {total_rows} rows to storage") - - # 摘要 - print("\n" + "=" * 60) - print("[ProBarSync] Sync Summary") - print("=" * 60) - print(f" Total stocks: {len(stock_codes)}") - print(f" Updated: {success_count}") - print(f" Skipped (empty/delisted): {empty_count}") - print( - f" Errors: {error_count} (aborted on first error)" - if error_count - else " Errors: 0" + # 使用 get_pro_bar 获取数据(默认包含所有字段,传递共享 client) + data = get_pro_bar( + ts_code=ts_code, + start_date=start_date, + end_date=end_date, + client=self.client, # 传递共享客户端以确保限流 ) - print(f" Date range: {sync_start_date} to {end_date}") - print("=" * 60) - - return results + return data def sync_pro_bar( diff --git a/src/data/api_wrappers/base_sync.py b/src/data/api_wrappers/base_sync.py new file mode 100644 index 0000000..41a90ef --- /dev/null +++ b/src/data/api_wrappers/base_sync.py @@ -0,0 +1,1137 @@ +"""数据同步基础抽象模块。 + +提供三层抽象结构统一所有数据同步流程: +- BaseDataSync: 所有同步类型的基础抽象类 +- StockBasedSync: 按股票同步的抽象类(daily, pro_bar) +- DateBasedSync: 按日期同步的抽象类(bak_basic) + +使用方式: + # 按股票同步(daily, pro_bar) + class DailySync(StockBasedSync): + table_name = "daily" + + def fetch_single_stock(self, ts_code: str, start_date: str, end_date: str) -> pd.DataFrame: + # 实现单只股票数据获取 + ... + + # 按日期同步(bak_basic) + class BakBasicSync(DateBasedSync): + table_name = "bak_basic" + + def fetch_single_date(self, trade_date: str) -> pd.DataFrame: + # 实现单日数据获取 + ... +""" + +from abc import ABC, abstractmethod +from typing import Optional, List, Dict +from datetime import datetime, timedelta +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading + +import pandas as pd +from tqdm import tqdm + +from src.data.client import TushareClient +from src.data.storage import ThreadSafeStorage, Storage +from src.data.utils import get_today_date, get_next_date +from src.config.settings import get_settings +from src.data.api_wrappers.api_trade_cal import ( + get_first_trading_day, + get_last_trading_day, + sync_trade_cal_cache, +) +from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks + + +class BaseDataSync(ABC): + """数据同步基础抽象类。 + + 提供所有同步类型共有的功能: + - 客户端初始化 + - 股票代码获取 + - 交易日历操作 + - 同步需求检查 + - 缓存管理 + + Attributes: + table_name: 子类必须定义的目标表名 + DEFAULT_START_DATE: 默认同步起始日期 + """ + + table_name: str = "" # 子类必须覆盖 + DEFAULT_START_DATE = "20180101" + DEFAULT_MAX_WORKERS = get_settings().threads + + def __init__(self, max_workers: Optional[int] = None): + """初始化同步管理器。 + + Args: + max_workers: 工作线程数(默认从配置读取) + """ + self.storage = ThreadSafeStorage() + self.client = TushareClient() + self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS + self._stop_flag = threading.Event() + self._stop_flag.set() # 初始为未停止状态 + self._cached_data: Optional[pd.DataFrame] = None + + def _load_cached_data(self) -> pd.DataFrame: + """从存储加载数据(带缓存)。 + + Returns: + 缓存或从存储加载的 DataFrame + """ + if self._cached_data is None: + self._cached_data = self.storage.load(self.table_name) + return self._cached_data + + def clear_cache(self) -> None: + """清除缓存的数据,强制下次访问时重新加载。""" + self._cached_data = None + + def get_all_stock_codes(self, only_listed: bool = True) -> List[str]: + """从本地存储获取所有股票代码。 + + 优先使用 stock_basic.csv 以确保包含所有股票, + 避免回测中的前视偏差。 + + Args: + only_listed: 若为 True,仅返回当前上市股票(L 状态) + + Returns: + 股票代码列表 + """ + class_name = self.__class__.__name__ + print(f"[{class_name}] Ensuring stock_basic.csv is up-to-date...") + sync_all_stocks() + + stock_csv_path = _get_csv_path() + + if stock_csv_path.exists(): + print(f"[{class_name}] Reading stock_basic from CSV: {stock_csv_path}") + try: + stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig") + if not stock_df.empty and "ts_code" in stock_df.columns: + if only_listed and "list_status" in stock_df.columns: + listed_stocks = stock_df[stock_df["list_status"] == "L"] + codes = listed_stocks["ts_code"].unique().tolist() + total = len(stock_df["ts_code"].unique()) + print( + f"[{class_name}] Found {len(codes)} listed stocks (filtered from {total} total)" + ) + else: + codes = stock_df["ts_code"].unique().tolist() + print( + f"[{class_name}] Found {len(codes)} stock codes from stock_basic.csv" + ) + return codes + else: + print( + f"[{class_name}] stock_basic.csv exists but no ts_code column or empty" + ) + except Exception as e: + print(f"[{class_name}] Error reading stock_basic.csv: {e}") + + # 回退:从目标表存储获取 + print( + f"[{class_name}] stock_basic.csv not available, falling back to {self.table_name} data..." + ) + cached_data = self._load_cached_data() + if not cached_data.empty and "ts_code" in cached_data.columns: + codes = cached_data["ts_code"].unique().tolist() + print( + f"[{class_name}] Found {len(codes)} stock codes from {self.table_name} data" + ) + return codes + + print(f"[{class_name}] No stock codes found in local storage") + return [] + + def get_global_last_date(self) -> Optional[str]: + """获取全局最后交易日期。 + + Returns: + 最后交易日期字符串,若无数据则返回 None + """ + data = self._load_cached_data() + if data.empty or "trade_date" not in data.columns: + return None + return str(data["trade_date"].max()) + + def get_global_first_date(self) -> Optional[str]: + """获取全局最早交易日期。 + + Returns: + 最早交易日期字符串,若无数据则返回 None + """ + data = self._load_cached_data() + if data.empty or "trade_date" not in data.columns: + return None + return str(data["trade_date"].min()) + + def get_trade_calendar_bounds( + self, start_date: str, end_date: str + ) -> tuple[Optional[str], Optional[str]]: + """从交易日历获取首尾交易日。 + + Args: + start_date: 开始日期(YYYYMMDD 格式) + end_date: 结束日期(YYYYMMDD 格式) + + Returns: + (首交易日, 尾交易日) 元组,若出错则返回 (None, None) + """ + try: + first_day = get_first_trading_day(start_date, end_date) + last_day = get_last_trading_day(start_date, end_date) + return (first_day, last_day) + except Exception as e: + print(f"[ERROR] Failed to get trade calendar bounds: {e}") + return (None, None) + + def check_sync_needed( + self, + force_full: bool = False, + ) -> tuple[bool, Optional[str], Optional[str], Optional[str]]: + """基于交易日历检查是否需要同步。 + + 逻辑: + - 若 force_full:需要同步,返回 (True, default_start, today) + - 若无本地数据:需要同步,返回 (True, default_start, today) + - 若存在本地数据: + - 从交易日历获取最后交易日 + - 若本地最后日期 >= 日历最后日期:无需同步 + - 否则:从本地最后日期+1 到最新交易日同步 + + Args: + force_full: 若为 True,始终返回需要同步 + + Returns: + (需要同步, 起始日期, 结束日期, 本地最后日期) + """ + class_name = self.__class__.__name__ + + if force_full: + print(f"[{class_name}] Force full sync requested") + return (True, self.DEFAULT_START_DATE, get_today_date(), None) + + # 检查本地数据是否存在 + storage = Storage() + table_data = ( + storage.load(self.table_name) + if storage.exists(self.table_name) + else pd.DataFrame() + ) + + if table_data.empty or "trade_date" not in table_data.columns: + print( + f"[{class_name}] No local data found for table '{self.table_name}', full sync needed" + ) + return (True, self.DEFAULT_START_DATE, get_today_date(), None) + + # 获取本地数据最后日期 + local_last_date = str(table_data["trade_date"].max()) + print(f"[{class_name}] Local data last date: {local_last_date}") + + # 从交易日历获取最新交易日 + today = get_today_date() + _, cal_last = self.get_trade_calendar_bounds(self.DEFAULT_START_DATE, today) + + if cal_last is None: + print(f"[{class_name}] Failed to get trade calendar, proceeding with sync") + return (True, self.DEFAULT_START_DATE, today, local_last_date) + + print(f"[{class_name}] Calendar last trading day: {cal_last}") + + # 比较本地最后日期与日历最后日期 + try: + local_last_int = int(local_last_date) + cal_last_int = int(cal_last) + if local_last_int >= cal_last_int: + print( + f"[{class_name}] Local data is up-to-date, SKIPPING sync (no tokens consumed)" + ) + return (False, None, None, None) + except (ValueError, TypeError) as e: + print(f"[ERROR] Date comparison failed: {e}") + + # 需要从本地最后日期+1 同步到最新交易日 + sync_start = get_next_date(local_last_date) + print(f"[{class_name}] Incremental sync needed from {sync_start} to {cal_last}") + return (True, sync_start, cal_last, local_last_date) + + def _determine_sync_mode( + self, + force_full: bool, + sync_start_date: str, + local_last: Optional[str], + ) -> str: + """确定同步模式。 + + Args: + force_full: 是否强制全量同步 + sync_start_date: 同步起始日期 + local_last: 本地最后日期 + + Returns: + 同步模式: 'full', 'incremental', 或 'partial' + """ + if force_full: + return "full" + elif local_last and sync_start_date == get_next_date(local_last): + return "incremental" + else: + return "partial" + + @abstractmethod + def preview_sync( + self, + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + sample_size: int = 3, + ) -> dict: + """预览同步数据量和样本(子类必须实现)。""" + pass + + @abstractmethod + def sync_all( + self, + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + max_workers: Optional[int] = None, + dry_run: bool = False, + ) -> Dict[str, pd.DataFrame]: + """执行同步(子类必须实现)。""" + pass + + +class StockBasedSync(BaseDataSync): + """按股票同步的抽象类。 + + 适用于需要按股票代码逐个获取数据的接口(如 daily, pro_bar)。 + 提供多线程并发获取能力。 + + 子类必须实现: + - fetch_single_stock(ts_code, start_date, end_date) -> pd.DataFrame + """ + + @abstractmethod + def fetch_single_stock( + self, + ts_code: str, + start_date: str, + end_date: str, + ) -> pd.DataFrame: + """获取单只股票的数据。 + + Args: + ts_code: 股票代码 + start_date: 起始日期(YYYYMMDD) + end_date: 结束日期(YYYYMMDD) + + Returns: + 包含股票数据的 DataFrame + """ + pass + + def _sync_single_stock_wrapper( + self, + ts_code: str, + start_date: str, + end_date: str, + ) -> pd.DataFrame: + """单只股票同步的包装器(处理停止标志和异常)。 + + Args: + ts_code: 股票代码 + start_date: 起始日期 + end_date: 结束日期 + + Returns: + 股票数据 DataFrame + """ + if not self._stop_flag.is_set(): + return pd.DataFrame() + + try: + data = self.fetch_single_stock(ts_code, start_date, end_date) + return data + except Exception as e: + self._stop_flag.clear() + print(f"[ERROR] Exception syncing {ts_code}: {e}") + raise + + def _run_concurrent_sync( + self, + stock_codes: List[str], + start_date: str, + end_date: str, + max_workers: Optional[int] = None, + dry_run: bool = False, + ) -> Dict[str, pd.DataFrame]: + """执行多线程并发同步。 + + Args: + stock_codes: 股票代码列表 + start_date: 同步起始日期 + end_date: 同步结束日期 + max_workers: 工作线程数 + dry_run: 是否为预览模式 + + Returns: + 映射 ts_code 到 DataFrame 的字典 + """ + class_name = self.__class__.__name__ + + if dry_run: + print(f"\n{'=' * 60}") + print(f"[{class_name}] DRY RUN MODE - No data will be written") + print(f"{'=' * 60}") + print(f" Would sync {len(stock_codes)} stocks") + print(f" Date range: {start_date} to {end_date}") + print(f"{'=' * 60}") + return {} + + # 为新同步重置停止标志 + self._stop_flag.set() + + results: Dict[str, pd.DataFrame] = {} + error_occurred = False + exception_to_raise = None + + def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]: + """每只股票的任务函数。""" + data = self._sync_single_stock_wrapper( + ts_code=ts_code, + start_date=start_date, + end_date=end_date, + ) + return (ts_code, data) + + workers = max_workers or self.max_workers + with ThreadPoolExecutor(max_workers=workers) as executor: + future_to_code = { + executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes + } + + error_count = 0 + empty_count = 0 + success_count = 0 + + pbar = tqdm( + total=len(stock_codes), desc=f"Syncing {self.table_name} stocks" + ) + + try: + for future in as_completed(future_to_code): + ts_code = future_to_code[future] + + try: + _, data = future.result() + if data is not None and not data.empty: + results[ts_code] = data + success_count += 1 + else: + empty_count += 1 + print( + f"[{class_name}] Stock {ts_code}: empty data (skipped, may be delisted)" + ) + except Exception as e: + error_occurred = True + exception_to_raise = e + print(f"\n[ERROR] Sync aborted due to exception: {e}") + executor.shutdown(wait=False, cancel_futures=True) + raise exception_to_raise + + pbar.update(1) + + except Exception: + error_count = 1 + print(f"[{class_name}] Sync stopped due to exception") + finally: + pbar.close() + + # 批量写入所有数据(仅在无错误时) + if results and not error_occurred: + for ts_code, data in results.items(): + if not data.empty: + self.storage.queue_save(self.table_name, data) + self.storage.flush() + total_rows = sum(len(df) for df in results.values()) + print(f"\n[{class_name}] Saved {total_rows} rows to storage") + + # 打印摘要 + print(f"\n{'=' * 60}") + print(f"[{class_name}] Sync Summary") + print(f"{'=' * 60}") + print(f" Total stocks: {len(stock_codes)}") + print(f" Updated: {success_count}") + print(f" Skipped (empty/delisted): {empty_count}") + print( + f" Errors: {error_count} (aborted on first error)" + if error_count + else " Errors: 0" + ) + print(f" Date range: {start_date} to {end_date}") + print(f"{'=' * 60}") + + return results + + def preview_sync( + self, + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + sample_size: int = 3, + ) -> dict: + """预览同步数据量和样本。 + + Args: + force_full: 若为 True,预览全量同步 + start_date: 手动指定起始日期 + end_date: 手动指定结束日期 + sample_size: 预览用样本股票数量 + + Returns: + 包含预览信息的字典 + """ + class_name = self.__class__.__name__ + + print(f"\n{'=' * 60}") + print(f"[{class_name}] Preview Mode - Analyzing sync requirements...") + print(f"{'=' * 60}") + + # 确保交易日历缓存是最新的 + print(f"[{class_name}] Syncing trade calendar cache...") + sync_trade_cal_cache() + + if end_date is None: + end_date = get_today_date() + + # 检查是否需要同步 + sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) + + if not sync_needed: + print(f"\n{'=' * 60}") + print(f"[{class_name}] Preview Result") + print(f"{'=' * 60}") + print(" Sync Status: NOT NEEDED") + print(" Reason: Local data is up-to-date with trade calendar") + print(f"{'=' * 60}") + return { + "sync_needed": False, + "stock_count": 0, + "start_date": None, + "end_date": None, + "estimated_records": 0, + "sample_data": pd.DataFrame(), + "mode": "none", + } + + # 使用 check_sync_needed 返回的日期 + if cal_start and cal_end: + sync_start_date = cal_start + end_date = cal_end + else: + sync_start_date = start_date or self.DEFAULT_START_DATE + if end_date is None: + end_date = get_today_date() + + # 确定同步模式 + mode = self._determine_sync_mode(force_full, sync_start_date, local_last) + + if mode == "full": + print( + f"[{class_name}] Mode: FULL SYNC from {sync_start_date} to {end_date}" + ) + elif mode == "incremental": + print(f"[{class_name}] Mode: INCREMENTAL SYNC (bandwidth optimized)") + print(f"[{class_name}] Sync from: {sync_start_date} to {end_date}") + else: + print(f"[{class_name}] Mode: SYNC from {sync_start_date} to {end_date}") + + # 获取所有股票代码 + stock_codes = self.get_all_stock_codes() + if not stock_codes: + print(f"[{class_name}] No stocks found to sync") + return { + "sync_needed": False, + "stock_count": 0, + "start_date": None, + "end_date": None, + "estimated_records": 0, + "sample_data": pd.DataFrame(), + "mode": "none", + } + + stock_count = len(stock_codes) + print(f"[{class_name}] Total stocks to sync: {stock_count}") + + # 从前几只股票获取样本数据 + print(f"[{class_name}] Fetching sample data from {sample_size} stocks...") + sample_data_list = [] + sample_codes = stock_codes[:sample_size] + + for ts_code in sample_codes: + try: + data = self.fetch_single_stock(ts_code, sync_start_date, end_date) + if not data.empty: + sample_data_list.append(data) + print(f" - {ts_code}: {len(data)} records") + except Exception as e: + print(f" - {ts_code}: Error fetching - {e}") + + # 合并样本数据 + sample_df = ( + pd.concat(sample_data_list, ignore_index=True) + if sample_data_list + else pd.DataFrame() + ) + + # 基于样本估算总记录数 + if not sample_df.empty: + avg_records_per_stock = len(sample_df) / len(sample_data_list) + estimated_records = int(avg_records_per_stock * stock_count) + else: + estimated_records = 0 + + # 显示预览结果 + print(f"\n{'=' * 60}") + print(f"[{class_name}] Preview Result") + print(f"{'=' * 60}") + print(f" Sync Mode: {mode.upper()}") + print(f" Date Range: {sync_start_date} to {end_date}") + print(f" Stocks to Sync: {stock_count}") + print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}") + print(f" Estimated Total Records: ~{estimated_records:,}") + + if not sample_df.empty: + print(f"\n Sample Data Preview (first {len(sample_df)} rows):") + print(" " + "-" * 56) + preview_cols = [ + "ts_code", + "trade_date", + "open", + "high", + "low", + "close", + "vol", + ] + available_cols = [c for c in preview_cols if c in sample_df.columns] + sample_display = sample_df[available_cols].head(10) + for idx, row in sample_display.iterrows(): + print(f" {row.to_dict()}") + print(" " + "-" * 56) + + print(f"{'=' * 60}") + + return { + "sync_needed": True, + "stock_count": stock_count, + "start_date": sync_start_date, + "end_date": end_date, + "estimated_records": estimated_records, + "sample_data": sample_df, + "mode": mode, + } + + def sync_all( + self, + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + max_workers: Optional[int] = None, + dry_run: bool = False, + ) -> Dict[str, pd.DataFrame]: + """同步所有股票的数据。 + + Args: + force_full: 若为 True,强制完整重载 + start_date: 手动指定起始日期 + end_date: 手动指定结束日期 + max_workers: 工作线程数 + dry_run: 若为 True,仅预览 + + Returns: + 映射 ts_code 到 DataFrame 的字典 + """ + class_name = self.__class__.__name__ + + print(f"\n{'=' * 60}") + print(f"[{class_name}] Starting {self.table_name} data sync...") + print(f"{'=' * 60}") + + # 首先确保交易日历缓存是最新的 + print(f"[{class_name}] Syncing trade calendar cache...") + sync_trade_cal_cache() + + if end_date is None: + end_date = get_today_date() + + # 基于交易日历检查是否需要同步 + sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) + + if not sync_needed: + print(f"\n{'=' * 60}") + print(f"[{class_name}] Sync Summary") + print(f"{'=' * 60}") + print(" Sync: SKIPPED (local data up-to-date with trade calendar)") + print(" Tokens saved: 0 consumed") + print(f"{'=' * 60}") + return {} + + # 使用 check_sync_needed 返回的日期 + if cal_start and cal_end: + sync_start_date = cal_start + end_date = cal_end + else: + sync_start_date = start_date or self.DEFAULT_START_DATE + if end_date is None: + end_date = get_today_date() + + # 确定同步模式 + mode = self._determine_sync_mode(force_full, sync_start_date, local_last) + + if mode == "full": + print( + f"[{class_name}] Mode: FULL SYNC from {sync_start_date} to {end_date}" + ) + elif mode == "incremental": + print(f"[{class_name}] Mode: INCREMENTAL SYNC (bandwidth optimized)") + print(f"[{class_name}] Sync from: {sync_start_date} to {end_date}") + else: + print(f"[{class_name}] Mode: SYNC from {sync_start_date} to {end_date}") + + # 获取所有股票代码 + stock_codes = self.get_all_stock_codes() + if not stock_codes: + print(f"[{class_name}] No stocks found to sync") + return {} + + print(f"[{class_name}] Total stocks to sync: {len(stock_codes)}") + print(f"[{class_name}] Using {max_workers or self.max_workers} worker threads") + + # 执行并发同步 + return self._run_concurrent_sync( + stock_codes=stock_codes, + start_date=sync_start_date, + end_date=end_date, + max_workers=max_workers, + dry_run=dry_run, + ) + + +class DateBasedSync(BaseDataSync): + """按日期同步的抽象类。 + + 适用于需要按日期逐个获取数据的接口(如 bak_basic)。 + 提供顺序遍历日期范围的能力。 + + 子类必须实现: + - fetch_single_date(trade_date) -> pd.DataFrame + - default_start_date: 类属性,指定默认起始日期 + """ + + default_start_date: str = "20160101" # 子类可覆盖 + + @abstractmethod + def fetch_single_date(self, trade_date: str) -> pd.DataFrame: + """获取单日的数据。 + + Args: + trade_date: 交易日期(YYYYMMDD) + + Returns: + 包含当日数据的 DataFrame + """ + pass + + def _ensure_table_schema(self, sample_data: pd.DataFrame) -> None: + """确保表结构存在(根据样本数据创建表)。 + + Args: + sample_data: 样本数据 DataFrame + """ + storage = Storage() + + if sample_data.empty: + return + + # 构建列定义 + columns = [] + for col in sample_data.columns: + dtype = str(sample_data[col].dtype) + if col == "trade_date": + col_type = "DATE" + elif "int" in dtype: + col_type = "INTEGER" + elif "float" in dtype: + col_type = "DOUBLE" + else: + col_type = "VARCHAR" + columns.append(f'"{col}" {col_type}') + + columns_sql = ", ".join(columns) + create_sql = f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ({columns_sql}, PRIMARY KEY ("trade_date", "ts_code"))' + + try: + storage._connection.execute(create_sql) + print(f"[{self.__class__.__name__}] Created table '{self.table_name}'") + except Exception as e: + print(f"[{self.__class__.__name__}] Error creating table: {e}") + + # 创建复合索引 + try: + storage._connection.execute(f""" + CREATE INDEX IF NOT EXISTS "idx_{self.table_name}_date_code" + ON "{self.table_name}"("trade_date", "ts_code") + """) + print( + f"[{self.__class__.__name__}] Created composite index on (trade_date, ts_code)" + ) + except Exception as e: + print(f"[{self.__class__.__name__}] Error creating index: {e}") + + def _get_sync_date_range( + self, + start_date: Optional[str], + end_date: Optional[str], + force_full: bool, + ) -> tuple[str, str, str]: + """确定同步日期范围。 + + Args: + start_date: 指定的起始日期 + end_date: 指定的结束日期 + force_full: 是否强制全量同步 + + Returns: + (sync_start, sync_end, mode) 元组 + """ + class_name = self.__class__.__name__ + storage = Storage() + + # 默认结束日期 + if end_date is None: + end_date = datetime.now().strftime("%Y%m%d") + + # 检查表是否存在 + table_exists = storage.exists(self.table_name) + + if not table_exists or force_full: + # 全量同步 + sync_start = start_date or self.default_start_date + mode = "FULL" + print(f"[{class_name}] Mode: {mode} SYNC from {sync_start} to {end_date}") + else: + # 增量同步 + try: + result = storage._connection.execute( + f'SELECT MAX("trade_date") FROM "{self.table_name}"' + ).fetchone() + last_date = result[0] if result and result[0] else None + except Exception as e: + print(f"[{class_name}] Error getting last date: {e}") + last_date = None + + if last_date is None: + sync_start = start_date or self.default_start_date + mode = "FULL (empty table)" + else: + # 从最后日期+1开始 + last_date_str = str(last_date).replace("-", "") + last_dt = datetime.strptime(last_date_str, "%Y%m%d") + next_dt = last_dt + timedelta(days=1) + sync_start = next_dt.strftime("%Y%m%d") + mode = "INCREMENTAL" + + # 检查是否已最新 + if sync_start > end_date: + print( + f"[{class_name}] Data is up-to-date (last: {last_date}), skipping sync" + ) + return ("", "", "up_to_date") + + print( + f"[{class_name}] Mode: {mode} from {sync_start} to {end_date} (last: {last_date})" + ) + + return (sync_start, end_date, mode) + + def _run_date_range_sync( + self, + start_date: str, + end_date: str, + dry_run: bool = False, + ) -> pd.DataFrame: + """执行日期范围的顺序同步。 + + Args: + start_date: 起始日期(YYYYMMDD) + end_date: 结束日期(YYYYMMDD) + dry_run: 是否为预览模式 + + Returns: + 合并后的数据 DataFrame + """ + class_name = self.__class__.__name__ + + if dry_run or not start_date: + return pd.DataFrame() + + all_data: List[pd.DataFrame] = [] + current = datetime.strptime(start_date, "%Y%m%d") + end_dt = datetime.strptime(end_date, "%Y%m%d") + + # 计算总天数用于进度条 + total_days = (end_dt - current).days + 1 + print(f"[{class_name}] Fetching data for {total_days} days...") + + with tqdm(total=total_days, desc="Syncing dates") as pbar: + while current <= end_dt: + date_str = current.strftime("%Y%m%d") + try: + data = self.fetch_single_date(date_str) + if not data.empty: + all_data.append(data) + pbar.set_postfix({"date": date_str, "records": len(data)}) + except Exception as e: + print(f" {date_str}: ERROR - {e}") + + current += timedelta(days=1) + pbar.update(1) + + if not all_data: + print(f"[{class_name}] No data fetched") + return pd.DataFrame() + + # 合并数据 + combined = pd.concat(all_data, ignore_index=True) + + # 转换 trade_date 为 datetime + if "trade_date" in combined.columns: + combined["trade_date"] = pd.to_datetime( + combined["trade_date"], format="%Y%m%d" + ) + + print(f"[{class_name}] Total records: {len(combined)}") + + return combined + + def _save_data(self, data: pd.DataFrame, sync_start_date: str) -> None: + """保存数据到存储。 + + Args: + data: 要保存的数据 + sync_start_date: 同步起始日期(用于删除旧数据) + """ + if data.empty: + return + + storage = Storage() + + # 删除日期范围内的旧数据 + if sync_start_date: + sync_start_date_fmt = pd.to_datetime( + sync_start_date, format="%Y%m%d" + ).date() + storage._connection.execute( + f'DELETE FROM "{self.table_name}" WHERE "trade_date" >= ?', + [sync_start_date_fmt], + ) + + # 保存新数据 + self.storage.queue_save(self.table_name, data) + self.storage.flush() + + print(f"[{self.__class__.__name__}] Saved {len(data)} records to DuckDB") + + def preview_sync( + self, + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + sample_size: int = 3, + ) -> dict: + """预览同步数据量和样本。 + + Args: + force_full: 若为 True,预览全量同步 + start_date: 手动指定起始日期 + end_date: 手动指定结束日期 + sample_size: 预览天数 + + Returns: + 包含预览信息的字典 + """ + class_name = self.__class__.__name__ + + print(f"\n{'=' * 60}") + print(f"[{class_name}] Preview Mode - Analyzing sync requirements...") + print(f"{'=' * 60}") + + # 确定日期范围 + sync_start, sync_end, mode = self._get_sync_date_range( + start_date, end_date, force_full + ) + + if mode == "up_to_date": + return { + "sync_needed": False, + "date_count": 0, + "start_date": None, + "end_date": None, + "estimated_records": 0, + "sample_data": pd.DataFrame(), + "mode": "none", + } + + # 计算天数 + if sync_start and sync_end: + start_dt = datetime.strptime(sync_start, "%Y%m%d") + end_dt = datetime.strptime(sync_end, "%Y%m%d") + date_count = (end_dt - start_dt).days + 1 + else: + date_count = 0 + + # 获取样本数据(取前几天) + print( + f"[{class_name}] Fetching sample data for {min(sample_size, date_count)} days..." + ) + sample_data_list = [] + + if sync_start and date_count > 0: + current = datetime.strptime(sync_start, "%Y%m%d") + for _ in range(min(sample_size, date_count)): + date_str = current.strftime("%Y%m%d") + try: + data = self.fetch_single_date(date_str) + if not data.empty: + sample_data_list.append(data) + print(f" - {date_str}: {len(data)} records") + except Exception as e: + print(f" - {date_str}: Error fetching - {e}") + current += timedelta(days=1) + + sample_df = ( + pd.concat(sample_data_list, ignore_index=True) + if sample_data_list + else pd.DataFrame() + ) + + # 估算总记录数 + if not sample_df.empty and date_count > 0: + avg_records_per_day = len(sample_df) / len(sample_data_list) + estimated_records = int(avg_records_per_day * date_count) + else: + estimated_records = 0 + + # 显示预览结果 + print(f"\n{'=' * 60}") + print(f"[{class_name}] Preview Result") + print(f"{'=' * 60}") + print(f" Sync Mode: {mode}") + print(f" Date Range: {sync_start} to {sync_end}") + print(f" Days to Sync: {date_count}") + print(f" Sample Days Checked: {len(sample_data_list)}") + print(f" Estimated Total Records: ~{estimated_records:,}") + + if not sample_df.empty: + print(f"\n Sample Data Preview:") + print(" " + "-" * 56) + print(f" Columns: {list(sample_df.columns)}") + print(f" Rows: {len(sample_df)}") + print(" " + "-" * 56) + + print(f"{'=' * 60}") + + return { + "sync_needed": mode != "up_to_date", + "date_count": date_count, + "start_date": sync_start, + "end_date": sync_end, + "estimated_records": estimated_records, + "sample_data": sample_df, + "mode": mode.lower().replace(" ", "_"), + } + + def sync_all( + self, + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + max_workers: Optional[int] = None, + dry_run: bool = False, + ) -> pd.DataFrame: + """同步日期范围内的所有数据。 + + Args: + force_full: 若为 True,强制完整重载 + start_date: 手动指定起始日期 + end_date: 手动指定结束日期 + max_workers: 工作线程数(本类中不使用,保持接口一致) + dry_run: 若为 True,仅预览 + + Returns: + 同步的数据 DataFrame + """ + class_name = self.__class__.__name__ + + print(f"\n{'=' * 60}") + print(f"[{class_name}] Starting {self.table_name} data sync...") + print(f"{'=' * 60}") + + # 确定日期范围 + sync_start, sync_end, mode = self._get_sync_date_range( + start_date, end_date, force_full + ) + + if mode == "up_to_date": + return pd.DataFrame() + + if dry_run: + print(f"\n{'=' * 60}") + print(f"[{class_name}] DRY RUN MODE - No data will be written") + print(f"{'=' * 60}") + print(f" Would sync from {sync_start} to {sync_end}") + print(f" Mode: {mode}") + print(f"{'=' * 60}") + return pd.DataFrame() + + # 检查表是否存在,不存在则创建 + storage = Storage() + if not storage.exists(self.table_name): + print( + f"[{class_name}] Table '{self.table_name}' doesn't exist, creating..." + ) + # 获取样本数据以推断 schema + sample = self.fetch_single_date(sync_end) + if sample.empty: + # 尝试另一个日期 + sample = self.fetch_single_date("20240102") + if not sample.empty: + self._ensure_table_schema(sample) + else: + print(f"[{class_name}] Cannot create table: no sample data available") + return pd.DataFrame() + + # 执行同步 + combined = self._run_date_range_sync(sync_start, sync_end, dry_run) + + # 保存数据 + if not combined.empty: + self._save_data(combined, sync_start) + + # 打印摘要 + print(f"\n{'=' * 60}") + print(f"[{class_name}] Sync Summary") + print(f"{'=' * 60}") + print(f" Mode: {mode}") + print(f" Date range: {sync_start} to {sync_end}") + print(f" Records: {len(combined)}") + print(f"{'=' * 60}") + + return combined diff --git a/src/data/storage.py b/src/data/storage.py index 1b83d47..07dcd66 100644 --- a/src/data/storage.py +++ b/src/data/storage.py @@ -189,7 +189,10 @@ class Storage: end_net_profit DOUBLE, update_flag VARCHAR(1), PRIMARY KEY (ts_code, end_date) + update_flag VARCHAR(1), + PRIMARY KEY (ts_code, end_date) ) + """) # Create pro_bar table for pro bar data (with adj, tor, vr) self._connection.execute(""" diff --git a/src/data/sync.py b/src/data/sync.py index 0315d46..e4ba33b 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -3,7 +3,8 @@ 该模块作为数据同步的调度中心,统一管理各类型数据的同步流程。 具体的同步逻辑已迁移到对应的 api_xxx.py 文件中: - api_daily.py: 日线数据同步 (DailySync 类) -- api_bak_basic.py: 历史股票列表同步 +- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类) +- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类) - api_stock_basic.py: 股票基本信息同步 - api_trade_cal.py: 交易日历同步 @@ -30,6 +31,7 @@ import pandas as pd from src.data.api_wrappers import sync_all_stocks from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync from src.data.api_wrappers.api_pro_bar import sync_pro_bar +from src.data.api_wrappers.api_bak_basic import sync_bak_basic def preview_sync( @@ -135,30 +137,31 @@ def sync_all_data( dry_run: bool = False, ) -> Dict[str, pd.DataFrame]: """同步所有数据类型(每日同步)。 - 该函数按顺序同步所有可用的数据类型: - 1. 交易日历 (sync_trade_cal_cache) - 2. 股票基本信息 (sync_all_stocks) - 3. 日线市场数据 (sync_all) - 4. 历史股票列表 (sync_bak_basic) - 注意:名称变更 (namechange) 不在自动同步中,如需同步请手动调用。 + 该函数按顺序同步所有可用的数据类型: + 1. 交易日历 (sync_trade_cal_cache) + 2. 股票基本信息 (sync_all_stocks) + 3. Pro Bar 数据 (sync_pro_bar) + 4. 历史股票列表 (sync_bak_basic) - Args: - force_full: 若为 True,强制所有数据类型完整重载 - max_workers: 日线数据同步的工作线程数(默认: 10) - dry_run: 若为 True,仅显示将要同步的内容,不写入数据 + 注意:名称变更 (namechange) 不在自动同步中,如需同步请手动调用。 - Returns: - 映射数据类型到同步结果的字典 + Args: + force_full: 若为 True,强制所有数据类型完整重载 + max_workers: 日线数据同步的工作线程数(默认: 10) + dry_run: 若为 True,仅显示将要同步的内容,不写入数据 - Example: - >>> result = sync_all_data() - >>> - >>> # 强制完整重载 - >>> result = sync_all_data(force_full=True) - >>> - >>> # Dry run - >>> result = sync_all_data(dry_run=True) + Returns: + 映射数据类型到同步结果的字典 + + Example: + >>> result = sync_all_data() + >>> + >>> # 强制完整重载 + >>> result = sync_all_data(force_full=True) + >>> + >>> # Dry run + >>> result = sync_all_data(dry_run=True) """ results: Dict[str, pd.DataFrame] = {} @@ -167,47 +170,29 @@ def sync_all_data( print("=" * 60) # 1. Sync trade calendar (always needed first) - print("\n[1/6] Syncing trade calendar cache...") + print("\n[1/4] Syncing trade calendar cache...") try: from src.data.api_wrappers import sync_trade_cal_cache sync_trade_cal_cache() results["trade_cal"] = pd.DataFrame() - print("[1/6] Trade calendar: OK") + print("[1/4] Trade calendar: OK") except Exception as e: - print(f"[1/6] Trade calendar: FAILED - {e}") + print(f"[1/4] Trade calendar: FAILED - {e}") results["trade_cal"] = pd.DataFrame() # 2. Sync stock basic info - print("\n[2/6] Syncing stock basic info...") + print("\n[2/4] Syncing stock basic info...") try: sync_all_stocks() results["stock_basic"] = pd.DataFrame() - print("[2/6] Stock basic: OK") + print("[2/4] Stock basic: OK") except Exception as e: - print(f"[2/6] Stock basic: FAILED - {e}") + print(f"[2/4] Stock basic: FAILED - {e}") results["stock_basic"] = pd.DataFrame() - # # 3. Sync daily market data - # print("\n[3/6] Syncing daily market data...") - # try: - # daily_result = sync_daily( - # force_full=force_full, - # max_workers=max_workers, - # dry_run=dry_run, - # ) - # results["daily"] = ( - # pd.concat(daily_result.values(), ignore_index=True) - # if daily_result - # else pd.DataFrame() - # ) - # print("[3/6] Daily data: OK") - # except Exception as e: - # print(f"[3/6] Daily data: FAILED - {e}") - # results["daily"] = pd.DataFrame() - - # 4. Sync Pro Bar data - print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...") + # 3. Sync Pro Bar data + print("\n[3/4] Syncing Pro Bar data (with adj, tor, vr)...") try: pro_bar_result = sync_pro_bar( force_full=force_full, @@ -219,87 +204,19 @@ def sync_all_data( if pro_bar_result else pd.DataFrame() ) - print(f"[4/6] Pro Bar data: OK ({len(results['pro_bar'])} records)") + print(f"[3/4] Pro Bar data: OK ({len(results['pro_bar'])} records)") except Exception as e: - print(f"[4/6] Pro Bar data: FAILED - {e}") + print(f"[3/4] Pro Bar data: FAILED - {e}") results["pro_bar"] = pd.DataFrame() - # 5. Sync stock historical list (bak_basic) - print("\n[5/6] Syncing stock historical list (bak_basic)...") - try: - bak_basic_result = sync_bak_basic(force_full=force_full) - results["bak_basic"] = bak_basic_result - print(f"[5/6] Bak basic: OK ({len(bak_basic_result)} records)") - except Exception as e: - print(f"[5/6] Bak basic: FAILED - {e}") - results["bak_basic"] = pd.DataFrame() - - # Summary - print("\n" + "=" * 60) - print("[sync_all_data] Sync Summary") - print("=" * 60) - for data_type, df in results.items(): - print(f" {data_type}: {len(df)} records") - print("=" * 60) - print("\nNote: namechange is NOT in auto-sync. To sync manually:") - print(" from src.data.api_wrappers import sync_namechange") - print(" sync_namechange(force=True)") - - return results - results: Dict[str, pd.DataFrame] = {} - - print("\n" + "=" * 60) - print("[sync_all_data] Starting full data synchronization...") - print("=" * 60) - - # 1. Sync trade calendar (always needed first) - print("\n[1/5] Syncing trade calendar cache...") - try: - from src.data.api_wrappers import sync_trade_cal_cache - - sync_trade_cal_cache() - results["trade_cal"] = pd.DataFrame() - print("[1/5] Trade calendar: OK") - except Exception as e: - print(f"[1/5] Trade calendar: FAILED - {e}") - results["trade_cal"] = pd.DataFrame() - - # 2. Sync stock basic info - print("\n[2/5] Syncing stock basic info...") - try: - sync_all_stocks() - results["stock_basic"] = pd.DataFrame() - print("[2/5] Stock basic: OK") - except Exception as e: - print(f"[2/5] Stock basic: FAILED - {e}") - results["stock_basic"] = pd.DataFrame() - - # 3. Sync daily market data - print("\n[3/5] Syncing daily market data...") - try: - daily_result = sync_daily( - force_full=force_full, - max_workers=max_workers, - dry_run=dry_run, - ) - results["daily"] = ( - pd.concat(daily_result.values(), ignore_index=True) - if daily_result - else pd.DataFrame() - ) - print("[3/5] Daily data: OK") - except Exception as e: - print(f"[3/5] Daily data: FAILED - {e}") - results["daily"] = pd.DataFrame() - # 4. Sync stock historical list (bak_basic) - print("\n[4/5] Syncing stock historical list (bak_basic)...") + print("\n[4/4] Syncing stock historical list (bak_basic)...") try: bak_basic_result = sync_bak_basic(force_full=force_full) results["bak_basic"] = bak_basic_result - print(f"[4/5] Bak basic: OK ({len(bak_basic_result)} records)") + print(f"[4/4] Bak basic: OK ({len(bak_basic_result)} records)") except Exception as e: - print(f"[4/5] Bak basic: FAILED - {e}") + print(f"[4/4] Bak basic: FAILED - {e}") results["bak_basic"] = pd.DataFrame() # Summary @@ -316,10 +233,6 @@ def sync_all_data( return results -# 保留向后兼容的导入 -from src.data.api_wrappers import sync_bak_basic - - if __name__ == "__main__": print("=" * 60) print("Data Sync Module")