{ "cells": [ { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:52:54.170824Z", "start_time": "2025-02-09T14:52:53.544850Z" } }, "cell_type": "code", "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from code.utils.utils import read_and_merge_h5_data" ], "id": "79a7758178bafdd3", "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:53:36.873700Z", "start_time": "2025-02-09T14:52:54.170824Z" } }, "cell_type": "code", "source": [ "print('daily data')\n", "df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n", " columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n", " df=None)\n", "\n", "print('daily basic')\n", "df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic_with_st',\n", " columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n", " 'is_st'], df=df)\n", "\n", "print('stk limit')\n", "df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n", " columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n", " df=df)\n", "print('money flow')\n", "df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n", " columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n", " 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n", " df=df)" ], "id": "a79cafb06a7e0e43", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "daily data\n", "daily basic\n", "stk limit\n", "money flow\n" ] } ], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:53:37.426404Z", "start_time": "2025-02-09T14:53:36.955552Z" } }, "cell_type": "code", "source": "origin_columns = df.columns.tolist()", "id": "c4e9e1d31da6dba6", "outputs": [], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:53:38.164112Z", "start_time": "2025-02-09T14:53:38.070007Z" } }, "cell_type": "code", "source": [ "import numpy as np\n", "import talib\n", "\n", "\n", "def get_technical_factor(df):\n", " df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n", " df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n", "\n", " df['atr_14'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=14)\n", " df['atr_6'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=6)\n", "\n", " df['obv'] = talib.OBV(df['close'], df['vol'])\n", " df['maobv_6'] = talib.SMA(df['obv'], timeperiod=6)\n", " df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n", "\n", " df['rsi_3'] = talib.RSI(df['close'], timeperiod=3)\n", " df['rsi_6'] = talib.RSI(df['close'], timeperiod=6)\n", " df['rsi_9'] = talib.RSI(df['close'], timeperiod=9)\n", "\n", " df['return_10'] = df['close'] / df['close'].shift(10) - 1\n", " df['return_20'] = df['close'] / df['close'].shift(20) - 1\n", "\n", " # # 计算 _rank_return_10 和 _rank_return_20\n", " # df['_rank_return_10'] = df['return_10'].rank(pct=True)\n", " # df['_rank_return_20'] = df['return_20'].rank(pct=True)\n", "\n", " # 计算 avg_close_5\n", " df['avg_close_5'] = df['close'].rolling(window=5).mean() / df['close']\n", "\n", " # 计算 std_return_5, std_return_15, std_return_25, std_return_252, std_return_2522\n", " df['std_return_5'] = df['close'].pct_change().shift(-1).rolling(window=5).std()\n", " df['std_return_15'] = df['close'].pct_change().shift(-1).rolling(window=15).std()\n", " df['std_return_25'] = df['close'].pct_change().shift(-1).rolling(window=25).std()\n", " df['std_return_90'] = df['close'].pct_change().shift(-1).rolling(window=90).std()\n", " df['std_return_90_2'] = df['close'].shift(10).pct_change().shift(-1).rolling(window=90).std()\n", "\n", " # 计算 std_return_5 / std_return_252 和 std_return_5 / std_return_25\n", " df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n", " df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n", "\n", " # 计算 std_return_252 - std_return_2522\n", " df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n", " return df\n", "\n", "\n", "def get_act_factor(df):\n", " # 计算 m_ta_ema(close, 5), m_ta_ema(close, 13), m_ta_ema(close, 20), m_ta_ema(close, 60)\n", " df['ema_5'] = talib.EMA(df['close'], timeperiod=5)\n", " df['ema_13'] = talib.EMA(df['close'], timeperiod=13)\n", " df['ema_20'] = talib.EMA(df['close'], timeperiod=20)\n", " df['ema_60'] = talib.EMA(df['close'], timeperiod=60)\n", "\n", " # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n", " df['act_factor1'] = np.arctan((df['ema_5'] / df['ema_5'].shift(1) - 1) * 100) * 57.3 / 50\n", " df['act_factor2'] = np.arctan((df['ema_13'] / df['ema_13'].shift(1) - 1) * 100) * 57.3 / 40\n", " df['act_factor3'] = np.arctan((df['ema_20'] / df['ema_20'].shift(1) - 1) * 100) * 57.3 / 21\n", " df['act_factor4'] = np.arctan((df['ema_60'] / df['ema_60'].shift(1) - 1) * 100) * 57.3 / 10\n", "\n", " # 计算 act_factor5 和 act_factor6\n", " df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n", " df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n", " df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n", "\n", " # 根据 'trade_date' 进行分组,在每个组内分别计算 'act_factor1', 'act_factor2', 'act_factor3' 的排名\n", " df['rank_act_factor1'] = df.groupby('trade_date')['act_factor1'].rank(ascending=False, pct=True)\n", " df['rank_act_factor2'] = df.groupby('trade_date')['act_factor2'].rank(ascending=False, pct=True)\n", " df['rank_act_factor3'] = df.groupby('trade_date')['act_factor3'].rank(ascending=False, pct=True)\n", "\n", " return df\n", "\n", "\n", "def get_money_flow_factor(df):\n", " df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n", " df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n", " df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n", "\n", " df['buy_lg_vol - sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n", " df['buy_elg_vol - sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n", "\n", " # # 你还提到了一些其他字段:\n", " # df['net_active_buy_volume_main'] = df['net_mf_vol'] / df['buy_sm_vol']\n", " # df['netflow_amount_main'] = df['net_mf_vol'] / df['buy_sm_vol'] # 这里假设 'net_mf_vol' 是主流资金流\n", "\n", " # df['active_sell_volume_large'] = df['sell_lg_vol'] / df['sell_sm_vol']\n", " # df['active_sell_volume_big'] = df['sell_elg_vol'] / df['sell_sm_vol']\n", " # df['active_sell_volume_small'] = df['sell_sm_vol'] / df['sell_sm_vol']\n", "\n", " return df\n", "\n", "\n", "def get_alpha_factor(df):\n", " df['alpha_022'] = df['close'] - df['close'].shift(5)\n", "\n", " # alpha_003: (close - open) / (high - low)\n", " df['alpha_003'] = (df['close'] - df['open']) / (df['high'] - df['low'])\n", "\n", " # alpha_007: rank(correlation(close, volume, 5))\n", " df['alpha_007'] = df['close'].rolling(5).corr(df['vol']).rank(axis=1)\n", "\n", " # alpha_013: rank(sum(close, 5) - sum(close, 20))\n", " df['alpha_013'] = (df['close'].rolling(5).sum() - df['close'].rolling(20).sum()).rank(axis=1)\n", " return df\n", "\n", "\n", "def get_future_data(df):\n", " df['future_return1'] = (df['close'].shift(-1) - df['close']) / df['close']\n", " df['future_return2'] = (df['open'].shift(-2) - df['open'].shift(-1)) / df['open'].shift(-1)\n", " df['future_return3'] = (df['close'].shift(-2) - df['close'].shift(-1)) / df['close'].shift(-1)\n", " df['future_return4'] = (df['close'].shift(-2) - df['open'].shift(-1)) / df['open'].shift(-1)\n", " df['future_return5'] = (df['close'].shift(-5) - df['open'].shift(-1)) / df['open'].shift(-1)\n", " df['future_return6'] = (df['close'].shift(-10) - df['open'].shift(-1)) / df['open'].shift(-1)\n", " df['future_return7'] = (df['close'].shift(-20) - df['open'].shift(-1)) / df['open'].shift(-1)\n", " df['future_close1'] = (df['close'].shift(-1) - df['close']) / df['close']\n", " df['future_close2'] = (df['close'].shift(-2) - df['close']) / df['close']\n", " df['future_close3'] = (df['close'].shift(-3) - df['close']) / df['close']\n", " df['future_close4'] = (df['close'].shift(-4) - df['close']) / df['close']\n", " df['future_close5'] = (df['close'].shift(-5) - df['close']) / df['close']\n", " df['future_af11'] = df['act_factor1'].shift(-1)\n", " df['future_af12'] = df['act_factor1'].shift(-2)\n", " df['future_af13'] = df['act_factor1'].shift(-3)\n", " df['future_af14'] = df['act_factor1'].shift(-4)\n", " df['future_af15'] = df['act_factor1'].shift(-5)\n", " df['future_af21'] = df['act_factor2'].shift(-1)\n", " df['future_af22'] = df['act_factor2'].shift(-2)\n", " df['future_af23'] = df['act_factor2'].shift(-3)\n", " df['future_af24'] = df['act_factor2'].shift(-4)\n", " df['future_af25'] = df['act_factor2'].shift(-5)\n", " df['future_af31'] = df['act_factor3'].shift(-1)\n", " df['future_af32'] = df['act_factor3'].shift(-2)\n", " df['future_af33'] = df['act_factor3'].shift(-3)\n", " df['future_af34'] = df['act_factor3'].shift(-4)\n", " df['future_af35'] = df['act_factor3'].shift(-5)\n", "\n", " return df\n" ], "id": "a735bc02ceb4d872", "outputs": [], "execution_count": 4 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:53:49.153376Z", "start_time": "2025-02-09T14:53:38.164112Z" } }, "cell_type": "code", "source": [ "df = get_technical_factor(df)\n", "df = get_act_factor(df)\n", "df = get_money_flow_factor(df)\n", "df = get_future_data(df)\n", "# df = df.drop(columns=origin_columns)\n", "\n", "print(df.info())" ], "id": "53f86ddc0677a6d7", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 8364308 entries, 0 to 8364307\n", "Data columns (total 83 columns):\n", " # Column Dtype \n", "--- ------ ----- \n", " 0 ts_code object \n", " 1 trade_date datetime64[ns]\n", " 2 open float64 \n", " 3 close float64 \n", " 4 high float64 \n", " 5 low float64 \n", " 6 vol float64 \n", " 7 is_st object \n", " 8 up_limit float64 \n", " 9 down_limit float64 \n", " 10 buy_sm_vol float64 \n", " 11 sell_sm_vol float64 \n", " 12 buy_lg_vol float64 \n", " 13 sell_lg_vol float64 \n", " 14 buy_elg_vol float64 \n", " 15 sell_elg_vol float64 \n", " 16 net_mf_vol float64 \n", " 17 up float64 \n", " 18 down float64 \n", " 19 atr_14 float64 \n", " 20 atr_6 float64 \n", " 21 obv float64 \n", " 22 maobv_6 float64 \n", " 23 obv-maobv_6 float64 \n", " 24 rsi_3 float64 \n", " 25 rsi_6 float64 \n", " 26 rsi_9 float64 \n", " 27 return_10 float64 \n", " 28 return_20 float64 \n", " 29 avg_close_5 float64 \n", " 30 std_return_5 float64 \n", " 31 std_return_15 float64 \n", " 32 std_return_25 float64 \n", " 33 std_return_90 float64 \n", " 34 std_return_90_2 float64 \n", " 35 std_return_5 / std_return_90 float64 \n", " 36 std_return_5 / std_return_25 float64 \n", " 37 std_return_90 - std_return_90_2 float64 \n", " 38 ema_5 float64 \n", " 39 ema_13 float64 \n", " 40 ema_20 float64 \n", " 41 ema_60 float64 \n", " 42 act_factor1 float64 \n", " 43 act_factor2 float64 \n", " 44 act_factor3 float64 \n", " 45 act_factor4 float64 \n", " 46 act_factor5 float64 \n", " 47 act_factor6 float64 \n", " 48 rank_act_factor1 float64 \n", " 49 rank_act_factor2 float64 \n", " 50 rank_act_factor3 float64 \n", " 51 active_buy_volume_large float64 \n", " 52 active_buy_volume_big float64 \n", " 53 active_buy_volume_small float64 \n", " 54 buy_lg_vol - sell_lg_vol float64 \n", " 55 buy_elg_vol - sell_elg_vol float64 \n", " 56 future_return1 float64 \n", " 57 future_return2 float64 \n", " 58 future_return3 float64 \n", " 59 future_return4 float64 \n", " 60 future_return5 float64 \n", " 61 future_return6 float64 \n", " 62 future_return7 float64 \n", " 63 future_close1 float64 \n", " 64 future_close2 float64 \n", " 65 future_close3 float64 \n", " 66 future_close4 float64 \n", " 67 future_close5 float64 \n", " 68 future_af11 float64 \n", " 69 future_af12 float64 \n", " 70 future_af13 float64 \n", " 71 future_af14 float64 \n", " 72 future_af15 float64 \n", " 73 future_af21 float64 \n", " 74 future_af22 float64 \n", " 75 future_af23 float64 \n", " 76 future_af24 float64 \n", " 77 future_af25 float64 \n", " 78 future_af31 float64 \n", " 79 future_af32 float64 \n", " 80 future_af33 float64 \n", " 81 future_af34 float64 \n", " 82 future_af35 float64 \n", "dtypes: datetime64[ns](1), float64(80), object(2)\n", "memory usage: 5.2+ GB\n", "None\n" ] } ], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:55:28.712343Z", "start_time": "2025-02-09T14:53:49.279168Z" } }, "cell_type": "code", "source": [ "def filter_data(df):\n", " df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor3'))\n", " df = df[df['is_st'] == False]\n", " df = df[df['is_st'] == False]\n", " df = df[~df['ts_code'].str.startswith('30')]\n", " df = df[~df['ts_code'].str.startswith('68')]\n", " df = df[~df['ts_code'].str.startswith('8')]\n", " df = df.reset_index(drop=True)\n", " return df\n", "\n", "\n", "df = filter_data(df)\n", "print(df.info())" ], "id": "dbe2fd8021b9417f", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 1136157 entries, 0 to 1136156\n", "Data columns (total 83 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 ts_code 1136157 non-null object \n", " 1 trade_date 1136157 non-null datetime64[ns]\n", " 2 open 1136157 non-null float64 \n", " 3 close 1136157 non-null float64 \n", " 4 high 1136157 non-null float64 \n", " 5 low 1136157 non-null float64 \n", " 6 vol 1136157 non-null float64 \n", " 7 is_st 1136157 non-null object \n", " 8 up_limit 1135878 non-null float64 \n", " 9 down_limit 1135878 non-null float64 \n", " 10 buy_sm_vol 1135663 non-null float64 \n", " 11 sell_sm_vol 1135663 non-null float64 \n", " 12 buy_lg_vol 1135663 non-null float64 \n", " 13 sell_lg_vol 1135663 non-null float64 \n", " 14 buy_elg_vol 1135663 non-null float64 \n", " 15 sell_elg_vol 1135663 non-null float64 \n", " 16 net_mf_vol 1135663 non-null float64 \n", " 17 up 1136157 non-null float64 \n", " 18 down 1136157 non-null float64 \n", " 19 atr_14 1136157 non-null float64 \n", " 20 atr_6 1136157 non-null float64 \n", " 21 obv 1136157 non-null float64 \n", " 22 maobv_6 1136157 non-null float64 \n", " 23 obv-maobv_6 1136157 non-null float64 \n", " 24 rsi_3 1136157 non-null float64 \n", " 25 rsi_6 1136157 non-null float64 \n", " 26 rsi_9 1136157 non-null float64 \n", " 27 return_10 1136157 non-null float64 \n", " 28 return_20 1136157 non-null float64 \n", " 29 avg_close_5 1136157 non-null float64 \n", " 30 std_return_5 1136157 non-null float64 \n", " 31 std_return_15 1136157 non-null float64 \n", " 32 std_return_25 1136157 non-null float64 \n", " 33 std_return_90 1136131 non-null float64 \n", " 34 std_return_90_2 1136129 non-null float64 \n", " 35 std_return_5 / std_return_90 1136131 non-null float64 \n", " 36 std_return_5 / std_return_25 1136157 non-null float64 \n", " 37 std_return_90 - std_return_90_2 1136129 non-null float64 \n", " 38 ema_5 1136157 non-null float64 \n", " 39 ema_13 1136157 non-null float64 \n", " 40 ema_20 1136157 non-null float64 \n", " 41 ema_60 1136153 non-null float64 \n", " 42 act_factor1 1136157 non-null float64 \n", " 43 act_factor2 1136157 non-null float64 \n", " 44 act_factor3 1136157 non-null float64 \n", " 45 act_factor4 1136152 non-null float64 \n", " 46 act_factor5 1136152 non-null float64 \n", " 47 act_factor6 1136157 non-null float64 \n", " 48 rank_act_factor1 1136157 non-null float64 \n", " 49 rank_act_factor2 1136157 non-null float64 \n", " 50 rank_act_factor3 1136157 non-null float64 \n", " 51 active_buy_volume_large 1135659 non-null float64 \n", " 52 active_buy_volume_big 1135636 non-null float64 \n", " 53 active_buy_volume_small 1135663 non-null float64 \n", " 54 buy_lg_vol - sell_lg_vol 1135660 non-null float64 \n", " 55 buy_elg_vol - sell_elg_vol 1135640 non-null float64 \n", " 56 future_return1 1136157 non-null float64 \n", " 57 future_return2 1136157 non-null float64 \n", " 58 future_return3 1136157 non-null float64 \n", " 59 future_return4 1136157 non-null float64 \n", " 60 future_return5 1136157 non-null float64 \n", " 61 future_return6 1136157 non-null float64 \n", " 62 future_return7 1136157 non-null float64 \n", " 63 future_close1 1136157 non-null float64 \n", " 64 future_close2 1136157 non-null float64 \n", " 65 future_close3 1136157 non-null float64 \n", " 66 future_close4 1136157 non-null float64 \n", " 67 future_close5 1136157 non-null float64 \n", " 68 future_af11 1136157 non-null float64 \n", " 69 future_af12 1136157 non-null float64 \n", " 70 future_af13 1136157 non-null float64 \n", " 71 future_af14 1136157 non-null float64 \n", " 72 future_af15 1136157 non-null float64 \n", " 73 future_af21 1136157 non-null float64 \n", " 74 future_af22 1136157 non-null float64 \n", " 75 future_af23 1136157 non-null float64 \n", " 76 future_af24 1136157 non-null float64 \n", " 77 future_af25 1136157 non-null float64 \n", " 78 future_af31 1136157 non-null float64 \n", " 79 future_af32 1136157 non-null float64 \n", " 80 future_af33 1136157 non-null float64 \n", " 81 future_af34 1136157 non-null float64 \n", " 82 future_af35 1136157 non-null float64 \n", "dtypes: datetime64[ns](1), float64(80), object(2)\n", "memory usage: 719.5+ MB\n", "None\n" ] } ], "execution_count": 6 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T15:00:45.828404Z", "start_time": "2025-02-09T15:00:45.294830Z" } }, "cell_type": "code", "source": [ "def remove_outliers_iqr(series, lower_quantile=0.05, upper_quantile=0.95, threshold=1.5):\n", " Q1 = series.quantile(lower_quantile)\n", " Q3 = series.quantile(upper_quantile)\n", " IQR = Q3 - Q1\n", " lower_bound = Q1 - threshold * IQR\n", " upper_bound = Q3 + threshold * IQR\n", " # 过滤掉低于下边界或高于上边界的极值\n", " return (series >= lower_bound) & (series <= upper_bound)\n", "\n", "\n", "def neutralize_labels(labels, features, feature_columns, z_threshold=3, method='regression'):\n", " labels_no_outliers = remove_outliers_iqr(labels)\n", " return labels_no_outliers\n", "\n", "\n", "train_data = df[df['trade_date'] <= '2023-01-01']\n", "test_data = df[df['trade_date'] >= '2023-01-01']\n", "\n", "feature_columns = [col for col in df.columns if col not in ['trade_date',\n", " 'ts_code',\n", " 'label']]\n", "feature_columns = [col for col in feature_columns if 'future' not in col]\n", "feature_columns = [col for col in feature_columns if 'score' not in col]\n", "feature_columns = [col for col in feature_columns if col not in origin_columns]\n", "\n", "# for column in [column for column in train_data.columns if 'future' in column]:\n", "# label_index = neutralize_labels(train_data[column], train_data, feature_columns, z_threshold=3, method='regression')\n", "# train_data = train_data[label_index]\n", "# label_index = neutralize_labels(test_data[column], test_data, feature_columns, z_threshold=3, method='regression')\n", "# test_data = test_data[label_index]\n", "\n", "print(len(train_data))\n", "print(len(test_data))" ], "id": "5f3d9aece75318cd", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol - sell_lg_vol', 'buy_elg_vol - sell_elg_vol']\n" ] } ], "execution_count": 19 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:56:05.319915Z", "start_time": "2025-02-09T14:56:03.355725Z" } }, "cell_type": "code", "source": [ "def get_qcuts(series, quantiles):\n", " q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n", " return q[-1] # 返回窗口最后一个元素的分位数标签\n", "\n", "\n", "window = 5\n", "quantiles = 20\n", "\n", "\n", "def get_label(df):\n", " labels = df['future_af13'] - df['act_factor1']\n", " # labels = df['future_close3']\n", " return labels\n", "\n", "\n", "train_data['label'], test_data['label'] = get_label(train_data), get_label(test_data)\n", "\n", "train_data, test_data = train_data.dropna(subset=['label']), test_data.dropna(subset=['label'])\n", "train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan).dropna(), test_data.replace([np.inf, -np.inf],\n", " np.nan).dropna()\n", "train_data, test_data = train_data.reset_index(drop=True), test_data.reset_index(drop=True)\n", "\n", "print(len(train_data))\n", "print(len(test_data))" ], "id": "f4f16d63ad18d1bc", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "875004\n", "最小日期: 2017-01-03\n", "最大日期: 2022-12-30\n", "260581\n", "最小日期: 2023-01-03\n", "最大日期: 2025-01-27\n" ] } ], "execution_count": 13 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:56:05.480695Z", "start_time": "2025-02-09T14:56:05.367238Z" } }, "cell_type": "code", "source": [ "import lightgbm as lgb\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import optuna\n", "from sklearn.model_selection import KFold\n", "from sklearn.metrics import mean_absolute_error\n", "import os\n", "import json\n", "import pickle\n", "import hashlib\n", "\n", "\n", "def objective(trial, X, y, num_boost_round, params):\n", " # 参数网格\n", " X, y = X.reset_index(drop=True), y.reset_index(drop=True)\n", " param_grid = {\n", " \"n_estimators\": trial.suggest_categorical(\"n_estimators\", [10000]),\n", " \"learning_rate\": trial.suggest_float(\"learning_rate\", 0.01, 0.3),\n", " \"num_leaves\": trial.suggest_int(\"num_leaves\", 20, 3000, step=25),\n", " \"max_depth\": trial.suggest_int(\"max_depth\", 3, 16),\n", " \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 200, 10000, step=100),\n", " \"lambda_l1\": trial.suggest_int(\"lambda_l1\", 0, 100, step=5),\n", " \"lambda_l2\": trial.suggest_int(\"lambda_l2\", 0, 100, step=5),\n", " \"min_gain_to_split\": trial.suggest_float(\"min_gain_to_split\", 0, 15),\n", " \"bagging_fraction\": trial.suggest_float(\"bagging_fraction\", 0.2, 0.95, step=0.1),\n", " \"bagging_freq\": trial.suggest_categorical(\"bagging_freq\", [1]),\n", " \"feature_fraction\": trial.suggest_float(\"feature_fraction\", 0.2, 0.95, step=0.1),\n", " \"random_state\": 1,\n", " \"objective\": 'regression',\n", " 'verbosity': -1\n", " }\n", " # 5折交叉验证\n", " cv = KFold(n_splits=5, shuffle=False)\n", "\n", " cv_scores = np.empty(5)\n", " for idx, (train_idx, test_idx) in enumerate(cv.split(X, y)):\n", " X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]\n", " y_train, y_test = y[train_idx], y[test_idx]\n", "\n", " # LGBM建模\n", " model = lgb.LGBMRegressor(**param_grid, num_boost_round=num_boost_round)\n", " model.fit(\n", " X_train,\n", " y_train,\n", " eval_set=[(X_test, y_test)],\n", " eval_metric=\"l2\",\n", " callbacks=[\n", " # LightGBMPruningCallback(trial, \"l2\"),\n", " lgb.early_stopping(50, first_metric_only=True),\n", " lgb.log_evaluation(period=-1)\n", " ],\n", " )\n", " # 模型预测\n", " preds = model.predict(X_test)\n", " # 优化指标logloss最小\n", " cv_scores[idx] = mean_absolute_error(y_test, preds)\n", "\n", " return np.mean(cv_scores)\n", "\n", "def generate_key(params, feature_columns, num_boost_round):\n", " key_data = {\n", " \"params\": params,\n", " \"feature_columns\": feature_columns,\n", " \"num_boost_round\": num_boost_round\n", " }\n", " # 转换成排序后的 JSON 字符串,再生成 md5 hash\n", " key_str = json.dumps(key_data, sort_keys=True)\n", " return hashlib.md5(key_str.encode('utf-8')).hexdigest()\n", "\n", "def train_light_model(df, params, feature_columns, callbacks, evals,\n", " print_feature_importance=True, num_boost_round=100,\n", " use_optuna=False):\n", " cache_file = 'light_model.pkl'\n", " cache_key = generate_key(params, feature_columns, num_boost_round)\n", "\n", " # 检查缓存文件是否存在\n", " if os.path.exists(cache_file):\n", " try:\n", " with open(cache_file, 'rb') as f:\n", " cache_data = pickle.load(f)\n", " if cache_data.get('key') == cache_key:\n", " print(\"加载缓存模型...\")\n", " return cache_data.get('model')\n", " else:\n", " print(\"缓存模型的参数与当前参数不匹配,重新训练模型。\")\n", " except Exception as e:\n", " print(f\"加载缓存失败: {e},重新训练模型。\")\n", " else:\n", " print(\"未发现缓存模型,开始训练新模型。\")\n", " # 确保数据按照 date 和 label 排序\n", " df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False]) # 按日期升序、标签降序排序\n", " df_sorted = df_sorted.sort_values(by='trade_date')\n", " unique_dates = df_sorted['trade_date'].unique()\n", " val_date_count = int(len(unique_dates) * 0.1)\n", " val_dates = unique_dates[-val_date_count:]\n", " val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n", " train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n", "\n", " # 获取训练集和验证集的样本\n", " train_df = df_sorted.iloc[train_indices]\n", " val_df = df_sorted.iloc[val_indices]\n", "\n", " X_train = train_df[feature_columns]\n", " y_train = train_df['label']\n", "\n", " X_val = val_df[feature_columns]\n", " y_val = val_df['label']\n", "\n", " train_data = lgb.Dataset(X_train, label=y_train)\n", " val_data = lgb.Dataset(X_val, label=y_val)\n", " if use_optuna:\n", " # study = optuna.create_study(direction='minimize' if classify else 'maximize')\n", " study = optuna.create_study(direction='minimize')\n", " study.optimize(lambda trial: objective(trial, X_train, y_train, num_boost_round, params), n_trials=20)\n", "\n", " print(f\"Best parameters: {study.best_trial.params}\")\n", " print(f\"Best score: {study.best_trial.value}\")\n", "\n", " params.update(study.best_trial.params)\n", " model = lgb.train(\n", " params, train_data, num_boost_round=num_boost_round,\n", " valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n", " callbacks=callbacks\n", " )\n", "\n", " # 打印特征重要性(如果需要)\n", " if print_feature_importance:\n", " lgb.plot_metric(evals)\n", " lgb.plot_tree(model, figsize=(20, 8))\n", " lgb.plot_importance(model, importance_type='split', max_num_features=20)\n", " plt.show()\n", " # with open(cache_file, 'wb') as f:\n", " # pickle.dump({'key': cache_key,\n", " # 'model': model,\n", " # 'feature_columns': feature_columns}, f)\n", " # print(\"模型训练完成并已保存缓存。\")\n", " return model\n", "\n", "\n", "from catboost import CatBoostRegressor\n", "import pandas as pd\n", "\n", "\n", "def train_catboost(df, num_boost_round, params=None):\n", " \"\"\"\n", " 训练 CatBoost 排序模型\n", " - df: 包含因子、date、instrument 和 label 的 DataFrame\n", " - num_boost_round: 训练的轮数\n", " - print_feature_importance: 是否打印特征重要性\n", " - plot: 是否绘制特征重要性图\n", " - split_date: 用于划分训练集和验证集的日期(比如 '2020-01-01')\n", "\n", " 返回训练好的模型\n", " \"\"\"\n", " df_sorted = df.sort_values(by=['date', 'label'], ascending=[True, False])\n", "\n", " # 提取特征和标签\n", " feature_columns = [col for col in df.columns if col not in ['date',\n", " 'instrument',\n", " 'label']]\n", " feature_columns = [col for col in feature_columns if 'future' not in col]\n", " feature_columns = [col for col in feature_columns if 'score' not in col]\n", "\n", " df_sorted = df_sorted.sort_values(by='date')\n", " unique_dates = df_sorted['date'].unique()\n", " val_date_count = int(len(unique_dates) * 0.1)\n", " val_dates = unique_dates[-val_date_count:]\n", " val_indices = df_sorted[df_sorted['date'].isin(val_dates)].index\n", " train_indices = df_sorted[~df_sorted['date'].isin(val_dates)].index\n", "\n", " # 获取训练集和验证集的样本\n", " train_df = df_sorted.iloc[train_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n", " val_df = df_sorted.iloc[val_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n", "\n", " X_train = train_df[feature_columns]\n", " y_train = train_df['label']\n", "\n", " X_val = val_df[feature_columns]\n", " y_val = val_df['label']\n", "\n", " model = CatBoostRegressor(**params, iterations=num_boost_round)\n", " model.fit(X_train,\n", " y_train,\n", " eval_set=(X_val, y_val))\n", "\n", " return model" ], "id": "8f134d435f71e9e2", "outputs": [], "execution_count": 14 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:56:05.576927Z", "start_time": "2025-02-09T14:56:05.480695Z" } }, "cell_type": "code", "source": [ "light_params = {\n", " 'objective': 'regression',\n", " 'metric': 'l2',\n", " 'learning_rate': 0.05,\n", " 'is_unbalance': True,\n", " 'num_leaves': 2048,\n", " 'min_data_in_leaf': 16,\n", " 'max_depth': 32,\n", " 'max_bin': 1024,\n", " 'nthread': 2,\n", " 'feature_fraction': 0.7,\n", " 'bagging_fraction': 0.7,\n", " 'bagging_freq': 5,\n", " 'lambda_l1': 80,\n", " 'lambda_l2': 65,\n", " 'verbosity': -1\n", "}" ], "id": "4a4542e1ed6afe7d", "outputs": [], "execution_count": 15 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:57:25.341222Z", "start_time": "2025-02-09T14:56:05.640256Z" } }, "cell_type": "code", "source": [ "print('train data size: ', len(train_data))\n", "df = train_data\n", "\n", "evals = {}\n", "light_model = train_light_model(train_data, light_params, feature_columns,\n", " [lgb.log_evaluation(period=500),\n", " lgb.callback.record_evaluation(evals),\n", " lgb.early_stopping(50, first_metric_only=True)\n", " ], evals,\n", " num_boost_round=1000, use_optuna=False,\n", " print_feature_importance=False)" ], "id": "beeb098799ecfa6a", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train data size: 875004\n", "未发现缓存模型,开始训练新模型。\n", "Training until validation scores don't improve for 50 rounds\n", "Early stopping, best iteration is:\n", "[378]\ttrain's l2: 0.435049\tvalid's l2: 0.589178\n", "Evaluated only: l2\n" ] } ], "execution_count": 16 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:57:27.394697Z", "start_time": "2025-02-09T14:57:25.373274Z" } }, "cell_type": "code", "source": [ "test_data['score'] = light_model.predict(test_data[feature_columns])\n", "predictions = test_data.loc[test_data.groupby('trade_date')['score'].idxmax()]" ], "id": "5bb96ca8492e74d", "outputs": [], "execution_count": 17 }, { "metadata": { "ExecuteTime": { "end_time": "2025-02-09T14:57:27.489570Z", "start_time": "2025-02-09T14:57:27.397368Z" } }, "cell_type": "code", "source": "predictions[['trade_date', 'score', 'ts_code']].to_csv('predictions.csv', index=False)", "id": "5d1522a7538db91b", "outputs": [], "execution_count": 18 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }