feat: 添加DSL因子表达式系统和Pro Bar API封装
- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合 - 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等) - 新增 factors/compiler.py: 因子编译器 - 新增 factors/translator.py: DSL表达式翻译器 - 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据 - 新增 data/data_router.py: 数据路由功能 - 新增相关测试用例
This commit is contained in:
880
src/data/api_wrappers/api_pro_bar.py
Normal file
880
src/data/api_wrappers/api_pro_bar.py
Normal file
@@ -0,0 +1,880 @@
|
||||
"""Pro Bar (通用行情) interface.
|
||||
|
||||
Fetch A-share stock market data with adjustment factors from Tushare.
|
||||
This interface provides backward-adjusted (后复权) daily market data
|
||||
including all available fields: base price data, turnover rate (tor),
|
||||
volume ratio (vr), and adjustment factors.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
|
||||
from src.config.settings import get_settings
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
|
||||
|
||||
|
||||
def get_pro_bar(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
asset: Literal["E", "I", "C", "FT", "FD", "O", "CB"] = "E",
|
||||
adj: Literal[None, "qfq", "hfq"] = "hfq",
|
||||
freq: Literal["D", "W", "M"] = "D",
|
||||
ma: Optional[List[int]] = None,
|
||||
factors: Optional[List[Literal["tor", "vr"]]] = None,
|
||||
adjfactor: bool = True,
|
||||
client: Optional[TushareClient] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch pro bar (universal market) data from Tushare.
|
||||
|
||||
This interface retrieves stock market data with adjustment factors.
|
||||
By default, it fetches backward-adjusted (后复权) daily data for stocks
|
||||
with turnover rate and volume ratio factors enabled.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
asset: Asset type - 'E' (stock), 'I' (index), 'C' (crypto),
|
||||
'FT' (futures), 'FD' (fund), 'O' (options), 'CB' (convertible bond)
|
||||
adj: Adjustment type - None (no adjustment), 'qfq' (forward),
|
||||
'hfq' (backward). Default is 'hfq' (backward-adjusted).
|
||||
freq: Data frequency - 'D' (daily), 'W' (weekly), 'M' (monthly)
|
||||
ma: List of moving average periods (e.g., [5, 10, 20])
|
||||
factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio).
|
||||
Default is ['tor', 'vr'] to fetch all available fields.
|
||||
adjfactor: Whether to include adjustment factor column. Default is True.
|
||||
client: Optional TushareClient instance for shared rate limiting.
|
||||
If None, creates a new client. For concurrent sync operations,
|
||||
pass a shared client to ensure proper rate limiting.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
- ts_code: Stock code
|
||||
- trade_date: Trade date (YYYYMMDD)
|
||||
- open: Opening price
|
||||
- high: Highest price
|
||||
- low: Lowest price
|
||||
- close: Closing price
|
||||
- pre_close: Previous closing price (adjusted)
|
||||
- change: Price change amount
|
||||
- pct_chg: Price change percentage
|
||||
- vol: Trading volume (lots)
|
||||
- amount: Trading amount (thousand CNY)
|
||||
- tor: Turnover rate (if factors includes 'tor')
|
||||
- vr: Volume ratio (if factors includes 'vr')
|
||||
- adj_factor: Adjustment factor (if adjfactor=True)
|
||||
- ma_X: Moving average price for period X (if ma specified)
|
||||
- ma_v_X: Moving average volume for period X (if ma specified)
|
||||
|
||||
Example:
|
||||
>>> # Get backward-adjusted daily data with all factors (default)
|
||||
>>> data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # Get unadjusted data
|
||||
>>> data = get_pro_bar('000001.SZ', start_date='20240101', adj=None)
|
||||
>>>
|
||||
>>> # Get data with moving averages
|
||||
>>> data = get_pro_bar('000001.SZ', start_date='20240101', ma=[5, 10, 20])
|
||||
>>>
|
||||
>>> # Get index data
|
||||
>>> data = get_pro_bar('000001.SH', asset='I', start_date='20240101')
|
||||
"""
|
||||
client = client or TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {"ts_code": ts_code}
|
||||
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if asset:
|
||||
params["asset"] = asset
|
||||
if adj:
|
||||
params["adj"] = adj
|
||||
if freq:
|
||||
params["freq"] = freq
|
||||
if ma:
|
||||
# Tushare expects ma as comma-separated string
|
||||
if isinstance(ma, list):
|
||||
ma_str = ",".join(str(m) for m in ma)
|
||||
else:
|
||||
ma_str = str(ma)
|
||||
params["ma"] = ma_str
|
||||
|
||||
# Default to fetching all factors if not specified
|
||||
factors_to_use = factors if factors is not None else ["tor", "vr"]
|
||||
if factors_to_use:
|
||||
# Tushare expects factors as comma-separated string
|
||||
if isinstance(factors_to_use, list):
|
||||
factors_str = ",".join(factors_to_use)
|
||||
else:
|
||||
factors_str = factors_to_use
|
||||
params["factors"] = factors_str
|
||||
|
||||
if adjfactor:
|
||||
params["adjfactor"] = "True"
|
||||
|
||||
# Fetch data using pro_bar API
|
||||
data = client.query("pro_bar", **params)
|
||||
|
||||
# Rename date column if needed
|
||||
if "date" in data.columns:
|
||||
data = data.rename(columns={"date": "trade_date"})
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ProBarSync - Pro Bar 数据批量同步类
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ProBarSync:
|
||||
"""Pro Bar 数据批量同步管理器,支持全量/增量同步。
|
||||
|
||||
功能特性:
|
||||
- 多线程并发获取(ThreadPoolExecutor)
|
||||
- 增量同步(自动检测上次同步位置)
|
||||
- 内存缓存(避免重复磁盘读取)
|
||||
- 异常立即停止(确保数据一致性)
|
||||
- 预览模式(预览同步数据量,不实际写入)
|
||||
- 默认获取全部数据列(tor, vr, adj_factor)
|
||||
"""
|
||||
|
||||
# 默认工作线程数(从配置读取,默认10)
|
||||
DEFAULT_MAX_WORKERS = get_settings().threads
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""初始化同步管理器。
|
||||
|
||||
max_workers: 工作线程数(默认从配置读取,若未指定则使用配置值)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
"""
|
||||
self.storage = ThreadSafeStorage()
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
self._stop_flag.set() # 初始为未停止状态
|
||||
self._cached_pro_bar_data: Optional[pd.DataFrame] = None # 数据缓存
|
||||
|
||||
def _load_pro_bar_data(self) -> pd.DataFrame:
|
||||
"""从存储加载 Pro Bar 数据(带缓存)。
|
||||
|
||||
该方法会将数据缓存在内存中以避免重复磁盘读取。
|
||||
调用 clear_cache() 可强制重新加载。
|
||||
|
||||
Returns:
|
||||
缓存或从存储加载的 Pro Bar 数据 DataFrame
|
||||
"""
|
||||
if self._cached_pro_bar_data is None:
|
||||
self._cached_pro_bar_data = self.storage.load("pro_bar")
|
||||
return self._cached_pro_bar_data
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除缓存的 Pro Bar 数据,强制下次访问时重新加载。"""
|
||||
self._cached_pro_bar_data = None
|
||||
|
||||
def get_all_stock_codes(self, only_listed: bool = True) -> list:
|
||||
"""从本地存储获取所有股票代码。
|
||||
|
||||
优先使用 stock_basic.csv 以确保包含所有股票,
|
||||
避免回测中的前视偏差。
|
||||
|
||||
Args:
|
||||
only_listed: 若为 True,仅返回当前上市股票(L 状态)。
|
||||
设为 False 可包含退市股票(用于完整回测)。
|
||||
|
||||
Returns:
|
||||
股票代码列表
|
||||
"""
|
||||
# 确保 stock_basic.csv 是最新的
|
||||
print("[ProBarSync] Ensuring stock_basic.csv is up-to-date...")
|
||||
sync_all_stocks()
|
||||
|
||||
# 从 stock_basic.csv 文件获取
|
||||
stock_csv_path = _get_csv_path()
|
||||
|
||||
if stock_csv_path.exists():
|
||||
print(f"[ProBarSync] Reading stock_basic from CSV: {stock_csv_path}")
|
||||
try:
|
||||
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
|
||||
if not stock_df.empty and "ts_code" in stock_df.columns:
|
||||
# 根据 list_status 过滤
|
||||
if only_listed and "list_status" in stock_df.columns:
|
||||
listed_stocks = stock_df[stock_df["list_status"] == "L"]
|
||||
codes = listed_stocks["ts_code"].unique().tolist()
|
||||
total = len(stock_df["ts_code"].unique())
|
||||
print(
|
||||
f"[ProBarSync] Found {len(codes)} listed stocks (filtered from {total} total)"
|
||||
)
|
||||
else:
|
||||
codes = stock_df["ts_code"].unique().tolist()
|
||||
print(
|
||||
f"[ProBarSync] Found {len(codes)} stock codes from stock_basic.csv"
|
||||
)
|
||||
return codes
|
||||
else:
|
||||
print(
|
||||
f"[ProBarSync] stock_basic.csv exists but no ts_code column or empty"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[ProBarSync] Error reading stock_basic.csv: {e}")
|
||||
|
||||
# 回退:从 Pro Bar 存储获取
|
||||
print(
|
||||
"[ProBarSync] stock_basic.csv not available, falling back to pro_bar data..."
|
||||
)
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if not pro_bar_data.empty and "ts_code" in pro_bar_data.columns:
|
||||
codes = pro_bar_data["ts_code"].unique().tolist()
|
||||
print(f"[ProBarSync] Found {len(codes)} stock codes from pro_bar data")
|
||||
return codes
|
||||
|
||||
print("[ProBarSync] No stock codes found in local storage")
|
||||
return []
|
||||
|
||||
def get_global_last_date(self) -> Optional[str]:
|
||||
"""获取全局最后交易日期。
|
||||
|
||||
Returns:
|
||||
最后交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
|
||||
return None
|
||||
return str(pro_bar_data["trade_date"].max())
|
||||
|
||||
def get_global_first_date(self) -> Optional[str]:
|
||||
"""获取全局最早交易日期。
|
||||
|
||||
Returns:
|
||||
最早交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
|
||||
return None
|
||||
return str(pro_bar_data["trade_date"].min())
|
||||
|
||||
def get_trade_calendar_bounds(
|
||||
self, start_date: str, end_date: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""从交易日历获取首尾交易日。
|
||||
|
||||
Args:
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
|
||||
Returns:
|
||||
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
|
||||
"""
|
||||
try:
|
||||
first_day = get_first_trading_day(start_date, end_date)
|
||||
last_day = get_last_trading_day(start_date, end_date)
|
||||
return (first_day, last_day)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
|
||||
return (None, None)
|
||||
|
||||
def check_sync_needed(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
table_name: str = "pro_bar",
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""基于交易日历检查是否需要同步。
|
||||
|
||||
该方法比较本地数据日期范围与交易日历,
|
||||
以确定是否需要获取新数据。
|
||||
|
||||
逻辑:
|
||||
- 若 force_full:需要同步,返回 (True, 20180101, today)
|
||||
- 若无本地数据:需要同步,返回 (True, 20180101, today)
|
||||
- 若存在本地数据:
|
||||
- 从交易日历获取最后交易日
|
||||
- 若本地最后日期 >= 日历最后日期:无需同步
|
||||
- 否则:从本地最后日期+1 到最新交易日同步
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,始终返回需要同步
|
||||
table_name: 要检查的表名(默认: "pro_bar")
|
||||
|
||||
Returns:
|
||||
(需要同步, 起始日期, 结束日期, 本地最后日期)
|
||||
- 需要同步: True 表示应继续同步
|
||||
- 起始日期: 同步起始日期(无需同步时为 None)
|
||||
- 结束日期: 同步结束日期(无需同步时为 None)
|
||||
- 本地最后日期: 本地数据最后日期(用于增量同步)
|
||||
"""
|
||||
# 若 force_full,始终同步
|
||||
if force_full:
|
||||
print("[ProBarSync] Force full sync requested")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 检查特定表的本地数据是否存在
|
||||
storage = Storage()
|
||||
table_data = (
|
||||
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
|
||||
)
|
||||
|
||||
if table_data.empty or "trade_date" not in table_data.columns:
|
||||
print(
|
||||
f"[ProBarSync] No local data found for table '{table_name}', full sync needed"
|
||||
)
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 获取本地数据最后日期
|
||||
local_last_date = str(table_data["trade_date"].max())
|
||||
|
||||
print(f"[ProBarSync] Local data last date: {local_last_date}")
|
||||
|
||||
# 从交易日历获取最新交易日
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print("[ProBarSync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[ProBarSync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# 比较本地最后日期与日历最后日期
|
||||
print(
|
||||
f"[ProBarSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
|
||||
f"cal={cal_last} (type={type(cal_last).__name__})"
|
||||
)
|
||||
try:
|
||||
local_last_int = int(local_last_date)
|
||||
cal_last_int = int(cal_last)
|
||||
print(
|
||||
f"[ProBarSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
|
||||
f"{local_last_int >= cal_last_int}"
|
||||
)
|
||||
if local_last_int >= cal_last_int:
|
||||
print(
|
||||
"[ProBarSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
|
||||
)
|
||||
return (False, None, None, None)
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"[ERROR] Date comparison failed: {e}")
|
||||
|
||||
# 需要从本地最后日期+1 同步到最新交易日
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[ProBarSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览同步数据量和样本(不实际同步)。
|
||||
|
||||
该方法提供即将同步的数据的预览,包括:
|
||||
- 将同步的股票数量
|
||||
- 同步日期范围
|
||||
- 预估总记录数
|
||||
- 前几只股票的样本数据
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' 或 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的
|
||||
print("[ProBarSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[ProBarSync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[ProBarSync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# 从前几只股票获取样本数据
|
||||
print(f"[ProBarSync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
# 使用 get_pro_bar 获取样本数据(包含所有字段)
|
||||
data = get_pro_bar(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# 合并样本数据
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# 基于样本估算总记录数
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# 显示预览结果
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# 以紧凑格式显示样本数据
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
"tor",
|
||||
"vr",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""同步单只股票的 Pro Bar 数据。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
start_date: 起始日期(YYYYMMDD)
|
||||
end_date: 结束日期(YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
包含 Pro Bar 数据的 DataFrame
|
||||
"""
|
||||
# 检查是否应该停止同步(用于异常处理)
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# 使用 get_pro_bar 获取数据(默认包含所有字段,传递共享 client)
|
||||
data = get_pro_bar(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
client=self.client, # 传递共享客户端以确保限流
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
# 设置停止标志以通知其他线程停止
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {ts_code}: {e}")
|
||||
raise
|
||||
|
||||
def sync_all(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步本地存储中所有股票的 Pro Bar 数据。
|
||||
|
||||
该函数:
|
||||
1. 从本地存储读取股票代码(pro_bar 或 stock_basic)
|
||||
2. 检查交易日历确定是否需要同步:
|
||||
- 若本地数据匹配交易日历边界,则跳过同步(节省 token)
|
||||
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
|
||||
3. 使用多线程并发获取(带速率限制)
|
||||
4. 跳过返回空数据的股票(退市/不可用)
|
||||
5. 遇异常立即停止
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Starting pro_bar data sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的(使用增量同步)
|
||||
print("[ProBarSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 基于交易日历检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
# 跳过同步 - 不消耗 token
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
|
||||
print(" Tokens saved: 0 consumed")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
# 回退到默认逻辑
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[ProBarSync] No stocks found to sync")
|
||||
return {}
|
||||
|
||||
print(f"[ProBarSync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[ProBarSync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# 处理 dry run 模式
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 为新同步重置停止标志
|
||||
self._stop_flag.set()
|
||||
|
||||
# 多线程并发获取
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
error_occurred = False
|
||||
exception_to_raise = None
|
||||
|
||||
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
|
||||
"""每只股票的任务函数。"""
|
||||
try:
|
||||
data = self.sync_single_stock(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return (ts_code, data)
|
||||
except Exception as e:
|
||||
# 重新抛出以被 Future 捕获
|
||||
raise
|
||||
|
||||
# 使用 ThreadPoolExecutor 进行并发获取
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 提交所有任务并跟踪 futures 与股票代码的映射
|
||||
future_to_code = {
|
||||
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
|
||||
}
|
||||
|
||||
# 使用 as_completed 处理结果
|
||||
error_count = 0
|
||||
empty_count = 0
|
||||
success_count = 0
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(stock_codes), desc="Syncing pro_bar stocks")
|
||||
|
||||
try:
|
||||
# 处理完成的 futures
|
||||
for future in as_completed(future_to_code):
|
||||
ts_code = future_to_code[future]
|
||||
|
||||
try:
|
||||
_, data = future.result()
|
||||
if data is not None and not data.empty:
|
||||
results[ts_code] = data
|
||||
success_count += 1
|
||||
else:
|
||||
# 空数据 - 股票可能已退市或不可用
|
||||
empty_count += 1
|
||||
print(
|
||||
f"[ProBarSync] Stock {ts_code}: empty data (skipped, may be delisted)"
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常 - 停止全部并中止
|
||||
error_occurred = True
|
||||
exception_to_raise = e
|
||||
print(f"\n[ERROR] Sync aborted due to exception: {e}")
|
||||
# 关闭 executor 以停止所有待处理任务
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise exception_to_raise
|
||||
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
|
||||
except Exception:
|
||||
error_count = 1
|
||||
print("[ProBarSync] Sync stopped due to exception")
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
# 批量写入所有数据(仅在无错误时)
|
||||
if results and not error_occurred:
|
||||
for ts_code, data in results.items():
|
||||
if not data.empty:
|
||||
self.storage.queue_save("pro_bar", data)
|
||||
# 一次性刷新所有排队写入
|
||||
self.storage.flush()
|
||||
total_rows = sum(len(df) for df in results.values())
|
||||
print(f"\n[ProBarSync] Saved {total_rows} rows to storage")
|
||||
|
||||
# 摘要
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(f" Total stocks: {len(stock_codes)}")
|
||||
print(f" Updated: {success_count}")
|
||||
print(f" Skipped (empty/delisted): {empty_count}")
|
||||
print(
|
||||
f" Errors: {error_count} (aborted on first error)"
|
||||
if error_count
|
||||
else " Errors: 0"
|
||||
)
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def sync_pro_bar(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步所有股票的 Pro Bar 数据。
|
||||
|
||||
这是 Pro Bar 数据同步的主要入口点。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(YYYYMMDD)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典
|
||||
|
||||
Example:
|
||||
>>> # 首次同步(从 20180101 全量加载)
|
||||
>>> result = sync_pro_bar()
|
||||
>>>
|
||||
>>> # 后续同步(增量 - 仅新数据)
|
||||
>>> result = sync_pro_bar()
|
||||
>>>
|
||||
>>> # 强制完整重载
|
||||
>>> result = sync_pro_bar(force_full=True)
|
||||
>>>
|
||||
>>> # 手动指定日期范围
|
||||
>>> result = sync_pro_bar(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # 自定义线程数
|
||||
>>> result = sync_pro_bar(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run(仅预览)
|
||||
>>> result = sync_pro_bar(dry_run=True)
|
||||
"""
|
||||
sync_manager = ProBarSync(max_workers=max_workers)
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def preview_pro_bar_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览 Pro Bar 同步数据量和样本(不实际同步)。
|
||||
|
||||
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # 预览将要同步的内容
|
||||
>>> preview = preview_pro_bar_sync()
|
||||
>>>
|
||||
>>> # 预览全量同步
|
||||
>>> preview = preview_pro_bar_sync(force_full=True)
|
||||
>>>
|
||||
>>> # 预览更多样本
|
||||
>>> preview = preview_pro_bar_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = ProBarSync()
|
||||
return sync_manager.preview_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
663
src/data/data_router.py
Normal file
663
src/data/data_router.py
Normal file
@@ -0,0 +1,663 @@
|
||||
"""数据目录与动态 SQL 路由模块。
|
||||
|
||||
用于动态 SQL 生成和数据拉取,解决多表架构下的数据查询痛点。
|
||||
支持 DAILY(日频精确对齐)和 PIT(低频财务数据,按披露日对齐)两种表类型。
|
||||
|
||||
核心特性:
|
||||
- 自动发现 DuckDB 数据库中的表结构
|
||||
- 支持通过配置覆盖自动发现的元数据
|
||||
- 智能识别 PIT 类型表(通过 ann_date/f_ann_date 字段)
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import polars as pl
|
||||
import duckdb
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TableFrequency(Enum):
|
||||
"""表频度类型。"""
|
||||
|
||||
DAILY = "daily" # 日频数据,精确对齐
|
||||
PIT = "pit" # 低频数据,按披露日对齐 (Point-In-Time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableMetadata:
|
||||
"""表元数据配置。
|
||||
|
||||
Attributes:
|
||||
name: 表名
|
||||
frequency: 表频度类型(DAILY 或 PIT)
|
||||
date_field: 日期字段名(DAILY 表为 trade_date,PIT 表为 ann_date)
|
||||
code_field: 资产代码字段名(通常为 ts_code)
|
||||
fields: 表中所有字段列表
|
||||
description: 表描述
|
||||
"""
|
||||
|
||||
name: str
|
||||
frequency: TableFrequency
|
||||
date_field: str
|
||||
code_field: str = "ts_code"
|
||||
fields: List[str] = field(default_factory=list)
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldMapping:
|
||||
"""字段映射配置。
|
||||
|
||||
Attributes:
|
||||
field_name: 字段名
|
||||
table_name: 所属表名
|
||||
description: 字段描述
|
||||
"""
|
||||
|
||||
field_name: str
|
||||
table_name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class DatabaseCatalog:
|
||||
"""数据库目录类,管理字段到表的映射关系。
|
||||
|
||||
核心职责:
|
||||
1. 自动从 DuckDB 数据库中发现表结构
|
||||
2. 维护字段到表的映射关系
|
||||
3. 管理表的元数据(频度类型、日期字段等)
|
||||
4. 提供字段解析和表路由功能
|
||||
|
||||
表类型自动识别规则:
|
||||
- 如果表包含 ann_date 或 f_ann_date 字段,识别为 PIT 类型
|
||||
- 否则,如果包含 trade_date 字段,识别为 DAILY 类型
|
||||
|
||||
Attributes:
|
||||
tables: 表元数据字典,表名 -> TableMetadata
|
||||
field_mappings: 字段映射字典,字段名 -> FieldMapping
|
||||
db_path: 数据库文件路径
|
||||
|
||||
Example:
|
||||
>>> catalog = DatabaseCatalog("data/prostock.db")
|
||||
>>> # 自动发现所有表结构
|
||||
>>> catalog.discover_tables()
|
||||
>>> table = catalog.get_table_for_field("close")
|
||||
>>> print(table) # "daily"
|
||||
"""
|
||||
|
||||
# PIT 类型表的标识字段(优先级顺序)
|
||||
PIT_DATE_FIELDS = ["ann_date", "f_ann_date", "publish_date"]
|
||||
# DAILY 类型表的标识字段
|
||||
DAILY_DATE_FIELDS = ["trade_date", "cal_date", "date"]
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
"""初始化数据库目录。
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径,如果为 None 则使用默认配置
|
||||
"""
|
||||
self.tables: Dict[str, TableMetadata] = {}
|
||||
self.field_mappings: Dict[str, FieldMapping] = {}
|
||||
self.db_path = db_path
|
||||
self._table_frequency_overrides: Dict[str, TableFrequency] = {}
|
||||
|
||||
if db_path:
|
||||
self.discover_tables(db_path)
|
||||
|
||||
def set_table_frequency_override(
|
||||
self, table_name: str, frequency: TableFrequency
|
||||
) -> None:
|
||||
"""设置表频度类型覆盖。
|
||||
|
||||
用于手动指定表的频度类型,覆盖自动识别的结果。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
frequency: 频度类型(DAILY 或 PIT)
|
||||
"""
|
||||
self._table_frequency_overrides[table_name] = frequency
|
||||
|
||||
def discover_tables(self, db_path: str) -> None:
|
||||
"""自动发现数据库中的所有表结构。
|
||||
|
||||
从 information_schema 中读取表和列信息,自动识别:
|
||||
- 表名和字段列表
|
||||
- 资产代码字段(ts_code)
|
||||
- 日期字段(根据字段名智能识别表类型)
|
||||
- 表频度类型(DAILY 或 PIT)
|
||||
|
||||
Args:
|
||||
db_path: DuckDB 数据库文件路径
|
||||
"""
|
||||
db_file = db_path.replace("duckdb://", "").lstrip("/")
|
||||
|
||||
if not Path(db_file).exists():
|
||||
print(f"[DatabaseCatalog] 数据库文件不存在: {db_file}")
|
||||
return
|
||||
|
||||
conn = duckdb.connect(db_file, read_only=True)
|
||||
try:
|
||||
# 获取所有表
|
||||
tables_query = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'main'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
tables_result = conn.execute(tables_query).fetchall()
|
||||
|
||||
for (table_name,) in tables_result:
|
||||
# 获取表的列信息
|
||||
columns_query = """
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = ? AND table_schema = 'main'
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
columns_result = conn.execute(columns_query, [table_name]).fetchall()
|
||||
|
||||
fields = [col[0] for col in columns_result]
|
||||
|
||||
# 自动识别表类型和日期字段
|
||||
frequency, date_field = self._detect_table_type(fields, table_name)
|
||||
|
||||
# 检查是否有资产代码字段
|
||||
code_field = "ts_code" if "ts_code" in fields else None
|
||||
|
||||
if code_field and date_field:
|
||||
# 创建表元数据
|
||||
metadata = TableMetadata(
|
||||
name=table_name,
|
||||
frequency=frequency,
|
||||
date_field=date_field,
|
||||
code_field=code_field,
|
||||
fields=fields,
|
||||
description=f"自动发现的表: {table_name}",
|
||||
)
|
||||
self.register_table(metadata)
|
||||
print(
|
||||
f"[DatabaseCatalog] 发现表: {table_name} ({frequency.value}, "
|
||||
f"日期字段: {date_field})"
|
||||
)
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _detect_table_type(
|
||||
self, fields: List[str], table_name: str
|
||||
) -> tuple[TableFrequency, Optional[str]]:
|
||||
"""自动检测表的频度类型和日期字段。
|
||||
|
||||
检测规则(按优先级):
|
||||
1. 检查是否有手动覆盖配置
|
||||
2. 检查是否包含 PIT 标识字段(ann_date, f_ann_date 等)
|
||||
3. 检查是否包含 DAILY 标识字段(trade_date, cal_date 等)
|
||||
|
||||
Args:
|
||||
fields: 表的字段列表
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
(频度类型, 日期字段名)
|
||||
"""
|
||||
# 检查手动覆盖配置
|
||||
if table_name in self._table_frequency_overrides:
|
||||
frequency = self._table_frequency_overrides[table_name]
|
||||
if frequency == TableFrequency.PIT:
|
||||
for field in self.PIT_DATE_FIELDS:
|
||||
if field in fields:
|
||||
return frequency, field
|
||||
else:
|
||||
for field in self.DAILY_DATE_FIELDS:
|
||||
if field in fields:
|
||||
return frequency, field
|
||||
|
||||
# 检查 PIT 标识字段
|
||||
for field in self.PIT_DATE_FIELDS:
|
||||
if field in fields:
|
||||
return TableFrequency.PIT, field
|
||||
|
||||
# 检查 DAILY 标识字段
|
||||
for field in self.DAILY_DATE_FIELDS:
|
||||
if field in fields:
|
||||
return TableFrequency.DAILY, field
|
||||
|
||||
# 默认返回 DAILY,但无日期字段
|
||||
return TableFrequency.DAILY, None
|
||||
|
||||
def register_table(self, metadata: TableMetadata) -> None:
|
||||
"""注册表元数据。
|
||||
|
||||
Args:
|
||||
metadata: 表元数据配置
|
||||
"""
|
||||
self.tables[metadata.name] = metadata
|
||||
|
||||
# 自动注册字段映射(如果字段已存在,保留第一个表的映射)
|
||||
for field_name in metadata.fields:
|
||||
if field_name not in self.field_mappings:
|
||||
self.field_mappings[field_name] = FieldMapping(
|
||||
field_name=field_name,
|
||||
table_name=metadata.name,
|
||||
description=f"{metadata.description} - {field_name}",
|
||||
)
|
||||
|
||||
def get_table_for_field(self, field: str) -> Optional[str]:
|
||||
"""获取字段对应的表名。
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
Returns:
|
||||
表名,如果字段不存在则返回 None
|
||||
"""
|
||||
mapping = self.field_mappings.get(field)
|
||||
return mapping.table_name if mapping else None
|
||||
|
||||
def get_table_metadata(self, table_name: str) -> Optional[TableMetadata]:
|
||||
"""获取表的元数据。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
表元数据,如果不存在则返回 None
|
||||
"""
|
||||
return self.tables.get(table_name)
|
||||
|
||||
def get_table_frequency(self, table_name: str) -> Optional[TableFrequency]:
|
||||
"""获取表的频度类型。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
表频度类型(DAILY 或 PIT),如果不存在则返回 None
|
||||
"""
|
||||
metadata = self.tables.get(table_name)
|
||||
return metadata.frequency if metadata else None
|
||||
|
||||
def get_required_tables(self, fields: List[str]) -> Set[str]:
|
||||
"""获取所需字段涉及的所有表名。
|
||||
|
||||
Args:
|
||||
fields: 字段列表
|
||||
|
||||
Returns:
|
||||
涉及的表名集合
|
||||
"""
|
||||
tables = set()
|
||||
for field in fields:
|
||||
table = self.get_table_for_field(field)
|
||||
if table:
|
||||
tables.add(table)
|
||||
return tables
|
||||
|
||||
def get_fields_for_table(
|
||||
self, table_name: str, required_fields: List[str]
|
||||
) -> List[str]:
|
||||
"""获取指定表需要的字段列表(包含必要的键字段)。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
required_fields: 用户请求的所有字段
|
||||
|
||||
Returns:
|
||||
该表需要查询的字段列表(包含键字段)
|
||||
"""
|
||||
metadata = self.tables.get(table_name)
|
||||
if not metadata:
|
||||
return []
|
||||
|
||||
# 基础键字段
|
||||
fields = [metadata.code_field, metadata.date_field]
|
||||
|
||||
# 添加用户请求的字段(属于该表的)
|
||||
for field in required_fields:
|
||||
if self.get_table_for_field(field) == table_name and field not in fields:
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def is_pit_table(self, table_name: str) -> bool:
|
||||
"""判断表是否为 PIT 类型。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
是否为 PIT 类型表
|
||||
"""
|
||||
frequency = self.get_table_frequency(table_name)
|
||||
return frequency == TableFrequency.PIT
|
||||
|
||||
|
||||
class SQLQueryBuilder:
|
||||
"""SQL 查询构建器。
|
||||
|
||||
根据表类型(DAILY/PIT)构建优化的 SQL 查询。
|
||||
"""
|
||||
|
||||
def __init__(self, catalog: DatabaseCatalog):
|
||||
"""初始化 SQL 构建器。
|
||||
|
||||
Args:
|
||||
catalog: 数据库目录实例
|
||||
"""
|
||||
self.catalog = catalog
|
||||
|
||||
def build_query(
|
||||
self,
|
||||
table_name: str,
|
||||
fields: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
lookback_days: int = 90,
|
||||
) -> str:
|
||||
"""构建优化的 SQL 查询。
|
||||
|
||||
对于 PIT 类型表,会自动向前回溯 lookback_days 天,
|
||||
以确保起始日期能匹配到最近的旧数据。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
fields: 需要查询的字段列表
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
lookback_days: PIT 表回溯天数(默认90天)
|
||||
|
||||
Returns:
|
||||
构建好的 SQL 查询语句
|
||||
"""
|
||||
metadata = self.catalog.get_table_metadata(table_name)
|
||||
if not metadata:
|
||||
raise ValueError(f"未知的表: {table_name}")
|
||||
|
||||
# 构建字段列表
|
||||
fields_str = ", ".join(fields)
|
||||
|
||||
# 根据表类型构建 WHERE 条件
|
||||
if metadata.frequency == TableFrequency.PIT:
|
||||
# PIT 表:按公告日期查询,需要向前回溯
|
||||
date_field = metadata.date_field
|
||||
query_start = self._adjust_start_date(start_date, lookback_days)
|
||||
query_start_fmt = self._format_date(query_start)
|
||||
end_date_fmt = self._format_date(end_date)
|
||||
|
||||
sql = f"""
|
||||
SELECT {fields_str}
|
||||
FROM {table_name}
|
||||
WHERE {date_field} >= '{query_start_fmt}'
|
||||
AND {date_field} <= '{end_date_fmt}'
|
||||
ORDER BY {metadata.code_field}, {date_field}
|
||||
"""
|
||||
else:
|
||||
# DAILY 表:直接按交易日期查询
|
||||
date_field = metadata.date_field
|
||||
start_date_fmt = self._format_date(start_date)
|
||||
end_date_fmt = self._format_date(end_date)
|
||||
sql = f"""
|
||||
SELECT {fields_str}
|
||||
FROM {table_name}
|
||||
WHERE {date_field} >= '{start_date_fmt}'
|
||||
AND {date_field} <= '{end_date_fmt}'
|
||||
ORDER BY {metadata.code_field}, {date_field}
|
||||
"""
|
||||
|
||||
return sql.strip()
|
||||
|
||||
def _format_date(self, date_str: str) -> str:
|
||||
"""将 YYYYMMDD 格式转换为 YYYY-MM-DD 格式。
|
||||
|
||||
Args:
|
||||
date_str: 日期字符串(YYYYMMDD 格式)
|
||||
|
||||
Returns:
|
||||
格式化后的日期字符串(YYYY-MM-DD 格式)
|
||||
"""
|
||||
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
|
||||
|
||||
def _adjust_start_date(self, start_date: str, days: int) -> str:
|
||||
"""调整开始日期(向前回溯指定天数)。
|
||||
|
||||
Args:
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
days: 回溯天数
|
||||
|
||||
Returns:
|
||||
调整后的日期(YYYYMMDD 格式)
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
dt = datetime.strptime(start_date, "%Y%m%d")
|
||||
adjusted_dt = dt - timedelta(days=days)
|
||||
return adjusted_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def query_duckdb_to_polars(query: str, db_path: str) -> pl.LazyFrame:
|
||||
"""执行 DuckDB 查询并返回 Polars LazyFrame。
|
||||
|
||||
使用 duckdb.connect().sql(query).pl() 实现高速数据流转。
|
||||
|
||||
Args:
|
||||
query: SQL 查询语句
|
||||
db_path: DuckDB 数据库文件路径
|
||||
|
||||
Returns:
|
||||
Polars LazyFrame
|
||||
"""
|
||||
conn = duckdb.connect(db_path)
|
||||
try:
|
||||
# DuckDB -> Polars 高速转换
|
||||
df = conn.sql(query).pl()
|
||||
return df.lazy()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def build_context_lazyframe(
|
||||
required_fields: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
db_uri: str,
|
||||
catalog: Optional[DatabaseCatalog] = None,
|
||||
lookback_days: int = 90,
|
||||
) -> pl.LazyFrame:
|
||||
"""构建上下文 LazyFrame,根据所需字段动态生成 SQL 并合并数据。
|
||||
|
||||
核心逻辑:
|
||||
1. 根据 required_fields 反查涉及的表名
|
||||
2. 对每个表生成精简的 SQL 查询
|
||||
3. 从 DuckDB 加载数据到 Polars LazyFrame
|
||||
4. 合并不同表的数据:
|
||||
- DAILY 表按 ["trade_date", "ts_code"] 进行 left_join
|
||||
- PIT 表使用 join_asof 按公告日期对齐
|
||||
5. 最终按 ["ts_code", "trade_date"] 排序
|
||||
|
||||
Args:
|
||||
required_fields: 需要的字段列表
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
db_uri: 数据库连接 URI(如 "duckdb:///data/prostock.db")
|
||||
catalog: 数据库目录实例,如果为 None 则自动创建并发现表
|
||||
lookback_days: PIT 表回溯天数(默认90天)
|
||||
|
||||
Returns:
|
||||
合并后的 LazyFrame,包含所有请求的字段
|
||||
|
||||
Example:
|
||||
>>> lf = build_context_lazyframe(
|
||||
... required_fields=["close", "vol", "basic_eps"],
|
||||
... start_date="20240101",
|
||||
... end_date="20240131",
|
||||
... db_uri="duckdb:///data/prostock.db"
|
||||
... )
|
||||
>>> df = lf.collect()
|
||||
"""
|
||||
# 解析数据库路径
|
||||
db_path = db_uri.replace("duckdb://", "").lstrip("/")
|
||||
|
||||
# 如果没有提供 catalog,自动创建并发现表
|
||||
if catalog is None:
|
||||
catalog = DatabaseCatalog(db_path)
|
||||
|
||||
# 获取涉及的表
|
||||
tables = catalog.get_required_tables(required_fields)
|
||||
|
||||
if not tables:
|
||||
# 如果没有涉及的表,返回空 DataFrame
|
||||
return pl.LazyFrame({"ts_code": [], "trade_date": []})
|
||||
|
||||
# 分离 DAILY 表和 PIT 表
|
||||
daily_tables: List[str] = []
|
||||
pit_tables: List[str] = []
|
||||
|
||||
for table_name in tables:
|
||||
if catalog.is_pit_table(table_name):
|
||||
pit_tables.append(table_name)
|
||||
else:
|
||||
daily_tables.append(table_name)
|
||||
|
||||
# 构建 SQL 查询器
|
||||
query_builder = SQLQueryBuilder(catalog)
|
||||
|
||||
# 加载 DAILY 表数据
|
||||
daily_lfs: Dict[str, pl.LazyFrame] = {}
|
||||
for table_name in daily_tables:
|
||||
fields = catalog.get_fields_for_table(table_name, required_fields)
|
||||
sql = query_builder.build_query(
|
||||
table_name=table_name,
|
||||
fields=fields,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
print(f"[SQL] {sql[:100]}...")
|
||||
|
||||
lf = query_duckdb_to_polars(sql, db_path)
|
||||
|
||||
# 统一列名:将表的 date_field 重命名为 trade_date
|
||||
metadata = catalog.get_table_metadata(table_name)
|
||||
if metadata and metadata.date_field != "trade_date":
|
||||
lf = lf.rename({metadata.date_field: "trade_date"})
|
||||
|
||||
daily_lfs[table_name] = lf
|
||||
|
||||
# 加载 PIT 表数据
|
||||
pit_lfs: Dict[str, pl.LazyFrame] = {}
|
||||
for table_name in pit_tables:
|
||||
fields = catalog.get_fields_for_table(table_name, required_fields)
|
||||
sql = query_builder.build_query(
|
||||
table_name=table_name,
|
||||
fields=fields,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
lookback_days=lookback_days,
|
||||
)
|
||||
print(f"[SQL] {sql[:100]}...")
|
||||
|
||||
lf = query_duckdb_to_polars(sql, db_path)
|
||||
|
||||
# PIT 表保持原始公告日期字段(用于 join_asof)
|
||||
pit_lfs[table_name] = lf
|
||||
|
||||
# 合并所有 DAILY 表(以第一个 daily 表为基准)
|
||||
result_lf: Optional[pl.LazyFrame] = None
|
||||
|
||||
if daily_lfs:
|
||||
# 使用第一个 daily 表作为基准
|
||||
first_table = daily_tables[0]
|
||||
result_lf = daily_lfs[first_table]
|
||||
|
||||
# 合并其他 daily 表
|
||||
for table_name in daily_tables[1:]:
|
||||
lf = daily_lfs[table_name]
|
||||
result_lf = result_lf.join(lf, on=["trade_date", "ts_code"], how="left")
|
||||
elif pit_lfs:
|
||||
# 如果没有 daily 表,从 PIT 表创建基准时间轴
|
||||
# 使用第一个 PIT 表的日期范围
|
||||
first_pit = pit_tables[0]
|
||||
pit_metadata = catalog.get_table_metadata(first_pit)
|
||||
|
||||
# 从 PIT 表提取所有日期和股票代码组合
|
||||
result_lf = (
|
||||
pit_lfs[first_pit]
|
||||
.select([pl.col(pit_metadata.date_field).alias("trade_date"), "ts_code"])
|
||||
.unique()
|
||||
)
|
||||
|
||||
# 如果没有结果,返回空 DataFrame
|
||||
if result_lf is None:
|
||||
return pl.LazyFrame({"ts_code": [], "trade_date": []})
|
||||
|
||||
# 合并 PIT 表(使用 join_asof 按公告日期对齐)
|
||||
for table_name in pit_tables:
|
||||
pit_metadata = catalog.get_table_metadata(table_name)
|
||||
lf = pit_lfs[table_name]
|
||||
|
||||
# join_asof: 按 ts_code 分组,将 PIT 数据对齐到交易日
|
||||
# 策略为 backward:使用小于等于当前交易日的最新公告数据
|
||||
result_lf = result_lf.join_asof(
|
||||
lf,
|
||||
left_on="trade_date",
|
||||
right_on=pit_metadata.date_field,
|
||||
by="ts_code",
|
||||
strategy="backward",
|
||||
)
|
||||
|
||||
# 最终排序:按 ["ts_code", "trade_date"] 确保时序计算要求
|
||||
result_lf = result_lf.sort(["ts_code", "trade_date"])
|
||||
|
||||
return result_lf
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
print("=" * 60)
|
||||
print("DatabaseCatalog 自动发现测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试自动发现
|
||||
catalog = DatabaseCatalog("data/prostock.db")
|
||||
|
||||
print("\n=== 测试字段到表映射 ===")
|
||||
print(f"字段 'close' 对应的表: {catalog.get_table_for_field('close')}")
|
||||
print(f"字段 'vol' 对应的表: {catalog.get_table_for_field('vol')}")
|
||||
print(f"字段 'pe' 对应的表: {catalog.get_table_for_field('pe')}")
|
||||
print(f"字段 'basic_eps' 对应的表: {catalog.get_table_for_field('basic_eps')}")
|
||||
|
||||
print("\n=== 测试表频度类型 ===")
|
||||
for table_name in catalog.tables:
|
||||
freq = catalog.get_table_frequency(table_name)
|
||||
print(f"表 '{table_name}' 的频度: {freq.value if freq else 'Unknown'}")
|
||||
|
||||
print("\n=== 测试 SQL 构建 ===")
|
||||
query_builder = SQLQueryBuilder(catalog)
|
||||
|
||||
daily_sql = query_builder.build_query(
|
||||
table_name="daily",
|
||||
fields=["ts_code", "trade_date", "close", "vol"],
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
print(f"\nDAILY 表 SQL:\n{daily_sql}")
|
||||
|
||||
pit_sql = query_builder.build_query(
|
||||
table_name="financial_income",
|
||||
fields=["ts_code", "ann_date", "basic_eps", "total_revenue"],
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
lookback_days=90,
|
||||
)
|
||||
print(f"\nPIT 表 SQL:\n{pit_sql}")
|
||||
|
||||
print("\n=== 测试多表字段收集 ===")
|
||||
required_fields = ["close", "vol", "pe", "basic_eps", "total_revenue"]
|
||||
tables = catalog.get_required_tables(required_fields)
|
||||
print(f"字段 {required_fields} 涉及的表: {tables}")
|
||||
|
||||
for table_name in tables:
|
||||
fields = catalog.get_fields_for_table(table_name, required_fields)
|
||||
print(f" 表 '{table_name}' 需要查询的字段: {fields}")
|
||||
|
||||
print("\n所有测试通过!")
|
||||
Reference in New Issue
Block a user