Files
NewStock/main/train/DoubleQuntile.ipynb

968 lines
291 KiB
Plaintext
Raw Normal View History

2025-04-03 00:45:07 +08:00
{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-02-16T08:24:31.295363Z",
"start_time": "2025-02-16T08:24:31.248141Z"
}
},
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:26:37.921393Z",
"start_time": "2025-02-16T08:26:37.822365Z"
}
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" # 计算当日涨跌幅\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
"\n",
" # 计算5日移动平均\n",
" df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
"\n",
" # 计算RSI14日\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" return df\n",
"\n",
"def generate_index_indicators(h5_filename):\n",
" \"\"\"\n",
" 从H5文件中读取指数行情数据计算相关指标返回结果。\n",
"\n",
" 参数:\n",
" h5_filename (str): 存储指数行情数据的H5文件路径\n",
"\n",
" 返回:\n",
" DataFrame: 包含计算结果的DataFrame每行代表一天包含所有指数的指标\n",
" \"\"\"\n",
" # 读取指数行情数据\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" # 获取某个指数的日线数据\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
"\n",
" # 计算相关指标\n",
" df_index = calculate_indicators(df_index)\n",
"\n",
" # 将结果添加到df_indicators列表\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', '5_day_ma', 'RSI', 'MACD', 'Signal_line', 'MACD_hist'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
" df_final['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
"h5_filename = '../../data/index_data.h5'\n",
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n",
"# 打印结果\n",
"print(index_data.head())\n"
],
"id": "b216cc5529d07cad",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" trade_date 000852.SH_5_day_ma 000905.SH_5_day_ma 399006.SZ_5_day_ma \\\n",
"1314 2019-09-06 4108.5648 4064.9986 1013.1790 \n",
"1315 2019-09-05 4132.3992 4081.7042 1031.9632 \n",
"1316 2019-09-04 4164.2720 4107.3572 1048.1040 \n",
"1317 2019-09-03 4204.0682 4141.2734 1071.9208 \n",
"1318 2019-09-02 4232.8298 4163.0654 1090.8244 \n",
"\n",
" 000852.SH_MACD 000905.SH_MACD 399006.SZ_MACD 000852.SH_MACD_hist \\\n",
"1314 20.380810 34.702099 15.173396 -2.837620 \n",
"1315 28.830595 41.792560 21.448935 4.902761 \n",
"1316 32.263643 44.346834 26.714792 9.561498 \n",
"1317 39.294963 49.978947 32.917542 18.983194 \n",
"1318 44.192628 52.871877 37.373758 28.626657 \n",
"\n",
" 000905.SH_MACD_hist 399006.SZ_MACD_hist 000852.SH_RSI 000905.SH_RSI \\\n",
"1314 1.270353 -9.734351 39.526629 45.136516 \n",
"1315 8.678401 -5.892400 49.265364 53.813061 \n",
"1316 13.402276 -2.099643 53.026120 56.084097 \n",
"1317 22.384958 3.578197 66.615288 68.521794 \n",
"1318 30.874127 8.928962 68.827733 70.274913 \n",
"\n",
" 399006.SZ_RSI 000852.SH_Signal_line 000905.SH_Signal_line \\\n",
"1314 38.536512 23.218429 33.431746 \n",
"1315 47.245649 23.927834 33.114158 \n",
"1316 51.408682 22.702144 30.944558 \n",
"1317 63.606060 20.311770 27.593989 \n",
"1318 69.525065 15.565971 21.997750 \n",
"\n",
" 399006.SZ_Signal_line 000852.SH_daily_return 000905.SH_daily_return \\\n",
"1314 24.907747 -2.133006 -1.914426 \n",
"1315 27.341335 1.676448 1.436414 \n",
"1316 28.814435 -0.741755 -0.725607 \n",
"1317 29.339346 0.839908 0.864581 \n",
"1318 28.444796 -0.331791 -0.402711 \n",
"\n",
" 399006.SZ_daily_return \n",
"1314 -2.676700 \n",
"1315 2.454294 \n",
"1316 0.127868 \n",
"1317 2.933411 \n",
"1318 4.066145 \n"
]
}
],
"execution_count": 5
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:25:13.872207Z",
"start_time": "2025-02-16T08:24:31.906361Z"
}
},
"source": [
2025-04-28 11:02:52 +08:00
"from code.utils.utils import read_and_merge_h5_data\n",
2025-04-03 00:45:07 +08:00
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"stk limit\n",
"money flow\n"
]
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:26:43.228123Z",
"start_time": "2025-02-16T08:26:41.895312Z"
}
},
"cell_type": "code",
"source": "df = df.merge(index_data, on='trade_date', how='left')",
"id": "7357147395bda969",
"outputs": [],
"execution_count": 6
},
{
"cell_type": "code",
"id": "c4e9e1d31da6dba6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:26:43.275563Z",
"start_time": "2025-02-16T08:26:43.228123Z"
}
},
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio']]"
],
"outputs": [],
"execution_count": 7
},
{
"cell_type": "code",
"id": "a735bc02ceb4d872",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-02-16T08:26:43.388955Z",
"start_time": "2025-02-16T08:26:43.310081Z"
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"def get_technical_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 计算 up 和 down\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14), index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算 RSI\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
" # 计算 EMA 指标\n",
" df['ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(df['act_factor1']**2 + df['act_factor2']**2)\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" return df\n",
"\n"
],
"outputs": [],
"execution_count": 8
},
{
"cell_type": "code",
"id": "dbe2fd8021b9417f",
"metadata": {
"jupyter": {
"source_hidden": true
},
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-02-16T08:27:52.356590Z",
"start_time": "2025-02-16T08:26:43.420269Z"
}
},
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_alpha_factor(df)\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 5453316 entries, 1962 to 5453315\n",
"Data columns (total 87 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
" 21 000852.SH_5_day_ma float64 \n",
" 22 000905.SH_5_day_ma float64 \n",
" 23 399006.SZ_5_day_ma float64 \n",
" 24 000852.SH_MACD float64 \n",
" 25 000905.SH_MACD float64 \n",
" 26 399006.SZ_MACD float64 \n",
" 27 000852.SH_MACD_hist float64 \n",
" 28 000905.SH_MACD_hist float64 \n",
" 29 399006.SZ_MACD_hist float64 \n",
" 30 000852.SH_RSI float64 \n",
" 31 000905.SH_RSI float64 \n",
" 32 399006.SZ_RSI float64 \n",
" 33 000852.SH_Signal_line float64 \n",
" 34 000905.SH_Signal_line float64 \n",
" 35 399006.SZ_Signal_line float64 \n",
" 36 000852.SH_daily_return float64 \n",
" 37 000905.SH_daily_return float64 \n",
" 38 399006.SZ_daily_return float64 \n",
" 39 up float64 \n",
" 40 down float64 \n",
" 41 atr_14 float64 \n",
" 42 atr_6 float64 \n",
" 43 obv float64 \n",
" 44 maobv_6 float64 \n",
" 45 obv-maobv_6 float64 \n",
" 46 rsi_3 float64 \n",
" 47 rsi_6 float64 \n",
" 48 rsi_9 float64 \n",
" 49 return_10 float64 \n",
" 50 return_20 float64 \n",
" 51 avg_close_5 float64 \n",
" 52 std_return_5 float64 \n",
" 53 std_return_15 float64 \n",
" 54 std_return_25 float64 \n",
" 55 std_return_90 float64 \n",
" 56 std_return_90_2 float64 \n",
" 57 std_return_5 / std_return_90 float64 \n",
" 58 std_return_5 / std_return_25 float64 \n",
" 59 std_return_90 - std_return_90_2 float64 \n",
" 60 ema_5 float64 \n",
" 61 ema_13 float64 \n",
" 62 ema_20 float64 \n",
" 63 ema_60 float64 \n",
" 64 act_factor1 float64 \n",
" 65 act_factor2 float64 \n",
" 66 act_factor3 float64 \n",
" 67 act_factor4 float64 \n",
" 68 cat_af1 bool \n",
" 69 cat_af2 bool \n",
" 70 cat_af3 bool \n",
" 71 cat_af4 bool \n",
" 72 act_factor5 float64 \n",
" 73 act_factor6 float64 \n",
" 74 rank_act_factor1 float64 \n",
" 75 rank_act_factor2 float64 \n",
" 76 rank_act_factor3 float64 \n",
" 77 active_buy_volume_large float64 \n",
" 78 active_buy_volume_big float64 \n",
" 79 active_buy_volume_small float64 \n",
" 80 buy_lg_vol_minus_sell_lg_vol float64 \n",
" 81 buy_elg_vol_minus_sell_elg_vol float64 \n",
" 82 log(circ_mv) float64 \n",
" 83 alpha_022 float64 \n",
" 84 alpha_003 float64 \n",
" 85 alpha_007 float64 \n",
" 86 alpha_013 float64 \n",
"dtypes: bool(5), datetime64[ns](1), float64(80), object(1)\n",
"memory usage: 3.4+ GB\n",
"None\n"
]
}
],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:27:52.874223Z",
"start_time": "2025-02-16T08:27:52.717759Z"
}
},
"cell_type": "code",
"source": [
"feature_columns = [col for col in df.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"print(feature_columns)"
],
"id": "2132103543a77819",
"outputs": [],
"execution_count": 10
},
{
"cell_type": "code",
"id": "5f3d9aece75318cd",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:42:38.569134Z",
"start_time": "2025-02-16T08:42:37.956365Z"
}
},
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"window = 5\n",
"quantiles = 20\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(level=0, drop=True)\n",
" df['sharpe_ratio'] = df['future_return'] * df['future_volatility']\n",
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return df['sharpe_ratio']\n",
"\n",
"def get_label(df):\n",
" # labels = df['future_af13'] - df['act_factor1']\n",
" labels = df['future_close5']\n",
" # labels = df['future_af11']\n",
" # labels = df['ema_5'].shift(-1) - df['close']\n",
" # df['label'] = (df['future_af11'] - df['act_factor1']) / df['act_factor1']\n",
" # df['label'] = calculate_risk_adjusted_target(df, days=5)\n",
" # lower_percentile = df['label'].quantile(0.01) # 1%分位数\n",
" # upper_percentile = df['label'].quantile(0.99) # 99%分位数\n",
" # labels = df['label'].clip(lower=lower_percentile, upper=upper_percentile)\n",
" # labels = calculate_risk_adjusted_return(df, days=3, history_days=3, method='ratio')\n",
" return labels\n",
"\n",
"df['label'] = get_label(df)\n",
"train_data = df[df['trade_date'] <= '2023-01-01']\n",
"test_data = df[df['trade_date'] >= '2023-01-01']\n",
"\n",
"train_data = train_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
"test_data = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
"\n",
"# train_data = get_future_data(train_data)\n",
"\n",
"# df = df[['ts_code', 'trade_date', 'open', 'close']]\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"train_data = train_data.dropna(subset=feature_columns)\n",
"train_data = train_data.reset_index(drop=True)\n",
"train_data['label'] = get_label(train_data)\n",
"\n",
"test_data = test_data.dropna(subset=feature_columns)\n",
"test_data = test_data.reset_index(drop=True)\n",
"test_data['label'] = get_label(test_data)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n"
],
"outputs": [
{
"ename": "NameError",
"evalue": "name 'df' is not defined",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mNameError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[1;32mIn[1], line 34\u001B[0m\n\u001B[0;32m 24\u001B[0m \u001B[38;5;66;03m# labels = df['future_af11']\u001B[39;00m\n\u001B[0;32m 25\u001B[0m \u001B[38;5;66;03m# labels = df['ema_5'].shift(-1) - df['close']\u001B[39;00m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;66;03m# df['label'] = (df['future_af11'] - df['act_factor1']) / df['act_factor1']\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;66;03m# labels = df['label'].clip(lower=lower_percentile, upper=upper_percentile)\u001B[39;00m\n\u001B[0;32m 31\u001B[0m \u001B[38;5;66;03m# labels = calculate_risk_adjusted_return(df, days=3, history_days=3, method='ratio')\u001B[39;00m\n\u001B[0;32m 32\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m labels\n\u001B[1;32m---> 34\u001B[0m df[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlabel\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m=\u001B[39m get_label(\u001B[43mdf\u001B[49m)\n\u001B[0;32m 35\u001B[0m train_data \u001B[38;5;241m=\u001B[39m df[df[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtrade_date\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m2023-01-01\u001B[39m\u001B[38;5;124m'\u001B[39m]\n\u001B[0;32m 36\u001B[0m test_data \u001B[38;5;241m=\u001B[39m df[df[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtrade_date\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m>\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m2023-01-01\u001B[39m\u001B[38;5;124m'\u001B[39m]\n",
"\u001B[1;31mNameError\u001B[0m: name 'df' is not defined"
]
}
],
"execution_count": 1
},
{
"cell_type": "code",
"id": "8f134d435f71e9e2",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-02-16T08:42:38.569134800Z",
"start_time": "2025-02-16T08:28:25.333768Z"
}
},
"source": [
"import lightgbm as lgb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import optuna\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.metrics import mean_absolute_error\n",
"import os\n",
"import json\n",
"import pickle\n",
"import hashlib\n",
"\n",
"def train_light_model(train_data_df, test_data_df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" use_optuna=False):\n",
"\n",
" X_train = train_data_df[feature_columns]\n",
" y_train = train_data_df['label']\n",
"\n",
" X_val = test_data_df[feature_columns]\n",
" y_val = test_data_df['label']\n",
"\n",
" categorical_feature = [col for col in feature_columns if 'cat' in col]\n",
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
" val_data = lgb.Dataset(X_val, label=y_val, categorical_feature=categorical_feature)\n",
" model = lgb.train(\n",
" params, train_data, num_boost_round=num_boost_round,\n",
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" # lgb.plot_tree(model, figsize=(20, 8))\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
" return model\n",
"\n",
"\n",
"from catboost import CatBoostRegressor\n",
"import pandas as pd\n",
"\n",
"\n",
"def train_catboost(df, feature_columns, params=None):\n",
" \"\"\"\n",
" 训练 CatBoost 排序模型\n",
" - df: 包含因子、date、instrument 和 label 的 DataFrame\n",
" - num_boost_round: 训练的轮数\n",
" - print_feature_importance: 是否打印特征重要性\n",
" - plot: 是否绘制特征重要性图\n",
" - split_date: 用于划分训练集和验证集的日期(比如 '2020-01-01'\n",
"\n",
" 返回训练好的模型\n",
" \"\"\"\n",
" df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False])\n",
"\n",
" df_sorted = df_sorted.sort_values(by='trade_date')\n",
" unique_dates = df_sorted['trade_date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices].sort_values(by=['trade_date', 'label'], ascending=[True, False])\n",
" val_df = df_sorted.iloc[val_indices].sort_values(by=['trade_date', 'label'], ascending=[True, False])\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" model = CatBoostRegressor(**params)\n",
" model.fit(X_train,\n",
" y_train,\n",
" eval_set=(X_val, y_val))\n",
"\n",
" return model"
],
"outputs": [],
"execution_count": 12
},
{
"cell_type": "code",
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:42:38.569134800Z",
"start_time": "2025-02-16T08:29:41.943281Z"
}
},
"source": [
"print('train data size: ', len(train_data))\n",
"\n",
"light_params9 = {\n",
" # 'objective': 'regression',\n",
" # 'metric': 'l2',\n",
" 'objective': 'quantile', # 分位回归\n",
" 'metric': 'quantile', # 使用 quantile 作为评估指标\n",
" 'alpha': 0.75, # 90% 分位数\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 256,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
" # 'feature_fraction': 0.7,\n",
" # 'bagging_fraction': 0.7,\n",
" # 'bagging_freq': 5,\n",
" # 'lambda_l1': 10,\n",
" # 'lambda_l2': 10,\n",
" 'verbosity': -1,\n",
" # 'device': 'gpu'\n",
"}\n",
"evals = {}\n",
"light_model9 = train_light_model(train_data, test_data, light_params9, feature_columns,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=10000, use_optuna=False,\n",
" print_feature_importance=True)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 1170833\n",
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
"[147]\ttrain's quantile: 0.00152696\tvalid's quantile: 0.00164099\n",
"Evaluated only: quantile\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlIAAAHFCAYAAAA5VBcVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABy9ElEQVR4nO3dd3hUVf7H8fdkMukhIT2BQEInlAAJUpQuIAgCimBDUGR/iK4ClpVV17qiu66iq4iuIOq6gorYQCFIl0jvvQRCSSGhpEHq/f0xMBKSQBhChiSf1/PMk8ydc+98z8wAH849c67JMAwDEREREbliTo4uQERERKSqUpASERERsZOClIiIiIidFKRERERE7KQgJSIiImInBSkREREROylIiYiIiNhJQUpERETETgpSIiIiInZSkBKppmbOnInJZMJkMrF06dISjxuGQaNGjTCZTHTv3t2u55g6dSozZ868on2WLl1aZk0V5Vo9R2XUXpYdO3bw4osvcvDgwWty/BdffBGTyWTXvo58XUQcTUFKpJrz9vZm+vTpJbYvW7aM/fv34+3tbfex7QlS7dq1Iz4+nnbt2tn9vI7iyNp37NjBSy+9dM2C1EMPPUR8fLxd+1bl91TkailIiVRzw4cPZ86cOWRkZBTbPn36dDp16kS9evUqpY78/HwKCgqoVasWHTt2pFatWpXyvBWhKtaek5NzRe3r1q1Lx44d7XquqvS6iFQ0BSmRau7uu+8G4Msvv7RtO336NHPmzOHBBx8sdZ+8vDxeffVVmjVrhqurK4GBgTzwwAMcP37c1iYiIoLt27ezbNky2ynEiIgI4I9TPZ9//jlPPPEEderUwdXVlX379pV5Gmj16tUMHDgQf39/3NzcaNiwIePHj79s/3bt2sUtt9yCh4cHAQEBjB07lszMzBLtIiIiGDVqVInt3bt3L3Zq80prHzVqFF5eXuzbt4/+/fvj5eVFeHg4TzzxBLm5ucWe68iRIwwdOhRvb298fX259957Wbt2LSaT6ZIjezNnzuTOO+8EoEePHrbX+/w+3bt3p2XLlixfvpzOnTvj4eFhe29nz55Nnz59CA0Nxd3dnebNm/PMM8+QnZ1d7DlKO7UXERHBgAED+OWXX2jXrh3u7u40a9aMGTNmFGvnqNdF5HqgICVSzdWqVYuhQ4cW+8fvyy+/xMnJieHDh5doX1RUxKBBg3j99de55557mDdvHq+//jpxcXF0796dM2fOADB37lwaNGhA27ZtiY+PJz4+nrlz5xY71qRJk0hMTGTatGn8+OOPBAUFlVrjggUL6NKlC4mJibz11lv8/PPPPPfcc6SkpFyybykpKXTr1o1t27YxdepUPv/8c7Kysnj00Uev9GUqoby1g3XE6rbbbqNXr158//33PPjgg7z99tu88cYbtjbZ2dn06NGDJUuW8MYbb/DVV18RHBxc6ntwsVtvvZXXXnsNgPfff9/2et966622NklJSdx3333cc889zJ8/n3HjxgGwd+9e+vfvz/Tp0/nll18YP348X331FQMHDizX67B582aeeOIJJkyYwPfff0/r1q0ZPXo0y5cvv+y+1/p1EbkuGCJSLX3yyScGYKxdu9ZYsmSJARjbtm0zDMMw2rdvb4waNcowDMNo0aKF0a1bN9t+X375pQEYc+bMKXa8tWvXGoAxdepU27aL9z3v/PN17dq1zMeWLFli29awYUOjYcOGxpkzZ66oj3/5y18Mk8lkbNq0qdj23r17l3iO+vXrGyNHjixxjG7duhXrw5XWPnLkSAMwvvrqq2Jt+/fvbzRt2tR2//333zcA4+effy7W7v/+7/8MwPjkk08u2devv/66xHNf2AfA+PXXXy95jKKiIiM/P99YtmyZARibN2+2PfbCCy8YF/+TUL9+fcPNzc04dOiQbduZM2cMPz8/4//+7/9s2xz5uog4mkakRGqAbt260bBhQ2bMmMHWrVtZu3Ztmaf1fvrpJ3x9fRk4cCAFBQW2W5s2bQgJCbmib2bdcccdl22zZ88e9u/fz+jRo3Fzcyv3sQGWLFlCixYtiI6OLrb9nnvuuaLjlKY8tZ9nMplKjPC0bt2aQ4cO2e4vW7YMb29vbrnllmLtzp96vVq1a9emZ8+eJbYfOHCAe+65h5CQEMxmMxaLhW7dugGwc+fOyx63TZs2xebRubm50aRJk2J9K8v18LqIXGvOji5ARK49k8nEAw88wLvvvsvZs2dp0qQJXbp0KbVtSkoKp06dwsXFpdTH09LSyv28oaGhl21zft5V3bp1y33c89LT04mMjCyxPSQk5IqPdbHy1H6eh4dHiRDo6urK2bNnbffT09MJDg4usW9p2+xRWr1ZWVl06dIFNzc3Xn31VZo0aYKHhweHDx/m9ttvt52mvRR/f/8S21xdXcu17/XwuohcawpSIjXEqFGj+Nvf/sa0adP4+9//Xma7gIAA/P39+eWXX0p9/EqWSyjPukSBgYGAdcLxlfL39yc5ObnE9tK2ubm5lZjkDNZgGBAQUGK7vWsqlcXf3581a9aU2F5arfYord7Fixdz7Ngxli5dahuFAjh16lSFPGdFuNavi8i1plN7IjVEnTp1eOqppxg4cCAjR44ss92AAQNIT0+nsLCQ2NjYEremTZva2pZ3ZOJSmjRpYjvtWFrQuZQePXqwfft2Nm/eXGz7//73vxJtIyIi2LJlS7Fte/bsYffu3VdetB26detGZmYmP//8c7Hts2bNKtf+rq6uAFf0ep8PV+f3Pe/DDz8s9zGutat9XUQcTSNSIjXI66+/ftk2d911F1988QX9+/fn8ccf54YbbsBisXDkyBGWLFnCoEGDGDJkCACtWrVi1qxZzJ49mwYNGuDm5karVq2uuK7333+fgQMH0rFjRyZMmEC9evVITExkwYIFfPHFF2XuN378eGbMmMGtt97Kq6++SnBwMF988QW7du0q0XbEiBHcd999jBs3jjvuuINDhw7xj3/8wzYidq2NHDmSt99+m/vuu49XX32VRo0a8fPPP7NgwQIAnJwu/f/ali1bAvDRRx/h7e2Nm5sbkZGRpZ56O69z587Url2bsWPH8sILL2CxWPjiiy9KBE9HutrXRcTR9AkVkWLMZjM//PADf/3rX/n2228ZMmQIgwcP5vXXXy8RlF566SW6devGmDFjuOGGG8r9lfqL9e3bl+XLlxMaGspjjz3GLbfcwssvv3zZeTIhISEsW7aMqKgoHn74Ye677z7c3Nx47733SrS95557+Mc//sGCBQsYMGAAH3zwAR988AFNmjSxq+Yr5enpyeLFi+nevTtPP/00d9xxB4mJiUydOhUAX1/fS+4fGRnJlClT2Lx5M927d6d9+/b8+OOPl9zH39+fefPm4eHhwX333ceDDz6Il5cXs2fPrqhuXbWrfV1EHM1kGIbh6CJERGqq1157jeeee47ExES7JtxXV3pdpKrQqT0RkUpyfqSsWbNm5Ofns3jxYt59913uu+++Gh0W9LpIVaYgJSJSSTw8PHj77bc5ePAgubm51KtXj7/85S8899xzji7NofS6SFWmU3siIiIidtJkcxERERE7KUiJiIiI2MnhQWrq1KlERkbi5uZGTEwMK1asuGT7ZcuWERMTg5ubGw0aNGDatGkl2syZM4eoqChcXV2JiooqcUX65cuXM3DgQMLCwjCZTHz33XcljpGSksKoUaMICwvDw8ODW265hb17915VX0VERKR6cehk89mzZzN+/HimTp3KjTfeyIcffki/fv3YsWNHsYtknpeQkED//v0ZM2YM//3vf/ntt98YN24cgYGBtguMxsfHM3z4cF555RWGDBnC3LlzGTZsGCtXrqRDhw4AZGdnEx0dzQMPPFDqhUkNw2Dw4MFYLBa+//57atWqxVtvvcXNN9/Mjh078PT0LFf/ioqKOHbsGN7e3hV+uQkRERG5NgzDIDMzk7CwsMsvCms40A033GCMHTu22LZmzZoZzzzzTKntn376aaNZs2bFtv3f//2f0bFjR9v9YcOGGbfcckuxNn379jXuuuuuUo8JGHPnzi22bffu3QZgbNu2zbatoKD
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwMAAAHFCAYAAACuDCWjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVyO2f/48dddKq20LyZlF6WaacYgylIU2YfB2JfBGJPGvmYt2YcZjEF2YyaMoUH2aZgsY+fDWBJDY6xR5K7u3x/3r+vrVsha6v18PO6HrnOd65xzvct93+c651yXSqPRaBBCCCGEEEIUOXr53QAhhBBCCCFE/pDOgBBCCCGEEEWUdAaEEEIIIYQooqQzIIQQQgghRBElnQEhhBBCCCGKKOkMCCGEEEIIUURJZ0AIIYQQQogiSjoDQgghhBBCFFHSGRBCCCGEEKKIks6AEEKId1Z0dDQqlSrX18CBA99InadOnSI8PJzExMQ3Uv6rSExMRKVSER0dnd9NeWmxsbGEh4fndzOEKDKK5XcDhBBCiFe1ePFiKleurJPm5OT0Ruo6deoUY8eOxd/fH1dX1zdSx8tydHRk3759lCtXLr+b8tJiY2P59ttvpUMgxFsinQEhhBDvPHd3d3x8fPK7Ga9ErVajUqkoVuzlP5qNjIz4+OOPX2Or3p60tDRMTEzyuxlCFDkyTUgIIUSh9+OPP1KjRg1MTU0xMzOjYcOGHD58WCfPwYMH+fTTT3F1dcXY2BhXV1fatWvHpUuXlDzR0dF88sknANStW1eZkpQ9LcfV1ZUuXbrkqN/f3x9/f39le9euXahUKpYtW8bXX39NqVKlMDIy4ty5cwBs27aN+vXrY2FhgYmJCbVq1WL79u3PPc/cpgmFh4ejUqk4duwYn3zyCSVKlMDKyoqwsDAyMjI4c+YMjRo1wtzcHFdXV6KionTKzG7r8uXLCQsLw8HBAWNjY/z8/HLEEGDDhg3UqFEDExMTzM3NCQgIYN++fTp5stv0119/0bp1aywtLSlXrhxdunTh22+/BdCZ8pU9Jevbb7+lTp062NnZYWpqioeHB1FRUajV6hzxdnd358CBA9SuXRsTExPKli1LZGQkWVlZOnnv3LnD119/TdmyZTEyMsLOzo7g4GD+97//KXkePXrEhAkTqFy5MkZGRtja2tK1a1f++++/5/5OhCjopDMghBDinZeZmUlGRobOK9ukSZNo164dVapUYc2aNSxbtox79+5Ru3ZtTp06peRLTEykUqVKzJw5ky1btjB58mSuXbvGhx9+yI0bNwBo3LgxkyZNArRfTPft28e+ffto3LjxS7V72LBhJCUlMW/ePH799Vfs7OxYvnw5gYGBWFhYsGTJEtasWYOVlRUNGzbMU4fgadq0aYOnpycxMTH07NmTGTNmMGDAAJo3b07jxo1Zt24d9erVY8iQIaxduzbH8cOHD+fChQv88MMP/PDDD1y9ehV/f38uXLig5Fm5ciXNmjXDwsKCVatWsXDhQm7fvo2/vz/x8fE5ymzZsiXly5fnp59+Yt68eYwaNYrWrVsDKLHdt28fjo6OAJw/f5727duzbNkyNm7cSPfu3ZkyZQqff/55jrKTk5Pp0KEDn332GRs2bCAoKIhhw4axfPlyJc+9e/fw9fVl/vz5dO3alV9//ZV58+ZRsWJFrl27BkBWVhbNmjUjMjKS9u3bs2nTJiIjI4mLi8Pf358HDx689O9EiAJBI4QQQryjFi9erAFyfanVak1SUpKmWLFimi+//FLnuHv37mkcHBw0bdq0eWrZGRkZmvv372tMTU01s2bNUtJ/+uknDaDZuXNnjmNcXFw0nTt3zpHu5+en8fPzU7Z37typATR16tTRyZeamqqxsrLShISE6KRnZmZqPD09NR999NEzoqHRXLx4UQNoFi9erKSNGTNGA2imTZumk9fLy0sDaNauXaukqdVqja2traZly5Y52vr+++9rsrKylPTExESNgYGBpkePHkobnZycNB4eHprMzEwl37179zR2dnaamjVr5mjT6NGjc5zDF198ocnL15PMzEyNWq3WLF26VKOvr6+5deuWss/Pz08DaBISEnSOqVKliqZhw4bK9rhx4zSAJi4u7qn1rFq1SgNoYmJidNIPHDigATTffffdc9sqREEmIwNCCCHeeUuXLuXAgQM6r2LFirFlyxYyMjLo1KmTzqhB8eLF8fPzY9euXUoZ9+/fZ8iQIZQvX55ixYpRrFgxzMzMSE1N5fTp02+k3a1atdLZ3rt3L7du3aJz58467c3KyqJRo0YcOHCA1NTUl6qrSZMmOttubm6oVCqCgoKUtGLFilG+fHmdqVHZ2rdvj0qlUrZdXFyoWbMmO3fuBODMmTNcvXqVjh07oqf3f18vzMzMaNWqFX/++SdpaWnPPP/nOXz4ME2bNsXa2hp9fX0MDAzo1KkTmZmZnD17Vievg4MDH330kU5atWrVdM7tt99+o2LFijRo0OCpdW7cuJGSJUsSEhKi8zvx8vLCwcFB529IiHeRLCAWQgjxznNzc8t1AfG///4LwIcffpjrcY9/aW3fvj3bt29n1KhRfPjhh1hYWKBSqQgODn5jU0Gyp7882d7sqTK5uXXrFqampi9cl5WVlc62oaEhJiYmFC9ePEd6SkpKjuMdHBxyTTt69CgAN2/eBHKeE2jv7JSVlcXt27d1FgnnlvdpkpKSqF27NpUqVWLWrFm4urpSvHhx9u/fzxdffJHjd2RtbZ2jDCMjI518//33H6VLl35mvf/++y937tzB0NAw1/3ZU8iEeFdJZ0AIIUShZWNjA8DPP/+Mi4vLU/PdvXuXjRs3MmbMGIYOHaqkp6enc+vWrTzXV7x4cdLT03Ok37hxQ2nL4x6/0v54e2fPnv3UuwLZ29vnuT2vU3Jycq5p2V+6s//Nnmv/uKtXr6Knp4elpaVO+pPn/yzr168nNTWVtWvX6vwujxw5kucynmRra8uVK1eemcfGxgZra2s2b96c635zc/OXrl+IgkA6A0IIIQqthg0bUqxYMc6fP//MKSkqlQqNRoORkZFO+g8//EBmZqZOWnae3EYLXF1dOXbsmE7a2bNnOXPmTK6dgSfVqlWLkiVLcurUKfr16/fc/G/TqlWrCAsLU77AX7p0ib1799KpUycAKlWqRKlSpVi5ciUDBw5U8qWmphITE6PcYeh5Ho+vsbGxkp5d3uO/I41Gw4IFC176nIKCghg9ejQ7duygXr16ueZp0qQJq1evJjMzk+rVq790XUIUVNIZEEIIUWi5uroybtw4RowYwYULF2jUqBGWlpb8+++/7N+/H1NTU8aOHYuFhQV16tRhypQp2NjY4Orqyu7du1m4cCElS5bUKdPd3R2A77//HnNzc4oXL06ZMmWwtramY8eOfPbZZ/Tt25dWrVpx6dIloqKisLW1zVN7zczMmD17Np07d+bWrVu0bt0aOzs7/vvvP44ePcp///3H3LlzX3eY8uT69eu0aNGCnj17cvfuXcaMGUPx4sUZNmwYoJ1yFRUVRYcOHWjSpAmff/456enpTJkyhTt37hAZGZmnejw8PACYPHkyQUFB6OvrU61aNQICAjA0NKRdu3YMHjyYhw8fMnfuXG7fvv3S5xQaGsqPP/5Is2bNGDp0KB999BEPHjxg9+7dNGnShLp16/Lpp5+yYsUKgoOD+eqrr/joo48wMDDgypUr7Ny5k2bNmtGiRYuXboMQ+U0WEAshhCjUhg0bxs8//8zZs2fp3LkzDRs2ZPDgwVy6dIk6deoo+VauXEndunUZPHgwLVu25ODBg8TFxVGiRAmd8sqUKcPMmTM5evQo/v7+fPjhh/z666+Adt1BVFQUW7ZsoUmTJsydO5e5c+dSsWLFPLf3s88+Y+fOndy/f5/PP/+cBg0a8NVXX/HXX39Rv3791xOUlzBp0iRcXFzo2rUr3bp1w9HRkZ07d+o87bh9+/asX7+emzdv0rZtW7p27YqFhQU7d+7E19c3T/W0b9+eHj168N1331GjRg0+/PBDrl69SuXKlYmJieH27du0bNmSL7/8Ei8vL7755puXPid
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 14
},
{
"cell_type": "code",
"id": "10bdf199fcff6b48",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:42:38.569134800Z",
"start_time": "2025-02-16T08:30:40.002290Z"
}
},
"source": [
"light_params1 = {\n",
" # 'objective': 'regression',\n",
" # 'metric': 'l2',\n",
" 'objective': 'quantile', # 分位回归\n",
" 'metric': 'quantile', # 使用 quantile 作为评估指标\n",
" 'alpha': 0.25, # 90% 分位数\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 256,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
" # 'feature_fraction': 0.7,\n",
" # 'bagging_fraction': 0.7,\n",
" # 'bagging_freq': 5,\n",
" # 'lambda_l1': 10,\n",
" # 'lambda_l2': 10,\n",
" 'verbosity': -1,\n",
" # 'device': 'gpu'\n",
"}\n",
"evals = {}\n",
"light_model1 = train_light_model(train_data, test_data, light_params1, feature_columns,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=10000, use_optuna=False,\n",
" print_feature_importance=True)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
"[82]\ttrain's quantile: 0.00160062\tvalid's quantile: 0.00150239\n",
"Evaluated only: quantile\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlIAAAHFCAYAAAA5VBcVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABrtUlEQVR4nO3dd3RU1d7G8e9k0nuDFAhJ6IQOQYpUkSKKFFEUQewvolfKtWHvYudyFbGg2K6gAoqCQlCqRDpIr4FQEkJoaZBMkvP+MTAQkkAyBCbl+aw1K5k9+5z5zTaSJ/ucs4/JMAwDERERESk1J0cXICIiIlJRKUiJiIiI2ElBSkRERMROClIiIiIidlKQEhEREbGTgpSIiIiInRSkREREROykICUiIiJiJwUpERERETspSIlUUlOnTsVkMmEymVi0aFGh1w3DoG7duphMJrp27WrXe0yaNImpU6eWaptFixYVW1NZuVLvcTVqL86WLVt48cUX2bt37xXZ/4svvojJZLJrW0eOi4ijKUiJVHI+Pj5MmTKlUPvixYvZvXs3Pj4+du/bniDVqlUr4uPjadWqld3v6yiOrH3Lli289NJLVyxI3X///cTHx9u1bUX+bypyuRSkRCq5wYMHM2PGDNLS0gq0T5kyhfbt21OrVq2rUofFYiE3NxdfX1/atWuHr6/vVXnfslARa8/KyipV/5o1a9KuXTu73qsijYtIWVOQEqnk7rjjDgC+++47W9vJkyeZMWMG9957b5Hb5OTk8Oqrr9KwYUPc3NyoVq0a99xzD0eOHLH1iYqKYvPmzSxevNh2CDEqKgo4d6jn66+/5t///jc1atTAzc2NXbt2FXsYaMWKFfTt25egoCDc3d2pU6cOo0ePvuTn27ZtG71798bT05Pg4GBGjBhBenp6oX5RUVHcfffdhdq7du1a4NBmaWu/++678fb2ZteuXfTp0wdvb28iIiL497//TXZ2doH3OnDgAIMGDcLHxwd/f3/uvPNOVq1ahclkuujM3tSpU7n11lsB6Natm228z27TtWtXmjRpwpIlS+jQoQOenp62/7bTp0+nZ8+ehIWF4eHhQaNGjXjqqafIzMws8B5FHdqLioripptu4vfff6dVq1Z4eHjQsGFDPv/88wL9HDUuIuWBgpRIJefr68ugQYMK/PL77rvvcHJyYvDgwYX65+fn069fP8aPH8+QIUOYM2cO48ePJy4ujq5du3Lq1CkAZs2aRe3atWnZsiXx8fHEx8cza9asAvsaN24ciYmJTJ48mV9++YXq1asXWeO8efPo1KkTiYmJvPfee/z22288++yzHD58+KKf7fDhw3Tp0oVNmzYxadIkvv76azIyMnjkkUdKO0yFlLR2sM5Y3XzzzXTv3p2ff/6Ze++9l/fff58333zT1iczM5Nu3bqxcOFC3nzzTb7//ntCQkKK/G9woRtvvJHXX38dgA8//NA23jfeeKOtT1JSEkOHDmXIkCHMnTuXkSNHArBz50769OnDlClT+P333xk9ejTff/89ffv2LdE4bNiwgX//+9+MGTOGn3/+mWbNmnHfffexZMmSS257pcdFpFwwRKRS+uKLLwzAWLVqlbFw4UIDMDZt2mQYhmG0adPGuPvuuw3DMIzGjRsbXbp0sW333XffGYAxY8aMAvtbtWqVARiTJk2ytV247Vln369z587FvrZw4UJbW506dYw6deoYp06dKtVnfPLJJw2TyWSsX7++QHuPHj0KvUdkZKQxfPjwQvvo0qVLgc9Q2tqHDx9uAMb3339foG+fPn2MBg0a2J5/+OGHBmD89ttvBfr93//9nwEYX3zxxUU/6w8//FDovc//DIDxxx9/XHQf+fn5hsViMRYvXmwAxoYNG2yvvfDCC8aFvxIiIyMNd3d3Y9++fba2U6dOGYGBgcb//d//2docOS4ijqYZKZEqoEuXLtSpU4fPP/+cjRs3smrVqmIP6/3666/4+/vTt29fcnNzbY8WLVoQGhpaqiuzbrnllkv22bFjB7t37+a+++7D3d29xPsGWLhwIY0bN6Z58+YF2ocMGVKq/RSlJLWfZTKZCs3wNGvWjH379tmeL168GB8fH3r37l2g39lDr5crICCA6667rlD7nj17GDJkCKGhoZjNZlxcXOjSpQsAW7duveR+W7RoUeA8Ond3d+rXr1/gsxWnPIyLyJXm7OgCROTKM5lM3HPPPUycOJHTp09Tv359OnXqVGTfw4cPc+LECVxdXYt8PTU1tcTvGxYWdsk+Z8+7qlmzZon3e9bRo0eJjo4u1B4aGlrqfV2oJLWf5enpWSgEurm5cfr0advzo0ePEhISUmjbotrsUVS9GRkZdOrUCXd3d1599VXq16+Pp6cn+/fvZ+DAgbbDtBcTFBRUqM3Nza1E25aHcRG50hSkRKqIu+++m+eff57Jkyfz2muvFdsvODiYoKAgfv/99yJfL81yCSVZl6hatWqA9YTj0goKCiI5OblQe1Ft7u7uhU5yBmswDA4OLtRu75pKxQkKCmLlypWF2ouq1R5F1fvnn39y6NAhFi1aZJuFAjhx4kSZvGdZuNLjInKl6dCeSBVRo0YNHn/8cfr27cvw4cOL7XfTTTdx9OhR8vLyiI2NLfRo0KCBrW9JZyYupn79+rbDjkUFnYvp1q0bmzdvZsOGDQXa//e//xXqGxUVxT///FOgbceOHWzfvr30RduhS5cupKen89tvvxVonzZtWom2d3NzAyjVeJ8NV2e3Pevjjz8u8T6utMsdFxFH04yUSBUyfvz4S/a5/fbb+fbbb+nTpw+jRo3immuuwcXFhQMHDrBw4UL69evHgAEDAGjatCnTpk1j+vTp1K5dG3d3d5o2bVrquj788EP69u1Lu3btGDNmDLVq1SIxMZF58+bx7bffFrvd6NGj+fzzz7nxxht59dVXCQkJ4dtvv2Xbtm2F+g4bNoyhQ4cycuRIbrnlFvbt28dbb71lmxG70oYPH87777/P0KFDefXVV6lbty6//fYb8+bNA8DJ6eJ/1zZp0gSATz75BB8fH9zd3YmOji7y0NtZHTp0ICAggBEjRvDCCy/g4uLCt99+Wyh4OtLljouIo+knVEQKMJvNzJ49m6effpqZM2cyYMAA+vfvz/jx4wsFpZdeeokuXbrwwAMPcM0115T4kvoL9erViyVLlhAWFsajjz5K7969efnlly95nkxoaCiLFy8mJiaGhx56iKFDh+Lu7s4HH3xQqO+QIUN46623mDdvHjfddBMfffQRH330EfXr17er5tLy8vLizz//pGvXrjzxxBPccsstJCYmMmnSJAD8/f0vun10dDQTJkxgw4YNdO3alTZt2vDLL79cdJugoCDmzJmDp6cnQ4cO5d5778Xb25vp06eX1ce6bJc7LiKOZjIMw3B0ESIiVdXrr7/Os88+S2Jiol0n3FdWGhepKHRoT0TkKjk7U9awYUMsFgt//vknEydOZOjQoVU6LGhcpCJTkBIRuUo8PT15//332bt3L9nZ2dSqVYsnn3ySZ5991tGlOZTGRSoyHdoTERERsZNONhcRERGxk4KUiIiIiJ0UpERERETspJPNr6D8/HwOHTqEj49Pmd9uQkRERK4MwzBIT08nPDz8kovCKkhdQYcOHSIiIsLRZYiIiIgd9u/ff8klOBSkrqCzN3dNSEggMDDQwdWUbxaLhfnz59OzZ09cXFwcXU65prEqOY1V6Wi8Sk5jVXIVcazS0tKIiIgo0U3aFaSuoLOH83x8fPD19XVwNeWbxWLB09MTX1/fCvM/mqNorEpOY1U6Gq+S01iVXEUeq5KclqOTzUVERETspCAlIiIiYicFKRERERE76RwpERGRCiY/P5+cnBxHl1EiFosFZ2dnTp8+TV5enqPLAcDFxQWz2Vwm+1KQEhERqUBycnJISEggPz/f0aWUiGEYhIaGsn///nK1pqK/vz+hoaGXXZOClIiISAVhGAZJSUmYzWYiIiIuuVhkeZCfn09GRgbe3t7lol7DMMjKyiIlJQWAsLCwy9qfgpSIiEgFkZubS1ZWFuHh4Xh6ejq6nBI5exjS3d29XAQpAA8
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxQAAAHFCAYAAABrfr8yAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVzN2f/A8dfttmhXaWNSthKFZrJH2UJElmE0g4ZhMIY0MXbZ92XGjH3IYMbMyFgbNIixG/vga4nEiMYaIbe6vz88+vxchQpFvZ+Px33ons/5nHPe96Z7z+csH5VWq9UihBBCCCGEEHmgV9ANEEIIIYQQQry7pEMhhBBCCCGEyDPpUAghhBBCCCHyTDoUQgghhBBCiDyTDoUQQgghhBAiz6RDIYQQQgghhMgz6VAIIYQQQggh8kw6FEIIIYQQQog8kw6FEEIIIYQQIs+kQyGEEOKdFRkZiUqlyvYRHh7+Ruo8deoUERERxMfHv5HyX0V8fDwqlYrIyMiCbkqeRUdHExERUdDNEELkgn5BN0AIIYR4VUuWLKFixYo6aSVLlnwjdZ06dYrRo0fj5+eHi4vLG6kjrxwdHdm7dy/lypUr6KbkWXR0NN9//710KoR4h0iHQgghxDvPw8MDb2/vgm7GK9FoNKhUKvT18/7RbGRkRK1atV5jq/LPgwcPMDExKehmCCHyQKY8CSGEKPR++eUXateujampKWZmZjRt2pQjR47o5Pn777/56KOPcHFxwdjYGBcXFzp16sSlS5eUPJGRkXz44YcANGjQQJlelTnFyMXFhZCQkCz1+/n54efnpzyPjY1FpVKxbNkyvvrqK0qVKoWRkRHnz58H4M8//6RRo0ZYWFhgYmJC3bp12bp160vjzG7KU0REBCqViuPHj/Phhx9iaWmJtbU1YWFhpKWlcebMGZo1a4a5uTkuLi5MmTJFp8zMti5fvpywsDAcHBwwNjbG19c3y2sIsG7dOmrXro2JiQnm5uY0adKEvXv36uTJbNPhw4dp3749VlZWlCtXjpCQEL7//nsAnelrmdPLvv/+e+rXr4+dnR2mpqZ4enoyZcoUNBpNltfbw8ODgwcPUq9ePUxMTChbtiyTJk0iIyNDJ++dO3f46quvKFu2LEZGRtjZ2REQEMD//vc/Jc/jx48ZN24cFStWxMjICFtbWz799FP++++/l74nQhQF0qEQQgjxzktPTyctLU3nkWnChAl06tSJSpUq8euvv7Js2TLu3btHvXr1OHXqlJIvPj4eNzc3Zs2axebNm5k8eTKJiYlUr16dGzduANCiRQsmTJgAPPlyu3fvXvbu3UuLFi3y1O4hQ4aQkJDAvHnzWL9+PXZ2dixfvhx/f38sLCxYunQpv/76K9bW1jRt2jRHnYrn6dChA1WrViUqKooePXowc+ZMBgwYQFBQEC1atOD333+nYcOGfP3116xevTrL+UOHDuXChQssWrSIRYsWcfXqVfz8/Lhw4YKS56effqJ169ZYWFjw888/88MPP3D79m38/PzYtWtXljLbtm1L+fLl+e2335g3bx4jRoygffv2AMpru3fvXhwdHQGIi4sjODiYZcuWsWHDBrp3787UqVP5/PPPs5R97do1Pv74Yz755BPWrVtH8+bNGTJkCMuXL1fy3Lt3Dx8fH+bPn8+nn37K+vXrmTdvHq6uriQmJgKQkZFB69atmTRpEsHBwWzcuJFJkyYRExODn58fDx8+zPN7IkShoRVCCCHeUUuWLNEC2T40Go02ISFBq6+vr/3yyy91zrt3757WwcFB26FDh+eWnZaWpr1//77W1NRU+8033yjpv/32mxbQbt++Pcs5zs7O2q5du2ZJ9/X11fr6+irPt2/frgW09evX18mXkpKitba21gYGBuqkp6ena6tWraqtUaPGC14NrfbixYtaQLtkyRIlbdSoUVpAO336dJ281apV0wLa1atXK2kajUZra2urbdu2bZa2vv/++9qMjAwlPT4+XmtgYKD97LPPlDaWLFlS6+npqU1PT1fy3bt3T2tnZ6etU6dOljaNHDkySwxffPGFNidfT9LT07UajUb7448/atVqtfbWrVvKMV9fXy2g3b9/v845lSpV0jZt2lR5PmbMGC2gjYmJeW49P//8sxbQRkVF6aQfPHhQC2jnzJnz0rYKUdjJCIUQQoh33o8//sjBgwd1Hvr6+mzevJm0tDS6dOmiM3pRrFgxfH19iY2NVcq4f/8+X3/9NeXLl0dfXx99fX3MzMxISUnh9OnTb6Td7dq103m+Z88ebt26RdeuXXXam5GRQbNmzTh48CApKSl5qqtly5Y6z93d3VGpVDRv3lxJ09fXp3z58jrTvDIFBwejUqmU587OztSpU4ft27cDcObMGa5evUrnzp3R0/v/rxdmZma0a9eOffv28eDBgxfG/zJHjhyhVatW2NjYoFarMTAwoEuXLqSnp3P27FmdvA4ODtSoUUMnrUqVKjqx/fHHH7i6utK4cePn1rlhwwaKFy9OYGCgzntSrVo1HBwcdH6HhCiqZFG2EEKId567u3u2i7KvX78OQPXq1bM97+kvvsHBwWzdupURI0ZQvXp1LCwsUKlUBAQEvLFpLZlTeZ5tb+a0n+zcunULU1PTXNdlbW2t89zQ0BATExOKFSuWJT05OTnL+Q4ODtmmHTt2DICbN28CWWOCJztuZWRkcPv2bZ2F19nlfZ6EhATq1auHm5sb33zzDS4uLhQrVowDBw7wxRdfZHmPbGxsspRhZGSkk++///6jdOnSL6z3+vXr3LlzB0NDw2yPZ06HE6Iokw6FEEKIQqtEiRIArFq1Cmdn5+fmu3v3Lhs2bGDUqFEMHjxYSU9NTeXWrVs5rq9YsWKkpqZmSb9x44bSlqc9fcX/6fbOnj37ubs12dvb57g9r9O1a9eyTcv84p75b+bag6ddvXoVPT09rKysdNKfjf9F1qxZQ0pKCqtXr9Z5L48ePZrjMp5la2vLlStXXpinRIkS2NjYsGnTpmyPm5ub57l+IQoL6VAIIYQotJo2bYq+vj5xcXEvnF6jUqnQarUYGRnppC9atIj09HSdtMw82Y1auLi4cPz4cZ20s2fPcubMmWw7FM+qW7cuxYsX59SpU/Tt2/el+fPTzz//TFhYmNIJuHTpEnv27KFLly4AuLm5UapUKX766SfCw8OVfCkpKURFRSk7P73M06+vsbGxkp5Z3tPvkVarZeHChXmOqXnz5owcOZJt27bRsGHDbPO0bNmSlStXkp6eTs2aNfNclxCFmXQohBBCFFouLi6MGTOGYcOGceHCBZo1a4aVlRXXr1/nwIEDmJqaMnr0aCwsLKhfvz5Tp06lRIkSuLi4sGPHDn744QeKFy+uU6aHhwcACxYswNzcnGLFilGmTBlsbGzo3Lkzn3zyCX369KFdu3ZcunSJKVOmYGtrm6P2mpmZMXv2bLp27cqtW7do3749dnZ2/Pfffxw7doz//vuPuXPnvu6XKUeSkpJo06YNPXr04O7du4waNYpixYoxZMgQ4Mn0sSlTpvDxxx/TsmVLPv/8c1JTU5k6dSp37txh0qRJOarH09MTgMmTJ9O8eXPUajVVqlShSZMmGBoa0qlTJwYNGsSjR4+YO3cut2/fznNMoaGh/PLLL7Ru3ZrBgwdTo0YNHj58yI4dO2jZsiUNGjTgo48+YsWKFQQEBNC/f39q1KiBgYEBV65cYfv27bRu3Zo2bdrkuQ1CFAayKFsIIUShNmTIEFatWsXZs2fp2rUrTZs2ZdCgQVy6dIn69esr+X766ScaNGjAoEGDaNu2LX///TcxMTFYWlrqlFemTBlmzZrFsWPH8PPzo3r16qxfvx54sg5jypQpbN68mZYtWzJ37lzmzp2Lq6trjtv7ySefsH37du7fv8/nn39O48aN6d+/P4cPH6ZRo0av50XJgwkTJuDs7Mynn35Kt27dcHR0ZPv27Tp35Q4ODmbNmjXcvHmTjh078umnn2JhYcH27dvx8fHJUT3BwcF89tlnzJkzh9q1a1O9enWuXr1KxYoViYqK4vbt27Rt25Yvv/ySatWq8e233+Y5JnN
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 15
},
{
"cell_type": "code",
"id": "5bb96ca8492e74d",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:42:38.584794600Z",
"start_time": "2025-02-16T08:31:20.861518Z"
}
},
"source": [
"score_df = train_data\n",
"score_df['score'] = light_model9.predict(score_df[feature_columns]) + light_model1.predict(score_df[feature_columns])\n",
"# train_data['score'] = catboost_model.predict(train_data[feature_columns])\n",
"# score_df = score_df[score_df['score'] > 0]\n",
"predictions_train = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
"predictions_train[['trade_date', 'score', 'ts_code']].to_csv('predictions_train.tsv', index=False)\n"
],
"outputs": [],
"execution_count": 16
},
{
"cell_type": "code",
"id": "5d1522a7538db91b",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:42:38.584794600Z",
"start_time": "2025-02-16T08:31:26.697459Z"
}
},
"source": [
"score_df = test_data\n",
"score_df['score'] = light_model9.predict(score_df[feature_columns]) / light_model1.predict(score_df[feature_columns])\n",
"# test_data['score'] = catboost_model.predict(test_data[feature_columns])\n",
"# score_df = score_df[score_df['score'] > 0]\n",
"predictions_test = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)"
],
"outputs": [],
"execution_count": 17
},
{
"cell_type": "code",
"id": "b427ce41-9739-4e9e-bea8-5f2551fec5d7",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-16T08:42:38.584794600Z",
"start_time": "2025-02-16T08:31:28.801289Z"
}
},
"source": [],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}