feat: 添加DSL因子表达式系统和Pro Bar API封装

- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合
- 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等)
- 新增 factors/compiler.py: 因子编译器
- 新增 factors/translator.py: DSL表达式翻译器
- 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据
- 新增 data/data_router.py: 数据路由功能
- 新增相关测试用例
This commit is contained in:
2026-02-27 22:43:45 +08:00
parent a56433e440
commit 0698b9d919
9 changed files with 4012 additions and 0 deletions

View File

@@ -0,0 +1,880 @@
"""Pro Bar (通用行情) interface.
Fetch A-share stock market data with adjustment factors from Tushare.
This interface provides backward-adjusted (后复权) daily market data
including all available fields: base price data, turnover rate (tor),
volume ratio (vr), and adjustment factors.
"""
import pandas as pd
from typing import Optional, List, Literal, Dict
from datetime import datetime, timedelta
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
from src.config.settings import get_settings
from src.data.api_wrappers.api_trade_cal import (
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
)
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
def get_pro_bar(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
asset: Literal["E", "I", "C", "FT", "FD", "O", "CB"] = "E",
adj: Literal[None, "qfq", "hfq"] = "hfq",
freq: Literal["D", "W", "M"] = "D",
ma: Optional[List[int]] = None,
factors: Optional[List[Literal["tor", "vr"]]] = None,
adjfactor: bool = True,
client: Optional[TushareClient] = None,
) -> pd.DataFrame:
"""Fetch pro bar (universal market) data from Tushare.
This interface retrieves stock market data with adjustment factors.
By default, it fetches backward-adjusted (后复权) daily data for stocks
with turnover rate and volume ratio factors enabled.
Args:
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
start_date: Start date in YYYYMMDD format
end_date: End date in YYYYMMDD format
asset: Asset type - 'E' (stock), 'I' (index), 'C' (crypto),
'FT' (futures), 'FD' (fund), 'O' (options), 'CB' (convertible bond)
adj: Adjustment type - None (no adjustment), 'qfq' (forward),
'hfq' (backward). Default is 'hfq' (backward-adjusted).
freq: Data frequency - 'D' (daily), 'W' (weekly), 'M' (monthly)
ma: List of moving average periods (e.g., [5, 10, 20])
factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio).
Default is ['tor', 'vr'] to fetch all available fields.
adjfactor: Whether to include adjustment factor column. Default is True.
client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting.
Returns:
pd.DataFrame with columns:
- ts_code: Stock code
- trade_date: Trade date (YYYYMMDD)
- open: Opening price
- high: Highest price
- low: Lowest price
- close: Closing price
- pre_close: Previous closing price (adjusted)
- change: Price change amount
- pct_chg: Price change percentage
- vol: Trading volume (lots)
- amount: Trading amount (thousand CNY)
- tor: Turnover rate (if factors includes 'tor')
- vr: Volume ratio (if factors includes 'vr')
- adj_factor: Adjustment factor (if adjfactor=True)
- ma_X: Moving average price for period X (if ma specified)
- ma_v_X: Moving average volume for period X (if ma specified)
Example:
>>> # Get backward-adjusted daily data with all factors (default)
>>> data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
>>>
>>> # Get unadjusted data
>>> data = get_pro_bar('000001.SZ', start_date='20240101', adj=None)
>>>
>>> # Get data with moving averages
>>> data = get_pro_bar('000001.SZ', start_date='20240101', ma=[5, 10, 20])
>>>
>>> # Get index data
>>> data = get_pro_bar('000001.SH', asset='I', start_date='20240101')
"""
client = client or TushareClient()
# Build parameters
params = {"ts_code": ts_code}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if asset:
params["asset"] = asset
if adj:
params["adj"] = adj
if freq:
params["freq"] = freq
if ma:
# Tushare expects ma as comma-separated string
if isinstance(ma, list):
ma_str = ",".join(str(m) for m in ma)
else:
ma_str = str(ma)
params["ma"] = ma_str
# Default to fetching all factors if not specified
factors_to_use = factors if factors is not None else ["tor", "vr"]
if factors_to_use:
# Tushare expects factors as comma-separated string
if isinstance(factors_to_use, list):
factors_str = ",".join(factors_to_use)
else:
factors_str = factors_to_use
params["factors"] = factors_str
if adjfactor:
params["adjfactor"] = "True"
# Fetch data using pro_bar API
data = client.query("pro_bar", **params)
# Rename date column if needed
if "date" in data.columns:
data = data.rename(columns={"date": "trade_date"})
return data
# =============================================================================
# ProBarSync - Pro Bar 数据批量同步类
# =============================================================================
class ProBarSync:
"""Pro Bar 数据批量同步管理器,支持全量/增量同步。
功能特性:
- 多线程并发获取ThreadPoolExecutor
- 增量同步(自动检测上次同步位置)
- 内存缓存(避免重复磁盘读取)
- 异常立即停止(确保数据一致性)
- 预览模式(预览同步数据量,不实际写入)
- 默认获取全部数据列tor, vr, adj_factor
"""
# 默认工作线程数从配置读取默认10
DEFAULT_MAX_WORKERS = get_settings().threads
def __init__(self, max_workers: Optional[int] = None):
"""初始化同步管理器。
max_workers: 工作线程数(默认从配置读取,若未指定则使用配置值)
max_workers: 工作线程数(默认: 10
"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()
self._stop_flag.set() # 初始为未停止状态
self._cached_pro_bar_data: Optional[pd.DataFrame] = None # 数据缓存
def _load_pro_bar_data(self) -> pd.DataFrame:
"""从存储加载 Pro Bar 数据(带缓存)。
该方法会将数据缓存在内存中以避免重复磁盘读取。
调用 clear_cache() 可强制重新加载。
Returns:
缓存或从存储加载的 Pro Bar 数据 DataFrame
"""
if self._cached_pro_bar_data is None:
self._cached_pro_bar_data = self.storage.load("pro_bar")
return self._cached_pro_bar_data
def clear_cache(self) -> None:
"""清除缓存的 Pro Bar 数据,强制下次访问时重新加载。"""
self._cached_pro_bar_data = None
def get_all_stock_codes(self, only_listed: bool = True) -> list:
"""从本地存储获取所有股票代码。
优先使用 stock_basic.csv 以确保包含所有股票,
避免回测中的前视偏差。
Args:
only_listed: 若为 True仅返回当前上市股票L 状态)。
设为 False 可包含退市股票(用于完整回测)。
Returns:
股票代码列表
"""
# 确保 stock_basic.csv 是最新的
print("[ProBarSync] Ensuring stock_basic.csv is up-to-date...")
sync_all_stocks()
# 从 stock_basic.csv 文件获取
stock_csv_path = _get_csv_path()
if stock_csv_path.exists():
print(f"[ProBarSync] Reading stock_basic from CSV: {stock_csv_path}")
try:
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
if not stock_df.empty and "ts_code" in stock_df.columns:
# 根据 list_status 过滤
if only_listed and "list_status" in stock_df.columns:
listed_stocks = stock_df[stock_df["list_status"] == "L"]
codes = listed_stocks["ts_code"].unique().tolist()
total = len(stock_df["ts_code"].unique())
print(
f"[ProBarSync] Found {len(codes)} listed stocks (filtered from {total} total)"
)
else:
codes = stock_df["ts_code"].unique().tolist()
print(
f"[ProBarSync] Found {len(codes)} stock codes from stock_basic.csv"
)
return codes
else:
print(
f"[ProBarSync] stock_basic.csv exists but no ts_code column or empty"
)
except Exception as e:
print(f"[ProBarSync] Error reading stock_basic.csv: {e}")
# 回退:从 Pro Bar 存储获取
print(
"[ProBarSync] stock_basic.csv not available, falling back to pro_bar data..."
)
pro_bar_data = self._load_pro_bar_data()
if not pro_bar_data.empty and "ts_code" in pro_bar_data.columns:
codes = pro_bar_data["ts_code"].unique().tolist()
print(f"[ProBarSync] Found {len(codes)} stock codes from pro_bar data")
return codes
print("[ProBarSync] No stock codes found in local storage")
return []
def get_global_last_date(self) -> Optional[str]:
"""获取全局最后交易日期。
Returns:
最后交易日期字符串,若无数据则返回 None
"""
pro_bar_data = self._load_pro_bar_data()
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
return None
return str(pro_bar_data["trade_date"].max())
def get_global_first_date(self) -> Optional[str]:
"""获取全局最早交易日期。
Returns:
最早交易日期字符串,若无数据则返回 None
"""
pro_bar_data = self._load_pro_bar_data()
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
return None
return str(pro_bar_data["trade_date"].min())
def get_trade_calendar_bounds(
self, start_date: str, end_date: str
) -> tuple[Optional[str], Optional[str]]:
"""从交易日历获取首尾交易日。
Args:
start_date: 开始日期YYYYMMDD 格式)
end_date: 结束日期YYYYMMDD 格式)
Returns:
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
"""
try:
first_day = get_first_trading_day(start_date, end_date)
last_day = get_last_trading_day(start_date, end_date)
return (first_day, last_day)
except Exception as e:
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
return (None, None)
def check_sync_needed(
self,
force_full: bool = False,
table_name: str = "pro_bar",
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""基于交易日历检查是否需要同步。
该方法比较本地数据日期范围与交易日历,
以确定是否需要获取新数据。
逻辑:
- 若 force_full需要同步返回 (True, 20180101, today)
- 若无本地数据:需要同步,返回 (True, 20180101, today)
- 若存在本地数据:
- 从交易日历获取最后交易日
- 若本地最后日期 >= 日历最后日期:无需同步
- 否则:从本地最后日期+1 到最新交易日同步
Args:
force_full: 若为 True始终返回需要同步
table_name: 要检查的表名(默认: "pro_bar"
Returns:
(需要同步, 起始日期, 结束日期, 本地最后日期)
- 需要同步: True 表示应继续同步
- 起始日期: 同步起始日期(无需同步时为 None
- 结束日期: 同步结束日期(无需同步时为 None
- 本地最后日期: 本地数据最后日期(用于增量同步)
"""
# 若 force_full始终同步
if force_full:
print("[ProBarSync] Force full sync requested")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# 检查特定表的本地数据是否存在
storage = Storage()
table_data = (
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
)
if table_data.empty or "trade_date" not in table_data.columns:
print(
f"[ProBarSync] No local data found for table '{table_name}', full sync needed"
)
return (True, DEFAULT_START_DATE, get_today_date(), None)
# 获取本地数据最后日期
local_last_date = str(table_data["trade_date"].max())
print(f"[ProBarSync] Local data last date: {local_last_date}")
# 从交易日历获取最新交易日
today = get_today_date()
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
if cal_last is None:
print("[ProBarSync] Failed to get trade calendar, proceeding with sync")
return (True, DEFAULT_START_DATE, today, local_last_date)
print(f"[ProBarSync] Calendar last trading day: {cal_last}")
# 比较本地最后日期与日历最后日期
print(
f"[ProBarSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
f"cal={cal_last} (type={type(cal_last).__name__})"
)
try:
local_last_int = int(local_last_date)
cal_last_int = int(cal_last)
print(
f"[ProBarSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
f"{local_last_int >= cal_last_int}"
)
if local_last_int >= cal_last_int:
print(
"[ProBarSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
)
return (False, None, None, None)
except (ValueError, TypeError) as e:
print(f"[ERROR] Date comparison failed: {e}")
# 需要从本地最后日期+1 同步到最新交易日
sync_start = get_next_date(local_last_date)
print(f"[ProBarSync] Incremental sync needed from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
def preview_sync(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""预览同步数据量和样本(不实际同步)。
该方法提供即将同步的数据的预览,包括:
- 将同步的股票数量
- 同步日期范围
- 预估总记录数
- 前几只股票的样本数据
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full''incremental'
}
"""
print("\n" + "=" * 60)
print("[ProBarSync] Preview Mode - Analyzing sync requirements...")
print("=" * 60)
# 首先确保交易日历缓存是最新的
print("[ProBarSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# 确定日期范围
if end_date is None:
end_date = get_today_date()
# 检查是否需要同步
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
print("\n" + "=" * 60)
print("[ProBarSync] Preview Result")
print("=" * 60)
print(" Sync Status: NOT NEEDED")
print(" Reason: Local data is up-to-date with trade calendar")
print("=" * 60)
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
# 使用 check_sync_needed 返回的日期
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# 确定同步模式
if force_full:
mode = "full"
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
# 获取所有股票代码
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[ProBarSync] No stocks found to sync")
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
stock_count = len(stock_codes)
print(f"[ProBarSync] Total stocks to sync: {stock_count}")
# 从前几只股票获取样本数据
print(f"[ProBarSync] Fetching sample data from {sample_size} stocks...")
sample_data_list = []
sample_codes = stock_codes[:sample_size]
for ts_code in sample_codes:
try:
# 使用 get_pro_bar 获取样本数据(包含所有字段)
data = get_pro_bar(
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
)
if not data.empty:
sample_data_list.append(data)
print(f" - {ts_code}: {len(data)} records")
except Exception as e:
print(f" - {ts_code}: Error fetching - {e}")
# 合并样本数据
sample_df = (
pd.concat(sample_data_list, ignore_index=True)
if sample_data_list
else pd.DataFrame()
)
# 基于样本估算总记录数
if not sample_df.empty:
avg_records_per_stock = len(sample_df) / len(sample_data_list)
estimated_records = int(avg_records_per_stock * stock_count)
else:
estimated_records = 0
# 显示预览结果
print("\n" + "=" * 60)
print("[ProBarSync] Preview Result")
print("=" * 60)
print(f" Sync Mode: {mode.upper()}")
print(f" Date Range: {sync_start_date} to {end_date}")
print(f" Stocks to Sync: {stock_count}")
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
print(f" Estimated Total Records: ~{estimated_records:,}")
if not sample_df.empty:
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
print(" " + "-" * 56)
# 以紧凑格式显示样本数据
preview_cols = [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"vol",
"tor",
"vr",
]
available_cols = [c for c in preview_cols if c in sample_df.columns]
sample_display = sample_df[available_cols].head(10)
for idx, row in sample_display.iterrows():
print(f" {row.to_dict()}")
print(" " + "-" * 56)
print("=" * 60)
return {
"sync_needed": True,
"stock_count": stock_count,
"start_date": sync_start_date,
"end_date": end_date,
"estimated_records": estimated_records,
"sample_data": sample_df,
"mode": mode,
}
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""同步单只股票的 Pro Bar 数据。
Args:
ts_code: 股票代码
start_date: 起始日期YYYYMMDD
end_date: 结束日期YYYYMMDD
Returns:
包含 Pro Bar 数据的 DataFrame
"""
# 检查是否应该停止同步(用于异常处理)
if not self._stop_flag.is_set():
return pd.DataFrame()
try:
# 使用 get_pro_bar 获取数据(默认包含所有字段,传递共享 client
data = get_pro_bar(
ts_code=ts_code,
start_date=start_date,
end_date=end_date,
client=self.client, # 传递共享客户端以确保限流
)
return data
except Exception as e:
# 设置停止标志以通知其他线程停止
self._stop_flag.clear()
print(f"[ERROR] Exception syncing {ts_code}: {e}")
raise
def sync_all(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步本地存储中所有股票的 Pro Bar 数据。
该函数:
1. 从本地存储读取股票代码pro_bar 或 stock_basic
2. 检查交易日历确定是否需要同步:
- 若本地数据匹配交易日历边界,则跳过同步(节省 token
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
3. 使用多线程并发获取(带速率限制)
4. 跳过返回空数据的股票(退市/不可用)
5. 遇异常立即停止
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
"""
print("\n" + "=" * 60)
print("[ProBarSync] Starting pro_bar data sync...")
print("=" * 60)
# 首先确保交易日历缓存是最新的(使用增量同步)
print("[ProBarSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# 确定日期范围
if end_date is None:
end_date = get_today_date()
# 基于交易日历检查是否需要同步
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
# 跳过同步 - 不消耗 token
print("\n" + "=" * 60)
print("[ProBarSync] Sync Summary")
print("=" * 60)
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
print(" Tokens saved: 0 consumed")
print("=" * 60)
return {}
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
# 回退到默认逻辑
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# 确定同步模式
if force_full:
mode = "full"
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
# 获取所有股票代码
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[ProBarSync] No stocks found to sync")
return {}
print(f"[ProBarSync] Total stocks to sync: {len(stock_codes)}")
print(f"[ProBarSync] Using {max_workers or self.max_workers} worker threads")
# 处理 dry run 模式
if dry_run:
print("\n" + "=" * 60)
print("[ProBarSync] DRY RUN MODE - No data will be written")
print("=" * 60)
print(f" Would sync {len(stock_codes)} stocks")
print(f" Date range: {sync_start_date} to {end_date}")
print(f" Mode: {mode}")
print("=" * 60)
return {}
# 为新同步重置停止标志
self._stop_flag.set()
# 多线程并发获取
results: Dict[str, pd.DataFrame] = {}
error_occurred = False
exception_to_raise = None
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
"""每只股票的任务函数。"""
try:
data = self.sync_single_stock(
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
)
return (ts_code, data)
except Exception as e:
# 重新抛出以被 Future 捕获
raise
# 使用 ThreadPoolExecutor 进行并发获取
workers = max_workers or self.max_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
# 提交所有任务并跟踪 futures 与股票代码的映射
future_to_code = {
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
}
# 使用 as_completed 处理结果
error_count = 0
empty_count = 0
success_count = 0
# 创建进度条
pbar = tqdm(total=len(stock_codes), desc="Syncing pro_bar stocks")
try:
# 处理完成的 futures
for future in as_completed(future_to_code):
ts_code = future_to_code[future]
try:
_, data = future.result()
if data is not None and not data.empty:
results[ts_code] = data
success_count += 1
else:
# 空数据 - 股票可能已退市或不可用
empty_count += 1
print(
f"[ProBarSync] Stock {ts_code}: empty data (skipped, may be delisted)"
)
except Exception as e:
# 发生异常 - 停止全部并中止
error_occurred = True
exception_to_raise = e
print(f"\n[ERROR] Sync aborted due to exception: {e}")
# 关闭 executor 以停止所有待处理任务
executor.shutdown(wait=False, cancel_futures=True)
raise exception_to_raise
# 更新进度条
pbar.update(1)
except Exception:
error_count = 1
print("[ProBarSync] Sync stopped due to exception")
finally:
pbar.close()
# 批量写入所有数据(仅在无错误时)
if results and not error_occurred:
for ts_code, data in results.items():
if not data.empty:
self.storage.queue_save("pro_bar", data)
# 一次性刷新所有排队写入
self.storage.flush()
total_rows = sum(len(df) for df in results.values())
print(f"\n[ProBarSync] Saved {total_rows} rows to storage")
# 摘要
print("\n" + "=" * 60)
print("[ProBarSync] Sync Summary")
print("=" * 60)
print(f" Total stocks: {len(stock_codes)}")
print(f" Updated: {success_count}")
print(f" Skipped (empty/delisted): {empty_count}")
print(
f" Errors: {error_count} (aborted on first error)"
if error_count
else " Errors: 0"
)
print(f" Date range: {sync_start_date} to {end_date}")
print("=" * 60)
return results
def sync_pro_bar(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步所有股票的 Pro Bar 数据。
这是 Pro Bar 数据同步的主要入口点。
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期YYYYMMDD
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
映射 ts_code 到 DataFrame 的字典
Example:
>>> # 首次同步(从 20180101 全量加载)
>>> result = sync_pro_bar()
>>>
>>> # 后续同步(增量 - 仅新数据)
>>> result = sync_pro_bar()
>>>
>>> # 强制完整重载
>>> result = sync_pro_bar(force_full=True)
>>>
>>> # 手动指定日期范围
>>> result = sync_pro_bar(start_date='20240101', end_date='20240131')
>>>
>>> # 自定义线程数
>>> result = sync_pro_bar(max_workers=20)
>>>
>>> # Dry run仅预览
>>> result = sync_pro_bar(dry_run=True)
"""
sync_manager = ProBarSync(max_workers=max_workers)
return sync_manager.sync_all(
force_full=force_full,
start_date=start_date,
end_date=end_date,
dry_run=dry_run,
)
def preview_pro_bar_sync(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""预览 Pro Bar 同步数据量和样本(不实际同步)。
这是推荐的方式,可在实际同步前检查将要同步的内容。
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
}
Example:
>>> # 预览将要同步的内容
>>> preview = preview_pro_bar_sync()
>>>
>>> # 预览全量同步
>>> preview = preview_pro_bar_sync(force_full=True)
>>>
>>> # 预览更多样本
>>> preview = preview_pro_bar_sync(sample_size=5)
"""
sync_manager = ProBarSync()
return sync_manager.preview_sync(
force_full=force_full,
start_date=start_date,
end_date=end_date,
sample_size=sample_size,
)

663
src/data/data_router.py Normal file
View File

@@ -0,0 +1,663 @@
"""数据目录与动态 SQL 路由模块。
用于动态 SQL 生成和数据拉取,解决多表架构下的数据查询痛点。
支持 DAILY日频精确对齐和 PIT低频财务数据按披露日对齐两种表类型。
核心特性:
- 自动发现 DuckDB 数据库中的表结构
- 支持通过配置覆盖自动发现的元数据
- 智能识别 PIT 类型表(通过 ann_date/f_ann_date 字段)
"""
from typing import Dict, List, Set, Optional, Literal
from dataclasses import dataclass, field
from enum import Enum
import polars as pl
import duckdb
from pathlib import Path
class TableFrequency(Enum):
"""表频度类型。"""
DAILY = "daily" # 日频数据,精确对齐
PIT = "pit" # 低频数据,按披露日对齐 (Point-In-Time)
@dataclass
class TableMetadata:
"""表元数据配置。
Attributes:
name: 表名
frequency: 表频度类型DAILY 或 PIT
date_field: 日期字段名DAILY 表为 trade_datePIT 表为 ann_date
code_field: 资产代码字段名(通常为 ts_code
fields: 表中所有字段列表
description: 表描述
"""
name: str
frequency: TableFrequency
date_field: str
code_field: str = "ts_code"
fields: List[str] = field(default_factory=list)
description: str = ""
@dataclass
class FieldMapping:
"""字段映射配置。
Attributes:
field_name: 字段名
table_name: 所属表名
description: 字段描述
"""
field_name: str
table_name: str
description: str = ""
class DatabaseCatalog:
"""数据库目录类,管理字段到表的映射关系。
核心职责:
1. 自动从 DuckDB 数据库中发现表结构
2. 维护字段到表的映射关系
3. 管理表的元数据(频度类型、日期字段等)
4. 提供字段解析和表路由功能
表类型自动识别规则:
- 如果表包含 ann_date 或 f_ann_date 字段,识别为 PIT 类型
- 否则,如果包含 trade_date 字段,识别为 DAILY 类型
Attributes:
tables: 表元数据字典,表名 -> TableMetadata
field_mappings: 字段映射字典,字段名 -> FieldMapping
db_path: 数据库文件路径
Example:
>>> catalog = DatabaseCatalog("data/prostock.db")
>>> # 自动发现所有表结构
>>> catalog.discover_tables()
>>> table = catalog.get_table_for_field("close")
>>> print(table) # "daily"
"""
# PIT 类型表的标识字段(优先级顺序)
PIT_DATE_FIELDS = ["ann_date", "f_ann_date", "publish_date"]
# DAILY 类型表的标识字段
DAILY_DATE_FIELDS = ["trade_date", "cal_date", "date"]
def __init__(self, db_path: Optional[str] = None):
"""初始化数据库目录。
Args:
db_path: 数据库文件路径,如果为 None 则使用默认配置
"""
self.tables: Dict[str, TableMetadata] = {}
self.field_mappings: Dict[str, FieldMapping] = {}
self.db_path = db_path
self._table_frequency_overrides: Dict[str, TableFrequency] = {}
if db_path:
self.discover_tables(db_path)
def set_table_frequency_override(
self, table_name: str, frequency: TableFrequency
) -> None:
"""设置表频度类型覆盖。
用于手动指定表的频度类型,覆盖自动识别的结果。
Args:
table_name: 表名
frequency: 频度类型DAILY 或 PIT
"""
self._table_frequency_overrides[table_name] = frequency
def discover_tables(self, db_path: str) -> None:
"""自动发现数据库中的所有表结构。
从 information_schema 中读取表和列信息,自动识别:
- 表名和字段列表
- 资产代码字段ts_code
- 日期字段(根据字段名智能识别表类型)
- 表频度类型DAILY 或 PIT
Args:
db_path: DuckDB 数据库文件路径
"""
db_file = db_path.replace("duckdb://", "").lstrip("/")
if not Path(db_file).exists():
print(f"[DatabaseCatalog] 数据库文件不存在: {db_file}")
return
conn = duckdb.connect(db_file, read_only=True)
try:
# 获取所有表
tables_query = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'main'
ORDER BY table_name
"""
tables_result = conn.execute(tables_query).fetchall()
for (table_name,) in tables_result:
# 获取表的列信息
columns_query = """
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = ? AND table_schema = 'main'
ORDER BY ordinal_position
"""
columns_result = conn.execute(columns_query, [table_name]).fetchall()
fields = [col[0] for col in columns_result]
# 自动识别表类型和日期字段
frequency, date_field = self._detect_table_type(fields, table_name)
# 检查是否有资产代码字段
code_field = "ts_code" if "ts_code" in fields else None
if code_field and date_field:
# 创建表元数据
metadata = TableMetadata(
name=table_name,
frequency=frequency,
date_field=date_field,
code_field=code_field,
fields=fields,
description=f"自动发现的表: {table_name}",
)
self.register_table(metadata)
print(
f"[DatabaseCatalog] 发现表: {table_name} ({frequency.value}, "
f"日期字段: {date_field})"
)
finally:
conn.close()
def _detect_table_type(
self, fields: List[str], table_name: str
) -> tuple[TableFrequency, Optional[str]]:
"""自动检测表的频度类型和日期字段。
检测规则(按优先级):
1. 检查是否有手动覆盖配置
2. 检查是否包含 PIT 标识字段ann_date, f_ann_date 等)
3. 检查是否包含 DAILY 标识字段trade_date, cal_date 等)
Args:
fields: 表的字段列表
table_name: 表名
Returns:
(频度类型, 日期字段名)
"""
# 检查手动覆盖配置
if table_name in self._table_frequency_overrides:
frequency = self._table_frequency_overrides[table_name]
if frequency == TableFrequency.PIT:
for field in self.PIT_DATE_FIELDS:
if field in fields:
return frequency, field
else:
for field in self.DAILY_DATE_FIELDS:
if field in fields:
return frequency, field
# 检查 PIT 标识字段
for field in self.PIT_DATE_FIELDS:
if field in fields:
return TableFrequency.PIT, field
# 检查 DAILY 标识字段
for field in self.DAILY_DATE_FIELDS:
if field in fields:
return TableFrequency.DAILY, field
# 默认返回 DAILY但无日期字段
return TableFrequency.DAILY, None
def register_table(self, metadata: TableMetadata) -> None:
"""注册表元数据。
Args:
metadata: 表元数据配置
"""
self.tables[metadata.name] = metadata
# 自动注册字段映射(如果字段已存在,保留第一个表的映射)
for field_name in metadata.fields:
if field_name not in self.field_mappings:
self.field_mappings[field_name] = FieldMapping(
field_name=field_name,
table_name=metadata.name,
description=f"{metadata.description} - {field_name}",
)
def get_table_for_field(self, field: str) -> Optional[str]:
"""获取字段对应的表名。
Args:
field: 字段名
Returns:
表名,如果字段不存在则返回 None
"""
mapping = self.field_mappings.get(field)
return mapping.table_name if mapping else None
def get_table_metadata(self, table_name: str) -> Optional[TableMetadata]:
"""获取表的元数据。
Args:
table_name: 表名
Returns:
表元数据,如果不存在则返回 None
"""
return self.tables.get(table_name)
def get_table_frequency(self, table_name: str) -> Optional[TableFrequency]:
"""获取表的频度类型。
Args:
table_name: 表名
Returns:
表频度类型DAILY 或 PIT如果不存在则返回 None
"""
metadata = self.tables.get(table_name)
return metadata.frequency if metadata else None
def get_required_tables(self, fields: List[str]) -> Set[str]:
"""获取所需字段涉及的所有表名。
Args:
fields: 字段列表
Returns:
涉及的表名集合
"""
tables = set()
for field in fields:
table = self.get_table_for_field(field)
if table:
tables.add(table)
return tables
def get_fields_for_table(
self, table_name: str, required_fields: List[str]
) -> List[str]:
"""获取指定表需要的字段列表(包含必要的键字段)。
Args:
table_name: 表名
required_fields: 用户请求的所有字段
Returns:
该表需要查询的字段列表(包含键字段)
"""
metadata = self.tables.get(table_name)
if not metadata:
return []
# 基础键字段
fields = [metadata.code_field, metadata.date_field]
# 添加用户请求的字段(属于该表的)
for field in required_fields:
if self.get_table_for_field(field) == table_name and field not in fields:
fields.append(field)
return fields
def is_pit_table(self, table_name: str) -> bool:
"""判断表是否为 PIT 类型。
Args:
table_name: 表名
Returns:
是否为 PIT 类型表
"""
frequency = self.get_table_frequency(table_name)
return frequency == TableFrequency.PIT
class SQLQueryBuilder:
"""SQL 查询构建器。
根据表类型DAILY/PIT构建优化的 SQL 查询。
"""
def __init__(self, catalog: DatabaseCatalog):
"""初始化 SQL 构建器。
Args:
catalog: 数据库目录实例
"""
self.catalog = catalog
def build_query(
self,
table_name: str,
fields: List[str],
start_date: str,
end_date: str,
lookback_days: int = 90,
) -> str:
"""构建优化的 SQL 查询。
对于 PIT 类型表,会自动向前回溯 lookback_days 天,
以确保起始日期能匹配到最近的旧数据。
Args:
table_name: 表名
fields: 需要查询的字段列表
start_date: 开始日期YYYYMMDD 格式)
end_date: 结束日期YYYYMMDD 格式)
lookback_days: PIT 表回溯天数默认90天
Returns:
构建好的 SQL 查询语句
"""
metadata = self.catalog.get_table_metadata(table_name)
if not metadata:
raise ValueError(f"未知的表: {table_name}")
# 构建字段列表
fields_str = ", ".join(fields)
# 根据表类型构建 WHERE 条件
if metadata.frequency == TableFrequency.PIT:
# PIT 表:按公告日期查询,需要向前回溯
date_field = metadata.date_field
query_start = self._adjust_start_date(start_date, lookback_days)
query_start_fmt = self._format_date(query_start)
end_date_fmt = self._format_date(end_date)
sql = f"""
SELECT {fields_str}
FROM {table_name}
WHERE {date_field} >= '{query_start_fmt}'
AND {date_field} <= '{end_date_fmt}'
ORDER BY {metadata.code_field}, {date_field}
"""
else:
# DAILY 表:直接按交易日期查询
date_field = metadata.date_field
start_date_fmt = self._format_date(start_date)
end_date_fmt = self._format_date(end_date)
sql = f"""
SELECT {fields_str}
FROM {table_name}
WHERE {date_field} >= '{start_date_fmt}'
AND {date_field} <= '{end_date_fmt}'
ORDER BY {metadata.code_field}, {date_field}
"""
return sql.strip()
def _format_date(self, date_str: str) -> str:
"""将 YYYYMMDD 格式转换为 YYYY-MM-DD 格式。
Args:
date_str: 日期字符串YYYYMMDD 格式)
Returns:
格式化后的日期字符串YYYY-MM-DD 格式)
"""
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
def _adjust_start_date(self, start_date: str, days: int) -> str:
"""调整开始日期(向前回溯指定天数)。
Args:
start_date: 开始日期YYYYMMDD 格式)
days: 回溯天数
Returns:
调整后的日期YYYYMMDD 格式)
"""
from datetime import datetime, timedelta
dt = datetime.strptime(start_date, "%Y%m%d")
adjusted_dt = dt - timedelta(days=days)
return adjusted_dt.strftime("%Y%m%d")
def query_duckdb_to_polars(query: str, db_path: str) -> pl.LazyFrame:
"""执行 DuckDB 查询并返回 Polars LazyFrame。
使用 duckdb.connect().sql(query).pl() 实现高速数据流转。
Args:
query: SQL 查询语句
db_path: DuckDB 数据库文件路径
Returns:
Polars LazyFrame
"""
conn = duckdb.connect(db_path)
try:
# DuckDB -> Polars 高速转换
df = conn.sql(query).pl()
return df.lazy()
finally:
conn.close()
def build_context_lazyframe(
required_fields: List[str],
start_date: str,
end_date: str,
db_uri: str,
catalog: Optional[DatabaseCatalog] = None,
lookback_days: int = 90,
) -> pl.LazyFrame:
"""构建上下文 LazyFrame根据所需字段动态生成 SQL 并合并数据。
核心逻辑:
1. 根据 required_fields 反查涉及的表名
2. 对每个表生成精简的 SQL 查询
3. 从 DuckDB 加载数据到 Polars LazyFrame
4. 合并不同表的数据:
- DAILY 表按 ["trade_date", "ts_code"] 进行 left_join
- PIT 表使用 join_asof 按公告日期对齐
5. 最终按 ["ts_code", "trade_date"] 排序
Args:
required_fields: 需要的字段列表
start_date: 开始日期YYYYMMDD 格式)
end_date: 结束日期YYYYMMDD 格式)
db_uri: 数据库连接 URI"duckdb:///data/prostock.db"
catalog: 数据库目录实例,如果为 None 则自动创建并发现表
lookback_days: PIT 表回溯天数默认90天
Returns:
合并后的 LazyFrame包含所有请求的字段
Example:
>>> lf = build_context_lazyframe(
... required_fields=["close", "vol", "basic_eps"],
... start_date="20240101",
... end_date="20240131",
... db_uri="duckdb:///data/prostock.db"
... )
>>> df = lf.collect()
"""
# 解析数据库路径
db_path = db_uri.replace("duckdb://", "").lstrip("/")
# 如果没有提供 catalog自动创建并发现表
if catalog is None:
catalog = DatabaseCatalog(db_path)
# 获取涉及的表
tables = catalog.get_required_tables(required_fields)
if not tables:
# 如果没有涉及的表,返回空 DataFrame
return pl.LazyFrame({"ts_code": [], "trade_date": []})
# 分离 DAILY 表和 PIT 表
daily_tables: List[str] = []
pit_tables: List[str] = []
for table_name in tables:
if catalog.is_pit_table(table_name):
pit_tables.append(table_name)
else:
daily_tables.append(table_name)
# 构建 SQL 查询器
query_builder = SQLQueryBuilder(catalog)
# 加载 DAILY 表数据
daily_lfs: Dict[str, pl.LazyFrame] = {}
for table_name in daily_tables:
fields = catalog.get_fields_for_table(table_name, required_fields)
sql = query_builder.build_query(
table_name=table_name,
fields=fields,
start_date=start_date,
end_date=end_date,
)
print(f"[SQL] {sql[:100]}...")
lf = query_duckdb_to_polars(sql, db_path)
# 统一列名:将表的 date_field 重命名为 trade_date
metadata = catalog.get_table_metadata(table_name)
if metadata and metadata.date_field != "trade_date":
lf = lf.rename({metadata.date_field: "trade_date"})
daily_lfs[table_name] = lf
# 加载 PIT 表数据
pit_lfs: Dict[str, pl.LazyFrame] = {}
for table_name in pit_tables:
fields = catalog.get_fields_for_table(table_name, required_fields)
sql = query_builder.build_query(
table_name=table_name,
fields=fields,
start_date=start_date,
end_date=end_date,
lookback_days=lookback_days,
)
print(f"[SQL] {sql[:100]}...")
lf = query_duckdb_to_polars(sql, db_path)
# PIT 表保持原始公告日期字段(用于 join_asof
pit_lfs[table_name] = lf
# 合并所有 DAILY 表(以第一个 daily 表为基准)
result_lf: Optional[pl.LazyFrame] = None
if daily_lfs:
# 使用第一个 daily 表作为基准
first_table = daily_tables[0]
result_lf = daily_lfs[first_table]
# 合并其他 daily 表
for table_name in daily_tables[1:]:
lf = daily_lfs[table_name]
result_lf = result_lf.join(lf, on=["trade_date", "ts_code"], how="left")
elif pit_lfs:
# 如果没有 daily 表,从 PIT 表创建基准时间轴
# 使用第一个 PIT 表的日期范围
first_pit = pit_tables[0]
pit_metadata = catalog.get_table_metadata(first_pit)
# 从 PIT 表提取所有日期和股票代码组合
result_lf = (
pit_lfs[first_pit]
.select([pl.col(pit_metadata.date_field).alias("trade_date"), "ts_code"])
.unique()
)
# 如果没有结果,返回空 DataFrame
if result_lf is None:
return pl.LazyFrame({"ts_code": [], "trade_date": []})
# 合并 PIT 表(使用 join_asof 按公告日期对齐)
for table_name in pit_tables:
pit_metadata = catalog.get_table_metadata(table_name)
lf = pit_lfs[table_name]
# join_asof: 按 ts_code 分组,将 PIT 数据对齐到交易日
# 策略为 backward使用小于等于当前交易日的最新公告数据
result_lf = result_lf.join_asof(
lf,
left_on="trade_date",
right_on=pit_metadata.date_field,
by="ts_code",
strategy="backward",
)
# 最终排序:按 ["ts_code", "trade_date"] 确保时序计算要求
result_lf = result_lf.sort(["ts_code", "trade_date"])
return result_lf
if __name__ == "__main__":
# 测试代码
print("=" * 60)
print("DatabaseCatalog 自动发现测试")
print("=" * 60)
# 测试自动发现
catalog = DatabaseCatalog("data/prostock.db")
print("\n=== 测试字段到表映射 ===")
print(f"字段 'close' 对应的表: {catalog.get_table_for_field('close')}")
print(f"字段 'vol' 对应的表: {catalog.get_table_for_field('vol')}")
print(f"字段 'pe' 对应的表: {catalog.get_table_for_field('pe')}")
print(f"字段 'basic_eps' 对应的表: {catalog.get_table_for_field('basic_eps')}")
print("\n=== 测试表频度类型 ===")
for table_name in catalog.tables:
freq = catalog.get_table_frequency(table_name)
print(f"'{table_name}' 的频度: {freq.value if freq else 'Unknown'}")
print("\n=== 测试 SQL 构建 ===")
query_builder = SQLQueryBuilder(catalog)
daily_sql = query_builder.build_query(
table_name="daily",
fields=["ts_code", "trade_date", "close", "vol"],
start_date="20240101",
end_date="20240131",
)
print(f"\nDAILY 表 SQL:\n{daily_sql}")
pit_sql = query_builder.build_query(
table_name="financial_income",
fields=["ts_code", "ann_date", "basic_eps", "total_revenue"],
start_date="20240101",
end_date="20240131",
lookback_days=90,
)
print(f"\nPIT 表 SQL:\n{pit_sql}")
print("\n=== 测试多表字段收集 ===")
required_fields = ["close", "vol", "pe", "basic_eps", "total_revenue"]
tables = catalog.get_required_tables(required_fields)
print(f"字段 {required_fields} 涉及的表: {tables}")
for table_name in tables:
fields = catalog.get_fields_for_table(table_name, required_fields)
print(f"'{table_name}' 需要查询的字段: {fields}")
print("\n所有测试通过!")