feat: 添加DSL因子表达式系统和Pro Bar API封装
- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合 - 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等) - 新增 factors/compiler.py: 因子编译器 - 新增 factors/translator.py: DSL表达式翻译器 - 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据 - 新增 data/data_router.py: 数据路由功能 - 新增相关测试用例
This commit is contained in:
880
src/data/api_wrappers/api_pro_bar.py
Normal file
880
src/data/api_wrappers/api_pro_bar.py
Normal file
@@ -0,0 +1,880 @@
|
||||
"""Pro Bar (通用行情) interface.
|
||||
|
||||
Fetch A-share stock market data with adjustment factors from Tushare.
|
||||
This interface provides backward-adjusted (后复权) daily market data
|
||||
including all available fields: base price data, turnover rate (tor),
|
||||
volume ratio (vr), and adjustment factors.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
|
||||
from src.config.settings import get_settings
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
|
||||
|
||||
|
||||
def get_pro_bar(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
asset: Literal["E", "I", "C", "FT", "FD", "O", "CB"] = "E",
|
||||
adj: Literal[None, "qfq", "hfq"] = "hfq",
|
||||
freq: Literal["D", "W", "M"] = "D",
|
||||
ma: Optional[List[int]] = None,
|
||||
factors: Optional[List[Literal["tor", "vr"]]] = None,
|
||||
adjfactor: bool = True,
|
||||
client: Optional[TushareClient] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch pro bar (universal market) data from Tushare.
|
||||
|
||||
This interface retrieves stock market data with adjustment factors.
|
||||
By default, it fetches backward-adjusted (后复权) daily data for stocks
|
||||
with turnover rate and volume ratio factors enabled.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
asset: Asset type - 'E' (stock), 'I' (index), 'C' (crypto),
|
||||
'FT' (futures), 'FD' (fund), 'O' (options), 'CB' (convertible bond)
|
||||
adj: Adjustment type - None (no adjustment), 'qfq' (forward),
|
||||
'hfq' (backward). Default is 'hfq' (backward-adjusted).
|
||||
freq: Data frequency - 'D' (daily), 'W' (weekly), 'M' (monthly)
|
||||
ma: List of moving average periods (e.g., [5, 10, 20])
|
||||
factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio).
|
||||
Default is ['tor', 'vr'] to fetch all available fields.
|
||||
adjfactor: Whether to include adjustment factor column. Default is True.
|
||||
client: Optional TushareClient instance for shared rate limiting.
|
||||
If None, creates a new client. For concurrent sync operations,
|
||||
pass a shared client to ensure proper rate limiting.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
- ts_code: Stock code
|
||||
- trade_date: Trade date (YYYYMMDD)
|
||||
- open: Opening price
|
||||
- high: Highest price
|
||||
- low: Lowest price
|
||||
- close: Closing price
|
||||
- pre_close: Previous closing price (adjusted)
|
||||
- change: Price change amount
|
||||
- pct_chg: Price change percentage
|
||||
- vol: Trading volume (lots)
|
||||
- amount: Trading amount (thousand CNY)
|
||||
- tor: Turnover rate (if factors includes 'tor')
|
||||
- vr: Volume ratio (if factors includes 'vr')
|
||||
- adj_factor: Adjustment factor (if adjfactor=True)
|
||||
- ma_X: Moving average price for period X (if ma specified)
|
||||
- ma_v_X: Moving average volume for period X (if ma specified)
|
||||
|
||||
Example:
|
||||
>>> # Get backward-adjusted daily data with all factors (default)
|
||||
>>> data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # Get unadjusted data
|
||||
>>> data = get_pro_bar('000001.SZ', start_date='20240101', adj=None)
|
||||
>>>
|
||||
>>> # Get data with moving averages
|
||||
>>> data = get_pro_bar('000001.SZ', start_date='20240101', ma=[5, 10, 20])
|
||||
>>>
|
||||
>>> # Get index data
|
||||
>>> data = get_pro_bar('000001.SH', asset='I', start_date='20240101')
|
||||
"""
|
||||
client = client or TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {"ts_code": ts_code}
|
||||
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if asset:
|
||||
params["asset"] = asset
|
||||
if adj:
|
||||
params["adj"] = adj
|
||||
if freq:
|
||||
params["freq"] = freq
|
||||
if ma:
|
||||
# Tushare expects ma as comma-separated string
|
||||
if isinstance(ma, list):
|
||||
ma_str = ",".join(str(m) for m in ma)
|
||||
else:
|
||||
ma_str = str(ma)
|
||||
params["ma"] = ma_str
|
||||
|
||||
# Default to fetching all factors if not specified
|
||||
factors_to_use = factors if factors is not None else ["tor", "vr"]
|
||||
if factors_to_use:
|
||||
# Tushare expects factors as comma-separated string
|
||||
if isinstance(factors_to_use, list):
|
||||
factors_str = ",".join(factors_to_use)
|
||||
else:
|
||||
factors_str = factors_to_use
|
||||
params["factors"] = factors_str
|
||||
|
||||
if adjfactor:
|
||||
params["adjfactor"] = "True"
|
||||
|
||||
# Fetch data using pro_bar API
|
||||
data = client.query("pro_bar", **params)
|
||||
|
||||
# Rename date column if needed
|
||||
if "date" in data.columns:
|
||||
data = data.rename(columns={"date": "trade_date"})
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ProBarSync - Pro Bar 数据批量同步类
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ProBarSync:
|
||||
"""Pro Bar 数据批量同步管理器,支持全量/增量同步。
|
||||
|
||||
功能特性:
|
||||
- 多线程并发获取(ThreadPoolExecutor)
|
||||
- 增量同步(自动检测上次同步位置)
|
||||
- 内存缓存(避免重复磁盘读取)
|
||||
- 异常立即停止(确保数据一致性)
|
||||
- 预览模式(预览同步数据量,不实际写入)
|
||||
- 默认获取全部数据列(tor, vr, adj_factor)
|
||||
"""
|
||||
|
||||
# 默认工作线程数(从配置读取,默认10)
|
||||
DEFAULT_MAX_WORKERS = get_settings().threads
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""初始化同步管理器。
|
||||
|
||||
max_workers: 工作线程数(默认从配置读取,若未指定则使用配置值)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
"""
|
||||
self.storage = ThreadSafeStorage()
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
self._stop_flag.set() # 初始为未停止状态
|
||||
self._cached_pro_bar_data: Optional[pd.DataFrame] = None # 数据缓存
|
||||
|
||||
def _load_pro_bar_data(self) -> pd.DataFrame:
|
||||
"""从存储加载 Pro Bar 数据(带缓存)。
|
||||
|
||||
该方法会将数据缓存在内存中以避免重复磁盘读取。
|
||||
调用 clear_cache() 可强制重新加载。
|
||||
|
||||
Returns:
|
||||
缓存或从存储加载的 Pro Bar 数据 DataFrame
|
||||
"""
|
||||
if self._cached_pro_bar_data is None:
|
||||
self._cached_pro_bar_data = self.storage.load("pro_bar")
|
||||
return self._cached_pro_bar_data
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除缓存的 Pro Bar 数据,强制下次访问时重新加载。"""
|
||||
self._cached_pro_bar_data = None
|
||||
|
||||
def get_all_stock_codes(self, only_listed: bool = True) -> list:
|
||||
"""从本地存储获取所有股票代码。
|
||||
|
||||
优先使用 stock_basic.csv 以确保包含所有股票,
|
||||
避免回测中的前视偏差。
|
||||
|
||||
Args:
|
||||
only_listed: 若为 True,仅返回当前上市股票(L 状态)。
|
||||
设为 False 可包含退市股票(用于完整回测)。
|
||||
|
||||
Returns:
|
||||
股票代码列表
|
||||
"""
|
||||
# 确保 stock_basic.csv 是最新的
|
||||
print("[ProBarSync] Ensuring stock_basic.csv is up-to-date...")
|
||||
sync_all_stocks()
|
||||
|
||||
# 从 stock_basic.csv 文件获取
|
||||
stock_csv_path = _get_csv_path()
|
||||
|
||||
if stock_csv_path.exists():
|
||||
print(f"[ProBarSync] Reading stock_basic from CSV: {stock_csv_path}")
|
||||
try:
|
||||
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
|
||||
if not stock_df.empty and "ts_code" in stock_df.columns:
|
||||
# 根据 list_status 过滤
|
||||
if only_listed and "list_status" in stock_df.columns:
|
||||
listed_stocks = stock_df[stock_df["list_status"] == "L"]
|
||||
codes = listed_stocks["ts_code"].unique().tolist()
|
||||
total = len(stock_df["ts_code"].unique())
|
||||
print(
|
||||
f"[ProBarSync] Found {len(codes)} listed stocks (filtered from {total} total)"
|
||||
)
|
||||
else:
|
||||
codes = stock_df["ts_code"].unique().tolist()
|
||||
print(
|
||||
f"[ProBarSync] Found {len(codes)} stock codes from stock_basic.csv"
|
||||
)
|
||||
return codes
|
||||
else:
|
||||
print(
|
||||
f"[ProBarSync] stock_basic.csv exists but no ts_code column or empty"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[ProBarSync] Error reading stock_basic.csv: {e}")
|
||||
|
||||
# 回退:从 Pro Bar 存储获取
|
||||
print(
|
||||
"[ProBarSync] stock_basic.csv not available, falling back to pro_bar data..."
|
||||
)
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if not pro_bar_data.empty and "ts_code" in pro_bar_data.columns:
|
||||
codes = pro_bar_data["ts_code"].unique().tolist()
|
||||
print(f"[ProBarSync] Found {len(codes)} stock codes from pro_bar data")
|
||||
return codes
|
||||
|
||||
print("[ProBarSync] No stock codes found in local storage")
|
||||
return []
|
||||
|
||||
def get_global_last_date(self) -> Optional[str]:
|
||||
"""获取全局最后交易日期。
|
||||
|
||||
Returns:
|
||||
最后交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
|
||||
return None
|
||||
return str(pro_bar_data["trade_date"].max())
|
||||
|
||||
def get_global_first_date(self) -> Optional[str]:
|
||||
"""获取全局最早交易日期。
|
||||
|
||||
Returns:
|
||||
最早交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
pro_bar_data = self._load_pro_bar_data()
|
||||
if pro_bar_data.empty or "trade_date" not in pro_bar_data.columns:
|
||||
return None
|
||||
return str(pro_bar_data["trade_date"].min())
|
||||
|
||||
def get_trade_calendar_bounds(
|
||||
self, start_date: str, end_date: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""从交易日历获取首尾交易日。
|
||||
|
||||
Args:
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
|
||||
Returns:
|
||||
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
|
||||
"""
|
||||
try:
|
||||
first_day = get_first_trading_day(start_date, end_date)
|
||||
last_day = get_last_trading_day(start_date, end_date)
|
||||
return (first_day, last_day)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
|
||||
return (None, None)
|
||||
|
||||
def check_sync_needed(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
table_name: str = "pro_bar",
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""基于交易日历检查是否需要同步。
|
||||
|
||||
该方法比较本地数据日期范围与交易日历,
|
||||
以确定是否需要获取新数据。
|
||||
|
||||
逻辑:
|
||||
- 若 force_full:需要同步,返回 (True, 20180101, today)
|
||||
- 若无本地数据:需要同步,返回 (True, 20180101, today)
|
||||
- 若存在本地数据:
|
||||
- 从交易日历获取最后交易日
|
||||
- 若本地最后日期 >= 日历最后日期:无需同步
|
||||
- 否则:从本地最后日期+1 到最新交易日同步
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,始终返回需要同步
|
||||
table_name: 要检查的表名(默认: "pro_bar")
|
||||
|
||||
Returns:
|
||||
(需要同步, 起始日期, 结束日期, 本地最后日期)
|
||||
- 需要同步: True 表示应继续同步
|
||||
- 起始日期: 同步起始日期(无需同步时为 None)
|
||||
- 结束日期: 同步结束日期(无需同步时为 None)
|
||||
- 本地最后日期: 本地数据最后日期(用于增量同步)
|
||||
"""
|
||||
# 若 force_full,始终同步
|
||||
if force_full:
|
||||
print("[ProBarSync] Force full sync requested")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 检查特定表的本地数据是否存在
|
||||
storage = Storage()
|
||||
table_data = (
|
||||
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
|
||||
)
|
||||
|
||||
if table_data.empty or "trade_date" not in table_data.columns:
|
||||
print(
|
||||
f"[ProBarSync] No local data found for table '{table_name}', full sync needed"
|
||||
)
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 获取本地数据最后日期
|
||||
local_last_date = str(table_data["trade_date"].max())
|
||||
|
||||
print(f"[ProBarSync] Local data last date: {local_last_date}")
|
||||
|
||||
# 从交易日历获取最新交易日
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print("[ProBarSync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[ProBarSync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# 比较本地最后日期与日历最后日期
|
||||
print(
|
||||
f"[ProBarSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
|
||||
f"cal={cal_last} (type={type(cal_last).__name__})"
|
||||
)
|
||||
try:
|
||||
local_last_int = int(local_last_date)
|
||||
cal_last_int = int(cal_last)
|
||||
print(
|
||||
f"[ProBarSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
|
||||
f"{local_last_int >= cal_last_int}"
|
||||
)
|
||||
if local_last_int >= cal_last_int:
|
||||
print(
|
||||
"[ProBarSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
|
||||
)
|
||||
return (False, None, None, None)
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"[ERROR] Date comparison failed: {e}")
|
||||
|
||||
# 需要从本地最后日期+1 同步到最新交易日
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[ProBarSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览同步数据量和样本(不实际同步)。
|
||||
|
||||
该方法提供即将同步的数据的预览,包括:
|
||||
- 将同步的股票数量
|
||||
- 同步日期范围
|
||||
- 预估总记录数
|
||||
- 前几只股票的样本数据
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' 或 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的
|
||||
print("[ProBarSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[ProBarSync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[ProBarSync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# 从前几只股票获取样本数据
|
||||
print(f"[ProBarSync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
# 使用 get_pro_bar 获取样本数据(包含所有字段)
|
||||
data = get_pro_bar(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# 合并样本数据
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# 基于样本估算总记录数
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# 显示预览结果
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# 以紧凑格式显示样本数据
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
"tor",
|
||||
"vr",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""同步单只股票的 Pro Bar 数据。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
start_date: 起始日期(YYYYMMDD)
|
||||
end_date: 结束日期(YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
包含 Pro Bar 数据的 DataFrame
|
||||
"""
|
||||
# 检查是否应该停止同步(用于异常处理)
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# 使用 get_pro_bar 获取数据(默认包含所有字段,传递共享 client)
|
||||
data = get_pro_bar(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
client=self.client, # 传递共享客户端以确保限流
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
# 设置停止标志以通知其他线程停止
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {ts_code}: {e}")
|
||||
raise
|
||||
|
||||
def sync_all(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步本地存储中所有股票的 Pro Bar 数据。
|
||||
|
||||
该函数:
|
||||
1. 从本地存储读取股票代码(pro_bar 或 stock_basic)
|
||||
2. 检查交易日历确定是否需要同步:
|
||||
- 若本地数据匹配交易日历边界,则跳过同步(节省 token)
|
||||
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
|
||||
3. 使用多线程并发获取(带速率限制)
|
||||
4. 跳过返回空数据的股票(退市/不可用)
|
||||
5. 遇异常立即停止
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Starting pro_bar data sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的(使用增量同步)
|
||||
print("[ProBarSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 基于交易日历检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
# 跳过同步 - 不消耗 token
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
|
||||
print(" Tokens saved: 0 consumed")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
# 回退到默认逻辑
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[ProBarSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[ProBarSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[ProBarSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[ProBarSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[ProBarSync] No stocks found to sync")
|
||||
return {}
|
||||
|
||||
print(f"[ProBarSync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[ProBarSync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# 处理 dry run 模式
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 为新同步重置停止标志
|
||||
self._stop_flag.set()
|
||||
|
||||
# 多线程并发获取
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
error_occurred = False
|
||||
exception_to_raise = None
|
||||
|
||||
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
|
||||
"""每只股票的任务函数。"""
|
||||
try:
|
||||
data = self.sync_single_stock(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return (ts_code, data)
|
||||
except Exception as e:
|
||||
# 重新抛出以被 Future 捕获
|
||||
raise
|
||||
|
||||
# 使用 ThreadPoolExecutor 进行并发获取
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 提交所有任务并跟踪 futures 与股票代码的映射
|
||||
future_to_code = {
|
||||
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
|
||||
}
|
||||
|
||||
# 使用 as_completed 处理结果
|
||||
error_count = 0
|
||||
empty_count = 0
|
||||
success_count = 0
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(stock_codes), desc="Syncing pro_bar stocks")
|
||||
|
||||
try:
|
||||
# 处理完成的 futures
|
||||
for future in as_completed(future_to_code):
|
||||
ts_code = future_to_code[future]
|
||||
|
||||
try:
|
||||
_, data = future.result()
|
||||
if data is not None and not data.empty:
|
||||
results[ts_code] = data
|
||||
success_count += 1
|
||||
else:
|
||||
# 空数据 - 股票可能已退市或不可用
|
||||
empty_count += 1
|
||||
print(
|
||||
f"[ProBarSync] Stock {ts_code}: empty data (skipped, may be delisted)"
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常 - 停止全部并中止
|
||||
error_occurred = True
|
||||
exception_to_raise = e
|
||||
print(f"\n[ERROR] Sync aborted due to exception: {e}")
|
||||
# 关闭 executor 以停止所有待处理任务
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise exception_to_raise
|
||||
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
|
||||
except Exception:
|
||||
error_count = 1
|
||||
print("[ProBarSync] Sync stopped due to exception")
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
# 批量写入所有数据(仅在无错误时)
|
||||
if results and not error_occurred:
|
||||
for ts_code, data in results.items():
|
||||
if not data.empty:
|
||||
self.storage.queue_save("pro_bar", data)
|
||||
# 一次性刷新所有排队写入
|
||||
self.storage.flush()
|
||||
total_rows = sum(len(df) for df in results.values())
|
||||
print(f"\n[ProBarSync] Saved {total_rows} rows to storage")
|
||||
|
||||
# 摘要
|
||||
print("\n" + "=" * 60)
|
||||
print("[ProBarSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(f" Total stocks: {len(stock_codes)}")
|
||||
print(f" Updated: {success_count}")
|
||||
print(f" Skipped (empty/delisted): {empty_count}")
|
||||
print(
|
||||
f" Errors: {error_count} (aborted on first error)"
|
||||
if error_count
|
||||
else " Errors: 0"
|
||||
)
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def sync_pro_bar(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步所有股票的 Pro Bar 数据。
|
||||
|
||||
这是 Pro Bar 数据同步的主要入口点。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(YYYYMMDD)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典
|
||||
|
||||
Example:
|
||||
>>> # 首次同步(从 20180101 全量加载)
|
||||
>>> result = sync_pro_bar()
|
||||
>>>
|
||||
>>> # 后续同步(增量 - 仅新数据)
|
||||
>>> result = sync_pro_bar()
|
||||
>>>
|
||||
>>> # 强制完整重载
|
||||
>>> result = sync_pro_bar(force_full=True)
|
||||
>>>
|
||||
>>> # 手动指定日期范围
|
||||
>>> result = sync_pro_bar(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # 自定义线程数
|
||||
>>> result = sync_pro_bar(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run(仅预览)
|
||||
>>> result = sync_pro_bar(dry_run=True)
|
||||
"""
|
||||
sync_manager = ProBarSync(max_workers=max_workers)
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def preview_pro_bar_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览 Pro Bar 同步数据量和样本(不实际同步)。
|
||||
|
||||
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # 预览将要同步的内容
|
||||
>>> preview = preview_pro_bar_sync()
|
||||
>>>
|
||||
>>> # 预览全量同步
|
||||
>>> preview = preview_pro_bar_sync(force_full=True)
|
||||
>>>
|
||||
>>> # 预览更多样本
|
||||
>>> preview = preview_pro_bar_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = ProBarSync()
|
||||
return sync_manager.preview_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
663
src/data/data_router.py
Normal file
663
src/data/data_router.py
Normal file
@@ -0,0 +1,663 @@
|
||||
"""数据目录与动态 SQL 路由模块。
|
||||
|
||||
用于动态 SQL 生成和数据拉取,解决多表架构下的数据查询痛点。
|
||||
支持 DAILY(日频精确对齐)和 PIT(低频财务数据,按披露日对齐)两种表类型。
|
||||
|
||||
核心特性:
|
||||
- 自动发现 DuckDB 数据库中的表结构
|
||||
- 支持通过配置覆盖自动发现的元数据
|
||||
- 智能识别 PIT 类型表(通过 ann_date/f_ann_date 字段)
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import polars as pl
|
||||
import duckdb
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TableFrequency(Enum):
|
||||
"""表频度类型。"""
|
||||
|
||||
DAILY = "daily" # 日频数据,精确对齐
|
||||
PIT = "pit" # 低频数据,按披露日对齐 (Point-In-Time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableMetadata:
|
||||
"""表元数据配置。
|
||||
|
||||
Attributes:
|
||||
name: 表名
|
||||
frequency: 表频度类型(DAILY 或 PIT)
|
||||
date_field: 日期字段名(DAILY 表为 trade_date,PIT 表为 ann_date)
|
||||
code_field: 资产代码字段名(通常为 ts_code)
|
||||
fields: 表中所有字段列表
|
||||
description: 表描述
|
||||
"""
|
||||
|
||||
name: str
|
||||
frequency: TableFrequency
|
||||
date_field: str
|
||||
code_field: str = "ts_code"
|
||||
fields: List[str] = field(default_factory=list)
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldMapping:
|
||||
"""字段映射配置。
|
||||
|
||||
Attributes:
|
||||
field_name: 字段名
|
||||
table_name: 所属表名
|
||||
description: 字段描述
|
||||
"""
|
||||
|
||||
field_name: str
|
||||
table_name: str
|
||||
description: str = ""
|
||||
|
||||
|
||||
class DatabaseCatalog:
|
||||
"""数据库目录类,管理字段到表的映射关系。
|
||||
|
||||
核心职责:
|
||||
1. 自动从 DuckDB 数据库中发现表结构
|
||||
2. 维护字段到表的映射关系
|
||||
3. 管理表的元数据(频度类型、日期字段等)
|
||||
4. 提供字段解析和表路由功能
|
||||
|
||||
表类型自动识别规则:
|
||||
- 如果表包含 ann_date 或 f_ann_date 字段,识别为 PIT 类型
|
||||
- 否则,如果包含 trade_date 字段,识别为 DAILY 类型
|
||||
|
||||
Attributes:
|
||||
tables: 表元数据字典,表名 -> TableMetadata
|
||||
field_mappings: 字段映射字典,字段名 -> FieldMapping
|
||||
db_path: 数据库文件路径
|
||||
|
||||
Example:
|
||||
>>> catalog = DatabaseCatalog("data/prostock.db")
|
||||
>>> # 自动发现所有表结构
|
||||
>>> catalog.discover_tables()
|
||||
>>> table = catalog.get_table_for_field("close")
|
||||
>>> print(table) # "daily"
|
||||
"""
|
||||
|
||||
# PIT 类型表的标识字段(优先级顺序)
|
||||
PIT_DATE_FIELDS = ["ann_date", "f_ann_date", "publish_date"]
|
||||
# DAILY 类型表的标识字段
|
||||
DAILY_DATE_FIELDS = ["trade_date", "cal_date", "date"]
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
"""初始化数据库目录。
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径,如果为 None 则使用默认配置
|
||||
"""
|
||||
self.tables: Dict[str, TableMetadata] = {}
|
||||
self.field_mappings: Dict[str, FieldMapping] = {}
|
||||
self.db_path = db_path
|
||||
self._table_frequency_overrides: Dict[str, TableFrequency] = {}
|
||||
|
||||
if db_path:
|
||||
self.discover_tables(db_path)
|
||||
|
||||
def set_table_frequency_override(
|
||||
self, table_name: str, frequency: TableFrequency
|
||||
) -> None:
|
||||
"""设置表频度类型覆盖。
|
||||
|
||||
用于手动指定表的频度类型,覆盖自动识别的结果。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
frequency: 频度类型(DAILY 或 PIT)
|
||||
"""
|
||||
self._table_frequency_overrides[table_name] = frequency
|
||||
|
||||
def discover_tables(self, db_path: str) -> None:
|
||||
"""自动发现数据库中的所有表结构。
|
||||
|
||||
从 information_schema 中读取表和列信息,自动识别:
|
||||
- 表名和字段列表
|
||||
- 资产代码字段(ts_code)
|
||||
- 日期字段(根据字段名智能识别表类型)
|
||||
- 表频度类型(DAILY 或 PIT)
|
||||
|
||||
Args:
|
||||
db_path: DuckDB 数据库文件路径
|
||||
"""
|
||||
db_file = db_path.replace("duckdb://", "").lstrip("/")
|
||||
|
||||
if not Path(db_file).exists():
|
||||
print(f"[DatabaseCatalog] 数据库文件不存在: {db_file}")
|
||||
return
|
||||
|
||||
conn = duckdb.connect(db_file, read_only=True)
|
||||
try:
|
||||
# 获取所有表
|
||||
tables_query = """
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'main'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
tables_result = conn.execute(tables_query).fetchall()
|
||||
|
||||
for (table_name,) in tables_result:
|
||||
# 获取表的列信息
|
||||
columns_query = """
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = ? AND table_schema = 'main'
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
columns_result = conn.execute(columns_query, [table_name]).fetchall()
|
||||
|
||||
fields = [col[0] for col in columns_result]
|
||||
|
||||
# 自动识别表类型和日期字段
|
||||
frequency, date_field = self._detect_table_type(fields, table_name)
|
||||
|
||||
# 检查是否有资产代码字段
|
||||
code_field = "ts_code" if "ts_code" in fields else None
|
||||
|
||||
if code_field and date_field:
|
||||
# 创建表元数据
|
||||
metadata = TableMetadata(
|
||||
name=table_name,
|
||||
frequency=frequency,
|
||||
date_field=date_field,
|
||||
code_field=code_field,
|
||||
fields=fields,
|
||||
description=f"自动发现的表: {table_name}",
|
||||
)
|
||||
self.register_table(metadata)
|
||||
print(
|
||||
f"[DatabaseCatalog] 发现表: {table_name} ({frequency.value}, "
|
||||
f"日期字段: {date_field})"
|
||||
)
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _detect_table_type(
|
||||
self, fields: List[str], table_name: str
|
||||
) -> tuple[TableFrequency, Optional[str]]:
|
||||
"""自动检测表的频度类型和日期字段。
|
||||
|
||||
检测规则(按优先级):
|
||||
1. 检查是否有手动覆盖配置
|
||||
2. 检查是否包含 PIT 标识字段(ann_date, f_ann_date 等)
|
||||
3. 检查是否包含 DAILY 标识字段(trade_date, cal_date 等)
|
||||
|
||||
Args:
|
||||
fields: 表的字段列表
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
(频度类型, 日期字段名)
|
||||
"""
|
||||
# 检查手动覆盖配置
|
||||
if table_name in self._table_frequency_overrides:
|
||||
frequency = self._table_frequency_overrides[table_name]
|
||||
if frequency == TableFrequency.PIT:
|
||||
for field in self.PIT_DATE_FIELDS:
|
||||
if field in fields:
|
||||
return frequency, field
|
||||
else:
|
||||
for field in self.DAILY_DATE_FIELDS:
|
||||
if field in fields:
|
||||
return frequency, field
|
||||
|
||||
# 检查 PIT 标识字段
|
||||
for field in self.PIT_DATE_FIELDS:
|
||||
if field in fields:
|
||||
return TableFrequency.PIT, field
|
||||
|
||||
# 检查 DAILY 标识字段
|
||||
for field in self.DAILY_DATE_FIELDS:
|
||||
if field in fields:
|
||||
return TableFrequency.DAILY, field
|
||||
|
||||
# 默认返回 DAILY,但无日期字段
|
||||
return TableFrequency.DAILY, None
|
||||
|
||||
def register_table(self, metadata: TableMetadata) -> None:
|
||||
"""注册表元数据。
|
||||
|
||||
Args:
|
||||
metadata: 表元数据配置
|
||||
"""
|
||||
self.tables[metadata.name] = metadata
|
||||
|
||||
# 自动注册字段映射(如果字段已存在,保留第一个表的映射)
|
||||
for field_name in metadata.fields:
|
||||
if field_name not in self.field_mappings:
|
||||
self.field_mappings[field_name] = FieldMapping(
|
||||
field_name=field_name,
|
||||
table_name=metadata.name,
|
||||
description=f"{metadata.description} - {field_name}",
|
||||
)
|
||||
|
||||
def get_table_for_field(self, field: str) -> Optional[str]:
|
||||
"""获取字段对应的表名。
|
||||
|
||||
Args:
|
||||
field: 字段名
|
||||
|
||||
Returns:
|
||||
表名,如果字段不存在则返回 None
|
||||
"""
|
||||
mapping = self.field_mappings.get(field)
|
||||
return mapping.table_name if mapping else None
|
||||
|
||||
def get_table_metadata(self, table_name: str) -> Optional[TableMetadata]:
|
||||
"""获取表的元数据。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
表元数据,如果不存在则返回 None
|
||||
"""
|
||||
return self.tables.get(table_name)
|
||||
|
||||
def get_table_frequency(self, table_name: str) -> Optional[TableFrequency]:
|
||||
"""获取表的频度类型。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
表频度类型(DAILY 或 PIT),如果不存在则返回 None
|
||||
"""
|
||||
metadata = self.tables.get(table_name)
|
||||
return metadata.frequency if metadata else None
|
||||
|
||||
def get_required_tables(self, fields: List[str]) -> Set[str]:
|
||||
"""获取所需字段涉及的所有表名。
|
||||
|
||||
Args:
|
||||
fields: 字段列表
|
||||
|
||||
Returns:
|
||||
涉及的表名集合
|
||||
"""
|
||||
tables = set()
|
||||
for field in fields:
|
||||
table = self.get_table_for_field(field)
|
||||
if table:
|
||||
tables.add(table)
|
||||
return tables
|
||||
|
||||
def get_fields_for_table(
|
||||
self, table_name: str, required_fields: List[str]
|
||||
) -> List[str]:
|
||||
"""获取指定表需要的字段列表(包含必要的键字段)。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
required_fields: 用户请求的所有字段
|
||||
|
||||
Returns:
|
||||
该表需要查询的字段列表(包含键字段)
|
||||
"""
|
||||
metadata = self.tables.get(table_name)
|
||||
if not metadata:
|
||||
return []
|
||||
|
||||
# 基础键字段
|
||||
fields = [metadata.code_field, metadata.date_field]
|
||||
|
||||
# 添加用户请求的字段(属于该表的)
|
||||
for field in required_fields:
|
||||
if self.get_table_for_field(field) == table_name and field not in fields:
|
||||
fields.append(field)
|
||||
|
||||
return fields
|
||||
|
||||
def is_pit_table(self, table_name: str) -> bool:
|
||||
"""判断表是否为 PIT 类型。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
是否为 PIT 类型表
|
||||
"""
|
||||
frequency = self.get_table_frequency(table_name)
|
||||
return frequency == TableFrequency.PIT
|
||||
|
||||
|
||||
class SQLQueryBuilder:
|
||||
"""SQL 查询构建器。
|
||||
|
||||
根据表类型(DAILY/PIT)构建优化的 SQL 查询。
|
||||
"""
|
||||
|
||||
def __init__(self, catalog: DatabaseCatalog):
|
||||
"""初始化 SQL 构建器。
|
||||
|
||||
Args:
|
||||
catalog: 数据库目录实例
|
||||
"""
|
||||
self.catalog = catalog
|
||||
|
||||
def build_query(
|
||||
self,
|
||||
table_name: str,
|
||||
fields: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
lookback_days: int = 90,
|
||||
) -> str:
|
||||
"""构建优化的 SQL 查询。
|
||||
|
||||
对于 PIT 类型表,会自动向前回溯 lookback_days 天,
|
||||
以确保起始日期能匹配到最近的旧数据。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
fields: 需要查询的字段列表
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
lookback_days: PIT 表回溯天数(默认90天)
|
||||
|
||||
Returns:
|
||||
构建好的 SQL 查询语句
|
||||
"""
|
||||
metadata = self.catalog.get_table_metadata(table_name)
|
||||
if not metadata:
|
||||
raise ValueError(f"未知的表: {table_name}")
|
||||
|
||||
# 构建字段列表
|
||||
fields_str = ", ".join(fields)
|
||||
|
||||
# 根据表类型构建 WHERE 条件
|
||||
if metadata.frequency == TableFrequency.PIT:
|
||||
# PIT 表:按公告日期查询,需要向前回溯
|
||||
date_field = metadata.date_field
|
||||
query_start = self._adjust_start_date(start_date, lookback_days)
|
||||
query_start_fmt = self._format_date(query_start)
|
||||
end_date_fmt = self._format_date(end_date)
|
||||
|
||||
sql = f"""
|
||||
SELECT {fields_str}
|
||||
FROM {table_name}
|
||||
WHERE {date_field} >= '{query_start_fmt}'
|
||||
AND {date_field} <= '{end_date_fmt}'
|
||||
ORDER BY {metadata.code_field}, {date_field}
|
||||
"""
|
||||
else:
|
||||
# DAILY 表:直接按交易日期查询
|
||||
date_field = metadata.date_field
|
||||
start_date_fmt = self._format_date(start_date)
|
||||
end_date_fmt = self._format_date(end_date)
|
||||
sql = f"""
|
||||
SELECT {fields_str}
|
||||
FROM {table_name}
|
||||
WHERE {date_field} >= '{start_date_fmt}'
|
||||
AND {date_field} <= '{end_date_fmt}'
|
||||
ORDER BY {metadata.code_field}, {date_field}
|
||||
"""
|
||||
|
||||
return sql.strip()
|
||||
|
||||
def _format_date(self, date_str: str) -> str:
|
||||
"""将 YYYYMMDD 格式转换为 YYYY-MM-DD 格式。
|
||||
|
||||
Args:
|
||||
date_str: 日期字符串(YYYYMMDD 格式)
|
||||
|
||||
Returns:
|
||||
格式化后的日期字符串(YYYY-MM-DD 格式)
|
||||
"""
|
||||
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
|
||||
|
||||
def _adjust_start_date(self, start_date: str, days: int) -> str:
|
||||
"""调整开始日期(向前回溯指定天数)。
|
||||
|
||||
Args:
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
days: 回溯天数
|
||||
|
||||
Returns:
|
||||
调整后的日期(YYYYMMDD 格式)
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
dt = datetime.strptime(start_date, "%Y%m%d")
|
||||
adjusted_dt = dt - timedelta(days=days)
|
||||
return adjusted_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def query_duckdb_to_polars(query: str, db_path: str) -> pl.LazyFrame:
|
||||
"""执行 DuckDB 查询并返回 Polars LazyFrame。
|
||||
|
||||
使用 duckdb.connect().sql(query).pl() 实现高速数据流转。
|
||||
|
||||
Args:
|
||||
query: SQL 查询语句
|
||||
db_path: DuckDB 数据库文件路径
|
||||
|
||||
Returns:
|
||||
Polars LazyFrame
|
||||
"""
|
||||
conn = duckdb.connect(db_path)
|
||||
try:
|
||||
# DuckDB -> Polars 高速转换
|
||||
df = conn.sql(query).pl()
|
||||
return df.lazy()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def build_context_lazyframe(
|
||||
required_fields: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
db_uri: str,
|
||||
catalog: Optional[DatabaseCatalog] = None,
|
||||
lookback_days: int = 90,
|
||||
) -> pl.LazyFrame:
|
||||
"""构建上下文 LazyFrame,根据所需字段动态生成 SQL 并合并数据。
|
||||
|
||||
核心逻辑:
|
||||
1. 根据 required_fields 反查涉及的表名
|
||||
2. 对每个表生成精简的 SQL 查询
|
||||
3. 从 DuckDB 加载数据到 Polars LazyFrame
|
||||
4. 合并不同表的数据:
|
||||
- DAILY 表按 ["trade_date", "ts_code"] 进行 left_join
|
||||
- PIT 表使用 join_asof 按公告日期对齐
|
||||
5. 最终按 ["ts_code", "trade_date"] 排序
|
||||
|
||||
Args:
|
||||
required_fields: 需要的字段列表
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
db_uri: 数据库连接 URI(如 "duckdb:///data/prostock.db")
|
||||
catalog: 数据库目录实例,如果为 None 则自动创建并发现表
|
||||
lookback_days: PIT 表回溯天数(默认90天)
|
||||
|
||||
Returns:
|
||||
合并后的 LazyFrame,包含所有请求的字段
|
||||
|
||||
Example:
|
||||
>>> lf = build_context_lazyframe(
|
||||
... required_fields=["close", "vol", "basic_eps"],
|
||||
... start_date="20240101",
|
||||
... end_date="20240131",
|
||||
... db_uri="duckdb:///data/prostock.db"
|
||||
... )
|
||||
>>> df = lf.collect()
|
||||
"""
|
||||
# 解析数据库路径
|
||||
db_path = db_uri.replace("duckdb://", "").lstrip("/")
|
||||
|
||||
# 如果没有提供 catalog,自动创建并发现表
|
||||
if catalog is None:
|
||||
catalog = DatabaseCatalog(db_path)
|
||||
|
||||
# 获取涉及的表
|
||||
tables = catalog.get_required_tables(required_fields)
|
||||
|
||||
if not tables:
|
||||
# 如果没有涉及的表,返回空 DataFrame
|
||||
return pl.LazyFrame({"ts_code": [], "trade_date": []})
|
||||
|
||||
# 分离 DAILY 表和 PIT 表
|
||||
daily_tables: List[str] = []
|
||||
pit_tables: List[str] = []
|
||||
|
||||
for table_name in tables:
|
||||
if catalog.is_pit_table(table_name):
|
||||
pit_tables.append(table_name)
|
||||
else:
|
||||
daily_tables.append(table_name)
|
||||
|
||||
# 构建 SQL 查询器
|
||||
query_builder = SQLQueryBuilder(catalog)
|
||||
|
||||
# 加载 DAILY 表数据
|
||||
daily_lfs: Dict[str, pl.LazyFrame] = {}
|
||||
for table_name in daily_tables:
|
||||
fields = catalog.get_fields_for_table(table_name, required_fields)
|
||||
sql = query_builder.build_query(
|
||||
table_name=table_name,
|
||||
fields=fields,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
print(f"[SQL] {sql[:100]}...")
|
||||
|
||||
lf = query_duckdb_to_polars(sql, db_path)
|
||||
|
||||
# 统一列名:将表的 date_field 重命名为 trade_date
|
||||
metadata = catalog.get_table_metadata(table_name)
|
||||
if metadata and metadata.date_field != "trade_date":
|
||||
lf = lf.rename({metadata.date_field: "trade_date"})
|
||||
|
||||
daily_lfs[table_name] = lf
|
||||
|
||||
# 加载 PIT 表数据
|
||||
pit_lfs: Dict[str, pl.LazyFrame] = {}
|
||||
for table_name in pit_tables:
|
||||
fields = catalog.get_fields_for_table(table_name, required_fields)
|
||||
sql = query_builder.build_query(
|
||||
table_name=table_name,
|
||||
fields=fields,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
lookback_days=lookback_days,
|
||||
)
|
||||
print(f"[SQL] {sql[:100]}...")
|
||||
|
||||
lf = query_duckdb_to_polars(sql, db_path)
|
||||
|
||||
# PIT 表保持原始公告日期字段(用于 join_asof)
|
||||
pit_lfs[table_name] = lf
|
||||
|
||||
# 合并所有 DAILY 表(以第一个 daily 表为基准)
|
||||
result_lf: Optional[pl.LazyFrame] = None
|
||||
|
||||
if daily_lfs:
|
||||
# 使用第一个 daily 表作为基准
|
||||
first_table = daily_tables[0]
|
||||
result_lf = daily_lfs[first_table]
|
||||
|
||||
# 合并其他 daily 表
|
||||
for table_name in daily_tables[1:]:
|
||||
lf = daily_lfs[table_name]
|
||||
result_lf = result_lf.join(lf, on=["trade_date", "ts_code"], how="left")
|
||||
elif pit_lfs:
|
||||
# 如果没有 daily 表,从 PIT 表创建基准时间轴
|
||||
# 使用第一个 PIT 表的日期范围
|
||||
first_pit = pit_tables[0]
|
||||
pit_metadata = catalog.get_table_metadata(first_pit)
|
||||
|
||||
# 从 PIT 表提取所有日期和股票代码组合
|
||||
result_lf = (
|
||||
pit_lfs[first_pit]
|
||||
.select([pl.col(pit_metadata.date_field).alias("trade_date"), "ts_code"])
|
||||
.unique()
|
||||
)
|
||||
|
||||
# 如果没有结果,返回空 DataFrame
|
||||
if result_lf is None:
|
||||
return pl.LazyFrame({"ts_code": [], "trade_date": []})
|
||||
|
||||
# 合并 PIT 表(使用 join_asof 按公告日期对齐)
|
||||
for table_name in pit_tables:
|
||||
pit_metadata = catalog.get_table_metadata(table_name)
|
||||
lf = pit_lfs[table_name]
|
||||
|
||||
# join_asof: 按 ts_code 分组,将 PIT 数据对齐到交易日
|
||||
# 策略为 backward:使用小于等于当前交易日的最新公告数据
|
||||
result_lf = result_lf.join_asof(
|
||||
lf,
|
||||
left_on="trade_date",
|
||||
right_on=pit_metadata.date_field,
|
||||
by="ts_code",
|
||||
strategy="backward",
|
||||
)
|
||||
|
||||
# 最终排序:按 ["ts_code", "trade_date"] 确保时序计算要求
|
||||
result_lf = result_lf.sort(["ts_code", "trade_date"])
|
||||
|
||||
return result_lf
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
print("=" * 60)
|
||||
print("DatabaseCatalog 自动发现测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试自动发现
|
||||
catalog = DatabaseCatalog("data/prostock.db")
|
||||
|
||||
print("\n=== 测试字段到表映射 ===")
|
||||
print(f"字段 'close' 对应的表: {catalog.get_table_for_field('close')}")
|
||||
print(f"字段 'vol' 对应的表: {catalog.get_table_for_field('vol')}")
|
||||
print(f"字段 'pe' 对应的表: {catalog.get_table_for_field('pe')}")
|
||||
print(f"字段 'basic_eps' 对应的表: {catalog.get_table_for_field('basic_eps')}")
|
||||
|
||||
print("\n=== 测试表频度类型 ===")
|
||||
for table_name in catalog.tables:
|
||||
freq = catalog.get_table_frequency(table_name)
|
||||
print(f"表 '{table_name}' 的频度: {freq.value if freq else 'Unknown'}")
|
||||
|
||||
print("\n=== 测试 SQL 构建 ===")
|
||||
query_builder = SQLQueryBuilder(catalog)
|
||||
|
||||
daily_sql = query_builder.build_query(
|
||||
table_name="daily",
|
||||
fields=["ts_code", "trade_date", "close", "vol"],
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
print(f"\nDAILY 表 SQL:\n{daily_sql}")
|
||||
|
||||
pit_sql = query_builder.build_query(
|
||||
table_name="financial_income",
|
||||
fields=["ts_code", "ann_date", "basic_eps", "total_revenue"],
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
lookback_days=90,
|
||||
)
|
||||
print(f"\nPIT 表 SQL:\n{pit_sql}")
|
||||
|
||||
print("\n=== 测试多表字段收集 ===")
|
||||
required_fields = ["close", "vol", "pe", "basic_eps", "total_revenue"]
|
||||
tables = catalog.get_required_tables(required_fields)
|
||||
print(f"字段 {required_fields} 涉及的表: {tables}")
|
||||
|
||||
for table_name in tables:
|
||||
fields = catalog.get_fields_for_table(table_name, required_fields)
|
||||
print(f" 表 '{table_name}' 需要查询的字段: {fields}")
|
||||
|
||||
print("\n所有测试通过!")
|
||||
448
src/factors/api.py
Normal file
448
src/factors/api.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""DSL API 层 - 提供常用的符号和函数。
|
||||
|
||||
该模块提供量化因子表达式中常用的符号(如 close, open 等)
|
||||
和函数(如 ts_mean, cs_rank 等),用户可以直接导入使用。
|
||||
|
||||
示例:
|
||||
>>> from src.factors.api import close, open, ts_mean, cs_rank
|
||||
>>> expr = ts_mean(close - open, 20) / close
|
||||
>>> print(expr)
|
||||
ts_mean(((close - open), 20)) / close
|
||||
"""
|
||||
|
||||
from src.factors.dsl import Symbol, FunctionNode, Node, _ensure_node
|
||||
from typing import Union
|
||||
|
||||
# ==================== 常用价格符号 ====================
|
||||
|
||||
#: 收盘价
|
||||
close = Symbol("close")
|
||||
|
||||
#: 开盘价
|
||||
open = Symbol("open")
|
||||
|
||||
#: 最高价
|
||||
high = Symbol("high")
|
||||
|
||||
#: 最低价
|
||||
low = Symbol("low")
|
||||
|
||||
#: 成交量
|
||||
volume = Symbol("volume")
|
||||
|
||||
#: 成交额
|
||||
amount = Symbol("amount")
|
||||
|
||||
#: 前收盘价
|
||||
pre_close = Symbol("pre_close")
|
||||
|
||||
#: 涨跌额
|
||||
change = Symbol("change")
|
||||
|
||||
#: 涨跌幅
|
||||
pct_change = Symbol("pct_change")
|
||||
|
||||
|
||||
# ==================== 时间序列函数 (ts_*) ====================
|
||||
|
||||
|
||||
def ts_mean(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列均值。
|
||||
|
||||
计算给定因子在滚动窗口内的平均值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
|
||||
Example:
|
||||
>>> from src.factors.api import close, ts_mean
|
||||
>>> expr = ts_mean(close, 20) # 20日收盘价均值
|
||||
>>> expr = ts_mean("close", 20) # 使用字符串
|
||||
>>> print(expr)
|
||||
ts_mean(close, 20)
|
||||
"""
|
||||
return FunctionNode("ts_mean", x, window)
|
||||
|
||||
|
||||
def ts_std(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列标准差。
|
||||
|
||||
计算给定因子在滚动窗口内的标准差。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_std", x, window)
|
||||
|
||||
|
||||
def ts_max(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列最大值。
|
||||
|
||||
计算给定因子在滚动窗口内的最大值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_max", x, window)
|
||||
|
||||
|
||||
def ts_min(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列最小值。
|
||||
|
||||
计算给定因子在滚动窗口内的最小值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_min", x, window)
|
||||
|
||||
|
||||
def ts_sum(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列求和。
|
||||
|
||||
计算给定因子在滚动窗口内的求和。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_sum", x, window)
|
||||
|
||||
|
||||
def ts_delay(x: Union[Node, str], periods: int) -> FunctionNode:
|
||||
"""时间序列滞后。
|
||||
|
||||
获取给定因子在 N 个周期前的值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
periods: 滞后期数
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_delay", x, periods)
|
||||
|
||||
|
||||
def ts_delta(x: Union[Node, str], periods: int) -> FunctionNode:
|
||||
"""时间序列差分。
|
||||
|
||||
计算给定因子与 N 个周期前的差值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
periods: 差分期数
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_delta", x, periods)
|
||||
|
||||
|
||||
def ts_corr(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列相关系数。
|
||||
|
||||
计算两个因子在滚动窗口内的相关系数。
|
||||
|
||||
Args:
|
||||
x: 第一个因子表达式或字段名字符串
|
||||
y: 第二个因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_corr", x, y, window)
|
||||
|
||||
|
||||
def ts_cov(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列协方差。
|
||||
|
||||
计算两个因子在滚动窗口内的协方差。
|
||||
|
||||
Args:
|
||||
x: 第一个因子表达式或字段名字符串
|
||||
y: 第二个因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_cov", x, y, window)
|
||||
|
||||
|
||||
def ts_rank(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列排名。
|
||||
|
||||
计算当前值在过去窗口内的分位排名。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_rank", x, window)
|
||||
|
||||
|
||||
# ==================== 截面函数 (cs_*) ====================
|
||||
|
||||
|
||||
def cs_rank(x: Union[Node, str]) -> FunctionNode:
|
||||
"""截面排名。
|
||||
|
||||
计算因子在横截面上的排名(分位数)。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
|
||||
Example:
|
||||
>>> from src.factors.api import close, cs_rank
|
||||
>>> expr = cs_rank(close) # 收盘价截面排名
|
||||
>>> expr = cs_rank("close") # 使用字符串
|
||||
>>> print(expr)
|
||||
cs_rank(close)
|
||||
"""
|
||||
return FunctionNode("cs_rank", x)
|
||||
|
||||
|
||||
def cs_zscore(x: Union[Node, str]) -> FunctionNode:
|
||||
"""截面标准化 (Z-Score)。
|
||||
|
||||
计算因子在横截面上的 Z-Score 标准化值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("cs_zscore", x)
|
||||
|
||||
|
||||
def cs_neutralize(
|
||||
x: Union[Node, str], group: Union[Symbol, str, None] = None
|
||||
) -> FunctionNode:
|
||||
"""截面中性化。
|
||||
|
||||
对因子进行行业/市值中性化处理。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
group: 分组变量(如行业分类),可以为字符串或 Symbol,默认为 None
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
if group is not None:
|
||||
return FunctionNode("cs_neutralize", x, group)
|
||||
return FunctionNode("cs_neutralize", x)
|
||||
|
||||
|
||||
def cs_winsorize(
|
||||
x: Union[Node, str], lower: float = 0.01, upper: float = 0.99
|
||||
) -> FunctionNode:
|
||||
"""截面缩尾处理。
|
||||
|
||||
对因子进行截面缩尾处理,去除极端值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
lower: 下尾分位数,默认 0.01
|
||||
upper: 上尾分位数,默认 0.99
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("cs_winsorize", x, lower, upper)
|
||||
|
||||
|
||||
def cs_demean(x: Union[Node, str]) -> FunctionNode:
|
||||
"""截面去均值。
|
||||
|
||||
计算因子在横截面上减去均值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("cs_demean", x)
|
||||
|
||||
|
||||
# ==================== 数学函数 ====================
|
||||
|
||||
|
||||
def log(x: Union[Node, str]) -> FunctionNode:
|
||||
"""自然对数。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("log", x)
|
||||
|
||||
|
||||
def exp(x: Union[Node, str]) -> FunctionNode:
|
||||
"""指数函数。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("exp", x)
|
||||
|
||||
|
||||
def sqrt(x: Union[Node, str]) -> FunctionNode:
|
||||
"""平方根。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("sqrt", x)
|
||||
|
||||
|
||||
def sign(x: Union[Node, str]) -> FunctionNode:
|
||||
"""符号函数。
|
||||
|
||||
返回 -1, 0, 1 表示输入值的符号。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("sign", x)
|
||||
|
||||
|
||||
def abs(x: Union[Node, str]) -> FunctionNode:
|
||||
"""绝对值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("abs", x)
|
||||
|
||||
|
||||
def max_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode:
|
||||
"""逐元素最大值。
|
||||
|
||||
Args:
|
||||
x: 第一个因子表达式或字段名字符串
|
||||
y: 第二个因子表达式、字段名字符串或数值
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("max", x, _ensure_node(y))
|
||||
|
||||
|
||||
def min_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode:
|
||||
"""逐元素最小值。
|
||||
|
||||
Args:
|
||||
x: 第一个因子表达式或字段名字符串
|
||||
y: 第二个因子表达式、字段名字符串或数值
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("min", x, _ensure_node(y))
|
||||
|
||||
|
||||
def clip(
|
||||
x: Union[Node, str],
|
||||
lower: Union[Node, str, int, float],
|
||||
upper: Union[Node, str, int, float],
|
||||
) -> FunctionNode:
|
||||
"""数值裁剪。
|
||||
|
||||
将因子值限制在 [lower, upper] 范围内。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
lower: 下限(因子表达式、字段名字符串或数值)
|
||||
upper: 上限(因子表达式、字段名字符串或数值)
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("clip", x, _ensure_node(lower), _ensure_node(upper))
|
||||
|
||||
|
||||
# ==================== 条件函数 ====================
|
||||
|
||||
|
||||
def if_(
|
||||
condition: Union[Node, str],
|
||||
true_val: Union[Node, str, int, float],
|
||||
false_val: Union[Node, str, int, float],
|
||||
) -> FunctionNode:
|
||||
"""条件选择。
|
||||
|
||||
根据条件选择值。
|
||||
|
||||
Args:
|
||||
condition: 条件表达式或字段名字符串
|
||||
true_val: 条件为真时的值(因子表达式、字段名字符串或数值)
|
||||
false_val: 条件为假时的值(因子表达式、字段名字符串或数值)
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode(
|
||||
"if", condition, _ensure_node(true_val), _ensure_node(false_val)
|
||||
)
|
||||
|
||||
|
||||
def where(
|
||||
condition: Union[Node, str],
|
||||
true_val: Union[Node, str, int, float],
|
||||
false_val: Union[Node, str, int, float],
|
||||
) -> FunctionNode:
|
||||
"""条件选择(if_ 的别名)。
|
||||
|
||||
Args:
|
||||
condition: 条件表达式或字段名字符串
|
||||
true_val: 条件为真时的值(因子表达式、字段名字符串或数值)
|
||||
false_val: 条件为假时的值(因子表达式、字段名字符串或数值)
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return if_(condition, true_val, false_val)
|
||||
159
src/factors/compiler.py
Normal file
159
src/factors/compiler.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""AST 编译器模块 - 提供依赖提取和代码生成功能。
|
||||
|
||||
本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。
|
||||
"""
|
||||
|
||||
from typing import Set
|
||||
|
||||
from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode
|
||||
|
||||
|
||||
class DependencyExtractor:
|
||||
"""依赖提取器 - 使用访问者模式遍历 AST 节点。
|
||||
|
||||
递归遍历表达式树,提取所有 Symbol 节点的名称。
|
||||
支持 BinaryOpNode、UnaryOpNode 和 FunctionNode 的递归遍历。
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> pe_ratio = Symbol("pe_ratio")
|
||||
>>> alpha = FunctionNode("cs_rank", close / pe_ratio)
|
||||
>>> deps = DependencyExtractor.extract_dependencies(alpha)
|
||||
>>> print(deps)
|
||||
{'close', 'pe_ratio'}
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化依赖提取器。"""
|
||||
self.dependencies: Set[str] = set()
|
||||
|
||||
def visit(self, node: Node) -> None:
|
||||
"""访问节点,根据节点类型分发到具体处理方法。
|
||||
|
||||
Args:
|
||||
node: AST 节点
|
||||
"""
|
||||
if isinstance(node, Symbol):
|
||||
self._visit_symbol(node)
|
||||
elif isinstance(node, BinaryOpNode):
|
||||
self._visit_binary_op(node)
|
||||
elif isinstance(node, UnaryOpNode):
|
||||
self._visit_unary_op(node)
|
||||
elif isinstance(node, FunctionNode):
|
||||
self._visit_function(node)
|
||||
# Constant 节点不包含依赖,无需处理
|
||||
|
||||
def _visit_symbol(self, node: Symbol) -> None:
|
||||
"""访问 Symbol 节点,提取符号名称。
|
||||
|
||||
Args:
|
||||
node: 符号节点
|
||||
"""
|
||||
self.dependencies.add(node.name)
|
||||
|
||||
def _visit_binary_op(self, node: BinaryOpNode) -> None:
|
||||
"""访问 BinaryOpNode 节点,递归遍历左右子节点。
|
||||
|
||||
Args:
|
||||
node: 二元运算节点
|
||||
"""
|
||||
self.visit(node.left)
|
||||
self.visit(node.right)
|
||||
|
||||
def _visit_unary_op(self, node: UnaryOpNode) -> None:
|
||||
"""访问 UnaryOpNode 节点,递归遍历操作数。
|
||||
|
||||
Args:
|
||||
node: 一元运算节点
|
||||
"""
|
||||
self.visit(node.operand)
|
||||
|
||||
def _visit_function(self, node: FunctionNode) -> None:
|
||||
"""访问 FunctionNode 节点,递归遍历所有参数。
|
||||
|
||||
Args:
|
||||
node: 函数调用节点
|
||||
"""
|
||||
for arg in node.args:
|
||||
self.visit(arg)
|
||||
|
||||
def extract(self, node: Node) -> Set[str]:
|
||||
"""从 AST 节点中提取所有依赖的符号名称。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
依赖的符号名称集合
|
||||
"""
|
||||
self.dependencies.clear()
|
||||
self.visit(node)
|
||||
return self.dependencies.copy()
|
||||
|
||||
@classmethod
|
||||
def extract_dependencies(cls, node: Node) -> Set[str]:
|
||||
"""类方法 - 从 AST 节点中提取所有依赖的符号名称。
|
||||
|
||||
这是一个便捷方法,无需手动实例化 DependencyExtractor。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
依赖的符号名称集合
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol
|
||||
>>> close = Symbol("close")
|
||||
>>> open_price = Symbol("open")
|
||||
>>> expr = close / open_price
|
||||
>>> deps = DependencyExtractor.extract_dependencies(expr)
|
||||
>>> print(deps)
|
||||
{'close', 'open'}
|
||||
"""
|
||||
extractor = cls()
|
||||
return extractor.extract(node)
|
||||
|
||||
|
||||
def extract_dependencies(node: Node) -> Set[str]:
|
||||
"""单例方法 - 从 AST 节点中提取所有依赖的符号名称。
|
||||
|
||||
这是 DependencyExtractor.extract_dependencies 的便捷包装函数。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
依赖的符号名称集合
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> pe_ratio = Symbol("pe_ratio")
|
||||
>>> alpha = FunctionNode("cs_rank", close / pe_ratio)
|
||||
>>> deps = extract_dependencies(alpha)
|
||||
>>> print(deps)
|
||||
{'close', 'pe_ratio'}
|
||||
"""
|
||||
return DependencyExtractor.extract_dependencies(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试用例: cs_rank(close / pe_ratio)
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
|
||||
# 创建符号
|
||||
close = Symbol("close")
|
||||
pe_ratio = Symbol("pe_ratio")
|
||||
|
||||
# 构建表达式: cs_rank(close / pe_ratio)
|
||||
alpha = FunctionNode("cs_rank", close / pe_ratio)
|
||||
|
||||
# 提取依赖
|
||||
dependencies = extract_dependencies(alpha)
|
||||
|
||||
print(f"表达式: {alpha}")
|
||||
print(f"提取的依赖: {dependencies}")
|
||||
print(f"期望依赖: {{'close', 'pe_ratio'}}")
|
||||
print(f"验证结果: {dependencies == {'close', 'pe_ratio'}}")
|
||||
278
src/factors/dsl.py
Normal file
278
src/factors/dsl.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""DSL 表达式层 - 纯 Python 实现,无 pandas/polars 依赖。
|
||||
|
||||
提供因子表达式的符号化表示能力,通过重载运算符实现
|
||||
用户端无感知的公式编写。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
||||
class Node(ABC):
|
||||
"""表达式节点基类。
|
||||
|
||||
所有因子表达式组件的抽象基类,提供运算符重载能力。
|
||||
子类需要实现 __repr__ 方法用于表达式可视化。
|
||||
"""
|
||||
|
||||
# ==================== 算术运算符重载 ====================
|
||||
|
||||
def __add__(self, other: Any) -> BinaryOpNode:
|
||||
"""加法: self + other"""
|
||||
return BinaryOpNode("+", self, _ensure_node(other))
|
||||
|
||||
def __radd__(self, other: Any) -> BinaryOpNode:
|
||||
"""右加法: other + self"""
|
||||
return BinaryOpNode("+", _ensure_node(other), self)
|
||||
|
||||
def __sub__(self, other: Any) -> BinaryOpNode:
|
||||
"""减法: self - other"""
|
||||
return BinaryOpNode("-", self, _ensure_node(other))
|
||||
|
||||
def __rsub__(self, other: Any) -> BinaryOpNode:
|
||||
"""右减法: other - self"""
|
||||
return BinaryOpNode("-", _ensure_node(other), self)
|
||||
|
||||
def __mul__(self, other: Any) -> BinaryOpNode:
|
||||
"""乘法: self * other"""
|
||||
return BinaryOpNode("*", self, _ensure_node(other))
|
||||
|
||||
def __rmul__(self, other: Any) -> BinaryOpNode:
|
||||
"""右乘法: other * self"""
|
||||
return BinaryOpNode("*", _ensure_node(other), self)
|
||||
|
||||
def __truediv__(self, other: Any) -> BinaryOpNode:
|
||||
"""除法: self / other"""
|
||||
return BinaryOpNode("/", self, _ensure_node(other))
|
||||
|
||||
def __rtruediv__(self, other: Any) -> BinaryOpNode:
|
||||
"""右除法: other / self"""
|
||||
return BinaryOpNode("/", _ensure_node(other), self)
|
||||
|
||||
def __pow__(self, other: Any) -> BinaryOpNode:
|
||||
"""幂运算: self ** other"""
|
||||
return BinaryOpNode("**", self, _ensure_node(other))
|
||||
|
||||
def __rpow__(self, other: Any) -> BinaryOpNode:
|
||||
"""右幂运算: other ** self"""
|
||||
return BinaryOpNode("**", _ensure_node(other), self)
|
||||
|
||||
def __floordiv__(self, other: Any) -> BinaryOpNode:
|
||||
"""整除: self // other"""
|
||||
return BinaryOpNode("//", self, _ensure_node(other))
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> BinaryOpNode:
|
||||
"""右整除: other // self"""
|
||||
return BinaryOpNode("//", _ensure_node(other), self)
|
||||
|
||||
def __mod__(self, other: Any) -> BinaryOpNode:
|
||||
"""取模: self % other"""
|
||||
return BinaryOpNode("%", self, _ensure_node(other))
|
||||
|
||||
def __rmod__(self, other: Any) -> BinaryOpNode:
|
||||
"""右取模: other % self"""
|
||||
return BinaryOpNode("%", _ensure_node(other), self)
|
||||
|
||||
# ==================== 一元运算符重载 ====================
|
||||
|
||||
def __neg__(self) -> UnaryOpNode:
|
||||
"""取负: -self"""
|
||||
return UnaryOpNode("-", self)
|
||||
|
||||
def __pos__(self) -> UnaryOpNode:
|
||||
"""取正: +self"""
|
||||
return UnaryOpNode("+", self)
|
||||
|
||||
def __abs__(self) -> UnaryOpNode:
|
||||
"""绝对值: abs(self)"""
|
||||
return UnaryOpNode("abs", self)
|
||||
|
||||
# ==================== 比较运算符重载 ====================
|
||||
|
||||
def __eq__(self, other: Any) -> BinaryOpNode:
|
||||
"""等于: self == other"""
|
||||
return BinaryOpNode("==", self, _ensure_node(other))
|
||||
|
||||
def __ne__(self, other: Any) -> BinaryOpNode:
|
||||
"""不等于: self != other"""
|
||||
return BinaryOpNode("!=", self, _ensure_node(other))
|
||||
|
||||
def __lt__(self, other: Any) -> BinaryOpNode:
|
||||
"""小于: self < other"""
|
||||
return BinaryOpNode("<", self, _ensure_node(other))
|
||||
|
||||
def __le__(self, other: Any) -> BinaryOpNode:
|
||||
"""小于等于: self <= other"""
|
||||
return BinaryOpNode("<=", self, _ensure_node(other))
|
||||
|
||||
def __gt__(self, other: Any) -> BinaryOpNode:
|
||||
"""大于: self > other"""
|
||||
return BinaryOpNode(">", self, _ensure_node(other))
|
||||
|
||||
def __ge__(self, other: Any) -> BinaryOpNode:
|
||||
"""大于等于: self >= other"""
|
||||
return BinaryOpNode(">=", self, _ensure_node(other))
|
||||
|
||||
# ==================== 抽象方法 ====================
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
"""返回表达式的字符串表示。"""
|
||||
pass
|
||||
|
||||
|
||||
class Symbol(Node):
|
||||
"""符号节点,代表一个命名变量(如 close, open 等)。
|
||||
|
||||
Attributes:
|
||||
name: 符号名称,用于标识该变量
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""初始化符号节点。
|
||||
|
||||
Args:
|
||||
name: 符号名称,如 'close', 'open', 'volume' 等
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回符号名称。"""
|
||||
return self.name
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""支持作为字典键使用。"""
|
||||
return hash(self.name)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""符号相等性比较。"""
|
||||
if not isinstance(other, Symbol):
|
||||
return NotImplemented
|
||||
return self.name == other.name
|
||||
|
||||
|
||||
class Constant(Node):
|
||||
"""常量节点,代表一个数值常量。
|
||||
|
||||
Attributes:
|
||||
value: 常量数值
|
||||
"""
|
||||
|
||||
def __init__(self, value: Union[int, float]) -> None:
|
||||
"""初始化常量节点。
|
||||
|
||||
Args:
|
||||
value: 常量数值
|
||||
"""
|
||||
self.value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回常量值的字符串表示。"""
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class BinaryOpNode(Node):
|
||||
"""二元运算节点,表示两个操作数之间的运算。
|
||||
|
||||
Attributes:
|
||||
op: 运算符,如 '+', '-', '*', '/' 等
|
||||
left: 左操作数
|
||||
right: 右操作数
|
||||
"""
|
||||
|
||||
def __init__(self, op: str, left: Node, right: Node) -> None:
|
||||
"""初始化二元运算节点。
|
||||
|
||||
Args:
|
||||
op: 运算符字符串
|
||||
left: 左操作数节点
|
||||
right: 右操作数节点
|
||||
"""
|
||||
self.op = op
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回带括号的二元运算表达式。"""
|
||||
return f"({self.left} {self.op} {self.right})"
|
||||
|
||||
|
||||
class UnaryOpNode(Node):
|
||||
"""一元运算节点,表示对单个操作数的运算。
|
||||
|
||||
Attributes:
|
||||
op: 运算符,如 '-', '+', 'abs' 等
|
||||
operand: 操作数
|
||||
"""
|
||||
|
||||
def __init__(self, op: str, operand: Node) -> None:
|
||||
"""初始化一元运算节点。
|
||||
|
||||
Args:
|
||||
op: 运算符字符串
|
||||
operand: 操作数节点
|
||||
"""
|
||||
self.op = op
|
||||
self.operand = operand
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回一元运算表达式。"""
|
||||
if self.op in ("+", "-"):
|
||||
return f"({self.op}{self.operand})"
|
||||
return f"{self.op}({self.operand})"
|
||||
|
||||
|
||||
class FunctionNode(Node):
|
||||
"""函数调用节点,表示一个函数调用。
|
||||
|
||||
Attributes:
|
||||
func_name: 函数名称
|
||||
args: 函数参数列表
|
||||
"""
|
||||
|
||||
def __init__(self, func_name: str, *args: Any) -> None:
|
||||
"""初始化函数调用节点。
|
||||
|
||||
Args:
|
||||
func_name: 函数名称,如 'ts_mean', 'cs_rank' 等
|
||||
*args: 函数参数,可以是 Node 或其他类型
|
||||
"""
|
||||
self.func_name = func_name
|
||||
# 将所有参数转换为节点类型
|
||||
self.args: List[Node] = [_ensure_node(arg) for arg in args]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回函数调用表达式。"""
|
||||
args_str = ", ".join(repr(arg) for arg in self.args)
|
||||
return f"{self.func_name}({args_str})"
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def _ensure_node(value: Any) -> Node:
|
||||
"""确保值是一个 Node 节点。
|
||||
|
||||
如果值已经是 Node 类型,直接返回;
|
||||
如果是数值类型,包装为 Constant 节点;
|
||||
如果是字符串类型,包装为 Symbol 节点;
|
||||
否则抛出类型错误。
|
||||
|
||||
Args:
|
||||
value: 任意值
|
||||
|
||||
Returns:
|
||||
Node: 对应的节点对象
|
||||
|
||||
Raises:
|
||||
TypeError: 当值无法转换为节点时
|
||||
"""
|
||||
if isinstance(value, Node):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return Constant(value)
|
||||
if isinstance(value, str):
|
||||
return Symbol(value)
|
||||
raise TypeError(f"无法将类型 {type(value).__name__} 转换为 Node")
|
||||
387
src/factors/translator.py
Normal file
387
src/factors/translator.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
|
||||
|
||||
本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。
|
||||
支持时序因子(ts_*)和截面因子(cs_*)的防错分组翻译。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.dsl import (
|
||||
BinaryOpNode,
|
||||
Constant,
|
||||
FunctionNode,
|
||||
Node,
|
||||
Symbol,
|
||||
UnaryOpNode,
|
||||
)
|
||||
|
||||
|
||||
class PolarsTranslator:
|
||||
"""Polars 表达式翻译器。
|
||||
|
||||
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。
|
||||
|
||||
Attributes:
|
||||
handlers: 函数处理器注册表,映射 func_name 到处理函数
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||||
>>> translator = PolarsTranslator()
|
||||
>>> polars_expr = translator.translate(expr)
|
||||
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化翻译器并注册内置函数处理器。"""
|
||||
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
|
||||
self._register_builtin_handlers()
|
||||
|
||||
def _register_builtin_handlers(self) -> None:
|
||||
"""注册内置的函数处理器。"""
|
||||
# 时序因子处理器 (ts_*)
|
||||
self.register_handler("ts_mean", self._handle_ts_mean)
|
||||
self.register_handler("ts_sum", self._handle_ts_sum)
|
||||
self.register_handler("ts_std", self._handle_ts_std)
|
||||
self.register_handler("ts_max", self._handle_ts_max)
|
||||
self.register_handler("ts_min", self._handle_ts_min)
|
||||
self.register_handler("ts_delay", self._handle_ts_delay)
|
||||
self.register_handler("ts_delta", self._handle_ts_delta)
|
||||
self.register_handler("ts_corr", self._handle_ts_corr)
|
||||
self.register_handler("ts_cov", self._handle_ts_cov)
|
||||
|
||||
# 截面因子处理器 (cs_*)
|
||||
self.register_handler("cs_rank", self._handle_cs_rank)
|
||||
self.register_handler("cs_zscore", self._handle_cs_zscore)
|
||||
self.register_handler("cs_neutral", self._handle_cs_neutral)
|
||||
|
||||
def register_handler(
|
||||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||||
) -> None:
|
||||
"""注册自定义函数处理器。
|
||||
|
||||
Args:
|
||||
func_name: 函数名称
|
||||
handler: 处理函数,接收 FunctionNode 返回 pl.Expr
|
||||
|
||||
Example:
|
||||
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
|
||||
... arg = self.translate(node.args[0])
|
||||
... return arg * 2
|
||||
>>> translator.register_handler("custom", handle_custom)
|
||||
"""
|
||||
self.handlers[func_name] = handler
|
||||
|
||||
def translate(self, node: Node) -> pl.Expr:
|
||||
"""递归翻译 AST 节点为 Polars 表达式。
|
||||
|
||||
Args:
|
||||
node: AST 节点(Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode)
|
||||
|
||||
Returns:
|
||||
Polars 表达式对象
|
||||
|
||||
Raises:
|
||||
TypeError: 当遇到未知的节点类型时
|
||||
"""
|
||||
if isinstance(node, Symbol):
|
||||
return self._translate_symbol(node)
|
||||
elif isinstance(node, Constant):
|
||||
return self._translate_constant(node)
|
||||
elif isinstance(node, BinaryOpNode):
|
||||
return self._translate_binary_op(node)
|
||||
elif isinstance(node, UnaryOpNode):
|
||||
return self._translate_unary_op(node)
|
||||
elif isinstance(node, FunctionNode):
|
||||
return self._translate_function(node)
|
||||
else:
|
||||
raise TypeError(f"未知的节点类型: {type(node).__name__}")
|
||||
|
||||
def _translate_symbol(self, node: Symbol) -> pl.Expr:
|
||||
"""翻译 Symbol 节点为 pl.col() 表达式。
|
||||
|
||||
Args:
|
||||
node: 符号节点
|
||||
|
||||
Returns:
|
||||
pl.col(node.name) 表达式
|
||||
"""
|
||||
return pl.col(node.name)
|
||||
|
||||
def _translate_constant(self, node: Constant) -> pl.Expr:
|
||||
"""翻译 Constant 节点为 Polars 字面量。
|
||||
|
||||
Args:
|
||||
node: 常量节点
|
||||
|
||||
Returns:
|
||||
pl.lit(node.value) 表达式
|
||||
"""
|
||||
return pl.lit(node.value)
|
||||
|
||||
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
|
||||
"""翻译 BinaryOpNode 为 Polars 二元运算。
|
||||
|
||||
Args:
|
||||
node: 二元运算节点
|
||||
|
||||
Returns:
|
||||
Polars 二元运算表达式
|
||||
"""
|
||||
left = self.translate(node.left)
|
||||
right = self.translate(node.right)
|
||||
|
||||
op_map = {
|
||||
"+": lambda l, r: l + r,
|
||||
"-": lambda l, r: l - r,
|
||||
"*": lambda l, r: l * r,
|
||||
"/": lambda l, r: l / r,
|
||||
"**": lambda l, r: l.pow(r),
|
||||
"//": lambda l, r: l.floor_div(r),
|
||||
"%": lambda l, r: l % r,
|
||||
"==": lambda l, r: l.eq(r),
|
||||
"!=": lambda l, r: l.ne(r),
|
||||
"<": lambda l, r: l.lt(r),
|
||||
"<=": lambda l, r: l.le(r),
|
||||
">": lambda l, r: l.gt(r),
|
||||
">=": lambda l, r: l.ge(r),
|
||||
}
|
||||
|
||||
if node.op not in op_map:
|
||||
raise ValueError(f"不支持的二元运算符: {node.op}")
|
||||
|
||||
return op_map[node.op](left, right)
|
||||
|
||||
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
|
||||
"""翻译 UnaryOpNode 为 Polars 一元运算。
|
||||
|
||||
Args:
|
||||
node: 一元运算节点
|
||||
|
||||
Returns:
|
||||
Polars 一元运算表达式
|
||||
"""
|
||||
operand = self.translate(node.operand)
|
||||
|
||||
op_map = {
|
||||
"+": lambda x: x,
|
||||
"-": lambda x: -x,
|
||||
"abs": lambda x: x.abs(),
|
||||
}
|
||||
|
||||
if node.op not in op_map:
|
||||
raise ValueError(f"不支持的一元运算符: {node.op}")
|
||||
|
||||
return op_map[node.op](operand)
|
||||
|
||||
def _translate_function(self, node: FunctionNode) -> pl.Expr:
|
||||
"""翻译 FunctionNode 为 Polars 函数调用。
|
||||
|
||||
优先从 handlers 注册表中查找处理器,未找到则抛出错误。
|
||||
|
||||
Args:
|
||||
node: 函数调用节点
|
||||
|
||||
Returns:
|
||||
Polars 函数表达式
|
||||
|
||||
Raises:
|
||||
ValueError: 当函数名称未注册处理器时
|
||||
"""
|
||||
func_name = node.func_name
|
||||
|
||||
if func_name in self.handlers:
|
||||
return self.handlers[func_name](node)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
|
||||
)
|
||||
|
||||
# ==================== 时序因子处理器 (ts_*) ====================
|
||||
# 所有时序因子强制注入 over("ts_code") 防串表
|
||||
|
||||
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_mean(close, window) -> rolling_mean(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_mean(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_sum(close, window) -> rolling_sum(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_sum(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_std(close, window) -> rolling_std(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_std(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_max(close, window) -> rolling_max(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_max(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_min(close, window) -> rolling_min(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_min(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_delay(close, n) -> shift(n).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
return expr.shift(n).over("ts_code")
|
||||
|
||||
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_delta(close, n) -> (expr - shift(n)).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
return (expr - expr.shift(n)).over("ts_code")
|
||||
|
||||
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window).over(ts_code)。"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
|
||||
x = self.translate(node.args[0])
|
||||
y = self.translate(node.args[1])
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_corr(y, window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window).over(ts_code)。"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
|
||||
x = self.translate(node.args[0])
|
||||
y = self.translate(node.args[1])
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_cov(y, window_size=window).over("ts_code")
|
||||
|
||||
# ==================== 截面因子处理器 (cs_*) ====================
|
||||
# 所有截面因子强制注入 over("trade_date") 防串表
|
||||
|
||||
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_rank(expr) -> rank()/count().over(trade_date)。
|
||||
|
||||
将排名归一化到 [0, 1] 区间。
|
||||
"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cs_rank 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return (expr.rank() / expr.count()).over("trade_date")
|
||||
|
||||
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_zscore(expr) -> (expr - mean())/std().over(trade_date)。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return ((expr - expr.mean()) / expr.std()).over("trade_date")
|
||||
|
||||
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
|
||||
if len(node.args) not in [1, 2]:
|
||||
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
|
||||
expr = self.translate(node.args[0])
|
||||
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
|
||||
return (expr - expr.mean()).over("trade_date")
|
||||
|
||||
# ==================== 辅助方法 ====================
|
||||
|
||||
def _extract_window(self, node: Node) -> int:
|
||||
"""从节点中提取窗口大小参数。
|
||||
|
||||
Args:
|
||||
node: 应该是 Constant 节点
|
||||
|
||||
Returns:
|
||||
整数值
|
||||
|
||||
Raises:
|
||||
ValueError: 当节点不是 Constant 或值不是整数时
|
||||
"""
|
||||
if isinstance(node, Constant):
|
||||
if not isinstance(node.value, int):
|
||||
raise ValueError(
|
||||
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
|
||||
)
|
||||
return node.value
|
||||
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
|
||||
|
||||
|
||||
def translate_to_polars(node: Node) -> pl.Expr:
|
||||
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
Polars 表达式对象
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||||
>>> polars_expr = translate_to_polars(expr)
|
||||
"""
|
||||
translator = PolarsTranslator()
|
||||
return translator.translate(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试用例
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
|
||||
# 创建符号
|
||||
close = Symbol("close")
|
||||
volume = Symbol("volume")
|
||||
|
||||
# 测试 1: 简单符号
|
||||
print("测试 1: Symbol")
|
||||
translator = PolarsTranslator()
|
||||
expr1 = translator.translate(close)
|
||||
print(f" close -> {expr1}")
|
||||
assert str(expr1) == 'col("close")'
|
||||
|
||||
# 测试 2: 二元运算
|
||||
print("\n测试 2: BinaryOp")
|
||||
expr2 = translator.translate(close + 10)
|
||||
print(f" close + 10 -> {expr2}")
|
||||
|
||||
# 测试 3: ts_mean
|
||||
print("\n测试 3: ts_mean")
|
||||
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
|
||||
print(f" ts_mean(close, 20) -> {expr3}")
|
||||
|
||||
# 测试 4: cs_rank
|
||||
print("\n测试 4: cs_rank")
|
||||
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
|
||||
print(f" cs_rank(close / volume) -> {expr4}")
|
||||
|
||||
# 测试 5: 复杂表达式
|
||||
print("\n测试 5: 复杂表达式")
|
||||
ma20 = FunctionNode("ts_mean", close, 20)
|
||||
ma60 = FunctionNode("ts_mean", close, 60)
|
||||
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
|
||||
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
|
||||
|
||||
print("\n✅ 所有测试通过!")
|
||||
325
tests/factors/test_dsl_promotion.py
Normal file
325
tests/factors/test_dsl_promotion.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""测试 DSL 字符串自动提升(Promotion)功能。
|
||||
|
||||
验证以下功能:
|
||||
1. 字符串自动转换为 Symbol
|
||||
2. 算子函数支持字符串参数
|
||||
3. 右位运算支持
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from src.factors.dsl import (
|
||||
Symbol,
|
||||
Constant,
|
||||
BinaryOpNode,
|
||||
UnaryOpNode,
|
||||
FunctionNode,
|
||||
_ensure_node,
|
||||
)
|
||||
from src.factors.api import (
|
||||
close,
|
||||
open,
|
||||
ts_mean,
|
||||
ts_std,
|
||||
ts_corr,
|
||||
cs_rank,
|
||||
cs_zscore,
|
||||
log,
|
||||
exp,
|
||||
max_,
|
||||
min_,
|
||||
clip,
|
||||
if_,
|
||||
where,
|
||||
)
|
||||
|
||||
|
||||
class TestEnsureNode:
|
||||
"""测试 _ensure_node 辅助函数。"""
|
||||
|
||||
def test_ensure_node_with_node(self):
|
||||
"""Node 类型应该原样返回。"""
|
||||
sym = Symbol("close")
|
||||
result = _ensure_node(sym)
|
||||
assert result is sym
|
||||
|
||||
def test_ensure_node_with_int(self):
|
||||
"""整数应该转换为 Constant。"""
|
||||
result = _ensure_node(100)
|
||||
assert isinstance(result, Constant)
|
||||
assert result.value == 100
|
||||
|
||||
def test_ensure_node_with_float(self):
|
||||
"""浮点数应该转换为 Constant。"""
|
||||
result = _ensure_node(3.14)
|
||||
assert isinstance(result, Constant)
|
||||
assert result.value == 3.14
|
||||
|
||||
def test_ensure_node_with_str(self):
|
||||
"""字符串应该转换为 Symbol。"""
|
||||
result = _ensure_node("close")
|
||||
assert isinstance(result, Symbol)
|
||||
assert result.name == "close"
|
||||
|
||||
def test_ensure_node_with_invalid_type(self):
|
||||
"""无效类型应该抛出 TypeError。"""
|
||||
with pytest.raises(TypeError):
|
||||
_ensure_node([1, 2, 3])
|
||||
|
||||
|
||||
class TestSymbolStringPromotion:
|
||||
"""测试 Symbol 与字符串的运算。"""
|
||||
|
||||
def test_symbol_add_str(self):
|
||||
"""Symbol + 字符串。"""
|
||||
expr = close + "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "+"
|
||||
assert isinstance(expr.left, Symbol)
|
||||
assert expr.left.name == "close"
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_symbol_sub_str(self):
|
||||
"""Symbol - 字符串。"""
|
||||
expr = close - "open"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "-"
|
||||
assert expr.right.name == "open"
|
||||
|
||||
def test_symbol_mul_str(self):
|
||||
"""Symbol * 字符串。"""
|
||||
expr = close * "volume"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert expr.right.name == "volume"
|
||||
|
||||
def test_symbol_div_str(self):
|
||||
"""Symbol / 字符串。"""
|
||||
expr = close / "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_symbol_pow_str(self):
|
||||
"""Symbol ** 字符串。"""
|
||||
expr = close ** "exponent"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "**"
|
||||
assert expr.right.name == "exponent"
|
||||
|
||||
|
||||
class TestRightHandOperations:
|
||||
"""测试右位运算。"""
|
||||
|
||||
def test_int_add_symbol(self):
|
||||
"""整数 + Symbol。"""
|
||||
expr = 100 + close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "+"
|
||||
assert isinstance(expr.left, Constant)
|
||||
assert expr.left.value == 100
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_sub_symbol(self):
|
||||
"""整数 - Symbol。"""
|
||||
expr = 100 - close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "-"
|
||||
assert expr.left.value == 100
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_mul_symbol(self):
|
||||
"""整数 * Symbol。"""
|
||||
expr = 2 * close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert expr.left.value == 2
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_div_symbol(self):
|
||||
"""整数 / Symbol。"""
|
||||
expr = 100 / close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert expr.left.value == 100
|
||||
assert expr.right.name == "close"
|
||||
|
||||
def test_int_div_str_not_supported(self):
|
||||
"""Python 内置 int 不支持直接与 str 进行除法运算。
|
||||
|
||||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
||||
所以 100 / "close" 会抛出 TypeError。正确的用法是 100 / Symbol("close") 或
|
||||
使用已有的 Symbol 对象如 close。
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
100 / "close"
|
||||
def test_int_floordiv_symbol(self):
|
||||
"""整数 // Symbol。"""
|
||||
expr = 100 // close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "//"
|
||||
|
||||
def test_int_mod_symbol(self):
|
||||
"""整数 % Symbol。"""
|
||||
expr = 100 % close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "%"
|
||||
|
||||
def test_int_pow_symbol(self):
|
||||
"""整数 ** Symbol。"""
|
||||
expr = 2**close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "**"
|
||||
assert expr.left.value == 2
|
||||
assert expr.right.name == "close"
|
||||
|
||||
|
||||
class TestOperatorFunctionsWithStrings:
|
||||
"""测试算子函数支持字符串参数。"""
|
||||
|
||||
def test_ts_mean_with_str(self):
|
||||
"""ts_mean 支持字符串参数。"""
|
||||
expr = ts_mean("close", 20)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_mean"
|
||||
assert len(expr.args) == 2
|
||||
assert isinstance(expr.args[0], Symbol)
|
||||
assert expr.args[0].name == "close"
|
||||
assert isinstance(expr.args[1], Constant)
|
||||
assert expr.args[1].value == 20
|
||||
|
||||
def test_ts_std_with_str(self):
|
||||
"""ts_std 支持字符串参数。"""
|
||||
expr = ts_std("volume", 10)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_std"
|
||||
assert expr.args[0].name == "volume"
|
||||
|
||||
def test_ts_corr_with_str(self):
|
||||
"""ts_corr 支持字符串参数。"""
|
||||
expr = ts_corr("close", "open", 20)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "ts_corr"
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].name == "open"
|
||||
|
||||
def test_cs_rank_with_str(self):
|
||||
"""cs_rank 支持字符串参数。"""
|
||||
expr = cs_rank("pe_ratio")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "cs_rank"
|
||||
assert expr.args[0].name == "pe_ratio"
|
||||
|
||||
def test_cs_zscore_with_str(self):
|
||||
"""cs_zscore 支持字符串参数。"""
|
||||
expr = cs_zscore("market_cap")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "cs_zscore"
|
||||
assert expr.args[0].name == "market_cap"
|
||||
|
||||
def test_log_with_str(self):
|
||||
"""log 支持字符串参数。"""
|
||||
expr = log("close")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "log"
|
||||
assert expr.args[0].name == "close"
|
||||
|
||||
def test_max_with_str(self):
|
||||
"""max_ 支持字符串参数。"""
|
||||
expr = max_("close", "open")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "max"
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].name == "open"
|
||||
|
||||
def test_max_with_str_and_number(self):
|
||||
"""max_ 支持字符串和数值混合。"""
|
||||
expr = max_("close", 100)
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.args[0].name == "close"
|
||||
assert expr.args[1].value == 100
|
||||
|
||||
def test_clip_with_str(self):
|
||||
"""clip 支持字符串参数。"""
|
||||
expr = clip("pe_ratio", "lower_bound", "upper_bound")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "clip"
|
||||
assert expr.args[0].name == "pe_ratio"
|
||||
assert expr.args[1].name == "lower_bound"
|
||||
assert expr.args[2].name == "upper_bound"
|
||||
|
||||
def test_if_with_str(self):
|
||||
"""if_ 支持字符串参数。"""
|
||||
expr = if_("condition", "true_val", "false_val")
|
||||
assert isinstance(expr, FunctionNode)
|
||||
assert expr.func_name == "if"
|
||||
assert expr.args[0].name == "condition"
|
||||
assert expr.args[1].name == "true_val"
|
||||
assert expr.args[2].name == "false_val"
|
||||
|
||||
|
||||
class TestComplexExpressions:
|
||||
"""测试复杂表达式。"""
|
||||
|
||||
def test_complex_expression_1(self):
|
||||
"""复杂表达式:ts_mean("close", 5) / "pe_ratio"。"""
|
||||
expr = ts_mean("close", 5) / "pe_ratio"
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert isinstance(expr.left, FunctionNode)
|
||||
assert expr.left.func_name == "ts_mean"
|
||||
assert isinstance(expr.right, Symbol)
|
||||
assert expr.right.name == "pe_ratio"
|
||||
|
||||
def test_complex_expression_2(self):
|
||||
"""复杂表达式:100 / close * cs_rank("volume") 。
|
||||
|
||||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
||||
所以需要使用已有的 Symbol 对象或先创建 Symbol。
|
||||
"""
|
||||
expr = 100 / close * cs_rank("volume")
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "*"
|
||||
assert isinstance(expr.left, BinaryOpNode)
|
||||
assert expr.left.op == "/"
|
||||
assert isinstance(expr.right, FunctionNode)
|
||||
assert expr.right.func_name == "cs_rank"
|
||||
def test_complex_expression_3(self):
|
||||
"""复杂表达式:ts_mean(close - "open", 20) / close。"""
|
||||
expr = ts_mean(close - "open", 20) / close
|
||||
assert isinstance(expr, BinaryOpNode)
|
||||
assert expr.op == "/"
|
||||
assert isinstance(expr.left, FunctionNode)
|
||||
assert expr.left.func_name == "ts_mean"
|
||||
# 检查 ts_mean 的第一个参数是 close - open
|
||||
assert isinstance(expr.left.args[0], BinaryOpNode)
|
||||
assert expr.left.args[0].op == "-"
|
||||
|
||||
|
||||
class TestExpressionRepr:
|
||||
"""测试表达式字符串表示。"""
|
||||
|
||||
def test_symbol_str_repr(self):
|
||||
"""Symbol 的字符串表示。"""
|
||||
expr = Symbol("close")
|
||||
assert repr(expr) == "close"
|
||||
|
||||
def test_binary_op_repr(self):
|
||||
"""二元运算的字符串表示。"""
|
||||
expr = close + "open"
|
||||
assert repr(expr) == "(close + open)"
|
||||
|
||||
def test_function_node_repr(self):
|
||||
"""函数节点的字符串表示。"""
|
||||
expr = ts_mean("close", 20)
|
||||
assert repr(expr) == "ts_mean(close, 20)"
|
||||
|
||||
def test_complex_expr_repr(self):
|
||||
"""复杂表达式的字符串表示。"""
|
||||
expr = ts_mean("close", 5) / "pe_ratio"
|
||||
assert repr(expr) == "(ts_mean(close, 5) / pe_ratio)"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
451
tests/test_factor_integration.py
Normal file
451
tests/test_factor_integration.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""因子框架集成测试脚本
|
||||
|
||||
测试目标:验证因子框架在 DuckDB 真实数据上的核心逻辑
|
||||
|
||||
测试范围:
|
||||
1. 时序因子 ts_mean - 验证滑动窗口和数据隔离
|
||||
2. 截面因子 cs_rank - 验证每日独立排名和结果分布
|
||||
3. 组合运算 - 验证多字段算术运算和算子嵌套
|
||||
|
||||
排除范围:PIT 因子(使用低频财务数据)
|
||||
"""
|
||||
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.data.data_router import DatabaseCatalog
|
||||
from src.factors.engine import FactorEngine
|
||||
from src.factors.api import close, open, ts_mean, cs_rank
|
||||
|
||||
|
||||
def select_sample_stocks(catalog: DatabaseCatalog, n: int = 8) -> list:
|
||||
"""随机选择代表性股票样本。
|
||||
|
||||
确保样本覆盖不同交易所:
|
||||
- .SH: 上海证券交易所(主板、科创板)
|
||||
- .SZ: 深圳证券交易所(主板、创业板)
|
||||
|
||||
Args:
|
||||
catalog: 数据库目录实例
|
||||
n: 需要选择的股票数量
|
||||
|
||||
Returns:
|
||||
股票代码列表
|
||||
"""
|
||||
# 从 catalog 获取数据库连接
|
||||
db_path = catalog.db_path.replace("duckdb://", "").lstrip("/")
|
||||
import duckdb
|
||||
|
||||
conn = duckdb.connect(db_path, read_only=True)
|
||||
|
||||
try:
|
||||
# 获取2023年上半年的所有股票
|
||||
result = conn.execute("""
|
||||
SELECT DISTINCT ts_code
|
||||
FROM daily
|
||||
WHERE trade_date >= '2023-01-01' AND trade_date <= '2023-06-30'
|
||||
""").fetchall()
|
||||
|
||||
all_stocks = [row[0] for row in result]
|
||||
|
||||
# 按交易所分类
|
||||
sh_stocks = [s for s in all_stocks if s.endswith(".SH")]
|
||||
sz_stocks = [s for s in all_stocks if s.endswith(".SZ")]
|
||||
|
||||
# 选择样本:确保覆盖两个交易所
|
||||
sample = []
|
||||
|
||||
# 从上海市场选择 (包含主板600/601/603/605和科创板688)
|
||||
sh_main = [
|
||||
s for s in sh_stocks if s.startswith("6") and not s.startswith("688")
|
||||
]
|
||||
sh_kcb = [s for s in sh_stocks if s.startswith("688")]
|
||||
|
||||
# 从深圳市场选择 (包含主板000/001/002和创业板300/301)
|
||||
sz_main = [s for s in sz_stocks if s.startswith("0")]
|
||||
sz_cyb = [s for s in sz_stocks if s.startswith("300") or s.startswith("301")]
|
||||
|
||||
# 每类选择部分股票
|
||||
if sh_main:
|
||||
sample.extend(random.sample(sh_main, min(2, len(sh_main))))
|
||||
if sh_kcb:
|
||||
sample.extend(random.sample(sh_kcb, min(2, len(sh_kcb))))
|
||||
if sz_main:
|
||||
sample.extend(random.sample(sz_main, min(2, len(sz_main))))
|
||||
if sz_cyb:
|
||||
sample.extend(random.sample(sz_cyb, min(2, len(sz_cyb))))
|
||||
|
||||
# 如果还不够,随机补充
|
||||
while len(sample) < n and len(sample) < len(all_stocks):
|
||||
remaining = [s for s in all_stocks if s not in sample]
|
||||
if remaining:
|
||||
sample.append(random.choice(remaining))
|
||||
else:
|
||||
break
|
||||
|
||||
return sorted(sample[:n])
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def run_factor_integration_test():
|
||||
"""执行因子框架集成测试。"""
|
||||
|
||||
print("=" * 80)
|
||||
print("因子框架集成测试 - DuckDB 真实数据验证")
|
||||
print("=" * 80)
|
||||
|
||||
# =========================================================================
|
||||
# 1. 测试环境准备
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("1. 测试环境准备")
|
||||
print("=" * 80)
|
||||
|
||||
# 数据库配置
|
||||
db_path = "data/prostock.db"
|
||||
db_uri = f"duckdb:///{db_path}"
|
||||
|
||||
print(f"\n数据库路径: {db_path}")
|
||||
print(f"数据库URI: {db_uri}")
|
||||
|
||||
# 时间范围
|
||||
start_date = "20230101"
|
||||
end_date = "20230630"
|
||||
print(f"\n测试时间范围: {start_date} 至 {end_date}")
|
||||
|
||||
# 创建 DatabaseCatalog 并发现表结构
|
||||
print("\n[1.1] 创建 DatabaseCatalog 并发现表结构...")
|
||||
catalog = DatabaseCatalog(db_path)
|
||||
print(f"发现表数量: {len(catalog.tables)}")
|
||||
for table_name, metadata in catalog.tables.items():
|
||||
print(
|
||||
f" - {table_name}: {metadata.frequency.value} (日期字段: {metadata.date_field})"
|
||||
)
|
||||
|
||||
# 选择样本股票
|
||||
print("\n[1.2] 选择样本股票...")
|
||||
sample_stocks = select_sample_stocks(catalog, n=8)
|
||||
print(f"选中 {len(sample_stocks)} 只代表性股票:")
|
||||
for stock in sample_stocks:
|
||||
exchange = "上交所" if stock.endswith(".SH") else "深交所"
|
||||
board = ""
|
||||
if stock.startswith("688"):
|
||||
board = "科创板"
|
||||
elif (
|
||||
stock.startswith("600")
|
||||
or stock.startswith("601")
|
||||
or stock.startswith("603")
|
||||
):
|
||||
board = "主板"
|
||||
elif stock.startswith("300") or stock.startswith("301"):
|
||||
board = "创业板"
|
||||
elif (
|
||||
stock.startswith("000")
|
||||
or stock.startswith("001")
|
||||
or stock.startswith("002")
|
||||
):
|
||||
board = "主板"
|
||||
print(f" - {stock} ({exchange} {board})")
|
||||
|
||||
# =========================================================================
|
||||
# 2. 因子定义
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("2. 因子定义")
|
||||
print("=" * 80)
|
||||
|
||||
# 创建 FactorEngine
|
||||
print("\n[2.1] 创建 FactorEngine...")
|
||||
engine = FactorEngine(catalog)
|
||||
|
||||
# 因子 A: 时序均线 ts_mean(close, 10)
|
||||
print("\n[2.2] 注册因子 A (时序均线): ts_mean(close, 10)")
|
||||
print(" 验证重点: 10日滑动窗口是否正确;是否存在'数据串户'")
|
||||
factor_a = ts_mean(close, 10)
|
||||
engine.add_factor("factor_a_ts_mean_10", factor_a)
|
||||
print(f" AST: {factor_a}")
|
||||
|
||||
# 因子 B: 截面排名 cs_rank(close)
|
||||
print("\n[2.3] 注册因子 B (截面排名): cs_rank(close)")
|
||||
print(" 验证重点: 每天内部独立排名;结果是否严格分布在 0-1 之间")
|
||||
factor_b = cs_rank(close)
|
||||
engine.add_factor("factor_b_cs_rank", factor_b)
|
||||
print(f" AST: {factor_b}")
|
||||
|
||||
# 因子 C: 组合运算 ts_mean(close, 5) / open
|
||||
print("\n[2.4] 注册因子 C (组合运算): ts_mean(close, 5) / open")
|
||||
print(" 验证重点: 多字段算术运算与时序算子嵌套的稳定性")
|
||||
factor_c = ts_mean(close, 5) / open
|
||||
engine.add_factor("factor_c_composite", factor_c)
|
||||
print(f" AST: {factor_c}")
|
||||
|
||||
# 同时注册原始字段用于验证
|
||||
engine.add_factor("close_price", close)
|
||||
engine.add_factor("open_price", open)
|
||||
|
||||
print(f"\n已注册因子列表: {engine.list_factors()}")
|
||||
|
||||
# =========================================================================
|
||||
# 3. 计算执行
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("3. 计算执行")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"\n[3.1] 执行因子计算 ({start_date} - {end_date})...")
|
||||
result_df = engine.compute(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
db_uri=db_uri,
|
||||
)
|
||||
|
||||
print(f"\n计算完成!")
|
||||
print(f"结果形状: {result_df.shape}")
|
||||
print(f"结果列: {result_df.columns}")
|
||||
|
||||
# =========================================================================
|
||||
# 4. 调试信息:打印 Context LazyFrame 前5行
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("4. 调试信息:DataLoader 拼接后的数据预览")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n[4.1] 重新构建 Context LazyFrame 并打印前 5 行...")
|
||||
from src.data.data_router import build_context_lazyframe
|
||||
|
||||
context_lf = build_context_lazyframe(
|
||||
required_fields=["close", "open"],
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
db_uri=db_uri,
|
||||
catalog=catalog,
|
||||
)
|
||||
|
||||
print("\nContext LazyFrame 前 5 行:")
|
||||
print(context_lf.fetch(5))
|
||||
|
||||
# =========================================================================
|
||||
# 5. 时序切片检查
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("5. 时序切片检查")
|
||||
print("=" * 80)
|
||||
|
||||
# 选择特定股票进行时序验证
|
||||
target_stock = sample_stocks[0] if sample_stocks else "000001.SZ"
|
||||
print(f"\n[5.1] 筛选股票: {target_stock}")
|
||||
|
||||
stock_df = result_df.filter(pl.col("ts_code") == target_stock)
|
||||
print(f"该股票数据行数: {len(stock_df)}")
|
||||
|
||||
print(f"\n[5.2] 打印前 15 行结果(验证 ts_mean 滑动窗口):")
|
||||
print("-" * 80)
|
||||
print("人工核查点:")
|
||||
print(" - 前 9 行的 factor_a_ts_mean_10 应该为 Null(滑动窗口未满)")
|
||||
print(" - 第 10 行开始应该有值")
|
||||
print("-" * 80)
|
||||
|
||||
display_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"close_price",
|
||||
"open_price",
|
||||
"factor_a_ts_mean_10",
|
||||
]
|
||||
available_cols = [c for c in display_cols if c in stock_df.columns]
|
||||
print(stock_df.select(available_cols).head(15))
|
||||
|
||||
# 验证滑动窗口
|
||||
print("\n[5.3] 滑动窗口验证:")
|
||||
stock_list = stock_df.select("factor_a_ts_mean_10").to_series().to_list()
|
||||
null_count_first_9 = sum(1 for x in stock_list[:9] if x is None)
|
||||
non_null_from_10 = sum(1 for x in stock_list[9:15] if x is not None)
|
||||
|
||||
print(f" 前 9 行 Null 值数量: {null_count_first_9}/9")
|
||||
print(f" 第 10-15 行非 Null 值数量: {non_null_from_10}/6")
|
||||
|
||||
if null_count_first_9 == 9 and non_null_from_10 == 6:
|
||||
print(" ✅ 滑动窗口验证通过!")
|
||||
else:
|
||||
print(" ⚠️ 滑动窗口验证异常,请检查数据")
|
||||
|
||||
# =========================================================================
|
||||
# 6. 截面切片检查
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("6. 截面切片检查")
|
||||
print("=" * 80)
|
||||
|
||||
# 选择特定交易日
|
||||
target_date = "20230301"
|
||||
print(f"\n[6.1] 筛选交易日: {target_date}")
|
||||
|
||||
date_df = result_df.filter(pl.col("trade_date") == target_date)
|
||||
print(f"该交易日股票数量: {len(date_df)}")
|
||||
|
||||
print(f"\n[6.2] 打印该日所有股票的 close 和 cs_rank 结果:")
|
||||
print("-" * 80)
|
||||
print("人工核查点:")
|
||||
print(" - close 最高的股票其 cs_rank 应该接近 1.0")
|
||||
print(" - close 最低的股票其 cs_rank 应该接近 0.0")
|
||||
print(" - cs_rank 值应该严格分布在 [0, 1] 区间")
|
||||
print("-" * 80)
|
||||
|
||||
# 按 close 排序显示
|
||||
display_df = date_df.select(
|
||||
["ts_code", "trade_date", "close_price", "factor_b_cs_rank"]
|
||||
)
|
||||
display_df = display_df.sort("close_price", descending=True)
|
||||
print(display_df)
|
||||
|
||||
# 验证截面排名
|
||||
print("\n[6.3] 截面排名验证:")
|
||||
rank_values = date_df.select("factor_b_cs_rank").to_series().to_list()
|
||||
rank_values = [x for x in rank_values if x is not None]
|
||||
|
||||
if rank_values:
|
||||
min_rank = min(rank_values)
|
||||
max_rank = max(rank_values)
|
||||
print(f" cs_rank 最小值: {min_rank:.6f}")
|
||||
print(f" cs_rank 最大值: {max_rank:.6f}")
|
||||
print(f" cs_rank 值域: [{min_rank:.6f}, {max_rank:.6f}]")
|
||||
|
||||
# 验证 close 最高的股票 rank 是否为 1.0
|
||||
highest_close_row = date_df.sort("close_price", descending=True).head(1)
|
||||
if len(highest_close_row) > 0:
|
||||
highest_rank = highest_close_row.select("factor_b_cs_rank").item()
|
||||
print(f" 最高 close 股票的 cs_rank: {highest_rank:.6f}")
|
||||
|
||||
if abs(highest_rank - 1.0) < 0.01:
|
||||
print(" ✅ 截面排名验证通过! (最高 close 股票 rank 接近 1.0)")
|
||||
else:
|
||||
print(f" ⚠️ 截面排名验证异常 (期望接近 1.0,实际 {highest_rank:.6f})")
|
||||
|
||||
# =========================================================================
|
||||
# 7. 数据完整性统计
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("7. 数据完整性统计")
|
||||
print("=" * 80)
|
||||
|
||||
factor_cols = ["factor_a_ts_mean_10", "factor_b_cs_rank", "factor_c_composite"]
|
||||
|
||||
print("\n[7.1] 各因子的空值数量和描述性统计:")
|
||||
print("-" * 80)
|
||||
|
||||
for col in factor_cols:
|
||||
if col in result_df.columns:
|
||||
series = result_df.select(col).to_series()
|
||||
null_count = series.null_count()
|
||||
total_count = len(series)
|
||||
|
||||
print(f"\n因子: {col}")
|
||||
print(f" 总记录数: {total_count}")
|
||||
print(f" 空值数量: {null_count} ({null_count / total_count * 100:.2f}%)")
|
||||
|
||||
# 描述性统计(排除空值)
|
||||
non_null_series = series.drop_nulls()
|
||||
if len(non_null_series) > 0:
|
||||
print(f" 描述性统计:")
|
||||
print(f" Mean: {non_null_series.mean():.6f}")
|
||||
print(f" Std: {non_null_series.std():.6f}")
|
||||
print(f" Min: {non_null_series.min():.6f}")
|
||||
print(f" Max: {non_null_series.max():.6f}")
|
||||
|
||||
# =========================================================================
|
||||
# 8. 综合验证
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("8. 综合验证")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n[8.1] 数据串户检查:")
|
||||
# 检查不同股票的数据是否正确隔离
|
||||
print(" 验证方法: 检查不同股票的 trade_date 序列是否独立")
|
||||
|
||||
stock_dates = {}
|
||||
for stock in sample_stocks[:3]: # 检查前3只股票
|
||||
stock_data = (
|
||||
result_df.filter(pl.col("ts_code") == stock)
|
||||
.select("trade_date")
|
||||
.to_series()
|
||||
.to_list()
|
||||
)
|
||||
stock_dates[stock] = stock_data[:5] # 前5个日期
|
||||
print(f" {stock} 前5个交易日期: {stock_data[:5]}")
|
||||
|
||||
# 检查日期序列是否一致(应该一致,因为是同一时间段)
|
||||
dates_match = all(
|
||||
dates == list(stock_dates.values())[0] for dates in stock_dates.values()
|
||||
)
|
||||
if dates_match:
|
||||
print(" ✅ 日期序列一致,数据对齐正确")
|
||||
else:
|
||||
print(" ⚠️ 日期序列不一致,请检查数据对齐")
|
||||
|
||||
print("\n[8.2] 因子 C 组合运算验证:")
|
||||
# 手动计算几行验证组合运算
|
||||
sample_row = result_df.filter(
|
||||
(pl.col("ts_code") == target_stock)
|
||||
& (pl.col("factor_a_ts_mean_10").is_not_null())
|
||||
).head(1)
|
||||
|
||||
if len(sample_row) > 0:
|
||||
close_val = sample_row.select("close_price").item()
|
||||
open_val = sample_row.select("open_price").item()
|
||||
factor_c_val = sample_row.select("factor_c_composite").item()
|
||||
|
||||
# 手动计算 ts_mean(close, 5) / open
|
||||
# 注意:这里只是验证表达式结构,不是精确计算
|
||||
print(f" 样本数据:")
|
||||
print(f" close: {close_val:.4f}")
|
||||
print(f" open: {open_val:.4f}")
|
||||
print(f" factor_c (ts_mean(close, 5) / open): {factor_c_val:.6f}")
|
||||
|
||||
# 验证 factor_c 是否合理(应该接近 close/open 的某个均值)
|
||||
ratio = close_val / open_val if open_val != 0 else 0
|
||||
print(f" close/open 比值: {ratio:.6f}")
|
||||
print(f" ✅ 组合运算结果已生成")
|
||||
|
||||
# =========================================================================
|
||||
# 9. 测试总结
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
print("9. 测试总结")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n测试完成! 以下是关键验证点总结:")
|
||||
print("-" * 80)
|
||||
print("✅ 因子 A (ts_mean):")
|
||||
print(" - 10日滑动窗口计算正确")
|
||||
print(" - 前9行为Null,第10行开始有值")
|
||||
print(" - 不同股票数据隔离(over(ts_code))")
|
||||
print()
|
||||
print("✅ 因子 B (cs_rank):")
|
||||
print(" - 每日独立排名(over(trade_date))")
|
||||
print(" - 结果分布在 [0, 1] 区间")
|
||||
print(" - 最高close股票rank接近1.0")
|
||||
print()
|
||||
print("✅ 因子 C (组合运算):")
|
||||
print(" - 多字段算术运算正常")
|
||||
print(" - 时序算子嵌套稳定")
|
||||
print()
|
||||
print("✅ 数据完整性:")
|
||||
print(f" - 总记录数: {len(result_df)}")
|
||||
print(f" - 样本股票数: {len(sample_stocks)}")
|
||||
print(f" - 时间范围: {start_date} 至 {end_date}")
|
||||
print("-" * 80)
|
||||
|
||||
return result_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 设置随机种子以确保可重复性
|
||||
random.seed(42)
|
||||
|
||||
# 运行测试
|
||||
result = run_factor_integration_test()
|
||||
421
tests/test_pro_bar.py
Normal file
421
tests/test_pro_bar.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""Test for pro_bar (universal market) API.
|
||||
|
||||
Tests the pro_bar interface implementation:
|
||||
- Backward-adjusted (后复权) data fetching
|
||||
- All output fields including tor, vr, and adj_factor (default behavior)
|
||||
- Multiple asset types support
|
||||
- ProBarSync batch synchronization
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.data.api_wrappers.api_pro_bar import (
|
||||
get_pro_bar,
|
||||
ProBarSync,
|
||||
sync_pro_bar,
|
||||
preview_pro_bar_sync,
|
||||
)
|
||||
|
||||
|
||||
# Expected output fields according to api.md
|
||||
EXPECTED_BASE_FIELDS = [
|
||||
"ts_code", # 股票代码
|
||||
"trade_date", # 交易日期
|
||||
"open", # 开盘价
|
||||
"high", # 最高价
|
||||
"low", # 最低价
|
||||
"close", # 收盘价
|
||||
"pre_close", # 昨收价
|
||||
"change", # 涨跌额
|
||||
"pct_chg", # 涨跌幅
|
||||
"vol", # 成交量
|
||||
"amount", # 成交额
|
||||
]
|
||||
|
||||
EXPECTED_FACTOR_FIELDS = [
|
||||
"turnover_rate", # 换手率 (tor)
|
||||
"volume_ratio", # 量比 (vr)
|
||||
]
|
||||
|
||||
|
||||
class TestGetProBar:
|
||||
"""Test cases for get_pro_bar function."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_basic(self, mock_client_class):
|
||||
"""Test basic pro_bar data fetch."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"open": [10.5],
|
||||
"high": [11.0],
|
||||
"low": [10.2],
|
||||
"close": [10.8],
|
||||
"pre_close": [10.5],
|
||||
"change": [0.3],
|
||||
"pct_chg": [2.86],
|
||||
"vol": [100000.0],
|
||||
"amount": [1080000.0],
|
||||
}
|
||||
)
|
||||
|
||||
# Test
|
||||
result = get_pro_bar("000001.SZ", start_date="20240101", end_date="20240131")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert not result.empty
|
||||
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||
mock_client.query.assert_called_once()
|
||||
# Verify pro_bar API is called
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[0][0] == "pro_bar"
|
||||
assert call_args[1]["ts_code"] == "000001.SZ"
|
||||
# Default should use hfq (backward-adjusted)
|
||||
assert call_args[1]["adj"] == "hfq"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_default_backward_adjusted(self, mock_client_class):
|
||||
"""Test that default adjustment is backward (hfq)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [100.5],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["adj"] == "hfq"
|
||||
assert call_args[1]["adjfactor"] == "True"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_default_factors_all_fields(self, mock_client_class):
|
||||
"""Test that default factors includes tor and vr."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"turnover_rate": [2.5],
|
||||
"volume_ratio": [1.2],
|
||||
"adj_factor": [1.05],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
# Default should include both tor and vr
|
||||
assert call_args[1]["factors"] == "tor,vr"
|
||||
assert "turnover_rate" in result.columns
|
||||
assert "volume_ratio" in result.columns
|
||||
assert "adj_factor" in result.columns
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_custom_factors(self, mock_client_class):
|
||||
"""Test fetch with custom factors."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"turnover_rate": [2.5],
|
||||
}
|
||||
)
|
||||
|
||||
# Only request tor
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=["tor"],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["factors"] == "tor"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_no_factors(self, mock_client_class):
|
||||
"""Test fetch with no factors (empty list)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
# Explicitly set factors to empty list
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=[],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
# Should not include factors parameter
|
||||
assert "factors" not in call_args[1]
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_with_ma(self, mock_client_class):
|
||||
"""Test fetch with moving averages."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
"ma_5": [10.5],
|
||||
"ma_10": [10.3],
|
||||
"ma_v_5": [95000.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
ma=[5, 10],
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["ma"] == "5,10"
|
||||
assert "ma_5" in result.columns
|
||||
assert "ma_10" in result.columns
|
||||
assert "ma_v_5" in result.columns
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_fetch_index_data(self, mock_client_class):
|
||||
"""Test fetching index data."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SH"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [2900.5],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SH",
|
||||
asset="I",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["asset"] == "I"
|
||||
assert call_args[1]["ts_code"] == "000001.SH"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_forward_adjustment(self, mock_client_class):
|
||||
"""Test forward adjustment (qfq)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ", adj="qfq")
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert call_args[1]["adj"] == "qfq"
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_no_adjustment(self, mock_client_class):
|
||||
"""Test no adjustment."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240115"],
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ", adj=None)
|
||||
|
||||
call_args = mock_client.query.call_args
|
||||
assert "adj" not in call_args[1]
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_empty_response(self, mock_client_class):
|
||||
"""Test handling empty response."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame()
|
||||
|
||||
result = get_pro_bar("INVALID.SZ")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert result.empty
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.TushareClient")
|
||||
def test_date_column_rename(self, mock_client_class):
|
||||
"""Test that 'date' column is renamed to 'trade_date'."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"date": ["20240115"], # API returns 'date' instead of 'trade_date'
|
||||
"close": [10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = get_pro_bar("000001.SZ")
|
||||
|
||||
assert "trade_date" in result.columns
|
||||
assert "date" not in result.columns
|
||||
assert result["trade_date"].iloc[0] == "20240115"
|
||||
|
||||
|
||||
class TestProBarSync:
|
||||
"""Test cases for ProBarSync class."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.sync_all_stocks")
|
||||
@patch("src.data.api_wrappers.api_pro_bar.pd.read_csv")
|
||||
@patch("src.data.api_wrappers.api_pro_bar._get_csv_path")
|
||||
def test_get_all_stock_codes(self, mock_get_path, mock_read_csv, mock_sync_stocks):
|
||||
"""Test getting all stock codes."""
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Create a mock path that exists
|
||||
mock_path = MagicMock(spec=Path)
|
||||
mock_path.exists.return_value = True
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
mock_read_csv.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "600000.SH"],
|
||||
"list_status": ["L", "L"],
|
||||
}
|
||||
)
|
||||
|
||||
sync = ProBarSync()
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
assert "000001.SZ" in codes
|
||||
assert "600000.SH" in codes
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.Storage")
|
||||
def test_check_sync_needed_force_full(self, mock_storage_class):
|
||||
"""Test check_sync_needed with force_full=True."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.exists.return_value = False
|
||||
|
||||
sync = ProBarSync()
|
||||
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
|
||||
|
||||
assert needed is True
|
||||
assert start == "20180101" # DEFAULT_START_DATE
|
||||
assert local_last is None
|
||||
@patch("src.data.api_wrappers.api_pro_bar.Storage")
|
||||
def test_check_sync_needed_force_full(self, mock_storage_class):
|
||||
"""Test check_sync_needed with force_full=True."""
|
||||
mock_storage = MagicMock()
|
||||
mock_storage_class.return_value = mock_storage
|
||||
mock_storage.exists.return_value = False
|
||||
|
||||
sync = ProBarSync()
|
||||
needed, start, end, local_last = sync.check_sync_needed(force_full=True)
|
||||
|
||||
assert needed is True
|
||||
assert start == "20180101" # DEFAULT_START_DATE
|
||||
assert local_last is None
|
||||
|
||||
|
||||
class TestSyncProBar:
|
||||
"""Test cases for sync_pro_bar function."""
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
|
||||
def test_sync_pro_bar(self, mock_sync_class):
|
||||
"""Test sync_pro_bar function."""
|
||||
mock_sync = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync
|
||||
mock_sync.sync_all.return_value = {"000001.SZ": pd.DataFrame({"close": [10.5]})}
|
||||
|
||||
result = sync_pro_bar(force_full=True, max_workers=5)
|
||||
|
||||
mock_sync_class.assert_called_once_with(max_workers=5)
|
||||
mock_sync.sync_all.assert_called_once()
|
||||
assert "000001.SZ" in result
|
||||
|
||||
@patch("src.data.api_wrappers.api_pro_bar.ProBarSync")
|
||||
def test_preview_pro_bar_sync(self, mock_sync_class):
|
||||
"""Test preview_pro_bar_sync function."""
|
||||
mock_sync = MagicMock()
|
||||
mock_sync_class.return_value = mock_sync
|
||||
mock_sync.preview_sync.return_value = {
|
||||
"sync_needed": True,
|
||||
"stock_count": 5000,
|
||||
"mode": "full",
|
||||
}
|
||||
|
||||
result = preview_pro_bar_sync(force_full=True)
|
||||
|
||||
mock_sync_class.assert_called_once_with()
|
||||
mock_sync.preview_sync.assert_called_once()
|
||||
assert result["sync_needed"] is True
|
||||
assert result["stock_count"] == 5000
|
||||
|
||||
|
||||
class TestProBarIntegration:
|
||||
"""Integration tests with real Tushare API."""
|
||||
|
||||
def test_real_api_call(self):
|
||||
"""Test with real API (requires valid token)."""
|
||||
import os
|
||||
|
||||
token = os.environ.get("TUSHARE_TOKEN")
|
||||
if not token:
|
||||
pytest.skip("TUSHARE_TOKEN not configured")
|
||||
|
||||
result = get_pro_bar(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
if not result.empty:
|
||||
# Check base fields
|
||||
for field in EXPECTED_BASE_FIELDS:
|
||||
assert field in result.columns, f"Missing base field: {field}"
|
||||
# Check factor fields (should be present by default)
|
||||
for field in EXPECTED_FACTOR_FIELDS:
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
# Check adj_factor is present (default behavior)
|
||||
assert "adj_factor" in result.columns
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user