diff --git a/.gitignore b/.gitignore index 50e6e55..5c3e46a 100644 --- a/.gitignore +++ b/.gitignore @@ -76,5 +76,9 @@ temp/ # 数据目录(允许跟踪,但忽略内容) data/* +# 训练输出目录(不需要版本控制) +src/training/output/* +!src/training/output/.gitkeep + # AI Agent 工作目录 /.sisyphus/ diff --git a/docs/Classify2_load_model_doc.md b/docs/Classify2_load_model_doc.md new file mode 100644 index 0000000..b6c705c --- /dev/null +++ b/docs/Classify2_load_model_doc.md @@ -0,0 +1,467 @@ +# Classify2_load_model.ipynb 代码流程详解文档 + +## 概述 + +本文档详细描述了 `Classify2_load_model.ipynb` 文件中完整的代码流程,该notebook实现了一个股票分类模型的特征工程与数据预处理流程。整个流程涵盖了从原始数据加载、多源数据合并、因子计算、数据清洗、特征标准化等多个环节,最终输出可供机器学习模型使用的特征矩阵。 + +--- + +## 第一部分:环境配置与初始化 + +### 1.1 基础环境设置 + + notebook开头的第一个代码单元负责基础的开发环境配置。首先启用Jupyter的自动重载功能,通过`%load_ext autoreload`和`%autoreload 2`指令,使得在修改导入的模块后能够自动重新加载,这对于开发调试阶段非常实用。随后进行垃圾回收机制的相关设置,导入`gc`模块用于手动管理内存,因为后续处理的数据量较大(数GB级别),及时释放不需要的对象可以避免内存溢出。 + +系统路径的配置通过`sys.path.append`完成,将项目根目录添加到Python的模块搜索路径中,这样可以直接导入项目内的自定义模块。打印工作目录用于确认当前运行环境。数据处理的核心库`pandas`被导入用于表格数据处理,同时导入项目内部的因子模块和工具函数,包括`get_rolling_factor`和`get_simple_factor`用于计算技术因子,`read_industry_data`用于读取行业分类数据,`calculate_score`用于计算目标变量。最后通过`warnings.filterwarnings("ignore")`忽略所有警告信息,保持输出界面的整洁。 + +### 1.2 并行计算配置 + +第二个代码单元配置了Modin框架的并行计算参数。通过设置环境变量`os.environ["MODIN_CPUS"] = "4"`,指定使用4个CPU核心进行并行计算。Modin是一个pandas的并行替代库,能够在多核CPU上并行执行pandas操作,显著加速大规模数据的处理速度。这一配置对于后续处理数千万条记录的数据框至关重要。 + +--- + +## 第二部分:核心数据加载与合并 + +### 2.1 日线数据加载 + +第三个代码单元是整个流程中最关键的数据加载步骤,依次从HDF5文件中读取多个数据源并合并。HDF5是一种高效的分块压缩存储格式,非常适合存储大规模数值型数据。 + +**第一步:加载日线行情数据** + +使用`read_and_merge_h5_data`函数从`daily_data.h5`文件中读取日线行情数据。该函数是项目自定义的工具函数,其核心逻辑在`main/utils/utils.py`中实现。读取的字段包括股票代码`ts_code`、交易日期`trade_date`、开盘价`open`、收盘价`close`、最高价`high`、最低价`low`、成交量`vol`、成交额`amount`以及涨跌幅`pct_chg`。这些字段构成分析的基础价格信息。首次调用该函数时`df`参数为空,因此直接返回读取的数据作为初始数据框。 + +**第二步:加载日线基础数据** + +继续调用`read_and_merge_h5_data`读取`daily_basic.h5`文件,获取每日股票的基础财务指标。本次读取使用内连接(inner join)方式与已有数据进行合并,这意味着只保留两个数据集中都存在的股票记录。读取的字段包括`turnover_rate`(换手率)、`pe_ttm`(滚动市盈率)、`circ_mv`(流通市值)、`total_mv`(总市值)以及`volume_ratio`(量比)。这些指标对于评估股票的流动性和估值水平非常重要。 + +**第三步:加载涨跌停限制数据** + +从`stk_limit.h5`文件中读取股票的涨跌停价格信息。读取的字段包括`pre_close`(前一日收盘价)、`up_limit`(涨停价)和`down_limit`(跌停价)。这些信息用于后续识别股票的涨停状态,是构建目标变量的重要依据。 + +**第四步:加载资金流数据** + +从`money_flow.h5`文件中读取资金流信息,这是计算资金流因子的基础数据。读取的字段包括各类成交量的分解:`buy_sm_vol`和`sell_sm_vol`代表小单买卖成交量、`buy_lg_vol`和`sell_lg_vol`代表大单买卖成交量、`buy_elg_vol`和`sell_elg_vol`代表超大单买卖成交量,以及`net_mf_vol`(净资金流成交量)。通过分析这些不同档位的资金流向,可以判断机构投资者和散户的交易行为。 + +**第五步:加载筹码分布数据** + +最后从`cyq_perf.h5`文件中读取筹码分布相关指标。这些数据反映的是股票持仓成本分布情况,包括历史最低价`his_low`、历史最高价`his_high`、不同成本分位线的价格(`cost_5pct`、`cost_15pct`、`cost_50pct`、`cost_85pct`、`cost_95pct`)、加权平均成本`weight_avg`以及获利盘比例`winner_rate`。这些指标是计算筹码分布因子的核心数据。 + +经过上述五个步骤的依次合并,最终生成的数据框包含9436343条记录、33个字段,占用内存约2.3GB。 + +### 2.2 行业数据加载与合并 + +第四个代码单元处理行业分类数据的加载与合并。首先使用`read_and_merge_h5_data`函数从`industry_data.h5`文件中读取行业分类数据,提取股票代码、二级行业代码`l2_code`以及行业纳入日期`in_date`。 + +随后定义了`merge_with_industry_data`函数,用于将行业数据与主数据框进行时间匹配合并。该函数的核心逻辑分为以下几个步骤:首先确保日期字段转换为datetime类型;然后分别对行业数据和主数据按股票代码和日期排序;接着使用`pd.merge_asof`函数进行向后合并(direction='backward'),这意味着对于每一股票的每一个交易日,会匹配该日之前(包括当日)最近的行业变更记录;对于交易日期早于所有行业变更日期的记录,使用每个股票的最早行业代码进行填充。 + +这一合并策略的设计目的是确保在任意交易日期都能获取到该股票当前所属的行业分类,这对于后续进行行业中性化处理至关重要。 + +--- + +## 第三部分:指数数据处理 + +### 3.1 指数数据读取 + +第五个代码单元实现指数数据的读取与指标计算。定义了两个核心函数:`calculate_indicators`用于计算单个指数的技术指标,`generate_index_indicators`用于批量处理多个指数。 + +在`calculate_indicators`函数中,首先对数据按交易日期排序,然后计算以下指标: + +**当日涨跌幅**:通过`(close - pre_close) / pre_close * 100`计算,这是最基础的收益率指标。 + +**RSI指标**:相对强弱指数,计算过程包括首先计算价格变动delta,然后分离上涨和下跌部分,分别计算14日滚动平均,最后通过公式`100 - (100 / (1 + rs))`得到RSI值。RSI是衡量股票近期涨跌动能的经典技术指标。 + +**MACD指标**:移动平均收敛发散指标,计算过程包括12日EMA减去26日EMA得到MACD线,9日EMA得到信号线,MACD与信号线的差值得到MACD柱。MACD是判断趋势方向和动量变化的重要工具。 + +**情绪因子**: + +- **上涨比例(up_ratio_20d)**:过去20天上涨天数占比,反映市场整体情绪 +- **成交量变化率(volume_change_rate)**:当日成交量与20日均量的比率,反映成交量异常变化 +- **波动率(volatility)**:过去20天日收益率的标准差,反映市场波动程度 +- **成交额变化率(amount_change_rate)**:当日成交额与20日均额的比率 + +`generate_index_indicators`函数则对所有指数分别计算上述指标,然后将结果透视为宽表格式,使每个指数的指标成为独立的列。最终输出包含多个指数技术指标的数据框,这些指标可以作为市场情绪的代理变量加入模型。 + +### 3.2 行业指数数据处理 + +第六个代码单元使用talib库和numpy库进行更专业的技术指标计算。talib是专业的技术分析库,包含数百种经典技术指标的实现。 + +从`sw_daily.h5`文件读取行业日线数据后,依次计算: + +**OBV(能量潮指标)**:通过`talib.OBV`函数计算,将成交量与价格变动方向结合,衡量资金流入流出的强度。 + +**短期收益率(return_5)**:5日收益率,计算公式为`x / x.shift(5) - 1`。 + +**中期收益率(return_20)**:20日收益率,用于捕捉中短期动量。 + +**动量因子(act_factor)**:通过`get_act_factor`函数计算,该函数对不同周期的EMA(指数移动平均)进行arctan变换,得到平滑的动量因子。这是项目自定义的核心因子之一。 + +**收益率分位数排名**:将收益率在截面内转换为百分位排名,消除不同行业间的基数差异,便于跨行业比较。 + +最终对所有新增字段添加`industry_`前缀,并重命名`ts_code`为`cat_l2_code`,形成行业级别的因子数据,可用于后续的行业对标分析。 + +--- + +## 第四部分:财务数据加载 + +### 4.1 财务指标数据 + +第九至十一个代码单元负责加载各类财务数据,为模型增加基本面因子。 + +**财务指标数据(fina_indicator)**:从`fina_indicator.h5`文件读取,包含`undist_profit_ps`(每股未分配利润)、`ocfps`(每股经营现金流)、`bps`(每股净资产)、`roa`(资产收益率)和`roe`(净资产收益率)。这些指标反映企业的盈利能力和财务健康状况。 + +**现金流数据(cashflow)**:从`cashflow.h5`文件读取,包含`n_cashflow_act`(经营活动净现金流),用于评估企业现金流的真实状况。 + +**资产负债表数据(balancesheet)**:从`balancesheet.h5`文件读取,包含`money_cap`(货币资金)和`total_liab`(总负债),用于计算企业的偿债能力。 + +这些财务数据通过`ann_date`(公告日期)与交易数据进行时间匹配,确保使用最新披露的财务信息。需要注意的是,财务数据存在滞后性,通常季报在季度结束后一段时间才披露,因此在匹配时使用向后合并策略。 + +--- + +## 第五部分:数据过滤与清洗 + +### 5.1 数据过滤规则 + +第十一个代码单元定义了`filter_data`函数,应用一系列过滤规则清理数据: + +**排除ST股票**:`df[~df['is_st']]`过滤掉所有ST(特别处理)股票,这类股票存在较大风险。 + +**排除北京股票**:`df[~df['ts_code'].str.endswith('BJ')]`排除北京交易所的股票。 + +**排除创业板和科创板**:`df[~df['ts_code'].str.startswith('30')]`排除创业板股票,`df[~df['ts_code'].str.startswith('68')]`排除科创板股票。这一设计可能是因为这两个板块的股票特性与主板有较大差异。 + +**排除ST类代码**:`df[~df['ts_code'].str.startswith('8')]`排除了以8开头的股票代码,通常这类代码也属于风险警示类。 + +**时间范围过滤**:`df[df['trade_date'] >= '2019-01-01']]`只保留2019年及以后的数据,确保数据质量和一致性。 + +**删除冗余列**:如果存在`in_date`列则删除,避免与交易日期混淆。 + +过滤后的数据量从9436343条减少到5087384条,约减少46%。 + +--- + +## 第六部分:因子计算 + +### 6.1 财务因子计算 + +第十二个代码单元是因子计算的核心部分,首先计算一系列财务相关因子: + +**现金流因子(cashflow_to_ev_factor)**:将现金流数据与企业价值进行对比,计算企业现金流的相对估值水平。 + +**市净率因子(book_to_price_ratio)**:通过`caculate_book_to_price_ratio`函数计算,每股净资产与股价的比值,是价值投资的经典指标。 + +### 6.2 基础技术因子 + +继续计算各类技术因子: + +**换手率均值(turnover_rate_mean_5)**:5日平均换手率,反映股票的交易活跃程度。 + +**收益率方差(variance_20)**:20日收益率方差,衡量短期波动性。 + +**BBI比率因子(bbi_ratio_factor)**:BBI是多空指数的简称,通过不同周期均线的组合判断多空趋势。 + +**日偏离度(daily_deviation)**:当日涨跌幅与市场平均涨跌幅的差异,反映个股的相对强弱。 + +**日行业偏离度(daily_industry_deviation)**:相对于所在行业平均涨跌幅的偏离,反映行业内相对表现。 + +### 6.3 滚动因子与简单因子 + +调用`get_rolling_factor`和`get_simple_factor`函数批量计算大量因子。这两个函数定义在`main/factor/factor.py`中,是项目最核心的因子计算模块。 + +**get_rolling_factor函数**主要计算的因子包括: + +**资金流因子组**: + +- `lg_elg_net_buy_vol`:超大单加大单的净买入量 +- `flow_lg_elg_intensity`:主力资金流强度,等于净买入量除以成交量 +- `sm_net_buy_vol`:散户净买入量 +- `flow_divergence_diff`:散户与主力背离度,主力净买入减去散户净买入 +- `flow_divergence_ratio`:背离比率形式 +- `lg_elg_buy_prop`:主力买入占比 +- `flow_struct_buy_change`:资金流结构变动,1日变化率 +- `flow_lg_elg_accel`:资金流加速度,二阶导数 + +**筹码分布因子组**: + +- `chip_concentration_range`:(95%成本价 - 5%成本价)/ 当前价格,反映筹码集中程度 +- `chip_skewness`:(加权平均成本 - 50%成本价)/ 50%成本价,反映成本分布偏斜方向 +- `floating_chip_proxy`:浮筹比例,结合获利盘和价格位置 +- `cost_support_15pct_change`:15%成本线变化率,反映支撑位变动 +- `cat_winner_price_zone`:获利盘价格区域分类 + +**资金与筹码结合因子**: + +- `flow_chip_consistency`:资金流与筹码结构一致性 +- `profit_taking_vs_absorb`:获利了结压力与承接盘强度 + +**收益率分布因子**: + +- `upside_vol`和`downside_vol`:上涨和下跌方向的标准差 +- `vol_ratio`:上下波动率比值 +- `return_skew`和`return_kurtosis`:收益率的偏度和峰度 + +**成交量因子**: + +- `volume_change_rate`:成交量变化率 +- `cat_volume_breakout`:成交量突破信号 +- `turnover_deviation`:换手率偏离度 +- `cat_turnover_spike`:换手率激增信号 +- `avg_volume_ratio`和`cat_volume_ratio_breakout`:量比相关因子 + +**talib技术指标**: + +- `atr_14`和`atr_6`:14日和6日平均真实波幅 +- `obv`和`maobv_6`:能量潮及其6日均线 +- `rsi_3`:3日RSI + +**收益率因子**: + +- `return_5`和`return_20`:5日和20日收益率 +- `std_return_5`、`std_return_90`、`std_return_90_2`:不同周期的收益率标准差 + +**EMA与动量因子**: + +- 各周期EMA(5、13、20、60日) +- `act_factor1`到`act_factor4`:基于EMA的动量因子 + +**协方差因子**: + +- `cov`:高价与成交量的滚动协方差 +- `delta_cov`:协方差差分 +- `alpha_22_improved`:改进的alpha因子 + +**其他因子**: + +- `alpha_003`:收盘价与开盘价的相对位置 +- `alpha_007`:收盘价与成交量的5日滚动相关性 +- `alpha_013`:5日与20日累计收益差值 +- `vol_break`:价格突破85%成本线且量比大于2的信号 +- `weight_roc5`:加权成本5日变化率 +- `price_cost_divergence`:价格与成本相关性 +- `smallcap_concentration`:小盘股筹码集中度 +- `cost_stability`:筹码稳定性指数 +- `liquidity_risk`:筹码流动性风险 + +**get_simple_factor函数**在滚动因子的基础上进一步计算衍生因子: + +- `momentum_factor`:动量因子,等于成交量变化率加0.5倍换手率偏离度 +- `resonance_factor`:共振因子,量比乘以涨跌幅 +- `log_close`:收盘价对数 +- `cat_vol_spike`:成交量激增分类变量 +- `up`和`down`:上下影线比例 +- `obv_maobv_6`:OBV与6日均线差值 +- `std_return_5_over_std_return_90`:短期与长期波动率比值 +- `act_factor5`和`act_factor6`:综合动量因子 +- `active_buy_volume_*`:各档位主动买入占比 +- `ctrl_strength`:控制强度,反映成本区间与历史区间的比值 +- `low_cost_dev`、`asymmetry`、`lock_factor`:各类筹码特征因子 +- `cat_golden_resonance`:黄金共振信号 + +### 6.4 资金流因子 + +第十三个代码单元继续计算`money_factor.py`中定义的各类资金流相关因子。这些因子主要分析大单资金的流动特征和动量属性。 + +**lg_flow_mom_corr**:大单资金流与20至60日滚动窗口的动量相关性,衡量资金流趋势的持续性。 + +**lg_flow_accel**:大单资金流加速度,捕捉资金流入流出的加速变化。 + +**profit_pressure**:盈利压力因子,结合价格位置和资金流向。 + +**underwater_resistance**:水下阻力因子,分析亏损区域的支持强度。 + +**cost_conc_std_20**:20日成本集中度标准差,反映筹码分布的稳定程度。 + +**profit_decay_20**:20日盈利衰减因子,衡量获利盘随时间的减少程度。 + +**vol_amp_loss_20**:20日波动幅度损失因子。 + +**vol_drop_profit_cnt_5**:5日内量价齐升计数。 + +**lg_flow_vol_interact_20**:20日大单资金流与成交量交互因子。 + +**cost_break_confirm_cnt_5**:5日内成本突破确认计数。 + +**atr_norm_channel_pos_14**:14日ATR标准化通道位置。 + +**turnover_diff_skew_20**:20日换手率差值偏度。 + +**lg_sm_flow_diverge_20**:20日大小单资金流背离。 + +**pullback_strong_20_20**:20日回撤强度因子。 + +**vol_wgt_hist_pos_20**:20日成交量加权历史位置。 + +**vol_adj_roc_20**:20日成交量调整变动率。 + +### 6.5 截面排名因子 + +第十三个代码单元的后半部分定义了一系列`cs_rank_*`开头的截面排名因子,这些因子都是通过截面排名方式计算得出,可以消除不同股票间的基数差异。 + +**cs_rank_net_lg_flow_val**:大单净流入值截面排名 + +**cs_rank_flow_divergence**:资金流背离截面排名 + +**cs_rank_industry_adj_lg_flow**:行业调整后大单资金流截面排名 + +**cs_rank_elg_buy_ratio**:超大单买入占比截面排名 + +**cs_rank_rel_profit_margin**:相对盈利-margin截面排名 + +**cs_rank_cost_breadth**:成本宽度截面排名 + +**cs_rank_dist_to_upper_cost**:到上方成本距离截面排名 + +**cs_rank_winner_rate**:获利盘比例截面排名 + +**cs_rank_intraday_range**:日内振幅截面排名 + +**cs_rank_close_pos_in_range**:收盘价在成本区间位置截面排名 + +**cs_rank_opening_gap**:跳空缺口截面排名(需要前收盘价) + +**cs_rank_pos_in_hist_range**:历史区间位置截面排名 + +**cs_rank_vol_x_profit_margin**:成交量与盈利margin交互截面排名 + +**cs_rank_lg_flow_price_concordance**:大单资金流与价格一致性截面排名 + +**cs_rank_turnover_per_winner**:单位获利盘换手率截面排名 + +**cs_rank_ind_cap_neutral_pe**:行业市值中性市盈率截面排名 + +**cs_rank_volume_ratio**:量比截面排名 + +**cs_rank_elg_buy_sell_sm_ratio**:超大单买卖与小单比值截面排名 + +**cs_rank_cost_dist_vol_ratio**:成本分布成交量比值截面排名 + +**cs_rank_size**:市值规模截面排名 + +经过上述所有因子计算,数据框最终包含181个字段,内存占用约6.5GB。 + +--- + +## 第七部分:数据预处理函数定义 + +### 7.1 特征漂移检测 + +第十四个代码单元定义了`remove_shifted_features`函数,用于检测训练集和测试集之间是否存在特征分布漂移。漂移检测使用两种统计方法: + +**KS检验(Kolmogorov-Smirnov检验)**:比较两个分布是否来自同一分布,通过p值判断,p值小于阈值(默认0.05)则认为存在显著差异。 + +**Wasserstein距离**:衡量两个分布之间的Earth Mover距离,距离大于阈值(默认0.1)则认为漂移严重。 + +如果特征同时满足两个条件,则认为该特征存在漂移并将其移除。这一步骤对于确保模型在测试集上的泛化能力非常重要。 + +### 7.2 标准化处理函数 + +第十五个代码单元定义了三个核心的标准化处理函数: + +**cs_mad_filter(截面MAD去极值函数)**: + +MAD(Median Absolute Deviation,中位绝对偏差)是一种稳健的极值检测方法。具体步骤包括:按日期分组计算每列的中位数,然后计算每个值与中位数的绝对偏差,再次取中位数得到MAD,最后将超出`[median - k * MAD, median + k * MAD]`范围的值截断到边界。默认k=3,scale_factor=1.4826使得MAD约等于正态分布的标准差。 + +**cs_neutralize_market_cap_numpy(市值中性化函数)**: + +对每个交易日的每只股票,使用OLS回归将因子值对市值(取对数)进行回归,提取残差作为中性化后的因子值。这一步骤消除因子中与市值相关的系统性偏差,使因子更能反映股票的真实特性而非市值效应。 + +**cs_zscore_standardize(截面Z-Score标准化函数)**: + +对每个截面(每日)计算各因子的均值和标准差,然后进行Z-Score变换:`Z = (value - mean) / (std + epsilon)`。这使得不同量纲的因子具有可比性,且均值为0、标准差为1。 + +**fill_nan_with_daily_median(截面中位数填充函数)**: + +对每个交易日分别计算各因子的中位数,用该中位数填充该日内的缺失值。这一方法比全局中位数填充更能反映数据的时间特性。 + +### 7.3 其他预处理函数 + +第十五个代码单元还定义了其他辅助预处理函数: + +**remove_outliers_label_percentile**:对标签进行百分位去极值,默认保留1%到99%分位之间的数据。 + +**calculate_risk_adjusted_target**:计算风险调整后的目标变量,结合未来收益率和未来波动性。 + +**calculate_score**:计算综合评分目标,考虑未来收益减去风险惩罚(最大回撤)。 + +**remove_highly_correlated_features**:移除高度相关的特征,避免多重共线性问题。 + +**cross_sectional_standardization**:截面标准化,使用StandardScaler进行标准化。 + +**neutralize_manual_revised**:手动实现的行业市值中性化函数,对每个行业分别进行回归取残差。 + +**mad_filter**:全局MAD去极值函数(简化版)。 + +**percentile_filter**:百分位去极值函数。 + +**iqr_filter**:四分位距标准化函数。 + +**quantile_filter**:滚动分位数去极值函数。 + +**select_top_features_by_rankic**:基于RankIC选择最优特征。 + +**create_deviation_within_dates**:创建截面偏差特征,计算行业内各因子与均值的偏差。 + +--- + +## 第八部分:目标变量构建 + +### 8.1 未来收益率计算 + +第十六个代码单元负责构建预测目标变量。 + +**未来收益率**:通过`df.groupby('ts_code')['close'].shift(-days) / df['close'] - 1`计算,days默认为5,即预测未来5日的收益率。shift(-days)表示向后移动获取未来价格。 + +**涨停标签**:构建分类目标变量,首先识别当日涨幅超过5%的股票(`cat_up_limit`),然后计算过去5日内是否存在涨停(通过rolling窗口的max函数),再向后shift(5)避免未来信息泄露,最后fillna(0)并转换为整数类型。 + +**过滤异常收益率**:使用`between`方法保留1%到99%分位之间的收益率,去除极端值。 + +--- + +## 第九部分:特征列筛选 + +### 9.1 特征列筛选逻辑 + +第十七个代码单元进行特征列的最终筛选。通过一系列排除规则,从所有可用列中筛选出建模所需的特征列: + +**排除非特征列**:排除`trade_date`、`ts_code`、`label`等标识列;排除包含`future`、`label`、`score`、`gen`、`is_st`、`pe_ttm`、`circ_mv`、`code`等关键词的列;排除原始基础数据列`origin_columns`;排除下划线开头的临时列。 + +**排除特定特征**:手动排除存在问题的特征如`intraday_lg_flow_corr_20`、`cap_neutral_cost_metric`、`hurst_net_mf_vol_60`等;排除财务因子`roa`和`roe`。 + +最终筛选出191个特征列用于建模。 + +--- + +## 第十部分:缺失值处理与标准化流程 + +### 10.1 缺失值填充 + +第十八个代码单元将所有特征的缺失值填充为0。这一简化处理的假设是:缺失值可能代表数据不可用或极小概率事件,用0填充对模型影响相对中性。 + +### 10.2 完整的预处理流水线 + +第十九个代码单元展示了完整的预处理流程执行: + +**第一步:特征列确定**——合并行业数据、指数数据后确定最终特征列表。 + +**第二步:MAD去极值**——执行第一轮截面MAD去极值,处理全量特征。 + +**第三步:市值中性化**——对第一轮去极值后的特征进行截面市值中性化。 + +**第四步:第二轮MAD去极值**——对中性化后的特征再次进行去极值处理。 + +**第五步:Z-Score标准化**——最后对所有特征进行截面Z-Score标准化。 + +整个流水线确保了最终进入模型的特征具有:稳定的分布(去极值)、去除市值偏差(中性化)、可比性(标准化)。 + +--- + +## 总结 + +`Classify2_load_model.ipynb`实现了一个完整、规范的量化因子工程流程。整个流程涵盖了数据加载、多源数据合并、因子计算、特征筛选、预处理标准化等关键步骤。 + +**数据层面**:整合了日线行情、资金流、筹码分布、财务指标、行业分类、指数数据等多维度数据源,形成了超过180个特征的基础特征池。 + +**因子层面**:因子体系覆盖了资金流因子、筹码分布因子、技术指标因子、动量因子、波动率因子、截面排名因子等多个类别,形成了多角度、全方位的特征体系。 + +**预处理层面**:建立了完整的预处理流水线,包括MAD去极值、市值中性化、Z-Score标准化等标准化步骤,确保特征的质量和稳定性。 + +整个notebook的设计体现了量化投资特征工程的最佳实践,为后续的机器学习模型训练提供了高质量的数据基础。 diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py index 13f3b89..b7a2bac 100644 --- a/src/data/api_wrappers/__init__.py +++ b/src/data/api_wrappers/__init__.py @@ -20,6 +20,7 @@ Example: """ from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync +from src.data.api_wrappers.financial_data.api_income import get_income, sync_income, IncomeSync from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks @@ -37,6 +38,10 @@ __all__ = [ "sync_daily", "preview_daily_sync", "DailySync", + # Income statement + "get_income", + "sync_income", + "IncomeSync", # Historical stock list "get_bak_basic", "sync_bak_basic", diff --git a/src/data/api_wrappers/financial_data/api_financial_sync.py b/src/data/api_wrappers/financial_data/api_financial_sync.py new file mode 100644 index 0000000..a85ad4a --- /dev/null +++ b/src/data/api_wrappers/financial_data/api_financial_sync.py @@ -0,0 +1,587 @@ +"""财务数据统一同步调度中心。 + +该模块作为财务数据同步的调度中心,统一管理各类型财务数据的同步流程。 +支持全量同步和增量同步两种模式。 + +财务数据类型: +- income: 利润表 (已实现) +- balance: 资产负债表 (预留) +- cashflow: 现金流量表 (预留) + +同步模式: +1. 全量同步 (force_full=True): + - 检查表是否存在,如不存在则建表+建索引 + - 从默认开始日期 (20180101) 同步到当前季度 + +2. 增量同步 (force_full=False): + - 获取表中最新季度 (MAX(end_date)) + - 计算当前季度(如果当前日期未到季末,则用前一季度) + - 如果最新季度 == 当前季度,不同步(避免消耗流量) + - 否则从最新季度+1 同步到当前季度 + +使用方式: + # 增量同步利润表数据(推荐) + from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial + sync_financial() + + # 全量同步利润表数据 + from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial + sync_financial(force_full=True) + + # 预览同步 + from src.data.api_wrappers.financial_data.api_financial_sync import preview_sync + preview = preview_sync() +""" + +from typing import Optional, Dict, List +from datetime import datetime + +import pandas as pd + +from src.data.storage import Storage, ThreadSafeStorage +from src.data.utils import ( + get_today_date, + get_quarters_in_range, + date_to_quarter, + DEFAULT_START_DATE, +) +from src.data.api_wrappers.financial_data.api_income import get_income + + +# ============================================================================= +# 财务数据表结构定义 +# ============================================================================= + +# 各财务数据表的表名和字段定义 +FINANCIAL_TABLES = { + "income": { + "table_name": "financial_income", + "api_name": "income_vip", + "period_field": "end_date", # 用于存储最新季度的字段 + "get_data_func": get_income, + }, + # 预留:资产负债表 + # "balance": { + # "table_name": "financial_balance", + # "api_name": "balance Sheet_vip", + # "period_field": "end_date", + # "get_data_func": get_balance, + # }, + # 预留:现金流量表 + # "cashflow": { + # "table_name": "financial_cashflow", + # "api_name": "cashflow_vip", + # "period_field": "end_date", + # "get_data_func": get_cashflow, + # }, +} + + +# ============================================================================= +# 财务数据同步核心类 +# ============================================================================= + + +class FinancialSync: + """财务数据统一同步管理器。 + + 支持全量同步和增量同步,自动检测数据状态并选择最优同步策略。 + + 功能特性: + - 全量/增量同步自动切换 + - 自动建表和索引(如不存在) + - 智能季度计算(当前季度未到季末时使用前一季度) + - 流量保护(最新季度==当前季度时不请求API) + + Example: + >>> sync = FinancialSync() + >>> sync.sync_all() # 增量同步所有财务数据 + >>> sync.sync_all(force_full=True) # 全量同步 + >>> sync.sync_income() # 只同步利润表 + """ + + def __init__(self): + """初始化同步管理器""" + self.storage = Storage() + self.thread_storage = ThreadSafeStorage() + + def _create_table_if_not_exists(self, table_name: str) -> None: + """如果表不存在则创建表和索引。 + + Args: + table_name: 表名 + """ + if self.storage.exists(table_name): + print(f"[FinancialSync] 表 {table_name} 已存在,跳过建表") + return + + print(f"[FinancialSync] 表 {table_name} 不存在,创建表和索引...") + + # 根据表名创建不同的表结构 + if table_name == "financial_income": + self.storage._connection.execute(f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + ts_code VARCHAR(16) NOT NULL, + ann_date DATE, + f_ann_date DATE, + end_date DATE NOT NULL, + report_type INTEGER, + comp_type INTEGER, + basic_eps DOUBLE, + diluted_eps DOUBLE, + PRIMARY KEY (ts_code, end_date) + ) + """) + # 创建索引 + self.storage._connection.execute(f""" + CREATE INDEX IF NOT EXISTS idx_financial_ann + ON {table_name}(ts_code, ann_date) + """) + else: + # 默认表结构 + self.storage._connection.execute(f""" + CREATE TABLE IF NOT EXISTS {table_name} ( + ts_code VARCHAR(16) NOT NULL, + end_date DATE NOT NULL, + PRIMARY KEY (ts_code, end_date) + ) + """) + + print(f"[FinancialSync] 表 {table_name} 创建完成") + + def _get_latest_quarter( + self, table_name: str, period_field: str = "end_date" + ) -> Optional[str]: + """获取表中最新季度。 + + Args: + table_name: 表名 + period_field: 季度字段名 + + Returns: + 最新季度字符串 (YYYYMMDD),如无数据返回 None + """ + try: + result = self.storage._connection.execute(f""" + SELECT MAX({period_field}) FROM {table_name} + """).fetchone() + + if result and result[0]: + # 转换为字符串格式 + latest = result[0] + if hasattr(latest, "strftime"): + return latest.strftime("%Y%m%d") + return str(latest) + return None + except Exception as e: + print(f"[FinancialSync] 获取最新季度失败: {e}") + return None + + def _get_current_quarter(self) -> str: + """获取当前季度(考虑是否到季末)。 + + 如果当前日期未到季度最后一天,则返回前一季度。 + 这样可以避免获取尚无数据的未来季度。 + + Returns: + 当前季度字符串 (YYYYMMDD) + """ + today = get_today_date() + current_quarter = date_to_quarter(today) + + # 检查今天是否到了当前季度的最后一天 + if today < current_quarter: + # 未到季末,返回前一季度 + return self._get_prev_quarter(current_quarter) + + return current_quarter + + def _get_prev_quarter(self, quarter: str) -> str: + """获取前一季度。 + + Args: + quarter: 季度字符串 (YYYYMMDD) + + Returns: + 前一季度字符串 (YYYYMMDD) + """ + year = int(quarter[:4]) + month_day = quarter[4:] + + if month_day == "0331": + # Q1 -> 去年 Q4 + return f"{year - 1}1231" + elif month_day == "0630": + # Q2 -> Q1 + return f"{year}0331" + elif month_day == "0930": + # Q3 -> Q2 + return f"{year}0630" + else: # "1231" + # Q4 -> Q3 + return f"{year}0930" + + def _get_next_quarter(self, quarter: str) -> str: + """获取下一季度。 + + Args: + quarter: 季度字符串 (YYYYMMDD) + + Returns: + 下一季度字符串 (YYYYMMDD) + """ + year = int(quarter[:4]) + month_day = quarter[4:] + + if month_day == "0331": + # Q1 -> Q2 + return f"{year}0630" + elif month_day == "0630": + # Q2 -> Q3 + return f"{year}0930" + elif month_day == "0930": + # Q3 -> Q4 + return f"{year}1231" + else: # "1231" + # Q4 -> 明年 Q1 + return f"{year + 1}0331" + + def _check_incremental_needed( + self, + table_name: str, + period_field: str = "end_date", + ) -> tuple[bool, Optional[str], Optional[str]]: + """检查增量同步是否需要。 + + Args: + table_name: 表名 + period_field: 季度字段名 + + Returns: + (是否需要同步, 起始季度, 目标季度) + - 如果不需要同步,返回 (False, None, None) + """ + # 获取表中最新季度 + latest_quarter = self._get_latest_quarter(table_name, period_field) + + if latest_quarter is None: + # 无本地数据,需要全量同步 + print(f"[FinancialSync] 表 {table_name} 无数据,需要全量同步") + return (True, DEFAULT_START_DATE, self._get_current_quarter()) + + print(f"[FinancialSync] 表 {table_name} 最新季度: {latest_quarter}") + + # 获取当前季度(考虑是否到季末) + current_quarter = self._get_current_quarter() + print(f"[FinancialSync] 当前季度: {current_quarter}") + + # 比较:如果最新季度 >= 当前季度,不需要同步 + if latest_quarter >= current_quarter: + print( + f"[FinancialSync] 最新季度 {latest_quarter} >= 当前季度 {current_quarter},跳过增量同步" + ) + return (False, None, None) + + # 需要增量同步:从最新季度+1 到 当前季度 + start_quarter = self._get_next_quarter(latest_quarter) + print(f"[FinancialSync] 增量同步: {start_quarter} -> {current_quarter}") + + return (True, start_quarter, current_quarter) + + def _sync_single_table( + self, + table_config: Dict, + start_quarter: str, + end_quarter: str, + ) -> int: + """同步单个财务数据表。 + + Args: + table_config: 表配置字典 + start_quarter: 起始季度 + end_quarter: 目标季度 + + Returns: + 同步的记录数 + """ + table_name = table_config["table_name"] + get_data_func = table_config["get_data_func"] + + # 获取需要同步的季度列表 + quarters = get_quarters_in_range(start_quarter, end_quarter) + print(f"[FinancialSync] 计划同步 {len(quarters)} 个季度: {quarters}") + + total_records = 0 + + # 对每个季度调用 API 获取数据 + for period in quarters: + try: + df = get_data_func(period) + if df.empty: + print(f"[WARN] 季度 {period} 无数据") + continue + + # 只保留合并报表 (report_type='1',注意是字符串) + if "report_type" in df.columns: + df = df[df["report_type"] == "1"] + + if not df.empty: + self.thread_storage.queue_save(table_name, df) + print(f"[FinancialSync] 季度 {period} -> {len(df)} 条记录") + total_records += len(df) + + except Exception as e: + print(f"[ERROR] 获取季度 {period} 数据失败: {e}") + + # 刷新缓存到数据库 + self.thread_storage.flush() + + return total_records + + def sync_income( + self, + force_full: bool = False, + ) -> Dict: + """同步利润表数据。 + + Args: + force_full: 若为 True,强制全量同步 + + Returns: + 同步结果字典 + """ + table_config = FINANCIAL_TABLES["income"] + table_name = table_config["table_name"] + period_field = table_config["period_field"] + + print("\n" + "=" * 60) + print(f"[FinancialSync] 开始同步利润表 (force_full={force_full})") + print("=" * 60) + + # 1. 全量同步:建表 + if force_full: + self._create_table_if_not_exists(table_name) + start_quarter = DEFAULT_START_DATE + end_quarter = self._get_current_quarter() + else: + # 2. 增量同步:检查是否需要 + needed, start_quarter, end_quarter = self._check_incremental_needed( + table_name, period_field + ) + + if not needed: + return { + "status": "skipped", + "message": "数据已是最新", + "table": table_name, + } + + # 检查表是否存在,不存在则创建 + if not self.storage.exists(table_name): + self._create_table_if_not_exists(table_name) + + # 3. 执行同步 + print(f"[FinancialSync] 同步范围: {start_quarter} -> {end_quarter}") + total_records = self._sync_single_table( + table_config, start_quarter, end_quarter + ) + + result = { + "status": "success", + "table": table_name, + "start_quarter": start_quarter, + "end_quarter": end_quarter, + "records": total_records, + } + + print(f"[FinancialSync] 利润表同步完成: {total_records} 条记录") + return result + + def sync_all( + self, + force_full: bool = False, + ) -> Dict[str, Dict]: + """同步所有财务数据表。 + + Args: + force_full: 若为 True,强制全量同步 + + Returns: + 各表同步结果字典 + """ + results = {} + + print("\n" + "=" * 60) + print(f"[FinancialSync] 开始同步所有财务数据 (force_full={force_full})") + print("=" * 60) + + # 同步各财务数据表 + for data_type, table_config in FINANCIAL_TABLES.items(): + try: + if data_type == "income": + result = self.sync_income(force_full=force_full) + else: + # 预留其他表的同步逻辑 + print(f"[FinancialSync] {data_type} 暂未实现,跳过") + result = {"status": "not_implemented"} + + results[data_type] = result + + except Exception as e: + print(f"[ERROR] 同步 {data_type} 失败: {e}") + results[data_type] = {"status": "error", "error": str(e)} + + # 打印汇总 + print("\n" + "=" * 60) + print("[FinancialSync] 同步汇总") + print("=" * 60) + for data_type, result in results.items(): + status = result.get("status", "unknown") + records = result.get("records", 0) + print(f" {data_type}: {status} ({records} records)") + print("=" * 60) + + return results + + +# ============================================================================= +# 便捷函数 +# ============================================================================= + + +def sync_financial( + data_type: str = "income", + force_full: bool = False, +) -> Dict: + """同步财务数据(便捷函数)。 + + Args: + data_type: 财务数据类型 ('income', 'balance', 'cashflow') + force_full: 若为 True,强制全量同步 + + Returns: + 同步结果字典 + + Example: + >>> # 增量同步利润表 + >>> sync_financial() + >>> # 全量同步 + >>> sync_financial(force_full=True) + """ + syncer = FinancialSync() + + if data_type == "income": + return syncer.sync_income(force_full=force_full) + else: + raise ValueError(f"不支持的财务数据类型: {data_type}") + + +def sync_all_financial(force_full: bool = False) -> Dict[str, Dict]: + """同步所有财务数据(便捷函数)。 + + Args: + force_full: 若为 True,强制全量同步 + + Returns: + 各表同步结果字典 + + Example: + >>> # 增量同步所有财务数据 + >>> sync_all_financial() + >>> # 全量同步 + >>> sync_all_financial(force_full=True) + """ + syncer = FinancialSync() + return syncer.sync_all(force_full=force_full) + + +def preview_sync() -> Dict: + """预览同步信息(不实际同步)。 + + Returns: + 预览信息字典: + { + 'income': { + 'sync_needed': bool, + 'latest_quarter': str, + 'current_quarter': str, + 'start_quarter': str, + 'end_quarter': str, + }, + ... + } + """ + syncer = FinancialSync() + preview = {} + + for data_type, table_config in FINANCIAL_TABLES.items(): + if data_type != "income": + continue + + table_name = table_config["table_name"] + period_field = table_config["period_field"] + + # 获取最新季度 + latest_quarter = syncer._get_latest_quarter(table_name, period_field) + current_quarter = syncer._get_current_quarter() + + # 检查是否需要同步 + needed, start_quarter, end_quarter = syncer._check_incremental_needed( + table_name, period_field + ) + + preview[data_type] = { + "sync_needed": needed, + "latest_quarter": latest_quarter, + "current_quarter": current_quarter, + "start_quarter": start_quarter, + "end_quarter": end_quarter, + } + + return preview + + +# ============================================================================= +# 主程序入口 +# ============================================================================= + + +if __name__ == "__main__": + import sys + + print("=" * 60) + print("财务数据同步模块") + print("=" * 60) + print("\n使用方式:") + print(" # 预览同步信息") + print( + " from src.data.api_wrappers.financial_data.api_financial_sync import preview_sync" + ) + print(" preview = preview_sync()") + print("") + print(" # 增量同步(推荐)") + print( + " from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial" + ) + print(" sync_financial()") + print("") + print(" # 全量同步") + print(" sync_financial(force_full=True)") + print("") + print(" # 同步所有财务数据") + print( + " from src.data.api_wrappers.financial_data.api_financial_sync import sync_all_financial" + ) + print(" sync_all_financial()") + print("=" * 60) + + # 默认执行增量同步 + if len(sys.argv) > 1 and sys.argv[1] == "--full": + print("\n[Main] 执行全量同步...") + result = sync_all_financial(force_full=True) + else: + print("\n[Main] 执行增量同步...") + result = sync_financial() + + print("\n[Main] 执行完成!") + print(f"结果: {result}") diff --git a/src/data/api_wrappers/financial_data/api_income.py b/src/data/api_wrappers/financial_data/api_income.py new file mode 100644 index 0000000..fd66ca2 --- /dev/null +++ b/src/data/api_wrappers/financial_data/api_income.py @@ -0,0 +1,139 @@ +"""利润表数据接口 (VIP 版本) + +使用 Tushare VIP 接口 (income_vip) 获取利润表数据。 +按季度同步,一次请求获取一个季度的全部上市公司数据。 + +接口说明: +- income_vip: 获取某一季度全部上市公司利润表数据 +- 需要 5000 积分才能调用 +- period 参数为报告期(季度最后一天,如 20231231) +""" + +import pandas as pd +from typing import Optional, List +from tqdm import tqdm + +from src.data.client import TushareClient +from src.data.storage import ThreadSafeStorage +from src.data.utils import get_today_date, get_quarters_in_range + + +def get_income( + period: str, + fields: Optional[str] = None, +) -> pd.DataFrame: + """获取利润表数据 (VIP 接口) + + 从 Tushare 获取指定季度的全部上市公司利润表数据。 + + Args: + period: 报告期,季度最后一天日期 (如 '20231231', '20230930') + - 0331: 一季报 + - 0630: 半年报 + - 0930: 三季报 + - 1231: 年报 + fields: 指定返回字段,默认返回全部字段 + + Returns: + pd.DataFrame 包含利润表数据: + - ts_code: 股票代码 + - ann_date: 公告日期 + - end_date: 报告期 + - basic_eps: 基本每股收益 + - report_type: 报告类型 (1=合并报表) + + Example: + >>> data = get_income('20231231') + >>> print(data[['ts_code', 'ann_date', 'basic_eps']].head()) + """ + client = TushareClient() + # 默认字段:返回全部字段(利润表有100+字段) + if fields is None: + fields = "ts_code,ann_date,f_ann_date,end_date,report_type,comp_type,end_type,basic_eps,diluted_eps,total_revenue,revenue,int_income,prem_earned,comm_income,n_commis_income,n_oth_income,n_oth_b_income,prem_income,out_prem,une_prem_reser,reins_income,n_sec_tb_income,n_sec_uw_income,n_asset_mg_income,oth_b_income,fv_value_chg_gain,invest_income,ass_invest_income,forex_gain,total_cogs,oper_cost,int_exp,comm_exp,biz_tax_surchg,sell_exp,admin_exp,fin_exp,assets_impair_loss,prem_refund,compens_payout,reser_insur_liab,div_payt,reins_exp,oper_exp,compens_payout_refu,insur_reser_refu,reins_cost_refund,other_bus_cost,operate_profit,non_oper_income,non_oper_exp,nca_disploss,total_profit,income_tax,n_income,n_income_attr_p,minority_gain,oth_compr_income,t_compr_income,compr_inc_attr_p,compr_inc_attr_m_s,ebit,ebitda,insurance_exp,undist_profit,distable_profit,rd_exp,fin_exp_int_exp,fin_exp_int_inc,transfer_surplus_rese,transfer_housing_imprest,transfer_oth,adj_lossgain,withdra_legal_surplus,withdra_legal_pubfund,withdra_biz_devfund,withdra_rese_fund,withdra_oth_ersu,workers_welfare,distr_profit_shrhder,prfshare_payable_dvd,comshare_payable_dvd,capit_comstock_div,net_after_nr_lp_correct,credit_impa_loss,net_expo_hedging_benefits,oth_impair_loss_assets,total_opcost,amodcost_fin_assets,oth_income,asset_disp_income,continued_net_profit,end_net_profit,update_flag" + + params = {"fields": fields, "period": period} + return client.query("income_vip", **params) + + +# ============================================================================= +# IncomeSync - 利润表数据批量同步类 +# ============================================================================= + + +class IncomeSync: + """利润表数据批量同步管理器 (VIP 版本) + + 功能特性: + - 按季度同步,每次请求获取该季度全部上市公司数据 + - 使用 income_vip 接口 + - 只保留合并报表(report_type=1) + - 使用 ThreadSafeStorage 安全写入 + + Example: + >>> sync = IncomeSync() + >>> sync.sync(start_date='20200101', end_date='20231231') + """ + + def __init__(self): + """初始化同步管理器""" + self.storage = ThreadSafeStorage() + self.client = TushareClient() + + def sync( + self, + start_date: str, + end_date: Optional[str] = None, + ) -> None: + """同步利润表数据 + + Args: + start_date: 开始日期 YYYYMMDD + end_date: 结束日期 YYYYMMDD(默认为今天) + """ + if end_date is None: + end_date = get_today_date() + + # 获取日期范围内的所有季度 + quarters = get_quarters_in_range(start_date, end_date) + print(f"[IncomeSync] 计划同步 {len(quarters)} 个季度: {quarters}") + + # 对每个季度调用 income_vip 获取全部股票数据 + for period in tqdm(quarters, desc="Syncing income by quarter"): + try: + df = get_income(period) + if df.empty: + print(f"[WARN] 季度 {period} 无数据") + continue + + # 只保留合并报表 (report_type='1',注意是字符串) + if "report_type" in df.columns: + df = df[df["report_type"] == "1"] + + if not df.empty: + self.storage.queue_save("financial_income", df) + print(f"[IncomeSync] 季度 {period} -> {len(df)} 条记录") + + except Exception as e: + print(f"[ERROR] 获取季度 {period} 数据失败: {e}") + + # 刷新缓存到数据库 + self.storage.flush() + print(f"[IncomeSync] 同步完成,共处理 {len(quarters)} 个季度") + + +def sync_income( + start_date: str, + end_date: Optional[str] = None, +) -> None: + """同步利润表数据(便捷函数) + + Args: + start_date: 开始日期 YYYYMMDD + end_date: 结束日期 YYYYMMDD(默认为今天) + + Example: + >>> sync_income('20200101') + >>> sync_income('20200101', '20231231') + """ + syncer = IncomeSync() + syncer.sync(start_date, end_date) diff --git a/src/data/api_wrappers/financial_data/financial_api.md b/src/data/api_wrappers/financial_data/financial_api.md new file mode 100644 index 0000000..6e0d629 --- /dev/null +++ b/src/data/api_wrappers/financial_data/financial_api.md @@ -0,0 +1,145 @@ +利润表 +接口:income,可以通过数据工具调试和查看数据。 +描述:获取上市公司财务利润表数据 +积分:用户需要至少2000积分才可以调取,具体请参阅积分获取办法 + +提示:当前接口只能按单只股票获取其历史数据,如果需要获取某一季度全部上市公司数据,请使用income_vip接口(参数一致),需积攒5000积分。 + +输入参数 + +名称 类型 必选 描述 +ts_code str Y 股票代码 +ann_date str N 公告日期(YYYYMMDD格式,下同) +f_ann_date str N 实际公告日期 +start_date str N 公告日开始日期 +end_date str N 公告日结束日期 +period str N 报告期(每个季度最后一天的日期,比如20171231表示年报,20170630半年报,20170930三季报) +report_type str N 报告类型,参考文档最下方说明 +comp_type str N 公司类型(1一般工商业2银行3保险4证券) +输出参数 + +名称 类型 默认显示 描述 +ts_code str Y TS代码 +ann_date str Y 公告日期 +f_ann_date str Y 实际公告日期 +end_date str Y 报告期 +report_type str Y 报告类型 见底部表 +comp_type str Y 公司类型(1一般工商业2银行3保险4证券) +end_type str Y 报告期类型 +basic_eps float Y 基本每股收益 +diluted_eps float Y 稀释每股收益 +total_revenue float Y 营业总收入 +revenue float Y 营业收入 +int_income float Y 利息收入 +prem_earned float Y 已赚保费 +comm_income float Y 手续费及佣金收入 +n_commis_income float Y 手续费及佣金净收入 +n_oth_income float Y 其他经营净收益 +n_oth_b_income float Y 加:其他业务净收益 +prem_income float Y 保险业务收入 +out_prem float Y 减:分出保费 +une_prem_reser float Y 提取未到期责任准备金 +reins_income float Y 其中:分保费收入 +n_sec_tb_income float Y 代理买卖证券业务净收入 +n_sec_uw_income float Y 证券承销业务净收入 +n_asset_mg_income float Y 受托客户资产管理业务净收入 +oth_b_income float Y 其他业务收入 +fv_value_chg_gain float Y 加:公允价值变动净收益 +invest_income float Y 加:投资净收益 +ass_invest_income float Y 其中:对联营企业和合营企业的投资收益 +forex_gain float Y 加:汇兑净收益 +total_cogs float Y 营业总成本 +oper_cost float Y 减:营业成本 +int_exp float Y 减:利息支出 +comm_exp float Y 减:手续费及佣金支出 +biz_tax_surchg float Y 减:营业税金及附加 +sell_exp float Y 减:销售费用 +admin_exp float Y 减:管理费用 +fin_exp float Y 减:财务费用 +assets_impair_loss float Y 减:资产减值损失 +prem_refund float Y 退保金 +compens_payout float Y 赔付总支出 +reser_insur_liab float Y 提取保险责任准备金 +div_payt float Y 保户红利支出 +reins_exp float Y 分保费用 +oper_exp float Y 营业支出 +compens_payout_refu float Y 减:摊回赔付支出 +insur_reser_refu float Y 减:摊回保险责任准备金 +reins_cost_refund float Y 减:摊回分保费用 +other_bus_cost float Y 其他业务成本 +operate_profit float Y 营业利润 +non_oper_income float Y 加:营业外收入 +non_oper_exp float Y 减:营业外支出 +nca_disploss float Y 其中:减:非流动资产处置净损失 +total_profit float Y 利润总额 +income_tax float Y 所得税费用 +n_income float Y 净利润(含少数股东损益) +n_income_attr_p float Y 净利润(不含少数股东损益) +minority_gain float Y 少数股东损益 +oth_compr_income float Y 其他综合收益 +t_compr_income float Y 综合收益总额 +compr_inc_attr_p float Y 归属于母公司(或股东)的综合收益总额 +compr_inc_attr_m_s float Y 归属于少数股东的综合收益总额 +ebit float Y 息税前利润 +ebitda float Y 息税折旧摊销前利润 +insurance_exp float Y 保险业务支出 +undist_profit float Y 年初未分配利润 +distable_profit float Y 可分配利润 +rd_exp float Y 研发费用 +fin_exp_int_exp float Y 财务费用:利息费用 +fin_exp_int_inc float Y 财务费用:利息收入 +transfer_surplus_rese float Y 盈余公积转入 +transfer_housing_imprest float Y 住房周转金转入 +transfer_oth float Y 其他转入 +adj_lossgain float Y 调整以前年度损益 +withdra_legal_surplus float Y 提取法定盈余公积 +withdra_legal_pubfund float Y 提取法定公益金 +withdra_biz_devfund float Y 提取企业发展基金 +withdra_rese_fund float Y 提取储备基金 +withdra_oth_ersu float Y 提取任意盈余公积金 +workers_welfare float Y 职工奖金福利 +distr_profit_shrhder float Y 可供股东分配的利润 +prfshare_payable_dvd float Y 应付优先股股利 +comshare_payable_dvd float Y 应付普通股股利 +capit_comstock_div float Y 转作股本的普通股股利 +net_after_nr_lp_correct float N 扣除非经常性损益后的净利润(更正前) +credit_impa_loss float N 信用减值损失 +net_expo_hedging_benefits float N 净敞口套期收益 +oth_impair_loss_assets float N 其他资产减值损失 +total_opcost float N 营业总成本(二) +amodcost_fin_assets float N 以摊余成本计量的金融资产终止确认收益 +oth_income float N 其他收益 +asset_disp_income float N 资产处置收益 +continued_net_profit float N 持续经营净利润 +end_net_profit float N 终止经营净利润 +update_flag str Y 更新标识 +接口使用说明 + +pro = ts.pro_api() + +df = pro.income(ts_code='600000.SH', start_date='20180101', end_date='20180730', fields='ts_code,ann_date,f_ann_date,end_date,report_type,comp_type,basic_eps,diluted_eps') +获取某一季度全部股票数据 + +df = pro.income_vip(period='20181231',fields='ts_code,ann_date,f_ann_date,end_date,report_type,comp_type,basic_eps,diluted_eps') +数据样例 + + ts_code ann_date f_ann_date end_date report_type comp_type basic_eps diluted_eps \ +0 600000.SH 20180428 20180428 20180331 1 2 0.46 0.46 +1 600000.SH 20180428 20180428 20180331 1 2 0.46 0.46 +2 600000.SH 20180428 20180428 20171231 1 2 1.84 1.84 +主要报表类型说明 + +代码 | 类型 | 说明 +---- | ----- | ---- | +1 | 合并报表 | 上市公司最新报表(默认) +2 | 单季合并 | 单一季度的合并报表 +3 | 调整单季合并表 | 调整后的单季合并报表(如果有) +4 | 调整合并报表 | 本年度公布上年同期的财务报表数据,报告期为上年度 +5 | 调整前合并报表 | 数据发生变更,将原数据进行保留,即调整前的原数据 +6 | 母公司报表 | 该公司母公司的财务报表数据 +7 | 母公司单季表 | 母公司的单季度表 +8 | 母公司调整单季表 | 母公司调整后的单季表 +9 | 母公司调整表 | 该公司母公司的本年度公布上年同期的财务报表数据 +10 | 母公司调整前报表 | 母公司调整之前的原始财务报表数据 +11 | 母公司调整前合并报表 | 母公司调整之前合并报表原数据 +12 | 母公司调整前报表 | 母公司报表发生变更前保留的原数据 \ No newline at end of file diff --git a/src/data/storage.py b/src/data/storage.py index c526e76..9e285da 100644 --- a/src/data/storage.py +++ b/src/data/storage.py @@ -90,6 +90,113 @@ class Storage: CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code) """) + # Create financial_income table for income statement data + # 完整的利润表字段(94列全部) + self._connection.execute(""" + CREATE TABLE IF NOT EXISTS financial_income ( + ts_code VARCHAR(16) NOT NULL, + ann_date DATE, + f_ann_date DATE, + end_date DATE NOT NULL, + report_type INTEGER, + comp_type INTEGER, + end_type VARCHAR(10), + basic_eps DOUBLE, + diluted_eps DOUBLE, + total_revenue DOUBLE, + revenue DOUBLE, + int_income DOUBLE, + prem_earned DOUBLE, + comm_income DOUBLE, + n_commis_income DOUBLE, + n_oth_income DOUBLE, + n_oth_b_income DOUBLE, + prem_income DOUBLE, + out_prem DOUBLE, + une_prem_reser DOUBLE, + reins_income DOUBLE, + n_sec_tb_income DOUBLE, + n_sec_uw_income DOUBLE, + n_asset_mg_income DOUBLE, + oth_b_income DOUBLE, + fv_value_chg_gain DOUBLE, + invest_income DOUBLE, + ass_invest_income DOUBLE, + forex_gain DOUBLE, + total_cogs DOUBLE, + oper_cost DOUBLE, + int_exp DOUBLE, + comm_exp DOUBLE, + biz_tax_surchg DOUBLE, + sell_exp DOUBLE, + admin_exp DOUBLE, + fin_exp DOUBLE, + assets_impair_loss DOUBLE, + prem_refund DOUBLE, + compens_payout DOUBLE, + reser_insur_liab DOUBLE, + div_payt DOUBLE, + reins_exp DOUBLE, + oper_exp DOUBLE, + compens_payout_refu DOUBLE, + insur_reser_refu DOUBLE, + reins_cost_refund DOUBLE, + other_bus_cost DOUBLE, + operate_profit DOUBLE, + non_oper_income DOUBLE, + non_oper_exp DOUBLE, + nca_disploss DOUBLE, + total_profit DOUBLE, + income_tax DOUBLE, + n_income DOUBLE, + n_income_attr_p DOUBLE, + minority_gain DOUBLE, + oth_compr_income DOUBLE, + t_compr_income DOUBLE, + compr_inc_attr_p DOUBLE, + compr_inc_attr_m_s DOUBLE, + ebit DOUBLE, + ebitda DOUBLE, + insurance_exp DOUBLE, + undist_profit DOUBLE, + distable_profit DOUBLE, + rd_exp DOUBLE, + fin_exp_int_exp DOUBLE, + fin_exp_int_inc DOUBLE, + transfer_surplus_rese DOUBLE, + transfer_housing_imprest DOUBLE, + transfer_oth DOUBLE, + adj_lossgain DOUBLE, + withdra_legal_surplus DOUBLE, + withdra_legal_pubfund DOUBLE, + withdra_biz_devfund DOUBLE, + withdra_rese_fund DOUBLE, + withdra_oth_ersu DOUBLE, + workers_welfare DOUBLE, + distr_profit_shrhder DOUBLE, + prfshare_payable_dvd DOUBLE, + comshare_payable_dvd DOUBLE, + capit_comstock_div DOUBLE, + net_after_nr_lp_correct DOUBLE, + credit_impa_loss DOUBLE, + net_expo_hedging_benefits DOUBLE, + oth_impair_loss_assets DOUBLE, + total_opcost DOUBLE, + amodcost_fin_assets DOUBLE, + oth_income DOUBLE, + asset_disp_income DOUBLE, + continued_net_profit DOUBLE, + end_net_profit DOUBLE, + update_flag VARCHAR(1), + PRIMARY KEY (ts_code, end_date) + ) + """) + + # Create index for financial_income + self._connection.execute(""" + CREATE INDEX IF NOT EXISTS idx_financial_ann ON financial_income(ts_code, ann_date) + """) + def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict: """Save data to DuckDB. @@ -104,13 +211,35 @@ class Storage: if data.empty: return {"status": "skipped", "rows": 0} - # Ensure date column is proper type + # 确保日期列是正确的类型 (YYYYMMDD -> date) + # trade_date: 日线数据日期 if "trade_date" in data.columns: data = data.copy() data["trade_date"] = pd.to_datetime( data["trade_date"], format="%Y%m%d" ).dt.date + # ann_date: 公告日期 + if "ann_date" in data.columns: + data = data.copy() + data["ann_date"] = pd.to_datetime( + data["ann_date"], format="%Y%m%d", errors="coerce" + ).dt.date + + # f_ann_date: 最终公告日期 + if "f_ann_date" in data.columns: + data = data.copy() + data["f_ann_date"] = pd.to_datetime( + data["f_ann_date"], format="%Y%m%d", errors="coerce" + ).dt.date + + # end_date: 报告期/期末日期 + if "end_date" in data.columns: + data = data.copy() + data["end_date"] = pd.to_datetime( + data["end_date"], format="%Y%m%d", errors="coerce" + ).dt.date + # Register DataFrame as temporary view self._connection.register("temp_data", data) diff --git a/src/data/utils.py b/src/data/utils.py index 350fc76..b63bf2c 100644 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -4,7 +4,7 @@ """ from datetime import datetime, timedelta -from typing import Optional +from typing import Optional, List # 默认全量同步开始日期 @@ -73,3 +73,74 @@ def format_date(dt: datetime) -> str: YYYYMMDD 格式的日期字符串 """ return dt.strftime("%Y%m%d") + + +def is_quarter_end(date_str: str) -> bool: + """判断是否为季度最后一天。 + + Args: + date_str: YYYYMMDD 格式的日期 + + Returns: + 是否为季度最后一天 + """ + month_day = date_str[4:] + return month_day in ("0331", "0630", "0930", "1231") + + +def date_to_quarter(date_str: str) -> str: + """将日期转换为对应季度的最后一天。 + + Args: + date_str: YYYYMMDD 格式的日期 + + Returns: + 季度最后一天,格式为 YYYYMMDD + 例如: 20230115 -> 20230331 + """ + year = date_str[:4] + month = int(date_str[4:6]) + + if month <= 3: + return year + "0331" + elif month <= 6: + return year + "0630" + elif month <= 9: + return year + "0930" + else: + return year + "1231" + + +def get_quarters_in_range(start_date: str, end_date: str) -> List[str]: + """获取日期范围内的所有季度列表。 + + Args: + start_date: 开始日期 YYYYMMDD + end_date: 结束日期 YYYYMMDD + + Returns: + 季度列表,格式为 YYYYMMDD,按时间倒序排列 + 例如: ['20231231', '20230930', '20230630', '20230331'] + """ + quarters = [] + + # 将开始日期和结束日期都转换为季度 + start_quarter = date_to_quarter(start_date) + end_quarter = date_to_quarter(end_date) + + # 解析年份 + start_year = int(start_date[:4]) + end_year = int(end_date[:4]) + + # 遍历所有年份和季度 + for year in range(end_year, start_year - 1, -1): + year_str = str(year) + # 季度顺序: Q4, Q3, Q2, Q1 (倒序) + for quarter in ["1231", "0930", "0630", "0331"]: + quarter_date = year_str + quarter + + # 只包含在范围内的季度 + if quarter_date >= start_quarter and quarter_date <= end_quarter: + quarters.append(quarter_date) + + return quarters diff --git a/src/factors/FACTOR_GUIDE.md b/src/factors/FACTOR_GUIDE.md new file mode 100644 index 0000000..851c98f --- /dev/null +++ b/src/factors/FACTOR_GUIDE.md @@ -0,0 +1,1535 @@ +# ProStock 因子开发规范 + +本文档是 ProStock 因子框架的完整开发指南,涵盖从因子设计到测试的全流程规范。 + +## 目录 + +- [1. 因子框架概述](#1-因子框架概述) +- [2. 因子分类体系](#2-因子分类体系) +- [3. 因子类型选择](#3-因子类型选择) +- [4. 项目结构](#4-项目结构) +- [5. 编写步骤](#5-编写步骤) +- [6. 编码规范](#6-编码规范) +- [7. 命名规范](#7-命名规范) +- [8. 数据需求规范](#8-数据需求规范) +- [9. 防泄露机制](#9-防泄露机制) +- [10. 参数化因子](#10-参数化因子) +- [11. 因子组合](#11-因子组合) +- [12. 性能优化](#12-性能优化) +- [13. 测试规范](#13-测试规范) +- [14. 完整示例](#14-完整示例) +- [15. 常见错误](#15-常见错误) +- [16. 验证清单](#16-验证清单) +- [附录](#附录) + +--- + +## 1. 因子框架概述 + +ProStock 因子框架采用**类型安全**设计,严格区分截面因子和时序因子,在框架层面防止数据泄露。 + +### 1.1 核心组件 + +``` +src/factors/ +├── base.py # 因子基类(CrossSectionalFactor / TimeSeriesFactor) +├── data_spec.py # 数据规格定义(DataSpec, FactorData, FactorContext) +├── composite.py # 组合因子(支持因子间的加减乘除) +├── engine.py # 执行引擎(FactorEngine) +├── data_loader.py # 数据加载器(DataLoader) +├── momentum/ # 动量因子 +├── financial/ # 财务因子 +├── valuation/ # 估值因子 +├── technical/ # 技术指标因子 +├── quality/ # 质量因子 +├── sentiment/ # 情绪因子 +├── volume/ # 成交量因子 +└── volatility/ # 波动率因子 +``` + +### 1.2 关键概念 + +| 概念 | 说明 | +|------|------| +| `DataSpec` | 声明因子所需的数据源、列和回看窗口 | +| `FactorData` | 数据容器,封装 Polars DataFrame | +| `FactorContext` | 计算上下文,提供当前日期/股票信息 | +| `FactorEngine` | 执行引擎,根据因子类型采用不同计算策略 | + +### 1.3 设计原则 + +1. **类型安全**:编译时检查因子类型,防止运行时错误 +2. **防泄露**:框架层面防止未来数据和跨股票数据泄露 +3. **可组合**:支持因子间的数学运算(加减乘除) +4. **高性能**:基于 Polars 的向量化计算 +5. **可扩展**:插件化架构,易于添加新因子 + +--- + +## 2. 因子分类体系 + +### 2.1 按经济含义分类 + +| 分类 | 目录 | 说明 | 示例因子 | +|------|------|------|----------| +| **动量因子** | `momentum/` | 价格趋势和动量指标 | MA、收益率排名、动量 | +| **财务因子** | `financial/` | 财务报表相关指标 | EPS、营收增长、资产负债率 | +| **估值因子** | `valuation/` | 估值水平指标 | PE、PB、PS排名 | +| **技术指标** | `technical/` | 技术分析指标 | RSI、MACD、布林带、KDJ | +| **质量因子** | `quality/` | 公司质量指标 | ROE、ROA、盈利稳定性 | +| **情绪因子** | `sentiment/` | 市场情绪指标 | 换手率、资金流向、振幅 | +| **成交量因子** | `volume/` | 成交量相关指标 | OBV、成交量比率、量价配合 | +| **波动率因子** | `volatility/` | 波动率指标 | 历史波动率、GARCH、实现波动率 | + +### 2.2 按计算方式分类 + +| 类型 | 基类 | 计算维度 | 防泄露重点 | +|------|------|----------|------------| +| **截面因子** | `CrossSectionalFactor` | 横向:每天对所有股票计算 | 防止日期泄露 | +| **时序因子** | `TimeSeriesFactor` | 纵向:对每只股票单独计算 | 防止股票泄露 | + +### 2.3 分类选择指南 + +**如何选择因子分类?** + +1. **先确定计算方式**:你的因子是每天对所有股票计算(截面),还是对每只股票单独计算(时序)? +2. **再确定经济含义**:根据因子的经济逻辑选择对应的分类目录 +3. **特殊情况**:如果因子涉及多个维度(如行业内的时序计算),选择主要的经济含义分类 + +--- + +## 3. 因子类型选择 + +### 3.1 决策流程图 + +``` +开始编写因子 + │ + ▼ +因子计算是否涉及多只股票比较? + │ + ├── 是 → 截面因子 (CrossSectionalFactor) + │ └── 每天传入当天所有股票数据 + │ └── 示例:PE排名、市值分位数 + │ + └── 否 → 时序因子 (TimeSeriesFactor) + └── 每只股票单独传入其时间序列 + └── 示例:MA、RSI、历史波动率 +``` + +### 3.2 详细对比 + +#### 截面因子 (CrossSectionalFactor) + +**计算逻辑**:在每个交易日,对所有股票进行横向计算。 + +**防泄露边界**: +- ❌ 禁止访问未来日期的数据(日期泄露) +- ✅ 允许访问当前日期的所有股票数据 + +**数据传入**: +- `compute()` 接收的是 `[T-lookback+1, T]` 的数据 +- 包含 lookback_days 的历史数据(用于时序计算后再截面) + +**典型应用**: +- PE排名、市值分位数 +- 当日收益率排名 +- 行业分类排名 + +**使用场景判断**: +- 需要比较不同股票之间的值 +- 需要对股票进行排序或分组 +- 计算涉及多个股票的统计量(均值、标准差等) + +#### 时序因子 (TimeSeriesFactor) + +**计算逻辑**:对每只股票,在其时间序列上进行纵向计算。 + +**防泄露边界**: +- ❌ 禁止访问其他股票的数据(股票泄露) +- ✅ 允许访问该股票的完整历史数据 + +**数据传入**: +- `compute()` 接收的是单只股票的完整时间序列 +- 包含该股票在 `[start_date, end_date]` 范围内的所有数据 + +**典型应用**: +- 移动平均线 (MA) +- 相对强弱指标 (RSI) +- 历史波动率 + +**使用场景判断**: +- 只关注单只股票的历史表现 +- 计算涉及时间序列的滚动窗口 +- 不需要与其他股票比较 + +### 3.3 混合场景处理 + +**场景**:需要先计算时序指标,再进行截面排名(如20日收益率排名) + +**解决方案**: +1. 使用 **截面因子** 作为外层类型 +2. 在 `compute()` 中使用时序计算(如 `rolling_mean`) +3. 设置合适的 `lookback_days` 以确保有足够的历史数据 + +**示例**: +```python +class ReturnRankFactor(CrossSectionalFactor): + """过去n日收益率排名因子""" + data_specs = [DataSpec("daily", ["close"], lookback_days=period+1)] + + def compute(self, data: FactorData) -> pl.Series: + # 获取历史数据(包含时序信息) + df = data.to_polars() + # 计算每只股票的收益率(时序计算) + # 然后进行截面排名(截面计算) +``` + +--- + +## 4. 项目结构 + +### 4.1 完整目录结构 + +``` +src/factors/ +│ +├── __init__.py # 包入口,导出公共接口 +├── FACTOR_GUIDE.md # 本规范文档 +│ +├── base.py # 因子基类定义 +├── data_spec.py # 数据类型定义 +├── composite.py # 组合因子实现 +├── data_loader.py # 数据加载器 +├── engine.py # 执行引擎 +│ +├── momentum/ # 动量因子 +│ ├── __init__.py +│ ├── ma.py # 移动平均线 +│ └── return_rank.py # 收益率排名 +│ +├── financial/ # 财务因子 +│ ├── __init__.py +│ ├── eps_factor.py # EPS因子 +│ └── utils.py # 财务工具函数 +│ +├── valuation/ # 估值因子 +│ ├── __init__.py +│ └── [你的估值因子] +│ +├── technical/ # 技术指标因子 +│ ├── __init__.py +│ └── [你的技术指标因子] +│ +├── quality/ # 质量因子 +│ ├── __init__.py +│ └── [你的质量因子] +│ +├── sentiment/ # 情绪因子 +│ ├── __init__.py +│ └── [你的情绪因子] +│ +├── volume/ # 成交量因子 +│ ├── __init__.py +│ └── [你的成交量因子] +│ +└── volatility/ # 波动率因子 + ├── __init__.py + └── [你的波动率因子] +``` + +### 4.2 文件组织原则 + +1. **单一职责**:每个文件只包含一个因子类(或紧密相关的多个因子) +2. **分类清晰**:根据因子的经济含义放入对应目录 +3. **命名一致**:文件名使用 `snake_case`,反映因子功能 + +### 4.3 新增因子流程 + +``` +1. 确定因子类型(截面/时序) +2. 选择因子分类(momentum/financial/...) +3. 在对应目录创建文件 +4. 实现因子类 +5. 更新该目录的 __init__.py +6. 编写测试 +7. 更新主 __init__.py(如需要公开导出) +``` + +--- + +## 5. 编写步骤 + +### 步骤 1:确定因子类型和分类 + +**问题清单**: +- [ ] 因子是截面计算还是时序计算? +- [ ] 因子的经济含义属于哪个分类? +- [ ] 需要哪些数据字段? +- [ ] 是否需要参数化(如 MA 的周期)? + +### 步骤 2:创建文件 + +根据因子分类创建文件: + +```bash +# 示例:创建估值因子 +touch src/factors/valuation/pe_rank.py +``` + +### 步骤 3:定义类属性(必须) + +每个因子必须声明以下类属性: + +```python +class MyFactor(CrossSectionalFactor): # 或 TimeSeriesFactor + # 必须声明 + name: str = "my_factor" # 因子唯一标识(snake_case) + factor_type: str = "cross_sectional" # 或 "time_series" + data_specs: List[DataSpec] = [...] # 数据需求列表 + + # 可选声明 + category: str = "default" # 因子分类 + description: str = "" # 因子描述 +``` + +### 步骤 4:实现 compute 方法 + +```python +def compute(self, data: FactorData) -> pl.Series: + """核心计算逻辑 + + Args: + data: FactorData,已根据因子类型裁剪 + + Returns: + 计算得到的因子值 Series + """ + # 你的计算逻辑 + pass +``` + +### 步骤 5:更新 __init__.py + +在对应分类的 `__init__.py` 中导出因子: + +```python +# src/factors/momentum/__init__.py +from .ma import MovingAverageFactor +from .return_rank import ReturnRankFactor +from .your_factor import MyFactor # 添加你的因子 + +__all__ = [ + "MovingAverageFactor", + "ReturnRankFactor", + "MyFactor", # 添加你的因子 +] +``` + +### 步骤 6:编写测试 + +创建对应的测试文件: + +```bash +touch tests/factors/test_momentum.py +``` + +### 步骤 7:验证和文档 + +- [ ] 运行测试确保通过 +- [ ] 检查代码风格 +- [ ] 更新相关文档 + +--- + +## 6. 编码规范 + +### 6.1 文件头格式 + +```python +"""因子名称 - 一句话描述 + +本模块提供: +- FactorClassName: 详细描述 + +使用示例: + >>> from src.factors.category import FactorClassName + >>> factor = FactorClassName(param=value) +""" +``` + +### 6.2 导入顺序 + +```python +# 1. 标准库 +from typing import List, Optional + +# 2. 第三方库 +import polars as pl + +# 3. 本地模块(使用绝对导入) +from src.factors.base import CrossSectionalFactor, TimeSeriesFactor +from src.factors.data_spec import DataSpec, FactorData +``` + +### 6.3 类文档字符串(Google 风格) + +```python +class MovingAverageFactor(TimeSeriesFactor): + """移动平均线因子 + + 计算逻辑:对每只股票,计算其过去n日收盘价的移动平均值。 + + 特点: + - 参数化因子:训练时通过 period 参数指定计算窗口 + - 时序因子:每只股票单独计算,防止股票间数据泄露 + + Attributes: + period: MA计算期(天数),默认5 + + Example: + >>> ma5 = MovingAverageFactor(period=5) + >>> # 计算过去5日的收盘价均值 + """ +``` + +### 6.4 方法文档字符串 + +```python +def compute(self, data: FactorData) -> pl.Series: + """计算移动平均线 + + Args: + data: FactorData,包含单只股票的完整时间序列 + + Returns: + 移动平均值序列 + + Raises: + KeyError: 当所需列不存在于数据中 + """ +``` + +### 6.5 类型提示(强制) + +- 所有函数参数必须标注类型 +- 所有函数返回值必须标注类型 +- 使用 `Optional[X]` 表示可空类型 +- 使用 `List[X]` 表示列表类型 + +```python +def __init__(self, period: int = 5) -> None: + self.period: int = period + +def compute(self, data: FactorData) -> pl.Series: + pass +``` + +### 6.6 代码风格 + +- 使用 4 空格缩进 +- 行长度限制 88 字符(Black 默认) +- 使用 snake_case 命名变量和函数 +- 使用 PascalCase 命名类 + +--- + +## 7. 命名规范 + +### 7.1 因子名称 (name) + +- 使用 **snake_case** 格式 +- 简洁明了,反映因子含义 +- 参数化因子在 `__init__` 中动态设置名称 + +```python +# 好的命名 +name = "pe_rank" # PE排名 +name = "ma_20" # 20日移动平均 +name = "return_5d" # 5日收益率 +name = "volatility_20" # 20日波动率 + +# 不好的命名 +name = "PERank" # 错误:使用 PascalCase +name = "moving_average" # 不够具体,未体现参数 +name = "factor1" # 无意义 +``` + +### 7.2 因子分类 (category) + +使用统一的分类标签: + +| 分类 | 说明 | 示例 | +|------|------|------| +| `momentum` | 动量因子 | 收益率、MA、RSI | +| `financial` | 财务因子 | EPS、ROE、营收增长 | +| `valuation` | 估值因子 | PE、PB、PS | +| `technical` | 技术指标 | MACD、KDJ、布林带 | +| `quality` | 质量因子 | 盈利能力、稳定性 | +| `sentiment` | 情绪因子 | 换手率、资金流向 | +| `volume` | 成交量因子 | OBV、成交量比率 | +| `volatility` | 波动率因子 | 历史波动率、GARCH | + +### 7.3 文件名 + +- 使用 **snake_case** 格式 +- 反映因子功能 + +``` +ma.py # 移动平均线 +return_rank.py # 收益率排名 +eps_factor.py # EPS因子 +pe_ratio.py # 市盈率因子 +``` + +### 7.4 类名 + +- 使用 **PascalCase** 格式 +- 以 `Factor` 结尾 + +```python +class MovingAverageFactor(TimeSeriesFactor): +class ReturnRankFactor(CrossSectionalFactor): +class EPSFactor(CrossSectionalFactor): +``` + +### 7.5 命名最佳实践 + +#### 前缀/后缀约定 + +| 前缀/后缀 | 含义 | 示例 | +|-----------|------|------| +| `_rank` | 排名因子 | `pe_rank`, `pb_rank` | +| `_zscore` | Z-Score标准化 | `roe_zscore` | +| `_ma` | 移动平均 | `volume_ma`, `price_ma` | +| `_std` | 标准差 | `return_std` | +| `_skew` | 偏度 | `return_skew` | +| `_kurt` | 峰度 | `return_kurt` | + +#### 时间周期表示 + +```python +# 推荐 +name = "return_5d" # 5日收益率 +name = "return_1m" # 1月收益率 +name = "return_1y" # 1年收益率 +name = "ma_20" # 20日移动平均 + +# 不推荐 +name = "return5" # 缺少单位 +name = "ma20" # 缺少分隔符 +``` + +--- + +## 8. 数据需求规范 + +### 8.1 DataSpec 定义 + +```python +from src.factors.data_spec import DataSpec + +data_specs: List[DataSpec] = [ + DataSpec( + source="daily", # 数据源名称(DuckDB 表名) + columns=["ts_code", "trade_date", "close", "volume"], # 必需列 + lookback_days=20 # 回看窗口(包含当日) + ) +] +``` + +### 8.2 必需列 + +`columns` 必须包含以下列: +- `ts_code`: 股票代码 +- `trade_date`: 交易日期 + +### 8.3 lookback_days 设置 + +- **最小值**:1(只包含当日) +- **时序因子**:设置为计算所需的完整窗口 + - MA(20) → `lookback_days=20` + - 5日收益率 → `lookback_days=6`(需要T日和T-5日) +- **截面因子**:通常为 1,除非需要历史数据进行时序计算 + +### 8.4 参数化因子的 DataSpec + +对于参数化因子,在 `__init__` 中动态创建 DataSpec: + +```python +def __init__(self, period: int = 5): + super().__init__(period=period) + # 重新创建 DataSpec 以设置正确的 lookback_days + self.data_specs = [ + DataSpec( + "daily", + ["ts_code", "trade_date", "close"], + lookback_days=period, # 使用参数 + ) + ] + self.name = f"ma_{period}" # 动态设置名称 +``` + +### 8.5 数据表和字段参考 + +#### 日线数据表 (daily) + +| 字段名 | 类型 | 说明 | +|--------|------|------| +| `ts_code` | str | 股票代码 | +| `trade_date` | str | 交易日期 (YYYYMMDD) | +| `open` | f64 | 开盘价 | +| `high` | f64 | 最高价 | +| `low` | f64 | 最低价 | +| `close` | f64 | 收盘价 | +| `pre_close` | f64 | 昨收价 | +| `change` | f64 | 涨跌额 | +| `pct_chg` | f64 | 涨跌幅 (%) | +| `vol` | f64 | 成交量(手) | +| `amount` | f64 | 成交额(千元) | +| `pe` | f64 | 市盈率 | +| `pb` | f64 | 市净率 | +| `ps` | f64 | 市销率 | +| `total_mv` | f64 | 总市值(万元) | +| `circ_mv` | f64 | 流通市值(万元) | + +#### 财务数据表 (financial_income) + +| 字段名 | 类型 | 说明 | +|--------|------|------| +| `ts_code` | str | 股票代码 | +| `trade_date` | str | 报告期 (YYYYMMDD) | +| `basic_eps` | f64 | 基本每股收益 | +| `diluted_eps` | f64 | 稀释每股收益 | +| `total_revenue` | f64 | 营业总收入 | +| `revenue` | f64 | 营业收入 | +| `total_profit` | f64 | 营业利润 | +| `net_income` | f64 | 净利润 | + +#### 使用示例 + +```python +# 日线数据 +data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close", "volume"], lookback_days=20)] + +# 财务数据 +data_specs = [DataSpec("financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1)] + +# 多数据源 +data_specs = [ + DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20), + DataSpec("financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1) +] +``` + +--- + +## 9. 防泄露机制 + +### 9.1 截面因子注意事项 + +**禁止**: +- 访问 `data.context.current_date` 之后的数据 +- 使用 `data.to_polars()` 获取未来日期的数据 + +**正确做法**: +```python +def compute(self, data: FactorData) -> pl.Series: + # 获取当前日期的截面数据(框架已裁剪到正确范围) + cs = data.get_cross_section() + + # 只使用当前日期的数据 + return cs["pe"].rank() +``` + +### 9.2 时序因子注意事项 + +**禁止**: +- 访问其他股票的数据 +- 使用包含多只股票的数据进行计算 + +**正确做法**: +```python +def compute(self, data: FactorData) -> pl.Series: + # 获取该股票的时间序列(框架已确保只有一只股票) + close = data.get_column("close") + + # 只使用该股票的历史数据 + return close.rolling_mean(window_size=self.params["period"]) +``` + +### 9.3 数据访问方法 + +| 方法 | 适用场景 | 说明 | +|------|----------|------| +| `data.get_column(col)` | 通用 | 获取指定列的 Series | +| `data.get_cross_section()` | 截面因子 | 获取当前日期的所有股票数据 | +| `data.filter_by_date(date)` | 截面因子 | 按日期过滤数据 | +| `data.to_polars()` | 高级用法 | 获取底层 DataFrame(谨慎使用) | +| `data.context` | 通用 | 获取计算上下文 | + +--- + +## 10. 参数化因子 + +### 10.1 定义方式 + +通过 `__init__` 接收参数: + +```python +class MovingAverageFactor(TimeSeriesFactor): + name: str = "ma" + factor_type: str = "time_series" + category: str = "momentum" + description: str = "移动平均线因子" + data_specs: List[DataSpec] = [] # 占位,在 __init__ 中设置 + + def __init__(self, period: int = 5): + super().__init__(period=period) + # 动态设置 data_specs 和 name + self.data_specs = [ + DataSpec( + "daily", + ["ts_code", "trade_date", "close"], + lookback_days=period, + ) + ] + self.name = f"ma_{period}" +``` + +### 10.2 参数验证 + +可选覆盖 `_validate_params` 方法进行参数验证: + +```python +def _validate_params(self): + """验证参数有效性""" + period = self.params.get("period", 5) + if period < 1: + raise ValueError(f"period must be >= 1, got {period}") + if period > 252: + raise ValueError(f"period must be <= 252, got {period}") +``` + +### 10.3 参数使用 + +在 `compute` 中通过 `self.params` 访问参数: + +```python +def compute(self, data: FactorData) -> pl.Series: + close = data.get_column("close") + period = self.params["period"] + return close.rolling_mean(window_size=period) +``` + +--- + +## 11. 因子组合 + +### 11.1 支持的运算符 + +因子框架支持因子间的数学运算: + +```python +# 因子相加 +combined = factor1 + factor2 + +# 因子相减 +diff = factor1 - factor2 + +# 因子相乘 +product = factor1 * factor2 + +# 因子相除 +ratio = factor1 / factor2 + +# 标量乘法 +scaled = 0.5 * factor1 + +# 复杂组合 +final = 0.3 * factor1 + 0.5 * factor2 - 0.2 * factor3 +``` + +### 11.2 组合约束 + +- **类型一致性**:只能组合同类型因子(截面+截面,时序+时序) +- **自动合并 data_specs**:组合因子会自动合并左右因子的数据需求 + +### 11.3 组合因子示例 + +```python +from src.factors.momentum import MovingAverageFactor, ReturnRankFactor + +# 创建基础因子 +ma20 = MovingAverageFactor(period=20) +ret5 = ReturnRankFactor(period=5) + +# 组合(注意:这里只是示例,实际不能组合不同类型因子) +# 同类型因子可以组合: +# combined = ma20 + MovingAverageFactor(period=10) +``` + +--- + +## 12. 性能优化 + +### 12.1 向量化计算 + +**推荐**:使用 Polars 的向量化操作 + +```python +def compute(self, data: FactorData) -> pl.Series: + close = data.get_column("close") + # 向量化计算 - 高效 + return close.rolling_mean(window_size=20) +``` + +**避免**:使用 Python 循环 + +```python +def compute(self, data: FactorData) -> pl.Series: + close = data.get_column("close") + # 避免:Python 循环 - 低效 + result = [] + for i in range(len(close)): + if i < 20: + result.append(None) + else: + result.append(sum(close[i-20:i]) / 20) + return pl.Series(result) +``` + +### 12.2 数据加载优化 + +1. **只加载需要的列**:在 `DataSpec` 中明确指定需要的列 +2. **合理设置 lookback_days**:不要设置过大的回看窗口 +3. **使用合适的数据类型**:Polars 会自动优化,但避免不必要的类型转换 + +### 12.3 内存优化 + +1. **避免复制大数据**:使用 Polars 的零拷贝操作 +2. **及时释放中间结果**:不需要的变量及时清理 +3. **使用惰性求值**(如适用):`pl.lazy()` + +### 12.4 计算优化技巧 + +```python +# 1. 预计算常用值 +def compute(self, data: FactorData) -> pl.Series: + close = data.get_column("close") + # 预计算 + log_close = close.log() + return log_close.rolling_mean(window_size=20) + +# 2. 使用窗口函数避免重复计算 +def compute(self, data: FactorData) -> pl.Series: + close = data.get_column("close") + # 一次计算多个统计量 + return close.rolling_mean(window_size=20), close.rolling_std(window_size=20) + +# 3. 避免重复的数据转换 +def compute(self, data: FactorData) -> pl.Series: + # 获取 DataFrame 一次 + df = data.to_polars() + # 多次使用 + mean_val = df["close"].mean() + std_val = df["close"].std() +``` + +--- + +## 13. 测试规范 + +### 13.1 测试文件位置 + +``` +tests/ +├── factors/ +│ ├── __init__.py +│ ├── test_momentum.py # 动量因子测试 +│ ├── test_financial.py # 财务因子测试 +│ ├── test_valuation.py # 估值因子测试 +│ ├── test_technical.py # 技术指标测试 +│ └── test_your_factor.py # 你的因子测试 +``` + +### 13.2 测试模板 + +```python +"""XXX因子测试 + +测试覆盖: +- 初始化测试 +- 计算逻辑测试 +- 边界条件测试 +- 参数验证测试 +""" + +import pytest +import polars as pl +from src.factors.momentum import MovingAverageFactor +from src.factors.data_spec import FactorData, FactorContext + + +class TestMovingAverageFactor: + """MovingAverageFactor 测试类""" + + def test_init(self): + """测试初始化""" + factor = MovingAverageFactor(period=10) + assert factor.name == "ma_10" + assert factor.params["period"] == 10 + assert factor.data_specs[0].lookback_days == 10 + + def test_compute(self): + """测试计算逻辑""" + factor = MovingAverageFactor(period=3) + + # 构造测试数据 + df = pl.DataFrame({ + "ts_code": ["000001.SZ"] * 5, + "trade_date": ["20240101", "20240102", "20240103", "20240104", "20240105"], + "close": [10.0, 11.0, 12.0, 13.0, 14.0], + }) + + context = FactorContext(current_stock="000001.SZ") + data = FactorData(df, context) + + # 计算 + result = factor.compute(data) + + # 验证 + assert len(result) == 5 + # 第3个值应该是 (10+11+12)/3 = 11.0 + assert result[2] == 11.0 + + def test_compute_empty_data(self): + """测试空数据处理""" + factor = MovingAverageFactor(period=5) + + df = pl.DataFrame({ + "ts_code": [], + "trade_date": [], + "close": [], + }) + + context = FactorContext(current_stock="000001.SZ") + data = FactorData(df, context) + + result = factor.compute(data) + assert len(result) == 0 + + def test_compute_single_row(self): + """测试单行数据处理""" + factor = MovingAverageFactor(period=5) + + df = pl.DataFrame({ + "ts_code": ["000001.SZ"], + "trade_date": ["20240101"], + "close": [10.0], + }) + + context = FactorContext(current_stock="000001.SZ") + data = FactorData(df, context) + + result = factor.compute(data) + assert len(result) == 1 + # 数据不足时返回 null + assert result[0] is None + + def test_validation(self): + """测试参数验证""" + with pytest.raises(ValueError): + MovingAverageFactor(period=0) + + with pytest.raises(ValueError): + MovingAverageFactor(period=-1) +``` + +### 13.3 测试要点 + +1. **初始化测试**:验证 name、params、data_specs 正确设置 +2. **计算逻辑测试**:使用构造数据验证计算结果 +3. **边界条件测试**:空数据、单条数据、缺失值处理 +4. **参数验证测试**:无效参数应抛出 ValueError +5. **防泄露测试**:确保因子不会访问不应访问的数据 + +### 13.4 运行测试 + +```bash +# 运行所有因子测试 +uv run pytest tests/factors/ + +# 运行特定测试文件 +uv run pytest tests/factors/test_momentum.py + +# 运行特定测试类 +uv run pytest tests/factors/test_momentum.py::TestMovingAverageFactor + +# 运行特定测试方法 +uv run pytest tests/factors/test_momentum.py::TestMovingAverageFactor::test_compute + +# 带覆盖率报告 +uv run pytest tests/factors/ --cov=src.factors --cov-report=term-missing +``` + +--- + +## 14. 完整示例 + +### 14.1 截面因子示例:PE 排名因子 + +```python +"""估值因子 - 市盈率排名 + +本模块提供市盈率(PE)排名因子: +- PERankFactor: PE截面排名因子 + +使用示例: + >>> from src.factors.valuation import PERankFactor + >>> pe_rank = PERankFactor() + >>> # 每天返回所有股票的PE排名(0-1之间) +""" + +from typing import List + +import polars as pl + +from src.factors.base import CrossSectionalFactor +from src.factors.data_spec import DataSpec, FactorData + + +class PERankFactor(CrossSectionalFactor): + """市盈率(PE)排名因子 + + 计算逻辑:每个交易日,对所有股票的PE进行截面排名。 + + 特点: + - 截面因子:每天对所有股票进行横向排名 + - 值域:0-1之间,0表示PE最小,1表示PE最大 + + Attributes: + name: 因子名称 "pe_rank" + category: 因子分类 "valuation" + + Example: + >>> pe_rank = PERankFactor() + >>> # 每个交易日,返回所有股票的PE排名 + """ + + name: str = "pe_rank" + factor_type: str = "cross_sectional" + category: str = "valuation" + description: str = "市盈率截面排名因子,值域0-1" + data_specs: List[DataSpec] = [ + DataSpec( + "daily", + ["ts_code", "trade_date", "pe"], + lookback_days=1 + ) + ] + + def compute(self, data: FactorData) -> pl.Series: + """计算PE排名 + + Args: + data: FactorData,包含当前日期的截面数据 + + Returns: + PE排名的0-1标准化值 + """ + # 获取当前日期的截面数据 + cs = data.get_cross_section() + + if len(cs) == 0: + return pl.Series(name=self.name, values=[]) + + # 获取PE值,处理负值和缺失值 + pe = cs["pe"] + # PE为负表示亏损,设为一个大值(排名靠后) + pe = pe.fill_null(9999) + pe = pl.when(pe < 0).then(9999).otherwise(pe) + + # 计算排名(升序,PE越小排名越靠前) + if len(pe) > 1 and pe.max() != pe.min(): + ranks = pe.rank(method="average") / len(pe) + else: + # 数据不足或全部相同,返回0.5 + ranks = pl.Series(name=self.name, values=[0.5] * len(pe)) + + return ranks +``` + +### 14.2 时序因子示例:RSI 因子 + +```python +"""技术指标 - 相对强弱指标(RSI) + +本模块提供RSI因子: +- RSIFactor: 相对强弱指标(时序因子) + +使用示例: + >>> from src.factors.technical import RSIFactor + >>> rsi14 = RSIFactor(period=14) + >>> # 计算14日RSI +""" + +from typing import List + +import polars as pl + +from src.factors.base import TimeSeriesFactor +from src.factors.data_spec import DataSpec, FactorData + + +class RSIFactor(TimeSeriesFactor): + """相对强弱指标(RSI)因子 + + 计算逻辑:对每只股票,计算其过去n日的RSI值。 + + RSI = 100 - (100 / (1 + RS)) + RS = 平均上涨幅度 / 平均下跌幅度 + + 特点: + - 参数化因子:通过 period 参数指定计算窗口 + - 时序因子:每只股票单独计算 + - 值域:0-100,通常>70为超买,<30为超卖 + + Attributes: + period: RSI计算期(天数),默认14 + + Example: + >>> rsi14 = RSIFactor(period=14) + >>> # 计算14日RSI + """ + + name: str = "rsi" + factor_type: str = "time_series" + category: str = "technical" + description: str = "相对强弱指标,值域0-100" + data_specs: List[DataSpec] = [] + + def __init__(self, period: int = 14): + """初始化RSI因子 + + Args: + period: RSI计算期(天数),默认14 + """ + super().__init__(period=period) + # RSI需要 period+1 天的数据来计算涨跌幅 + self.data_specs = [ + DataSpec( + "daily", + ["ts_code", "trade_date", "close"], + lookback_days=period + 1, + ) + ] + self.name = f"rsi_{period}" + + def compute(self, data: FactorData) -> pl.Series: + """计算RSI + + Args: + data: FactorData,包含单只股票的完整时间序列 + + Returns: + RSI值序列(0-100) + """ + # 获取收盘价 + close = data.get_column("close") + period = self.params["period"] + + # 计算涨跌幅 + delta = close.diff() + + # 分离上涨和下跌 + gain = pl.when(delta > 0).then(delta).otherwise(0) + loss = pl.when(delta < 0).then(-delta).otherwise(0) + + # 计算平均上涨和平均下跌(使用指数移动平均) + avg_gain = gain.ewm_mean(span=period, min_periods=period) + avg_loss = loss.ewm_mean(span=period, min_periods=period) + + # 计算RS和RSI + rs = avg_gain / avg_loss + rsi = 100 - (100 / (1 + rs)) + + return rsi +``` + +### 14.3 使用示例 + +```python +from src.factors import FactorEngine, DataLoader +from src.factors.valuation import PERankFactor +from src.factors.technical import RSIFactor + +# 创建数据加载器和执行引擎 +loader = DataLoader(data_dir="data") +engine = FactorEngine(loader) + +# 创建因子 +pe_rank = PERankFactor() +rsi14 = RSIFactor(period=14) + +# 计算截面因子 +pe_result = engine.compute( + pe_rank, + start_date="20240101", + end_date="20240131" +) + +# 计算时序因子 +rsi_result = engine.compute( + rsi14, + stock_codes=["000001.SZ", "000002.SZ"], + start_date="20240101", + end_date="20240131" +) + +# 因子组合(同类型才能组合) +# combined = 0.5 * pe_rank + 0.5 * other_cs_factor +``` + +--- + +## 15. 常见错误 + +### 15.1 类型错误 + +#### 错误 1:混淆截面因子和时序因子 + +```python +# 错误:时序计算却继承截面因子 +class MovingAverageFactor(CrossSectionalFactor): + def compute(self, data): + # 这里只会传入一天的数据,无法计算MA + return data.get_column("close").rolling_mean(window_size=20) + +# 正确:继承时序因子 +class MovingAverageFactor(TimeSeriesFactor): + def compute(self, data): + # 这里传入完整时间序列,可以计算MA + return data.get_column("close").rolling_mean(window_size=20) +``` + +#### 错误 2:在截面因子中访问其他日期的数据 + +```python +# 错误:试图访问历史数据 +class MyFactor(CrossSectionalFactor): + def compute(self, data): + df = data.to_polars() + # 错误:data 只包含当前日期的数据 + yesterday_data = df.filter(pl.col("trade_date") == "20240101") + +# 正确:使用 lookback_days 获取历史数据 +class MyFactor(CrossSectionalFactor): + data_specs = [DataSpec("daily", ["close"], lookback_days=5)] + + def compute(self, data): + df = data.to_polars() + # 现在 df 包含 [T-4, T] 的数据 + return df["close"].mean() +``` + +### 15.2 命名错误 + +#### 错误 3:使用非法字符或格式 + +```python +# 错误 +name = "PE Rank" # 包含空格 +name = "pe-rank" # 使用连字符 +name = "PERank" # 使用 PascalCase + +# 正确 +name = "pe_rank" # snake_case +``` + +#### 错误 4:参数化因子未动态设置 name + +```python +# 错误 +class MovingAverageFactor(TimeSeriesFactor): + name = "ma" # 固定名称,无法区分 ma_5 和 ma_10 + + def __init__(self, period: int = 5): + super().__init__(period=period) + +# 正确 +class MovingAverageFactor(TimeSeriesFactor): + name = "ma" # 默认名称 + + def __init__(self, period: int = 5): + super().__init__(period=period) + self.name = f"ma_{period}" # 动态设置名称 +``` + +### 15.3 数据规范错误 + +#### 错误 5:忘记包含必需列 + +```python +# 错误 +data_specs = [DataSpec("daily", ["close"], lookback_days=5)] +# 缺少 ts_code 和 trade_date + +# 正确 +data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)] +``` + +#### 错误 6:lookback_days 设置不当 + +```python +# 错误:计算20日MA但只请求5天数据 +class MA20Factor(TimeSeriesFactor): + data_specs = [DataSpec("daily", ["close"], lookback_days=5)] + + def compute(self, data): + return data.get_column("close").rolling_mean(window_size=20) + +# 正确 +class MA20Factor(TimeSeriesFactor): + data_specs = [DataSpec("daily", ["close"], lookback_days=20)] + + def compute(self, data): + return data.get_column("close").rolling_mean(window_size=20) +``` + +### 15.4 计算逻辑错误 + +#### 错误 7:返回类型错误 + +```python +# 错误:返回 DataFrame 而不是 Series +class MyFactor(CrossSectionalFactor): + def compute(self, data): + cs = data.get_cross_section() + return cs # 错误:返回了 DataFrame + +# 正确 +class MyFactor(CrossSectionalFactor): + def compute(self, data): + cs = data.get_cross_section() + return cs["pe"].rank() # 正确:返回 Series +``` + +#### 错误 8:Series 长度不匹配 + +```python +# 错误:返回的 Series 长度与股票数量不匹配 +class MyFactor(CrossSectionalFactor): + def compute(self, data): + cs = data.get_cross_section() + # 错误:只返回3个值,但可能有3000只股票 + return pl.Series([0.1, 0.2, 0.3]) + +# 正确 +class MyFactor(CrossSectionalFactor): + def compute(self, data): + cs = data.get_cross_section() + # 正确:返回与股票数量相同的 Series + return cs["pe"].rank() / len(cs) +``` + +### 15.5 性能错误 + +#### 错误 9:使用 Python 循环 + +```python +# 错误:使用 Python 循环 +def compute(self, data): + close = data.get_column("close") + result = [] + for i in range(len(close)): + if i < 20: + result.append(None) + else: + result.append(sum(close[i-20:i]) / 20) + return pl.Series(result) + +# 正确:使用向量化操作 +def compute(self, data): + close = data.get_column("close") + return close.rolling_mean(window_size=20) +``` + +--- + +## 16. 验证清单 + +### 16.1 开发前检查 + +- [ ] 确定因子类型(截面/时序) +- [ ] 确定因子分类(momentum/financial/...) +- [ ] 确定需要的数据字段 +- [ ] 确定是否需要参数化 +- [ ] 确定因子的经济含义和计算逻辑 + +### 16.2 编码检查 + +- [ ] 类继承正确的基类(CrossSectionalFactor/TimeSeriesFactor) +- [ ] 声明了必需的类属性(name, factor_type, data_specs) +- [ ] data_specs 包含 ts_code 和 trade_date +- [ ] lookback_days 设置正确 +- [ ] name 使用 snake_case +- [ ] 类名使用 PascalCase 且以 Factor 结尾 +- [ ] 文件名使用 snake_case +- [ ] 所有函数都有类型提示 +- [ ] 文档字符串符合 Google 风格 +- [ ] compute 方法返回 pl.Series + +### 16.3 防泄露检查 + +- [ ] 截面因子只访问当前日期数据 +- [ ] 时序因子不访问其他股票数据 +- [ ] 没有使用 `data.to_polars()` 绕过框架 +- [ ] 没有硬编码日期或股票代码 + +### 16.4 性能检查 + +- [ ] 使用 Polars 向量化操作 +- [ ] 避免 Python 循环 +- [ ] 只加载需要的列 +- [ ] lookback_days 没有设置过大 + +### 16.5 测试检查 + +- [ ] 编写了初始化测试 +- [ ] 编写了计算逻辑测试 +- [ ] 编写了边界条件测试(空数据、单条数据) +- [ ] 编写了参数验证测试 +- [ ] 所有测试通过 + +### 16.6 文档检查 + +- [ ] 文件头文档完整 +- [ ] 类文档字符串完整 +- [ ] 方法文档字符串完整 +- [ ] 更新了分类 __init__.py +- [ ] 更新了主 __init__.py(如需要) + +--- + +## 附录 + +### A. 常见问题 + +#### Q1: 截面因子和时序因子有什么区别? + +**截面因子**:每天对所有股票进行一次计算,可以比较不同股票之间的值。例如:PE排名、市值分位数。 + +**时序因子**:对每只股票单独计算其时间序列。例如:移动平均线、RSI、历史波动率。 + +#### Q2: lookback_days 怎么设置? + +- 如果只需要当日数据:`lookback_days=1` +- 如果需要n日历史数据:`lookback_days=n` +- 对于收益率计算(需要T日和T-n日):`lookback_days=n+1` + +#### Q3: 如何处理缺失值? + +```python +# Polars 提供多种缺失值处理方法 +series.fill_null(0) # 填充为0 +series.fill_null(strategy="forward") # 向前填充 +series.drop_nulls() # 删除缺失值 +``` + +#### Q4: 如何调试因子计算? + +```python +# 1. 构造测试数据 +df = pl.DataFrame({ + "ts_code": [...], + "trade_date": [...], + "close": [...], +}) + +# 2. 创建 FactorData +context = FactorContext(current_date="20240101") +data = FactorData(df, context) + +# 3. 直接调用 compute +factor = MyFactor() +result = factor.compute(data) +print(result) +``` + +#### Q5: 可以使用 Pandas 吗? + +推荐使用 **Polars**,因为: +- 性能更好(向量化操作) +- 内存占用更低 +- 与框架其他部分保持一致 + +如果必须使用 Pandas: +```python +def compute(self, data: FactorData) -> pl.Series: + # 转换为 Pandas + pdf = data.to_polars().to_pandas() + + # 使用 Pandas 计算 + result = pdf["close"].rolling(window=20).mean() + + # 转回 Polars Series + return pl.Series(name=self.name, values=result.values) +``` + +### B. 快速参考卡 + +#### 创建截面因子 + +```python +from src.factors.base import CrossSectionalFactor +from src.factors.data_spec import DataSpec, FactorData +import polars as pl +from typing import List + +class MyFactor(CrossSectionalFactor): + name = "my_factor" + factor_type = "cross_sectional" + category = "momentum" + data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1)] + + def compute(self, data: FactorData) -> pl.Series: + cs = data.get_cross_section() + return cs["close"].rank() +``` + +#### 创建时序因子 + +```python +from src.factors.base import TimeSeriesFactor +from src.factors.data_spec import DataSpec, FactorData +import polars as pl +from typing import List + +class MyFactor(TimeSeriesFactor): + name = "my_factor" + factor_type = "time_series" + category = "momentum" + data_specs = [] + + def __init__(self, period: int = 5): + super().__init__(period=period) + self.data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=period)] + self.name = f"my_factor_{period}" + + def compute(self, data: FactorData) -> pl.Series: + close = data.get_column("close") + return close.rolling_mean(window_size=self.params["period"]) +``` + +### C. 相关文档 + +- [因子框架设计](../../docs/factor_framework_design.md) +- [数据同步指南](../../docs/db_sync_guide.md) +- [API 接口规范](../../src/data/api_wrappers/API_INTERFACE_SPEC.md) + +--- + +**文档版本**: 2.0 +**最后更新**: 2026-02-25 +**维护者**: ProStock Team diff --git a/src/factors/__init__.py b/src/factors/__init__.py index df8edf8..e46c282 100644 --- a/src/factors/__init__.py +++ b/src/factors/__init__.py @@ -18,6 +18,52 @@ - CompositeFactor: 组合因子 - ScalarFactor: 标量运算因子 +因子分类目录: +- momentum/: 动量因子(MA、收益率排名等) +- financial/: 财务因子(EPS、ROE等) +- valuation/: 估值因子(PE、PB、PS等) +- technical/: 技术指标因子(RSI、MACD、布林带等) +- quality/: 质量因子(盈利能力、稳定性等) +- sentiment/: 情绪因子(换手率、资金流向等) +- volume/: 成交量因子(OBV、成交量比率等) +- volatility/: 波动率因子(历史波动率、GARCH等) + +数据加载和执行(Phase 3-4): +- DataLoader: 数据加载器 +- FactorEngine: 因子执行引擎 + +使用示例: + # 使用通用因子(参数化) + from src.factors import MovingAverageFactor, ReturnRankFactor + from src.factors import DataLoader, FactorEngine + + ma5 = MovingAverageFactor(period=5) # 5日MA + ma10 = MovingAverageFactor(period=10) # 10日MA + ret5 = ReturnRankFactor(period=5) # 5日收益率排名 + + loader = DataLoader(data_dir="data") + engine = FactorEngine(loader) + result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131") +""" + +因子框架提供以下核心功能: +1. 类型安全的因子定义(截面因子、时序因子) +2. 数据泄露防护机制 +3. 因子组合和运算 +4. 高效的数据加载和计算引擎 + +基础数据类型(Phase 1): +- DataSpec: 数据需求规格 +- FactorContext: 计算上下文 +- FactorData: 数据容器 + +因子基类(Phase 2): +- BaseFactor: 抽象基类 +- CrossSectionalFactor: 日期截面因子基类 +- TimeSeriesFactor: 时间序列因子基类 +- CompositeFactor: 组合因子 +- ScalarFactor: 标量运算因子 + 动量因子(momentum/): - MovingAverageFactor: 移动平均线(时序因子) - ReturnRankFactor: 收益率排名(截面因子) diff --git a/src/factors/data_loader.py b/src/factors/data_loader.py index 82ddcdb..5c81bc0 100644 --- a/src/factors/data_loader.py +++ b/src/factors/data_loader.py @@ -72,8 +72,8 @@ class DataLoader: if cache_key in self._cache: df = self._cache[cache_key] else: - # 读取 H5 文件 - df = self._read_h5(spec.source) + # 读取 H5 文件(传入日期范围以支持过滤) + df = self._read_h5(spec.source, date_range=date_range) # 列选择 - 只保留需要的列 missing_cols = set(spec.columns) - set(df.columns) @@ -107,7 +107,11 @@ class DataLoader: """清空缓存""" self._cache.clear() - def _read_h5(self, source: str) -> pl.DataFrame: + def _read_h5( + self, + source: str, + date_range: Optional[Tuple[str, str]] = None, + ) -> pl.DataFrame: """读取数据 - 从 DuckDB 加载为 Polars DataFrame。 迁移说明: @@ -117,6 +121,7 @@ class DataLoader: Args: source: 表名(对应 DuckDB 中的表,如 "daily") + date_range: 日期范围限制 (start_date, end_date),可选 Returns: Polars DataFrame @@ -125,11 +130,38 @@ class DataLoader: Exception: 数据库查询错误 """ from src.data.storage import Storage + from src.data.api_wrappers.api_trade_cal import get_trading_days + from src.data.utils import get_today_date + from src.factors.financial.utils import expand_period_to_trading_days storage = Storage() - # 如果 DataLoader 有 date_range,传递给 Storage 进行过滤 - # 实现查询下推,只加载必要数据 + # 特殊处理财务数据:将报告期展开到交易日 + if source == "financial_income": + # 确定日期范围 + start_date = date_range[0] if date_range else "20180101" + end_date = date_range[1] if date_range else get_today_date() + + # 1. 加载原始财务数据(报告期粒度),按日期范围过滤 + # 注意:financial_income 使用 end_date 字段作为报告期 + df = storage.load_polars( + "financial_income", + start_date=start_date, + end_date=end_date, + ) + + if len(df) == 0: + return pl.DataFrame() + + # 2. 获取交易日历(从2018年开始到当前,确保有足够的历史数据用于前向填充) + # 需要从数据的最小日期开始,确保能获取到足够的交易日 + trade_start = "20180101" if start_date > "20180101" else start_date + trade_dates = get_trading_days(trade_start, get_today_date()) + + # 3. 展开到交易日(前向填充) + return expand_period_to_trading_days(df, trade_dates) + + # 其他数据源保持原有逻辑 return storage.load_polars(source) def _merge_dataframes(self, dataframes: List[pl.DataFrame]) -> pl.DataFrame: diff --git a/src/factors/financial/__init__.py b/src/factors/financial/__init__.py index 9c34516..839e6f4 100644 --- a/src/factors/financial/__init__.py +++ b/src/factors/financial/__init__.py @@ -4,7 +4,10 @@ 因子分类: - financial: 财务因子 - - (待添加) + - EPSFactor: 每股收益排名因子 + +已添加因子: +- EPSFactor: 每股收益排名(基于basic_eps) 待添加因子: - PERankFactor: 市盈率排名 @@ -12,4 +15,6 @@ - DividendFactor: 股息率因子 """ -__all__ = [] +from src.factors.financial.eps_factor import EPSFactor + +__all__ = ["EPSFactor"] diff --git a/src/factors/financial/eps_factor.py b/src/factors/financial/eps_factor.py new file mode 100644 index 0000000..4151b89 --- /dev/null +++ b/src/factors/financial/eps_factor.py @@ -0,0 +1,66 @@ +"""EPS因子 + +每股收益(EPS)排名因子实现 +""" + +from typing import List +import polars as pl + +from src.factors.base import CrossSectionalFactor +from src.factors.data_spec import DataSpec, FactorData + + +class EPSFactor(CrossSectionalFactor): + """每股收益(EPS)排名因子 + + 计算逻辑:使用最新报告期的basic_eps,每天对所有股票进行截面排名 + + Attributes: + name: 因子名称 "eps_rank" + category: 因子分类 "financial" + data_specs: 数据需求规格 + + Example: + >>> from src.factors import FactorEngine, DataLoader + >>> from src.factors.financial.eps_factor import EPSFactor + >>> loader = DataLoader('data') + >>> engine = FactorEngine(loader) + >>> eps_factor = EPSFactor() + >>> result = engine.compute(eps_factor, start_date='20210101', end_date='20210131') + """ + + name: str = "eps_rank" + category: str = "financial" + description: str = "每股收益截面排名因子" + data_specs: List[DataSpec] = [ + DataSpec( + "financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1 + ) + ] + + def compute(self, data: FactorData) -> pl.Series: + """计算EPS排名 + + Args: + data: FactorData,包含当前日期的截面数据 + + Returns: + EPS排名的0-1标准化值(0-1之间) + """ + # 获取当前日期的截面数据 + cs = data.get_cross_section() + + if len(cs) == 0: + return pl.Series(name=self.name, values=[]) + + # 提取EPS值,填充缺失值为0 + eps = cs["basic_eps"].fill_null(0) + + # 计算排名并归一化到0-1 + if len(eps) > 1 and eps.max() != eps.min(): + ranks = eps.rank(method="average") / len(eps) + else: + # 数据不足或全部相同,返回0.5 + ranks = pl.Series(name=self.name, values=[0.5] * len(eps)) + + return ranks diff --git a/src/factors/financial/utils.py b/src/factors/financial/utils.py new file mode 100644 index 0000000..cb656bc --- /dev/null +++ b/src/factors/financial/utils.py @@ -0,0 +1,82 @@ +"""财务因子工具函数 + +提供财务数据处理的工具函数: +- expand_period_to_trading_days: 将报告期数据展开到每个交易日(前向填充) +""" + +from typing import List +import polars as pl + + +def expand_period_to_trading_days( + financial_df: pl.DataFrame, + trade_dates: List[str], +) -> pl.DataFrame: + """将财务数据(报告期粒度)展开到每个交易日(前向填充) + + 核心逻辑:对于每个交易日,找到该日期之前最新的已公告报告期数据。 + 例如:2020年报(20201231)公告于20210428,则在2021-04-28之后的每个 + 交易日都使用该年报数据,直到2021一季报公告。 + + Args: + financial_df: 财务数据DataFrame,包含 ts_code, ann_date, end_date, ... + trade_dates: 交易日列表(YYYYMMDD格式,已排序) + + Returns: + DataFrame,包含 trade_date, ts_code 和所有财务字段 + + Example: + >>> financial_df = pl.DataFrame({ + ... 'ts_code': ['000001.SZ'], + ... 'ann_date': ['20210428'], + ... 'end_date': ['20210331'], + ... 'basic_eps': [0.5] + ... }) + >>> trade_dates = ['20210428', '20210429', '20210430'] + >>> result = expand_period_to_trading_days(financial_df, trade_dates) + >>> print(result) + shape: (3, 5) + ┌───────────┬───────────┬────────────┬────────────┬───────────┐ + │ ts_code ┆ ann_date ┆ end_date ┆ basic_eps ┆ trade_date│ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ str ┆ f64 ┆ str │ + ╞═══════════╪═══════════╪════════════╪════════════╪═══════════╡ + │ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210428 │ + │ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210429 │ + │ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210430 │ + └───────────┴───────────┴────────────┴────────────┴───────────┘ + """ + if len(financial_df) == 0: + return pl.DataFrame() + + results = [] + + # 按股票分组处理 + for ts_code in financial_df["ts_code"].unique(): + stock_data = financial_df.filter(pl.col("ts_code") == ts_code) + + # 按报告期排序(end_date升序) + stock_data = stock_data.sort("end_date") + + rows = [] + for trade_date in trade_dates: + # 找到该交易日之前最新的已公告报告期 + # 条件1: end_date <= trade_date(报告期不晚于交易日) + # 条件2: ann_date <= trade_date(已公告) + applicable = stock_data.filter( + (pl.col("end_date") <= trade_date) & (pl.col("ann_date") <= trade_date) + ) + + if len(applicable) > 0: + # 取最新的一条(end_date最大的) + latest = applicable.tail(1).with_columns( + [pl.lit(trade_date).alias("trade_date")] + ) + rows.append(latest) + + if rows: + results.append(pl.concat(rows)) + + if results: + return pl.concat(results) + return pl.DataFrame() diff --git a/src/factors/quality/__init__.py b/src/factors/quality/__init__.py new file mode 100644 index 0000000..646f1f9 --- /dev/null +++ b/src/factors/quality/__init__.py @@ -0,0 +1,20 @@ +"""质量因子模块 + +本模块提供质量类因子: +- 盈利能力:ROE、ROA、毛利率、净利率 +- 盈利稳定性:盈利波动率、盈利持续性 +- 财务健康度:资产负债率、流动比率等 + +使用示例: + >>> from src.factors.quality import ROEFactor + >>> factor = ROEFactor() +""" + +# 在此处导入具体的质量因子 +# from .roe import ROEFactor +# from .roa import ROAFactor +# from .profit_stability import ProfitStabilityFactor + +__all__ = [ + # 添加你的质量因子 +] diff --git a/src/factors/sentiment/__init__.py b/src/factors/sentiment/__init__.py new file mode 100644 index 0000000..c15a1af --- /dev/null +++ b/src/factors/sentiment/__init__.py @@ -0,0 +1,20 @@ +"""情绪因子模块 + +本模块提供市场情绪类因子: +- 换手率、换手率变化率 +- 资金流向、主力净流入 +- 波动率、振幅等 + +使用示例: + >>> from src.factors.sentiment import TurnoverFactor + >>> factor = TurnoverFactor(period=20) +""" + +# 在此处导入具体的情绪因子 +# from .turnover import TurnoverFactor +# from .money_flow import MoneyFlowFactor +# from .amplitude import AmplitudeFactor + +__all__ = [ + # 添加你的情绪因子 +] diff --git a/src/factors/technical/__init__.py b/src/factors/technical/__init__.py new file mode 100644 index 0000000..45b5c36 --- /dev/null +++ b/src/factors/technical/__init__.py @@ -0,0 +1,20 @@ +"""技术指标因子模块 + +本模块提供技术分析类因子: +- 移动平均线(MA)、指数移动平均(EMA) +- 相对强弱指标(RSI)、MACD、KDJ +- 布林带(Bollinger Bands)等 + +使用示例: + >>> from src.factors.technical import RSIFactor + >>> factor = RSIFactor(period=14) +""" + +# 在此处导入具体的技术指标因子 +# from .rsi import RSIFactor +# from .macd import MACDFactor +# from .bollinger import BollingerFactor + +__all__ = [ + # 添加你的技术指标因子 +] diff --git a/src/factors/valuation/__init__.py b/src/factors/valuation/__init__.py new file mode 100644 index 0000000..17a9b25 --- /dev/null +++ b/src/factors/valuation/__init__.py @@ -0,0 +1,18 @@ +"""估值因子模块 + +本模块提供估值类因子: +- 市盈率(PE)、市净率(PB)、市销率(PS)等估值指标 +- 估值排名、估值分位数等衍生因子 + +使用示例: + >>> from src.factors.valuation import PERankFactor + >>> factor = PERankFactor() +""" + +# 在此处导入具体的估值因子 +# from .pe_rank import PERankFactor +# from .pb_rank import PBRankFactor + +__all__ = [ + # 添加你的估值因子 +] diff --git a/src/factors/volatility/__init__.py b/src/factors/volatility/__init__.py new file mode 100644 index 0000000..4a95d03 --- /dev/null +++ b/src/factors/volatility/__init__.py @@ -0,0 +1,21 @@ +"""波动率因子模块 + +本模块提供波动率相关因子: +- 历史波动率(Historical Volatility) +- 实现波动率(Realized Volatility) +- GARCH类波动率预测 +- 波动率风险指标等 + +使用示例: + >>> from src.factors.volatility import HistoricalVolFactor + >>> factor = HistoricalVolFactor(period=20) +""" + +# 在此处导入具体的波动率因子 +# from .historical_vol import HistoricalVolFactor +# from .realized_vol import RealizedVolFactor +# from .garch_vol import GARCHVolFactor + +__all__ = [ + # 添加你的波动率因子 +] diff --git a/src/factors/volume/__init__.py b/src/factors/volume/__init__.py new file mode 100644 index 0000000..e3bc180 --- /dev/null +++ b/src/factors/volume/__init__.py @@ -0,0 +1,20 @@ +"""成交量因子模块 + +本模块提供成交量相关因子: +- 成交量移动平均 +- 成交量比率(VR)、能量潮(OBV) +- 量价配合指标等 + +使用示例: + >>> from src.factors.volume import OBVFactor + >>> factor = OBVFactor() +""" + +# 在此处导入具体的成交量因子 +# from .obv import OBVFactor +# from .volume_ratio import VolumeRatioFactor +# from .volume_ma import VolumeMAFactor + +__all__ = [ + # 添加你的成交量因子 +] diff --git a/src/pipeline/processors/__init__.py b/src/pipeline/processors/__init__.py index f33779e..be33a77 100644 --- a/src/pipeline/processors/__init__.py +++ b/src/pipeline/processors/__init__.py @@ -8,6 +8,7 @@ from src.pipeline.processors.processors import ( MinMaxScaler, RankTransformer, Neutralizer, + MADClipper, ) __all__ = [ @@ -18,4 +19,5 @@ __all__ = [ "MinMaxScaler", "RankTransformer", "Neutralizer", + "MADClipper", ] diff --git a/src/pipeline/processors/processors.py b/src/pipeline/processors/processors.py index 31985c2..72bb958 100644 --- a/src/pipeline/processors/processors.py +++ b/src/pipeline/processors/processors.py @@ -3,9 +3,8 @@ 提供常用的数据预处理和转换处理器。 """ -from typing import List, Optional, Dict, Any +from typing import List, Optional import polars as pl -import numpy as np from src.pipeline.core import BaseProcessor, PipelineStage from src.pipeline.registry import PluginRegistry @@ -227,6 +226,64 @@ class Neutralizer(BaseProcessor): return result +@PluginRegistry.register_processor("mad_clipper") +class MADClipper(BaseProcessor): + """MAD去极值处理器 - 基于每日截面的中位数绝对偏差去除极值 + + 使用3倍MAD作为阈值,比标准差方法更稳健,对异常值不敏感。 + 阈值范围: [median - n*MAD, median + n*MAD] + """ + + stage = PipelineStage.TRAIN + + def __init__( + self, + columns: Optional[List[str]] = None, + n_mad: float = 3.0, + ): + super().__init__(columns) + self.n_mad = n_mad + + def fit(self, data: pl.DataFrame) -> "MADClipper": + cols = _get_numeric_columns(data, self.columns) + bounds = {} + + for col in cols: + # 按日期分组计算每个截面的 median 和 MAD + daily_stats = data.group_by("trade_date").agg( + pl.col(col).median().alias("median"), + (pl.col(col) - pl.col(col).median()).abs().median().alias("mad"), + ) + bounds[col] = daily_stats + + self._fitted_params = {"bounds": bounds, "columns": cols} + self._is_fitted = True + return self + + def transform(self, data: pl.DataFrame) -> pl.DataFrame: + """使用窗口函数进行MAD去极值,避免join操作提升性能""" + result = data + bounds = self._fitted_params.get("bounds", {}) + + for col in bounds.keys(): + if col not in result.columns: + continue + + # 使用窗口函数直接计算每个截面的median和MAD,避免join + # 1. 计算每个日期截面的median + median = pl.col(col).median().over("trade_date") + # 2. 计算每个日期截面的MAD + mad = (pl.col(col) - median).abs().median().over("trade_date") + + # 3. 计算上下界并clip + lower = median - self.n_mad * mad + upper = median + self.n_mad * mad + + result = result.with_columns(pl.col(col).clip(lower, upper).alias(col)) + + return result + + __all__ = [ "DropNAProcessor", "FillNAProcessor", @@ -235,4 +292,5 @@ __all__ = [ "MinMaxScaler", "RankTransformer", "Neutralizer", + "MADClipper", ] diff --git a/src/training/main.py b/src/training/main.py index 97cd868..4876dad 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -12,13 +12,16 @@ from src.training.pipeline import run_training if __name__ == "__main__": # 运行完整训练流程 - # 训练集:20180101 - 20230101 - # 测试集:20230101 - 20240101 + # 训练集:20190101 - 20231231 + # 验证集:20240102 - 20240531 (与训练集间隔1天,避免数据泄露) + # 测试集:20240602 - 20241231 (与验证集间隔1天,避免数据泄露) result = run_training( train_start="20190101", - train_end="20250101", - test_start="20250101", - test_end="20260101", + train_end="20231231", + val_start="20240102", + val_end="20240531", + test_start="20240602", + test_end="20241231", top_n=5, output_path="output/top_stocks.tsv", ) diff --git a/src/training/pipeline.py b/src/training/pipeline.py index 8d93577..cf5b859 100644 --- a/src/training/pipeline.py +++ b/src/training/pipeline.py @@ -30,10 +30,12 @@ def prepare_data( data_dir: str = "data", train_start: str = "20180101", train_end: str = "20230101", - test_start: str = "20230101", + val_start: str = "20230101", + val_end: str = "20230601", + test_start: str = "20230601", test_end: str = "20240101", -) -> Tuple[pl.DataFrame, pl.DataFrame]: - """准备训练和测试数据 +) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: + """准备训练、验证和测试数据 从DuckDB加载原始日线数据,计算所需因子并生成标签。 @@ -41,11 +43,13 @@ def prepare_data( data_dir: 数据目录 train_start: 训练集开始日期 train_end: 训练集结束日期 + val_start: 验证集开始日期 + val_end: 验证集结束日期 test_start: 测试集开始日期 test_end: 测试集结束日期 Returns: - (train_data, test_data): 训练集和测试集的DataFrame + (train_data, val_data, test_data): 训练集、验证集和测试集的DataFrame """ from src.data.storage import Storage @@ -56,47 +60,56 @@ def prepare_data( lookback_days = 20 # 足够计算MA10和5日收益率 start_with_lookback = str(int(train_start) - 10000) # 往前取一年 - # 查询训练集数据 + # 查询全部数据(包含train、val、test),然后再拆分 # 注意:DuckDB 中 trade_date 是 DATE 类型,需要转换 start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}" - end_dt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}" - train_query = f""" + end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}" + + all_query = f""" SELECT ts_code, trade_date, close, pre_close FROM daily WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}' ORDER BY ts_code, trade_date """ - train_raw = storage._connection.sql(train_query).pl() + all_raw = storage._connection.sql(all_query).pl() # 转换 trade_date 为字符串格式 - train_raw = train_raw.with_columns( - pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date") - ) - - # 查询测试集数据(也需要历史数据计算因子) - test_start_dt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}" - test_end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}" - test_query = f""" - SELECT ts_code, trade_date, close, pre_close - FROM daily - WHERE trade_date >= '{test_start_dt}' AND trade_date <= '{test_end_dt}' - ORDER BY ts_code, trade_date - """ - test_raw = storage._connection.sql(test_query).pl() - # 转换 trade_date 为字符串格式 - test_raw = test_raw.with_columns( + all_raw = all_raw.with_columns( pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date") ) # 过滤不符合条件的股票 - train_raw = _filter_invalid_stocks(train_raw) - test_raw = _filter_invalid_stocks(test_raw) - print(f"[PrepareData] After filtering: train={len(train_raw)}, test={len(test_raw)}") + all_raw = _filter_invalid_stocks(all_raw) + print(f"[PrepareData] After filtering: total={len(all_raw)}") - # 计算因子和标签 - train_data = _compute_features_and_label(train_raw, train_start, train_end) - test_data = _compute_features_and_label(test_raw, test_start, test_end) + # 计算因子和标签(需要全局数据一次性计算) + all_data = _compute_features_and_label( + all_raw, + start_date=train_start, + end_date=test_end + ) - return train_data, test_data + # 转换日期格式用于比较 + train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}" + train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}" + val_start_fmt = f"{val_start[:4]}-{val_start[4:6]}-{val_start[6:8]}" + val_end_fmt = f"{val_end[:4]}-{val_end[4:6]}-{val_end[6:8]}" + test_start_fmt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}" + test_end_fmt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}" + + # 拆分数据 + train_data = all_data.filter( + (pl.col("trade_date") >= train_start_fmt) & (pl.col("trade_date") <= train_end_fmt) + ) + val_data = all_data.filter( + (pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt) + ) + test_data = all_data.filter( + (pl.col("trade_date") >= test_start_fmt) & (pl.col("trade_date") <= test_end_fmt) + ) + + print(f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}") + + return train_data, val_data, test_data def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame: @@ -254,6 +267,7 @@ def create_pipeline() -> ProcessingPipeline: def train_model( train_data: pl.DataFrame, + val_data: Optional[pl.DataFrame], feature_cols: List[str], label_col: str = "label", model_params: Optional[dict] = None, @@ -262,6 +276,7 @@ def train_model( Args: train_data: 训练数据 + val_data: 验证数据(用于早停) feature_cols: 特征列名列表 label_col: 标签列名 model_params: 模型参数字典 @@ -273,21 +288,39 @@ def train_model( pipeline = create_pipeline() print("[TrainModel] Pipeline created: FillNA(0)") - # 准备特征和标签 + # 准备训练特征和标签 X_train = train_data.select(feature_cols) y_train = train_data[label_col] - print(f"[TrainModel] Raw samples: {len(X_train)}, features: {feature_cols}") + print(f"[TrainModel] Train samples: {len(X_train)}, features: {feature_cols}") - # 处理数据 + # 处理训练数据 X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN) print(f"[TrainModel] After processing: {len(X_train_processed)} samples") - # 过滤有效标签(排除-1等无效值) + # 过滤训练集有效标签(排除-1等无效值) valid_mask = y_train.is_in([0, 1]) X_train_processed = X_train_processed.filter(valid_mask) y_train = y_train.filter(valid_mask) print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples") - print(f"[TrainModel] Label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}") + print(f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}") + + # 准备验证集 + X_val_processed = None + y_val = None + if val_data is not None and len(val_data) > 0: + X_val = val_data.select(feature_cols) + y_val = val_data[label_col] + print(f"[TrainModel] Val samples: {len(X_val)}") + + # 处理验证集数据(使用训练集的参数) + X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST) + + # 过滤验证集有效标签 + val_valid_mask = y_val.is_in([0, 1]) + X_val_processed = X_val_processed.filter(val_valid_mask) + y_val = y_val.filter(val_valid_mask) + print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples") + print(f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}") # 创建模型 params = model_params or { @@ -302,9 +335,13 @@ def train_model( params=params, ) - # 训练模型 + # 训练模型(使用验证集早停) print("[TrainModel] Training LightGBM...") - model.fit(X_train_processed, y_train) + if X_val_processed is not None and y_val is not None: + print("[TrainModel] Using validation set for early stopping") + model.fit(X_train_processed, y_train, X_val_processed, y_val) + else: + model.fit(X_train_processed, y_train) print("[TrainModel] Training completed!") return model, pipeline @@ -382,7 +419,9 @@ def run_training( output_path: str = "output/top_stocks.tsv", train_start: str = "20180101", train_end: str = "20230101", - test_start: str = "20230101", + val_start: str = "20230101", + val_end: str = "20230601", + test_start: str = "20230601", test_end: str = "20240101", top_n: int = 5, ) -> pl.DataFrame: @@ -393,6 +432,8 @@ def run_training( output_path: 输出文件路径 train_start: 训练集开始日期 train_end: 训练集结束日期 + val_start: 验证集开始日期 + val_end: 验证集结束日期 test_start: 测试集开始日期 test_end: 测试集结束日期 top_n: 每日选股数量 @@ -402,18 +443,22 @@ def run_training( """ print(f"[Training] Starting training pipeline...") print(f"[Training] Train period: {train_start} -> {train_end}") + print(f"[Training] Val period: {val_start} -> {val_end}") print(f"[Training] Test period: {test_start} -> {test_end}") # 1. 准备数据 print("[Training] Preparing data...") - train_data, test_data = prepare_data( + train_data, val_data, test_data = prepare_data( data_dir=data_dir, train_start=train_start, train_end=train_end, + val_start=val_start, + val_end=val_end, test_start=test_start, test_end=test_end, ) print(f"[Training] Train samples: {len(train_data)}") + print(f"[Training] Val samples: {len(val_data)}") print(f"[Training] Test samples: {len(test_data)}") # 2. 定义特征列 @@ -424,6 +469,7 @@ def run_training( print("[Training] Training model...") model, pipeline = train_model( train_data=train_data, + val_data=val_data, feature_cols=feature_cols, label_col=label_col, )