diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py index d1b29fc..d5dbb2f 100644 --- a/src/data/api_wrappers/__init__.py +++ b/src/data/api_wrappers/__init__.py @@ -88,6 +88,7 @@ __all__ = [ # Historical stock list "get_bak_basic", "sync_bak_basic", + "BakBasicSync", # Namechange "get_namechange", "sync_namechange", @@ -105,3 +106,77 @@ __all__ = [ "sync_stock_st", "StockSTSync", ] + +# ============================================================================= +# 自动注册同步任务到 SyncRegistry +# 这样 sync.py 不需要手动罗列各个接口 +# ============================================================================= + +try: + from src.data.sync_registry import sync_registry + + # 1. Trade Calendar - 最高优先级,其他任务可能依赖 + sync_registry.register_func( + name="trade_cal", + sync_func=sync_trade_cal_cache, + display_name="交易日历", + description="交易日期缓存", + order=1, + ) + + # 2. Stock Basic - 基础数据 + sync_registry.register_func( + name="stock_basic", + sync_func=sync_all_stocks, + display_name="股票基本信息", + description="所有上市/退市股票的基础信息", + order=2, + ) + + # 3. Pro Bar - 通用行情(推荐用于替代日线) + from src.data.api_wrappers.api_pro_bar import ProBarSync + + sync_registry.register_class( + name="pro_bar", + sync_class=ProBarSync, + display_name="Pro Bar 数据", + description="包含复权因子、换手率、量比的数据", + order=10, + ) + + # 4. Daily Basic - 每日指标 + from src.data.api_wrappers.api_daily_basic import DailyBasicSync + + sync_registry.register_class( + name="daily_basic", + sync_class=DailyBasicSync, + display_name="每日指标", + description="市盈率、市净率、换手率、市值等指标", + order=20, + ) + + # 5. Bak Basic - 历史股票列表 + from src.data.api_wrappers.api_bak_basic import BakBasicSync + + sync_registry.register_class( + name="bak_basic", + sync_class=BakBasicSync, + display_name="历史股票列表", + description="历史股票列表(包含退市股票)", + order=30, + ) + + # 6. ST Stock - ST股票列表 + from src.data.api_wrappers.api_stock_st import StockSTSync + + sync_registry.register_class( + name="stock_st", + sync_class=StockSTSync, + display_name="ST股票列表", + description="ST股票历史记录", + order=40, + ) + +except ImportError: + # sync_registry 可能不存在(首次导入),忽略 + pass diff --git a/src/data/sync.py b/src/data/sync.py index 42c082c..baa6223 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -24,6 +24,10 @@ from src.data.api_wrappers import sync_namechange sync_namechange(force=True) +【架构说明】 +本模块使用 SyncRegistry 注册表模式管理同步任务,避免手动罗列各个接口。 +同步任务在 api_wrappers/__init__.py 中自动注册,新增接口无需修改 sync.py。 + 使用方式: # 预览同步(检查数据量,不写入) from src.data.sync import preview_sync @@ -35,18 +39,23 @@ # 强制全量重载 result = sync_all_data(force_full=True) + + # 查看已注册的所有同步任务 + from src.data.sync_registry import sync_registry + tasks = sync_registry.list_tasks() + for task in tasks: + print(f"{task.name}: {task.display_name}") """ from typing import Optional, Dict, Union, Any import pandas as pd +# 导入以触发自动注册 +from src.data import api_wrappers # noqa: F401 +from src.data.sync_registry import sync_registry from src.data.api_wrappers import sync_all_stocks from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync -from src.data.api_wrappers.api_pro_bar import sync_pro_bar -from src.data.api_wrappers.api_bak_basic import sync_bak_basic -from src.data.api_wrappers.api_daily_basic import sync_daily_basic -from src.data.api_wrappers.api_stock_st import sync_stock_st def preview_sync( @@ -150,19 +159,24 @@ def sync_all_data( force_full: bool = False, max_workers: Optional[int] = None, dry_run: bool = False, + selected: Optional[list[str]] = None, ) -> dict[str, Any]: """同步所有每日更新的数据类型。 【重要】本函数仅同步每日更新的数据,不包含季度/低频数据。 - 该函数按顺序同步以下每日更新的数据类型: - 1. 交易日历 (sync_trade_cal_cache) - 2. 股票基本信息 (sync_all_stocks) - 3. 日线数据 (sync_daily) - 4. Pro Bar 数据 (sync_pro_bar) - 5. 每日指标数据 (sync_daily_basic) - 6. 历史股票列表 (sync_bak_basic) - 7. ST股票列表 (sync_stock_st) + 【自动注册机制】 + 同步任务在 api_wrappers/__init__.py 中自动注册到 SyncRegistry。 + 当前注册的同步任务(按执行顺序): + 1. trade_cal: 交易日历缓存 + 2. stock_basic: 股票基本信息 + 3. pro_bar: Pro Bar 数据(复权、换手率、量比) + 4. daily_basic: 每日指标(PE、PB、换手率、市值) + 5. bak_basic: 历史股票列表 + 6. stock_st: ST股票列表 + + 新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码, + 无需修改本函数。 【不包含的同步(需单独调用)】 - 财务数据: 利润表、资产负债表、现金流量表(季度更新) @@ -177,6 +191,8 @@ def sync_all_data( force_full: 若为 True,强制所有数据类型完整重载 max_workers: 日线数据同步的工作线程数(默认: 10) dry_run: 若为 True,仅显示将要同步的内容,不写入数据 + selected: 只同步指定的任务列表,None表示同步所有 + 例如: selected=["trade_cal", "stock_basic"] 只同步交易日历和股票基本信息 Returns: 映射数据类型到同步结果的字典 @@ -189,163 +205,73 @@ def sync_all_data( >>> >>> # Dry run >>> result = sync_all_data(dry_run=True) + >>> + >>> # 只同步特定任务 + >>> result = sync_all_data(selected=["trade_cal", "stock_basic"]) + >>> + >>> # 查看所有可用任务 + >>> from src.data.sync_registry import sync_registry + >>> tasks = sync_registry.list_tasks() + >>> for t in tasks: + ... print(f"{t.name}: {t.display_name}") """ - results: dict[str, Any] = {} - - print("\n" + "=" * 60) - print("[sync_all_data] Starting full data synchronization...") - print("=" * 60) - - # 1. Sync trade calendar (always needed first) - print("\n[1/7] Syncing trade calendar cache...") - try: - from src.data.api_wrappers import sync_trade_cal_cache - - sync_trade_cal_cache() - results["trade_cal"] = pd.DataFrame() - print("[1/7] Trade calendar: OK") - except Exception as e: - print(f"[1/7] Trade calendar: FAILED - {e}") - results["trade_cal"] = pd.DataFrame() - - # 2. Sync stock basic info - print("\n[2/7] Syncing stock basic info...") - try: - sync_all_stocks() - results["stock_basic"] = pd.DataFrame() - print("[2/7] Stock basic: OK") - except Exception as e: - print(f"[2/7] Stock basic: FAILED - {e}") - results["stock_basic"] = pd.DataFrame() - - # 3. Sync daily market data - # print("\n[3/7] Syncing daily market data...") - # try: - # # 确保表存在 - # from src.data.api_wrappers.api_daily import DailySync - # - # DailySync().ensure_table_exists() - # - # daily_result = sync_daily( - # force_full=force_full, - # max_workers=max_workers, - # dry_run=dry_run, - # ) - # results["daily"] = daily_result - # total_daily_records = ( - # sum(len(df) for df in daily_result.values()) if daily_result else 0 - # ) - # print( - # f"[3/7] Daily data: OK ({total_daily_records} records from {len(daily_result)} stocks)" - # ) - # except Exception as e: - # print(f"[3/7] Daily data: FAILED - {e}") - # results["daily"] = pd.DataFrame() - - # 4. Sync Pro Bar data - print("\n[4/7] Syncing Pro Bar data (with adj, tor, vr)...") - try: - # 确保表存在 - from src.data.api_wrappers.api_pro_bar import ProBarSync - - ProBarSync().ensure_table_exists() - - pro_bar_result = sync_pro_bar( - force_full=force_full, - max_workers=max_workers, - dry_run=dry_run, - ) - results["pro_bar"] = pro_bar_result - total_pro_bar_records = ( - sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0 - ) - print( - f"[4/7] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)" - ) - except Exception as e: - print(f"[4/7] Pro Bar data: FAILED - {e}") - results["pro_bar"] = pd.DataFrame() - - # 5. Sync daily basic indicators - print( - "\n[5/7] Syncing daily basic indicators (PE, PB, turnover rate, market value)..." + return sync_registry.sync_all( + force_full=force_full, + max_workers=max_workers, + dry_run=dry_run, + selected=selected, ) - try: - # 确保表存在 - from src.data.api_wrappers.api_daily_basic import DailyBasicSync - DailyBasicSync().ensure_table_exists() - daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run) - results["daily_basic"] = daily_basic_result - print(f"[5/7] Daily basic: OK ({len(daily_basic_result)} records)") - except Exception as e: - print(f"[5/7] Daily basic: FAILED - {e}") - results["daily_basic"] = pd.DataFrame() +def list_sync_tasks() -> list[dict[str, Any]]: + """列出所有已注册的同步任务。 - # 6. Sync stock historical list (bak_basic) - print("\n[6/7] Syncing stock historical list (bak_basic)...") - try: - # 确保表存在 - from src.data.api_wrappers.api_bak_basic import BakBasicSync + Returns: + 任务信息列表,每个任务包含 name, display_name, description, order, enabled - BakBasicSync().ensure_table_exists() - - bak_basic_result = sync_bak_basic(force_full=force_full) - results["bak_basic"] = bak_basic_result - print(f"[6/7] Bak basic: OK ({len(bak_basic_result)} records)") - except Exception as e: - print(f"[6/7] Bak basic: FAILED - {e}") - results["bak_basic"] = pd.DataFrame() - - # 7. Sync ST stock list - print("\n[7/7] Syncing ST stock list...") - try: - # 确保表存在 - from src.data.api_wrappers.api_stock_st import StockSTSync - - StockSTSync().ensure_table_exists() - - stock_st_result = sync_stock_st(force_full=force_full) - results["stock_st"] = stock_st_result - print(f"[7/7] ST stock list: OK ({len(stock_st_result)} records)") - except Exception as e: - print(f"[7/7] ST stock list: FAILED - {e}") - results["stock_st"] = pd.DataFrame() - - # Summary - print("\n" + "=" * 60) - print("[sync_all_data] Sync Summary") - print("=" * 60) - for data_type, data in results.items(): - if isinstance(data, dict): - # 日线和 Pro Bar 返回的是 dict[str, DataFrame] - total_records = sum(len(df) for df in data.values()) - print(f" {data_type}: {len(data)} stocks, {total_records} total records") - else: - # daily_basic 和 bak_basic 返回的是 DataFrame - print(f" {data_type}: {len(data)} records") - print("=" * 60) - print("\nNote: namechange is NOT in auto-sync. To sync manually:") - print(" from src.data.api_wrappers import sync_namechange") - print(" sync_namechange(force=True)") - - return results + Example: + >>> tasks = list_sync_tasks() + >>> for task in tasks: + ... print(f"{task['order']:2d}. {task['name']}: {task['display_name']}") + """ + tasks = sync_registry.list_tasks() + return [ + { + "name": t.name, + "display_name": t.display_name, + "description": t.description, + "order": t.order, + "enabled": t.enabled, + } + for t in tasks + ] if __name__ == "__main__": print("=" * 60) print("Data Sync Module") print("=" * 60) + print("\nRegistered sync tasks:") + print("-" * 60) + + tasks = list_sync_tasks() + for task in tasks: + status = "[启用]" if task["enabled"] else "[禁用]" + print(f" {status} {task['order']:2d}. {task['name']}: {task['display_name']}") + + print("-" * 60) + print(f"\nTotal: {len(tasks)} tasks") print("\nUsage:") print(" # Sync all data types at once (RECOMMENDED)") print(" from src.data.sync import sync_all_data") print(" result = sync_all_data() # Incremental sync all") print(" result = sync_all_data(force_full=True) # Full reload") print("") - print(" # Or sync individual data types:") - print(" from src.data.sync import sync_all, preview_sync") - print(" from src.data.api_wrappers import sync_daily_basic, sync_bak_basic") + print(" # Sync selected data types only") + print(" result = sync_all_data(selected=['trade_cal', 'pro_bar'])") + print("") + print(" # List all available sync tasks") + print(" tasks = list_sync_tasks()") print("") print(" # Preview before sync (recommended)") print(" preview = preview_sync()") @@ -356,10 +282,6 @@ if __name__ == "__main__": print(" # Actual sync") print(" result = sync_all() # Incremental sync") print(" result = sync_all(force_full=True) # Full reload") - print("") - print(" # bak_basic sync") - print(" result = sync_bak_basic() # Incremental sync") - print(" result = sync_bak_basic(force_full=True) # Full reload") print("\n" + "=" * 60) # Run sync_all_data by default diff --git a/src/data/sync_registry.py b/src/data/sync_registry.py new file mode 100644 index 0000000..ead58a5 --- /dev/null +++ b/src/data/sync_registry.py @@ -0,0 +1,333 @@ +"""数据同步注册表模块。 + +该模块提供统一的同步任务注册和管理机制,避免在 sync.py 中手动罗列各个接口。 + +使用方式: + # 在 api_xxx.py 文件中注册同步任务 + from src.data.sync_registry import sync_registry, SyncTask + + sync_registry.register( + SyncTask( + name="pro_bar", + display_name="Pro Bar 数据", + description="包含复权因子、换手率、量比的数据", + sync_func=lambda **kwargs: ProBarSync().sync_all(**kwargs), + preview_func=lambda **kwargs: ProBarSync().preview_sync(**kwargs), + order=10, # 执行顺序 + ) + ) + + # 在 sync.py 中统一执行 + from src.data.sync_registry import sync_registry + results = sync_registry.sync_all(force_full=False, dry_run=False) +""" + +from dataclasses import dataclass, field +from typing import Callable, Optional, Any +from collections import OrderedDict + +import pandas as pd + + +@dataclass +class SyncTask: + """同步任务定义。 + + Attributes: + name: 任务唯一标识名 + display_name: 显示名称(用于日志) + description: 任务描述 + sync_func: 同步函数,接收 force_full, dry_run 等参数 + preview_func: 预览函数(可选) + order: 执行顺序(数字越小越先执行,默认100) + enabled: 是否启用(默认True) + """ + + name: str + display_name: str + description: str + sync_func: Callable[..., pd.DataFrame | dict[str, pd.DataFrame]] + preview_func: Optional[Callable[..., Any]] = None + order: int = 100 + enabled: bool = True + + +class SyncRegistry: + """同步任务注册表。 + + 统一管理所有数据同步任务,支持自动发现和批量执行。 + + Example: + >>> registry = SyncRegistry() + >>> + >>> # 注册类方式同步器 + >>> registry.register_class("daily", DailySync, "日线数据", "股票日线行情", order=10) + >>> + >>> # 注册函数方式同步器 + >>> registry.register_func("stock_basic", sync_all_stocks, "股票基本信息", order=5) + >>> + >>> # 批量执行所有同步 + >>> results = registry.sync_all() + >>> + >>> # 只执行特定任务 + >>> results = registry.sync_selected(["stock_basic", "daily"]) + """ + + def __init__(self): + self._tasks: OrderedDict[str, SyncTask] = OrderedDict() + + def register(self, task: SyncTask) -> "SyncRegistry": + """注册同步任务。 + + Args: + task: 同步任务定义 + + Returns: + self,支持链式调用 + """ + if task.name in self._tasks: + print( + f"[SyncRegistry] Warning: Task '{task.name}' already registered, overwriting" + ) + + self._tasks[task.name] = task + return self + + def register_class( + self, + name: str, + sync_class: type, + display_name: str, + description: str, + order: int = 100, + ) -> "SyncRegistry": + """注册基于类的同步器。 + + Args: + name: 任务名 + sync_class: 同步器类(必须有 sync_all() 和 ensure_table_exists() 方法) + display_name: 显示名称 + description: 描述 + order: 执行顺序 + + Returns: + self,支持链式调用 + """ + + def sync_func(**kwargs) -> dict[str, pd.DataFrame]: + instance = sync_class() + instance.ensure_table_exists() + return instance.sync_all(**kwargs) + + def preview_func(**kwargs) -> Any: + instance = sync_class() + return instance.preview_sync(**kwargs) + + return self.register( + SyncTask( + name=name, + display_name=display_name, + description=description, + sync_func=sync_func, + preview_func=preview_func, + order=order, + ) + ) + + def register_func( + self, + name: str, + sync_func: Callable[..., pd.DataFrame], + display_name: str, + description: str = "", + order: int = 100, + ) -> "SyncRegistry": + """注册基于函数的同步器。 + + Args: + name: 任务名 + sync_func: 同步函数 + display_name: 显示名称 + description: 描述 + order: 执行顺序 + + Returns: + self,支持链式调用 + """ + return self.register( + SyncTask( + name=name, + display_name=display_name, + description=description, + sync_func=sync_func, + order=order, + ) + ) + + def get_task(self, name: str) -> Optional[SyncTask]: + """获取指定任务。 + + Args: + name: 任务名 + + Returns: + SyncTask 或 None + """ + return self._tasks.get(name) + + def list_tasks(self, enabled_only: bool = True) -> list[SyncTask]: + """获取所有任务列表(按 order 排序)。 + + Args: + enabled_only: 是否只返回启用的任务 + + Returns: + 排序后的任务列表 + """ + tasks = self._tasks.values() + if enabled_only: + tasks = [t for t in tasks if t.enabled] + return sorted(tasks, key=lambda t: t.order) + + def sync_all( + self, + force_full: bool = False, + dry_run: bool = False, + max_workers: Optional[int] = None, + selected: Optional[list[str]] = None, + ) -> dict[str, Any]: + """执行所有同步任务。 + + Args: + force_full: 是否强制完整重载 + dry_run: 是否仅预览 + max_workers: 工作线程数(传递给支持的任务) + selected: 只执行指定的任务列表,None表示执行所有 + + Returns: + 每个任务的执行结果字典 + """ + tasks = self.list_tasks(enabled_only=True) + + if selected: + tasks = [t for t in tasks if t.name in selected] + + total = len(tasks) + results: dict[str, Any] = {} + + print("\n" + "=" * 60) + print("[SyncRegistry] Starting data synchronization...") + print(f"[SyncRegistry] Total tasks: {total}") + print("=" * 60) + + for idx, task in enumerate(tasks, 1): + print(f"\n[{idx}/{total}] Syncing {task.display_name}...") + if task.description: + print(f" Description: {task.description}") + + try: + # 构建参数 + kwargs: dict[str, Any] = { + "force_full": force_full, + "dry_run": dry_run, + } + if max_workers is not None: + kwargs["max_workers"] = max_workers + + # 执行同步 + result = task.sync_func(**kwargs) + results[task.name] = result + + # 输出统计信息 + if isinstance(result, dict): + # 返回 dict[str, DataFrame],如日线数据 + total_records = sum(len(df) for df in result.values()) + print( + f"[{idx}/{total}] {task.display_name}: OK ({len(result)} items, {total_records} records)" + ) + elif isinstance(result, pd.DataFrame): + # 返回 DataFrame + print( + f"[{idx}/{total}] {task.display_name}: OK ({len(result)} records)" + ) + else: + print(f"[{idx}/{total}] {task.display_name}: OK") + + except Exception as e: + print(f"[{idx}/{total}] {task.display_name}: FAILED - {e}") + results[task.name] = pd.DataFrame() + + # Summary + print("\n" + "=" * 60) + print("[SyncRegistry] Sync Summary") + print("=" * 60) + + success_count = sum( + 1 + for name in results + if not (isinstance(results[name], pd.DataFrame) and results[name].empty) + ) + print( + f"Total tasks: {total}, Success: {success_count}, Failed: {total - success_count}" + ) + + for name, data in results.items(): + task = self._tasks.get(name) + display_name = task.display_name if task else name + + if isinstance(data, dict): + total_records = sum(len(df) for df in data.values()) + print(f" {display_name}: {len(data)} items, {total_records} records") + elif isinstance(data, pd.DataFrame): + status = "OK" if not data.empty else "EMPTY/FAILED" + print(f" {display_name}: {len(data)} records ({status})") + else: + print(f" {display_name}: Completed") + + print("=" * 60) + + return results + + def preview_all( + self, + force_full: bool = False, + selected: Optional[list[str]] = None, + ) -> dict[str, Any]: + """预览所有启用的任务。 + + Args: + force_full: 是否预览完整重载 + selected: 只预览指定的任务列表 + + Returns: + 每个任务的预览结果 + """ + tasks = self.list_tasks(enabled_only=True) + + if selected: + tasks = [t for t in tasks if t.name in selected] + + results: dict[str, Any] = {} + + print("\n" + "=" * 60) + print("[SyncRegistry] Previewing sync tasks...") + print("=" * 60) + + for task in tasks: + if task.preview_func is None: + print(f"\n[{task.display_name}] Preview not supported") + continue + + print(f"\n[{task.display_name}] Previewing...") + try: + result = task.preview_func(force_full=force_full) + results[task.name] = result + except Exception as e: + print(f"[{task.display_name}] Preview failed: {e}") + results[task.name] = None + + return results + + +# Global registry instance +sync_registry = SyncRegistry()