refactor(sync): 引入 SyncRegistry 注册表模式管理同步任务
- 新增 sync_registry.py 模块,提供统一的同步任务注册和管理机制 - 在 api_wrappers/__init__.py 中实现自动注册逻辑,新增接口无需修改 sync.py - 重构 sync_all_data() 函数,使用注册表模式替代手动罗列,代码从 400+ 行精简至 293 行 - 新增 selected 参数,支持选择性执行特定同步任务 - 新增 list_sync_tasks() 函数,方便查看所有已注册任务
This commit is contained in:
232
src/data/sync.py
232
src/data/sync.py
@@ -24,6 +24,10 @@
|
||||
from src.data.api_wrappers import sync_namechange
|
||||
sync_namechange(force=True)
|
||||
|
||||
【架构说明】
|
||||
本模块使用 SyncRegistry 注册表模式管理同步任务,避免手动罗列各个接口。
|
||||
同步任务在 api_wrappers/__init__.py 中自动注册,新增接口无需修改 sync.py。
|
||||
|
||||
使用方式:
|
||||
# 预览同步(检查数据量,不写入)
|
||||
from src.data.sync import preview_sync
|
||||
@@ -35,18 +39,23 @@
|
||||
|
||||
# 强制全量重载
|
||||
result = sync_all_data(force_full=True)
|
||||
|
||||
# 查看已注册的所有同步任务
|
||||
from src.data.sync_registry import sync_registry
|
||||
tasks = sync_registry.list_tasks()
|
||||
for task in tasks:
|
||||
print(f"{task.name}: {task.display_name}")
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Union, Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# 导入以触发自动注册
|
||||
from src.data import api_wrappers # noqa: F401
|
||||
from src.data.sync_registry import sync_registry
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
||||
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
|
||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
||||
from src.data.api_wrappers.api_daily_basic import sync_daily_basic
|
||||
from src.data.api_wrappers.api_stock_st import sync_stock_st
|
||||
|
||||
|
||||
def preview_sync(
|
||||
@@ -150,19 +159,24 @@ def sync_all_data(
|
||||
force_full: bool = False,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
selected: Optional[list[str]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""同步所有每日更新的数据类型。
|
||||
|
||||
【重要】本函数仅同步每日更新的数据,不包含季度/低频数据。
|
||||
|
||||
该函数按顺序同步以下每日更新的数据类型:
|
||||
1. 交易日历 (sync_trade_cal_cache)
|
||||
2. 股票基本信息 (sync_all_stocks)
|
||||
3. 日线数据 (sync_daily)
|
||||
4. Pro Bar 数据 (sync_pro_bar)
|
||||
5. 每日指标数据 (sync_daily_basic)
|
||||
6. 历史股票列表 (sync_bak_basic)
|
||||
7. ST股票列表 (sync_stock_st)
|
||||
【自动注册机制】
|
||||
同步任务在 api_wrappers/__init__.py 中自动注册到 SyncRegistry。
|
||||
当前注册的同步任务(按执行顺序):
|
||||
1. trade_cal: 交易日历缓存
|
||||
2. stock_basic: 股票基本信息
|
||||
3. pro_bar: Pro Bar 数据(复权、换手率、量比)
|
||||
4. daily_basic: 每日指标(PE、PB、换手率、市值)
|
||||
5. bak_basic: 历史股票列表
|
||||
6. stock_st: ST股票列表
|
||||
|
||||
新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码,
|
||||
无需修改本函数。
|
||||
|
||||
【不包含的同步(需单独调用)】
|
||||
- 财务数据: 利润表、资产负债表、现金流量表(季度更新)
|
||||
@@ -177,6 +191,8 @@ def sync_all_data(
|
||||
force_full: 若为 True,强制所有数据类型完整重载
|
||||
max_workers: 日线数据同步的工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅显示将要同步的内容,不写入数据
|
||||
selected: 只同步指定的任务列表,None表示同步所有
|
||||
例如: selected=["trade_cal", "stock_basic"] 只同步交易日历和股票基本信息
|
||||
|
||||
Returns:
|
||||
映射数据类型到同步结果的字典
|
||||
@@ -189,163 +205,73 @@ def sync_all_data(
|
||||
>>>
|
||||
>>> # Dry run
|
||||
>>> result = sync_all_data(dry_run=True)
|
||||
>>>
|
||||
>>> # 只同步特定任务
|
||||
>>> result = sync_all_data(selected=["trade_cal", "stock_basic"])
|
||||
>>>
|
||||
>>> # 查看所有可用任务
|
||||
>>> from src.data.sync_registry import sync_registry
|
||||
>>> tasks = sync_registry.list_tasks()
|
||||
>>> for t in tasks:
|
||||
... print(f"{t.name}: {t.display_name}")
|
||||
"""
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("[sync_all_data] Starting full data synchronization...")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Sync trade calendar (always needed first)
|
||||
print("\n[1/7] Syncing trade calendar cache...")
|
||||
try:
|
||||
from src.data.api_wrappers import sync_trade_cal_cache
|
||||
|
||||
sync_trade_cal_cache()
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
print("[1/7] Trade calendar: OK")
|
||||
except Exception as e:
|
||||
print(f"[1/7] Trade calendar: FAILED - {e}")
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
|
||||
# 2. Sync stock basic info
|
||||
print("\n[2/7] Syncing stock basic info...")
|
||||
try:
|
||||
sync_all_stocks()
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
print("[2/7] Stock basic: OK")
|
||||
except Exception as e:
|
||||
print(f"[2/7] Stock basic: FAILED - {e}")
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
|
||||
# 3. Sync daily market data
|
||||
# print("\n[3/7] Syncing daily market data...")
|
||||
# try:
|
||||
# # 确保表存在
|
||||
# from src.data.api_wrappers.api_daily import DailySync
|
||||
#
|
||||
# DailySync().ensure_table_exists()
|
||||
#
|
||||
# daily_result = sync_daily(
|
||||
# force_full=force_full,
|
||||
# max_workers=max_workers,
|
||||
# dry_run=dry_run,
|
||||
# )
|
||||
# results["daily"] = daily_result
|
||||
# total_daily_records = (
|
||||
# sum(len(df) for df in daily_result.values()) if daily_result else 0
|
||||
# )
|
||||
# print(
|
||||
# f"[3/7] Daily data: OK ({total_daily_records} records from {len(daily_result)} stocks)"
|
||||
# )
|
||||
# except Exception as e:
|
||||
# print(f"[3/7] Daily data: FAILED - {e}")
|
||||
# results["daily"] = pd.DataFrame()
|
||||
|
||||
# 4. Sync Pro Bar data
|
||||
print("\n[4/7] Syncing Pro Bar data (with adj, tor, vr)...")
|
||||
try:
|
||||
# 确保表存在
|
||||
from src.data.api_wrappers.api_pro_bar import ProBarSync
|
||||
|
||||
ProBarSync().ensure_table_exists()
|
||||
|
||||
pro_bar_result = sync_pro_bar(
|
||||
force_full=force_full,
|
||||
max_workers=max_workers,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
results["pro_bar"] = pro_bar_result
|
||||
total_pro_bar_records = (
|
||||
sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0
|
||||
)
|
||||
print(
|
||||
f"[4/7] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[4/7] Pro Bar data: FAILED - {e}")
|
||||
results["pro_bar"] = pd.DataFrame()
|
||||
|
||||
# 5. Sync daily basic indicators
|
||||
print(
|
||||
"\n[5/7] Syncing daily basic indicators (PE, PB, turnover rate, market value)..."
|
||||
return sync_registry.sync_all(
|
||||
force_full=force_full,
|
||||
max_workers=max_workers,
|
||||
dry_run=dry_run,
|
||||
selected=selected,
|
||||
)
|
||||
try:
|
||||
# 确保表存在
|
||||
from src.data.api_wrappers.api_daily_basic import DailyBasicSync
|
||||
|
||||
DailyBasicSync().ensure_table_exists()
|
||||
|
||||
daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run)
|
||||
results["daily_basic"] = daily_basic_result
|
||||
print(f"[5/7] Daily basic: OK ({len(daily_basic_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[5/7] Daily basic: FAILED - {e}")
|
||||
results["daily_basic"] = pd.DataFrame()
|
||||
def list_sync_tasks() -> list[dict[str, Any]]:
|
||||
"""列出所有已注册的同步任务。
|
||||
|
||||
# 6. Sync stock historical list (bak_basic)
|
||||
print("\n[6/7] Syncing stock historical list (bak_basic)...")
|
||||
try:
|
||||
# 确保表存在
|
||||
from src.data.api_wrappers.api_bak_basic import BakBasicSync
|
||||
Returns:
|
||||
任务信息列表,每个任务包含 name, display_name, description, order, enabled
|
||||
|
||||
BakBasicSync().ensure_table_exists()
|
||||
|
||||
bak_basic_result = sync_bak_basic(force_full=force_full)
|
||||
results["bak_basic"] = bak_basic_result
|
||||
print(f"[6/7] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[6/7] Bak basic: FAILED - {e}")
|
||||
results["bak_basic"] = pd.DataFrame()
|
||||
|
||||
# 7. Sync ST stock list
|
||||
print("\n[7/7] Syncing ST stock list...")
|
||||
try:
|
||||
# 确保表存在
|
||||
from src.data.api_wrappers.api_stock_st import StockSTSync
|
||||
|
||||
StockSTSync().ensure_table_exists()
|
||||
|
||||
stock_st_result = sync_stock_st(force_full=force_full)
|
||||
results["stock_st"] = stock_st_result
|
||||
print(f"[7/7] ST stock list: OK ({len(stock_st_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[7/7] ST stock list: FAILED - {e}")
|
||||
results["stock_st"] = pd.DataFrame()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("[sync_all_data] Sync Summary")
|
||||
print("=" * 60)
|
||||
for data_type, data in results.items():
|
||||
if isinstance(data, dict):
|
||||
# 日线和 Pro Bar 返回的是 dict[str, DataFrame]
|
||||
total_records = sum(len(df) for df in data.values())
|
||||
print(f" {data_type}: {len(data)} stocks, {total_records} total records")
|
||||
else:
|
||||
# daily_basic 和 bak_basic 返回的是 DataFrame
|
||||
print(f" {data_type}: {len(data)} records")
|
||||
print("=" * 60)
|
||||
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
|
||||
print(" from src.data.api_wrappers import sync_namechange")
|
||||
print(" sync_namechange(force=True)")
|
||||
|
||||
return results
|
||||
Example:
|
||||
>>> tasks = list_sync_tasks()
|
||||
>>> for task in tasks:
|
||||
... print(f"{task['order']:2d}. {task['name']}: {task['display_name']}")
|
||||
"""
|
||||
tasks = sync_registry.list_tasks()
|
||||
return [
|
||||
{
|
||||
"name": t.name,
|
||||
"display_name": t.display_name,
|
||||
"description": t.description,
|
||||
"order": t.order,
|
||||
"enabled": t.enabled,
|
||||
}
|
||||
for t in tasks
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Data Sync Module")
|
||||
print("=" * 60)
|
||||
print("\nRegistered sync tasks:")
|
||||
print("-" * 60)
|
||||
|
||||
tasks = list_sync_tasks()
|
||||
for task in tasks:
|
||||
status = "[启用]" if task["enabled"] else "[禁用]"
|
||||
print(f" {status} {task['order']:2d}. {task['name']}: {task['display_name']}")
|
||||
|
||||
print("-" * 60)
|
||||
print(f"\nTotal: {len(tasks)} tasks")
|
||||
print("\nUsage:")
|
||||
print(" # Sync all data types at once (RECOMMENDED)")
|
||||
print(" from src.data.sync import sync_all_data")
|
||||
print(" result = sync_all_data() # Incremental sync all")
|
||||
print(" result = sync_all_data(force_full=True) # Full reload")
|
||||
print("")
|
||||
print(" # Or sync individual data types:")
|
||||
print(" from src.data.sync import sync_all, preview_sync")
|
||||
print(" from src.data.api_wrappers import sync_daily_basic, sync_bak_basic")
|
||||
print(" # Sync selected data types only")
|
||||
print(" result = sync_all_data(selected=['trade_cal', 'pro_bar'])")
|
||||
print("")
|
||||
print(" # List all available sync tasks")
|
||||
print(" tasks = list_sync_tasks()")
|
||||
print("")
|
||||
print(" # Preview before sync (recommended)")
|
||||
print(" preview = preview_sync()")
|
||||
@@ -356,10 +282,6 @@ if __name__ == "__main__":
|
||||
print(" # Actual sync")
|
||||
print(" result = sync_all() # Incremental sync")
|
||||
print(" result = sync_all(force_full=True) # Full reload")
|
||||
print("")
|
||||
print(" # bak_basic sync")
|
||||
print(" result = sync_bak_basic() # Incremental sync")
|
||||
print(" result = sync_bak_basic(force_full=True) # Full reload")
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
# Run sync_all_data by default
|
||||
|
||||
Reference in New Issue
Block a user