Files
NewStock/main/train/V1-copy.ipynb
2025-04-28 11:02:52 +08:00

897 lines
41 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:52:54.170824Z",
"start_time": "2025-02-09T14:52:53.544850Z"
}
},
"cell_type": "code",
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from code.utils.utils import read_and_merge_h5_data"
],
"id": "79a7758178bafdd3",
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:53:36.873700Z",
"start_time": "2025-02-09T14:52:54.170824Z"
}
},
"cell_type": "code",
"source": [
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic_with_st',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df)\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)"
],
"id": "a79cafb06a7e0e43",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"stk limit\n",
"money flow\n"
]
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:53:37.426404Z",
"start_time": "2025-02-09T14:53:36.955552Z"
}
},
"cell_type": "code",
"source": "origin_columns = df.columns.tolist()",
"id": "c4e9e1d31da6dba6",
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:53:38.164112Z",
"start_time": "2025-02-09T14:53:38.070007Z"
}
},
"cell_type": "code",
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"def get_technical_factor(df):\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" df['atr_14'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=14)\n",
" df['atr_6'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=6)\n",
"\n",
" df['obv'] = talib.OBV(df['close'], df['vol'])\n",
" df['maobv_6'] = talib.SMA(df['obv'], timeperiod=6)\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" df['rsi_3'] = talib.RSI(df['close'], timeperiod=3)\n",
" df['rsi_6'] = talib.RSI(df['close'], timeperiod=6)\n",
" df['rsi_9'] = talib.RSI(df['close'], timeperiod=9)\n",
"\n",
" df['return_10'] = df['close'] / df['close'].shift(10) - 1\n",
" df['return_20'] = df['close'] / df['close'].shift(20) - 1\n",
"\n",
" # # 计算 _rank_return_10 和 _rank_return_20\n",
" # df['_rank_return_10'] = df['return_10'].rank(pct=True)\n",
" # df['_rank_return_20'] = df['return_20'].rank(pct=True)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = df['close'].rolling(window=5).mean() / df['close']\n",
"\n",
" # 计算 std_return_5, std_return_15, std_return_25, std_return_252, std_return_2522\n",
" df['std_return_5'] = df['close'].pct_change().shift(-1).rolling(window=5).std()\n",
" df['std_return_15'] = df['close'].pct_change().shift(-1).rolling(window=15).std()\n",
" df['std_return_25'] = df['close'].pct_change().shift(-1).rolling(window=25).std()\n",
" df['std_return_90'] = df['close'].pct_change().shift(-1).rolling(window=90).std()\n",
" df['std_return_90_2'] = df['close'].shift(10).pct_change().shift(-1).rolling(window=90).std()\n",
"\n",
" # 计算 std_return_5 / std_return_252 和 std_return_5 / std_return_25\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算 std_return_252 - std_return_2522\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df):\n",
" # 计算 m_ta_ema(close, 5), m_ta_ema(close, 13), m_ta_ema(close, 20), m_ta_ema(close, 60)\n",
" df['ema_5'] = talib.EMA(df['close'], timeperiod=5)\n",
" df['ema_13'] = talib.EMA(df['close'], timeperiod=13)\n",
" df['ema_20'] = talib.EMA(df['close'], timeperiod=20)\n",
" df['ema_60'] = talib.EMA(df['close'], timeperiod=60)\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = np.arctan((df['ema_5'] / df['ema_5'].shift(1) - 1) * 100) * 57.3 / 50\n",
" df['act_factor2'] = np.arctan((df['ema_13'] / df['ema_13'].shift(1) - 1) * 100) * 57.3 / 40\n",
" df['act_factor3'] = np.arctan((df['ema_20'] / df['ema_20'].shift(1) - 1) * 100) * 57.3 / 21\n",
" df['act_factor4'] = np.arctan((df['ema_60'] / df['ema_60'].shift(1) - 1) * 100) * 57.3 / 10\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" # 根据 'trade_date' 进行分组,在每个组内分别计算 'act_factor1', 'act_factor2', 'act_factor3' 的排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date')['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date')['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date')['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol - sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol - sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" # # 你还提到了一些其他字段:\n",
" # df['net_active_buy_volume_main'] = df['net_mf_vol'] / df['buy_sm_vol']\n",
" # df['netflow_amount_main'] = df['net_mf_vol'] / df['buy_sm_vol'] # 这里假设 'net_mf_vol' 是主流资金流\n",
"\n",
" # df['active_sell_volume_large'] = df['sell_lg_vol'] / df['sell_sm_vol']\n",
" # df['active_sell_volume_big'] = df['sell_elg_vol'] / df['sell_sm_vol']\n",
" # df['active_sell_volume_small'] = df['sell_sm_vol'] / df['sell_sm_vol']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df['alpha_022'] = df['close'] - df['close'].shift(5)\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = (df['close'] - df['open']) / (df['high'] - df['low'])\n",
"\n",
" # alpha_007: rank(correlation(close, volume, 5))\n",
" df['alpha_007'] = df['close'].rolling(5).corr(df['vol']).rank(axis=1)\n",
"\n",
" # alpha_013: rank(sum(close, 5) - sum(close, 20))\n",
" df['alpha_013'] = (df['close'].rolling(5).sum() - df['close'].rolling(20).sum()).rank(axis=1)\n",
" return df\n",
"\n",
"\n",
"def get_future_data(df):\n",
" df['future_return1'] = (df['close'].shift(-1) - df['close']) / df['close']\n",
" df['future_return2'] = (df['open'].shift(-2) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_return3'] = (df['close'].shift(-2) - df['close'].shift(-1)) / df['close'].shift(-1)\n",
" df['future_return4'] = (df['close'].shift(-2) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_return5'] = (df['close'].shift(-5) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_return6'] = (df['close'].shift(-10) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_return7'] = (df['close'].shift(-20) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_close1'] = (df['close'].shift(-1) - df['close']) / df['close']\n",
" df['future_close2'] = (df['close'].shift(-2) - df['close']) / df['close']\n",
" df['future_close3'] = (df['close'].shift(-3) - df['close']) / df['close']\n",
" df['future_close4'] = (df['close'].shift(-4) - df['close']) / df['close']\n",
" df['future_close5'] = (df['close'].shift(-5) - df['close']) / df['close']\n",
" df['future_af11'] = df['act_factor1'].shift(-1)\n",
" df['future_af12'] = df['act_factor1'].shift(-2)\n",
" df['future_af13'] = df['act_factor1'].shift(-3)\n",
" df['future_af14'] = df['act_factor1'].shift(-4)\n",
" df['future_af15'] = df['act_factor1'].shift(-5)\n",
" df['future_af21'] = df['act_factor2'].shift(-1)\n",
" df['future_af22'] = df['act_factor2'].shift(-2)\n",
" df['future_af23'] = df['act_factor2'].shift(-3)\n",
" df['future_af24'] = df['act_factor2'].shift(-4)\n",
" df['future_af25'] = df['act_factor2'].shift(-5)\n",
" df['future_af31'] = df['act_factor3'].shift(-1)\n",
" df['future_af32'] = df['act_factor3'].shift(-2)\n",
" df['future_af33'] = df['act_factor3'].shift(-3)\n",
" df['future_af34'] = df['act_factor3'].shift(-4)\n",
" df['future_af35'] = df['act_factor3'].shift(-5)\n",
"\n",
" return df\n"
],
"id": "a735bc02ceb4d872",
"outputs": [],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:53:49.153376Z",
"start_time": "2025-02-09T14:53:38.164112Z"
}
},
"cell_type": "code",
"source": [
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_future_data(df)\n",
"# df = df.drop(columns=origin_columns)\n",
"\n",
"print(df.info())"
],
"id": "53f86ddc0677a6d7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8364308 entries, 0 to 8364307\n",
"Data columns (total 83 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 is_st object \n",
" 8 up_limit float64 \n",
" 9 down_limit float64 \n",
" 10 buy_sm_vol float64 \n",
" 11 sell_sm_vol float64 \n",
" 12 buy_lg_vol float64 \n",
" 13 sell_lg_vol float64 \n",
" 14 buy_elg_vol float64 \n",
" 15 sell_elg_vol float64 \n",
" 16 net_mf_vol float64 \n",
" 17 up float64 \n",
" 18 down float64 \n",
" 19 atr_14 float64 \n",
" 20 atr_6 float64 \n",
" 21 obv float64 \n",
" 22 maobv_6 float64 \n",
" 23 obv-maobv_6 float64 \n",
" 24 rsi_3 float64 \n",
" 25 rsi_6 float64 \n",
" 26 rsi_9 float64 \n",
" 27 return_10 float64 \n",
" 28 return_20 float64 \n",
" 29 avg_close_5 float64 \n",
" 30 std_return_5 float64 \n",
" 31 std_return_15 float64 \n",
" 32 std_return_25 float64 \n",
" 33 std_return_90 float64 \n",
" 34 std_return_90_2 float64 \n",
" 35 std_return_5 / std_return_90 float64 \n",
" 36 std_return_5 / std_return_25 float64 \n",
" 37 std_return_90 - std_return_90_2 float64 \n",
" 38 ema_5 float64 \n",
" 39 ema_13 float64 \n",
" 40 ema_20 float64 \n",
" 41 ema_60 float64 \n",
" 42 act_factor1 float64 \n",
" 43 act_factor2 float64 \n",
" 44 act_factor3 float64 \n",
" 45 act_factor4 float64 \n",
" 46 act_factor5 float64 \n",
" 47 act_factor6 float64 \n",
" 48 rank_act_factor1 float64 \n",
" 49 rank_act_factor2 float64 \n",
" 50 rank_act_factor3 float64 \n",
" 51 active_buy_volume_large float64 \n",
" 52 active_buy_volume_big float64 \n",
" 53 active_buy_volume_small float64 \n",
" 54 buy_lg_vol - sell_lg_vol float64 \n",
" 55 buy_elg_vol - sell_elg_vol float64 \n",
" 56 future_return1 float64 \n",
" 57 future_return2 float64 \n",
" 58 future_return3 float64 \n",
" 59 future_return4 float64 \n",
" 60 future_return5 float64 \n",
" 61 future_return6 float64 \n",
" 62 future_return7 float64 \n",
" 63 future_close1 float64 \n",
" 64 future_close2 float64 \n",
" 65 future_close3 float64 \n",
" 66 future_close4 float64 \n",
" 67 future_close5 float64 \n",
" 68 future_af11 float64 \n",
" 69 future_af12 float64 \n",
" 70 future_af13 float64 \n",
" 71 future_af14 float64 \n",
" 72 future_af15 float64 \n",
" 73 future_af21 float64 \n",
" 74 future_af22 float64 \n",
" 75 future_af23 float64 \n",
" 76 future_af24 float64 \n",
" 77 future_af25 float64 \n",
" 78 future_af31 float64 \n",
" 79 future_af32 float64 \n",
" 80 future_af33 float64 \n",
" 81 future_af34 float64 \n",
" 82 future_af35 float64 \n",
"dtypes: datetime64[ns](1), float64(80), object(2)\n",
"memory usage: 5.2+ GB\n",
"None\n"
]
}
],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:55:28.712343Z",
"start_time": "2025-02-09T14:53:49.279168Z"
}
},
"cell_type": "code",
"source": [
"def filter_data(df):\n",
" df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor3'))\n",
" df = df[df['is_st'] == False]\n",
" df = df[df['is_st'] == False]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"print(df.info())"
],
"id": "dbe2fd8021b9417f",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1136157 entries, 0 to 1136156\n",
"Data columns (total 83 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 ts_code 1136157 non-null object \n",
" 1 trade_date 1136157 non-null datetime64[ns]\n",
" 2 open 1136157 non-null float64 \n",
" 3 close 1136157 non-null float64 \n",
" 4 high 1136157 non-null float64 \n",
" 5 low 1136157 non-null float64 \n",
" 6 vol 1136157 non-null float64 \n",
" 7 is_st 1136157 non-null object \n",
" 8 up_limit 1135878 non-null float64 \n",
" 9 down_limit 1135878 non-null float64 \n",
" 10 buy_sm_vol 1135663 non-null float64 \n",
" 11 sell_sm_vol 1135663 non-null float64 \n",
" 12 buy_lg_vol 1135663 non-null float64 \n",
" 13 sell_lg_vol 1135663 non-null float64 \n",
" 14 buy_elg_vol 1135663 non-null float64 \n",
" 15 sell_elg_vol 1135663 non-null float64 \n",
" 16 net_mf_vol 1135663 non-null float64 \n",
" 17 up 1136157 non-null float64 \n",
" 18 down 1136157 non-null float64 \n",
" 19 atr_14 1136157 non-null float64 \n",
" 20 atr_6 1136157 non-null float64 \n",
" 21 obv 1136157 non-null float64 \n",
" 22 maobv_6 1136157 non-null float64 \n",
" 23 obv-maobv_6 1136157 non-null float64 \n",
" 24 rsi_3 1136157 non-null float64 \n",
" 25 rsi_6 1136157 non-null float64 \n",
" 26 rsi_9 1136157 non-null float64 \n",
" 27 return_10 1136157 non-null float64 \n",
" 28 return_20 1136157 non-null float64 \n",
" 29 avg_close_5 1136157 non-null float64 \n",
" 30 std_return_5 1136157 non-null float64 \n",
" 31 std_return_15 1136157 non-null float64 \n",
" 32 std_return_25 1136157 non-null float64 \n",
" 33 std_return_90 1136131 non-null float64 \n",
" 34 std_return_90_2 1136129 non-null float64 \n",
" 35 std_return_5 / std_return_90 1136131 non-null float64 \n",
" 36 std_return_5 / std_return_25 1136157 non-null float64 \n",
" 37 std_return_90 - std_return_90_2 1136129 non-null float64 \n",
" 38 ema_5 1136157 non-null float64 \n",
" 39 ema_13 1136157 non-null float64 \n",
" 40 ema_20 1136157 non-null float64 \n",
" 41 ema_60 1136153 non-null float64 \n",
" 42 act_factor1 1136157 non-null float64 \n",
" 43 act_factor2 1136157 non-null float64 \n",
" 44 act_factor3 1136157 non-null float64 \n",
" 45 act_factor4 1136152 non-null float64 \n",
" 46 act_factor5 1136152 non-null float64 \n",
" 47 act_factor6 1136157 non-null float64 \n",
" 48 rank_act_factor1 1136157 non-null float64 \n",
" 49 rank_act_factor2 1136157 non-null float64 \n",
" 50 rank_act_factor3 1136157 non-null float64 \n",
" 51 active_buy_volume_large 1135659 non-null float64 \n",
" 52 active_buy_volume_big 1135636 non-null float64 \n",
" 53 active_buy_volume_small 1135663 non-null float64 \n",
" 54 buy_lg_vol - sell_lg_vol 1135660 non-null float64 \n",
" 55 buy_elg_vol - sell_elg_vol 1135640 non-null float64 \n",
" 56 future_return1 1136157 non-null float64 \n",
" 57 future_return2 1136157 non-null float64 \n",
" 58 future_return3 1136157 non-null float64 \n",
" 59 future_return4 1136157 non-null float64 \n",
" 60 future_return5 1136157 non-null float64 \n",
" 61 future_return6 1136157 non-null float64 \n",
" 62 future_return7 1136157 non-null float64 \n",
" 63 future_close1 1136157 non-null float64 \n",
" 64 future_close2 1136157 non-null float64 \n",
" 65 future_close3 1136157 non-null float64 \n",
" 66 future_close4 1136157 non-null float64 \n",
" 67 future_close5 1136157 non-null float64 \n",
" 68 future_af11 1136157 non-null float64 \n",
" 69 future_af12 1136157 non-null float64 \n",
" 70 future_af13 1136157 non-null float64 \n",
" 71 future_af14 1136157 non-null float64 \n",
" 72 future_af15 1136157 non-null float64 \n",
" 73 future_af21 1136157 non-null float64 \n",
" 74 future_af22 1136157 non-null float64 \n",
" 75 future_af23 1136157 non-null float64 \n",
" 76 future_af24 1136157 non-null float64 \n",
" 77 future_af25 1136157 non-null float64 \n",
" 78 future_af31 1136157 non-null float64 \n",
" 79 future_af32 1136157 non-null float64 \n",
" 80 future_af33 1136157 non-null float64 \n",
" 81 future_af34 1136157 non-null float64 \n",
" 82 future_af35 1136157 non-null float64 \n",
"dtypes: datetime64[ns](1), float64(80), object(2)\n",
"memory usage: 719.5+ MB\n",
"None\n"
]
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:00:45.828404Z",
"start_time": "2025-02-09T15:00:45.294830Z"
}
},
"cell_type": "code",
"source": [
"def remove_outliers_iqr(series, lower_quantile=0.05, upper_quantile=0.95, threshold=1.5):\n",
" Q1 = series.quantile(lower_quantile)\n",
" Q3 = series.quantile(upper_quantile)\n",
" IQR = Q3 - Q1\n",
" lower_bound = Q1 - threshold * IQR\n",
" upper_bound = Q3 + threshold * IQR\n",
" # 过滤掉低于下边界或高于上边界的极值\n",
" return (series >= lower_bound) & (series <= upper_bound)\n",
"\n",
"\n",
"def neutralize_labels(labels, features, feature_columns, z_threshold=3, method='regression'):\n",
" labels_no_outliers = remove_outliers_iqr(labels)\n",
" return labels_no_outliers\n",
"\n",
"\n",
"train_data = df[df['trade_date'] <= '2023-01-01']\n",
"test_data = df[df['trade_date'] >= '2023-01-01']\n",
"\n",
"feature_columns = [col for col in df.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"\n",
"# for column in [column for column in train_data.columns if 'future' in column]:\n",
"# label_index = neutralize_labels(train_data[column], train_data, feature_columns, z_threshold=3, method='regression')\n",
"# train_data = train_data[label_index]\n",
"# label_index = neutralize_labels(test_data[column], test_data, feature_columns, z_threshold=3, method='regression')\n",
"# test_data = test_data[label_index]\n",
"\n",
"print(len(train_data))\n",
"print(len(test_data))"
],
"id": "5f3d9aece75318cd",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol - sell_lg_vol', 'buy_elg_vol - sell_elg_vol']\n"
]
}
],
"execution_count": 19
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:56:05.319915Z",
"start_time": "2025-02-09T14:56:03.355725Z"
}
},
"cell_type": "code",
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"window = 5\n",
"quantiles = 20\n",
"\n",
"\n",
"def get_label(df):\n",
" labels = df['future_af13'] - df['act_factor1']\n",
" # labels = df['future_close3']\n",
" return labels\n",
"\n",
"\n",
"train_data['label'], test_data['label'] = get_label(train_data), get_label(test_data)\n",
"\n",
"train_data, test_data = train_data.dropna(subset=['label']), test_data.dropna(subset=['label'])\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan).dropna(), test_data.replace([np.inf, -np.inf],\n",
" np.nan).dropna()\n",
"train_data, test_data = train_data.reset_index(drop=True), test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(len(test_data))"
],
"id": "f4f16d63ad18d1bc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"875004\n",
"最小日期: 2017-01-03\n",
"最大日期: 2022-12-30\n",
"260581\n",
"最小日期: 2023-01-03\n",
"最大日期: 2025-01-27\n"
]
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:56:05.480695Z",
"start_time": "2025-02-09T14:56:05.367238Z"
}
},
"cell_type": "code",
"source": [
"import lightgbm as lgb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import optuna\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.metrics import mean_absolute_error\n",
"import os\n",
"import json\n",
"import pickle\n",
"import hashlib\n",
"\n",
"\n",
"def objective(trial, X, y, num_boost_round, params):\n",
" # 参数网格\n",
" X, y = X.reset_index(drop=True), y.reset_index(drop=True)\n",
" param_grid = {\n",
" \"n_estimators\": trial.suggest_categorical(\"n_estimators\", [10000]),\n",
" \"learning_rate\": trial.suggest_float(\"learning_rate\", 0.01, 0.3),\n",
" \"num_leaves\": trial.suggest_int(\"num_leaves\", 20, 3000, step=25),\n",
" \"max_depth\": trial.suggest_int(\"max_depth\", 3, 16),\n",
" \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 200, 10000, step=100),\n",
" \"lambda_l1\": trial.suggest_int(\"lambda_l1\", 0, 100, step=5),\n",
" \"lambda_l2\": trial.suggest_int(\"lambda_l2\", 0, 100, step=5),\n",
" \"min_gain_to_split\": trial.suggest_float(\"min_gain_to_split\", 0, 15),\n",
" \"bagging_fraction\": trial.suggest_float(\"bagging_fraction\", 0.2, 0.95, step=0.1),\n",
" \"bagging_freq\": trial.suggest_categorical(\"bagging_freq\", [1]),\n",
" \"feature_fraction\": trial.suggest_float(\"feature_fraction\", 0.2, 0.95, step=0.1),\n",
" \"random_state\": 1,\n",
" \"objective\": 'regression',\n",
" 'verbosity': -1\n",
" }\n",
" # 5折交叉验证\n",
" cv = KFold(n_splits=5, shuffle=False)\n",
"\n",
" cv_scores = np.empty(5)\n",
" for idx, (train_idx, test_idx) in enumerate(cv.split(X, y)):\n",
" X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]\n",
" y_train, y_test = y[train_idx], y[test_idx]\n",
"\n",
" # LGBM建模\n",
" model = lgb.LGBMRegressor(**param_grid, num_boost_round=num_boost_round)\n",
" model.fit(\n",
" X_train,\n",
" y_train,\n",
" eval_set=[(X_test, y_test)],\n",
" eval_metric=\"l2\",\n",
" callbacks=[\n",
" # LightGBMPruningCallback(trial, \"l2\"),\n",
" lgb.early_stopping(50, first_metric_only=True),\n",
" lgb.log_evaluation(period=-1)\n",
" ],\n",
" )\n",
" # 模型预测\n",
" preds = model.predict(X_test)\n",
" # 优化指标logloss最小\n",
" cv_scores[idx] = mean_absolute_error(y_test, preds)\n",
"\n",
" return np.mean(cv_scores)\n",
"\n",
"def generate_key(params, feature_columns, num_boost_round):\n",
" key_data = {\n",
" \"params\": params,\n",
" \"feature_columns\": feature_columns,\n",
" \"num_boost_round\": num_boost_round\n",
" }\n",
" # 转换成排序后的 JSON 字符串,再生成 md5 hash\n",
" key_str = json.dumps(key_data, sort_keys=True)\n",
" return hashlib.md5(key_str.encode('utf-8')).hexdigest()\n",
"\n",
"def train_light_model(df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" use_optuna=False):\n",
" cache_file = 'light_model.pkl'\n",
" cache_key = generate_key(params, feature_columns, num_boost_round)\n",
"\n",
" # 检查缓存文件是否存在\n",
" if os.path.exists(cache_file):\n",
" try:\n",
" with open(cache_file, 'rb') as f:\n",
" cache_data = pickle.load(f)\n",
" if cache_data.get('key') == cache_key:\n",
" print(\"加载缓存模型...\")\n",
" return cache_data.get('model')\n",
" else:\n",
" print(\"缓存模型的参数与当前参数不匹配,重新训练模型。\")\n",
" except Exception as e:\n",
" print(f\"加载缓存失败: {e},重新训练模型。\")\n",
" else:\n",
" print(\"未发现缓存模型,开始训练新模型。\")\n",
" # 确保数据按照 date 和 label 排序\n",
" df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False]) # 按日期升序、标签降序排序\n",
" df_sorted = df_sorted.sort_values(by='trade_date')\n",
" unique_dates = df_sorted['trade_date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices]\n",
" val_df = df_sorted.iloc[val_indices]\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" train_data = lgb.Dataset(X_train, label=y_train)\n",
" val_data = lgb.Dataset(X_val, label=y_val)\n",
" if use_optuna:\n",
" # study = optuna.create_study(direction='minimize' if classify else 'maximize')\n",
" study = optuna.create_study(direction='minimize')\n",
" study.optimize(lambda trial: objective(trial, X_train, y_train, num_boost_round, params), n_trials=20)\n",
"\n",
" print(f\"Best parameters: {study.best_trial.params}\")\n",
" print(f\"Best score: {study.best_trial.value}\")\n",
"\n",
" params.update(study.best_trial.params)\n",
" model = lgb.train(\n",
" params, train_data, num_boost_round=num_boost_round,\n",
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_tree(model, figsize=(20, 8))\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
" # with open(cache_file, 'wb') as f:\n",
" # pickle.dump({'key': cache_key,\n",
" # 'model': model,\n",
" # 'feature_columns': feature_columns}, f)\n",
" # print(\"模型训练完成并已保存缓存。\")\n",
" return model\n",
"\n",
"\n",
"from catboost import CatBoostRegressor\n",
"import pandas as pd\n",
"\n",
"\n",
"def train_catboost(df, num_boost_round, params=None):\n",
" \"\"\"\n",
" 训练 CatBoost 排序模型\n",
" - df: 包含因子、date、instrument 和 label 的 DataFrame\n",
" - num_boost_round: 训练的轮数\n",
" - print_feature_importance: 是否打印特征重要性\n",
" - plot: 是否绘制特征重要性图\n",
" - split_date: 用于划分训练集和验证集的日期(比如 '2020-01-01'\n",
"\n",
" 返回训练好的模型\n",
" \"\"\"\n",
" df_sorted = df.sort_values(by=['date', 'label'], ascending=[True, False])\n",
"\n",
" # 提取特征和标签\n",
" feature_columns = [col for col in df.columns if col not in ['date',\n",
" 'instrument',\n",
" 'label']]\n",
" feature_columns = [col for col in feature_columns if 'future' not in col]\n",
" feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"\n",
" df_sorted = df_sorted.sort_values(by='date')\n",
" unique_dates = df_sorted['date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
" val_df = df_sorted.iloc[val_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" model = CatBoostRegressor(**params, iterations=num_boost_round)\n",
" model.fit(X_train,\n",
" y_train,\n",
" eval_set=(X_val, y_val))\n",
"\n",
" return model"
],
"id": "8f134d435f71e9e2",
"outputs": [],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:56:05.576927Z",
"start_time": "2025-02-09T14:56:05.480695Z"
}
},
"cell_type": "code",
"source": [
"light_params = {\n",
" 'objective': 'regression',\n",
" 'metric': 'l2',\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 16,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
" 'nthread': 2,\n",
" 'feature_fraction': 0.7,\n",
" 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" 'lambda_l1': 80,\n",
" 'lambda_l2': 65,\n",
" 'verbosity': -1\n",
"}"
],
"id": "4a4542e1ed6afe7d",
"outputs": [],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:57:25.341222Z",
"start_time": "2025-02-09T14:56:05.640256Z"
}
},
"cell_type": "code",
"source": [
"print('train data size: ', len(train_data))\n",
"df = train_data\n",
"\n",
"evals = {}\n",
"light_model = train_light_model(train_data, light_params, feature_columns,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=1000, use_optuna=False,\n",
" print_feature_importance=False)"
],
"id": "beeb098799ecfa6a",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 875004\n",
"未发现缓存模型,开始训练新模型。\n",
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
"[378]\ttrain's l2: 0.435049\tvalid's l2: 0.589178\n",
"Evaluated only: l2\n"
]
}
],
"execution_count": 16
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:57:27.394697Z",
"start_time": "2025-02-09T14:57:25.373274Z"
}
},
"cell_type": "code",
"source": [
"test_data['score'] = light_model.predict(test_data[feature_columns])\n",
"predictions = test_data.loc[test_data.groupby('trade_date')['score'].idxmax()]"
],
"id": "5bb96ca8492e74d",
"outputs": [],
"execution_count": 17
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:57:27.489570Z",
"start_time": "2025-02-09T14:57:27.397368Z"
}
},
"cell_type": "code",
"source": "predictions[['trade_date', 'score', 'ts_code']].to_csv('predictions.csv', index=False)",
"id": "5d1522a7538db91b",
"outputs": [],
"execution_count": 18
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}