2025-02-12 00:21:33 +08:00
|
|
|
|
{
|
|
|
|
|
|
"cells": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "79a7758178bafdd3",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"jupyter": {
|
|
|
|
|
|
"source_hidden": true
|
2025-02-15 23:33:34 +08:00
|
|
|
|
},
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T17:33:40.203900Z",
|
|
|
|
|
|
"start_time": "2025-02-14T17:33:40.121936Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"%load_ext autoreload\n",
|
|
|
|
|
|
"%autoreload 2\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"import pandas as pd\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"def read_and_merge_h5_data(h5_filename, key, columns, df=None, join='left'):\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
|
" 读取 HDF5 文件中的数据,根据指定的 columns 筛选数据,\n",
|
|
|
|
|
|
" 如果传入 df 参数,则将其与读取的数据根据 ts_code 和 trade_date 合并。\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" 参数:\n",
|
|
|
|
|
|
" - h5_filename: HDF5 文件名\n",
|
|
|
|
|
|
" - key: 数据存储在 HDF5 文件中的 key\n",
|
|
|
|
|
|
" - columns: 要读取的列名列表\n",
|
|
|
|
|
|
" - df: 需要合并的 DataFrame(如果为空,则不进行合并)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" 返回:\n",
|
|
|
|
|
|
" - 合并后的 DataFrame\n",
|
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
|
" # 处理 _ 开头的列名\n",
|
|
|
|
|
|
" processed_columns = []\n",
|
|
|
|
|
|
" for col in columns:\n",
|
|
|
|
|
|
" if col.startswith('_'):\n",
|
|
|
|
|
|
" processed_columns.append(col[1:]) # 去掉下划线\n",
|
|
|
|
|
|
" else:\n",
|
|
|
|
|
|
" processed_columns.append(col)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 从 HDF5 文件读取数据,选择需要的列\n",
|
|
|
|
|
|
" data = pd.read_hdf(h5_filename, key=key, columns=processed_columns)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 修改列名,如果列名以前有 _,加上 _\n",
|
|
|
|
|
|
" for col in data.columns:\n",
|
|
|
|
|
|
" if col not in columns: # 只有不在 columns 中的列才需要加下划线\n",
|
|
|
|
|
|
" new_col = f'_{col}'\n",
|
|
|
|
|
|
" data.rename(columns={col: new_col}, inplace=True)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 如果传入的 df 不为空,则进行合并\n",
|
|
|
|
|
|
" if df is not None and not df.empty:\n",
|
|
|
|
|
|
" # 确保两个 DataFrame 都有 ts_code 和 trade_date 列\n",
|
|
|
|
|
|
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
|
|
|
|
|
|
" data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d')\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 根据 ts_code 和 trade_date 合并\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" merged_df = pd.merge(df, data, on=['ts_code', 'trade_date'], how=join)\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" else:\n",
|
|
|
|
|
|
" # 如果 df 为空,则直接返回读取的数据\n",
|
|
|
|
|
|
" merged_df = data\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" return merged_df\n",
|
|
|
|
|
|
"\n"
|
2025-02-15 23:33:34 +08:00
|
|
|
|
],
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"The autoreload extension is already loaded. To reload it, use:\n",
|
|
|
|
|
|
" %reload_ext autoreload\n"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"execution_count": 32
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "a79cafb06a7e0e43",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T17:34:23.845554Z",
|
|
|
|
|
|
"start_time": "2025-02-14T17:33:40.211865Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"source": [
|
|
|
|
|
|
"print('daily data')\n",
|
|
|
|
|
|
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
|
|
|
|
|
|
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
|
|
|
|
|
|
" df=None)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"print('daily basic')\n",
|
|
|
|
|
|
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
|
|
|
|
|
|
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" 'is_st'], df=df, join='inner')\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
"print('stk limit')\n",
|
|
|
|
|
|
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
|
|
|
|
|
|
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
|
|
|
|
|
|
" df=df)\n",
|
|
|
|
|
|
"print('money flow')\n",
|
|
|
|
|
|
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
|
|
|
|
|
|
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
|
|
|
|
|
|
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
|
|
|
|
|
|
" df=df)"
|
2025-02-15 23:33:34 +08:00
|
|
|
|
],
|
|
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"daily data\n",
|
|
|
|
|
|
"daily basic\n",
|
|
|
|
|
|
"stk limit\n",
|
|
|
|
|
|
"money flow\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 33
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "c4e9e1d31da6dba6",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"end_time": "2025-02-14T17:34:23.956070Z",
|
|
|
|
|
|
"start_time": "2025-02-14T17:34:23.878555Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"source": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"origin_columns = df.columns.tolist()\n",
|
|
|
|
|
|
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio']]"
|
|
|
|
|
|
],
|
|
|
|
|
|
"outputs": [],
|
|
|
|
|
|
"execution_count": 34
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "a735bc02ceb4d872",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"jupyter": {
|
|
|
|
|
|
"source_hidden": true
|
2025-02-15 23:33:34 +08:00
|
|
|
|
},
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T17:34:24.082032Z",
|
|
|
|
|
|
"start_time": "2025-02-14T17:34:23.990152Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
|
"import talib\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def get_technical_factor(df):\n",
|
|
|
|
|
|
" # 按股票和日期排序\n",
|
|
|
|
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|
|
|
|
|
" grouped = df.groupby('ts_code', group_keys=False)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算 up 和 down\n",
|
|
|
|
|
|
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
|
|
|
|
|
|
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算 ATR\n",
|
|
|
|
|
|
" df['atr_14'] = grouped.apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['atr_6'] = grouped.apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算 OBV 及其均线\n",
|
|
|
|
|
|
" df['obv'] = grouped.apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['maobv_6'] = grouped.apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算 RSI\n",
|
|
|
|
|
|
" df['rsi_3'] = grouped.apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['rsi_6'] = grouped.apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['rsi_9'] = grouped.apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算 return_10 和 return_20\n",
|
|
|
|
|
|
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
|
|
|
|
|
|
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算 avg_close_5\n",
|
|
|
|
|
|
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算标准差指标\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
|
|
|
|
|
|
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
|
|
|
|
|
|
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
|
|
|
|
|
|
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
|
|
|
|
|
|
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算比值指标\n",
|
|
|
|
|
|
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
|
|
|
|
|
|
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算标准差差值\n",
|
|
|
|
|
|
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" return df\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def get_act_factor(df):\n",
|
|
|
|
|
|
" # 按股票和日期排序\n",
|
|
|
|
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|
|
|
|
|
" grouped = df.groupby('ts_code', group_keys=False)\n",
|
|
|
|
|
|
" # 计算 EMA 指标\n",
|
|
|
|
|
|
" df['ema_5'] = grouped['close'].apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['ema_13'] = grouped['close'].apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['ema_20'] = grouped['close'].apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['ema_60'] = grouped['close'].apply(\n",
|
|
|
|
|
|
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
|
|
|
|
|
|
" df['act_factor1'] = grouped['ema_5'].apply(\n",
|
|
|
|
|
|
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['act_factor2'] = grouped['ema_13'].apply(\n",
|
|
|
|
|
|
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['act_factor3'] = grouped['ema_20'].apply(\n",
|
|
|
|
|
|
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" df['act_factor4'] = grouped['ema_60'].apply(\n",
|
|
|
|
|
|
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 计算 act_factor5 和 act_factor6\n",
|
|
|
|
|
|
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
|
|
|
|
|
|
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(df['act_factor1']**2 + df['act_factor2']**2)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 根据 trade_date 截面计算排名\n",
|
|
|
|
|
|
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
|
|
|
|
|
|
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
|
|
|
|
|
|
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" return df\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def get_money_flow_factor(df):\n",
|
|
|
|
|
|
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
|
|
|
|
|
|
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
|
|
|
|
|
|
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
|
|
|
|
|
|
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
|
|
|
|
|
|
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" return df\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def get_alpha_factor(df):\n",
|
|
|
|
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|
|
|
|
|
" grouped = df.groupby('ts_code')\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
" # alpha_003: (close - open) / (high - low)\n",
|
|
|
|
|
|
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
|
|
|
|
|
|
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
|
|
|
|
|
|
" 0)\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
|
|
|
|
|
|
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
|
|
|
|
|
|
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" return df\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def get_future_data(df):\n",
|
|
|
|
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|
|
|
|
|
" # 预先对 ts_code 分组,使用 transform 保持原 DataFrame 形状\n",
|
|
|
|
|
|
" grouped = df.groupby('ts_code')\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['future_return1'] = (grouped['close'].transform(lambda x: x.shift(-1)) - df['close']) / df['close']\n",
|
|
|
|
|
|
" df['future_return2'] = (grouped['open'].transform(lambda x: x.shift(-2)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
" df['future_return3'] = (grouped['close'].transform(lambda x: x.shift(-2)) - grouped['close'].transform(lambda x: x.shift(-1))) / grouped['close'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
" df['future_return4'] = (grouped['close'].transform(lambda x: x.shift(-2)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
" df['future_return5'] = (grouped['close'].transform(lambda x: x.shift(-5)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
" df['future_return6'] = (grouped['close'].transform(lambda x: x.shift(-10)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
" df['future_return7'] = (grouped['close'].transform(lambda x: x.shift(-20)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['future_close1'] = (grouped['close'].transform(lambda x: x.shift(-1)) - df['close']) / df['close']\n",
|
|
|
|
|
|
" df['future_close2'] = (grouped['close'].transform(lambda x: x.shift(-2)) - df['close']) / df['close']\n",
|
|
|
|
|
|
" df['future_close3'] = (grouped['close'].transform(lambda x: x.shift(-3)) - df['close']) / df['close']\n",
|
|
|
|
|
|
" df['future_close4'] = (grouped['close'].transform(lambda x: x.shift(-4)) - df['close']) / df['close']\n",
|
|
|
|
|
|
" df['future_close5'] = (grouped['close'].transform(lambda x: x.shift(-5)) - df['close']) / df['close']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['future_af11'] = grouped['act_factor1'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
" df['future_af12'] = grouped['act_factor1'].transform(lambda x: x.shift(-2))\n",
|
|
|
|
|
|
" df['future_af13'] = grouped['act_factor1'].transform(lambda x: x.shift(-3))\n",
|
|
|
|
|
|
" df['future_af14'] = grouped['act_factor1'].transform(lambda x: x.shift(-4))\n",
|
|
|
|
|
|
" df['future_af15'] = grouped['act_factor1'].transform(lambda x: x.shift(-5))\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['future_af21'] = grouped['act_factor2'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
" df['future_af22'] = grouped['act_factor2'].transform(lambda x: x.shift(-2))\n",
|
|
|
|
|
|
" df['future_af23'] = grouped['act_factor2'].transform(lambda x: x.shift(-3))\n",
|
|
|
|
|
|
" df['future_af24'] = grouped['act_factor2'].transform(lambda x: x.shift(-4))\n",
|
|
|
|
|
|
" df['future_af25'] = grouped['act_factor2'].transform(lambda x: x.shift(-5))\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['future_af31'] = grouped['act_factor3'].transform(lambda x: x.shift(-1))\n",
|
|
|
|
|
|
" df['future_af32'] = grouped['act_factor3'].transform(lambda x: x.shift(-2))\n",
|
|
|
|
|
|
" df['future_af33'] = grouped['act_factor3'].transform(lambda x: x.shift(-3))\n",
|
|
|
|
|
|
" df['future_af34'] = grouped['act_factor3'].transform(lambda x: x.shift(-4))\n",
|
|
|
|
|
|
" df['future_af35'] = grouped['act_factor3'].transform(lambda x: x.shift(-5))\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" return df\n"
|
2025-02-15 23:33:34 +08:00
|
|
|
|
],
|
|
|
|
|
|
"outputs": [],
|
|
|
|
|
|
"execution_count": 35
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "53f86ddc0677a6d7",
|
|
|
|
|
|
"metadata": {
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"scrolled": true,
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"ExecuteTime": {
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"end_time": "2025-02-14T17:36:10.108321Z",
|
|
|
|
|
|
"start_time": "2025-02-14T17:34:24.118116Z"
|
|
|
|
|
|
}
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"source": [
|
|
|
|
|
|
"df = get_technical_factor(df)\n",
|
|
|
|
|
|
"df = get_act_factor(df)\n",
|
|
|
|
|
|
"df = get_money_flow_factor(df)\n",
|
|
|
|
|
|
"df = get_alpha_factor(df)\n",
|
|
|
|
|
|
"# df = get_future_data(df)\n",
|
|
|
|
|
|
"# df = df.drop(columns=origin_columns)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"print(df.info())"
|
|
|
|
|
|
],
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"Index: 8296325 entries, 1962 to 8296323\n",
|
|
|
|
|
|
"Data columns (total 65 columns):\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" # Column Dtype \n",
|
|
|
|
|
|
"--- ------ ----- \n",
|
|
|
|
|
|
" 0 ts_code object \n",
|
|
|
|
|
|
" 1 trade_date datetime64[ns]\n",
|
|
|
|
|
|
" 2 open float64 \n",
|
|
|
|
|
|
" 3 close float64 \n",
|
|
|
|
|
|
" 4 high float64 \n",
|
|
|
|
|
|
" 5 low float64 \n",
|
|
|
|
|
|
" 6 vol float64 \n",
|
|
|
|
|
|
" 7 turnover_rate float64 \n",
|
|
|
|
|
|
" 8 pe_ttm float64 \n",
|
|
|
|
|
|
" 9 circ_mv float64 \n",
|
|
|
|
|
|
" 10 volume_ratio float64 \n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" 11 is_st bool \n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" 12 up_limit float64 \n",
|
|
|
|
|
|
" 13 down_limit float64 \n",
|
|
|
|
|
|
" 14 buy_sm_vol float64 \n",
|
|
|
|
|
|
" 15 sell_sm_vol float64 \n",
|
|
|
|
|
|
" 16 buy_lg_vol float64 \n",
|
|
|
|
|
|
" 17 sell_lg_vol float64 \n",
|
|
|
|
|
|
" 18 buy_elg_vol float64 \n",
|
|
|
|
|
|
" 19 sell_elg_vol float64 \n",
|
|
|
|
|
|
" 20 net_mf_vol float64 \n",
|
|
|
|
|
|
" 21 up float64 \n",
|
|
|
|
|
|
" 22 down float64 \n",
|
|
|
|
|
|
" 23 atr_14 float64 \n",
|
|
|
|
|
|
" 24 atr_6 float64 \n",
|
|
|
|
|
|
" 25 obv float64 \n",
|
|
|
|
|
|
" 26 maobv_6 float64 \n",
|
|
|
|
|
|
" 27 obv-maobv_6 float64 \n",
|
|
|
|
|
|
" 28 rsi_3 float64 \n",
|
|
|
|
|
|
" 29 rsi_6 float64 \n",
|
|
|
|
|
|
" 30 rsi_9 float64 \n",
|
|
|
|
|
|
" 31 return_10 float64 \n",
|
|
|
|
|
|
" 32 return_20 float64 \n",
|
|
|
|
|
|
" 33 avg_close_5 float64 \n",
|
|
|
|
|
|
" 34 std_return_5 float64 \n",
|
|
|
|
|
|
" 35 std_return_15 float64 \n",
|
|
|
|
|
|
" 36 std_return_25 float64 \n",
|
|
|
|
|
|
" 37 std_return_90 float64 \n",
|
|
|
|
|
|
" 38 std_return_90_2 float64 \n",
|
|
|
|
|
|
" 39 std_return_5 / std_return_90 float64 \n",
|
|
|
|
|
|
" 40 std_return_5 / std_return_25 float64 \n",
|
|
|
|
|
|
" 41 std_return_90 - std_return_90_2 float64 \n",
|
|
|
|
|
|
" 42 ema_5 float64 \n",
|
|
|
|
|
|
" 43 ema_13 float64 \n",
|
|
|
|
|
|
" 44 ema_20 float64 \n",
|
|
|
|
|
|
" 45 ema_60 float64 \n",
|
|
|
|
|
|
" 46 act_factor1 float64 \n",
|
|
|
|
|
|
" 47 act_factor2 float64 \n",
|
|
|
|
|
|
" 48 act_factor3 float64 \n",
|
|
|
|
|
|
" 49 act_factor4 float64 \n",
|
|
|
|
|
|
" 50 act_factor5 float64 \n",
|
|
|
|
|
|
" 51 act_factor6 float64 \n",
|
|
|
|
|
|
" 52 rank_act_factor1 float64 \n",
|
|
|
|
|
|
" 53 rank_act_factor2 float64 \n",
|
|
|
|
|
|
" 54 rank_act_factor3 float64 \n",
|
|
|
|
|
|
" 55 active_buy_volume_large float64 \n",
|
|
|
|
|
|
" 56 active_buy_volume_big float64 \n",
|
|
|
|
|
|
" 57 active_buy_volume_small float64 \n",
|
|
|
|
|
|
" 58 buy_lg_vol_minus_sell_lg_vol float64 \n",
|
|
|
|
|
|
" 59 buy_elg_vol_minus_sell_elg_vol float64 \n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" 60 log(circ_mv) float64 \n",
|
|
|
|
|
|
" 61 alpha_022 float64 \n",
|
|
|
|
|
|
" 62 alpha_003 float64 \n",
|
|
|
|
|
|
" 63 alpha_007 float64 \n",
|
|
|
|
|
|
" 64 alpha_013 float64 \n",
|
|
|
|
|
|
"dtypes: bool(1), datetime64[ns](1), float64(62), object(1)\n",
|
|
|
|
|
|
"memory usage: 4.0+ GB\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"None\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"execution_count": 36
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "dbe2fd8021b9417f",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"jupyter": {
|
|
|
|
|
|
"source_hidden": true
|
|
|
|
|
|
},
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"scrolled": true,
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T17:36:52.088380Z",
|
|
|
|
|
|
"start_time": "2025-02-14T17:36:10.512236Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
2025-02-15 23:33:34 +08:00
|
|
|
|
},
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"source": [
|
|
|
|
|
|
"def filter_data(df):\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
|
|
|
|
|
|
" df = df[~df['is_st']]\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
|
|
|
|
|
|
" df = df[~df['ts_code'].str.startswith('30')]\n",
|
|
|
|
|
|
" df = df[~df['ts_code'].str.startswith('68')]\n",
|
|
|
|
|
|
" df = df[~df['ts_code'].str.startswith('8')]\n",
|
|
|
|
|
|
" df = df.reset_index(drop=True)\n",
|
|
|
|
|
|
" return df\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"df = filter_data(df)\n",
|
|
|
|
|
|
"print(df.info())"
|
2025-02-15 23:33:34 +08:00
|
|
|
|
],
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|
|
|
|
|
"RangeIndex: 5453316 entries, 0 to 5453315\n",
|
|
|
|
|
|
"Data columns (total 65 columns):\n",
|
|
|
|
|
|
" # Column Dtype \n",
|
|
|
|
|
|
"--- ------ ----- \n",
|
|
|
|
|
|
" 0 ts_code object \n",
|
|
|
|
|
|
" 1 trade_date datetime64[ns]\n",
|
|
|
|
|
|
" 2 open float64 \n",
|
|
|
|
|
|
" 3 close float64 \n",
|
|
|
|
|
|
" 4 high float64 \n",
|
|
|
|
|
|
" 5 low float64 \n",
|
|
|
|
|
|
" 6 vol float64 \n",
|
|
|
|
|
|
" 7 turnover_rate float64 \n",
|
|
|
|
|
|
" 8 pe_ttm float64 \n",
|
|
|
|
|
|
" 9 circ_mv float64 \n",
|
|
|
|
|
|
" 10 volume_ratio float64 \n",
|
|
|
|
|
|
" 11 is_st bool \n",
|
|
|
|
|
|
" 12 up_limit float64 \n",
|
|
|
|
|
|
" 13 down_limit float64 \n",
|
|
|
|
|
|
" 14 buy_sm_vol float64 \n",
|
|
|
|
|
|
" 15 sell_sm_vol float64 \n",
|
|
|
|
|
|
" 16 buy_lg_vol float64 \n",
|
|
|
|
|
|
" 17 sell_lg_vol float64 \n",
|
|
|
|
|
|
" 18 buy_elg_vol float64 \n",
|
|
|
|
|
|
" 19 sell_elg_vol float64 \n",
|
|
|
|
|
|
" 20 net_mf_vol float64 \n",
|
|
|
|
|
|
" 21 up float64 \n",
|
|
|
|
|
|
" 22 down float64 \n",
|
|
|
|
|
|
" 23 atr_14 float64 \n",
|
|
|
|
|
|
" 24 atr_6 float64 \n",
|
|
|
|
|
|
" 25 obv float64 \n",
|
|
|
|
|
|
" 26 maobv_6 float64 \n",
|
|
|
|
|
|
" 27 obv-maobv_6 float64 \n",
|
|
|
|
|
|
" 28 rsi_3 float64 \n",
|
|
|
|
|
|
" 29 rsi_6 float64 \n",
|
|
|
|
|
|
" 30 rsi_9 float64 \n",
|
|
|
|
|
|
" 31 return_10 float64 \n",
|
|
|
|
|
|
" 32 return_20 float64 \n",
|
|
|
|
|
|
" 33 avg_close_5 float64 \n",
|
|
|
|
|
|
" 34 std_return_5 float64 \n",
|
|
|
|
|
|
" 35 std_return_15 float64 \n",
|
|
|
|
|
|
" 36 std_return_25 float64 \n",
|
|
|
|
|
|
" 37 std_return_90 float64 \n",
|
|
|
|
|
|
" 38 std_return_90_2 float64 \n",
|
|
|
|
|
|
" 39 std_return_5 / std_return_90 float64 \n",
|
|
|
|
|
|
" 40 std_return_5 / std_return_25 float64 \n",
|
|
|
|
|
|
" 41 std_return_90 - std_return_90_2 float64 \n",
|
|
|
|
|
|
" 42 ema_5 float64 \n",
|
|
|
|
|
|
" 43 ema_13 float64 \n",
|
|
|
|
|
|
" 44 ema_20 float64 \n",
|
|
|
|
|
|
" 45 ema_60 float64 \n",
|
|
|
|
|
|
" 46 act_factor1 float64 \n",
|
|
|
|
|
|
" 47 act_factor2 float64 \n",
|
|
|
|
|
|
" 48 act_factor3 float64 \n",
|
|
|
|
|
|
" 49 act_factor4 float64 \n",
|
|
|
|
|
|
" 50 act_factor5 float64 \n",
|
|
|
|
|
|
" 51 act_factor6 float64 \n",
|
|
|
|
|
|
" 52 rank_act_factor1 float64 \n",
|
|
|
|
|
|
" 53 rank_act_factor2 float64 \n",
|
|
|
|
|
|
" 54 rank_act_factor3 float64 \n",
|
|
|
|
|
|
" 55 active_buy_volume_large float64 \n",
|
|
|
|
|
|
" 56 active_buy_volume_big float64 \n",
|
|
|
|
|
|
" 57 active_buy_volume_small float64 \n",
|
|
|
|
|
|
" 58 buy_lg_vol_minus_sell_lg_vol float64 \n",
|
|
|
|
|
|
" 59 buy_elg_vol_minus_sell_elg_vol float64 \n",
|
|
|
|
|
|
" 60 log(circ_mv) float64 \n",
|
|
|
|
|
|
" 61 alpha_022 float64 \n",
|
|
|
|
|
|
" 62 alpha_003 float64 \n",
|
|
|
|
|
|
" 63 alpha_007 float64 \n",
|
|
|
|
|
|
" 64 alpha_013 float64 \n",
|
|
|
|
|
|
"dtypes: bool(1), datetime64[ns](1), float64(62), object(1)\n",
|
|
|
|
|
|
"memory usage: 2.6+ GB\n",
|
|
|
|
|
|
"None\n"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"execution_count": 37
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "5f3d9aece75318cd",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T17:37:16.249760Z",
|
|
|
|
|
|
"start_time": "2025-02-14T17:36:52.387540Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"source": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"train_data = df[df['trade_date'] <= '2023-01-01']\n",
|
|
|
|
|
|
"test_data = df[df['trade_date'] >= '2023-01-01']\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"train_data = train_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
|
|
|
|
|
|
"test_data = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"train_data = get_future_data(train_data)\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
"feature_columns = [col for col in df.columns if col not in ['trade_date',\n",
|
|
|
|
|
|
" 'ts_code',\n",
|
|
|
|
|
|
" 'label']]\n",
|
|
|
|
|
|
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
|
|
|
|
|
|
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
|
|
|
|
|
|
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# for column in [column for column in train_data.columns if 'future' in column]:\n",
|
|
|
|
|
|
"# label_index = neutralize_labels(train_data[column], train_data, feature_columns, z_threshold=3, method='regression')\n",
|
|
|
|
|
|
"# train_data = train_data[label_index]\n",
|
|
|
|
|
|
"# label_index = neutralize_labels(test_data[column], test_data, feature_columns, z_threshold=3, method='regression')\n",
|
|
|
|
|
|
"# test_data = test_data[label_index]\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"df = df[['ts_code', 'trade_date', 'open', 'close']]\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"print(feature_columns)\n",
|
|
|
|
|
|
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|
|
|
|
|
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|
|
|
|
|
"print(len(test_data))\n",
|
|
|
|
|
|
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|
|
|
|
|
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")"
|
2025-02-15 23:33:34 +08:00
|
|
|
|
],
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"['turnover_rate', 'pe_ttm', 'volume_ratio', 'up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'log(circ_mv)', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013']\n",
|
|
|
|
|
|
"最小日期: 2017-01-03\n",
|
|
|
|
|
|
"最大日期: 2022-12-30\n",
|
|
|
|
|
|
"507000\n",
|
|
|
|
|
|
"最小日期: 2023-01-03\n",
|
|
|
|
|
|
"最大日期: 2025-02-12\n"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"execution_count": 38
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "f4f16d63ad18d1bc",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T19:01:09.964760Z",
|
|
|
|
|
|
"start_time": "2025-02-14T19:01:05.892897Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"source": [
|
|
|
|
|
|
"def get_qcuts(series, quantiles):\n",
|
|
|
|
|
|
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
|
|
|
|
|
|
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"window = 5\n",
|
|
|
|
|
|
"quantiles = 20\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"def calculate_risk_adjusted_target(df, days=5):\n",
|
|
|
|
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
|
|
|
|
|
|
" df['past_close'] = df.groupby('ts_code')['close'].shift(days)\n",
|
|
|
|
|
|
" df['future_return'] = (df['future_close'] - df['past_close']) / df['past_close']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(level=0, drop=True)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" df['sharpe_ratio'] = df['future_return'] / df['future_volatility']\n",
|
|
|
|
|
|
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" return df['sharpe_ratio']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
"def get_label(df):\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" # labels = df['future_af13'] - df['act_factor1']\n",
|
|
|
|
|
|
" # labels = df['future_close5']\n",
|
|
|
|
|
|
" # labels = df['future_af11']\n",
|
|
|
|
|
|
" # labels = df['ema_5'].shift(-1) - df['close']\n",
|
|
|
|
|
|
" # labels = df['future_af15']\n",
|
|
|
|
|
|
" df['label'] = calculate_risk_adjusted_target(df, days=5)\n",
|
|
|
|
|
|
" lower_percentile = df['label'].quantile(0.01) # 1%分位数\n",
|
|
|
|
|
|
" upper_percentile = df['label'].quantile(0.99) # 99%分位数\n",
|
|
|
|
|
|
" labels = df['label'].clip(lower=lower_percentile, upper=upper_percentile)\n",
|
|
|
|
|
|
" # labels = calculate_risk_adjusted_return(df, days=3, history_days=3, method='ratio')\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" return labels\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
|
|
|
|
|
|
"# train_data = train_data.dropna(subset=feature_columns)\n",
|
|
|
|
|
|
"train_data = train_data.dropna(subset=feature_columns)\n",
|
|
|
|
|
|
"\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"train_data['label'] = get_label(train_data)\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"# test_data['label'] = get_label(test_data)\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"# train_data = train_data.dropna(subset=['label'])\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"# test_data = test_data.dropna(subset=['label'])\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"# train_data = train_data.replace([np.inf, -np.inf], np.nan).dropna()\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"# test_data = test_data.replace([np.inf, -np.inf], np.nan).dropna()\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"train_data = train_data.dropna(subset=['label'])\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"train_data = train_data.reset_index(drop=True)\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
"test_data = test_data.dropna(subset=feature_columns)\n",
|
|
|
|
|
|
"test_data = test_data.reset_index(drop=True)\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
"print(len(train_data))\n",
|
|
|
|
|
|
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|
|
|
|
|
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|
|
|
|
|
"print(len(test_data))\n",
|
|
|
|
|
|
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|
|
|
|
|
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n"
|
2025-02-15 23:33:34 +08:00
|
|
|
|
],
|
|
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"1067937\n",
|
|
|
|
|
|
"最小日期: 2017-06-20\n",
|
|
|
|
|
|
"最大日期: 2022-11-29\n",
|
|
|
|
|
|
"403686\n",
|
|
|
|
|
|
"最小日期: 2023-01-03\n",
|
|
|
|
|
|
"最大日期: 2025-02-12\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 92
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "8f134d435f71e9e2",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"jupyter": {
|
|
|
|
|
|
"source_hidden": true
|
2025-02-15 23:33:34 +08:00
|
|
|
|
},
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T19:01:10.101871Z",
|
|
|
|
|
|
"start_time": "2025-02-14T19:01:10.011761Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"source": [
|
|
|
|
|
|
"import lightgbm as lgb\n",
|
|
|
|
|
|
"import numpy as np\n",
|
|
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
|
"import optuna\n",
|
|
|
|
|
|
"from sklearn.model_selection import KFold\n",
|
|
|
|
|
|
"from sklearn.metrics import mean_absolute_error\n",
|
|
|
|
|
|
"import os\n",
|
|
|
|
|
|
"import json\n",
|
|
|
|
|
|
"import pickle\n",
|
|
|
|
|
|
"import hashlib\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def objective(trial, X, y, num_boost_round, params):\n",
|
|
|
|
|
|
" # 参数网格\n",
|
|
|
|
|
|
" X, y = X.reset_index(drop=True), y.reset_index(drop=True)\n",
|
|
|
|
|
|
" param_grid = {\n",
|
|
|
|
|
|
" \"n_estimators\": trial.suggest_categorical(\"n_estimators\", [10000]),\n",
|
|
|
|
|
|
" \"learning_rate\": trial.suggest_float(\"learning_rate\", 0.01, 0.3),\n",
|
|
|
|
|
|
" \"num_leaves\": trial.suggest_int(\"num_leaves\", 20, 3000, step=25),\n",
|
|
|
|
|
|
" \"max_depth\": trial.suggest_int(\"max_depth\", 3, 16),\n",
|
|
|
|
|
|
" \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 200, 10000, step=100),\n",
|
|
|
|
|
|
" \"lambda_l1\": trial.suggest_int(\"lambda_l1\", 0, 100, step=5),\n",
|
|
|
|
|
|
" \"lambda_l2\": trial.suggest_int(\"lambda_l2\", 0, 100, step=5),\n",
|
|
|
|
|
|
" \"min_gain_to_split\": trial.suggest_float(\"min_gain_to_split\", 0, 15),\n",
|
|
|
|
|
|
" \"bagging_fraction\": trial.suggest_float(\"bagging_fraction\", 0.2, 0.95, step=0.1),\n",
|
|
|
|
|
|
" \"bagging_freq\": trial.suggest_categorical(\"bagging_freq\", [1]),\n",
|
|
|
|
|
|
" \"feature_fraction\": trial.suggest_float(\"feature_fraction\", 0.2, 0.95, step=0.1),\n",
|
|
|
|
|
|
" \"random_state\": 1,\n",
|
|
|
|
|
|
" \"objective\": 'regression',\n",
|
|
|
|
|
|
" 'verbosity': -1\n",
|
|
|
|
|
|
" }\n",
|
|
|
|
|
|
" # 5折交叉验证\n",
|
|
|
|
|
|
" cv = KFold(n_splits=5, shuffle=False)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" cv_scores = np.empty(5)\n",
|
|
|
|
|
|
" for idx, (train_idx, test_idx) in enumerate(cv.split(X, y)):\n",
|
|
|
|
|
|
" X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]\n",
|
|
|
|
|
|
" y_train, y_test = y[train_idx], y[test_idx]\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # LGBM建模\n",
|
|
|
|
|
|
" model = lgb.LGBMRegressor(**param_grid, num_boost_round=num_boost_round)\n",
|
|
|
|
|
|
" model.fit(\n",
|
|
|
|
|
|
" X_train,\n",
|
|
|
|
|
|
" y_train,\n",
|
|
|
|
|
|
" eval_set=[(X_test, y_test)],\n",
|
|
|
|
|
|
" eval_metric=\"l2\",\n",
|
|
|
|
|
|
" callbacks=[\n",
|
|
|
|
|
|
" # LightGBMPruningCallback(trial, \"l2\"),\n",
|
|
|
|
|
|
" lgb.early_stopping(50, first_metric_only=True),\n",
|
|
|
|
|
|
" lgb.log_evaluation(period=-1)\n",
|
|
|
|
|
|
" ],\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
" # 模型预测\n",
|
|
|
|
|
|
" preds = model.predict(X_test)\n",
|
|
|
|
|
|
" # 优化指标logloss最小\n",
|
|
|
|
|
|
" cv_scores[idx] = mean_absolute_error(y_test, preds)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" return np.mean(cv_scores)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def generate_key(params, feature_columns, num_boost_round):\n",
|
|
|
|
|
|
" key_data = {\n",
|
|
|
|
|
|
" \"params\": params,\n",
|
|
|
|
|
|
" \"feature_columns\": feature_columns,\n",
|
|
|
|
|
|
" \"num_boost_round\": num_boost_round\n",
|
|
|
|
|
|
" }\n",
|
|
|
|
|
|
" # 转换成排序后的 JSON 字符串,再生成 md5 hash\n",
|
|
|
|
|
|
" key_str = json.dumps(key_data, sort_keys=True)\n",
|
|
|
|
|
|
" return hashlib.md5(key_str.encode('utf-8')).hexdigest()\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"def train_light_model(df, params, feature_columns, callbacks, evals,\n",
|
|
|
|
|
|
" print_feature_importance=True, num_boost_round=100,\n",
|
|
|
|
|
|
" use_optuna=False):\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False]) # 按日期升序、标签降序排序\n",
|
|
|
|
|
|
" df_sorted = df_sorted.sort_values(by='trade_date')\n",
|
|
|
|
|
|
" unique_dates = df_sorted['trade_date'].unique()\n",
|
|
|
|
|
|
" val_date_count = int(len(unique_dates) * 0.1)\n",
|
|
|
|
|
|
" val_dates = unique_dates[-val_date_count:]\n",
|
|
|
|
|
|
" val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n",
|
|
|
|
|
|
" train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 获取训练集和验证集的样本\n",
|
|
|
|
|
|
" train_df = df_sorted.iloc[train_indices]\n",
|
|
|
|
|
|
" val_df = df_sorted.iloc[val_indices]\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" X_train = train_df[feature_columns]\n",
|
|
|
|
|
|
" y_train = train_df['label']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" X_val = val_df[feature_columns]\n",
|
|
|
|
|
|
" y_val = val_df['label']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" train_data = lgb.Dataset(X_train, label=y_train)\n",
|
|
|
|
|
|
" val_data = lgb.Dataset(X_val, label=y_val)\n",
|
|
|
|
|
|
" if use_optuna:\n",
|
|
|
|
|
|
" # study = optuna.create_study(direction='minimize' if classify else 'maximize')\n",
|
|
|
|
|
|
" study = optuna.create_study(direction='minimize')\n",
|
|
|
|
|
|
" study.optimize(lambda trial: objective(trial, X_train, y_train, num_boost_round, params), n_trials=20)\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" print(f\"Best parameters: {study.best_trial.params}\")\n",
|
|
|
|
|
|
" print(f\"Best score: {study.best_trial.value}\")\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" params.update(study.best_trial.params)\n",
|
|
|
|
|
|
" model = lgb.train(\n",
|
|
|
|
|
|
" params, train_data, num_boost_round=num_boost_round,\n",
|
|
|
|
|
|
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
|
|
|
|
|
|
" callbacks=callbacks\n",
|
|
|
|
|
|
" )\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" # 打印特征重要性(如果需要)\n",
|
|
|
|
|
|
" if print_feature_importance:\n",
|
|
|
|
|
|
" lgb.plot_metric(evals)\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" # lgb.plot_tree(model, figsize=(20, 8))\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
|
|
|
|
|
|
" plt.show()\n",
|
|
|
|
|
|
" # with open(cache_file, 'wb') as f:\n",
|
|
|
|
|
|
" # pickle.dump({'key': cache_key,\n",
|
|
|
|
|
|
" # 'model': model,\n",
|
|
|
|
|
|
" # 'feature_columns': feature_columns}, f)\n",
|
|
|
|
|
|
" # print(\"模型训练完成并已保存缓存。\")\n",
|
|
|
|
|
|
" return model\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"from catboost import CatBoostRegressor\n",
|
|
|
|
|
|
"import pandas as pd\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"def train_catboost(df, feature_columns, params=None):\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
|
" 训练 CatBoost 排序模型\n",
|
|
|
|
|
|
" - df: 包含因子、date、instrument 和 label 的 DataFrame\n",
|
|
|
|
|
|
" - num_boost_round: 训练的轮数\n",
|
|
|
|
|
|
" - print_feature_importance: 是否打印特征重要性\n",
|
|
|
|
|
|
" - plot: 是否绘制特征重要性图\n",
|
|
|
|
|
|
" - split_date: 用于划分训练集和验证集的日期(比如 '2020-01-01')\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" 返回训练好的模型\n",
|
|
|
|
|
|
" \"\"\"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False])\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" df_sorted = df_sorted.sort_values(by='trade_date')\n",
|
|
|
|
|
|
" unique_dates = df_sorted['trade_date'].unique()\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" val_date_count = int(len(unique_dates) * 0.1)\n",
|
|
|
|
|
|
" val_dates = unique_dates[-val_date_count:]\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n",
|
|
|
|
|
|
" train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
" # 获取训练集和验证集的样本\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" train_df = df_sorted.iloc[train_indices].sort_values(by=['trade_date', 'label'], ascending=[True, False])\n",
|
|
|
|
|
|
" val_df = df_sorted.iloc[val_indices].sort_values(by=['trade_date', 'label'], ascending=[True, False])\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"\n",
|
|
|
|
|
|
" X_train = train_df[feature_columns]\n",
|
|
|
|
|
|
" y_train = train_df['label']\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" X_val = val_df[feature_columns]\n",
|
|
|
|
|
|
" y_val = val_df['label']\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" model = CatBoostRegressor(**params)\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" model.fit(X_train,\n",
|
|
|
|
|
|
" y_train,\n",
|
|
|
|
|
|
" eval_set=(X_val, y_val))\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" return model"
|
2025-02-15 23:33:34 +08:00
|
|
|
|
],
|
|
|
|
|
|
"outputs": [],
|
|
|
|
|
|
"execution_count": 93
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "4a4542e1ed6afe7d",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"end_time": "2025-02-14T19:02:42.420508Z",
|
|
|
|
|
|
"start_time": "2025-02-14T19:02:42.341401Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"source": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"def max_drawdown_loss(y_true, y_pred):\n",
|
|
|
|
|
|
" # y_true和y_pred表示资产的实际和预测回报序列\n",
|
|
|
|
|
|
" cumulative_return = np.cumsum(y_pred) # 计算累积回报\n",
|
|
|
|
|
|
" peak = np.maximum.accumulate(cumulative_return)\n",
|
|
|
|
|
|
" drawdown = (cumulative_return - peak) / peak # 计算回撤\n",
|
|
|
|
|
|
" max_drawdown = np.min(drawdown) # 最大回撤\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
" loss = -max_drawdown # 最大回撤越大,损失越小,取负数使得回撤最小化\n",
|
|
|
|
|
|
" return loss, np.zeros_like(loss) # 返回损失和零梯度\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"light_params = {\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" # 'objective': 'regression',\n",
|
|
|
|
|
|
" # 'metric': 'l2',\n",
|
|
|
|
|
|
" 'objective': 'quantile', # 分位回归\n",
|
|
|
|
|
|
" 'metric': 'quantile', # 使用 quantile 作为评估指标\n",
|
|
|
|
|
|
" 'alpha': 0.75, # 90% 分位数\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" 'learning_rate': 0.05,\n",
|
|
|
|
|
|
" 'is_unbalance': True,\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" 'num_leaves': 64,\n",
|
|
|
|
|
|
" 'min_data_in_leaf': 128,\n",
|
|
|
|
|
|
" 'max_depth': 6,\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" 'max_bin': 1024,\n",
|
|
|
|
|
|
" 'feature_fraction': 0.7,\n",
|
|
|
|
|
|
" 'bagging_fraction': 0.7,\n",
|
|
|
|
|
|
" 'bagging_freq': 5,\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
" 'lambda_l1': 1,\n",
|
|
|
|
|
|
" 'lambda_l2': 1,\n",
|
|
|
|
|
|
" # 'boosting_type': 'dart',\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
" 'verbosity': -1\n",
|
|
|
|
|
|
"}"
|
2025-02-15 23:33:34 +08:00
|
|
|
|
],
|
|
|
|
|
|
"outputs": [],
|
|
|
|
|
|
"execution_count": 96
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "beeb098799ecfa6a",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"end_time": "2025-02-14T19:05:21.315576Z",
|
|
|
|
|
|
"start_time": "2025-02-14T19:02:42.469389Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
},
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"source": [
|
|
|
|
|
|
"print('train data size: ', len(train_data))\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"evals = {}\n",
|
|
|
|
|
|
"light_model = train_light_model(train_data, light_params, feature_columns,\n",
|
|
|
|
|
|
" [lgb.log_evaluation(period=500),\n",
|
|
|
|
|
|
" lgb.callback.record_evaluation(evals),\n",
|
|
|
|
|
|
" lgb.early_stopping(50, first_metric_only=True)\n",
|
|
|
|
|
|
" ], evals,\n",
|
|
|
|
|
|
" num_boost_round=10000, use_optuna=False,\n",
|
|
|
|
|
|
" print_feature_importance=True)"
|
|
|
|
|
|
],
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"train data size: 1067937\n",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"Training until validation scores don't improve for 50 rounds\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"[500]\ttrain's quantile: 0.766393\tvalid's quantile: 0.783404\n",
|
|
|
|
|
|
"[1000]\ttrain's quantile: 0.749886\tvalid's quantile: 0.775331\n",
|
|
|
|
|
|
"[1500]\ttrain's quantile: 0.739319\tvalid's quantile: 0.771256\n",
|
|
|
|
|
|
"[2000]\ttrain's quantile: 0.731186\tvalid's quantile: 0.768522\n",
|
|
|
|
|
|
"[2500]\ttrain's quantile: 0.724016\tvalid's quantile: 0.766611\n",
|
|
|
|
|
|
"[3000]\ttrain's quantile: 0.717949\tvalid's quantile: 0.765331\n",
|
|
|
|
|
|
"[3500]\ttrain's quantile: 0.712718\tvalid's quantile: 0.76426\n",
|
|
|
|
|
|
"[4000]\ttrain's quantile: 0.708171\tvalid's quantile: 0.763329\n",
|
|
|
|
|
|
"[4500]\ttrain's quantile: 0.703619\tvalid's quantile: 0.762479\n",
|
|
|
|
|
|
"[5000]\ttrain's quantile: 0.699455\tvalid's quantile: 0.761843\n",
|
|
|
|
|
|
"[5500]\ttrain's quantile: 0.695762\tvalid's quantile: 0.761218\n",
|
|
|
|
|
|
"[6000]\ttrain's quantile: 0.692351\tvalid's quantile: 0.760606\n",
|
|
|
|
|
|
"[6500]\ttrain's quantile: 0.689176\tvalid's quantile: 0.760173\n",
|
|
|
|
|
|
"[7000]\ttrain's quantile: 0.686318\tvalid's quantile: 0.759704\n",
|
|
|
|
|
|
"Early stopping, best iteration is:\n",
|
|
|
|
|
|
"[7090]\ttrain's quantile: 0.685745\tvalid's quantile: 0.759607\n",
|
|
|
|
|
|
"Evaluated only: quantile\n"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
]
|
2025-02-15 23:33:34 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAHFCAYAAAAJ2AY0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABu50lEQVR4nO3deVxU5f4H8M/MMAvDMqyyCALuC+C+gLlVappmlulN0zTt5rVFs+XmbbUstV8pWmmbRnotvaW2WoqpqLmmYiruiKgMIOuwDsPM+f0xMDILq8Ag83m/XvNi5pznnHnOV9JPz3nOOSJBEAQQERERkYnY3h0gIiIiam4YkIiIiIgsMCARERERWWBAIiIiIrLAgERERERkgQGJiIiIyAIDEhEREZEFBiQiIiIiCwxIRERERBYYkIiamdjYWIhEIohEIuzZs8dqvSAIaN++PUQiEYYOHVqv71i1ahViY2PrtM2ePXuq7FNDaazvaIq+VyUxMRFvvfUWkpOTG2X/b731FkQiUb22tWddiJo7BiSiZsrNzQ1r1qyxWh4fH4/Lly/Dzc2t3vuuT0Dq1asXDh48iF69etX7e+3Fnn1PTEzEwoULGy0gzZo1CwcPHqzXtnfynylRY2NAImqmJk2ahM2bN0Oj0ZgtX7NmDaKiotCmTZsm6YdOp0NZWRnc3d0xYMAAuLu7N8n3NoQ7se9FRUV1ah8UFIQBAwbU67vupLoQNTUGJKJm6tFHHwUAfPvtt6ZleXl52Lx5M5544gmb25SWlmLRokXo3Lkz5HI5fH19MWPGDNy8edPUJjQ0FGfOnEF8fLzpVF5oaCiAW6dc1q9fjxdeeAGtW7eGXC7HpUuXqjwdc/jwYYwdOxbe3t5QKBRo164d5s2bV+PxnTt3Dvfddx+USiV8fHwwe/Zs5OfnW7ULDQ3F9OnTrZYPHTrU7BRjXfs+ffp0uLq64tKlSxg9ejRcXV0RHByMF154AVqt1uy7rl+/jgkTJsDNzQ0eHh6YMmUKjh49CpFIVO1IXGxsLB555BEAwLBhw0z1rthm6NChCA8Px969exEdHQ2lUmn6s920aRNGjBiBgIAAODs7o0uXLnjllVdQWFho9h22TrGFhoZizJgx+P3339GrVy84Ozujc+fOWLt2rVk7e9WF6E7AgETUTLm7u2PChAlm/6h9++23EIvFmDRpklV7g8GAcePGYcmSJZg8eTJ+/fVXLFmyBHFxcRg6dCiKi4sBAFu3bkXbtm3Rs2dPHDx4EAcPHsTWrVvN9rVgwQKkpKTg008/xc8//4xWrVrZ7OP27dsxaNAgpKSkYNmyZfjtt9/w2muvIT09vdpjS09Px5AhQ3D69GmsWrUK69evR0FBAZ555pm6lslKbfsOGEeYHnjgAdxzzz348ccf8cQTT2D58uVYunSpqU1hYSGGDRuG3bt3Y+nSpfjf//4HPz8/m38Glu6//3689957AIBPPvnEVO/777/f1EatVuOxxx7D5MmTsW3bNsyZMwcAcPHiRYwePRpr1qzB77//jnnz5uF///sfxo4dW6s6nDx5Ei+88AKef/55/Pjjj4iMjMTMmTOxd+/eGrdt7LoQ3REEImpWvvrqKwGAcPToUWH37t0CAOH06dOCIAhC3759henTpwuCIAjdunUThgwZYtru22+/FQAImzdvNtvf0aNHBQDCqlWrTMsst61Q8X2DBw+uct3u3btNy9q1aye0a9dOKC4urtMx/vvf/xZEIpGQkJBgtnz48OFW3xESEiI8/vjjVvsYMmSI2THUte+PP/64AED43//+Z9Z29OjRQqdOnUyfP/nkEwGA8Ntvv5m1e+qppwQAwldffVXtsX733XdW3135GAAIf/zxR7X7MBgMgk6nE+Lj4wUAwsmTJ03r3nzzTcHyr/KQkBBBoVAIV69eNS0rLi4WvLy8hKeeesq0zJ51IWruOIJE1IwNGTIE7dq1w9q1a3Hq1CkcPXq0ytNrv/zyCzw8PDB27FiUlZWZXj169IC/v3+drlR6+OGHa2xz4cIFXL58GTNnzoRCoaj1vgFg9+7d6NatG7p37262fPLkyXXajy216XsFkUhkNSITGRmJq1evmj7Hx8fDzc0N9913n1m7ilOgt8vT0xN333231fKkpCRMnjwZ/v7+kEgkkEqlGDJkCADg7NmzNe63R48eZvPUFAoFOnbsaHZsVWkOdSGyNyd7d4CIqiYSiTBjxgysXLkSJSUl6NixIwYNGmSzbXp6OnJzcyGTyWyuz8zMrPX3BgQE1NimYl5TUFBQrfdbISsrC2FhYVbL/f3967wvS7XpewWlUmkV7uRyOUpKSkyfs7Ky4OfnZ7WtrWX1Yau/BQUFGDRoEBQKBRYtWoSOHTtCqVTi2rVreOihh0ynS6vj7e1ttUwul9dq2+ZQFyJ7Y0AiauamT5+ON954A59++inefffdKtv5+PjA29sbv//+u831dbktQG3uq+Pr6wvAOFG3rry9vZGWlma13NYyhUJhNTkYMAY+Hx8fq+X1vSdQVby9vXHkyBGr5bb6Wh+2+rtr1y6kpqZiz549plEjAMjNzW2Q72wIjV0XInvjKTaiZq5169Z46aWXMHbsWDz++ONVthszZgyysrKg1+vRp08fq1enTp1MbWs7klCdjh07mk7/2Qow1Rk2bBjOnDmDkydPmi3/5ptvrNqGhobi77//Nlt24cIFnD9/vu6drochQ4YgPz8fv/32m9nyjRs31mp7uVwOAHWqd0Voqti2wmeffVbrfTS2260LUXPHESSiO8CSJUtqbPOPf/wDGzZswOjRozF37lz069cPUqkU169fx+7duzFu3DiMHz8eABAREYGNGzdi06ZNaNu2LRQKBSIiIurcr08++QRjx47FgAED8Pzzz6NNmzZISUnB9u3bsWHDhiq3mzdvHtauXYv7778fixYtgp+fHzZs2IBz585ZtZ06dSoee+wxzJkzBw8//DCuXr2K999/3zSC1dgef/xxLF++HI899hgWLVqE9u3b47fffsP27dsBAGJx9f+fGR4eDgD4/PPP4ebmBoVCgbCwMJunwCpER0fD09MTs2fPxptvvgmpVIoNGzZYBUp7ut26EDV3/A0maiEkEgl++ukn/Oc//8GWLVswfvx4PPjgg1iyZIlVAFq4cCGGDBmCJ598Ev369av1peOWRo4cib179yIgIADPPfcc7rvvPrz99ts1zkPx9/dHfHw8unbtin/961947LHHoFAo8PHHH1u1nTx5Mt5//31s374dY8aMwerVq7F69Wp07NixXn2uKxcXF+zatQtDhw7Fyy+/jIcffhgpKSlYtWoVAMDDw6Pa7cPCwhATE4OTJ09i6NCh6Nu3L37++edqt/H29savv/4KpVKJxx57DE888QRcXV2xadOmhjqs23a7dSFq7kSCIAj27gQR0Z3mvffew2uvvYaUlJR6TVRvqVgXail4io2IqAYVI1udO3eGTqfDrl27sHLlSjz22GMOHQJYF2rJGJCIiGqgVCqxfPlyJCcnQ6vVok2bNvj3v/+N1157zd5dsyvWhVoynmIjIiIissBJ2kREREQWGJCIiIiILDAgEREREVngJG0bDAYDUlNT4ebm1uCPLSAiIqLGIQgC8vPzERgYeNs3K2VAsiE1NRXBwcH27gYRERHVw7Vr1277VhMMSDZUPNTzypUr8PLysnNvmg+dTocdO3ZgxIgRkEql9u5Os8Ca2Ma6WGNNbGNdrLEmttWmLhqNBsHBwXV6OHdVGJBsqDit5ubmBnd3dzv3pvnQ6XRQKpVwd3fnf7TlWBPbWBdrrIltrIs11sS2utSlIabHcJI2ERERkQUGJCIiIiILDEhEREREFuwakPbu3YuxY8ciMDAQIpEIP/zwQ43bxMfHo3fv3lAoFGjbti0+/fRTqzabN29G165dIZfL0bVrV2zdurURek9ERFR3er0eJSUl1b6cnJxqbOOILycnJxgMhib5c7LrJO3CwkJ0794dM2bMwMMPP1xj+ytXrmD06NF48skn8d///hd//vkn5syZA19fX9P2Bw8exKRJk/DOO+9g/Pjx2Lp1KyZ
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data"
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"data": {
|
|
|
|
|
|
"text/plain": [
|
|
|
|
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
|
|
|
|
],
|
|
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwMAAAHFCAYAAACuDCWjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1RUx9vA8e9SpYnSwSCgKKKgEjVGYwIWithLrFGJJRorEuwNMfYeTdSfib1EjdEYQ1RiS+xYY4tGBLFgiwoKSr3vHx7u6woq6CqW53POHti5c2fmPgu7O3fmztUoiqIghBBCCCGEeOfoFXYDhBBCCCGEEIVDOgNCCCGEEEK8o6QzIIQQQgghxDtKOgNCCCGEEEK8o6QzIIQQQgghxDtKOgNCCCGEEEK8o6QzIIQQQgghxDtKOgNCCCGEEEK8o6QzIIQQQgghxDtKOgNCCCHeWIsWLUKj0eT5CA8Pfyl1njp1ioiICOLj419K+S8iPj4ejUbDokWLCrspzy0qKoqIiIjCboYQ7wyDwm6AEEII8aIWLlxIuXLltNKcnJxeSl2nTp1i9OjR+Pn54erq+lLqeF6Ojo7s3buX0qVLF3ZTnltUVBTffvutdAiEeEWkMyCEEOKN5+XlRdWqVQu7GS8kIyMDjUaDgcHzfzQbGxvz4Ycf6rBVr05qaiqmpqaF3Qwh3jkyTUgIIcRbb9WqVdSoUQMzMzPMzc0JDAzkyJEjWnkOHjxImzZtcHV1xcTEBFdXV9q2bcuFCxfUPIsWLeLTTz8FoHbt2uqUpJxpOa6uroSEhOSq38/PDz8/P/X5jh070Gg0LF26lK+++ooSJUpgbGzMuXPnAPjjjz+oW7cuRYsWxdTUlI8++oitW7c+8zjzmiYUERGBRqPh77//5tNPP8XS0hIrKyvCwsLIzMzkzJkzBAUFYWFhgaurK5MmTdIqM6ety5YtIywsDAcHB0xMTPD19c0VQ4ANGzZQo0YNTE1NsbCwwN/fn71792rlyWnT4cOHadmyJcWLF6d06dKEhITw7bffAmhN+cqZkvXtt9/yySefYGdnh5mZGd7e3kyaNImMjIxc8fby8iImJoaPP/4YU1NTSpUqxYQJE8jOztbKe+fOHb766itKlSqFsbExdnZ2BAcH888//6h50tPT+frrrylXrhzGxsbY2try+eefc+PGjWe+JkK87qQzIIQQ4o2XlZVFZmam1iPHuHHjaNu2LeXLl2f16tUsXbqUu3fv8vHHH3Pq1Ck1X3x8PB4eHsyYMYPNmzczceJEEhMTqVatGjdv3gSgQYMGjBs3Dnj4xXTv3r3s3buXBg0aPFe7hwwZQkJCAnPnzuXXX3/Fzs6OZcuWERAQQNGiRVm8eDGrV6/GysqKwMDAfHUInqRVq1ZUqlSJtWvX0q1bN6ZPn07//v1p2rQpDRo0YN26ddSpU4dBgwbx888/59p/6NChnD9/nu+//57vv/+eK1eu4Ofnx/nz59U8K1asoEmTJhQtWpSVK1fyww8/cPv2bfz8/Ni1a1euMps3b467uztr1qxh7ty5jBgxgpYtWwKosd27dy+Ojo4AxMbG0q5dO5YuXcrGjRvp0qULkydPpnv37rnKvnr1Ku3bt+ezzz5jw4YN1K9fnyFDhrBs2TI1z927d6lVqxbz5s3j888/59dff2Xu3LmULVuWxMREALKzs2nSpAkTJkygXbt2/Pbbb0yYMIHo6Gj8/Py4f//+c78mQrwWFCGEEOINtXDhQgXI85GRkaEkJCQoBgYGSp8+fbT2u3v3ruLg4KC0atXqiWVnZmYq9+7dU8zMzJSZM2eq6WvWrFEAZfv27bn2cXFxUTp16pQr3dfXV/H19VWfb9++XQGUTz75RCtfSkqKYmVlpTRq1EgrPSsrS6lUqZLywQcfPCUaihIXF6cAysKFC9W0UaNGKYAydepUrbyVK1dWAOXnn39W0zIyMhRbW1ulefPmudr6/vvvK9nZ2Wp6fHy8YmhoqHTt2lVto5OTk+Lt7a1kZWWp+e7evavY2dkpNWvWzNWmkSNH5jqGXr16Kfn5epKVlaVkZGQoS5YsUfT19ZVbt26p23x9fRVA2b9/v9Y+5cuXVwIDA9XnkZGRCqBER0c/sZ6VK1cqgLJ27Vqt9JiYGAVQvvvuu2e2VYjXmYwMCCGEeOMtWbKEmJgYrYeBgQGbN28mMzOTjh07ao0aFClSBF9fX3bs2KGWce/ePQYNGoS7uzsGBgYYGBhgbm5OSkoKp0+ffintbtGihdbzPXv2cOvWLTp16qTV3uzsbIKCgoiJiSElJeW56mrYsKHWc09PTzQaDfXr11fTDAwMcHd315oalaNdu3ZoNBr1uYuLCzVr1mT79u0AnDlzhitXrtChQwf09P7/64W5uTktWrRg3759pKamPvX4n+XIkSM0btwYa2tr9PX1MTQ0pGPHjmRlZXH27FmtvA4ODnzwwQdaaRUrVtQ6tt9//52yZctSr169J9a5ceNGihUrRqNGjbRek8qVK+Pg4KD1NyTEm0guIBZCCPHG8/T0zPMC4mvXrgFQrVq1PPd79Etru3bt2Lp1KyNGjKBatWoULVoUjUZDcHDwS5sKkjP95fH25kyVycutW7cwMzMrcF1WVlZaz42MjDA1NaVIkSK50pOTk3Pt7+DgkGfasWPHAPjvv/+A3McED1d2ys7O5vbt21oXCeeV90kSEhL4+OOP8fDwYObMmbi6ulKkSBEOHDhAr169cr1G1tbWucowNjbWynfjxg1Kliz51HqvXbvGnTt3MDIyynN7zhQyId5U0hkQQgjx1rKxsQHgp59+wsXF5Yn5kpKS2LhxI6NGjWLw4MFqelpaGrdu3cp3fUWKFCEtLS1X+s2bN9W2POrRM+2PtnfWrFlPXBXI3t4+3+3RpatXr+aZlvOlO+dnzlz7R125cgU9PT2KFy+ulf748T/N+vXrSUlJ4eeff9Z6LY8ePZrvMh5na2vLpUuXnprHxsYGa2trNm3alOd2CwuL565fiNeBdAaEEEK8tQIDAzEwMCA2NvapU1I0Gg2KomBsbKyV/v3335OVlaWVlpMnr9ECV1dX/v77b620s2fPcubMmTw7A4/76KOPKFasGKdOnaJ3797PzP8qrVy5krCwMPUL/IULF9izZw8dO3YEwMPDgxIlSrBixQrCw8PVfCkpKaxdu1ZdYehZHo2viYmJmp5T3qOvkaIozJ8//7mPqX79+owcOZJt27ZRp06dPPM0bNiQH3/8kaysLKpXr/7cdQnxupLOgBBCiLeWq6srkZGRDBs2jPPnzxMUFETx4sW5du0aBw4cwMzMjNGjR1O0aFE++eQTJk+ejI2NDa6uruzcuZMffviBYsWKaZXp5eUFwP/+9z8sLCwoUqQIbm5uWFtb06FDBz777DN69uxJixYtuHDhApMmTcLW1jZf7TU3N2fWrFl06tSJW7du0bJlS+zs7Lhx4wbHjh3jxo0bzJkzR9dhypfr16/TrFkzunXrRlJSEqNGjaJIkSIMGTIEeDjlatKkSbRv356GDRvSvXt30tLSmDx5Mnfu3GHChAn5qsfb2xuAiRMnUr9+ffT19alYsSL+/v4YGRnRtm1bBg4cyIMHD5gzZw63b99+7mMKDQ1l1apVNGnShMGDB/PBBx9w//59du7cScOGDalduzZt2rRh+fLlBAcH069fPz744AMMDQ25dOkS27dvp0mTJjRr1uy52yBEYZMLiIUQQrzVhgwZwk8//cTZs2fp1KkTgYGBDBw4kAsXLvDJJ5+o+VasWEHt2rUZOHAgzZs35+DBg0RHR2NpaalVnpubGzNmzODYsWP4+flRrVo1fv31V+DhdQeTJk1i8+bNNGzYkDlz5jBnzhzKli2b7/Z+9tlnbN++nXv37tG9e3fq1atHv379OHz4MHXr1tVNUJ7DuHHjcHFx4fPPP6dz5844Ojqyfft2rbsdt2vXjvXr1/Pff//RunVrPv/8c4oWLcr27dupVatWvupp164dXbt25bvvvqNGjRpUq1aNK1euUK5cOdauXcvt27dp3rw5ffr0oXLlynzzzTfPfUwWFhbs2rW
|
|
|
|
|
|
},
|
|
|
|
|
|
"metadata": {},
|
|
|
|
|
|
"output_type": "display_data"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
],
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"execution_count": 97
|
|
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "63235069-dc59-48fb-961a-e80373e41a61",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T19:05:21.442954Z",
|
|
|
|
|
|
"start_time": "2025-02-14T19:05:21.364837Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"source": [
|
|
|
|
|
|
"print('train data size: ', len(train_data))\n",
|
|
|
|
|
|
"\n",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"catboost_params = {\n",
|
|
|
|
|
|
" 'loss_function': 'MAE', # 90% 分位回归\n",
|
|
|
|
|
|
" 'iterations': 5000, # 训练轮数\n",
|
|
|
|
|
|
" 'learning_rate': 0.05, # 学习率,较低以防止过拟合\n",
|
|
|
|
|
|
" 'depth': 16, # 树的深度,防止过拟合\n",
|
|
|
|
|
|
" 'l2_leaf_reg': 10.0, # L2 正则化,提高泛化能力\n",
|
|
|
|
|
|
" 'bagging_temperature': 1, # 降低过拟合\n",
|
|
|
|
|
|
" # 'subsample': 0.8, # 每轮随机 80% 的样本,减少过拟合\n",
|
|
|
|
|
|
" 'colsample_bylevel': 0.8, # 每层 80% 特征子集,防止过拟合\n",
|
|
|
|
|
|
" 'random_seed': 42, # 固定随机种子,保证可复现\n",
|
|
|
|
|
|
" 'verbose': 500, # 每 100 轮打印一次信息\n",
|
|
|
|
|
|
" 'early_stopping_rounds': 100, # 早停,防止过拟合\n",
|
|
|
|
|
|
" # 'task_type': 'GPU'\n",
|
|
|
|
|
|
"}\n",
|
|
|
|
|
|
"\n",
|
|
|
|
|
|
"# catboost_model = train_catboost(train_data, feature_columns, catboost_params)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"outputs": [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": "stdout",
|
|
|
|
|
|
"output_type": "stream",
|
|
|
|
|
|
"text": [
|
|
|
|
|
|
"train data size: 1067937\n"
|
|
|
|
|
|
]
|
|
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"execution_count": 98
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "5bb96ca8492e74d",
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"end_time": "2025-02-14T19:05:44.296879Z",
|
|
|
|
|
|
"start_time": "2025-02-14T19:05:21.450842Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"source": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"score_df = test_data\n",
|
|
|
|
|
|
"score_df['score'] = light_model.predict(score_df[feature_columns])\n",
|
|
|
|
|
|
"# train_data['score'] = catboost_model.predict(train_data[feature_columns])\n",
|
|
|
|
|
|
"predictions_test = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
|
|
|
|
|
|
"predictions_test = predictions_test[predictions_test['score'] > 0]\n",
|
|
|
|
|
|
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"outputs": [],
|
|
|
|
|
|
"execution_count": 99
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"end_time": "2025-02-14T19:06:45.153554Z",
|
|
|
|
|
|
"start_time": "2025-02-14T19:05:44.329062Z"
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
},
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"cell_type": "code",
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"source": [
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"score_df = train_data\n",
|
|
|
|
|
|
"score_df['score'] = light_model.predict(score_df[feature_columns])\n",
|
|
|
|
|
|
"# train_data['score'] = catboost_model.predict(train_data[feature_columns])\n",
|
|
|
|
|
|
"predictions_test = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
|
|
|
|
|
|
"predictions_test = predictions_test[predictions_test['score'] > 0]\n",
|
|
|
|
|
|
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_train.tsv', index=False)"
|
|
|
|
|
|
],
|
|
|
|
|
|
"id": "7359f89064a124d2",
|
|
|
|
|
|
"outputs": [],
|
|
|
|
|
|
"execution_count": 100
|
2025-02-12 00:21:33 +08:00
|
|
|
|
},
|
|
|
|
|
|
{
|
|
|
|
|
|
"cell_type": "code",
|
|
|
|
|
|
"id": "b427ce41-9739-4e9e-bea8-5f2551fec5d7",
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"metadata": {
|
|
|
|
|
|
"ExecuteTime": {
|
|
|
|
|
|
"end_time": "2025-02-14T19:06:45.232334Z",
|
|
|
|
|
|
"start_time": "2025-02-14T19:06:45.218159Z"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"source": [],
|
2025-02-12 00:21:33 +08:00
|
|
|
|
"outputs": [],
|
2025-02-15 23:33:34 +08:00
|
|
|
|
"execution_count": null
|
2025-02-12 00:21:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
],
|
|
|
|
|
|
"metadata": {
|
|
|
|
|
|
"kernelspec": {
|
|
|
|
|
|
"display_name": "Python 3 (ipykernel)",
|
|
|
|
|
|
"language": "python",
|
|
|
|
|
|
"name": "python3"
|
|
|
|
|
|
},
|
|
|
|
|
|
"language_info": {
|
|
|
|
|
|
"codemirror_mode": {
|
|
|
|
|
|
"name": "ipython",
|
|
|
|
|
|
"version": 3
|
|
|
|
|
|
},
|
|
|
|
|
|
"file_extension": ".py",
|
|
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
|
|
"name": "python",
|
|
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
|
|
"version": "3.8.19"
|
|
|
|
|
|
}
|
|
|
|
|
|
},
|
|
|
|
|
|
"nbformat": 4,
|
|
|
|
|
|
"nbformat_minor": 5
|
|
|
|
|
|
}
|