refactor: 提取数据同步逻辑为抽象基类
新增 base_sync.py 模块,提供三层抽象结构统一数据同步流程: - BaseDataSync: 所有同步类型的基础抽象(客户端、股票代码获取、交易日历) - StockBasedSync: 按股票同步抽象类(适用于 daily, pro_bar) - DateBasedSync: 按日期同步抽象类(适用于 bak_basic)
This commit is contained in:
@@ -5,12 +5,10 @@ Data available from 2016 onwards.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.db_manager import ensure_table
|
||||
from src.data.api_wrappers.base_sync import DateBasedSync
|
||||
|
||||
|
||||
def get_bak_basic(
|
||||
@@ -75,6 +73,34 @@ def get_bak_basic(
|
||||
return data
|
||||
|
||||
|
||||
class BakBasicSync(DateBasedSync):
|
||||
"""历史股票列表批量同步管理器,支持全量/增量同步。
|
||||
|
||||
继承自 DateBasedSync,按日期顺序获取数据。
|
||||
数据从 2016 年开始可用。
|
||||
|
||||
Example:
|
||||
>>> sync = BakBasicSync()
|
||||
>>> results = sync.sync_all() # 增量同步
|
||||
>>> results = sync.sync_all(force_full=True) # 全量同步
|
||||
>>> preview = sync.preview_sync() # 预览
|
||||
"""
|
||||
|
||||
table_name = "bak_basic"
|
||||
default_start_date = "20160101"
|
||||
|
||||
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
|
||||
"""获取单日的历史股票列表数据。
|
||||
|
||||
Args:
|
||||
trade_date: 交易日期(YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
包含当日所有股票数据的 DataFrame
|
||||
"""
|
||||
return get_bak_basic(trade_date=trade_date)
|
||||
|
||||
|
||||
def sync_bak_basic(
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
@@ -94,152 +120,12 @@ def sync_bak_basic(
|
||||
Returns:
|
||||
pd.DataFrame with synced data
|
||||
"""
|
||||
from src.data.db_manager import ensure_table
|
||||
|
||||
TABLE_NAME = "bak_basic"
|
||||
storage = Storage()
|
||||
thread_storage = ThreadSafeStorage()
|
||||
|
||||
# Default end date
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
# Check if table exists
|
||||
table_exists = storage.exists(TABLE_NAME)
|
||||
|
||||
if not table_exists or force_full:
|
||||
# ===== FULL SYNC =====
|
||||
# 1. Create table with schema
|
||||
# 2. Create composite index (trade_date, ts_code)
|
||||
# 3. Full sync from start_date
|
||||
|
||||
if not table_exists:
|
||||
print(f"[sync_bak_basic] Table '{TABLE_NAME}' doesn't exist, creating...")
|
||||
|
||||
# Fetch sample to get schema
|
||||
sample = get_bak_basic(trade_date=end_date)
|
||||
if sample.empty:
|
||||
sample = get_bak_basic(trade_date="20240102")
|
||||
|
||||
if sample.empty:
|
||||
print("[sync_bak_basic] Cannot create table: no sample data available")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Create table with schema
|
||||
columns = []
|
||||
for col in sample.columns:
|
||||
dtype = str(sample[col].dtype)
|
||||
if col == "trade_date":
|
||||
col_type = "DATE"
|
||||
elif "int" in dtype:
|
||||
col_type = "INTEGER"
|
||||
elif "float" in dtype:
|
||||
col_type = "DOUBLE"
|
||||
else:
|
||||
col_type = "VARCHAR"
|
||||
columns.append(f'"{col}" {col_type}')
|
||||
|
||||
columns_sql = ", ".join(columns)
|
||||
create_sql = f'CREATE TABLE IF NOT EXISTS "{TABLE_NAME}" ({columns_sql}, PRIMARY KEY ("trade_date", "ts_code"))'
|
||||
|
||||
try:
|
||||
storage._connection.execute(create_sql)
|
||||
print(f"[sync_bak_basic] Created table '{TABLE_NAME}'")
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error creating table: {e}")
|
||||
|
||||
# Create composite index
|
||||
try:
|
||||
storage._connection.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS "idx_bak_basic_date_code"
|
||||
ON "{TABLE_NAME}"("trade_date", "ts_code")
|
||||
""")
|
||||
print(f"[sync_bak_basic] Created composite index on (trade_date, ts_code)")
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error creating index: {e}")
|
||||
|
||||
# Determine sync dates
|
||||
sync_start = start_date or "20160101"
|
||||
mode = "FULL"
|
||||
print(f"[sync_bak_basic] Mode: {mode} SYNC from {sync_start} to {end_date}")
|
||||
|
||||
else:
|
||||
# ===== INCREMENTAL SYNC =====
|
||||
# Check last date in table, sync from last_date + 1
|
||||
|
||||
try:
|
||||
result = storage._connection.execute(
|
||||
f'SELECT MAX("trade_date") FROM "{TABLE_NAME}"'
|
||||
).fetchone()
|
||||
last_date = result[0] if result and result[0] else None
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error getting last date: {e}")
|
||||
last_date = None
|
||||
|
||||
if last_date is None:
|
||||
# Table exists but empty, do full sync
|
||||
sync_start = start_date or "20160101"
|
||||
mode = "FULL (empty table)"
|
||||
else:
|
||||
# Incremental from last_date + 1
|
||||
# Handle both YYYYMMDD and YYYY-MM-DD formats
|
||||
last_date_str = str(last_date).replace("-", "")
|
||||
last_dt = datetime.strptime(last_date_str, "%Y%m%d")
|
||||
next_dt = last_dt + timedelta(days=1)
|
||||
sync_start = next_dt.strftime("%Y%m%d")
|
||||
mode = "INCREMENTAL"
|
||||
|
||||
# Skip if already up to date
|
||||
if sync_start > end_date:
|
||||
print(f"[sync_bak_basic] Data is up-to-date (last: {last_date}), skipping sync")
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"[sync_bak_basic] Mode: {mode} from {sync_start} to {end_date} (last: {last_date})")
|
||||
|
||||
# ===== FETCH AND SAVE DATA =====
|
||||
all_data: List[pd.DataFrame] = []
|
||||
current = datetime.strptime(sync_start, "%Y%m%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y%m%d")
|
||||
|
||||
# Calculate total days for progress bar
|
||||
total_days = (end_dt - current).days + 1
|
||||
print(f"[sync_bak_basic] Fetching data for {total_days} days...")
|
||||
|
||||
with tqdm(total=total_days, desc="Syncing dates") as pbar:
|
||||
while current <= end_dt:
|
||||
date_str = current.strftime("%Y%m%d")
|
||||
try:
|
||||
data = get_bak_basic(trade_date=date_str)
|
||||
if not data.empty:
|
||||
all_data.append(data)
|
||||
pbar.set_postfix({"date": date_str, "records": len(data)})
|
||||
except Exception as e:
|
||||
print(f" {date_str}: ERROR - {e}")
|
||||
|
||||
current += timedelta(days=1)
|
||||
pbar.update(1)
|
||||
|
||||
if not all_data:
|
||||
print("[sync_bak_basic] No data fetched")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine and save
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
|
||||
# Convert trade_date to datetime for proper DATE type storage
|
||||
combined["trade_date"] = pd.to_datetime(combined["trade_date"], format="%Y%m%d")
|
||||
|
||||
print(f"[sync_bak_basic] Total records: {len(combined)}")
|
||||
|
||||
# Delete existing data for the date range and append new data
|
||||
# Convert sync_start to date format for comparison with DATE column
|
||||
sync_start_date = pd.to_datetime(sync_start, format="%Y%m%d").date()
|
||||
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start_date])
|
||||
thread_storage.queue_save(TABLE_NAME, combined)
|
||||
thread_storage.flush()
|
||||
|
||||
print(f"[sync_bak_basic] Saved {len(combined)} records to DuckDB")
|
||||
return combined
|
||||
sync_manager = BakBasicSync()
|
||||
return sync_manager.sync_all(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
force_full=force_full,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -9,21 +9,9 @@ batch synchronization (DailySync class) for daily market data.
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
|
||||
from src.config.settings import get_settings
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
|
||||
from src.data.api_wrappers.base_sync import StockBasedSync
|
||||
|
||||
|
||||
def get_daily(
|
||||
@@ -90,426 +78,27 @@ def get_daily(
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DailySync - 日线数据批量同步类
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DailySync:
|
||||
class DailySync(StockBasedSync):
|
||||
"""日线数据批量同步管理器,支持全量/增量同步。
|
||||
|
||||
功能特性:
|
||||
- 多线程并发获取(ThreadPoolExecutor)
|
||||
- 增量同步(自动检测上次同步位置)
|
||||
- 内存缓存(避免重复磁盘读取)
|
||||
- 异常立即停止(确保数据一致性)
|
||||
- 预览模式(预览同步数据量,不实际写入)
|
||||
继承自 StockBasedSync,使用多线程按股票并发获取数据。
|
||||
|
||||
Example:
|
||||
>>> sync = DailySync()
|
||||
>>> results = sync.sync_all() # 增量同步
|
||||
>>> results = sync.sync_all(force_full=True) # 全量同步
|
||||
>>> preview = sync.preview_sync() # 预览
|
||||
"""
|
||||
|
||||
# 默认工作线程数(从配置读取,默认10)
|
||||
DEFAULT_MAX_WORKERS = get_settings().threads
|
||||
table_name = "daily"
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""初始化同步管理器。
|
||||
|
||||
Args:
|
||||
max_workers: 工作线程数(默认从配置读取)
|
||||
"""
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
self._stop_flag.set() # 初始为未停止状态
|
||||
self._cached_daily_data: Optional[pd.DataFrame] = None # 日线数据缓存
|
||||
|
||||
def _load_daily_data(self) -> pd.DataFrame:
|
||||
"""从存储加载日线数据(带缓存)。
|
||||
|
||||
该方法会将数据缓存在内存中以避免重复磁盘读取。
|
||||
调用 clear_cache() 可强制重新加载。
|
||||
|
||||
Returns:
|
||||
缓存或从存储加载的日线数据 DataFrame
|
||||
"""
|
||||
if self._cached_daily_data is None:
|
||||
self._cached_daily_data = self.storage.load("daily")
|
||||
return self._cached_daily_data
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除缓存的日线数据,强制下次访问时重新加载。"""
|
||||
self._cached_daily_data = None
|
||||
|
||||
def get_all_stock_codes(self, only_listed: bool = True) -> list:
|
||||
"""从本地存储获取所有股票代码。
|
||||
|
||||
优先使用 stock_basic.csv 以确保包含所有股票,
|
||||
避免回测中的前视偏差。
|
||||
|
||||
Args:
|
||||
only_listed: 若为 True,仅返回当前上市股票(L 状态)。
|
||||
设为 False 可包含退市股票(用于完整回测)。
|
||||
|
||||
Returns:
|
||||
股票代码列表
|
||||
"""
|
||||
# 确保 stock_basic.csv 是最新的
|
||||
print("[DailySync] Ensuring stock_basic.csv is up-to-date...")
|
||||
sync_all_stocks()
|
||||
|
||||
# 从 stock_basic.csv 文件获取
|
||||
stock_csv_path = _get_csv_path()
|
||||
|
||||
if stock_csv_path.exists():
|
||||
print(f"[DailySync] Reading stock_basic from CSV: {stock_csv_path}")
|
||||
try:
|
||||
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
|
||||
if not stock_df.empty and "ts_code" in stock_df.columns:
|
||||
# 根据 list_status 过滤
|
||||
if only_listed and "list_status" in stock_df.columns:
|
||||
listed_stocks = stock_df[stock_df["list_status"] == "L"]
|
||||
codes = listed_stocks["ts_code"].unique().tolist()
|
||||
total = len(stock_df["ts_code"].unique())
|
||||
print(
|
||||
f"[DailySync] Found {len(codes)} listed stocks (filtered from {total} total)"
|
||||
)
|
||||
else:
|
||||
codes = stock_df["ts_code"].unique().tolist()
|
||||
print(
|
||||
f"[DailySync] Found {len(codes)} stock codes from stock_basic.csv"
|
||||
)
|
||||
return codes
|
||||
else:
|
||||
print(
|
||||
f"[DailySync] stock_basic.csv exists but no ts_code column or empty"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[DailySync] Error reading stock_basic.csv: {e}")
|
||||
|
||||
# 回退:从日线存储获取
|
||||
print(
|
||||
"[DailySync] stock_basic.csv not available, falling back to daily data..."
|
||||
)
|
||||
daily_data = self._load_daily_data()
|
||||
if not daily_data.empty and "ts_code" in daily_data.columns:
|
||||
codes = daily_data["ts_code"].unique().tolist()
|
||||
print(f"[DailySync] Found {len(codes)} stock codes from daily data")
|
||||
return codes
|
||||
|
||||
print("[DailySync] No stock codes found in local storage")
|
||||
return []
|
||||
|
||||
def get_global_last_date(self) -> Optional[str]:
|
||||
"""获取全局最后交易日期。
|
||||
|
||||
Returns:
|
||||
最后交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
return None
|
||||
return str(daily_data["trade_date"].max())
|
||||
|
||||
def get_global_first_date(self) -> Optional[str]:
|
||||
"""获取全局最早交易日期。
|
||||
|
||||
Returns:
|
||||
最早交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
return None
|
||||
return str(daily_data["trade_date"].min())
|
||||
|
||||
def get_trade_calendar_bounds(
|
||||
self, start_date: str, end_date: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""从交易日历获取首尾交易日。
|
||||
|
||||
Args:
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
|
||||
Returns:
|
||||
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
|
||||
"""
|
||||
try:
|
||||
first_day = get_first_trading_day(start_date, end_date)
|
||||
last_day = get_last_trading_day(start_date, end_date)
|
||||
return (first_day, last_day)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
|
||||
return (None, None)
|
||||
|
||||
def check_sync_needed(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
table_name: str = "daily",
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""基于交易日历检查是否需要同步。
|
||||
|
||||
该方法比较本地数据日期范围与交易日历,
|
||||
以确定是否需要获取新数据。
|
||||
|
||||
逻辑:
|
||||
- 若 force_full:需要同步,返回 (True, 20180101, today)
|
||||
- 若无本地数据:需要同步,返回 (True, 20180101, today)
|
||||
- 若存在本地数据:
|
||||
- 从交易日历获取最后交易日
|
||||
- 若本地最后日期 >= 日历最后日期:无需同步
|
||||
- 否则:从本地最后日期+1 到最新交易日同步
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,始终返回需要同步
|
||||
table_name: 要检查的表名(默认: "daily")
|
||||
|
||||
Returns:
|
||||
(需要同步, 起始日期, 结束日期, 本地最后日期)
|
||||
- 需要同步: True 表示应继续同步
|
||||
- 起始日期: 同步起始日期(无需同步时为 None)
|
||||
- 结束日期: 同步结束日期(无需同步时为 None)
|
||||
- 本地最后日期: 本地数据最后日期(用于增量同步)
|
||||
"""
|
||||
# 若 force_full,始终同步
|
||||
if force_full:
|
||||
print("[DailySync] Force full sync requested")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 检查特定表的本地数据是否存在
|
||||
storage = Storage()
|
||||
table_data = (
|
||||
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
|
||||
)
|
||||
|
||||
if table_data.empty or "trade_date" not in table_data.columns:
|
||||
print(
|
||||
f"[DailySync] No local data found for table '{table_name}', full sync needed"
|
||||
)
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 获取本地数据最后日期
|
||||
local_last_date = str(table_data["trade_date"].max())
|
||||
|
||||
print(f"[DailySync] Local data last date: {local_last_date}")
|
||||
|
||||
# 从交易日历获取最新交易日
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print("[DailySync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[DailySync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# 比较本地最后日期与日历最后日期
|
||||
print(
|
||||
f"[DailySync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
|
||||
f"cal={cal_last} (type={type(cal_last).__name__})"
|
||||
)
|
||||
try:
|
||||
local_last_int = int(local_last_date)
|
||||
cal_last_int = int(cal_last)
|
||||
print(
|
||||
f"[DailySync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
|
||||
f"{local_last_int >= cal_last_int}"
|
||||
)
|
||||
if local_last_int >= cal_last_int:
|
||||
print(
|
||||
"[DailySync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
|
||||
)
|
||||
return (False, None, None, None)
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"[ERROR] Date comparison failed: {e}")
|
||||
|
||||
# 需要从本地最后日期+1 同步到最新交易日
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[DailySync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览同步数据量和样本(不实际同步)。
|
||||
|
||||
该方法提供即将同步的数据的预览,包括:
|
||||
- 将同步的股票数量
|
||||
- 同步日期范围
|
||||
- 预估总记录数
|
||||
- 前几只股票的样本数据
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' 或 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的
|
||||
print("[DailySync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DailySync] Mode: INCREmental SYNC (bandwidth optimized)")
|
||||
print(f"[DailySync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DailySync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[DailySync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# 从前几只股票获取样本数据
|
||||
print(f"[DailySync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# 合并样本数据
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# 基于样本估算总记录数
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# 显示预览结果
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# 以紧凑格式显示样本数据
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
def fetch_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""同步单只股票的日线数据。
|
||||
"""获取单只股票的日线数据。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
@@ -517,13 +106,8 @@ class DailySync:
|
||||
end_date: 结束日期(YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
包含日线市场数据的 DataFrame
|
||||
包含日线数据的 DataFrame
|
||||
"""
|
||||
# 检查是否应该停止同步(用于异常处理)
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# 使用共享客户端进行跨线程速率限制
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
@@ -533,205 +117,6 @@ class DailySync:
|
||||
factors="tor,vr",
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
# 设置停止标志以通知其他线程停止
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {ts_code}: {e}")
|
||||
raise
|
||||
|
||||
def sync_all(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步本地存储中所有股票的日线数据。
|
||||
|
||||
该函数:
|
||||
1. 从本地存储读取股票代码(daily 或 stock_basic)
|
||||
2. 检查交易日历确定是否需要同步:
|
||||
- 若本地数据匹配交易日历边界,则跳过同步(节省 token)
|
||||
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
|
||||
3. 使用多线程并发获取(带速率限制)
|
||||
4. 跳过返回空数据的股票(退市/不可用)
|
||||
5. 遇异常立即停止
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Starting daily data sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的(使用增量同步)
|
||||
print("[DailySync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 基于交易日历检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
# 跳过同步 - 不消耗 token
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
|
||||
print(" Tokens saved: 0 consumed")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
# 回退到默认逻辑
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DailySync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DailySync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DailySync] No stocks found to sync")
|
||||
return {}
|
||||
|
||||
print(f"[DailySync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[DailySync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# 处理 dry run 模式
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 为新同步重置停止标志
|
||||
self._stop_flag.set()
|
||||
|
||||
# 多线程并发获取
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
error_occurred = False
|
||||
exception_to_raise = None
|
||||
|
||||
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
|
||||
"""每只股票的任务函数。"""
|
||||
try:
|
||||
data = self.sync_single_stock(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return (ts_code, data)
|
||||
except Exception as e:
|
||||
# 重新抛出以被 Future 捕获
|
||||
raise
|
||||
|
||||
# 使用 ThreadPoolExecutor 进行并发获取
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 提交所有任务并跟踪 futures 与股票代码的映射
|
||||
future_to_code = {
|
||||
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
|
||||
}
|
||||
|
||||
# 使用 as_completed 处理结果
|
||||
error_count = 0
|
||||
empty_count = 0
|
||||
success_count = 0
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(stock_codes), desc="Syncing stocks")
|
||||
|
||||
try:
|
||||
# 处理完成的 futures
|
||||
for future in as_completed(future_to_code):
|
||||
ts_code = future_to_code[future]
|
||||
|
||||
try:
|
||||
_, data = future.result()
|
||||
if data is not None and not data.empty:
|
||||
results[ts_code] = data
|
||||
success_count += 1
|
||||
else:
|
||||
# 空数据 - 股票可能已退市或不可用
|
||||
empty_count += 1
|
||||
print(
|
||||
f"[DailySync] Stock {ts_code}: empty data (skipped, may be delisted)"
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常 - 停止全部并中止
|
||||
error_occurred = True
|
||||
exception_to_raise = e
|
||||
print(f"\n[ERROR] Sync aborted due to exception: {e}")
|
||||
# 关闭 executor 以停止所有待处理任务
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise exception_to_raise
|
||||
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
|
||||
except Exception:
|
||||
error_count = 1
|
||||
print("[DailySync] Sync stopped due to exception")
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
# 批量写入所有数据(仅在无错误时)
|
||||
if results and not error_occurred:
|
||||
for ts_code, data in results.items():
|
||||
if not data.empty:
|
||||
self.storage.queue_save("daily", data)
|
||||
# 一次性刷新所有排队写入
|
||||
self.storage.flush()
|
||||
total_rows = sum(len(df) for df in results.values())
|
||||
print(f"\n[DailySync] Saved {total_rows} rows to storage")
|
||||
|
||||
# 摘要
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(f" Total stocks: {len(stock_codes)}")
|
||||
print(f" Updated: {success_count}")
|
||||
print(f" Skipped (empty/delisted): {empty_count}")
|
||||
print(
|
||||
f" Errors: {error_count} (aborted on first error)"
|
||||
if error_count
|
||||
else " Errors: 0"
|
||||
)
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def sync_daily(
|
||||
|
||||
@@ -8,21 +8,9 @@ volume ratio (vr), and adjustment factors.
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
|
||||
from src.config.settings import get_settings
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
|
||||
from src.data.api_wrappers.base_sync import StockBasedSync
|
||||
|
||||
|
||||
def get_pro_bar(
|
||||
@@ -138,429 +126,28 @@ def get_pro_bar(
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ProBarSync - Pro Bar 数据批量同步类
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ProBarSync:
|
||||
class ProBarSync(StockBasedSync):
|
||||
"""Pro Bar 数据批量同步管理器,支持全量/增量同步。
|
||||
|
||||
功能特性:
|
||||
- 多线程并发获取(ThreadPoolExecutor)
|
||||
- 增量同步(自动检测上次同步位置)
|
||||
- 内存缓存(避免重复磁盘读取)
|
||||
- 异常立即停止(确保数据一致性)
|
||||
- 预览模式(预览同步数据量,不实际写入)
|
||||
- 默认获取全部数据列(tor, vr, adj_factor)
|
||||
继承自 StockBasedSync,使用多线程按股票并发获取数据。
|
||||
默认获取全部数据列(tor, vr, adj_factor)。
|
||||
|
||||
Example:
|
||||
>>> sync = ProBarSync()
|
||||
>>> results = sync.sync_all() # 增量同步
|
||||
>>> results = sync.sync_all(force_full=True) # 全量同步
|
||||
>>> preview = sync.preview_sync() # 预览
|
||||
"""
|
||||
|
||||
# 默认工作线程数(从配置读取,默认10)
|
||||
DEFAULT_MAX_WORKERS = get_settings().threads
|
||||
table_name = "pro_bar"
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""初始化同步管理器。
|
||||
|
||||
max_workers: 工作线程数(默认从配置读取,若未指定则使用配置值)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
"""
|
||||
self.storage = ThreadSafeStorage()
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
self._stop_flag.set() # 初始为未停止状态
|
||||
self._cached_pro_bar_data: Optional[pd.DataFrame] = None # 数据缓存
|
||||
|
||||
def _load_pro_bar_data(self) -> pd.DataFrame:
|
||||
"""从存储加载 Pro Bar 数据(带缓存)。
|
||||
|
||||
该方法会将数据缓存在内存中以避免重复磁盘读取。
|
||||
调用 clear_cache() 可强制重新加载。
|
||||
|
||||
Returns:
|
||||
缓存或从存储加载的 Pro Bar 数据 DataFrame
|
||||
"""
|
||||
if self._cached_pro_bar_data is None:
|
||||
self._cached_pro_bar_data = self.storage.load("pro_bar")
|
||||
return self._cached_pro_bar_data
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除缓存的 Pro Bar 数据,强制下次访问时重新加载。"""
|
||||
self._cached_pro_bar_data = None
|
||||
|
||||
def get_all_stock_codes(self, only_listed: bool = True) -> list:
|
||||
"""从本地存储获取所有股票代码。
|
||||
|
||||
优先使用 stock_basic.csv 以确保包含所有股票,
|
||||
避免回测中的前视偏差。
|
||||
|
||||
Args:
|
||||
only_listed: 若为 True,仅返回当前上市股票(L 状态)。
|
||||
设为 False 可包含退市股票(用于完整回测)。
|
||||
|
||||
Returns:
|
||||
股票代码列表
|
||||
"""
|
||||
# 确保 stock_basic.csv 是最新的
|
||||
print("[ProBarSync] Ensuring stock_basic.csv is up-to-date...")
|
||||
sync_all_stocks()
|
||||
|
||||
# 从 stock_basic.csv 文件获取
|
||||
stock_csv_path = _get_csv_path()
|
||||
|
||||
if stock_csv_path.exists():
|
||||
print(f"[ProBarSync] Reading stock_basic from CSV: {stock_csv_path}")
|
||||
try:
|
||||
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
|
||||
if not stock_df.empty and "ts_code" in stock_df.columns:
|
||||
# 根据 list_status 过滤
|
||||
if only_listed and "list_status" in stock_df.columns:
|
||||
listed_stocks = stock_df[stock_df["list_status"] == "L"]
|
||||
codes = listed_stocks["ts_code"].unique().tolist()
|
||||
total = len(stock_df["ts_code"].unique())
|
||||
print(
|
||||
f"[ProBarSync] Found {len(codes)} listed stocks (filtered from {total} total)"
|
||||
)
|
||||
else:
|
||||
codes = stock_df["ts_code"].unique().tolist()
|
||||
print(
|
||||
f"[ProBarSync] Found {len(codes)} stock codes from stock_basic.csv"
|
||||
)
|
||||
return codes
|
||||
else:
|
||||
print(
|
||||
f"[ProBarSync] stock_basic.csv exists but no ts_code column or empty"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[ProBarSync] Error reading stock_basic.csv: {e}")
|
||||
|
||||
# 回退:从 Pro Bar 存储获取
|
||||
print(
|
||||
"[ProBarSync] stock_basic.csv not available, falling back to pro_bar data..."
|
||||
)
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if not pro_bar_data.empty and "ts_code" in pro_bar_data.columns:
|
||||
codes = pro_bar_data["ts_code"].unique().tolist()
|
||||
print(f"[ProBarSync] Found {len(codes)} stock codes from pro_bar data")
|
||||
return codes
|
||||
|
||||
print("[ProBarSync] No stock codes found in local storage")
|
||||
return []
|
||||
|
||||
def get_global_last_date(self) -> Optional[str]:
|
||||
"""获取全局最后交易日期。
|
||||
|
||||
Returns:
|
||||
最后交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
|
||||
return None
|
||||
return str(pro_bar_data["trade_date"].max())
|
||||
|
||||
def get_global_first_date(self) -> Optional[str]:
|
||||
"""获取全局最早交易日期。
|
||||
|
||||
Returns:
|
||||
最早交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
|
||||
return None
|
||||
return str(pro_bar_data["trade_date"].min())
|
||||
|
||||
def get_trade_calendar_bounds(
|
||||
self, start_date: str, end_date: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""从交易日历获取首尾交易日。
|
||||
|
||||
Args:
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
|
||||
Returns:
|
||||
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
|
||||
"""
|
||||
try:
|
||||
first_day = get_first_trading_day(start_date, end_date)
|
||||
last_day = get_last_trading_day(start_date, end_date)
|
||||
return (first_day, last_day)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
|
||||
return (None, None)
|
||||
|
||||
def check_sync_needed(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
table_name: str = "pro_bar",
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""基于交易日历检查是否需要同步。
|
||||
|
||||
该方法比较本地数据日期范围与交易日历,
|
||||
以确定是否需要获取新数据。
|
||||
|
||||
逻辑:
|
||||
- 若 force_full:需要同步,返回 (True, 20180101, today)
|
||||
- 若无本地数据:需要同步,返回 (True, 20180101, today)
|
||||
- 若存在本地数据:
|
||||
- 从交易日历获取最后交易日
|
||||
- 若本地最后日期 >= 日历最后日期:无需同步
|
||||
- 否则:从本地最后日期+1 到最新交易日同步
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,始终返回需要同步
|
||||
table_name: 要检查的表名(默认: "pro_bar")
|
||||
|
||||
Returns:
|
||||
(需要同步, 起始日期, 结束日期, 本地最后日期)
|
||||
- 需要同步: True 表示应继续同步
|
||||
- 起始日期: 同步起始日期(无需同步时为 None)
|
||||
- 结束日期: 同步结束日期(无需同步时为 None)
|
||||
- 本地最后日期: 本地数据最后日期(用于增量同步)
|
||||
"""
|
||||
# 若 force_full,始终同步
|
||||
if force_full:
|
||||
print("[ProBarSync] Force full sync requested")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 检查特定表的本地数据是否存在
|
||||
storage = Storage()
|
||||
table_data = (
|
||||
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
|
||||
)
|
||||
|
||||
if table_data.empty or "trade_date" not in table_data.columns:
|
||||
print(
|
||||
f"[ProBarSync] No local data found for table '{table_name}', full sync needed"
|
||||
)
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 获取本地数据最后日期
|
||||
local_last_date = str(table_data["trade_date"].max())
|
||||
|
||||
print(f"[ProBarSync] Local data last date: {local_last_date}")
|
||||
|
||||
# 从交易日历获取最新交易日
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print("[ProBarSync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[ProBarSync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# 比较本地最后日期与日历最后日期
|
||||
print(
|
||||
f"[ProBarSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
|
||||
f"cal={cal_last} (type={type(cal_last).__name__})"
|
||||
)
|
||||
try:
|
||||
local_last_int = int(local_last_date)
|
||||
cal_last_int = int(cal_last)
|
||||
print(
|
||||
f"[ProBarSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
|
||||
f"{local_last_int >= cal_last_int}"
|
||||
)
|
||||
if local_last_int >= cal_last_int:
|
||||
print(
|
||||
"[ProBarSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
|
||||
)
|
||||
return (False, None, None, None)
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"[ERROR] Date comparison failed: {e}")
|
||||
|
||||
# 需要从本地最后日期+1 同步到最新交易日
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[ProBarSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览同步数据量和样本(不实际同步)。
|
||||
|
||||
该方法提供即将同步的数据的预览,包括:
|
||||
- 将同步的股票数量
|
||||
- 同步日期范围
|
||||
- 预估总记录数
|
||||
- 前几只股票的样本数据
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' 或 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的
|
||||
print("[ProBarSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[ProBarSync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[ProBarSync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# 从前几只股票获取样本数据
|
||||
print(f"[ProBarSync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
# 使用 get_pro_bar 获取样本数据(包含所有字段)
|
||||
data = get_pro_bar(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# 合并样本数据
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# 基于样本估算总记录数
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# 显示预览结果
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# 以紧凑格式显示样本数据
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
"tor",
|
||||
"vr",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
def fetch_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""同步单只股票的 Pro Bar 数据。
|
||||
"""获取单只股票的 Pro Bar 数据。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
@@ -570,11 +157,6 @@ class ProBarSync:
|
||||
Returns:
|
||||
包含 Pro Bar 数据的 DataFrame
|
||||
"""
|
||||
# 检查是否应该停止同步(用于异常处理)
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# 使用 get_pro_bar 获取数据(默认包含所有字段,传递共享 client)
|
||||
data = get_pro_bar(
|
||||
ts_code=ts_code,
|
||||
@@ -583,205 +165,6 @@ class ProBarSync:
|
||||
client=self.client, # 传递共享客户端以确保限流
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
# 设置停止标志以通知其他线程停止
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {ts_code}: {e}")
|
||||
raise
|
||||
|
||||
def sync_all(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步本地存储中所有股票的 Pro Bar 数据。
|
||||
|
||||
该函数:
|
||||
1. 从本地存储读取股票代码(pro_bar 或 stock_basic)
|
||||
2. 检查交易日历确定是否需要同步:
|
||||
- 若本地数据匹配交易日历边界,则跳过同步(节省 token)
|
||||
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
|
||||
3. 使用多线程并发获取(带速率限制)
|
||||
4. 跳过返回空数据的股票(退市/不可用)
|
||||
5. 遇异常立即停止
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Starting pro_bar data sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的(使用增量同步)
|
||||
print("[ProBarSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 基于交易日历检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
# 跳过同步 - 不消耗 token
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
|
||||
print(" Tokens saved: 0 consumed")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
# 回退到默认逻辑
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[ProBarSync] No stocks found to sync")
|
||||
return {}
|
||||
|
||||
print(f"[ProBarSync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[ProBarSync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# 处理 dry run 模式
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 为新同步重置停止标志
|
||||
self._stop_flag.set()
|
||||
|
||||
# 多线程并发获取
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
error_occurred = False
|
||||
exception_to_raise = None
|
||||
|
||||
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
|
||||
"""每只股票的任务函数。"""
|
||||
try:
|
||||
data = self.sync_single_stock(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return (ts_code, data)
|
||||
except Exception as e:
|
||||
# 重新抛出以被 Future 捕获
|
||||
raise
|
||||
|
||||
# 使用 ThreadPoolExecutor 进行并发获取
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 提交所有任务并跟踪 futures 与股票代码的映射
|
||||
future_to_code = {
|
||||
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
|
||||
}
|
||||
|
||||
# 使用 as_completed 处理结果
|
||||
error_count = 0
|
||||
empty_count = 0
|
||||
success_count = 0
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(stock_codes), desc="Syncing pro_bar stocks")
|
||||
|
||||
try:
|
||||
# 处理完成的 futures
|
||||
for future in as_completed(future_to_code):
|
||||
ts_code = future_to_code[future]
|
||||
|
||||
try:
|
||||
_, data = future.result()
|
||||
if data is not None and not data.empty:
|
||||
results[ts_code] = data
|
||||
success_count += 1
|
||||
else:
|
||||
# 空数据 - 股票可能已退市或不可用
|
||||
empty_count += 1
|
||||
print(
|
||||
f"[ProBarSync] Stock {ts_code}: empty data (skipped, may be delisted)"
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常 - 停止全部并中止
|
||||
error_occurred = True
|
||||
exception_to_raise = e
|
||||
print(f"\n[ERROR] Sync aborted due to exception: {e}")
|
||||
# 关闭 executor 以停止所有待处理任务
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise exception_to_raise
|
||||
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
|
||||
except Exception:
|
||||
error_count = 1
|
||||
print("[ProBarSync] Sync stopped due to exception")
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
# 批量写入所有数据(仅在无错误时)
|
||||
if results and not error_occurred:
|
||||
for ts_code, data in results.items():
|
||||
if not data.empty:
|
||||
self.storage.queue_save("pro_bar", data)
|
||||
# 一次性刷新所有排队写入
|
||||
self.storage.flush()
|
||||
total_rows = sum(len(df) for df in results.values())
|
||||
print(f"\n[ProBarSync] Saved {total_rows} rows to storage")
|
||||
|
||||
# 摘要
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(f" Total stocks: {len(stock_codes)}")
|
||||
print(f" Updated: {success_count}")
|
||||
print(f" Skipped (empty/delisted): {empty_count}")
|
||||
print(
|
||||
f" Errors: {error_count} (aborted on first error)"
|
||||
if error_count
|
||||
else " Errors: 0"
|
||||
)
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def sync_pro_bar(
|
||||
|
||||
1137
src/data/api_wrappers/base_sync.py
Normal file
1137
src/data/api_wrappers/base_sync.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -189,7 +189,10 @@ class Storage:
|
||||
end_net_profit DOUBLE,
|
||||
update_flag VARCHAR(1),
|
||||
PRIMARY KEY (ts_code, end_date)
|
||||
update_flag VARCHAR(1),
|
||||
PRIMARY KEY (ts_code, end_date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create pro_bar table for pro bar data (with adj, tor, vr)
|
||||
self._connection.execute("""
|
||||
|
||||
123
src/data/sync.py
123
src/data/sync.py
@@ -3,7 +3,8 @@
|
||||
该模块作为数据同步的调度中心,统一管理各类型数据的同步流程。
|
||||
具体的同步逻辑已迁移到对应的 api_xxx.py 文件中:
|
||||
- api_daily.py: 日线数据同步 (DailySync 类)
|
||||
- api_bak_basic.py: 历史股票列表同步
|
||||
- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类)
|
||||
- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类)
|
||||
- api_stock_basic.py: 股票基本信息同步
|
||||
- api_trade_cal.py: 交易日历同步
|
||||
|
||||
@@ -30,6 +31,7 @@ import pandas as pd
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
||||
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
|
||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
||||
|
||||
|
||||
def preview_sync(
|
||||
@@ -135,10 +137,11 @@ def sync_all_data(
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步所有数据类型(每日同步)。
|
||||
|
||||
该函数按顺序同步所有可用的数据类型:
|
||||
1. 交易日历 (sync_trade_cal_cache)
|
||||
2. 股票基本信息 (sync_all_stocks)
|
||||
3. 日线市场数据 (sync_all)
|
||||
3. Pro Bar 数据 (sync_pro_bar)
|
||||
4. 历史股票列表 (sync_bak_basic)
|
||||
|
||||
注意:名称变更 (namechange) 不在自动同步中,如需同步请手动调用。
|
||||
@@ -167,47 +170,29 @@ def sync_all_data(
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Sync trade calendar (always needed first)
|
||||
print("\n[1/6] Syncing trade calendar cache...")
|
||||
print("\n[1/4] Syncing trade calendar cache...")
|
||||
try:
|
||||
from src.data.api_wrappers import sync_trade_cal_cache
|
||||
|
||||
sync_trade_cal_cache()
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
print("[1/6] Trade calendar: OK")
|
||||
print("[1/4] Trade calendar: OK")
|
||||
except Exception as e:
|
||||
print(f"[1/6] Trade calendar: FAILED - {e}")
|
||||
print(f"[1/4] Trade calendar: FAILED - {e}")
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
|
||||
# 2. Sync stock basic info
|
||||
print("\n[2/6] Syncing stock basic info...")
|
||||
print("\n[2/4] Syncing stock basic info...")
|
||||
try:
|
||||
sync_all_stocks()
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
print("[2/6] Stock basic: OK")
|
||||
print("[2/4] Stock basic: OK")
|
||||
except Exception as e:
|
||||
print(f"[2/6] Stock basic: FAILED - {e}")
|
||||
print(f"[2/4] Stock basic: FAILED - {e}")
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
|
||||
# # 3. Sync daily market data
|
||||
# print("\n[3/6] Syncing daily market data...")
|
||||
# try:
|
||||
# daily_result = sync_daily(
|
||||
# force_full=force_full,
|
||||
# max_workers=max_workers,
|
||||
# dry_run=dry_run,
|
||||
# )
|
||||
# results["daily"] = (
|
||||
# pd.concat(daily_result.values(), ignore_index=True)
|
||||
# if daily_result
|
||||
# else pd.DataFrame()
|
||||
# )
|
||||
# print("[3/6] Daily data: OK")
|
||||
# except Exception as e:
|
||||
# print(f"[3/6] Daily data: FAILED - {e}")
|
||||
# results["daily"] = pd.DataFrame()
|
||||
|
||||
# 4. Sync Pro Bar data
|
||||
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...")
|
||||
# 3. Sync Pro Bar data
|
||||
print("\n[3/4] Syncing Pro Bar data (with adj, tor, vr)...")
|
||||
try:
|
||||
pro_bar_result = sync_pro_bar(
|
||||
force_full=force_full,
|
||||
@@ -219,87 +204,19 @@ def sync_all_data(
|
||||
if pro_bar_result
|
||||
else pd.DataFrame()
|
||||
)
|
||||
print(f"[4/6] Pro Bar data: OK ({len(results['pro_bar'])} records)")
|
||||
print(f"[3/4] Pro Bar data: OK ({len(results['pro_bar'])} records)")
|
||||
except Exception as e:
|
||||
print(f"[4/6] Pro Bar data: FAILED - {e}")
|
||||
print(f"[3/4] Pro Bar data: FAILED - {e}")
|
||||
results["pro_bar"] = pd.DataFrame()
|
||||
|
||||
# 5. Sync stock historical list (bak_basic)
|
||||
print("\n[5/6] Syncing stock historical list (bak_basic)...")
|
||||
try:
|
||||
bak_basic_result = sync_bak_basic(force_full=force_full)
|
||||
results["bak_basic"] = bak_basic_result
|
||||
print(f"[5/6] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[5/6] Bak basic: FAILED - {e}")
|
||||
results["bak_basic"] = pd.DataFrame()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("[sync_all_data] Sync Summary")
|
||||
print("=" * 60)
|
||||
for data_type, df in results.items():
|
||||
print(f" {data_type}: {len(df)} records")
|
||||
print("=" * 60)
|
||||
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
|
||||
print(" from src.data.api_wrappers import sync_namechange")
|
||||
print(" sync_namechange(force=True)")
|
||||
|
||||
return results
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("[sync_all_data] Starting full data synchronization...")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Sync trade calendar (always needed first)
|
||||
print("\n[1/5] Syncing trade calendar cache...")
|
||||
try:
|
||||
from src.data.api_wrappers import sync_trade_cal_cache
|
||||
|
||||
sync_trade_cal_cache()
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
print("[1/5] Trade calendar: OK")
|
||||
except Exception as e:
|
||||
print(f"[1/5] Trade calendar: FAILED - {e}")
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
|
||||
# 2. Sync stock basic info
|
||||
print("\n[2/5] Syncing stock basic info...")
|
||||
try:
|
||||
sync_all_stocks()
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
print("[2/5] Stock basic: OK")
|
||||
except Exception as e:
|
||||
print(f"[2/5] Stock basic: FAILED - {e}")
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
|
||||
# 3. Sync daily market data
|
||||
print("\n[3/5] Syncing daily market data...")
|
||||
try:
|
||||
daily_result = sync_daily(
|
||||
force_full=force_full,
|
||||
max_workers=max_workers,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
results["daily"] = (
|
||||
pd.concat(daily_result.values(), ignore_index=True)
|
||||
if daily_result
|
||||
else pd.DataFrame()
|
||||
)
|
||||
print("[3/5] Daily data: OK")
|
||||
except Exception as e:
|
||||
print(f"[3/5] Daily data: FAILED - {e}")
|
||||
results["daily"] = pd.DataFrame()
|
||||
|
||||
# 4. Sync stock historical list (bak_basic)
|
||||
print("\n[4/5] Syncing stock historical list (bak_basic)...")
|
||||
print("\n[4/4] Syncing stock historical list (bak_basic)...")
|
||||
try:
|
||||
bak_basic_result = sync_bak_basic(force_full=force_full)
|
||||
results["bak_basic"] = bak_basic_result
|
||||
print(f"[4/5] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||
print(f"[4/4] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[4/5] Bak basic: FAILED - {e}")
|
||||
print(f"[4/4] Bak basic: FAILED - {e}")
|
||||
results["bak_basic"] = pd.DataFrame()
|
||||
|
||||
# Summary
|
||||
@@ -316,10 +233,6 @@ def sync_all_data(
|
||||
return results
|
||||
|
||||
|
||||
# 保留向后兼容的导入
|
||||
from src.data.api_wrappers import sync_bak_basic
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Data Sync Module")
|
||||
|
||||
Reference in New Issue
Block a user