Files
NewStock/code/train/V1.1.ipynb

1066 lines
50 KiB
Plaintext
Raw Normal View History

2025-02-12 00:21:33 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:47:01.758712Z",
"start_time": "2025-02-09T15:47:01.615180Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import pandas as pd\n",
"def read_and_merge_h5_data(h5_filename, key, columns, df=None):\n",
" \"\"\"\n",
" 读取 HDF5 文件中的数据,根据指定的 columns 筛选数据,\n",
" 如果传入 df 参数,则将其与读取的数据根据 ts_code 和 trade_date 合并。\n",
"\n",
" 参数:\n",
" - h5_filename: HDF5 文件名\n",
" - key: 数据存储在 HDF5 文件中的 key\n",
" - columns: 要读取的列名列表\n",
" - df: 需要合并的 DataFrame如果为空则不进行合并\n",
"\n",
" 返回:\n",
" - 合并后的 DataFrame\n",
" \"\"\"\n",
" # 处理 _ 开头的列名\n",
" processed_columns = []\n",
" for col in columns:\n",
" if col.startswith('_'):\n",
" processed_columns.append(col[1:]) # 去掉下划线\n",
" else:\n",
" processed_columns.append(col)\n",
"\n",
" # 从 HDF5 文件读取数据,选择需要的列\n",
" data = pd.read_hdf(h5_filename, key=key, columns=processed_columns)\n",
"\n",
" # 修改列名,如果列名以前有 _加上 _\n",
" for col in data.columns:\n",
" if col not in columns: # 只有不在 columns 中的列才需要加下划线\n",
" new_col = f'_{col}'\n",
" data.rename(columns={col: new_col}, inplace=True)\n",
"\n",
" # 如果传入的 df 不为空,则进行合并\n",
" if df is not None and not df.empty:\n",
" # 确保两个 DataFrame 都有 ts_code 和 trade_date 列\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" data['trade_date'] = pd.to_datetime(data['trade_date'], format='%Y%m%d')\n",
"\n",
" # 根据 ts_code 和 trade_date 合并\n",
" merged_df = pd.merge(df, data, on=['ts_code', 'trade_date'], how='left')\n",
" else:\n",
" # 如果 df 为空,则直接返回读取的数据\n",
" merged_df = data\n",
"\n",
" return merged_df\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:47:44.572473Z",
"start_time": "2025-02-09T15:47:01.772245Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"stk limit\n",
"money flow\n"
]
}
],
"source": [
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df)\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c4e9e1d31da6dba6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:47:44.700071Z",
"start_time": "2025-02-09T15:47:44.603849Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"origin_columns = df.columns.tolist()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a735bc02ceb4d872",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:47:44.856898Z",
"start_time": "2025-02-09T15:47:44.752186Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"def get_technical_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 计算 up 和 down\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14), index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算 RSI\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().shift(-1).rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().shift(-1).rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().shift(-1).rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().shift(-1).rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().shift(-1).rolling(window=90).std())\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
" # 计算 EMA 指标\n",
" df['ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(df['act_factor1']**2 + df['act_factor2']**2)\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
" df['alpha_022'] = grouped['close'].apply(lambda x: x - x.shift(5))\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" # alpha_007: 计算过去5日内 close 与 vol 的相关性,再按 trade_date 横截面排名\n",
" df['alpha_007'] = grouped.apply(\n",
" lambda x: pd.Series(x['close'].rolling(5).corr(x['vol']), index=x.index)\n",
" ).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,再按 trade_date 横截面排名\n",
" df['alpha_013'] = grouped.apply(\n",
" lambda x: pd.Series(x['close'].rolling(5).sum() - x['close'].rolling(20).sum(), index=x.index)\n",
" ).reset_index(level=0, drop=True)\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_future_data(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" # 预先对 ts_code 分组,使用 transform 保持原 DataFrame 形状\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" df['future_return1'] = (grouped['close'].transform(lambda x: x.shift(-1)) - df['close']) / df['close']\n",
" df['future_return2'] = (grouped['open'].transform(lambda x: x.shift(-2)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
" df['future_return3'] = (grouped['close'].transform(lambda x: x.shift(-2)) - grouped['close'].transform(lambda x: x.shift(-1))) / grouped['close'].transform(lambda x: x.shift(-1))\n",
" df['future_return4'] = (grouped['close'].transform(lambda x: x.shift(-2)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
" df['future_return5'] = (grouped['close'].transform(lambda x: x.shift(-5)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
" df['future_return6'] = (grouped['close'].transform(lambda x: x.shift(-10)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
" df['future_return7'] = (grouped['close'].transform(lambda x: x.shift(-20)) - grouped['open'].transform(lambda x: x.shift(-1))) / grouped['open'].transform(lambda x: x.shift(-1))\n",
"\n",
" df['future_close1'] = (grouped['close'].transform(lambda x: x.shift(-1)) - df['close']) / df['close']\n",
" df['future_close2'] = (grouped['close'].transform(lambda x: x.shift(-2)) - df['close']) / df['close']\n",
" df['future_close3'] = (grouped['close'].transform(lambda x: x.shift(-3)) - df['close']) / df['close']\n",
" df['future_close4'] = (grouped['close'].transform(lambda x: x.shift(-4)) - df['close']) / df['close']\n",
" df['future_close5'] = (grouped['close'].transform(lambda x: x.shift(-5)) - df['close']) / df['close']\n",
"\n",
" df['future_af11'] = grouped['act_factor1'].transform(lambda x: x.shift(-1))\n",
" df['future_af12'] = grouped['act_factor1'].transform(lambda x: x.shift(-2))\n",
" df['future_af13'] = grouped['act_factor1'].transform(lambda x: x.shift(-3))\n",
" df['future_af14'] = grouped['act_factor1'].transform(lambda x: x.shift(-4))\n",
" df['future_af15'] = grouped['act_factor1'].transform(lambda x: x.shift(-5))\n",
"\n",
" df['future_af21'] = grouped['act_factor2'].transform(lambda x: x.shift(-1))\n",
" df['future_af22'] = grouped['act_factor2'].transform(lambda x: x.shift(-2))\n",
" df['future_af23'] = grouped['act_factor2'].transform(lambda x: x.shift(-3))\n",
" df['future_af24'] = grouped['act_factor2'].transform(lambda x: x.shift(-4))\n",
" df['future_af25'] = grouped['act_factor2'].transform(lambda x: x.shift(-5))\n",
"\n",
" df['future_af31'] = grouped['act_factor3'].transform(lambda x: x.shift(-1))\n",
" df['future_af32'] = grouped['act_factor3'].transform(lambda x: x.shift(-2))\n",
" df['future_af33'] = grouped['act_factor3'].transform(lambda x: x.shift(-3))\n",
" df['future_af34'] = grouped['act_factor3'].transform(lambda x: x.shift(-4))\n",
" df['future_af35'] = grouped['act_factor3'].transform(lambda x: x.shift(-5))\n",
"\n",
" return df\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "53f86ddc0677a6d7",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:49:48.641755Z",
"start_time": "2025-02-09T15:47:44.862968Z"
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 8369694 entries, 1964 to 8129890\n",
"Data columns (total 87 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st object \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
" 21 up float64 \n",
" 22 down float64 \n",
" 23 atr_14 float64 \n",
" 24 atr_6 float64 \n",
" 25 obv float64 \n",
" 26 maobv_6 float64 \n",
" 27 obv-maobv_6 float64 \n",
" 28 rsi_3 float64 \n",
" 29 rsi_6 float64 \n",
" 30 rsi_9 float64 \n",
" 31 return_10 float64 \n",
" 32 return_20 float64 \n",
" 33 avg_close_5 float64 \n",
" 34 std_return_5 float64 \n",
" 35 std_return_15 float64 \n",
" 36 std_return_25 float64 \n",
" 37 std_return_90 float64 \n",
" 38 std_return_90_2 float64 \n",
" 39 std_return_5 / std_return_90 float64 \n",
" 40 std_return_5 / std_return_25 float64 \n",
" 41 std_return_90 - std_return_90_2 float64 \n",
" 42 ema_5 float64 \n",
" 43 ema_13 float64 \n",
" 44 ema_20 float64 \n",
" 45 ema_60 float64 \n",
" 46 act_factor1 float64 \n",
" 47 act_factor2 float64 \n",
" 48 act_factor3 float64 \n",
" 49 act_factor4 float64 \n",
" 50 act_factor5 float64 \n",
" 51 act_factor6 float64 \n",
" 52 rank_act_factor1 float64 \n",
" 53 rank_act_factor2 float64 \n",
" 54 rank_act_factor3 float64 \n",
" 55 active_buy_volume_large float64 \n",
" 56 active_buy_volume_big float64 \n",
" 57 active_buy_volume_small float64 \n",
" 58 buy_lg_vol_minus_sell_lg_vol float64 \n",
" 59 buy_elg_vol_minus_sell_elg_vol float64 \n",
" 60 future_return1 float64 \n",
" 61 future_return2 float64 \n",
" 62 future_return3 float64 \n",
" 63 future_return4 float64 \n",
" 64 future_return5 float64 \n",
" 65 future_return6 float64 \n",
" 66 future_return7 float64 \n",
" 67 future_close1 float64 \n",
" 68 future_close2 float64 \n",
" 69 future_close3 float64 \n",
" 70 future_close4 float64 \n",
" 71 future_close5 float64 \n",
" 72 future_af11 float64 \n",
" 73 future_af12 float64 \n",
" 74 future_af13 float64 \n",
" 75 future_af14 float64 \n",
" 76 future_af15 float64 \n",
" 77 future_af21 float64 \n",
" 78 future_af22 float64 \n",
" 79 future_af23 float64 \n",
" 80 future_af24 float64 \n",
" 81 future_af25 float64 \n",
" 82 future_af31 float64 \n",
" 83 future_af32 float64 \n",
" 84 future_af33 float64 \n",
" 85 future_af34 float64 \n",
" 86 future_af35 float64 \n",
"dtypes: datetime64[ns](1), float64(84), object(2)\n",
"memory usage: 5.5+ GB\n",
"None\n"
]
}
],
"source": [
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_future_data(df)\n",
"# df = df.drop(columns=origin_columns)\n",
"\n",
"print(df.info())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dbe2fd8021b9417f",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:50:31.240188Z",
"start_time": "2025-02-09T15:49:48.842399Z"
},
"jupyter": {
"source_hidden": true
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1312816 entries, 0 to 1312815\n",
"Data columns (total 87 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 ts_code 1312816 non-null object \n",
" 1 trade_date 1312816 non-null datetime64[ns]\n",
" 2 open 1312816 non-null float64 \n",
" 3 close 1312816 non-null float64 \n",
" 4 high 1312816 non-null float64 \n",
" 5 low 1312816 non-null float64 \n",
" 6 vol 1312816 non-null float64 \n",
" 7 turnover_rate 1312816 non-null float64 \n",
" 8 pe_ttm 1124776 non-null float64 \n",
" 9 circ_mv 1312816 non-null float64 \n",
" 10 volume_ratio 1312655 non-null float64 \n",
" 11 is_st 1312816 non-null object \n",
" 12 up_limit 1312488 non-null float64 \n",
" 13 down_limit 1312488 non-null float64 \n",
" 14 buy_sm_vol 1312268 non-null float64 \n",
" 15 sell_sm_vol 1312268 non-null float64 \n",
" 16 buy_lg_vol 1312268 non-null float64 \n",
" 17 sell_lg_vol 1312268 non-null float64 \n",
" 18 buy_elg_vol 1312268 non-null float64 \n",
" 19 sell_elg_vol 1312268 non-null float64 \n",
" 20 net_mf_vol 1312268 non-null float64 \n",
" 21 up 1312816 non-null float64 \n",
" 22 down 1312816 non-null float64 \n",
" 23 atr_14 1298995 non-null float64 \n",
" 24 atr_6 1306828 non-null float64 \n",
" 25 obv 1312816 non-null float64 \n",
" 26 maobv_6 1307823 non-null float64 \n",
" 27 obv-maobv_6 1307823 non-null float64 \n",
" 28 rsi_3 1309818 non-null float64 \n",
" 29 rsi_6 1306828 non-null float64 \n",
" 30 rsi_9 1303864 non-null float64 \n",
" 31 return_10 1302880 non-null float64 \n",
" 32 return_20 1293356 non-null float64 \n",
" 33 avg_close_5 1308820 non-null float64 \n",
" 34 std_return_5 1308425 non-null float64 \n",
" 35 std_return_15 1298600 non-null float64 \n",
" 36 std_return_25 1287756 non-null float64 \n",
" 37 std_return_90 1220157 non-null float64 \n",
" 38 std_return_90_2 1209808 non-null float64 \n",
" 39 std_return_5 / std_return_90 1220157 non-null float64 \n",
" 40 std_return_5 / std_return_25 1287756 non-null float64 \n",
" 41 std_return_90 - std_return_90_2 1209808 non-null float64 \n",
" 42 ema_5 1308820 non-null float64 \n",
" 43 ema_13 1300928 non-null float64 \n",
" 44 ema_20 1294280 non-null float64 \n",
" 45 ema_60 1252148 non-null float64 \n",
" 46 act_factor1 1307823 non-null float64 \n",
" 47 act_factor2 1299958 non-null float64 \n",
" 48 act_factor3 1293356 non-null float64 \n",
" 49 act_factor4 1251052 non-null float64 \n",
" 50 act_factor5 1251052 non-null float64 \n",
" 51 act_factor6 1299958 non-null float64 \n",
" 52 rank_act_factor1 1307823 non-null float64 \n",
" 53 rank_act_factor2 1299958 non-null float64 \n",
" 54 rank_act_factor3 1293356 non-null float64 \n",
" 55 active_buy_volume_large 1312260 non-null float64 \n",
" 56 active_buy_volume_big 1312240 non-null float64 \n",
" 57 active_buy_volume_small 1312268 non-null float64 \n",
" 58 buy_lg_vol_minus_sell_lg_vol 1312260 non-null float64 \n",
" 59 buy_elg_vol_minus_sell_elg_vol 1312245 non-null float64 \n",
" 60 future_return1 1312421 non-null float64 \n",
" 61 future_return2 1311991 non-null float64 \n",
" 62 future_return3 1311991 non-null float64 \n",
" 63 future_return4 1311991 non-null float64 \n",
" 64 future_return5 1310501 non-null float64 \n",
" 65 future_return6 1307894 non-null float64 \n",
" 66 future_return7 1301736 non-null float64 \n",
" 67 future_close1 1312421 non-null float64 \n",
" 68 future_close2 1311991 non-null float64 \n",
" 69 future_close3 1311562 non-null float64 \n",
" 70 future_close4 1311085 non-null float64 \n",
" 71 future_close5 1310501 non-null float64 \n",
" 72 future_af11 1308425 non-null float64 \n",
" 73 future_af12 1308993 non-null float64 \n",
" 74 future_af13 1309562 non-null float64 \n",
" 75 future_af14 1310083 non-null float64 \n",
" 76 future_af15 1310501 non-null float64 \n",
" 77 future_af21 1300533 non-null float64 \n",
" 78 future_af22 1301077 non-null float64 \n",
" 79 future_af23 1301626 non-null float64 \n",
" 80 future_af24 1302133 non-null float64 \n",
" 81 future_af25 1302535 non-null float64 \n",
" 82 future_af31 1293885 non-null float64 \n",
" 83 future_af32 1294388 non-null float64 \n",
" 84 future_af33 1294896 non-null float64 \n",
" 85 future_af34 1295360 non-null float64 \n",
" 86 future_af35 1295722 non-null float64 \n",
"dtypes: datetime64[ns](1), float64(84), object(2)\n",
"memory usage: 871.4+ MB\n",
"None\n"
]
}
],
"source": [
"def filter_data(df):\n",
" df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor3'))\n",
" df = df[df['is_st'] == False]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"print(df.info())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5f3d9aece75318cd",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:50:31.869896Z",
"start_time": "2025-02-09T15:50:31.350003Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol']\n",
"最小日期: 2017-01-03\n",
"最大日期: 2023-12-29\n",
"152435\n",
"最小日期: 2024-01-02\n",
"最大日期: 2025-02-10\n"
]
}
],
"source": [
"def remove_outliers_iqr(series, lower_quantile=0.05, upper_quantile=0.95, threshold=1.5):\n",
" Q1 = series.quantile(lower_quantile)\n",
" Q3 = series.quantile(upper_quantile)\n",
" IQR = Q3 - Q1\n",
" lower_bound = Q1 - threshold * IQR\n",
" upper_bound = Q3 + threshold * IQR\n",
" # 过滤掉低于下边界或高于上边界的极值\n",
" return (series >= lower_bound) & (series <= upper_bound)\n",
"\n",
"\n",
"def neutralize_labels(labels, features, feature_columns, z_threshold=3, method='regression'):\n",
" labels_no_outliers = remove_outliers_iqr(labels)\n",
" return labels_no_outliers\n",
"\n",
"\n",
"train_data = df[df['trade_date'] <= '2024-01-01']\n",
"test_data = df[df['trade_date'] >= '2024-01-01']\n",
"\n",
"feature_columns = [col for col in df.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"\n",
"# for column in [column for column in train_data.columns if 'future' in column]:\n",
"# label_index = neutralize_labels(train_data[column], train_data, feature_columns, z_threshold=3, method='regression')\n",
"# train_data = train_data[label_index]\n",
"# label_index = neutralize_labels(test_data[column], test_data, feature_columns, z_threshold=3, method='regression')\n",
"# test_data = test_data[label_index]\n",
"\n",
"print(feature_columns)\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f4f16d63ad18d1bc",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:50:34.382521Z",
"start_time": "2025-02-09T15:50:31.885550Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_113788\\2866503568.py:16: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" train_data['label'] = get_label(train_data)\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_113788\\2866503568.py:17: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" test_data['label'] = get_label(test_data)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"913920\n",
"最小日期: 2017-06-05\n",
"最大日期: 2023-12-29\n",
"152435\n",
"最小日期: 2024-01-02\n",
"最大日期: 2025-02-10\n"
]
}
],
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"window = 5\n",
"quantiles = 20\n",
"\n",
"\n",
"def get_label(df):\n",
" labels = df['future_af11'] - df['act_factor1']\n",
" # labels = df['future_close3']\n",
" return labels\n",
"\n",
"# train_data = get_future_data(train_data)\n",
"train_data['label'] = get_label(train_data)\n",
"test_data['label'] = get_label(test_data)\n",
"\n",
"train_data = train_data.dropna(subset=['label'])\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"train_data = train_data.replace([np.inf, -np.inf], np.nan).dropna()\n",
"# test_data = test_data.replace([np.inf, -np.inf], np.nan).dropna()\n",
"\n",
"# train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"# train_data = train_data.dropna(subset=['label'])\n",
"# train_data = train_data.dropna(subset=feature_columns)\n",
"# # test_data = test_data.dropna(subset=feature_columns)\n",
"train_data = train_data.reset_index(drop=True)\n",
"# test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8f134d435f71e9e2",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:50:34.548390Z",
"start_time": "2025-02-09T15:50:34.434660Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"import lightgbm as lgb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import optuna\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.metrics import mean_absolute_error\n",
"import os\n",
"import json\n",
"import pickle\n",
"import hashlib\n",
"\n",
"\n",
"def objective(trial, X, y, num_boost_round, params):\n",
" # 参数网格\n",
" X, y = X.reset_index(drop=True), y.reset_index(drop=True)\n",
" param_grid = {\n",
" \"n_estimators\": trial.suggest_categorical(\"n_estimators\", [10000]),\n",
" \"learning_rate\": trial.suggest_float(\"learning_rate\", 0.01, 0.3),\n",
" \"num_leaves\": trial.suggest_int(\"num_leaves\", 20, 3000, step=25),\n",
" \"max_depth\": trial.suggest_int(\"max_depth\", 3, 16),\n",
" \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 200, 10000, step=100),\n",
" \"lambda_l1\": trial.suggest_int(\"lambda_l1\", 0, 100, step=5),\n",
" \"lambda_l2\": trial.suggest_int(\"lambda_l2\", 0, 100, step=5),\n",
" \"min_gain_to_split\": trial.suggest_float(\"min_gain_to_split\", 0, 15),\n",
" \"bagging_fraction\": trial.suggest_float(\"bagging_fraction\", 0.2, 0.95, step=0.1),\n",
" \"bagging_freq\": trial.suggest_categorical(\"bagging_freq\", [1]),\n",
" \"feature_fraction\": trial.suggest_float(\"feature_fraction\", 0.2, 0.95, step=0.1),\n",
" \"random_state\": 1,\n",
" \"objective\": 'regression',\n",
" 'verbosity': -1\n",
" }\n",
" # 5折交叉验证\n",
" cv = KFold(n_splits=5, shuffle=False)\n",
"\n",
" cv_scores = np.empty(5)\n",
" for idx, (train_idx, test_idx) in enumerate(cv.split(X, y)):\n",
" X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]\n",
" y_train, y_test = y[train_idx], y[test_idx]\n",
"\n",
" # LGBM建模\n",
" model = lgb.LGBMRegressor(**param_grid, num_boost_round=num_boost_round)\n",
" model.fit(\n",
" X_train,\n",
" y_train,\n",
" eval_set=[(X_test, y_test)],\n",
" eval_metric=\"l2\",\n",
" callbacks=[\n",
" # LightGBMPruningCallback(trial, \"l2\"),\n",
" lgb.early_stopping(50, first_metric_only=True),\n",
" lgb.log_evaluation(period=-1)\n",
" ],\n",
" )\n",
" # 模型预测\n",
" preds = model.predict(X_test)\n",
" # 优化指标logloss最小\n",
" cv_scores[idx] = mean_absolute_error(y_test, preds)\n",
"\n",
" return np.mean(cv_scores)\n",
"\n",
"def generate_key(params, feature_columns, num_boost_round):\n",
" key_data = {\n",
" \"params\": params,\n",
" \"feature_columns\": feature_columns,\n",
" \"num_boost_round\": num_boost_round\n",
" }\n",
" # 转换成排序后的 JSON 字符串,再生成 md5 hash\n",
" key_str = json.dumps(key_data, sort_keys=True)\n",
" return hashlib.md5(key_str.encode('utf-8')).hexdigest()\n",
"\n",
"def train_light_model(df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" use_optuna=False):\n",
" cache_file = 'light_model.pkl'\n",
" cache_key = generate_key(params, feature_columns, num_boost_round)\n",
"\n",
" # 检查缓存文件是否存在\n",
" if os.path.exists(cache_file):\n",
" try:\n",
" with open(cache_file, 'rb') as f:\n",
" cache_data = pickle.load(f)\n",
" if cache_data.get('key') == cache_key:\n",
" print(\"加载缓存模型...\")\n",
" return cache_data.get('model')\n",
" else:\n",
" print(\"缓存模型的参数与当前参数不匹配,重新训练模型。\")\n",
" except Exception as e:\n",
" print(f\"加载缓存失败: {e},重新训练模型。\")\n",
" else:\n",
" print(\"未发现缓存模型,开始训练新模型。\")\n",
" # 确保数据按照 date 和 label 排序\n",
" df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False]) # 按日期升序、标签降序排序\n",
" df_sorted = df_sorted.sort_values(by='trade_date')\n",
" unique_dates = df_sorted['trade_date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices]\n",
" val_df = df_sorted.iloc[val_indices]\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" train_data = lgb.Dataset(X_train, label=y_train)\n",
" val_data = lgb.Dataset(X_val, label=y_val)\n",
" if use_optuna:\n",
" # study = optuna.create_study(direction='minimize' if classify else 'maximize')\n",
" study = optuna.create_study(direction='minimize')\n",
" study.optimize(lambda trial: objective(trial, X_train, y_train, num_boost_round, params), n_trials=20)\n",
"\n",
" print(f\"Best parameters: {study.best_trial.params}\")\n",
" print(f\"Best score: {study.best_trial.value}\")\n",
"\n",
" params.update(study.best_trial.params)\n",
" model = lgb.train(\n",
" params, train_data, num_boost_round=num_boost_round,\n",
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_tree(model, figsize=(20, 8))\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
" # with open(cache_file, 'wb') as f:\n",
" # pickle.dump({'key': cache_key,\n",
" # 'model': model,\n",
" # 'feature_columns': feature_columns}, f)\n",
" # print(\"模型训练完成并已保存缓存。\")\n",
" return model\n",
"\n",
"\n",
"from catboost import CatBoostRegressor\n",
"import pandas as pd\n",
"\n",
"\n",
"def train_catboost(df, num_boost_round, params=None):\n",
" \"\"\"\n",
" 训练 CatBoost 排序模型\n",
" - df: 包含因子、date、instrument 和 label 的 DataFrame\n",
" - num_boost_round: 训练的轮数\n",
" - print_feature_importance: 是否打印特征重要性\n",
" - plot: 是否绘制特征重要性图\n",
" - split_date: 用于划分训练集和验证集的日期(比如 '2020-01-01'\n",
"\n",
" 返回训练好的模型\n",
" \"\"\"\n",
" df_sorted = df.sort_values(by=['date', 'label'], ascending=[True, False])\n",
"\n",
" # 提取特征和标签\n",
" feature_columns = [col for col in df.columns if col not in ['date',\n",
" 'instrument',\n",
" 'label']]\n",
" feature_columns = [col for col in feature_columns if 'future' not in col]\n",
" feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"\n",
" df_sorted = df_sorted.sort_values(by='date')\n",
" unique_dates = df_sorted['date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
" val_df = df_sorted.iloc[val_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" model = CatBoostRegressor(**params, iterations=num_boost_round)\n",
" model.fit(X_train,\n",
" y_train,\n",
" eval_set=(X_val, y_val))\n",
"\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4a4542e1ed6afe7d",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:50:34.626736Z",
"start_time": "2025-02-09T15:50:34.548390Z"
}
},
"outputs": [],
"source": [
"light_params = {\n",
" 'objective': 'regression',\n",
" 'metric': 'l2',\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 16,\n",
" 'max_depth': 10,\n",
" 'max_bin': 1024,\n",
" 'nthread': 2,\n",
" 'feature_fraction': 0.7,\n",
" 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" 'lambda_l1': 80,\n",
" 'lambda_l2': 65,\n",
" 'verbosity': -1\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:52:13.316938Z",
"start_time": "2025-02-09T15:50:34.658007Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 913920\n",
"未发现缓存模型,开始训练新模型。\n",
"Training until validation scores don't improve for 50 rounds\n",
"[500]\ttrain's l2: 0.309564\tvalid's l2: 0.257146\n"
]
}
],
"source": [
"print('train data size: ', len(train_data))\n",
"\n",
"evals = {}\n",
"light_model = train_light_model(train_data, light_params, feature_columns,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=1000, use_optuna=False,\n",
" print_feature_importance=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5bb96ca8492e74d",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:52:16.272002Z",
"start_time": "2025-02-09T15:52:13.379954Z"
}
},
"outputs": [],
"source": [
"test_data['score'] = light_model.predict(test_data[feature_columns])\n",
"predictions = test_data.loc[test_data.groupby('trade_date')['score'].idxmax()]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d1522a7538db91b",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:52:16.381786Z",
"start_time": "2025-02-09T15:52:16.272002Z"
}
},
"outputs": [],
"source": [
"predictions[['trade_date', 'score', 'ts_code']].to_csv('predictions.tsv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b427ce41-9739-4e9e-bea8-5f2551fec5d7",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}