Files
NewStock/main/train/AnalyzeData.ipynb
2025-06-02 22:23:44 +08:00

2460 lines
143 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:39:30.609224Z",
"start_time": "2025-04-09T16:39:29.929606Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import pandas as pd\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"pd.set_option('display.max_columns', None)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:40:19.471361Z",
"start_time": "2025-04-09T16:39:30.917824Z"
},
"jupyter": {
"source_hidden": true
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"cyq perf\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8692146 entries, 0 to 8692145\n",
"Data columns (total 33 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 amount float64 \n",
" 8 pct_chg float64 \n",
" 9 turnover_rate float64 \n",
" 10 pe_ttm float64 \n",
" 11 circ_mv float64 \n",
" 12 total_mv float64 \n",
" 13 volume_ratio float64 \n",
" 14 is_st bool \n",
" 15 up_limit float64 \n",
" 16 down_limit float64 \n",
" 17 buy_sm_vol float64 \n",
" 18 sell_sm_vol float64 \n",
" 19 buy_lg_vol float64 \n",
" 20 sell_lg_vol float64 \n",
" 21 buy_elg_vol float64 \n",
" 22 sell_elg_vol float64 \n",
" 23 net_mf_vol float64 \n",
" 24 his_low float64 \n",
" 25 his_high float64 \n",
" 26 cost_5pct float64 \n",
" 27 cost_15pct float64 \n",
" 28 cost_50pct float64 \n",
" 29 cost_85pct float64 \n",
" 30 cost_95pct float64 \n",
" 31 weight_avg float64 \n",
" 32 winner_rate float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(30), object(1)\n",
"memory usage: 2.1+ GB\n",
"None\n"
]
}
],
"source": [
"from main.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'amount', 'pct_chg'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print('cyq perf')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/cyq_perf.h5', key='cyq_perf',\n",
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct',\n",
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
" df=df)\n",
"print(df.info())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8d2acede",
"metadata": {},
"outputs": [],
"source": [
"fina_indicator_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock//data/fina_indicator.h5', key='fina_indicator',\n",
" columns=['ts_code', 'ann_date', 'undist_profit_ps', 'ocfps', 'bps', 'roa', 'roe'],\n",
" df=None)\n",
"cashflow_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock//data/cashflow.h5', key='cashflow',\n",
" columns=['ts_code', 'ann_date', 'n_cashflow_act'],\n",
" df=None)\n",
"balancesheet_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock//data/balancesheet.h5', key='balancesheet',\n",
" columns=['ts_code', 'ann_date', 'money_cap', 'total_liab'],\n",
" df=None)\n",
"top_list_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock//data/top_list.h5', key='top_list',\n",
" columns=['ts_code', 'trade_date', 'reason'],\n",
" df=None)\n",
"\n",
"top_list_df = top_list_df.sort_values(by='trade_date', ascending=False).drop_duplicates(subset=['ts_code', 'trade_date'], keep='first').sort_values(by='trade_date')\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "85c3e3d0235ffffa",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:41:39.580305Z",
"start_time": "2025-04-09T16:40:30.170820Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"开始计算因子: AR, BR (原地修改)...\n",
"因子 AR, BR 计算成功。\n",
"因子 AR, BR 计算流程结束。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"计算 BBI...\n",
"--- 计算日级别偏离度 (使用 pct_chg) ---\n",
"--- 计算日级别动量基准 (使用 pct_chg) ---\n",
"日级别动量基准计算完成 (使用 pct_chg)。\n",
"日级别偏离度计算完成 (使用 pct_chg)。\n",
"--- 计算日级别行业偏离度 (使用 pct_chg 和行业基准) ---\n",
"--- 计算日级别行业动量基准 (使用 pct_chg 和 cat_l2_code) ---\n",
"错误: 计算日级别行业动量基准需要以下列: ['pct_chg', 'cat_l2_code', 'trade_date', 'ts_code']。\n",
"错误: 计算日级别行业偏离度需要以下列: ['pct_chg', 'daily_industry_positive_benchmark', 'daily_industry_negative_benchmark']。请先运行 daily_industry_momentum_benchmark(df)。\n",
"Index(['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol',\n",
" 'amount', 'pct_chg', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv',\n",
" 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol',\n",
" 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol',\n",
" 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct',\n",
" 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg',\n",
" 'winner_rate', 'undist_profit_ps', 'ocfps', 'roa', 'roe', 'AR', 'BR',\n",
" 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio',\n",
" 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor',\n",
" 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity',\n",
" 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio',\n",
" 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change',\n",
" 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel',\n",
" 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy',\n",
" 'cost_support_15pct_change', 'cat_winner_price_zone',\n",
" 'flow_chip_consistency', 'profit_taking_vs_absorb', '_is_positive',\n",
" '_is_negative', 'cat_is_positive', '_pos_returns', '_neg_returns',\n",
" '_pos_returns_sq', '_neg_returns_sq', 'upside_vol', 'downside_vol',\n",
" 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate',\n",
" 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike',\n",
" 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike',\n",
" 'vol_std_5', 'atr_14', 'atr_6', 'obv'],\n",
" dtype='object')\n",
"Calculating lg_flow_mom_corr_20_60...\n",
"Finished lg_flow_mom_corr_20_60.\n",
"Calculating lg_flow_accel...\n",
"Finished lg_flow_accel.\n",
"Calculating profit_pressure...\n",
"Finished profit_pressure.\n",
"Calculating underwater_resistance...\n",
"Finished underwater_resistance.\n",
"Calculating cost_conc_std_20...\n",
"Finished cost_conc_std_20.\n",
"Calculating profit_decay_20...\n",
"Finished profit_decay_20.\n",
"Calculating vol_amp_loss_20...\n",
"Finished vol_amp_loss_20.\n",
"Calculating vol_drop_profit_cnt_5...\n",
"Finished vol_drop_profit_cnt_5.\n",
"Calculating lg_flow_vol_interact_20...\n",
"Finished lg_flow_vol_interact_20.\n",
"Calculating cost_break_confirm_cnt_5...\n",
"Finished cost_break_confirm_cnt_5.\n",
"Calculating atr_norm_channel_pos_14...\n",
"Finished atr_norm_channel_pos_14.\n",
"Calculating turnover_diff_skew_20...\n",
"Finished turnover_diff_skew_20.\n",
"Calculating lg_sm_flow_diverge_20...\n",
"Finished lg_sm_flow_diverge_20.\n",
"Calculating pullback_strong_20_20...\n",
"Finished pullback_strong_20_20.\n",
"Calculating vol_wgt_hist_pos_20...\n",
"Finished vol_wgt_hist_pos_20.\n",
"Calculating vol_adj_roc_20...\n",
"Finished vol_adj_roc_20.\n",
"Calculating cs_rank_net_lg_flow_val...\n",
"Finished cs_rank_net_lg_flow_val.\n",
"Calculating cs_rank_flow_divergence...\n",
"Finished cs_rank_flow_divergence.\n",
"Calculating cs_rank_ind_adj_lg_flow...\n",
"Error calculating cs_rank_ind_adj_lg_flow: Missing 'cat_l2_code' column. Assigning NaN.\n",
"Calculating cs_rank_elg_buy_ratio...\n",
"Finished cs_rank_elg_buy_ratio.\n",
"Calculating cs_rank_rel_profit_margin...\n",
"Finished cs_rank_rel_profit_margin.\n",
"Calculating cs_rank_cost_breadth...\n",
"Finished cs_rank_cost_breadth.\n",
"Calculating cs_rank_dist_to_upper_cost...\n",
"Finished cs_rank_dist_to_upper_cost.\n",
"Calculating cs_rank_winner_rate...\n",
"Finished cs_rank_winner_rate.\n",
"Calculating cs_rank_intraday_range...\n",
"Finished cs_rank_intraday_range.\n",
"Calculating cs_rank_close_pos_in_range...\n",
"Finished cs_rank_close_pos_in_range.\n",
"Calculating cs_rank_opening_gap...\n",
"Error calculating cs_rank_opening_gap: Missing 'pre_close' column. Assigning NaN.\n",
"Calculating cs_rank_pos_in_hist_range...\n",
"Finished cs_rank_pos_in_hist_range.\n",
"Calculating cs_rank_vol_x_profit_margin...\n",
"Finished cs_rank_vol_x_profit_margin.\n",
"Calculating cs_rank_lg_flow_price_concordance...\n",
"Finished cs_rank_lg_flow_price_concordance.\n",
"Calculating cs_rank_turnover_per_winner...\n",
"Finished cs_rank_turnover_per_winner.\n",
"Calculating cs_rank_ind_cap_neutral_pe (Placeholder - requires statsmodels)...\n",
"Finished cs_rank_ind_cap_neutral_pe (Placeholder).\n",
"Calculating cs_rank_volume_ratio...\n",
"Finished cs_rank_volume_ratio.\n",
"Calculating cs_rank_elg_buy_sell_sm_ratio...\n",
"Finished cs_rank_elg_buy_sell_sm_ratio.\n",
"Calculating cs_rank_cost_dist_vol_ratio...\n",
"Finished cs_rank_cost_dist_vol_ratio.\n",
"Calculating cs_rank_size...\n",
"Finished cs_rank_size.\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 2511964 entries, 0 to 2511963\n",
"Columns: 180 entries, ts_code to cs_rank_size\n",
"dtypes: bool(10), datetime64[ns](1), float64(165), int64(3), object(1)\n",
"memory usage: 3.2+ GB\n",
"None\n",
"['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'amount', 'pct_chg', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate', 'undist_profit_ps', 'ocfps', 'roa', 'roe', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'price_cost_divergence', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_flow_divergence', 'cs_rank_ind_adj_lg_flow', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_opening_gap', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_ind_cap_neutral_pe', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size']\n"
]
}
],
"source": [
"# df1\n",
"\n",
"import numpy as np\n",
"from main.factor.factor import *\n",
"\n",
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df[df['trade_date'] >= '2022-01-01']\n",
" if 'in_date' in df.columns:\n",
" df = df.drop(columns=['in_date'])\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"import gc\n",
"gc.collect()\n",
"\n",
"df = filter_data(df)\n",
"df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
"# df = price_minus_deduction_price(df, n=120)\n",
"# df = price_deduction_price_diff_ratio_to_sma(df, n=120)\n",
"# df = cat_price_vs_sma_vs_deduction_price(df, n=120)\n",
"# df = cat_reason(df, top_list_df)\n",
"# df = cat_is_on_top_list(df, top_list_df)\n",
"\n",
"# df = cat_senti_mom_vol_spike(\n",
"# df,\n",
"# return_period=3,\n",
"# return_threshold=0.03, # 近3日涨幅超3%\n",
"# volume_ratio_threshold=1.3,\n",
"# current_pct_chg_min=0.0, # 当日必须收红\n",
"# current_pct_chg_max=0.05,\n",
"# ) # 当日涨幅不宜过大\n",
"\n",
"# df = cat_senti_pre_breakout(\n",
"# df,\n",
"# atr_short_N=10,\n",
"# atr_long_M=40,\n",
"# vol_atrophy_N=10,\n",
"# vol_atrophy_M=40,\n",
"# price_stab_N=5,\n",
"# price_stab_threshold=0.06,\n",
"# current_pct_chg_min_signal=0.002,\n",
"# current_pct_chg_max_signal=0.05,\n",
"# volume_ratio_signal_threshold=1.1,\n",
"# )\n",
"\n",
"# df = ts_turnover_rate_acceleration_5_20(df)\n",
"# df = ts_vol_sustain_10_30(df)\n",
"# # df = cs_turnover_rate_relative_strength_20(df)\n",
"# df = cs_amount_outlier_10(df)\n",
"# df = ts_ff_to_total_turnover_ratio(df)\n",
"# df = ts_price_volume_trend_coherence_5_20(df)\n",
"# # df = ts_turnover_rate_trend_strength_5(df)\n",
"# df = ts_ff_turnover_rate_surge_10(df)\n",
"\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='undist_profit_ps')\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='ocfps')\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='roa')\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='roe')\n",
"\n",
"calculate_arbr(df, N=26)\n",
"df['log_circ_mv'] = np.log(df['circ_mv'])\n",
"df = calculate_cashflow_to_ev_factor(df, cashflow_df, balancesheet_df)\n",
"df = caculate_book_to_price_ratio(df, fina_indicator_df)\n",
"\n",
"df = turnover_rate_n(df, n=5)\n",
"df = variance_n(df, n=20)\n",
"df = bbi_ratio_factor(df)\n",
"df = daily_deviation(df)\n",
"df = daily_industry_deviation(df)\n",
"df, _ = get_rolling_factor(df)\n",
"df, _ = get_simple_factor(df)\n",
"\n",
"df = df.rename(columns={'l1_code': 'cat_l1_code'})\n",
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"\n",
"lg_flow_mom_corr(df, N=20, M=60)\n",
"lg_flow_accel(df)\n",
"profit_pressure(df)\n",
"underwater_resistance(df)\n",
"cost_conc_std(df, N=20)\n",
"profit_decay(df, N=20)\n",
"vol_amp_loss(df, N=20)\n",
"vol_drop_profit_cnt(df, N=20, M=5)\n",
"lg_flow_vol_interact(df, N=20)\n",
"cost_break_confirm_cnt(df, M=5)\n",
"atr_norm_channel_pos(df, N=14)\n",
"turnover_diff_skew(df, N=20)\n",
"lg_sm_flow_diverge(df, N=20)\n",
"pullback_strong(df, N=20, M=20)\n",
"vol_wgt_hist_pos(df, N=20)\n",
"vol_adj_roc(df, N=20)\n",
"\n",
"cs_rank_net_lg_flow_val(df)\n",
"cs_rank_flow_divergence(df)\n",
"cs_rank_industry_adj_lg_flow(df) # Needs cat_l2_code\n",
"cs_rank_elg_buy_ratio(df)\n",
"cs_rank_rel_profit_margin(df)\n",
"cs_rank_cost_breadth(df)\n",
"cs_rank_dist_to_upper_cost(df)\n",
"cs_rank_winner_rate(df)\n",
"cs_rank_intraday_range(df)\n",
"cs_rank_close_pos_in_range(df)\n",
"cs_rank_opening_gap(df) # Needs pre_close\n",
"cs_rank_pos_in_hist_range(df) # Needs his_low, his_high\n",
"cs_rank_vol_x_profit_margin(df)\n",
"cs_rank_lg_flow_price_concordance(df)\n",
"cs_rank_turnover_per_winner(df)\n",
"cs_rank_ind_cap_neutral_pe(df) # Placeholder - needs external libraries\n",
"cs_rank_volume_ratio(df) # Needs volume_ratio\n",
"cs_rank_elg_buy_sell_sm_ratio(df)\n",
"cs_rank_cost_dist_vol_ratio(df) # Needs volume_ratio\n",
"cs_rank_size(df) # Needs circ_mv\n",
"\n",
"df1 = df.copy()\n",
"\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())\n",
"print(df.columns.tolist())"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5dabff1e7bdd48c0",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:42:29.604069Z",
"start_time": "2025-04-09T16:41:39.621703Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"cyq perf\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8692146 entries, 0 to 8692145\n",
"Data columns (total 33 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 amount float64 \n",
" 8 pct_chg float64 \n",
" 9 turnover_rate float64 \n",
" 10 pe_ttm float64 \n",
" 11 circ_mv float64 \n",
" 12 total_mv float64 \n",
" 13 volume_ratio float64 \n",
" 14 is_st bool \n",
" 15 up_limit float64 \n",
" 16 down_limit float64 \n",
" 17 buy_sm_vol float64 \n",
" 18 sell_sm_vol float64 \n",
" 19 buy_lg_vol float64 \n",
" 20 sell_lg_vol float64 \n",
" 21 buy_elg_vol float64 \n",
" 22 sell_elg_vol float64 \n",
" 23 net_mf_vol float64 \n",
" 24 his_low float64 \n",
" 25 his_high float64 \n",
" 26 cost_5pct float64 \n",
" 27 cost_15pct float64 \n",
" 28 cost_50pct float64 \n",
" 29 cost_85pct float64 \n",
" 30 cost_95pct float64 \n",
" 31 weight_avg float64 \n",
" 32 winner_rate float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(30), object(1)\n",
"memory usage: 2.1+ GB\n",
"None\n"
]
}
],
"source": [
"from main.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'amount', 'pct_chg'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print('cyq perf')\n",
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/cyq_perf.h5', key='cyq_perf',\n",
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct',\n",
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
" df=df)\n",
"print(df.info())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ee9d7511597a312b",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:43:50.865104Z",
"start_time": "2025-04-09T16:42:39.340589Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"开始计算因子: AR, BR (原地修改)...\n",
"因子 AR, BR 计算成功。\n",
"因子 AR, BR 计算流程结束。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"计算 BBI...\n",
"--- 计算日级别偏离度 (使用 pct_chg) ---\n",
"--- 计算日级别动量基准 (使用 pct_chg) ---\n",
"日级别动量基准计算完成 (使用 pct_chg)。\n",
"日级别偏离度计算完成 (使用 pct_chg)。\n",
"--- 计算日级别行业偏离度 (使用 pct_chg 和行业基准) ---\n",
"--- 计算日级别行业动量基准 (使用 pct_chg 和 cat_l2_code) ---\n",
"错误: 计算日级别行业动量基准需要以下列: ['pct_chg', 'cat_l2_code', 'trade_date', 'ts_code']。\n",
"错误: 计算日级别行业偏离度需要以下列: ['pct_chg', 'daily_industry_positive_benchmark', 'daily_industry_negative_benchmark']。请先运行 daily_industry_momentum_benchmark(df)。\n",
"Index(['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol',\n",
" 'amount', 'pct_chg', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv',\n",
" 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol',\n",
" 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol',\n",
" 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct',\n",
" 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg',\n",
" 'winner_rate', 'undist_profit_ps', 'ocfps', 'roa', 'roe', 'AR', 'BR',\n",
" 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio',\n",
" 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor',\n",
" 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity',\n",
" 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio',\n",
" 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change',\n",
" 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel',\n",
" 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy',\n",
" 'cost_support_15pct_change', 'cat_winner_price_zone',\n",
" 'flow_chip_consistency', 'profit_taking_vs_absorb', '_is_positive',\n",
" '_is_negative', 'cat_is_positive', '_pos_returns', '_neg_returns',\n",
" '_pos_returns_sq', '_neg_returns_sq', 'upside_vol', 'downside_vol',\n",
" 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate',\n",
" 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike',\n",
" 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike',\n",
" 'vol_std_5', 'atr_14', 'atr_6', 'obv'],\n",
" dtype='object')\n",
"Calculating lg_flow_mom_corr_20_60...\n",
"Finished lg_flow_mom_corr_20_60.\n",
"Calculating lg_flow_accel...\n",
"Finished lg_flow_accel.\n",
"Calculating profit_pressure...\n",
"Finished profit_pressure.\n",
"Calculating underwater_resistance...\n",
"Finished underwater_resistance.\n",
"Calculating cost_conc_std_20...\n",
"Finished cost_conc_std_20.\n",
"Calculating profit_decay_20...\n",
"Finished profit_decay_20.\n",
"Calculating vol_amp_loss_20...\n",
"Finished vol_amp_loss_20.\n",
"Calculating vol_drop_profit_cnt_5...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Finished vol_drop_profit_cnt_5.\n",
"Calculating lg_flow_vol_interact_20...\n",
"Finished lg_flow_vol_interact_20.\n",
"Calculating cost_break_confirm_cnt_5...\n",
"Finished cost_break_confirm_cnt_5.\n",
"Calculating atr_norm_channel_pos_14...\n",
"Finished atr_norm_channel_pos_14.\n",
"Calculating turnover_diff_skew_20...\n",
"Finished turnover_diff_skew_20.\n",
"Calculating lg_sm_flow_diverge_20...\n",
"Finished lg_sm_flow_diverge_20.\n",
"Calculating pullback_strong_20_20...\n",
"Finished pullback_strong_20_20.\n",
"Calculating vol_wgt_hist_pos_20...\n",
"Finished vol_wgt_hist_pos_20.\n",
"Calculating vol_adj_roc_20...\n",
"Finished vol_adj_roc_20.\n",
"Calculating cs_rank_net_lg_flow_val...\n",
"Finished cs_rank_net_lg_flow_val.\n",
"Calculating cs_rank_flow_divergence...\n",
"Finished cs_rank_flow_divergence.\n",
"Calculating cs_rank_ind_adj_lg_flow...\n",
"Error calculating cs_rank_ind_adj_lg_flow: Missing 'cat_l2_code' column. Assigning NaN.\n",
"Calculating cs_rank_elg_buy_ratio...\n",
"Finished cs_rank_elg_buy_ratio.\n",
"Calculating cs_rank_rel_profit_margin...\n",
"Finished cs_rank_rel_profit_margin.\n",
"Calculating cs_rank_cost_breadth...\n",
"Finished cs_rank_cost_breadth.\n",
"Calculating cs_rank_dist_to_upper_cost...\n",
"Finished cs_rank_dist_to_upper_cost.\n",
"Calculating cs_rank_winner_rate...\n",
"Finished cs_rank_winner_rate.\n",
"Calculating cs_rank_intraday_range...\n",
"Finished cs_rank_intraday_range.\n",
"Calculating cs_rank_close_pos_in_range...\n",
"Finished cs_rank_close_pos_in_range.\n",
"Calculating cs_rank_opening_gap...\n",
"Error calculating cs_rank_opening_gap: Missing 'pre_close' column. Assigning NaN.\n",
"Calculating cs_rank_pos_in_hist_range...\n",
"Finished cs_rank_pos_in_hist_range.\n",
"Calculating cs_rank_vol_x_profit_margin...\n",
"Finished cs_rank_vol_x_profit_margin.\n",
"Calculating cs_rank_lg_flow_price_concordance...\n",
"Finished cs_rank_lg_flow_price_concordance.\n",
"Calculating cs_rank_turnover_per_winner...\n",
"Finished cs_rank_turnover_per_winner.\n",
"Calculating cs_rank_ind_cap_neutral_pe (Placeholder - requires statsmodels)...\n",
"Finished cs_rank_ind_cap_neutral_pe (Placeholder).\n",
"Calculating cs_rank_volume_ratio...\n",
"Finished cs_rank_volume_ratio.\n",
"Calculating cs_rank_elg_buy_sell_sm_ratio...\n",
"Finished cs_rank_elg_buy_sell_sm_ratio.\n",
"Calculating cs_rank_cost_dist_vol_ratio...\n",
"Finished cs_rank_cost_dist_vol_ratio.\n",
"Calculating cs_rank_size...\n",
"Finished cs_rank_size.\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1784215 entries, 0 to 1784214\n",
"Columns: 180 entries, ts_code to cs_rank_size\n",
"dtypes: bool(10), datetime64[ns](1), float64(165), int64(3), object(1)\n",
"memory usage: 2.3+ GB\n",
"None\n",
"['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'amount', 'pct_chg', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate', 'undist_profit_ps', 'ocfps', 'roa', 'roe', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'price_cost_divergence', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_flow_divergence', 'cs_rank_ind_adj_lg_flow', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_opening_gap', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_ind_cap_neutral_pe', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size']\n"
]
}
],
"source": [
"# df2\n",
"\n",
"import numpy as np\n",
"from main.factor.factor import *\n",
"\n",
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df[df['trade_date'] >= '2023-01-01']\n",
" if 'in_date' in df.columns:\n",
" df = df.drop(columns=['in_date'])\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"import gc\n",
"gc.collect()\n",
"\n",
"df = filter_data(df)\n",
"df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
"# df = price_minus_deduction_price(df, n=120)\n",
"# df = price_deduction_price_diff_ratio_to_sma(df, n=120)\n",
"# df = cat_price_vs_sma_vs_deduction_price(df, n=120)\n",
"# df = cat_reason(df, top_list_df)\n",
"# df = cat_is_on_top_list(df, top_list_df)\n",
"\n",
"# df = cat_senti_mom_vol_spike(\n",
"# df,\n",
"# return_period=3,\n",
"# return_threshold=0.03, # 近3日涨幅超3%\n",
"# volume_ratio_threshold=1.3,\n",
"# current_pct_chg_min=0.0, # 当日必须收红\n",
"# current_pct_chg_max=0.05,\n",
"# ) # 当日涨幅不宜过大\n",
"\n",
"# df = cat_senti_pre_breakout(\n",
"# df,\n",
"# atr_short_N=10,\n",
"# atr_long_M=40,\n",
"# vol_atrophy_N=10,\n",
"# vol_atrophy_M=40,\n",
"# price_stab_N=5,\n",
"# price_stab_threshold=0.06,\n",
"# current_pct_chg_min_signal=0.002,\n",
"# current_pct_chg_max_signal=0.05,\n",
"# volume_ratio_signal_threshold=1.1,\n",
"# )\n",
"\n",
"# df = ts_turnover_rate_acceleration_5_20(df)\n",
"# df = ts_vol_sustain_10_30(df)\n",
"# # df = cs_turnover_rate_relative_strength_20(df)\n",
"# df = cs_amount_outlier_10(df)\n",
"# df = ts_ff_to_total_turnover_ratio(df)\n",
"# df = ts_price_volume_trend_coherence_5_20(df)\n",
"# # df = ts_turnover_rate_trend_strength_5(df)\n",
"# df = ts_ff_turnover_rate_surge_10(df)\n",
"\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='undist_profit_ps')\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='ocfps')\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='roa')\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='roe')\n",
"\n",
"calculate_arbr(df, N=26)\n",
"df['log_circ_mv'] = np.log(df['circ_mv'])\n",
"df = calculate_cashflow_to_ev_factor(df, cashflow_df, balancesheet_df)\n",
"df = caculate_book_to_price_ratio(df, fina_indicator_df)\n",
"\n",
"df = turnover_rate_n(df, n=5)\n",
"df = variance_n(df, n=20)\n",
"df = bbi_ratio_factor(df)\n",
"df = daily_deviation(df)\n",
"df = daily_industry_deviation(df)\n",
"df, _ = get_rolling_factor(df)\n",
"df, _ = get_simple_factor(df)\n",
"\n",
"df = df.rename(columns={'l1_code': 'cat_l1_code'})\n",
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"\n",
"lg_flow_mom_corr(df, N=20, M=60)\n",
"lg_flow_accel(df)\n",
"profit_pressure(df)\n",
"underwater_resistance(df)\n",
"cost_conc_std(df, N=20)\n",
"profit_decay(df, N=20)\n",
"vol_amp_loss(df, N=20)\n",
"vol_drop_profit_cnt(df, N=20, M=5)\n",
"lg_flow_vol_interact(df, N=20)\n",
"cost_break_confirm_cnt(df, M=5)\n",
"atr_norm_channel_pos(df, N=14)\n",
"turnover_diff_skew(df, N=20)\n",
"lg_sm_flow_diverge(df, N=20)\n",
"pullback_strong(df, N=20, M=20)\n",
"vol_wgt_hist_pos(df, N=20)\n",
"vol_adj_roc(df, N=20)\n",
"\n",
"cs_rank_net_lg_flow_val(df)\n",
"cs_rank_flow_divergence(df)\n",
"cs_rank_industry_adj_lg_flow(df) # Needs cat_l2_code\n",
"cs_rank_elg_buy_ratio(df)\n",
"cs_rank_rel_profit_margin(df)\n",
"cs_rank_cost_breadth(df)\n",
"cs_rank_dist_to_upper_cost(df)\n",
"cs_rank_winner_rate(df)\n",
"cs_rank_intraday_range(df)\n",
"cs_rank_close_pos_in_range(df)\n",
"cs_rank_opening_gap(df) # Needs pre_close\n",
"cs_rank_pos_in_hist_range(df) # Needs his_low, his_high\n",
"cs_rank_vol_x_profit_margin(df)\n",
"cs_rank_lg_flow_price_concordance(df)\n",
"cs_rank_turnover_per_winner(df)\n",
"cs_rank_ind_cap_neutral_pe(df) # Placeholder - needs external libraries\n",
"cs_rank_volume_ratio(df) # Needs volume_ratio\n",
"cs_rank_elg_buy_sell_sm_ratio(df)\n",
"cs_rank_cost_dist_vol_ratio(df) # Needs volume_ratio\n",
"cs_rank_size(df) # Needs circ_mv\n",
"\n",
"df2 = df\n",
"\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())\n",
"print(df.columns.tolist())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "770520c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2022-01-04 00:00:00\n",
"2023-01-03 00:00:00\n"
]
}
],
"source": [
"print(df1['trade_date'].min())\n",
"print(df2['trade_date'].min())\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3cff0731",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Empty DataFrame\n",
"Columns: [ts_code, trade_date, open, close, high, low, vol, amount, pct_chg, turnover_rate, pe_ttm, circ_mv, total_mv, volume_ratio, is_st, up_limit, down_limit, buy_sm_vol, sell_sm_vol, buy_lg_vol, sell_lg_vol, buy_elg_vol, sell_elg_vol, net_mf_vol, his_low, his_high, cost_5pct, cost_15pct, cost_50pct, cost_85pct, cost_95pct, weight_avg, winner_rate, undist_profit_ps, ocfps, roa, roe, AR, BR, AR_BR, log_circ_mv, cashflow_to_ev_factor, book_to_price_ratio, turnover_rate_mean_5, variance_20, bbi_ratio_factor, daily_deviation, lg_elg_net_buy_vol, flow_lg_elg_intensity, sm_net_buy_vol, flow_divergence_diff, flow_divergence_ratio, total_buy_vol, lg_elg_buy_prop, flow_struct_buy_change, lg_elg_net_buy_vol_change, flow_lg_elg_accel, chip_concentration_range, chip_skewness, floating_chip_proxy, cost_support_15pct_change, cat_winner_price_zone, flow_chip_consistency, profit_taking_vs_absorb, cat_is_positive, upside_vol, downside_vol, vol_ratio, return_skew, return_kurtosis, volume_change_rate, cat_volume_breakout, turnover_deviation, cat_turnover_spike, avg_volume_ratio, cat_volume_ratio_breakout, vol_spike, vol_std_5, atr_14, atr_6, obv, maobv_6, rsi_3, return_5, return_20, std_return_5, std_return_90, std_return_90_2, act_factor1, act_factor2, act_factor3, act_factor4, rank_act_factor1, rank_act_factor2, rank_act_factor3, cov, delta_cov, alpha_22_improved, alpha_003, alpha_007, ...]\n",
"Index: []\n"
]
}
],
"source": [
"\n",
"print(df1[df1['ts_code'] == '002259.SZ'])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4ae711775caefbe5",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:43:53.621695Z",
"start_time": "2025-04-09T16:43:50.925481Z"
}
},
"outputs": [],
"source": [
"# # print(df1[df1['trade_date'] == '2025-04-07'][['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']].tail())\n",
"# # print(df2[df2['trade_date'] == '2025-04-07'][['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']].tail())\n",
"# # print(df1[df1['trade_date'] == '2025-04-07'].equals(df2[df2['trade_date'] == '2025-04-07']))\n",
"\n",
"# from main.utils.factor_processor import calculate_score\n",
"\n",
"# days = 2\n",
"# df1 = df1.sort_values(by=['ts_code', 'trade_date'])\n",
"# # df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"# df1['future_return'] = (df1.groupby('ts_code')['close'].shift(-days) - df1.groupby('ts_code')['open'].shift(-1)) / \\\n",
"# df1.groupby('ts_code')['open'].shift(-1)\n",
"# df1['future_score'] = calculate_score(df1, days=2, lambda_param=0.3)\n",
"# df1['label'] = df1.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
"# lambda x: pd.qcut(x, q=20, labels=False, duplicates='drop')\n",
"# )\n",
"\n",
"# df2 = df2.sort_values(by=['ts_code', 'trade_date'])\n",
"# # df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"# df2['future_return'] = (df2.groupby('ts_code')['close'].shift(-days) - df2.groupby('ts_code')['open'].shift(-1)) / \\\n",
"# df2.groupby('ts_code')['open'].shift(-1)\n",
"# df2['future_score'] = calculate_score(df2, days=2, lambda_param=0.3)\n",
"# df2['label'] = df2.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
"# lambda x: pd.qcut(x, q=20, labels=False, duplicates='drop')\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "350bf91df8c3dfc2",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:43:53.723327Z",
"start_time": "2025-04-09T16:43:53.658090Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"日期: 2025-03-26\n",
"------------------------------\n",
"Slice 1 形状: (3064, 184)\n",
"Slice 2 形状: (3064, 184)\n",
"!!! 索引不同,尝试按 ts_code 对齐 !!!\n",
"------------------------------\n",
"使用 compare() 方法查找差异:\n",
"!!! 发现差异 (compare结果):\n",
"MultiIndex([( 'turnover_rate_mean_5', 'self'),\n",
" ( 'turnover_rate_mean_5', 'other'),\n",
" ( 'variance_20', 'self'),\n",
" ( 'variance_20', 'other'),\n",
" ( 'bbi_ratio_factor', 'self'),\n",
" ( 'bbi_ratio_factor', 'other'),\n",
" ( 'upside_vol', 'self'),\n",
" ( 'upside_vol', 'other'),\n",
" ( 'downside_vol', 'self'),\n",
" ( 'downside_vol', 'other'),\n",
" ( 'vol_ratio', 'self'),\n",
" ( 'vol_ratio', 'other'),\n",
" ( 'return_skew', 'self'),\n",
" ( 'return_skew', 'other'),\n",
" ( 'return_kurtosis', 'self'),\n",
" ( 'return_kurtosis', 'other'),\n",
" ( 'volume_change_rate', 'self'),\n",
" ( 'volume_change_rate', 'other'),\n",
" ( 'turnover_deviation', 'self'),\n",
" ( 'turnover_deviation', 'other'),\n",
" ( 'avg_volume_ratio', 'self'),\n",
" ( 'avg_volume_ratio', 'other'),\n",
" ( 'vol_spike', 'self'),\n",
" ( 'vol_spike', 'other'),\n",
" ( 'vol_std_5', 'self'),\n",
" ( 'vol_std_5', 'other'),\n",
" ( 'atr_14', 'self'),\n",
" ( 'atr_14', 'other'),\n",
" ( 'atr_6', 'self'),\n",
" ( 'atr_6', 'other'),\n",
" ( 'obv', 'self'),\n",
" ( 'obv', 'other'),\n",
" ( 'maobv_6', 'self'),\n",
" ( 'maobv_6', 'other'),\n",
" ( 'std_return_5', 'self'),\n",
" ( 'std_return_5', 'other'),\n",
" ( 'std_return_90', 'self'),\n",
" ( 'std_return_90', 'other'),\n",
" ( 'std_return_90_2', 'self'),\n",
" ( 'std_return_90_2', 'other'),\n",
" ( 'act_factor2', 'self'),\n",
" ( 'act_factor2', 'other'),\n",
" ( 'act_factor3', 'self'),\n",
" ( 'act_factor3', 'other'),\n",
" ( 'act_factor4', 'self'),\n",
" ( 'act_factor4', 'other'),\n",
" ( 'cov', 'self'),\n",
" ( 'cov', 'other'),\n",
" ( 'delta_cov', 'self'),\n",
" ( 'delta_cov', 'other'),\n",
" ( 'alpha_22_improved', 'self'),\n",
" ( 'alpha_22_improved', 'other'),\n",
" ( 'alpha_013', 'self'),\n",
" ( 'alpha_013', 'other'),\n",
" ( 'price_cost_divergence', 'self'),\n",
" ( 'price_cost_divergence', 'other'),\n",
" ( 'cost_stability', 'self'),\n",
" ( 'cost_stability', 'other'),\n",
" ( 'liquidity_risk', 'self'),\n",
" ( 'liquidity_risk', 'other'),\n",
" ( 'turnover_std', 'self'),\n",
" ( 'turnover_std', 'other'),\n",
" ( 'mv_volatility', 'self'),\n",
" ( 'mv_volatility', 'other'),\n",
" ( 'momentum_factor', 'self'),\n",
" ( 'momentum_factor', 'other'),\n",
" ( 'obv_maobv_6', 'self'),\n",
" ( 'obv_maobv_6', 'other'),\n",
" ( 'std_return_5_over_std_return_90', 'self'),\n",
" ( 'std_return_5_over_std_return_90', 'other'),\n",
" ('std_return_90_minus_std_return_90_2', 'self'),\n",
" ('std_return_90_minus_std_return_90_2', 'other'),\n",
" ( 'act_factor5', 'self'),\n",
" ( 'act_factor5', 'other'),\n",
" ( 'act_factor6', 'self'),\n",
" ( 'act_factor6', 'other'),\n",
" ( 'cost_atr_adj', 'self'),\n",
" ( 'cost_atr_adj', 'other'),\n",
" ( 'lg_flow_mom_corr_20_60', 'self'),\n",
" ( 'lg_flow_mom_corr_20_60', 'other'),\n",
" ( 'cost_conc_std_20', 'self'),\n",
" ( 'cost_conc_std_20', 'other'),\n",
" ( 'vol_amp_loss_20', 'self'),\n",
" ( 'vol_amp_loss_20', 'other'),\n",
" ( 'lg_flow_vol_interact_20', 'self'),\n",
" ( 'lg_flow_vol_interact_20', 'other'),\n",
" ( 'turnover_diff_skew_20', 'self'),\n",
" ( 'turnover_diff_skew_20', 'other'),\n",
" ( 'lg_sm_flow_diverge_20', 'self'),\n",
" ( 'lg_sm_flow_diverge_20', 'other'),\n",
" ( 'vol_wgt_hist_pos_20', 'self'),\n",
" ( 'vol_wgt_hist_pos_20', 'other'),\n",
" ( 'vol_adj_roc_20', 'self'),\n",
" ( 'vol_adj_roc_20', 'other')],\n",
" )\n",
" turnover_rate_mean_5 variance_20 bbi_ratio_factor \\\n",
" self other self other self \n",
"ts_code \n",
"000001.SZ NaN NaN 1.350932 1.350932 1.010161 \n",
"000002.SZ NaN NaN 2.117668 2.117668 NaN \n",
"000004.SZ NaN NaN 8.061740 8.061740 NaN \n",
"000006.SZ NaN NaN 7.298740 7.298740 NaN \n",
"000007.SZ NaN NaN NaN NaN NaN \n",
"... ... ... ... ... ... \n",
"605580.SH NaN NaN NaN NaN NaN \n",
"605588.SH NaN NaN 6.613036 6.613036 NaN \n",
"605589.SH 1.2223 1.2223 7.727591 7.727591 NaN \n",
"605598.SH NaN NaN 3.079582 3.079582 NaN \n",
"605599.SH NaN NaN 3.339119 3.339119 NaN \n",
"\n",
" upside_vol downside_vol vol_ratio \\\n",
" other self other self other self \n",
"ts_code \n",
"000001.SZ 1.010161 NaN NaN NaN NaN NaN \n",
"000002.SZ NaN NaN NaN NaN NaN NaN \n",
"000004.SZ NaN 1.074360 1.074360 2.400379 2.400379 0.447579 \n",
"000006.SZ NaN NaN NaN NaN NaN NaN \n",
"000007.SZ NaN NaN NaN 1.277065 1.277065 1.098127 \n",
"... ... ... ... ... ... ... \n",
"605580.SH NaN NaN NaN 0.699229 0.699229 2.174044 \n",
"605588.SH NaN NaN NaN NaN NaN NaN \n",
"605589.SH NaN NaN NaN 1.540787 1.540787 1.064451 \n",
"605598.SH NaN 1.374078 1.374078 NaN NaN 1.528319 \n",
"605599.SH NaN NaN NaN NaN NaN NaN \n",
"\n",
" return_skew return_kurtosis \\\n",
" other self other self other \n",
"ts_code \n",
"000001.SZ NaN NaN NaN NaN NaN \n",
"000002.SZ NaN NaN NaN NaN NaN \n",
"000004.SZ 0.447579 -1.534768 -1.534768 2.792027 2.792027 \n",
"000006.SZ NaN -0.577787 -0.577787 -0.476524 -0.476524 \n",
"000007.SZ 1.098127 NaN NaN NaN NaN \n",
"... ... ... ... ... ... \n",
"605580.SH 2.174044 2.126640 2.126640 4.642559 4.642559 \n",
"605588.SH NaN 0.609743 0.609743 -3.243014 -3.243014 \n",
"605589.SH 1.064451 NaN NaN NaN NaN \n",
"605598.SH 1.528319 0.278034 0.278034 0.049195 0.049195 \n",
"605599.SH NaN 0.030262 0.030262 -0.694891 -0.694891 \n",
"\n",
" volume_change_rate turnover_deviation \\\n",
" self other self other \n",
"ts_code \n",
"000001.SZ -0.530526 -0.530526 -0.566673 -0.566673 \n",
"000002.SZ NaN NaN -0.446047 -0.446047 \n",
"000004.SZ NaN NaN -0.948364 -0.948364 \n",
"000006.SZ NaN NaN 1.052435 1.052435 \n",
"000007.SZ NaN NaN 1.111991 1.111991 \n",
"... ... ... ... ... \n",
"605580.SH NaN NaN NaN NaN \n",
"605588.SH NaN NaN -0.678620 -0.678620 \n",
"605589.SH NaN NaN -0.938677 -0.938677 \n",
"605598.SH 0.166773 0.166773 0.933048 0.933048 \n",
"605599.SH NaN NaN 1.075588 1.075588 \n",
"\n",
" avg_volume_ratio vol_spike vol_std_5 \\\n",
" self other self other self other \n",
"ts_code \n",
"000001.SZ NaN NaN NaN NaN 0.232868 0.232868 \n",
"000002.SZ 0.876667 0.876667 NaN NaN NaN NaN \n",
"000004.SZ NaN NaN NaN NaN NaN NaN \n",
"000006.SZ NaN NaN NaN NaN NaN NaN \n",
"000007.SZ NaN NaN NaN NaN NaN NaN \n",
"... ... ... ... ... ... ... \n",
"605580.SH NaN NaN NaN NaN 1.164645 1.164645 \n",
"605588.SH 0.596667 0.596667 NaN NaN 0.314876 0.314876 \n",
"605589.SH NaN NaN NaN NaN 0.562543 0.562543 \n",
"605598.SH NaN NaN NaN NaN 1.057029 1.057029 \n",
"605599.SH NaN NaN NaN NaN 0.193314 0.193314 \n",
"\n",
" atr_14 atr_6 obv \\\n",
" self other self other self other \n",
"ts_code \n",
"000001.SZ NaN NaN NaN NaN 11801312.72 2738247.22 \n",
"000002.SZ NaN NaN NaN NaN 41291828.48 25339065.82 \n",
"000004.SZ 2.030307 2.030307 NaN NaN 6090154.93 5915712.53 \n",
"000006.SZ NaN NaN NaN NaN 27874233.68 14515751.96 \n",
"000007.SZ 1.713940 1.713940 NaN NaN 1716807.04 270569.37 \n",
"... ... ... ... ... ... ... \n",
"605580.SH NaN NaN NaN NaN 4089325.91 3674511.22 \n",
"605588.SH NaN NaN NaN NaN 1537384.91 1376415.97 \n",
"605589.SH NaN NaN NaN NaN 6078107.94 5044023.07 \n",
"605598.SH NaN NaN NaN NaN 3839018.04 1797711.65 \n",
"605599.SH NaN NaN NaN NaN 2485523.73 2575349.58 \n",
"\n",
" maobv_6 std_return_5 std_return_90 \\\n",
" self other self other self \n",
"ts_code \n",
"000001.SZ 1.315290e+07 4.089838e+06 0.004030 0.004030 0.010625 \n",
"000002.SZ 4.132689e+07 2.537412e+07 0.010652 0.010652 0.023221 \n",
"000004.SZ 6.210621e+06 6.036178e+06 NaN NaN NaN \n",
"000006.SZ 2.735946e+07 1.400098e+07 NaN NaN 0.035182 \n",
"000007.SZ 1.642407e+06 1.961695e+05 0.032904 0.032904 0.022552 \n",
"... ... ... ... ... ... \n",
"605580.SH 4.143541e+06 3.728727e+06 0.031017 0.031017 0.026001 \n",
"605588.SH 1.534916e+06 1.373947e+06 NaN NaN 0.028813 \n",
"605589.SH 6.139128e+06 5.105043e+06 0.017784 0.017784 0.023788 \n",
"605598.SH 3.879454e+06 1.838147e+06 0.030152 0.030152 0.034004 \n",
"605599.SH 2.480334e+06 2.570160e+06 NaN NaN 0.020242 \n",
"\n",
" std_return_90_2 act_factor2 act_factor3 \\\n",
" other self other self other self \n",
"ts_code \n",
"000001.SZ 0.010625 0.010835 0.010835 NaN NaN NaN \n",
"000002.SZ 0.023221 0.024306 0.024306 NaN NaN NaN \n",
"000004.SZ NaN 0.045928 0.045928 NaN NaN NaN \n",
"000006.SZ 0.035182 0.041835 0.041835 NaN NaN NaN \n",
"000007.SZ 0.022552 0.026021 0.026021 NaN NaN 0.100465 \n",
"... ... ... ... ... ... ... \n",
"605580.SH 0.026001 0.025335 0.025335 NaN NaN NaN \n",
"605588.SH 0.028813 0.029796 0.029796 NaN NaN NaN \n",
"605589.SH 0.023788 0.024998 0.024998 NaN NaN NaN \n",
"605598.SH 0.034004 0.033996 0.033996 NaN NaN NaN \n",
"605599.SH 0.020242 0.019792 0.019792 NaN NaN NaN \n",
"\n",
" act_factor4 cov \\\n",
" other self other self other \n",
"ts_code \n",
"000001.SZ NaN -0.232135 -0.232135 NaN NaN \n",
"000002.SZ NaN -0.920105 -0.920105 NaN NaN \n",
"000004.SZ NaN -3.500298 -3.500299 NaN NaN \n",
"000006.SZ NaN -0.349179 -0.349179 1.083770e+06 1.083770e+06 \n",
"000007.SZ 0.100465 -0.608592 -0.608472 NaN NaN \n",
"... ... ... ... ... ... \n",
"605580.SH NaN 1.490573 1.490573 NaN NaN \n",
"605588.SH NaN 0.353965 0.353965 NaN NaN \n",
"605589.SH NaN 0.537608 0.537608 NaN NaN \n",
"605598.SH NaN -0.092286 -0.092286 NaN NaN \n",
"605599.SH NaN 0.777405 0.777405 NaN NaN \n",
"\n",
" delta_cov alpha_22_improved \\\n",
" self other self other \n",
"ts_code \n",
"000001.SZ 6.374603e+06 6.374603e+06 -6.262257e+06 -6.262257e+06 \n",
"000002.SZ NaN NaN NaN NaN \n",
"000004.SZ NaN NaN NaN NaN \n",
"000006.SZ 1.064318e+06 1.064318e+06 -9.302364e+05 -9.302364e+05 \n",
"000007.SZ NaN NaN NaN NaN \n",
"... ... ... ... ... \n",
"605580.SH NaN NaN NaN NaN \n",
"605588.SH NaN NaN NaN NaN \n",
"605589.SH NaN NaN NaN NaN \n",
"605598.SH 4.773888e+03 4.773888e+03 -1.441203e+03 -1.441203e+03 \n",
"605599.SH NaN NaN NaN NaN \n",
"\n",
" alpha_013 price_cost_divergence cost_stability \\\n",
" self other self other self \n",
"ts_code \n",
"000001.SZ NaN NaN 0.903682 0.903682 0.001266 \n",
"000002.SZ NaN NaN -0.063527 -0.063527 0.003122 \n",
"000004.SZ NaN NaN 0.448699 0.448699 0.039623 \n",
"000006.SZ NaN NaN -0.578683 -0.578683 0.014733 \n",
"000007.SZ NaN NaN -0.437721 -0.437721 0.012050 \n",
"... ... ... ... ... ... \n",
"605580.SH NaN NaN NaN NaN 0.007148 \n",
"605588.SH NaN NaN 0.135146 0.135146 0.011301 \n",
"605589.SH NaN NaN 0.653798 0.653798 0.012510 \n",
"605598.SH NaN NaN 0.386083 0.386083 0.002406 \n",
"605599.SH NaN NaN 0.170158 0.170158 0.002202 \n",
"\n",
" liquidity_risk turnover_std mv_volatility \\\n",
" other self other self other self \n",
"ts_code \n",
"000001.SZ 0.001266 NaN NaN 0.447536 0.447536 0.026465 \n",
"000002.SZ 0.003122 NaN NaN 0.583259 0.583259 0.037007 \n",
"000004.SZ 0.039623 NaN NaN NaN NaN NaN \n",
"000006.SZ 0.014733 NaN NaN 0.879670 0.879670 0.064097 \n",
"000007.SZ 0.012050 NaN NaN 0.574015 0.574015 0.046657 \n",
"... ... ... ... ... ... ... \n",
"605580.SH 0.007148 NaN NaN 0.795605 0.795605 0.062561 \n",
"605588.SH 0.011301 NaN NaN NaN NaN NaN \n",
"605589.SH 0.012510 NaN NaN 0.824547 0.824547 0.056549 \n",
"605598.SH 0.002406 NaN NaN 0.617382 0.617382 0.046860 \n",
"605599.SH 0.002202 NaN NaN 0.430581 0.430581 0.031348 \n",
"\n",
" momentum_factor obv_maobv_6 \\\n",
" other self other self other \n",
"ts_code \n",
"000001.SZ 0.026465 -0.813862 -0.813862 -1.351591e+06 -1.351591e+06 \n",
"000002.SZ 0.037007 -0.500416 -0.500416 -3.505751e+04 -3.505751e+04 \n",
"000004.SZ NaN -0.480611 -0.480611 -1.204658e+05 -1.204658e+05 \n",
"000006.SZ 0.064097 0.690214 0.690214 5.147734e+05 5.147734e+05 \n",
"000007.SZ 0.046657 1.021245 1.021245 7.439987e+04 7.439987e+04 \n",
"... ... ... ... ... ... \n",
"605580.SH 0.062561 NaN NaN -5.421535e+04 -5.421535e+04 \n",
"605588.SH NaN -0.582460 -0.582460 NaN NaN \n",
"605589.SH 0.056549 -0.673217 -0.673217 -6.101995e+04 -6.101995e+04 \n",
"605598.SH 0.046860 0.633297 0.633297 -4.043576e+04 -4.043575e+04 \n",
"605599.SH 0.031348 0.072720 0.072720 5.189752e+03 5.189752e+03 \n",
"\n",
" std_return_5_over_std_return_90 \\\n",
" self other \n",
"ts_code \n",
"000001.SZ 0.379306 0.379306 \n",
"000002.SZ 0.458738 0.458738 \n",
"000004.SZ NaN NaN \n",
"000006.SZ 0.971216 0.971216 \n",
"000007.SZ 1.459001 1.459001 \n",
"... ... ... \n",
"605580.SH 1.192910 1.192910 \n",
"605588.SH 0.976527 0.976527 \n",
"605589.SH 0.747608 0.747608 \n",
"605598.SH 0.886730 0.886730 \n",
"605599.SH 0.981816 0.981816 \n",
"\n",
" std_return_90_minus_std_return_90_2 act_factor5 \\\n",
" self other self other \n",
"ts_code \n",
"000001.SZ -0.000210 -0.000210 -1.146214 -1.146214 \n",
"000002.SZ NaN NaN -2.818276 -2.818276 \n",
"000004.SZ -0.001451 -0.001451 -8.723004 -8.723004 \n",
"000006.SZ NaN NaN 1.744042 1.744042 \n",
"000007.SZ -0.003469 -0.003469 0.744885 0.745004 \n",
"... ... ... ... ... \n",
"605580.SH NaN NaN 2.746294 2.746294 \n",
"605588.SH -0.000984 -0.000984 0.419094 0.419094 \n",
"605589.SH NaN NaN -1.461317 -1.461317 \n",
"605598.SH 0.000008 0.000008 -1.883670 -1.883670 \n",
"605599.SH NaN NaN 1.765644 1.765644 \n",
"\n",
" act_factor6 cost_atr_adj lg_flow_mom_corr_20_60 \\\n",
" self other self other self \n",
"ts_code \n",
"000001.SZ NaN NaN NaN NaN 0.611662 \n",
"000002.SZ NaN NaN NaN NaN 0.928029 \n",
"000004.SZ NaN NaN 1.379102 1.379102 -0.146624 \n",
"000006.SZ NaN NaN NaN NaN 0.834084 \n",
"000007.SZ NaN NaN 0.816831 0.816831 0.539901 \n",
"... ... ... ... ... ... \n",
"605580.SH NaN NaN NaN NaN 0.202900 \n",
"605588.SH NaN NaN NaN NaN 0.679962 \n",
"605589.SH NaN NaN NaN NaN -0.008683 \n",
"605598.SH NaN NaN NaN NaN 0.824691 \n",
"605599.SH NaN NaN NaN NaN 0.243661 \n",
"\n",
" cost_conc_std_20 vol_amp_loss_20 \\\n",
" other self other self other \n",
"ts_code \n",
"000001.SZ 0.611662 0.012551 0.012551 NaN NaN \n",
"000002.SZ 0.928029 0.013871 0.013871 NaN NaN \n",
"000004.SZ -0.146624 NaN NaN NaN NaN \n",
"000006.SZ 0.834084 0.024574 0.024574 NaN NaN \n",
"000007.SZ 0.539901 NaN NaN NaN NaN \n",
"... ... ... ... ... ... \n",
"605580.SH 0.202899 0.005251 0.005251 NaN NaN \n",
"605588.SH 0.679961 0.010481 0.010481 NaN NaN \n",
"605589.SH -0.008683 0.024145 0.024145 NaN NaN \n",
"605598.SH 0.824690 0.013963 0.013963 NaN NaN \n",
"605599.SH 0.243661 0.007508 0.007508 NaN NaN \n",
"\n",
" lg_flow_vol_interact_20 turnover_diff_skew_20 \\\n",
" self other self other \n",
"ts_code \n",
"000001.SZ 0.083196 0.083196 NaN NaN \n",
"000002.SZ 0.134502 0.134502 NaN NaN \n",
"000004.SZ 0.135444 0.135444 NaN NaN \n",
"000006.SZ 0.190142 0.190142 NaN NaN \n",
"000007.SZ NaN NaN NaN NaN \n",
"... ... ... ... ... \n",
"605580.SH NaN NaN NaN NaN \n",
"605588.SH 0.094075 0.094075 NaN NaN \n",
"605589.SH 0.123490 0.123490 -0.580349 -0.580349 \n",
"605598.SH 0.068323 0.068323 0.586522 0.586522 \n",
"605599.SH 0.079483 0.079483 1.398138 1.398138 \n",
"\n",
" lg_sm_flow_diverge_20 vol_wgt_hist_pos_20 \\\n",
" self other self other \n",
"ts_code \n",
"000001.SZ NaN NaN NaN NaN \n",
"000002.SZ NaN NaN NaN NaN \n",
"000004.SZ NaN NaN NaN NaN \n",
"000006.SZ NaN NaN NaN NaN \n",
"000007.SZ NaN NaN NaN NaN \n",
"... ... ... ... ... \n",
"605580.SH -0.036004 -0.036004 NaN NaN \n",
"605588.SH 0.035758 0.035758 NaN NaN \n",
"605589.SH NaN NaN NaN NaN \n",
"605598.SH 0.013836 0.013836 NaN NaN \n",
"605599.SH -0.049469 -0.049469 NaN NaN \n",
"\n",
" vol_adj_roc_20 \n",
" self other \n",
"ts_code \n",
"000001.SZ -0.010456 -0.010456 \n",
"000002.SZ -0.056096 -0.056096 \n",
"000004.SZ -0.067600 -0.067600 \n",
"000006.SZ 0.008398 0.008398 \n",
"000007.SZ NaN NaN \n",
"... ... ... \n",
"605580.SH NaN NaN \n",
"605588.SH -0.008329 -0.008329 \n",
"605589.SH -0.018240 -0.018240 \n",
"605598.SH -0.015839 -0.015839 \n",
"605599.SH 0.022997 0.022997 \n",
"\n",
"[3064 rows x 94 columns]\n",
"\n",
"存在差异的列: ['turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'turnover_deviation', 'avg_volume_ratio', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor2', 'act_factor3', 'act_factor4', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_013', 'price_cost_divergence', 'cost_stability', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'momentum_factor', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'act_factor5', 'act_factor6', 'cost_atr_adj', 'lg_flow_mom_corr_20_60', 'cost_conc_std_20', 'vol_amp_loss_20', 'lg_flow_vol_interact_20', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20']\n"
]
}
],
"source": [
"# 假设 slice1 和 slice2 已经获取,并且索引和列已对齐\n",
"# (如果索引或列不同,需要先用 .sort_index() 或 .sort_index(axis=1) 对齐)\n",
"# 假设 pdf1 和 pdf2 已经是处理到最后一步的结果\n",
"date_to_compare = '2025-03-26'\n",
"\n",
"# 1. 获取两个 DataFrame 在该日期的切片\n",
"slice1 = df1[df1['trade_date'] == date_to_compare]\n",
"slice2 = df2[df2['trade_date'] == date_to_compare]\n",
"\n",
"def get_diff(slice1, slice2):\n",
" print(f\"日期: {date_to_compare}\")\n",
" print(\"-\" * 30)\n",
" print(f\"Slice 1 形状: {slice1.shape}\")\n",
" print(f\"Slice 2 形状: {slice2.shape}\")\n",
" if slice1.shape != slice2.shape:\n",
" print(\"!!! 形状不同 !!!\")\n",
"\n",
" if not slice1.index.equals(slice2.index):\n",
" print(\"!!! 索引不同,尝试按 ts_code 对齐 !!!\")\n",
" try:\n",
" slice1 = slice1.set_index('ts_code').sort_index()\n",
" slice2 = slice2.set_index('ts_code').sort_index()\n",
" except KeyError:\n",
" print(\"错误:无法按 ts_code 设置索引,请确保该列存在。\")\n",
" # 或者尝试其他对齐方式,例如 reset_index\n",
" # slice1 = slice1.reset_index(drop=True)\n",
" # slice2 = slice2.reset_index(drop=True)\n",
"\n",
" if not slice1.columns.equals(slice2.columns):\n",
" print(\"!!! 列名或顺序不同,尝试按列名排序对齐 !!!\")\n",
" slice1 = slice1.sort_index(axis=1)\n",
" slice2 = slice2.sort_index(axis=1)\n",
"\n",
" # 再次检查对齐情况\n",
" if slice1.index.equals(slice2.index) and slice1.columns.equals(slice2.columns):\n",
" print(\"-\" * 30)\n",
" print(\"使用 compare() 方法查找差异:\")\n",
" try:\n",
" # compare 会返回一个显示差异的 DataFrame\n",
" # self 列显示 slice1 的值other 列显示 slice2 的值\n",
" diff_compare = slice1.compare(slice2)\n",
"\n",
" if diff_compare.empty:\n",
" print(\"使用 compare() 未发现差异。\")\n",
" # 如果 compare 为空但 equals 仍为 False, 可能是非常细微的浮点差异或类型差异\n",
" # 可以再检查一下dtypes\n",
" if not slice1.dtypes.equals(slice2.dtypes):\n",
" print(\"!!! 发现数据类型 (dtypes) 不同 !!!\")\n",
" print(slice1.dtypes[slice1.dtypes != slice2.dtypes])\n",
" print(slice2.dtypes[slice1.dtypes != slice2.dtypes])\n",
"\n",
" else:\n",
" print(\"!!! 发现差异 (compare结果):\")\n",
" # 默认情况下compare 的列是 MultiIndex ('列名', 'self'/'other')\n",
" # 为了更清晰地显示,可以调整一下格式\n",
" # diff_compare.columns = ['_'.join(col) for col in diff_compare.columns]\n",
" print(diff_compare.columns)\n",
" print(diff_compare[diff_compare[('vol_std_5', 'self')] != diff_compare[('vol_std_5', 'other')]]) # 打印差异的头部\n",
"\n",
" # 找出哪些列存在差异\n",
" differing_columns = diff_compare.columns.get_level_values(0).unique().tolist()\n",
" print(f\"\\n存在差异的列: {differing_columns}\")\n",
"\n",
" except Exception as e:\n",
" print(f\"使用 compare() 时出错: {e}\")\n",
" else:\n",
" print(\"-\" * 30)\n",
" print(\"索引或列在对齐后仍然不匹配,无法使用 compare()。请检查对齐逻辑。\")\n",
"\n",
"get_diff(slice1, slice2)\n",
"# print(df1['trade_date'].unique().tolist()[-5:])\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d56f61c4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"set()\n",
"set()\n",
"3064\n",
"3064\n",
"3064\n",
"3064\n"
]
}
],
"source": [
"s1 = set(df1[df1['trade_date'] == date_to_compare].columns.to_list())\n",
"s2 = set(df2[df2['trade_date'] == date_to_compare].columns.to_list())\n",
"\n",
"print(s2 - s1)\n",
"print(s1 - s2)\n",
"\n",
"print(len(df1[df1['trade_date'] == date_to_compare]))\n",
"print(len(df2[df2['trade_date'] == date_to_compare]))\n",
"\n",
"print(df1[df1['trade_date'] == date_to_compare]['ts_code'].nunique())\n",
"print(df2[df2['trade_date'] == date_to_compare]['ts_code'].nunique())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9df2781fc6c7ae44",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:43:55.223316Z",
"start_time": "2025-04-09T16:43:53.868461Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from scipy.stats import ks_2samp, wasserstein_distance\n",
"\n",
"\n",
"def remove_shifted_features(train_data, feature_columns, ks_threshold=0.05, wasserstein_threshold=0.1, size=0.8,\n",
" log=True):\n",
" dropped_features = []\n",
"\n",
" all_dates = sorted(train_data['trade_date'].unique().tolist()) # 获取所有唯一的 trade_date\n",
" split_date = all_dates[int(len(all_dates) * size)] # 划分点为倒数第 validation_days 天\n",
" train_data_split = train_data[train_data['trade_date'] < split_date] # 训练集\n",
" val_data_split = train_data[train_data['trade_date'] >= split_date] # 验证集\n",
"\n",
" # **统计数据漂移**\n",
" numeric_columns = train_data_split.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" for feature in numeric_columns:\n",
" ks_stat, p_value = ks_2samp(train_data_split[feature], val_data_split[feature])\n",
" wasserstein_dist = wasserstein_distance(train_data_split[feature], val_data_split[feature])\n",
"\n",
" if p_value < ks_threshold or wasserstein_dist > wasserstein_threshold:\n",
" dropped_features.append(feature)\n",
" if log:\n",
" print(f\"检测到 {len(dropped_features)} 个可能漂移的特征: {dropped_features}\")\n",
"\n",
" # **应用阈值进行最终筛选**\n",
" filtered_features = [f for f in feature_columns if f not in dropped_features]\n",
"\n",
" return filtered_features, dropped_features\n",
"\n",
"\n",
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
" log=True):\n",
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
"\n",
" # Calculate lower and upper bounds based on percentiles\n",
" lower_bound = label.quantile(lower_percentile)\n",
" upper_bound = label.quantile(upper_percentile)\n",
"\n",
" # Filter out values outside the bounds\n",
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
"\n",
" # Print the number of removed outliers\n",
" if log:\n",
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
" return filtered_label\n",
"\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
" df['future_return'] = (df['future_close'] - df['future_open']) / df['future_open']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" sharpe_ratio = df['future_return'] * df['future_volatility']\n",
" sharpe_ratio.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return sharpe_ratio\n",
"\n",
"\n",
"def calculate_score(df, days=5, lambda_param=1.0):\n",
" def calculate_max_drawdown(prices):\n",
" peak = prices.iloc[0] # 初始化峰值\n",
" max_drawdown = 0 # 初始化最大回撤\n",
"\n",
" for price in prices:\n",
" if price > peak:\n",
" peak = price # 更新峰值\n",
" else:\n",
" drawdown = (peak - price) / peak # 计算当前回撤\n",
" max_drawdown = max(max_drawdown, drawdown) # 更新最大回撤\n",
"\n",
" return max_drawdown\n",
"\n",
" def compute_stock_score(stock_df):\n",
" stock_df = stock_df.sort_values(by=['trade_date'])\n",
" future_return = stock_df['future_return']\n",
" # 使用已有的 pct_chg 字段计算波动率\n",
" volatility = stock_df['pct_chg'].rolling(days).std().shift(-days)\n",
" max_drawdown = stock_df['close'].rolling(days).apply(calculate_max_drawdown, raw=False).shift(-days)\n",
" score = future_return - lambda_param * max_drawdown\n",
" return score\n",
"\n",
" # # 确保 DataFrame 按照股票代码和交易日期排序\n",
" # df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" # 对每个股票分别计算 score\n",
" df['score'] = df.groupby('ts_code').apply(compute_stock_score).reset_index(level=0, drop=True)\n",
"\n",
" return df['score']\n",
"\n",
"\n",
"def remove_highly_correlated_features(df, feature_columns, threshold=0.9):\n",
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
" if not numeric_features:\n",
" raise ValueError(\"No numeric features found in the provided data.\")\n",
"\n",
" corr_matrix = df[numeric_features].corr().abs()\n",
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
" remaining_features = [col for col in feature_columns if col not in to_drop\n",
" or 'act' in col or 'af' in col]\n",
" return remaining_features\n",
"\n",
"\n",
"def cross_sectional_standardization(df, features):\n",
" df_sorted = df.sort_values(by='trade_date') # 按时间排序\n",
" df_standardized = df_sorted.copy()\n",
"\n",
" for date in df_sorted['trade_date'].unique():\n",
" # 获取当前时间点的数据\n",
" current_data = df_standardized[df_standardized['trade_date'] == date]\n",
"\n",
" # 只对指定特征进行标准化\n",
" scaler = StandardScaler()\n",
" standardized_values = scaler.fit_transform(current_data[features])\n",
"\n",
" # 将标准化结果重新赋值回去\n",
" df_standardized.loc[df_standardized['trade_date'] == date, features] = standardized_values\n",
"\n",
" return df_standardized\n",
"\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"\n",
"def neutralize_manual(df, features, industry_col, mkt_cap_col):\n",
" \"\"\" 手动实现简单回归以提升速度 \"\"\"\n",
"\n",
" for col in features:\n",
" residuals = []\n",
" for _, group in df.groupby(industry_col):\n",
" if len(group) > 1:\n",
" x = np.log(group[mkt_cap_col]) # 市值对数\n",
" y = group[col] # 因子值\n",
" beta = np.cov(y, x)[0, 1] / np.var(x) # 计算斜率\n",
" alpha = np.mean(y) - beta * np.mean(x) # 计算截距\n",
" resid = y - (alpha + beta * x) # 计算残差\n",
" residuals.extend(resid)\n",
" else:\n",
" residuals.extend(group[col]) # 样本不足时保留原值\n",
"\n",
" df[col] = residuals\n",
"\n",
" return df\n",
"\n",
"\n",
"import gc\n",
"\n",
"gc.collect()\n",
"\n",
"\n",
"def mad_filter(df, features, n=3):\n",
" for col in features:\n",
" median = df[col].median()\n",
" mad = np.median(np.abs(df[col] - median))\n",
" upper = median + n * mad\n",
" lower = median - n * mad\n",
" df[col] = np.clip(df[col], lower, upper) # 截断极值\n",
" return df\n",
"\n",
"\n",
"def percentile_filter(df, features, lower_percentile=0.01, upper_percentile=0.99):\n",
" for col in features:\n",
" # 按日期分组计算上下百分位数\n",
" lower_bound = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.quantile(lower_percentile)\n",
" )\n",
" upper_bound = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.quantile(upper_percentile)\n",
" )\n",
" # 截断超出范围的值\n",
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
" return df\n",
"\n",
"\n",
"from scipy.stats import iqr\n",
"\n",
"\n",
"def iqr_filter(df, features):\n",
" for col in features:\n",
" df[col] = df.groupby('trade_date')[col].transform(\n",
" lambda x: (x - x.median()) / iqr(x) if iqr(x) != 0 else x\n",
" )\n",
" return df\n",
"\n",
"\n",
"def quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99, window=60):\n",
" df = df.copy()\n",
" for col in features:\n",
" # 计算 rolling 统计量,需要按日期进行 groupby\n",
" rolling_lower = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.rolling(window=min(len(x), window)).quantile(lower_quantile))\n",
" rolling_upper = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.rolling(window=min(len(x), window)).quantile(upper_quantile))\n",
"\n",
" # 对数据进行裁剪\n",
" df[col] = np.clip(df[col], rolling_lower, rolling_upper)\n",
"\n",
" return df\n",
"\n",
"def time_series_quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99, window=60):\n",
" df = df.copy()\n",
" # 确保按股票和时间排序\n",
" df = df.sort_values(['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
" for col in features:\n",
" # 对每个股票的时间序列计算滚动分位数\n",
" rolling_lower = grouped[col].rolling(window=window, min_periods=window // 2).quantile(lower_quantile)\n",
" rolling_upper = grouped[col].rolling(window=window, min_periods=window // 2).quantile(upper_quantile)\n",
" # rolling结果带有多重索引需要对齐\n",
" rolling_lower = rolling_lower.reset_index(level=0, drop=True)\n",
" rolling_upper = rolling_upper.reset_index(level=0, drop=True)\n",
" # 应用 clip\n",
" df[col] = np.clip(df[col], rolling_lower, rolling_upper)\n",
" return df\n",
"\n",
"def cross_sectional_quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99):\n",
" df = df.copy()\n",
" grouped = df.groupby('trade_date')\n",
" for col in features:\n",
" # 计算每日截面的分位数边界\n",
" lower_bound = grouped[col].transform(lambda x: x.quantile(lower_quantile))\n",
" upper_bound = grouped[col].transform(lambda x: x.quantile(upper_quantile))\n",
" # 应用 clip\n",
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99f677aca6a286d0",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:54:07.250024Z",
"start_time": "2025-04-09T16:53:57.299050Z"
}
},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'industry_df1' is not defined",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mNameError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[13]\u001b[39m\u001b[32m, line 81\u001b[39m\n\u001b[32m 77\u001b[39m feature_columns = remove_highly_correlated_features(pdf, feature_columns)\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m pdf, feature_columns, filter_index\n\u001b[32m---> \u001b[39m\u001b[32m81\u001b[39m pdf1, feature_columns1, filter_index1 = get_pdf(df1[df1[\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m] >= \u001b[33m'\u001b[39m\u001b[33m2025-04-01\u001b[39m\u001b[33m'\u001b[39m], \u001b[43mindustry_df1\u001b[49m)\n\u001b[32m 82\u001b[39m pdf2, feature_columns2, filter_index2 = get_pdf(df2[df2[\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m] >= \u001b[33m'\u001b[39m\u001b[33m2025-04-01\u001b[39m\u001b[33m'\u001b[39m], industry_df2)\n\u001b[32m 84\u001b[39m \u001b[38;5;66;03m# date_to_compare = '2025-04-07'\u001b[39;00m\n",
"\u001b[31mNameError\u001b[39m: name 'industry_df1' is not defined"
]
}
],
"source": [
"def get_pdf(df, industry_df):\n",
" origin_columns = df.columns.tolist()\n",
" origin_columns = [col for col in origin_columns if\n",
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
" origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
"\n",
" days = 2\n",
" # df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" # # df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
" # df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / \\\n",
" # df.groupby('ts_code')['open'].shift(-1)\n",
" # df['future_score'] = calculate_score(df, days=2, lambda_param=0.3)\n",
" # df['label'] = df.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" # lambda x: pd.qcut(x, q=20, labels=False, duplicates='drop')\n",
" # )\n",
" # df['label'] = df.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" # lambda x: pd.qcut(x.rank(method='first'), q=20, labels=False, duplicates='raise')\n",
" # )\n",
" # df['future_score'] = (\n",
" # 0.7 * df['future_return']\n",
" # * 0.3 * df['future_volatility']\n",
" # )\n",
"\n",
" def select_pre_zt_stocks_dynamic(stock_df):\n",
" def select_stocks(group):\n",
" max_stocks = 150\n",
" initial_data = group.nlargest(100, 'return_20')\n",
" unique_labels = initial_data['label'].nunique()\n",
"\n",
" print(group['trade_date'].unique().tolist(), initial_data['label'].nunique(), initial_data['label'].unique())\n",
" if unique_labels >= 20 or unique_labels == 0: # 包含标签种类为0的情况\n",
" return initial_data\n",
"\n",
" for i in range(110, max_stocks + 1, 10):\n",
" data = group.nlargest(i, 'return_20')\n",
" unique_labels = data['label'].nunique()\n",
" if unique_labels >= 20:\n",
" return data\n",
"\n",
" return group.nlargest(max_stocks, 'return_20') # 如果循环结束仍未找到足够标签,则返回最大数量的股票\n",
"\n",
" stock_df = stock_df.groupby('trade_date', group_keys=False).apply(select_stocks)\n",
" return stock_df\n",
"\n",
"\n",
" pdf = select_pre_zt_stocks_dynamic(df[(df['trade_date'] >= '2022-01-01') & (df['trade_date'] <= '2029-04-07')])\n",
" print(pdf['trade_date'].max())\n",
"\n",
" pdf = pdf.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
" pdf = pdf.replace([np.inf, -np.inf], np.nan)\n",
"\n",
" feature_columns = [col for col in pdf.columns if col in pdf.columns]\n",
" feature_columns = [col for col in feature_columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
" feature_columns = [col for col in feature_columns if 'future' not in col]\n",
" feature_columns = [col for col in feature_columns if 'label' not in col]\n",
" feature_columns = [col for col in feature_columns if 'score' not in col]\n",
" feature_columns = [col for col in feature_columns if 'gen' not in col]\n",
" feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]\n",
" feature_columns = [col for col in feature_columns if 'volatility' not in col]\n",
" feature_columns = [col for col in feature_columns if 'cat_l2_code' not in col]\n",
" feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
" feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
" # feature_columns = [col for col in feature_columns if col not in ['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']]\n",
"\n",
" numeric_columns = pdf.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
"\n",
" pdf = cross_sectional_quantile_filter(pdf, numeric_columns)\n",
" # pdf = cross_sectional_standardization(pdf, numeric_columns)\n",
"\n",
" pdf = pdf.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" filter_index = pdf['future_return'].between(pdf['future_return'].quantile(0.01), pdf['future_return'].quantile(0.99))\n",
"\n",
" feature_columns = remove_highly_correlated_features(pdf, feature_columns)\n",
"\n",
" return pdf, feature_columns, filter_index\n",
"\n",
"pdf1, feature_columns1, filter_index1 = get_pdf(df1[df1['trade_date'] >= '2025-04-01'], industry_df1)\n",
"pdf2, feature_columns2, filter_index2 = get_pdf(df2[df2['trade_date'] >= '2025-04-01'], industry_df2)\n",
"\n",
"# date_to_compare = '2025-04-07'\n",
"slice1 = pdf1[pdf1['trade_date'] == date_to_compare]\n",
"slice2 = pdf2[pdf2['trade_date'] == date_to_compare]\n",
"get_diff(slice1, slice2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b863e4115252d2d",
"metadata": {
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"import lightgbm as lgb\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def train_light_model(train_data_df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" validation_days=180, use_pca=False, split_date=None): # 新增参数validation_days\n",
" # 确保数据按时间排序\n",
" train_data_df = train_data_df.sort_values(by='trade_date')\n",
"\n",
" numeric_columns = train_data_df.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" # X_train.loc[:, numeric_columns] = scaler.fit_transform(X_train[numeric_columns])\n",
" # X_val.loc[:, numeric_columns] = scaler.transform(X_val[numeric_columns])\n",
" # train_data_df = cross_sectional_standardization(train_data_df, numeric_columns)\n",
"\n",
" # 去除标签为空的样本\n",
" train_data_df = train_data_df.dropna(subset=['label'])\n",
" # print('原始训练集大小: ', len(train_data_df))\n",
"\n",
" # 按时间顺序划分训练集和验证集\n",
" if split_date is None:\n",
" all_dates = train_data_df['trade_date'].unique() # 获取所有唯一的 trade_date\n",
" if validation_days == 0:\n",
" split_date = all_dates[-1]\n",
" else:\n",
" split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
" if validation_days == 0:\n",
" train_data_split = train_data_df\n",
" else:\n",
" train_data_split = train_data_df[train_data_df['trade_date'] < split_date] # 训练集\n",
" val_data_split = train_data_df[train_data_df['trade_date'] >= split_date] # 验证集\n",
"\n",
" # 打印划分结果\n",
" print(f\"划分后的训练集大小: {len(train_data_split)}, 验证集大小: {len(val_data_split)}\")\n",
"\n",
" # 提取特征和标签\n",
" X_train = train_data_split[feature_columns]\n",
" y_train = train_data_split['label']\n",
"\n",
" # 标准化数值特征\n",
" scaler = StandardScaler()\n",
"\n",
" # 计算每个 trade_date 内的样本数LTR 需要 group 信息)\n",
" train_groups = train_data_split.groupby('trade_date').size().tolist()\n",
" val_groups = val_data_split.groupby('trade_date').size().tolist()\n",
"\n",
" # 处理类别特征\n",
" categorical_feature = [col for col in feature_columns if 'cat' in col]\n",
"\n",
" # 计算权重(基于时间)\n",
" # trade_date = train_data_split['trade_date'] # 交易日期\n",
" # weights = (trade_date - trade_date.min()).dt.days / (trade_date.max() - trade_date.min()).days + 1\n",
" # weights = train_data_split.groupby('trade_date')['std_return_5'].transform(\n",
" # lambda x: x / x.mean()\n",
" # )\n",
" ud = sorted(train_data_split[\"trade_date\"].unique().tolist())\n",
" date_weights = {date: weight * weight for date, weight in zip(ud, np.linspace(1, 10, len(ud)))}\n",
" params['weight'] = train_data_split[\"trade_date\"].map(date_weights).tolist()\n",
"\n",
" train_dataset = lgb.Dataset(\n",
" X_train, label=y_train, group=train_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
"\n",
" if validation_days > 0:\n",
" X_val = val_data_split[feature_columns]\n",
" y_val = val_data_split['label']\n",
" val_groups = val_data_split.groupby('trade_date').size().tolist()\n",
" val_dataset = lgb.Dataset(\n",
" X_val, label=y_val, group=val_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
" # 训练模型\n",
" model = lgb.train(\n",
" params, train_dataset, num_boost_round=num_boost_round,\n",
" valid_sets=[train_dataset, val_dataset], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
" else:\n",
" model = lgb.train(\n",
" params, train_dataset, num_boost_round=num_boost_round, callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
"\n",
" return model, scaler, None\n",
"\n",
"def rolling_train_predict(df, train_days, test_days, feature_columns_origin, days=5, use_pca=False, validation_days=60,\n",
" filter_index=None, params=None):\n",
" # 1. 按照交易日期排序\n",
" unique_dates = df[df['trade_date'] >= '2020-01-01']['trade_date'].unique().tolist()\n",
" unique_dates = sorted(unique_dates)\n",
" n = len(unique_dates)\n",
"\n",
" # 2. 计算需要跳过的天数,使后续窗口对齐\n",
" extra_days = (n - train_days) % test_days\n",
" start_index = extra_days # 从此索引开始滚动\n",
"\n",
" predictions_list = []\n",
"\n",
" for start in range(start_index, n - train_days - test_days + 1, test_days):\n",
"\n",
" train_dates = unique_dates[start: start + train_days]\n",
" test_dates = unique_dates[start + train_days: start + train_days + test_days]\n",
"\n",
" # 根据日期筛选数据\n",
" # train_data = df[df['trade_date'].isin(train_dates)]\n",
" train_data = df[filter_index & df['trade_date'].isin(train_dates)]\n",
" test_data = df[df['trade_date'].isin(test_dates)]\n",
"\n",
" train_data = train_data.sort_values('trade_date')\n",
" test_data = test_data.sort_values('trade_date')\n",
"\n",
" feature_columns, _ = remove_shifted_features(train_data, feature_columns_origin, size=0.8, log=False)\n",
"\n",
" train_data = train_data.dropna(subset=feature_columns)\n",
" train_data = train_data.dropna(subset=['label'])\n",
" train_data = train_data.reset_index(drop=True)\n",
"\n",
" # print(test_data.tail())\n",
" test_data = test_data.dropna(subset=feature_columns)\n",
" # test_data = test_data.dropna(subset=['label'])\n",
" test_data = test_data.reset_index(drop=True)\n",
"\n",
" # print(len(train_data))\n",
" # print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
" print(f\"train_data最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
" # # print(len(test_data))\n",
" # print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
" print(f\"test_data最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"\n",
" cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
" for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')\n",
"\n",
" label_gain = list(range(len(train_data['label'].unique())))\n",
" label_gain = [(gain + 1) * (gain + 1) for gain in label_gain]\n",
" params['label_gain'] = label_gain\n",
"\n",
" # ud = train_data[\"trade_date\"].unique()\n",
" # date_weights = {date: weight for date, weight in zip(ud, np.linspace(1, 2, len(unique_dates)))}\n",
" # params['weight'] = train_data[\"trade_date\"].map(date_weights).tolist()\n",
"\n",
" # print(f'feature_columns: {feature_columns}')\n",
" # feature_contri = [2 if feat.startswith('act_factor') else 1 for feat in feature_columns]\n",
" # params['feature_contri'] = feature_contri\n",
" evals = {}\n",
" model, _, _ = train_light_model(train_data.dropna(subset=['label']),\n",
" params, feature_columns,\n",
" [lgb.log_evaluation(period=100),\n",
" lgb.callback.record_evaluation(evals),\n",
" # lgb.early_stopping(100, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=100, validation_days=validation_days,\n",
" print_feature_importance=False, use_pca=False)\n",
"\n",
" score_df = test_data.copy()\n",
" score_df['score'] = model.predict(score_df[feature_columns])\n",
" score_df = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
" score_df = score_df[['trade_date', 'score', 'ts_code']]\n",
" predictions_list.append(score_df)\n",
"\n",
" final_predictions = pd.concat(predictions_list, ignore_index=True)\n",
" return final_predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ddb5b67a9852e2",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_data最大日期: 2022-12-07\n",
"test_data最大日期: 2022-12-08\n",
"划分后的训练集大小: 525, 验证集大小: 106\n",
"train_data最大日期: 2022-12-08\n",
"test_data最大日期: 2022-12-09\n",
"划分后的训练集大小: 531, 验证集大小: 109\n",
"train_data最大日期: 2022-12-09\n",
"test_data最大日期: 2022-12-12\n",
"划分后的训练集大小: 516, 验证集大小: 100\n",
"train_data最大日期: 2022-12-12\n",
"test_data最大日期: 2022-12-13\n",
"划分后的训练集大小: 528, 验证集大小: 108\n",
"train_data最大日期: 2022-12-13\n",
"test_data最大日期: 2022-12-14\n",
"划分后的训练集大小: 571, 验证集大小: 148\n",
"train_data最大日期: 2022-12-14\n",
"test_data最大日期: 2022-12-15\n",
"划分后的训练集大小: 565, 验证集大小: 100\n",
"train_data最大日期: 2022-12-15\n",
"test_data最大日期: 2022-12-16\n",
"划分后的训练集大小: 600, 验证集大小: 144\n",
"train_data最大日期: 2022-12-16\n",
"test_data最大日期: 2022-12-19\n",
"划分后的训练集大小: 597, 验证集大小: 97\n",
"train_data最大日期: 2022-12-19\n",
"test_data最大日期: 2022-12-20\n",
"划分后的训练集大小: 633, 验证集大小: 144\n",
"train_data最大日期: 2022-12-20\n",
"test_data最大日期: 2022-12-21\n",
"划分后的训练集大小: 627, 验证集大小: 142\n",
"train_data最大日期: 2022-12-21\n",
"test_data最大日期: 2022-12-22\n",
"划分后的训练集大小: 624, 验证集大小: 97\n",
"train_data最大日期: 2022-12-22\n",
"test_data最大日期: 2022-12-23\n",
"划分后的训练集大小: 605, 验证集大小: 125\n",
"train_data最大日期: 2022-12-23\n",
"test_data最大日期: 2022-12-26\n",
"划分后的训练集大小: 603, 验证集大小: 95\n",
"train_data最大日期: 2022-12-26\n",
"test_data最大日期: 2022-12-27\n",
"划分后的训练集大小: 558, 验证集大小: 99\n",
"train_data最大日期: 2022-12-27\n",
"test_data最大日期: 2022-12-28\n",
"划分后的训练集大小: 510, 验证集大小: 94\n",
"train_data最大日期: 2022-12-28\n",
"test_data最大日期: 2022-12-29\n",
"划分后的训练集大小: 508, 验证集大小: 95\n",
"train_data最大日期: 2022-12-29\n",
"test_data最大日期: 2022-12-30\n",
"划分后的训练集大小: 528, 验证集大小: 145\n",
"train_data最大日期: 2022-12-30\n",
"test_data最大日期: 2023-01-03\n",
"划分后的训练集大小: 570, 验证集大小: 137\n",
"train_data最大日期: 2023-01-03\n",
"test_data最大日期: 2023-01-04\n",
"划分后的训练集大小: 618, 验证集大小: 147\n",
"train_data最大日期: 2023-01-04\n",
"test_data最大日期: 2023-01-05\n",
"划分后的训练集大小: 666, 验证集大小: 142\n",
"train_data最大日期: 2023-01-05\n",
"test_data最大日期: 2023-01-06\n",
"划分后的训练集大小: 717, 验证集大小: 146\n",
"train_data最大日期: 2023-01-06\n",
"test_data最大日期: 2023-01-09\n",
"划分后的训练集大小: 670, 验证集大小: 98\n",
"train_data最大日期: 2023-01-09\n",
"test_data最大日期: 2023-01-10\n",
"划分后的训练集大小: 630, 验证集大小: 97\n",
"train_data最大日期: 2023-01-10\n",
"test_data最大日期: 2023-01-11\n",
"划分后的训练集大小: 589, 验证集大小: 106\n",
"train_data最大日期: 2023-01-11\n",
"test_data最大日期: 2023-01-12\n",
"划分后的训练集大小: 543, 验证集大小: 96\n",
"train_data最大日期: 2023-01-12\n",
"test_data最大日期: 2023-01-13\n",
"划分后的训练集大小: 544, 验证集大小: 147\n",
"train_data最大日期: 2023-01-13\n",
"test_data最大日期: 2023-01-16\n",
"划分后的训练集大小: 553, 验证集大小: 107\n",
"train_data最大日期: 2023-01-16\n",
"test_data最大日期: 2023-01-17\n",
"划分后的训练集大小: 573, 验证集大小: 117\n",
"train_data最大日期: 2023-01-17\n",
"test_data最大日期: 2023-01-18\n",
"划分后的训练集大小: 604, 验证集大小: 137\n",
"train_data最大日期: 2023-01-18\n",
"test_data最大日期: 2023-01-19\n",
"划分后的训练集大小: 625, 验证集大小: 117\n",
"train_data最大日期: 2023-01-19\n",
"test_data最大日期: 2023-01-20\n",
"划分后的训练集大小: 616, 验证集大小: 138\n",
"train_data最大日期: 2023-01-20\n",
"test_data最大日期: 2023-01-30\n",
"划分后的训练集大小: 609, 验证集大小: 100\n",
"train_data最大日期: 2023-01-30\n",
"test_data最大日期: 2023-01-31\n",
"划分后的训练集大小: 621, 验证集大小: 129\n",
"train_data最大日期: 2023-01-31\n",
"test_data最大日期: 2023-02-01\n",
"划分后的训练集大小: 584, 验证集大小: 100\n",
"train_data最大日期: 2023-02-01\n",
"test_data最大日期: 2023-02-02\n",
"划分后的训练集大小: 583, 验证集大小: 116\n",
"train_data最大日期: 2023-02-02\n",
"test_data最大日期: 2023-02-03\n",
"划分后的训练集大小: 553, 验证集大小: 108\n",
"train_data最大日期: 2023-02-03\n",
"test_data最大日期: 2023-02-06\n",
"划分后的训练集大小: 581, 验证集大小: 128\n",
"train_data最大日期: 2023-02-06\n",
"test_data最大日期: 2023-02-07\n",
"划分后的训练集大小: 572, 验证集大小: 120\n",
"train_data最大日期: 2023-02-07\n",
"test_data最大日期: 2023-02-08\n",
"划分后的训练集大小: 622, 验证集大小: 150\n",
"train_data最大日期: 2023-02-08\n",
"test_data最大日期: 2023-02-09\n",
"划分后的训练集大小: 656, 验证集大小: 150\n",
"train_data最大日期: 2023-02-09\n",
"test_data最大日期: 2023-02-10\n",
"划分后的训练集大小: 697, 验证集大小: 149\n",
"train_data最大日期: 2023-02-10\n",
"test_data最大日期: 2023-02-13\n",
"划分后的训练集大小: 698, 验证集大小: 129\n",
"train_data最大日期: 2023-02-13\n",
"test_data最大日期: 2023-02-14\n",
"划分后的训练集大小: 717, 验证集大小: 139\n",
"train_data最大日期: 2023-02-14\n",
"test_data最大日期: 2023-02-15\n",
"划分后的训练集大小: 715, 验证集大小: 148\n",
"train_data最大日期: 2023-02-15\n",
"test_data最大日期: 2023-02-16\n",
"划分后的训练集大小: 714, 验证集大小: 149\n",
"train_data最大日期: 2023-02-16\n",
"test_data最大日期: 2023-02-17\n",
"划分后的训练集大小: 713, 验证集大小: 148\n",
"train_data最大日期: 2023-02-17\n",
"test_data最大日期: 2023-02-20\n",
"划分后的训练集大小: 682, 验证集大小: 98\n",
"train_data最大日期: 2023-02-20\n",
"test_data最大日期: 2023-02-21\n",
"划分后的训练集大小: 681, 验证集大小: 138\n",
"train_data最大日期: 2023-02-21\n",
"test_data最大日期: 2023-02-22\n",
"划分后的训练集大小: 632, 验证集大小: 99\n",
"train_data最大日期: 2023-02-22\n",
"test_data最大日期: 2023-02-23\n",
"划分后的训练集大小: 619, 验证集大小: 136\n",
"train_data最大日期: 2023-02-23\n",
"test_data最大日期: 2023-02-24\n",
"划分后的训练集大小: 571, 验证集大小: 100\n",
"train_data最大日期: 2023-02-24\n",
"test_data最大日期: 2023-02-27\n",
"划分后的训练集大小: 621, 验证集大小: 148\n",
"train_data最大日期: 2023-02-27\n",
"test_data最大日期: 2023-02-28\n",
"划分后的训练集大小: 632, 验证集大小: 149\n",
"train_data最大日期: 2023-02-28\n",
"test_data最大日期: 2023-03-01\n",
"划分后的训练集大小: 632, 验证集大小: 99\n",
"train_data最大日期: 2023-03-01\n",
"test_data最大日期: 2023-03-02\n",
"划分后的训练集大小: 596, 验证集大小: 100\n",
"train_data最大日期: 2023-03-02\n",
"test_data最大日期: 2023-03-03\n",
"划分后的训练集大小: 595, 验证集大小: 99\n",
"train_data最大日期: 2023-03-03\n",
"test_data最大日期: 2023-03-06\n",
"划分后的训练集大小: 596, 验证集大小: 149\n",
"train_data最大日期: 2023-03-06\n",
"test_data最大日期: 2023-03-07\n",
"划分后的训练集大小: 547, 验证集大小: 100\n",
"train_data最大日期: 2023-03-07\n",
"test_data最大日期: 2023-03-08\n",
"划分后的训练集大小: 567, 验证集大小: 119\n",
"train_data最大日期: 2023-03-08\n",
"test_data最大日期: 2023-03-09\n",
"划分后的训练集大小: 585, 验证集大小: 118\n",
"train_data最大日期: 2023-03-09\n",
"test_data最大日期: 2023-03-10\n",
"划分后的训练集大小: 634, 验证集大小: 148\n",
"train_data最大日期: 2023-03-10\n",
"test_data最大日期: 2023-03-13\n",
"划分后的训练集大小: 630, 验证集大小: 145\n",
"train_data最大日期: 2023-03-13\n",
"test_data最大日期: 2023-03-14\n",
"划分后的训练集大小: 638, 验证集大小: 108\n",
"train_data最大日期: 2023-03-14\n",
"test_data最大日期: 2023-03-15\n",
"划分后的训练集大小: 665, 验证集大小: 146\n",
"train_data最大日期: 2023-03-15\n",
"test_data最大日期: 2023-03-16\n",
"划分后的训练集大小: 677, 验证集大小: 130\n",
"train_data最大日期: 2023-03-16\n",
"test_data最大日期: 2023-03-17\n",
"划分后的训练集大小: 678, 验证集大小: 149\n",
"train_data最大日期: 2023-03-17\n",
"test_data最大日期: 2023-03-20\n",
"划分后的训练集大小: 642, 验证集大小: 109\n",
"train_data最大日期: 2023-03-20\n",
"test_data最大日期: 2023-03-21\n",
"划分后的训练集大小: 663, 验证集大小: 129\n",
"train_data最大日期: 2023-03-21\n",
"test_data最大日期: 2023-03-22\n",
"划分后的训练集大小: 615, 验证集大小: 98\n",
"train_data最大日期: 2023-03-22\n",
"test_data最大日期: 2023-03-23\n",
"划分后的训练集大小: 633, 验证集大小: 148\n",
"train_data最大日期: 2023-03-23\n",
"test_data最大日期: 2023-03-24\n",
"划分后的训练集大小: 627, 验证集大小: 143\n",
"train_data最大日期: 2023-03-24\n",
"test_data最大日期: 2023-03-27\n",
"划分后的训练集大小: 646, 验证集大小: 128\n",
"train_data最大日期: 2023-03-27\n",
"test_data最大日期: 2023-03-28\n",
"划分后的训练集大小: 615, 验证集大小: 98\n",
"train_data最大日期: 2023-03-28\n",
"test_data最大日期: 2023-03-29\n",
"划分后的训练集大小: 644, 验证集大小: 127\n",
"train_data最大日期: 2023-03-29\n",
"test_data最大日期: 2023-03-30\n",
"划分后的训练集大小: 623, 验证集大小: 127\n",
"train_data最大日期: 2023-03-30\n",
"test_data最大日期: 2023-03-31\n",
"划分后的训练集大小: 577, 验证集大小: 97\n",
"train_data最大日期: 2023-03-31\n",
"test_data最大日期: 2023-04-03\n",
"划分后的训练集大小: 595, 验证集大小: 146\n",
"train_data最大日期: 2023-04-03\n",
"test_data最大日期: 2023-04-04\n",
"划分后的训练集大小: 644, 验证集大小: 147\n",
"train_data最大日期: 2023-04-04\n",
"test_data最大日期: 2023-04-06\n",
"划分后的训练集大小: 632, 验证集大小: 115\n",
"train_data最大日期: 2023-04-06\n",
"test_data最大日期: 2023-04-07\n",
"划分后的训练集大小: 651, 验证集大小: 146\n",
"train_data最大日期: 2023-04-07\n",
"test_data最大日期: 2023-04-10\n",
"划分后的训练集大小: 702, 验证集大小: 148\n",
"train_data最大日期: 2023-04-10\n",
"test_data最大日期: 2023-04-11\n",
"划分后的训练集大小: 701, 验证集大小: 145\n",
"train_data最大日期: 2023-04-11\n",
"test_data最大日期: 2023-04-12\n",
"划分后的训练集大小: 672, 验证集大小: 118\n",
"train_data最大日期: 2023-04-12\n",
"test_data最大日期: 2023-04-13\n",
"划分后的训练集大小: 694, 验证集大小: 137\n",
"train_data最大日期: 2023-04-13\n",
"test_data最大日期: 2023-04-14\n",
"划分后的训练集大小: 695, 验证集大小: 147\n",
"train_data最大日期: 2023-04-14\n",
"test_data最大日期: 2023-04-17\n",
"划分后的训练集大小: 684, 验证集大小: 137\n",
"train_data最大日期: 2023-04-17\n",
"test_data最大日期: 2023-04-18\n",
"划分后的训练集大小: 638, 验证集大小: 99\n",
"train_data最大日期: 2023-04-18\n",
"test_data最大日期: 2023-04-19\n",
"划分后的训练集大小: 649, 验证集大小: 129\n",
"train_data最大日期: 2023-04-19\n",
"test_data最大日期: 2023-04-20\n",
"划分后的训练集大小: 610, 验证集大小: 98\n",
"train_data最大日期: 2023-04-20\n",
"test_data最大日期: 2023-04-21\n",
"划分后的训练集大小: 611, 验证集大小: 148\n",
"train_data最大日期: 2023-04-21\n",
"test_data最大日期: 2023-04-24\n",
"划分后的训练集大小: 610, 验证集大小: 136\n",
"train_data最大日期: 2023-04-24\n",
"test_data最大日期: 2023-04-25\n",
"划分后的训练集大小: 657, 验证集大小: 146\n",
"train_data最大日期: 2023-04-25\n",
"test_data最大日期: 2023-04-26\n",
"划分后的训练集大小: 675, 验证集大小: 147\n",
"train_data最大日期: 2023-04-26\n",
"test_data最大日期: 2023-04-27\n",
"划分后的训练集大小: 677, 验证集大小: 100\n",
"train_data最大日期: 2023-04-27\n",
"test_data最大日期: 2023-04-28\n",
"划分后的训练集大小: 653, 验证集大小: 124\n",
"train_data最大日期: 2023-04-28\n",
"test_data最大日期: 2023-05-04\n",
"划分后的训练集大小: 664, 验证集大小: 147\n",
"train_data最大日期: 2023-05-04\n",
"test_data最大日期: 2023-05-05\n",
"划分后的训练集大小: 636, 验证集大小: 118\n",
"train_data最大日期: 2023-05-05\n",
"test_data最大日期: 2023-05-08\n",
"划分后的训练集大小: 637, 验证集大小: 148\n",
"train_data最大日期: 2023-05-08\n",
"test_data最大日期: 2023-05-09\n",
"划分后的训练集大小: 685, 验证集大小: 148\n",
"train_data最大日期: 2023-05-09\n",
"test_data最大日期: 2023-05-10\n",
"划分后的训练集大小: 658, 验证集大小: 97\n",
"train_data最大日期: 2023-05-10\n",
"test_data最大日期: 2023-05-11\n",
"划分后的训练集大小: 638, 验证集大小: 127\n",
"train_data最大日期: 2023-05-11\n",
"test_data最大日期: 2023-05-12\n",
"划分后的训练集大小: 666, 验证集大小: 146\n",
"train_data最大日期: 2023-05-12\n",
"test_data最大日期: 2023-05-15\n",
"划分后的训练集大小: 664, 验证集大小: 146\n",
"train_data最大日期: 2023-05-15\n",
"test_data最大日期: 2023-05-16\n",
"划分后的训练集大小: 621, 验证集大小: 105\n",
"train_data最大日期: 2023-05-16\n",
"test_data最大日期: 2023-05-17\n",
"划分后的训练集大小: 623, 验证集大小: 99\n",
"train_data最大日期: 2023-05-17\n",
"test_data最大日期: 2023-05-18\n",
"划分后的训练集大小: 606, 验证集大小: 110\n",
"train_data最大日期: 2023-05-18\n",
"test_data最大日期: 2023-05-19\n",
"划分后的训练集大小: 578, 验证集大小: 118\n",
"train_data最大日期: 2023-05-19\n",
"test_data最大日期: 2023-05-22\n",
"划分后的训练集大小: 540, 验证集大小: 108\n",
"train_data最大日期: 2023-05-22\n",
"test_data最大日期: 2023-05-23\n",
"划分后的训练集大小: 532, 验证集大小: 97\n",
"train_data最大日期: 2023-05-23\n",
"test_data最大日期: 2023-05-24\n",
"划分后的训练集大小: 559, 验证集大小: 126\n",
"train_data最大日期: 2023-05-24\n",
"test_data最大日期: 2023-05-25\n",
"划分后的训练集大小: 548, 验证集大小: 99\n",
"train_data最大日期: 2023-05-25\n",
"test_data最大日期: 2023-05-26\n",
"划分后的训练集大小: 526, 验证集大小: 96\n",
"train_data最大日期: 2023-05-26\n",
"test_data最大日期: 2023-05-29\n",
"划分后的训练集大小: 516, 验证集大小: 98\n",
"train_data最大日期: 2023-05-29\n",
"test_data最大日期: 2023-05-30\n",
"划分后的训练集大小: 527, 验证集大小: 108\n",
"train_data最大日期: 2023-05-30\n",
"test_data最大日期: 2023-05-31\n",
"划分后的训练集大小: 546, 验证集大小: 145\n",
"train_data最大日期: 2023-05-31\n",
"test_data最大日期: 2023-06-01\n",
"划分后的训练集大小: 594, 验证集大小: 147\n",
"train_data最大日期: 2023-06-01\n",
"test_data最大日期: 2023-06-02\n",
"划分后的训练集大小: 616, 验证集大小: 118\n",
"train_data最大日期: 2023-06-02\n",
"test_data最大日期: 2023-06-05\n",
"划分后的训练集大小: 666, 验证集大小: 148\n",
"train_data最大日期: 2023-06-05\n",
"test_data最大日期: 2023-06-06\n",
"划分后的训练集大小: 676, 验证集大小: 118\n",
"train_data最大日期: 2023-06-06\n",
"test_data最大日期: 2023-06-07\n",
"划分后的训练集大小: 626, 验证集大小: 95\n",
"train_data最大日期: 2023-06-07\n",
"test_data最大日期: 2023-06-08\n",
"划分后的训练集大小: 626, 验证集大小: 147\n",
"train_data最大日期: 2023-06-08\n",
"test_data最大日期: 2023-06-09\n",
"划分后的训练集大小: 606, 验证集大小: 98\n",
"train_data最大日期: 2023-06-09\n",
"test_data最大日期: 2023-06-12\n",
"划分后的训练集大小: 558, 验证集大小: 100\n",
"train_data最大日期: 2023-06-12\n",
"test_data最大日期: 2023-06-13\n",
"划分后的训练集大小: 579, 验证集大小: 139\n",
"train_data最大日期: 2023-06-13\n",
"test_data最大日期: 2023-06-14\n",
"划分后的训练集大小: 630, 验证集大小: 146\n"
]
},
{
"ename": "LightGBMError",
"evalue": "Forced splits file includes feature index 0, but maximum feature index in dataset is -1",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mLightGBMError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[36], line 38\u001b[0m\n\u001b[0;32m 34\u001b[0m final_predictions\u001b[38;5;241m.\u001b[39mto_csv(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpredictions_test.tsv\u001b[39m\u001b[38;5;124m'\u001b[39m, index\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m final_predictions\n\u001b[1;32m---> 38\u001b[0m final_predictions1 \u001b[38;5;241m=\u001b[39m train(pdf1, feature_columns1, filter_index1)\n\u001b[0;32m 39\u001b[0m final_predictions2 \u001b[38;5;241m=\u001b[39m train(pdf2, feature_columns2, filter_index2)\n",
"Cell \u001b[1;32mIn[36], line 31\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m(pdf, feature_columns, filter_index)\u001b[0m\n\u001b[0;32m 4\u001b[0m light_params \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m 5\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlabel_gain\u001b[39m\u001b[38;5;124m'\u001b[39m: label_gain,\n\u001b[0;32m 6\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mobjective\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlambdarank\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 26\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mseed\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;241m7\u001b[39m\n\u001b[0;32m 27\u001b[0m }\n\u001b[0;32m 29\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 31\u001b[0m final_predictions \u001b[38;5;241m=\u001b[39m rolling_train_predict(\n\u001b[0;32m 32\u001b[0m pdf[(pdf[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrade_date\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m2022-12-01\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m&\u001b[39m (pdf[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrade_date\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m2029-03-26\u001b[39m\u001b[38;5;124m'\u001b[39m)], \u001b[38;5;241m5\u001b[39m, \u001b[38;5;241m1\u001b[39m, feature_columns,\n\u001b[0;32m 33\u001b[0m days\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, validation_days\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, filter_index\u001b[38;5;241m=\u001b[39mfilter_index, params\u001b[38;5;241m=\u001b[39mlight_params)\n\u001b[0;32m 34\u001b[0m final_predictions\u001b[38;5;241m.\u001b[39mto_csv(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpredictions_test.tsv\u001b[39m\u001b[38;5;124m'\u001b[39m, index\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m final_predictions\n",
"Cell \u001b[1;32mIn[33], line 154\u001b[0m, in \u001b[0;36mrolling_train_predict\u001b[1;34m(df, train_days, test_days, feature_columns_origin, days, use_pca, validation_days, filter_index, params)\u001b[0m\n\u001b[0;32m 146\u001b[0m \u001b[38;5;66;03m# ud = train_data[\"trade_date\"].unique()\u001b[39;00m\n\u001b[0;32m 147\u001b[0m \u001b[38;5;66;03m# date_weights = {date: weight for date, weight in zip(ud, np.linspace(1, 2, len(unique_dates)))}\u001b[39;00m\n\u001b[0;32m 148\u001b[0m \u001b[38;5;66;03m# params['weight'] = train_data[\"trade_date\"].map(date_weights).tolist()\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 151\u001b[0m \u001b[38;5;66;03m# feature_contri = [2 if feat.startswith('act_factor') else 1 for feat in feature_columns]\u001b[39;00m\n\u001b[0;32m 152\u001b[0m \u001b[38;5;66;03m# params['feature_contri'] = feature_contri\u001b[39;00m\n\u001b[0;32m 153\u001b[0m evals \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m--> 154\u001b[0m model, _, _ \u001b[38;5;241m=\u001b[39m train_light_model(train_data\u001b[38;5;241m.\u001b[39mdropna(subset\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlabel\u001b[39m\u001b[38;5;124m'\u001b[39m]),\n\u001b[0;32m 155\u001b[0m params, feature_columns,\n\u001b[0;32m 156\u001b[0m [lgb\u001b[38;5;241m.\u001b[39mlog_evaluation(period\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m),\n\u001b[0;32m 157\u001b[0m lgb\u001b[38;5;241m.\u001b[39mcallback\u001b[38;5;241m.\u001b[39mrecord_evaluation(evals),\n\u001b[0;32m 158\u001b[0m \u001b[38;5;66;03m# lgb.early_stopping(100, first_metric_only=True)\u001b[39;00m\n\u001b[0;32m 159\u001b[0m ], evals,\n\u001b[0;32m 160\u001b[0m num_boost_round\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m, validation_days\u001b[38;5;241m=\u001b[39mvalidation_days,\n\u001b[0;32m 161\u001b[0m print_feature_importance\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, use_pca\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m 163\u001b[0m score_df \u001b[38;5;241m=\u001b[39m test_data\u001b[38;5;241m.\u001b[39mcopy()\n\u001b[0;32m 164\u001b[0m score_df[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mscore\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mpredict(score_df[feature_columns])\n",
"Cell \u001b[1;32mIn[33], line 81\u001b[0m, in \u001b[0;36mtrain_light_model\u001b[1;34m(train_data_df, params, feature_columns, callbacks, evals, print_feature_importance, num_boost_round, validation_days, use_pca, split_date)\u001b[0m\n\u001b[0;32m 75\u001b[0m model \u001b[38;5;241m=\u001b[39m lgb\u001b[38;5;241m.\u001b[39mtrain(\n\u001b[0;32m 76\u001b[0m params, train_dataset, num_boost_round\u001b[38;5;241m=\u001b[39mnum_boost_round,\n\u001b[0;32m 77\u001b[0m valid_sets\u001b[38;5;241m=\u001b[39m[train_dataset, val_dataset], valid_names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvalid\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[0;32m 78\u001b[0m callbacks\u001b[38;5;241m=\u001b[39mcallbacks\n\u001b[0;32m 79\u001b[0m )\n\u001b[0;32m 80\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m---> 81\u001b[0m model \u001b[38;5;241m=\u001b[39m lgb\u001b[38;5;241m.\u001b[39mtrain(\n\u001b[0;32m 82\u001b[0m params, train_dataset, num_boost_round\u001b[38;5;241m=\u001b[39mnum_boost_round, callbacks\u001b[38;5;241m=\u001b[39mcallbacks\n\u001b[0;32m 83\u001b[0m )\n\u001b[0;32m 85\u001b[0m \u001b[38;5;66;03m# 打印特征重要性(如果需要)\u001b[39;00m\n\u001b[0;32m 86\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m print_feature_importance:\n",
"File \u001b[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\lightgbm\\engine.py:297\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m(params, train_set, num_boost_round, valid_sets, valid_names, feval, init_model, keep_training_booster, callbacks)\u001b[0m\n\u001b[0;32m 295\u001b[0m \u001b[38;5;66;03m# construct booster\u001b[39;00m\n\u001b[0;32m 296\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 297\u001b[0m booster \u001b[38;5;241m=\u001b[39m Booster(params\u001b[38;5;241m=\u001b[39mparams, train_set\u001b[38;5;241m=\u001b[39mtrain_set)\n\u001b[0;32m 298\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_valid_contain_train:\n\u001b[0;32m 299\u001b[0m booster\u001b[38;5;241m.\u001b[39mset_train_data_name(train_data_name)\n",
"File \u001b[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\lightgbm\\basic.py:3660\u001b[0m, in \u001b[0;36mBooster.__init__\u001b[1;34m(self, params, train_set, model_file, model_str)\u001b[0m\n\u001b[0;32m 3658\u001b[0m params\u001b[38;5;241m.\u001b[39mupdate(train_set\u001b[38;5;241m.\u001b[39mget_params())\n\u001b[0;32m 3659\u001b[0m params_str \u001b[38;5;241m=\u001b[39m _param_dict_to_str(params)\n\u001b[1;32m-> 3660\u001b[0m _safe_call(\n\u001b[0;32m 3661\u001b[0m _LIB\u001b[38;5;241m.\u001b[39mLGBM_BoosterCreate(\n\u001b[0;32m 3662\u001b[0m train_set\u001b[38;5;241m.\u001b[39m_handle,\n\u001b[0;32m 3663\u001b[0m _c_str(params_str),\n\u001b[0;32m 3664\u001b[0m ctypes\u001b[38;5;241m.\u001b[39mbyref(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_handle),\n\u001b[0;32m 3665\u001b[0m )\n\u001b[0;32m 3666\u001b[0m )\n\u001b[0;32m 3667\u001b[0m \u001b[38;5;66;03m# save reference to data\u001b[39;00m\n\u001b[0;32m 3668\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_set \u001b[38;5;241m=\u001b[39m train_set\n",
"File \u001b[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\lightgbm\\basic.py:313\u001b[0m, in \u001b[0;36m_safe_call\u001b[1;34m(ret)\u001b[0m\n\u001b[0;32m 305\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Check the return value from C API call.\u001b[39;00m\n\u001b[0;32m 306\u001b[0m \n\u001b[0;32m 307\u001b[0m \u001b[38;5;124;03mParameters\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 310\u001b[0m \u001b[38;5;124;03m The return value from C API calls.\u001b[39;00m\n\u001b[0;32m 311\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 312\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ret \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m--> 313\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m LightGBMError(_LIB\u001b[38;5;241m.\u001b[39mLGBM_GetLastError()\u001b[38;5;241m.\u001b[39mdecode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
"\u001b[1;31mLightGBMError\u001b[0m: Forced splits file includes feature index 0, but maximum feature index in dataset is -1"
]
}
],
"source": [
"\n",
"\n",
"def train(pdf, feature_columns, filter_index):\n",
" label_gain = list(range(len(pdf['label'].unique())))\n",
" label_gain = [gain * gain for gain in label_gain]\n",
" light_params = {\n",
" 'label_gain': label_gain,\n",
" 'objective': 'lambdarank',\n",
" 'metric': 'ndcg',\n",
" 'learning_rate': 0.03,\n",
" 'num_leaves': 32,\n",
" # 'min_data_in_leaf': 128,\n",
" 'max_depth': 8,\n",
" 'max_bin': 32,\n",
" 'feature_fraction': 0.7,\n",
" # 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" 'lambda_l1': 0.1,\n",
" 'lambda_l2': 0.1,\n",
" 'boosting': 'gbdt',\n",
" 'verbosity': -1,\n",
" 'extra_trees': True,\n",
" 'max_position': 5,\n",
" 'ndcg_at': 1,\n",
" 'quant_train_renew_leaf': True,\n",
" 'lambdarank_truncation_level': 3,\n",
" # 'lambdarank_position_bias_regularization': 1,\n",
" 'seed': 7\n",
" }\n",
"\n",
" gc.collect()\n",
"\n",
" final_predictions = rolling_train_predict(\n",
" pdf[(pdf['trade_date'] >= '2022-12-01') & (pdf['trade_date'] <= '2029-03-26')], 5, 1, feature_columns,\n",
" days=0, validation_days=0, filter_index=filter_index, params=light_params)\n",
" final_predictions.to_csv('predictions_test.tsv', index=False)\n",
"\n",
" return final_predictions\n",
"\n",
"final_predictions1 = train(pdf1, feature_columns1, filter_index1)\n",
"final_predictions2 = train(pdf2, feature_columns2, filter_index2)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7e470a2-e1e5-42e5-a2ee-a5fc80455d95",
"metadata": {},
"outputs": [],
"source": [
"\n",
"slice1 = final_predictions1[final_predictions1['trade_date'] == date_to_compare]\n",
"slice2 = final_predictions2[final_predictions2['trade_date'] == date_to_compare]\n",
"get_diff(slice1, slice2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "stock",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}