Files
NewStock/main/train/V1.1.ipynb

1118 lines
300 KiB
Plaintext
Raw Normal View History

2025-02-12 00:21:33 +08:00
{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"jupyter": {
"source_hidden": true
2025-02-15 23:33:34 +08:00
},
"ExecuteTime": {
"end_time": "2025-02-11T16:44:40.335452Z",
"start_time": "2025-02-11T16:44:39.871705Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import pandas as pd\n",
"def read_and_merge_h5_data(h5_filename, key, columns, df=None):\n",
" \"\"\"\n",
" 读取 HDF5 文件中的数据,根据指定的 columns 筛选数据,\n",
" 如果传入 df 参数,则将其与读取的数据根据 ts_code 和 trade_date 合并。\n",
"\n",
" 参数:\n",
" - h5_filename: HDF5 文件名\n",
" - key: 数据存储在 HDF5 文件中的 key\n",
" - columns: 要读取的列名列表\n",
" - df: 需要合并的 DataFrame如果为空则不进行合并\n",
"\n",
" 返回:\n",
" - 合并后的 DataFrame\n",
" \"\"\"\n",
" # 处理 _ 开头的列名\n",
" processed_columns = []\n",
" for col in columns:\n",
" if col.startswith('_'):\n",
" processed_columns.append(col[1:]) # 去掉下划线\n",
" else:\n",
" processed_columns.append(col)\n",
"\n",
" # 从 HDF5 文件读取数据,选择需要的列\n",
" data = pd.read_hdf(h5_filename, key=key, columns=processed_columns)\n",
"\n",
" # 修改列名,如果列名以前有 _加上 _\n",
" for col in data.columns:\n",
" if col not in columns: # 只有不在 columns 中的列才需要加下划线\n",
" new_col = f'_{col}'\n",
" data.rename(columns={col: new_col}, inplace=True)\n",
"\n",
" # 如果传入的 df 不为空,则进行合并\n",
" if df is not None and not df.empty:\n",
" # 确保两个 DataFrame 都有 ts_code 和 trade_date 列\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d')\n",
"\n",
" # 根据 ts_code 和 trade_date 合并\n",
" merged_df = pd.merge(df, data, on=['ts_code', 'trade_date'], how='left')\n",
" else:\n",
" # 如果 df 为空,则直接返回读取的数据\n",
" merged_df = data\n",
"\n",
" return merged_df\n",
"\n"
2025-02-15 23:33:34 +08:00
],
"outputs": [],
"execution_count": 1
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
2025-02-15 23:33:34 +08:00
"end_time": "2025-02-11T16:45:23.542844Z",
"start_time": "2025-02-11T16:44:40.341453Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df)\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)"
2025-02-15 23:33:34 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"stk limit\n",
"money flow\n"
]
}
],
"execution_count": 2
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "c4e9e1d31da6dba6",
"metadata": {
"jupyter": {
"source_hidden": true
2025-02-15 23:33:34 +08:00
},
"ExecuteTime": {
"end_time": "2025-02-11T16:45:23.963120Z",
"start_time": "2025-02-11T16:45:23.605890Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"origin_columns = df.columns.tolist()"
2025-02-15 23:33:34 +08:00
],
"outputs": [],
"execution_count": 3
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "a735bc02ceb4d872",
"metadata": {
"jupyter": {
"source_hidden": true
2025-02-15 23:33:34 +08:00
},
"ExecuteTime": {
"end_time": "2025-02-11T16:45:24.539606Z",
"start_time": "2025-02-11T16:45:24.462439Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"def get_technical_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 计算 up 和 down\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14), index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算 RSI\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().shift(-1).rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().shift(-1).rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().shift(-1).rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().shift(-1).rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().shift(-1).rolling(window=90).std())\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
" # 计算 EMA 指标\n",
" df['ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(df['act_factor1']**2 + df['act_factor2']**2)\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
" df['alpha_022'] = grouped['close'].apply(lambda x: x - x.shift(5))\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" # alpha_007: 计算过去5日内 close 与 vol 的相关性,再按 trade_date 横截面排名\n",
" df['alpha_007'] = grouped.apply(\n",
" lambda x: pd.Series(x['close'].rolling(5).corr(x['vol']), index=x.index)\n",
" ).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,再按 trade_date 横截面排名\n",
" df['alpha_013'] = grouped.apply(\n",
" lambda x: pd.Series(x['close'].rolling(5).sum() - x['close'].rolling(20).sum(), index=x.index)\n",
" ).reset_index(level=0, drop=True)\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_future_data(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" # 预先对 ts_code 分组,使用 transform 保持原 DataFrame 形状\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" df['future_return1'] = (grouped['close'].transform(lambda x: x.shift(-1)) - df['close']) / df['close']\n",
" df['future_return2'] = (grouped['open'].transform(lambda x: x.shift(-2)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
" df['future_return3'] = (grouped['close'].transform(lambda x: x.shift(-2)) - grouped['close'].transform(lambda x: x.shift(-1))) / grouped['close'].transform(lambda x: x.shift(-1))\n",
" df['future_return4'] = (grouped['close'].transform(lambda x: x.shift(-2)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
" df['future_return5'] = (grouped['close'].transform(lambda x: x.shift(-5)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
" df['future_return6'] = (grouped['close'].transform(lambda x: x.shift(-10)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
" df['future_return7'] = (grouped['close'].transform(lambda x: x.shift(-20)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
"\n",
" df['future_close1'] = (grouped['close'].transform(lambda x: x.shift(-1)) - df['close']) / df['close']\n",
" df['future_close2'] = (grouped['close'].transform(lambda x: x.shift(-2)) - df['close']) / df['close']\n",
" df['future_close3'] = (grouped['close'].transform(lambda x: x.shift(-3)) - df['close']) / df['close']\n",
" df['future_close4'] = (grouped['close'].transform(lambda x: x.shift(-4)) - df['close']) / df['close']\n",
" df['future_close5'] = (grouped['close'].transform(lambda x: x.shift(-5)) - df['close']) / df['close']\n",
"\n",
" df['future_af11'] = grouped['act_factor1'].transform(lambda x: x.shift(-1))\n",
" df['future_af12'] = grouped['act_factor1'].transform(lambda x: x.shift(-2))\n",
" df['future_af13'] = grouped['act_factor1'].transform(lambda x: x.shift(-3))\n",
" df['future_af14'] = grouped['act_factor1'].transform(lambda x: x.shift(-4))\n",
" df['future_af15'] = grouped['act_factor1'].transform(lambda x: x.shift(-5))\n",
"\n",
" df['future_af21'] = grouped['act_factor2'].transform(lambda x: x.shift(-1))\n",
" df['future_af22'] = grouped['act_factor2'].transform(lambda x: x.shift(-2))\n",
" df['future_af23'] = grouped['act_factor2'].transform(lambda x: x.shift(-3))\n",
" df['future_af24'] = grouped['act_factor2'].transform(lambda x: x.shift(-4))\n",
" df['future_af25'] = grouped['act_factor2'].transform(lambda x: x.shift(-5))\n",
"\n",
" df['future_af31'] = grouped['act_factor3'].transform(lambda x: x.shift(-1))\n",
" df['future_af32'] = grouped['act_factor3'].transform(lambda x: x.shift(-2))\n",
" df['future_af33'] = grouped['act_factor3'].transform(lambda x: x.shift(-3))\n",
" df['future_af34'] = grouped['act_factor3'].transform(lambda x: x.shift(-4))\n",
" df['future_af35'] = grouped['act_factor3'].transform(lambda x: x.shift(-5))\n",
"\n",
" return df\n"
2025-02-15 23:33:34 +08:00
],
"outputs": [],
"execution_count": 4
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "53f86ddc0677a6d7",
"metadata": {
2025-02-15 23:33:34 +08:00
"scrolled": true,
2025-02-12 00:21:33 +08:00
"ExecuteTime": {
2025-02-15 23:33:34 +08:00
"end_time": "2025-02-11T16:47:14.618805Z",
"start_time": "2025-02-11T16:45:24.573259Z"
}
2025-02-12 00:21:33 +08:00
},
2025-02-15 23:33:34 +08:00
"source": [
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_future_data(df)\n",
"# df = df.drop(columns=origin_columns)\n",
"\n",
"print(df.info())"
],
2025-02-12 00:21:33 +08:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
2025-02-15 23:33:34 +08:00
"Index: 8375079 entries, 1964 to 8375077\n",
2025-02-12 00:21:33 +08:00
"Data columns (total 87 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st object \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
" 21 up float64 \n",
" 22 down float64 \n",
" 23 atr_14 float64 \n",
" 24 atr_6 float64 \n",
" 25 obv float64 \n",
" 26 maobv_6 float64 \n",
" 27 obv-maobv_6 float64 \n",
" 28 rsi_3 float64 \n",
" 29 rsi_6 float64 \n",
" 30 rsi_9 float64 \n",
" 31 return_10 float64 \n",
" 32 return_20 float64 \n",
" 33 avg_close_5 float64 \n",
" 34 std_return_5 float64 \n",
" 35 std_return_15 float64 \n",
" 36 std_return_25 float64 \n",
" 37 std_return_90 float64 \n",
" 38 std_return_90_2 float64 \n",
" 39 std_return_5 / std_return_90 float64 \n",
" 40 std_return_5 / std_return_25 float64 \n",
" 41 std_return_90 - std_return_90_2 float64 \n",
" 42 ema_5 float64 \n",
" 43 ema_13 float64 \n",
" 44 ema_20 float64 \n",
" 45 ema_60 float64 \n",
" 46 act_factor1 float64 \n",
" 47 act_factor2 float64 \n",
" 48 act_factor3 float64 \n",
" 49 act_factor4 float64 \n",
" 50 act_factor5 float64 \n",
" 51 act_factor6 float64 \n",
" 52 rank_act_factor1 float64 \n",
" 53 rank_act_factor2 float64 \n",
" 54 rank_act_factor3 float64 \n",
" 55 active_buy_volume_large float64 \n",
" 56 active_buy_volume_big float64 \n",
" 57 active_buy_volume_small float64 \n",
" 58 buy_lg_vol_minus_sell_lg_vol float64 \n",
" 59 buy_elg_vol_minus_sell_elg_vol float64 \n",
" 60 future_return1 float64 \n",
" 61 future_return2 float64 \n",
" 62 future_return3 float64 \n",
" 63 future_return4 float64 \n",
" 64 future_return5 float64 \n",
" 65 future_return6 float64 \n",
" 66 future_return7 float64 \n",
" 67 future_close1 float64 \n",
" 68 future_close2 float64 \n",
" 69 future_close3 float64 \n",
" 70 future_close4 float64 \n",
" 71 future_close5 float64 \n",
" 72 future_af11 float64 \n",
" 73 future_af12 float64 \n",
" 74 future_af13 float64 \n",
" 75 future_af14 float64 \n",
" 76 future_af15 float64 \n",
" 77 future_af21 float64 \n",
" 78 future_af22 float64 \n",
" 79 future_af23 float64 \n",
" 80 future_af24 float64 \n",
" 81 future_af25 float64 \n",
" 82 future_af31 float64 \n",
" 83 future_af32 float64 \n",
" 84 future_af33 float64 \n",
" 85 future_af34 float64 \n",
" 86 future_af35 float64 \n",
"dtypes: datetime64[ns](1), float64(84), object(2)\n",
"memory usage: 5.5+ GB\n",
"None\n"
]
}
],
2025-02-15 23:33:34 +08:00
"execution_count": 5
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "dbe2fd8021b9417f",
"metadata": {
"jupyter": {
"source_hidden": true
},
2025-02-15 23:33:34 +08:00
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-02-11T16:48:20.074570Z",
"start_time": "2025-02-11T16:47:14.683689Z"
2025-02-12 00:21:33 +08:00
}
2025-02-15 23:33:34 +08:00
},
2025-02-12 00:21:33 +08:00
"source": [
"def filter_data(df):\n",
" df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor3'))\n",
" df = df[df['is_st'] == False]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"print(df.info())"
2025-02-15 23:33:34 +08:00
],
2025-02-12 00:21:33 +08:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-02-15 23:33:34 +08:00
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1275521 entries, 0 to 1275520\n",
"Data columns (total 87 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 ts_code 1275521 non-null object \n",
" 1 trade_date 1275521 non-null datetime64[ns]\n",
" 2 open 1275521 non-null float64 \n",
" 3 close 1275521 non-null float64 \n",
" 4 high 1275521 non-null float64 \n",
" 5 low 1275521 non-null float64 \n",
" 6 vol 1275521 non-null float64 \n",
" 7 turnover_rate 1275521 non-null float64 \n",
" 8 pe_ttm 1108329 non-null float64 \n",
" 9 circ_mv 1275521 non-null float64 \n",
" 10 volume_ratio 1275360 non-null float64 \n",
" 11 is_st 1275521 non-null object \n",
" 12 up_limit 1275193 non-null float64 \n",
" 13 down_limit 1275193 non-null float64 \n",
" 14 buy_sm_vol 1274976 non-null float64 \n",
" 15 sell_sm_vol 1274976 non-null float64 \n",
" 16 buy_lg_vol 1274976 non-null float64 \n",
" 17 sell_lg_vol 1274976 non-null float64 \n",
" 18 buy_elg_vol 1274976 non-null float64 \n",
" 19 sell_elg_vol 1274976 non-null float64 \n",
" 20 net_mf_vol 1274976 non-null float64 \n",
" 21 up 1275521 non-null float64 \n",
" 22 down 1275521 non-null float64 \n",
" 23 atr_14 1261700 non-null float64 \n",
" 24 atr_6 1269533 non-null float64 \n",
" 25 obv 1275521 non-null float64 \n",
" 26 maobv_6 1270528 non-null float64 \n",
" 27 obv-maobv_6 1270528 non-null float64 \n",
" 28 rsi_3 1272523 non-null float64 \n",
" 29 rsi_6 1269533 non-null float64 \n",
" 30 rsi_9 1266569 non-null float64 \n",
" 31 return_10 1265585 non-null float64 \n",
" 32 return_20 1256061 non-null float64 \n",
" 33 avg_close_5 1271525 non-null float64 \n",
" 34 std_return_5 1271155 non-null float64 \n",
" 35 std_return_15 1261330 non-null float64 \n",
" 36 std_return_25 1250488 non-null float64 \n",
" 37 std_return_90 1182896 non-null float64 \n",
" 38 std_return_90_2 1172558 non-null float64 \n",
" 39 std_return_5 / std_return_90 1182896 non-null float64 \n",
" 40 std_return_5 / std_return_25 1250488 non-null float64 \n",
" 41 std_return_90 - std_return_90_2 1172558 non-null float64 \n",
" 42 ema_5 1271525 non-null float64 \n",
" 43 ema_13 1263633 non-null float64 \n",
" 44 ema_20 1256985 non-null float64 \n",
" 45 ema_60 1214858 non-null float64 \n",
" 46 act_factor1 1270528 non-null float64 \n",
" 47 act_factor2 1262663 non-null float64 \n",
" 48 act_factor3 1256061 non-null float64 \n",
" 49 act_factor4 1213762 non-null float64 \n",
" 50 act_factor5 1213762 non-null float64 \n",
" 51 act_factor6 1262663 non-null float64 \n",
" 52 rank_act_factor1 1270528 non-null float64 \n",
" 53 rank_act_factor2 1262663 non-null float64 \n",
" 54 rank_act_factor3 1256061 non-null float64 \n",
" 55 active_buy_volume_large 1274968 non-null float64 \n",
" 56 active_buy_volume_big 1274948 non-null float64 \n",
" 57 active_buy_volume_small 1274976 non-null float64 \n",
" 58 buy_lg_vol_minus_sell_lg_vol 1274968 non-null float64 \n",
" 59 buy_elg_vol_minus_sell_elg_vol 1274953 non-null float64 \n",
" 60 future_return1 1275151 non-null float64 \n",
" 61 future_return2 1274804 non-null float64 \n",
" 62 future_return3 1274804 non-null float64 \n",
" 63 future_return4 1274804 non-null float64 \n",
" 64 future_return5 1273605 non-null float64 \n",
" 65 future_return6 1271215 non-null float64 \n",
" 66 future_return7 1265850 non-null float64 \n",
" 67 future_close1 1275151 non-null float64 \n",
" 68 future_close2 1274804 non-null float64 \n",
" 69 future_close3 1274411 non-null float64 \n",
" 70 future_close4 1274026 non-null float64 \n",
" 71 future_close5 1273605 non-null float64 \n",
" 72 future_af11 1271155 non-null float64 \n",
" 73 future_af12 1271806 non-null float64 \n",
" 74 future_af13 1272411 non-null float64 \n",
" 75 future_af14 1273024 non-null float64 \n",
" 76 future_af15 1273605 non-null float64 \n",
" 77 future_af21 1263263 non-null float64 \n",
" 78 future_af22 1263890 non-null float64 \n",
" 79 future_af23 1264475 non-null float64 \n",
" 80 future_af24 1265074 non-null float64 \n",
" 81 future_af25 1265639 non-null float64 \n",
" 82 future_af31 1256615 non-null float64 \n",
" 83 future_af32 1257201 non-null float64 \n",
" 84 future_af33 1257745 non-null float64 \n",
" 85 future_af34 1258301 non-null float64 \n",
" 86 future_af35 1258826 non-null float64 \n",
"dtypes: datetime64[ns](1), float64(84), object(2)\n",
"memory usage: 846.6+ MB\n",
"None\n"
2025-02-12 00:21:33 +08:00
]
}
],
2025-02-15 23:33:34 +08:00
"execution_count": 6
},
{
"cell_type": "code",
"id": "5f3d9aece75318cd",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-11T16:48:20.605749Z",
"start_time": "2025-02-11T16:48:20.177294Z"
}
},
2025-02-12 00:21:33 +08:00
"source": [
"def remove_outliers_iqr(series, lower_quantile=0.05, upper_quantile=0.95, threshold=1.5):\n",
" Q1 = series.quantile(lower_quantile)\n",
" Q3 = series.quantile(upper_quantile)\n",
" IQR = Q3 - Q1\n",
" lower_bound = Q1 - threshold * IQR\n",
" upper_bound = Q3 + threshold * IQR\n",
" # 过滤掉低于下边界或高于上边界的极值\n",
" return (series >= lower_bound) & (series <= upper_bound)\n",
"\n",
"\n",
"def neutralize_labels(labels, features, feature_columns, z_threshold=3, method='regression'):\n",
" labels_no_outliers = remove_outliers_iqr(labels)\n",
" return labels_no_outliers\n",
"\n",
"\n",
"train_data = df[df['trade_date'] <= '2024-01-01']\n",
"test_data = df[df['trade_date'] >= '2024-01-01']\n",
"\n",
"feature_columns = [col for col in df.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"\n",
"# for column in [column for column in train_data.columns if 'future' in column]:\n",
"# label_index = neutralize_labels(train_data[column], train_data, feature_columns, z_threshold=3, method='regression')\n",
"# train_data = train_data[label_index]\n",
"# label_index = neutralize_labels(test_data[column], test_data, feature_columns, z_threshold=3, method='regression')\n",
"# test_data = test_data[label_index]\n",
"\n",
"print(feature_columns)\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")"
2025-02-15 23:33:34 +08:00
],
2025-02-12 00:21:33 +08:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-02-15 23:33:34 +08:00
"['up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol']\n",
"最小日期: 2017-01-03\n",
2025-02-12 00:21:33 +08:00
"最大日期: 2023-12-29\n",
2025-02-15 23:33:34 +08:00
"138753\n",
2025-02-12 00:21:33 +08:00
"最小日期: 2024-01-02\n",
2025-02-15 23:33:34 +08:00
"最大日期: 2025-02-11\n"
2025-02-12 00:21:33 +08:00
]
}
],
2025-02-15 23:33:34 +08:00
"execution_count": 7
},
{
"cell_type": "code",
"id": "f4f16d63ad18d1bc",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-11T16:48:22.847190Z",
"start_time": "2025-02-11T16:48:20.619639Z"
}
},
2025-02-12 00:21:33 +08:00
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"window = 5\n",
"quantiles = 20\n",
"\n",
"\n",
"def get_label(df):\n",
" labels = df['future_af11'] - df['act_factor1']\n",
" # labels = df['future_close3']\n",
" return labels\n",
"\n",
"# train_data = get_future_data(train_data)\n",
"train_data['label'] = get_label(train_data)\n",
"test_data['label'] = get_label(test_data)\n",
"\n",
"train_data = train_data.dropna(subset=['label'])\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"train_data = train_data.replace([np.inf, -np.inf], np.nan).dropna()\n",
"# test_data = test_data.replace([np.inf, -np.inf], np.nan).dropna()\n",
"\n",
"# train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"# train_data = train_data.dropna(subset=['label'])\n",
"# train_data = train_data.dropna(subset=feature_columns)\n",
"# # test_data = test_data.dropna(subset=feature_columns)\n",
"train_data = train_data.reset_index(drop=True)\n",
"# test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n"
2025-02-15 23:33:34 +08:00
],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_5800\\2658667834.py:16: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" train_data['label'] = get_label(train_data)\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_5800\\2658667834.py:17: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" test_data['label'] = get_label(test_data)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"902614\n",
"最小日期: 2017-06-05\n",
"最大日期: 2023-12-29\n",
"138753\n",
"最小日期: 2024-01-02\n",
"最大日期: 2025-02-11\n"
]
}
],
"execution_count": 8
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "8f134d435f71e9e2",
"metadata": {
"jupyter": {
"source_hidden": true
2025-02-15 23:33:34 +08:00
},
"ExecuteTime": {
"end_time": "2025-02-11T16:48:24.323978Z",
"start_time": "2025-02-11T16:48:22.880681Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"import lightgbm as lgb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import optuna\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.metrics import mean_absolute_error\n",
"import os\n",
"import json\n",
"import pickle\n",
"import hashlib\n",
"\n",
"\n",
"def objective(trial, X, y, num_boost_round, params):\n",
" # 参数网格\n",
" X, y = X.reset_index(drop=True), y.reset_index(drop=True)\n",
" param_grid = {\n",
" \"n_estimators\": trial.suggest_categorical(\"n_estimators\", [10000]),\n",
" \"learning_rate\": trial.suggest_float(\"learning_rate\", 0.01, 0.3),\n",
" \"num_leaves\": trial.suggest_int(\"num_leaves\", 20, 3000, step=25),\n",
" \"max_depth\": trial.suggest_int(\"max_depth\", 3, 16),\n",
" \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 200, 10000, step=100),\n",
" \"lambda_l1\": trial.suggest_int(\"lambda_l1\", 0, 100, step=5),\n",
" \"lambda_l2\": trial.suggest_int(\"lambda_l2\", 0, 100, step=5),\n",
" \"min_gain_to_split\": trial.suggest_float(\"min_gain_to_split\", 0, 15),\n",
" \"bagging_fraction\": trial.suggest_float(\"bagging_fraction\", 0.2, 0.95, step=0.1),\n",
" \"bagging_freq\": trial.suggest_categorical(\"bagging_freq\", [1]),\n",
" \"feature_fraction\": trial.suggest_float(\"feature_fraction\", 0.2, 0.95, step=0.1),\n",
" \"random_state\": 1,\n",
" \"objective\": 'regression',\n",
" 'verbosity': -1\n",
" }\n",
" # 5折交叉验证\n",
" cv = KFold(n_splits=5, shuffle=False)\n",
"\n",
" cv_scores = np.empty(5)\n",
" for idx, (train_idx, test_idx) in enumerate(cv.split(X, y)):\n",
" X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]\n",
" y_train, y_test = y[train_idx], y[test_idx]\n",
"\n",
" # LGBM建模\n",
" model = lgb.LGBMRegressor(**param_grid, num_boost_round=num_boost_round)\n",
" model.fit(\n",
" X_train,\n",
" y_train,\n",
" eval_set=[(X_test, y_test)],\n",
" eval_metric=\"l2\",\n",
" callbacks=[\n",
" # LightGBMPruningCallback(trial, \"l2\"),\n",
" lgb.early_stopping(50, first_metric_only=True),\n",
" lgb.log_evaluation(period=-1)\n",
" ],\n",
" )\n",
" # 模型预测\n",
" preds = model.predict(X_test)\n",
" # 优化指标logloss最小\n",
" cv_scores[idx] = mean_absolute_error(y_test, preds)\n",
"\n",
" return np.mean(cv_scores)\n",
"\n",
"def generate_key(params, feature_columns, num_boost_round):\n",
" key_data = {\n",
" \"params\": params,\n",
" \"feature_columns\": feature_columns,\n",
" \"num_boost_round\": num_boost_round\n",
" }\n",
" # 转换成排序后的 JSON 字符串,再生成 md5 hash\n",
" key_str = json.dumps(key_data, sort_keys=True)\n",
" return hashlib.md5(key_str.encode('utf-8')).hexdigest()\n",
"\n",
"def train_light_model(df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" use_optuna=False):\n",
" cache_file = 'light_model.pkl'\n",
" cache_key = generate_key(params, feature_columns, num_boost_round)\n",
"\n",
" # 检查缓存文件是否存在\n",
" if os.path.exists(cache_file):\n",
" try:\n",
" with open(cache_file, 'rb') as f:\n",
" cache_data = pickle.load(f)\n",
" if cache_data.get('key') == cache_key:\n",
" print(\"加载缓存模型...\")\n",
" return cache_data.get('model')\n",
" else:\n",
" print(\"缓存模型的参数与当前参数不匹配,重新训练模型。\")\n",
" except Exception as e:\n",
" print(f\"加载缓存失败: {e},重新训练模型。\")\n",
" else:\n",
" print(\"未发现缓存模型,开始训练新模型。\")\n",
" # 确保数据按照 date 和 label 排序\n",
" df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False]) # 按日期升序、标签降序排序\n",
" df_sorted = df_sorted.sort_values(by='trade_date')\n",
" unique_dates = df_sorted['trade_date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices]\n",
" val_df = df_sorted.iloc[val_indices]\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" train_data = lgb.Dataset(X_train, label=y_train)\n",
" val_data = lgb.Dataset(X_val, label=y_val)\n",
" if use_optuna:\n",
" # study = optuna.create_study(direction='minimize' if classify else 'maximize')\n",
" study = optuna.create_study(direction='minimize')\n",
" study.optimize(lambda trial: objective(trial, X_train, y_train, num_boost_round, params), n_trials=20)\n",
"\n",
" print(f\"Best parameters: {study.best_trial.params}\")\n",
" print(f\"Best score: {study.best_trial.value}\")\n",
"\n",
" params.update(study.best_trial.params)\n",
" model = lgb.train(\n",
" params, train_data, num_boost_round=num_boost_round,\n",
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_tree(model, figsize=(20, 8))\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
" # with open(cache_file, 'wb') as f:\n",
" # pickle.dump({'key': cache_key,\n",
" # 'model': model,\n",
" # 'feature_columns': feature_columns}, f)\n",
" # print(\"模型训练完成并已保存缓存。\")\n",
" return model\n",
"\n",
"\n",
"from catboost import CatBoostRegressor\n",
"import pandas as pd\n",
"\n",
"\n",
"def train_catboost(df, num_boost_round, params=None):\n",
" \"\"\"\n",
" 训练 CatBoost 排序模型\n",
" - df: 包含因子、date、instrument 和 label 的 DataFrame\n",
" - num_boost_round: 训练的轮数\n",
" - print_feature_importance: 是否打印特征重要性\n",
" - plot: 是否绘制特征重要性图\n",
" - split_date: 用于划分训练集和验证集的日期(比如 '2020-01-01'\n",
"\n",
" 返回训练好的模型\n",
" \"\"\"\n",
" df_sorted = df.sort_values(by=['date', 'label'], ascending=[True, False])\n",
"\n",
" # 提取特征和标签\n",
" feature_columns = [col for col in df.columns if col not in ['date',\n",
" 'instrument',\n",
" 'label']]\n",
" feature_columns = [col for col in feature_columns if 'future' not in col]\n",
" feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"\n",
" df_sorted = df_sorted.sort_values(by='date')\n",
" unique_dates = df_sorted['date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
" val_df = df_sorted.iloc[val_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" model = CatBoostRegressor(**params, iterations=num_boost_round)\n",
" model.fit(X_train,\n",
" y_train,\n",
" eval_set=(X_val, y_val))\n",
"\n",
" return model"
2025-02-15 23:33:34 +08:00
],
"outputs": [],
"execution_count": 9
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "4a4542e1ed6afe7d",
"metadata": {
"ExecuteTime": {
2025-02-15 23:33:34 +08:00
"end_time": "2025-02-11T16:48:24.622644Z",
"start_time": "2025-02-11T16:48:24.550645Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"light_params = {\n",
" 'objective': 'regression',\n",
" 'metric': 'l2',\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 16,\n",
" 'max_depth': 10,\n",
" 'max_bin': 1024,\n",
" 'nthread': 2,\n",
" 'feature_fraction': 0.7,\n",
" 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" 'lambda_l1': 80,\n",
" 'lambda_l2': 65,\n",
" 'verbosity': -1\n",
"}"
2025-02-15 23:33:34 +08:00
],
"outputs": [],
"execution_count": 10
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
2025-02-15 23:33:34 +08:00
"end_time": "2025-02-11T16:50:03.925800Z",
"start_time": "2025-02-11T16:48:24.655023Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"print('train data size: ', len(train_data))\n",
"\n",
"evals = {}\n",
"light_model = train_light_model(train_data, light_params, feature_columns,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=1000, use_optuna=False,\n",
" print_feature_importance=True)"
2025-02-15 23:33:34 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 902614\n",
"未发现缓存模型,开始训练新模型。\n",
"Training until validation scores don't improve for 50 rounds\n",
"[500]\ttrain's l2: 0.309611\tvalid's l2: 0.256064\n",
"[1000]\ttrain's l2: 0.29282\tvalid's l2: 0.253211\n",
"Did not meet early stopping. Best iteration is:\n",
"[975]\ttrain's l2: 0.293549\tvalid's l2: 0.25321\n",
"Evaluated only: l2\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlIAAAHFCAYAAAA5VBcVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABgfElEQVR4nO3deXhTVeI+8Dd7mu4LXYBu7EtZC0hBNlmtIC4osgkqKoM4IOMojPgTGASccQBxgBG/KuOODuCKYpFdEBBbUDYRgbK0lK5pmzZJk/v74yZp03RJQ9sb2vfzPPdJ7sldTnIEXs8991yZIAgCiIiIiKjO5FJXgIiIiOhWxSBFRERE5CEGKSIiIiIPMUgREREReYhBioiIiMhDDFJEREREHmKQIiIiIvIQgxQRERGRhxikiIiIiDzEIEXkZTZt2gSZTAaZTIY9e/a4fC4IAtq1aweZTIahQ4d6dI7169dj06ZNddpnz5491dapvjTUORqj7tU5deoUFi9ejIsXLzbI8RcvXgyZTObRvlL+LkRNBYMUkZfy9/fHW2+95VK+d+9enD9/Hv7+/h4f25Mg1bt3bxw6dAi9e/f2+LxSkbLup06dwpIlSxosSM2cOROHDh3yaN9buU2JvAWDFJGXmjhxIrZs2QK9Xu9U/tZbbyEpKQkxMTGNUg+z2YyysjIEBASgf//+CAgIaJTz1odbse4Gg6FO27du3Rr9+/f36Fy30u9C5K0YpIi81KRJkwAAH330kaOsoKAAW7ZswaOPPlrlPiaTCcuWLUOnTp2g0WjQokULPPLII7hx44Zjm7i4OJw8eRJ79+51XEKMi4sDUH6p57333sNf/vIXtGrVChqNBr///nu1l4EOHz6McePGITQ0FFqtFm3btsW8efNq/X5nzpzBmDFjoNPpEBYWhlmzZqGwsNBlu7i4OMyYMcOlfOjQoU6XNuta9xkzZsDPzw+///47kpOT4efnh+joaPzlL3+B0Wh0OteVK1cwYcIE+Pv7IygoCFOmTMHRo0chk8lq7NnbtGkTHnjgAQDAsGHDHL+3fZ+hQ4ciISEB+/btw4ABA6DT6Rxtu3nzZowaNQpRUVHw8fFB586dsWDBAhQXFzudo6pLe3FxcRg7diy+/fZb9O7dGz4+PujUqRPefvttp+2k+l2ImhIGKSIvFRAQgAkTJjj94/fRRx9BLpdj4sSJLttbrVaMHz8eK1euxOTJk/H1119j5cqVSElJwdChQ1FSUgIA2LZtG9q0aYNevXrh0KFDOHToELZt2+Z0rIULFyI9PR3/+c9/8OWXXyI8PLzKOu7YsQODBg1Ceno6Vq1ahW+++QaLFi3C9evXa/xu169fx5AhQ/Drr79i/fr1eO+991BUVIQ5c+bU9Wdy4W7dAbHH6u6778bw4cPx+eef49FHH8Xq1avxyiuvOLYpLi7GsGHDsHv3brzyyiv45JNPEBERUWUbVHbXXXdh+fLlAIB169Y5fu+77rrLsU1GRgamTp2KyZMnY/v27Zg9ezYA4Ny5c0hOTsZbb72Fb7/9FvPmzcMnn3yCcePGufU7HD9+HH/5y1/wzDPP4PPPP0f37t3x2GOPYd++fbXu29C/C1GTIhCRV3nnnXcEAMLRo0eF3bt3CwCEX3/9VRAEQejbt68wY8YMQRAEoWvXrsKQIUMc+3300UcCAGHLli1Oxzt69KgAQFi/fr2jrPK+dvbzDR48uNrPdu/e7Shr27at0LZtW6GkpKRO3/H5558XZDKZkJaW5lQ+cuRIl3PExsYK06dPdznGkCFDnL5DXes+ffp0AYDwySefOG2bnJwsdOzY0bG+bt06AYDwzTffOG335JNPCgCEd955p8bv+umnn7qcu+J3ACB8//33NR7DarUKZrNZ2Lt3rwBAOH78uOOzl156Saj8V3lsbKyg1WqFS5cuOcpKSkqEkJAQ4cknn3SUSfm7EDUV7JEi8mJDhgxB27Zt8fbbb+OXX37B0aNHq72s99VXXyEoKAjjxo1DWVmZY+nZsyciIyPrdGfW/fffX+s2v/32G86fP4/HHnsMWq3W7WMDwO7du9G1a1f06NHDqXzy5Ml1Ok5V3Km7nUwmc+nh6d69Oy5duuRY37t3L/z9/TFmzBin7eyXXm9WcHAw7rjjDpfyP/74A5MnT0ZkZCQUCgVUKhWGDBkCADh9+nStx+3Zs6fTODqtVosOHTo4fbfqeMPvQnSrUEpdASKqnkwmwyOPPIK1a9eitLQUHTp0wKBBg6rc9vr168jPz4dara7y8+zsbLfPGxUVVes29nFXrVu3dvu4djk5OYiPj3cpj4yMrPOxKnOn7nY6nc4lBGo0GpSWljrWc3JyEBER4bJvVWWeqKq+RUVFGDRoELRaLZYtW4YOHTpAp9Ph8uXLuO+++xyXaWsSGhrqUqbRaNza1xt+F6JbBYMUkZebMWMG/t//+3/4z3/+g5dffrna7cLCwhAaGopvv/22ys/rMl2CO/MStWjRAoA44LiuQkNDkZmZ6VJeVZlWq3UZ5AyIwTAsLMyl3NM5laoTGhqKI0eOuJRXVVdPVFXfXbt24dq1a9izZ4+jFwoA8vPz6+Wc9aGhfxeiWwUv7RF5uVatWuGvf/0rxo0bh+nTp1e73dixY5GTkwOLxYI+ffq4LB07dnRs627PRE06dOjguOxYVdCpybBhw3Dy5EkcP37cqfzDDz902TYuLg4nTpxwKvvtt99w9uzZulfaA0OGDEFhYSG++eYbp/KPP/7Yrf01Gg0A1On3tocr+752b7zxhtvHaGg3+7sQNRXskSK6BaxcubLWbR566CF88MEHSE5Oxty5c9GvXz+oVCpcuXIFu3fvxvjx43HvvfcCALp164aPP/4YmzdvRps2baDVatGtW7c612vdunUYN24c+vfvj2eeeQYxMTFIT0/Hjh078MEHH1S737x58/D222/jrrvuwrJlyxAREYEPPvgAZ86ccdl22rRpmDp1KmbPno37778fly5dwj/+8Q9Hj1hDmz59OlavXo2pU6di2bJlaNeuHb755hvs2LEDACCX1/z/owkJCQCAjRs3wt/fH1qtFvHx8VVeerMbMGAAgoODMWvWLLz00ktQqVT44IMPXIKnlG72dyFqKvhfOlEToVAo8MUXX+Bvf/sbtm7dinvvvRf33HMPVq5c6RKUlixZgiFDhuDxxx9Hv3793L6lvrLRo0dj3759iIqKwp///GeMGTMGS5curXWcTGRkJPbu3YsuXbrgT3/6E6ZOnQqtVot///vfLttOnjwZ//jHP7Bjxw6MHTsWGzZswIYNG9ChQweP6lxXvr6+2LVrF4YOHYrnnnsO999/P9LT07F+/XoAQFBQUI37x8fHY82aNTh+/DiGDh2Kvn374ssvv6xxn9DQUHz99dfQ6XSYOnUqHn30Ufj5+WHz5s319bVu2s3+LkRNhUwQBEHqShAR3WqWL1+ORYsWIT093aMB900Vfxdqbnhpj4ioFvaesk6dOsFsNmPXrl1Yu3Ytpk6d2qzDAn8XIgYpIqJa6XQ6rF69GhcvXoTRaERMTAyef/55LFq0SOqqSYq/CxEv7RERERF5jIPNiYiIiDzEIEVERETkIQYpIiIiIg9xsHkVrFYrrl27Bn9//3p/3AQRERE1DEEQUFhYiJYtWzbapLAMUlW4du0aoqOjpa4GEREReeDy5cuNNgUHg1QV7A93vXDhAkJCQiSuTfNmNpvx3XffYdSoUVCpVFJXp9lje3gPtoX3YFt4j9zcXMTHx9fpIe03i0GqCvbLef7+/ggICJC4Ns2b2WyGTqdDQEAA/4LyAmwP78G28B5sC+9hNpsBoFGH5XCwOREREZGHGKSIiIiIPMQgRUREROQhjpEiIiJqRBaLxTGWh+pOrVY32tQG7mCQIiIiagSCICAzMxP5+flSV+WWJpfLER8fD7VaLXVVADBIERERNQp7iAoPD4dOp+OEzx6wT5idkZGBmJgYr/gNGaSIiIgamMVicYSo0NBQqatzS2vRogWuXbuGsrIyr5huQvKLjOvXr0d8fDy0Wi0SExOxf//+arfds2cPZDKZy3LmzBmn7bZ
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 2000x800 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAU4AAAJ8CAYAAACRCUhSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9WXMk93UlfrL2fV+wA93oDQ30QjZFiosoa6UkT9hhz8s8zSeYeZkPMq8TEzEPM+Gw/bfkUFi2JYdkS7JISSR7IZu9EOgFO6pQ+5ZZlZVVufwf4HuZVSgAhaXRQKNOBKK70UBlVlbm/d3fveeeIxiGYWCAAQYYYIC+YXnZJzDAAAMMcNowCJwDDDDAAPvEIHAOMMAAA+wTg8A5wAADDLBPDALnAAMMMMA+MQicAwwwwAD7xCBwDjDAAAPsE4PAOcAAAwywTwwC5wADDDDAPmF72ScwwKsHwzAgyzLq9TparRZ0XUf3gJrFYoHNZoPH44HX64XFYoEgCC/pjAcYYH8QBiOXAxwlFEVBNpuFy+WC3W6Hx+OBpmkQBAEulwuKogBAR0AVRRF+vx+hUGgQPAc4FRhknAMcGQzDQC6XQzweh8vlwv3791EoFGAYBmKxGC5evIhPPvkEPp8P9XodgiDgW9/6FkKhEFZXVxEMBgeBc4BTgUHGOUBf0DQNKysrkGUZ09PTcLlc0DQN7XYbrVYLiqIgl8vh7//+7/GXf/mXGB4e5qxTEATYbDZYLBZomgZVVaFpGgzDgKqqEEURXq8X4XB4EDgHOBUYBM5XDFRf1DRt379rtVrhdrshCALXJHVdh6qqaDabWF1dhaIoiEajsNlssFqtsNvtcDqdcDgcsFqtaLfbcLvdkGUZtVoNsiyj3W53nA/VN10uF3w+H/x+/6DGOcCpwiBwvmIwDANLS0uIRCL7/t1sNot4PI5WqwVVVWEYBiwWC+x2OxwOBzRNg91uh9/vhyAI2wIdBe1qtYp2u80BkrJN+nld1znzVFUVFosFPp8PgUAAFsuA6DHAycegxvkKgrbH5XIZfr8f2WwWPp8PiqLAarVClmWcO3cOpVIJuq4jk8kgFovBYrFAVVW43W7YbLZtgfHu3buwWCy4ceMGrFYrf1/XdXz55ZdIJpNwOp2IRCIQRRGKoqBWq8Hr9aLdbkMQBIyOjmJhYQE+nw+SJCEUCmF4eBi1Wg0bGxsYHx8fZJ4DnHgMAucris3NTWSzWYyNjaFUKiGdTvMW+vz585AkCcFgENlsFna7Hbquw2KxwGq1QtO0nlv92dlZCIKAdruNdrvN39c0DRsbG/B6vfD7/bDb7dA0DalUCqqqwm63o9FoIBAIoNVqodFoQNM02Gw2RKNRWCwWOJ1OVKtVGIYxCJwDnHgMtuqvCGibfO/ePQSDQVy8eBHAVlDrVT/sFaAymQympqb2HbjMt1C1WoUoihAEgbf4drudz8EwDG4KtdttKIoCTdPgdDoRjUZhtVoHgXOAE49BxnmKQA0aXddx4cIFCIKARqOBRqPBvMhyuYwLFy5AkqR9v34wGDzQeZkDXTAYRDAY5FpnoVCAoihQVRW6rkMQBFitVjgcDng8HgQCgUFzaIBTh0HGeUpAGeXq6iqazSbC4TATzD0eDxwOB//sywpAhmGgWq2iVqvB4XDA5XLB5XLxOdlsW+s00Zjoq9lswm63I5FIDJpDA5wKDALnCQR9JKqqckbZbrdhs9ngdrvh9Xpht9sBHF+QNAyDaUVOp3NbgKMJoFKphMnJSciyjE8++QStVgtjY2NYX1/HzZs3IYoi1tbWIAgCZFnGm2++iXg8jmq1ClVVEYvFjuX9DDDAYTAInIdAu93G5uZmR4f5KKDrOrLZLAzD4IzSbrdjdHR0z0BpGAay2SxUVd13UBUEAcPDwx2/R7eHpmn48ssvUalUcPnyZTgcDrTbbSazt9tt/PVf/zW++93v4ty5c/D5fExpcrlc0HUdTqcTuq5DlmXoug4AcLlcaDQakCQJQ0NDcDqd+7xaAwxw/BgEzkOgXq+j0WggEolwVxrY3njRNI2DKzVHem1JaXbbarXy3+l3s9kshoeHsbq6Co/Hg6GhoZ4NHwBYW1vD2NjYtuPTedG/qXFEf66trSEcDjO/kqZ7AHBjx+l0wm63w263M0eTGjqqqsJms0FRFIiiiGazya/RS+SDOKI+nw9er7cnN3SAAU4iBs2hQ0LXdfzqV79CpVLB17/+dSwtLSGbzUIQBHg8HrzzzjtYW1vDzMwMyuUy/vCHPyCRSMDv9yOVSqFarcLn82FkZAQWiwXz8/P4/ve/j3v37nEds1wu4/z589jc3EQmk4Hf798WiABAlmX88z//M775zW+i0WhgeXkZly9fxvz8PCwWCxPPQ6EQrl69irt37yIQCKBUKiGRSCAcDqPdbsPpdMLr9XJQpMbN559/DgC4ceNGRzOHzsVms0FVVUiSBEVRYBgGk98BdPy8YRg8lVSv13lqaYABTgMGGechQBknAO4au1wuHlG0Wq2w2WxMSCfuI2WTTqcTsixDEAToug6Px8OZGNU0PR4PVFVFsVhEMpnctXmiaRoKhQJarRbC4TArEdF0DgU2v9/P22sKqE6nE9lsFqFQiIMafVGgE0URFosFHo+Hg6AgCLBYLPi3f/s3XL58GSMjIxyANU2DLMtwuVxotVoQBAHJZBIbGxtwuVxoNpsIhUJwuVyoVCrQNG1bqWCAAU4iBhnnISAIAk/GEM2GgqMgCNu2zObfo0BKnWbz7xIoAANbVCRz9tYLNpsNIyMjWFpa4t+j1wbA8m69qEqqqgIAotHojoGre401Z47f+MY30Gq1EAwG4fV6sbm5iaWlJbTbbUxMTKDVasHpdDKlijQ4Y7EYHA4HIpEINjY2BgT4Y0CvXOlVv+ZH/Z4HGechQJ1mM3RdR61WQyqVwh//+Ef81//6X4+k4UF0nn6aQ5qmcfNlP6DJoYPeUJqmoVgsotlswuFwsPiHeVadgi1lvIqioNlsQhAExGIxOJ3OV/4hPk7Qwka7I2rYmR97WvSpRHPaPwMasKjX62g2mzu+Z3qmiKmyn3t/EDiPABRAK5UKms0m/H4/C1ac5hvwIKDbiQjw1CSim9dMgPd6vQgGg/D5fGfyWr1oGIaBcrkMSZLg9XrhcDg6xFqsVisPJ5CWgKIoaLVaGB4eZsrbaQIxUqhs5nQ6YbPZOsS0qfFJcog2mw2NRgM2m61vLvEgcO4TjUYD5XIZNpsN8XgcsiyjXC5DEASEw2GWZTuLMAwDjUYDpVKJb1KXywWr1dqxmuu6ztlmq9VCu92G1+tFJBI5s9fuqGEYBlqtFtLpNKamplCv1/HHP/4RiqLgwoULzGb49NNPEQgE0Gw2cfXqVUxMTKBer6NeryOZTL7st7EnKLvUdR12ux2SJKHZbCIej6NUKuHevXuQZZm5xZcuXcLi4iJzoxOJBK5evQqbzYZsNgu/3w+v17vncQeBc5+o1WpYW1vjD8fj8SAUCvW1jX7V0W63kU6nMTY2hlarhS+++AKqqiIUCqFer2NqagqNRgPpdJppV9euXYPX60WxWITFYjmQHN5ZAmWRhUIB0WiUrxdlja1Wixejf/qnf0IsFsNbb70Fv9/PJRP602KxsIQgUdKazSZkWcbQ0FDHNNpJBJUh5ufnUSqVcPHiRXzxxRfY2NjAD3/4Q/h8PmaFkJAN/dlut6HrOmw2G1qtFkRRBAAMDQ0NMs6jBK3gpVIJ7XYbwWBwR13KVxGapqFcLkPTNEQiEdjt9g5yfLvdxhdffIHFxUV88MEH8Pl8KBaL0DSNyfDUzScVJL/fD5/PB1mWUSwWeSE6DTiMYPRhjmmxWJDJZFCv1+FyuXiHQ6LStB232+1MLdM0DZIkMU3MXPMz82ndbjf8fn8Ha+KwoPrqYeB2u5mHTDsV2q2QkhftbhqNBg+NSJKERqPBQjLEECEmiM1mg9PphMfjgc/nY/ZLPxgEzl1AN1e9XkelUoHNZkM4HD6RxXM616P4OGkxMCvBt9ttzM/Po1qtYmpqiutGAPihLRaLcLlciEQikCSJp5e6i+6
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwMAAAHFCAYAAACuDCWjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeXxM1//48ddkG9lDJCIaSS0hkiAau0osCUJqLUWR6kctVUtqizX2fW+LUmLXxVJbEUuopSiiilIkUkRJhRBkvb8//HK/RoIsQ0Lez8djHsy9557lPZOZOfece65GURQFIYQQQgghRKFjkN8VEEIIIYQQQuQP6QwIIYQQQghRSElnQAghhBBCiEJKOgNCCCGEEEIUUtIZEEIIIYQQopCSzoAQQgghhBCFlHQGhBBCCCGEKKSkMyCEEEIIIUQhJZ0BIYQQQgghCinpDAghhHhjhYWFodFosnwMGjTolZR57tw5QkNDiY6OfiX550V0dDQajYawsLD8rkqubd++ndDQ0PyuhhCFhlF+V0AIIYTIq2XLllGxYkWdbY6Ojq+krHPnzjF27Fh8fX1xcXF5JWXkVsmSJTly5Ahly5bN76rk2vbt2/n666+lQyDEayKdASGEEG88Dw8PvL2987saeZKSkoJGo8HIKPdfzVqtllq1aumxVq/Pw4cPMTMzy+9qCFHoyDQhIYQQb73vv/+e2rVrY25ujoWFBU2aNOHUqVM6aX7//Xc++ugjXFxcMDU1xcXFhY4dO3L16lU1TVhYGB9++CEADRo0UKckZUzLcXFxISgoKFP5vr6++Pr6qs8jIiLQaDSsXLmSL7/8klKlSqHVarl06RIAu3fvplGjRlhZWWFmZkbdunXZs2fPS9uZ1TSh0NBQNBoNf/zxBx9++CHW1tYUK1aM4OBgUlNTuXDhAk2bNsXS0hIXFxemTZumk2dGXVetWkVwcDAODg6Ympri4+OTKYYAmzdvpnbt2piZmWFpaYmfnx9HjhzRSZNRp5MnT9KuXTuKFi1K2bJlCQoK4uuvvwbQmfKVMSXr66+/pn79+tjb22Nubo6npyfTpk0jJSUlU7w9PDw4fvw477//PmZmZpQpU4YpU6aQnp6uk/bu3bt8+eWXlClTBq1Wi729PQEBAfz1119qmuTkZCZMmEDFihXRarXY2dnxySefcPv27Ze+JkIUdNIZEEII8cZLS0sjNTVV55Fh0qRJdOzYkUqVKvHDDz+wcuVK7t+/z/vvv8+5c+fUdNHR0VSoUIE5c+awc+dOpk6dSmxsLNWrVycuLg6A5s2bM2nSJODJD9MjR45w5MgRmjdvnqt6h4SEEBMTw8KFC9myZQv29vasWrUKf39/rKysWL58OT/88APFihWjSZMm2eoQPE/79u2pUqUK69evp0ePHsyePZuBAwfSqlUrmjdvzsaNG2nYsCFDhw5lw4YNmY4fPnw4V65cYcmSJSxZsoQbN27g6+vLlStX1DRr1qyhZcuWWFlZsXbtWr777jvi4+Px9fXl4MGDmfJs06YN5cqV48cff2ThwoWMGjWKdu3aAaixPXLkCCVLlgTg8uXLdOrUiZUrV7J161Y+/fRTpk+fTs+ePTPlffPmTTp37szHH3/M5s2badasGSEhIaxatUpNc//+ferVq8eiRYv45JNP2LJlCwsXLsTV1ZXY2FgA0tPTadmyJVOmTKFTp05s27aNKVOmEB4ejq+vL48ePcr1ayJEgaAIIYQQb6hly5YpQJaPlJQUJSYmRjEyMlK++OILnePu37+vODg4KO3bt39u3qmpqcqDBw8Uc3NzZe7cuer2H3/8UQGUffv2ZTrG2dlZ6datW6btPj4+io+Pj/p83759CqDUr19fJ11iYqJSrFgxJTAwUGd7WlqaUqVKFaVGjRoviIaiREVFKYCybNkydduYMWMUQJk5c6ZO2qpVqyqAsmHDBnVbSkqKYmdnp7Rp0yZTXatVq6akp6er26OjoxVjY2Plf//7n1pHR0dHxdPTU0lLS1PT3b9/X7G3t1fq1KmTqU6jR4/O1IbPP/9cyc7Pk7S0NCUlJUVZsWKFYmhoqNy5c0fd5+PjowDK0aNHdY6pVKmS0qRJE/X5uHHjFEAJDw9/bjlr165VAGX9+vU6248fP64AyjfffPPSugpRkMnIgBBCiDfeihUrOH78uM7DyMiInTt3kpqaSteuXXVGDYoUKYKPjw8RERFqHg8ePGDo0KGUK1cOIyMjjIyMsLCwIDExkfPnz7+Serdt21bn+eHDh7lz5w7dunXTqW96ejpNmzbl+PHjJCYm5qqsFi1a6Dx3c3NDo9HQrFkzdZuRkRHlypXTmRqVoVOnTmg0GvW5s7MzderUYd++fQBcuHCBGzdu0KVLFwwM/u/nhYWFBW3btuW3337j4cOHL2z/y5w6dYoPPvgAW1tbDA0NMTY2pmvXrqSlpXHx4kWdtA4ODtSoUUNnW+XKlXXa9ssvv+Dq6krjxo2fW+bWrVuxsbEhMDBQ5zWpWrUqDg4OOu8hId5EcgGxEEKIN56bm1uWFxD/+++/AFSvXj3L457+0dqpUyf27NnDqFGjqF69OlZWVmg0GgICAl7ZVJCM6S/P1jdjqkxW7ty5g7m5eY7LKlasmM5zExMTzMzMKFKkSKbtCQkJmY53cHDIctvp06cB+O+//4DMbYInKzulp6cTHx+vc5FwVmmfJyYmhvfff58KFSowd+5cXFxcKFKkCMeOHePzzz/P9BrZ2tpmykOr1eqku337NqVLl35huf/++y93797FxMQky/0ZU8iEeFNJZ0AIIcRbq3jx4gD89NNPODs7PzfdvXv32Lp1K2PGjGHYsGHq9qSkJO7cuZPt8ooUKUJSUlKm7XFxcWpdnvb0mfan6zt//vznrgpUokSJbNdHn27evJnltowf3Rn/Zsy1f9qNGzcwMDCgaNGiOtufbf+LbNq0icTERDZs2KDzWkZGRmY7j2fZ2dlx7dq1F6YpXrw4tra27NixI8v9lpaWuS5fiIJAOgNCCCHeWk2aNMHIyIjLly+/cEqKRqNBURS0Wq3O9iVLlpCWlqazLSNNVqMFLi4u/PHHHzrbLl68yIULF7LsDDyrbt262NjYcO7cOfr27fvS9K/T2rVrCQ4OVn/AX716lcOHD9O1a1cAKlSoQKlSpVizZg2DBg1S0yUmJrJ+/Xp1haGXeTq+pqam6vaM/J5+jRRFYfHixbluU7NmzRg9ejR79+6lYcOGWaZp0aIF69atIy0tjZo1a+a6LCEKKukMCCGEeGu5uLgwbtw4RowYwZUrV2jatClFixbl33//5dixY5ibmzN27FisrKyoX78+06dPp3jx4ri4uLB//36+++47bGxsdPL08PAA4Ntvv8XS0pIiRYrw7rvvYmtrS5cuXfj444/p06cPbdu25erVq0ybNg07O7ts1dfCwoL58+fTrVs37ty5Q7t27bC3t+f27ducPn2a27dvs2DBAn2HKVtu3bpF69at6dGjB/fu3WPMmDEUKVKEkJAQ4MmUq2nTptG5c2datGhBz549SUpKYvr06dy9e5cpU6ZkqxxPT08Apk6dSrNmzTA0NKRy5cr4+flhYmJCx44dGTJkCI8fP2bBggXEx8fnuk0DBgzg+++/p2XLlgwbNowaNWrw6NEj9u/fT4sWLWjQoAEfffQRq1evJiAggP79+1OjRg2MjY25du0a+/bto2XLlrRu3TrXdRAiv8kFxEIIId5qISEh/PTTT1y8eJFu3brRpEkThgwZwtWrV6lfv76abs2aNTRo0IAhQ4bQpk0bfv/9d8LDw7G2ttbJ791332XOnDmcPn0aX19fqlevzpYtW4An1x1MmzaNnTt30qJFCxYsWMCCBQtwdXXNdn0//vhj9u3bx4MHD+jZsyeNGzemf//+nDx5kkaNGuknKLkwadIknJ2d+eSTT+jevTslS5Zk3759Onc77tSpE5s2beK///6jQ4cOfPLJJ1hZWbFv3z7q1auXrXI6derE//73P7755htq165N9erVuXHjBhUrVmT9+vXEx8fTpk0bvvjiC6pWrcq8efNy3SZLS0sOHjz
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 11
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "5bb96ca8492e74d",
"metadata": {
"ExecuteTime": {
2025-02-15 23:33:34 +08:00
"end_time": "2025-02-11T16:50:05.435066Z",
"start_time": "2025-02-11T16:50:03.988772Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"test_data['score'] = light_model.predict(test_data[feature_columns])\n",
"predictions = test_data.loc[test_data.groupby('trade_date')['score'].idxmax()]"
2025-02-15 23:33:34 +08:00
],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_5800\\1422049760.py:1: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" test_data['score'] = light_model.predict(test_data[feature_columns])\n"
]
}
],
"execution_count": 12
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "5d1522a7538db91b",
"metadata": {
"ExecuteTime": {
2025-02-15 23:33:34 +08:00
"end_time": "2025-02-11T16:50:05.530654Z",
"start_time": "2025-02-11T16:50:05.439067Z"
2025-02-12 00:21:33 +08:00
}
},
"source": [
"predictions[['trade_date', 'score', 'ts_code']].to_csv('predictions.tsv', index=False)"
2025-02-15 23:33:34 +08:00
],
"outputs": [],
"execution_count": 13
2025-02-12 00:21:33 +08:00
},
{
"cell_type": "code",
"id": "b427ce41-9739-4e9e-bea8-5f2551fec5d7",
2025-02-15 23:33:34 +08:00
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-11T16:50:05.576808Z",
"start_time": "2025-02-11T16:50:05.563061Z"
}
},
"source": [],
2025-02-12 00:21:33 +08:00
"outputs": [],
2025-02-15 23:33:34 +08:00
"execution_count": null
2025-02-12 00:21:33 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}