refactor: 代码审查修复 - 日期过滤、性能优化、数据泄露防护

- 修复 data_loader.py 财务数据日期过滤,支持按范围加载
- 优化 MADClipper 使用窗口函数替代 join,提升性能
- 修复训练日期边界问题,添加1天间隔避免数据泄露
- 新增 .gitignore 规则忽略训练输出目录
This commit is contained in:
2026-02-25 21:11:19 +08:00
parent 593ec99466
commit a9e4746239
24 changed files with 3597 additions and 56 deletions

4
.gitignore vendored
View File

@@ -76,5 +76,9 @@ temp/
# 数据目录(允许跟踪,但忽略内容) # 数据目录(允许跟踪,但忽略内容)
data/* data/*
# 训练输出目录(不需要版本控制)
src/training/output/*
!src/training/output/.gitkeep
# AI Agent 工作目录 # AI Agent 工作目录
/.sisyphus/ /.sisyphus/

View File

@@ -0,0 +1,467 @@
# Classify2_load_model.ipynb 代码流程详解文档
## 概述
本文档详细描述了 `Classify2_load_model.ipynb` 文件中完整的代码流程该notebook实现了一个股票分类模型的特征工程与数据预处理流程。整个流程涵盖了从原始数据加载、多源数据合并、因子计算、数据清洗、特征标准化等多个环节最终输出可供机器学习模型使用的特征矩阵。
---
## 第一部分:环境配置与初始化
### 1.1 基础环境设置
notebook开头的第一个代码单元负责基础的开发环境配置。首先启用Jupyter的自动重载功能通过`%load_ext autoreload``%autoreload 2`指令,使得在修改导入的模块后能够自动重新加载,这对于开发调试阶段非常实用。随后进行垃圾回收机制的相关设置,导入`gc`模块用于手动管理内存因为后续处理的数据量较大数GB级别及时释放不需要的对象可以避免内存溢出。
系统路径的配置通过`sys.path.append`完成将项目根目录添加到Python的模块搜索路径中这样可以直接导入项目内的自定义模块。打印工作目录用于确认当前运行环境。数据处理的核心库`pandas`被导入用于表格数据处理,同时导入项目内部的因子模块和工具函数,包括`get_rolling_factor``get_simple_factor`用于计算技术因子,`read_industry_data`用于读取行业分类数据,`calculate_score`用于计算目标变量。最后通过`warnings.filterwarnings("ignore")`忽略所有警告信息,保持输出界面的整洁。
### 1.2 并行计算配置
第二个代码单元配置了Modin框架的并行计算参数。通过设置环境变量`os.environ["MODIN_CPUS"] = "4"`指定使用4个CPU核心进行并行计算。Modin是一个pandas的并行替代库能够在多核CPU上并行执行pandas操作显著加速大规模数据的处理速度。这一配置对于后续处理数千万条记录的数据框至关重要。
---
## 第二部分:核心数据加载与合并
### 2.1 日线数据加载
第三个代码单元是整个流程中最关键的数据加载步骤依次从HDF5文件中读取多个数据源并合并。HDF5是一种高效的分块压缩存储格式非常适合存储大规模数值型数据。
**第一步:加载日线行情数据**
使用`read_and_merge_h5_data`函数从`daily_data.h5`文件中读取日线行情数据。该函数是项目自定义的工具函数,其核心逻辑在`main/utils/utils.py`中实现。读取的字段包括股票代码`ts_code`、交易日期`trade_date`、开盘价`open`、收盘价`close`、最高价`high`、最低价`low`、成交量`vol`、成交额`amount`以及涨跌幅`pct_chg`。这些字段构成分析的基础价格信息。首次调用该函数时`df`参数为空,因此直接返回读取的数据作为初始数据框。
**第二步:加载日线基础数据**
继续调用`read_and_merge_h5_data`读取`daily_basic.h5`文件获取每日股票的基础财务指标。本次读取使用内连接inner join方式与已有数据进行合并这意味着只保留两个数据集中都存在的股票记录。读取的字段包括`turnover_rate`(换手率)、`pe_ttm`(滚动市盈率)、`circ_mv`(流通市值)、`total_mv`(总市值)以及`volume_ratio`(量比)。这些指标对于评估股票的流动性和估值水平非常重要。
**第三步:加载涨跌停限制数据**
`stk_limit.h5`文件中读取股票的涨跌停价格信息。读取的字段包括`pre_close`(前一日收盘价)、`up_limit`(涨停价)和`down_limit`(跌停价)。这些信息用于后续识别股票的涨停状态,是构建目标变量的重要依据。
**第四步:加载资金流数据**
`money_flow.h5`文件中读取资金流信息,这是计算资金流因子的基础数据。读取的字段包括各类成交量的分解:`buy_sm_vol``sell_sm_vol`代表小单买卖成交量、`buy_lg_vol``sell_lg_vol`代表大单买卖成交量、`buy_elg_vol``sell_elg_vol`代表超大单买卖成交量,以及`net_mf_vol`(净资金流成交量)。通过分析这些不同档位的资金流向,可以判断机构投资者和散户的交易行为。
**第五步:加载筹码分布数据**
最后从`cyq_perf.h5`文件中读取筹码分布相关指标。这些数据反映的是股票持仓成本分布情况,包括历史最低价`his_low`、历史最高价`his_high`、不同成本分位线的价格(`cost_5pct``cost_15pct``cost_50pct``cost_85pct``cost_95pct`)、加权平均成本`weight_avg`以及获利盘比例`winner_rate`。这些指标是计算筹码分布因子的核心数据。
经过上述五个步骤的依次合并最终生成的数据框包含9436343条记录、33个字段占用内存约2.3GB。
### 2.2 行业数据加载与合并
第四个代码单元处理行业分类数据的加载与合并。首先使用`read_and_merge_h5_data`函数从`industry_data.h5`文件中读取行业分类数据,提取股票代码、二级行业代码`l2_code`以及行业纳入日期`in_date`
随后定义了`merge_with_industry_data`函数用于将行业数据与主数据框进行时间匹配合并。该函数的核心逻辑分为以下几个步骤首先确保日期字段转换为datetime类型然后分别对行业数据和主数据按股票代码和日期排序接着使用`pd.merge_asof`函数进行向后合并direction='backward'),这意味着对于每一股票的每一个交易日,会匹配该日之前(包括当日)最近的行业变更记录;对于交易日期早于所有行业变更日期的记录,使用每个股票的最早行业代码进行填充。
这一合并策略的设计目的是确保在任意交易日期都能获取到该股票当前所属的行业分类,这对于后续进行行业中性化处理至关重要。
---
## 第三部分:指数数据处理
### 3.1 指数数据读取
第五个代码单元实现指数数据的读取与指标计算。定义了两个核心函数:`calculate_indicators`用于计算单个指数的技术指标,`generate_index_indicators`用于批量处理多个指数。
`calculate_indicators`函数中,首先对数据按交易日期排序,然后计算以下指标:
**当日涨跌幅**:通过`(close - pre_close) / pre_close * 100`计算,这是最基础的收益率指标。
**RSI指标**相对强弱指数计算过程包括首先计算价格变动delta然后分离上涨和下跌部分分别计算14日滚动平均最后通过公式`100 - (100 / (1 + rs))`得到RSI值。RSI是衡量股票近期涨跌动能的经典技术指标。
**MACD指标**移动平均收敛发散指标计算过程包括12日EMA减去26日EMA得到MACD线9日EMA得到信号线MACD与信号线的差值得到MACD柱。MACD是判断趋势方向和动量变化的重要工具。
**情绪因子**
- **上涨比例up_ratio_20d**过去20天上涨天数占比反映市场整体情绪
- **成交量变化率volume_change_rate**当日成交量与20日均量的比率反映成交量异常变化
- **波动率volatility**过去20天日收益率的标准差反映市场波动程度
- **成交额变化率amount_change_rate**当日成交额与20日均额的比率
`generate_index_indicators`函数则对所有指数分别计算上述指标,然后将结果透视为宽表格式,使每个指数的指标成为独立的列。最终输出包含多个指数技术指标的数据框,这些指标可以作为市场情绪的代理变量加入模型。
### 3.2 行业指数数据处理
第六个代码单元使用talib库和numpy库进行更专业的技术指标计算。talib是专业的技术分析库包含数百种经典技术指标的实现。
`sw_daily.h5`文件读取行业日线数据后,依次计算:
**OBV能量潮指标**:通过`talib.OBV`函数计算,将成交量与价格变动方向结合,衡量资金流入流出的强度。
**短期收益率return_5**5日收益率计算公式为`x / x.shift(5) - 1`
**中期收益率return_20**20日收益率用于捕捉中短期动量。
**动量因子act_factor**:通过`get_act_factor`函数计算该函数对不同周期的EMA指数移动平均进行arctan变换得到平滑的动量因子。这是项目自定义的核心因子之一。
**收益率分位数排名**:将收益率在截面内转换为百分位排名,消除不同行业间的基数差异,便于跨行业比较。
最终对所有新增字段添加`industry_`前缀,并重命名`ts_code``cat_l2_code`,形成行业级别的因子数据,可用于后续的行业对标分析。
---
## 第四部分:财务数据加载
### 4.1 财务指标数据
第九至十一个代码单元负责加载各类财务数据,为模型增加基本面因子。
**财务指标数据fina_indicator**:从`fina_indicator.h5`文件读取,包含`undist_profit_ps`(每股未分配利润)、`ocfps`(每股经营现金流)、`bps`(每股净资产)、`roa`(资产收益率)和`roe`(净资产收益率)。这些指标反映企业的盈利能力和财务健康状况。
**现金流数据cashflow**:从`cashflow.h5`文件读取,包含`n_cashflow_act`(经营活动净现金流),用于评估企业现金流的真实状况。
**资产负债表数据balancesheet**:从`balancesheet.h5`文件读取,包含`money_cap`(货币资金)和`total_liab`(总负债),用于计算企业的偿债能力。
这些财务数据通过`ann_date`(公告日期)与交易数据进行时间匹配,确保使用最新披露的财务信息。需要注意的是,财务数据存在滞后性,通常季报在季度结束后一段时间才披露,因此在匹配时使用向后合并策略。
---
## 第五部分:数据过滤与清洗
### 5.1 数据过滤规则
第十一个代码单元定义了`filter_data`函数,应用一系列过滤规则清理数据:
**排除ST股票**`df[~df['is_st']]`过滤掉所有ST特别处理股票这类股票存在较大风险。
**排除北京股票**`df[~df['ts_code'].str.endswith('BJ')]`排除北京交易所的股票。
**排除创业板和科创板**`df[~df['ts_code'].str.startswith('30')]`排除创业板股票,`df[~df['ts_code'].str.startswith('68')]`排除科创板股票。这一设计可能是因为这两个板块的股票特性与主板有较大差异。
**排除ST类代码**`df[~df['ts_code'].str.startswith('8')]`排除了以8开头的股票代码通常这类代码也属于风险警示类。
**时间范围过滤**`df[df['trade_date'] >= '2019-01-01']]`只保留2019年及以后的数据确保数据质量和一致性。
**删除冗余列**:如果存在`in_date`列则删除,避免与交易日期混淆。
过滤后的数据量从9436343条减少到5087384条约减少46%。
---
## 第六部分:因子计算
### 6.1 财务因子计算
第十二个代码单元是因子计算的核心部分,首先计算一系列财务相关因子:
**现金流因子cashflow_to_ev_factor**:将现金流数据与企业价值进行对比,计算企业现金流的相对估值水平。
**市净率因子book_to_price_ratio**:通过`caculate_book_to_price_ratio`函数计算,每股净资产与股价的比值,是价值投资的经典指标。
### 6.2 基础技术因子
继续计算各类技术因子:
**换手率均值turnover_rate_mean_5**5日平均换手率反映股票的交易活跃程度。
**收益率方差variance_20**20日收益率方差衡量短期波动性。
**BBI比率因子bbi_ratio_factor**BBI是多空指数的简称通过不同周期均线的组合判断多空趋势。
**日偏离度daily_deviation**:当日涨跌幅与市场平均涨跌幅的差异,反映个股的相对强弱。
**日行业偏离度daily_industry_deviation**:相对于所在行业平均涨跌幅的偏离,反映行业内相对表现。
### 6.3 滚动因子与简单因子
调用`get_rolling_factor``get_simple_factor`函数批量计算大量因子。这两个函数定义在`main/factor/factor.py`中,是项目最核心的因子计算模块。
**get_rolling_factor函数**主要计算的因子包括:
**资金流因子组**
- `lg_elg_net_buy_vol`:超大单加大单的净买入量
- `flow_lg_elg_intensity`:主力资金流强度,等于净买入量除以成交量
- `sm_net_buy_vol`:散户净买入量
- `flow_divergence_diff`:散户与主力背离度,主力净买入减去散户净买入
- `flow_divergence_ratio`:背离比率形式
- `lg_elg_buy_prop`:主力买入占比
- `flow_struct_buy_change`资金流结构变动1日变化率
- `flow_lg_elg_accel`:资金流加速度,二阶导数
**筹码分布因子组**
- `chip_concentration_range`95%成本价 - 5%成本价)/ 当前价格,反映筹码集中程度
- `chip_skewness`:(加权平均成本 - 50%成本价)/ 50%成本价,反映成本分布偏斜方向
- `floating_chip_proxy`:浮筹比例,结合获利盘和价格位置
- `cost_support_15pct_change`15%成本线变化率,反映支撑位变动
- `cat_winner_price_zone`:获利盘价格区域分类
**资金与筹码结合因子**
- `flow_chip_consistency`:资金流与筹码结构一致性
- `profit_taking_vs_absorb`:获利了结压力与承接盘强度
**收益率分布因子**
- `upside_vol``downside_vol`:上涨和下跌方向的标准差
- `vol_ratio`:上下波动率比值
- `return_skew``return_kurtosis`:收益率的偏度和峰度
**成交量因子**
- `volume_change_rate`:成交量变化率
- `cat_volume_breakout`:成交量突破信号
- `turnover_deviation`:换手率偏离度
- `cat_turnover_spike`:换手率激增信号
- `avg_volume_ratio``cat_volume_ratio_breakout`:量比相关因子
**talib技术指标**
- `atr_14``atr_6`14日和6日平均真实波幅
- `obv``maobv_6`能量潮及其6日均线
- `rsi_3`3日RSI
**收益率因子**
- `return_5``return_20`5日和20日收益率
- `std_return_5``std_return_90``std_return_90_2`:不同周期的收益率标准差
**EMA与动量因子**
- 各周期EMA5、13、20、60日
- `act_factor1``act_factor4`基于EMA的动量因子
**协方差因子**
- `cov`:高价与成交量的滚动协方差
- `delta_cov`:协方差差分
- `alpha_22_improved`改进的alpha因子
**其他因子**
- `alpha_003`:收盘价与开盘价的相对位置
- `alpha_007`收盘价与成交量的5日滚动相关性
- `alpha_013`5日与20日累计收益差值
- `vol_break`价格突破85%成本线且量比大于2的信号
- `weight_roc5`加权成本5日变化率
- `price_cost_divergence`:价格与成本相关性
- `smallcap_concentration`:小盘股筹码集中度
- `cost_stability`:筹码稳定性指数
- `liquidity_risk`:筹码流动性风险
**get_simple_factor函数**在滚动因子的基础上进一步计算衍生因子:
- `momentum_factor`动量因子等于成交量变化率加0.5倍换手率偏离度
- `resonance_factor`:共振因子,量比乘以涨跌幅
- `log_close`:收盘价对数
- `cat_vol_spike`:成交量激增分类变量
- `up``down`:上下影线比例
- `obv_maobv_6`OBV与6日均线差值
- `std_return_5_over_std_return_90`:短期与长期波动率比值
- `act_factor5``act_factor6`:综合动量因子
- `active_buy_volume_*`:各档位主动买入占比
- `ctrl_strength`:控制强度,反映成本区间与历史区间的比值
- `low_cost_dev``asymmetry``lock_factor`:各类筹码特征因子
- `cat_golden_resonance`:黄金共振信号
### 6.4 资金流因子
第十三个代码单元继续计算`money_factor.py`中定义的各类资金流相关因子。这些因子主要分析大单资金的流动特征和动量属性。
**lg_flow_mom_corr**大单资金流与20至60日滚动窗口的动量相关性衡量资金流趋势的持续性。
**lg_flow_accel**:大单资金流加速度,捕捉资金流入流出的加速变化。
**profit_pressure**:盈利压力因子,结合价格位置和资金流向。
**underwater_resistance**:水下阻力因子,分析亏损区域的支持强度。
**cost_conc_std_20**20日成本集中度标准差反映筹码分布的稳定程度。
**profit_decay_20**20日盈利衰减因子衡量获利盘随时间的减少程度。
**vol_amp_loss_20**20日波动幅度损失因子。
**vol_drop_profit_cnt_5**5日内量价齐升计数。
**lg_flow_vol_interact_20**20日大单资金流与成交量交互因子。
**cost_break_confirm_cnt_5**5日内成本突破确认计数。
**atr_norm_channel_pos_14**14日ATR标准化通道位置。
**turnover_diff_skew_20**20日换手率差值偏度。
**lg_sm_flow_diverge_20**20日大小单资金流背离。
**pullback_strong_20_20**20日回撤强度因子。
**vol_wgt_hist_pos_20**20日成交量加权历史位置。
**vol_adj_roc_20**20日成交量调整变动率。
### 6.5 截面排名因子
第十三个代码单元的后半部分定义了一系列`cs_rank_*`开头的截面排名因子,这些因子都是通过截面排名方式计算得出,可以消除不同股票间的基数差异。
**cs_rank_net_lg_flow_val**:大单净流入值截面排名
**cs_rank_flow_divergence**:资金流背离截面排名
**cs_rank_industry_adj_lg_flow**:行业调整后大单资金流截面排名
**cs_rank_elg_buy_ratio**:超大单买入占比截面排名
**cs_rank_rel_profit_margin**:相对盈利-margin截面排名
**cs_rank_cost_breadth**:成本宽度截面排名
**cs_rank_dist_to_upper_cost**:到上方成本距离截面排名
**cs_rank_winner_rate**:获利盘比例截面排名
**cs_rank_intraday_range**:日内振幅截面排名
**cs_rank_close_pos_in_range**:收盘价在成本区间位置截面排名
**cs_rank_opening_gap**:跳空缺口截面排名(需要前收盘价)
**cs_rank_pos_in_hist_range**:历史区间位置截面排名
**cs_rank_vol_x_profit_margin**成交量与盈利margin交互截面排名
**cs_rank_lg_flow_price_concordance**:大单资金流与价格一致性截面排名
**cs_rank_turnover_per_winner**:单位获利盘换手率截面排名
**cs_rank_ind_cap_neutral_pe**:行业市值中性市盈率截面排名
**cs_rank_volume_ratio**:量比截面排名
**cs_rank_elg_buy_sell_sm_ratio**:超大单买卖与小单比值截面排名
**cs_rank_cost_dist_vol_ratio**:成本分布成交量比值截面排名
**cs_rank_size**:市值规模截面排名
经过上述所有因子计算数据框最终包含181个字段内存占用约6.5GB。
---
## 第七部分:数据预处理函数定义
### 7.1 特征漂移检测
第十四个代码单元定义了`remove_shifted_features`函数,用于检测训练集和测试集之间是否存在特征分布漂移。漂移检测使用两种统计方法:
**KS检验Kolmogorov-Smirnov检验**比较两个分布是否来自同一分布通过p值判断p值小于阈值默认0.05)则认为存在显著差异。
**Wasserstein距离**衡量两个分布之间的Earth Mover距离距离大于阈值默认0.1)则认为漂移严重。
如果特征同时满足两个条件,则认为该特征存在漂移并将其移除。这一步骤对于确保模型在测试集上的泛化能力非常重要。
### 7.2 标准化处理函数
第十五个代码单元定义了三个核心的标准化处理函数:
**cs_mad_filter截面MAD去极值函数**
MADMedian Absolute Deviation中位绝对偏差是一种稳健的极值检测方法。具体步骤包括按日期分组计算每列的中位数然后计算每个值与中位数的绝对偏差再次取中位数得到MAD最后将超出`[median - k * MAD, median + k * MAD]`范围的值截断到边界。默认k=3scale_factor=1.4826使得MAD约等于正态分布的标准差。
**cs_neutralize_market_cap_numpy市值中性化函数**
对每个交易日的每只股票使用OLS回归将因子值对市值取对数进行回归提取残差作为中性化后的因子值。这一步骤消除因子中与市值相关的系统性偏差使因子更能反映股票的真实特性而非市值效应。
**cs_zscore_standardize截面Z-Score标准化函数**
对每个截面每日计算各因子的均值和标准差然后进行Z-Score变换`Z = (value - mean) / (std + epsilon)`。这使得不同量纲的因子具有可比性且均值为0、标准差为1。
**fill_nan_with_daily_median截面中位数填充函数**
对每个交易日分别计算各因子的中位数,用该中位数填充该日内的缺失值。这一方法比全局中位数填充更能反映数据的时间特性。
### 7.3 其他预处理函数
第十五个代码单元还定义了其他辅助预处理函数:
**remove_outliers_label_percentile**对标签进行百分位去极值默认保留1%到99%分位之间的数据。
**calculate_risk_adjusted_target**:计算风险调整后的目标变量,结合未来收益率和未来波动性。
**calculate_score**:计算综合评分目标,考虑未来收益减去风险惩罚(最大回撤)。
**remove_highly_correlated_features**:移除高度相关的特征,避免多重共线性问题。
**cross_sectional_standardization**截面标准化使用StandardScaler进行标准化。
**neutralize_manual_revised**:手动实现的行业市值中性化函数,对每个行业分别进行回归取残差。
**mad_filter**全局MAD去极值函数简化版
**percentile_filter**:百分位去极值函数。
**iqr_filter**:四分位距标准化函数。
**quantile_filter**:滚动分位数去极值函数。
**select_top_features_by_rankic**基于RankIC选择最优特征。
**create_deviation_within_dates**:创建截面偏差特征,计算行业内各因子与均值的偏差。
---
## 第八部分:目标变量构建
### 8.1 未来收益率计算
第十六个代码单元负责构建预测目标变量。
**未来收益率**:通过`df.groupby('ts_code')['close'].shift(-days) / df['close'] - 1`计算days默认为5即预测未来5日的收益率。shift(-days)表示向后移动获取未来价格。
**涨停标签**构建分类目标变量首先识别当日涨幅超过5%的股票(`cat_up_limit`然后计算过去5日内是否存在涨停通过rolling窗口的max函数再向后shift(5)避免未来信息泄露最后fillna(0)并转换为整数类型。
**过滤异常收益率**:使用`between`方法保留1%到99%分位之间的收益率,去除极端值。
---
## 第九部分:特征列筛选
### 9.1 特征列筛选逻辑
第十七个代码单元进行特征列的最终筛选。通过一系列排除规则,从所有可用列中筛选出建模所需的特征列:
**排除非特征列**:排除`trade_date``ts_code``label`等标识列;排除包含`future``label``score``gen``is_st``pe_ttm``circ_mv``code`等关键词的列;排除原始基础数据列`origin_columns`;排除下划线开头的临时列。
**排除特定特征**:手动排除存在问题的特征如`intraday_lg_flow_corr_20``cap_neutral_cost_metric``hurst_net_mf_vol_60`等;排除财务因子`roa``roe`
最终筛选出191个特征列用于建模。
---
## 第十部分:缺失值处理与标准化流程
### 10.1 缺失值填充
第十八个代码单元将所有特征的缺失值填充为0。这一简化处理的假设是缺失值可能代表数据不可用或极小概率事件用0填充对模型影响相对中性。
### 10.2 完整的预处理流水线
第十九个代码单元展示了完整的预处理流程执行:
**第一步:特征列确定**——合并行业数据、指数数据后确定最终特征列表。
**第二步MAD去极值**——执行第一轮截面MAD去极值处理全量特征。
**第三步:市值中性化**——对第一轮去极值后的特征进行截面市值中性化。
**第四步第二轮MAD去极值**——对中性化后的特征再次进行去极值处理。
**第五步Z-Score标准化**——最后对所有特征进行截面Z-Score标准化。
整个流水线确保了最终进入模型的特征具有:稳定的分布(去极值)、去除市值偏差(中性化)、可比性(标准化)。
---
## 总结
`Classify2_load_model.ipynb`实现了一个完整、规范的量化因子工程流程。整个流程涵盖了数据加载、多源数据合并、因子计算、特征筛选、预处理标准化等关键步骤。
**数据层面**整合了日线行情、资金流、筹码分布、财务指标、行业分类、指数数据等多维度数据源形成了超过180个特征的基础特征池。
**因子层面**:因子体系覆盖了资金流因子、筹码分布因子、技术指标因子、动量因子、波动率因子、截面排名因子等多个类别,形成了多角度、全方位的特征体系。
**预处理层面**建立了完整的预处理流水线包括MAD去极值、市值中性化、Z-Score标准化等标准化步骤确保特征的质量和稳定性。
整个notebook的设计体现了量化投资特征工程的最佳实践为后续的机器学习模型训练提供了高质量的数据基础。

View File

@@ -20,6 +20,7 @@ Example:
""" """
from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync
from src.data.api_wrappers.financial_data.api_income import get_income, sync_income, IncomeSync
from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic
from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
@@ -37,6 +38,10 @@ __all__ = [
"sync_daily", "sync_daily",
"preview_daily_sync", "preview_daily_sync",
"DailySync", "DailySync",
# Income statement
"get_income",
"sync_income",
"IncomeSync",
# Historical stock list # Historical stock list
"get_bak_basic", "get_bak_basic",
"sync_bak_basic", "sync_bak_basic",

View File

@@ -0,0 +1,587 @@
"""财务数据统一同步调度中心。
该模块作为财务数据同步的调度中心,统一管理各类型财务数据的同步流程。
支持全量同步和增量同步两种模式。
财务数据类型:
- income: 利润表 (已实现)
- balance: 资产负债表 (预留)
- cashflow: 现金流量表 (预留)
同步模式:
1. 全量同步 (force_full=True):
- 检查表是否存在,如不存在则建表+建索引
- 从默认开始日期 (20180101) 同步到当前季度
2. 增量同步 (force_full=False):
- 获取表中最新季度 (MAX(end_date))
- 计算当前季度(如果当前日期未到季末,则用前一季度)
- 如果最新季度 == 当前季度,不同步(避免消耗流量)
- 否则从最新季度+1 同步到当前季度
使用方式:
# 增量同步利润表数据(推荐)
from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial
sync_financial()
# 全量同步利润表数据
from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial
sync_financial(force_full=True)
# 预览同步
from src.data.api_wrappers.financial_data.api_financial_sync import preview_sync
preview = preview_sync()
"""
from typing import Optional, Dict, List
from datetime import datetime
import pandas as pd
from src.data.storage import Storage, ThreadSafeStorage
from src.data.utils import (
get_today_date,
get_quarters_in_range,
date_to_quarter,
DEFAULT_START_DATE,
)
from src.data.api_wrappers.financial_data.api_income import get_income
# =============================================================================
# 财务数据表结构定义
# =============================================================================
# 各财务数据表的表名和字段定义
FINANCIAL_TABLES = {
"income": {
"table_name": "financial_income",
"api_name": "income_vip",
"period_field": "end_date", # 用于存储最新季度的字段
"get_data_func": get_income,
},
# 预留:资产负债表
# "balance": {
# "table_name": "financial_balance",
# "api_name": "balance Sheet_vip",
# "period_field": "end_date",
# "get_data_func": get_balance,
# },
# 预留:现金流量表
# "cashflow": {
# "table_name": "financial_cashflow",
# "api_name": "cashflow_vip",
# "period_field": "end_date",
# "get_data_func": get_cashflow,
# },
}
# =============================================================================
# 财务数据同步核心类
# =============================================================================
class FinancialSync:
"""财务数据统一同步管理器。
支持全量同步和增量同步,自动检测数据状态并选择最优同步策略。
功能特性:
- 全量/增量同步自动切换
- 自动建表和索引(如不存在)
- 智能季度计算(当前季度未到季末时使用前一季度)
- 流量保护(最新季度==当前季度时不请求API
Example:
>>> sync = FinancialSync()
>>> sync.sync_all() # 增量同步所有财务数据
>>> sync.sync_all(force_full=True) # 全量同步
>>> sync.sync_income() # 只同步利润表
"""
def __init__(self):
"""初始化同步管理器"""
self.storage = Storage()
self.thread_storage = ThreadSafeStorage()
def _create_table_if_not_exists(self, table_name: str) -> None:
"""如果表不存在则创建表和索引。
Args:
table_name: 表名
"""
if self.storage.exists(table_name):
print(f"[FinancialSync] 表 {table_name} 已存在,跳过建表")
return
print(f"[FinancialSync] 表 {table_name} 不存在,创建表和索引...")
# 根据表名创建不同的表结构
if table_name == "financial_income":
self.storage._connection.execute(f"""
CREATE TABLE IF NOT EXISTS {table_name} (
ts_code VARCHAR(16) NOT NULL,
ann_date DATE,
f_ann_date DATE,
end_date DATE NOT NULL,
report_type INTEGER,
comp_type INTEGER,
basic_eps DOUBLE,
diluted_eps DOUBLE,
PRIMARY KEY (ts_code, end_date)
)
""")
# 创建索引
self.storage._connection.execute(f"""
CREATE INDEX IF NOT EXISTS idx_financial_ann
ON {table_name}(ts_code, ann_date)
""")
else:
# 默认表结构
self.storage._connection.execute(f"""
CREATE TABLE IF NOT EXISTS {table_name} (
ts_code VARCHAR(16) NOT NULL,
end_date DATE NOT NULL,
PRIMARY KEY (ts_code, end_date)
)
""")
print(f"[FinancialSync] 表 {table_name} 创建完成")
def _get_latest_quarter(
self, table_name: str, period_field: str = "end_date"
) -> Optional[str]:
"""获取表中最新季度。
Args:
table_name: 表名
period_field: 季度字段名
Returns:
最新季度字符串 (YYYYMMDD),如无数据返回 None
"""
try:
result = self.storage._connection.execute(f"""
SELECT MAX({period_field}) FROM {table_name}
""").fetchone()
if result and result[0]:
# 转换为字符串格式
latest = result[0]
if hasattr(latest, "strftime"):
return latest.strftime("%Y%m%d")
return str(latest)
return None
except Exception as e:
print(f"[FinancialSync] 获取最新季度失败: {e}")
return None
def _get_current_quarter(self) -> str:
"""获取当前季度(考虑是否到季末)。
如果当前日期未到季度最后一天,则返回前一季度。
这样可以避免获取尚无数据的未来季度。
Returns:
当前季度字符串 (YYYYMMDD)
"""
today = get_today_date()
current_quarter = date_to_quarter(today)
# 检查今天是否到了当前季度的最后一天
if today < current_quarter:
# 未到季末,返回前一季度
return self._get_prev_quarter(current_quarter)
return current_quarter
def _get_prev_quarter(self, quarter: str) -> str:
"""获取前一季度。
Args:
quarter: 季度字符串 (YYYYMMDD)
Returns:
前一季度字符串 (YYYYMMDD)
"""
year = int(quarter[:4])
month_day = quarter[4:]
if month_day == "0331":
# Q1 -> 去年 Q4
return f"{year - 1}1231"
elif month_day == "0630":
# Q2 -> Q1
return f"{year}0331"
elif month_day == "0930":
# Q3 -> Q2
return f"{year}0630"
else: # "1231"
# Q4 -> Q3
return f"{year}0930"
def _get_next_quarter(self, quarter: str) -> str:
"""获取下一季度。
Args:
quarter: 季度字符串 (YYYYMMDD)
Returns:
下一季度字符串 (YYYYMMDD)
"""
year = int(quarter[:4])
month_day = quarter[4:]
if month_day == "0331":
# Q1 -> Q2
return f"{year}0630"
elif month_day == "0630":
# Q2 -> Q3
return f"{year}0930"
elif month_day == "0930":
# Q3 -> Q4
return f"{year}1231"
else: # "1231"
# Q4 -> 明年 Q1
return f"{year + 1}0331"
def _check_incremental_needed(
self,
table_name: str,
period_field: str = "end_date",
) -> tuple[bool, Optional[str], Optional[str]]:
"""检查增量同步是否需要。
Args:
table_name: 表名
period_field: 季度字段名
Returns:
(是否需要同步, 起始季度, 目标季度)
- 如果不需要同步,返回 (False, None, None)
"""
# 获取表中最新季度
latest_quarter = self._get_latest_quarter(table_name, period_field)
if latest_quarter is None:
# 无本地数据,需要全量同步
print(f"[FinancialSync] 表 {table_name} 无数据,需要全量同步")
return (True, DEFAULT_START_DATE, self._get_current_quarter())
print(f"[FinancialSync] 表 {table_name} 最新季度: {latest_quarter}")
# 获取当前季度(考虑是否到季末)
current_quarter = self._get_current_quarter()
print(f"[FinancialSync] 当前季度: {current_quarter}")
# 比较:如果最新季度 >= 当前季度,不需要同步
if latest_quarter >= current_quarter:
print(
f"[FinancialSync] 最新季度 {latest_quarter} >= 当前季度 {current_quarter},跳过增量同步"
)
return (False, None, None)
# 需要增量同步:从最新季度+1 到 当前季度
start_quarter = self._get_next_quarter(latest_quarter)
print(f"[FinancialSync] 增量同步: {start_quarter} -> {current_quarter}")
return (True, start_quarter, current_quarter)
def _sync_single_table(
self,
table_config: Dict,
start_quarter: str,
end_quarter: str,
) -> int:
"""同步单个财务数据表。
Args:
table_config: 表配置字典
start_quarter: 起始季度
end_quarter: 目标季度
Returns:
同步的记录数
"""
table_name = table_config["table_name"]
get_data_func = table_config["get_data_func"]
# 获取需要同步的季度列表
quarters = get_quarters_in_range(start_quarter, end_quarter)
print(f"[FinancialSync] 计划同步 {len(quarters)} 个季度: {quarters}")
total_records = 0
# 对每个季度调用 API 获取数据
for period in quarters:
try:
df = get_data_func(period)
if df.empty:
print(f"[WARN] 季度 {period} 无数据")
continue
# 只保留合并报表 (report_type='1',注意是字符串)
if "report_type" in df.columns:
df = df[df["report_type"] == "1"]
if not df.empty:
self.thread_storage.queue_save(table_name, df)
print(f"[FinancialSync] 季度 {period} -> {len(df)} 条记录")
total_records += len(df)
except Exception as e:
print(f"[ERROR] 获取季度 {period} 数据失败: {e}")
# 刷新缓存到数据库
self.thread_storage.flush()
return total_records
def sync_income(
self,
force_full: bool = False,
) -> Dict:
"""同步利润表数据。
Args:
force_full: 若为 True强制全量同步
Returns:
同步结果字典
"""
table_config = FINANCIAL_TABLES["income"]
table_name = table_config["table_name"]
period_field = table_config["period_field"]
print("\n" + "=" * 60)
print(f"[FinancialSync] 开始同步利润表 (force_full={force_full})")
print("=" * 60)
# 1. 全量同步:建表
if force_full:
self._create_table_if_not_exists(table_name)
start_quarter = DEFAULT_START_DATE
end_quarter = self._get_current_quarter()
else:
# 2. 增量同步:检查是否需要
needed, start_quarter, end_quarter = self._check_incremental_needed(
table_name, period_field
)
if not needed:
return {
"status": "skipped",
"message": "数据已是最新",
"table": table_name,
}
# 检查表是否存在,不存在则创建
if not self.storage.exists(table_name):
self._create_table_if_not_exists(table_name)
# 3. 执行同步
print(f"[FinancialSync] 同步范围: {start_quarter} -> {end_quarter}")
total_records = self._sync_single_table(
table_config, start_quarter, end_quarter
)
result = {
"status": "success",
"table": table_name,
"start_quarter": start_quarter,
"end_quarter": end_quarter,
"records": total_records,
}
print(f"[FinancialSync] 利润表同步完成: {total_records} 条记录")
return result
def sync_all(
self,
force_full: bool = False,
) -> Dict[str, Dict]:
"""同步所有财务数据表。
Args:
force_full: 若为 True强制全量同步
Returns:
各表同步结果字典
"""
results = {}
print("\n" + "=" * 60)
print(f"[FinancialSync] 开始同步所有财务数据 (force_full={force_full})")
print("=" * 60)
# 同步各财务数据表
for data_type, table_config in FINANCIAL_TABLES.items():
try:
if data_type == "income":
result = self.sync_income(force_full=force_full)
else:
# 预留其他表的同步逻辑
print(f"[FinancialSync] {data_type} 暂未实现,跳过")
result = {"status": "not_implemented"}
results[data_type] = result
except Exception as e:
print(f"[ERROR] 同步 {data_type} 失败: {e}")
results[data_type] = {"status": "error", "error": str(e)}
# 打印汇总
print("\n" + "=" * 60)
print("[FinancialSync] 同步汇总")
print("=" * 60)
for data_type, result in results.items():
status = result.get("status", "unknown")
records = result.get("records", 0)
print(f" {data_type}: {status} ({records} records)")
print("=" * 60)
return results
# =============================================================================
# 便捷函数
# =============================================================================
def sync_financial(
data_type: str = "income",
force_full: bool = False,
) -> Dict:
"""同步财务数据(便捷函数)。
Args:
data_type: 财务数据类型 ('income', 'balance', 'cashflow')
force_full: 若为 True强制全量同步
Returns:
同步结果字典
Example:
>>> # 增量同步利润表
>>> sync_financial()
>>> # 全量同步
>>> sync_financial(force_full=True)
"""
syncer = FinancialSync()
if data_type == "income":
return syncer.sync_income(force_full=force_full)
else:
raise ValueError(f"不支持的财务数据类型: {data_type}")
def sync_all_financial(force_full: bool = False) -> Dict[str, Dict]:
"""同步所有财务数据(便捷函数)。
Args:
force_full: 若为 True强制全量同步
Returns:
各表同步结果字典
Example:
>>> # 增量同步所有财务数据
>>> sync_all_financial()
>>> # 全量同步
>>> sync_all_financial(force_full=True)
"""
syncer = FinancialSync()
return syncer.sync_all(force_full=force_full)
def preview_sync() -> Dict:
"""预览同步信息(不实际同步)。
Returns:
预览信息字典:
{
'income': {
'sync_needed': bool,
'latest_quarter': str,
'current_quarter': str,
'start_quarter': str,
'end_quarter': str,
},
...
}
"""
syncer = FinancialSync()
preview = {}
for data_type, table_config in FINANCIAL_TABLES.items():
if data_type != "income":
continue
table_name = table_config["table_name"]
period_field = table_config["period_field"]
# 获取最新季度
latest_quarter = syncer._get_latest_quarter(table_name, period_field)
current_quarter = syncer._get_current_quarter()
# 检查是否需要同步
needed, start_quarter, end_quarter = syncer._check_incremental_needed(
table_name, period_field
)
preview[data_type] = {
"sync_needed": needed,
"latest_quarter": latest_quarter,
"current_quarter": current_quarter,
"start_quarter": start_quarter,
"end_quarter": end_quarter,
}
return preview
# =============================================================================
# 主程序入口
# =============================================================================
if __name__ == "__main__":
import sys
print("=" * 60)
print("财务数据同步模块")
print("=" * 60)
print("\n使用方式:")
print(" # 预览同步信息")
print(
" from src.data.api_wrappers.financial_data.api_financial_sync import preview_sync"
)
print(" preview = preview_sync()")
print("")
print(" # 增量同步(推荐)")
print(
" from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial"
)
print(" sync_financial()")
print("")
print(" # 全量同步")
print(" sync_financial(force_full=True)")
print("")
print(" # 同步所有财务数据")
print(
" from src.data.api_wrappers.financial_data.api_financial_sync import sync_all_financial"
)
print(" sync_all_financial()")
print("=" * 60)
# 默认执行增量同步
if len(sys.argv) > 1 and sys.argv[1] == "--full":
print("\n[Main] 执行全量同步...")
result = sync_all_financial(force_full=True)
else:
print("\n[Main] 执行增量同步...")
result = sync_financial()
print("\n[Main] 执行完成!")
print(f"结果: {result}")

View File

@@ -0,0 +1,139 @@
"""利润表数据接口 (VIP 版本)
使用 Tushare VIP 接口 (income_vip) 获取利润表数据。
按季度同步,一次请求获取一个季度的全部上市公司数据。
接口说明:
- income_vip: 获取某一季度全部上市公司利润表数据
- 需要 5000 积分才能调用
- period 参数为报告期(季度最后一天,如 20231231
"""
import pandas as pd
from typing import Optional, List
from tqdm import tqdm
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage
from src.data.utils import get_today_date, get_quarters_in_range
def get_income(
period: str,
fields: Optional[str] = None,
) -> pd.DataFrame:
"""获取利润表数据 (VIP 接口)
从 Tushare 获取指定季度的全部上市公司利润表数据。
Args:
period: 报告期,季度最后一天日期 (如 '20231231', '20230930')
- 0331: 一季报
- 0630: 半年报
- 0930: 三季报
- 1231: 年报
fields: 指定返回字段,默认返回全部字段
Returns:
pd.DataFrame 包含利润表数据:
- ts_code: 股票代码
- ann_date: 公告日期
- end_date: 报告期
- basic_eps: 基本每股收益
- report_type: 报告类型 (1=合并报表)
Example:
>>> data = get_income('20231231')
>>> print(data[['ts_code', 'ann_date', 'basic_eps']].head())
"""
client = TushareClient()
# 默认字段返回全部字段利润表有100+字段)
if fields is None:
fields = "ts_code,ann_date,f_ann_date,end_date,report_type,comp_type,end_type,basic_eps,diluted_eps,total_revenue,revenue,int_income,prem_earned,comm_income,n_commis_income,n_oth_income,n_oth_b_income,prem_income,out_prem,une_prem_reser,reins_income,n_sec_tb_income,n_sec_uw_income,n_asset_mg_income,oth_b_income,fv_value_chg_gain,invest_income,ass_invest_income,forex_gain,total_cogs,oper_cost,int_exp,comm_exp,biz_tax_surchg,sell_exp,admin_exp,fin_exp,assets_impair_loss,prem_refund,compens_payout,reser_insur_liab,div_payt,reins_exp,oper_exp,compens_payout_refu,insur_reser_refu,reins_cost_refund,other_bus_cost,operate_profit,non_oper_income,non_oper_exp,nca_disploss,total_profit,income_tax,n_income,n_income_attr_p,minority_gain,oth_compr_income,t_compr_income,compr_inc_attr_p,compr_inc_attr_m_s,ebit,ebitda,insurance_exp,undist_profit,distable_profit,rd_exp,fin_exp_int_exp,fin_exp_int_inc,transfer_surplus_rese,transfer_housing_imprest,transfer_oth,adj_lossgain,withdra_legal_surplus,withdra_legal_pubfund,withdra_biz_devfund,withdra_rese_fund,withdra_oth_ersu,workers_welfare,distr_profit_shrhder,prfshare_payable_dvd,comshare_payable_dvd,capit_comstock_div,net_after_nr_lp_correct,credit_impa_loss,net_expo_hedging_benefits,oth_impair_loss_assets,total_opcost,amodcost_fin_assets,oth_income,asset_disp_income,continued_net_profit,end_net_profit,update_flag"
params = {"fields": fields, "period": period}
return client.query("income_vip", **params)
# =============================================================================
# IncomeSync - 利润表数据批量同步类
# =============================================================================
class IncomeSync:
"""利润表数据批量同步管理器 (VIP 版本)
功能特性:
- 按季度同步,每次请求获取该季度全部上市公司数据
- 使用 income_vip 接口
- 只保留合并报表report_type=1
- 使用 ThreadSafeStorage 安全写入
Example:
>>> sync = IncomeSync()
>>> sync.sync(start_date='20200101', end_date='20231231')
"""
def __init__(self):
"""初始化同步管理器"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
def sync(
self,
start_date: str,
end_date: Optional[str] = None,
) -> None:
"""同步利润表数据
Args:
start_date: 开始日期 YYYYMMDD
end_date: 结束日期 YYYYMMDD默认为今天
"""
if end_date is None:
end_date = get_today_date()
# 获取日期范围内的所有季度
quarters = get_quarters_in_range(start_date, end_date)
print(f"[IncomeSync] 计划同步 {len(quarters)} 个季度: {quarters}")
# 对每个季度调用 income_vip 获取全部股票数据
for period in tqdm(quarters, desc="Syncing income by quarter"):
try:
df = get_income(period)
if df.empty:
print(f"[WARN] 季度 {period} 无数据")
continue
# 只保留合并报表 (report_type='1',注意是字符串)
if "report_type" in df.columns:
df = df[df["report_type"] == "1"]
if not df.empty:
self.storage.queue_save("financial_income", df)
print(f"[IncomeSync] 季度 {period} -> {len(df)} 条记录")
except Exception as e:
print(f"[ERROR] 获取季度 {period} 数据失败: {e}")
# 刷新缓存到数据库
self.storage.flush()
print(f"[IncomeSync] 同步完成,共处理 {len(quarters)} 个季度")
def sync_income(
start_date: str,
end_date: Optional[str] = None,
) -> None:
"""同步利润表数据(便捷函数)
Args:
start_date: 开始日期 YYYYMMDD
end_date: 结束日期 YYYYMMDD默认为今天
Example:
>>> sync_income('20200101')
>>> sync_income('20200101', '20231231')
"""
syncer = IncomeSync()
syncer.sync(start_date, end_date)

View File

@@ -0,0 +1,145 @@
利润表
接口income可以通过数据工具调试和查看数据。
描述:获取上市公司财务利润表数据
积分用户需要至少2000积分才可以调取具体请参阅积分获取办法
提示当前接口只能按单只股票获取其历史数据如果需要获取某一季度全部上市公司数据请使用income_vip接口参数一致需积攒5000积分。
输入参数
名称 类型 必选 描述
ts_code str Y 股票代码
ann_date str N 公告日期YYYYMMDD格式下同
f_ann_date str N 实际公告日期
start_date str N 公告日开始日期
end_date str N 公告日结束日期
period str N 报告期(每个季度最后一天的日期比如20171231表示年报20170630半年报20170930三季报)
report_type str N 报告类型,参考文档最下方说明
comp_type str N 公司类型1一般工商业2银行3保险4证券
输出参数
名称 类型 默认显示 描述
ts_code str Y TS代码
ann_date str Y 公告日期
f_ann_date str Y 实际公告日期
end_date str Y 报告期
report_type str Y 报告类型 见底部表
comp_type str Y 公司类型(1一般工商业2银行3保险4证券)
end_type str Y 报告期类型
basic_eps float Y 基本每股收益
diluted_eps float Y 稀释每股收益
total_revenue float Y 营业总收入
revenue float Y 营业收入
int_income float Y 利息收入
prem_earned float Y 已赚保费
comm_income float Y 手续费及佣金收入
n_commis_income float Y 手续费及佣金净收入
n_oth_income float Y 其他经营净收益
n_oth_b_income float Y 加:其他业务净收益
prem_income float Y 保险业务收入
out_prem float Y 减:分出保费
une_prem_reser float Y 提取未到期责任准备金
reins_income float Y 其中:分保费收入
n_sec_tb_income float Y 代理买卖证券业务净收入
n_sec_uw_income float Y 证券承销业务净收入
n_asset_mg_income float Y 受托客户资产管理业务净收入
oth_b_income float Y 其他业务收入
fv_value_chg_gain float Y 加:公允价值变动净收益
invest_income float Y 加:投资净收益
ass_invest_income float Y 其中:对联营企业和合营企业的投资收益
forex_gain float Y 加:汇兑净收益
total_cogs float Y 营业总成本
oper_cost float Y 减:营业成本
int_exp float Y 减:利息支出
comm_exp float Y 减:手续费及佣金支出
biz_tax_surchg float Y 减:营业税金及附加
sell_exp float Y 减:销售费用
admin_exp float Y 减:管理费用
fin_exp float Y 减:财务费用
assets_impair_loss float Y 减:资产减值损失
prem_refund float Y 退保金
compens_payout float Y 赔付总支出
reser_insur_liab float Y 提取保险责任准备金
div_payt float Y 保户红利支出
reins_exp float Y 分保费用
oper_exp float Y 营业支出
compens_payout_refu float Y 减:摊回赔付支出
insur_reser_refu float Y 减:摊回保险责任准备金
reins_cost_refund float Y 减:摊回分保费用
other_bus_cost float Y 其他业务成本
operate_profit float Y 营业利润
non_oper_income float Y 加:营业外收入
non_oper_exp float Y 减:营业外支出
nca_disploss float Y 其中:减:非流动资产处置净损失
total_profit float Y 利润总额
income_tax float Y 所得税费用
n_income float Y 净利润(含少数股东损益)
n_income_attr_p float Y 净利润(不含少数股东损益)
minority_gain float Y 少数股东损益
oth_compr_income float Y 其他综合收益
t_compr_income float Y 综合收益总额
compr_inc_attr_p float Y 归属于母公司(或股东)的综合收益总额
compr_inc_attr_m_s float Y 归属于少数股东的综合收益总额
ebit float Y 息税前利润
ebitda float Y 息税折旧摊销前利润
insurance_exp float Y 保险业务支出
undist_profit float Y 年初未分配利润
distable_profit float Y 可分配利润
rd_exp float Y 研发费用
fin_exp_int_exp float Y 财务费用:利息费用
fin_exp_int_inc float Y 财务费用:利息收入
transfer_surplus_rese float Y 盈余公积转入
transfer_housing_imprest float Y 住房周转金转入
transfer_oth float Y 其他转入
adj_lossgain float Y 调整以前年度损益
withdra_legal_surplus float Y 提取法定盈余公积
withdra_legal_pubfund float Y 提取法定公益金
withdra_biz_devfund float Y 提取企业发展基金
withdra_rese_fund float Y 提取储备基金
withdra_oth_ersu float Y 提取任意盈余公积金
workers_welfare float Y 职工奖金福利
distr_profit_shrhder float Y 可供股东分配的利润
prfshare_payable_dvd float Y 应付优先股股利
comshare_payable_dvd float Y 应付普通股股利
capit_comstock_div float Y 转作股本的普通股股利
net_after_nr_lp_correct float N 扣除非经常性损益后的净利润(更正前)
credit_impa_loss float N 信用减值损失
net_expo_hedging_benefits float N 净敞口套期收益
oth_impair_loss_assets float N 其他资产减值损失
total_opcost float N 营业总成本(二)
amodcost_fin_assets float N 以摊余成本计量的金融资产终止确认收益
oth_income float N 其他收益
asset_disp_income float N 资产处置收益
continued_net_profit float N 持续经营净利润
end_net_profit float N 终止经营净利润
update_flag str Y 更新标识
接口使用说明
pro = ts.pro_api()
df = pro.income(ts_code='600000.SH', start_date='20180101', end_date='20180730', fields='ts_code,ann_date,f_ann_date,end_date,report_type,comp_type,basic_eps,diluted_eps')
获取某一季度全部股票数据
df = pro.income_vip(period='20181231',fields='ts_code,ann_date,f_ann_date,end_date,report_type,comp_type,basic_eps,diluted_eps')
数据样例
ts_code ann_date f_ann_date end_date report_type comp_type basic_eps diluted_eps \
0 600000.SH 20180428 20180428 20180331 1 2 0.46 0.46
1 600000.SH 20180428 20180428 20180331 1 2 0.46 0.46
2 600000.SH 20180428 20180428 20171231 1 2 1.84 1.84
主要报表类型说明
代码 | 类型 | 说明
---- | ----- | ---- |
1 | 合并报表 | 上市公司最新报表(默认)
2 | 单季合并 | 单一季度的合并报表
3 | 调整单季合并表 | 调整后的单季合并报表(如果有)
4 | 调整合并报表 | 本年度公布上年同期的财务报表数据,报告期为上年度
5 | 调整前合并报表 | 数据发生变更,将原数据进行保留,即调整前的原数据
6 | 母公司报表 | 该公司母公司的财务报表数据
7 | 母公司单季表 | 母公司的单季度表
8 | 母公司调整单季表 | 母公司调整后的单季表
9 | 母公司调整表 | 该公司母公司的本年度公布上年同期的财务报表数据
10 | 母公司调整前报表 | 母公司调整之前的原始财务报表数据
11 | 母公司调整前合并报表 | 母公司调整之前合并报表原数据
12 | 母公司调整前报表 | 母公司报表发生变更前保留的原数据

View File

@@ -90,6 +90,113 @@ class Storage:
CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code) CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code)
""") """)
# Create financial_income table for income statement data
# 完整的利润表字段94列全部
self._connection.execute("""
CREATE TABLE IF NOT EXISTS financial_income (
ts_code VARCHAR(16) NOT NULL,
ann_date DATE,
f_ann_date DATE,
end_date DATE NOT NULL,
report_type INTEGER,
comp_type INTEGER,
end_type VARCHAR(10),
basic_eps DOUBLE,
diluted_eps DOUBLE,
total_revenue DOUBLE,
revenue DOUBLE,
int_income DOUBLE,
prem_earned DOUBLE,
comm_income DOUBLE,
n_commis_income DOUBLE,
n_oth_income DOUBLE,
n_oth_b_income DOUBLE,
prem_income DOUBLE,
out_prem DOUBLE,
une_prem_reser DOUBLE,
reins_income DOUBLE,
n_sec_tb_income DOUBLE,
n_sec_uw_income DOUBLE,
n_asset_mg_income DOUBLE,
oth_b_income DOUBLE,
fv_value_chg_gain DOUBLE,
invest_income DOUBLE,
ass_invest_income DOUBLE,
forex_gain DOUBLE,
total_cogs DOUBLE,
oper_cost DOUBLE,
int_exp DOUBLE,
comm_exp DOUBLE,
biz_tax_surchg DOUBLE,
sell_exp DOUBLE,
admin_exp DOUBLE,
fin_exp DOUBLE,
assets_impair_loss DOUBLE,
prem_refund DOUBLE,
compens_payout DOUBLE,
reser_insur_liab DOUBLE,
div_payt DOUBLE,
reins_exp DOUBLE,
oper_exp DOUBLE,
compens_payout_refu DOUBLE,
insur_reser_refu DOUBLE,
reins_cost_refund DOUBLE,
other_bus_cost DOUBLE,
operate_profit DOUBLE,
non_oper_income DOUBLE,
non_oper_exp DOUBLE,
nca_disploss DOUBLE,
total_profit DOUBLE,
income_tax DOUBLE,
n_income DOUBLE,
n_income_attr_p DOUBLE,
minority_gain DOUBLE,
oth_compr_income DOUBLE,
t_compr_income DOUBLE,
compr_inc_attr_p DOUBLE,
compr_inc_attr_m_s DOUBLE,
ebit DOUBLE,
ebitda DOUBLE,
insurance_exp DOUBLE,
undist_profit DOUBLE,
distable_profit DOUBLE,
rd_exp DOUBLE,
fin_exp_int_exp DOUBLE,
fin_exp_int_inc DOUBLE,
transfer_surplus_rese DOUBLE,
transfer_housing_imprest DOUBLE,
transfer_oth DOUBLE,
adj_lossgain DOUBLE,
withdra_legal_surplus DOUBLE,
withdra_legal_pubfund DOUBLE,
withdra_biz_devfund DOUBLE,
withdra_rese_fund DOUBLE,
withdra_oth_ersu DOUBLE,
workers_welfare DOUBLE,
distr_profit_shrhder DOUBLE,
prfshare_payable_dvd DOUBLE,
comshare_payable_dvd DOUBLE,
capit_comstock_div DOUBLE,
net_after_nr_lp_correct DOUBLE,
credit_impa_loss DOUBLE,
net_expo_hedging_benefits DOUBLE,
oth_impair_loss_assets DOUBLE,
total_opcost DOUBLE,
amodcost_fin_assets DOUBLE,
oth_income DOUBLE,
asset_disp_income DOUBLE,
continued_net_profit DOUBLE,
end_net_profit DOUBLE,
update_flag VARCHAR(1),
PRIMARY KEY (ts_code, end_date)
)
""")
# Create index for financial_income
self._connection.execute("""
CREATE INDEX IF NOT EXISTS idx_financial_ann ON financial_income(ts_code, ann_date)
""")
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict: def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
"""Save data to DuckDB. """Save data to DuckDB.
@@ -104,13 +211,35 @@ class Storage:
if data.empty: if data.empty:
return {"status": "skipped", "rows": 0} return {"status": "skipped", "rows": 0}
# Ensure date column is proper type # 确保日期列是正确的类型 (YYYYMMDD -> date)
# trade_date: 日线数据日期
if "trade_date" in data.columns: if "trade_date" in data.columns:
data = data.copy() data = data.copy()
data["trade_date"] = pd.to_datetime( data["trade_date"] = pd.to_datetime(
data["trade_date"], format="%Y%m%d" data["trade_date"], format="%Y%m%d"
).dt.date ).dt.date
# ann_date: 公告日期
if "ann_date" in data.columns:
data = data.copy()
data["ann_date"] = pd.to_datetime(
data["ann_date"], format="%Y%m%d", errors="coerce"
).dt.date
# f_ann_date: 最终公告日期
if "f_ann_date" in data.columns:
data = data.copy()
data["f_ann_date"] = pd.to_datetime(
data["f_ann_date"], format="%Y%m%d", errors="coerce"
).dt.date
# end_date: 报告期/期末日期
if "end_date" in data.columns:
data = data.copy()
data["end_date"] = pd.to_datetime(
data["end_date"], format="%Y%m%d", errors="coerce"
).dt.date
# Register DataFrame as temporary view # Register DataFrame as temporary view
self._connection.register("temp_data", data) self._connection.register("temp_data", data)

View File

@@ -4,7 +4,7 @@
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional, List
# 默认全量同步开始日期 # 默认全量同步开始日期
@@ -73,3 +73,74 @@ def format_date(dt: datetime) -> str:
YYYYMMDD 格式的日期字符串 YYYYMMDD 格式的日期字符串
""" """
return dt.strftime("%Y%m%d") return dt.strftime("%Y%m%d")
def is_quarter_end(date_str: str) -> bool:
"""判断是否为季度最后一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
是否为季度最后一天
"""
month_day = date_str[4:]
return month_day in ("0331", "0630", "0930", "1231")
def date_to_quarter(date_str: str) -> str:
"""将日期转换为对应季度的最后一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
季度最后一天,格式为 YYYYMMDD
例如: 20230115 -> 20230331
"""
year = date_str[:4]
month = int(date_str[4:6])
if month <= 3:
return year + "0331"
elif month <= 6:
return year + "0630"
elif month <= 9:
return year + "0930"
else:
return year + "1231"
def get_quarters_in_range(start_date: str, end_date: str) -> List[str]:
"""获取日期范围内的所有季度列表。
Args:
start_date: 开始日期 YYYYMMDD
end_date: 结束日期 YYYYMMDD
Returns:
季度列表,格式为 YYYYMMDD按时间倒序排列
例如: ['20231231', '20230930', '20230630', '20230331']
"""
quarters = []
# 将开始日期和结束日期都转换为季度
start_quarter = date_to_quarter(start_date)
end_quarter = date_to_quarter(end_date)
# 解析年份
start_year = int(start_date[:4])
end_year = int(end_date[:4])
# 遍历所有年份和季度
for year in range(end_year, start_year - 1, -1):
year_str = str(year)
# 季度顺序: Q4, Q3, Q2, Q1 (倒序)
for quarter in ["1231", "0930", "0630", "0331"]:
quarter_date = year_str + quarter
# 只包含在范围内的季度
if quarter_date >= start_quarter and quarter_date <= end_quarter:
quarters.append(quarter_date)
return quarters

1535
src/factors/FACTOR_GUIDE.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,52 @@
- CompositeFactor: 组合因子 - CompositeFactor: 组合因子
- ScalarFactor: 标量运算因子 - ScalarFactor: 标量运算因子
因子分类目录:
- momentum/: 动量因子MA、收益率排名等
- financial/: 财务因子EPS、ROE等
- valuation/: 估值因子PE、PB、PS等
- technical/: 技术指标因子RSI、MACD、布林带等
- quality/: 质量因子(盈利能力、稳定性等)
- sentiment/: 情绪因子(换手率、资金流向等)
- volume/: 成交量因子OBV、成交量比率等
- volatility/: 波动率因子历史波动率、GARCH等
数据加载和执行Phase 3-4
- DataLoader: 数据加载器
- FactorEngine: 因子执行引擎
使用示例:
# 使用通用因子(参数化)
from src.factors import MovingAverageFactor, ReturnRankFactor
from src.factors import DataLoader, FactorEngine
ma5 = MovingAverageFactor(period=5) # 5日MA
ma10 = MovingAverageFactor(period=10) # 10日MA
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
loader = DataLoader(data_dir="data")
engine = FactorEngine(loader)
result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131")
"""
因子框架提供以下核心功能
1. 类型安全的因子定义截面因子时序因子
2. 数据泄露防护机制
3. 因子组合和运算
4. 高效的数据加载和计算引擎
基础数据类型Phase 1
- DataSpec: 数据需求规格
- FactorContext: 计算上下文
- FactorData: 数据容器
因子基类Phase 2
- BaseFactor: 抽象基类
- CrossSectionalFactor: 日期截面因子基类
- TimeSeriesFactor: 时间序列因子基类
- CompositeFactor: 组合因子
- ScalarFactor: 标量运算因子
动量因子momentum/ 动量因子momentum/
- MovingAverageFactor: 移动平均线时序因子 - MovingAverageFactor: 移动平均线时序因子
- ReturnRankFactor: 收益率排名截面因子 - ReturnRankFactor: 收益率排名截面因子

View File

@@ -72,8 +72,8 @@ class DataLoader:
if cache_key in self._cache: if cache_key in self._cache:
df = self._cache[cache_key] df = self._cache[cache_key]
else: else:
# 读取 H5 文件 # 读取 H5 文件(传入日期范围以支持过滤)
df = self._read_h5(spec.source) df = self._read_h5(spec.source, date_range=date_range)
# 列选择 - 只保留需要的列 # 列选择 - 只保留需要的列
missing_cols = set(spec.columns) - set(df.columns) missing_cols = set(spec.columns) - set(df.columns)
@@ -107,7 +107,11 @@ class DataLoader:
"""清空缓存""" """清空缓存"""
self._cache.clear() self._cache.clear()
def _read_h5(self, source: str) -> pl.DataFrame: def _read_h5(
self,
source: str,
date_range: Optional[Tuple[str, str]] = None,
) -> pl.DataFrame:
"""读取数据 - 从 DuckDB 加载为 Polars DataFrame。 """读取数据 - 从 DuckDB 加载为 Polars DataFrame。
迁移说明: 迁移说明:
@@ -117,6 +121,7 @@ class DataLoader:
Args: Args:
source: 表名(对应 DuckDB 中的表,如 "daily" source: 表名(对应 DuckDB 中的表,如 "daily"
date_range: 日期范围限制 (start_date, end_date),可选
Returns: Returns:
Polars DataFrame Polars DataFrame
@@ -125,11 +130,38 @@ class DataLoader:
Exception: 数据库查询错误 Exception: 数据库查询错误
""" """
from src.data.storage import Storage from src.data.storage import Storage
from src.data.api_wrappers.api_trade_cal import get_trading_days
from src.data.utils import get_today_date
from src.factors.financial.utils import expand_period_to_trading_days
storage = Storage() storage = Storage()
# 如果 DataLoader 有 date_range传递给 Storage 进行过滤 # 特殊处理财务数据:将报告期展开到交易日
# 实现查询下推,只加载必要数据 if source == "financial_income":
# 确定日期范围
start_date = date_range[0] if date_range else "20180101"
end_date = date_range[1] if date_range else get_today_date()
# 1. 加载原始财务数据(报告期粒度),按日期范围过滤
# 注意financial_income 使用 end_date 字段作为报告期
df = storage.load_polars(
"financial_income",
start_date=start_date,
end_date=end_date,
)
if len(df) == 0:
return pl.DataFrame()
# 2. 获取交易日历从2018年开始到当前确保有足够的历史数据用于前向填充
# 需要从数据的最小日期开始,确保能获取到足够的交易日
trade_start = "20180101" if start_date > "20180101" else start_date
trade_dates = get_trading_days(trade_start, get_today_date())
# 3. 展开到交易日(前向填充)
return expand_period_to_trading_days(df, trade_dates)
# 其他数据源保持原有逻辑
return storage.load_polars(source) return storage.load_polars(source)
def _merge_dataframes(self, dataframes: List[pl.DataFrame]) -> pl.DataFrame: def _merge_dataframes(self, dataframes: List[pl.DataFrame]) -> pl.DataFrame:

View File

@@ -4,7 +4,10 @@
因子分类: 因子分类:
- financial: 财务因子 - financial: 财务因子
- (待添加) - EPSFactor: 每股收益排名因子
已添加因子:
- EPSFactor: 每股收益排名基于basic_eps
待添加因子: 待添加因子:
- PERankFactor: 市盈率排名 - PERankFactor: 市盈率排名
@@ -12,4 +15,6 @@
- DividendFactor: 股息率因子 - DividendFactor: 股息率因子
""" """
__all__ = [] from src.factors.financial.eps_factor import EPSFactor
__all__ = ["EPSFactor"]

View File

@@ -0,0 +1,66 @@
"""EPS因子
每股收益(EPS)排名因子实现
"""
from typing import List
import polars as pl
from src.factors.base import CrossSectionalFactor
from src.factors.data_spec import DataSpec, FactorData
class EPSFactor(CrossSectionalFactor):
"""每股收益(EPS)排名因子
计算逻辑使用最新报告期的basic_eps每天对所有股票进行截面排名
Attributes:
name: 因子名称 "eps_rank"
category: 因子分类 "financial"
data_specs: 数据需求规格
Example:
>>> from src.factors import FactorEngine, DataLoader
>>> from src.factors.financial.eps_factor import EPSFactor
>>> loader = DataLoader('data')
>>> engine = FactorEngine(loader)
>>> eps_factor = EPSFactor()
>>> result = engine.compute(eps_factor, start_date='20210101', end_date='20210131')
"""
name: str = "eps_rank"
category: str = "financial"
description: str = "每股收益截面排名因子"
data_specs: List[DataSpec] = [
DataSpec(
"financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1
)
]
def compute(self, data: FactorData) -> pl.Series:
"""计算EPS排名
Args:
data: FactorData包含当前日期的截面数据
Returns:
EPS排名的0-1标准化值0-1之间
"""
# 获取当前日期的截面数据
cs = data.get_cross_section()
if len(cs) == 0:
return pl.Series(name=self.name, values=[])
# 提取EPS值填充缺失值为0
eps = cs["basic_eps"].fill_null(0)
# 计算排名并归一化到0-1
if len(eps) > 1 and eps.max() != eps.min():
ranks = eps.rank(method="average") / len(eps)
else:
# 数据不足或全部相同返回0.5
ranks = pl.Series(name=self.name, values=[0.5] * len(eps))
return ranks

View File

@@ -0,0 +1,82 @@
"""财务因子工具函数
提供财务数据处理的工具函数:
- expand_period_to_trading_days: 将报告期数据展开到每个交易日(前向填充)
"""
from typing import List
import polars as pl
def expand_period_to_trading_days(
financial_df: pl.DataFrame,
trade_dates: List[str],
) -> pl.DataFrame:
"""将财务数据(报告期粒度)展开到每个交易日(前向填充)
核心逻辑:对于每个交易日,找到该日期之前最新的已公告报告期数据。
例如2020年报(20201231)公告于20210428则在2021-04-28之后的每个
交易日都使用该年报数据直到2021一季报公告。
Args:
financial_df: 财务数据DataFrame包含 ts_code, ann_date, end_date, ...
trade_dates: 交易日列表YYYYMMDD格式已排序
Returns:
DataFrame包含 trade_date, ts_code 和所有财务字段
Example:
>>> financial_df = pl.DataFrame({
... 'ts_code': ['000001.SZ'],
... 'ann_date': ['20210428'],
... 'end_date': ['20210331'],
... 'basic_eps': [0.5]
... })
>>> trade_dates = ['20210428', '20210429', '20210430']
>>> result = expand_period_to_trading_days(financial_df, trade_dates)
>>> print(result)
shape: (3, 5)
┌───────────┬───────────┬────────────┬────────────┬───────────┐
│ ts_code ┆ ann_date ┆ end_date ┆ basic_eps ┆ trade_date│
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ str │
╞═══════════╪═══════════╪════════════╪════════════╪═══════════╡
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210428 │
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210429 │
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210430 │
└───────────┴───────────┴────────────┴────────────┴───────────┘
"""
if len(financial_df) == 0:
return pl.DataFrame()
results = []
# 按股票分组处理
for ts_code in financial_df["ts_code"].unique():
stock_data = financial_df.filter(pl.col("ts_code") == ts_code)
# 按报告期排序end_date升序
stock_data = stock_data.sort("end_date")
rows = []
for trade_date in trade_dates:
# 找到该交易日之前最新的已公告报告期
# 条件1: end_date <= trade_date报告期不晚于交易日
# 条件2: ann_date <= trade_date已公告
applicable = stock_data.filter(
(pl.col("end_date") <= trade_date) & (pl.col("ann_date") <= trade_date)
)
if len(applicable) > 0:
# 取最新的一条end_date最大的
latest = applicable.tail(1).with_columns(
[pl.lit(trade_date).alias("trade_date")]
)
rows.append(latest)
if rows:
results.append(pl.concat(rows))
if results:
return pl.concat(results)
return pl.DataFrame()

View File

@@ -0,0 +1,20 @@
"""质量因子模块
本模块提供质量类因子:
- 盈利能力ROE、ROA、毛利率、净利率
- 盈利稳定性:盈利波动率、盈利持续性
- 财务健康度:资产负债率、流动比率等
使用示例:
>>> from src.factors.quality import ROEFactor
>>> factor = ROEFactor()
"""
# 在此处导入具体的质量因子
# from .roe import ROEFactor
# from .roa import ROAFactor
# from .profit_stability import ProfitStabilityFactor
__all__ = [
# 添加你的质量因子
]

View File

@@ -0,0 +1,20 @@
"""情绪因子模块
本模块提供市场情绪类因子:
- 换手率、换手率变化率
- 资金流向、主力净流入
- 波动率、振幅等
使用示例:
>>> from src.factors.sentiment import TurnoverFactor
>>> factor = TurnoverFactor(period=20)
"""
# 在此处导入具体的情绪因子
# from .turnover import TurnoverFactor
# from .money_flow import MoneyFlowFactor
# from .amplitude import AmplitudeFactor
__all__ = [
# 添加你的情绪因子
]

View File

@@ -0,0 +1,20 @@
"""技术指标因子模块
本模块提供技术分析类因子:
- 移动平均线(MA)、指数移动平均(EMA)
- 相对强弱指标(RSI)、MACD、KDJ
- 布林带(Bollinger Bands)等
使用示例:
>>> from src.factors.technical import RSIFactor
>>> factor = RSIFactor(period=14)
"""
# 在此处导入具体的技术指标因子
# from .rsi import RSIFactor
# from .macd import MACDFactor
# from .bollinger import BollingerFactor
__all__ = [
# 添加你的技术指标因子
]

View File

@@ -0,0 +1,18 @@
"""估值因子模块
本模块提供估值类因子:
- 市盈率(PE)、市净率(PB)、市销率(PS)等估值指标
- 估值排名、估值分位数等衍生因子
使用示例:
>>> from src.factors.valuation import PERankFactor
>>> factor = PERankFactor()
"""
# 在此处导入具体的估值因子
# from .pe_rank import PERankFactor
# from .pb_rank import PBRankFactor
__all__ = [
# 添加你的估值因子
]

View File

@@ -0,0 +1,21 @@
"""波动率因子模块
本模块提供波动率相关因子:
- 历史波动率(Historical Volatility)
- 实现波动率(Realized Volatility)
- GARCH类波动率预测
- 波动率风险指标等
使用示例:
>>> from src.factors.volatility import HistoricalVolFactor
>>> factor = HistoricalVolFactor(period=20)
"""
# 在此处导入具体的波动率因子
# from .historical_vol import HistoricalVolFactor
# from .realized_vol import RealizedVolFactor
# from .garch_vol import GARCHVolFactor
__all__ = [
# 添加你的波动率因子
]

View File

@@ -0,0 +1,20 @@
"""成交量因子模块
本模块提供成交量相关因子:
- 成交量移动平均
- 成交量比率(VR)、能量潮(OBV)
- 量价配合指标等
使用示例:
>>> from src.factors.volume import OBVFactor
>>> factor = OBVFactor()
"""
# 在此处导入具体的成交量因子
# from .obv import OBVFactor
# from .volume_ratio import VolumeRatioFactor
# from .volume_ma import VolumeMAFactor
__all__ = [
# 添加你的成交量因子
]

View File

@@ -8,6 +8,7 @@ from src.pipeline.processors.processors import (
MinMaxScaler, MinMaxScaler,
RankTransformer, RankTransformer,
Neutralizer, Neutralizer,
MADClipper,
) )
__all__ = [ __all__ = [
@@ -18,4 +19,5 @@ __all__ = [
"MinMaxScaler", "MinMaxScaler",
"RankTransformer", "RankTransformer",
"Neutralizer", "Neutralizer",
"MADClipper",
] ]

View File

@@ -3,9 +3,8 @@
提供常用的数据预处理和转换处理器。 提供常用的数据预处理和转换处理器。
""" """
from typing import List, Optional, Dict, Any from typing import List, Optional
import polars as pl import polars as pl
import numpy as np
from src.pipeline.core import BaseProcessor, PipelineStage from src.pipeline.core import BaseProcessor, PipelineStage
from src.pipeline.registry import PluginRegistry from src.pipeline.registry import PluginRegistry
@@ -227,6 +226,64 @@ class Neutralizer(BaseProcessor):
return result return result
@PluginRegistry.register_processor("mad_clipper")
class MADClipper(BaseProcessor):
"""MAD去极值处理器 - 基于每日截面的中位数绝对偏差去除极值
使用3倍MAD作为阈值比标准差方法更稳健对异常值不敏感。
阈值范围: [median - n*MAD, median + n*MAD]
"""
stage = PipelineStage.TRAIN
def __init__(
self,
columns: Optional[List[str]] = None,
n_mad: float = 3.0,
):
super().__init__(columns)
self.n_mad = n_mad
def fit(self, data: pl.DataFrame) -> "MADClipper":
cols = _get_numeric_columns(data, self.columns)
bounds = {}
for col in cols:
# 按日期分组计算每个截面的 median 和 MAD
daily_stats = data.group_by("trade_date").agg(
pl.col(col).median().alias("median"),
(pl.col(col) - pl.col(col).median()).abs().median().alias("mad"),
)
bounds[col] = daily_stats
self._fitted_params = {"bounds": bounds, "columns": cols}
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
"""使用窗口函数进行MAD去极值避免join操作提升性能"""
result = data
bounds = self._fitted_params.get("bounds", {})
for col in bounds.keys():
if col not in result.columns:
continue
# 使用窗口函数直接计算每个截面的median和MAD避免join
# 1. 计算每个日期截面的median
median = pl.col(col).median().over("trade_date")
# 2. 计算每个日期截面的MAD
mad = (pl.col(col) - median).abs().median().over("trade_date")
# 3. 计算上下界并clip
lower = median - self.n_mad * mad
upper = median + self.n_mad * mad
result = result.with_columns(pl.col(col).clip(lower, upper).alias(col))
return result
__all__ = [ __all__ = [
"DropNAProcessor", "DropNAProcessor",
"FillNAProcessor", "FillNAProcessor",
@@ -235,4 +292,5 @@ __all__ = [
"MinMaxScaler", "MinMaxScaler",
"RankTransformer", "RankTransformer",
"Neutralizer", "Neutralizer",
"MADClipper",
] ]

View File

@@ -12,13 +12,16 @@ from src.training.pipeline import run_training
if __name__ == "__main__": if __name__ == "__main__":
# 运行完整训练流程 # 运行完整训练流程
# 训练集20180101 - 20230101 # 训练集20190101 - 20231231
# 测试20230101 - 20240101 # 验证20240102 - 20240531 (与训练集间隔1天避免数据泄露)
# 测试集20240602 - 20241231 (与验证集间隔1天避免数据泄露)
result = run_training( result = run_training(
train_start="20190101", train_start="20190101",
train_end="20250101", train_end="20231231",
test_start="20250101", val_start="20240102",
test_end="20260101", val_end="20240531",
test_start="20240602",
test_end="20241231",
top_n=5, top_n=5,
output_path="output/top_stocks.tsv", output_path="output/top_stocks.tsv",
) )

View File

@@ -30,10 +30,12 @@ def prepare_data(
data_dir: str = "data", data_dir: str = "data",
train_start: str = "20180101", train_start: str = "20180101",
train_end: str = "20230101", train_end: str = "20230101",
test_start: str = "20230101", val_start: str = "20230101",
val_end: str = "20230601",
test_start: str = "20230601",
test_end: str = "20240101", test_end: str = "20240101",
) -> Tuple[pl.DataFrame, pl.DataFrame]: ) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
"""准备训练和测试数据 """准备训练、验证和测试数据
从DuckDB加载原始日线数据计算所需因子并生成标签。 从DuckDB加载原始日线数据计算所需因子并生成标签。
@@ -41,11 +43,13 @@ def prepare_data(
data_dir: 数据目录 data_dir: 数据目录
train_start: 训练集开始日期 train_start: 训练集开始日期
train_end: 训练集结束日期 train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期 test_start: 测试集开始日期
test_end: 测试集结束日期 test_end: 测试集结束日期
Returns: Returns:
(train_data, test_data): 训练集和测试集的DataFrame (train_data, val_data, test_data): 训练集、验证集和测试集的DataFrame
""" """
from src.data.storage import Storage from src.data.storage import Storage
@@ -56,47 +60,56 @@ def prepare_data(
lookback_days = 20 # 足够计算MA10和5日收益率 lookback_days = 20 # 足够计算MA10和5日收益率
start_with_lookback = str(int(train_start) - 10000) # 往前取一年 start_with_lookback = str(int(train_start) - 10000) # 往前取一年
# 查询训练集数据 # 查询全部数据包含train、val、test然后再拆分
# 注意DuckDB 中 trade_date 是 DATE 类型,需要转换 # 注意DuckDB 中 trade_date 是 DATE 类型,需要转换
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}" start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
end_dt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}" end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
train_query = f"""
all_query = f"""
SELECT ts_code, trade_date, close, pre_close SELECT ts_code, trade_date, close, pre_close
FROM daily FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}' WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
ORDER BY ts_code, trade_date ORDER BY ts_code, trade_date
""" """
train_raw = storage._connection.sql(train_query).pl() all_raw = storage._connection.sql(all_query).pl()
# 转换 trade_date 为字符串格式 # 转换 trade_date 为字符串格式
train_raw = train_raw.with_columns( all_raw = all_raw.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
# 查询测试集数据(也需要历史数据计算因子)
test_start_dt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
test_end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
test_query = f"""
SELECT ts_code, trade_date, close, pre_close
FROM daily
WHERE trade_date >= '{test_start_dt}' AND trade_date <= '{test_end_dt}'
ORDER BY ts_code, trade_date
"""
test_raw = storage._connection.sql(test_query).pl()
# 转换 trade_date 为字符串格式
test_raw = test_raw.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date") pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
) )
# 过滤不符合条件的股票 # 过滤不符合条件的股票
train_raw = _filter_invalid_stocks(train_raw) all_raw = _filter_invalid_stocks(all_raw)
test_raw = _filter_invalid_stocks(test_raw) print(f"[PrepareData] After filtering: total={len(all_raw)}")
print(f"[PrepareData] After filtering: train={len(train_raw)}, test={len(test_raw)}")
# 计算因子和标签 # 计算因子和标签(需要全局数据一次性计算)
train_data = _compute_features_and_label(train_raw, train_start, train_end) all_data = _compute_features_and_label(
test_data = _compute_features_and_label(test_raw, test_start, test_end) all_raw,
start_date=train_start,
end_date=test_end
)
return train_data, test_data # 转换日期格式用于比较
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
val_start_fmt = f"{val_start[:4]}-{val_start[4:6]}-{val_start[6:8]}"
val_end_fmt = f"{val_end[:4]}-{val_end[4:6]}-{val_end[6:8]}"
test_start_fmt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
test_end_fmt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
# 拆分数据
train_data = all_data.filter(
(pl.col("trade_date") >= train_start_fmt) & (pl.col("trade_date") <= train_end_fmt)
)
val_data = all_data.filter(
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
)
test_data = all_data.filter(
(pl.col("trade_date") >= test_start_fmt) & (pl.col("trade_date") <= test_end_fmt)
)
print(f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}")
return train_data, val_data, test_data
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame: def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
@@ -254,6 +267,7 @@ def create_pipeline() -> ProcessingPipeline:
def train_model( def train_model(
train_data: pl.DataFrame, train_data: pl.DataFrame,
val_data: Optional[pl.DataFrame],
feature_cols: List[str], feature_cols: List[str],
label_col: str = "label", label_col: str = "label",
model_params: Optional[dict] = None, model_params: Optional[dict] = None,
@@ -262,6 +276,7 @@ def train_model(
Args: Args:
train_data: 训练数据 train_data: 训练数据
val_data: 验证数据(用于早停)
feature_cols: 特征列名列表 feature_cols: 特征列名列表
label_col: 标签列名 label_col: 标签列名
model_params: 模型参数字典 model_params: 模型参数字典
@@ -273,21 +288,39 @@ def train_model(
pipeline = create_pipeline() pipeline = create_pipeline()
print("[TrainModel] Pipeline created: FillNA(0)") print("[TrainModel] Pipeline created: FillNA(0)")
# 准备特征和标签 # 准备训练特征和标签
X_train = train_data.select(feature_cols) X_train = train_data.select(feature_cols)
y_train = train_data[label_col] y_train = train_data[label_col]
print(f"[TrainModel] Raw samples: {len(X_train)}, features: {feature_cols}") print(f"[TrainModel] Train samples: {len(X_train)}, features: {feature_cols}")
# 处理数据 # 处理训练数据
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN) X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
print(f"[TrainModel] After processing: {len(X_train_processed)} samples") print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
# 过滤有效标签(排除-1等无效值 # 过滤训练集有效标签(排除-1等无效值
valid_mask = y_train.is_in([0, 1]) valid_mask = y_train.is_in([0, 1])
X_train_processed = X_train_processed.filter(valid_mask) X_train_processed = X_train_processed.filter(valid_mask)
y_train = y_train.filter(valid_mask) y_train = y_train.filter(valid_mask)
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples") print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
print(f"[TrainModel] Label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}") print(f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
# 准备验证集
X_val_processed = None
y_val = None
if val_data is not None and len(val_data) > 0:
X_val = val_data.select(feature_cols)
y_val = val_data[label_col]
print(f"[TrainModel] Val samples: {len(X_val)}")
# 处理验证集数据(使用训练集的参数)
X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST)
# 过滤验证集有效标签
val_valid_mask = y_val.is_in([0, 1])
X_val_processed = X_val_processed.filter(val_valid_mask)
y_val = y_val.filter(val_valid_mask)
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
print(f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}")
# 创建模型 # 创建模型
params = model_params or { params = model_params or {
@@ -302,8 +335,12 @@ def train_model(
params=params, params=params,
) )
# 训练模型 # 训练模型(使用验证集早停)
print("[TrainModel] Training LightGBM...") print("[TrainModel] Training LightGBM...")
if X_val_processed is not None and y_val is not None:
print("[TrainModel] Using validation set for early stopping")
model.fit(X_train_processed, y_train, X_val_processed, y_val)
else:
model.fit(X_train_processed, y_train) model.fit(X_train_processed, y_train)
print("[TrainModel] Training completed!") print("[TrainModel] Training completed!")
@@ -382,7 +419,9 @@ def run_training(
output_path: str = "output/top_stocks.tsv", output_path: str = "output/top_stocks.tsv",
train_start: str = "20180101", train_start: str = "20180101",
train_end: str = "20230101", train_end: str = "20230101",
test_start: str = "20230101", val_start: str = "20230101",
val_end: str = "20230601",
test_start: str = "20230601",
test_end: str = "20240101", test_end: str = "20240101",
top_n: int = 5, top_n: int = 5,
) -> pl.DataFrame: ) -> pl.DataFrame:
@@ -393,6 +432,8 @@ def run_training(
output_path: 输出文件路径 output_path: 输出文件路径
train_start: 训练集开始日期 train_start: 训练集开始日期
train_end: 训练集结束日期 train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期 test_start: 测试集开始日期
test_end: 测试集结束日期 test_end: 测试集结束日期
top_n: 每日选股数量 top_n: 每日选股数量
@@ -402,18 +443,22 @@ def run_training(
""" """
print(f"[Training] Starting training pipeline...") print(f"[Training] Starting training pipeline...")
print(f"[Training] Train period: {train_start} -> {train_end}") print(f"[Training] Train period: {train_start} -> {train_end}")
print(f"[Training] Val period: {val_start} -> {val_end}")
print(f"[Training] Test period: {test_start} -> {test_end}") print(f"[Training] Test period: {test_start} -> {test_end}")
# 1. 准备数据 # 1. 准备数据
print("[Training] Preparing data...") print("[Training] Preparing data...")
train_data, test_data = prepare_data( train_data, val_data, test_data = prepare_data(
data_dir=data_dir, data_dir=data_dir,
train_start=train_start, train_start=train_start,
train_end=train_end, train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start, test_start=test_start,
test_end=test_end, test_end=test_end,
) )
print(f"[Training] Train samples: {len(train_data)}") print(f"[Training] Train samples: {len(train_data)}")
print(f"[Training] Val samples: {len(val_data)}")
print(f"[Training] Test samples: {len(test_data)}") print(f"[Training] Test samples: {len(test_data)}")
# 2. 定义特征列 # 2. 定义特征列
@@ -424,6 +469,7 @@ def run_training(
print("[Training] Training model...") print("[Training] Training model...")
model, pipeline = train_model( model, pipeline = train_model(
train_data=train_data, train_data=train_data,
val_data=val_data,
feature_cols=feature_cols, feature_cols=feature_cols,
label_col=label_col, label_col=label_col,
) )