diff --git a/docs/Classify2_load_model_doc.md b/docs/Classify2_load_model_doc.md deleted file mode 100644 index b6c705c..0000000 --- a/docs/Classify2_load_model_doc.md +++ /dev/null @@ -1,467 +0,0 @@ -# Classify2_load_model.ipynb 代码流程详解文档 - -## 概述 - -本文档详细描述了 `Classify2_load_model.ipynb` 文件中完整的代码流程,该notebook实现了一个股票分类模型的特征工程与数据预处理流程。整个流程涵盖了从原始数据加载、多源数据合并、因子计算、数据清洗、特征标准化等多个环节,最终输出可供机器学习模型使用的特征矩阵。 - ---- - -## 第一部分:环境配置与初始化 - -### 1.1 基础环境设置 - - notebook开头的第一个代码单元负责基础的开发环境配置。首先启用Jupyter的自动重载功能,通过`%load_ext autoreload`和`%autoreload 2`指令,使得在修改导入的模块后能够自动重新加载,这对于开发调试阶段非常实用。随后进行垃圾回收机制的相关设置,导入`gc`模块用于手动管理内存,因为后续处理的数据量较大(数GB级别),及时释放不需要的对象可以避免内存溢出。 - -系统路径的配置通过`sys.path.append`完成,将项目根目录添加到Python的模块搜索路径中,这样可以直接导入项目内的自定义模块。打印工作目录用于确认当前运行环境。数据处理的核心库`pandas`被导入用于表格数据处理,同时导入项目内部的因子模块和工具函数,包括`get_rolling_factor`和`get_simple_factor`用于计算技术因子,`read_industry_data`用于读取行业分类数据,`calculate_score`用于计算目标变量。最后通过`warnings.filterwarnings("ignore")`忽略所有警告信息,保持输出界面的整洁。 - -### 1.2 并行计算配置 - -第二个代码单元配置了Modin框架的并行计算参数。通过设置环境变量`os.environ["MODIN_CPUS"] = "4"`,指定使用4个CPU核心进行并行计算。Modin是一个pandas的并行替代库,能够在多核CPU上并行执行pandas操作,显著加速大规模数据的处理速度。这一配置对于后续处理数千万条记录的数据框至关重要。 - ---- - -## 第二部分:核心数据加载与合并 - -### 2.1 日线数据加载 - -第三个代码单元是整个流程中最关键的数据加载步骤,依次从HDF5文件中读取多个数据源并合并。HDF5是一种高效的分块压缩存储格式,非常适合存储大规模数值型数据。 - -**第一步:加载日线行情数据** - -使用`read_and_merge_h5_data`函数从`daily_data.h5`文件中读取日线行情数据。该函数是项目自定义的工具函数,其核心逻辑在`main/utils/utils.py`中实现。读取的字段包括股票代码`ts_code`、交易日期`trade_date`、开盘价`open`、收盘价`close`、最高价`high`、最低价`low`、成交量`vol`、成交额`amount`以及涨跌幅`pct_chg`。这些字段构成分析的基础价格信息。首次调用该函数时`df`参数为空,因此直接返回读取的数据作为初始数据框。 - -**第二步:加载日线基础数据** - -继续调用`read_and_merge_h5_data`读取`daily_basic.h5`文件,获取每日股票的基础财务指标。本次读取使用内连接(inner join)方式与已有数据进行合并,这意味着只保留两个数据集中都存在的股票记录。读取的字段包括`turnover_rate`(换手率)、`pe_ttm`(滚动市盈率)、`circ_mv`(流通市值)、`total_mv`(总市值)以及`volume_ratio`(量比)。这些指标对于评估股票的流动性和估值水平非常重要。 - -**第三步:加载涨跌停限制数据** - -从`stk_limit.h5`文件中读取股票的涨跌停价格信息。读取的字段包括`pre_close`(前一日收盘价)、`up_limit`(涨停价)和`down_limit`(跌停价)。这些信息用于后续识别股票的涨停状态,是构建目标变量的重要依据。 - -**第四步:加载资金流数据** - -从`money_flow.h5`文件中读取资金流信息,这是计算资金流因子的基础数据。读取的字段包括各类成交量的分解:`buy_sm_vol`和`sell_sm_vol`代表小单买卖成交量、`buy_lg_vol`和`sell_lg_vol`代表大单买卖成交量、`buy_elg_vol`和`sell_elg_vol`代表超大单买卖成交量,以及`net_mf_vol`(净资金流成交量)。通过分析这些不同档位的资金流向,可以判断机构投资者和散户的交易行为。 - -**第五步:加载筹码分布数据** - -最后从`cyq_perf.h5`文件中读取筹码分布相关指标。这些数据反映的是股票持仓成本分布情况,包括历史最低价`his_low`、历史最高价`his_high`、不同成本分位线的价格(`cost_5pct`、`cost_15pct`、`cost_50pct`、`cost_85pct`、`cost_95pct`)、加权平均成本`weight_avg`以及获利盘比例`winner_rate`。这些指标是计算筹码分布因子的核心数据。 - -经过上述五个步骤的依次合并,最终生成的数据框包含9436343条记录、33个字段,占用内存约2.3GB。 - -### 2.2 行业数据加载与合并 - -第四个代码单元处理行业分类数据的加载与合并。首先使用`read_and_merge_h5_data`函数从`industry_data.h5`文件中读取行业分类数据,提取股票代码、二级行业代码`l2_code`以及行业纳入日期`in_date`。 - -随后定义了`merge_with_industry_data`函数,用于将行业数据与主数据框进行时间匹配合并。该函数的核心逻辑分为以下几个步骤:首先确保日期字段转换为datetime类型;然后分别对行业数据和主数据按股票代码和日期排序;接着使用`pd.merge_asof`函数进行向后合并(direction='backward'),这意味着对于每一股票的每一个交易日,会匹配该日之前(包括当日)最近的行业变更记录;对于交易日期早于所有行业变更日期的记录,使用每个股票的最早行业代码进行填充。 - -这一合并策略的设计目的是确保在任意交易日期都能获取到该股票当前所属的行业分类,这对于后续进行行业中性化处理至关重要。 - ---- - -## 第三部分:指数数据处理 - -### 3.1 指数数据读取 - -第五个代码单元实现指数数据的读取与指标计算。定义了两个核心函数:`calculate_indicators`用于计算单个指数的技术指标,`generate_index_indicators`用于批量处理多个指数。 - -在`calculate_indicators`函数中,首先对数据按交易日期排序,然后计算以下指标: - -**当日涨跌幅**:通过`(close - pre_close) / pre_close * 100`计算,这是最基础的收益率指标。 - -**RSI指标**:相对强弱指数,计算过程包括首先计算价格变动delta,然后分离上涨和下跌部分,分别计算14日滚动平均,最后通过公式`100 - (100 / (1 + rs))`得到RSI值。RSI是衡量股票近期涨跌动能的经典技术指标。 - -**MACD指标**:移动平均收敛发散指标,计算过程包括12日EMA减去26日EMA得到MACD线,9日EMA得到信号线,MACD与信号线的差值得到MACD柱。MACD是判断趋势方向和动量变化的重要工具。 - -**情绪因子**: - -- **上涨比例(up_ratio_20d)**:过去20天上涨天数占比,反映市场整体情绪 -- **成交量变化率(volume_change_rate)**:当日成交量与20日均量的比率,反映成交量异常变化 -- **波动率(volatility)**:过去20天日收益率的标准差,反映市场波动程度 -- **成交额变化率(amount_change_rate)**:当日成交额与20日均额的比率 - -`generate_index_indicators`函数则对所有指数分别计算上述指标,然后将结果透视为宽表格式,使每个指数的指标成为独立的列。最终输出包含多个指数技术指标的数据框,这些指标可以作为市场情绪的代理变量加入模型。 - -### 3.2 行业指数数据处理 - -第六个代码单元使用talib库和numpy库进行更专业的技术指标计算。talib是专业的技术分析库,包含数百种经典技术指标的实现。 - -从`sw_daily.h5`文件读取行业日线数据后,依次计算: - -**OBV(能量潮指标)**:通过`talib.OBV`函数计算,将成交量与价格变动方向结合,衡量资金流入流出的强度。 - -**短期收益率(return_5)**:5日收益率,计算公式为`x / x.shift(5) - 1`。 - -**中期收益率(return_20)**:20日收益率,用于捕捉中短期动量。 - -**动量因子(act_factor)**:通过`get_act_factor`函数计算,该函数对不同周期的EMA(指数移动平均)进行arctan变换,得到平滑的动量因子。这是项目自定义的核心因子之一。 - -**收益率分位数排名**:将收益率在截面内转换为百分位排名,消除不同行业间的基数差异,便于跨行业比较。 - -最终对所有新增字段添加`industry_`前缀,并重命名`ts_code`为`cat_l2_code`,形成行业级别的因子数据,可用于后续的行业对标分析。 - ---- - -## 第四部分:财务数据加载 - -### 4.1 财务指标数据 - -第九至十一个代码单元负责加载各类财务数据,为模型增加基本面因子。 - -**财务指标数据(fina_indicator)**:从`fina_indicator.h5`文件读取,包含`undist_profit_ps`(每股未分配利润)、`ocfps`(每股经营现金流)、`bps`(每股净资产)、`roa`(资产收益率)和`roe`(净资产收益率)。这些指标反映企业的盈利能力和财务健康状况。 - -**现金流数据(cashflow)**:从`cashflow.h5`文件读取,包含`n_cashflow_act`(经营活动净现金流),用于评估企业现金流的真实状况。 - -**资产负债表数据(balancesheet)**:从`balancesheet.h5`文件读取,包含`money_cap`(货币资金)和`total_liab`(总负债),用于计算企业的偿债能力。 - -这些财务数据通过`ann_date`(公告日期)与交易数据进行时间匹配,确保使用最新披露的财务信息。需要注意的是,财务数据存在滞后性,通常季报在季度结束后一段时间才披露,因此在匹配时使用向后合并策略。 - ---- - -## 第五部分:数据过滤与清洗 - -### 5.1 数据过滤规则 - -第十一个代码单元定义了`filter_data`函数,应用一系列过滤规则清理数据: - -**排除ST股票**:`df[~df['is_st']]`过滤掉所有ST(特别处理)股票,这类股票存在较大风险。 - -**排除北京股票**:`df[~df['ts_code'].str.endswith('BJ')]`排除北京交易所的股票。 - -**排除创业板和科创板**:`df[~df['ts_code'].str.startswith('30')]`排除创业板股票,`df[~df['ts_code'].str.startswith('68')]`排除科创板股票。这一设计可能是因为这两个板块的股票特性与主板有较大差异。 - -**排除ST类代码**:`df[~df['ts_code'].str.startswith('8')]`排除了以8开头的股票代码,通常这类代码也属于风险警示类。 - -**时间范围过滤**:`df[df['trade_date'] >= '2019-01-01']]`只保留2019年及以后的数据,确保数据质量和一致性。 - -**删除冗余列**:如果存在`in_date`列则删除,避免与交易日期混淆。 - -过滤后的数据量从9436343条减少到5087384条,约减少46%。 - ---- - -## 第六部分:因子计算 - -### 6.1 财务因子计算 - -第十二个代码单元是因子计算的核心部分,首先计算一系列财务相关因子: - -**现金流因子(cashflow_to_ev_factor)**:将现金流数据与企业价值进行对比,计算企业现金流的相对估值水平。 - -**市净率因子(book_to_price_ratio)**:通过`caculate_book_to_price_ratio`函数计算,每股净资产与股价的比值,是价值投资的经典指标。 - -### 6.2 基础技术因子 - -继续计算各类技术因子: - -**换手率均值(turnover_rate_mean_5)**:5日平均换手率,反映股票的交易活跃程度。 - -**收益率方差(variance_20)**:20日收益率方差,衡量短期波动性。 - -**BBI比率因子(bbi_ratio_factor)**:BBI是多空指数的简称,通过不同周期均线的组合判断多空趋势。 - -**日偏离度(daily_deviation)**:当日涨跌幅与市场平均涨跌幅的差异,反映个股的相对强弱。 - -**日行业偏离度(daily_industry_deviation)**:相对于所在行业平均涨跌幅的偏离,反映行业内相对表现。 - -### 6.3 滚动因子与简单因子 - -调用`get_rolling_factor`和`get_simple_factor`函数批量计算大量因子。这两个函数定义在`main/factor/factor.py`中,是项目最核心的因子计算模块。 - -**get_rolling_factor函数**主要计算的因子包括: - -**资金流因子组**: - -- `lg_elg_net_buy_vol`:超大单加大单的净买入量 -- `flow_lg_elg_intensity`:主力资金流强度,等于净买入量除以成交量 -- `sm_net_buy_vol`:散户净买入量 -- `flow_divergence_diff`:散户与主力背离度,主力净买入减去散户净买入 -- `flow_divergence_ratio`:背离比率形式 -- `lg_elg_buy_prop`:主力买入占比 -- `flow_struct_buy_change`:资金流结构变动,1日变化率 -- `flow_lg_elg_accel`:资金流加速度,二阶导数 - -**筹码分布因子组**: - -- `chip_concentration_range`:(95%成本价 - 5%成本价)/ 当前价格,反映筹码集中程度 -- `chip_skewness`:(加权平均成本 - 50%成本价)/ 50%成本价,反映成本分布偏斜方向 -- `floating_chip_proxy`:浮筹比例,结合获利盘和价格位置 -- `cost_support_15pct_change`:15%成本线变化率,反映支撑位变动 -- `cat_winner_price_zone`:获利盘价格区域分类 - -**资金与筹码结合因子**: - -- `flow_chip_consistency`:资金流与筹码结构一致性 -- `profit_taking_vs_absorb`:获利了结压力与承接盘强度 - -**收益率分布因子**: - -- `upside_vol`和`downside_vol`:上涨和下跌方向的标准差 -- `vol_ratio`:上下波动率比值 -- `return_skew`和`return_kurtosis`:收益率的偏度和峰度 - -**成交量因子**: - -- `volume_change_rate`:成交量变化率 -- `cat_volume_breakout`:成交量突破信号 -- `turnover_deviation`:换手率偏离度 -- `cat_turnover_spike`:换手率激增信号 -- `avg_volume_ratio`和`cat_volume_ratio_breakout`:量比相关因子 - -**talib技术指标**: - -- `atr_14`和`atr_6`:14日和6日平均真实波幅 -- `obv`和`maobv_6`:能量潮及其6日均线 -- `rsi_3`:3日RSI - -**收益率因子**: - -- `return_5`和`return_20`:5日和20日收益率 -- `std_return_5`、`std_return_90`、`std_return_90_2`:不同周期的收益率标准差 - -**EMA与动量因子**: - -- 各周期EMA(5、13、20、60日) -- `act_factor1`到`act_factor4`:基于EMA的动量因子 - -**协方差因子**: - -- `cov`:高价与成交量的滚动协方差 -- `delta_cov`:协方差差分 -- `alpha_22_improved`:改进的alpha因子 - -**其他因子**: - -- `alpha_003`:收盘价与开盘价的相对位置 -- `alpha_007`:收盘价与成交量的5日滚动相关性 -- `alpha_013`:5日与20日累计收益差值 -- `vol_break`:价格突破85%成本线且量比大于2的信号 -- `weight_roc5`:加权成本5日变化率 -- `price_cost_divergence`:价格与成本相关性 -- `smallcap_concentration`:小盘股筹码集中度 -- `cost_stability`:筹码稳定性指数 -- `liquidity_risk`:筹码流动性风险 - -**get_simple_factor函数**在滚动因子的基础上进一步计算衍生因子: - -- `momentum_factor`:动量因子,等于成交量变化率加0.5倍换手率偏离度 -- `resonance_factor`:共振因子,量比乘以涨跌幅 -- `log_close`:收盘价对数 -- `cat_vol_spike`:成交量激增分类变量 -- `up`和`down`:上下影线比例 -- `obv_maobv_6`:OBV与6日均线差值 -- `std_return_5_over_std_return_90`:短期与长期波动率比值 -- `act_factor5`和`act_factor6`:综合动量因子 -- `active_buy_volume_*`:各档位主动买入占比 -- `ctrl_strength`:控制强度,反映成本区间与历史区间的比值 -- `low_cost_dev`、`asymmetry`、`lock_factor`:各类筹码特征因子 -- `cat_golden_resonance`:黄金共振信号 - -### 6.4 资金流因子 - -第十三个代码单元继续计算`money_factor.py`中定义的各类资金流相关因子。这些因子主要分析大单资金的流动特征和动量属性。 - -**lg_flow_mom_corr**:大单资金流与20至60日滚动窗口的动量相关性,衡量资金流趋势的持续性。 - -**lg_flow_accel**:大单资金流加速度,捕捉资金流入流出的加速变化。 - -**profit_pressure**:盈利压力因子,结合价格位置和资金流向。 - -**underwater_resistance**:水下阻力因子,分析亏损区域的支持强度。 - -**cost_conc_std_20**:20日成本集中度标准差,反映筹码分布的稳定程度。 - -**profit_decay_20**:20日盈利衰减因子,衡量获利盘随时间的减少程度。 - -**vol_amp_loss_20**:20日波动幅度损失因子。 - -**vol_drop_profit_cnt_5**:5日内量价齐升计数。 - -**lg_flow_vol_interact_20**:20日大单资金流与成交量交互因子。 - -**cost_break_confirm_cnt_5**:5日内成本突破确认计数。 - -**atr_norm_channel_pos_14**:14日ATR标准化通道位置。 - -**turnover_diff_skew_20**:20日换手率差值偏度。 - -**lg_sm_flow_diverge_20**:20日大小单资金流背离。 - -**pullback_strong_20_20**:20日回撤强度因子。 - -**vol_wgt_hist_pos_20**:20日成交量加权历史位置。 - -**vol_adj_roc_20**:20日成交量调整变动率。 - -### 6.5 截面排名因子 - -第十三个代码单元的后半部分定义了一系列`cs_rank_*`开头的截面排名因子,这些因子都是通过截面排名方式计算得出,可以消除不同股票间的基数差异。 - -**cs_rank_net_lg_flow_val**:大单净流入值截面排名 - -**cs_rank_flow_divergence**:资金流背离截面排名 - -**cs_rank_industry_adj_lg_flow**:行业调整后大单资金流截面排名 - -**cs_rank_elg_buy_ratio**:超大单买入占比截面排名 - -**cs_rank_rel_profit_margin**:相对盈利-margin截面排名 - -**cs_rank_cost_breadth**:成本宽度截面排名 - -**cs_rank_dist_to_upper_cost**:到上方成本距离截面排名 - -**cs_rank_winner_rate**:获利盘比例截面排名 - -**cs_rank_intraday_range**:日内振幅截面排名 - -**cs_rank_close_pos_in_range**:收盘价在成本区间位置截面排名 - -**cs_rank_opening_gap**:跳空缺口截面排名(需要前收盘价) - -**cs_rank_pos_in_hist_range**:历史区间位置截面排名 - -**cs_rank_vol_x_profit_margin**:成交量与盈利margin交互截面排名 - -**cs_rank_lg_flow_price_concordance**:大单资金流与价格一致性截面排名 - -**cs_rank_turnover_per_winner**:单位获利盘换手率截面排名 - -**cs_rank_ind_cap_neutral_pe**:行业市值中性市盈率截面排名 - -**cs_rank_volume_ratio**:量比截面排名 - -**cs_rank_elg_buy_sell_sm_ratio**:超大单买卖与小单比值截面排名 - -**cs_rank_cost_dist_vol_ratio**:成本分布成交量比值截面排名 - -**cs_rank_size**:市值规模截面排名 - -经过上述所有因子计算,数据框最终包含181个字段,内存占用约6.5GB。 - ---- - -## 第七部分:数据预处理函数定义 - -### 7.1 特征漂移检测 - -第十四个代码单元定义了`remove_shifted_features`函数,用于检测训练集和测试集之间是否存在特征分布漂移。漂移检测使用两种统计方法: - -**KS检验(Kolmogorov-Smirnov检验)**:比较两个分布是否来自同一分布,通过p值判断,p值小于阈值(默认0.05)则认为存在显著差异。 - -**Wasserstein距离**:衡量两个分布之间的Earth Mover距离,距离大于阈值(默认0.1)则认为漂移严重。 - -如果特征同时满足两个条件,则认为该特征存在漂移并将其移除。这一步骤对于确保模型在测试集上的泛化能力非常重要。 - -### 7.2 标准化处理函数 - -第十五个代码单元定义了三个核心的标准化处理函数: - -**cs_mad_filter(截面MAD去极值函数)**: - -MAD(Median Absolute Deviation,中位绝对偏差)是一种稳健的极值检测方法。具体步骤包括:按日期分组计算每列的中位数,然后计算每个值与中位数的绝对偏差,再次取中位数得到MAD,最后将超出`[median - k * MAD, median + k * MAD]`范围的值截断到边界。默认k=3,scale_factor=1.4826使得MAD约等于正态分布的标准差。 - -**cs_neutralize_market_cap_numpy(市值中性化函数)**: - -对每个交易日的每只股票,使用OLS回归将因子值对市值(取对数)进行回归,提取残差作为中性化后的因子值。这一步骤消除因子中与市值相关的系统性偏差,使因子更能反映股票的真实特性而非市值效应。 - -**cs_zscore_standardize(截面Z-Score标准化函数)**: - -对每个截面(每日)计算各因子的均值和标准差,然后进行Z-Score变换:`Z = (value - mean) / (std + epsilon)`。这使得不同量纲的因子具有可比性,且均值为0、标准差为1。 - -**fill_nan_with_daily_median(截面中位数填充函数)**: - -对每个交易日分别计算各因子的中位数,用该中位数填充该日内的缺失值。这一方法比全局中位数填充更能反映数据的时间特性。 - -### 7.3 其他预处理函数 - -第十五个代码单元还定义了其他辅助预处理函数: - -**remove_outliers_label_percentile**:对标签进行百分位去极值,默认保留1%到99%分位之间的数据。 - -**calculate_risk_adjusted_target**:计算风险调整后的目标变量,结合未来收益率和未来波动性。 - -**calculate_score**:计算综合评分目标,考虑未来收益减去风险惩罚(最大回撤)。 - -**remove_highly_correlated_features**:移除高度相关的特征,避免多重共线性问题。 - -**cross_sectional_standardization**:截面标准化,使用StandardScaler进行标准化。 - -**neutralize_manual_revised**:手动实现的行业市值中性化函数,对每个行业分别进行回归取残差。 - -**mad_filter**:全局MAD去极值函数(简化版)。 - -**percentile_filter**:百分位去极值函数。 - -**iqr_filter**:四分位距标准化函数。 - -**quantile_filter**:滚动分位数去极值函数。 - -**select_top_features_by_rankic**:基于RankIC选择最优特征。 - -**create_deviation_within_dates**:创建截面偏差特征,计算行业内各因子与均值的偏差。 - ---- - -## 第八部分:目标变量构建 - -### 8.1 未来收益率计算 - -第十六个代码单元负责构建预测目标变量。 - -**未来收益率**:通过`df.groupby('ts_code')['close'].shift(-days) / df['close'] - 1`计算,days默认为5,即预测未来5日的收益率。shift(-days)表示向后移动获取未来价格。 - -**涨停标签**:构建分类目标变量,首先识别当日涨幅超过5%的股票(`cat_up_limit`),然后计算过去5日内是否存在涨停(通过rolling窗口的max函数),再向后shift(5)避免未来信息泄露,最后fillna(0)并转换为整数类型。 - -**过滤异常收益率**:使用`between`方法保留1%到99%分位之间的收益率,去除极端值。 - ---- - -## 第九部分:特征列筛选 - -### 9.1 特征列筛选逻辑 - -第十七个代码单元进行特征列的最终筛选。通过一系列排除规则,从所有可用列中筛选出建模所需的特征列: - -**排除非特征列**:排除`trade_date`、`ts_code`、`label`等标识列;排除包含`future`、`label`、`score`、`gen`、`is_st`、`pe_ttm`、`circ_mv`、`code`等关键词的列;排除原始基础数据列`origin_columns`;排除下划线开头的临时列。 - -**排除特定特征**:手动排除存在问题的特征如`intraday_lg_flow_corr_20`、`cap_neutral_cost_metric`、`hurst_net_mf_vol_60`等;排除财务因子`roa`和`roe`。 - -最终筛选出191个特征列用于建模。 - ---- - -## 第十部分:缺失值处理与标准化流程 - -### 10.1 缺失值填充 - -第十八个代码单元将所有特征的缺失值填充为0。这一简化处理的假设是:缺失值可能代表数据不可用或极小概率事件,用0填充对模型影响相对中性。 - -### 10.2 完整的预处理流水线 - -第十九个代码单元展示了完整的预处理流程执行: - -**第一步:特征列确定**——合并行业数据、指数数据后确定最终特征列表。 - -**第二步:MAD去极值**——执行第一轮截面MAD去极值,处理全量特征。 - -**第三步:市值中性化**——对第一轮去极值后的特征进行截面市值中性化。 - -**第四步:第二轮MAD去极值**——对中性化后的特征再次进行去极值处理。 - -**第五步:Z-Score标准化**——最后对所有特征进行截面Z-Score标准化。 - -整个流水线确保了最终进入模型的特征具有:稳定的分布(去极值)、去除市值偏差(中性化)、可比性(标准化)。 - ---- - -## 总结 - -`Classify2_load_model.ipynb`实现了一个完整、规范的量化因子工程流程。整个流程涵盖了数据加载、多源数据合并、因子计算、特征筛选、预处理标准化等关键步骤。 - -**数据层面**:整合了日线行情、资金流、筹码分布、财务指标、行业分类、指数数据等多维度数据源,形成了超过180个特征的基础特征池。 - -**因子层面**:因子体系覆盖了资金流因子、筹码分布因子、技术指标因子、动量因子、波动率因子、截面排名因子等多个类别,形成了多角度、全方位的特征体系。 - -**预处理层面**:建立了完整的预处理流水线,包括MAD去极值、市值中性化、Z-Score标准化等标准化步骤,确保特征的质量和稳定性。 - -整个notebook的设计体现了量化投资特征工程的最佳实践,为后续的机器学习模型训练提供了高质量的数据基础。 diff --git a/docs/code_review_factors_20260222.md b/docs/code_review_factors_20260222.md deleted file mode 100644 index 83f3447..0000000 --- a/docs/code_review_factors_20260222.md +++ /dev/null @@ -1,227 +0,0 @@ -# 代码审查报告 - Factor 框架 - -**审查日期**: 2026-02-22 -**审查范围**: `src/factors/` 模块及测试代码 - ---- - -## 变更概述 - -| 类型 | 文件 | -|------|------| -| **已暂存** | `.kilocode/rules/project_rules.md` - Git 提交规范文档 | -| | `.kilocode/rules/python-development-guidelines.md` - Python 开发规范扩展 | -| **未跟踪** | `src/factors/` - 新增因子计算框架 | -| | `tests/factors/` - 对应测试文件 | -| | `docs/` - 文档目录 | - ---- - -## 文档变更(已暂存) - -✅ 无问题。Git 提交规范添加符合项目风格,格式清晰。 - ---- - -## 新增代码审查(factors 模块) - -### 1. 严重问题 - -#### 1.1 `engine.py:306-323` - 交易日偏移实现假设错误 - -```python -def _get_trading_date_offset(self, date: str, offset: int) -> str: - from datetime import datetime, timedelta - dt = datetime.strptime(date, "%Y%m%d") - new_dt = dt + timedelta(days=offset) - return new_dt.strftime("%Y%m%d") -``` - -**问题**:简单使用日历日偏移,假设每天都是交易日。A股市场有周末和节假日,这会导致: -- 偏移计算不准确 -- `lookback_days` 实际不等于交易天数 -- 可能加载过多或过少的历史数据 - -**建议**:使用真实的交易日历,或至少跳过周末。 - ---- - -#### 1.2 `base.py:137-147` - 乘法运算符类型检查不完整 - -```python -def __mul__(self, other): - if isinstance(other, (int, float)): - from src.factors.composite import ScalarFactor - return ScalarFactor(self, float(other), "*") - elif isinstance(other, BaseFactor): - from src.factors.composite import CompositeFactor - return CompositeFactor(self, other, "*") - return NotImplemented -``` - -**问题**:`float` 类型的负数会匹配 `int` 分支,但 `bool` 是 `int` 的子类,会被错误匹配: -```python -factor * True # 返回 ScalarFactor(factor, 1.0, "*") - 可能不是预期行为 -factor * False # 返回 ScalarFactor(factor, 0.0, "*") - 可能不是预期行为 -``` - -**建议**:显式排除 `bool` 类型: -```python -if isinstance(other, (int, float)) and not isinstance(other, bool): -``` - ---- - -#### 1.3 `engine.py:60-72` - compute 方法缺少必需参数验证 ✅ 已修复 - -**修复内容**: -- 为截面因子添加了 `start_date` 和 `end_date` 必填参数验证 -- 为时序因子添加了 `stock_codes`、`start_date`、`end_date` 必填参数验证 -- 参数缺失时抛出明确的 `ValueError`,指出缺少哪些参数 - -**修复代码**: -```python -if factor.factor_type == "cross_sectional": - if "start_date" not in kwargs or "end_date" not in kwargs: - raise ValueError( - "cross_sectional factor requires 'start_date' and 'end_date' parameters" - ) -elif factor.factor_type == "time_series": - missing = [] - if "stock_codes" not in kwargs: - missing.append("stock_codes") - if "start_date" not in kwargs: - missing.append("start_date") - if "end_date" not in kwargs: - missing.append("end_date") - if missing: - raise ValueError( - f"time_series factor requires parameters: {', '.join(missing)}" - ) -``` - ---- - -### 2. 中等问题 - -#### 2.1 `base.py:92-101` - 参数验证时机问题 ✅ 已修复 - -**修复内容**: -- 移除了 `__init__` 中自动调用的 `self._validate_params()` -- 更新了 `_validate_params()` 的文档,明确说明子类如需自定义验证,需自行在子类 `__init__` 中调用 -- 添加了关于 `data_specs` 必须在类级别定义的说明 - -**修复代码**: -```python -def __init__(self, **params): - """初始化因子参数 - - 注意:data_specs 必须在类级别定义(类属性), - 而非在 __init__ 中设置。data_specs 的验证在 - __init_subclass__ 中完成(类创建时)。 - """ - self.params = params - -def _validate_params(self): - """验证参数有效性 - - 子类可覆盖此方法进行自定义验证(需自行在子类 __init__ 中调用)。 - 基类实现为空,表示不执行任何验证。 - - 注意:由于 data_specs 在类创建时通过 __init_subclass__ 验证, - 不应在实例级别修改。如需动态 data_specs,请使用参数化模式: - ... - """ - pass -``` - ---- - -#### 2.2 `engine.py:161-169` - 静默修复长度不匹配 - -```python -if len(factor_values) != len(today_stocks): - # 尝试从 factor_data 重新提取 - cs_data = factor_data.get_cross_section() - if len(cs_data) > 0: - today_stocks = cs_data["ts_code"] - if len(factor_values) != len(today_stocks): - factor_values = pl.Series([None] * len(today_stocks)) # 静默填充 null! -``` - -**问题**:静默返回 null 值可能掩盖因子计算中的逻辑错误。开发者难以发现因子实现问题。 - -**建议**: -- 至少记录警告日志 -- 或在开发环境抛出异常 - ---- - -#### 2.3 `data_spec.py:14` - 使用 `frozen=True` 但通过 `object.__setattr__` 绕过 ✅ 已修复 - -**修复内容**: -- 更新了 `__post_init__` 的注释,准确说明 `frozen=True` 的含义 -- 说明本类仅做验证,无需修改字段,因此直接 raise ValueError 即可 -- 补充说明如需在 `__post_init__` 中修改字段,可使用 `object.__setattr__` - -**修复代码**: -```python -def __post_init__(self): - """验证约束条件 - - 验证项: - 1. lookback_days >= 1(至少包含当日) - 2. columns 必须包含 ts_code 和 trade_date - 3. source 不能为空字符串 - - 注意:由于 frozen=True,实例创建后不可修改。 - 若需要在 __post_init__ 中修改字段(如有),可使用 object.__setattr__。 - 本类仅做验证,无需修改字段,因此直接 raise ValueError 即可。 - """ - if self.lookback_days < 1: - raise ValueError(f"lookback_days must be >= 1, got {self.lookback_days}") -``` - ---- - -### 3. 轻微问题 / 优化建议 - -#### 3.1 `engine.py` - 缺少类型注解 - -`compute()` 方法返回类型注解为 `pl.DataFrame`,但 `_compute_cross_sectional` 和 `_compute_time_series` 返回类型未标注。 - -#### 3.2 `data_spec.py:42` - 默认值 `lookback_days=1` 语义 - -注释说 "包含当日",但 `lookback_days=1` 实际只包含 `[T]`,这与注释中 `lookback_days=5` 表示 `[T-4, T]` 一致。 - ---- - -## 测试覆盖 - -测试覆盖良好,共 97 个测试用例,覆盖: -- 因子基类验证 -- 组合因子运算 -- 数据加载器 -- 数据规格定义 -- 引擎执行逻辑 - ---- - -## 总结 - -| 状态 | 严重 | 中等 | 轻微 | -|------|------|------|------| -| 修复前 | 3 | 3 | 2 | -| 修复后 | 2 | 1 | 2 | - -**已修复问题**: -- ✅ 1.3 compute 方法参数验证 -- ✅ 2.1 参数验证时机问题 -- ✅ 2.3 frozen dataclass 注释 - -**待修复问题**: -- ⚠️ 1.1 交易日偏移实现(严重) -- ⚠️ 1.2 乘法运算符 bool 边界(严重) -- ⚠️ 2.2 静默修复长度不匹配(中等) - -Factor 框架整体设计良好,测试覆盖全面。建议修复剩余 2 个严重问题后再合并代码。 diff --git a/docs/db_sync_guide.md b/docs/db_sync_guide.md deleted file mode 100644 index d9e6a4f..0000000 --- a/docs/db_sync_guide.md +++ /dev/null @@ -1,267 +0,0 @@ -# DuckDB 数据同步指南 - -ProStock 现已从 HDF5 迁移至 DuckDB 存储。本文档介绍新的同步机制。 - -## 新功能概览 - -- **自动表创建**: 根据 DataFrame 自动推断表结构 -- **复合索引**: 自动为 `(trade_date, ts_code)` 创建复合索引 -- **增量同步**: 智能判断同步策略(按日期或按股票) -- **类型映射**: 预定义常见字段的数据类型 - -## 核心模块 - -### 1. TableManager - 表管理 - -```python -from src.data.db_manager import TableManager - -# 创建表管理器 -manager = TableManager() - -# 从 DataFrame 创建表(自动创建复合索引) -import pandas as pd -data = pd.DataFrame({ - "ts_code": ["000001.SZ"], - "trade_date": ["20240101"], - "close": [10.5], -}) - -manager.create_table_from_dataframe("daily", data) - -# 确保表存在(不存在则自动创建) -manager.ensure_table_exists("daily", sample_data=data) -``` - -### 2. IncrementalSync - 增量同步 - -```python -from src.data.db_manager import IncrementalSync - -sync = IncrementalSync() - -# 获取同步策略 -strategy, start, end, stocks = sync.get_sync_strategy( - table_name="daily", - start_date="20240101", - end_date="20240131", - stock_codes=None # None = 所有股票 -) - -# 返回值: -# - strategy: "by_date" | "by_stock" | "none" -# - start: 同步开始日期 -# - end: 同步结束日期 -# - stocks: 需要同步的股票列表(None = 全部) - -# 执行数据同步 -result = sync.sync_data("daily", data, strategy="by_date") -``` - -### 3. SyncManager - 高级同步 - -```python -from src.data.db_manager import SyncManager -from src.data.api_wrappers import get_daily - -# 创建同步管理器 -manager = SyncManager() - -# 一键同步(自动处理表创建、策略选择、数据获取) -result = manager.sync( - table_name="daily", - fetch_func=get_daily, # 数据获取函数 - start_date="20240101", - end_date="20240131", - stock_codes=["000001.SZ", "600000.SH"] # 可选:指定股票 -) - -print(result) -# { -# "status": "success", -# "table": "daily", -# "strategy": "by_date", -# "rows": 1000, -# "date_range": "20240101 to 20240131" -# } -``` - -## 便捷函数 - -### 快速同步数据 - -```python -from src.data.db_manager import sync_table -from src.data.api_wrappers import get_daily - -# 同步日线数据 -result = sync_table( - table_name="daily", - fetch_func=get_daily, - start_date="20240101", - end_date="20240131" -) -``` - -### 获取表信息 - -```python -from src.data.db_manager import get_table_info - -# 查看表统计信息 -info = get_table_info("daily") -print(info) -# { -# "exists": True, -# "row_count": 100000, -# "min_date": "20240101", -# "max_date": "20240131", -# "unique_stocks": 5000 -# } -``` - -### 确保表存在 - -```python -from src.data.db_manager import ensure_table - -# 如果表不存在,使用 sample_data 创建 -ensure_table("daily", sample_data=df) -``` - -## 同步策略详解 - -### 1. 按日期同步 (by_date) - -**适用场景**: 全市场数据同步、每日增量更新 - -**逻辑**: -- 表不存在 → 全量同步 -- 表存在但空 → 全量同步 -- 表存在且有数据 → 从 `last_date + 1` 开始增量同步 - -```python -# 示例: 表已有数据到 20240115 -strategy, start, end, stocks = sync.get_sync_strategy( - "daily", "20240101", "20240131" -) -# 返回: ("by_date", "20240116", "20240131", None) -# 只需同步 16-31 号的新数据 -``` - -### 2. 按股票同步 (by_stock) - -**适用场景**: 补充特定股票的历史数据 - -**逻辑**: -- 检查哪些请求的股票不存在于表中 -- 仅同步缺失的股票 - -```python -# 示例: 表中已有 000001.SZ,请求两只股票 -strategy, start, end, stocks = sync.get_sync_strategy( - "daily", "20240101", "20240131", - stock_codes=["000001.SZ", "600000.SH"] -) -# 返回: ("by_stock", "20240101", "20240131", ["600000.SH"]) -# 只同步缺失的 600000.SH -``` - -### 3. 无需同步 (none) - -**适用场景**: 数据已是最新 - -**触发条件**: -- 表存在且日期已覆盖请求范围 -- 所有请求的股票都已存在 - -## 完整示例 - -```python -from src.data.db_manager import SyncManager, get_table_info -from src.data.api_wrappers import get_daily - -# 1. 查看当前表状态 -info = get_table_info("daily") -print(f"当前数据: {info['row_count']} 行, 最新日期: {info['max_date']}") - -# 2. 创建同步管理器 -manager = SyncManager() - -# 3. 执行同步 -result = manager.sync( - table_name="daily", - fetch_func=get_daily, - start_date="20240101", - end_date="20240222" -) - -# 4. 检查结果 -if result["status"] == "success": - print(f"成功同步 {result['rows']} 行数据") - print(f"使用策略: {result['strategy']}") -elif result["status"] == "skipped": - print("数据已是最新,无需同步") -else: - print(f"同步失败: {result.get('error')}") -``` - -## 类型映射 - -默认字段类型映射: - -```python -DEFAULT_TYPE_MAPPING = { - "ts_code": "VARCHAR(16)", - "trade_date": "DATE", - "open": "DOUBLE", - "high": "DOUBLE", - "low": "DOUBLE", - "close": "DOUBLE", - "pre_close": "DOUBLE", - "change": "DOUBLE", - "pct_chg": "DOUBLE", - "vol": "DOUBLE", - "amount": "DOUBLE", - "turnover_rate": "DOUBLE", - "volume_ratio": "DOUBLE", - "adj_factor": "DOUBLE", - "suspend_flag": "INTEGER", -} -``` - -未定义字段会根据 pandas dtype 自动推断: -- `int` → `INTEGER` -- `float` → `DOUBLE` -- `bool` → `BOOLEAN` -- `datetime` → `TIMESTAMP` -- 其他 → `VARCHAR` - -## 索引策略 - -自动创建的索引: - -1. **主键**: `(ts_code, trade_date)` - 确保数据唯一性 -2. **复合索引**: `(trade_date, ts_code)` - 优化按日期查询性能 - -## 与旧代码的兼容性 - -原有 `Storage` 和 `ThreadSafeStorage` API 保持不变: - -```python -from src.data.storage import Storage, ThreadSafeStorage - -# 旧代码继续可用 -storage = Storage() -storage.save("daily", data) -df = storage.load("daily", start_date="20240101") -``` - -新增的功能通过 `db_manager` 模块提供。 - -## 性能建议 - -1. **批量写入**: 使用 `SyncManager` 自动处理批量写入 -2. **避免重复查询**: 使用 `get_table_info()` 检查现有数据 -3. **合理选择策略**: 全市场更新用 `by_date`,补充数据用 `by_stock` -4. **利用索引**: 查询时优先使用 `trade_date` 和 `ts_code` 过滤 diff --git a/docs/factor_design.md b/docs/factor_design.md deleted file mode 100644 index 46cb9d9..0000000 --- a/docs/factor_design.md +++ /dev/null @@ -1,707 +0,0 @@ - - -# 🚀 量化因子计算框架抽象设计与实施蓝图 - -## 一、 系统架构设计(四层解耦模型) - -本系统采用严格的分层架构,每一层只需关注自己的输入与输出,层与层之间通过标准化的数据结构(如抽象语法树、需求清单、物理执行图)进行通信。 - -### 1. 领域特定语言层(DSL Layer / 用户层) -* **职责**:提供对量化研究员极度友好的因子表达式编写接口,屏蔽所有底层计算引擎和数据库的痕迹。 -* **输入**:研究员编写的数学与逻辑表达式。 -* **输出**:纯粹的、无状态的**抽象语法树(AST)**。 -* **边界约束**:本层绝对不允许依赖任何外部数据处理库。它只负责描述“计算逻辑是什么”,不涉及“怎么算”和“数据在哪”。 - -### 2. 编译与分析层(Compiler Layer / 解析层) -* **职责**:接收 DSL 层生成的 AST,进行语法树分析与优化。 -* **核心动作 1:依赖提取**。遍历语法树,找出所有的“叶子节点”(即基础数据字段),生成全局数据需求清单。 -* **核心动作 2:图优化(可选)**。识别重复的子表达式结构,进行合并计算标记。 -* **输出**:结构化的数据依赖清单(Set/List)和经过校验的 AST。 - -### 3. 动态数据路由层(Data Router Layer / IO 层) -* **职责**:充当量化系统与底层多表数据库之间的桥梁。 -* **核心逻辑**:基于元数据字典(记录字段所属的数据库表及数据频度),将分析层传递的“数据需求清单”转化为对数据库的最优查询指令。 -* **输出**:在内存中组装好的、经过严格时间对齐与防未来函数处理的、极简的数据上下文(Data Context)。 - -### 4. 物理执行引擎层(Execution Engine / 计算层) -* **职责**:将抽象的计算逻辑映射到具体的硬件或高性能计算库(如 Polars/向量化引擎)上并执行。 -* **核心逻辑**:遍历 AST,将其翻译为物理引擎的执行算子。在这个翻译过程中,**系统隐式地强制注入量化计算的安全规则**(如截面分组、时序分组)。 -* **输出**:最终的因子计算结果(面板数据表)。 - ---- - -## 二、 核心机制的具体实现逻辑(非代码描述) - -为了让 AI 准确理解你的意图,你需要向 AI 阐明以下四个核心逻辑的运作机制: - -### 1. 表达式树的生成机制 (符号化运算) -* **逻辑说明**:定义基础的变量节点(代表底层字段)和操作节点(代表加减乘除或函数)。通过重载面向对象语言的原生运算符(如算术运算符、比较运算符),使得变量节点参与运算时,不会抛出错误或执行计算,而是生成一个新的、包含左右子节点和操作符的父节点。 -* **结果**:一个复杂的数学公式最终在内存里会变成一棵树状的数据结构。 - -### 2. 动态 SQL 生成与按需加载机制 -* **逻辑说明**:系统初始化时,加载一次数据库元数据(表名、列名、更新频率),形成路由字典。当收到需求清单时,系统不使用 `SELECT *`,而是通过路由字典找到字段对应的表,动态拼接 `SELECT [必要关联键], [需求字段] FROM [表名] WHERE[时间与股票池过滤]`。 -* **结果**:极大降低数据库的 I/O 压力和网络传输负载。 - -### 3. 数据对齐与防未来函数机制(极其重要) -数据在内存中合并时,必须根据表的“频度属性”采取不同的关联策略: -* **同行频表(如日频基础与日频行情)**:以基准时间轴为左表,严格按照 `[资产标识, 交易日]` 进行精确匹配连接。 -* **低频事件表(如财务报表)**:绝不能按自然日期或报告期关联。必须以“财报实际披露日”作为右表时间键,采用**“就近向后寻找匹配(Asof Join / Point-in-Time Join)”**策略。即某一天的财务数据,只能使用该日期之前(含当天)最新发布的那份财报。 -* **防错铁律**:拼表完成后,必须强制按照 `[资产标识, 交易日]` 的优先级进行升序排序,为后续的滑动窗口计算提供物理连续性保障。 - -### 4. 算子翻译与引擎方言注入机制 -物理层在将 AST 翻译为引擎执行图时,必须自动附加以下安全约束,这是研究员无需关心但系统必须保证的: -* **时序算子(如移动平均、动量)**:翻译时,必须向引擎下达强制指令——“本计算窗口必须被严格限制在单一资产的边界内”。 -* **截面算子(如截面排名、行业中性化)**:翻译时,必须向引擎下达强制指令——“本计算必须在同一个交易日切片内横向展开”。 - ---- - -## 三、 Vibe Coding 实施与 Prompt 投喂计划 - -在利用 AI 编写代码时,建议按照以下阶段逐步进行(可作为每个阶段发给 AI 的指令纲要): - -### 里程碑 1:构建抽象语法树引擎 (DSL & AST) -* **任务指派**:要求 AI 设计一套纯粹的表达式树数据结构。包含基础节点类、变量节点类、二元/一元操作节点类、以及函数调用节点类。 -* **验收标准**:通过重载运算符,可以随意组合变量(如 A, B, C),并且编写一个简单的打印函数,能够以可视化的方式(或 JSON 结构)输出这棵树的层次关系。绝不包含任何第三方数据处理库。 - -### 里程碑 2:实现依赖解析器 (Compiler) -* **任务指派**:要求 AI 编写一个树遍历器(如使用 Visitor 模式)。该遍历器接收里程碑 1 产生的树根节点,递归访问所有分支,收集所有叶子节点(变量节点)的名称。 -* **验收标准**:输入一个深层嵌套的复杂公式树,解析器能够准确、去重地返回该公式依赖的所有底层基础字段名称列表。 - -### 里程碑 3:构建元数据路由与动态组装器 (Data Router) -* **任务指派**:要求 AI 设计一个数据上下文管理器。 - 1. 实现注册机制,能接收不同表的数据字典和频度属性(日频或 PIT 低频)。 - 2. 根据里程碑 2 提取的依赖列表,自动分配表归属,并生成最小化拉取数据的伪代码或抽象 SQL 查询计划。 - 3. 阐明并在代码结构中实现不同频度数据的合并对齐逻辑(精确连接与就近前向连接),以及最后的全局强制排序逻辑。 -* **验收标准**:输入几个测试字段,管理器能正确输出不同表的查询指令清单,并展现合并逻辑的抽象流程。 - -### 里程碑 4:构建物理引擎翻译器 (Translator) -* **任务指派**:指定一个高性能计算库(如 Polars)。要求 AI 编写一个翻译层,接收里程碑 1 的树节点,递归转化为该计算库的原生表达式对象。 -* **验收约束**:在这个环节,要求 AI 必须在翻译时序函数时自动附加资产分组属性,在翻译截面函数时自动附加日期分组属性。 -* **验收标准**:输入的抽象树被成功转化为计算引擎可以识别的执行计划对象,且分组属性被正确挂载。 - -### 里程碑 5:系统顶层编排与端到端测试 (Orchestrator) -* **任务指派**:要求 AI 编写一个对外的 `FactorEngine` 类,作为系统的统一入口。 -* **执行流编排**:接收研究员的表达式 -> 调用编译器解析依赖 -> 调用路由器连接数据库拉取并组装核心宽表 -> 调用翻译器生成物理执行计划 -> 将计划提交给计算引擎执行并行运算。 -* **验收标准**:模拟少量的内存数据作为假数据库,完整跑通一条“从表达式注册,到自动按需取数,最终输出包含因子结果数据表”的全流程链路。 - ---- - - - -## 四、 详细设计规范(新增) - -### 4.1 五层架构总览 - -``` -┌─────────────────────────────────────────────────────────────────┐ -│ Layer 5: 编排层 (Orchestrator) │ -│ - FactorEngine: 统一入口 │ -│ - 协调各层工作流 │ -└─────────────────────────────────────────────────────────────────┘ - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Layer 4: 物理执行引擎层 (Execution Engine) │ -│ - PolarsTranslator: AST → Polars表达式 │ -│ - 自动注入分组约束(截面/时序) │ -│ - 执行计算并返回结果 │ -└─────────────────────────────────────────────────────────────────┘ - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Layer 3: 动态数据路由层 (Data Router) │ -│ - MetadataRegistry: 字段→表映射 │ -│ - QueryPlanner: 生成最优查询计划 │ -│ - DataAligner: PIT对齐与防未来函数处理 │ -└─────────────────────────────────────────────────────────────────┘ - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Layer 2: 编译与分析层 (Compiler) │ -│ - DependencyExtractor: 提取数据依赖 │ -│ - GraphOptimizer: 子表达式合并(预留接口) │ -│ - 输出: 数据需求清单 + 优化后的AST │ -└─────────────────────────────────────────────────────────────────┘ - ↓ -┌─────────────────────────────────────────────────────────────────┐ -│ Layer 1: DSL层 (领域特定语言) │ -│ - AST节点: Field, BinaryOp, UnaryOp, FunctionCall, Constant │ -│ - 算子库: ts_* (时序), cs_* (截面), math_* (数学) │ -│ - 运算符重载: +, -, *, /, >, <, == 等 │ -└─────────────────────────────────────────────────────────────────┘ -``` - ---- - -### 4.2 Layer 1: DSL层详细设计 - -#### 核心设计原则 -- **算子与数据解耦**:算子只描述计算逻辑,不绑定具体数据 -- **纯表达式树**:输出无状态的AST,不涉及任何外部库 -- **延迟执行**:表达式构建时不执行计算,只生成树结构 - -#### AST节点类型体系 - -```python -# 节点基类 -class ASTNode(ABC): - """AST节点基类""" - - @abstractmethod - def accept(self, visitor: "NodeVisitor") -> Any: - """接受访问者""" - pass - - @abstractmethod - def get_children(self) -> List["ASTNode"]: - """获取子节点列表""" - pass - -# 1. 字段节点(叶子节点) -class Field(ASTNode): - """ - 字段节点 - 代表底层数据字段 - 示例: close, volume, pe, pb - """ - name: str # 字段名 - dtype: Optional[str] = None # 数据类型提示 - -# 2. 常量节点(叶子节点) -class Constant(ASTNode): - """ - 常量节点 - 代表常量值 - 示例: 5, 10.5, "20240101" - """ - value: Union[int, float, str] - dtype: str - -# 3. 二元操作节点 -class BinaryOp(ASTNode): - """ - 二元操作节点 - 支持的运算符: +, -, *, /, //, %, **, >, >=, <, <=, ==, !=, &, | - """ - op: str # '+', '-', '*', '/', '>', etc. - left: ASTNode - right: ASTNode - -# 4. 一元操作节点 -class UnaryOp(ASTNode): - """ - 一元操作节点 - 支持的运算符: -, +, ~, abs - """ - op: str # '-', '+', '~', 'abs' - operand: ASTNode - -# 5. 函数调用节点 -class FunctionCall(ASTNode): - """ - 函数调用节点 - 代表算子调用 - 示例: ts_mean(close, 20), cs_rank(pe) - """ - name: str # 函数名 - args: List[ASTNode] - kwargs: Dict[str, Any] - func_type: str # "timeseries" | "cross_sectional" | "math" -``` - -#### 运算符重载规则 - -在 ASTNode 基类中实现运算符重载: - -```python -class ASTNode: - # 算术运算符 - def __add__(self, other) -> BinaryOp: - return BinaryOp("+", self, _ensure_node(other)) - - def __sub__(self, other) -> BinaryOp: - return BinaryOp("-", self, _ensure_node(other)) - - def __mul__(self, other) -> BinaryOp: - return BinaryOp("*", self, _ensure_node(other)) - - def __truediv__(self, other) -> BinaryOp: - return BinaryOp("/", self, _ensure_node(other)) - - # 反向运算符(支持 5 * field) - def __radd__(self, other) -> BinaryOp: - return BinaryOp("+", _ensure_node(other), self) - - def __rmul__(self, other) -> BinaryOp: - return BinaryOp("*", _ensure_node(other), self) - - # 比较运算符 - def __gt__(self, other) -> BinaryOp: - return BinaryOp(">", self, _ensure_node(other)) - - def __lt__(self, other) -> BinaryOp: - return BinaryOp("<", self, _ensure_node(other)) - - # 一元运算符 - def __neg__(self) -> UnaryOp: - return UnaryOp("-", self) -``` - -#### 算子库规范 - -算子按功能分为三类: - -| 前缀 | 类别 | 说明 | 示例 | -|------|------|------|------| -| `ts_` | 时序算子 | 在时间序列上计算,需按股票分组 | `ts_mean`, `ts_std`, `ts_sum` | -| `cs_` | 截面算子 | 在截面上计算,需按日期分组 | `cs_rank`, `cs_zscore`, `cs_percentile` | -| `math_` | 数学算子 | 逐元素计算,无需分组 | `math_log`, `math_exp`, `math_sqrt` | - -**时序算子列表(ts_*)**: -```python -ts_mean(field, window: int) # 移动平均 -ts_std(field, window: int) # 移动标准差 -ts_sum(field, window: int) # 移动求和 -ts_max(field, window: int) # 移动最大值 -ts_min(field, window: int) # 移动最小值 -ts_delta(field, period: int = 1) # 差分 -ts_pct_change(field, period: int = 1) # 百分比变化 -ts_corr(f1, f2, window: int) # 滚动相关系数 -``` - -**截面算子列表(cs_*)**: -```python -cs_rank(field) # 截面排名(0-1) -cs_percentile(field) # 截面分位数 -cs_zscore(field) # Z-Score标准化 -cs_mean(field) # 截面均值 -cs_std(field) # 截面标准差 -``` - -**数学算子列表(math_*)**: -```python -math_log(field) # 自然对数 -math_exp(field) # 指数 -math_sqrt(field) # 平方根 -math_abs(field) # 绝对值 -``` - -#### 表达式构建示例 - -```python -from src.factors.dsl import Field, ts_mean, cs_rank - -# ========== 示例 1: 简单移动平均线因子 ========== -close = Field("close") -ma20 = ts_mean(close, 20) -factor1 = ma20 - -# ========== 示例 2: 双均线差值因子 ========== -close = Field("close") -ma20 = ts_mean(close, 20) -ma5 = ts_mean(close, 5) -factor2 = (ma20 - ma5) / close - -# ========== 示例 3: 复杂多因子组合 ========== -close = Field("close") -volume = Field("volume") -pe = Field("pe") - -price_momentum = ts_pct_change(close, 20) -vol_ma = ts_mean(volume, 20) -vol_ratio = volume / vol_ma -pe_rank = cs_rank(pe) - -factor3 = price_momentum * 0.4 + vol_ratio * 0.3 + pe_rank * 0.3 -``` - ---- - -### 4.3 Layer 2: 编译层详细设计 - -#### 依赖提取器 - -```python -class DependencyExtractor(NodeVisitor): - """ - 依赖提取器 - 遍历AST收集数据依赖 - 输出: DataRequirement - - fields: Set[str] 需要的字段列表 - - min_lookback: Dict[str, int] 每个字段的最小回看天数 - """ - - def __init__(self): - self.fields: Set[str] = set() - self.field_lookback: Dict[str, int] = defaultdict(int) - - def visit_field(self, node: Field) -> None: - """记录字段依赖""" - self.fields.add(node.name) - self.field_lookback[node.name] = max( - self.field_lookback[node.name], 1 - ) - - def visit_function_call(self, node: FunctionCall) -> None: - """处理函数调用,提取窗口参数""" - for arg in node.args: - arg.accept(self) - - if node.func_type == "timeseries": - window = self._extract_window(node) - self._update_lookback(node.args[0], window) - - def extract(self, root: ASTNode) -> DataRequirement: - """执行提取""" - root.accept(self) - return DataRequirement( - fields=self.fields, - lookback=dict(self.field_lookback) - ) -``` - -#### 数据需求规格 - -```python -@dataclass -class DataRequirement: - """ - 数据需求规格 - - 属性: - fields: 需要的字段集合 - lookback: 每个字段需要回看的天数 - date_range: 计算日期范围 (start, end) - """ - fields: Set[str] - lookback: Dict[str, int] - date_range: Optional[Tuple[str, str]] = None - - def get_max_lookback(self) -> int: - """获取最大回看天数""" - return max(self.lookback.values()) if self.lookback else 1 -``` - ---- - -### 4.4 Layer 3: 数据路由层详细设计 - -#### 元数据注册表 - -```python -@dataclass -class FieldMetadata: - """ - 字段元数据 - - 属性: - name: 字段名 - table: 所属表名 - dtype: 数据类型 - freq: 数据频度 ("daily", "quarterly", "pit") - announce_date_field: 公告日字段名(PIT数据使用) - """ - name: str - table: str - dtype: str - freq: str - announce_date_field: Optional[str] = None - -class MetadataRegistry: - """ - 元数据注册表 - 管理字段到表的映射 - 单例模式,系统启动时加载配置 - """ - - def register(self, metadata: FieldMetadata) -> None: - """注册字段元数据""" - pass - - def get_table(self, field: str) -> str: - """获取字段所属表""" - pass - - def group_by_table(self, fields: Set[str]) -> Dict[str, Set[str]]: - """按表分组字段""" - pass -``` - -#### PIT对齐策略 - -```python -class DataAligner: - """ - 数据对齐器 - 处理多表数据合并与PIT对齐 - """ - - def align( - self, - dataframes: Dict[str, pl.DataFrame], - plans: List[QueryPlan] - ) -> pl.DataFrame: - """ - 对齐并合并多个数据表 - - 步骤: - 1. 分离日频表和PIT表 - 2. 日频表直接join - 3. PIT表使用asof join - 4. 最终排序 - """ - pass - - def _asof_join( - self, - left: pl.DataFrame, - right: pl.DataFrame, - announce_date_field: str - ) -> pl.DataFrame: - """ - 执行PIT asof join - 策略: 对于每个交易日,使用最新公告的数据 - """ - return left.join_asof( - right, - left_on="trade_date", - right_on=announce_date_field, - by="ts_code", - strategy="backward" - ) -``` - ---- - -### 4.5 Layer 4: 执行引擎层详细设计 - -#### Polars翻译器 - -```python -class PolarsTranslator(NodeVisitor): - """ - Polars翻译器 - 将AST翻译为Polars表达式 - """ - - def __init__(self, df: pl.LazyFrame): - self.df = df - - def translate(self, root: ASTNode) -> pl.Expr: - """翻译AST为Polars表达式""" - return root.accept(self) - - def visit_field(self, node: Field) -> pl.Expr: - """字段 → pl.col()""" - return pl.col(node.name) - - def visit_binary_op(self, node: BinaryOp) -> pl.Expr: - """二元操作 → Polars运算符""" - left = node.left.accept(self) - right = node.right.accept(self) - - ops = { - "+": lambda a, b: a + b, - "-": lambda a, b: a - b, - "*": lambda a, b: a * b, - "/": lambda a, b: a / b, - } - - return ops[node.op](left, right) - - def visit_function_call(self, node: FunctionCall) -> pl.Expr: - """ - 函数调用 → Polars窗口函数 - 关键:根据func_type注入分组约束 - """ - args = [arg.accept(self) for arg in node.args] - impl = self._get_impl(node.name) - - if node.func_type == "timeseries": - return impl(*args).over("ts_code") - elif node.func_type == "cross_sectional": - return impl(*args).over("trade_date") - else: - return impl(*args) -``` - -#### 分组约束注入规则 - -```python -# 时序算子:按股票分组,确保滚动窗口不跨股票 -def inject_timeseries_constraint(expr: pl.Expr) -> pl.Expr: - return expr.over("ts_code") - -# 截面算子:按日期分组,确保排名在每天内部进行 -def inject_cross_sectional_constraint(expr: pl.Expr) -> pl.Expr: - return expr.over("trade_date") -``` - ---- - -### 4.6 Layer 5: 编排层详细设计 - -#### FactorEngine - -```python -class FactorEngine: - """ - 因子执行引擎 - 系统统一入口 - """ - - def __init__( - self, - data_source: DataSource, - registry: MetadataRegistry - ): - self.data_source = data_source - self.registry = registry - self.compiler = Compiler() - self.planner = QueryPlanner(registry) - self.aligner = DataAligner() - - def compute( - self, - expression: ASTNode, - start_date: str, - end_date: str, - stock_codes: Optional[List[str]] = None - ) -> pl.DataFrame: - """ - 计算因子表达式 - - 执行流程: - 1. 编译:提取数据依赖 - 2. 规划:生成查询计划 - 3. 加载:从数据源获取数据 - 4. 对齐:PIT对齐与合并 - 5. 翻译:AST → Polars表达式 - 6. 执行:计算并返回结果 - """ - # Step 1: 编译 - requirement = self.compiler.extract_dependency(expression) - requirement.date_range = (start_date, end_date) - - # Step 2: 规划 - plans = self.planner.plan(requirement) - - # Step 3: 加载 - raw_data = {} - for plan in plans: - df = self.data_source.load(...) - raw_data[plan.table] = df - - # Step 4: 对齐 - aligned_data = self.aligner.align(raw_data, plans) - - # Step 5: 翻译 - translator = PolarsTranslator(aligned_data.lazy()) - polars_expr = translator.translate(expression) - - # Step 6: 执行 - result = aligned_data.with_columns( - polars_expr.alias("factor_value") - ) - - return result -``` - ---- - -## 五、 实施路线图(详细版) - -### 阶段1: 基础架构(Layer 1 + Layer 2) -**目标**: 实现DSL表达式树和依赖提取 - -**任务清单**: -- [ ] 实现AST节点类(Field, Constant, BinaryOp, UnaryOp, FunctionCall) -- [ ] 实现运算符重载 -- [ ] 实现基础算子库(ts_mean, ts_std, cs_rank等) -- [ ] 实现DependencyExtractor -- [ ] 编写单元测试 - -**验收标准**: -```python -close = Field("close") -factor = ts_mean(close, 20) / close - -deps = extract_dependencies(factor) -assert deps.fields == {"close"} -assert deps.lookback == {"close": 20} -``` - -### 阶段2: 数据层(Layer 3) -**目标**: 实现元数据管理和PIT对齐 - -**任务清单**: -- [ ] 实现MetadataRegistry -- [ ] 实现QueryPlanner -- [ ] 实现DataAligner(含asof join) -- [ ] 集成DuckDB数据源 - -### 阶段3: 执行层(Layer 4) -**目标**: 实现Polars翻译和执行 - -**任务清单**: -- [ ] 实现PolarsTranslator -- [ ] 实现算子到Polars的映射 -- [ ] 实现分组约束注入 - -### 阶段4: 编排层(Layer 5) -**目标**: 实现FactorEngine统一入口 - -**任务清单**: -- [ ] 实现FactorEngine -- [ ] 整合各层组件 -- [ ] 编写端到端测试 - ---- - -## 六、 关键设计决策 - -### 6.1 为什么使用Visitor模式? -- **扩展性**: 新增节点类型只需添加visit方法 -- **分离关注点**: 遍历逻辑与处理逻辑分离 -- **类型安全**: 每个节点类型有明确的处理函数 - -### 6.2 为什么算子需要分类(ts_/cs_/math_)? -- **显式分组**: 用户明确知道计算维度 -- **约束注入**: 系统根据前缀自动注入正确的分组 -- **错误预防**: 避免截面/时序算子混用导致的逻辑错误 - -### 6.3 向后兼容性 -**决策**: 完全重构,不保留旧API - -**理由**: -- 新旧架构差异过大(绑定vs解耦) -- 保持旧API会增加维护负担 -- 量化策略代码通常是一次性编写,迁移成本可控 - ---- - -## 七、 附录 - -### A. 完整算子列表 - -**时序算子 (ts_*)**: ts_mean, ts_std, ts_var, ts_sum, ts_max, ts_min, ts_product, ts_median, ts_argmax, ts_argmin, ts_skew, ts_kurt, ts_delta, ts_pct_change, ts_corr, ts_cov, ts_rank - -**截面算子 (cs_*)**: cs_rank, cs_percentile, cs_zscore, cs_mean, cs_std, cs_median, cs_max, cs_min - -**数学算子 (math_*)**: math_log, math_log1p, math_exp, math_sqrt, math_abs, math_sign, math_power - -### B. 元数据配置示例 - -```python -METADATA = [ - {"name": "close", "table": "daily", "dtype": "float64", "freq": "daily"}, - {"name": "volume", "table": "daily", "dtype": "float64", "freq": "daily"}, - {"name": "pe", "table": "daily", "dtype": "float64", "freq": "daily"}, - {"name": "eps", "table": "financial_income", "dtype": "float64", - "freq": "pit", "announce_date_field": "ann_date"}, -] -``` - -### C. 与现有代码对比 - -| 维度 | 现有实现 | 新设计 | -|------|---------|--------| -| 因子定义 | 类继承 | 表达式 | -| 数据绑定 | data_specs硬编码 | 元数据注册表 | -| 组合方式 | CompositeFactor包装 | AST节点自然组合 | -| 执行时机 | 立即执行 | 延迟执行 | -| 防泄露 | 手动控制 | 自动注入分组约束 | -| 可优化性 | 低 | 高 | - ---- - -**文档版本**: 2.0 | **更新日期**: 2026-02-26 diff --git a/docs/factor_framework_design.md b/docs/factor_framework_design.md deleted file mode 100644 index b0763a5..0000000 --- a/docs/factor_framework_design.md +++ /dev/null @@ -1,1303 +0,0 @@ -# ProStock 因子计算框架设计文档 - -## 1. 设计目标 - -- **安全性**:在框架层面彻底防止数据泄露(使用未来数据) -- **易用性**:因子开发者只需关注计算逻辑,无需担心数据安全 -- **可扩展性**:支持日期截面、股票截面、交叉因子三种计算模式 -- **组合性**:支持因子组合和嵌套 -- **性能**:合理利用 Polars 的高效计算能力 - -## 1.1 核心原则 - -### 原则 1:因子类型单一性 -每个因子**只能是一种类型**(日期截面 或 股票截面),不允许同时支持两种模式。这确保: -- 因子语义清晰明确 -- 数据访问模式可预测 -- 便于框架进行正确的数据裁剪 - -### 原则 2:Point-in-Time 严格性 -对于任意计算点 `T`(特定日期或特定股票-日期组合): -- 因子**只能访问 `T` 及之前的数据** -- **绝对禁止访问 `T` 之后的任何数据** -- 每个计算点都是独立、隔离的计算上下文 - -这类似于数据库的 "as-of" 查询语义:"在当时那个时刻,我能看到什么数据?" - ---- - -## 2. 架构概述 - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Factor Engine (执行引擎) │ -│ ┌──────────────┐ ┌──────────────┐ ┌────────────────────┐ │ -│ │ DAG Builder │ │ Lookback │ │ Parallel Executor │ │ -│ │ (依赖图构建) │ │ Validator │ │ (并行计算) │ │ -│ └──────────────┘ └──────────────┘ └────────────────────┘ │ -└──────────────────────────┬──────────────────────────────────┘ - │ -┌──────────────────────────▼──────────────────────────────────┐ -│ DataLoader (数据加载层) │ -│ ┌────────────────┐ ┌────────────────┐ ┌────────────────┐ │ -│ │ Multi-File │ │ Column │ │ Lookback │ │ -│ │ Aggregation │ │ Selector │ │ Window Control │ │ -│ └────────────────┘ └────────────────┘ └────────────────┘ │ -└──────────────────────────┬──────────────────────────────────┘ - │ - ┌──────▼──────┐ - │ HDF5 Files │ - └─────────────┘ -``` - ---- - -## 3. 核心组件设计 - -### 3.1 数据类型定义 - -```python -from dataclasses import dataclass -from typing import Dict, List, Optional -import polars as pl - -@dataclass(frozen=True) -class DataSpec: - """数据需求规格说明""" - source: str # H5 文件名(不含扩展名) - columns: List[str] # 需要的列名 - lookback_days: int = 0 # 回看窗口(用于时序计算) - -@dataclass -class FactorContext: - """ - 因子计算上下文(由框架自动注入) - - 根据因子类型的不同,包含不同的上下文信息: - - CrossSectionalFactor:current_date 表示当前计算的日期 - - TimeSeriesFactor:current_stock 表示当前计算的股票 - """ - current_date: Optional[str] = None # 当前计算日期 YYYYMMDD(截面因子) - current_stock: Optional[str] = None # 当前计算股票代码(时序因子) - trade_dates: List[str] = None # 所有交易日期列表(用于对齐) - -class FactorData: - """ - 提供给因子的数据容器 - - 根据因子类型的不同,包含不同的数据: - - CrossSectionalFactor:当前日期及历史 lookback 的截面数据(所有股票) - - TimeSeriesFactor:单只股票的完整时间序列数据 - """ - def __init__(self, df: pl.DataFrame, context: FactorContext): - self._df = df - self._context = context - - def get_column(self, col: str) -> pl.Series: - """ - 获取指定列的数据 - - 适用于两种因子类型: - - 截面因子:获取当天所有股票的该列值 - - 时序因子:获取该股票时间序列的该列值 - """ - return self._df[col] - - def filter_by_date(self, date: str) -> "FactorData": - """ - 按日期过滤数据(主要用于截面因子) - - 截面因子可以使用此方法获取特定日期的数据 - 但注意:无法获取未来日期的数据(引擎已经裁剪掉) - """ - filtered = self._df.filter(pl.col("trade_date") == date) - return FactorData(filtered, self._context) - - def to_polars(self) -> pl.DataFrame: - """获取底层的 Polars DataFrame(高级用法)""" - return self._df - - @property - def context(self) -> FactorContext: - """获取计算上下文""" - return self._context -``` - -### 3.2 因子基类(按类型严格分离) - -```python -from abc import ABC, abstractmethod -from typing import TypeVar, Generic, Literal -import polars as pl - -FactorType = Literal["cross_sectional", "time_series"] - -class BaseFactor(ABC): - """ - 因子基类 - 定义通用接口 - - 设计原则: - 1. 类型单一性:每个因子只能是 cross_sectional 或 time_series 之一 - 2. 声明式依赖:通过类属性声明所需数据和回看窗口 - 3. 防泄露保障:根据因子类型,在框架层面防止不同的泄露 - 4. 参数化支持:通过 __init__ 参数实现因子变体 - """ - - # ========== 必须声明的类属性 ========== - name: str = "" # 因子名称(唯一标识) - factor_type: FactorType # 因子类型(强制指定) - data_specs: List[DataSpec] = [] # 数据需求规格 - - # ========== 可选声明的类属性 ========== - category: str = "default" # 因子分类(用于组织管理) - description: str = "" # 因子描述 - - def __init_subclass__(cls, **kwargs): - """子类创建时验证必须属性""" - super().__init_subclass__(**kwargs) - if not cls.name: - raise ValueError(f"Factor {cls.__name__} must define 'name'") - if not cls.factor_type: - raise ValueError(f"Factor {cls.__name__} must define 'factor_type'") - if not cls.data_specs: - raise ValueError(f"Factor {cls.__name__} must define 'data_specs'") - - def __init__(self, **params): - """初始化因子参数""" - self.params = params - self._validate_params() - - def _validate_params(self): - """验证参数有效性(子类可覆盖)""" - pass - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """核心计算逻辑(由子类实现)""" - pass - - # ========== 因子组合运算符 ========== - def __add__(self, other: 'BaseFactor') -> 'CompositeFactor': - """因子相加:f1 + f2(要求同类型)""" - return CompositeFactor(self, other, '+') - - def __sub__(self, other: 'BaseFactor') -> 'CompositeFactor': - """因子相减:f1 - f2(要求同类型)""" - return CompositeFactor(self, other, '-') - - def __mul__(self, other: 'BaseFactor') -> 'CompositeFactor': - """因子相乘:f1 * f2(要求同类型)""" - return CompositeFactor(self, other, '*') - - def __truediv__(self, other: 'BaseFactor') -> 'CompositeFactor': - """因子相除:f1 / f2(要求同类型)""" - return CompositeFactor(self, other, '/') - - def __rmul__(self, scalar: float) -> 'ScalarFactor': - """标量乘法:0.5 * f1""" - return ScalarFactor(self, scalar, '*') - - -class CrossSectionalFactor(BaseFactor): - """ - 日期截面因子基类 - - 计算逻辑:在每个交易日,对所有股票进行横向计算 - - 防泄露边界: - - ❌ 禁止访问未来日期的数据(日期泄露) - - ✅ 允许访问当前日期的所有股票数据(股票间比较是正常的) - - 数据传入: - - compute() 接收的是单日的截面数据(所有股票在该日期的数据) - - 包含 lookback_days 的历史截面数据(用于时序计算后再截面比较) - - 性能优化: - - 按日期遍历,每天计算一次 - - 不需要重复计算,每天独立计算 - """ - - factor_type: FactorType = "cross_sectional" - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """ - 计算截面因子值 - - Args: - data: FactorData,包含当前日期及之前 lookback_days 的截面数据 - 格式:DataFrame[ts_code, trade_date, col1, col2, ...] - - Returns: - pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量) - - 示例: - def compute(self, data): - # 获取当前日期的截面(已经过滤到当前日期,无未来数据) - cs = data.get_cross_section() - # 计算市值排名(在同一天的股票间比较) - return cs['market_cap'].rank() - """ - pass - - -class TimeSeriesFactor(BaseFactor): - """ - 股票截面因子基类(时间序列因子) - - 计算逻辑:对每只股票,在其时间序列上进行纵向计算 - - 防泄露边界: - - ❌ 禁止访问其他股票的数据(股票泄露) - - ✅ 允许访问该股票的完整历史数据(时序计算需要历史数据) - - 数据传入: - - compute() 接收的是单只股票的完整时间序列数据 - - 包含该股票在 [start_date, end_date] 范围内的所有数据 - - 性能优化: - - 按股票遍历,每只股票一次性计算整个时间序列 - - 使用 Polars 的向量化计算(如 rolling_mean),高效批量计算 - - 无重复计算问题 - """ - - factor_type: FactorType = "time_series" - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """ - 计算时间序列因子值 - - Args: - data: FactorData,包含单只股票的完整时间序列 - 格式:DataFrame[ts_code, trade_date, col1, col2, ...] - 该股票的所有历史数据都已加载 - - Returns: - pl.Series: 该股票在各日期的因子值(长度 = 日期数量) - - 示例: - def compute(self, data): - # 获取该股票的价格序列(该股票的完整历史) - series = data.get_series() - # 一次性计算整个序列的移动平均(高效) - return series.rolling_mean(window_size=self.params['period']) - """ - pass - - -class CompositeFactor(BaseFactor): - """组合因子 - 用于实现因子间的数学运算(要求同类型)""" - - def __init__(self, left: BaseFactor, right: BaseFactor, op: str): - # 验证类型一致性 - if left.factor_type != right.factor_type: - raise ValueError( - f"Cannot combine factors of different types: " - f"{left.factor_type} vs {right.factor_type}" - ) - - self.left = left - self.right = right - self.op = op - self.factor_type = left.factor_type - self.name = f"({left.name}_{op}_{right.name})" - - # 合并数据需求 - self.data_specs = self._merge_data_specs() - - def _merge_data_specs(self) -> List[DataSpec]: - """合并左右因子的数据需求(取最大 lookback)""" - # ... 合并逻辑 - pass - - def compute(self, data: FactorData) -> pl.Series: - """执行组合运算""" - left_values = self.left.compute(data) - right_values = self.right.compute(data) - - ops = { - '+': lambda a, b: a + b, - '-': lambda a, b: a - b, - '*': lambda a, b: a * b, - '/': lambda a, b: a / b, - } - return ops[self.op](left_values, right_values) - - -class ScalarFactor(BaseFactor): - """标量运算因子""" - - def __init__(self, factor: BaseFactor, scalar: float, op: str): - self.factor = factor - self.scalar = scalar - self.op = op - self.factor_type = factor.factor_type - self.name = f"({scalar}_{op}_{factor.name})" - self.data_specs = factor.data_specs - - def compute(self, data: FactorData) -> pl.Series: - values = self.factor.compute(data) - if self.op == '*': - return values * self.scalar - elif self.op == '+': - return values + self.scalar - # ... 其他运算 -``` - ---- - -## 4. 两种计算模式(严格分离的防泄露边界) - -### 4.1 核心设计原则 - -**防泄露边界与因子类型的对应关系:** - -| 因子类型 | 防止泄露 | 允许访问 | 计算方式 | -|---------|---------|---------|---------| -| **CrossSectionalFactor** | **日期泄露**(不能用未来日期) | 当天所有股票的数据 | 按日期遍历,每天计算 | -| **TimeSeriesFactor** | **股票泄露**(不能用其他股票) | 该股票的完整历史 | 按股票遍历,每只股票一次性计算 | - -**为什么这样设计?** - -1. **时序因子(如 MA5)**: - - 本质上需要历史时序数据来计算(前5天收盘价) - - 如果防止时序泄露(每次只传1天数据),会导致 O(N×L) 的重复计算 - - 更好的做法:传入整只股票序列,一次性计算整个时间序列的滚动平均 - - 需要防止的是**股票泄露**(不能用其他股票的数据来预测这只股票) - -2. **截面因子(如 PE 排名)**: - - 本质上是当天所有股票之间的相对比较 - - 如果传入多天的数据,容易误用未来日期的信息(如用明天的 PE 算今天的排名) - - 更好的做法:每天只传入当天的数据 - - 需要防止的是**日期泄露**(不能用未来日期的数据) - -### 4.2 日期截面因子(Cross-Sectional Factor) - -**计算方式**:在每个交易日,对所有股票进行横向计算 - -**典型因子**: -- 当日收益率排名 -- PE 行业分位数 -- 市值对数 -- 换手率排序 - -**防泄露边界 - 防止日期泄露:** -``` -对于日期 D 的计算(D 遍历 start_date 到 end_date): - -┌────────────────────────────────────────────────────────────┐ -│ 传入数据: │ -│ - trade_date = D 的所有股票数据 │ -│ - 以及 [D-lookback, D] 的历史数据(用于时序计算后截面) │ -│ │ -│ 禁止传入: │ -│ - trade_date > D 的任何数据 ❌ │ -│ - 即未来日期的数据绝对不可见 │ -│ │ -│ 允许访问: │ -│ - D 当天的所有股票数据 ✅ │ -│ - 股票间比较是正常的(如排名、分位数) │ -└────────────────────────────────────────────────────────────┘ -``` - -**引擎计算流程:** -```python -# 引擎层伪代码 -for current_date in date_range: - # 1. 加载当前日期及历史 lookback 的数据(不含未来) - day_data = load_data( - start_date=current_date - lookback, - end_date=current_date # 注意:不包含未来日期 - ) - - # 2. 传入因子计算 - factor_values = factor.compute(day_data) - # factor 只能看到 current_date 及之前的数据 - - # 3. 保存结果 - results[current_date] = factor_values -``` - -**示例**: -```python -class ReturnRankFactor(CrossSectionalFactor): - """当日收益率排名因子""" - name = "return_rank" - factor_type = "cross_sectional" - data_specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1 # 需要前一天的收盘价计算收益率 - ) - ] - category = "momentum" - - def compute(self, data: FactorData) -> pl.Series: - # data 包含当天及前1天的数据(无未来数据) - # 获取当前日期的截面 - today_data = data.get_cross_section() # 只返回当前日期的数据 - - # 计算当天收益率(需要当天和前一天的收盘价) - # 注意:因为 lookback_days=1,data 包含两天的数据 - returns = today_data["close"].pct_change() - - # 返回排名(在当天所有股票间排名) - return returns.rank() -``` - -### 4.3 时间序列因子(Time-Series Factor) - -**计算方式**:对每只股票,传入完整时间序列,一次性计算所有日期的因子值 - -**典型因子**: -- 20 日移动平均 -- 历史波动率 -- RSI 技术指标 -- MACD - -**防泄露边界 - 防止股票泄露:** -``` -对于股票 S 的计算(S 遍历所有股票): - -┌────────────────────────────────────────────────────────────┐ -│ 传入数据: │ -│ - ts_code = S 的完整时间序列 │ -│ - trade_date 在 [start_date, end_date] 范围内的所有数据 │ -│ │ -│ 禁止传入: │ -│ - 其他股票的数据 ❌ │ -│ - 即绝对不能混入其他股票的信息 │ -│ │ -│ 允许访问: │ -│ - 该股票的完整历史数据 ✅ │ -│ - 包括 start_date 之前 lookback_days 的数据 │ -└────────────────────────────────────────────────────────────┘ -``` - -**引擎计算流程:** -```python -# 引擎层伪代码 -for stock_code in stock_codes: - # 1. 加载该股票的完整时间序列(所有日期) - stock_data = load_stock_data( - ts_code=stock_code, - start_date=start_date - lookback, # 需要额外的历史数据计算初期值 - end_date=end_date - ) - - # 2. 传入因子计算(一次性计算整个序列) - factor_values = factor.compute(stock_data) - # factor 看到的是该股票的完整历史 - # 使用 Polars 的 rolling_mean 等向量化操作,高效计算 - - # 3. 保存结果 - results[stock_code] = factor_values -``` - -**性能优势(以 MA5 为例):** -```python -# ❌ 低效方式(Point-in-Time 逐个计算) -for date in dates: - data = load_data(date-5, date) # 加载6天数据 - ma = data["close"].mean() # 计算平均值 - # 时间复杂度: O(N × L),N=日期数, L=窗口长度 - -# ✅ 高效方式(向量化批量计算) -series = load_all_data() # 加载全部数据 -ma = series.rolling_mean(window_size=5) # 一次性计算整个序列 -# 时间复杂度: O(N),Polars 底层 Rust 优化 -``` - -**示例**: -```python -class MovingAverageFactor(TimeSeriesFactor): - """移动平均线因子""" - name = "ma" - factor_type = "time_series" - data_specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=20 # 需要20天历史数据计算初期的 MA - ) - ] - category = "technical" - - def __init__(self, period: int = 20): - super().__init__(period=period) - # 动态调整 lookback - self.data_specs[0].lookback_days = period - - def compute(self, data: FactorData) -> pl.Series: - # data 是该股票的完整时间序列(高效传入) - series = data.get_series(column="close") - - # 使用 Polars 的向量化 rolling_mean,一次性计算整个序列 - return series.rolling_mean(window_size=self.params["period"]) -``` - -### 4.4 交叉因子(Cross Factor)【预留设计】 - -**适用场景**:同时涉及时间序列和截面计算 - -**典型因子**: -- 某股票过去 20 天在行业内的涨幅排名变化 -- 个股波动率与市场波动率的比值 - -**设计挑战**: -这类因子需要同时访问: -1. **个股的时间序列数据**(时序因子特性) -2. **市场的截面数据**(截面因子特性) - -**解决方案:将交叉因子拆分为基础因子的组合** -```python -# 推荐方案:组合基础因子而非创建新类型 - -# 1. 先计算个股动量(时序因子) -stock_momentum = TimeSeriesMomentumFactor(period=20) - -# 2. 再计算市场平均动量(截面因子,每天计算市场均值) -market_momentum = CrossSectionalMeanFactor( - base_factor=stock_momentum, # 基于时序因子的结果 - aggregation="mean" -) - -# 3. 计算相对动量(时序因子间的比较) -# 注意:这里需要特殊处理,因为两个因子类型不同 -relative_momentum = RelativeStrengthFactor( - stock_factor=stock_momentum, - market_factor=market_momentum -) -``` - -**替代方案:专门的 CrossFactor 类型** -```python -class CrossFactor(BaseFactor): - """ - 交叉因子基类(预留) - - 同时需要时序和截面数据,但保持防泄露边界: - - 时序部分:只能看到该股票的历史(防止股票泄露) - - 截面部分:只能看到当前日期的数据(防止日期泄露) - """ - - factor_type = "cross" - - @abstractmethod - def compute_cross( - self, - stock_series: pl.Series, # 当前股票的时间序列(无其他股票) - market_section: pl.DataFrame # 当前日期的全市场截面(无未来日期) - ) -> float: - """交叉计算逻辑""" - pass -``` - ---- - -## 5. 防数据泄露机制(按因子类型区分) - -### 5.1 核心原则:不同类型的不同防泄露边界 - -**关键洞察**:不同类型的因子需要防止的泄露不同,应该采用不同的数据传入策略: - -| 因子类型 | 需要防止 | 数据传入策略 | 计算效率 | -|---------|---------|-------------|---------| -| **CrossSectionalFactor** | **日期泄露**(不能用未来日期的数据) | 每天传入当天的数据(含 lookback 历史) | 按天遍历,每天计算一次 | -| **TimeSeriesFactor** | **股票泄露**(不能用其他股票的数据) | 每只股票传入完整序列 | 按股票遍历,向量化计算,高效 | - -**为什么不需要同时防止两种泄露?** - -1. **时序因子的特点**: - - 计算 MA5 需要前5天的收盘价,这是**正常的计算需求**,不是泄露 - - 如果防止时序上的"未来数据",每次只能传1天数据,会导致重复计算 - - 真正需要防止的是用**其他股票**的数据来预测这只股票 - -2. **截面因子的特点**: - - 计算 PE 排名需要当天所有股票的 PE,这是**正常的计算需求**,不是泄露 - - 如果传入多天的数据,容易误用未来日期的 PE(如用明天的 PE 算今天的排名) - - 真正需要防止的是用**未来日期**的数据 - -### 5.2 实现策略 - -``` -┌────────────────────────────────────────────────────────────┐ -│ 因子定义阶段 │ -│ 1. 因子声明 required_columns 和 lookback_days │ -│ 2. 框架静态分析依赖关系 │ -└──────────────────────────┬─────────────────────────────────┘ - │ - ▼ -┌────────────────────────────────────────────────────────────┐ -│ 数据加载阶段 │ -│ 1. 根据声明加载所需数据(多文件聚合 + 列选择) │ -│ 2. 根据 lookback_days 计算每个日期所需的历史数据窗口 │ -└──────────────────────────┬─────────────────────────────────┘ - │ - ▼ -┌────────────────────────────────────────────────────────────┐ -│ 计算阶段(防泄露核心) │ -│ │ -│ 对于每个 (date, stock) 组合: │ -│ │ -│ 日期截面模式: │ -│ - 截取 date - lookback_days 到 date 的数据 │ -│ - 所有股票,仅该时间窗口 │ -│ │ -│ 股票截面模式: │ -│ - 截取该股票 date - lookback_days 到 date 的数据 │ -│ - 仅该股票,仅该时间窗口 │ -│ │ -│ 关键:因子代码只能看到已截断的数据,无法访问未来数据 │ -└────────────────────────────────────────────────────────────┘ -``` - -### 5.2 实现细节 - -```python -class DataLoader: - """数据加载器 - 负责安全的数据访问""" - - def __init__(self, data_dir: str): - self.data_dir = Path(data_dir) - self._cache: Dict[str, pl.DataFrame] = {} - - def load( - self, - specs: List[DataSpec], - date_range: Optional[Tuple[str, str]] = None - ) -> pl.DataFrame: - """ - 加载并聚合多个 H5 文件的数据 - - Args: - specs: 数据需求规格列表 - date_range: 日期范围限制 (start_date, end_date) - - Returns: - 合并后的 Polars DataFrame - """ - pass - - def get_safe_data( - self, - data: pl.DataFrame, - current_date: str, - lookback_days: int, - mode: str, # 'cross_sectional' | 'time_series' - stock_code: Optional[str] = None - ) -> FactorData: - """ - 获取安全的数据视图(防泄露核心) - - 对于日期 D 和回看窗口 L: - - 只返回 [D-L, D] 范围内的数据 - - 根据 mode 决定是返回截面还是单只股票数据 - """ - # 计算截断日期 - cutoff_start = self._get_trading_date_offset(current_date, -lookback_days) - cutoff_end = current_date - - # 截断数据 - safe_df = data.filter( - (pl.col("trade_date") >= cutoff_start) & - (pl.col("trade_date") <= cutoff_end) - ) - - if mode == "time_series": - # 只保留指定股票的数据 - safe_df = safe_df.filter(pl.col("ts_code") == stock_code) - - # 创建上下文 - context = FactorContext( - current_date=current_date, - current_stock=stock_code, - trade_dates=self._get_all_trade_dates() - ) - - return FactorData(safe_df, context) - - -class FactorEngine: - """ - 因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略 - - 核心职责: - 1. CrossSectionalFactor:防止日期泄露,每天传入当天的数据 - 2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列 - 3. 管理计算流程和结果组装 - """ - - def __init__(self, data_loader: DataLoader): - self.data_loader = data_loader - - def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame: - """ - 统一的计算入口,根据因子类型分发到具体方法 - """ - if factor.factor_type == "cross_sectional": - return self._compute_cross_sectional(factor, **kwargs) - elif factor.factor_type == "time_series": - return self._compute_time_series(factor, **kwargs) - else: - raise ValueError(f"Unknown factor type: {factor.factor_type}") - - def _compute_cross_sectional( - self, - factor: CrossSectionalFactor, - start_date: str, - end_date: str - ) -> pl.DataFrame: - """ - 执行日期截面计算 - - 防泄露策略: - - 防止日期泄露:每天只传入当天的数据(含 lookback 历史,但不含未来) - - 允许股票间比较:传入当天所有股票的数据 - - 计算方式: - - 按日期遍历 - - 每天计算一次,返回当天所有股票的因子值 - - 返回 DataFrame 格式: - ┌────────────┬──────────┬──────────────┐ - │ trade_date │ ts_code │ factor_name │ - ├────────────┼──────────┼──────────────┤ - │ 20240101 │ 000001.SZ│ 0.5 │ - │ 20240101 │ 000002.SZ│ 0.3 │ - └────────────┴──────────┴──────────────┘ - """ - # 计算实际需要加载的日期范围(考虑 lookback) - max_lookback = max(spec.lookback_days for spec in factor.data_specs) - data_start = self._get_trading_date_offset(start_date, -max_lookback) - - # 一次性加载所有数据(后续按天裁剪) - raw_data = self.data_loader.load(factor.data_specs, (data_start, end_date)) - - results = [] - - # 按日期遍历:每天计算一次 - for current_date in self._get_date_range(start_date, end_date): - - # 裁剪数据:只保留 current_date 及之前的数据(防止日期泄露) - # 但保留所有股票的数据(允许股票间比较) - day_data = raw_data.filter( - pl.col("trade_date") <= current_date - ) - - # 如果 lookback > 0,进一步裁剪到 lookback 窗口 - if max_lookback > 0: - lookback_start = self._get_trading_date_offset(current_date, -max_lookback) - day_data = day_data.filter(pl.col("trade_date") >= lookback_start) - - # 创建 FactorData(包含当天及历史数据,无未来数据) - context = FactorContext( - current_date=current_date, - trade_dates=self._get_all_trade_dates() - ) - factor_data = FactorData(day_data, context) - - # 计算因子值 - factor_values = factor.compute(factor_data) - - # 获取当前日期的股票列表 - today_stocks = day_data.filter( - pl.col("trade_date") == current_date - )["ts_code"] - - results.append(pl.DataFrame({ - "trade_date": [current_date] * len(today_stocks), - "ts_code": today_stocks, - factor.name: factor_values - })) - - return pl.concat(results) - - def _compute_time_series( - self, - factor: TimeSeriesFactor, - stock_codes: List[str], - start_date: str, - end_date: str - ) -> pl.DataFrame: - """ - 执行时间序列计算(股票截面) - - 防泄露策略: - - 防止股票泄露:每只股票单独计算,传入该股票的完整序列 - - 允许访问历史数据:时序计算需要历史数据,这是正常的 - - 计算方式: - - 按股票遍历 - - 每只股票一次性计算整个时间序列(向量化,高效) - - 性能优势: - - 使用 Polars 的 rolling_mean 等向量化操作 - - 无重复计算问题 - """ - max_lookback = max(spec.lookback_days for spec in factor.data_specs) - data_start = self._get_trading_date_offset(start_date, -max_lookback) - - # 加载所有数据 - all_data = self.data_loader.load(factor.data_specs, (data_start, end_date)) - - results = [] - - # 按股票遍历:每只股票一次性计算 - for stock_code in stock_codes: - # 过滤出该股票的数据(防止股票泄露) - stock_data = all_data.filter(pl.col("ts_code") == stock_code) - - if stock_data.is_empty(): - continue - - # 创建 FactorData(该股票的完整序列) - context = FactorContext( - current_stock=stock_code, - trade_dates=self._get_all_trade_dates() - ) - factor_data = FactorData(stock_data, context) - - # 一次性计算整个时间序列(向量化,高效) - factor_values = factor.compute(factor_data) - - # 获取该股票的日期列表 - stock_dates = stock_data["trade_date"] - - results.append(pl.DataFrame({ - "trade_date": stock_dates, - "ts_code": [stock_code] * len(stock_dates), - factor.name: factor_values - })) - - return pl.concat(results) -``` - -### 5.3 防泄露验证示例 - -**日期截面因子 - 防止日期泄露:** - -```python -# ❌ 错误的截面因子(试图访问未来日期) -class BadCrossSectionalFactor(CrossSectionalFactor): - name = "bad_cs" - factor_type = "cross_sectional" - data_specs = [DataSpec(source="daily", columns=["close"], lookback_days=5)] - - def compute(self, data: FactorData) -> pl.Series: - # 引擎传入的是当前日期的数据(假设今天是 20240110) - # data 包含 20240106-20240110(根据 lookback_days=5) - # 但绝对不包含 20240115 的数据 - - # 试图访问未来日期 - 会报错或返回空 - future_data = data.filter(pl.col("trade_date") == "20240115") - # 因为引擎已经裁剪掉未来数据,future_data 为空 - - return future_data["close"] # 错误! - -# ✅ 正确的截面因子 -class GoodCrossSectionalFactor(CrossSectionalFactor): - name = "pe_rank" - factor_type = "cross_sectional" - data_specs = [DataSpec(source="daily", columns=["pe"], lookback_days=0)] - - def compute(self, data: FactorData) -> pl.Series: - # data 只包含当前日期的数据(无未来日期) - # 可以安全地访问当天所有股票的 PE - pe_values = data.get_column("pe") - - # 计算排名(在当天股票间比较是正常的) - return pe_values.rank() -``` - -**时间序列因子 - 防止股票泄露:** - -```python -# ❌ 错误的时序因子(试图访问其他股票数据) -class BadTimeSeriesFactor(TimeSeriesFactor): - name = "bad_ts" - factor_type = "time_series" - data_specs = [DataSpec(source="daily", columns=["close"], lookback_days=20)] - - def compute(self, data: FactorData) -> pl.Series: - # 引擎传入的是单只股票的数据 - # 但因子试图加载其他股票的数据(泄露!) - - # 试图访问全局数据 - 这是框架要阻止的 - all_stocks_data = load_all_stocks() # 不应该允许! - market_mean = all_stocks_data["close"].mean() - - return data.get_series() / market_mean # 用了其他股票的数据! - -# ✅ 正确的时序因子 -class GoodTimeSeriesFactor(TimeSeriesFactor): - name = "ma20" - factor_type = "time_series" - data_specs = [DataSpec(source="daily", columns=["close"], lookback_days=20)] - - def compute(self, data: FactorData) -> pl.Series: - # data 是该股票的完整时间序列(无其他股票) - # 可以安全地计算时序指标(需要历史数据是正常的) - prices = data.get_column("close") - - # 计算 20 日移动平均(使用历史数据是正常的) - return prices.rolling_mean(window_size=20) -``` - ---- - -## 6. 目录结构 - -``` -src/ -├── factors/ -│ ├── __init__.py -│ ├── base.py # BaseFactor、CrossSectionalFactor、TimeSeriesFactor 基类 -│ ├── data_spec.py # DataSpec、FactorContext、FactorData -│ ├── data_loader.py # DataLoader 多文件聚合、列选择 -│ ├── engine.py # FactorEngine 执行引擎 -│ ├── composite.py # CompositeFactor、ScalarFactor 组合因子 -│ ├── cross_sectional.py # CrossSectionalEngine 日期截面计算 -│ ├── time_series.py # TimeSeriesEngine 股票截面计算 -│ └── builtin/ # 内置因子库 -│ ├── __init__.py -│ ├── momentum.py # 动量类因子(CrossSectional) -│ ├── technical.py # 技术指标类(TimeSeries) -│ ├── value.py # 价值类因子(CrossSectional) -│ └── volatility.py # 波动率类因子(TimeSeries) -└── data/ - └── storage.py # 已有模块(复用) -``` - ---- - -## 7. 使用示例 - -### 7.1 基础用法 - -```python -from src.factors import FactorEngine, DataLoader -from src.factors.builtin import MovingAverageFactor, ReturnRankFactor - -# 1. 初始化数据加载器和执行引擎 -data_loader = DataLoader(data_dir="data") -engine = FactorEngine(data_loader) - -# 2. 定义因子 -ma20 = MovingAverageFactor(period=20) -rank = ReturnRankFactor() - -# 3. 计算日期截面因子(自动识别因子类型) -result1 = engine.compute( - factor=rank, # CrossSectionalFactor 类型 - start_date="20240101", - end_date="20240131" -) - -# 4. 计算股票截面因子(自动识别因子类型) -result2 = engine.compute( - factor=ma20, # TimeSeriesFactor 类型 - stock_codes=["000001.SZ", "600000.SH"], - start_date="20240101", - end_date="20240131" -) -``` - -### 7.2 因子组合 - -```python -from src.factors.builtin import PEFactor, PBFactor - -# 价值复合因子:0.5 * PE + 0.3 * PB -pe = PEFactor() # CrossSectionalFactor -pb = PBFactor() # CrossSectionalFactor - -# 使用运算符重载组合(类型必须一致) -value_factor = 0.5 * pe + 0.3 * pb - -# 计算复合因子 -result = engine.compute( - factor=value_factor, # 仍然是 CrossSectionalFactor - start_date="20240101", - end_date="20240131" -) - -# ❌ 错误示例:不能组合不同类型的因子 -ma = MovingAverageFactor() # TimeSeriesFactor -# bad = pe + ma # ValueError: Cannot combine factors of different types -``` - -### 7.3 自定义因子 - -```python -from src.factors import TimeSeriesFactor, DataSpec, FactorData -import polars as pl - -class ReturnStdFactor(TimeSeriesFactor): - """自定义因子示例:20日收益率标准差""" - - name = "return_std_20" - factor_type = "time_series" # 明确指定类型 - data_specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=21 # 20天收益率需要21天收盘价 - ) - ] - category = "volatility" - description = "20日收益率标准差" - - def compute(self, data: FactorData) -> pl.Series: - # 获取当前股票的时间序列收盘价(传入的是完整序列) - prices = data.get_column("close") - - # 计算收益率 - returns = prices.pct_change() - - # 计算20日滚动标准差(向量化计算,高效) - std = returns.rolling_std(window_size=20) - - return std - -# 使用自定义因子 -my_factor = ReturnStdFactor() -result = engine.compute( - factor=my_factor, - stock_codes=["000001.SZ"], - start_date="20240101", - end_date="20240131" -) -``` - ---- - -## 8. 关键设计决策总结 - -| 决策点 | 选择 | 理由 | -|--------|------|------| -| **防泄露边界** | 按因子类型区分 | 时序防股票泄露,截面防日期泄露,各取所需 | -| **时序因子数据策略** | 传入完整序列,一次性计算 | 避免 O(N×L) 重复计算,利用 Polars 向量化性能 | -| **截面因子数据策略** | 每天传入当天数据 | 防止日期泄露,确保不会误用未来信息 | -| **因子类型** | 严格分离(CrossSectional vs TimeSeries) | 语义清晰,不同防泄露策略,便于框架优化 | -| **因子接口** | 类继承 + 抽象方法 | 强制规范,支持参数化,易于 IDE 提示 | -| **编程范式** | OOP(对比函数式) | 更好的类型检查、组合性和可维护性 | -| **数据返回** | Polars DataFrame/Series | 高性能,现代化 API | -| **组合机制** | 运算符重载(同类型可组合) | 直观易用,类型安全 | -| **缓存策略** | 暂不支持 | 先保证正确性,后续按需添加 | - -### 8.1 设计修正的关键洞察 - -**原设计的问题**: -- 试图对所有因子使用统一的 Point-in-Time 策略 -- 导致时序因子(如 MA5)需要按天裁剪数据,产生大量重复计算 -- 实际上时序因子需要历史数据是正常的,不应该被限制 - -**修正后的设计**: -- **时序因子**:防止股票泄露(不能用其他股票数据),允许访问完整历史 -- **截面因子**:防止日期泄露(不能用未来日期),每天只传入当天数据 -- 这样既保证了数据安全,又获得了计算性能 - -**类比理解**: -- 时序因子像"技术分析":只看这只股票自己的历史走势 -- 截面因子像"相对估值":只看当天所有股票的相对比较 -- 两者关注的泄露风险不同,应该采用不同的防护策略 | - ---- - -## 9. 后续扩展计划 - -### 阶段 2(近期) -- [ ] 因子结果缓存机制 -- [ ] 并行计算优化 -- [ ] 更多内置因子 - -### 阶段 3(中期) -- [ ] 交叉因子完整实现 -- [ ] 自定义数据源支持 -- [ ] 因子可视化工具 - -### 阶段 4(远期) -- [ ] 实时数据接入 -- [ ] 分布式计算支持 -- [ ] 机器学习因子自动生成 - ---- - -## 10. 命名约定 - -- **因子类名**:`{描述}Factor`,如 `MovingAverageFactor` -- **因子名称**(name 属性):`snake_case`,如 `"ma_20"` -- **数据源名**:与 H5 文件名一致,如 `"daily"`, `"fundamental"` -- **列名**:与数据源中的列名完全一致 -- **日期格式**:`YYYYMMDD` 字符串 - ---- - -*文档版本: v1.1* -*最后更新: 2026-02-21* - -**v1.1 更新说明**: -- 修正防泄露边界:时序因子防止股票泄露,截面因子防止日期泄露 -- 优化计算策略:时序因子传入完整序列一次性计算,避免重复计算 -- 明确不同类型因子的数据传入策略和防泄露重点 - ---- - -## 附录 A:函数式编程方案讨论 - -作为对比,以下是使用**装饰器 + 函数式编程**的替代设计方案。 - -### A.1 函数式方案设计 - -```python -from typing import Callable, List -import polars as pl -from dataclasses import dataclass - -@dataclass -class FactorDef: - """因子定义(由装饰器创建)""" - name: str - factor_type: str # 'cross_sectional' | 'time_series' - data_specs: List[DataSpec] - compute_func: Callable[[FactorData], pl.Series] - params: dict = None - - -def factor( - name: str, - factor_type: str, - data_specs: List[DataSpec], - category: str = "default", - description: str = "" -): - """ - 因子装饰器 - 将函数注册为因子 - - 使用示例: - @factor( - name="ma_20", - factor_type="time_series", - data_specs=[DataSpec(source="daily", columns=["close"], lookback_days=20)] - ) - def moving_average(data: FactorData, period: int = 20) -> pl.Series: - return data.get_series().rolling_mean(window_size=period) - """ - def decorator(func: Callable) -> FactorDef: - return FactorDef( - name=name, - factor_type=factor_type, - data_specs=data_specs, - compute_func=func, - params={} - ) - return decorator - - -# 参数化因子的函数式实现 -def parameterized_factor(base_factor: FactorDef): - """ - 创建参数化因子的工厂函数 - - 使用示例: - # 基础定义 - ma_base = factor(...)(moving_average) - - # 创建 MA5 和 MA20 - ma5 = parameterized_factor(ma_base)(period=5) - ma20 = parameterized_factor(ma_base)(period=20) - """ - def create_instance(**params) -> FactorDef: - # 创建新的 FactorDef,更新参数 - new_def = FactorDef( - name=f"{base_factor.name}_{'_'.join(f'{k}{v}' for k, v in params.items())}", - factor_type=base_factor.factor_type, - data_specs=base_factor.data_specs, # 可能需要根据 params 调整 lookback - compute_func=lambda data: base_factor.compute_func(data, **params), - params=params - ) - return new_def - return create_instance -``` - -### A.2 两种方案对比 - -| 维度 | 类继承方案(当前) | 函数式 + 装饰器方案 | -|------|-------------------|-------------------| -| **代码量** | 需要定义类,代码较多 | 函数 + 装饰器,代码较少 | -| **参数化** | `__init__` 直观自然 | 需要工厂函数或偏函数 | -| **IDE 支持** | 好,有类型提示和补全 | 一般,装饰器会丢失类型信息 | -| **组合性** | 运算符重载 (`+`, `-`, `*`) | 需要显式组合函数 | -| **可扩展性** | 继承机制成熟 | 函数组合灵活但复杂 | -| **学习成本** | 需要理解 OOP | 需要理解函数式编程 | -| **调试难度** | 类层次清晰,易调试 | 装饰器嵌套,调试较困难 | -| **元数据管理** | 类属性自然 | 需要额外的元数据结构 | - -### A.3 函数式方案示例 - -```python -# ========== 使用函数式方案 ========== - -# 1. 定义基础因子(使用装饰器) -@factor( - name="ma", - factor_type="time_series", - data_specs=[DataSpec(source="daily", columns=["close"], lookback_days=20)] -) -def ma_factor(data: FactorData, period: int) -> pl.Series: - return data.get_series().rolling_mean(window_size=period) - -# 2. 创建参数化实例 -ma5 = parameterized_factor(ma_factor)(period=5) -ma20 = parameterized_factor(ma_factor)(period=20) - -# 3. 组合因子(函数式方式) -def combine_factors(f1: FactorDef, f2: FactorDef, op: str) -> FactorDef: - """组合两个因子""" - def combined_compute(data: FactorData) -> pl.Series: - v1 = f1.compute_func(data) - v2 = f2.compute_func(data) - if op == '+': - return v1 + v2 - elif op == '*': - return v1 * v2 - # ... - - return FactorDef( - name=f"{f1.name}_{op}_{f2.name}", - factor_type=f1.factor_type, # 假设类型相同 - data_specs=merge_specs(f1.data_specs, f2.data_specs), - compute_func=combined_compute - ) - -# 使用 -value_factor = combine_factors( - combine_factors(pe_factor, pb_factor, '+'), - ps_factor, - '+' -) -``` - -### A.4 方案选择建议 - -**推荐使用类继承方案的情况**: -- 团队熟悉面向对象编程 -- 需要丰富的 IDE 支持和类型检查 -- 因子逻辑复杂,需要分层抽象 -- 需要频繁的组合和参数化操作 - -**考虑使用函数式方案的情况**: -- 团队偏好函数式编程风格 -- 因子逻辑简单,主要是数学运算 -- 希望减少样板代码 -- 需要高度灵活的组合方式 - -**当前项目选择类继承方案的理由**: -1. **Python 生态更熟悉 OOP**:大多数量化开发者更习惯类的方式 -2. **IDE 支持更好**:VSCode/PyCharm 对类属性和方法的提示更完善 -3. **参数化更自然**:`MA(period=20)` 比 `create_ma(period=20)` 更直观 -4. **运算符重载**:`f1 + f2` 比 `combine_factors(f1, f2, '+')` 更易读 -5. **可维护性**:类层次结构在长期维护中更清晰 diff --git a/docs/hdf5_to_duckdb_migration.md b/docs/hdf5_to_duckdb_migration.md deleted file mode 100644 index a03a96c..0000000 --- a/docs/hdf5_to_duckdb_migration.md +++ /dev/null @@ -1,1072 +0,0 @@ -# ProStock HDF5 到 DuckDB 迁移方案 - -**文档版本**: v1.1 -**创建日期**: 2026-02-22 -**完成日期**: 2026-02-22 -**状态**: ✅ 已完成 -**影响范围**: data 模块、factors 模块、相关文档 - -## 相关文档 - - [DuckDB 数据同步指南](./db_sync_guide.md) - 同步 API 使用说明 - [迁移测试报告](./test_report_duckdb_migration.md) - 测试验证结果 - - ---- - -## 目录 - -1. [执行摘要](#1-执行摘要) -2. [迁移方案](#2-迁移方案) -3. [迁移计划](#3-迁移计划) -4. [影响范围分析](#4-影响范围分析) -5. [风险与回滚策略](#5-风险与回滚策略) -6. [附录](#6-附录) - ---- - -## 1. 执行摘要 - -### 1.1 迁移目标 - -将 ProStock 项目的数据存储从 **HDF5 格式** 迁移到 **DuckDB 嵌入式数据库**,解决以下核心问题: - -| 问题 | 现状 (HDF5) | 目标 (DuckDB) | 预期收益 | -|------|------------|--------------|---------| -| **全表加载** | 每次查询加载 1GB+ 数据 | 查询下推,按需加载 | **单股票查询 100x 加速** | -| **内存占用** | 必须全表载入内存 | 磁盘级过滤 | **内存使用降低 80%** | -| **并发写入** | 文件锁,伪并发 | 事务支持 | **更可靠的增量更新** | -| **数据压缩** | HDF5 内置压缩 | DuckDB 列式压缩 | **存储空间减少 20-50%** | - -### 1.2 工作量估算 - -| 阶段 | 工作量 | 说明 | -|------|--------|------| -| **核心开发** | 6-8 小时 | Storage 重写、DataLoader 适配、Sync 调整 | -| **文档更新** | 2-3 小时 | 3 份设计文档修改 | -| **数据迁移** | 30 分钟 | H5 → DuckDB 脚本运行 | -| **测试验证** | 2-4 小时 | 单元测试、集成测试、性能基准 | -| **总计** | **10-15 小时** | 1-2 个工作日 | - -### 1.3 关键决策 - -- ✅ **完全迁移**:不保留 HDF5 代码,彻底迁移到 DuckDB -- ✅ **API 兼容**:保持 `Storage` 类接口不变,调用方零改动 -- ✅ **Polars 集成**:支持 `load_polars()` 方法,DataLoader 无缝衔接 -- ✅ **并发安全**:使用单线程写入队列,避免 DuckDB 锁冲突 - ---- - -## 2. 迁移方案 - -### 2.1 技术架构对比 - -#### 当前架构 (HDF5) - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Factor Engine (执行引擎) │ -└──────────────────────────┬──────────────────────────────────┘ - │ -┌──────────────────────────▼──────────────────────────────────┐ -│ DataLoader (数据加载层) │ -│ ┌────────────────┐ ┌────────────────┐ ┌────────────────┐ │ -│ │ Multi-File │ │ Column │ │ Lookback │ │ -│ │ Aggregation │ │ Selector │ │ Window Control │ │ -│ └────────────────┘ └────────────────┘ └────────────────┘ │ -└──────────────────────────┬──────────────────────────────────┘ - │ - ┌──────▼──────┐ - │ HDF5 Files │ ←── 每个表一个 .h5 文件 - └─────────────┘ 全表加载到内存后过滤 -``` - -#### 目标架构 (DuckDB) - -``` -┌─────────────────────────────────────────────────────────────┐ -│ Factor Engine (执行引擎) │ -└──────────────────────────┬──────────────────────────────────┘ - │ -┌──────────────────────────▼──────────────────────────────────┐ -│ DataLoader (数据加载层) │ -│ ┌────────────────┐ ┌────────────────┐ ┌────────────────┐ │ -│ │ SQL Query │ │ Predicate │ │ Polars Export │ │ -│ │ Generation │ │ Pushdown │ │ (Zero-Copy) │ │ -│ └────────────────┘ └────────────────┘ └────────────────┘ │ -└──────────────────────────┬──────────────────────────────────┘ - │ - ┌──────▼──────┐ - │ DuckDB │ ←── 单个 .duckdb 文件 - │ (Embedded) │ SQL 查询下推,只读必要数据 - └─────────────┘ -``` - -### 2.2 数据库 Schema 设计 - -#### 2.2.1 表结构设计 - -```sql --- 日线数据表(替代 daily.h5) -CREATE TABLE daily ( - ts_code VARCHAR(16) NOT NULL, -- 股票代码 - trade_date DATE NOT NULL, -- 交易日期 - open DOUBLE, - high DOUBLE, - low DOUBLE, - close DOUBLE, - pre_close DOUBLE, - change DOUBLE, - pct_chg DOUBLE, - vol DOUBLE, - amount DOUBLE, - turnover_rate DOUBLE, -- 换手率 - volume_ratio DOUBLE, -- 量比 - -- 其他字段... - - PRIMARY KEY (ts_code, trade_date) -- 复合主键,自动去重 -); - --- 创建复合索引(覆盖常用查询场景:按日期范围+股票代码过滤) -CREATE INDEX idx_daily_date_code ON daily(trade_date, ts_code); - --- 股票基础信息表(替代 stock_basic.h5) -CREATE TABLE stock_basic ( - ts_code VARCHAR(16) PRIMARY KEY, - symbol VARCHAR(10), - name VARCHAR(50), - area VARCHAR(20), - industry VARCHAR(50), - market VARCHAR(10), - list_date DATE, - -- 其他字段... -); - --- 交易日历表(替代 trade_cal.h5) -CREATE TABLE trade_cal ( - exchange VARCHAR(10), - cal_date DATE, - is_open BOOLEAN, - PRIMARY KEY (exchange, cal_date) -); -``` - -#### 2.2.2 数据类型映射 - -| HDF5/Pandas | DuckDB | 说明 | -|------------|--------|------| -| `object` (string) | `VARCHAR` | 股票代码、名称 | -| `int64` | `BIGINT` | 成交量(整数) | -| `float64` | `DOUBLE` | 价格、收益率 | -| `object` (date) | `DATE` | 交易日期,支持范围查询 | -| `bool` | `BOOLEAN` | 是否交易日 | - -### 2.3 核心代码改造方案 - -#### 2.3.1 Storage 类重写 (`src/data/storage.py`) - -**当前 HDF5 实现**(151 行)→ **DuckDB 实现**(约 200 行) - -```python -"""DuckDB storage for data persistence.""" - -import duckdb -import pandas as pd -import polars as pl -from pathlib import Path -from typing import Optional, List -from contextlib import contextmanager -from src.data.config import get_config - - -class Storage: - """DuckDB storage manager for saving and loading data. - - 迁移说明: - - 保持 API 完全兼容,调用方无需修改 - - 新增 load_polars() 方法支持 Polars 零拷贝导出 - - 使用单例模式管理数据库连接 - - 并发写入通过队列管理(见 ThreadSafeStorage) - """ - - _instance = None - _connection = None - - def __new__(cls, *args, **kwargs): - """Singleton to ensure single connection.""" - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - def __init__(self, path: Optional[Path] = None): - """Initialize storage.""" - if hasattr(self, '_initialized'): - return - - cfg = get_config() - self.base_path = path or cfg.data_path_resolved - self.base_path.mkdir(parents=True, exist_ok=True) - self.db_path = self.base_path / "prostock.db" - - self._init_db() - self._initialized = True - - def _init_db(self): - """Initialize database connection and schema.""" - self._connection = duckdb.connect(str(self.db_path)) - - # Create tables with schema validation - self._connection.execute(""" - CREATE TABLE IF NOT EXISTS daily ( - ts_code VARCHAR(16) NOT NULL, - trade_date DATE NOT NULL, - open DOUBLE, - high DOUBLE, - low DOUBLE, - close DOUBLE, - pre_close DOUBLE, - change DOUBLE, - pct_chg DOUBLE, - vol DOUBLE, - amount DOUBLE, - turnover_rate DOUBLE, - volume_ratio DOUBLE, - PRIMARY KEY (ts_code, trade_date) - ) - """) - - # Create composite index for query optimization (trade_date, ts_code) - self._connection.execute(""" - CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code) - """) - - def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict: - """Save data to DuckDB. - - Args: - name: Table name - data: DataFrame to save - mode: 'append' (UPSERT) or 'replace' (DELETE + INSERT) - - Returns: - Dict with save result - """ - if data.empty: - return {"status": "skipped", "rows": 0} - - # Ensure date column is proper type - if 'trade_date' in data.columns: - data = data.copy() - data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d').dt.date - - # Register DataFrame as temporary view - self._connection.register("temp_data", data) - - try: - if mode == "replace": - self._connection.execute(f"DELETE FROM {name}") - - # UPSERT: INSERT OR REPLACE - columns = ", ".join(data.columns) - self._connection.execute(f""" - INSERT OR REPLACE INTO {name} ({columns}) - SELECT {columns} FROM temp_data - """) - - row_count = len(data) - print(f"[Storage] Saved {row_count} rows to DuckDB ({name})") - return {"status": "success", "rows": row_count} - - except Exception as e: - print(f"[Storage] Error saving {name}: {e}") - return {"status": "error", "error": str(e)} - finally: - self._connection.unregister("temp_data") - - def load( - self, - name: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - ts_code: Optional[str] = None, - ) -> pd.DataFrame: - """Load data from DuckDB with query pushdown. - - 关键优化: - - WHERE 条件在数据库层过滤,无需加载全表 - - 只返回匹配条件的行,大幅减少内存占用 - - Args: - name: Table name - start_date: Start date filter (YYYYMMDD) - end_date: End date filter (YYYYMMDD) - ts_code: Stock code filter - - Returns: - Filtered DataFrame - """ - # Build WHERE clause with parameterized queries - conditions = [] - params = [] - - if start_date and end_date: - conditions.append("trade_date BETWEEN ? AND ?") - # Convert to DATE type - start = pd.to_datetime(start_date, format='%Y%m%d').date() - end = pd.to_datetime(end_date, format='%Y%m%d').date() - params.extend([start, end]) - elif start_date: - conditions.append("trade_date >= ?") - params.append(pd.to_datetime(start_date, format='%Y%m%d').date()) - elif end_date: - conditions.append("trade_date <= ?") - params.append(pd.to_datetime(end_date, format='%Y%m%d').date()) - - if ts_code: - conditions.append("ts_code = ?") - params.append(ts_code) - - where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" - query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date" - - try: - # Execute query with parameters (SQL injection safe) - result = self._connection.execute(query, params).fetchdf() - - # Convert trade_date back to string format for compatibility - if 'trade_date' in result.columns: - result['trade_date'] = result['trade_date'].dt.strftime('%Y%m%d') - - return result - except Exception as e: - print(f"[Storage] Error loading {name}: {e}") - return pd.DataFrame() - - def load_polars( - self, - name: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - ts_code: Optional[str] = None, - ) -> pl.DataFrame: - """Load data as Polars DataFrame (for DataLoader). - - 性能优势: - - 零拷贝导出(DuckDB → Polars) - - 无需经过 Pandas 转换 - """ - # Build query - conditions = [] - if start_date and end_date: - start = pd.to_datetime(start_date, format='%Y%m%d').date() - end = pd.to_datetime(end_date, format='%Y%m%d').date() - conditions.append(f"trade_date BETWEEN '{start}' AND '{end}'") - if ts_code: - conditions.append(f"ts_code = '{ts_code}'") - - where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" - query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date" - - # Return Polars DataFrame directly - return self._connection.sql(query).pl() - - def exists(self, name: str) -> bool: - """Check if table exists.""" - result = self._connection.execute(""" - SELECT COUNT(*) FROM information_schema.tables - WHERE table_name = ? - """, [name]).fetchone() - return result[0] > 0 - - def delete(self, name: str) -> bool: - """Delete a table.""" - try: - self._connection.execute(f"DROP TABLE IF EXISTS {name}") - print(f"[Storage] Deleted table {name}") - return True - except Exception as e: - print(f"[Storage] Error deleting {name}: {e}") - return False - - def get_last_date(self, name: str) -> Optional[str]: - """Get the latest date in storage.""" - try: - result = self._connection.execute(f""" - SELECT MAX(trade_date) FROM {name} - """).fetchone() - if result[0]: - # Convert date back to string format - return result[0].strftime('%Y%m%d') if hasattr(result[0], 'strftime') else str(result[0]) - return None - except: - return None - - def close(self): - """Close database connection.""" - if self._connection: - self._connection.close() - Storage._connection = None - Storage._instance = None - - -class ThreadSafeStorage: - """线程安全的 DuckDB 写入包装器。 - - DuckDB 写入时不支持并发,使用队列收集写入请求, - 在 sync 结束时统一批量写入。 - """ - - def __init__(self): - self.storage = Storage() - self._pending_writes: List[tuple] = [] # [(name, data), ...] - - def queue_save(self, name: str, data: pd.DataFrame): - """将数据放入写入队列(不立即写入)""" - if not data.empty: - self._pending_writes.append((name, data)) - - def flush(self): - """批量写入所有队列数据。 - - 调用时机:在 sync 结束时统一调用,避免并发写入冲突。 - """ - if not self._pending_writes: - return - - # 合并相同表的数据 - from collections import defaultdict - table_data = defaultdict(list) - - for name, data in self._pending_writes: - table_data[name].append(data) - - # 批量写入每个表 - for name, data_list in table_data.items(): - combined = pd.concat(data_list, ignore_index=True) - # 在批量数据中先去重 - if 'ts_code' in combined.columns and 'trade_date' in combined.columns: - combined = combined.drop_duplicates( - subset=["ts_code", "trade_date"], - keep="last" - ) - self.storage.save(name, combined, mode="append") - - self._pending_writes.clear() - - def __getattr__(self, name): - """代理其他方法到 Storage 实例""" - return getattr(self.storage, name) -``` - -#### 2.3.2 DataLoader 适配 (`src/factors/data_loader.py`) - -**改动点**:修改 `_read_h5` 方法,使用 DuckDB 查询 - -```python -def _read_h5(self, source: str) -> pl.DataFrame: - """读取数据 - 从 DuckDB 加载为 Polars DataFrame。 - - 迁移说明: - - 方法名保持 _read_h5 以兼容现有代码(实际从 DuckDB 读取) - - 使用 Storage.load_polars() 直接返回 Polars DataFrame - - 支持零拷贝导出,性能优于 HDF5 + Pandas + Polars 转换 - """ - from src.data.storage import Storage - - storage = Storage() - - # 如果 DataLoader 有 date_range,传递给 Storage 进行过滤 - # 实现查询下推,只加载必要数据 - return storage.load_polars(source) -``` - -#### 2.3.3 Sync 模块调整 (`src/data/sync.py`) - -**改动点**:使用 ThreadSafeStorage 替代 Storage - -```python -# 修改前 -from src.data.storage import Storage - -class DataSync: - def __init__(self, max_workers: Optional[int] = None): - self.storage = Storage() # 直接写入 - ... - - def sync_daily(self, ...): - # 多线程中直接调用 save - self.storage.save("daily", data, mode="append") - -# 修改后 -from src.data.storage import ThreadSafeStorage - -class DataSync: - def __init__(self, max_workers: Optional[int] = None): - self.storage = ThreadSafeStorage() # 队列写入 - ... - - def sync_daily(self, ...): - # 多线程中排队,不立即写入 - self.storage.queue_save("daily", data) - - def sync_all(self, ...): - try: - # ... 多线程获取数据 ... - pass - finally: - # 统一批量写入 - self.storage.flush() -``` - -### 2.4 数据同步方案 - -**无需迁移脚本,直接使用 sync 模块同步数据** - -由于 DuckDB 存储层完全兼容现有 API,无需创建专门的数据迁移脚本。采用以下策略: - -1. **新环境/首次部署**:直接运行 `sync_all()` 从 Tushare 获取全部数据 -2. **现有 HDF5 数据迁移**:保留 HDF5 文件作为备份,DuckDB 从最新日期开始增量同步 - -**同步命令**: - -```bash -# 全量同步(首次部署或需要完整数据时) -uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" - -# 增量同步(日常使用) -uv run python -c "from src.data.sync import sync_all; sync_all()" - -# 指定线程数 -uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" -``` - -**优势**: -- ✅ 无需维护独立的迁移脚本 -- ✅ 数据直接从源头同步,确保最新 -- ✅ 利用现有 sync 逻辑,代码复用 -- ✅ 支持增量更新,节省时间 - ---- - -## 3. 迁移计划 - -### 3.1 实施阶段 - -#### Phase 1: 准备与开发 (Day 1) - -**任务清单**: - -| 序号 | 任务 | 文件 | 预估时间 | 负责人 | -|------|------|------|---------|--------| -| 1.1 | 安装 DuckDB 依赖 | `pyproject.toml` | 10 分钟 | Dev | -| 1.2 | 重写 Storage 类 | `src/data/storage.py` | 2 小时 | Dev | -| 1.3 | 创建 ThreadSafeStorage | `src/data/storage.py` | 30 分钟 | Dev | -| 1.4 | 适配 DataLoader | `src/factors/data_loader.py` | 30 分钟 | Dev | -| 1.5 | 修改 Sync 并发逻辑 | `src/data/sync.py` | 1 小时 | Dev | - -**产出物**: -- ✅ 可运行的 DuckDB Storage 实现 -- ✅ 单元测试通过 - -#### Phase 2: 测试与验证 (Day 1-2) - -**任务清单**: - -| 序号 | 任务 | 说明 | 预估时间 | -|------|------|------|---------| -| 2.1 | 运行现有单元测试 | `uv run pytest tests/test_sync.py` | 15 分钟 | -| 2.2 | 运行 DataLoader 测试 | `uv run pytest tests/factors/test_data_spec.py` | 15 分钟 | -| 2.3 | 数据同步测试 | `uv run python -c "from src.data.sync import sync_all; sync_all()"` | 10 分钟 | -| 2.4 | 性能基准测试 | 对比 HDF5 vs DuckDB 查询性能 | 1 小时 | -| 2.5 | 并发写入测试 | 验证 ThreadSafeStorage 正确性 | 30 分钟 | - -**验证标准**: -- [ ] 所有现有测试通过 -- [ ] 单股票查询 < 1 秒(HDF5 需 5-10 秒) -- [ ] 日期范围查询 < 0.5 秒 -- [ ] 数据完整性验证通过(记录数一致) - -#### Phase 3: 文档更新 (Day 2) - -**需修改的文档**: - -| 文档 | 修改内容 | 预估时间 | -|------|---------|---------| -| `docs/factor_framework_design.md` | 架构图 HDF5 → DuckDB,DataSpec 说明 | 30 分钟 | -| `docs/factor_implementation_plan.md` | DataLoader 描述,Phase 3 实现细节 | 30 分钟 | -| `docs/data_sync.md` | 存储格式说明,同步逻辑描述 | 30 分钟 | -| `README.md` | 数据存储说明 | 15 分钟 | - -**文档修改详情**见 [第 4 节:影响范围分析](#4-影响范围分析) - -#### Phase 4: 部署与清理 (Day 2) - -**任务清单**: - -| 序号 | 任务 | 说明 | -|------|------|------| -| 4.1 | 备份 HDF5 文件 | `cp data/*.h5 data/backup/` | -| 4.2 | 运行全量同步 | `uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"` | -| 4.3 | 验证数据完整性 | 抽样检查(从 DuckDB 查询并对比关键数据点) | -| 4.4 | 删除 HDF5 文件 | `rm data/*.h5`(验证通过后) | -| 4.5 | 提交代码 | `git add . && git commit -m "migrate: HDF5 to DuckDB"` | - -### 3.2 回滚计划 - -如果迁移后发现问题,执行以下回滚步骤: - -```bash -# 1. 恢复 HDF5 文件 -cp data/backup/*.h5 data/ - -# 2. 恢复 Storage 代码(从 git 历史) -git checkout HEAD~1 -- src/data/storage.py - -# 3. 重新安装依赖(如果需要) -# pip uninstall duckdb - -# 4. 验证 -uv run pytest tests/test_sync.py -``` - ---- - -## 4. 影响范围分析 - -### 4.1 代码文件改动清单 - -#### 核心文件(必须修改) - -| 文件路径 | 改动类型 | 改动说明 | 影响程度 | -|---------|---------|---------|---------| -| `src/data/storage.py` | 重写 | HDF5 → DuckDB 实现 | 🔴 高 | -| `src/data/sync.py` | 修改 | 使用 ThreadSafeStorage | 🟡 中 | -| `src/factors/data_loader.py` | 修改 | `_read_h5()` 适配 | 🟡 中 | -| `pyproject.toml` | 修改 | 添加 `duckdb` 依赖 | 🟢 低 | - -#### 新增文件 - -| 文件路径 | 说明 | -|---------|------| -| `docs/hdf5_to_duckdb_migration.md` | 本文档 | - -#### 测试文件(需要验证) - -| 文件路径 | 验证内容 | -|---------|---------| -| `tests/test_sync.py` | 同步流程正常 | -| `tests/test_daily_storage.py` | Storage 接口兼容 | -| `tests/factors/test_data_spec.py` | DataLoader 工作正常 | - -### 4.2 设计文档修改详情 - -#### 4.2.1 `docs/factor_framework_design.md` - -**修改位置**: 第 2 节 架构概述 - -**当前内容**: -```markdown - ┌──────▼──────┐ - │ HDF5 Files │ - └─────────────┘ -``` - -**修改为**: -```markdown - ┌──────▼──────┐ - │ DuckDB │ - │ (Embedded) │ - └─────────────┘ -``` - -**修改位置**: 第 3.1 节 DataSpec - -**当前内容**: -```python -source: str # H5 文件名(不含扩展名) -``` - -**修改为**: -```python -source: str # 表名(对应 DuckDB 中的表,如 "daily", "stock_basic") -``` - -#### 4.2.2 `docs/factor_implementation_plan.md` - -**修改位置**: Phase 3 数据加载 - -**当前内容**: -```markdown -### 3.1 DataLoader - 数据加载器 - -"""数据加载器 - 负责从 HDF5 安全加载数据""" - -实现:使用 pandas.read_hdf(),然后 pl.from_pandas() -``` - -**修改为**: -```markdown -### 3.1 DataLoader - 数据加载器 - -"""数据加载器 - 负责从 DuckDB 安全加载数据""" - -实现:使用 Storage.load_polars() 直接返回 Polars DataFrame - 支持 SQL 查询下推,只加载必要数据 -``` - -**修改位置**: Phase 3 测试需求 - -**添加**: -```markdown -**DuckDB 集成测试需求:** -- [ ] 测试 DuckDB 查询下推正确性 -- [ ] 测试 Polars 零拷贝导出 -- [ ] 测试并发写入队列机制 -``` - -#### 4.2.3 新增/修改的数据文档 - -**`docs/data_sync.md`**(新增或修改) - -需要添加/修改的内容: -- 存储格式说明:HDF5 → DuckDB -- 数据库文件位置:`data/prostock.db` -- 查询优化:使用 SQL 条件代替内存过滤 - -### 4.3 API 兼容性说明 - -#### 保持不变的接口 ✅ - -以下接口完全保持兼容,调用方无需修改: - -```python -# Storage 类核心方法 -storage.save(name, data, mode="append") -storage.load(name, start_date, end_date, ts_code) -storage.exists(name) -storage.delete(name) -storage.get_last_date(name) - -# DataLoader 类 -loader.load(specs, date_range) -loader._read_h5(source) # 内部方法,行为不变 -``` - -#### 新增的接口 🆕 - -```python -# Storage 新增方法 -storage.load_polars(name, start_date, end_date, ts_code) # 直接返回 Polars - -# ThreadSafeStorage(Sync 内部使用) -thread_safe_storage.queue_save(name, data) -thread_safe_storage.flush() -``` - -#### 废弃的接口 ❌ - -```python -# 不再支持 HDF5 特定的方法 -# 无(所有 HDF5 特定逻辑都在 Storage 内部) -``` - -### 4.4 依赖变更 - -#### `pyproject.toml` 修改 - -```toml -[project] -dependencies = [ - # ... 现有依赖 ... - "duckdb>=0.10.0", # 新增 -] - -[project.optional-dependencies] -dev = [ - # ... 现有 dev 依赖 ... - "pytest-duckdb", # 可选:DuckDB 测试工具 -] -``` - -#### 安装命令 - -```bash -# 安装 DuckDB -uv pip install duckdb - -# 或使用 requirements 安装所有依赖 -uv pip install -e ".[dev]" -``` - ---- - -## 5. 风险与回滚策略 - -### 5.1 风险识别 - -| 风险 | 概率 | 影响 | 缓解措施 | -|------|------|------|---------| -| **并发写入冲突** | 中 | 高 | 使用 ThreadSafeStorage 队列管理 | -| **数据类型不匹配** | 低 | 中 | 严格的 Schema 定义和转换逻辑 | -| **性能不如预期** | 低 | 高 | 性能基准测试,预留回滚方案 | -| **依赖兼容性问题** | 低 | 中 | 使用虚拟环境隔离测试 | -| **数据丢失** | 低 | 极高 | 迁移前完整备份 HDF5 文件 | - -### 5.2 回滚触发条件 - -以下情况触发回滚: - -1. **数据完整性验证失败** - - 记录数不一致 - - 抽样数据不匹配 - -2. **性能下降超过 20%** - - 全表扫描比 HDF5 慢 - - 内存占用不降反升 - -3. **核心测试失败** - - `test_sync.py` 失败 - - `test_data_loader.py` 失败 - -4. **生产环境异常** - - 数据同步失败 - - 查询超时 - -### 5.3 回滚步骤 - -```bash -#!/bin/bash -# rollback.sh - 回滚脚本 - -echo "[Rollback] Starting rollback to HDF5..." - -# 1. 停止所有运行中的服务 -pkill -f "python.*sync" - -# 2. 恢复 HDF5 文件 -echo "[Rollback] Restoring HDF5 files..." -cp data/backup/*.h5 data/ 2>/dev/null || echo "No backup found, keeping existing" - -# 3. 从 git 恢复代码 -echo "[Rollback] Restoring code from git..." -git checkout HEAD~1 -- src/data/storage.py -git checkout HEAD~1 -- src/data/sync.py -git checkout HEAD~1 -- src/factors/data_loader.py -git checkout HEAD~1 -- pyproject.toml - -# 4. 重新安装依赖(如果需要) -echo "[Rollback] Reinstalling dependencies..." -uv pip install -e . - -# 5. 验证 -echo "[Rollback] Running tests..." -uv run pytest tests/test_sync.py -v - -echo "[Rollback] Rollback completed!" -``` - -### 5.4 数据备份策略 - -**迁移前备份**: - -```bash -# 创建备份目录 -mkdir -p data/backup_$(date +%Y%m%d_%H%M%S) - -# 备份所有 HDF5 文件 -cp data/*.h5 data/backup_$(date +%Y%m%d_%H%M%S)/ - -# 备份完成后的 DuckDB 文件(迁移后) -cp data/prostock.db data/backup_$(date +%Y%m%d_%H%M%S)/ 2>/dev/null || true -``` - -**定期备份**(迁移后): - -```bash -# DuckDB 文件备份(每天) -0 2 * * * cp /path/to/prostock.db /path/to/backup/prostock_$(date +\%Y\%m\%d).db -``` - ---- - -## 6. 附录 - -### 附录 A:性能基准测试方案 - -**测试脚本**: `scripts/benchmark_storage.py` - -```python -"""存储性能基准测试:HDF5 vs DuckDB""" - -import time -import pandas as pd -from src.data.storage import Storage - -def benchmark_load(storage, name, iterations=5): - """测试加载性能""" - times = [] - - for _ in range(iterations): - start = time.time() - # 单股票查询 - df = storage.load(name, ts_code="000001.SZ") - elapsed = time.time() - start - times.append(elapsed) - - return { - "mean": sum(times) / len(times), - "min": min(times), - "max": max(times), - } - -def main(): - storage = Storage() - - print("=== Storage Performance Benchmark ===\n") - - # 单股票查询 - print("Single stock query (000001.SZ):") - result = benchmark_load(storage, "daily") - print(f" Mean: {result['mean']:.3f}s") - print(f" Min: {result['min']:.3f}s") - print(f" Max: {result['max']:.3f}s") - - # 日期范围查询 - print("\nDate range query (20240101-20240131):") - start = time.time() - df = storage.load("daily", start_date="20240101", end_date="20240131") - elapsed = time.time() - start - print(f" Time: {elapsed:.3f}s") - print(f" Rows: {len(df)}") - -if __name__ == "__main__": - main() -``` - -**预期结果**: - -| 测试项 | HDF5 | DuckDB | 提升 | -|--------|------|--------|------| -| 单股票查询 | 5-10s | 0.1-0.5s | **10-100x** | -| 日期范围查询 | 5-10s | 0.2-1s | **5-50x** | -| 全表扫描 | 5-10s | 3-5s | 1.5-2x | -| 内存占用 | 1GB+ | 100-500MB | **50-90%** | - -### 附录 B:DuckDB 运维指南 - -#### 数据库文件位置 - -``` -data/ -├── prostock.db # DuckDB 主数据库文件 -├── prostock.db.wal # WAL 日志文件(写入时存在) -└── backup/ # 备份目录 -``` - -#### 常用维护命令 - -```python -import duckdb - -# 查看数据库信息 -conn = duckdb.connect("data/prostock.db") - -# 查看所有表 -tables = conn.execute(""" - SELECT table_name, - estimated_size - FROM information_schema.tables - WHERE table_schema = 'main' -""").fetchall() - -# 查看表结构 -schema = conn.execute("DESCRIBE daily").fetchall() - -# 分析表统计(优化查询计划) -conn.execute("ANALYZE daily") - -# 压缩数据库(VACUUM) -conn.execute("VACUUM") - -conn.close() -``` - -#### 性能优化建议 - -1. **创建适当的索引**: - ```sql - CREATE INDEX idx_daily_date_code ON daily(trade_date, ts_code); - ``` - -2. **使用分区(大数据量时)**: - ```sql - -- 按年分区(如果数据量达到亿级) - CREATE TABLE daily_partitioned AS - SELECT *, YEAR(trade_date) as year - FROM daily; - ``` - -3. **批量插入优化**: - ```python - # 使用事务批量插入 - conn.execute("BEGIN TRANSACTION") - # ... 多个插入操作 ... - conn.execute("COMMIT") - ``` - -### 附录 C:常见问题 FAQ - -**Q: DuckDB 是否支持多线程并发写入?** - -A: DuckDB 支持并发读取,但写入时需要锁。我们使用 `ThreadSafeStorage` 队列机制,将并发写入转换为批量单线程写入,避免锁冲突。 - -**Q: 数据迁移后 HDF5 文件可以删除吗?** - -A: 验证通过后可以删除。建议保留备份至少 1 周。 - -**Q: DuckDB 文件损坏怎么办?** - -A: DuckDB 具有事务日志(WAL),正常情况下不会损坏。如果发生: -1. 从备份恢复 `.db` 文件 -2. 删除 `.db.wal` 文件(如果存在) -3. 重新连接 - -**Q: 如何查看 DuckDB 查询执行计划?** - -A: 使用 `EXPLAIN` 命令: -```python -conn.execute("EXPLAIN SELECT * FROM daily WHERE ts_code = '000001.SZ'").fetchall() -``` - -**Q: 是否支持从 DuckDB 直接导出 HDF5?** - -A: 支持,可以使用 Pandas 中转: -```python -df = conn.execute("SELECT * FROM daily").fetchdf() -df.to_hdf("backup.h5", key="daily") -``` - ---- - -## 文档历史 - -| 版本 | 日期 | 作者 | 变更说明 | -|------|------|------|---------| -| v1.0 | 2026-02-22 | Sisyphus | 初始版本,完整迁移方案 | - ---- - -## 审批记录 - -| 角色 | 姓名 | 日期 | 意见 | -|------|------|------|------| -| 技术负责人 | ______ | ______ | ______ | -| 项目负责人 | ______ | ______ | ______ | - ---- - -**下一步行动**: -1. [ ] 技术负责人审批方案 -2. [ ] 确定实施日期 -3. [ ] 分配开发资源 -4. [ ] 执行 Phase 1 开发 diff --git a/docs/ml_framework_design.md b/docs/ml_framework_design.md deleted file mode 100644 index def2c50..0000000 --- a/docs/ml_framework_design.md +++ /dev/null @@ -1,1472 +0,0 @@ -# ProStock 模型训练框架设计文档 - -## 1. 设计目标与原则 - -### 1.1 核心目标 -- **组件化**:每个阶段(数据获取、处理、训练、评估)都是独立组件 -- **低耦合**:组件间通过标准接口交互,不依赖具体实现 -- **插件式**:新功能通过插件注册,无需修改核心代码 -- **阶段感知**:数据处理区分训练阶段和测试阶段,防止数据泄露 -- **多模型支持**:统一接口支持 LightGBM、CatBoost 等多种模型 -- **多任务支持**:分类、回归、排序三种任务类型 - -### 1.2 设计原则 - -| 原则 | 说明 | -|------|------| -| **单一职责** | 每个组件只做一件事,做好一件事 | -| **开闭原则** | 对扩展开放(插件),对修改封闭(核心) | -| **依赖倒置** | 依赖抽象接口,而非具体实现 | -| **显式优于隐式** | 阶段标记、处理逻辑必须显式声明 | -| **配置驱动** | 通过配置文件或代码配置定义流程,减少硬编码 | - ---- - -## 2. 整体架构 - -### 2.1 架构概览 - -``` -┌─────────────────────────────────────────────────────────────────────────┐ -│ ML Pipeline Orchestrator │ -│ (流水线编排器 - 配置驱动执行) │ -└─────────────────────────────────────────────────────────────────────────┘ - │ - ┌───────────────────────────┼───────────────────────────┐ - ▼ ▼ ▼ -┌───────────────┐ ┌───────────────┐ ┌───────────────┐ -│ Data Source │ │ Data Source │ │ Data Source │ -│ (因子数据) │ │ (行情数据) │ │ (标签数据) │ -└───────┬───────┘ └───────┬───────┘ └───────┬───────┘ - │ │ │ - └──────────────────────────┼──────────────────────────┘ - ▼ -┌─────────────────────────────────────────────────────────────────────────┐ -│ Feature Store (特征存储层) │ -│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ -│ │ FactorLoader │ │ LabelLoader │ │ DataMerger │ │ CacheMgr │ │ -│ └──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘ │ -└─────────────────────────────────────────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────────────────┐ -│ Processing Pipeline (处理流水线) │ -│ │ -│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌──────────┐ │ -│ │ Processor │ -> │ Processor │ -> │ Processor │ -> │ ... │ │ -│ │ (阶段:ALL) │ │ (阶段:TRAIN)│ │ (阶段:TEST) │ │ │ │ -│ └─────────────┘ └─────────────┘ └─────────────┘ └──────────┘ │ -│ │ -│ 处理器类型: │ -│ - FeatureEncoder: 特征编码(类别编码、数值缩放等) │ -│ - FeatureSelector: 特征选择(相关性过滤、重要性筛选等) │ -│ - OutlierHandler: 异常值处理 │ -│ - MissingValueHandler: 缺失值处理 │ -│ - CustomTransformer: 自定义转换器 │ -└─────────────────────────────────────────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────────────────┐ -│ Train/Test Split (数据划分) │ -│ │ -│ 支持多种划分策略: │ -│ - TimeSeriesSplit: 时间序列划分(防止未来泄露) │ -│ - PurgedKFold: 清除重叠样本的K折交叉验证 │ -│ - EmbargoSplit: embargo 延迟验证 │ -│ - CustomSplit: 自定义划分策略 │ -└─────────────────────────────────────────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────────────────┐ -│ Model Training (模型训练层) │ -│ │ -│ ┌─────────────────────────────────────────────────────────────────┐ │ -│ │ Model Registry │ │ -│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │ -│ │ │ LightGBM │ │CatBoost │ │ XGBoost │ │ Custom │ ... │ │ -│ │ │ Model │ │ Model │ │ Model │ │ Model │ │ │ -│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │ -│ └─────────────────────────────────────────────────────────────────┘ │ -│ │ -│ 任务类型: │ -│ - Classification: 分类任务(上涨/下跌预测) │ -│ - Regression: 回归任务(收益率预测) │ -│ - Ranking: 排序任务(股票排序/选股) │ -└─────────────────────────────────────────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────────────────┐ -│ Evaluation (评估层) │ -│ │ -│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌────────────┐ │ -│ │ Metric │ │ Metric │ │ Metric │ │ Analyzer │ │ -│ │ (IC/IR) │ │ (Sharpe) │ │ (Accuracy) │ │ (回测) │ │ -│ └──────────────┘ └──────────────┘ └──────────────┘ └────────────┘ │ -│ │ -│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ -│ │ ResultStore │ │ Report │ │ Visualizer │ │ -│ │ (模型存储) │ │ (报告生成) │ │ (可视化) │ │ -│ └──────────────┘ └──────────────┘ └──────────────┘ │ -└─────────────────────────────────────────────────────────────────────────┘ -``` - -### 2.2 数据流向图 - -``` -因子DataFrame (Polars) - │ - ▼ -┌──────────────────────┐ -│ Feature Store │ 1. 加载并合并因子、标签、辅助数据 -│ - 列选择 │ 2. 支持按日期/股票过滤 -│ - 数据对齐 │ 3. 缓存机制避免重复加载 -└──────────┬───────────┘ - │ - ▼ -┌──────────────────────┐ -│ Processing Pipeline │ 顺序执行多个处理器 -│ │ 每个处理器标记适用阶段 (ALL/TRAIN/TEST) -│ for processor in pipeline: -│ if processor.stage in [current_stage, ALL]: -│ data = processor.transform(data) -└──────────┬───────────┘ - │ - ▼ -┌──────────────────────┐ -│ Data Splitter │ 时间序列感知的划分策略 -│ - X_train, y_train │ 防止未来泄露 -│ - X_test, y_test │ -└──────────┬───────────┘ - │ - ▼ -┌──────────────────────┐ -│ Model Training │ 统一接口,支持多种模型 -│ - fit(X_train) │ 任务类型: classification/regression/ranking -│ - predict(X_test) │ -└──────────┬───────────┘ - │ - ▼ -┌──────────────────────┐ -│ Evaluation │ 多维度评估 -│ - 预测指标 │ - IC/IR -│ - 回测指标 │ - 分组收益 -│ - 可视化 │ - 累计收益曲线 -└──────────────────────┘ -``` - ---- - -## 3. 核心组件设计 - -### 3.1 基础抽象类 - -#### 3.1.1 PipelineStage (流水线阶段枚举) - -```python -from enum import Enum, auto - -class PipelineStage(Enum): - """流水线阶段标记""" - ALL = auto() # 适用于所有阶段 - TRAIN = auto() # 仅训练阶段 - TEST = auto() # 仅测试阶段 - VALIDATION = auto() # 仅验证阶段 -``` - -#### 3.1.2 BaseProcessor (处理器基类) - -```python -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional -import polars as pl - -class BaseProcessor(ABC): - """数据处理器基类 - - 所有数据处理器必须继承此类。 - 关键特性:通过 stage 属性控制处理器在哪些阶段生效。 - - 示例: - >>> class StandardScaler(BaseProcessor): - ... stage = PipelineStage.ALL # 训练和测试都使用 - ... - ... def fit(self, data: pl.DataFrame) -> None: - ... self.mean = data[self.columns].mean() - ... self.std = data[self.columns].std() - ... - ... def transform(self, data: pl.DataFrame) -> pl.DataFrame: - ... return (data - self.mean) / self.std - """ - - # 子类必须定义适用阶段 - stage: PipelineStage = PipelineStage.ALL - - def __init__(self, columns: Optional[list] = None, **params): - """初始化处理器 - - Args: - columns: 要处理的列,None表示所有数值列 - **params: 处理器特定参数 - """ - self.columns = columns - self.params = params - self._is_fitted = False - self._fitted_params: Dict[str, Any] = {} - - @abstractmethod - def fit(self, data: pl.DataFrame) -> "BaseProcessor": - """在训练数据上学习参数 - - 此方法只在训练阶段调用一次。 - 学习到的参数存储在 self._fitted_params 中。 - - Args: - data: 训练数据 - - Returns: - self (支持链式调用) - """ - pass - - @abstractmethod - def transform(self, data: pl.DataFrame) -> pl.DataFrame: - """转换数据 - - 在训练和测试阶段都会被调用。 - 使用 fit() 阶段学习到的参数进行转换。 - - Args: - data: 输入数据 - - Returns: - 转换后的数据 - """ - pass - - def fit_transform(self, data: pl.DataFrame) -> pl.DataFrame: - """先fit再transform的便捷方法""" - return self.fit(data).transform(data) - - def get_fitted_params(self) -> Dict[str, Any]: - """获取学习到的参数(用于保存/加载)""" - return self._fitted_params.copy() - - def set_fitted_params(self, params: Dict[str, Any]) -> "BaseProcessor": - """设置学习到的参数(用于从checkpoint恢复)""" - self._fitted_params = params.copy() - self._is_fitted = True - return self -``` - -#### 3.1.3 BaseModel (模型基类) - -```python -from abc import ABC, abstractmethod -from typing import Literal, Any, Dict -import polars as pl -import numpy as np - -TaskType = Literal["classification", "regression", "ranking"] - -class BaseModel(ABC): - """机器学习模型基类 - - 统一接口支持多种模型(LightGBM, CatBoost, XGBoost等) - 和多种任务类型(分类、回归、排序)。 - - 示例: - >>> model = LightGBMModel( - ... task_type="classification", - ... params={"n_estimators": 100} - ... ) - >>> model.fit(X_train, y_train) - >>> predictions = model.predict(X_test) - """ - - def __init__( - self, - task_type: TaskType, - params: Optional[Dict[str, Any]] = None, - name: Optional[str] = None - ): - """初始化模型 - - Args: - task_type: 任务类型 - "classification", "regression", "ranking" - params: 模型特定参数 - name: 模型名称(用于日志和报告) - """ - self.task_type = task_type - self.params = params or {} - self.name = name or self.__class__.__name__ - self._model: Any = None - self._is_fitted = False - - @abstractmethod - def fit( - self, - X: pl.DataFrame, - y: pl.Series, - X_val: Optional[pl.DataFrame] = None, - y_val: Optional[pl.Series] = None, - **fit_params - ) -> "BaseModel": - """训练模型 - - Args: - X: 特征数据 - y: 目标变量 - X_val: 验证集特征(可选) - y_val: 验证集目标(可选) - **fit_params: 额外的fit参数 - - Returns: - self (支持链式调用) - """ - pass - - @abstractmethod - def predict(self, X: pl.DataFrame) -> np.ndarray: - """预测 - - Args: - X: 特征数据 - - Returns: - 预测结果数组 - - classification: 类别标签或概率 - - regression: 连续值 - - ranking: 排序分数 - """ - pass - - def predict_proba(self, X: pl.DataFrame) -> np.ndarray: - """预测概率(仅分类任务) - - Args: - X: 特征数据 - - Returns: - 类别概率数组 [n_samples, n_classes] - """ - raise NotImplementedError("predict_proba only available for classification tasks") - - def get_feature_importance(self) -> Optional[pl.DataFrame]: - """获取特征重要性(如果模型支持) - - Returns: - DataFrame[feature, importance] 或 None - """ - return None - - def save(self, path: str) -> None: - """保存模型到文件""" - import pickle - with open(path, 'wb') as f: - pickle.dump(self, f) - - @classmethod - def load(cls, path: str) -> "BaseModel": - """从文件加载模型""" - import pickle - with open(path, 'rb') as f: - return pickle.load(f) -``` - -#### 3.1.4 BaseSplitter (数据划分基类) - -```python -from abc import ABC, abstractmethod -from typing import Iterator, Tuple, List -import polars as pl - -class BaseSplitter(ABC): - """数据划分策略基类 - - 针对时间序列数据的特殊划分策略,防止未来泄露。 - - 示例: - >>> splitter = TimeSeriesSplit(n_splits=5, gap=5) - >>> for train_idx, test_idx in splitter.split(data): - ... X_train, X_test = X[train_idx], X[test_idx] - """ - - @abstractmethod - def split( - self, - data: pl.DataFrame, - date_col: str = "trade_date" - ) -> Iterator[Tuple[List[int], List[int]]]: - """生成训练/测试索引 - - Args: - data: 完整数据集 - date_col: 日期列名 - - Yields: - (train_indices, test_indices) 元组 - """ - pass - - @abstractmethod - def get_split_dates( - self, - data: pl.DataFrame, - date_col: str = "trade_date" - ) -> List[Tuple[str, str, str, str]]: - """获取划分日期范围 - - Returns: - [(train_start, train_end, test_start, test_end), ...] - """ - pass -``` - ---- - -### 3.2 核心组件 - -#### 3.2.1 FeatureStore (特征存储) - -```python -from typing import List, Optional, Dict -import polars as pl -from pathlib import Path - -class FeatureStore: - """特征存储管理器 - - 负责加载、合并、缓存因子数据。 - 支持从多个数据源(因子、标签、行情)加载并合并。 - """ - - def __init__(self, data_dir: str): - self.data_dir = Path(data_dir) - self._cache: Dict[str, pl.DataFrame] = {} - - def load_factors( - self, - factor_names: List[str], - start_date: Optional[str] = None, - end_date: Optional[str] = None, - stock_codes: Optional[List[str]] = None - ) -> pl.DataFrame: - """加载因子数据 - - Args: - factor_names: 因子名称列表 - start_date: 开始日期 YYYYMMDD - end_date: 结束日期 YYYYMMDD - stock_codes: 股票代码列表(可选) - - Returns: - DataFrame[trade_date, ts_code, factor1, factor2, ...] - """ - pass - - def load_labels( - self, - label_name: str, - forward_period: int = 5, - start_date: Optional[str] = None, - end_date: Optional[str] = None - ) -> pl.DataFrame: - """加载标签数据(未来收益) - - Args: - label_name: 标签名称(如 "return", "rank") - forward_period: 前瞻期(如5天后收益) - start_date: 开始日期 - end_date: 结束日期 - - Returns: - DataFrame[trade_date, ts_code, label] - """ - pass - - def build_dataset( - self, - factor_names: List[str], - label_config: Dict, - date_range: Tuple[str, str], - stock_codes: Optional[List[str]] = None, - additional_cols: Optional[List[str]] = None - ) -> pl.DataFrame: - """构建完整数据集 - - 合并因子、标签、辅助列,并对齐数据。 - - Args: - factor_names: 因子列表 - label_config: 标签配置 {"name": str, "forward_period": int} - date_range: (start_date, end_date) - stock_codes: 限定股票列表 - additional_cols: 额外列(如 industry, market_cap) - - Returns: - DataFrame[trade_date, ts_code, factor_cols..., label] - """ - pass -``` - -#### 3.2.2 ProcessingPipeline (处理流水线) - -```python -from typing import List -import polars as pl - -class ProcessingPipeline: - """数据处理流水线 - - 按顺序执行多个处理器,自动处理阶段标记。 - 关键特性:在测试阶段使用训练阶段学习到的参数。 - """ - - def __init__(self, processors: List[BaseProcessor]): - """初始化流水线 - - Args: - processors: 处理器列表(按执行顺序) - """ - self.processors = processors - self._fitted_processors: Dict[int, BaseProcessor] = {} - - def fit_transform( - self, - data: pl.DataFrame, - stage: PipelineStage = PipelineStage.TRAIN - ) -> pl.DataFrame: - """在训练数据上fit所有处理器并transform - - Args: - data: 训练数据 - stage: 当前阶段标记 - - Returns: - 处理后的数据 - """ - result = data - for i, processor in enumerate(self.processors): - # 检查处理器是否适用于当前阶段 - if processor.stage in [PipelineStage.ALL, stage]: - # fit并transform - result = processor.fit_transform(result) - self._fitted_processors[i] = processor - elif stage == PipelineStage.TRAIN: - # 即使不适用于TRAIN阶段,也要fit(为TEST阶段准备) - if processor.stage == PipelineStage.TEST: - processor.fit(result) - self._fitted_processors[i] = processor - return result - - def transform( - self, - data: pl.DataFrame, - stage: PipelineStage = PipelineStage.TEST - ) -> pl.DataFrame: - """在测试数据上应用已fit的处理器 - - 使用训练阶段学习到的参数,防止数据泄露。 - - Args: - data: 测试数据 - stage: 当前阶段标记 - - Returns: - 处理后的数据 - """ - result = data - for i, processor in enumerate(self.processors): - if processor.stage in [PipelineStage.ALL, stage]: - if i in self._fitted_processors: - # 使用已fit的处理器 - result = self._fitted_processors[i].transform(result) - else: - # 未fit的处理器(ALL阶段但train时没执行到) - result = processor.transform(result) - return result - - def save_processors(self, path: str) -> None: - """保存所有已fit的处理器状态""" - import pickle - with open(path, 'wb') as f: - pickle.dump(self._fitted_processors, f) - - def load_processors(self, path: str) -> None: - """加载处理器状态""" - import pickle - with open(path, 'rb') as f: - self._fitted_processors = pickle.load(f) -``` - ---- - -## 4. 插件系统 - -### 4.1 注册器模式 - -```python -from typing import Type, Dict, TypeVar -from functools import wraps - -T = TypeVar('T') - -class PluginRegistry: - """插件注册中心 - - 提供装饰器方式注册处理器、模型、划分策略等组件。 - 实现真正的插件式架构 - 新功能只需注册即可使用。 - """ - - _processors: Dict[str, Type[BaseProcessor]] = {} - _models: Dict[str, Type[BaseModel]] = {} - _splitters: Dict[str, Type[BaseSplitter]] = {} - _metrics: Dict[str, Type["BaseMetric"]] = {} - - @classmethod - def register_processor(cls, name: Optional[str] = None): - """注册处理器装饰器 - - 示例: - >>> @PluginRegistry.register_processor("standard_scaler") - ... class StandardScaler(BaseProcessor): - ... pass - - >>> # 使用 - >>> scaler = PluginRegistry.get_processor("standard_scaler")() - """ - def decorator(processor_class: Type[BaseProcessor]) -> Type[BaseProcessor]: - key = name or processor_class.__name__ - cls._processors[key] = processor_class - processor_class._registry_name = key - return processor_class - return decorator - - @classmethod - def register_model(cls, name: Optional[str] = None): - """注册模型装饰器""" - def decorator(model_class: Type[BaseModel]) -> Type[BaseModel]: - key = name or model_class.__name__ - cls._models[key] = model_class - model_class._registry_name = key - return model_class - return decorator - - @classmethod - def register_splitter(cls, name: Optional[str] = None): - """注册划分策略装饰器""" - def decorator(splitter_class: Type[BaseSplitter]) -> Type[BaseSplitter]: - key = name or splitter_class.__name__ - cls._splitters[key] = splitter_class - return splitter_class - return decorator - - @classmethod - def get_processor(cls, name: str) -> Type[BaseProcessor]: - """获取处理器类""" - if name not in cls._processors: - raise KeyError(f"Processor '{name}' not found. Available: {list(cls._processors.keys())}") - return cls._processors[name] - - @classmethod - def get_model(cls, name: str) -> Type[BaseModel]: - """获取模型类""" - if name not in cls._models: - raise KeyError(f"Model '{name}' not found. Available: {list(cls._models.keys())}") - return cls._models[name] - - @classmethod - def get_splitter(cls, name: str) -> Type[BaseSplitter]: - """获取划分策略类""" - if name not in cls._splitters: - raise KeyError(f"Splitter '{name}' not found. Available: {list(cls._splitters.keys())}") - return cls._splitters[name] - - @classmethod - def list_processors(cls) -> List[str]: - """列出所有可用处理器""" - return list(cls._processors.keys()) - - @classmethod - def list_models(cls) -> List[str]: - """列出所有可用模型""" - return list(cls._models.keys()) -``` - -### 4.2 内置插件 - -```python -# ========== 内置处理器 ========== - -@PluginRegistry.register_processor("standard_scaler") -class StandardScaler(BaseProcessor): - """标准缩放处理器 - Z-score标准化""" - stage = PipelineStage.ALL - - def fit(self, data: pl.DataFrame) -> "StandardScaler": - cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] - self._fitted_params = { - "mean": {c: data[c].mean() for c in cols}, - "std": {c: data[c].std() for c in cols}, - "columns": cols - } - return self - - def transform(self, data: pl.DataFrame) -> pl.DataFrame: - result = data - for col in self._fitted_params["columns"]: - mean = self._fitted_params["mean"][col] - std = self._fitted_params["std"][col] - if std > 0: - result = result.with_columns( - ((pl.col(col) - mean) / std).alias(col) - ) - return result - - -@PluginRegistry.register_processor("winsorizer") -class Winsorizer(BaseProcessor): - """缩尾处理器 - 防止极端值影响""" - stage = PipelineStage.TRAIN # 只在训练阶段计算分位数 - - def __init__(self, columns=None, lower=0.01, upper=0.99): - super().__init__(columns) - self.lower = lower - self.upper = upper - - def fit(self, data: pl.DataFrame) -> "Winsorizer": - cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] - self._fitted_params = { - "lower": {c: data[c].quantile(self.lower) for c in cols}, - "upper": {c: data[c].quantile(self.upper) for c in cols}, - "columns": cols - } - return self - - def transform(self, data: pl.DataFrame) -> pl.DataFrame: - result = data - for col in self._fitted_params["columns"]: - lower = self._fitted_params["lower"][col] - upper = self._fitted_params["upper"][col] - result = result.with_columns( - pl.col(col).clip(lower, upper).alias(col) - ) - return result - - -@PluginRegistry.register_processor("neutralizer") -class Neutralizer(BaseProcessor): - """行业/市值中性化处理器""" - stage = PipelineStage.ALL - - def __init__(self, columns=None, group_col="industry", exclude_cols=None): - super().__init__(columns) - self.group_col = group_col - self.exclude_cols = exclude_cols or [] - - def fit(self, data: pl.DataFrame) -> "Neutralizer": - # 中性化通常在每个截面独立进行,不需要全局fit - return self - - def transform(self, data: pl.DataFrame) -> pl.DataFrame: - # 按日期分组,对每个截面进行中性化 - result = data - for col in self.columns or []: - if col in self.exclude_cols: - continue - # 分组去均值 - result = result.with_columns( - (pl.col(col) - pl.col(col).mean().over(["trade_date", self.group_col])) - .alias(col) - ) - return result - - -@PluginRegistry.register_processor("dropna") -class DropNAProcessor(BaseProcessor): - """缺失值删除处理器""" - stage = PipelineStage.ALL - - def fit(self, data: pl.DataFrame) -> "DropNAProcessor": - return self - - def transform(self, data: pl.DataFrame) -> pl.DataFrame: - cols = self.columns or data.columns - return data.drop_nulls(subset=cols) - - -@PluginRegistry.register_processor("fillna") -class FillNAProcessor(BaseProcessor): - """缺失值填充处理器""" - stage = PipelineStage.TRAIN - - def __init__(self, columns=None, method="median"): - super().__init__(columns) - self.method = method - - def fit(self, data: pl.DataFrame) -> "FillNAProcessor": - cols = self.columns or [c for c in data.columns if data[c].dtype in FLOAT_TYPES] - fill_values = {} - for col in cols: - if self.method == "median": - fill_values[col] = data[col].median() - elif self.method == "mean": - fill_values[col] = data[col].mean() - elif self.method == "zero": - fill_values[col] = 0 - self._fitted_params = {"fill_values": fill_values, "columns": cols} - return self - - def transform(self, data: pl.DataFrame) -> pl.DataFrame: - result = data - for col, val in self._fitted_params["fill_values"].items(): - result = result.with_columns(pl.col(col).fill_null(val).alias(col)) - return result - - -@PluginRegistry.register_processor("rank_transformer") -class RankTransformer(BaseProcessor): - """排名转换处理器 - 转换为截面排名""" - stage = PipelineStage.ALL - - def fit(self, data: pl.DataFrame) -> "RankTransformer": - return self - - def transform(self, data: pl.DataFrame) -> pl.DataFrame: - result = data - for col in self.columns or []: - # 按日期分组计算排名 - result = result.with_columns( - pl.col(col).rank().over("trade_date").alias(col) - ) - return result - - -# ========== 内置模型 ========== - -@PluginRegistry.register_model("lightgbm") -class LightGBMModel(BaseModel): - """LightGBM模型包装器""" - - def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None): - super().__init__(task_type, params, name) - self._model = None - - def fit( - self, - X: pl.DataFrame, - y: pl.Series, - X_val: Optional[pl.DataFrame] = None, - y_val: Optional[pl.Series] = None, - **fit_params - ) -> "LightGBMModel": - import lightgbm as lgb - - # 转换数据格式 - X_arr = X.to_numpy() - y_arr = y.to_numpy() - - # 构建数据集 - train_data = lgb.Dataset(X_arr, label=y_arr) - valid_sets = [train_data] - - if X_val is not None and y_val is not None: - valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy()) - valid_sets.append(valid_data) - - # 设置默认参数 - default_params = { - "objective": self._get_objective(), - "metric": self._get_metric(), - "boosting_type": "gbdt", - "num_leaves": 31, - "learning_rate": 0.05, - "feature_fraction": 0.9, - "bagging_fraction": 0.8, - "bagging_freq": 5, - "verbose": -1 - } - default_params.update(self.params) - - # 训练 - self._model = lgb.train( - default_params, - train_data, - num_boost_round=fit_params.get("num_boost_round", 100), - valid_sets=valid_sets, - callbacks=[lgb.early_stopping(stopping_rounds=10, verbose=False)] if len(valid_sets) > 1 else [] - ) - self._is_fitted = True - return self - - def predict(self, X: pl.DataFrame) -> np.ndarray: - if not self._is_fitted: - raise RuntimeError("Model not fitted yet") - return self._model.predict(X.to_numpy()) - - def predict_proba(self, X: pl.DataFrame) -> np.ndarray: - if self.task_type != "classification": - raise ValueError("predict_proba only for classification") - probs = self.predict(X) - if len(probs.shape) == 1: - return np.vstack([1 - probs, probs]).T - return probs - - def get_feature_importance(self) -> Optional[pl.DataFrame]: - if self._model is None: - return None - importance = self._model.feature_importance(importance_type="gain") - return pl.DataFrame({ - "feature": self._model.feature_name(), - "importance": importance - }).sort("importance", descending=True) - - def _get_objective(self) -> str: - if self.task_type == "classification": - return "binary" - elif self.task_type == "regression": - return "regression" - elif self.task_type == "ranking": - return "lambdarank" - return "regression" - - def _get_metric(self) -> str: - if self.task_type == "classification": - return "auc" - elif self.task_type == "regression": - return "rmse" - elif self.task_type == "ranking": - return "ndcg" - return "rmse" - - -@PluginRegistry.register_model("catboost") -class CatBoostModel(BaseModel): - """CatBoost模型包装器""" - - def __init__(self, task_type: TaskType, params: Optional[Dict] = None, name: Optional[str] = None): - super().__init__(task_type, params, name) - self._model = None - - def fit( - self, - X: pl.DataFrame, - y: pl.Series, - X_val: Optional[pl.DataFrame] = None, - y_val: Optional[pl.Series] = None, - **fit_params - ) -> "CatBoostModel": - from catboost import CatBoostClassifier, CatBoostRegressor - - # 选择模型类型 - if self.task_type == "classification": - model_class = CatBoostClassifier - default_params = {"loss_function": "Logloss", "eval_metric": "AUC"} - elif self.task_type == "regression": - model_class = CatBoostRegressor - default_params = {"loss_function": "RMSE"} - else: # ranking - model_class = CatBoostRegressor - default_params = {"loss_function": "QueryRMSE"} - - default_params.update(self.params) - default_params["verbose"] = False - - self._model = model_class(**default_params) - - # 准备验证集 - eval_set = None - if X_val is not None and y_val is not None: - eval_set = (X_val.to_pandas(), y_val.to_pandas()) - - # 训练 - self._model.fit( - X.to_pandas(), - y.to_pandas(), - eval_set=eval_set, - early_stopping_rounds=10, - verbose=False - ) - self._is_fitted = True - return self - - def predict(self, X: pl.DataFrame) -> np.ndarray: - if not self._is_fitted: - raise RuntimeError("Model not fitted yet") - return self._model.predict(X.to_pandas()) - - def predict_proba(self, X: pl.DataFrame) -> np.ndarray: - if self.task_type != "classification": - raise ValueError("predict_proba only for classification") - return self._model.predict_proba(X.to_pandas()) - - def get_feature_importance(self) -> Optional[pl.DataFrame]: - if self._model is None: - return None - return pl.DataFrame({ - "feature": self._model.feature_names_, - "importance": self._model.feature_importances_ - }).sort("importance", descending=True) - - -# ========== 内置划分策略 ========== - -@PluginRegistry.register_splitter("time_series") -class TimeSeriesSplit(BaseSplitter): - """时间序列划分 - 确保训练数据在测试数据之前""" - - def __init__(self, n_splits: int = 5, gap: int = 5, min_train_size: int = 252): - self.n_splits = n_splits - self.gap = gap - self.min_train_size = min_train_size - - def split(self, data: pl.DataFrame, date_col: str = "trade_date"): - dates = data[date_col].unique().sort() - n_dates = len(dates) - - # 计算每个split的测试集大小 - test_size = (n_dates - self.min_train_size) // self.n_splits - - for i in range(self.n_splits): - # 训练集结束位置 - train_end_idx = self.min_train_size + i * test_size - # 测试集开始位置(留gap防止泄露) - test_start_idx = train_end_idx + self.gap - test_end_idx = test_start_idx + test_size - - if test_end_idx > n_dates: - break - - train_dates = dates[:train_end_idx] - test_dates = dates[test_start_idx:test_end_idx] - - train_mask = data[date_col].is_in(train_dates) - test_mask = data[date_col].is_in(test_dates) - - train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list() - test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list() - - yield train_idx, test_idx - - def get_split_dates(self, data: pl.DataFrame, date_col: str = "trade_date"): - dates = data[date_col].unique().sort() - n_dates = len(dates) - test_size = (n_dates - self.min_train_size) // self.n_splits - - result = [] - for i in range(self.n_splits): - train_end_idx = self.min_train_size + i * test_size - test_start_idx = train_end_idx + self.gap - test_end_idx = test_start_idx + test_size - - if test_end_idx > n_dates: - break - - result.append(( - dates[0], - dates[train_end_idx - 1], - dates[test_start_idx], - dates[test_end_idx - 1] - )) - return result - - -@PluginRegistry.register_splitter("walk_forward") -class WalkForwardSplit(BaseSplitter): - """滚动前向验证 - 训练集逐步扩展""" - - def __init__(self, train_window: int = 504, test_window: int = 21, gap: int = 5): - self.train_window = train_window - self.test_window = test_window - self.gap = gap - - def split(self, data: pl.DataFrame, date_col: str = "trade_date"): - dates = data[date_col].unique().sort() - n_dates = len(dates) - - start_idx = self.train_window - while start_idx + self.gap + self.test_window <= n_dates: - train_start = start_idx - self.train_window - train_end = start_idx - test_start = start_idx + self.gap - test_end = test_start + self.test_window - - train_dates = dates[train_start:train_end] - test_dates = dates[test_start:test_end] - - train_mask = data[date_col].is_in(train_dates) - test_mask = data[date_col].is_in(test_dates) - - train_idx = data.with_row_count().filter(train_mask)["row_count"].to_list() - test_idx = data.with_row_count().filter(test_mask)["row_count"].to_list() - - yield train_idx, test_idx - start_idx += self.test_window -``` - ---- - -## 5. 使用示例 - -### 5.1 基础用法 - -```python -from src.models import ( - FeatureStore, ProcessingPipeline, PluginRegistry, - PipelineStage, MLPipeline -) - -# 1. 创建数据存储 -store = FeatureStore(data_dir="data") - -# 2. 构建数据集 -dataset = store.build_dataset( - factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"], - label_config={"name": "forward_return", "forward_period": 5}, - date_range=("20200101", "20241231") -) - -# 3. 创建处理流水线 -processors = [ - # 删除缺失值 - PluginRegistry.get_processor("dropna")(), - - # 异常值处理(只在训练阶段计算分位数) - PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99), - - # 中性化(行业和市值中性化) - PluginRegistry.get_processor("neutralizer")(group_col="industry"), - - # 标准化(训练和测试都使用) - PluginRegistry.get_processor("standard_scaler")(), -] -pipeline = ProcessingPipeline(processors) - -# 4. 创建划分策略 -splitter = PluginRegistry.get_splitter("time_series")( - n_splits=5, - gap=5, - min_train_size=252 -) - -# 5. 创建模型 -model = PluginRegistry.get_model("lightgbm")( - task_type="regression", - params={"n_estimators": 200, "learning_rate": 0.03} -) - -# 6. 运行完整流程 -ml_pipeline = MLPipeline( - feature_store=store, - processing_pipeline=pipeline, - splitter=splitter, - model=model -) - -results = ml_pipeline.run( - factor_names=["pe", "pb", "roe", "momentum_20", "volatility_20"], - label_config={"name": "forward_return", "forward_period": 5}, - date_range=("20200101", "20241231") -) - -# 7. 查看结果 -print(results.metrics) # 各折的评估指标 -print(results.feature_importance) # 特征重要性 -print(results.predictions) # 预测结果 -``` - -### 5.2 配置驱动用法(推荐) - -```python -# config.yaml -experiment: - name: "momentum_factor_regression" - -data: - factor_names: ["momentum_5", "momentum_20", "momentum_60", "volatility_20"] - label: - name: "forward_return" - forward_period: 5 - date_range: ["20200101", "20241231"] - -processing: - - name: "dropna" - params: {} - stage: "all" - - - name: "winsorizer" - params: - lower: 0.01 - upper: 0.99 - stage: "train" # 只在训练阶段计算分位数 - - - name: "neutralizer" - params: - group_col: "industry" - stage: "all" - - - name: "standard_scaler" - params: {} - stage: "all" - -splitting: - strategy: "time_series" - params: - n_splits: 5 - gap: 5 - min_train_size: 252 - -model: - name: "lightgbm" - task_type: "regression" - params: - n_estimators: 200 - learning_rate: 0.03 - max_depth: 6 - -evaluation: - metrics: ["ic", "rank_ic", "mse", "mae"] - output_dir: "results/momentum_experiment" -``` - -```python -# 代码中使用配置 -from src.models import MLPipeline - -pipeline = MLPipeline.from_config("config.yaml") -results = pipeline.run() - -# 保存结果 -results.save("results/momentum_experiment") -``` - -### 5.3 自定义插件 - -```python -# 1. 创建自定义处理器 -@PluginRegistry.register_processor("my_transformer") -class MyTransformer(BaseProcessor): - """自定义转换器示例""" - stage = PipelineStage.ALL - - def __init__(self, columns=None, multiplier=2.0): - super().__init__(columns) - self.multiplier = multiplier - - def fit(self, data: pl.DataFrame) -> "MyTransformer": - # 学习参数(如有需要) - return self - - def transform(self, data: pl.DataFrame) -> pl.DataFrame: - result = data - for col in self.columns or []: - result = result.with_columns( - (pl.col(col) * self.multiplier).alias(col) - ) - return result - - -# 2. 创建自定义模型 -@PluginRegistry.register_model("my_model") -class MyModel(BaseModel): - """自定义模型示例""" - - def fit(self, X, y, X_val=None, y_val=None, **kwargs): - # 实现训练逻辑 - self._model = ... - return self - - def predict(self, X): - # 实现预测逻辑 - return self._model.predict(X) - - -# 3. 在配置中使用 -# config.yaml -processing: - - name: "my_transformer" - params: - multiplier: 3.0 - stage: "all" - -model: - name: "my_model" - task_type: "regression" -``` - ---- - -## 6. 目录结构 - -``` -src/ -├── models/ # 模型训练框架 -│ ├── __init__.py # 导出主要类 -│ ├── core/ # 核心抽象和基类 -│ │ ├── __init__.py -│ │ ├── processor.py # BaseProcessor, PipelineStage -│ │ ├── model.py # BaseModel, TaskType -│ │ ├── splitter.py # BaseSplitter -│ │ ├── metric.py # BaseMetric -│ │ └── pipeline.py # MLPipeline (编排器) -│ │ -│ ├── registry.py # PluginRegistry 插件注册中心 -│ │ -│ ├── data/ # 数据相关 -│ │ ├── __init__.py -│ │ ├── feature_store.py # FeatureStore 特征存储 -│ │ ├── label_generator.py # LabelGenerator 标签生成 -│ │ └── dataset.py # Dataset 数据集包装 -│ │ -│ ├── processors/ # 内置处理器 -│ │ ├── __init__.py # 自动注册所有处理器 -│ │ ├── scaler.py # StandardScaler -│ │ ├── winsorizer.py # Winsorizer -│ │ ├── neutralizer.py # Neutralizer -│ │ ├── imputer.py # FillNAProcessor -│ │ ├── selector.py # FeatureSelector -│ │ └── custom.py # 其他处理器 -│ │ -│ ├── models/ # 内置模型 -│ │ ├── __init__.py # 自动注册所有模型 -│ │ ├── lightgbm_model.py # LightGBMModel -│ │ ├── catboost_model.py # CatBoostModel -│ │ └── sklearn_model.py # SklearnModel (LR, RF等) -│ │ -│ ├── splitters/ # 划分策略 -│ │ ├── __init__.py -│ │ ├── time_series.py # TimeSeriesSplit -│ │ ├── walk_forward.py # WalkForwardSplit -│ │ └── purged.py # PurgedKFold -│ │ -│ ├── metrics/ # 评估指标 -│ │ ├── __init__.py -│ │ ├── ic.py # IC, RankIC -│ │ ├── returns.py # 收益指标 -│ │ └── classification.py # 分类指标 -│ │ -│ ├── evaluation/ # 评估和报告 -│ │ ├── __init__.py -│ │ ├── evaluator.py # ModelEvaluator -│ │ ├── report.py # ReportGenerator -│ │ └── visualizer.py # ResultVisualizer -│ │ -│ └── config/ # 配置解析 -│ ├── __init__.py -│ └── parser.py # ConfigParser -│ -├── factors/ # 已有因子框架 -│ └── ... -│ -tests/ -├── models/ # 模型框架测试 -│ ├── __init__.py -│ ├── test_processors.py # 处理器测试 -│ ├── test_models.py # 模型测试 -│ ├── test_pipeline.py # 流水线集成测试 -│ └── test_registry.py # 注册器测试 -│ -└── factors/ # 已有因子测试 - └── ... - -configs/ # 配置文件目录 -├── momentum_regression.yaml -├── value_classification.yaml -└├── ranking_lambdamart.yaml - -experiments/ # 实验结果目录 -└── {experiment_name}/ - ├── config.yaml # 实验配置 - ├── model.pkl # 保存的模型 - ├── processors.pkl # 保存的处理器状态 - ├── predictions.parquet # 预测结果 - ├── metrics.json # 评估指标 - ├── feature_importance.csv # 特征重要性 - └── report.html # 可视化报告 -``` - ---- - -## 7. 开发计划 - -### Phase 1: 核心基础设施 (Week 1-2) -- [ ] 设计并实现 `BaseProcessor`, `BaseModel`, `BaseSplitter` 抽象类 -- [ ] 实现 `PluginRegistry` 注册中心 -- [ ] 实现 `PipelineStage` 阶段管理 -- [ ] 编写基础单元测试 - -### Phase 2: 数据层 (Week 2-3) -- [ ] 实现 `FeatureStore` 特征存储 -- [ ] 实现 `LabelGenerator` 标签生成器 -- [ ] 实现 `Dataset` 数据集包装 -- [ ] 集成现有因子框架输出 - -### Phase 3: 处理器 (Week 3-4) -- [ ] 实现 `StandardScaler` 标准化处理器 -- [ ] 实现 `Winsorizer` 缩尾处理器 -- [ ] 实现 `Neutralizer` 中性化处理器 -- [ ] 实现 `FillNAProcessor` 缺失值处理器 -- [ ] 实现 `DropNAProcessor` 缺失值删除处理器 -- [ ] 实现 `FeatureSelector` 特征选择器 -- [ ] 实现 `ProcessingPipeline` 流水线 - -### Phase 4: 模型层 (Week 4-5) -- [ ] 实现 `LightGBMModel` LightGBM包装 -- [ ] 实现 `CatBoostModel` CatBoost包装 -- [ ] 实现 `SklearnModel` sklearn模型支持 -- [ ] 支持 classification/regression/ranking 三种任务 - -### Phase 5: 划分策略 (Week 5) -- [ ] 实现 `TimeSeriesSplit` 时间序列划分 -- [ ] 实现 `WalkForwardSplit` 滚动前向验证 -- [ ] 实现 `PurgedKFold` 清除重叠样本 - -### Phase 6: 评估层 (Week 5-6) -- [ ] 实现 IC/RankIC 指标 -- [ ] 实现收益分析指标 -- [ ] 实现分类指标 -- [ ] 实现 `ModelEvaluator` 评估器 -- [ ] 实现 `ReportGenerator` 报告生成 - -### Phase 7: 配置和编排 (Week 6) -- [ ] 实现配置解析器 -- [ ] 实现 `MLPipeline` 编排器 -- [ ] 支持配置驱动执行 - -### Phase 8: 集成测试和文档 (Week 7) -- [ ] 编写完整集成测试 -- [ ] 编写使用文档 -- [ ] 编写示例代码 -- [ ] 性能基准测试 - ---- - -## 8. 关键设计决策 - -| 决策点 | 选择 | 理由 | -|--------|------|------| -| **数据处理阶段标记** | `PipelineStage` 枚举 | 显式、类型安全、易于扩展 | -| **插件注册方式** | 装饰器模式 | Pythonic、简洁、自动发现 | -| **数据格式** | Polars DataFrame | 与因子框架一致、高性能 | -| **模型接口** | `fit/predict` 统一接口 | 行业标准、易于替换模型 | -| **配置格式** | YAML | 人类可读、支持复杂结构 | -| **处理器状态保存** | pickle | 简单、Python原生、支持大部分对象 | -| **特征存储** | 从因子框架直接读取 | 避免数据冗余、保持一致性 | - ---- - -## 9. 防数据泄露检查清单 - -- [x] 处理器明确标记适用阶段 (`stage` 属性) -- [x] `TRAIN` 阶段处理器只在训练数据上 `fit` -- [x] `TEST` 阶段使用训练阶段学习到的参数 -- [x] 划分策略支持时间序列感知 (`TimeSeriesSplit`, `WalkForwardSplit`) -- [x] 划分时支持 `gap` 参数防止相邻样本泄露 -- [x] 特征存储从已计算的因子加载(不访问未来数据) -- [x] 标签生成使用预定义的前瞻期(明确的future data) - ---- - -*文档版本: v1.0* -*最后更新: 2026-02-23* -*设计状态: 草案 - 待评审* diff --git a/docs/test_report_duckdb_migration.md b/docs/test_report_duckdb_migration.md deleted file mode 100644 index 7085d61..0000000 --- a/docs/test_report_duckdb_migration.md +++ /dev/null @@ -1,211 +0,0 @@ -# ProStock HDF5 到 DuckDB 迁移测试报告 - -**报告生成时间**: 2026-02-22 -**完成时间**: 2026-02-22 -**状态**: ✅ 已完成 -**迁移文档**: [hdf5_to_duckdb_migration.md](./hdf5_to_duckdb_migration.md) -**测试数据范围**: 2024年1月-3月(3个月) - ---- - -## 1. 迁移实施摘要 - -### 已完成的核心任务 ✅ - -| 任务 | 文件 | 状态 | -|------|------|------| -| Storage 类重写 | `src/data/storage.py` | ✅ 完成 | -| ThreadSafeStorage 实现 | `src/data/storage.py` | ✅ 完成 | -| Sync 模块适配 | `src/data/sync.py` | ✅ 完成 | -| DataLoader 适配 | `src/factors/data_loader.py` | ✅ 完成 | -| 测试文件更新 | `tests/` | ✅ 完成 | - -### 架构变更 - -``` -HDF5 格式 (.h5 文件) → DuckDB (prostock.db) -├── pandas.read_hdf() → duckdb.execute().fetchdf() -├── 全表加载到内存 → SQL 查询下推,按需加载 -├── 文件锁并发 → ThreadSafeStorage 队列写入 -└── Polars 通过 Pandas 中转 → DuckDB → PyArrow → Polars (零拷贝) -``` - ---- - -## 2. 测试执行情况 - -### 2.1 测试文件清单 - -| 测试文件 | 测试类型 | 数据范围 | -|---------|---------|---------| -| `test_daily_storage.py` | DuckDB Storage 集成测试 | 3个月(2024/01-03) | -| `test_data_loader.py` | DataLoader 功能测试 | 3个月(2024/01-03) | -| `test_sync.py` | Sync 模块单元测试 | Mock 数据 | - -### 2.2 关键测试用例 - -#### DuckDB Storage 测试 (`test_daily_storage.py`) - -```python -class TestDailyStorageValidation: - TEST_START_DATE = "20240101" - TEST_END_DATE = "20240331" # 3个月数据 - - def test_duckdb_connection() # ✅ 连接测试 - def test_load_3months_data() # ⚠️ 需要先有数据 - def test_polars_export() # ✅ PyArrow 零拷贝导出 - def test_all_stocks_saved() # ⚠️ 需要先有数据 -``` - -#### DataLoader 测试 (`test_data_loader.py`) - -```python -class TestDataLoaderBasic: - def test_load_single_source() # 从 DuckDB 加载 - def test_load_with_date_range() # 3个月日期范围 - def test_column_selection() # 列选择 - def test_cache_used() # 缓存性能 -``` - ---- - -## 3. 性能对比预期 - -| 测试项 | HDF5 (旧) | DuckDB (新) | 预期提升 | -|--------|----------|------------|---------| -| 单股票查询 | 5-10s | 0.1-0.5s | **10-100x** | -| 日期范围查询 | 5-10s | 0.2-1s | **5-50x** | -| 内存占用 | 1GB+ | 100-500MB | **50-90%** | - ---- - -## 4. 使用前准备 - -### 4.1 数据同步(必须) - -当前数据库中没有 2024年1-3月的测试数据,需要先进行数据同步: - -```bash -# 方式1: 同步特定股票代码的3个月数据(推荐用于测试) -uv run python -c " -from src.data.sync import DataSync -from src.data.api_wrappers import get_daily -import pandas as pd - -# 获取测试股票数据 -data = get_daily('000001.SZ', start_date='20240101', end_date='20240331') - -# 保存到 DuckDB -from src.data.storage import Storage -storage = Storage() -storage.save('daily', data) -print(f'已保存 {len(data)} 行数据') -" - -# 方式2: 全量同步所有股票(耗时较长) -uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" - -# 方式3: 增量同步(从上次同步日期继续) -uv run python -c "from src.data.sync import sync_all; sync_all()" -``` - -### 4.2 验证安装 - -```bash -# 检查 DuckDB 和 PyArrow 是否安装 -uv run python -c "import duckdb; import pyarrow; print('✅ 依赖检查通过')" - -# 验证 Storage 类 -uv run python -c "from src.data.storage import Storage, ThreadSafeStorage; print('✅ Storage 类导入成功')" -``` - ---- - -## 5. 运行测试 - -### 5.1 运行所有测试 - -```bash -# 运行 DuckDB 相关测试 -uv run pytest tests/test_daily_storage.py tests/factors/test_data_loader.py -v - -# 运行 Sync 模块测试 -uv run pytest tests/test_sync.py -v - -# 运行全部测试 -uv run pytest tests/ -v -``` - -### 5.2 预期输出 - -``` -tests/test_daily_storage.py::TestDailyStorageValidation::test_duckdb_connection PASSED -tests/test_daily_storage.py::TestDailyStorageValidation::test_polars_export PASSED -tests/factors/test_data_loader.py::TestDataLoaderBasic::test_load_single_source PASSED -tests/factors/test_data_loader.py::TestDataLoaderBasic::test_load_with_date_range PASSED -... -``` - ---- - -## 6. 常见问题 (FAQ) - -### Q: 测试提示 "No data found for period"? -**A**: 需要先执行数据同步,将 2024年1-3月的数据写入 DuckDB。 - -### Q: ModuleNotFoundError: No module named 'pyarrow'? -**A**: 需要安装 pyarrow: -```bash -uv pip install pyarrow -``` - -### Q: 如何查看数据库中的数据? -**A**: -```python -from src.data.storage import Storage -storage = Storage() - -# 检查表是否存在 -print(storage.exists("daily")) # True/False - -# 查询最新日期 -print(storage.get_last_date("daily")) # "20240331" -``` - -### Q: 如何备份 DuckDB 数据库? -**A**: -```bash -# 备份 -cp data/prostock.db data/prostock_backup.db - -# 恢复 -cp data/prostock_backup.db data/prostock.db -``` - ---- - -## 7. 迁移验证清单 - -- [x] Storage 类实现 DuckDB 存储 -- [x] ThreadSafeStorage 实现并发安全 -- [x] DataLoader 适配 DuckDB -- [x] Sync 模块使用 ThreadSafeStorage -- [x] 测试文件更新为 3 个月数据范围 -- [x] PyArrow 零拷贝导出支持 -- [ ] 执行数据同步(需手动运行) -- [ ] 运行全部测试通过(需先有数据) -- [ ] 性能基准测试对比 - ---- - -## 8. 下一步行动 - -1. **数据同步**: 运行上述 4.1 节的数据同步命令 -2. **测试验证**: 运行 `uv run pytest tests/ -v` 确认所有测试通过 -3. **性能测试**: 使用 `scripts/benchmark_storage.py` 对比 HDF5 vs DuckDB 性能 -4. **生产部署**: 备份 HDF5 文件,删除旧数据,完全切换到 DuckDB - ---- - -**报告生成**: ProStock Migration Tool -**状态**: 核心代码完成,等待数据同步后运行测试 diff --git a/src/__init__.py b/src/__init__.py index bec3b83..b4221d4 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -3,3 +3,17 @@ 提供股票数据分析和交易策略等功能 """ __version__ = "1.0.0" + + +import warnings +from pandas.errors import SettingWithCopyWarning + +# 忽略 tushare 库的 FutureWarning(fillna method 参数已弃用) +warnings.filterwarnings( + "ignore", + category=FutureWarning, + message=".*fillna with 'method' is deprecated.*", +) + +# 忽略 SettingWithCopyWarning(常见于 pandas 链式赋值) +warnings.filterwarnings("ignore", category=SettingWithCopyWarning) \ No newline at end of file diff --git a/src/config/settings.py b/src/config/settings.py index acdf989..047dbf5 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -2,6 +2,7 @@ 从环境变量加载应用配置,使用pydantic-settings进行类型验证 """ + import os from pathlib import Path from pydantic_settings import BaseSettings @@ -15,20 +16,33 @@ CONFIG_DIR = PROJECT_ROOT / "config" class Settings(BaseSettings): - """应用配置类,从环境变量加载""" + """应用配置类,从环境变量加载 - # 数据库配置 + 所有配置项都会自动从环境变量读取(小写转大写) + 例如:tushare_token 会读取 TUSHARE_TOKEN 环境变量 + """ + + # Tushare API 配置 + tushare_token: str = "" + + # 数据存储配置 + root_path: str = "" # 项目根路径,默认自动检测 + data_path: str = "data" # 数据存储路径,相对于 root_path + + # API 速率限制(每分钟请求数) + rate_limit: int = 300 + + # 同步工作线程数 + threads: int = 10 + + # 数据库配置(可选,用于未来扩展) database_host: str = "localhost" database_port: int = 5432 database_name: str = "prostock" - database_user: str - database_password: str + database_user: Optional[str] = None + database_password: Optional[str] = None - # API密钥配置 - api_key: str - secret_key: str - - # Redis配置 + # Redis配置(可选,用于未来扩展) redis_host: str = "localhost" redis_port: int = 6379 redis_password: Optional[str] = None @@ -38,11 +52,27 @@ class Settings(BaseSettings): app_debug: bool = False app_port: int = 8000 + @property + def project_root(self) -> Path: + """获取项目根路径。""" + if self.root_path: + return Path(self.root_path) + return PROJECT_ROOT + + @property + def data_path_resolved(self) -> Path: + """获取解析后的数据路径(绝对路径)。""" + path = Path(self.data_path) + if path.is_absolute(): + return path + return self.project_root / path + class Config: # 从 config/ 目录读取 .env.local 文件 env_file = str(CONFIG_DIR / ".env.local") env_file_encoding = "utf-8" case_sensitive = False + extra = "ignore" # 忽略 .env.local 中的额外变量 @lru_cache() diff --git a/src/data/__init__.py b/src/data/__init__.py index 8036357..ab47151 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -3,7 +3,7 @@ Provides simplified interfaces for fetching and storing Tushare data. """ -from src.data.config import Config, get_config +from src.config.settings import Settings, get_settings, settings from src.data.client import TushareClient from src.data.storage import Storage, ThreadSafeStorage, DEFAULT_TYPE_MAPPING from src.data.api_wrappers import get_stock_basic, sync_all_stocks diff --git a/src/data/api_wrappers/API_INTERFACE_SPEC.md b/src/data/api_wrappers/API_INTERFACE_SPEC.md index ced32b6..9b715a8 100644 --- a/src/data/api_wrappers/API_INTERFACE_SPEC.md +++ b/src/data/api_wrappers/API_INTERFACE_SPEC.md @@ -169,6 +169,120 @@ if "date" in data.columns: ### 4.5 令牌桶限速要求 +所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。 + +#### 4.5.1 基本用法(单线程场景) + +```python +from src.data.client import TushareClient + +def get_{data_type}(...) -> pd.DataFrame: + client = TushareClient() + + # Build parameters + params = {} + if trade_date: + params["trade_date"] = trade_date + if ts_code: + params["ts_code"] = ts_code + # ... + + # Fetch data (rate limiting handled automatically) + data = client.query("{api_name}", **params) + + return data +``` + +**注意**: `TushareClient` 自动处理: +- 令牌桶速率限制 +- API 重试逻辑(指数退避) +- 配置加载 + +#### 4.5.2 多线程/并发场景(重要) + +**问题**: 多线程并发调用时,如果每个线程创建独立的 `TushareClient` 实例,每个实例会有独立的限流器,导致实际并发请求数 = 线程数 × 单个限流器速率,**限流失效**。 + +**解决方案**: 数据获取函数必须接受可选的 `client` 参数,Sync 类传递共享的客户端实例。 + +**数据获取函数签名**(必须支持 client 参数): + +```python +from src.data.client import TushareClient +from typing import Optional + +def get_{data_type}( + ts_code: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + client: Optional[TushareClient] = None, # 新增:可选客户端参数 +) -> pd.DataFrame: + """Fetch {数据描述} from Tushare. + + Args: + ts_code: Stock code + start_date: Start date (YYYYMMDD) + end_date: End date (YYYYMMDD) + client: Optional TushareClient instance for shared rate limiting. + If None, creates a new client. For concurrent sync operations, + pass a shared client to ensure proper rate limiting. + + Returns: + pd.DataFrame with data + """ + client = client or TushareClient() # 如果没有提供则创建新实例 + + params = {"ts_code": ts_code} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + + data = client.query("{api_name}", **params) + return data +``` + +**Sync 类实现**(必须传递共享 client): + +```python +from concurrent.futures import ThreadPoolExecutor +from src.data.client import TushareClient +from src.data.storage import ThreadSafeStorage + +class {DataType}Sync: + def __init__(self, max_workers: Optional[int] = None): + self.storage = ThreadSafeStorage() + self.client = TushareClient() # 共享客户端实例 + self.max_workers = max_workers or 10 + + def sync_single_stock( + self, + ts_code: str, + start_date: str, + end_date: str, + ) -> pd.DataFrame: + """同步单只股票的数据。""" + # 传递共享 client 以确保多线程下的限流生效 + data = get_{data_type}( + ts_code=ts_code, + start_date=start_date, + end_date=end_date, + client=self.client, # 关键:传递共享客户端 + ) + return data + + def sync_all(self, ...): + # 使用 ThreadPoolExecutor 并发执行 + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + # 所有线程共享 self.client,限流器正常工作 + ... +``` + +**关键规则**: +1. 所有按股票获取的接口必须接受 `client: Optional[TushareClient] = None` 参数 +2. Sync 类在 `__init__` 中创建 `self.client = TushareClient()` +3. Sync 类的同步方法必须将 `self.client` 传递给数据获取函数 +4. 数据获取函数使用 `client = client or TushareClient()` 模式 + 所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求: ```python @@ -198,6 +312,26 @@ def get_{data_type}(...) -> pd.DataFrame: ## 5. DuckDB 存储规范 +### 5.0 强制落库要求(关键) + +**所有封装的 API 接口必须将数据落库到 DuckDB。** + +这是数据同步的核心原则,确保: +- 数据持久化:避免重复调用 API,节省 token +- 增量更新:基于本地数据状态进行智能同步 +- 数据一致性:所有数据都有统一的存储和访问方式 +- 离线可用:数据可以在没有网络的情况下查询 + +**落库检查清单**: +- [ ] 在 `storage.py` 的 `_init_db()` 方法中创建对应的表 +- [ ] 表结构必须包含 `ts_code` 和 `trade_date` 作为主键 +- [ ] 实现 `sync_{data_type}()` 函数,使用 `Storage` 或 `ThreadSafeStorage` 保存数据 +- [ ] 确保同步逻辑正确处理增量更新 + +**反例警示**:`api_pro_bar.py` 早期版本虽然实现了 `sync_pro_bar()` 函数,但忘记在 `storage.py` 中创建 `pro_bar` 表,导致同步的数据无法落库,造成 token 浪费和数据丢失。 + +### 5.1 存储架构 + ### 5.1 存储架构 项目使用 **DuckDB** 作为持久化存储: diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py index b7a2bac..1060376 100644 --- a/src/data/api_wrappers/__init__.py +++ b/src/data/api_wrappers/__init__.py @@ -5,6 +5,7 @@ All wrapper files follow the naming convention: api_{data_type}.py Available APIs: - api_daily: Daily market data (日线行情) + - api_pro_bar: Pro Bar universal market data (通用行情,后复权) - api_stock_basic: Stock basic information (股票基本信息) - api_trade_cal: Trading calendar (交易日历) - api_namechange: Stock name change history (股票曾用名) @@ -12,15 +13,31 @@ Available APIs: Example: >>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic - >>> from src.data.api_wrappers import get_bak_basic, sync_bak_basic + >>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') + >>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131') >>> stocks = get_stock_basic() >>> calendar = get_trade_cal('20240101', '20240131') >>> bak_basic = get_bak_basic(trade_date='20240101') """ -from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync -from src.data.api_wrappers.financial_data.api_income import get_income, sync_income, IncomeSync +from src.data.api_wrappers.api_daily import ( + get_daily, + sync_daily, + preview_daily_sync, + DailySync, +) +from src.data.api_wrappers.api_pro_bar import ( + get_pro_bar, + sync_pro_bar, + preview_pro_bar_sync, + ProBarSync, +) +from src.data.api_wrappers.financial_data.api_income import ( + get_income, + sync_income, + IncomeSync, +) from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks @@ -38,6 +55,11 @@ __all__ = [ "sync_daily", "preview_daily_sync", "DailySync", + # Pro Bar (universal market data) + "get_pro_bar", + "sync_pro_bar", + "preview_pro_bar_sync", + "ProBarSync", # Income statement "get_income", "sync_income", diff --git a/src/data/api_wrappers/api.md b/src/data/api_wrappers/api.md index 8e83438..26f398e 100644 --- a/src/data/api_wrappers/api.md +++ b/src/data/api_wrappers/api.md @@ -345,4 +345,154 @@ df = pro.bak_basic(trade_date='20211012', fields='trade_date,ts_code,name,indust 4530 20211012 688255.SH 凯尔达 机械基件 0.0000 4531 20211012 688211.SH 中科微至 专用机械 0.0000 4532 20211012 605567.SH 春雪食品 食品 0.0000 -4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000 \ No newline at end of file +4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000 + + +通用行情接口 +接口名称:pro_bar,本接口是集成开发接口,部分指标是现用现算 +更新时间:股票和指数通常在15点~17点之间,数字货币实时更新,具体请参考各接口文档明细。 +描述:目前整合了股票(未复权、前复权、后复权)、指数、数字货币、ETF基金、期货、期权的行情数据,未来还将整合包括外汇在内的所有交易行情数据,同时提供分钟数据。不同数据对应不同的积分要求,具体请参阅每类数据的文档说明。 +其它:由于本接口是集成接口,在SDK层做了一些逻辑处理,目前暂时没法用http的方式调取通用行情接口。用户可以访问Tushare的Github,查看源代码完成类似功能。 + +输入参数 + +名称 类型 必选 描述 +ts_code str Y 证券代码,不支持多值输入,多值输入获取结果会有重复记录 +start_date str N 开始日期 (日线格式:YYYYMMDD,提取分钟数据请用2019-09-01 09:00:00这种格式) +end_date str N 结束日期 (日线格式:YYYYMMDD) +asset str Y 资产类别:E股票 I沪深指数 C数字货币 FT期货 FD基金 O期权 CB可转债(v1.2.39),默认E +adj str N 复权类型(只针对股票):None未复权 qfq前复权 hfq后复权 , 默认None,目前只支持日线复权,同时复权机制是根据设定的end_date参数动态复权,采用分红再投模式,具体请参考常见问题列表里的说明。 +freq str Y 数据频度 :支持分钟(min)/日(D)/周(W)/月(M)K线,其中1min表示1分钟(类推1/5/15/30/60分钟) ,默认D。对于分钟数据有600积分用户可以试用(请求2次),正式权限可以参考权限列表说明 ,使用方法请参考股票分钟使用方法。 +ma list N 均线,支持任意合理int数值。注:均线是动态计算,要设置一定时间范围才能获得相应的均线,比如5日均线,开始和结束日期参数跨度必须要超过5日。目前只支持单一个股票提取均线,即需要输入ts_code参数。e.g: ma_5表示5日均价,ma_v_5表示5日均量 +factors list N 股票因子(asset='E'有效)支持 tor换手率 vr量比 +adjfactor str N 复权因子,在复权数据时,如果此参数为True,返回的数据中则带复权因子,默认为False。 该功能从1.2.33版本开始生效 + +输出指标 + +具体输出的数据指标可参考各行情具体指标: + +股票Daily:https://tushare.pro/document/2?doc_id=27 +(内容如下:A股日线行情 +接口:daily,可以通过数据工具调试和查看数据 +数据说明:交易日每天15点~16点之间入库。本接口是未复权行情,停牌期间不提供数据 +调取说明:基础积分每分钟内可调取500次,每次6000条数据,一次请求相当于提取一个股票23年历史 +描述:获取股票行情数据,或通过通用行情接口获取数据,包含了前后复权数据 + +输入参数 + +名称 类型 必选 描述 +ts_code str N 股票代码(支持多个股票同时提取,逗号分隔) +trade_date str N 交易日期(YYYYMMDD) +start_date str N 开始日期(YYYYMMDD) +end_date str N 结束日期(YYYYMMDD) +注:日期都填YYYYMMDD格式,比如20181010 + +输出参数 + +名称 类型 描述 +ts_code str 股票代码 +trade_date str 交易日期 +open float 开盘价 +high float 最高价 +low float 最低价 +close float 收盘价 +pre_close float 昨收价【除权价】 +change float 涨跌额 +pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】 +vol float 成交量 (手) +amount float 成交额 (千元) +接口示例 + +pro = ts.pro_api() + +df = pro.daily(ts_code='000001.SZ', start_date='20180701', end_date='20180718') + +#多个股票 +df = pro.daily(ts_code='000001.SZ,600000.SH', start_date='20180701', end_date='20180718') +或者 + +df = pro.query('daily', ts_code='000001.SZ', start_date='20180701', end_date='20180718') +也可以通过日期取历史某一天的全部历史 + +df = pro.daily(trade_date='20180810') +数据样例 + + ts_code trade_date open high low close pre_close change pct_chg vol amount +0 000001.SZ 20180718 8.75 8.85 8.69 8.70 8.72 -0.02 -0.23 525152.77 460697.377 +1 000001.SZ 20180717 8.74 8.75 8.66 8.72 8.73 -0.01 -0.11 375356.33 326396.994 +2 000001.SZ 20180716 8.85 8.90 8.69 8.73 8.88 -0.15 -1.69 689845.58 603427.713 +3 000001.SZ 20180713 8.92 8.94 8.82 8.88 8.88 0.00 0.00 603378.21 535401.175 +4 000001.SZ 20180712 8.60 8.97 8.58 8.88 8.64 0.24 2.78 1140492.31 1008658.828 +5 000001.SZ 20180711 8.76 8.83 8.68 8.78 8.98 -0.20 -2.23 851296.70 744765.824 +6 000001.SZ 20180710 9.02 9.02 8.89 8.98 9.03 -0.05 -0.55 896862.02 803038.965 +7 000001.SZ 20180709 8.69 9.03 8.68 9.03 8.66 0.37 4.27 1409954.60 1255007.609 +8 000001.SZ 20180706 8.61 8.78 8.45 8.66 8.60 0.06 0.70 988282.69 852071.526 +9 000001.SZ 20180705 8.62 8.73 8.55 8.60 8.61 -0.01 -0.12 835768.77 722169.579) + +基金Daily:https://tushare.pro/document/2?doc_id=127 + +期货Daily:https://tushare.pro/document/2?doc_id=138 + +期权Daily:https://tushare.pro/document/2?doc_id=159 + +指数Daily:https://tushare.pro/document/2?doc_id=95 + +接口用例 + + +#取000001的前复权行情 +df = ts.pro_bar(ts_code='000001.SZ', adj='qfq', start_date='20180101', end_date='20181011') + + ts_code trade_date open high low close \ +trade_date +20181011 000001.SZ 20181011 1085.71 1097.59 1047.90 1065.19 +20181010 000001.SZ 20181010 1138.65 1151.61 1121.36 1128.92 +20181009 000001.SZ 20181009 1130.00 1155.93 1122.44 1140.81 +20181008 000001.SZ 20181008 1155.93 1165.65 1128.92 1128.92 +20180928 000001.SZ 20180928 1164.57 1217.51 1164.57 1193.74 + + + +#取上证指数行情数据 + +df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011') + +In [10]: df.head() +Out[10]: + ts_code trade_date close open high low \ +0 000001.SH 20181011 2583.4575 2643.0740 2661.2859 2560.3164 +1 000001.SH 20181010 2725.8367 2723.7242 2743.5480 2703.0626 +2 000001.SH 20181009 2721.0130 2713.7319 2734.3142 2711.1971 +3 000001.SH 20181008 2716.5104 2768.2075 2771.9384 2710.1781 +4 000001.SH 20180928 2821.3501 2794.2644 2821.7553 2791.8363 + + pre_close change pct_chg vol amount +0 2725.8367 -142.3792 -5.2233 197150702.0 170057762.5 +1 2721.0130 4.8237 0.1773 113485736.0 111312455.3 +2 2716.5104 4.5026 0.1657 116771899.0 110292457.8 +3 2821.3501 -104.8397 -3.7159 149501388.0 141531551.8 +4 2791.7748 29.5753 1.0594 134290456.0 125369989.4 + + + +#均线 + +df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', ma=[5, 20, 50]) +注:Tushare pro_bar接口的均价和均量数据是动态计算,想要获取某个时间段的均线,必须要设置start_date日期大于最大均线的日期数,然后自行截取想要日期段。例如,想要获取20190801开始的3日均线,必须设置start_date='20190729',然后剔除20190801之前的日期记录。 + + + + +#换手率tor,量比vr + +df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', factors=['tor', 'vr']) + + +说明 + +对于pro_api参数,如果在一开始就通过 ts.set_token('xxxx') 设置过token的情况,这个参数就不是必需的。 + +例如: + + +df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011') \ No newline at end of file diff --git a/src/data/api_wrappers/api_bak_basic.py b/src/data/api_wrappers/api_bak_basic.py index 9185bb3..a06b166 100644 --- a/src/data/api_wrappers/api_bak_basic.py +++ b/src/data/api_wrappers/api_bak_basic.py @@ -129,7 +129,9 @@ def sync_bak_basic( columns = [] for col in sample.columns: dtype = str(sample[col].dtype) - if "int" in dtype: + if col == "trade_date": + col_type = "DATE" + elif "int" in dtype: col_type = "INTEGER" elif "float" in dtype: col_type = "DOUBLE" @@ -223,10 +225,16 @@ def sync_bak_basic( # Combine and save combined = pd.concat(all_data, ignore_index=True) + + # Convert trade_date to datetime for proper DATE type storage + combined["trade_date"] = pd.to_datetime(combined["trade_date"], format="%Y%m%d") + print(f"[sync_bak_basic] Total records: {len(combined)}") # Delete existing data for the date range and append new data - storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start]) + # Convert sync_start to date format for comparison with DATE column + sync_start_date = pd.to_datetime(sync_start, format="%Y%m%d").date() + storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start_date]) thread_storage.queue_save(TABLE_NAME, combined) thread_storage.flush() diff --git a/src/data/api_wrappers/api_daily.py b/src/data/api_wrappers/api_daily.py index 456ab9d..96b4929 100644 --- a/src/data/api_wrappers/api_daily.py +++ b/src/data/api_wrappers/api_daily.py @@ -17,6 +17,7 @@ import threading from src.data.client import TushareClient from src.data.storage import ThreadSafeStorage, Storage from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE +from src.config.settings import get_settings from src.data.api_wrappers.api_trade_cal import ( get_first_trading_day, get_last_trading_day, @@ -105,16 +106,15 @@ class DailySync: - 预览模式(预览同步数据量,不实际写入) """ - # 默认工作线程数 - DEFAULT_MAX_WORKERS = 10 + # 默认工作线程数(从配置读取,默认10) + DEFAULT_MAX_WORKERS = get_settings().threads def __init__(self, max_workers: Optional[int] = None): """初始化同步管理器。 Args: - max_workers: 工作线程数(默认: 10) + max_workers: 工作线程数(默认从配置读取) """ - self.storage = ThreadSafeStorage() self.client = TushareClient() self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS self._stop_flag = threading.Event() diff --git a/src/data/api_wrappers/api_namechange.py b/src/data/api_wrappers/api_namechange.py index dbddd83..57aabde 100644 --- a/src/data/api_wrappers/api_namechange.py +++ b/src/data/api_wrappers/api_namechange.py @@ -8,13 +8,13 @@ import pandas as pd from pathlib import Path from typing import Optional, List from src.data.client import TushareClient -from src.data.config import get_config +from src.config.settings import get_settings # CSV file path for namechange data def _get_csv_path() -> Path: """Get the CSV file path for namechange data.""" - cfg = get_config() + cfg = get_settings() return cfg.data_path_resolved / "namechange.csv" diff --git a/src/data/api_wrappers/api_stock_basic.py b/src/data/api_wrappers/api_stock_basic.py index 0cdd48b..a695e2d 100644 --- a/src/data/api_wrappers/api_stock_basic.py +++ b/src/data/api_wrappers/api_stock_basic.py @@ -9,13 +9,13 @@ import pandas as pd from pathlib import Path from typing import Optional, Literal, List from src.data.client import TushareClient -from src.data.config import get_config +from src.config.settings import get_settings # CSV file path for stock basic data def _get_csv_path() -> Path: """Get the CSV file path for stock basic data.""" - cfg = get_config() + cfg = get_settings() return cfg.data_path_resolved / "stock_basic.csv" diff --git a/src/data/api_wrappers/api_trade_cal.py b/src/data/api_wrappers/api_trade_cal.py index fe0cd4c..6250d89 100644 --- a/src/data/api_wrappers/api_trade_cal.py +++ b/src/data/api_wrappers/api_trade_cal.py @@ -8,7 +8,7 @@ import pandas as pd from typing import Optional, Literal from pathlib import Path from src.data.client import TushareClient -from src.data.config import get_config +from src.config.settings import get_settings # Module-level flag to track if cache has been synced in this session _cache_synced = False @@ -18,7 +18,7 @@ _cache_synced = False # Trading calendar cache file path def _get_cache_path() -> Path: """Get the cache file path for trade calendar.""" - cfg = get_config() + cfg = get_settings() return cfg.data_path_resolved / "trade_cal.h5" @@ -296,8 +296,8 @@ def get_first_trading_day( trading_days = get_trading_days(start_date, end_date, exchange) if not trading_days: return None - # Trading days are sorted in descending order (newest first) from cache - return trading_days[-1] + # Return the earliest trading day + return min(trading_days) def get_last_trading_day( @@ -318,8 +318,8 @@ def get_last_trading_day( trading_days = get_trading_days(start_date, end_date, exchange) if not trading_days: return None - # Trading days are sorted in descending order (newest first) from cache - return trading_days[0] + # Return the latest trading day + return max(trading_days) if __name__ == "__main__": diff --git a/src/data/client.py b/src/data/client.py index c66edf3..920e803 100644 --- a/src/data/client.py +++ b/src/data/client.py @@ -1,21 +1,25 @@ """Simplified Tushare client with rate limiting and retry logic.""" + import time import pandas as pd from typing import Optional -from src.data.config import get_config from src.data.rate_limiter import TokenBucketRateLimiter +from src.config.settings import get_settings class TushareClient: """Tushare API client with rate limiting and retry.""" + # 类级别共享限流器(确保所有实例共享同一个限流器) + _shared_limiter: Optional[TokenBucketRateLimiter] = None + def __init__(self, token: Optional[str] = None): """Initialize client. Args: token: Tushare API token (auto-loaded from config if not provided) """ - cfg = get_config() + cfg = get_settings() token = token or cfg.tushare_token if not token: @@ -24,12 +28,21 @@ class TushareClient: self.token = token self.config = cfg - # Initialize rate limiter: capacity = rate_limit, refill_rate = rate_limit/60 per second + # 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器) rate_per_second = cfg.rate_limit / 60.0 - self.rate_limiter = TokenBucketRateLimiter( - capacity=cfg.rate_limit, - refill_rate_per_second=rate_per_second, - ) + capacity = cfg.rate_limit + + if TushareClient._shared_limiter is None: + # 首次创建:初始化共享限流器 + TushareClient._shared_limiter = TokenBucketRateLimiter( + capacity=capacity, + refill_rate_per_second=rate_per_second, + ) + print( + f"[TushareClient] Initialized shared rate limiter: capacity={capacity}, window=60s" + ) + # 复用共享限流器 + self.rate_limiter = TushareClient._shared_limiter self._api = None @@ -37,6 +50,7 @@ class TushareClient: """Get Tushare API instance.""" if self._api is None: import tushare as ts + self._api = ts.pro_api(self.token) return self._api @@ -52,7 +66,7 @@ class TushareClient: DataFrame with query results """ # Acquire rate limit token (None = wait indefinitely) - timeout = timeout if timeout is not None else float('inf') + timeout = timeout if timeout is not None else float("inf") success, wait_time = self.rate_limiter.acquire(timeout=timeout) if not success: @@ -72,14 +86,21 @@ class TushareClient: # pro_bar uses ts.pro_bar() instead of api.query() if api_name == "pro_bar": # pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor - data = ts.pro_bar(ts_code=params.get("ts_code"), - start_date=params.get("start_date"), - end_date=params.get("end_date"), - adj=params.get("adj"), - freq=params.get("freq", "D"), - factors=params.get("factors"), # factors should be a list like ['tor', 'vr'] - ma=params.get("ma"), - adjfactor=params.get("adjfactor")) + data = ts.pro_bar( + ts_code=params.get("ts_code"), + start_date=params.get("start_date"), + end_date=params.get("end_date"), + adj=params.get("adj"), + freq=params.get("freq", "D"), + factors=params.get( + "factors" + ), # factors should be a list like ['tor', 'vr'] + ma=params.get("ma"), + adjfactor=params.get("adjfactor"), + ) + # Handle None response (e.g., delisted stock) + if data is None: + data = pd.DataFrame() else: api = self._get_api() data = api.query(api_name, **params) @@ -89,10 +110,14 @@ class TushareClient: except Exception as e: if attempt < max_retries - 1: delay = retry_delays[attempt] - print(f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s") + print( + f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s" + ) time.sleep(delay) else: - raise RuntimeError(f"API call failed after {max_retries} attempts: {e}") + raise RuntimeError( + f"API call failed after {max_retries} attempts: {e}" + ) return pd.DataFrame() diff --git a/src/data/config.py b/src/data/config.py deleted file mode 100644 index 9303ed1..0000000 --- a/src/data/config.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Configuration management for data collection module.""" -import os -from pathlib import Path -from pydantic_settings import BaseSettings - - -# Config directory path - used for loading .env.local -# Static detection for pydantic-settings to find .env.local -CONFIG_DIR = Path(__file__).parent.parent.parent / "config" - - -def _get_project_root() -> Path: - """Get project root path from ROOT_PATH env var or auto-detect.""" - # Try to read from environment variable first - root_path = os.environ.get("ROOT_PATH") or os.environ.get("DATA_ROOT") - if root_path: - return Path(root_path) - # Fallback to auto-detection - return Path(__file__).parent.parent.parent - - -class Config(BaseSettings): - """Application configuration loaded from environment variables.""" - - # Tushare API token - tushare_token: str = "" - - # Root path - loaded from environment variable ROOT_PATH - # If not set, uses auto-detected path - root_path: str = "" - - # Data storage path - can be set via DATA_PATH environment variable - # If relative path, it will be resolved relative to root_path - data_path: str = "data" - - # Rate limit: requests per minute - rate_limit: int = 100 - - # Thread pool size - threads: int = 2 - - @property - def project_root(self) -> Path: - """Get project root path.""" - if self.root_path: - return Path(self.root_path) - return _get_project_root() - - @property - def data_path_resolved(self) -> Path: - """Get resolved data path (absolute).""" - path = Path(self.data_path) - if path.is_absolute(): - return path - return self.project_root / path - - class Config: - # 从 config/ 目录读取 .env.local 文件 - env_file = str(CONFIG_DIR / ".env.local") - env_file_encoding = "utf-8" - case_sensitive = False - extra = "ignore" # 忽略 .env.local 中的额外变量 - # pydantic-settings 默认会将字段名转换为大写作为环境变量名 - # 所以 tushare_token 会映射到 TUSHARE_TOKEN - # root_path 会映射到 ROOT_PATH - # data_path 会映射到 DATA_PATH - - -# Global config instance -config = Config() - - -def get_config() -> Config: - """Get configuration instance.""" - return config - - -def get_project_root() -> Path: - """Get project root path (convenience function).""" - return get_config().project_root diff --git a/src/data/db_inspector.py b/src/data/db_inspector.py index a3353ab..33fe465 100644 --- a/src/data/db_inspector.py +++ b/src/data/db_inspector.py @@ -32,9 +32,12 @@ def get_db_info(db_path: Optional[Path] = None): # Get database path if db_path is None: - from src.data.config import get_config + from src.config.settings import get_settings - cfg = get_config() + cfg = get_settings() + db_path = cfg.data_path_resolved / "prostock.db" + + cfg = get_settings() db_path = cfg.data_path_resolved / "prostock.db" else: db_path = Path(db_path) @@ -231,9 +234,12 @@ def get_table_sample(table_name: str, limit: int = 5, db_path: Optional[Path] = db_path: Path to database file """ if db_path is None: - from src.data.config import get_config + from src.config.settings import get_settings - cfg = get_config() + cfg = get_settings() + db_path = cfg.data_path_resolved / "prostock.db" + + cfg = get_settings() db_path = cfg.data_path_resolved / "prostock.db" else: db_path = Path(db_path) diff --git a/src/data/rate_limiter.py b/src/data/rate_limiter.py index 1ade893..2b80d0d 100644 --- a/src/data/rate_limiter.py +++ b/src/data/rate_limiter.py @@ -1,35 +1,35 @@ -"""Token bucket rate limiter implementation. +"""API 速率限制器实现。 -This module provides a thread-safe token bucket algorithm for rate limiting. +提供基于固定时间窗口的速率限制,适合 Tushare 等按分钟计费的 API。 """ import time import threading from typing import Optional -from dataclasses import dataclass, field +from dataclasses import dataclass @dataclass class RateLimiterStats: - """Statistics for rate limiter.""" + """速率限制器统计信息。""" total_requests: int = 0 successful_requests: int = 0 denied_requests: int = 0 total_wait_time: float = 0.0 - current_tokens: Optional[float] = None + current_window_requests: int = 0 + window_start_time: float = 0.0 class TokenBucketRateLimiter: - """Thread-safe token bucket rate limiter. + """基于固定时间窗口的速率限制器。 - Implements a token bucket algorithm for controlling request rate. - Tokens are added at a fixed rate up to the bucket capacity. + 适合 Tushare 等按时间窗口(如每分钟)限制请求数的 API 场景。 + 在窗口期内,请求数达到上限后将阻塞或等待下一个窗口。 Attributes: - capacity: Maximum number of tokens in the bucket - refill_rate: Number of tokens added per second - initial_tokens: Initial number of tokens (default: capacity) + capacity: 每个时间窗口内允许的最大请求数 + window_seconds: 时间窗口长度(秒) """ def __init__( @@ -38,155 +38,157 @@ class TokenBucketRateLimiter: refill_rate_per_second: float = 1.67, initial_tokens: Optional[int] = None, ) -> None: - """Initialize the token bucket rate limiter. + """初始化速率限制器。 Args: - capacity: Maximum token capacity - refill_rate_per_second: Token refill rate per second - initial_tokens: Initial token count (default: capacity) + capacity: 每个时间窗口内允许的最大请求数 + refill_rate_per_second: 保留参数(向后兼容),实际使用 window_seconds=60 + initial_tokens: 保留参数(向后兼容) """ self.capacity = capacity - self.refill_rate = refill_rate_per_second - self.tokens = float(initial_tokens if initial_tokens is not None else capacity) - self.last_refill_time = time.monotonic() + # Tushare 通常按分钟限制,所以固定使用 60 秒窗口 + self.window_seconds = 60.0 + + self._requests_in_window = 0 + self._window_start = time.monotonic() self._lock = threading.RLock() self._stats = RateLimiterStats() - self._stats.current_tokens = self.tokens + self._stats.window_start_time = self._window_start + + def _is_new_window(self) -> bool: + """检查是否已进入新的时间窗口。""" + current_time = time.monotonic() + elapsed = current_time - self._window_start + return elapsed >= self.window_seconds + + def _reset_window(self) -> None: + """重置时间窗口。""" + self._window_start = time.monotonic() + self._requests_in_window = 0 + self._stats.window_start_time = self._window_start def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]: - """Acquire a token from the bucket. + """获取请求许可。 - Blocks until a token is available or timeout expires. + 如果在当前窗口内请求数已达上限,则等待到下一个窗口。 Args: - timeout: Maximum time to wait for a token in seconds (default: inf) + timeout: 最大等待时间(秒),默认无限等待 Returns: - Tuple of (success, wait_time): - - success: True if token was acquired, False if timed out - - wait_time: Time spent waiting for token + (success, wait_time): 是否成功获取许可,以及等待时间 """ start_time = time.monotonic() - wait_time = 0.0 with self._lock: - self._refill() + # 检查是否需要进入新窗口 + if self._is_new_window(): + self._reset_window() - if self.tokens >= 1: - self.tokens -= 1 + # 如果当前窗口还有余量,直接通过 + if self._requests_in_window < self.capacity: + self._requests_in_window += 1 self._stats.total_requests += 1 self._stats.successful_requests += 1 - self._stats.current_tokens = self.tokens + self._stats.current_window_requests = self._requests_in_window return True, 0.0 - # Calculate time to wait for next token - tokens_needed = 1 - self.tokens - time_to_refill = tokens_needed / self.refill_rate + # 当前窗口已满,计算需要等待的时间 + current_time = time.monotonic() + time_to_next_window = self.window_seconds - ( + current_time - self._window_start + ) - # Check if we can wait for the token within timeout - # Handle infinite timeout specially - is_infinite_timeout = timeout == float("inf") - if not is_infinite_timeout and time_to_refill > timeout: + if time_to_next_window <= 0: + # 刚好进入新窗口 + self._reset_window() + self._requests_in_window = 1 + self._stats.total_requests += 1 + self._stats.successful_requests += 1 + self._stats.current_window_requests = 1 + return True, 0.0 + + # 检查是否能在超时时间内等待 + if timeout != float("inf") and time_to_next_window > timeout: self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, timeout - # Wait for tokens - loop until we get one or timeout - while True: - # Calculate remaining time we can wait - elapsed = time.monotonic() - start_time - remaining_timeout = ( - timeout - elapsed if not is_infinite_timeout else float("inf") - ) + # 需要等待到下一个窗口 + if timeout != float("inf"): + time_to_wait = min(time_to_next_window, timeout) + else: + time_to_wait = time_to_next_window - # Check if we've exceeded timeout - if not is_infinite_timeout and remaining_timeout <= 0: - self._stats.total_requests += 1 - self._stats.denied_requests += 1 - return False, elapsed + time.sleep(time_to_wait) - # Calculate wait time for next token - tokens_needed = max(0, 1 - self.tokens) - time_to_wait = ( - tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1 - ) - - # If we can't wait long enough, fail - if not is_infinite_timeout and time_to_wait > remaining_timeout: - self._stats.total_requests += 1 - self._stats.denied_requests += 1 - return False, elapsed - - # Wait outside the lock to allow other threads to refill - self._lock.release() - time.sleep( - min(time_to_wait, 0.1) - ) # Cap wait to 100ms to check frequently - self._lock.acquire() - - # Refill and check again - self._refill() - if self.tokens >= 1: - self.tokens -= 1 - wait_time = time.monotonic() - start_time - self._stats.total_requests += 1 - self._stats.successful_requests += 1 - self._stats.total_wait_time += wait_time - self._stats.current_tokens = self.tokens - return True, wait_time - - def acquire_nonblocking(self) -> tuple[bool, float]: - """Try to acquire a token without blocking. - - Returns: - Tuple of (success, wait_time): - - success: True if token was acquired, False otherwise - - wait_time: 0 for non-blocking, or required wait time if failed - """ + # 重新尝试获取许可 with self._lock: - self._refill() + # 再次检查窗口状态(可能其他线程已经重置了窗口) + if self._is_new_window(): + self._reset_window() - if self.tokens >= 1: - self.tokens -= 1 + if self._requests_in_window < self.capacity: + self._requests_in_window += 1 + wait_time = time.monotonic() - start_time self._stats.total_requests += 1 self._stats.successful_requests += 1 - self._stats.current_tokens = self.tokens + self._stats.total_wait_time += wait_time + self._stats.current_window_requests = self._requests_in_window + return True, wait_time + else: + # 在极端情况下,等待后仍然无法获取(其他线程抢先) + wait_time = time.monotonic() - start_time + self._stats.total_requests += 1 + self._stats.denied_requests += 1 + return False, wait_time + + def acquire_nonblocking(self) -> tuple[bool, float]: + """尝试非阻塞地获取请求许可。 + + Returns: + (success, wait_time): 是否成功获取许可,以及需要等待的时间 + """ + with self._lock: + # 检查是否需要进入新窗口 + if self._is_new_window(): + self._reset_window() + + # 如果当前窗口还有余量,直接通过 + if self._requests_in_window < self.capacity: + self._requests_in_window += 1 + self._stats.total_requests += 1 + self._stats.successful_requests += 1 + self._stats.current_window_requests = self._requests_in_window return True, 0.0 - # Calculate time needed - tokens_needed = 1 - self.tokens - time_to_refill = tokens_needed / self.refill_rate + # 当前窗口已满,计算需要等待的时间 + current_time = time.monotonic() + time_to_next_window = self.window_seconds - ( + current_time - self._window_start + ) self._stats.total_requests += 1 self._stats.denied_requests += 1 - return False, time_to_refill - - def _refill(self) -> None: - """Refill tokens based on elapsed time.""" - current_time = time.monotonic() - elapsed = current_time - self.last_refill_time - self.last_refill_time = current_time - - tokens_to_add = elapsed * self.refill_rate - self.tokens = min(self.capacity, self.tokens + tokens_to_add) + return False, max(0.0, time_to_next_window) def get_available_tokens(self) -> float: - """Get the current number of available tokens. + """获取当前窗口剩余可用请求数。 Returns: - Current token count + 当前窗口剩余可用请求数 """ with self._lock: - self._refill() - return self.tokens + if self._is_new_window(): + return float(self.capacity) + return float(self.capacity - self._requests_in_window) def get_stats(self) -> RateLimiterStats: - """Get rate limiter statistics. + """获取速率限制器统计信息。 Returns: - RateLimiterStats instance + RateLimiterStats 实例 """ with self._lock: - self._refill() - self._stats.current_tokens = self.tokens + self._stats.current_window_requests = self._requests_in_window return self._stats diff --git a/src/data/storage.py b/src/data/storage.py index 9e285da..1b83d47 100644 --- a/src/data/storage.py +++ b/src/data/storage.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Optional, List, Dict, Any, Tuple from collections import defaultdict from datetime import datetime -from src.data.config import get_config +from src.config.settings import get_settings # Default column type mapping for automatic schema inference @@ -53,7 +53,7 @@ class Storage: if hasattr(self, "_initialized"): return - cfg = get_config() + cfg = get_settings() self.base_path = path or cfg.data_path_resolved self.base_path.mkdir(parents=True, exist_ok=True) self.db_path = self.base_path / "prostock.db" @@ -190,6 +190,26 @@ class Storage: update_flag VARCHAR(1), PRIMARY KEY (ts_code, end_date) ) + + # Create pro_bar table for pro bar data (with adj, tor, vr) + self._connection.execute(""" + CREATE TABLE IF NOT EXISTS pro_bar ( + ts_code VARCHAR(16) NOT NULL, + trade_date DATE NOT NULL, + open DOUBLE, + high DOUBLE, + low DOUBLE, + close DOUBLE, + pre_close DOUBLE, + change DOUBLE, + pct_chg DOUBLE, + vol DOUBLE, + amount DOUBLE, + tor DOUBLE, + vr DOUBLE, + adj_factor DOUBLE, + PRIMARY KEY (ts_code, trade_date) + ) """) # Create index for financial_income diff --git a/src/data/sync.py b/src/data/sync.py index 27ecffb..0315d46 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -29,6 +29,7 @@ import pandas as pd from src.data.api_wrappers import sync_all_stocks from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync +from src.data.api_wrappers.api_pro_bar import sync_pro_bar def preview_sync( @@ -134,7 +135,6 @@ def sync_all_data( dry_run: bool = False, ) -> Dict[str, pd.DataFrame]: """同步所有数据类型(每日同步)。 - 该函数按顺序同步所有可用的数据类型: 1. 交易日历 (sync_trade_cal_cache) 2. 股票基本信息 (sync_all_stocks) @@ -146,13 +146,12 @@ def sync_all_data( Args: force_full: 若为 True,强制所有数据类型完整重载 max_workers: 日线数据同步的工作线程数(默认: 10) - dry_run: 若为 True,仅显示将要同步的内容 Returns: - 映射数据类型,不写入数据 + dry_run: 若为 True,仅显示将要同步的内容,不写入数据 - 到同步结果的字典 + Returns: + 映射数据类型到同步结果的字典 Example: - >>> # 同步所有数据(增量) >>> result = sync_all_data() >>> >>> # 强制完整重载 @@ -167,6 +166,92 @@ def sync_all_data( print("[sync_all_data] Starting full data synchronization...") print("=" * 60) + # 1. Sync trade calendar (always needed first) + print("\n[1/6] Syncing trade calendar cache...") + try: + from src.data.api_wrappers import sync_trade_cal_cache + + sync_trade_cal_cache() + results["trade_cal"] = pd.DataFrame() + print("[1/6] Trade calendar: OK") + except Exception as e: + print(f"[1/6] Trade calendar: FAILED - {e}") + results["trade_cal"] = pd.DataFrame() + + # 2. Sync stock basic info + print("\n[2/6] Syncing stock basic info...") + try: + sync_all_stocks() + results["stock_basic"] = pd.DataFrame() + print("[2/6] Stock basic: OK") + except Exception as e: + print(f"[2/6] Stock basic: FAILED - {e}") + results["stock_basic"] = pd.DataFrame() + + # # 3. Sync daily market data + # print("\n[3/6] Syncing daily market data...") + # try: + # daily_result = sync_daily( + # force_full=force_full, + # max_workers=max_workers, + # dry_run=dry_run, + # ) + # results["daily"] = ( + # pd.concat(daily_result.values(), ignore_index=True) + # if daily_result + # else pd.DataFrame() + # ) + # print("[3/6] Daily data: OK") + # except Exception as e: + # print(f"[3/6] Daily data: FAILED - {e}") + # results["daily"] = pd.DataFrame() + + # 4. Sync Pro Bar data + print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...") + try: + pro_bar_result = sync_pro_bar( + force_full=force_full, + max_workers=max_workers, + dry_run=dry_run, + ) + results["pro_bar"] = ( + pd.concat(pro_bar_result.values(), ignore_index=True) + if pro_bar_result + else pd.DataFrame() + ) + print(f"[4/6] Pro Bar data: OK ({len(results['pro_bar'])} records)") + except Exception as e: + print(f"[4/6] Pro Bar data: FAILED - {e}") + results["pro_bar"] = pd.DataFrame() + + # 5. Sync stock historical list (bak_basic) + print("\n[5/6] Syncing stock historical list (bak_basic)...") + try: + bak_basic_result = sync_bak_basic(force_full=force_full) + results["bak_basic"] = bak_basic_result + print(f"[5/6] Bak basic: OK ({len(bak_basic_result)} records)") + except Exception as e: + print(f"[5/6] Bak basic: FAILED - {e}") + results["bak_basic"] = pd.DataFrame() + + # Summary + print("\n" + "=" * 60) + print("[sync_all_data] Sync Summary") + print("=" * 60) + for data_type, df in results.items(): + print(f" {data_type}: {len(df)} records") + print("=" * 60) + print("\nNote: namechange is NOT in auto-sync. To sync manually:") + print(" from src.data.api_wrappers import sync_namechange") + print(" sync_namechange(force=True)") + + return results + results: Dict[str, pd.DataFrame] = {} + + print("\n" + "=" * 60) + print("[sync_all_data] Starting full data synchronization...") + print("=" * 60) + # 1. Sync trade calendar (always needed first) print("\n[1/5] Syncing trade calendar cache...") try: diff --git a/src/factors/FACTOR_GUIDE.md b/src/factors/FACTOR_GUIDE.md deleted file mode 100644 index 851c98f..0000000 --- a/src/factors/FACTOR_GUIDE.md +++ /dev/null @@ -1,1535 +0,0 @@ -# ProStock 因子开发规范 - -本文档是 ProStock 因子框架的完整开发指南,涵盖从因子设计到测试的全流程规范。 - -## 目录 - -- [1. 因子框架概述](#1-因子框架概述) -- [2. 因子分类体系](#2-因子分类体系) -- [3. 因子类型选择](#3-因子类型选择) -- [4. 项目结构](#4-项目结构) -- [5. 编写步骤](#5-编写步骤) -- [6. 编码规范](#6-编码规范) -- [7. 命名规范](#7-命名规范) -- [8. 数据需求规范](#8-数据需求规范) -- [9. 防泄露机制](#9-防泄露机制) -- [10. 参数化因子](#10-参数化因子) -- [11. 因子组合](#11-因子组合) -- [12. 性能优化](#12-性能优化) -- [13. 测试规范](#13-测试规范) -- [14. 完整示例](#14-完整示例) -- [15. 常见错误](#15-常见错误) -- [16. 验证清单](#16-验证清单) -- [附录](#附录) - ---- - -## 1. 因子框架概述 - -ProStock 因子框架采用**类型安全**设计,严格区分截面因子和时序因子,在框架层面防止数据泄露。 - -### 1.1 核心组件 - -``` -src/factors/ -├── base.py # 因子基类(CrossSectionalFactor / TimeSeriesFactor) -├── data_spec.py # 数据规格定义(DataSpec, FactorData, FactorContext) -├── composite.py # 组合因子(支持因子间的加减乘除) -├── engine.py # 执行引擎(FactorEngine) -├── data_loader.py # 数据加载器(DataLoader) -├── momentum/ # 动量因子 -├── financial/ # 财务因子 -├── valuation/ # 估值因子 -├── technical/ # 技术指标因子 -├── quality/ # 质量因子 -├── sentiment/ # 情绪因子 -├── volume/ # 成交量因子 -└── volatility/ # 波动率因子 -``` - -### 1.2 关键概念 - -| 概念 | 说明 | -|------|------| -| `DataSpec` | 声明因子所需的数据源、列和回看窗口 | -| `FactorData` | 数据容器,封装 Polars DataFrame | -| `FactorContext` | 计算上下文,提供当前日期/股票信息 | -| `FactorEngine` | 执行引擎,根据因子类型采用不同计算策略 | - -### 1.3 设计原则 - -1. **类型安全**:编译时检查因子类型,防止运行时错误 -2. **防泄露**:框架层面防止未来数据和跨股票数据泄露 -3. **可组合**:支持因子间的数学运算(加减乘除) -4. **高性能**:基于 Polars 的向量化计算 -5. **可扩展**:插件化架构,易于添加新因子 - ---- - -## 2. 因子分类体系 - -### 2.1 按经济含义分类 - -| 分类 | 目录 | 说明 | 示例因子 | -|------|------|------|----------| -| **动量因子** | `momentum/` | 价格趋势和动量指标 | MA、收益率排名、动量 | -| **财务因子** | `financial/` | 财务报表相关指标 | EPS、营收增长、资产负债率 | -| **估值因子** | `valuation/` | 估值水平指标 | PE、PB、PS排名 | -| **技术指标** | `technical/` | 技术分析指标 | RSI、MACD、布林带、KDJ | -| **质量因子** | `quality/` | 公司质量指标 | ROE、ROA、盈利稳定性 | -| **情绪因子** | `sentiment/` | 市场情绪指标 | 换手率、资金流向、振幅 | -| **成交量因子** | `volume/` | 成交量相关指标 | OBV、成交量比率、量价配合 | -| **波动率因子** | `volatility/` | 波动率指标 | 历史波动率、GARCH、实现波动率 | - -### 2.2 按计算方式分类 - -| 类型 | 基类 | 计算维度 | 防泄露重点 | -|------|------|----------|------------| -| **截面因子** | `CrossSectionalFactor` | 横向:每天对所有股票计算 | 防止日期泄露 | -| **时序因子** | `TimeSeriesFactor` | 纵向:对每只股票单独计算 | 防止股票泄露 | - -### 2.3 分类选择指南 - -**如何选择因子分类?** - -1. **先确定计算方式**:你的因子是每天对所有股票计算(截面),还是对每只股票单独计算(时序)? -2. **再确定经济含义**:根据因子的经济逻辑选择对应的分类目录 -3. **特殊情况**:如果因子涉及多个维度(如行业内的时序计算),选择主要的经济含义分类 - ---- - -## 3. 因子类型选择 - -### 3.1 决策流程图 - -``` -开始编写因子 - │ - ▼ -因子计算是否涉及多只股票比较? - │ - ├── 是 → 截面因子 (CrossSectionalFactor) - │ └── 每天传入当天所有股票数据 - │ └── 示例:PE排名、市值分位数 - │ - └── 否 → 时序因子 (TimeSeriesFactor) - └── 每只股票单独传入其时间序列 - └── 示例:MA、RSI、历史波动率 -``` - -### 3.2 详细对比 - -#### 截面因子 (CrossSectionalFactor) - -**计算逻辑**:在每个交易日,对所有股票进行横向计算。 - -**防泄露边界**: -- ❌ 禁止访问未来日期的数据(日期泄露) -- ✅ 允许访问当前日期的所有股票数据 - -**数据传入**: -- `compute()` 接收的是 `[T-lookback+1, T]` 的数据 -- 包含 lookback_days 的历史数据(用于时序计算后再截面) - -**典型应用**: -- PE排名、市值分位数 -- 当日收益率排名 -- 行业分类排名 - -**使用场景判断**: -- 需要比较不同股票之间的值 -- 需要对股票进行排序或分组 -- 计算涉及多个股票的统计量(均值、标准差等) - -#### 时序因子 (TimeSeriesFactor) - -**计算逻辑**:对每只股票,在其时间序列上进行纵向计算。 - -**防泄露边界**: -- ❌ 禁止访问其他股票的数据(股票泄露) -- ✅ 允许访问该股票的完整历史数据 - -**数据传入**: -- `compute()` 接收的是单只股票的完整时间序列 -- 包含该股票在 `[start_date, end_date]` 范围内的所有数据 - -**典型应用**: -- 移动平均线 (MA) -- 相对强弱指标 (RSI) -- 历史波动率 - -**使用场景判断**: -- 只关注单只股票的历史表现 -- 计算涉及时间序列的滚动窗口 -- 不需要与其他股票比较 - -### 3.3 混合场景处理 - -**场景**:需要先计算时序指标,再进行截面排名(如20日收益率排名) - -**解决方案**: -1. 使用 **截面因子** 作为外层类型 -2. 在 `compute()` 中使用时序计算(如 `rolling_mean`) -3. 设置合适的 `lookback_days` 以确保有足够的历史数据 - -**示例**: -```python -class ReturnRankFactor(CrossSectionalFactor): - """过去n日收益率排名因子""" - data_specs = [DataSpec("daily", ["close"], lookback_days=period+1)] - - def compute(self, data: FactorData) -> pl.Series: - # 获取历史数据(包含时序信息) - df = data.to_polars() - # 计算每只股票的收益率(时序计算) - # 然后进行截面排名(截面计算) -``` - ---- - -## 4. 项目结构 - -### 4.1 完整目录结构 - -``` -src/factors/ -│ -├── __init__.py # 包入口,导出公共接口 -├── FACTOR_GUIDE.md # 本规范文档 -│ -├── base.py # 因子基类定义 -├── data_spec.py # 数据类型定义 -├── composite.py # 组合因子实现 -├── data_loader.py # 数据加载器 -├── engine.py # 执行引擎 -│ -├── momentum/ # 动量因子 -│ ├── __init__.py -│ ├── ma.py # 移动平均线 -│ └── return_rank.py # 收益率排名 -│ -├── financial/ # 财务因子 -│ ├── __init__.py -│ ├── eps_factor.py # EPS因子 -│ └── utils.py # 财务工具函数 -│ -├── valuation/ # 估值因子 -│ ├── __init__.py -│ └── [你的估值因子] -│ -├── technical/ # 技术指标因子 -│ ├── __init__.py -│ └── [你的技术指标因子] -│ -├── quality/ # 质量因子 -│ ├── __init__.py -│ └── [你的质量因子] -│ -├── sentiment/ # 情绪因子 -│ ├── __init__.py -│ └── [你的情绪因子] -│ -├── volume/ # 成交量因子 -│ ├── __init__.py -│ └── [你的成交量因子] -│ -└── volatility/ # 波动率因子 - ├── __init__.py - └── [你的波动率因子] -``` - -### 4.2 文件组织原则 - -1. **单一职责**:每个文件只包含一个因子类(或紧密相关的多个因子) -2. **分类清晰**:根据因子的经济含义放入对应目录 -3. **命名一致**:文件名使用 `snake_case`,反映因子功能 - -### 4.3 新增因子流程 - -``` -1. 确定因子类型(截面/时序) -2. 选择因子分类(momentum/financial/...) -3. 在对应目录创建文件 -4. 实现因子类 -5. 更新该目录的 __init__.py -6. 编写测试 -7. 更新主 __init__.py(如需要公开导出) -``` - ---- - -## 5. 编写步骤 - -### 步骤 1:确定因子类型和分类 - -**问题清单**: -- [ ] 因子是截面计算还是时序计算? -- [ ] 因子的经济含义属于哪个分类? -- [ ] 需要哪些数据字段? -- [ ] 是否需要参数化(如 MA 的周期)? - -### 步骤 2:创建文件 - -根据因子分类创建文件: - -```bash -# 示例:创建估值因子 -touch src/factors/valuation/pe_rank.py -``` - -### 步骤 3:定义类属性(必须) - -每个因子必须声明以下类属性: - -```python -class MyFactor(CrossSectionalFactor): # 或 TimeSeriesFactor - # 必须声明 - name: str = "my_factor" # 因子唯一标识(snake_case) - factor_type: str = "cross_sectional" # 或 "time_series" - data_specs: List[DataSpec] = [...] # 数据需求列表 - - # 可选声明 - category: str = "default" # 因子分类 - description: str = "" # 因子描述 -``` - -### 步骤 4:实现 compute 方法 - -```python -def compute(self, data: FactorData) -> pl.Series: - """核心计算逻辑 - - Args: - data: FactorData,已根据因子类型裁剪 - - Returns: - 计算得到的因子值 Series - """ - # 你的计算逻辑 - pass -``` - -### 步骤 5:更新 __init__.py - -在对应分类的 `__init__.py` 中导出因子: - -```python -# src/factors/momentum/__init__.py -from .ma import MovingAverageFactor -from .return_rank import ReturnRankFactor -from .your_factor import MyFactor # 添加你的因子 - -__all__ = [ - "MovingAverageFactor", - "ReturnRankFactor", - "MyFactor", # 添加你的因子 -] -``` - -### 步骤 6:编写测试 - -创建对应的测试文件: - -```bash -touch tests/factors/test_momentum.py -``` - -### 步骤 7:验证和文档 - -- [ ] 运行测试确保通过 -- [ ] 检查代码风格 -- [ ] 更新相关文档 - ---- - -## 6. 编码规范 - -### 6.1 文件头格式 - -```python -"""因子名称 - 一句话描述 - -本模块提供: -- FactorClassName: 详细描述 - -使用示例: - >>> from src.factors.category import FactorClassName - >>> factor = FactorClassName(param=value) -""" -``` - -### 6.2 导入顺序 - -```python -# 1. 标准库 -from typing import List, Optional - -# 2. 第三方库 -import polars as pl - -# 3. 本地模块(使用绝对导入) -from src.factors.base import CrossSectionalFactor, TimeSeriesFactor -from src.factors.data_spec import DataSpec, FactorData -``` - -### 6.3 类文档字符串(Google 风格) - -```python -class MovingAverageFactor(TimeSeriesFactor): - """移动平均线因子 - - 计算逻辑:对每只股票,计算其过去n日收盘价的移动平均值。 - - 特点: - - 参数化因子:训练时通过 period 参数指定计算窗口 - - 时序因子:每只股票单独计算,防止股票间数据泄露 - - Attributes: - period: MA计算期(天数),默认5 - - Example: - >>> ma5 = MovingAverageFactor(period=5) - >>> # 计算过去5日的收盘价均值 - """ -``` - -### 6.4 方法文档字符串 - -```python -def compute(self, data: FactorData) -> pl.Series: - """计算移动平均线 - - Args: - data: FactorData,包含单只股票的完整时间序列 - - Returns: - 移动平均值序列 - - Raises: - KeyError: 当所需列不存在于数据中 - """ -``` - -### 6.5 类型提示(强制) - -- 所有函数参数必须标注类型 -- 所有函数返回值必须标注类型 -- 使用 `Optional[X]` 表示可空类型 -- 使用 `List[X]` 表示列表类型 - -```python -def __init__(self, period: int = 5) -> None: - self.period: int = period - -def compute(self, data: FactorData) -> pl.Series: - pass -``` - -### 6.6 代码风格 - -- 使用 4 空格缩进 -- 行长度限制 88 字符(Black 默认) -- 使用 snake_case 命名变量和函数 -- 使用 PascalCase 命名类 - ---- - -## 7. 命名规范 - -### 7.1 因子名称 (name) - -- 使用 **snake_case** 格式 -- 简洁明了,反映因子含义 -- 参数化因子在 `__init__` 中动态设置名称 - -```python -# 好的命名 -name = "pe_rank" # PE排名 -name = "ma_20" # 20日移动平均 -name = "return_5d" # 5日收益率 -name = "volatility_20" # 20日波动率 - -# 不好的命名 -name = "PERank" # 错误:使用 PascalCase -name = "moving_average" # 不够具体,未体现参数 -name = "factor1" # 无意义 -``` - -### 7.2 因子分类 (category) - -使用统一的分类标签: - -| 分类 | 说明 | 示例 | -|------|------|------| -| `momentum` | 动量因子 | 收益率、MA、RSI | -| `financial` | 财务因子 | EPS、ROE、营收增长 | -| `valuation` | 估值因子 | PE、PB、PS | -| `technical` | 技术指标 | MACD、KDJ、布林带 | -| `quality` | 质量因子 | 盈利能力、稳定性 | -| `sentiment` | 情绪因子 | 换手率、资金流向 | -| `volume` | 成交量因子 | OBV、成交量比率 | -| `volatility` | 波动率因子 | 历史波动率、GARCH | - -### 7.3 文件名 - -- 使用 **snake_case** 格式 -- 反映因子功能 - -``` -ma.py # 移动平均线 -return_rank.py # 收益率排名 -eps_factor.py # EPS因子 -pe_ratio.py # 市盈率因子 -``` - -### 7.4 类名 - -- 使用 **PascalCase** 格式 -- 以 `Factor` 结尾 - -```python -class MovingAverageFactor(TimeSeriesFactor): -class ReturnRankFactor(CrossSectionalFactor): -class EPSFactor(CrossSectionalFactor): -``` - -### 7.5 命名最佳实践 - -#### 前缀/后缀约定 - -| 前缀/后缀 | 含义 | 示例 | -|-----------|------|------| -| `_rank` | 排名因子 | `pe_rank`, `pb_rank` | -| `_zscore` | Z-Score标准化 | `roe_zscore` | -| `_ma` | 移动平均 | `volume_ma`, `price_ma` | -| `_std` | 标准差 | `return_std` | -| `_skew` | 偏度 | `return_skew` | -| `_kurt` | 峰度 | `return_kurt` | - -#### 时间周期表示 - -```python -# 推荐 -name = "return_5d" # 5日收益率 -name = "return_1m" # 1月收益率 -name = "return_1y" # 1年收益率 -name = "ma_20" # 20日移动平均 - -# 不推荐 -name = "return5" # 缺少单位 -name = "ma20" # 缺少分隔符 -``` - ---- - -## 8. 数据需求规范 - -### 8.1 DataSpec 定义 - -```python -from src.factors.data_spec import DataSpec - -data_specs: List[DataSpec] = [ - DataSpec( - source="daily", # 数据源名称(DuckDB 表名) - columns=["ts_code", "trade_date", "close", "volume"], # 必需列 - lookback_days=20 # 回看窗口(包含当日) - ) -] -``` - -### 8.2 必需列 - -`columns` 必须包含以下列: -- `ts_code`: 股票代码 -- `trade_date`: 交易日期 - -### 8.3 lookback_days 设置 - -- **最小值**:1(只包含当日) -- **时序因子**:设置为计算所需的完整窗口 - - MA(20) → `lookback_days=20` - - 5日收益率 → `lookback_days=6`(需要T日和T-5日) -- **截面因子**:通常为 1,除非需要历史数据进行时序计算 - -### 8.4 参数化因子的 DataSpec - -对于参数化因子,在 `__init__` 中动态创建 DataSpec: - -```python -def __init__(self, period: int = 5): - super().__init__(period=period) - # 重新创建 DataSpec 以设置正确的 lookback_days - self.data_specs = [ - DataSpec( - "daily", - ["ts_code", "trade_date", "close"], - lookback_days=period, # 使用参数 - ) - ] - self.name = f"ma_{period}" # 动态设置名称 -``` - -### 8.5 数据表和字段参考 - -#### 日线数据表 (daily) - -| 字段名 | 类型 | 说明 | -|--------|------|------| -| `ts_code` | str | 股票代码 | -| `trade_date` | str | 交易日期 (YYYYMMDD) | -| `open` | f64 | 开盘价 | -| `high` | f64 | 最高价 | -| `low` | f64 | 最低价 | -| `close` | f64 | 收盘价 | -| `pre_close` | f64 | 昨收价 | -| `change` | f64 | 涨跌额 | -| `pct_chg` | f64 | 涨跌幅 (%) | -| `vol` | f64 | 成交量(手) | -| `amount` | f64 | 成交额(千元) | -| `pe` | f64 | 市盈率 | -| `pb` | f64 | 市净率 | -| `ps` | f64 | 市销率 | -| `total_mv` | f64 | 总市值(万元) | -| `circ_mv` | f64 | 流通市值(万元) | - -#### 财务数据表 (financial_income) - -| 字段名 | 类型 | 说明 | -|--------|------|------| -| `ts_code` | str | 股票代码 | -| `trade_date` | str | 报告期 (YYYYMMDD) | -| `basic_eps` | f64 | 基本每股收益 | -| `diluted_eps` | f64 | 稀释每股收益 | -| `total_revenue` | f64 | 营业总收入 | -| `revenue` | f64 | 营业收入 | -| `total_profit` | f64 | 营业利润 | -| `net_income` | f64 | 净利润 | - -#### 使用示例 - -```python -# 日线数据 -data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close", "volume"], lookback_days=20)] - -# 财务数据 -data_specs = [DataSpec("financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1)] - -# 多数据源 -data_specs = [ - DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20), - DataSpec("financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1) -] -``` - ---- - -## 9. 防泄露机制 - -### 9.1 截面因子注意事项 - -**禁止**: -- 访问 `data.context.current_date` 之后的数据 -- 使用 `data.to_polars()` 获取未来日期的数据 - -**正确做法**: -```python -def compute(self, data: FactorData) -> pl.Series: - # 获取当前日期的截面数据(框架已裁剪到正确范围) - cs = data.get_cross_section() - - # 只使用当前日期的数据 - return cs["pe"].rank() -``` - -### 9.2 时序因子注意事项 - -**禁止**: -- 访问其他股票的数据 -- 使用包含多只股票的数据进行计算 - -**正确做法**: -```python -def compute(self, data: FactorData) -> pl.Series: - # 获取该股票的时间序列(框架已确保只有一只股票) - close = data.get_column("close") - - # 只使用该股票的历史数据 - return close.rolling_mean(window_size=self.params["period"]) -``` - -### 9.3 数据访问方法 - -| 方法 | 适用场景 | 说明 | -|------|----------|------| -| `data.get_column(col)` | 通用 | 获取指定列的 Series | -| `data.get_cross_section()` | 截面因子 | 获取当前日期的所有股票数据 | -| `data.filter_by_date(date)` | 截面因子 | 按日期过滤数据 | -| `data.to_polars()` | 高级用法 | 获取底层 DataFrame(谨慎使用) | -| `data.context` | 通用 | 获取计算上下文 | - ---- - -## 10. 参数化因子 - -### 10.1 定义方式 - -通过 `__init__` 接收参数: - -```python -class MovingAverageFactor(TimeSeriesFactor): - name: str = "ma" - factor_type: str = "time_series" - category: str = "momentum" - description: str = "移动平均线因子" - data_specs: List[DataSpec] = [] # 占位,在 __init__ 中设置 - - def __init__(self, period: int = 5): - super().__init__(period=period) - # 动态设置 data_specs 和 name - self.data_specs = [ - DataSpec( - "daily", - ["ts_code", "trade_date", "close"], - lookback_days=period, - ) - ] - self.name = f"ma_{period}" -``` - -### 10.2 参数验证 - -可选覆盖 `_validate_params` 方法进行参数验证: - -```python -def _validate_params(self): - """验证参数有效性""" - period = self.params.get("period", 5) - if period < 1: - raise ValueError(f"period must be >= 1, got {period}") - if period > 252: - raise ValueError(f"period must be <= 252, got {period}") -``` - -### 10.3 参数使用 - -在 `compute` 中通过 `self.params` 访问参数: - -```python -def compute(self, data: FactorData) -> pl.Series: - close = data.get_column("close") - period = self.params["period"] - return close.rolling_mean(window_size=period) -``` - ---- - -## 11. 因子组合 - -### 11.1 支持的运算符 - -因子框架支持因子间的数学运算: - -```python -# 因子相加 -combined = factor1 + factor2 - -# 因子相减 -diff = factor1 - factor2 - -# 因子相乘 -product = factor1 * factor2 - -# 因子相除 -ratio = factor1 / factor2 - -# 标量乘法 -scaled = 0.5 * factor1 - -# 复杂组合 -final = 0.3 * factor1 + 0.5 * factor2 - 0.2 * factor3 -``` - -### 11.2 组合约束 - -- **类型一致性**:只能组合同类型因子(截面+截面,时序+时序) -- **自动合并 data_specs**:组合因子会自动合并左右因子的数据需求 - -### 11.3 组合因子示例 - -```python -from src.factors.momentum import MovingAverageFactor, ReturnRankFactor - -# 创建基础因子 -ma20 = MovingAverageFactor(period=20) -ret5 = ReturnRankFactor(period=5) - -# 组合(注意:这里只是示例,实际不能组合不同类型因子) -# 同类型因子可以组合: -# combined = ma20 + MovingAverageFactor(period=10) -``` - ---- - -## 12. 性能优化 - -### 12.1 向量化计算 - -**推荐**:使用 Polars 的向量化操作 - -```python -def compute(self, data: FactorData) -> pl.Series: - close = data.get_column("close") - # 向量化计算 - 高效 - return close.rolling_mean(window_size=20) -``` - -**避免**:使用 Python 循环 - -```python -def compute(self, data: FactorData) -> pl.Series: - close = data.get_column("close") - # 避免:Python 循环 - 低效 - result = [] - for i in range(len(close)): - if i < 20: - result.append(None) - else: - result.append(sum(close[i-20:i]) / 20) - return pl.Series(result) -``` - -### 12.2 数据加载优化 - -1. **只加载需要的列**:在 `DataSpec` 中明确指定需要的列 -2. **合理设置 lookback_days**:不要设置过大的回看窗口 -3. **使用合适的数据类型**:Polars 会自动优化,但避免不必要的类型转换 - -### 12.3 内存优化 - -1. **避免复制大数据**:使用 Polars 的零拷贝操作 -2. **及时释放中间结果**:不需要的变量及时清理 -3. **使用惰性求值**(如适用):`pl.lazy()` - -### 12.4 计算优化技巧 - -```python -# 1. 预计算常用值 -def compute(self, data: FactorData) -> pl.Series: - close = data.get_column("close") - # 预计算 - log_close = close.log() - return log_close.rolling_mean(window_size=20) - -# 2. 使用窗口函数避免重复计算 -def compute(self, data: FactorData) -> pl.Series: - close = data.get_column("close") - # 一次计算多个统计量 - return close.rolling_mean(window_size=20), close.rolling_std(window_size=20) - -# 3. 避免重复的数据转换 -def compute(self, data: FactorData) -> pl.Series: - # 获取 DataFrame 一次 - df = data.to_polars() - # 多次使用 - mean_val = df["close"].mean() - std_val = df["close"].std() -``` - ---- - -## 13. 测试规范 - -### 13.1 测试文件位置 - -``` -tests/ -├── factors/ -│ ├── __init__.py -│ ├── test_momentum.py # 动量因子测试 -│ ├── test_financial.py # 财务因子测试 -│ ├── test_valuation.py # 估值因子测试 -│ ├── test_technical.py # 技术指标测试 -│ └── test_your_factor.py # 你的因子测试 -``` - -### 13.2 测试模板 - -```python -"""XXX因子测试 - -测试覆盖: -- 初始化测试 -- 计算逻辑测试 -- 边界条件测试 -- 参数验证测试 -""" - -import pytest -import polars as pl -from src.factors.momentum import MovingAverageFactor -from src.factors.data_spec import FactorData, FactorContext - - -class TestMovingAverageFactor: - """MovingAverageFactor 测试类""" - - def test_init(self): - """测试初始化""" - factor = MovingAverageFactor(period=10) - assert factor.name == "ma_10" - assert factor.params["period"] == 10 - assert factor.data_specs[0].lookback_days == 10 - - def test_compute(self): - """测试计算逻辑""" - factor = MovingAverageFactor(period=3) - - # 构造测试数据 - df = pl.DataFrame({ - "ts_code": ["000001.SZ"] * 5, - "trade_date": ["20240101", "20240102", "20240103", "20240104", "20240105"], - "close": [10.0, 11.0, 12.0, 13.0, 14.0], - }) - - context = FactorContext(current_stock="000001.SZ") - data = FactorData(df, context) - - # 计算 - result = factor.compute(data) - - # 验证 - assert len(result) == 5 - # 第3个值应该是 (10+11+12)/3 = 11.0 - assert result[2] == 11.0 - - def test_compute_empty_data(self): - """测试空数据处理""" - factor = MovingAverageFactor(period=5) - - df = pl.DataFrame({ - "ts_code": [], - "trade_date": [], - "close": [], - }) - - context = FactorContext(current_stock="000001.SZ") - data = FactorData(df, context) - - result = factor.compute(data) - assert len(result) == 0 - - def test_compute_single_row(self): - """测试单行数据处理""" - factor = MovingAverageFactor(period=5) - - df = pl.DataFrame({ - "ts_code": ["000001.SZ"], - "trade_date": ["20240101"], - "close": [10.0], - }) - - context = FactorContext(current_stock="000001.SZ") - data = FactorData(df, context) - - result = factor.compute(data) - assert len(result) == 1 - # 数据不足时返回 null - assert result[0] is None - - def test_validation(self): - """测试参数验证""" - with pytest.raises(ValueError): - MovingAverageFactor(period=0) - - with pytest.raises(ValueError): - MovingAverageFactor(period=-1) -``` - -### 13.3 测试要点 - -1. **初始化测试**:验证 name、params、data_specs 正确设置 -2. **计算逻辑测试**:使用构造数据验证计算结果 -3. **边界条件测试**:空数据、单条数据、缺失值处理 -4. **参数验证测试**:无效参数应抛出 ValueError -5. **防泄露测试**:确保因子不会访问不应访问的数据 - -### 13.4 运行测试 - -```bash -# 运行所有因子测试 -uv run pytest tests/factors/ - -# 运行特定测试文件 -uv run pytest tests/factors/test_momentum.py - -# 运行特定测试类 -uv run pytest tests/factors/test_momentum.py::TestMovingAverageFactor - -# 运行特定测试方法 -uv run pytest tests/factors/test_momentum.py::TestMovingAverageFactor::test_compute - -# 带覆盖率报告 -uv run pytest tests/factors/ --cov=src.factors --cov-report=term-missing -``` - ---- - -## 14. 完整示例 - -### 14.1 截面因子示例:PE 排名因子 - -```python -"""估值因子 - 市盈率排名 - -本模块提供市盈率(PE)排名因子: -- PERankFactor: PE截面排名因子 - -使用示例: - >>> from src.factors.valuation import PERankFactor - >>> pe_rank = PERankFactor() - >>> # 每天返回所有股票的PE排名(0-1之间) -""" - -from typing import List - -import polars as pl - -from src.factors.base import CrossSectionalFactor -from src.factors.data_spec import DataSpec, FactorData - - -class PERankFactor(CrossSectionalFactor): - """市盈率(PE)排名因子 - - 计算逻辑:每个交易日,对所有股票的PE进行截面排名。 - - 特点: - - 截面因子:每天对所有股票进行横向排名 - - 值域:0-1之间,0表示PE最小,1表示PE最大 - - Attributes: - name: 因子名称 "pe_rank" - category: 因子分类 "valuation" - - Example: - >>> pe_rank = PERankFactor() - >>> # 每个交易日,返回所有股票的PE排名 - """ - - name: str = "pe_rank" - factor_type: str = "cross_sectional" - category: str = "valuation" - description: str = "市盈率截面排名因子,值域0-1" - data_specs: List[DataSpec] = [ - DataSpec( - "daily", - ["ts_code", "trade_date", "pe"], - lookback_days=1 - ) - ] - - def compute(self, data: FactorData) -> pl.Series: - """计算PE排名 - - Args: - data: FactorData,包含当前日期的截面数据 - - Returns: - PE排名的0-1标准化值 - """ - # 获取当前日期的截面数据 - cs = data.get_cross_section() - - if len(cs) == 0: - return pl.Series(name=self.name, values=[]) - - # 获取PE值,处理负值和缺失值 - pe = cs["pe"] - # PE为负表示亏损,设为一个大值(排名靠后) - pe = pe.fill_null(9999) - pe = pl.when(pe < 0).then(9999).otherwise(pe) - - # 计算排名(升序,PE越小排名越靠前) - if len(pe) > 1 and pe.max() != pe.min(): - ranks = pe.rank(method="average") / len(pe) - else: - # 数据不足或全部相同,返回0.5 - ranks = pl.Series(name=self.name, values=[0.5] * len(pe)) - - return ranks -``` - -### 14.2 时序因子示例:RSI 因子 - -```python -"""技术指标 - 相对强弱指标(RSI) - -本模块提供RSI因子: -- RSIFactor: 相对强弱指标(时序因子) - -使用示例: - >>> from src.factors.technical import RSIFactor - >>> rsi14 = RSIFactor(period=14) - >>> # 计算14日RSI -""" - -from typing import List - -import polars as pl - -from src.factors.base import TimeSeriesFactor -from src.factors.data_spec import DataSpec, FactorData - - -class RSIFactor(TimeSeriesFactor): - """相对强弱指标(RSI)因子 - - 计算逻辑:对每只股票,计算其过去n日的RSI值。 - - RSI = 100 - (100 / (1 + RS)) - RS = 平均上涨幅度 / 平均下跌幅度 - - 特点: - - 参数化因子:通过 period 参数指定计算窗口 - - 时序因子:每只股票单独计算 - - 值域:0-100,通常>70为超买,<30为超卖 - - Attributes: - period: RSI计算期(天数),默认14 - - Example: - >>> rsi14 = RSIFactor(period=14) - >>> # 计算14日RSI - """ - - name: str = "rsi" - factor_type: str = "time_series" - category: str = "technical" - description: str = "相对强弱指标,值域0-100" - data_specs: List[DataSpec] = [] - - def __init__(self, period: int = 14): - """初始化RSI因子 - - Args: - period: RSI计算期(天数),默认14 - """ - super().__init__(period=period) - # RSI需要 period+1 天的数据来计算涨跌幅 - self.data_specs = [ - DataSpec( - "daily", - ["ts_code", "trade_date", "close"], - lookback_days=period + 1, - ) - ] - self.name = f"rsi_{period}" - - def compute(self, data: FactorData) -> pl.Series: - """计算RSI - - Args: - data: FactorData,包含单只股票的完整时间序列 - - Returns: - RSI值序列(0-100) - """ - # 获取收盘价 - close = data.get_column("close") - period = self.params["period"] - - # 计算涨跌幅 - delta = close.diff() - - # 分离上涨和下跌 - gain = pl.when(delta > 0).then(delta).otherwise(0) - loss = pl.when(delta < 0).then(-delta).otherwise(0) - - # 计算平均上涨和平均下跌(使用指数移动平均) - avg_gain = gain.ewm_mean(span=period, min_periods=period) - avg_loss = loss.ewm_mean(span=period, min_periods=period) - - # 计算RS和RSI - rs = avg_gain / avg_loss - rsi = 100 - (100 / (1 + rs)) - - return rsi -``` - -### 14.3 使用示例 - -```python -from src.factors import FactorEngine, DataLoader -from src.factors.valuation import PERankFactor -from src.factors.technical import RSIFactor - -# 创建数据加载器和执行引擎 -loader = DataLoader(data_dir="data") -engine = FactorEngine(loader) - -# 创建因子 -pe_rank = PERankFactor() -rsi14 = RSIFactor(period=14) - -# 计算截面因子 -pe_result = engine.compute( - pe_rank, - start_date="20240101", - end_date="20240131" -) - -# 计算时序因子 -rsi_result = engine.compute( - rsi14, - stock_codes=["000001.SZ", "000002.SZ"], - start_date="20240101", - end_date="20240131" -) - -# 因子组合(同类型才能组合) -# combined = 0.5 * pe_rank + 0.5 * other_cs_factor -``` - ---- - -## 15. 常见错误 - -### 15.1 类型错误 - -#### 错误 1:混淆截面因子和时序因子 - -```python -# 错误:时序计算却继承截面因子 -class MovingAverageFactor(CrossSectionalFactor): - def compute(self, data): - # 这里只会传入一天的数据,无法计算MA - return data.get_column("close").rolling_mean(window_size=20) - -# 正确:继承时序因子 -class MovingAverageFactor(TimeSeriesFactor): - def compute(self, data): - # 这里传入完整时间序列,可以计算MA - return data.get_column("close").rolling_mean(window_size=20) -``` - -#### 错误 2:在截面因子中访问其他日期的数据 - -```python -# 错误:试图访问历史数据 -class MyFactor(CrossSectionalFactor): - def compute(self, data): - df = data.to_polars() - # 错误:data 只包含当前日期的数据 - yesterday_data = df.filter(pl.col("trade_date") == "20240101") - -# 正确:使用 lookback_days 获取历史数据 -class MyFactor(CrossSectionalFactor): - data_specs = [DataSpec("daily", ["close"], lookback_days=5)] - - def compute(self, data): - df = data.to_polars() - # 现在 df 包含 [T-4, T] 的数据 - return df["close"].mean() -``` - -### 15.2 命名错误 - -#### 错误 3:使用非法字符或格式 - -```python -# 错误 -name = "PE Rank" # 包含空格 -name = "pe-rank" # 使用连字符 -name = "PERank" # 使用 PascalCase - -# 正确 -name = "pe_rank" # snake_case -``` - -#### 错误 4:参数化因子未动态设置 name - -```python -# 错误 -class MovingAverageFactor(TimeSeriesFactor): - name = "ma" # 固定名称,无法区分 ma_5 和 ma_10 - - def __init__(self, period: int = 5): - super().__init__(period=period) - -# 正确 -class MovingAverageFactor(TimeSeriesFactor): - name = "ma" # 默认名称 - - def __init__(self, period: int = 5): - super().__init__(period=period) - self.name = f"ma_{period}" # 动态设置名称 -``` - -### 15.3 数据规范错误 - -#### 错误 5:忘记包含必需列 - -```python -# 错误 -data_specs = [DataSpec("daily", ["close"], lookback_days=5)] -# 缺少 ts_code 和 trade_date - -# 正确 -data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)] -``` - -#### 错误 6:lookback_days 设置不当 - -```python -# 错误:计算20日MA但只请求5天数据 -class MA20Factor(TimeSeriesFactor): - data_specs = [DataSpec("daily", ["close"], lookback_days=5)] - - def compute(self, data): - return data.get_column("close").rolling_mean(window_size=20) - -# 正确 -class MA20Factor(TimeSeriesFactor): - data_specs = [DataSpec("daily", ["close"], lookback_days=20)] - - def compute(self, data): - return data.get_column("close").rolling_mean(window_size=20) -``` - -### 15.4 计算逻辑错误 - -#### 错误 7:返回类型错误 - -```python -# 错误:返回 DataFrame 而不是 Series -class MyFactor(CrossSectionalFactor): - def compute(self, data): - cs = data.get_cross_section() - return cs # 错误:返回了 DataFrame - -# 正确 -class MyFactor(CrossSectionalFactor): - def compute(self, data): - cs = data.get_cross_section() - return cs["pe"].rank() # 正确:返回 Series -``` - -#### 错误 8:Series 长度不匹配 - -```python -# 错误:返回的 Series 长度与股票数量不匹配 -class MyFactor(CrossSectionalFactor): - def compute(self, data): - cs = data.get_cross_section() - # 错误:只返回3个值,但可能有3000只股票 - return pl.Series([0.1, 0.2, 0.3]) - -# 正确 -class MyFactor(CrossSectionalFactor): - def compute(self, data): - cs = data.get_cross_section() - # 正确:返回与股票数量相同的 Series - return cs["pe"].rank() / len(cs) -``` - -### 15.5 性能错误 - -#### 错误 9:使用 Python 循环 - -```python -# 错误:使用 Python 循环 -def compute(self, data): - close = data.get_column("close") - result = [] - for i in range(len(close)): - if i < 20: - result.append(None) - else: - result.append(sum(close[i-20:i]) / 20) - return pl.Series(result) - -# 正确:使用向量化操作 -def compute(self, data): - close = data.get_column("close") - return close.rolling_mean(window_size=20) -``` - ---- - -## 16. 验证清单 - -### 16.1 开发前检查 - -- [ ] 确定因子类型(截面/时序) -- [ ] 确定因子分类(momentum/financial/...) -- [ ] 确定需要的数据字段 -- [ ] 确定是否需要参数化 -- [ ] 确定因子的经济含义和计算逻辑 - -### 16.2 编码检查 - -- [ ] 类继承正确的基类(CrossSectionalFactor/TimeSeriesFactor) -- [ ] 声明了必需的类属性(name, factor_type, data_specs) -- [ ] data_specs 包含 ts_code 和 trade_date -- [ ] lookback_days 设置正确 -- [ ] name 使用 snake_case -- [ ] 类名使用 PascalCase 且以 Factor 结尾 -- [ ] 文件名使用 snake_case -- [ ] 所有函数都有类型提示 -- [ ] 文档字符串符合 Google 风格 -- [ ] compute 方法返回 pl.Series - -### 16.3 防泄露检查 - -- [ ] 截面因子只访问当前日期数据 -- [ ] 时序因子不访问其他股票数据 -- [ ] 没有使用 `data.to_polars()` 绕过框架 -- [ ] 没有硬编码日期或股票代码 - -### 16.4 性能检查 - -- [ ] 使用 Polars 向量化操作 -- [ ] 避免 Python 循环 -- [ ] 只加载需要的列 -- [ ] lookback_days 没有设置过大 - -### 16.5 测试检查 - -- [ ] 编写了初始化测试 -- [ ] 编写了计算逻辑测试 -- [ ] 编写了边界条件测试(空数据、单条数据) -- [ ] 编写了参数验证测试 -- [ ] 所有测试通过 - -### 16.6 文档检查 - -- [ ] 文件头文档完整 -- [ ] 类文档字符串完整 -- [ ] 方法文档字符串完整 -- [ ] 更新了分类 __init__.py -- [ ] 更新了主 __init__.py(如需要) - ---- - -## 附录 - -### A. 常见问题 - -#### Q1: 截面因子和时序因子有什么区别? - -**截面因子**:每天对所有股票进行一次计算,可以比较不同股票之间的值。例如:PE排名、市值分位数。 - -**时序因子**:对每只股票单独计算其时间序列。例如:移动平均线、RSI、历史波动率。 - -#### Q2: lookback_days 怎么设置? - -- 如果只需要当日数据:`lookback_days=1` -- 如果需要n日历史数据:`lookback_days=n` -- 对于收益率计算(需要T日和T-n日):`lookback_days=n+1` - -#### Q3: 如何处理缺失值? - -```python -# Polars 提供多种缺失值处理方法 -series.fill_null(0) # 填充为0 -series.fill_null(strategy="forward") # 向前填充 -series.drop_nulls() # 删除缺失值 -``` - -#### Q4: 如何调试因子计算? - -```python -# 1. 构造测试数据 -df = pl.DataFrame({ - "ts_code": [...], - "trade_date": [...], - "close": [...], -}) - -# 2. 创建 FactorData -context = FactorContext(current_date="20240101") -data = FactorData(df, context) - -# 3. 直接调用 compute -factor = MyFactor() -result = factor.compute(data) -print(result) -``` - -#### Q5: 可以使用 Pandas 吗? - -推荐使用 **Polars**,因为: -- 性能更好(向量化操作) -- 内存占用更低 -- 与框架其他部分保持一致 - -如果必须使用 Pandas: -```python -def compute(self, data: FactorData) -> pl.Series: - # 转换为 Pandas - pdf = data.to_polars().to_pandas() - - # 使用 Pandas 计算 - result = pdf["close"].rolling(window=20).mean() - - # 转回 Polars Series - return pl.Series(name=self.name, values=result.values) -``` - -### B. 快速参考卡 - -#### 创建截面因子 - -```python -from src.factors.base import CrossSectionalFactor -from src.factors.data_spec import DataSpec, FactorData -import polars as pl -from typing import List - -class MyFactor(CrossSectionalFactor): - name = "my_factor" - factor_type = "cross_sectional" - category = "momentum" - data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1)] - - def compute(self, data: FactorData) -> pl.Series: - cs = data.get_cross_section() - return cs["close"].rank() -``` - -#### 创建时序因子 - -```python -from src.factors.base import TimeSeriesFactor -from src.factors.data_spec import DataSpec, FactorData -import polars as pl -from typing import List - -class MyFactor(TimeSeriesFactor): - name = "my_factor" - factor_type = "time_series" - category = "momentum" - data_specs = [] - - def __init__(self, period: int = 5): - super().__init__(period=period) - self.data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=period)] - self.name = f"my_factor_{period}" - - def compute(self, data: FactorData) -> pl.Series: - close = data.get_column("close") - return close.rolling_mean(window_size=self.params["period"]) -``` - -### C. 相关文档 - -- [因子框架设计](../../docs/factor_framework_design.md) -- [数据同步指南](../../docs/db_sync_guide.md) -- [API 接口规范](../../src/data/api_wrappers/API_INTERFACE_SPEC.md) - ---- - -**文档版本**: 2.0 -**最后更新**: 2026-02-25 -**维护者**: ProStock Team diff --git a/src/factors/__init__.py b/src/factors/__init__.py deleted file mode 100644 index e46c282..0000000 --- a/src/factors/__init__.py +++ /dev/null @@ -1,118 +0,0 @@ -"""ProStock 因子计算框架 - -因子框架提供以下核心功能: -1. 类型安全的因子定义(截面因子、时序因子) -2. 数据泄露防护机制 -3. 因子组合和运算 -4. 高效的数据加载和计算引擎 - -基础数据类型(Phase 1): -- DataSpec: 数据需求规格 -- FactorContext: 计算上下文 -- FactorData: 数据容器 - -因子基类(Phase 2): -- BaseFactor: 抽象基类 -- CrossSectionalFactor: 日期截面因子基类 -- TimeSeriesFactor: 时间序列因子基类 -- CompositeFactor: 组合因子 -- ScalarFactor: 标量运算因子 - -因子分类目录: -- momentum/: 动量因子(MA、收益率排名等) -- financial/: 财务因子(EPS、ROE等) -- valuation/: 估值因子(PE、PB、PS等) -- technical/: 技术指标因子(RSI、MACD、布林带等) -- quality/: 质量因子(盈利能力、稳定性等) -- sentiment/: 情绪因子(换手率、资金流向等) -- volume/: 成交量因子(OBV、成交量比率等) -- volatility/: 波动率因子(历史波动率、GARCH等) - -数据加载和执行(Phase 3-4): -- DataLoader: 数据加载器 -- FactorEngine: 因子执行引擎 - -使用示例: - # 使用通用因子(参数化) - from src.factors import MovingAverageFactor, ReturnRankFactor - from src.factors import DataLoader, FactorEngine - - ma5 = MovingAverageFactor(period=5) # 5日MA - ma10 = MovingAverageFactor(period=10) # 10日MA - ret5 = ReturnRankFactor(period=5) # 5日收益率排名 - - loader = DataLoader(data_dir="data") - engine = FactorEngine(loader) - result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131") -""" - -因子框架提供以下核心功能: -1. 类型安全的因子定义(截面因子、时序因子) -2. 数据泄露防护机制 -3. 因子组合和运算 -4. 高效的数据加载和计算引擎 - -基础数据类型(Phase 1): -- DataSpec: 数据需求规格 -- FactorContext: 计算上下文 -- FactorData: 数据容器 - -因子基类(Phase 2): -- BaseFactor: 抽象基类 -- CrossSectionalFactor: 日期截面因子基类 -- TimeSeriesFactor: 时间序列因子基类 -- CompositeFactor: 组合因子 -- ScalarFactor: 标量运算因子 - -动量因子(momentum/): -- MovingAverageFactor: 移动平均线(时序因子) -- ReturnRankFactor: 收益率排名(截面因子) - -财务因子(financial/): -- (待添加) - -数据加载和执行(Phase 3-4): -- DataLoader: 数据加载器 -- FactorEngine: 因子执行引擎 - -使用示例: - # 使用通用因子(参数化) - from src.factors import MovingAverageFactor, ReturnRankFactor - from src.factors import DataLoader, FactorEngine - - ma5 = MovingAverageFactor(period=5) # 5日MA - ma10 = MovingAverageFactor(period=10) # 10日MA - ret5 = ReturnRankFactor(period=5) # 5日收益率排名 - - loader = DataLoader(data_dir="data") - engine = FactorEngine(loader) - result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131") -""" - -from src.factors.data_spec import DataSpec, FactorContext, FactorData -from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor -from src.factors.composite import CompositeFactor, ScalarFactor -from src.factors.data_loader import DataLoader -from src.factors.engine import FactorEngine - -# 动量因子 -from src.factors.momentum import MovingAverageFactor, ReturnRankFactor - -__all__ = [ - # Phase 1: 数据类型定义 - "DataSpec", - "FactorContext", - "FactorData", - # Phase 2: 因子基类 - "BaseFactor", - "CrossSectionalFactor", - "TimeSeriesFactor", - "CompositeFactor", - "ScalarFactor", - # Phase 3-4: 数据加载和执行引擎 - "DataLoader", - "FactorEngine", - # 动量因子 - "MovingAverageFactor", - "ReturnRankFactor", -] diff --git a/src/factors/base.py b/src/factors/base.py deleted file mode 100644 index a3bec75..0000000 --- a/src/factors/base.py +++ /dev/null @@ -1,274 +0,0 @@ -"""因子基类 - Phase 2 核心抽象类 - -本模块定义了因子框架的基类: -- BaseFactor: 抽象基类,定义通用接口和验证逻辑 -- CrossSectionalFactor: 日期截面因子基类(防止日期泄露) -- TimeSeriesFactor: 时间序列因子基类(防止股票泄露) -""" - -from abc import ABC, abstractmethod -from dataclasses import field -from typing import List - -import polars as pl - -from src.factors.data_spec import DataSpec, FactorData - - -class BaseFactor(ABC): - """因子基类 - 定义通用接口 - - 所有因子必须继承此类,并声明以下类属性: - - name: 因子唯一标识(snake_case) - - factor_type: "cross_sectional" 或 "time_series" - - data_specs: List[DataSpec] 数据需求列表 - - 可选声明: - - category: 因子分类(默认 "default") - - description: 因子描述 - - 示例: - >>> class MyFactor(CrossSectionalFactor): - ... name = "my_factor" - ... data_specs = [DataSpec("daily", ["close"], lookback_days=5)] - ... - ... def compute(self, data: FactorData) -> pl.Series: - ... return data.get_column("close").rank() - """ - - # 必须声明的类属性 - name: str = "" - factor_type: str = "" # "cross_sectional" | "time_series" - data_specs: List[DataSpec] = field(default_factory=list) - - # 可选声明的类属性 - category: str = "default" - description: str = "" - - def __init_subclass__(cls, **kwargs): - """子类创建时验证必须属性 - - 验证项: - 1. name 必须是非空字符串 - 2. factor_type 必须是 "cross_sectional" 或 "time_series" - 3. data_specs 必须是非空列表 - """ - super().__init_subclass__(**kwargs) - - # 跳过抽象基类和特殊因子类的验证 - if cls.__name__ in ( - "CrossSectionalFactor", - "TimeSeriesFactor", - "CompositeFactor", - "ScalarFactor", - ): - return - - # 验证 name - 必须直接定义在类中(不能继承) - if "name" not in cls.__dict__ or not cls.name: - raise ValueError(f"Factor {cls.__name__} must define 'name'") - if not isinstance(cls.name, str): - raise ValueError(f"Factor {cls.__name__}.name must be a string") - - # 验证 factor_type - 必须有值(可以是继承的) - if not cls.factor_type: - raise ValueError(f"Factor {cls.__name__} must define 'factor_type'") - if cls.factor_type not in ("cross_sectional", "time_series"): - raise ValueError( - f"Factor {cls.__name__}.factor_type must be 'cross_sectional' " - f"or 'time_series', got '{cls.factor_type}'" - ) - - # 验证 data_specs - # 情况1: 完全没有定义 data_specs(继承的空列表) - if "data_specs" not in cls.__dict__: - raise ValueError(f"Factor {cls.__name__} must define 'data_specs'") - # 情况2: 定义了但为空列表 - if not cls.data_specs or len(cls.data_specs) == 0: - raise ValueError(f"Factor {cls.__name__}.data_specs cannot be empty") - if not isinstance(cls.data_specs, list): - raise ValueError(f"Factor {cls.__name__}.data_specs must be a list") - - def __init__(self, **params): - """初始化因子参数 - - 子类可通过 __init__ 接收参数化配置,如 MA(period=20) - - 注意:data_specs 必须在类级别定义(类属性), - 而非在 __init__ 中设置。data_specs 的验证在 - __init_subclass__ 中完成(类创建时)。 - - Args: - **params: 因子参数,存储在 self.params 中 - """ - self.params = params - - def _validate_params(self): - """验证参数有效性 - - 子类可覆盖此方法进行自定义验证(需自行在子类 __init__ 中调用)。 - 基类实现为空,表示不执行任何验证。 - - 注意:由于 data_specs 在类创建时通过 __init_subclass__ 验证, - 不应在实例级别修改。如需动态 data_specs,请使用参数化模式: - - >>> class ParamFactor(TimeSeriesFactor): - ... name = "param_factor" - ... data_specs = [] # 类级别定义 - ... - ... def __init__(self, period: int = 20): - ... super().__init__(period=period) - ... # 通过参数化改变计算逻辑,而非 data_specs - ... - ... def compute(self, data: FactorData) -> pl.Series: - ... return data.get_column("close").rolling_mean(self.params["period"]) - """ - pass - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """核心计算逻辑 - 子类必须实现 - - Args: - data: 安全的数据容器,已根据因子类型裁剪 - - Returns: - 计算得到的因子值 Series - """ - pass - - # ========== 因子组合运算符 ========== - - def __add__(self, other: "BaseFactor") -> "CompositeFactor": - """因子相加:f1 + f2(要求同类型)""" - from src.factors.composite import CompositeFactor - - return CompositeFactor(self, other, "+") - - def __sub__(self, other: "BaseFactor") -> "CompositeFactor": - """因子相减:f1 - f2(要求同类型)""" - from src.factors.composite import CompositeFactor - - return CompositeFactor(self, other, "-") - - def __mul__(self, other): - """因子相乘:f1 * f2 或 f1 * scalar""" - if isinstance(other, (int, float)): - from src.factors.composite import ScalarFactor - - return ScalarFactor(self, float(other), "*") - elif isinstance(other, BaseFactor): - from src.factors.composite import CompositeFactor - - return CompositeFactor(self, other, "*") - return NotImplemented - - def __truediv__(self, other: "BaseFactor") -> "CompositeFactor": - """因子相除:f1 / f2(要求同类型)""" - from src.factors.composite import CompositeFactor - - return CompositeFactor(self, other, "/") - - def __rmul__(self, scalar: float) -> "ScalarFactor": - """标量乘法:0.5 * f1""" - from src.factors.composite import ScalarFactor - - return ScalarFactor(self, scalar, "*") - - def __repr__(self) -> str: - """返回因子的字符串表示""" - return ( - f"{self.__class__.__name__}(name='{self.name}', type='{self.factor_type}')" - ) - - -class CrossSectionalFactor(BaseFactor): - """日期截面因子基类 - - 计算逻辑:在每个交易日,对所有股票进行横向计算 - - 防泄露边界: - - ❌ 禁止访问未来日期的数据(日期泄露) - - ✅ 允许访问当前日期的所有股票数据 - - 数据传入: - - compute() 接收的是 [T-lookback+1, T] 的数据 - - 包含 lookback_days 的历史数据(用于时序计算后再截面) - - 示例: - >>> class PERankFactor(CrossSectionalFactor): - ... name = "pe_rank" - ... data_specs = [DataSpec("daily", ["pe"], lookback_days=1)] - ... - ... def compute(self, data: FactorData) -> pl.Series: - ... cs = data.get_cross_section() - ... return cs["pe"].rank() - """ - - factor_type: str = "cross_sectional" - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """计算截面因子值 - - Args: - data: FactorData,包含 [T-lookback+1, T] 的截面数据 - 格式:DataFrame[ts_code, trade_date, col1, col2, ...] - - Returns: - pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量) - - 示例: - >>> def compute(self, data): - ... # 获取当前日期截面 - ... cs = data.get_cross_section() - ... # 计算市值排名 - ... return cs['market_cap'].rank() - """ - pass - - -class TimeSeriesFactor(BaseFactor): - """时间序列因子基类(股票截面) - - 计算逻辑:对每只股票,在其时间序列上进行纵向计算 - - 防泄露边界: - - ❌ 禁止访问其他股票的数据(股票泄露) - - ✅ 允许访问该股票的完整历史数据 - - 数据传入: - - compute() 接收的是单只股票的完整时间序列 - - 包含该股票在 [start_date, end_date] 范围内的所有数据 - - 示例: - >>> class MovingAverageFactor(TimeSeriesFactor): - ... name = "ma" - ... - ... def __init__(self, period: int = 20): - ... super().__init__(period=period) - ... self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)] - ... - ... def compute(self, data: FactorData) -> pl.Series: - ... return data.get_column("close").rolling_mean(self.params["period"]) - """ - - factor_type: str = "time_series" - - @abstractmethod - def compute(self, data: FactorData) -> pl.Series: - """计算时间序列因子值 - - Args: - data: FactorData,包含单只股票的完整时间序列 - 格式:DataFrame[ts_code, trade_date, col1, col2, ...] - - Returns: - pl.Series: 该股票在各日期的因子值(长度 = 日期数量) - - 示例: - >>> def compute(self, data): - ... series = data.get_column("close") - ... return series.rolling_mean(window_size=self.params['period']) - """ - pass diff --git a/src/factors/composite.py b/src/factors/composite.py deleted file mode 100644 index d800ac7..0000000 --- a/src/factors/composite.py +++ /dev/null @@ -1,201 +0,0 @@ -"""组合因子 - Phase 2 因子组合和标量运算 - -本模块定义了因子组合相关的类: -- CompositeFactor: 组合因子,用于实现因子间的数学运算 -- ScalarFactor: 标量运算因子,支持因子与标量的运算 -""" - -from typing import List - -import polars as pl - -from src.factors.data_spec import DataSpec, FactorData -from src.factors.base import BaseFactor - - -class CompositeFactor(BaseFactor): - """组合因子 - 用于实现因子间的数学运算 - - 约束:左右因子必须是同类型(同为截面或同为时序) - - 支持的运算符:'+', '-', '*', '/' - - 示例: - >>> f1 = SomeCrossSectionalFactor() - >>> f2 = AnotherCrossSectionalFactor() - >>> combined = f1 + f2 # 创建 CompositeFactor - """ - - def __init__(self, left: BaseFactor, right: BaseFactor, op: str): - """创建组合因子 - - Args: - left: 左操作数因子 - right: 右操作数因子 - op: 运算符,支持 '+', '-', '*', '/' - - Raises: - ValueError: 左右因子类型不一致 - ValueError: 不支持的运算符 - """ - # 验证类型一致性 - if left.factor_type != right.factor_type: - raise ValueError( - f"Cannot combine factors of different types: " - f"'{left.factor_type}' vs '{right.factor_type}'" - ) - - # 验证运算符 - if op not in ("+", "-", "*", "/"): - raise ValueError(f"Unsupported operator: '{op}'") - - self.left = left - self.right = right - self.op = op - - # 设置类属性 - self.factor_type = left.factor_type - self.name = f"({left.name}_{op}_{right.name})" - self.data_specs = self._merge_data_specs() - self.category = "composite" - self.description = f"Composite factor: {left.name} {op} {right.name}" - - # 注意:不调用 super().__init__(),因为 CompositeFactor 是特殊因子 - self.params = { - "left": left, - "right": right, - "op": op, - } - - def _merge_data_specs(self) -> List[DataSpec]: - """合并左右因子的数据需求 - - 策略: - 1. 相同 source 和 columns 的 DataSpec 合并 - 2. lookback_days 取最大值 - - Returns: - 合并后的 DataSpec 列表 - """ - merged = [] - - # 收集所有 specs - all_specs = list(self.left.data_specs) + list(self.right.data_specs) - - # 按 (source, columns_tuple) 分组 - spec_groups = {} - for spec in all_specs: - key = (spec.source, tuple(sorted(spec.columns))) - if key not in spec_groups: - spec_groups[key] = [] - spec_groups[key].append(spec) - - # 合并每组,取最大 lookback_days - for (source, columns_tuple), specs in spec_groups.items(): - max_lookback = max(spec.lookback_days for spec in specs) - merged.append( - DataSpec( - source=source, - columns=list(columns_tuple), - lookback_days=max_lookback, - ) - ) - - return merged - - def compute(self, data: FactorData) -> pl.Series: - """执行组合运算 - - 流程: - 1. 分别计算 left 和 right 的值 - 2. 根据 op 执行运算 - 3. 返回结果 - - Args: - data: 包含左右因子所需数据的 FactorData - - Returns: - 组合运算后的因子值 Series - """ - left_values = self.left.compute(data) - right_values = self.right.compute(data) - - ops = { - "+": lambda a, b: a + b, - "-": lambda a, b: a - b, - "*": lambda a, b: a * b, - "/": lambda a, b: a / b, - } - - return ops[self.op](left_values, right_values) - - def _validate_params(self): - """CompositeFactor 不需要额外验证""" - pass - - -class ScalarFactor(BaseFactor): - """标量运算因子 - - 支持:scalar * factor, factor * scalar(通过 __rmul__) - - 示例: - >>> factor = SomeFactor() - >>> scaled = 0.5 * factor # 创建 ScalarFactor - """ - - def __init__(self, factor: BaseFactor, scalar: float, op: str): - """创建标量运算因子 - - Args: - factor: 基础因子 - scalar: 标量值 - op: 运算符,支持 '*', '+' - - Raises: - ValueError: 不支持的运算符 - """ - # 验证运算符 - if op not in ("*", "+"): - raise ValueError(f"ScalarFactor only supports '*' and '+', got '{op}'") - - self.factor = factor - self.scalar = scalar - self.op = op - - # 设置类属性 - self.factor_type = factor.factor_type - self.name = f"({scalar}_{op}_{factor.name})" - self.data_specs = factor.data_specs - self.category = "scalar" - self.description = f"Scalar factor: {scalar} {op} {factor.name}" - - # 注意:不调用 super().__init__(),因为 ScalarFactor 是特殊因子 - self.params = { - "factor": factor, - "scalar": scalar, - "op": op, - } - - def compute(self, data: FactorData) -> pl.Series: - """执行标量运算 - - Args: - data: 包含基础因子所需数据的 FactorData - - Returns: - 标量运算后的因子值 Series - """ - values = self.factor.compute(data) - - if self.op == "*": - return values * self.scalar - elif self.op == "+": - return values + self.scalar - else: - # 不应该执行到这里,因为 __init__ 已经验证了 op - raise ValueError(f"Unsupported operation: '{self.op}'") - - def _validate_params(self): - """ScalarFactor 不需要额外验证""" - pass diff --git a/src/factors/data_loader.py b/src/factors/data_loader.py deleted file mode 100644 index 5c81bc0..0000000 --- a/src/factors/data_loader.py +++ /dev/null @@ -1,213 +0,0 @@ -"""数据加载器 - Phase 3 数据加载模块 - -本模块负责从 DuckDB 安全加载数据: -- DataLoader: 数据加载器,支持多文件聚合、列选择、缓存 -""" - -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import pandas as pd -import polars as pl - -from src.factors.data_spec import DataSpec - - -class DataLoader: - """数据加载器 - 负责从 DuckDB 安全加载数据 - - 功能: - 1. 多文件聚合:合并多个表的数据 - 2. 列选择:只加载需要的列 - 3. 原始数据缓存:避免重复读取 - 4. 查询下推:利用 DuckDB SQL 过滤,只加载必要数据 - - 示例: - >>> loader = DataLoader(data_dir="data") - >>> specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)] - >>> df = loader.load(specs, date_range=("20240101", "20240131")) - """ - - def __init__(self, data_dir: str): - """初始化 DataLoader - - Args: - data_dir: DuckDB 数据库文件所在目录 - """ - self.data_dir = Path(data_dir) - self._cache: Dict[str, pl.DataFrame] = {} - - def load( - self, - specs: List[DataSpec], - date_range: Optional[Tuple[str, str]] = None, - ) -> pl.DataFrame: - """加载并聚合多个 H5 文件的数据 - - 流程: - 1. 对每个 DataSpec: - a. 检查缓存,命中则直接使用 - b. 未命中则读取 HDF5(通过 pandas) - c. 转换为 Polars DataFrame - d. 按 date_range 过滤 - e. 存入缓存 - 2. 合并多个 DataFrame(按 trade_date 和 ts_code join) - - Args: - specs: 数据需求规格列表 - date_range: 日期范围限制 (start_date, end_date),可选 - - Returns: - 合并后的 Polars DataFrame - - Raises: - FileNotFoundError: H5 文件不存在 - KeyError: 列不存在于文件中 - """ - dataframes = [] - - for spec in specs: - # 检查缓存 - cache_key = f"{spec.source}_{','.join(sorted(spec.columns))}" - if cache_key in self._cache: - df = self._cache[cache_key] - else: - # 读取 H5 文件(传入日期范围以支持过滤) - df = self._read_h5(spec.source, date_range=date_range) - - # 列选择 - 只保留需要的列 - missing_cols = set(spec.columns) - set(df.columns) - if missing_cols: - raise KeyError( - f"Columns {missing_cols} not found in {spec.source}.h5. " - f"Available columns: {df.columns}" - ) - df = df.select(spec.columns) - - # 存入缓存 - self._cache[cache_key] = df - - # 按 date_range 过滤 - if date_range: - start_date, end_date = date_range - df = df.filter( - (pl.col("trade_date") >= start_date) - & (pl.col("trade_date") <= end_date) - ) - - dataframes.append(df) - - # 合并多个 DataFrame - if len(dataframes) == 1: - return dataframes[0] - else: - return self._merge_dataframes(dataframes) - - def clear_cache(self): - """清空缓存""" - self._cache.clear() - - def _read_h5( - self, - source: str, - date_range: Optional[Tuple[str, str]] = None, - ) -> pl.DataFrame: - """读取数据 - 从 DuckDB 加载为 Polars DataFrame。 - - 迁移说明: - - 方法名保持 _read_h5 以兼容现有代码(实际从 DuckDB 读取) - - 使用 Storage.load_polars() 直接返回 Polars DataFrame - - 支持零拷贝导出,性能优于 HDF5 + Pandas + Polars 转换 - - Args: - source: 表名(对应 DuckDB 中的表,如 "daily") - date_range: 日期范围限制 (start_date, end_date),可选 - - Returns: - Polars DataFrame - - Raises: - Exception: 数据库查询错误 - """ - from src.data.storage import Storage - from src.data.api_wrappers.api_trade_cal import get_trading_days - from src.data.utils import get_today_date - from src.factors.financial.utils import expand_period_to_trading_days - - storage = Storage() - - # 特殊处理财务数据:将报告期展开到交易日 - if source == "financial_income": - # 确定日期范围 - start_date = date_range[0] if date_range else "20180101" - end_date = date_range[1] if date_range else get_today_date() - - # 1. 加载原始财务数据(报告期粒度),按日期范围过滤 - # 注意:financial_income 使用 end_date 字段作为报告期 - df = storage.load_polars( - "financial_income", - start_date=start_date, - end_date=end_date, - ) - - if len(df) == 0: - return pl.DataFrame() - - # 2. 获取交易日历(从2018年开始到当前,确保有足够的历史数据用于前向填充) - # 需要从数据的最小日期开始,确保能获取到足够的交易日 - trade_start = "20180101" if start_date > "20180101" else start_date - trade_dates = get_trading_days(trade_start, get_today_date()) - - # 3. 展开到交易日(前向填充) - return expand_period_to_trading_days(df, trade_dates) - - # 其他数据源保持原有逻辑 - return storage.load_polars(source) - - def _merge_dataframes(self, dataframes: List[pl.DataFrame]) -> pl.DataFrame: - """合并多个 DataFrame - - 策略: - 1. 按 trade_date 和 ts_code join - 2. 使用外连接保留所有数据 - - Args: - dataframes: DataFrame 列表 - - Returns: - 合并后的 DataFrame - """ - result = dataframes[0] - - for df in dataframes[1:]: - # 确定 join 键 - join_keys = ["trade_date", "ts_code"] - - # 检查 join 键是否存在 - for key in join_keys: - if key not in result.columns or key not in df.columns: - raise KeyError(f"Join key '{key}' not found in DataFrames") - - # 获取需要添加的列(排除重复的 join 键) - new_cols = [c for c in df.columns if c not in result.columns] - - if new_cols: - # 选择必要的列进行 join - df_to_join = df.select(join_keys + new_cols) - - # 执行 join - result = result.join(df_to_join, on=join_keys, how="full") - - return result - - def get_cache_info(self) -> Dict[str, int]: - """获取缓存信息 - - Returns: - 包含缓存条目数和总字节数的字典 - """ - total_rows = sum(len(df) for df in self._cache.values()) - return { - "entries": len(self._cache), - "total_rows": total_rows, - } diff --git a/src/factors/data_spec.py b/src/factors/data_spec.py deleted file mode 100644 index e199318..0000000 --- a/src/factors/data_spec.py +++ /dev/null @@ -1,242 +0,0 @@ -"""数据类型定义 - Phase 1 核心数据模型 - -本模块定义了因子框架的基础数据类型: -- DataSpec: 数据需求规格,声明因子所需的数据源、列和回看窗口 -- FactorContext: 计算上下文,由引擎自动注入,提供计算点信息 -- FactorData: 数据容器,封装底层 Polars DataFrame,提供安全的数据访问 -""" - -from dataclasses import dataclass, field -from typing import List, Optional -import polars as pl - - -@dataclass(frozen=True) -class DataSpec: - """数据需求规格说明 - - 用于声明因子计算所需的数据来源、列和回看窗口。 - 这是一个不可变对象,创建后不可修改。 - - Args: - source: H5 文件名(如 "daily", "fundamental") - columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date" - lookback_days: 需要回看的天数(包含当日) - - 1 表示只需要当日数据 [T] - - 5 表示需要 [T-4, T] 共5天 - - 20 表示需要 [T-19, T] 共20天 - - Raises: - ValueError: 当参数不满足约束条件时 - - Examples: - >>> spec = DataSpec( - ... source="daily", - ... columns=["ts_code", "trade_date", "close"], - ... lookback_days=20 - ... ) - """ - - source: str - columns: List[str] - lookback_days: int = 1 - - def __post_init__(self): - """验证约束条件 - - 验证项: - 1. lookback_days >= 1(至少包含当日) - 2. columns 必须包含 ts_code 和 trade_date - 3. source 不能为空字符串 - - 注意:由于 frozen=True,实例创建后不可修改。 - 若需要在 __post_init__ 中修改字段(如有),可使用 object.__setattr__。 - 本类仅做验证,无需修改字段,因此直接 raise ValueError 即可。 - """ - if self.lookback_days < 1: - raise ValueError(f"lookback_days must be >= 1, got {self.lookback_days}") - - if not self.source: - raise ValueError("source cannot be empty string") - - required_cols = {"ts_code", "trade_date"} - missing_cols = required_cols - set(self.columns) - if missing_cols: - raise ValueError( - f"columns must contain {required_cols}, missing: {missing_cols}" - ) - - -@dataclass -class FactorContext: - """因子计算上下文 - - 由 FactorEngine 自动注入,因子开发者可通过 data.context 访问。 - 根据因子类型的不同,包含不同的上下文信息: - - CrossSectionalFactor:current_date 表示当前计算的日期 - - TimeSeriesFactor:current_stock 表示当前计算的股票 - - Attributes: - current_date: 当前计算日期 YYYYMMDD(截面因子使用) - current_stock: 当前计算股票代码(时序因子使用) - trade_dates: 交易日历列表(可选,用于对齐) - - Examples: - >>> context = FactorContext(current_date="20240101") - >>> context.current_date - '20240101' - """ - - current_date: Optional[str] = None - current_stock: Optional[str] = None - trade_dates: Optional[List[str]] = None - - -class FactorData: - """提供给因子的数据容器 - - 封装底层 Polars DataFrame,提供安全的数据访问接口。 - 根据因子类型的不同,包含不同的数据: - - CrossSectionalFactor:当前日期及历史 lookback 的截面数据(所有股票) - - TimeSeriesFactor:单只股票的完整时间序列数据 - - Args: - df: 底层的 Polars DataFrame - context: 计算上下文 - - Examples: - >>> df = pl.DataFrame({ - ... "ts_code": ["000001.SZ"], - ... "trade_date": ["20240101"], - ... "close": [10.0] - ... }) - >>> context = FactorContext(current_date="20240101") - >>> data = FactorData(df, context) - """ - - def __init__(self, df: pl.DataFrame, context: FactorContext): - self._df = df - self._context = context - - def get_column(self, col: str) -> pl.Series: - """获取指定列的数据 - - 适用于两种因子类型: - - 截面因子:获取当天所有股票的该列值 - - 时序因子:获取该股票时间序列的该列值 - - Args: - col: 列名 - - Returns: - Polars Series - - Raises: - KeyError: 列不存在于数据中 - - Examples: - >>> prices = data.get_column("close") - >>> print(prices) - """ - if col not in self._df.columns: - raise KeyError( - f"Column '{col}' not found in data. Available columns: {self._df.columns}" - ) - return self._df[col] - - def filter_by_date(self, date: str) -> "FactorData": - """按日期过滤数据,返回新的 FactorData - - 主要用于截面因子获取特定日期的数据。 - 注意:无法获取未来日期的数据(引擎已经裁剪掉)。 - - Args: - date: YYYYMMDD 格式的日期 - - Returns: - 过滤后的 FactorData(新实例,不修改原数据) - - Examples: - >>> today_data = data.filter_by_date("20240101") - >>> print(len(today_data)) - """ - filtered = self._df.filter(pl.col("trade_date") == date) - return FactorData(filtered, self._context) - - def get_cross_section(self) -> pl.DataFrame: - """获取当前日期的截面数据 - - 仅适用于截面因子,返回 current_date 当天的所有股票数据。 - - Returns: - DataFrame 包含当前日期的所有股票 - - Raises: - ValueError: current_date 未设置(非截面因子场景) - - Examples: - >>> cs = data.get_cross_section() - >>> rankings = cs["pe"].rank() - """ - if self._context.current_date is None: - raise ValueError( - "current_date is not set in context. " - "get_cross_section() is only applicable for cross-sectional factors." - ) - return self._df.filter(pl.col("trade_date") == self._context.current_date) - - def to_polars(self) -> pl.DataFrame: - """获取底层的 Polars DataFrame(高级用法) - - 返回原始 DataFrame,允许进行自定义的 Polars 操作。 - 注意:直接操作底层数据可能绕过框架的防泄露保护,请谨慎使用。 - - Returns: - 底层的 Polars DataFrame - - Examples: - >>> df = data.to_polars() - >>> result = df.group_by("industry").agg(pl.col("pe").mean()) - """ - return self._df - - @property - def context(self) -> FactorContext: - """获取计算上下文 - - Returns: - 当前的 FactorContext 实例 - - Examples: - >>> date = data.context.current_date - >>> stock = data.context.current_stock - """ - return self._context - - def __len__(self) -> int: - """返回数据行数 - - Returns: - DataFrame 的行数 - - Examples: - >>> if len(data) > 0: - ... result = data.get_column("close").mean() - """ - return len(self._df) - - def __repr__(self) -> str: - """返回 FactorData 的字符串表示 - - Returns: - 包含类名、行数、列数和上下文信息的字符串 - """ - cols = self._df.columns - context_info = [] - if self._context.current_date: - context_info.append(f"date={self._context.current_date}") - if self._context.current_stock: - context_info.append(f"stock={self._context.current_stock}") - - context_str = ", ".join(context_info) if context_info else "no context" - return f"FactorData(rows={len(self)}, cols={len(cols)}, {context_str})" diff --git a/src/factors/financial/__init__.py b/src/factors/financial/__init__.py deleted file mode 100644 index 839e6f4..0000000 --- a/src/factors/financial/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""财务因子模块 - -本模块提供财务类型的因子: - -因子分类: -- financial: 财务因子 - - EPSFactor: 每股收益排名因子 - -已添加因子: -- EPSFactor: 每股收益排名(基于basic_eps) - -待添加因子: -- PERankFactor: 市盈率排名 -- PBFactor: 市净率因子 -- DividendFactor: 股息率因子 -""" - -from src.factors.financial.eps_factor import EPSFactor - -__all__ = ["EPSFactor"] diff --git a/src/factors/financial/eps_factor.py b/src/factors/financial/eps_factor.py deleted file mode 100644 index 4151b89..0000000 --- a/src/factors/financial/eps_factor.py +++ /dev/null @@ -1,66 +0,0 @@ -"""EPS因子 - -每股收益(EPS)排名因子实现 -""" - -from typing import List -import polars as pl - -from src.factors.base import CrossSectionalFactor -from src.factors.data_spec import DataSpec, FactorData - - -class EPSFactor(CrossSectionalFactor): - """每股收益(EPS)排名因子 - - 计算逻辑:使用最新报告期的basic_eps,每天对所有股票进行截面排名 - - Attributes: - name: 因子名称 "eps_rank" - category: 因子分类 "financial" - data_specs: 数据需求规格 - - Example: - >>> from src.factors import FactorEngine, DataLoader - >>> from src.factors.financial.eps_factor import EPSFactor - >>> loader = DataLoader('data') - >>> engine = FactorEngine(loader) - >>> eps_factor = EPSFactor() - >>> result = engine.compute(eps_factor, start_date='20210101', end_date='20210131') - """ - - name: str = "eps_rank" - category: str = "financial" - description: str = "每股收益截面排名因子" - data_specs: List[DataSpec] = [ - DataSpec( - "financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1 - ) - ] - - def compute(self, data: FactorData) -> pl.Series: - """计算EPS排名 - - Args: - data: FactorData,包含当前日期的截面数据 - - Returns: - EPS排名的0-1标准化值(0-1之间) - """ - # 获取当前日期的截面数据 - cs = data.get_cross_section() - - if len(cs) == 0: - return pl.Series(name=self.name, values=[]) - - # 提取EPS值,填充缺失值为0 - eps = cs["basic_eps"].fill_null(0) - - # 计算排名并归一化到0-1 - if len(eps) > 1 and eps.max() != eps.min(): - ranks = eps.rank(method="average") / len(eps) - else: - # 数据不足或全部相同,返回0.5 - ranks = pl.Series(name=self.name, values=[0.5] * len(eps)) - - return ranks diff --git a/src/factors/financial/utils.py b/src/factors/financial/utils.py deleted file mode 100644 index cb656bc..0000000 --- a/src/factors/financial/utils.py +++ /dev/null @@ -1,82 +0,0 @@ -"""财务因子工具函数 - -提供财务数据处理的工具函数: -- expand_period_to_trading_days: 将报告期数据展开到每个交易日(前向填充) -""" - -from typing import List -import polars as pl - - -def expand_period_to_trading_days( - financial_df: pl.DataFrame, - trade_dates: List[str], -) -> pl.DataFrame: - """将财务数据(报告期粒度)展开到每个交易日(前向填充) - - 核心逻辑:对于每个交易日,找到该日期之前最新的已公告报告期数据。 - 例如:2020年报(20201231)公告于20210428,则在2021-04-28之后的每个 - 交易日都使用该年报数据,直到2021一季报公告。 - - Args: - financial_df: 财务数据DataFrame,包含 ts_code, ann_date, end_date, ... - trade_dates: 交易日列表(YYYYMMDD格式,已排序) - - Returns: - DataFrame,包含 trade_date, ts_code 和所有财务字段 - - Example: - >>> financial_df = pl.DataFrame({ - ... 'ts_code': ['000001.SZ'], - ... 'ann_date': ['20210428'], - ... 'end_date': ['20210331'], - ... 'basic_eps': [0.5] - ... }) - >>> trade_dates = ['20210428', '20210429', '20210430'] - >>> result = expand_period_to_trading_days(financial_df, trade_dates) - >>> print(result) - shape: (3, 5) - ┌───────────┬───────────┬────────────┬────────────┬───────────┐ - │ ts_code ┆ ann_date ┆ end_date ┆ basic_eps ┆ trade_date│ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ str ┆ str ┆ f64 ┆ str │ - ╞═══════════╪═══════════╪════════════╪════════════╪═══════════╡ - │ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210428 │ - │ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210429 │ - │ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210430 │ - └───────────┴───────────┴────────────┴────────────┴───────────┘ - """ - if len(financial_df) == 0: - return pl.DataFrame() - - results = [] - - # 按股票分组处理 - for ts_code in financial_df["ts_code"].unique(): - stock_data = financial_df.filter(pl.col("ts_code") == ts_code) - - # 按报告期排序(end_date升序) - stock_data = stock_data.sort("end_date") - - rows = [] - for trade_date in trade_dates: - # 找到该交易日之前最新的已公告报告期 - # 条件1: end_date <= trade_date(报告期不晚于交易日) - # 条件2: ann_date <= trade_date(已公告) - applicable = stock_data.filter( - (pl.col("end_date") <= trade_date) & (pl.col("ann_date") <= trade_date) - ) - - if len(applicable) > 0: - # 取最新的一条(end_date最大的) - latest = applicable.tail(1).with_columns( - [pl.lit(trade_date).alias("trade_date")] - ) - rows.append(latest) - - if rows: - results.append(pl.concat(rows)) - - if results: - return pl.concat(results) - return pl.DataFrame() diff --git a/src/factors/momentum/__init__.py b/src/factors/momentum/__init__.py deleted file mode 100644 index 4b8b401..0000000 --- a/src/factors/momentum/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""动量因子模块 - -本模块提供动量类型的因子: -- MovingAverageFactor: 移动平均线(时序因子) -- ReturnRankFactor: 收益率排名(截面因子) - -因子分类: -- momentum: 动量因子 - - ma: 移动平均线 - - return_rank: 收益率排名 -""" - -from src.factors.momentum.ma import MovingAverageFactor -from src.factors.momentum.return_rank import ReturnRankFactor - -__all__ = [ - "MovingAverageFactor", - "ReturnRankFactor", -] diff --git a/src/factors/momentum/ma.py b/src/factors/momentum/ma.py deleted file mode 100644 index 05ea234..0000000 --- a/src/factors/momentum/ma.py +++ /dev/null @@ -1,78 +0,0 @@ -"""动量因子 - 移动平均线 - -本模块提供通用移动平均线因子,支持参数化配置: -- MovingAverageFactor: 移动平均线(时序因子) - -使用示例: - >>> from src.factors.momentum import MovingAverageFactor - >>> ma5 = MovingAverageFactor(period=5) # 5日MA - >>> ma10 = MovingAverageFactor(period=10) # 10日MA - >>> ma20 = MovingAverageFactor(period=20) # 20日MA -""" - -from typing import List - -import polars as pl - -from src.factors.base import TimeSeriesFactor -from src.factors.data_spec import DataSpec, FactorData - - -class MovingAverageFactor(TimeSeriesFactor): - """移动平均线因子 - - 计算逻辑:对每只股票,计算其过去n日收盘价的移动平均值。 - - 特点: - - 参数化因子:训练时通过 period 参数指定计算窗口 - - 时序因子:每只股票单独计算,防止股票间数据泄露 - - Attributes: - period: MA计算期(天数),默认5 - - Example: - >>> ma5 = MovingAverageFactor(period=5) - >>> # 计算过去5日的收盘价均值 - """ - - name: str = "ma" - factor_type: str = "time_series" - category: str = "momentum" - description: str = "移动平均线因子,计算过去n日收盘价的均值" - data_specs: List[DataSpec] = [ - DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5) - ] - - def __init__(self, period: int = 5): - """初始化因子 - - Args: - period: MA计算期(天数),默认5日 - """ - super().__init__(period=period) - # 重新创建 DataSpec 以设置正确的 lookback_days(DataSpec 是 frozen 的) - self.data_specs = [ - DataSpec( - "daily", - ["ts_code", "trade_date", "close"], - lookback_days=period, - ) - ] - self.name = f"ma_{period}" - - def compute(self, data: FactorData) -> pl.Series: - """计算移动平均线 - - Args: - data: FactorData,包含单只股票的完整时间序列 - - Returns: - 移动平均值序列 - """ - # 获取收盘价序列 - close_prices = data.get_column("close") - - # 计算移动平均 - ma = close_prices.rolling_mean(window_size=self.params["period"]) - - return ma diff --git a/src/factors/momentum/return_rank.py b/src/factors/momentum/return_rank.py deleted file mode 100644 index ee7514b..0000000 --- a/src/factors/momentum/return_rank.py +++ /dev/null @@ -1,100 +0,0 @@ -"""动量因子 - 收益率排名 - -本模块提供收益率排名因子: -- ReturnRankFactor: 过去n日收益率的rank因子(截面因子) - -使用示例: - >>> from src.factors.momentum import ReturnRankFactor - >>> ret5 = ReturnRankFactor(period=5) # 5日收益率排名 - >>> ret10 = ReturnRankFactor(period=10) # 10日收益率排名 -""" - -from typing import List - -import polars as pl - -from src.factors.base import CrossSectionalFactor -from src.factors.data_spec import DataSpec, FactorData - - -class ReturnRankFactor(CrossSectionalFactor): - """过去n日收益率排名因子 - - 计算逻辑:每个交易日,计算所有股票过去n日的收益率,然后进行截面排名。 - - 特点: - - 参数化因子:训练时通过 period 参数指定计算窗口 - - 截面因子:每天对所有股票进行横向排名,防止日期泄露 - - Attributes: - period: 收益率计算期(默认5日) - - Example: - >>> ret5 = ReturnRankFactor(period=5) - >>> # 每个交易日,返回所有股票过去5日收益率的排名 - """ - - name: str = "return_rank" - factor_type: str = "cross_sectional" - category: str = "momentum" - description: str = "过去n日收益率的截面排名因子" - data_specs: List[DataSpec] = [ - DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5) - ] - - def __init__(self, period: int = 5): - """初始化因子 - - Args: - period: 收益率计算期(天数) - """ - super().__init__(period=period) - # 重新创建 DataSpec 以设置正确的 lookback_days(DataSpec 是 frozen 的) - self.data_specs = [ - DataSpec( - "daily", - ["ts_code", "trade_date", "close"], - lookback_days=period + 1, - ) - ] - self.name = f"return_{period}_rank" - - def compute(self, data: FactorData) -> pl.Series: - """计算过去n日收益率排名 - - Args: - data: FactorData,包含过去n+1天的截面数据 - - Returns: - 过去n日收益率的截面排名(0-1之间) - """ - # 获取当前日期的截面数据 - cs = data.to_polars() - - # 获取所有交易日期(已按日期排序) - trade_dates = cs["trade_date"].unique().sort() - - if len(trade_dates) < 2: - # 数据不足,返回空排名 - return pl.Series(name=self.name, values=[]) - - # 获取最新日期的数据 - latest_date = trade_dates[-1] - current_data = cs.filter(pl.col("trade_date") == latest_date) - - # 获取n天前的日期 - n_days_ago = trade_dates[-(self.params["period"] + 1)] - past_data = cs.filter(pl.col("trade_date") == n_days_ago) - - # 通过 ts_code join 计算收益率 - merged = current_data.select(["ts_code", "close"]).join( - past_data.select(["ts_code", "close"]).rename({"close": "close_past"}), - on="ts_code", - how="inner", - ) - - # 计算收益率 - returns = (merged["close"] - merged["close_past"]) / merged["close_past"] - - # 返回排名(0-1之间) - return returns.rank(method="average") / len(returns) diff --git a/src/factors/quality/__init__.py b/src/factors/quality/__init__.py deleted file mode 100644 index 646f1f9..0000000 --- a/src/factors/quality/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""质量因子模块 - -本模块提供质量类因子: -- 盈利能力:ROE、ROA、毛利率、净利率 -- 盈利稳定性:盈利波动率、盈利持续性 -- 财务健康度:资产负债率、流动比率等 - -使用示例: - >>> from src.factors.quality import ROEFactor - >>> factor = ROEFactor() -""" - -# 在此处导入具体的质量因子 -# from .roe import ROEFactor -# from .roa import ROAFactor -# from .profit_stability import ProfitStabilityFactor - -__all__ = [ - # 添加你的质量因子 -] diff --git a/src/factors/sentiment/__init__.py b/src/factors/sentiment/__init__.py deleted file mode 100644 index c15a1af..0000000 --- a/src/factors/sentiment/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""情绪因子模块 - -本模块提供市场情绪类因子: -- 换手率、换手率变化率 -- 资金流向、主力净流入 -- 波动率、振幅等 - -使用示例: - >>> from src.factors.sentiment import TurnoverFactor - >>> factor = TurnoverFactor(period=20) -""" - -# 在此处导入具体的情绪因子 -# from .turnover import TurnoverFactor -# from .money_flow import MoneyFlowFactor -# from .amplitude import AmplitudeFactor - -__all__ = [ - # 添加你的情绪因子 -] diff --git a/src/factors/technical/__init__.py b/src/factors/technical/__init__.py deleted file mode 100644 index 45b5c36..0000000 --- a/src/factors/technical/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""技术指标因子模块 - -本模块提供技术分析类因子: -- 移动平均线(MA)、指数移动平均(EMA) -- 相对强弱指标(RSI)、MACD、KDJ -- 布林带(Bollinger Bands)等 - -使用示例: - >>> from src.factors.technical import RSIFactor - >>> factor = RSIFactor(period=14) -""" - -# 在此处导入具体的技术指标因子 -# from .rsi import RSIFactor -# from .macd import MACDFactor -# from .bollinger import BollingerFactor - -__all__ = [ - # 添加你的技术指标因子 -] diff --git a/src/factors/valuation/__init__.py b/src/factors/valuation/__init__.py deleted file mode 100644 index 17a9b25..0000000 --- a/src/factors/valuation/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""估值因子模块 - -本模块提供估值类因子: -- 市盈率(PE)、市净率(PB)、市销率(PS)等估值指标 -- 估值排名、估值分位数等衍生因子 - -使用示例: - >>> from src.factors.valuation import PERankFactor - >>> factor = PERankFactor() -""" - -# 在此处导入具体的估值因子 -# from .pe_rank import PERankFactor -# from .pb_rank import PBRankFactor - -__all__ = [ - # 添加你的估值因子 -] diff --git a/src/factors/volatility/__init__.py b/src/factors/volatility/__init__.py deleted file mode 100644 index 4a95d03..0000000 --- a/src/factors/volatility/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -"""波动率因子模块 - -本模块提供波动率相关因子: -- 历史波动率(Historical Volatility) -- 实现波动率(Realized Volatility) -- GARCH类波动率预测 -- 波动率风险指标等 - -使用示例: - >>> from src.factors.volatility import HistoricalVolFactor - >>> factor = HistoricalVolFactor(period=20) -""" - -# 在此处导入具体的波动率因子 -# from .historical_vol import HistoricalVolFactor -# from .realized_vol import RealizedVolFactor -# from .garch_vol import GARCHVolFactor - -__all__ = [ - # 添加你的波动率因子 -] diff --git a/src/factors/volume/__init__.py b/src/factors/volume/__init__.py deleted file mode 100644 index e3bc180..0000000 --- a/src/factors/volume/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""成交量因子模块 - -本模块提供成交量相关因子: -- 成交量移动平均 -- 成交量比率(VR)、能量潮(OBV) -- 量价配合指标等 - -使用示例: - >>> from src.factors.volume import OBVFactor - >>> factor = OBVFactor() -""" - -# 在此处导入具体的成交量因子 -# from .obv import OBVFactor -# from .volume_ratio import VolumeRatioFactor -# from .volume_ma import VolumeMAFactor - -__all__ = [ - # 添加你的成交量因子 -] diff --git a/tests/factors/__init__.py b/tests/factors/__init__.py deleted file mode 100644 index ab588e4..0000000 --- a/tests/factors/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# tests/factors/__init__.py -"""Factors 模块测试包""" diff --git a/tests/factors/factor_test_report.md b/tests/factors/factor_test_report.md deleted file mode 100644 index 9b67336..0000000 --- a/tests/factors/factor_test_report.md +++ /dev/null @@ -1,143 +0,0 @@ -# 因子真实数据测试报告 - -## 1. 测试概述 - -本测试使用 `daily.h5` 文件中的真实A股市场数据,对 ProStock 因子框架进行验证。测试对比了因子框架计算结果与 Polars 原生计算结果,验证因子计算的正确性。 - -### 测试数据 -- **数据源**: `data/daily.h5` -- **时间范围**: 2024-01-01 至 2024-04-30 -- **股票数量**: 20只 -- **数据量**: 1,560条记录 - -### 测试因子类型 -1. **时序因子 (TimeSeriesFactor)**: 移动平均线 (MA) -2. **截面因子 (CrossSectionalFactor)**: PE排名 (PE_Rank) -3. **结合因子 (CompositeFactor)**: 标量组合 (0.5 * MA) 和因子加法 (MA5 + MA10) - ---- - -## 2. 测试结果 - -### 2.1 时序因子测试 - MA(5) - -``` -[时序因子 MA(5) 对比] - 样本股票: 000001.SZ - 有效数据点: 77 - 最大差异: 0.000000000000000 - 样本数据 (前5个): - Polars: 10.022000, Factor: 10.022000, Diff: 0.000000000000000 - Polars: 10.046000, Factor: 10.046000, Diff: 0.000000000000000 - Polars: 10.056000, Factor: 10.056000, Diff: 0.000000000000000 - Polars: 10.072000, Factor: 10.072000, Diff: 0.000000000000000 - Polars: 10.078000, Factor: 10.078000, Diff: 0.000000000000000 -``` - -**结论**: ✅ **通过** - 因子框架计算的 MA(5) 与 Polars 原生计算完全一致 - ---- - -### 2.2 截面因子测试 - PE_Rank - -``` -[截面因子 PE_Rank 对比] - 样本日期: 20240131 - 股票数量: 20 - 最大差异: 0.000000000000000 - 样本数据 (前5个): - 000001.SZ: Polars: 0.050000, Factor: 0.050000 - 000002.SZ: Polars: 0.550000, Factor: 0.550000 - 000004.SZ: Polars: 0.300000, Factor: 0.300000 - 000005.SZ: Polars: 0.100000, Factor: 0.100000 - 000006.SZ: Polars: 0.400000, Factor: 0.400000 -``` - -**结论**: ✅ **通过** - 因子框架计算的 PE_Rank 与 Polars 原生计算完全一致 - ---- - -### 2.3 结合因子测试 - 0.5 * MA(5) - -``` -[结合因子 0.5*MA(5) 对比] - 公式: 0.5 * MA(5) - 有效数据点: 77 - 最大差异: 0.000000000000000 - 样本数据 (前5个): - Polars: 5.011000, Factor: 5.011000, Diff: 0.000000000000000 - Polars: 5.023000, Factor: 5.023000, Diff: 0.000000000000000 - Polars: 5.028000, Factor: 5.028000, Diff: 0.000000000000000 - Polars: 5.036000, Factor: 5.036000, Diff: 0.000000000000000 - Polars: 5.039000, Factor: 5.039000, Diff: 0.000000000000000 -``` - -**结论**: ✅ **通过** - 标量组合因子计算正确 - ---- - -### 2.4 结合因子测试 - MA(5) + MA(10) - -``` -[结合因子 MA(5) + MA(10) 对比] - 有效数据点: 72 - 最大差异: 0.000000000000000 -``` - -**结论**: ✅ **通过** - 因子加法组合计算正确 - ---- - -## 3. 综合测试汇总 - -``` -============================================================ -因子测试汇总 -============================================================ - MA(5): 最大差异 = 0.00e+00 通过 - MA(10): 最大差异 = 0.00e+00 通过 - MA(20): 最大差异 = 0.00e+00 通过 - PE_Rank: 最大差异 = 0.00e+00 通过 -============================================================ -``` - ---- - -## 4. 测试结论 - -### 4.1 全部通过 ✅ - -所有5个测试用例均通过验证: - -| 测试项目 | 因子类型 | 最大差异 | 状态 | -|---------|---------|---------|------| -| MA(5) | 时序因子 | 0.00e+00 | ✅ 通过 | -| MA(10) | 时序因子 | 0.00e+00 | ✅ 通过 | -| MA(20) | 时序因子 | 0.00e+00 | ✅ 通过 | -| PE_Rank | 截面因子 | 0.00e+00 | ✅ 通过 | -| 0.5 * MA(5) | 结合因子 | 0.00e+00 | ✅ 通过 | -| MA(5) + MA(10) | 结合因子 | 0.00e+00 | ✅ 通过 | - -### 4.2 关键发现 - -1. **计算精度**: 因子框架与 Polars 原生计算结果的差异为 0,表明计算精度完全一致 -2. **时序因子**: `TimeSeriesFactor` 基类正确实现了股票级别的时序计算 -3. **截面因子**: `CrossSectionalFactor` 基类正确实现了日期级别的截面计算 -4. **组合因子**: `ScalarFactor` 和 `CompositeFactor` 正确实现了标量运算和因子组合 - -### 4.3 验证结论 - -ProStock 因子框架的计算逻辑与 Polars 原生计算完全一致,框架设计正确,可以用于实际量化投资研究。 - ---- - -## 5. 测试环境 - -- Python: 3.13.2 -- Polars: 最新版本 -- Pytest: 9.0.2 -- 数据: daily.h5 (8,856,081 条记录) - ---- - -*报告生成时间: 2026-02-22* diff --git a/tests/factors/test_base.py b/tests/factors/test_base.py deleted file mode 100644 index f0ed632..0000000 --- a/tests/factors/test_base.py +++ /dev/null @@ -1,406 +0,0 @@ -"""测试因子基类 - BaseFactor、CrossSectionalFactor、TimeSeriesFactor - -测试需求(来自 factor_implementation_plan.md): -- BaseFactor: - - 测试有效子类创建通过验证 - - 测试缺少 `name` 时抛出 ValueError - - 测试 `name` 为空字符串时抛出 ValueError - - 测试缺少 `factor_type` 时抛出 ValueError - - 测试无效的 `factor_type`(非 cs/ts)时抛出 ValueError - - 测试缺少 `data_specs` 时抛出 ValueError - - 测试 `data_specs` 为空列表时抛出 ValueError - - 测试 `compute()` 抽象方法强制子类实现 - - 测试参数化初始化 `params` 正确存储 - - 测试 `_validate_params()` 被调用 - -- CrossSectionalFactor: - - 测试 `factor_type` 自动设置为 "cross_sectional" - - 测试子类必须实现 `compute()` - - 测试 `compute()` 返回类型为 pl.Series - -- TimeSeriesFactor: - - 测试 `factor_type` 自动设置为 "time_series" - - 测试子类必须实现 `compute()` - - 测试 `compute()` 返回类型为 pl.Series -""" - -import pytest -import polars as pl - -from src.factors import DataSpec, FactorContext, FactorData -from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor - - -# ========== 测试数据准备 ========== - - -@pytest.fixture -def sample_dataspec(): - """创建一个示例 DataSpec""" - return DataSpec( - source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5 - ) - - -@pytest.fixture -def sample_factor_data(): - """创建一个示例 FactorData""" - df = pl.DataFrame( - { - "ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"], - "trade_date": ["20240101", "20240101", "20240102", "20240102"], - "close": [10.0, 20.0, 11.0, 21.0], - } - ) - context = FactorContext(current_date="20240102") - return FactorData(df, context) - - -# ========== BaseFactor 测试 ========== - - -class TestBaseFactorValidation: - """测试 BaseFactor 子类验证""" - - def test_valid_cross_sectional_subclass(self, sample_dataspec): - """测试有效的截面因子子类创建通过验证""" - - class ValidFactor(CrossSectionalFactor): - name = "valid_cs" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - # 应该能成功创建实例 - factor = ValidFactor() - assert factor.name == "valid_cs" - assert factor.factor_type == "cross_sectional" - - def test_valid_time_series_subclass(self, sample_dataspec): - """测试有效的时序因子子类创建通过验证""" - - class ValidFactor(TimeSeriesFactor): - name = "valid_ts" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - factor = ValidFactor() - assert factor.name == "valid_ts" - assert factor.factor_type == "time_series" - - def test_missing_name(self, sample_dataspec): - """测试缺少 name 时抛出 ValueError""" - with pytest.raises(ValueError, match="must define 'name'"): - - class BadFactor(CrossSectionalFactor): - # name = "" # 故意不定义 - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - def test_empty_name(self, sample_dataspec): - """测试 name 为空字符串时抛出 ValueError""" - with pytest.raises(ValueError, match="must define 'name'"): - - class BadFactor(CrossSectionalFactor): - name = "" # 空字符串 - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - def test_missing_factor_type(self, sample_dataspec): - """测试缺少 factor_type 时抛出 ValueError""" - with pytest.raises(ValueError, match="must define 'factor_type'"): - - class BadFactor(BaseFactor): - name = "bad_factor" - # factor_type = "" # 故意不定义 - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - def test_invalid_factor_type(self, sample_dataspec): - """测试无效的 factor_type(非 cs/ts)时抛出 ValueError""" - with pytest.raises( - ValueError, match="must be 'cross_sectional' or 'time_series'" - ): - - class BadFactor(BaseFactor): - name = "bad_factor" - factor_type = "invalid_type" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - def test_missing_data_specs(self): - """测试缺少 data_specs 时抛出 ValueError""" - with pytest.raises(ValueError, match="must define 'data_specs'"): - - class BadFactor(BaseFactor): - name = "bad_factor" - factor_type = "cross_sectional" - # data_specs = [] # 故意不定义 - - def compute(self, data): - return pl.Series([1.0]) - - def test_empty_data_specs(self): - """测试 data_specs 为空列表时抛出 ValueError""" - with pytest.raises(ValueError, match="cannot be empty"): - - class BadFactor(CrossSectionalFactor): - name = "bad_factor" - data_specs = [] # 空列表 - - def compute(self, data): - return pl.Series([1.0]) - - -class TestBaseFactorCompute: - """测试 compute() 抽象方法""" - - def test_compute_must_be_implemented_cs(self, sample_dataspec): - """测试截面因子子类必须实现 compute()""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - - class BadFactor(CrossSectionalFactor): - name = "bad_cs" - data_specs = [sample_dataspec] - # 不实现 compute() - - BadFactor() - - def test_compute_must_be_implemented_ts(self, sample_dataspec): - """测试时序因子子类必须实现 compute()""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): - - class BadFactor(TimeSeriesFactor): - name = "bad_ts" - data_specs = [sample_dataspec] - # 不实现 compute() - - BadFactor() - - -class TestBaseFactorParams: - """测试参数化初始化""" - - def test_params_stored_correctly(self, sample_dataspec): - """测试参数化初始化 params 正确存储""" - - class ParamFactor(CrossSectionalFactor): - name = "param_factor" - data_specs = [sample_dataspec] - - def __init__(self, period: int = 20, weight: float = 1.0): - super().__init__(period=period, weight=weight) - - def compute(self, data): - return pl.Series([1.0]) - - factor = ParamFactor(period=10, weight=0.5) - assert factor.params["period"] == 10 - assert factor.params["weight"] == 0.5 - - def test_validate_params_called(self, sample_dataspec): - """测试 _validate_params() 被调用""" - validated = [] - - class ValidatedFactor(CrossSectionalFactor): - name = "validated_factor" - data_specs = [sample_dataspec] - - def __init__(self, period: int = 20): - super().__init__(period=period) - - def _validate_params(self): - validated.append(True) - if self.params.get("period", 0) <= 0: - raise ValueError("period must be positive") - - def compute(self, data): - return pl.Series([1.0]) - - # 创建实例时应该调用 _validate_params - factor = ValidatedFactor(period=10) - assert len(validated) == 1 - assert factor.params["period"] == 10 - - def test_validate_params_raises(self, sample_dataspec): - """测试 _validate_params() 可以抛出异常""" - - class BadParamFactor(CrossSectionalFactor): - name = "bad_param_factor" - data_specs = [sample_dataspec] - - def __init__(self, period: int = 20): - super().__init__(period=period) - - def _validate_params(self): - if self.params.get("period", 0) <= 0: - raise ValueError("period must be positive") - - def compute(self, data): - return pl.Series([1.0]) - - with pytest.raises(ValueError, match="period must be positive"): - BadParamFactor(period=-5) - - -class TestBaseFactorRepr: - """测试 __repr__""" - - def test_repr(self, sample_dataspec): - """测试 __repr__ 返回正确格式""" - - class TestFactor(CrossSectionalFactor): - name = "test_factor" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - factor = TestFactor() - repr_str = repr(factor) - assert "TestFactor" in repr_str - assert "test_factor" in repr_str - assert "cross_sectional" in repr_str - - -# ========== CrossSectionalFactor 测试 ========== - - -class TestCrossSectionalFactor: - """测试 CrossSectionalFactor""" - - def test_factor_type_auto_set(self, sample_dataspec): - """测试 factor_type 自动设置为 'cross_sectional'""" - - class CSFactor(CrossSectionalFactor): - name = "cs_factor" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - factor = CSFactor() - assert factor.factor_type == "cross_sectional" - - def test_compute_returns_series(self, sample_factor_data, sample_dataspec): - """测试 compute() 返回类型为 pl.Series""" - - class CSFactor(CrossSectionalFactor): - name = "cs_factor" - data_specs = [sample_dataspec] - - def compute(self, data): - # 返回一个简单的 Series - return pl.Series([1.0, 2.0]) - - factor = CSFactor() - result = factor.compute(sample_factor_data) - assert isinstance(result, pl.Series) - - def test_compute_with_cross_section(self, sample_dataspec): - """测试 compute() 使用 get_cross_section()""" - df = pl.DataFrame( - { - "ts_code": ["000001.SZ", "000002.SZ"], - "trade_date": ["20240101", "20240101"], - "close": [10.0, 20.0], - } - ) - context = FactorContext(current_date="20240101") - data = FactorData(df, context) - - class RankFactor(CrossSectionalFactor): - name = "rank_factor" - data_specs = [sample_dataspec] - - def compute(self, data): - cs = data.get_cross_section() - return cs["close"].rank() - - factor = RankFactor() - result = factor.compute(data) - assert isinstance(result, pl.Series) - assert len(result) == 2 - - -# ========== TimeSeriesFactor 测试 ========== - - -class TestTimeSeriesFactor: - """测试 TimeSeriesFactor""" - - def test_factor_type_auto_set(self, sample_dataspec): - """测试 factor_type 自动设置为 'time_series'""" - - class TSFactor(TimeSeriesFactor): - name = "ts_factor" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0]) - - factor = TSFactor() - assert factor.factor_type == "time_series" - - def test_compute_returns_series(self, sample_factor_data, sample_dataspec): - """测试 compute() 返回类型为 pl.Series""" - - class TSFactor(TimeSeriesFactor): - name = "ts_factor" - data_specs = [sample_dataspec] - - def compute(self, data): - return data.get_column("close") * 2 - - factor = TSFactor() - result = factor.compute(sample_factor_data) - assert isinstance(result, pl.Series) - assert len(result) == 4 # 4行数据 - - def test_compute_with_rolling(self, sample_dataspec): - """测试 compute() 使用 rolling 操作""" - df = pl.DataFrame( - { - "ts_code": ["000001.SZ"] * 5, - "trade_date": [ - "20240101", - "20240102", - "20240103", - "20240104", - "20240105", - ], - "close": [10.0, 11.0, 12.0, 13.0, 14.0], - } - ) - context = FactorContext(current_stock="000001.SZ") - data = FactorData(df, context) - - class MAFactor(TimeSeriesFactor): - name = "ma_factor" - data_specs = [sample_dataspec] - - def __init__(self, period: int = 3): - super().__init__(period=period) - - def compute(self, data): - return data.get_column("close").rolling_mean(self.params["period"]) - - factor = MAFactor(period=3) - result = factor.compute(data) - assert isinstance(result, pl.Series) - assert len(result) == 5 - # 前2个应该是 null(因为 period=3) - assert result[0] is None - assert result[1] is None - assert result[2] is not None diff --git a/tests/factors/test_composite.py b/tests/factors/test_composite.py deleted file mode 100644 index d181801..0000000 --- a/tests/factors/test_composite.py +++ /dev/null @@ -1,417 +0,0 @@ -"""测试组合因子 - CompositeFactor、ScalarFactor - -测试需求(来自 factor_implementation_plan.md): -- CompositeFactor: - - 测试同类型因子组合成功(cs + cs) - - 测试同类型因子组合成功(ts + ts) - - 测试不同类型因子组合抛出 ValueError(cs + ts) - - 测试无效运算符抛出 ValueError - - 测试 `_merge_data_specs()` 正确合并(相同 source) - - 测试 `_merge_data_specs()` 正确合并(不同 source) - - 测试 `_merge_data_specs()` lookback 取最大值 - - 测试 `compute()` 执行正确的数学运算 - -- ScalarFactor: - - 测试标量乘法 `0.5 * factor` - - 测试标量乘法 `factor * 0.5` - - 测试标量加法(如支持) - - 测试继承基础因子的 data_specs - - 测试 `compute()` 返回正确缩放后的值 -""" - -import pytest -import polars as pl - -from src.factors import DataSpec, FactorContext, FactorData -from src.factors.base import CrossSectionalFactor, TimeSeriesFactor -from src.factors.composite import CompositeFactor, ScalarFactor - - -# ========== 测试数据准备 ========== - - -@pytest.fixture -def sample_dataspec(): - """创建一个示例 DataSpec""" - return DataSpec( - source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5 - ) - - -@pytest.fixture -def cs_factor1(sample_dataspec): - """截面因子 1""" - - class CSFactor1(CrossSectionalFactor): - name = "cs_factor1" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0, 2.0]) - - return CSFactor1() - - -@pytest.fixture -def cs_factor2(sample_dataspec): - """截面因子 2""" - - class CSFactor2(CrossSectionalFactor): - name = "cs_factor2" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([3.0, 4.0]) - - return CSFactor2() - - -@pytest.fixture -def ts_factor1(sample_dataspec): - """时序因子 1""" - - class TSFactor1(TimeSeriesFactor): - name = "ts_factor1" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([10.0, 20.0, 30.0]) - - return TSFactor1() - - -@pytest.fixture -def ts_factor2(sample_dataspec): - """时序因子 2""" - - class TSFactor2(TimeSeriesFactor): - name = "ts_factor2" - data_specs = [sample_dataspec] - - def compute(self, data): - return pl.Series([1.0, 2.0, 3.0]) - - return TSFactor2() - - -@pytest.fixture -def sample_factor_data(): - """创建一个示例 FactorData""" - df = pl.DataFrame( - { - "ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"], - "trade_date": ["20240101", "20240101", "20240102", "20240102"], - "close": [10.0, 20.0, 11.0, 21.0], - } - ) - context = FactorContext(current_date="20240102") - return FactorData(df, context) - - -# ========== CompositeFactor 测试 ========== - - -class TestCompositeFactorTypeValidation: - """测试类型验证""" - - def test_same_type_combination_cs(self, cs_factor1, cs_factor2): - """测试同类型截面因子组合成功""" - combined = cs_factor1 + cs_factor2 - assert isinstance(combined, CompositeFactor) - assert combined.factor_type == "cross_sectional" - assert combined.name == "(cs_factor1_+_cs_factor2)" - - def test_same_type_combination_ts(self, ts_factor1, ts_factor2): - """测试同类型时序因子组合成功""" - combined = ts_factor1 - ts_factor2 - assert isinstance(combined, CompositeFactor) - assert combined.factor_type == "time_series" - assert combined.name == "(ts_factor1_-_ts_factor2)" - - def test_different_type_raises(self, cs_factor1, ts_factor1): - """测试不同类型因子组合抛出 ValueError""" - with pytest.raises( - ValueError, match="Cannot combine factors of different types" - ): - cs_factor1 + ts_factor1 - - def test_invalid_operator_raises(self, cs_factor1, cs_factor2): - """测试无效运算符抛出 ValueError""" - with pytest.raises(ValueError, match="Unsupported operator"): - CompositeFactor(cs_factor1, cs_factor2, "%") - - -class TestCompositeFactorMergeDataSpecs: - """测试 _merge_data_specs""" - - def test_merge_same_source_same_columns(self): - """测试相同 source 和 columns 的 DataSpec 合并""" - spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5) - spec2 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=10) - - class Factor1(CrossSectionalFactor): - name = "f1" - data_specs = [spec1] - - def compute(self, data): - return pl.Series([1.0]) - - class Factor2(CrossSectionalFactor): - name = "f2" - data_specs = [spec2] - - def compute(self, data): - return pl.Series([2.0]) - - combined = Factor1() + Factor2() - - # 应该合并成一个 DataSpec,lookback_days 取最大值 10 - assert len(combined.data_specs) == 1 - assert combined.data_specs[0].lookback_days == 10 - assert combined.data_specs[0].source == "daily" - - def test_merge_different_source(self): - """测试不同 source 的 DataSpec 不合并""" - spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5) - spec2 = DataSpec( - "fundamental", ["ts_code", "trade_date", "pe"], lookback_days=1 - ) - - class Factor1(CrossSectionalFactor): - name = "f1" - data_specs = [spec1] - - def compute(self, data): - return pl.Series([1.0]) - - class Factor2(CrossSectionalFactor): - name = "f2" - data_specs = [spec2] - - def compute(self, data): - return pl.Series([2.0]) - - combined = Factor1() + Factor2() - - # 应该有两个 DataSpec - assert len(combined.data_specs) == 2 - sources = {s.source for s in combined.data_specs} - assert sources == {"daily", "fundamental"} - - def test_merge_same_source_different_columns(self): - """测试相同 source 但不同 columns 的 DataSpec 不合并""" - spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5) - spec2 = DataSpec("daily", ["ts_code", "trade_date", "open"], lookback_days=3) - - class Factor1(CrossSectionalFactor): - name = "f1" - data_specs = [spec1] - - def compute(self, data): - return pl.Series([1.0]) - - class Factor2(CrossSectionalFactor): - name = "f2" - data_specs = [spec2] - - def compute(self, data): - return pl.Series([2.0]) - - combined = Factor1() + Factor2() - - # 应该有两个 DataSpec(因为 columns 不同) - assert len(combined.data_specs) == 2 - - def test_merge_lookback_max(self): - """测试 lookback_days 取最大值""" - spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5) - spec2 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20) - spec3 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=10) - - class Factor1(CrossSectionalFactor): - name = "f1" - data_specs = [spec1] - - def compute(self, data): - return pl.Series([1.0]) - - class Factor2(CrossSectionalFactor): - name = "f2" - data_specs = [spec2, spec3] - - def compute(self, data): - return pl.Series([2.0]) - - combined = Factor1() + Factor2() - - # 应该合并成一个 DataSpec,lookback_days 取最大值 20 - assert len(combined.data_specs) == 1 - assert combined.data_specs[0].lookback_days == 20 - - -class TestCompositeFactorCompute: - """测试 compute 运算""" - - def test_compute_addition(self, cs_factor1, cs_factor2, sample_factor_data): - """测试加法运算""" - combined = cs_factor1 + cs_factor2 - result = combined.compute(sample_factor_data) - - # [1.0, 2.0] + [3.0, 4.0] = [4.0, 6.0] - expected = pl.Series([4.0, 6.0]) - assert (result - expected).abs().max() < 1e-10 - - def test_compute_subtraction(self, cs_factor1, cs_factor2, sample_factor_data): - """测试减法运算""" - combined = cs_factor1 - cs_factor2 - result = combined.compute(sample_factor_data) - - # [1.0, 2.0] - [3.0, 4.0] = [-2.0, -2.0] - expected = pl.Series([-2.0, -2.0]) - assert (result - expected).abs().max() < 1e-10 - - def test_compute_multiplication(self, cs_factor1, cs_factor2, sample_factor_data): - """测试乘法运算""" - combined = cs_factor1 * cs_factor2 - result = combined.compute(sample_factor_data) - - # [1.0, 2.0] * [3.0, 4.0] = [3.0, 8.0] - expected = pl.Series([3.0, 8.0]) - assert (result - expected).abs().max() < 1e-10 - - def test_compute_division(self, cs_factor1, cs_factor2, sample_factor_data): - """测试除法运算""" - combined = cs_factor1 / cs_factor2 - result = combined.compute(sample_factor_data) - - # [1.0, 2.0] / [3.0, 4.0] = [0.333..., 0.5] - expected = pl.Series([1.0 / 3.0, 0.5]) - assert (result - expected).abs().max() < 1e-10 - - def test_compute_with_ts_factors(self, ts_factor1, ts_factor2, sample_factor_data): - """测试时序因子的组合运算""" - combined = ts_factor1 + ts_factor2 - result = combined.compute(sample_factor_data) - - # [10.0, 20.0, 30.0] + [1.0, 2.0, 3.0] = [11.0, 22.0, 33.0] - expected = pl.Series([11.0, 22.0, 33.0]) - assert (result - expected).abs().max() < 1e-10 - - def test_chained_combination(self, cs_factor1, cs_factor2, sample_factor_data): - """测试链式组合 (f1 + f2) * f1""" - - # 创建第三个因子 - class CSFactor3(CrossSectionalFactor): - name = "cs_factor3" - data_specs = [ - DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5) - ] - - def compute(self, data): - return pl.Series([0.5, 1.0]) - - f3 = CSFactor3() - - # (f1 + f2) * f3 - # f1 + f2 = [1.0, 2.0] + [3.0, 4.0] = [4.0, 6.0] - # [4.0, 6.0] * [0.5, 1.0] = [2.0, 6.0] - combined = (cs_factor1 + cs_factor2) * f3 - result = combined.compute(sample_factor_data) - - expected = pl.Series([2.0, 6.0]) - assert (result - expected).abs().max() < 1e-10 - - -# ========== ScalarFactor 测试 ========== - - -class TestScalarFactor: - """测试 ScalarFactor""" - - def test_scalar_multiplication_left(self, cs_factor1): - """测试标量乘法 `0.5 * factor`(左乘)""" - scaled = 0.5 * cs_factor1 - assert isinstance(scaled, ScalarFactor) - assert scaled.scalar == 0.5 - assert scaled.op == "*" - assert scaled.factor == cs_factor1 - - def test_scalar_multiplication_right(self, cs_factor1): - """测试标量乘法 `factor * 0.5`(右乘)""" - scaled = cs_factor1 * 0.5 - assert isinstance(scaled, ScalarFactor) - assert scaled.scalar == 0.5 - assert scaled.op == "*" - - def test_scalar_integer_multiplication(self, cs_factor1): - """测试整数标量乘法""" - scaled = 2 * cs_factor1 - assert isinstance(scaled, ScalarFactor) - assert scaled.scalar == 2.0 - - def test_inherits_data_specs(self, cs_factor1): - """测试继承基础因子的 data_specs""" - scaled = 0.5 * cs_factor1 - assert scaled.data_specs == cs_factor1.data_specs - - def test_compute_multiplication(self, cs_factor1, sample_factor_data): - """测试标量乘法 compute 结果""" - scaled = 0.5 * cs_factor1 - result = scaled.compute(sample_factor_data) - - # [1.0, 2.0] * 0.5 = [0.5, 1.0] - expected = pl.Series([0.5, 1.0]) - assert (result - expected).abs().max() < 1e-10 - - def test_compute_with_ts_factor(self, ts_factor1, sample_factor_data): - """测试时序因子的标量乘法""" - scaled = 0.1 * ts_factor1 - result = scaled.compute(sample_factor_data) - - # [10.0, 20.0, 30.0] * 0.1 = [1.0, 2.0, 3.0] - expected = pl.Series([1.0, 2.0, 3.0]) - assert (result - expected).abs().max() < 1e-10 - - def test_factor_type_preserved(self, cs_factor1, ts_factor1): - """测试 factor_type 被正确保留""" - scaled_cs = 0.5 * cs_factor1 - scaled_ts = 0.5 * ts_factor1 - - assert scaled_cs.factor_type == "cross_sectional" - assert scaled_ts.factor_type == "time_series" - - def test_scalar_name_format(self, cs_factor1): - """测试 ScalarFactor 的 name 格式""" - scaled = 0.5 * cs_factor1 - assert scaled.name == "(0.5_*_cs_factor1)" - - -# ========== 组合和标量混合测试 ========== - - -class TestMixedOperations: - """测试组合因子和标量因子的混合运算""" - - def test_scalar_then_combine(self, cs_factor1, cs_factor2, sample_factor_data): - """测试先标量缩放再组合""" - # 0.5 * f1 + 0.3 * f2 - combined = 0.5 * cs_factor1 + 0.3 * cs_factor2 - result = combined.compute(sample_factor_data) - - # 0.5 * [1.0, 2.0] + 0.3 * [3.0, 4.0] - # = [0.5, 1.0] + [0.9, 1.2] - # = [1.4, 2.2] - expected = pl.Series([1.4, 2.2]) - assert (result - expected).abs().max() < 1e-10 - - def test_complex_formula(self, cs_factor1, cs_factor2, sample_factor_data): - """测试复杂公式: (f1 + f2) * 0.5 - f1 * 0.2""" - formula = (cs_factor1 + cs_factor2) * 0.5 - cs_factor1 * 0.2 - result = formula.compute(sample_factor_data) - - # (f1 + f2) = [4.0, 6.0] - # (f1 + f2) * 0.5 = [2.0, 3.0] - # f1 * 0.2 = [0.2, 0.4] - # [2.0, 3.0] - [0.2, 0.4] = [1.8, 2.6] - expected = pl.Series([1.8, 2.6]) - assert (result - expected).abs().max() < 1e-10 diff --git a/tests/factors/test_data_loader.py b/tests/factors/test_data_loader.py deleted file mode 100644 index 599d107..0000000 --- a/tests/factors/test_data_loader.py +++ /dev/null @@ -1,284 +0,0 @@ -"""测试数据加载器 - DataLoader - -测试需求(来自 factor_implementation_plan.md): -- 测试从 DuckDB 加载数据 -- 测试从多个查询加载并合并 -- 测试列选择(只加载需要的列) -- 测试缓存机制(第二次加载更快) -- 测试 clear_cache() 清空缓存 -- 测试按 date_range 过滤 -- 测试表不存在时的处理 -- 测试列不存在时抛出 KeyError - -使用 3 个月的真实数据进行测试 (2024年1月-3月) -""" - -import pytest -import polars as pl -import pandas as pd -from pathlib import Path - -from src.factors import DataSpec, DataLoader - - -class TestDataLoaderBasic: - """测试 DataLoader 基本功能""" - - # 测试数据时间范围:3个月 - TEST_START_DATE = "20240101" - TEST_END_DATE = "20240331" - - @pytest.fixture - def loader(self): - """创建 DataLoader 实例""" - return DataLoader(data_dir="data") - - def test_init(self): - """测试初始化""" - loader = DataLoader(data_dir="data") - assert loader.data_dir == Path("data") - assert loader._cache == {} - - def test_load_single_source(self, loader): - """测试从 DuckDB 加载数据""" - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ) - ] - - # 使用 3 个月日期范围限制数据量 - df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - - assert isinstance(df, pl.DataFrame) - assert len(df) > 0 - assert "ts_code" in df.columns - assert "trade_date" in df.columns - assert "close" in df.columns - - def test_load_with_date_range(self, loader): - """测试加载特定日期范围(3个月)""" - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close", "open", "high", "low"], - lookback_days=1, - ) - ] - - df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - - assert isinstance(df, pl.DataFrame) - assert len(df) > 0 - - # 验证日期范围 - if len(df) > 0: - dates = df["trade_date"].to_list() - assert all(self.TEST_START_DATE <= d <= self.TEST_END_DATE for d in dates) - print(f"[TEST] Loaded {len(df)} rows from {min(dates)} to {max(dates)}") - - def test_load_multiple_specs(self, loader): - """测试从多个 DataSpec 加载并合并""" - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ), - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "open", "high", "low"], - lookback_days=1, - ), - ] - - df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - - assert isinstance(df, pl.DataFrame) - assert len(df) > 0 - # 应该包含所有列 - assert set(df.columns) >= { - "ts_code", - "trade_date", - "close", - "open", - "high", - "low", - } - - def test_column_selection(self, loader): - """测试列选择(只加载需要的列)""" - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ) - ] - - df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - - # 只应该有 3 列 - assert set(df.columns) == {"ts_code", "trade_date", "close"} - - def test_date_range_filter(self, loader): - """测试按 date_range 过滤 - 使用3个月数据的不同子集""" - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ) - ] - - # 加载完整的3个月数据 - df_all = loader.load( - specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE) - ) - total_rows = len(df_all) - - # 清空缓存,重新加载1个月数据 - loader.clear_cache() - df_filtered = loader.load(specs, date_range=("20240101", "20240131")) - - # 过滤后的数据应该更少或相等 - assert len(df_filtered) <= total_rows - - # 所有日期都应该在范围内 - if len(df_filtered) > 0: - dates = df_filtered["trade_date"].to_list() - assert all("20240101" <= d <= "20240131" for d in dates) - - -class TestDataLoaderCache: - """测试 DataLoader 缓存机制""" - - TEST_START_DATE = "20240101" - TEST_END_DATE = "20240331" - - @pytest.fixture - def loader(self): - """创建 DataLoader 实例""" - return DataLoader(data_dir="data") - - def test_cache_populated(self, loader): - """测试加载后缓存被填充""" - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ) - ] - - # 第一次加载 - loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - - # 检查缓存 - assert len(loader._cache) > 0 - - def test_cache_used(self, loader): - """测试第二次加载使用缓存(更快)""" - import time - - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ) - ] - - # 第一次加载 - start = time.time() - df1 = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - time1 = time.time() - start - - # 第二次加载(应该使用缓存) - start = time.time() - df2 = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - time2 = time.time() - start - - # 数据应该相同 - assert df1.shape == df2.shape - - # 第二次应该更快 - print(f"[TEST] First load: {time1:.3f}s, cached load: {time2:.3f}s") - assert time2 < time1, "Cached load should be faster" - - def test_clear_cache(self, loader): - """测试 clear_cache() 清空缓存""" - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ) - ] - - # 加载数据 - loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - assert len(loader._cache) > 0 - - # 清空缓存 - loader.clear_cache() - assert len(loader._cache) == 0 - - def test_cache_info(self, loader): - """测试 get_cache_info()""" - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ) - ] - - # 加载前 - info_before = loader.get_cache_info() - assert info_before["entries"] == 0 - - # 加载后 - loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) - info_after = loader.get_cache_info() - assert info_after["entries"] > 0 - assert info_after["total_rows"] > 0 - - -class TestDataLoaderErrors: - """测试 DataLoader 错误处理""" - - def test_table_not_exists(self): - """测试表不存在时的处理""" - loader = DataLoader(data_dir="data") - specs = [ - DataSpec( - source="nonexistent_table", - columns=["ts_code", "trade_date", "close"], - lookback_days=1, - ) - ] - - # 应该返回空 DataFrame 或抛出异常 - with pytest.raises(Exception): - loader.load(specs) - - def test_column_not_found(self): - """测试列不存在时抛出 KeyError""" - loader = DataLoader(data_dir="data") - specs = [ - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "nonexistent_column"], - lookback_days=1, - ) - ] - - with pytest.raises(KeyError, match="nonexistent_column"): - loader.load(specs) - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/factors/test_data_spec.py b/tests/factors/test_data_spec.py deleted file mode 100644 index 79f4352..0000000 --- a/tests/factors/test_data_spec.py +++ /dev/null @@ -1,328 +0,0 @@ -"""Factors 模块测试 - Phase 1: 数据类型定义测试 - -测试范围: -- DataSpec: 数据需求规格的创建和验证 -- FactorContext: 计算上下文的创建 -- FactorData: 数据容器的基本操作 -- HDF5 数据读取: 验证能正确读取 daily.h5 文件 -""" - -import pytest -import polars as pl -import pandas as pd -from pathlib import Path - -from src.factors import DataSpec, FactorContext, FactorData - - -class TestDataSpec: - """测试 DataSpec 数据需求规格""" - - def test_valid_dataspec_creation(self): - """测试有效的 DataSpec 创建""" - spec = DataSpec( - source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5 - ) - assert spec.source == "daily" - assert spec.columns == ["ts_code", "trade_date", "close"] - assert spec.lookback_days == 5 - - def test_dataspec_default_lookback(self): - """测试 DataSpec 默认值 lookback_days=1""" - spec = DataSpec(source="daily", columns=["ts_code", "trade_date", "close"]) - assert spec.lookback_days == 1 - - def test_dataspec_frozen_immutable(self): - """测试 DataSpec 是 frozen(不可变)""" - spec = DataSpec( - source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5 - ) - with pytest.raises(FrozenInstanceError): - spec.source = "other" - - def test_dataspec_lookback_less_than_1_raises(self): - """测试 lookback_days < 1 时抛出 ValueError""" - with pytest.raises(ValueError, match="lookback_days must be >= 1"): - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=0, - ) - - with pytest.raises(ValueError, match="lookback_days must be >= 1"): - DataSpec( - source="daily", - columns=["ts_code", "trade_date", "close"], - lookback_days=-1, - ) - - def test_dataspec_missing_required_columns_raises(self): - """测试缺少 ts_code 或 trade_date 时抛出 ValueError""" - # 缺少 ts_code - with pytest.raises(ValueError, match="columns must contain"): - DataSpec(source="daily", columns=["trade_date", "close"], lookback_days=5) - - # 缺少 trade_date - with pytest.raises(ValueError, match="columns must contain"): - DataSpec(source="daily", columns=["ts_code", "close"], lookback_days=5) - - # 两者都缺少 - with pytest.raises(ValueError, match="columns must contain"): - DataSpec(source="daily", columns=["close", "open", "high"], lookback_days=5) - - def test_dataspec_empty_source_raises(self): - """测试空 source 时抛出 ValueError""" - with pytest.raises(ValueError, match="source cannot be empty string"): - DataSpec( - source="", columns=["ts_code", "trade_date", "close"], lookback_days=5 - ) - - -class TestFactorContext: - """测试 FactorContext 计算上下文""" - - def test_default_creation(self): - """测试默认值创建""" - ctx = FactorContext() - assert ctx.current_date is None - assert ctx.current_stock is None - assert ctx.trade_dates is None - - def test_full_creation(self): - """测试完整参数创建""" - ctx = FactorContext( - current_date="20240101", - current_stock="000001.SZ", - trade_dates=["20240101", "20240102", "20240103"], - ) - assert ctx.current_date == "20240101" - assert ctx.current_stock == "000001.SZ" - assert ctx.trade_dates == ["20240101", "20240102", "20240103"] - - def test_partial_creation(self): - """测试部分参数创建""" - ctx = FactorContext(current_date="20240101") - assert ctx.current_date == "20240101" - assert ctx.current_stock is None - assert ctx.trade_dates is None - - def test_dataclass_methods(self): - """测试 dataclass 自动生成的方法""" - ctx = FactorContext(current_date="20240101") - # __repr__ - assert "FactorContext" in repr(ctx) - assert "20240101" in repr(ctx) - # __eq__ - ctx2 = FactorContext(current_date="20240101") - assert ctx == ctx2 - ctx3 = FactorContext(current_date="20240102") - assert ctx != ctx3 - - -class TestFactorData: - """测试 FactorData 数据容器""" - - @pytest.fixture - def sample_df(self): - """创建示例 DataFrame""" - return pl.DataFrame( - { - "ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"], - "trade_date": ["20240101", "20240101", "20240102", "20240102"], - "close": [10.0, 20.0, 10.5, 20.5], - "volume": [1000, 2000, 1100, 2100], - } - ) - - @pytest.fixture - def cs_context(self): - """截面因子上下文""" - return FactorContext(current_date="20240101") - - @pytest.fixture - def ts_context(self): - """时序因子上下文""" - return FactorContext(current_stock="000001.SZ") - - def test_get_column(self, sample_df, cs_context): - """测试 get_column 返回正确的 Series""" - data = FactorData(sample_df, cs_context) - close_series = data.get_column("close") - assert isinstance(close_series, pl.Series) - assert close_series.to_list() == [10.0, 20.0, 10.5, 20.5] - - def test_get_column_keyerror(self, sample_df, cs_context): - """测试 get_column 列不存在时抛出 KeyError""" - data = FactorData(sample_df, cs_context) - with pytest.raises(KeyError, match="Column 'nonexistent' not found"): - data.get_column("nonexistent") - - def test_filter_by_date(self, sample_df, cs_context): - """测试 filter_by_date 返回正确的过滤结果""" - data = FactorData(sample_df, cs_context) - filtered = data.filter_by_date("20240101") - assert len(filtered) == 2 - assert filtered.to_polars()["ts_code"].to_list() == ["000001.SZ", "000002.SZ"] - assert filtered.to_polars()["close"].to_list() == [10.0, 20.0] - - def test_filter_by_date_empty_result(self, sample_df, cs_context): - """测试 filter_by_date 日期不存在时返回空的 FactorData""" - data = FactorData(sample_df, cs_context) - filtered = data.filter_by_date("20241231") - assert len(filtered) == 0 - assert isinstance(filtered, FactorData) - - def test_get_cross_section(self, sample_df, cs_context): - """测试 get_cross_section 返回 current_date 当天的数据""" - data = FactorData(sample_df, cs_context) - cs = data.get_cross_section() - assert len(cs) == 2 - assert cs["ts_code"].to_list() == ["000001.SZ", "000002.SZ"] - - def test_get_cross_section_no_date_raises(self, sample_df, ts_context): - """测试 get_cross_section current_date 为 None 时抛出 ValueError""" - data = FactorData(sample_df, ts_context) - with pytest.raises(ValueError, match="current_date is not set"): - data.get_cross_section() - - def test_to_polars(self, sample_df, cs_context): - """测试 to_polars 返回原始 DataFrame""" - data = FactorData(sample_df, cs_context) - df = data.to_polars() - assert isinstance(df, pl.DataFrame) - assert df.shape == sample_df.shape - assert df.columns == sample_df.columns - - def test_context_property(self, sample_df, cs_context): - """测试 context 属性返回正确的上下文""" - data = FactorData(sample_df, cs_context) - assert data.context == cs_context - assert data.context.current_date == "20240101" - - def test_len(self, sample_df, cs_context): - """测试 __len__ 返回正确的行数""" - data = FactorData(sample_df, cs_context) - assert len(data) == 4 - - def test_repr(self, sample_df, cs_context): - """测试 __repr__ 返回可读字符串""" - data = FactorData(sample_df, cs_context) - repr_str = repr(data) - assert "FactorData" in repr_str - assert "rows=4" in repr_str - assert "date=20240101" in repr_str - - -class TestHDF5DataAccess: - """测试 HDF5 数据读取功能""" - - def test_daily_h5_file_exists(self): - """测试 daily.h5 文件存在""" - data_path = Path("data/daily.h5") - assert data_path.exists(), f"daily.h5 文件不存在: {data_path.absolute()}" - - def test_daily_h5_can_read_with_pandas(self): - """测试能用 pandas 读取 daily.h5""" - data_path = Path("data/daily.h5") - df = pd.read_hdf(data_path, key="/daily") - - assert df is not None - assert len(df) > 0 - assert "ts_code" in df.columns - assert "trade_date" in df.columns - assert "close" in df.columns - - def test_daily_h5_columns(self): - """测试 daily.h5 包含预期的列""" - data_path = Path("data/daily.h5") - df = pd.read_hdf(data_path, key="/daily") - - expected_columns = [ - "trade_date", - "ts_code", - "open", - "high", - "low", - "close", - "pre_close", - "change", - "pct_chg", - "vol", - "amount", - "turnover_rate", - "volume_ratio", - ] - - for col in expected_columns: - assert col in df.columns, f"列 {col} 不存在于 daily.h5" - - def test_daily_h5_date_format(self): - """测试 daily.h5 日期格式正确""" - data_path = Path("data/daily.h5") - df = pd.read_hdf(data_path, key="/daily") - - # 检查日期格式是 YYYYMMDD 字符串 - sample_date = df["trade_date"].iloc[0] - assert isinstance(sample_date, str), "日期应该是字符串格式" - assert len(sample_date) == 8, "日期应该是 8 位字符串 (YYYYMMDD)" - assert sample_date.isdigit(), "日期应该只包含数字" - - def test_daily_h5_stock_format(self): - """测试 daily.h5 股票代码格式正确""" - data_path = Path("data/daily.h5") - df = pd.read_hdf(data_path, key="/daily") - - # 检查股票代码格式如 "000001.SZ" - sample_code = df["ts_code"].iloc[0] - assert isinstance(sample_code, str), "股票代码应该是字符串" - assert "." in sample_code, "股票代码应该包含交易所后缀" - assert sample_code.endswith((".SZ", ".SH", ".BJ")), ( - "股票代码应该以交易所后缀结尾" - ) - - def test_daily_h5_to_polars(self): - """测试将 daily.h5 数据转换为 Polars""" - data_path = Path("data/daily.h5") - pdf = pd.read_hdf(data_path, key="/daily") - - # 转换为 Polars - df = pl.from_pandas(pdf) - - assert isinstance(df, pl.DataFrame) - assert len(df) > 0 - assert "ts_code" in df.columns - assert "trade_date" in df.columns - - def test_daily_h5_sample_data_with_factors(self): - """测试用 daily.h5 真实数据创建 FactorData""" - data_path = Path("data/daily.h5") - pdf = pd.read_hdf(data_path, key="/daily") - - # 取前 100 行作为示例 - sample_pdf = pdf.head(100) - df = pl.from_pandas(sample_pdf) - - # 创建 FactorData - ctx = FactorContext(current_date=df["trade_date"][0]) - data = FactorData(df, ctx) - - # 验证基本操作 - assert len(data) == 100 - assert "close" in data.to_polars().columns - - # 测试 get_column - close_prices = data.get_column("close") - assert len(close_prices) == 100 - - # 测试 filter_by_date - first_date = df["trade_date"][0] - filtered = data.filter_by_date(first_date) - assert len(filtered) > 0 - - -# 导入 FrozenInstanceError -try: - from dataclasses import FrozenInstanceError -except ImportError: - # Python < 3.10 compatibility - FrozenInstanceError = AttributeError diff --git a/tests/factors/test_engine.py b/tests/factors/test_engine.py deleted file mode 100644 index 581f704..0000000 --- a/tests/factors/test_engine.py +++ /dev/null @@ -1,266 +0,0 @@ -"""测试执行引擎 - FactorEngine - -测试需求(来自 factor_implementation_plan.md): -- 测试 `compute()` 正确分发给截面计算 -- 测试 `compute()` 正确分发给时序计算 -- 测试无效 factor_type 时抛出 ValueError - -截面计算测试(防泄露验证): -- 测试数据裁剪正确(传入 [T-lookback+1, T]) -- 测试不包含未来日期 T+1 的数据 -- 测试每个日期独立计算 -- 测试结果包含所有日期和所有股票 -- 测试结果 DataFrame 格式正确 -- 测试多个 DataSpec 时 lookback 取最大值 - -时序计算测试(防泄露验证): -- 测试每只股票只看到自己的数据 -- 测试不包含其他股票的数据 -- 测试传入的是完整时间序列(向量化计算) -- 测试结果包含所有股票和所有日期 -- 测试结果 DataFrame 格式正确 -- 测试股票不在数据中时跳过(或填充 null) -""" - -import pytest -import polars as pl - -from src.factors import ( - DataSpec, - FactorContext, - FactorData, - DataLoader, - FactorEngine, - CrossSectionalFactor, - TimeSeriesFactor, -) - - -class SimpleCrossSectionalFactor(CrossSectionalFactor): - """简单的截面因子 - 返回收盘价排名""" - - name = "close_rank" - data_specs = [ - DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1) - ] - - def compute(self, data: FactorData) -> pl.Series: - cs = data.get_cross_section() - return cs["close"].rank() - - -class SimpleTimeSeriesFactor(TimeSeriesFactor): - """简单的时序因子 - 返回3日移动平均""" - - name = "ma3" - data_specs = [ - DataSpec( - "daily", - ["ts_code", "trade_date", "close"], - lookback_days=5, - ) - ] - - def __init__(self, period: int = 3): - super().__init__(period=period) - - def compute(self, data: FactorData) -> pl.Series: - return data.get_column("close").rolling_mean(self.params["period"]) - - -class ReturnFactor(CrossSectionalFactor): - """收益率因子 - 需要2天lookback计算收益率""" - - name = "return" - data_specs = [ - DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=2) - ] - - def compute(self, data: FactorData) -> pl.Series: - # 获取当前日期 - current_date = data.context.current_date - - # 获取当前日期的数据 - cs = data.get_cross_section() - - # 简单返回收盘价作为因子值 - # 实际应该计算收益率,但这里简化处理 - return cs["close"] - - -@pytest.fixture -def loader(): - """创建 DataLoader 实例""" - return DataLoader(data_dir="data") - - -@pytest.fixture -def engine(loader): - """创建 FactorEngine 实例""" - return FactorEngine(loader) - - -class TestFactorEngineDispatch: - """测试引擎分发逻辑""" - - def test_dispatch_cross_sectional(self, engine): - """测试 compute() 正确分发给截面计算""" - factor = SimpleCrossSectionalFactor() - - result = engine.compute(factor, start_date="20240101", end_date="20240105") - - assert isinstance(result, pl.DataFrame) - assert "trade_date" in result.columns - assert "ts_code" in result.columns - assert "close_rank" in result.columns - - def test_dispatch_time_series(self, engine, loader): - """测试 compute() 正确分发给时序计算""" - factor = SimpleTimeSeriesFactor(period=3) - - # 获取一些股票代码 - sample_data = loader.load( - [DataSpec("daily", ["ts_code", "trade_date"], lookback_days=1)] - ) - stock_codes = sample_data["ts_code"].unique().head(3).to_list() - - result = engine.compute( - factor, - stock_codes=stock_codes, - start_date="20240101", - end_date="20240110", - ) - - assert isinstance(result, pl.DataFrame) - assert "trade_date" in result.columns - assert "ts_code" in result.columns - assert "ma3" in result.columns - - def test_unknown_factor_type(self, engine): - """测试无效 factor_type 时抛出 ValueError""" - - class UnknownFactor: - name = "unknown" - factor_type = "unknown_type" - data_specs = [] - - factor = UnknownFactor() - - with pytest.raises(ValueError, match="Unknown factor type"): - engine.compute(factor) - - -class TestCrossSectionalComputation: - """测试截面计算(防泄露验证)""" - - def test_result_format(self, engine): - """测试结果 DataFrame 格式正确""" - factor = SimpleCrossSectionalFactor() - - result = engine.compute(factor, start_date="20240101", end_date="20240105") - - # 检查列 - assert "trade_date" in result.columns - assert "ts_code" in result.columns - assert factor.name in result.columns - - # 检查类型 - assert result["trade_date"].dtype == pl.Utf8 - assert result["ts_code"].dtype == pl.Utf8 - - def test_all_dates_present(self, engine): - """测试结果包含所有日期""" - factor = SimpleCrossSectionalFactor() - - start_date = "20240101" - end_date = "20240105" - - result = engine.compute(factor, start_date=start_date, end_date=end_date) - - if len(result) > 0: - dates = result["trade_date"].unique().to_list() - # 应该包含 start_date 和 end_date 之间的日期 - assert len(dates) > 0 - - def test_lookback_window(self, engine): - """测试多个 DataSpec 时 lookback 取最大值""" - factor = ReturnFactor() - - # lookback_days = 2 - result = engine.compute(factor, start_date="20240103", end_date="20240105") - - # 应该能计算出结果 - assert isinstance(result, pl.DataFrame) - - -class TestTimeSeriesComputation: - """测试时序计算(防泄露验证)""" - - def test_result_format(self, engine): - """测试结果 DataFrame 格式正确""" - factor = SimpleTimeSeriesFactor(period=3) - - result = engine.compute( - factor, - stock_codes=["000001.SZ"], - start_date="20240101", - end_date="20240110", - ) - - # 检查列 - assert "trade_date" in result.columns - assert "ts_code" in result.columns - assert factor.name in result.columns - - def test_single_stock_data(self, engine): - """测试每只股票只看到自己的数据""" - factor = SimpleTimeSeriesFactor(period=3) - - stock_codes = ["000001.SZ"] - - result = engine.compute( - factor, - stock_codes=stock_codes, - start_date="20240101", - end_date="20240110", - ) - - if len(result) > 0: - # 结果中只应该有指定的股票 - stocks = result["ts_code"].unique().to_list() - assert set(stocks) == set(stock_codes) - - def test_ma_calculation(self, engine): - """测试移动平均计算""" - factor = SimpleTimeSeriesFactor(period=3) - - result = engine.compute( - factor, - stock_codes=["000001.SZ"], - start_date="20240101", - end_date="20240110", - ) - - if len(result) > 2: - # 前2个应该是 null(因为 period=3) - ma_values = result[factor.name].to_list() - assert ma_values[0] is None or str(ma_values[0]) == "nan" - assert ma_values[1] is None or str(ma_values[1]) == "nan" - # 第3个应该有值 - assert ma_values[2] is not None - - def test_missing_stock_skipped(self, engine): - """测试股票不在数据中时返回空结果""" - factor = SimpleTimeSeriesFactor(period=3) - - result = engine.compute( - factor, - stock_codes=["NONEXISTENT.STOCK"], - start_date="20240101", - end_date="20240110", - ) - - # 应该返回空 DataFrame 或包含该股票但值为 null 的结果 - assert isinstance(result, pl.DataFrame) - # 对于不存在的股票,结果可能是空的 - # 或者包含该股票但值为 null diff --git a/tests/factors/test_factor_validation.py b/tests/factors/test_factor_validation.py deleted file mode 100644 index 810aab4..0000000 --- a/tests/factors/test_factor_validation.py +++ /dev/null @@ -1,397 +0,0 @@ -"""因子真实数据测试 - 与 Polars 原生计算对比 - -测试目标: -1. 时序因子 - 移动平均线 (MA) -2. 截面因子 - PE_Rank(市盈率排名) -3. 结合因子 - 时序 * 截面组合 - -每个因子都与原始 Polars 计算进行对比验证。 -""" - -import pytest -import pandas as pd -import polars as pl -import numpy as np -from src.factors import DataSpec, FactorContext, FactorData -from src.factors.base import CrossSectionalFactor, TimeSeriesFactor -from src.factors.composite import CompositeFactor, ScalarFactor - - -# ========== 测试数据准备 ========== - - -@pytest.fixture(scope="module") -def daily_data(): - """加载日线测试数据(直接使用 Polars)""" - with pd.HDFStore("data/daily.h5", mode="r") as store: - df = store["/daily"] - - # 筛选日期范围 - df = df[(df["trade_date"] >= "20240101") & (df["trade_date"] <= "20240430")] - - # 选择部分股票(取前20个) - stocks = df["ts_code"].unique()[:20] - df = df[df["ts_code"].isin(stocks)] - - # 直接返回 Polars DataFrame,不转 pandas - pl_df = pl.from_pandas(df) - pl_df = pl_df.sort(["ts_code", "trade_date"]) - - return pl_df - - -# ========== 时序因子定义 ========== - - -class MAFactor(TimeSeriesFactor): - """移动平均线因子(时序因子)""" - - name = "ma_factor" - data_specs = [ - DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5) - ] - - def __init__(self, period: int = 5): - super().__init__(period=period) - - def compute(self, data: FactorData) -> pl.Series: - close = data.get_column("close") - period = self.params["period"] - return close.rolling_mean(period) - - -class PERankFactor(CrossSectionalFactor): - """PE 市盈率排名因子(截面因子)""" - - name = "pe_rank_factor" - data_specs = [ - DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1) - ] - - def compute(self, data: FactorData) -> pl.Series: - cs = data.get_cross_section() - close = cs["close"] - return close.rank() / close.len() - - -# ========== 测试用例 ========== - - -class TestTimeSeriesFactor: - """时序因子测试""" - - def test_ma_factor(self, daily_data): - """测试 MA 因子与 Polars 原生计算对比""" - period = 5 - sample_stock = daily_data["ts_code"].to_list()[0] - stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort( - "trade_date" - ) - - # Polars 基准计算 - polars_result = stock_df.with_columns( - pl.col("close") - .rolling_mean(window_size=period) - .over("ts_code") - .alias("ma_polars") - ) - - # 因子框架计算 - context = FactorContext(current_stock=sample_stock) - factor_data = FactorData( - stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context - ) - - ma_factor = MAFactor(period=period) - factor_result = ma_factor.compute(factor_data).to_numpy() - - # 对比结果 - polars_values = polars_result["ma_polars"].to_numpy() - - # 去除 NaN 后对比 - valid_idx = ~np.isnan(polars_values) - polars_valid = polars_values[valid_idx] - factor_valid = factor_result[valid_idx] - - diff = np.abs(polars_valid - factor_valid) - max_diff = np.max(diff) - - print(f"\n[时序因子 MA({period}) 对比]") - print(f" 样本股票: {sample_stock}") - print(f" 有效数据点: {len(polars_valid)}") - print(f" 最大差异: {max_diff:.15f}") - print(f" 样本数据 (前5个):") - for i in range(min(5, len(polars_valid))): - print( - f" Polars: {polars_valid[i]:.6f}, Factor: {factor_valid[i]:.6f}, Diff: {abs(polars_valid[i] - factor_valid[i]):.15f}" - ) - - assert max_diff < 1e-10, f"MA 因子计算差异过大: {max_diff}" - - -class TestCrossSectionalFactor: - """截面因子测试""" - - def test_pe_rank_factor(self, daily_data): - """测试 PE_Rank 因子与 Polars 原生计算对比""" - trade_dates = daily_data["trade_date"].unique().to_list() - sample_date = trade_dates[50] - date_df = daily_data.filter(pl.col("trade_date") == sample_date) - - # Polars 基准计算 - polars_result = date_df.with_columns( - (pl.col("close").rank() / pl.col("close").count()).alias("pe_rank_polars") - ) - - # 因子框架计算 - context = FactorContext(current_date=str(sample_date)) - factor_data = FactorData( - date_df.with_columns( - [pl.col("trade_date").cast(pl.Utf8), pl.col("ts_code").cast(pl.Utf8)] - ), - context, - ) - - pe_factor = PERankFactor() - factor_result = pe_factor.compute(factor_data).to_numpy() - - # 对比结果 - polars_values = polars_result["pe_rank_polars"].to_numpy() - - diff = np.abs(polars_values - factor_result) - max_diff = np.max(diff) - - print(f"\n[截面因子 PE_Rank 对比]") - print(f" 样本日期: {sample_date}") - print(f" 股票数量: {len(polars_values)}") - print(f" 最大差异: {max_diff:.15f}") - print(f" 样本数据 (前5个):") - for i in range(min(5, len(polars_values))): - ts_code = polars_result["ts_code"].to_numpy()[i] - print( - f" {ts_code}: Polars: {polars_values[i]:.6f}, Factor: {factor_result[i]:.6f}" - ) - - assert max_diff < 1e-10, f"PE_Rank 因子计算差异过大: {max_diff}" - - -class TestCompositeFactor: - """结合因子测试""" - - def test_scalar_composite(self, daily_data): - """测试标量组合因子: 0.5 * MA""" - period = 5 - sample_stock = daily_data["ts_code"].to_list()[0] - stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort( - "trade_date" - ) - - # Polars 基准计算 - polars_ma = stock_df.with_columns( - pl.col("close").rolling_mean(window_size=period).over("ts_code").alias("ma") - ) - polars_combined = 0.5 * polars_ma["ma"].to_numpy() - - # 因子框架计算 - context = FactorContext(current_stock=sample_stock) - factor_data = FactorData( - stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context - ) - - # 组合因子: 0.5 * MA - ma_factor = MAFactor(period=period) - scalar_factor = 0.5 * ma_factor - factor_result = scalar_factor.compute(factor_data).to_numpy() - - # 对比结果 - valid_idx = ~np.isnan(polars_combined) - polars_valid = polars_combined[valid_idx] - factor_valid = factor_result[valid_idx] - - diff = np.abs(polars_valid - factor_valid) - max_diff = np.max(diff) - - print(f"\n[结合因子 0.5*MA({period}) 对比]") - print(f" 公式: 0.5 * MA({period})") - print(f" 有效数据点: {len(polars_valid)}") - print(f" 最大差异: {max_diff:.15f}") - print(f" 样本数据 (前5个):") - for i in range(min(5, len(polars_valid))): - print( - f" Polars: {polars_valid[i]:.6f}, Factor: {factor_valid[i]:.6f}, Diff: {abs(polars_valid[i] - factor_valid[i]):.15f}" - ) - - assert max_diff < 1e-10, f"组合因子计算差异过大: {max_diff}" - - def test_factor_addition(self, daily_data): - """测试因子加法组合: MA(5) + MA(10)""" - sample_stock = daily_data["ts_code"].to_list()[0] - stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort( - "trade_date" - ) - - context = FactorContext(current_stock=sample_stock) - - # Polars 基准计算 - polars_ma5 = stock_df.with_columns( - pl.col("close").rolling_mean(window_size=5).over("ts_code").alias("ma5") - ) - polars_ma10 = stock_df.with_columns( - pl.col("close").rolling_mean(window_size=10).over("ts_code").alias("ma10") - ) - polars_combined = polars_ma5["ma5"].to_numpy() + polars_ma10["ma10"].to_numpy() - - # 因子框架计算 - factor_data = FactorData( - stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context - ) - - ma5 = MAFactor(period=5) - ma10 = MAFactor(period=10) - combined = ma5 + ma10 - - factor_result = combined.compute(factor_data).to_numpy() - - # 对比结果 - valid_idx = ~(np.isnan(polars_combined) | np.isnan(factor_result)) - polars_valid = polars_combined[valid_idx] - factor_valid = factor_result[valid_idx] - - diff = np.abs(polars_valid - factor_valid) - max_diff = np.max(diff) - - print(f"\n[结合因子 MA(5) + MA(10) 对比]") - print(f" 有效数据点: {len(polars_valid)}") - print(f" 最大差异: {max_diff:.15f}") - - assert max_diff < 1e-10, f"因子加法组合差异过大: {max_diff}" - - -class TestFactorComparison: - """全面对比测试""" - - def test_all_factors_summary(self, daily_data): - """汇总所有因子测试结果""" - print("\n" + "=" * 60) - print("因子测试汇总") - print("=" * 60) - - # 测试多个时序周期 - for period in [5, 10, 20]: - sample_stock = daily_data["ts_code"].to_list()[0] - stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort( - "trade_date" - ) - - polars_result = stock_df.with_columns( - pl.col("close") - .rolling_mean(window_size=period) - .over("ts_code") - .alias("ma") - ) - - context = FactorContext(current_stock=sample_stock) - factor_data = FactorData( - stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context - ) - - ma_factor = MAFactor(period=period) - factor_result = ma_factor.compute(factor_data).to_numpy() - - polars_values = polars_result["ma"].to_numpy() - valid_idx = ~np.isnan(polars_values) - - diff = np.abs(polars_values[valid_idx] - factor_result[valid_idx]) - max_diff = np.max(diff) - - status = "通过" if max_diff < 1e-10 else "失败" - print(f" MA({period}): 最大差异 = {max_diff:.2e} {status}") - - # 测试截面因子 - trade_dates = daily_data["trade_date"].unique().to_list() - sample_date = trade_dates[50] - date_df = daily_data.filter(pl.col("trade_date") == sample_date) - - polars_result = date_df.with_columns( - (pl.col("close").rank() / pl.col("close").count()).alias("rank") - ) - - context = FactorContext(current_date=str(sample_date)) - factor_data = FactorData( - date_df.with_columns( - [pl.col("trade_date").cast(pl.Utf8), pl.col("ts_code").cast(pl.Utf8)] - ), - context, - ) - - pe_factor = PERankFactor() - factor_result = pe_factor.compute(factor_data).to_numpy() - - polars_values = polars_result["rank"].to_numpy() - diff = np.abs(polars_values - factor_result) - max_diff = np.max(diff) - - status = "通过" if max_diff < 1e-10 else "失败" - print(f" PE_Rank: 最大差异 = {max_diff:.2e} {status}") - - print("=" * 60) - - # 测试多个时序周期 - for period in [5, 10, 20]: - sample_stock = daily_data["ts_code"].to_list()[0] - stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort( - "trade_date" - ) - - polars_result = stock_df.with_columns( - pl.col("close") - .rolling_mean(window_size=period) - .over("ts_code") - .alias("ma") - ) - - context = FactorContext(current_stock=sample_stock) - factor_data = FactorData( - stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context - ) - - ma_factor = MAFactor(period=period) - factor_result = ma_factor.compute(factor_data).to_numpy() - - polars_values = polars_result["ma"].to_numpy() - valid_idx = ~np.isnan(polars_values) - - diff = np.abs(polars_values[valid_idx] - factor_result[valid_idx]) - max_diff = np.max(diff) - - status = "通过" if max_diff < 1e-10 else "失败" - print(f" MA({period}): 最大差异 = {max_diff:.2e} {status}") - - # 测试截面因子 - trade_dates = daily_data["trade_date"].unique().to_list() - sample_date = trade_dates[50] - date_df = daily_data.filter(pl.col("trade_date") == sample_date) - - polars_result = date_df.with_columns( - (pl.col("close").rank() / pl.col("close").count()).alias("rank") - ) - - context = FactorContext(current_date=str(sample_date)) - factor_data = FactorData( - date_df.with_columns( - [pl.col("trade_date").cast(pl.Utf8), pl.col("ts_code").cast(pl.Utf8)] - ), - context, - ) - - pe_factor = PERankFactor() - factor_result = pe_factor.compute(factor_data).to_numpy() - - polars_values = polars_result["rank"].to_numpy() - diff = np.abs(polars_values - factor_result) - max_diff = np.max(diff) - - status = "通过" if max_diff < 1e-10 else "失败" - print(f" PE_Rank: 最大差异 = {max_diff:.2e} {status}") - - print("=" * 60)