refactor: 提取数据同步逻辑为抽象基类

新增 base_sync.py 模块,提供三层抽象结构统一数据同步流程:
- BaseDataSync: 所有同步类型的基础抽象(客户端、股票代码获取、交易日历)
- StockBasedSync: 按股票同步抽象类(适用于 daily, pro_bar)
- DateBasedSync: 按日期同步抽象类(适用于 bak_basic)
This commit is contained in:
2026-02-27 23:34:12 +08:00
parent 0698b9d919
commit 484bcd0ab7
6 changed files with 1255 additions and 1548 deletions

View File

@@ -5,12 +5,10 @@ Data available from 2016 onwards.
"""
import pandas as pd
from typing import Optional, List
from datetime import datetime, timedelta
from tqdm import tqdm
from typing import Optional
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.db_manager import ensure_table
from src.data.api_wrappers.base_sync import DateBasedSync
def get_bak_basic(
@@ -75,6 +73,34 @@ def get_bak_basic(
return data
class BakBasicSync(DateBasedSync):
"""历史股票列表批量同步管理器,支持全量/增量同步。
继承自 DateBasedSync按日期顺序获取数据。
数据从 2016 年开始可用。
Example:
>>> sync = BakBasicSync()
>>> results = sync.sync_all() # 增量同步
>>> results = sync.sync_all(force_full=True) # 全量同步
>>> preview = sync.preview_sync() # 预览
"""
table_name = "bak_basic"
default_start_date = "20160101"
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
"""获取单日的历史股票列表数据。
Args:
trade_date: 交易日期YYYYMMDD
Returns:
包含当日所有股票数据的 DataFrame
"""
return get_bak_basic(trade_date=trade_date)
def sync_bak_basic(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
@@ -94,152 +120,12 @@ def sync_bak_basic(
Returns:
pd.DataFrame with synced data
"""
from src.data.db_manager import ensure_table
TABLE_NAME = "bak_basic"
storage = Storage()
thread_storage = ThreadSafeStorage()
# Default end date
if end_date is None:
end_date = datetime.now().strftime("%Y%m%d")
# Check if table exists
table_exists = storage.exists(TABLE_NAME)
if not table_exists or force_full:
# ===== FULL SYNC =====
# 1. Create table with schema
# 2. Create composite index (trade_date, ts_code)
# 3. Full sync from start_date
if not table_exists:
print(f"[sync_bak_basic] Table '{TABLE_NAME}' doesn't exist, creating...")
# Fetch sample to get schema
sample = get_bak_basic(trade_date=end_date)
if sample.empty:
sample = get_bak_basic(trade_date="20240102")
if sample.empty:
print("[sync_bak_basic] Cannot create table: no sample data available")
return pd.DataFrame()
# Create table with schema
columns = []
for col in sample.columns:
dtype = str(sample[col].dtype)
if col == "trade_date":
col_type = "DATE"
elif "int" in dtype:
col_type = "INTEGER"
elif "float" in dtype:
col_type = "DOUBLE"
else:
col_type = "VARCHAR"
columns.append(f'"{col}" {col_type}')
columns_sql = ", ".join(columns)
create_sql = f'CREATE TABLE IF NOT EXISTS "{TABLE_NAME}" ({columns_sql}, PRIMARY KEY ("trade_date", "ts_code"))'
try:
storage._connection.execute(create_sql)
print(f"[sync_bak_basic] Created table '{TABLE_NAME}'")
except Exception as e:
print(f"[sync_bak_basic] Error creating table: {e}")
# Create composite index
try:
storage._connection.execute(f"""
CREATE INDEX IF NOT EXISTS "idx_bak_basic_date_code"
ON "{TABLE_NAME}"("trade_date", "ts_code")
""")
print(f"[sync_bak_basic] Created composite index on (trade_date, ts_code)")
except Exception as e:
print(f"[sync_bak_basic] Error creating index: {e}")
# Determine sync dates
sync_start = start_date or "20160101"
mode = "FULL"
print(f"[sync_bak_basic] Mode: {mode} SYNC from {sync_start} to {end_date}")
else:
# ===== INCREMENTAL SYNC =====
# Check last date in table, sync from last_date + 1
try:
result = storage._connection.execute(
f'SELECT MAX("trade_date") FROM "{TABLE_NAME}"'
).fetchone()
last_date = result[0] if result and result[0] else None
except Exception as e:
print(f"[sync_bak_basic] Error getting last date: {e}")
last_date = None
if last_date is None:
# Table exists but empty, do full sync
sync_start = start_date or "20160101"
mode = "FULL (empty table)"
else:
# Incremental from last_date + 1
# Handle both YYYYMMDD and YYYY-MM-DD formats
last_date_str = str(last_date).replace("-", "")
last_dt = datetime.strptime(last_date_str, "%Y%m%d")
next_dt = last_dt + timedelta(days=1)
sync_start = next_dt.strftime("%Y%m%d")
mode = "INCREMENTAL"
# Skip if already up to date
if sync_start > end_date:
print(f"[sync_bak_basic] Data is up-to-date (last: {last_date}), skipping sync")
return pd.DataFrame()
print(f"[sync_bak_basic] Mode: {mode} from {sync_start} to {end_date} (last: {last_date})")
# ===== FETCH AND SAVE DATA =====
all_data: List[pd.DataFrame] = []
current = datetime.strptime(sync_start, "%Y%m%d")
end_dt = datetime.strptime(end_date, "%Y%m%d")
# Calculate total days for progress bar
total_days = (end_dt - current).days + 1
print(f"[sync_bak_basic] Fetching data for {total_days} days...")
with tqdm(total=total_days, desc="Syncing dates") as pbar:
while current <= end_dt:
date_str = current.strftime("%Y%m%d")
try:
data = get_bak_basic(trade_date=date_str)
if not data.empty:
all_data.append(data)
pbar.set_postfix({"date": date_str, "records": len(data)})
except Exception as e:
print(f" {date_str}: ERROR - {e}")
current += timedelta(days=1)
pbar.update(1)
if not all_data:
print("[sync_bak_basic] No data fetched")
return pd.DataFrame()
# Combine and save
combined = pd.concat(all_data, ignore_index=True)
# Convert trade_date to datetime for proper DATE type storage
combined["trade_date"] = pd.to_datetime(combined["trade_date"], format="%Y%m%d")
print(f"[sync_bak_basic] Total records: {len(combined)}")
# Delete existing data for the date range and append new data
# Convert sync_start to date format for comparison with DATE column
sync_start_date = pd.to_datetime(sync_start, format="%Y%m%d").date()
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start_date])
thread_storage.queue_save(TABLE_NAME, combined)
thread_storage.flush()
print(f"[sync_bak_basic] Saved {len(combined)} records to DuckDB")
return combined
sync_manager = BakBasicSync()
return sync_manager.sync_all(
start_date=start_date,
end_date=end_date,
force_full=force_full,
)
if __name__ == "__main__":

View File

@@ -9,21 +9,9 @@ batch synchronization (DailySync class) for daily market data.
import pandas as pd
from typing import Optional, List, Literal, Dict
from datetime import datetime, timedelta
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
from src.config.settings import get_settings
from src.data.api_wrappers.api_trade_cal import (
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
)
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
from src.data.api_wrappers.base_sync import StockBasedSync
def get_daily(
@@ -90,426 +78,27 @@ def get_daily(
return data
# =============================================================================
# DailySync - 日线数据批量同步类
# =============================================================================
class DailySync:
class DailySync(StockBasedSync):
"""日线数据批量同步管理器,支持全量/增量同步。
功能特性:
- 多线程并发获取ThreadPoolExecutor
- 增量同步(自动检测上次同步位置)
- 内存缓存(避免重复磁盘读取)
- 异常立即停止(确保数据一致性)
- 预览模式(预览同步数据量,不实际写入)
继承自 StockBasedSync使用多线程按股票并发获取数据。
Example:
>>> sync = DailySync()
>>> results = sync.sync_all() # 增量同步
>>> results = sync.sync_all(force_full=True) # 全量同步
>>> preview = sync.preview_sync() # 预览
"""
# 默认工作线程数从配置读取默认10
DEFAULT_MAX_WORKERS = get_settings().threads
table_name = "daily"
def __init__(self, max_workers: Optional[int] = None):
"""初始化同步管理器。
Args:
max_workers: 工作线程数(默认从配置读取)
"""
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()
self._stop_flag.set() # 初始为未停止状态
self._cached_daily_data: Optional[pd.DataFrame] = None # 日线数据缓存
def _load_daily_data(self) -> pd.DataFrame:
"""从存储加载日线数据(带缓存)。
该方法会将数据缓存在内存中以避免重复磁盘读取。
调用 clear_cache() 可强制重新加载。
Returns:
缓存或从存储加载的日线数据 DataFrame
"""
if self._cached_daily_data is None:
self._cached_daily_data = self.storage.load("daily")
return self._cached_daily_data
def clear_cache(self) -> None:
"""清除缓存的日线数据,强制下次访问时重新加载。"""
self._cached_daily_data = None
def get_all_stock_codes(self, only_listed: bool = True) -> list:
"""从本地存储获取所有股票代码。
优先使用 stock_basic.csv 以确保包含所有股票,
避免回测中的前视偏差。
Args:
only_listed: 若为 True仅返回当前上市股票L 状态)。
设为 False 可包含退市股票(用于完整回测)。
Returns:
股票代码列表
"""
# 确保 stock_basic.csv 是最新的
print("[DailySync] Ensuring stock_basic.csv is up-to-date...")
sync_all_stocks()
# 从 stock_basic.csv 文件获取
stock_csv_path = _get_csv_path()
if stock_csv_path.exists():
print(f"[DailySync] Reading stock_basic from CSV: {stock_csv_path}")
try:
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
if not stock_df.empty and "ts_code" in stock_df.columns:
# 根据 list_status 过滤
if only_listed and "list_status" in stock_df.columns:
listed_stocks = stock_df[stock_df["list_status"] == "L"]
codes = listed_stocks["ts_code"].unique().tolist()
total = len(stock_df["ts_code"].unique())
print(
f"[DailySync] Found {len(codes)} listed stocks (filtered from {total} total)"
)
else:
codes = stock_df["ts_code"].unique().tolist()
print(
f"[DailySync] Found {len(codes)} stock codes from stock_basic.csv"
)
return codes
else:
print(
f"[DailySync] stock_basic.csv exists but no ts_code column or empty"
)
except Exception as e:
print(f"[DailySync] Error reading stock_basic.csv: {e}")
# 回退:从日线存储获取
print(
"[DailySync] stock_basic.csv not available, falling back to daily data..."
)
daily_data = self._load_daily_data()
if not daily_data.empty and "ts_code" in daily_data.columns:
codes = daily_data["ts_code"].unique().tolist()
print(f"[DailySync] Found {len(codes)} stock codes from daily data")
return codes
print("[DailySync] No stock codes found in local storage")
return []
def get_global_last_date(self) -> Optional[str]:
"""获取全局最后交易日期。
Returns:
最后交易日期字符串,若无数据则返回 None
"""
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
return None
return str(daily_data["trade_date"].max())
def get_global_first_date(self) -> Optional[str]:
"""获取全局最早交易日期。
Returns:
最早交易日期字符串,若无数据则返回 None
"""
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
return None
return str(daily_data["trade_date"].min())
def get_trade_calendar_bounds(
self, start_date: str, end_date: str
) -> tuple[Optional[str], Optional[str]]:
"""从交易日历获取首尾交易日。
Args:
start_date: 开始日期YYYYMMDD 格式)
end_date: 结束日期YYYYMMDD 格式)
Returns:
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
"""
try:
first_day = get_first_trading_day(start_date, end_date)
last_day = get_last_trading_day(start_date, end_date)
return (first_day, last_day)
except Exception as e:
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
return (None, None)
def check_sync_needed(
self,
force_full: bool = False,
table_name: str = "daily",
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""基于交易日历检查是否需要同步。
该方法比较本地数据日期范围与交易日历,
以确定是否需要获取新数据。
逻辑:
- 若 force_full需要同步返回 (True, 20180101, today)
- 若无本地数据:需要同步,返回 (True, 20180101, today)
- 若存在本地数据:
- 从交易日历获取最后交易日
- 若本地最后日期 >= 日历最后日期:无需同步
- 否则:从本地最后日期+1 到最新交易日同步
Args:
force_full: 若为 True始终返回需要同步
table_name: 要检查的表名(默认: "daily"
Returns:
(需要同步, 起始日期, 结束日期, 本地最后日期)
- 需要同步: True 表示应继续同步
- 起始日期: 同步起始日期(无需同步时为 None
- 结束日期: 同步结束日期(无需同步时为 None
- 本地最后日期: 本地数据最后日期(用于增量同步)
"""
# 若 force_full始终同步
if force_full:
print("[DailySync] Force full sync requested")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# 检查特定表的本地数据是否存在
storage = Storage()
table_data = (
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
)
if table_data.empty or "trade_date" not in table_data.columns:
print(
f"[DailySync] No local data found for table '{table_name}', full sync needed"
)
return (True, DEFAULT_START_DATE, get_today_date(), None)
# 获取本地数据最后日期
local_last_date = str(table_data["trade_date"].max())
print(f"[DailySync] Local data last date: {local_last_date}")
# 从交易日历获取最新交易日
today = get_today_date()
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
if cal_last is None:
print("[DailySync] Failed to get trade calendar, proceeding with sync")
return (True, DEFAULT_START_DATE, today, local_last_date)
print(f"[DailySync] Calendar last trading day: {cal_last}")
# 比较本地最后日期与日历最后日期
print(
f"[DailySync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
f"cal={cal_last} (type={type(cal_last).__name__})"
)
try:
local_last_int = int(local_last_date)
cal_last_int = int(cal_last)
print(
f"[DailySync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
f"{local_last_int >= cal_last_int}"
)
if local_last_int >= cal_last_int:
print(
"[DailySync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
)
return (False, None, None, None)
except (ValueError, TypeError) as e:
print(f"[ERROR] Date comparison failed: {e}")
# 需要从本地最后日期+1 同步到最新交易日
sync_start = get_next_date(local_last_date)
print(f"[DailySync] Incremental sync needed from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
def preview_sync(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""预览同步数据量和样本(不实际同步)。
该方法提供即将同步的数据的预览,包括:
- 将同步的股票数量
- 同步日期范围
- 预估总记录数
- 前几只股票的样本数据
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full''incremental'
}
"""
print("\n" + "=" * 60)
print("[DailySync] Preview Mode - Analyzing sync requirements...")
print("=" * 60)
# 首先确保交易日历缓存是最新的
print("[DailySync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# 确定日期范围
if end_date is None:
end_date = get_today_date()
# 检查是否需要同步
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
print("\n" + "=" * 60)
print("[DailySync] Preview Result")
print("=" * 60)
print(" Sync Status: NOT NEEDED")
print(" Reason: Local data is up-to-date with trade calendar")
print("=" * 60)
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
# 使用 check_sync_needed 返回的日期
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# 确定同步模式
if force_full:
mode = "full"
print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DailySync] Mode: INCREmental SYNC (bandwidth optimized)")
print(f"[DailySync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}")
# 获取所有股票代码
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DailySync] No stocks found to sync")
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
stock_count = len(stock_codes)
print(f"[DailySync] Total stocks to sync: {stock_count}")
# 从前几只股票获取样本数据
print(f"[DailySync] Fetching sample data from {sample_size} stocks...")
sample_data_list = []
sample_codes = stock_codes[:sample_size]
for ts_code in sample_codes:
try:
data = self.client.query(
"pro_bar",
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
factors="tor,vr",
)
if not data.empty:
sample_data_list.append(data)
print(f" - {ts_code}: {len(data)} records")
except Exception as e:
print(f" - {ts_code}: Error fetching - {e}")
# 合并样本数据
sample_df = (
pd.concat(sample_data_list, ignore_index=True)
if sample_data_list
else pd.DataFrame()
)
# 基于样本估算总记录数
if not sample_df.empty:
avg_records_per_stock = len(sample_df) / len(sample_data_list)
estimated_records = int(avg_records_per_stock * stock_count)
else:
estimated_records = 0
# 显示预览结果
print("\n" + "=" * 60)
print("[DailySync] Preview Result")
print("=" * 60)
print(f" Sync Mode: {mode.upper()}")
print(f" Date Range: {sync_start_date} to {end_date}")
print(f" Stocks to Sync: {stock_count}")
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
print(f" Estimated Total Records: ~{estimated_records:,}")
if not sample_df.empty:
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
print(" " + "-" * 56)
# 以紧凑格式显示样本数据
preview_cols = [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"vol",
]
available_cols = [c for c in preview_cols if c in sample_df.columns]
sample_display = sample_df[available_cols].head(10)
for idx, row in sample_display.iterrows():
print(f" {row.to_dict()}")
print(" " + "-" * 56)
print("=" * 60)
return {
"sync_needed": True,
"stock_count": stock_count,
"start_date": sync_start_date,
"end_date": end_date,
"estimated_records": estimated_records,
"sample_data": sample_df,
"mode": mode,
}
def sync_single_stock(
def fetch_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""同步单只股票的日线数据。
"""获取单只股票的日线数据。
Args:
ts_code: 股票代码
@@ -517,13 +106,8 @@ class DailySync:
end_date: 结束日期YYYYMMDD
Returns:
包含日线市场数据的 DataFrame
包含日线数据的 DataFrame
"""
# 检查是否应该停止同步(用于异常处理)
if not self._stop_flag.is_set():
return pd.DataFrame()
try:
# 使用共享客户端进行跨线程速率限制
data = self.client.query(
"pro_bar",
@@ -533,205 +117,6 @@ class DailySync:
factors="tor,vr",
)
return data
except Exception as e:
# 设置停止标志以通知其他线程停止
self._stop_flag.clear()
print(f"[ERROR] Exception syncing {ts_code}: {e}")
raise
def sync_all(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步本地存储中所有股票的日线数据。
该函数:
1. 从本地存储读取股票代码daily 或 stock_basic
2. 检查交易日历确定是否需要同步:
- 若本地数据匹配交易日历边界,则跳过同步(节省 token
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
3. 使用多线程并发获取(带速率限制)
4. 跳过返回空数据的股票(退市/不可用)
5. 遇异常立即停止
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
"""
print("\n" + "=" * 60)
print("[DailySync] Starting daily data sync...")
print("=" * 60)
# 首先确保交易日历缓存是最新的(使用增量同步)
print("[DailySync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# 确定日期范围
if end_date is None:
end_date = get_today_date()
# 基于交易日历检查是否需要同步
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
# 跳过同步 - 不消耗 token
print("\n" + "=" * 60)
print("[DailySync] Sync Summary")
print("=" * 60)
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
print(" Tokens saved: 0 consumed")
print("=" * 60)
return {}
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
# 回退到默认逻辑
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# 确定同步模式
if force_full:
mode = "full"
print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DailySync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DailySync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}")
# 获取所有股票代码
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DailySync] No stocks found to sync")
return {}
print(f"[DailySync] Total stocks to sync: {len(stock_codes)}")
print(f"[DailySync] Using {max_workers or self.max_workers} worker threads")
# 处理 dry run 模式
if dry_run:
print("\n" + "=" * 60)
print("[DailySync] DRY RUN MODE - No data will be written")
print("=" * 60)
print(f" Would sync {len(stock_codes)} stocks")
print(f" Date range: {sync_start_date} to {end_date}")
print(f" Mode: {mode}")
print("=" * 60)
return {}
# 为新同步重置停止标志
self._stop_flag.set()
# 多线程并发获取
results: Dict[str, pd.DataFrame] = {}
error_occurred = False
exception_to_raise = None
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
"""每只股票的任务函数。"""
try:
data = self.sync_single_stock(
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
)
return (ts_code, data)
except Exception as e:
# 重新抛出以被 Future 捕获
raise
# 使用 ThreadPoolExecutor 进行并发获取
workers = max_workers or self.max_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
# 提交所有任务并跟踪 futures 与股票代码的映射
future_to_code = {
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
}
# 使用 as_completed 处理结果
error_count = 0
empty_count = 0
success_count = 0
# 创建进度条
pbar = tqdm(total=len(stock_codes), desc="Syncing stocks")
try:
# 处理完成的 futures
for future in as_completed(future_to_code):
ts_code = future_to_code[future]
try:
_, data = future.result()
if data is not None and not data.empty:
results[ts_code] = data
success_count += 1
else:
# 空数据 - 股票可能已退市或不可用
empty_count += 1
print(
f"[DailySync] Stock {ts_code}: empty data (skipped, may be delisted)"
)
except Exception as e:
# 发生异常 - 停止全部并中止
error_occurred = True
exception_to_raise = e
print(f"\n[ERROR] Sync aborted due to exception: {e}")
# 关闭 executor 以停止所有待处理任务
executor.shutdown(wait=False, cancel_futures=True)
raise exception_to_raise
# 更新进度条
pbar.update(1)
except Exception:
error_count = 1
print("[DailySync] Sync stopped due to exception")
finally:
pbar.close()
# 批量写入所有数据(仅在无错误时)
if results and not error_occurred:
for ts_code, data in results.items():
if not data.empty:
self.storage.queue_save("daily", data)
# 一次性刷新所有排队写入
self.storage.flush()
total_rows = sum(len(df) for df in results.values())
print(f"\n[DailySync] Saved {total_rows} rows to storage")
# 摘要
print("\n" + "=" * 60)
print("[DailySync] Sync Summary")
print("=" * 60)
print(f" Total stocks: {len(stock_codes)}")
print(f" Updated: {success_count}")
print(f" Skipped (empty/delisted): {empty_count}")
print(
f" Errors: {error_count} (aborted on first error)"
if error_count
else " Errors: 0"
)
print(f" Date range: {sync_start_date} to {end_date}")
print("=" * 60)
return results
def sync_daily(

View File

@@ -8,21 +8,9 @@ volume ratio (vr), and adjustment factors.
import pandas as pd
from typing import Optional, List, Literal, Dict
from datetime import datetime, timedelta
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
from src.config.settings import get_settings
from src.data.api_wrappers.api_trade_cal import (
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
)
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
from src.data.api_wrappers.base_sync import StockBasedSync
def get_pro_bar(
@@ -138,429 +126,28 @@ def get_pro_bar(
return data
# =============================================================================
# ProBarSync - Pro Bar 数据批量同步类
# =============================================================================
class ProBarSync:
class ProBarSync(StockBasedSync):
"""Pro Bar 数据批量同步管理器,支持全量/增量同步。
功能特性:
- 多线程并发获取ThreadPoolExecutor
- 增量同步(自动检测上次同步位置)
- 内存缓存(避免重复磁盘读取)
- 异常立即停止(确保数据一致性)
- 预览模式(预览同步数据量,不实际写入)
- 默认获取全部数据列tor, vr, adj_factor
继承自 StockBasedSync使用多线程按股票并发获取数据。
默认获取全部数据列tor, vr, adj_factor
Example:
>>> sync = ProBarSync()
>>> results = sync.sync_all() # 增量同步
>>> results = sync.sync_all(force_full=True) # 全量同步
>>> preview = sync.preview_sync() # 预览
"""
# 默认工作线程数从配置读取默认10
DEFAULT_MAX_WORKERS = get_settings().threads
table_name = "pro_bar"
def __init__(self, max_workers: Optional[int] = None):
"""初始化同步管理器。
max_workers: 工作线程数(默认从配置读取,若未指定则使用配置值)
max_workers: 工作线程数(默认: 10
"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()
self._stop_flag.set() # 初始为未停止状态
self._cached_pro_bar_data: Optional[pd.DataFrame] = None # 数据缓存
def _load_pro_bar_data(self) -> pd.DataFrame:
"""从存储加载 Pro Bar 数据(带缓存)。
该方法会将数据缓存在内存中以避免重复磁盘读取。
调用 clear_cache() 可强制重新加载。
Returns:
缓存或从存储加载的 Pro Bar 数据 DataFrame
"""
if self._cached_pro_bar_data is None:
self._cached_pro_bar_data = self.storage.load("pro_bar")
return self._cached_pro_bar_data
def clear_cache(self) -> None:
"""清除缓存的 Pro Bar 数据,强制下次访问时重新加载。"""
self._cached_pro_bar_data = None
def get_all_stock_codes(self, only_listed: bool = True) -> list:
"""从本地存储获取所有股票代码。
优先使用 stock_basic.csv 以确保包含所有股票,
避免回测中的前视偏差。
Args:
only_listed: 若为 True仅返回当前上市股票L 状态)。
设为 False 可包含退市股票(用于完整回测)。
Returns:
股票代码列表
"""
# 确保 stock_basic.csv 是最新的
print("[ProBarSync] Ensuring stock_basic.csv is up-to-date...")
sync_all_stocks()
# 从 stock_basic.csv 文件获取
stock_csv_path = _get_csv_path()
if stock_csv_path.exists():
print(f"[ProBarSync] Reading stock_basic from CSV: {stock_csv_path}")
try:
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
if not stock_df.empty and "ts_code" in stock_df.columns:
# 根据 list_status 过滤
if only_listed and "list_status" in stock_df.columns:
listed_stocks = stock_df[stock_df["list_status"] == "L"]
codes = listed_stocks["ts_code"].unique().tolist()
total = len(stock_df["ts_code"].unique())
print(
f"[ProBarSync] Found {len(codes)} listed stocks (filtered from {total} total)"
)
else:
codes = stock_df["ts_code"].unique().tolist()
print(
f"[ProBarSync] Found {len(codes)} stock codes from stock_basic.csv"
)
return codes
else:
print(
f"[ProBarSync] stock_basic.csv exists but no ts_code column or empty"
)
except Exception as e:
print(f"[ProBarSync] Error reading stock_basic.csv: {e}")
# 回退:从 Pro Bar 存储获取
print(
"[ProBarSync] stock_basic.csv not available, falling back to pro_bar data..."
)
pro_bar_data = self._load_pro_bar_data()
if not pro_bar_data.empty and "ts_code" in pro_bar_data.columns:
codes = pro_bar_data["ts_code"].unique().tolist()
print(f"[ProBarSync] Found {len(codes)} stock codes from pro_bar data")
return codes
print("[ProBarSync] No stock codes found in local storage")
return []
def get_global_last_date(self) -> Optional[str]:
"""获取全局最后交易日期。
Returns:
最后交易日期字符串,若无数据则返回 None
"""
pro_bar_data = self._load_pro_bar_data()
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
return None
return str(pro_bar_data["trade_date"].max())
def get_global_first_date(self) -> Optional[str]:
"""获取全局最早交易日期。
Returns:
最早交易日期字符串,若无数据则返回 None
"""
pro_bar_data = self._load_pro_bar_data()
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
return None
return str(pro_bar_data["trade_date"].min())
def get_trade_calendar_bounds(
self, start_date: str, end_date: str
) -> tuple[Optional[str], Optional[str]]:
"""从交易日历获取首尾交易日。
Args:
start_date: 开始日期YYYYMMDD 格式)
end_date: 结束日期YYYYMMDD 格式)
Returns:
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
"""
try:
first_day = get_first_trading_day(start_date, end_date)
last_day = get_last_trading_day(start_date, end_date)
return (first_day, last_day)
except Exception as e:
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
return (None, None)
def check_sync_needed(
self,
force_full: bool = False,
table_name: str = "pro_bar",
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""基于交易日历检查是否需要同步。
该方法比较本地数据日期范围与交易日历,
以确定是否需要获取新数据。
逻辑:
- 若 force_full需要同步返回 (True, 20180101, today)
- 若无本地数据:需要同步,返回 (True, 20180101, today)
- 若存在本地数据:
- 从交易日历获取最后交易日
- 若本地最后日期 >= 日历最后日期:无需同步
- 否则:从本地最后日期+1 到最新交易日同步
Args:
force_full: 若为 True始终返回需要同步
table_name: 要检查的表名(默认: "pro_bar"
Returns:
(需要同步, 起始日期, 结束日期, 本地最后日期)
- 需要同步: True 表示应继续同步
- 起始日期: 同步起始日期(无需同步时为 None
- 结束日期: 同步结束日期(无需同步时为 None
- 本地最后日期: 本地数据最后日期(用于增量同步)
"""
# 若 force_full始终同步
if force_full:
print("[ProBarSync] Force full sync requested")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# 检查特定表的本地数据是否存在
storage = Storage()
table_data = (
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
)
if table_data.empty or "trade_date" not in table_data.columns:
print(
f"[ProBarSync] No local data found for table '{table_name}', full sync needed"
)
return (True, DEFAULT_START_DATE, get_today_date(), None)
# 获取本地数据最后日期
local_last_date = str(table_data["trade_date"].max())
print(f"[ProBarSync] Local data last date: {local_last_date}")
# 从交易日历获取最新交易日
today = get_today_date()
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
if cal_last is None:
print("[ProBarSync] Failed to get trade calendar, proceeding with sync")
return (True, DEFAULT_START_DATE, today, local_last_date)
print(f"[ProBarSync] Calendar last trading day: {cal_last}")
# 比较本地最后日期与日历最后日期
print(
f"[ProBarSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
f"cal={cal_last} (type={type(cal_last).__name__})"
)
try:
local_last_int = int(local_last_date)
cal_last_int = int(cal_last)
print(
f"[ProBarSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
f"{local_last_int >= cal_last_int}"
)
if local_last_int >= cal_last_int:
print(
"[ProBarSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
)
return (False, None, None, None)
except (ValueError, TypeError) as e:
print(f"[ERROR] Date comparison failed: {e}")
# 需要从本地最后日期+1 同步到最新交易日
sync_start = get_next_date(local_last_date)
print(f"[ProBarSync] Incremental sync needed from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
def preview_sync(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""预览同步数据量和样本(不实际同步)。
该方法提供即将同步的数据的预览,包括:
- 将同步的股票数量
- 同步日期范围
- 预估总记录数
- 前几只股票的样本数据
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full''incremental'
}
"""
print("\n" + "=" * 60)
print("[ProBarSync] Preview Mode - Analyzing sync requirements...")
print("=" * 60)
# 首先确保交易日历缓存是最新的
print("[ProBarSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# 确定日期范围
if end_date is None:
end_date = get_today_date()
# 检查是否需要同步
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
print("\n" + "=" * 60)
print("[ProBarSync] Preview Result")
print("=" * 60)
print(" Sync Status: NOT NEEDED")
print(" Reason: Local data is up-to-date with trade calendar")
print("=" * 60)
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
# 使用 check_sync_needed 返回的日期
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# 确定同步模式
if force_full:
mode = "full"
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
# 获取所有股票代码
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[ProBarSync] No stocks found to sync")
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
stock_count = len(stock_codes)
print(f"[ProBarSync] Total stocks to sync: {stock_count}")
# 从前几只股票获取样本数据
print(f"[ProBarSync] Fetching sample data from {sample_size} stocks...")
sample_data_list = []
sample_codes = stock_codes[:sample_size]
for ts_code in sample_codes:
try:
# 使用 get_pro_bar 获取样本数据(包含所有字段)
data = get_pro_bar(
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
)
if not data.empty:
sample_data_list.append(data)
print(f" - {ts_code}: {len(data)} records")
except Exception as e:
print(f" - {ts_code}: Error fetching - {e}")
# 合并样本数据
sample_df = (
pd.concat(sample_data_list, ignore_index=True)
if sample_data_list
else pd.DataFrame()
)
# 基于样本估算总记录数
if not sample_df.empty:
avg_records_per_stock = len(sample_df) / len(sample_data_list)
estimated_records = int(avg_records_per_stock * stock_count)
else:
estimated_records = 0
# 显示预览结果
print("\n" + "=" * 60)
print("[ProBarSync] Preview Result")
print("=" * 60)
print(f" Sync Mode: {mode.upper()}")
print(f" Date Range: {sync_start_date} to {end_date}")
print(f" Stocks to Sync: {stock_count}")
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
print(f" Estimated Total Records: ~{estimated_records:,}")
if not sample_df.empty:
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
print(" " + "-" * 56)
# 以紧凑格式显示样本数据
preview_cols = [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"vol",
"tor",
"vr",
]
available_cols = [c for c in preview_cols if c in sample_df.columns]
sample_display = sample_df[available_cols].head(10)
for idx, row in sample_display.iterrows():
print(f" {row.to_dict()}")
print(" " + "-" * 56)
print("=" * 60)
return {
"sync_needed": True,
"stock_count": stock_count,
"start_date": sync_start_date,
"end_date": end_date,
"estimated_records": estimated_records,
"sample_data": sample_df,
"mode": mode,
}
def sync_single_stock(
def fetch_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""同步单只股票的 Pro Bar 数据。
"""获取单只股票的 Pro Bar 数据。
Args:
ts_code: 股票代码
@@ -570,11 +157,6 @@ class ProBarSync:
Returns:
包含 Pro Bar 数据的 DataFrame
"""
# 检查是否应该停止同步(用于异常处理)
if not self._stop_flag.is_set():
return pd.DataFrame()
try:
# 使用 get_pro_bar 获取数据(默认包含所有字段,传递共享 client
data = get_pro_bar(
ts_code=ts_code,
@@ -583,205 +165,6 @@ class ProBarSync:
client=self.client, # 传递共享客户端以确保限流
)
return data
except Exception as e:
# 设置停止标志以通知其他线程停止
self._stop_flag.clear()
print(f"[ERROR] Exception syncing {ts_code}: {e}")
raise
def sync_all(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步本地存储中所有股票的 Pro Bar 数据。
该函数:
1. 从本地存储读取股票代码pro_bar 或 stock_basic
2. 检查交易日历确定是否需要同步:
- 若本地数据匹配交易日历边界,则跳过同步(节省 token
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
3. 使用多线程并发获取(带速率限制)
4. 跳过返回空数据的股票(退市/不可用)
5. 遇异常立即停止
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
"""
print("\n" + "=" * 60)
print("[ProBarSync] Starting pro_bar data sync...")
print("=" * 60)
# 首先确保交易日历缓存是最新的(使用增量同步)
print("[ProBarSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# 确定日期范围
if end_date is None:
end_date = get_today_date()
# 基于交易日历检查是否需要同步
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
# 跳过同步 - 不消耗 token
print("\n" + "=" * 60)
print("[ProBarSync] Sync Summary")
print("=" * 60)
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
print(" Tokens saved: 0 consumed")
print("=" * 60)
return {}
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
# 回退到默认逻辑
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# 确定同步模式
if force_full:
mode = "full"
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
# 获取所有股票代码
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[ProBarSync] No stocks found to sync")
return {}
print(f"[ProBarSync] Total stocks to sync: {len(stock_codes)}")
print(f"[ProBarSync] Using {max_workers or self.max_workers} worker threads")
# 处理 dry run 模式
if dry_run:
print("\n" + "=" * 60)
print("[ProBarSync] DRY RUN MODE - No data will be written")
print("=" * 60)
print(f" Would sync {len(stock_codes)} stocks")
print(f" Date range: {sync_start_date} to {end_date}")
print(f" Mode: {mode}")
print("=" * 60)
return {}
# 为新同步重置停止标志
self._stop_flag.set()
# 多线程并发获取
results: Dict[str, pd.DataFrame] = {}
error_occurred = False
exception_to_raise = None
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
"""每只股票的任务函数。"""
try:
data = self.sync_single_stock(
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
)
return (ts_code, data)
except Exception as e:
# 重新抛出以被 Future 捕获
raise
# 使用 ThreadPoolExecutor 进行并发获取
workers = max_workers or self.max_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
# 提交所有任务并跟踪 futures 与股票代码的映射
future_to_code = {
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
}
# 使用 as_completed 处理结果
error_count = 0
empty_count = 0
success_count = 0
# 创建进度条
pbar = tqdm(total=len(stock_codes), desc="Syncing pro_bar stocks")
try:
# 处理完成的 futures
for future in as_completed(future_to_code):
ts_code = future_to_code[future]
try:
_, data = future.result()
if data is not None and not data.empty:
results[ts_code] = data
success_count += 1
else:
# 空数据 - 股票可能已退市或不可用
empty_count += 1
print(
f"[ProBarSync] Stock {ts_code}: empty data (skipped, may be delisted)"
)
except Exception as e:
# 发生异常 - 停止全部并中止
error_occurred = True
exception_to_raise = e
print(f"\n[ERROR] Sync aborted due to exception: {e}")
# 关闭 executor 以停止所有待处理任务
executor.shutdown(wait=False, cancel_futures=True)
raise exception_to_raise
# 更新进度条
pbar.update(1)
except Exception:
error_count = 1
print("[ProBarSync] Sync stopped due to exception")
finally:
pbar.close()
# 批量写入所有数据(仅在无错误时)
if results and not error_occurred:
for ts_code, data in results.items():
if not data.empty:
self.storage.queue_save("pro_bar", data)
# 一次性刷新所有排队写入
self.storage.flush()
total_rows = sum(len(df) for df in results.values())
print(f"\n[ProBarSync] Saved {total_rows} rows to storage")
# 摘要
print("\n" + "=" * 60)
print("[ProBarSync] Sync Summary")
print("=" * 60)
print(f" Total stocks: {len(stock_codes)}")
print(f" Updated: {success_count}")
print(f" Skipped (empty/delisted): {empty_count}")
print(
f" Errors: {error_count} (aborted on first error)"
if error_count
else " Errors: 0"
)
print(f" Date range: {sync_start_date} to {end_date}")
print("=" * 60)
return results
def sync_pro_bar(

File diff suppressed because it is too large Load Diff

View File

@@ -189,7 +189,10 @@ class Storage:
end_net_profit DOUBLE,
update_flag VARCHAR(1),
PRIMARY KEY (ts_code, end_date)
update_flag VARCHAR(1),
PRIMARY KEY (ts_code, end_date)
)
""")
# Create pro_bar table for pro bar data (with adj, tor, vr)
self._connection.execute("""

View File

@@ -3,7 +3,8 @@
该模块作为数据同步的调度中心,统一管理各类型数据的同步流程。
具体的同步逻辑已迁移到对应的 api_xxx.py 文件中:
- api_daily.py: 日线数据同步 (DailySync 类)
- api_bak_basic.py: 历史股票列表同步
- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类)
- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类)
- api_stock_basic.py: 股票基本信息同步
- api_trade_cal.py: 交易日历同步
@@ -30,6 +31,7 @@ import pandas as pd
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
def preview_sync(
@@ -135,10 +137,11 @@ def sync_all_data(
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步所有数据类型(每日同步)。
该函数按顺序同步所有可用的数据类型:
1. 交易日历 (sync_trade_cal_cache)
2. 股票基本信息 (sync_all_stocks)
3. 日线市场数据 (sync_all)
3. Pro Bar 数据 (sync_pro_bar)
4. 历史股票列表 (sync_bak_basic)
注意:名称变更 (namechange) 不在自动同步中,如需同步请手动调用。
@@ -167,47 +170,29 @@ def sync_all_data(
print("=" * 60)
# 1. Sync trade calendar (always needed first)
print("\n[1/6] Syncing trade calendar cache...")
print("\n[1/4] Syncing trade calendar cache...")
try:
from src.data.api_wrappers import sync_trade_cal_cache
sync_trade_cal_cache()
results["trade_cal"] = pd.DataFrame()
print("[1/6] Trade calendar: OK")
print("[1/4] Trade calendar: OK")
except Exception as e:
print(f"[1/6] Trade calendar: FAILED - {e}")
print(f"[1/4] Trade calendar: FAILED - {e}")
results["trade_cal"] = pd.DataFrame()
# 2. Sync stock basic info
print("\n[2/6] Syncing stock basic info...")
print("\n[2/4] Syncing stock basic info...")
try:
sync_all_stocks()
results["stock_basic"] = pd.DataFrame()
print("[2/6] Stock basic: OK")
print("[2/4] Stock basic: OK")
except Exception as e:
print(f"[2/6] Stock basic: FAILED - {e}")
print(f"[2/4] Stock basic: FAILED - {e}")
results["stock_basic"] = pd.DataFrame()
# # 3. Sync daily market data
# print("\n[3/6] Syncing daily market data...")
# try:
# daily_result = sync_daily(
# force_full=force_full,
# max_workers=max_workers,
# dry_run=dry_run,
# )
# results["daily"] = (
# pd.concat(daily_result.values(), ignore_index=True)
# if daily_result
# else pd.DataFrame()
# )
# print("[3/6] Daily data: OK")
# except Exception as e:
# print(f"[3/6] Daily data: FAILED - {e}")
# results["daily"] = pd.DataFrame()
# 4. Sync Pro Bar data
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...")
# 3. Sync Pro Bar data
print("\n[3/4] Syncing Pro Bar data (with adj, tor, vr)...")
try:
pro_bar_result = sync_pro_bar(
force_full=force_full,
@@ -219,87 +204,19 @@ def sync_all_data(
if pro_bar_result
else pd.DataFrame()
)
print(f"[4/6] Pro Bar data: OK ({len(results['pro_bar'])} records)")
print(f"[3/4] Pro Bar data: OK ({len(results['pro_bar'])} records)")
except Exception as e:
print(f"[4/6] Pro Bar data: FAILED - {e}")
print(f"[3/4] Pro Bar data: FAILED - {e}")
results["pro_bar"] = pd.DataFrame()
# 5. Sync stock historical list (bak_basic)
print("\n[5/6] Syncing stock historical list (bak_basic)...")
try:
bak_basic_result = sync_bak_basic(force_full=force_full)
results["bak_basic"] = bak_basic_result
print(f"[5/6] Bak basic: OK ({len(bak_basic_result)} records)")
except Exception as e:
print(f"[5/6] Bak basic: FAILED - {e}")
results["bak_basic"] = pd.DataFrame()
# Summary
print("\n" + "=" * 60)
print("[sync_all_data] Sync Summary")
print("=" * 60)
for data_type, df in results.items():
print(f" {data_type}: {len(df)} records")
print("=" * 60)
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
print(" from src.data.api_wrappers import sync_namechange")
print(" sync_namechange(force=True)")
return results
results: Dict[str, pd.DataFrame] = {}
print("\n" + "=" * 60)
print("[sync_all_data] Starting full data synchronization...")
print("=" * 60)
# 1. Sync trade calendar (always needed first)
print("\n[1/5] Syncing trade calendar cache...")
try:
from src.data.api_wrappers import sync_trade_cal_cache
sync_trade_cal_cache()
results["trade_cal"] = pd.DataFrame()
print("[1/5] Trade calendar: OK")
except Exception as e:
print(f"[1/5] Trade calendar: FAILED - {e}")
results["trade_cal"] = pd.DataFrame()
# 2. Sync stock basic info
print("\n[2/5] Syncing stock basic info...")
try:
sync_all_stocks()
results["stock_basic"] = pd.DataFrame()
print("[2/5] Stock basic: OK")
except Exception as e:
print(f"[2/5] Stock basic: FAILED - {e}")
results["stock_basic"] = pd.DataFrame()
# 3. Sync daily market data
print("\n[3/5] Syncing daily market data...")
try:
daily_result = sync_daily(
force_full=force_full,
max_workers=max_workers,
dry_run=dry_run,
)
results["daily"] = (
pd.concat(daily_result.values(), ignore_index=True)
if daily_result
else pd.DataFrame()
)
print("[3/5] Daily data: OK")
except Exception as e:
print(f"[3/5] Daily data: FAILED - {e}")
results["daily"] = pd.DataFrame()
# 4. Sync stock historical list (bak_basic)
print("\n[4/5] Syncing stock historical list (bak_basic)...")
print("\n[4/4] Syncing stock historical list (bak_basic)...")
try:
bak_basic_result = sync_bak_basic(force_full=force_full)
results["bak_basic"] = bak_basic_result
print(f"[4/5] Bak basic: OK ({len(bak_basic_result)} records)")
print(f"[4/4] Bak basic: OK ({len(bak_basic_result)} records)")
except Exception as e:
print(f"[4/5] Bak basic: FAILED - {e}")
print(f"[4/4] Bak basic: FAILED - {e}")
results["bak_basic"] = pd.DataFrame()
# Summary
@@ -316,10 +233,6 @@ def sync_all_data(
return results
# 保留向后兼容的导入
from src.data.api_wrappers import sync_bak_basic
if __name__ == "__main__":
print("=" * 60)
print("Data Sync Module")