refactor(sync): 引入 SyncRegistry 注册表模式管理同步任务

- 新增 sync_registry.py 模块,提供统一的同步任务注册和管理机制
- 在 api_wrappers/__init__.py 中实现自动注册逻辑,新增接口无需修改 sync.py
- 重构 sync_all_data() 函数,使用注册表模式替代手动罗列,代码从 400+ 行精简至 293 行
- 新增 selected 参数,支持选择性执行特定同步任务
- 新增 list_sync_tasks() 函数,方便查看所有已注册任务
This commit is contained in:
2026-03-05 21:11:18 +08:00
parent 5a1f278df8
commit aefe6d06cf
3 changed files with 485 additions and 155 deletions

View File

@@ -88,6 +88,7 @@ __all__ = [
# Historical stock list
"get_bak_basic",
"sync_bak_basic",
"BakBasicSync",
# Namechange
"get_namechange",
"sync_namechange",
@@ -105,3 +106,77 @@ __all__ = [
"sync_stock_st",
"StockSTSync",
]
# =============================================================================
# 自动注册同步任务到 SyncRegistry
# 这样 sync.py 不需要手动罗列各个接口
# =============================================================================
try:
from src.data.sync_registry import sync_registry
# 1. Trade Calendar - 最高优先级,其他任务可能依赖
sync_registry.register_func(
name="trade_cal",
sync_func=sync_trade_cal_cache,
display_name="交易日历",
description="交易日期缓存",
order=1,
)
# 2. Stock Basic - 基础数据
sync_registry.register_func(
name="stock_basic",
sync_func=sync_all_stocks,
display_name="股票基本信息",
description="所有上市/退市股票的基础信息",
order=2,
)
# 3. Pro Bar - 通用行情(推荐用于替代日线)
from src.data.api_wrappers.api_pro_bar import ProBarSync
sync_registry.register_class(
name="pro_bar",
sync_class=ProBarSync,
display_name="Pro Bar 数据",
description="包含复权因子、换手率、量比的数据",
order=10,
)
# 4. Daily Basic - 每日指标
from src.data.api_wrappers.api_daily_basic import DailyBasicSync
sync_registry.register_class(
name="daily_basic",
sync_class=DailyBasicSync,
display_name="每日指标",
description="市盈率、市净率、换手率、市值等指标",
order=20,
)
# 5. Bak Basic - 历史股票列表
from src.data.api_wrappers.api_bak_basic import BakBasicSync
sync_registry.register_class(
name="bak_basic",
sync_class=BakBasicSync,
display_name="历史股票列表",
description="历史股票列表(包含退市股票)",
order=30,
)
# 6. ST Stock - ST股票列表
from src.data.api_wrappers.api_stock_st import StockSTSync
sync_registry.register_class(
name="stock_st",
sync_class=StockSTSync,
display_name="ST股票列表",
description="ST股票历史记录",
order=40,
)
except ImportError:
# sync_registry 可能不存在(首次导入),忽略
pass

View File

@@ -24,6 +24,10 @@
from src.data.api_wrappers import sync_namechange
sync_namechange(force=True)
【架构说明】
本模块使用 SyncRegistry 注册表模式管理同步任务,避免手动罗列各个接口。
同步任务在 api_wrappers/__init__.py 中自动注册,新增接口无需修改 sync.py。
使用方式:
# 预览同步(检查数据量,不写入)
from src.data.sync import preview_sync
@@ -35,18 +39,23 @@
# 强制全量重载
result = sync_all_data(force_full=True)
# 查看已注册的所有同步任务
from src.data.sync_registry import sync_registry
tasks = sync_registry.list_tasks()
for task in tasks:
print(f"{task.name}: {task.display_name}")
"""
from typing import Optional, Dict, Union, Any
import pandas as pd
# 导入以触发自动注册
from src.data import api_wrappers # noqa: F401
from src.data.sync_registry import sync_registry
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
from src.data.api_wrappers.api_daily_basic import sync_daily_basic
from src.data.api_wrappers.api_stock_st import sync_stock_st
def preview_sync(
@@ -150,19 +159,24 @@ def sync_all_data(
force_full: bool = False,
max_workers: Optional[int] = None,
dry_run: bool = False,
selected: Optional[list[str]] = None,
) -> dict[str, Any]:
"""同步所有每日更新的数据类型。
【重要】本函数仅同步每日更新的数据,不包含季度/低频数据。
该函数按顺序同步以下每日更新的数据类型:
1. 交易日历 (sync_trade_cal_cache)
2. 股票基本信息 (sync_all_stocks)
3. 日线数据 (sync_daily)
4. Pro Bar 数据 (sync_pro_bar)
5. 每日指标数据 (sync_daily_basic)
6. 历史股票列表 (sync_bak_basic)
7. ST股票列表 (sync_stock_st)
【自动注册机制】
同步任务在 api_wrappers/__init__.py 中自动注册到 SyncRegistry。
当前注册的同步任务(按执行顺序):
1. trade_cal: 交易日历缓存
2. stock_basic: 股票基本信息
3. pro_bar: Pro Bar 数据(复权、换手率、量比)
4. daily_basic: 每日指标PE、PB、换手率、市值
5. bak_basic: 历史股票列表
6. stock_st: ST股票列表
新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码,
无需修改本函数。
【不包含的同步(需单独调用)】
- 财务数据: 利润表、资产负债表、现金流量表(季度更新)
@@ -177,6 +191,8 @@ def sync_all_data(
force_full: 若为 True强制所有数据类型完整重载
max_workers: 日线数据同步的工作线程数(默认: 10
dry_run: 若为 True仅显示将要同步的内容不写入数据
selected: 只同步指定的任务列表None表示同步所有
例如: selected=["trade_cal", "stock_basic"] 只同步交易日历和股票基本信息
Returns:
映射数据类型到同步结果的字典
@@ -189,163 +205,73 @@ def sync_all_data(
>>>
>>> # Dry run
>>> result = sync_all_data(dry_run=True)
>>>
>>> # 只同步特定任务
>>> result = sync_all_data(selected=["trade_cal", "stock_basic"])
>>>
>>> # 查看所有可用任务
>>> from src.data.sync_registry import sync_registry
>>> tasks = sync_registry.list_tasks()
>>> for t in tasks:
... print(f"{t.name}: {t.display_name}")
"""
results: dict[str, Any] = {}
print("\n" + "=" * 60)
print("[sync_all_data] Starting full data synchronization...")
print("=" * 60)
# 1. Sync trade calendar (always needed first)
print("\n[1/7] Syncing trade calendar cache...")
try:
from src.data.api_wrappers import sync_trade_cal_cache
sync_trade_cal_cache()
results["trade_cal"] = pd.DataFrame()
print("[1/7] Trade calendar: OK")
except Exception as e:
print(f"[1/7] Trade calendar: FAILED - {e}")
results["trade_cal"] = pd.DataFrame()
# 2. Sync stock basic info
print("\n[2/7] Syncing stock basic info...")
try:
sync_all_stocks()
results["stock_basic"] = pd.DataFrame()
print("[2/7] Stock basic: OK")
except Exception as e:
print(f"[2/7] Stock basic: FAILED - {e}")
results["stock_basic"] = pd.DataFrame()
# 3. Sync daily market data
# print("\n[3/7] Syncing daily market data...")
# try:
# # 确保表存在
# from src.data.api_wrappers.api_daily import DailySync
#
# DailySync().ensure_table_exists()
#
# daily_result = sync_daily(
# force_full=force_full,
# max_workers=max_workers,
# dry_run=dry_run,
# )
# results["daily"] = daily_result
# total_daily_records = (
# sum(len(df) for df in daily_result.values()) if daily_result else 0
# )
# print(
# f"[3/7] Daily data: OK ({total_daily_records} records from {len(daily_result)} stocks)"
# )
# except Exception as e:
# print(f"[3/7] Daily data: FAILED - {e}")
# results["daily"] = pd.DataFrame()
# 4. Sync Pro Bar data
print("\n[4/7] Syncing Pro Bar data (with adj, tor, vr)...")
try:
# 确保表存在
from src.data.api_wrappers.api_pro_bar import ProBarSync
ProBarSync().ensure_table_exists()
pro_bar_result = sync_pro_bar(
return sync_registry.sync_all(
force_full=force_full,
max_workers=max_workers,
dry_run=dry_run,
selected=selected,
)
results["pro_bar"] = pro_bar_result
total_pro_bar_records = (
sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0
)
print(
f"[4/7] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
)
except Exception as e:
print(f"[4/7] Pro Bar data: FAILED - {e}")
results["pro_bar"] = pd.DataFrame()
# 5. Sync daily basic indicators
print(
"\n[5/7] Syncing daily basic indicators (PE, PB, turnover rate, market value)..."
)
try:
# 确保表存在
from src.data.api_wrappers.api_daily_basic import DailyBasicSync
DailyBasicSync().ensure_table_exists()
def list_sync_tasks() -> list[dict[str, Any]]:
"""列出所有已注册的同步任务。
daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run)
results["daily_basic"] = daily_basic_result
print(f"[5/7] Daily basic: OK ({len(daily_basic_result)} records)")
except Exception as e:
print(f"[5/7] Daily basic: FAILED - {e}")
results["daily_basic"] = pd.DataFrame()
Returns:
任务信息列表,每个任务包含 name, display_name, description, order, enabled
# 6. Sync stock historical list (bak_basic)
print("\n[6/7] Syncing stock historical list (bak_basic)...")
try:
# 确保表存在
from src.data.api_wrappers.api_bak_basic import BakBasicSync
BakBasicSync().ensure_table_exists()
bak_basic_result = sync_bak_basic(force_full=force_full)
results["bak_basic"] = bak_basic_result
print(f"[6/7] Bak basic: OK ({len(bak_basic_result)} records)")
except Exception as e:
print(f"[6/7] Bak basic: FAILED - {e}")
results["bak_basic"] = pd.DataFrame()
# 7. Sync ST stock list
print("\n[7/7] Syncing ST stock list...")
try:
# 确保表存在
from src.data.api_wrappers.api_stock_st import StockSTSync
StockSTSync().ensure_table_exists()
stock_st_result = sync_stock_st(force_full=force_full)
results["stock_st"] = stock_st_result
print(f"[7/7] ST stock list: OK ({len(stock_st_result)} records)")
except Exception as e:
print(f"[7/7] ST stock list: FAILED - {e}")
results["stock_st"] = pd.DataFrame()
# Summary
print("\n" + "=" * 60)
print("[sync_all_data] Sync Summary")
print("=" * 60)
for data_type, data in results.items():
if isinstance(data, dict):
# 日线和 Pro Bar 返回的是 dict[str, DataFrame]
total_records = sum(len(df) for df in data.values())
print(f" {data_type}: {len(data)} stocks, {total_records} total records")
else:
# daily_basic 和 bak_basic 返回的是 DataFrame
print(f" {data_type}: {len(data)} records")
print("=" * 60)
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
print(" from src.data.api_wrappers import sync_namechange")
print(" sync_namechange(force=True)")
return results
Example:
>>> tasks = list_sync_tasks()
>>> for task in tasks:
... print(f"{task['order']:2d}. {task['name']}: {task['display_name']}")
"""
tasks = sync_registry.list_tasks()
return [
{
"name": t.name,
"display_name": t.display_name,
"description": t.description,
"order": t.order,
"enabled": t.enabled,
}
for t in tasks
]
if __name__ == "__main__":
print("=" * 60)
print("Data Sync Module")
print("=" * 60)
print("\nRegistered sync tasks:")
print("-" * 60)
tasks = list_sync_tasks()
for task in tasks:
status = "[启用]" if task["enabled"] else "[禁用]"
print(f" {status} {task['order']:2d}. {task['name']}: {task['display_name']}")
print("-" * 60)
print(f"\nTotal: {len(tasks)} tasks")
print("\nUsage:")
print(" # Sync all data types at once (RECOMMENDED)")
print(" from src.data.sync import sync_all_data")
print(" result = sync_all_data() # Incremental sync all")
print(" result = sync_all_data(force_full=True) # Full reload")
print("")
print(" # Or sync individual data types:")
print(" from src.data.sync import sync_all, preview_sync")
print(" from src.data.api_wrappers import sync_daily_basic, sync_bak_basic")
print(" # Sync selected data types only")
print(" result = sync_all_data(selected=['trade_cal', 'pro_bar'])")
print("")
print(" # List all available sync tasks")
print(" tasks = list_sync_tasks()")
print("")
print(" # Preview before sync (recommended)")
print(" preview = preview_sync()")
@@ -356,10 +282,6 @@ if __name__ == "__main__":
print(" # Actual sync")
print(" result = sync_all() # Incremental sync")
print(" result = sync_all(force_full=True) # Full reload")
print("")
print(" # bak_basic sync")
print(" result = sync_bak_basic() # Incremental sync")
print(" result = sync_bak_basic(force_full=True) # Full reload")
print("\n" + "=" * 60)
# Run sync_all_data by default

333
src/data/sync_registry.py Normal file
View File

@@ -0,0 +1,333 @@
"""数据同步注册表模块。
该模块提供统一的同步任务注册和管理机制,避免在 sync.py 中手动罗列各个接口。
使用方式:
# 在 api_xxx.py 文件中注册同步任务
from src.data.sync_registry import sync_registry, SyncTask
sync_registry.register(
SyncTask(
name="pro_bar",
display_name="Pro Bar 数据",
description="包含复权因子、换手率、量比的数据",
sync_func=lambda **kwargs: ProBarSync().sync_all(**kwargs),
preview_func=lambda **kwargs: ProBarSync().preview_sync(**kwargs),
order=10, # 执行顺序
)
)
# 在 sync.py 中统一执行
from src.data.sync_registry import sync_registry
results = sync_registry.sync_all(force_full=False, dry_run=False)
"""
from dataclasses import dataclass, field
from typing import Callable, Optional, Any
from collections import OrderedDict
import pandas as pd
@dataclass
class SyncTask:
"""同步任务定义。
Attributes:
name: 任务唯一标识名
display_name: 显示名称(用于日志)
description: 任务描述
sync_func: 同步函数,接收 force_full, dry_run 等参数
preview_func: 预览函数(可选)
order: 执行顺序数字越小越先执行默认100
enabled: 是否启用默认True
"""
name: str
display_name: str
description: str
sync_func: Callable[..., pd.DataFrame | dict[str, pd.DataFrame]]
preview_func: Optional[Callable[..., Any]] = None
order: int = 100
enabled: bool = True
class SyncRegistry:
"""同步任务注册表。
统一管理所有数据同步任务,支持自动发现和批量执行。
Example:
>>> registry = SyncRegistry()
>>>
>>> # 注册类方式同步器
>>> registry.register_class("daily", DailySync, "日线数据", "股票日线行情", order=10)
>>>
>>> # 注册函数方式同步器
>>> registry.register_func("stock_basic", sync_all_stocks, "股票基本信息", order=5)
>>>
>>> # 批量执行所有同步
>>> results = registry.sync_all()
>>>
>>> # 只执行特定任务
>>> results = registry.sync_selected(["stock_basic", "daily"])
"""
def __init__(self):
self._tasks: OrderedDict[str, SyncTask] = OrderedDict()
def register(self, task: SyncTask) -> "SyncRegistry":
"""注册同步任务。
Args:
task: 同步任务定义
Returns:
self支持链式调用
"""
if task.name in self._tasks:
print(
f"[SyncRegistry] Warning: Task '{task.name}' already registered, overwriting"
)
self._tasks[task.name] = task
return self
def register_class(
self,
name: str,
sync_class: type,
display_name: str,
description: str,
order: int = 100,
) -> "SyncRegistry":
"""注册基于类的同步器。
Args:
name: 任务名
sync_class: 同步器类(必须有 sync_all() 和 ensure_table_exists() 方法)
display_name: 显示名称
description: 描述
order: 执行顺序
Returns:
self支持链式调用
"""
def sync_func(**kwargs) -> dict[str, pd.DataFrame]:
instance = sync_class()
instance.ensure_table_exists()
return instance.sync_all(**kwargs)
def preview_func(**kwargs) -> Any:
instance = sync_class()
return instance.preview_sync(**kwargs)
return self.register(
SyncTask(
name=name,
display_name=display_name,
description=description,
sync_func=sync_func,
preview_func=preview_func,
order=order,
)
)
def register_func(
self,
name: str,
sync_func: Callable[..., pd.DataFrame],
display_name: str,
description: str = "",
order: int = 100,
) -> "SyncRegistry":
"""注册基于函数的同步器。
Args:
name: 任务名
sync_func: 同步函数
display_name: 显示名称
description: 描述
order: 执行顺序
Returns:
self支持链式调用
"""
return self.register(
SyncTask(
name=name,
display_name=display_name,
description=description,
sync_func=sync_func,
order=order,
)
)
def get_task(self, name: str) -> Optional[SyncTask]:
"""获取指定任务。
Args:
name: 任务名
Returns:
SyncTask 或 None
"""
return self._tasks.get(name)
def list_tasks(self, enabled_only: bool = True) -> list[SyncTask]:
"""获取所有任务列表(按 order 排序)。
Args:
enabled_only: 是否只返回启用的任务
Returns:
排序后的任务列表
"""
tasks = self._tasks.values()
if enabled_only:
tasks = [t for t in tasks if t.enabled]
return sorted(tasks, key=lambda t: t.order)
def sync_all(
self,
force_full: bool = False,
dry_run: bool = False,
max_workers: Optional[int] = None,
selected: Optional[list[str]] = None,
) -> dict[str, Any]:
"""执行所有同步任务。
Args:
force_full: 是否强制完整重载
dry_run: 是否仅预览
max_workers: 工作线程数(传递给支持的任务)
selected: 只执行指定的任务列表None表示执行所有
Returns:
每个任务的执行结果字典
"""
tasks = self.list_tasks(enabled_only=True)
if selected:
tasks = [t for t in tasks if t.name in selected]
total = len(tasks)
results: dict[str, Any] = {}
print("\n" + "=" * 60)
print("[SyncRegistry] Starting data synchronization...")
print(f"[SyncRegistry] Total tasks: {total}")
print("=" * 60)
for idx, task in enumerate(tasks, 1):
print(f"\n[{idx}/{total}] Syncing {task.display_name}...")
if task.description:
print(f" Description: {task.description}")
try:
# 构建参数
kwargs: dict[str, Any] = {
"force_full": force_full,
"dry_run": dry_run,
}
if max_workers is not None:
kwargs["max_workers"] = max_workers
# 执行同步
result = task.sync_func(**kwargs)
results[task.name] = result
# 输出统计信息
if isinstance(result, dict):
# 返回 dict[str, DataFrame],如日线数据
total_records = sum(len(df) for df in result.values())
print(
f"[{idx}/{total}] {task.display_name}: OK ({len(result)} items, {total_records} records)"
)
elif isinstance(result, pd.DataFrame):
# 返回 DataFrame
print(
f"[{idx}/{total}] {task.display_name}: OK ({len(result)} records)"
)
else:
print(f"[{idx}/{total}] {task.display_name}: OK")
except Exception as e:
print(f"[{idx}/{total}] {task.display_name}: FAILED - {e}")
results[task.name] = pd.DataFrame()
# Summary
print("\n" + "=" * 60)
print("[SyncRegistry] Sync Summary")
print("=" * 60)
success_count = sum(
1
for name in results
if not (isinstance(results[name], pd.DataFrame) and results[name].empty)
)
print(
f"Total tasks: {total}, Success: {success_count}, Failed: {total - success_count}"
)
for name, data in results.items():
task = self._tasks.get(name)
display_name = task.display_name if task else name
if isinstance(data, dict):
total_records = sum(len(df) for df in data.values())
print(f" {display_name}: {len(data)} items, {total_records} records")
elif isinstance(data, pd.DataFrame):
status = "OK" if not data.empty else "EMPTY/FAILED"
print(f" {display_name}: {len(data)} records ({status})")
else:
print(f" {display_name}: Completed")
print("=" * 60)
return results
def preview_all(
self,
force_full: bool = False,
selected: Optional[list[str]] = None,
) -> dict[str, Any]:
"""预览所有启用的任务。
Args:
force_full: 是否预览完整重载
selected: 只预览指定的任务列表
Returns:
每个任务的预览结果
"""
tasks = self.list_tasks(enabled_only=True)
if selected:
tasks = [t for t in tasks if t.name in selected]
results: dict[str, Any] = {}
print("\n" + "=" * 60)
print("[SyncRegistry] Previewing sync tasks...")
print("=" * 60)
for task in tasks:
if task.preview_func is None:
print(f"\n[{task.display_name}] Preview not supported")
continue
print(f"\n[{task.display_name}] Previewing...")
try:
result = task.preview_func(force_full=force_full)
results[task.name] = result
except Exception as e:
print(f"[{task.display_name}] Preview failed: {e}")
results[task.name] = None
return results
# Global registry instance
sync_registry = SyncRegistry()