897 lines
41 KiB
Plaintext
897 lines
41 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:52:54.170824Z",
|
||
"start_time": "2025-02-09T14:52:53.544850Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"%load_ext autoreload\n",
|
||
"%autoreload 2\n",
|
||
"\n",
|
||
"from code.utils.utils import read_and_merge_h5_data"
|
||
],
|
||
"id": "79a7758178bafdd3",
|
||
"outputs": [],
|
||
"execution_count": 1
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:53:36.873700Z",
|
||
"start_time": "2025-02-09T14:52:54.170824Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"print('daily data')\n",
|
||
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
|
||
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
|
||
" df=None)\n",
|
||
"\n",
|
||
"print('daily basic')\n",
|
||
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic_with_st',\n",
|
||
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
|
||
" 'is_st'], df=df)\n",
|
||
"\n",
|
||
"print('stk limit')\n",
|
||
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
|
||
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
|
||
" df=df)\n",
|
||
"print('money flow')\n",
|
||
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
|
||
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
|
||
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
|
||
" df=df)"
|
||
],
|
||
"id": "a79cafb06a7e0e43",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"daily data\n",
|
||
"daily basic\n",
|
||
"stk limit\n",
|
||
"money flow\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 2
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:53:37.426404Z",
|
||
"start_time": "2025-02-09T14:53:36.955552Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "origin_columns = df.columns.tolist()",
|
||
"id": "c4e9e1d31da6dba6",
|
||
"outputs": [],
|
||
"execution_count": 3
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:53:38.164112Z",
|
||
"start_time": "2025-02-09T14:53:38.070007Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import talib\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_technical_factor(df):\n",
|
||
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
|
||
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
|
||
"\n",
|
||
" df['atr_14'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=14)\n",
|
||
" df['atr_6'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=6)\n",
|
||
"\n",
|
||
" df['obv'] = talib.OBV(df['close'], df['vol'])\n",
|
||
" df['maobv_6'] = talib.SMA(df['obv'], timeperiod=6)\n",
|
||
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
|
||
"\n",
|
||
" df['rsi_3'] = talib.RSI(df['close'], timeperiod=3)\n",
|
||
" df['rsi_6'] = talib.RSI(df['close'], timeperiod=6)\n",
|
||
" df['rsi_9'] = talib.RSI(df['close'], timeperiod=9)\n",
|
||
"\n",
|
||
" df['return_10'] = df['close'] / df['close'].shift(10) - 1\n",
|
||
" df['return_20'] = df['close'] / df['close'].shift(20) - 1\n",
|
||
"\n",
|
||
" # # 计算 _rank_return_10 和 _rank_return_20\n",
|
||
" # df['_rank_return_10'] = df['return_10'].rank(pct=True)\n",
|
||
" # df['_rank_return_20'] = df['return_20'].rank(pct=True)\n",
|
||
"\n",
|
||
" # 计算 avg_close_5\n",
|
||
" df['avg_close_5'] = df['close'].rolling(window=5).mean() / df['close']\n",
|
||
"\n",
|
||
" # 计算 std_return_5, std_return_15, std_return_25, std_return_252, std_return_2522\n",
|
||
" df['std_return_5'] = df['close'].pct_change().shift(-1).rolling(window=5).std()\n",
|
||
" df['std_return_15'] = df['close'].pct_change().shift(-1).rolling(window=15).std()\n",
|
||
" df['std_return_25'] = df['close'].pct_change().shift(-1).rolling(window=25).std()\n",
|
||
" df['std_return_90'] = df['close'].pct_change().shift(-1).rolling(window=90).std()\n",
|
||
" df['std_return_90_2'] = df['close'].shift(10).pct_change().shift(-1).rolling(window=90).std()\n",
|
||
"\n",
|
||
" # 计算 std_return_5 / std_return_252 和 std_return_5 / std_return_25\n",
|
||
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
|
||
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
|
||
"\n",
|
||
" # 计算 std_return_252 - std_return_2522\n",
|
||
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_act_factor(df):\n",
|
||
" # 计算 m_ta_ema(close, 5), m_ta_ema(close, 13), m_ta_ema(close, 20), m_ta_ema(close, 60)\n",
|
||
" df['ema_5'] = talib.EMA(df['close'], timeperiod=5)\n",
|
||
" df['ema_13'] = talib.EMA(df['close'], timeperiod=13)\n",
|
||
" df['ema_20'] = talib.EMA(df['close'], timeperiod=20)\n",
|
||
" df['ema_60'] = talib.EMA(df['close'], timeperiod=60)\n",
|
||
"\n",
|
||
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
|
||
" df['act_factor1'] = np.arctan((df['ema_5'] / df['ema_5'].shift(1) - 1) * 100) * 57.3 / 50\n",
|
||
" df['act_factor2'] = np.arctan((df['ema_13'] / df['ema_13'].shift(1) - 1) * 100) * 57.3 / 40\n",
|
||
" df['act_factor3'] = np.arctan((df['ema_20'] / df['ema_20'].shift(1) - 1) * 100) * 57.3 / 21\n",
|
||
" df['act_factor4'] = np.arctan((df['ema_60'] / df['ema_60'].shift(1) - 1) * 100) * 57.3 / 10\n",
|
||
"\n",
|
||
" # 计算 act_factor5 和 act_factor6\n",
|
||
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
|
||
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
|
||
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
|
||
"\n",
|
||
" # 根据 'trade_date' 进行分组,在每个组内分别计算 'act_factor1', 'act_factor2', 'act_factor3' 的排名\n",
|
||
" df['rank_act_factor1'] = df.groupby('trade_date')['act_factor1'].rank(ascending=False, pct=True)\n",
|
||
" df['rank_act_factor2'] = df.groupby('trade_date')['act_factor2'].rank(ascending=False, pct=True)\n",
|
||
" df['rank_act_factor3'] = df.groupby('trade_date')['act_factor3'].rank(ascending=False, pct=True)\n",
|
||
"\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_money_flow_factor(df):\n",
|
||
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
|
||
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
|
||
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
|
||
"\n",
|
||
" df['buy_lg_vol - sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
|
||
" df['buy_elg_vol - sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
|
||
"\n",
|
||
" # # 你还提到了一些其他字段:\n",
|
||
" # df['net_active_buy_volume_main'] = df['net_mf_vol'] / df['buy_sm_vol']\n",
|
||
" # df['netflow_amount_main'] = df['net_mf_vol'] / df['buy_sm_vol'] # 这里假设 'net_mf_vol' 是主流资金流\n",
|
||
"\n",
|
||
" # df['active_sell_volume_large'] = df['sell_lg_vol'] / df['sell_sm_vol']\n",
|
||
" # df['active_sell_volume_big'] = df['sell_elg_vol'] / df['sell_sm_vol']\n",
|
||
" # df['active_sell_volume_small'] = df['sell_sm_vol'] / df['sell_sm_vol']\n",
|
||
"\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_alpha_factor(df):\n",
|
||
" df['alpha_022'] = df['close'] - df['close'].shift(5)\n",
|
||
"\n",
|
||
" # alpha_003: (close - open) / (high - low)\n",
|
||
" df['alpha_003'] = (df['close'] - df['open']) / (df['high'] - df['low'])\n",
|
||
"\n",
|
||
" # alpha_007: rank(correlation(close, volume, 5))\n",
|
||
" df['alpha_007'] = df['close'].rolling(5).corr(df['vol']).rank(axis=1)\n",
|
||
"\n",
|
||
" # alpha_013: rank(sum(close, 5) - sum(close, 20))\n",
|
||
" df['alpha_013'] = (df['close'].rolling(5).sum() - df['close'].rolling(20).sum()).rank(axis=1)\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_future_data(df):\n",
|
||
" df['future_return1'] = (df['close'].shift(-1) - df['close']) / df['close']\n",
|
||
" df['future_return2'] = (df['open'].shift(-2) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
|
||
" df['future_return3'] = (df['close'].shift(-2) - df['close'].shift(-1)) / df['close'].shift(-1)\n",
|
||
" df['future_return4'] = (df['close'].shift(-2) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
|
||
" df['future_return5'] = (df['close'].shift(-5) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
|
||
" df['future_return6'] = (df['close'].shift(-10) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
|
||
" df['future_return7'] = (df['close'].shift(-20) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
|
||
" df['future_close1'] = (df['close'].shift(-1) - df['close']) / df['close']\n",
|
||
" df['future_close2'] = (df['close'].shift(-2) - df['close']) / df['close']\n",
|
||
" df['future_close3'] = (df['close'].shift(-3) - df['close']) / df['close']\n",
|
||
" df['future_close4'] = (df['close'].shift(-4) - df['close']) / df['close']\n",
|
||
" df['future_close5'] = (df['close'].shift(-5) - df['close']) / df['close']\n",
|
||
" df['future_af11'] = df['act_factor1'].shift(-1)\n",
|
||
" df['future_af12'] = df['act_factor1'].shift(-2)\n",
|
||
" df['future_af13'] = df['act_factor1'].shift(-3)\n",
|
||
" df['future_af14'] = df['act_factor1'].shift(-4)\n",
|
||
" df['future_af15'] = df['act_factor1'].shift(-5)\n",
|
||
" df['future_af21'] = df['act_factor2'].shift(-1)\n",
|
||
" df['future_af22'] = df['act_factor2'].shift(-2)\n",
|
||
" df['future_af23'] = df['act_factor2'].shift(-3)\n",
|
||
" df['future_af24'] = df['act_factor2'].shift(-4)\n",
|
||
" df['future_af25'] = df['act_factor2'].shift(-5)\n",
|
||
" df['future_af31'] = df['act_factor3'].shift(-1)\n",
|
||
" df['future_af32'] = df['act_factor3'].shift(-2)\n",
|
||
" df['future_af33'] = df['act_factor3'].shift(-3)\n",
|
||
" df['future_af34'] = df['act_factor3'].shift(-4)\n",
|
||
" df['future_af35'] = df['act_factor3'].shift(-5)\n",
|
||
"\n",
|
||
" return df\n"
|
||
],
|
||
"id": "a735bc02ceb4d872",
|
||
"outputs": [],
|
||
"execution_count": 4
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:53:49.153376Z",
|
||
"start_time": "2025-02-09T14:53:38.164112Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"df = get_technical_factor(df)\n",
|
||
"df = get_act_factor(df)\n",
|
||
"df = get_money_flow_factor(df)\n",
|
||
"df = get_future_data(df)\n",
|
||
"# df = df.drop(columns=origin_columns)\n",
|
||
"\n",
|
||
"print(df.info())"
|
||
],
|
||
"id": "53f86ddc0677a6d7",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"RangeIndex: 8364308 entries, 0 to 8364307\n",
|
||
"Data columns (total 83 columns):\n",
|
||
" # Column Dtype \n",
|
||
"--- ------ ----- \n",
|
||
" 0 ts_code object \n",
|
||
" 1 trade_date datetime64[ns]\n",
|
||
" 2 open float64 \n",
|
||
" 3 close float64 \n",
|
||
" 4 high float64 \n",
|
||
" 5 low float64 \n",
|
||
" 6 vol float64 \n",
|
||
" 7 is_st object \n",
|
||
" 8 up_limit float64 \n",
|
||
" 9 down_limit float64 \n",
|
||
" 10 buy_sm_vol float64 \n",
|
||
" 11 sell_sm_vol float64 \n",
|
||
" 12 buy_lg_vol float64 \n",
|
||
" 13 sell_lg_vol float64 \n",
|
||
" 14 buy_elg_vol float64 \n",
|
||
" 15 sell_elg_vol float64 \n",
|
||
" 16 net_mf_vol float64 \n",
|
||
" 17 up float64 \n",
|
||
" 18 down float64 \n",
|
||
" 19 atr_14 float64 \n",
|
||
" 20 atr_6 float64 \n",
|
||
" 21 obv float64 \n",
|
||
" 22 maobv_6 float64 \n",
|
||
" 23 obv-maobv_6 float64 \n",
|
||
" 24 rsi_3 float64 \n",
|
||
" 25 rsi_6 float64 \n",
|
||
" 26 rsi_9 float64 \n",
|
||
" 27 return_10 float64 \n",
|
||
" 28 return_20 float64 \n",
|
||
" 29 avg_close_5 float64 \n",
|
||
" 30 std_return_5 float64 \n",
|
||
" 31 std_return_15 float64 \n",
|
||
" 32 std_return_25 float64 \n",
|
||
" 33 std_return_90 float64 \n",
|
||
" 34 std_return_90_2 float64 \n",
|
||
" 35 std_return_5 / std_return_90 float64 \n",
|
||
" 36 std_return_5 / std_return_25 float64 \n",
|
||
" 37 std_return_90 - std_return_90_2 float64 \n",
|
||
" 38 ema_5 float64 \n",
|
||
" 39 ema_13 float64 \n",
|
||
" 40 ema_20 float64 \n",
|
||
" 41 ema_60 float64 \n",
|
||
" 42 act_factor1 float64 \n",
|
||
" 43 act_factor2 float64 \n",
|
||
" 44 act_factor3 float64 \n",
|
||
" 45 act_factor4 float64 \n",
|
||
" 46 act_factor5 float64 \n",
|
||
" 47 act_factor6 float64 \n",
|
||
" 48 rank_act_factor1 float64 \n",
|
||
" 49 rank_act_factor2 float64 \n",
|
||
" 50 rank_act_factor3 float64 \n",
|
||
" 51 active_buy_volume_large float64 \n",
|
||
" 52 active_buy_volume_big float64 \n",
|
||
" 53 active_buy_volume_small float64 \n",
|
||
" 54 buy_lg_vol - sell_lg_vol float64 \n",
|
||
" 55 buy_elg_vol - sell_elg_vol float64 \n",
|
||
" 56 future_return1 float64 \n",
|
||
" 57 future_return2 float64 \n",
|
||
" 58 future_return3 float64 \n",
|
||
" 59 future_return4 float64 \n",
|
||
" 60 future_return5 float64 \n",
|
||
" 61 future_return6 float64 \n",
|
||
" 62 future_return7 float64 \n",
|
||
" 63 future_close1 float64 \n",
|
||
" 64 future_close2 float64 \n",
|
||
" 65 future_close3 float64 \n",
|
||
" 66 future_close4 float64 \n",
|
||
" 67 future_close5 float64 \n",
|
||
" 68 future_af11 float64 \n",
|
||
" 69 future_af12 float64 \n",
|
||
" 70 future_af13 float64 \n",
|
||
" 71 future_af14 float64 \n",
|
||
" 72 future_af15 float64 \n",
|
||
" 73 future_af21 float64 \n",
|
||
" 74 future_af22 float64 \n",
|
||
" 75 future_af23 float64 \n",
|
||
" 76 future_af24 float64 \n",
|
||
" 77 future_af25 float64 \n",
|
||
" 78 future_af31 float64 \n",
|
||
" 79 future_af32 float64 \n",
|
||
" 80 future_af33 float64 \n",
|
||
" 81 future_af34 float64 \n",
|
||
" 82 future_af35 float64 \n",
|
||
"dtypes: datetime64[ns](1), float64(80), object(2)\n",
|
||
"memory usage: 5.2+ GB\n",
|
||
"None\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 5
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:55:28.712343Z",
|
||
"start_time": "2025-02-09T14:53:49.279168Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def filter_data(df):\n",
|
||
" df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor3'))\n",
|
||
" df = df[df['is_st'] == False]\n",
|
||
" df = df[df['is_st'] == False]\n",
|
||
" df = df[~df['ts_code'].str.startswith('30')]\n",
|
||
" df = df[~df['ts_code'].str.startswith('68')]\n",
|
||
" df = df[~df['ts_code'].str.startswith('8')]\n",
|
||
" df = df.reset_index(drop=True)\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"df = filter_data(df)\n",
|
||
"print(df.info())"
|
||
],
|
||
"id": "dbe2fd8021b9417f",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"RangeIndex: 1136157 entries, 0 to 1136156\n",
|
||
"Data columns (total 83 columns):\n",
|
||
" # Column Non-Null Count Dtype \n",
|
||
"--- ------ -------------- ----- \n",
|
||
" 0 ts_code 1136157 non-null object \n",
|
||
" 1 trade_date 1136157 non-null datetime64[ns]\n",
|
||
" 2 open 1136157 non-null float64 \n",
|
||
" 3 close 1136157 non-null float64 \n",
|
||
" 4 high 1136157 non-null float64 \n",
|
||
" 5 low 1136157 non-null float64 \n",
|
||
" 6 vol 1136157 non-null float64 \n",
|
||
" 7 is_st 1136157 non-null object \n",
|
||
" 8 up_limit 1135878 non-null float64 \n",
|
||
" 9 down_limit 1135878 non-null float64 \n",
|
||
" 10 buy_sm_vol 1135663 non-null float64 \n",
|
||
" 11 sell_sm_vol 1135663 non-null float64 \n",
|
||
" 12 buy_lg_vol 1135663 non-null float64 \n",
|
||
" 13 sell_lg_vol 1135663 non-null float64 \n",
|
||
" 14 buy_elg_vol 1135663 non-null float64 \n",
|
||
" 15 sell_elg_vol 1135663 non-null float64 \n",
|
||
" 16 net_mf_vol 1135663 non-null float64 \n",
|
||
" 17 up 1136157 non-null float64 \n",
|
||
" 18 down 1136157 non-null float64 \n",
|
||
" 19 atr_14 1136157 non-null float64 \n",
|
||
" 20 atr_6 1136157 non-null float64 \n",
|
||
" 21 obv 1136157 non-null float64 \n",
|
||
" 22 maobv_6 1136157 non-null float64 \n",
|
||
" 23 obv-maobv_6 1136157 non-null float64 \n",
|
||
" 24 rsi_3 1136157 non-null float64 \n",
|
||
" 25 rsi_6 1136157 non-null float64 \n",
|
||
" 26 rsi_9 1136157 non-null float64 \n",
|
||
" 27 return_10 1136157 non-null float64 \n",
|
||
" 28 return_20 1136157 non-null float64 \n",
|
||
" 29 avg_close_5 1136157 non-null float64 \n",
|
||
" 30 std_return_5 1136157 non-null float64 \n",
|
||
" 31 std_return_15 1136157 non-null float64 \n",
|
||
" 32 std_return_25 1136157 non-null float64 \n",
|
||
" 33 std_return_90 1136131 non-null float64 \n",
|
||
" 34 std_return_90_2 1136129 non-null float64 \n",
|
||
" 35 std_return_5 / std_return_90 1136131 non-null float64 \n",
|
||
" 36 std_return_5 / std_return_25 1136157 non-null float64 \n",
|
||
" 37 std_return_90 - std_return_90_2 1136129 non-null float64 \n",
|
||
" 38 ema_5 1136157 non-null float64 \n",
|
||
" 39 ema_13 1136157 non-null float64 \n",
|
||
" 40 ema_20 1136157 non-null float64 \n",
|
||
" 41 ema_60 1136153 non-null float64 \n",
|
||
" 42 act_factor1 1136157 non-null float64 \n",
|
||
" 43 act_factor2 1136157 non-null float64 \n",
|
||
" 44 act_factor3 1136157 non-null float64 \n",
|
||
" 45 act_factor4 1136152 non-null float64 \n",
|
||
" 46 act_factor5 1136152 non-null float64 \n",
|
||
" 47 act_factor6 1136157 non-null float64 \n",
|
||
" 48 rank_act_factor1 1136157 non-null float64 \n",
|
||
" 49 rank_act_factor2 1136157 non-null float64 \n",
|
||
" 50 rank_act_factor3 1136157 non-null float64 \n",
|
||
" 51 active_buy_volume_large 1135659 non-null float64 \n",
|
||
" 52 active_buy_volume_big 1135636 non-null float64 \n",
|
||
" 53 active_buy_volume_small 1135663 non-null float64 \n",
|
||
" 54 buy_lg_vol - sell_lg_vol 1135660 non-null float64 \n",
|
||
" 55 buy_elg_vol - sell_elg_vol 1135640 non-null float64 \n",
|
||
" 56 future_return1 1136157 non-null float64 \n",
|
||
" 57 future_return2 1136157 non-null float64 \n",
|
||
" 58 future_return3 1136157 non-null float64 \n",
|
||
" 59 future_return4 1136157 non-null float64 \n",
|
||
" 60 future_return5 1136157 non-null float64 \n",
|
||
" 61 future_return6 1136157 non-null float64 \n",
|
||
" 62 future_return7 1136157 non-null float64 \n",
|
||
" 63 future_close1 1136157 non-null float64 \n",
|
||
" 64 future_close2 1136157 non-null float64 \n",
|
||
" 65 future_close3 1136157 non-null float64 \n",
|
||
" 66 future_close4 1136157 non-null float64 \n",
|
||
" 67 future_close5 1136157 non-null float64 \n",
|
||
" 68 future_af11 1136157 non-null float64 \n",
|
||
" 69 future_af12 1136157 non-null float64 \n",
|
||
" 70 future_af13 1136157 non-null float64 \n",
|
||
" 71 future_af14 1136157 non-null float64 \n",
|
||
" 72 future_af15 1136157 non-null float64 \n",
|
||
" 73 future_af21 1136157 non-null float64 \n",
|
||
" 74 future_af22 1136157 non-null float64 \n",
|
||
" 75 future_af23 1136157 non-null float64 \n",
|
||
" 76 future_af24 1136157 non-null float64 \n",
|
||
" 77 future_af25 1136157 non-null float64 \n",
|
||
" 78 future_af31 1136157 non-null float64 \n",
|
||
" 79 future_af32 1136157 non-null float64 \n",
|
||
" 80 future_af33 1136157 non-null float64 \n",
|
||
" 81 future_af34 1136157 non-null float64 \n",
|
||
" 82 future_af35 1136157 non-null float64 \n",
|
||
"dtypes: datetime64[ns](1), float64(80), object(2)\n",
|
||
"memory usage: 719.5+ MB\n",
|
||
"None\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 6
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T15:00:45.828404Z",
|
||
"start_time": "2025-02-09T15:00:45.294830Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def remove_outliers_iqr(series, lower_quantile=0.05, upper_quantile=0.95, threshold=1.5):\n",
|
||
" Q1 = series.quantile(lower_quantile)\n",
|
||
" Q3 = series.quantile(upper_quantile)\n",
|
||
" IQR = Q3 - Q1\n",
|
||
" lower_bound = Q1 - threshold * IQR\n",
|
||
" upper_bound = Q3 + threshold * IQR\n",
|
||
" # 过滤掉低于下边界或高于上边界的极值\n",
|
||
" return (series >= lower_bound) & (series <= upper_bound)\n",
|
||
"\n",
|
||
"\n",
|
||
"def neutralize_labels(labels, features, feature_columns, z_threshold=3, method='regression'):\n",
|
||
" labels_no_outliers = remove_outliers_iqr(labels)\n",
|
||
" return labels_no_outliers\n",
|
||
"\n",
|
||
"\n",
|
||
"train_data = df[df['trade_date'] <= '2023-01-01']\n",
|
||
"test_data = df[df['trade_date'] >= '2023-01-01']\n",
|
||
"\n",
|
||
"feature_columns = [col for col in df.columns if col not in ['trade_date',\n",
|
||
" 'ts_code',\n",
|
||
" 'label']]\n",
|
||
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
|
||
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
|
||
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
|
||
"\n",
|
||
"# for column in [column for column in train_data.columns if 'future' in column]:\n",
|
||
"# label_index = neutralize_labels(train_data[column], train_data, feature_columns, z_threshold=3, method='regression')\n",
|
||
"# train_data = train_data[label_index]\n",
|
||
"# label_index = neutralize_labels(test_data[column], test_data, feature_columns, z_threshold=3, method='regression')\n",
|
||
"# test_data = test_data[label_index]\n",
|
||
"\n",
|
||
"print(len(train_data))\n",
|
||
"print(len(test_data))"
|
||
],
|
||
"id": "5f3d9aece75318cd",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"['up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol - sell_lg_vol', 'buy_elg_vol - sell_elg_vol']\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 19
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:56:05.319915Z",
|
||
"start_time": "2025-02-09T14:56:03.355725Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def get_qcuts(series, quantiles):\n",
|
||
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
|
||
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
|
||
"\n",
|
||
"\n",
|
||
"window = 5\n",
|
||
"quantiles = 20\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_label(df):\n",
|
||
" labels = df['future_af13'] - df['act_factor1']\n",
|
||
" # labels = df['future_close3']\n",
|
||
" return labels\n",
|
||
"\n",
|
||
"\n",
|
||
"train_data['label'], test_data['label'] = get_label(train_data), get_label(test_data)\n",
|
||
"\n",
|
||
"train_data, test_data = train_data.dropna(subset=['label']), test_data.dropna(subset=['label'])\n",
|
||
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan).dropna(), test_data.replace([np.inf, -np.inf],\n",
|
||
" np.nan).dropna()\n",
|
||
"train_data, test_data = train_data.reset_index(drop=True), test_data.reset_index(drop=True)\n",
|
||
"\n",
|
||
"print(len(train_data))\n",
|
||
"print(len(test_data))"
|
||
],
|
||
"id": "f4f16d63ad18d1bc",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"875004\n",
|
||
"最小日期: 2017-01-03\n",
|
||
"最大日期: 2022-12-30\n",
|
||
"260581\n",
|
||
"最小日期: 2023-01-03\n",
|
||
"最大日期: 2025-01-27\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 13
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:56:05.480695Z",
|
||
"start_time": "2025-02-09T14:56:05.367238Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import lightgbm as lgb\n",
|
||
"import numpy as np\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import optuna\n",
|
||
"from sklearn.model_selection import KFold\n",
|
||
"from sklearn.metrics import mean_absolute_error\n",
|
||
"import os\n",
|
||
"import json\n",
|
||
"import pickle\n",
|
||
"import hashlib\n",
|
||
"\n",
|
||
"\n",
|
||
"def objective(trial, X, y, num_boost_round, params):\n",
|
||
" # 参数网格\n",
|
||
" X, y = X.reset_index(drop=True), y.reset_index(drop=True)\n",
|
||
" param_grid = {\n",
|
||
" \"n_estimators\": trial.suggest_categorical(\"n_estimators\", [10000]),\n",
|
||
" \"learning_rate\": trial.suggest_float(\"learning_rate\", 0.01, 0.3),\n",
|
||
" \"num_leaves\": trial.suggest_int(\"num_leaves\", 20, 3000, step=25),\n",
|
||
" \"max_depth\": trial.suggest_int(\"max_depth\", 3, 16),\n",
|
||
" \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 200, 10000, step=100),\n",
|
||
" \"lambda_l1\": trial.suggest_int(\"lambda_l1\", 0, 100, step=5),\n",
|
||
" \"lambda_l2\": trial.suggest_int(\"lambda_l2\", 0, 100, step=5),\n",
|
||
" \"min_gain_to_split\": trial.suggest_float(\"min_gain_to_split\", 0, 15),\n",
|
||
" \"bagging_fraction\": trial.suggest_float(\"bagging_fraction\", 0.2, 0.95, step=0.1),\n",
|
||
" \"bagging_freq\": trial.suggest_categorical(\"bagging_freq\", [1]),\n",
|
||
" \"feature_fraction\": trial.suggest_float(\"feature_fraction\", 0.2, 0.95, step=0.1),\n",
|
||
" \"random_state\": 1,\n",
|
||
" \"objective\": 'regression',\n",
|
||
" 'verbosity': -1\n",
|
||
" }\n",
|
||
" # 5折交叉验证\n",
|
||
" cv = KFold(n_splits=5, shuffle=False)\n",
|
||
"\n",
|
||
" cv_scores = np.empty(5)\n",
|
||
" for idx, (train_idx, test_idx) in enumerate(cv.split(X, y)):\n",
|
||
" X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]\n",
|
||
" y_train, y_test = y[train_idx], y[test_idx]\n",
|
||
"\n",
|
||
" # LGBM建模\n",
|
||
" model = lgb.LGBMRegressor(**param_grid, num_boost_round=num_boost_round)\n",
|
||
" model.fit(\n",
|
||
" X_train,\n",
|
||
" y_train,\n",
|
||
" eval_set=[(X_test, y_test)],\n",
|
||
" eval_metric=\"l2\",\n",
|
||
" callbacks=[\n",
|
||
" # LightGBMPruningCallback(trial, \"l2\"),\n",
|
||
" lgb.early_stopping(50, first_metric_only=True),\n",
|
||
" lgb.log_evaluation(period=-1)\n",
|
||
" ],\n",
|
||
" )\n",
|
||
" # 模型预测\n",
|
||
" preds = model.predict(X_test)\n",
|
||
" # 优化指标logloss最小\n",
|
||
" cv_scores[idx] = mean_absolute_error(y_test, preds)\n",
|
||
"\n",
|
||
" return np.mean(cv_scores)\n",
|
||
"\n",
|
||
"def generate_key(params, feature_columns, num_boost_round):\n",
|
||
" key_data = {\n",
|
||
" \"params\": params,\n",
|
||
" \"feature_columns\": feature_columns,\n",
|
||
" \"num_boost_round\": num_boost_round\n",
|
||
" }\n",
|
||
" # 转换成排序后的 JSON 字符串,再生成 md5 hash\n",
|
||
" key_str = json.dumps(key_data, sort_keys=True)\n",
|
||
" return hashlib.md5(key_str.encode('utf-8')).hexdigest()\n",
|
||
"\n",
|
||
"def train_light_model(df, params, feature_columns, callbacks, evals,\n",
|
||
" print_feature_importance=True, num_boost_round=100,\n",
|
||
" use_optuna=False):\n",
|
||
" cache_file = 'light_model.pkl'\n",
|
||
" cache_key = generate_key(params, feature_columns, num_boost_round)\n",
|
||
"\n",
|
||
" # 检查缓存文件是否存在\n",
|
||
" if os.path.exists(cache_file):\n",
|
||
" try:\n",
|
||
" with open(cache_file, 'rb') as f:\n",
|
||
" cache_data = pickle.load(f)\n",
|
||
" if cache_data.get('key') == cache_key:\n",
|
||
" print(\"加载缓存模型...\")\n",
|
||
" return cache_data.get('model')\n",
|
||
" else:\n",
|
||
" print(\"缓存模型的参数与当前参数不匹配,重新训练模型。\")\n",
|
||
" except Exception as e:\n",
|
||
" print(f\"加载缓存失败: {e},重新训练模型。\")\n",
|
||
" else:\n",
|
||
" print(\"未发现缓存模型,开始训练新模型。\")\n",
|
||
" # 确保数据按照 date 和 label 排序\n",
|
||
" df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False]) # 按日期升序、标签降序排序\n",
|
||
" df_sorted = df_sorted.sort_values(by='trade_date')\n",
|
||
" unique_dates = df_sorted['trade_date'].unique()\n",
|
||
" val_date_count = int(len(unique_dates) * 0.1)\n",
|
||
" val_dates = unique_dates[-val_date_count:]\n",
|
||
" val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n",
|
||
" train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n",
|
||
"\n",
|
||
" # 获取训练集和验证集的样本\n",
|
||
" train_df = df_sorted.iloc[train_indices]\n",
|
||
" val_df = df_sorted.iloc[val_indices]\n",
|
||
"\n",
|
||
" X_train = train_df[feature_columns]\n",
|
||
" y_train = train_df['label']\n",
|
||
"\n",
|
||
" X_val = val_df[feature_columns]\n",
|
||
" y_val = val_df['label']\n",
|
||
"\n",
|
||
" train_data = lgb.Dataset(X_train, label=y_train)\n",
|
||
" val_data = lgb.Dataset(X_val, label=y_val)\n",
|
||
" if use_optuna:\n",
|
||
" # study = optuna.create_study(direction='minimize' if classify else 'maximize')\n",
|
||
" study = optuna.create_study(direction='minimize')\n",
|
||
" study.optimize(lambda trial: objective(trial, X_train, y_train, num_boost_round, params), n_trials=20)\n",
|
||
"\n",
|
||
" print(f\"Best parameters: {study.best_trial.params}\")\n",
|
||
" print(f\"Best score: {study.best_trial.value}\")\n",
|
||
"\n",
|
||
" params.update(study.best_trial.params)\n",
|
||
" model = lgb.train(\n",
|
||
" params, train_data, num_boost_round=num_boost_round,\n",
|
||
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
|
||
" callbacks=callbacks\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 打印特征重要性(如果需要)\n",
|
||
" if print_feature_importance:\n",
|
||
" lgb.plot_metric(evals)\n",
|
||
" lgb.plot_tree(model, figsize=(20, 8))\n",
|
||
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
|
||
" plt.show()\n",
|
||
" # with open(cache_file, 'wb') as f:\n",
|
||
" # pickle.dump({'key': cache_key,\n",
|
||
" # 'model': model,\n",
|
||
" # 'feature_columns': feature_columns}, f)\n",
|
||
" # print(\"模型训练完成并已保存缓存。\")\n",
|
||
" return model\n",
|
||
"\n",
|
||
"\n",
|
||
"from catboost import CatBoostRegressor\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"\n",
|
||
"def train_catboost(df, num_boost_round, params=None):\n",
|
||
" \"\"\"\n",
|
||
" 训练 CatBoost 排序模型\n",
|
||
" - df: 包含因子、date、instrument 和 label 的 DataFrame\n",
|
||
" - num_boost_round: 训练的轮数\n",
|
||
" - print_feature_importance: 是否打印特征重要性\n",
|
||
" - plot: 是否绘制特征重要性图\n",
|
||
" - split_date: 用于划分训练集和验证集的日期(比如 '2020-01-01')\n",
|
||
"\n",
|
||
" 返回训练好的模型\n",
|
||
" \"\"\"\n",
|
||
" df_sorted = df.sort_values(by=['date', 'label'], ascending=[True, False])\n",
|
||
"\n",
|
||
" # 提取特征和标签\n",
|
||
" feature_columns = [col for col in df.columns if col not in ['date',\n",
|
||
" 'instrument',\n",
|
||
" 'label']]\n",
|
||
" feature_columns = [col for col in feature_columns if 'future' not in col]\n",
|
||
" feature_columns = [col for col in feature_columns if 'score' not in col]\n",
|
||
"\n",
|
||
" df_sorted = df_sorted.sort_values(by='date')\n",
|
||
" unique_dates = df_sorted['date'].unique()\n",
|
||
" val_date_count = int(len(unique_dates) * 0.1)\n",
|
||
" val_dates = unique_dates[-val_date_count:]\n",
|
||
" val_indices = df_sorted[df_sorted['date'].isin(val_dates)].index\n",
|
||
" train_indices = df_sorted[~df_sorted['date'].isin(val_dates)].index\n",
|
||
"\n",
|
||
" # 获取训练集和验证集的样本\n",
|
||
" train_df = df_sorted.iloc[train_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
|
||
" val_df = df_sorted.iloc[val_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
|
||
"\n",
|
||
" X_train = train_df[feature_columns]\n",
|
||
" y_train = train_df['label']\n",
|
||
"\n",
|
||
" X_val = val_df[feature_columns]\n",
|
||
" y_val = val_df['label']\n",
|
||
"\n",
|
||
" model = CatBoostRegressor(**params, iterations=num_boost_round)\n",
|
||
" model.fit(X_train,\n",
|
||
" y_train,\n",
|
||
" eval_set=(X_val, y_val))\n",
|
||
"\n",
|
||
" return model"
|
||
],
|
||
"id": "8f134d435f71e9e2",
|
||
"outputs": [],
|
||
"execution_count": 14
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:56:05.576927Z",
|
||
"start_time": "2025-02-09T14:56:05.480695Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"light_params = {\n",
|
||
" 'objective': 'regression',\n",
|
||
" 'metric': 'l2',\n",
|
||
" 'learning_rate': 0.05,\n",
|
||
" 'is_unbalance': True,\n",
|
||
" 'num_leaves': 2048,\n",
|
||
" 'min_data_in_leaf': 16,\n",
|
||
" 'max_depth': 32,\n",
|
||
" 'max_bin': 1024,\n",
|
||
" 'nthread': 2,\n",
|
||
" 'feature_fraction': 0.7,\n",
|
||
" 'bagging_fraction': 0.7,\n",
|
||
" 'bagging_freq': 5,\n",
|
||
" 'lambda_l1': 80,\n",
|
||
" 'lambda_l2': 65,\n",
|
||
" 'verbosity': -1\n",
|
||
"}"
|
||
],
|
||
"id": "4a4542e1ed6afe7d",
|
||
"outputs": [],
|
||
"execution_count": 15
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:57:25.341222Z",
|
||
"start_time": "2025-02-09T14:56:05.640256Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"print('train data size: ', len(train_data))\n",
|
||
"df = train_data\n",
|
||
"\n",
|
||
"evals = {}\n",
|
||
"light_model = train_light_model(train_data, light_params, feature_columns,\n",
|
||
" [lgb.log_evaluation(period=500),\n",
|
||
" lgb.callback.record_evaluation(evals),\n",
|
||
" lgb.early_stopping(50, first_metric_only=True)\n",
|
||
" ], evals,\n",
|
||
" num_boost_round=1000, use_optuna=False,\n",
|
||
" print_feature_importance=False)"
|
||
],
|
||
"id": "beeb098799ecfa6a",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"train data size: 875004\n",
|
||
"未发现缓存模型,开始训练新模型。\n",
|
||
"Training until validation scores don't improve for 50 rounds\n",
|
||
"Early stopping, best iteration is:\n",
|
||
"[378]\ttrain's l2: 0.435049\tvalid's l2: 0.589178\n",
|
||
"Evaluated only: l2\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 16
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:57:27.394697Z",
|
||
"start_time": "2025-02-09T14:57:25.373274Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"test_data['score'] = light_model.predict(test_data[feature_columns])\n",
|
||
"predictions = test_data.loc[test_data.groupby('trade_date')['score'].idxmax()]"
|
||
],
|
||
"id": "5bb96ca8492e74d",
|
||
"outputs": [],
|
||
"execution_count": 17
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-02-09T14:57:27.489570Z",
|
||
"start_time": "2025-02-09T14:57:27.397368Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "predictions[['trade_date', 'score', 'ts_code']].to_csv('predictions.csv', index=False)",
|
||
"id": "5d1522a7538db91b",
|
||
"outputs": [],
|
||
"execution_count": 18
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 2
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython2",
|
||
"version": "2.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|