Files
NewStock/code/train/UpdateRank.ipynb

1571 lines
191 KiB
Plaintext
Raw Normal View History

2025-04-03 00:45:07 +08:00
{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:10:25.019821Z",
"start_time": "2025-03-26T15:10:25.015412Z"
}
},
"source": [
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import pandas as pd\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"pd.set_option('display.max_columns', None)\n"
],
"outputs": [],
"execution_count": 66
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:11:40.164945Z",
"start_time": "2025-03-26T15:10:25.138896Z"
}
},
"source": [
"from utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print('cyq perf')\n",
"df = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf',\n",
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct',\n",
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
" df=df)\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"cyq perf\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8450470 entries, 0 to 8450469\n",
"Data columns (total 31 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 pct_chg float64 \n",
" 8 turnover_rate float64 \n",
" 9 pe_ttm float64 \n",
" 10 circ_mv float64 \n",
" 11 volume_ratio float64 \n",
" 12 is_st bool \n",
" 13 up_limit float64 \n",
" 14 down_limit float64 \n",
" 15 buy_sm_vol float64 \n",
" 16 sell_sm_vol float64 \n",
" 17 buy_lg_vol float64 \n",
" 18 sell_lg_vol float64 \n",
" 19 buy_elg_vol float64 \n",
" 20 sell_elg_vol float64 \n",
" 21 net_mf_vol float64 \n",
" 22 his_low float64 \n",
" 23 his_high float64 \n",
" 24 cost_5pct float64 \n",
" 25 cost_15pct float64 \n",
" 26 cost_50pct float64 \n",
" 27 cost_85pct float64 \n",
" 28 cost_95pct float64 \n",
" 29 weight_avg float64 \n",
" 30 winner_rate float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(28), object(1)\n",
"memory usage: 1.9+ GB\n",
"None\n"
]
}
],
"execution_count": 67
},
{
"cell_type": "code",
"id": "cac01788dac10678",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:11:55.524829Z",
"start_time": "2025-03-26T15:11:41.366630Z"
}
},
"source": [
"print('industry')\n",
"industry_df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code', 'in_date'],\n",
" df=None, on=['ts_code'], join='left')\n",
"\n",
"\n",
"def merge_with_industry_data(df, industry_df):\n",
" # 确保日期字段是 datetime 类型\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
"\n",
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
"\n",
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
"\n",
" # 使用 merge_asof 进行向后合并\n",
" merged = pd.merge_asof(\n",
" df_sorted,\n",
" industry_df_sorted,\n",
" by='ts_code', # 按 ts_code 分组\n",
" left_on='trade_date',\n",
" right_on='in_date',\n",
" direction='backward'\n",
" )\n",
"\n",
" # 获取每个 ts_code 的最早 in_date 记录\n",
" min_in_date_per_ts = (industry_df_sorted\n",
" .groupby('ts_code')\n",
" .first()\n",
" .reset_index()[['ts_code', 'l2_code']])\n",
"\n",
" # 填充未匹配到的记录trade_date 早于所有 in_date 的情况)\n",
" merged['l2_code'] = merged['l2_code'].fillna(\n",
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
" )\n",
"\n",
" # 保留需要的列并重置索引\n",
" result = merged.reset_index(drop=True)\n",
" return result\n",
"\n",
"\n",
"# 使用示例\n",
"df = merge_with_industry_data(df, industry_df)\n",
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n"
]
}
],
"execution_count": 68
},
{
"cell_type": "code",
"id": "c4e9e1d31da6dba6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:11:55.794514Z",
"start_time": "2025-03-26T15:11:55.600258Z"
}
},
"source": [
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" df = df.sort_values('trade_date')\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" # 4. 情绪因子1市场上涨比例Up Ratio\n",
" df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)\n",
" df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例\n",
"\n",
" # 5. 情绪因子2成交量变化率Volume Change Rate\n",
" df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量\n",
" df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率\n",
"\n",
" # 6. 情绪因子3波动率Volatility\n",
" df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差\n",
"\n",
" # 7. 情绪因子4成交额变化率Amount Change Rate\n",
" df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额\n",
" df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率\n",
"\n",
" return df\n",
"\n",
"\n",
"def generate_index_indicators(h5_filename):\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" df = df.sort_values('trade_date')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
" df_index = calculate_indicators(df_index)\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', 'RSI', 'MACD', 'Signal_line',\n",
" 'MACD_hist', 'up_ratio_20d', 'volume_change_rate', 'volatility',\n",
" 'amount_change_rate', 'amount_mean'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
"h5_filename = '../../data/index_data.h5'\n",
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n"
],
"outputs": [],
"execution_count": 69
},
{
"cell_type": "code",
"id": "a735bc02ceb4d872",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:11:56.026392Z",
"start_time": "2025-03-26T15:11:55.984754Z"
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"def get_rolling_factor(df):\n",
" old_columns = df.columns.tolist()[:]\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # df[\"gap_next_open\"] = (df[\"open\"].shift(-1) - df[\"close\"]) / df[\"close\"]\n",
"\n",
" df['return_skew'] = grouped['pct_chg'].rolling(window=5).skew().reset_index(0, drop=True)\n",
" df['return_kurtosis'] = grouped['pct_chg'].rolling(window=5).kurt().reset_index(0, drop=True)\n",
"\n",
" # 因子 1短期成交量变化率\n",
" df['volume_change_rate'] = (\n",
" grouped['vol'].rolling(window=2).mean() /\n",
" grouped['vol'].rolling(window=10).mean() - 1\n",
" ).reset_index(level=0, drop=True) # 确保索引对齐\n",
"\n",
" # 因子 2成交量突破信号\n",
" max_volume = grouped['vol'].rolling(window=5).max().reset_index(level=0, drop=True) # 确保索引对齐\n",
" df['cat_volume_breakout'] = (df['vol'] > max_volume)\n",
"\n",
" # 因子 3换手率均线偏离度\n",
" mean_turnover = grouped['turnover_rate'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
" std_turnover = grouped['turnover_rate'].rolling(window=3).std().reset_index(level=0, drop=True)\n",
" df['turnover_deviation'] = (df['turnover_rate'] - mean_turnover) / std_turnover\n",
"\n",
" # 因子 4换手率激增信号\n",
" df['cat_turnover_spike'] = (df['turnover_rate'] > mean_turnover + 2 * std_turnover)\n",
"\n",
" # 因子 5量比均值\n",
" df['avg_volume_ratio'] = grouped['volume_ratio'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
"\n",
" # 因子 6量比突破信号\n",
" max_volume_ratio = grouped['volume_ratio'].rolling(window=5).max().reset_index(level=0, drop=True)\n",
" df['cat_volume_ratio_breakout'] = (df['volume_ratio'] > max_volume_ratio)\n",
"\n",
" df['vol_spike'] = grouped.apply(\n",
" lambda x: pd.Series(x['vol'].rolling(20).mean(), index=x.index)\n",
" )\n",
" df['vol_std_5'] = df['vol'].pct_change().rolling(5).std()\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),\n",
" index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),\n",
" index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
"\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算 EMA 指标\n",
" df['_ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['_ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['_ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['_ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['_ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['_ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['_ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['_ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
"\n",
" def rolling_covariance(x, y, window):\n",
" return x.rolling(window).cov(y)\n",
"\n",
" def delta(series, period):\n",
" return series.diff(period)\n",
"\n",
" def rank(series):\n",
" return series.rank(pct=True)\n",
"\n",
" def stddev(series, window):\n",
" return series.rolling(window).std()\n",
"\n",
" window_high_volume = 5\n",
" window_close_stddev = 20\n",
" period_delta = 5\n",
" df['cov'] = rolling_covariance(df['high'], df['vol'], window_high_volume)\n",
" df['delta_cov'] = delta(df['cov'], period_delta)\n",
" df['_rank_stddev'] = rank(stddev(df['close'], window_close_stddev))\n",
" df['alpha_22_improved'] = -1 * df['delta_cov'] * df['_rank_stddev']\n",
"\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" df['cat_up_limit'] = (df['close'] == df['up_limit']) # 是否涨停1表示涨停0表示未涨停\n",
" df['cat_down_limit'] = (df['close'] == df['down_limit']) # 是否跌停1表示跌停0表示未跌停\n",
" df['up_limit_count_10d'] = grouped['cat_up_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
" df['down_limit_count_10d'] = grouped['cat_down_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
"\n",
" # 3. 最近连续涨跌停天数\n",
" def calculate_consecutive_limits(series):\n",
" \"\"\"\n",
" 计算连续涨停/跌停天数。\n",
" \"\"\"\n",
" consecutive_up = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" consecutive_down = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" return consecutive_up, consecutive_down\n",
"\n",
" # 连续涨停天数\n",
" df['consecutive_up_limit'] = grouped['cat_up_limit'].apply(\n",
" lambda x: calculate_consecutive_limits(x)[0]\n",
" ).reset_index(level=0, drop=True)\n",
"\n",
" df['vol_break'] = np.where((df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2), 1, 0)\n",
"\n",
" df['weight_roc5'] = grouped['weight_avg'].apply(lambda x: x.pct_change(5))\n",
"\n",
" def rolling_corr(group):\n",
" roc_close = group['close'].pct_change()\n",
" roc_weight = group['weight_avg'].pct_change()\n",
" return roc_close.rolling(10).corr(roc_weight)\n",
"\n",
" df['price_cost_divergence'] = grouped.apply(rolling_corr)\n",
"\n",
" df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])\n",
"\n",
" # 16. 筹码稳定性指数 (20日波动率)\n",
" df['weight_std20'] = grouped['weight_avg'].apply(lambda x: x.rolling(20).std())\n",
" df['cost_stability'] = df['weight_std20'] / grouped['weight_avg'].transform(lambda x: x.rolling(20).mean())\n",
"\n",
" # 17. 成本区间突破标记\n",
" df['high_cost_break_days'] = grouped.apply(lambda g: g['close'].gt(g['cost_95pct']).rolling(5).sum())\n",
"\n",
" # 20. 筹码-流动性风险\n",
" df['liquidity_risk'] = (df['cost_95pct'] - df['cost_5pct']) * (\n",
" 1 / grouped['vol'].transform(lambda x: x.rolling(10).mean()))\n",
"\n",
" # 7. 市值波动率因子\n",
" df['turnover_std'] = grouped['turnover_rate'].rolling(window=20).std().reset_index(level=0, drop=True)\n",
" df['mv_volatility'] = grouped.apply(lambda x: x['turnover_std'] / x['circ_mv']).reset_index(level=0, drop=True)\n",
"\n",
" # 8. 市值成长性因子\n",
" df['volume_growth'] = grouped['vol'].pct_change(periods=20).reset_index(level=0, drop=True)\n",
" df['mv_growth'] = grouped.apply(lambda x: x['volume_growth'] / x['circ_mv']).reset_index(level=0, drop=True)\n",
"\n",
" df.drop(columns=['weight_std20'], inplace=True, errors='ignore')\n",
" new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]\n",
"\n",
" return df, new_columns\n",
"\n",
"\n",
"def get_simple_factor(df):\n",
" old_columns = df.columns.tolist()[:]\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" alpha = 0.5\n",
" df['momentum_factor'] = df['volume_change_rate'] + alpha * df['turnover_deviation']\n",
" df['resonance_factor'] = df['volume_ratio'] * df['pct_chg']\n",
" df['log_close'] = np.log(df['close'])\n",
"\n",
" df['cat_vol_spike'] = df['vol'] > 2 * df['vol_spike']\n",
"\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
"\n",
" df['ctrl_strength'] = (df['cost_85pct'] - df['cost_15pct']) / (df['his_high'] - df['his_low'])\n",
"\n",
" df['low_cost_dev'] = (df['close'] - df['cost_5pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
"\n",
" df['asymmetry'] = (df['cost_95pct'] - df['cost_50pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
"\n",
" df['lock_factor'] = df['turnover_rate'] * (\n",
" 1 - (df['cost_95pct'] - df['cost_5pct']) / (df['his_high'] - df['his_low']))\n",
"\n",
" df['cat_vol_break'] = (df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2)\n",
"\n",
" df['cost_atr_adj'] = (df['cost_95pct'] - df['cost_5pct']) / df['atr_14']\n",
"\n",
" # 12. 小盘股筹码集中度\n",
" df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])\n",
"\n",
" df['cat_golden_resonance'] = ((df['close'] > df['weight_avg']) &\n",
" (df['volume_ratio'] > 1.5) &\n",
" (df['winner_rate'] > 0.7))\n",
"\n",
" df['mv_turnover_ratio'] = df['turnover_rate'] / df['circ_mv']\n",
"\n",
" df['mv_adjusted_volume'] = df['vol'] / df['circ_mv']\n",
"\n",
" df['mv_weighted_turnover'] = df['turnover_rate'] * (1 / df['circ_mv'])\n",
"\n",
" df['nonlinear_mv_volume'] = df['vol'] / df['circ_mv']\n",
"\n",
" df['mv_volume_ratio'] = df['volume_ratio'] / df['circ_mv']\n",
"\n",
" df['mv_momentum'] = df['turnover_rate'] * df['volume_ratio'] / df['circ_mv']\n",
"\n",
" drop_columns = [col for col in df.columns if col.startswith('_')]\n",
" df.drop(columns=drop_columns, inplace=True, errors='ignore')\n",
"\n",
" new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]\n",
" return df, new_columns\n"
],
"outputs": [],
"execution_count": 70
},
{
"cell_type": "code",
"id": "53f86ddc0677a6d7",
"metadata": {
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-03-26T15:12:02.787850Z",
"start_time": "2025-03-26T15:11:56.035514Z"
}
},
"source": [
"from utils.factor import get_act_factor\n",
"\n",
"\n",
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" industry_data = get_act_factor(industry_data, cat=False)\n",
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" # factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
" #\n",
" # for factor in factor_columns:\n",
" # if factor in industry_data.columns:\n",
" # # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
" # industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
" # lambda x: x - x.mean())\n",
"\n",
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
"\n",
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
],
"outputs": [],
"execution_count": 71
},
{
"cell_type": "code",
"id": "dbe2fd8021b9417f",
"metadata": {
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-03-26T15:12:02.870181Z",
"start_time": "2025-03-26T15:12:02.865559Z"
}
},
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if\n",
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n",
"origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
"print(origin_columns)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ts_code', 'open', 'close', 'high', 'low', 'circ_mv', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'in_date']\n"
]
}
],
"execution_count": 72
},
{
"cell_type": "code",
"id": "92d84ce15a562ec6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:14:41.756942Z",
"start_time": "2025-03-26T15:12:03.028028Z"
}
},
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df[df['trade_date'] >= '20180101']\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"# df = get_technical_factor(df)\n",
"# df = get_act_factor(df)\n",
"# df = get_money_flow_factor(df)\n",
"# df = get_alpha_factor(df)\n",
"# df = get_limit_factor(df)\n",
"# df = get_cyp_perf_factor(df)\n",
"# df = get_mv_factors(df)\n",
"df, _ = get_rolling_factor(df)\n",
"df, _ = get_simple_factor(df)\n",
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 5102787 entries, 0 to 5102786\n",
"Columns: 122 entries, ts_code to mv_momentum\n",
"dtypes: bool(13), datetime64[ns](2), float64(103), int32(1), int64(1), object(2)\n",
"memory usage: 4.2+ GB\n",
"None\n"
]
}
],
"execution_count": 73
},
{
"cell_type": "code",
"id": "f4f16d63ad18d1bc",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:14:42.493112Z",
"start_time": "2025-03-26T15:14:42.475598Z"
}
},
"source": [
"from scipy.stats import ks_2samp, wasserstein_distance\n",
"from sklearn.metrics import roc_auc_score\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"def remove_shifted_features(train_data, test_data, feature_columns, ks_threshold=0.1, wasserstein_threshold=0.15,\n",
" importance_threshold=0.05):\n",
" dropped_features = []\n",
"\n",
" # **统计数据漂移**\n",
" numeric_columns = train_data.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" for feature in numeric_columns:\n",
" ks_stat, p_value = ks_2samp(train_data[feature], test_data[feature])\n",
" wasserstein_dist = wasserstein_distance(train_data[feature], test_data[feature])\n",
"\n",
" if p_value < ks_threshold or wasserstein_dist > wasserstein_threshold:\n",
" dropped_features.append(feature)\n",
"\n",
" print(f\"检测到 {len(dropped_features)} 个可能漂移的特征: {dropped_features}\")\n",
"\n",
" # **应用阈值进行最终筛选**\n",
" filtered_features = [f for f in feature_columns if f not in dropped_features]\n",
"\n",
" return filtered_features, dropped_features\n",
"\n"
],
"outputs": [],
"execution_count": 74
},
{
"cell_type": "code",
"id": "9d807cb2cde5d92c",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:14:42.825272Z",
"start_time": "2025-03-26T15:14:42.814210Z"
}
},
"source": [
"def create_deviation_within_dates(df, feature_columns):\n",
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
" new_columns = {}\n",
" ret_feature_columns = feature_columns[:]\n",
"\n",
" # 自动选择所有数值型特征\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
"\n",
" # num_features = ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'cat_vol_spike', 'obv', 'maobv_6', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013']\n",
" num_features = [col for col in num_features if 'cat' not in col and 'industry' not in col]\n",
" num_features = [col for col in num_features if 'limit' not in col]\n",
" num_features = [col for col in num_features if 'cyq' not in col]\n",
"\n",
" # 遍历所有数值型特征\n",
" for feature in num_features:\n",
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
" continue\n",
"\n",
" # grouped_mean = df.groupby(['trade_date'])[feature].transform('mean')\n",
" # deviation_col_name = f'deviation_mean_{feature}'\n",
" # new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" # ret_feature_columns.append(deviation_col_name)\n",
"\n",
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
" deviation_col_name = f'deviation_mean_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
"\n",
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
"\n",
" return df, ret_feature_columns\n"
],
"outputs": [],
"execution_count": 75
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:14:44.049228Z",
"start_time": "2025-03-26T15:14:43.001080Z"
}
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
" log=True):\n",
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
"\n",
" # Calculate lower and upper bounds based on percentiles\n",
" lower_bound = label.quantile(lower_percentile)\n",
" upper_bound = label.quantile(upper_percentile)\n",
"\n",
" # Filter out values outside the bounds\n",
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
"\n",
" # Print the number of removed outliers\n",
" if log:\n",
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
" return filtered_label\n",
"\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
" df['future_return'] = (df['future_close'] - df['future_open']) / df['future_open']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" sharpe_ratio = df['future_return'] * df['future_volatility']\n",
" sharpe_ratio.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return sharpe_ratio\n",
"\n",
"\n",
"def calculate_score(df, days=5, lambda_param=1.0):\n",
" def calculate_max_drawdown(prices):\n",
" peak = prices.iloc[0] # 初始化峰值\n",
" max_drawdown = 0 # 初始化最大回撤\n",
"\n",
" for price in prices:\n",
" if price > peak:\n",
" peak = price # 更新峰值\n",
" else:\n",
" drawdown = (peak - price) / peak # 计算当前回撤\n",
" max_drawdown = max(max_drawdown, drawdown) # 更新最大回撤\n",
"\n",
" return max_drawdown\n",
"\n",
" def compute_stock_score(stock_df):\n",
" stock_df = stock_df.sort_values(by=['trade_date'])\n",
" future_return = stock_df['future_return']\n",
" volatility = stock_df['close'].pct_change().rolling(days).std().shift(-days)\n",
" max_drawdown = stock_df['close'].rolling(days).apply(calculate_max_drawdown, raw=False).shift(-days)\n",
" score = future_return - lambda_param * max_drawdown\n",
"\n",
" return score\n",
"\n",
" scores = df.groupby('ts_code').apply(lambda x: compute_stock_score(x))\n",
" scores = scores.reset_index(level=0, drop=True)\n",
"\n",
" return scores\n",
"\n",
"\n",
"def remove_highly_correlated_features(df, feature_columns, threshold=0.9):\n",
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
" if not numeric_features:\n",
" raise ValueError(\"No numeric features found in the provided data.\")\n",
"\n",
" corr_matrix = df[numeric_features].corr().abs()\n",
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
" remaining_features = [col for col in feature_columns if col not in to_drop\n",
" or 'act' in col or 'af' in col]\n",
" return remaining_features\n",
"\n",
"\n",
"import pandas as pd\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"def cross_sectional_standardization(df, features):\n",
" df_sorted = df.sort_values(by='trade_date') # 按时间排序\n",
" df_standardized = df_sorted.copy()\n",
"\n",
" for date in df_sorted['trade_date'].unique():\n",
" # 获取当前时间点的数据\n",
" current_data = df_standardized[df_standardized['trade_date'] == date]\n",
"\n",
" # 只对指定特征进行标准化\n",
" scaler = StandardScaler()\n",
" standardized_values = scaler.fit_transform(current_data[features])\n",
"\n",
" # 将标准化结果重新赋值回去\n",
" df_standardized.loc[df_standardized['trade_date'] == date, features] = standardized_values\n",
"\n",
" return df_standardized\n",
"\n",
"\n",
"import gc\n",
"\n",
"gc.collect()"
],
"id": "7ba833ee11a2f4cc",
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 76
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:14:44.297918Z",
"start_time": "2025-03-26T15:14:44.058234Z"
}
},
"cell_type": "code",
"source": "print(df[['ts_code', 'trade_date', 'act_factor1']].head())",
"id": "8491afd6ea37782",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ts_code trade_date act_factor1\n",
"0 000001.SZ 2018-01-02 NaN\n",
"2536 000001.SZ 2018-01-03 NaN\n",
"5079 000001.SZ 2018-01-04 NaN\n",
"7623 000001.SZ 2018-01-05 NaN\n",
"10167 000001.SZ 2018-01-08 NaN\n"
]
}
],
"execution_count": 77
},
{
"cell_type": "code",
"id": "097356cb-1cd8-4947-b870-9414abfdb3d8",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:17:54.377205Z",
"start_time": "2025-03-26T15:14:44.467603Z"
}
},
"source": [
"days = 2\n",
"validation_days = 120\n",
"\n",
"import gc\n",
"\n",
"gc.collect()\n",
"\n",
"# df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / \\\n",
" df.groupby('ts_code')['open'].shift(-1)\n",
"df['future_volatility'] = (\n",
" df.groupby('ts_code')['future_return']\n",
" .transform(lambda x: x.rolling(days).std())\n",
")\n",
"\n",
"df['future_score'] = (\n",
" 0.7 * df['future_return'] +\n",
" 0.3 * df['future_volatility']\n",
")\n",
"\n",
"filter_index = df['future_return'].between(df['future_return'].quantile(0.01), df['future_return'].quantile(0.99))\n",
"filter_index = df['future_volatility'].between(df['future_volatility'].quantile(0.01),\n",
" df['future_volatility'].quantile(0.99)) | filter_index\n",
"\n",
"# df['label'] = df.groupby('trade_date', group_keys=False)['future_volatility'].transform(\n",
"# lambda x: pd.qcut(x, q=30, labels=False, duplicates='drop')\n",
"# )\n",
"\n",
"df['label'] = df.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" lambda x: pd.qcut(x, q=50, labels=False, duplicates='drop')\n",
")\n",
"\n",
"\n",
"# df['1_score'] = df.groupby('ts_code', group_keys=False)['future_score'].shift(days)\n",
"# df['2_score'] = df.groupby('ts_code', group_keys=False)['future_score'].shift(1 + days)\n",
"# df['3_score'] = df.groupby('ts_code', group_keys=False)['future_score'].shift(3 + days - 1)\n",
"\n",
"def symmetric_log_transform(values):\n",
" return np.sign(values) * np.log1p(np.abs(values))\n",
"\n",
"\n",
"train_data = df[filter_index & (df['trade_date'] <= '2023-01-01') & (df['trade_date'] >= '2000-01-01')]\n",
"test_data = df[filter_index & (df['trade_date'] >= '2023-01-01')]\n",
"\n",
"\n",
"def select_pre_zt_stocks_dynamic(stock_df):\n",
" # 排序数据\n",
" stock_df = stock_df.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # avg_vol_3 = stock_df.groupby('ts_code')['vol'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
" # avg_vol_5 = stock_df.groupby('ts_code')['vol'].rolling(window=5).mean().shift(3).reset_index(level=0, drop=True)\n",
"\n",
" # stock_df = stock_df[\n",
" # (stock_df['cat_up_limit'] == 1) |\n",
" # (stock_df['vol'] > vol_spike_multiplier * stock_df['vol_spike'])\n",
" # ]\n",
" # cd1 = stock_df[\"close\"] > stock_df[\"close\"].shift(1)\n",
"\n",
" # cd2 = stock_df[\"close\"] > stock_df[\"close\"].rolling(window=10).mean()\n",
" #\n",
" # cd3 = (avg_vol_3 > avg_vol_5 * 2)\n",
" #\n",
" # cd4 = stock_df['gap_next_open'] < 0\n",
"\n",
" # stock_df = stock_df[(cd2 & cd4) | cd3]\n",
" stock_df = stock_df.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nlargest(1000, 'return_20')\n",
" )\n",
"\n",
" return stock_df\n",
"\n",
"\n",
"# train_data = select_pre_zt_stocks_dynamic(train_data)\n",
"# test_data = select_pre_zt_stocks_dynamic(test_data)\n",
"\n",
"train_data, _ = get_simple_factor(train_data)\n",
"test_data, _ = get_simple_factor(test_data)\n",
"\n",
"# train_data['label'] = train_data.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
"# lambda x: pd.qcut(x, q=50, labels=False, duplicates='drop')\n",
"# )\n",
"# test_data['label'] = test_data.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
"# lambda x: pd.qcut(x, q=50, labels=False, duplicates='drop')\n",
"# )\n",
"\n",
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
"index_data = index_data.sort_values(by=['trade_date'])\n",
"\n",
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"\n",
"# feature_columns_new = feature_columns[:]\n",
"# train_data, _ = create_deviation_within_dates(train_data, feature_columns)\n",
"# test_data, _ = create_deviation_within_dates(test_data, feature_columns)\n",
"\n",
"feature_columns = [col for col in train_data.columns if col in train_data.columns]\n",
"feature_columns = [col for col in feature_columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'label' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if 'gen' not in col]\n",
"feature_columns = [col for col in feature_columns if 'cat_l2_code' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"print(f'feature_columns size: {len(feature_columns)}')\n",
"\n",
"feature_columns, _ = remove_shifted_features(train_data[train_data['label'] == train_data['label'].max()],\n",
" test_data[test_data['label'] == test_data['label'].max()],\n",
" feature_columns)\n",
"\n",
"feature_columns = remove_highly_correlated_features(train_data[train_data['label'] == train_data['label'].max()],\n",
" feature_columns)\n",
"keep_columns = [col for col in train_data.columns if\n",
" col in feature_columns or col in ['ts_code', 'trade_date', 'label', 'future_return',\n",
" 'future_score', 'future_volatility']]\n",
"# train_data = train_data[keep_columns]\n",
"print(f'feature_columns: {feature_columns}')\n",
"\n",
"train_data = train_data.dropna(subset=feature_columns)\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"\n",
"# print(test_data.tail())\n",
"test_data = test_data.dropna(subset=feature_columns)\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"\n",
"cat_columns = [col for col in feature_columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')\n",
"\n",
"\n",
"\n",
"# feature_columns_new.remove('cat_l2_code')"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feature_columns size: 112\n",
"检测到 18 个可能漂移的特征: ['vol', 'pct_chg', 'turnover_rate', 'vol_std_5', 'obv', 'log(circ_mv)', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'up_limit_count_10d', 'log_close', 'up', 'down', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume']\n",
"feature_columns: ['pe_ttm', 'volume_ratio', 'winner_rate', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'atr_14', 'maobv_6', 'rsi_3', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'alpha_007', 'alpha_013', 'cat_up_limit', 'cat_down_limit', 'down_limit_count_10d', 'consecutive_up_limit', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'cat_vol_spike', 'obv-maobv_6', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'cat_af1', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_volume_ratio', 'mv_momentum', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile']\n",
"2543987\n",
"最小日期: 2018-06-04\n",
"最大日期: 2022-12-30\n",
"1234512\n",
"最小日期: 2023-01-03\n",
"最大日期: 2025-03-19\n"
]
}
],
"execution_count": 78
},
{
"cell_type": "code",
"id": "8f134d435f71e9e2",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:17:54.732350Z",
"start_time": "2025-03-26T15:17:54.705934Z"
}
},
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"import lightgbm as lgb\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.decomposition import PCA\n",
"\n",
"\n",
"def train_light_model(train_data_df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" validation_days=180, use_pca=False, split_date=None,\n",
" label_column='label'): # 新增参数validation_days\n",
" # 确保数据按时间排序\n",
" train_data_df = train_data_df.sort_values(by='trade_date')\n",
"\n",
" numeric_columns = train_data_df.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" # X_train.loc[:, numeric_columns] = scaler.fit_transform(X_train[numeric_columns])\n",
" # X_val.loc[:, numeric_columns] = scaler.transform(X_val[numeric_columns])\n",
" train_data_df = cross_sectional_standardization(train_data_df, numeric_columns)\n",
"\n",
" # 去除标签为空的样本\n",
" train_data_df = train_data_df.dropna(subset=[label_column])\n",
" print('原始训练集大小: ', len(train_data_df))\n",
"\n",
" # 按时间顺序划分训练集和验证集\n",
" if split_date is None:\n",
" all_dates = train_data_df['trade_date'].unique() # 获取所有唯一的 trade_date\n",
" split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
" train_data_split = train_data_df[train_data_df['trade_date'] < split_date] # 训练集\n",
" val_data_split = train_data_df[train_data_df['trade_date'] >= split_date] # 验证集\n",
"\n",
" # 打印划分结果\n",
" print(f\"划分后的训练集大小: {len(train_data_split)}, 验证集大小: {len(val_data_split)}\")\n",
"\n",
" # 提取特征和标签\n",
" X_train = train_data_split[feature_columns]\n",
" y_train = train_data_split[label_column]\n",
"\n",
" X_val = val_data_split[feature_columns]\n",
" y_val = val_data_split[label_column]\n",
"\n",
" # 标准化数值特征\n",
" scaler = StandardScaler()\n",
"\n",
" # 计算每个 trade_date 内的样本数LTR 需要 group 信息)\n",
" train_groups = train_data_split.groupby('trade_date').size().tolist()\n",
" val_groups = val_data_split.groupby('trade_date').size().tolist()\n",
"\n",
" # 处理类别特征\n",
" categorical_feature = [col for col in feature_columns if 'cat' in col]\n",
"\n",
" pca = None\n",
" if use_pca:\n",
" pca = PCA(n_components=0.95) # 或指定 n_components=固定值(如 10\n",
" numeric_features = [col for col in feature_columns if col not in categorical_feature]\n",
" numeric_pca = pca.fit_transform(X_train[numeric_features])\n",
" X_train = pd.concat([pd.DataFrame(numeric_pca, index=X_train.index), X_train[categorical_feature]], axis=1)\n",
"\n",
" numeric_pca = pca.transform(X_val[numeric_features])\n",
" X_val = pd.concat([pd.DataFrame(numeric_pca, index=X_val.index), X_val[categorical_feature]], axis=1)\n",
"\n",
" # 计算权重(基于时间)\n",
" # trade_date = train_data_split['trade_date'] # 交易日期\n",
" # weights = (trade_date - trade_date.min()).dt.days / (trade_date.max() - trade_date.min()).days + 1\n",
" # weights = train_data_split.groupby('trade_date')['std_return_5'].transform(\n",
" # lambda x: x / x.mean()\n",
" # )\n",
" ud = sorted(train_data_split[\"trade_date\"].unique().tolist())\n",
" date_weights = {date: weight * weight for date, weight in zip(ud, np.linspace(1, 10, len(ud)))}\n",
" params['weight'] = train_data_split[\"trade_date\"].map(date_weights).tolist()\n",
"\n",
" print('feature_columns size: ', len(X_train.columns.tolist()))\n",
"\n",
" train_dataset = lgb.Dataset(\n",
" X_train, label=y_train, group=train_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
"\n",
" # weights = val_data_split.groupby('trade_date')['std_return_5'].transform(\n",
" # lambda x: x / x.mean()\n",
" # )\n",
" val_dataset = lgb.Dataset(\n",
" X_val, label=y_val, group=val_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
"\n",
" # 训练模型\n",
" # 显式创建 LGBMRanker 模型\n",
" model = lgb.train(\n",
" params, train_dataset, num_boost_round=num_boost_round,\n",
" valid_sets=[train_dataset, val_dataset], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
"\n",
" return model, scaler, pca\n"
],
"outputs": [],
"execution_count": 79
},
{
"cell_type": "code",
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:21:10.639294Z",
"start_time": "2025-03-26T15:17:54.865846Z"
}
},
"source": [
"print('train data size: ', len(train_data))\n",
"\n",
"if 'gen_volatility' in feature_columns:\n",
" feature_columns.remove('gen_volatility')\n",
"\n",
"label_gain = list(range(len(train_data['label'].unique())))\n",
"label_gain = [gain * gain for gain in label_gain]\n",
"light_params = {\n",
" 'label_gain': label_gain,\n",
" 'objective': 'lambdarank',\n",
" 'metric': 'lambdarank',\n",
" 'learning_rate': 0.03,\n",
" 'num_leaves': 1024,\n",
" 'min_data_in_leaf': 512,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
" 'feature_fraction': 0.7,\n",
" 'bagging_fraction': 1,\n",
" 'bagging_freq': 5,\n",
" 'lambda_l1': 0.15,\n",
" 'lambda_l2': 0.15,\n",
" # 'boosting': 'dart',\n",
" 'verbosity': -1,\n",
" 'extra_trees': True,\n",
" 'max_position': 5,\n",
" 'ndcg_at': 1,\n",
" 'seed': 7\n",
"}\n",
"evals = {}\n",
"\n",
"gc.collect()\n",
"\n",
"use_pca = False\n",
"feature_contri = [2 if feat.startswith('act_factor') else 1 for feat in feature_columns]\n",
"light_params['feature_contri'] = feature_contri\n",
"print(f'feature_contri: {feature_contri}')\n",
"model, scaler, pca = train_light_model(train_data.copy().dropna(subset=['label']),\n",
" light_params, feature_columns,\n",
" [lgb.log_evaluation(period=100),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=1000, validation_days=validation_days,\n",
" print_feature_importance=True, use_pca=use_pca)\n",
"\n",
"print('train data size: ', len(train_data))"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 2543987\n",
"feature_contri: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
"原始训练集大小: 2543987\n",
"划分后的训练集大小: 2262535, 验证集大小: 281452\n",
"feature_columns size: 87\n",
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
"[18]\ttrain's ndcg@1: 0.649301\tvalid's ndcg@1: 0.586988\n",
"Evaluated only: ndcg@1\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAewZJREFUeJzt3Xd4lFXax/HvpPdCKr33TigCKqIggqJYURAFFRVh1WV5Xbuiq7gWRFHBVRHXFcWKBUEBKaL03juEllBCOmkzz/vHIYGYQApJJpn8Ptc1V2aeecqZk2TmnnPuc47NsiwLERERERfh5uwCiIiIiJQlBTciIiLiUhTciIiIiEtRcCMiIiIuRcGNiIiIuBQFNyIiIuJSFNyIiIiIS1FwIyIiIi5FwY2IiIi4FAU3IsL06dOx2Wzs37+/3K7x/PPPY7PZqsx5nW3//v3YbDamT59equNtNhvPP/98mZZJpKpQcCNSgXKDCJvNxtKlSws8b1kWdevWxWazcd1115XqGu+9916pPxClZGbMmMGkSZOcXQwR+QsFNyJO4OPjw4wZMwpsX7x4MYcOHcLb27vU5y5NcDNs2DBOnz5N/fr1S31dZ3n66ac5ffq0U65dnsFN/fr1OX36NMOGDSvV8adPn+bpp58u41KJVA0KbkScYMCAAXz11Vfk5OTk2z5jxgxiYmKIjo6ukHKkpaUB4O7ujo+PT5Xq3sktu4eHBz4+Pk4uTdEyMjJwOBzF3t9ms+Hj44O7u3uprufj44OHh0epjhWp6hTciDjBHXfcwcmTJ5k3b17etqysLL7++muGDBlS6DEOh4NJkybRunVrfHx8iIqK4oEHHuDUqVN5+zRo0IAtW7awePHivO6vK664AjjbJbZ48WIeeughIiMjqVOnTr7n/ppzM2fOHHr16kVgYCBBQUF06dKl0Banv1q6dCldunTBx8eHxo0b8/777xfY50I5JX/NF8nNq9m6dStDhgwhNDSUSy+9NN9zfz1+zJgxzJo1izZt2uDt7U3r1q2ZO3dugWstWrSIzp075ytrcfJ4rrjiCmbPns2BAwfy6rpBgwZ557TZbHzxxRc8/fTT1K5dGz8/P5KTk0lISGDcuHG0bduWgIAAgoKC6N+/Pxs2bCiyfoYPH05AQACHDx9m0KBBBAQEEBERwbhx47Db7cWqw927dzN8+HBCQkIIDg5mxIgRpKen5zv29OnTPPzww4SHhxMYGMj111/P4cOHlccjVYbCehEnaNCgAd27d+fzzz+nf//+gAkkkpKSuP3223n77bcLHPPAAw8wffp0RowYwcMPP8y+fft45513WLduHX/88Qeenp5MmjSJv/3tbwQEBPDUU08BEBUVle88Dz30EBERETz77LN5rR+FmT59Ovfccw+tW7fmiSeeICQkhHXr1jF37tzzBmAAmzZt4uqrryYiIoLnn3+enJwcnnvuuQLlKI1bb72Vpk2b8vLLL2NZ1gX3Xbp0Kd9++y0PPfQQgYGBvP3229x8883ExsYSFhYGwLp167jmmmuoWbMm48ePx26388ILLxAREVFkWZ566imSkpI4dOgQb775JgABAQH59nnxxRfx8vJi3LhxZGZm4uXlxdatW5k1axa33norDRs2JD4+nvfff59evXqxdetWatWqdcHr2u12+vXrR7du3Xj99deZP38+b7zxBo0bN2bUqFFFlvu2226jYcOGTJgwgbVr1/Lhhx8SGRnJv//977x9hg8fzpdffsmwYcO45JJLWLx4Mddee22R5xapNCwRqTAff/yxBVirVq2y3nnnHSswMNBKT0+3LMuybr31Vqt3796WZVlW/fr1rWuvvTbvuN9//90CrM8++yzf+ebOnVtge+vWra1evXqd99qXXnqplZOTU+hz+/btsyzLshITE63AwECrW7du1unTp/Pt63A4LvgaBw0aZPn4+FgHDhzI27Z161bL3d3dOvctZ9++fRZgffzxxwXOAVjPPfdc3uPnnnvOAqw77rijwL65z/31eC8vL2v37t152zZs2GAB1uTJk/O2DRw40PLz87MOHz6ct23Xrl2Wh4dHgXMW5tprr7Xq169fYPvChQstwGrUqFHe7zdXRkaGZbfb823bt2+f5e3tbb3wwgv5tv21fu6++24LyLefZVlWx44drZiYmAJ1UFgd3nPPPfn2u/HGG62wsLC8x2vWrLEA69FHH8233/DhwwucU6SyUreUiJPcdtttnD59mp9++omUlBR++umn87aIfPXVVwQHB9O3b19OnDiRd4uJiSEgIICFCxcW+7ojR44sMo9j3rx5pKSk8PjjjxfIZ7lQd43dbueXX35h0KBB1KtXL297y5Yt6devX7HLeD4PPvhgsfft06cPjRs3znvcrl07goKC2Lt3b15Z58+fz6BBg/K1ljRp0iSvNe1i3X333fj6+ubb5u3tjZubW14ZTp48SUBAAM2bN2ft2rXFOu9f6+Gyyy7Le12lOfbkyZMkJycD5HXdPfTQQ/n2+9vf/las84tUBuqWEnGSiIgI+vTpw4wZM0hPT8dut3PLLbcUuu+uXbtISkoiMjKy0OePHTtW7Os2bNiwyH327NkDQJs2bYp9XoDjx49z+vRpmjZtWuC55s2b8/PPP5fofH9VnLLnOje4yhUaGpqXo3Ts2DFOnz5NkyZNCuxX2LbSKKy8DoeDt956i/fee499+/bly5XJ7S67EB8fnwLdZue+rqL8tV5CQ0MBOHXqFEFBQRw4cAA3N7cCZS+rOhGpCApuRJxoyJAhjBw5kri4OPr3709ISEih+zkcDiIjI/nss88Kfb44OSK5/tqS4CznawH6a2LsuUpS9vO1TllF5OqUpcLK+/LLL/PMM89wzz338OKLL1KjRg3c3Nx49NFHizWaqrSjp4o6viLrRaS8KbgRcaIbb7yRBx54gOXLlzNz5szz7te4cWPmz59Pz549i/yAL4vh3LndOZs3by7RN/aIiAh8fX3ZtWtXged27NiR73Fui0FiYmK+7QcOHChhaUsnMjISHx8fdu/eXeC5wrYVpjR1/fXXX9O7d28++uijfNsTExMJDw8v8fnKWv369XE4HOzbty9fC1xx60SkMlDOjYgTBQQEMGXKFJ5//nkGDhx43v1uu+027HY7L774YoHncnJy8gUI/v7+BQKGkrr66qsJDAxkwoQJZGRk5HvuQt/w3d3d6devH7NmzSI2NjZv+7Zt2/jll1/y7RsUFER4eDhLlizJt/299967qLIXl7u7O3369GHWrFkcOXIkb/vu3buZM2dOsc7h7+9PUlJSia/71zr86quvOHz4cInOU15yc6P++nuYPHmyM4ojUipquRFxsrvvvrvIfXr16sUDDzzAhAkTWL9+PVdffTWenp7s2rWLr776irfeeisvXycmJoYpU6bwr3/9iyZNmhAZGcmVV15ZojIFBQXx5ptvct9999GlS5e8uWU2bNhAeno6n3zyyXmPHT9+PHPnzuWyyy7joYceIicnh8mTJ9O6dWs2btyYb9/77ruPV155hfvuu4/OnTuzZMkSdu7cWaKyXoznn3+eX3/9lZ49ezJq1CjsdjvvvPMObdq0Yf369UUeHxMTw8yZMxk7dixdunQhICDggkEqwHXXXccLL7zAiBEj6NGjB5s2beKzzz6jUaNGZfSqLk5MTAw333wzkyZN4uTJk3lDwXN/L1VpokepvhTciFQRU6dOJSYmhvfff58nn3wSDw8PGjRowJ133knPnj3z9nv22Wc5cOAAr776KikpKfTq1avEwQ3AvffeS2RkJK+88govvvginp6etGjRgr///e8XPK5du3b88ssvjB07lmeffZY6deowfvx4jh49WiC4efbZZzl+/Dhff/01X375Jf3792fOnDnnTZwuazExMcyZM4dx48bxzDPPULduXV544QW2bdvG9u3bizz+oYceYv369Xz88ce8+eab1K9fv8jg5sknnyQtLY0ZM2Ywc+ZMOnXqxOzZs3n88cfL6mVdtP/+979ER0fz+eef891339GnTx9mzpxJ8+bNq8Rs0CI2S1lkIiL5DBo0iC1bthSaO1R
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAvoAAAHHCAYAAADOE/w7AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAA7+ZJREFUeJzs3XdYFFf78PHv0osUQRQwKKiI2LF3UVGwYA8WomDEGhJ7i4pgCYoNWzQxUezGJ5YUjYqFWGOLPTZQRI2JJSoiipR9//Blfq6AFFGUvT/XtRfMzJk59727sGfPnDmjUqvVaoQQQgghhBCFik5BByCEEEIIIYTIf9LQF0IIIYQQohCShr4QQgghhBCFkDT0hRBCCCGEKISkoS+EEEIIIUQhJA19IYQQQgghCiFp6AshhBBCCFEISUNfCCGEEEKIQkga+kIIIYQQQhRC0tAXQgghPgARERGoVCpiY2MLOhQhxAdCGvpCCCHeS+kN28weY8eOfSt1Hjp0iODgYB4+fPhWjq/NEhMTCQ4OJioqqqBDEUJr6BV0AEIIIcTrTJ48GScnJ411lStXfit1HTp0iJCQEPz9/bG0tHwrdeRVr1696N69O4aGhgUdSp4kJiYSEhICgLu7e8EGI4SWkIa+EEKI91rr1q2pVatWQYfxRp48eYKpqekbHUNXVxddXd18iujdSUtL4/nz5wUdhhBaSYbuCCGE+KD99ttvNG7cGFNTU8zMzGjbti3nz5/XKHPmzBn8/f0pU6YMRkZG2Nra8umnn3L//n2lTHBwMKNGjQLAyclJGSYUGxtLbGwsKpWKiIiIDPWrVCqCg4M1jqNSqfjrr7/o2bMnRYsWpVGjRsr21atXU7NmTYyNjbGysqJ79+7cuHEj2zwzG6Pv6OhIu3btiIqKolatWhgbG1OlShVleMymTZuoUqUKRkZG1KxZk5MnT2oc09/fnyJFinD16lU8PT0xNTXF3t6eyZMno1arNco+efKEESNG4ODggKGhIS4uLsyaNStDOZVKRWBgIGvWrKFSpUoYGhqyZMkSbGxsAAgJCVGe2/TnLSevz8vPbXR0tHLWxcLCgj59+pCYmJjhOVu9ejV16tTBxMSEokWL0qRJE3bu3KlRJifvHyE+VNKjL4QQ4r326NEj7t27p7GuWLFiAKxatQo/Pz88PT2ZMWMGiYmJLF68mEaNGnHy5EkcHR0BiIyM5OrVq/Tp0wdbW1vOnz/Pt99+y/nz5/njjz9QqVR07tyZy5cvs27dOubOnavUYWNjw927d3Md98cff4yzszNfffWV0hieNm0aEydOxMfHh4CAAO7evcuCBQto0qQJJ0+ezNNwoejoaHr27MmAAQP45JNPmDVrFt7e3ixZsoQvv/ySwYMHAxAaGoqPjw+XLl1CR+f/+vlSU1Px8vKiXr16hIWFsX37diZNmkRKSgqTJ08GQK1W0759e/bu3Uvfvn2pXr06O3bsYNSoUdy6dYu5c+dqxLRnzx42bNhAYGAgxYoVo1q1aixevJhBgwbRqVMnOnfuDEDVqlWBnL0+L/Px8cHJyYnQ0FD+/PNPvvvuO4oXL86MGTOUMiEhIQQHB9OgQQMmT56MgYEBR44cYc+ePbRq1QrI+ftHiA+WWgghhHgPLV++XA1k+lCr1erHjx+rLS0t1f369dPY759//lFbWFhorE9MTMxw/HXr1qkB9b59+5R1M2fOVAPqa9euaZS9du2aGlAvX748w3EA9aRJk5TlSZMmqQF1jx49NMrFxsaqdXV11dOmTdNYf/bsWbWenl6G9Vk9Hy/HVrp0aTWgPnTokLJux44dakBtbGysvn79urL+m2++UQPqvXv3Kuv8/PzUgPrzzz9X1qWlpanbtm2rNjAwUN+9e1etVqvVW7ZsUQPqqVOnasTUtWtXtUqlUkdHR2s8Hzo6Ourz589rlL17926G5ypdTl+f9Of2008/1SjbqVMntbW1tbJ85coVtY6OjrpTp07q1NRUjbJpaWlqtTp37x8hPlQydEcIIcR7bdGiRURGRmo84EUv8MOHD+nRowf37t1THrq6utStW5e9e/cqxzA2NlZ+f/bsGffu3aNevXoA/Pnnn28l7oEDB2osb9q0ibS0NHx8fDTitbW1xdnZWSPe3KhYsSL169dXluvWrQtA8+bNKVWqVIb1V69ezXCMwMBA5ff0oTfPnz9n165dAGzbtg1dXV2++OILjf1GjBiBWq3mt99+01jftGlTKlasmOMccvv6vPrcNm7cmPv37xMfHw/Ali1bSEtLIygoSOPsRXp+kLv3jxAfKhm6I4QQ4r1Wp06dTC/GvXLlCvCiQZsZc3Nz5ff//vuPkJAQ1q9fz507dzTKPXr0KB+j/T+vzhR05coV1Go1zs7OmZbX19fPUz0vN+YBLCwsAHBwcMh0/YMHDzTW6+joUKZMGY115cuXB1CuB7h+/Tr29vaYmZlplHN1dVW2v+zV3LOT29fn1ZyLFi0KvMjN3NycmJgYdHR0XvtlIzfvHyE+VNLQF0II8UFKS0sDXoyztrW1zbBdT+//PuJ8fHw4dOgQo0aNonr16hQpUoS0tDS8vLyU47zOq2PE06Wmpma5z8u91OnxqlQqfvvtt0xnzylSpEi2cWQmq5l4slqvfuXi2bfh1dyzk9vXJz9yy837R4gPlbyLhRBCfJDKli0LQPHixfHw8Miy3IMHD9i9ezchISEEBQUp69N7dF+WVYM+vcf41RtpvdqTnV28arUaJycnpcf8fZCWlsbVq1c1Yrp8+TKAcjFq6dKl2bVrF48fP9bo1b948aKyPTtZPbe5eX1yqmzZsqSlpfHXX39RvXr1LMtA9u8fIT5kMkZfCCHEB8nT0xNzc3O++uorkpOTM2xPnyknvff31d7e8PDwDPukz3X/aoPe3NycYsWKsW/fPo31X3/9dY7j7dy5M7q6uoSEhGSIRa1WZ5hK8l1auHChRiwLFy5EX1+fFi1aANCmTRtSU1M1ygHMnTsXlUpF69ats63DxMQEyPjc5ub1yamOHTuio6PD5MmTM5wRSK8np+8fIT5k0qMvhBDig2Rubs7ixYvp1asXNWrUoHv37tjY2BAXF8fWrVtp2LAhCxcuxNzcnCZNmhAWFkZycjIlS5Zk586dXLt2LcMxa9asCcD48ePp3r07+vr6eHt7Y2pqSkBAANOnTycgIIBatWqxb98+pec7J8qWLcvUqVMZN24csbGxdOzYETMzM65du8bmzZvp378/I0eOzLfnJ6eMjIzYvn07fn5+1K1bl99++42tW7fy5ZdfKnPfe3t706xZM8aPH09sbCzVqlVj586d/PTTTwwdOlTpHX8dY2NjKlasyA8//ED58uWxsrKicuXKVK5cOcevT06VK1eO8ePHM2XKFBo3bkznzp0xNDTk2LFj2NvbExoamuP3jxAftAKa7UcIIYR4rfTpJI8dO/bacnv37lV7enqqLSws1EZGRuqyZcuq/f391cePH1fK3Lx5U92pUye1paWl2sLCQv3xxx+r//7770yne5wyZYq6ZMmSah0dHY3pLBMTE9V9+/ZVW1hYqM3MzNQ+Pj7qO3fuZDm9ZvrUlK/auHGjulGjRmpTU1O1qampukKFCurPPvtMfenSpRw9H69Or9m2bdsMZQH1Z599prEufYrQmTNnKuv8/PzUpqam6piYGHWrVq3UJiYm6hIlSqgnTZqUYVrKx48fq4cNG6a2t7dX6+vrq52dndUzZ85Upqt8Xd3pDh06pK5Zs6bawMBA43nL6euT1XOb2XOjVqvVy5YtU7u5uakNDQ3VRYsWVTdt2lQdGRmpUSYn7x8hPlQqtfodXJUjhBBCiPeOv78/P/74IwkJCQUdihDiLZAx+kIIIYQQQhRC0tAXQgghhBCiEJKGvhBCCCGEEIWQjNEXQgghhBCiEJIefSGEEEIIIQohaegLIYQQQghRCMkNs4TQUmlpafz999+YmZlleWt6IYQQQrxf1Go1jx8/xt7eHh2d1/fZS0NfCC31999/4+DgUNBhCCGEECIPbty4wUcfffTaMtLQF0JLmZmZAXDt2jWsrKwKOJq3Lzk
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 2543987\n"
]
}
],
"execution_count": 80
},
{
"cell_type": "code",
"id": "63235069-dc59-48fb-961a-e80373e41a61",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:21:10.902087Z",
"start_time": "2025-03-26T15:21:10.893389Z"
}
},
"source": [
"print('train data size: ', len(train_data))\n",
"\n",
"catboost_params = {\n",
" 'loss_function': 'QuerySoftMax', # 排序损失函数\n",
" 'custom_metric': ['NDCG', 'AverageGain:top=10'],\n",
" 'iterations': 5000, # 训练轮数\n",
" 'learning_rate': 0.05, # 学习率\n",
" 'depth': 10, # 树的深度,防止过拟合\n",
" # 'l2_leaf_reg': 10.0, # L2 正则化,提高泛化能力\n",
" # 'bagging_temperature': 1, # 降低过拟合\n",
" # 'subsample': 0.8, # 每轮随机 80% 样本\n",
" # 'colsample_bylevel': 0.8, # 每层 80% 特征\n",
" 'random_seed': 42, # 固定随机种子\n",
" 'verbose': 100, # 每 100 轮打印一次信息\n",
" 'early_stopping_rounds': 100, # 早停,防止过拟合\n",
" 'has_time': True, # 让模型知道数据有时间顺序\n",
" # 'task_type':\"GPU\",\n",
"}\n",
"\n",
"# model = train_catboost(train_data, test_data, feature_columns_new, catboost_params, plot=True)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 2543987\n"
]
}
],
"execution_count": 81
},
{
"cell_type": "code",
"id": "5d1522a7538db91b",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:21:10.967045Z",
"start_time": "2025-03-26T15:21:10.950621Z"
}
},
"source": [
"from tqdm import tqdm\n",
"\n",
"\n",
"def incremental_training(test_data: pd.DataFrame,\n",
" model,\n",
" scaler,\n",
" days: int,\n",
" back_days: int,\n",
" feature_columns: list,\n",
" params: dict,\n",
" model_type: str = 'lightgbm',\n",
" times=10\n",
" ):\n",
" if model_type not in ['lightgbm', 'catboost']:\n",
" raise ValueError(\"model_type must be either 'lightgbm' or 'catboost'\")\n",
"\n",
" test_data = test_data.sort_values(by='trade_date')\n",
" scores = []\n",
" unique_trade_dates = sorted(test_data['trade_date'].unique())\n",
"\n",
" new_model = None\n",
" current_times = 0\n",
" for i in tqdm(range(0, len(unique_trade_dates))):\n",
" # Get the current window of trade dates\n",
" current_dates = [unique_trade_dates[i]]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" X = window_data[feature_columns]\n",
"\n",
" if new_model is not None:\n",
" window_scores = new_model.predict(X, prediction_type='RawFormulaVal')\n",
" else:\n",
" window_scores = model.predict(X, prediction_type='RawFormulaVal')\n",
" scores.extend(window_scores)\n",
" current_times += 1\n",
"\n",
" if current_times % times == 0:\n",
" current_dates = unique_trade_dates[max(0, i - days - back_days):i + 1 - back_days]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" X_train = window_data[feature_columns]\n",
" y_train = window_data['label'] # Assuming 'label' is what you're predicting\n",
"\n",
" # Incrementally train the model\n",
" if len(y_train.unique()) > 1:\n",
" if model_type == 'lightgbm':\n",
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" train_groups = window_data.groupby('trade_date').size().tolist()\n",
" train_data = lgb.Dataset(X_train, label=y_train, group=train_groups,\n",
" categorical_feature=categorical_feature)\n",
" new_model = lgb.train(params,\n",
" train_set=train_data,\n",
" num_boost_round=24,\n",
" init_model=model,\n",
" keep_training_booster=True)\n",
" # print(f\"Number of trees: {model.num_trees()}\")\n",
" elif model_type == 'catboost':\n",
" from catboost import Pool\n",
" train_data = Pool(data=X_train, label=y_train,\n",
" cat_features=[col for col in feature_columns if col.startswith('cat')])\n",
" # model.set_params(**params)\n",
" model.fit(train_data, init_model=model)\n",
" # else:\n",
" # print(current_dates)\n",
"\n",
" # Add the scores as a new 'score' column to the test_data\n",
" test_data['score'] = scores\n",
" return test_data"
],
"outputs": [],
"execution_count": 82
},
{
"cell_type": "code",
"id": "bbcc55a58ee063d6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-26T15:42:15.991756Z",
"start_time": "2025-03-26T15:36:18.479450Z"
}
},
"source": [
"numeric_columns = test_data.select_dtypes(include=['float64', 'int64']).columns\n",
"numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
"td = cross_sectional_standardization(test_data, numeric_columns)\n",
"predictions_test = incremental_training(td, model, scaler, 180, days, feature_columns, light_params,\n",
" model_type='lightgbm', times=20)\n",
"predictions_test = predictions_test.loc[predictions_test.groupby('trade_date')['score'].idxmax()]\n",
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)"
],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 533/533 [05:13<00:00, 1.70it/s]\n"
]
}
],
"execution_count": 86
},
{
"cell_type": "code",
"id": "020c3e3b-388b-42aa-a089-895057230122",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": [],
"ExecuteTime": {
"end_time": "2025-03-26T15:34:41.558604Z",
"start_time": "2025-03-26T15:34:40.969894Z"
}
},
"source": "print(df[['ts_code', 'trade_date', 'act_factor1']].head())\n",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ts_code trade_date act_factor1\n",
"0 000001.SZ 2018-01-02 NaN\n",
"2536 000001.SZ 2018-01-03 NaN\n",
"5079 000001.SZ 2018-01-04 NaN\n",
"7623 000001.SZ 2018-01-05 NaN\n",
"10167 000001.SZ 2018-01-08 NaN\n"
]
}
],
"execution_count": 84
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}