This commit is contained in:
liaozhaorun
2025-04-03 00:45:07 +08:00
parent 01092b8cae
commit ea3955f80f
36 changed files with 44862 additions and 0 deletions

1317
code/train/ClassifyLR.ipynb Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1385
code/train/DoubleRank.ipynb Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

1250
code/train/RollingRank.py Normal file

File diff suppressed because it is too large Load Diff

1723
code/train/TRank.ipynb Normal file

File diff suppressed because it is too large Load Diff

1362
code/train/Transformer.ipynb Normal file

File diff suppressed because one or more lines are too long

1570
code/train/UpdateRank.ipynb Normal file

File diff suppressed because one or more lines are too long

896
code/train/V1-copy.ipynb Normal file
View File

@@ -0,0 +1,896 @@
{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:52:54.170824Z",
"start_time": "2025-02-09T14:52:53.544850Z"
}
},
"cell_type": "code",
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"from utils.utils import read_and_merge_h5_data"
],
"id": "79a7758178bafdd3",
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:53:36.873700Z",
"start_time": "2025-02-09T14:52:54.170824Z"
}
},
"cell_type": "code",
"source": [
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic_with_st',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df)\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)"
],
"id": "a79cafb06a7e0e43",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"stk limit\n",
"money flow\n"
]
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:53:37.426404Z",
"start_time": "2025-02-09T14:53:36.955552Z"
}
},
"cell_type": "code",
"source": "origin_columns = df.columns.tolist()",
"id": "c4e9e1d31da6dba6",
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:53:38.164112Z",
"start_time": "2025-02-09T14:53:38.070007Z"
}
},
"cell_type": "code",
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"def get_technical_factor(df):\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" df['atr_14'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=14)\n",
" df['atr_6'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=6)\n",
"\n",
" df['obv'] = talib.OBV(df['close'], df['vol'])\n",
" df['maobv_6'] = talib.SMA(df['obv'], timeperiod=6)\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" df['rsi_3'] = talib.RSI(df['close'], timeperiod=3)\n",
" df['rsi_6'] = talib.RSI(df['close'], timeperiod=6)\n",
" df['rsi_9'] = talib.RSI(df['close'], timeperiod=9)\n",
"\n",
" df['return_10'] = df['close'] / df['close'].shift(10) - 1\n",
" df['return_20'] = df['close'] / df['close'].shift(20) - 1\n",
"\n",
" # # 计算 _rank_return_10 和 _rank_return_20\n",
" # df['_rank_return_10'] = df['return_10'].rank(pct=True)\n",
" # df['_rank_return_20'] = df['return_20'].rank(pct=True)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = df['close'].rolling(window=5).mean() / df['close']\n",
"\n",
" # 计算 std_return_5, std_return_15, std_return_25, std_return_252, std_return_2522\n",
" df['std_return_5'] = df['close'].pct_change().shift(-1).rolling(window=5).std()\n",
" df['std_return_15'] = df['close'].pct_change().shift(-1).rolling(window=15).std()\n",
" df['std_return_25'] = df['close'].pct_change().shift(-1).rolling(window=25).std()\n",
" df['std_return_90'] = df['close'].pct_change().shift(-1).rolling(window=90).std()\n",
" df['std_return_90_2'] = df['close'].shift(10).pct_change().shift(-1).rolling(window=90).std()\n",
"\n",
" # 计算 std_return_5 / std_return_252 和 std_return_5 / std_return_25\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算 std_return_252 - std_return_2522\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df):\n",
" # 计算 m_ta_ema(close, 5), m_ta_ema(close, 13), m_ta_ema(close, 20), m_ta_ema(close, 60)\n",
" df['ema_5'] = talib.EMA(df['close'], timeperiod=5)\n",
" df['ema_13'] = talib.EMA(df['close'], timeperiod=13)\n",
" df['ema_20'] = talib.EMA(df['close'], timeperiod=20)\n",
" df['ema_60'] = talib.EMA(df['close'], timeperiod=60)\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = np.arctan((df['ema_5'] / df['ema_5'].shift(1) - 1) * 100) * 57.3 / 50\n",
" df['act_factor2'] = np.arctan((df['ema_13'] / df['ema_13'].shift(1) - 1) * 100) * 57.3 / 40\n",
" df['act_factor3'] = np.arctan((df['ema_20'] / df['ema_20'].shift(1) - 1) * 100) * 57.3 / 21\n",
" df['act_factor4'] = np.arctan((df['ema_60'] / df['ema_60'].shift(1) - 1) * 100) * 57.3 / 10\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" # 根据 'trade_date' 进行分组,在每个组内分别计算 'act_factor1', 'act_factor2', 'act_factor3' 的排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date')['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date')['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date')['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol - sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol - sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" # # 你还提到了一些其他字段:\n",
" # df['net_active_buy_volume_main'] = df['net_mf_vol'] / df['buy_sm_vol']\n",
" # df['netflow_amount_main'] = df['net_mf_vol'] / df['buy_sm_vol'] # 这里假设 'net_mf_vol' 是主流资金流\n",
"\n",
" # df['active_sell_volume_large'] = df['sell_lg_vol'] / df['sell_sm_vol']\n",
" # df['active_sell_volume_big'] = df['sell_elg_vol'] / df['sell_sm_vol']\n",
" # df['active_sell_volume_small'] = df['sell_sm_vol'] / df['sell_sm_vol']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df['alpha_022'] = df['close'] - df['close'].shift(5)\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = (df['close'] - df['open']) / (df['high'] - df['low'])\n",
"\n",
" # alpha_007: rank(correlation(close, volume, 5))\n",
" df['alpha_007'] = df['close'].rolling(5).corr(df['vol']).rank(axis=1)\n",
"\n",
" # alpha_013: rank(sum(close, 5) - sum(close, 20))\n",
" df['alpha_013'] = (df['close'].rolling(5).sum() - df['close'].rolling(20).sum()).rank(axis=1)\n",
" return df\n",
"\n",
"\n",
"def get_future_data(df):\n",
" df['future_return1'] = (df['close'].shift(-1) - df['close']) / df['close']\n",
" df['future_return2'] = (df['open'].shift(-2) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_return3'] = (df['close'].shift(-2) - df['close'].shift(-1)) / df['close'].shift(-1)\n",
" df['future_return4'] = (df['close'].shift(-2) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_return5'] = (df['close'].shift(-5) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_return6'] = (df['close'].shift(-10) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_return7'] = (df['close'].shift(-20) - df['open'].shift(-1)) / df['open'].shift(-1)\n",
" df['future_close1'] = (df['close'].shift(-1) - df['close']) / df['close']\n",
" df['future_close2'] = (df['close'].shift(-2) - df['close']) / df['close']\n",
" df['future_close3'] = (df['close'].shift(-3) - df['close']) / df['close']\n",
" df['future_close4'] = (df['close'].shift(-4) - df['close']) / df['close']\n",
" df['future_close5'] = (df['close'].shift(-5) - df['close']) / df['close']\n",
" df['future_af11'] = df['act_factor1'].shift(-1)\n",
" df['future_af12'] = df['act_factor1'].shift(-2)\n",
" df['future_af13'] = df['act_factor1'].shift(-3)\n",
" df['future_af14'] = df['act_factor1'].shift(-4)\n",
" df['future_af15'] = df['act_factor1'].shift(-5)\n",
" df['future_af21'] = df['act_factor2'].shift(-1)\n",
" df['future_af22'] = df['act_factor2'].shift(-2)\n",
" df['future_af23'] = df['act_factor2'].shift(-3)\n",
" df['future_af24'] = df['act_factor2'].shift(-4)\n",
" df['future_af25'] = df['act_factor2'].shift(-5)\n",
" df['future_af31'] = df['act_factor3'].shift(-1)\n",
" df['future_af32'] = df['act_factor3'].shift(-2)\n",
" df['future_af33'] = df['act_factor3'].shift(-3)\n",
" df['future_af34'] = df['act_factor3'].shift(-4)\n",
" df['future_af35'] = df['act_factor3'].shift(-5)\n",
"\n",
" return df\n"
],
"id": "a735bc02ceb4d872",
"outputs": [],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:53:49.153376Z",
"start_time": "2025-02-09T14:53:38.164112Z"
}
},
"cell_type": "code",
"source": [
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_future_data(df)\n",
"# df = df.drop(columns=origin_columns)\n",
"\n",
"print(df.info())"
],
"id": "53f86ddc0677a6d7",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8364308 entries, 0 to 8364307\n",
"Data columns (total 83 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 is_st object \n",
" 8 up_limit float64 \n",
" 9 down_limit float64 \n",
" 10 buy_sm_vol float64 \n",
" 11 sell_sm_vol float64 \n",
" 12 buy_lg_vol float64 \n",
" 13 sell_lg_vol float64 \n",
" 14 buy_elg_vol float64 \n",
" 15 sell_elg_vol float64 \n",
" 16 net_mf_vol float64 \n",
" 17 up float64 \n",
" 18 down float64 \n",
" 19 atr_14 float64 \n",
" 20 atr_6 float64 \n",
" 21 obv float64 \n",
" 22 maobv_6 float64 \n",
" 23 obv-maobv_6 float64 \n",
" 24 rsi_3 float64 \n",
" 25 rsi_6 float64 \n",
" 26 rsi_9 float64 \n",
" 27 return_10 float64 \n",
" 28 return_20 float64 \n",
" 29 avg_close_5 float64 \n",
" 30 std_return_5 float64 \n",
" 31 std_return_15 float64 \n",
" 32 std_return_25 float64 \n",
" 33 std_return_90 float64 \n",
" 34 std_return_90_2 float64 \n",
" 35 std_return_5 / std_return_90 float64 \n",
" 36 std_return_5 / std_return_25 float64 \n",
" 37 std_return_90 - std_return_90_2 float64 \n",
" 38 ema_5 float64 \n",
" 39 ema_13 float64 \n",
" 40 ema_20 float64 \n",
" 41 ema_60 float64 \n",
" 42 act_factor1 float64 \n",
" 43 act_factor2 float64 \n",
" 44 act_factor3 float64 \n",
" 45 act_factor4 float64 \n",
" 46 act_factor5 float64 \n",
" 47 act_factor6 float64 \n",
" 48 rank_act_factor1 float64 \n",
" 49 rank_act_factor2 float64 \n",
" 50 rank_act_factor3 float64 \n",
" 51 active_buy_volume_large float64 \n",
" 52 active_buy_volume_big float64 \n",
" 53 active_buy_volume_small float64 \n",
" 54 buy_lg_vol - sell_lg_vol float64 \n",
" 55 buy_elg_vol - sell_elg_vol float64 \n",
" 56 future_return1 float64 \n",
" 57 future_return2 float64 \n",
" 58 future_return3 float64 \n",
" 59 future_return4 float64 \n",
" 60 future_return5 float64 \n",
" 61 future_return6 float64 \n",
" 62 future_return7 float64 \n",
" 63 future_close1 float64 \n",
" 64 future_close2 float64 \n",
" 65 future_close3 float64 \n",
" 66 future_close4 float64 \n",
" 67 future_close5 float64 \n",
" 68 future_af11 float64 \n",
" 69 future_af12 float64 \n",
" 70 future_af13 float64 \n",
" 71 future_af14 float64 \n",
" 72 future_af15 float64 \n",
" 73 future_af21 float64 \n",
" 74 future_af22 float64 \n",
" 75 future_af23 float64 \n",
" 76 future_af24 float64 \n",
" 77 future_af25 float64 \n",
" 78 future_af31 float64 \n",
" 79 future_af32 float64 \n",
" 80 future_af33 float64 \n",
" 81 future_af34 float64 \n",
" 82 future_af35 float64 \n",
"dtypes: datetime64[ns](1), float64(80), object(2)\n",
"memory usage: 5.2+ GB\n",
"None\n"
]
}
],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:55:28.712343Z",
"start_time": "2025-02-09T14:53:49.279168Z"
}
},
"cell_type": "code",
"source": [
"def filter_data(df):\n",
" df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor3'))\n",
" df = df[df['is_st'] == False]\n",
" df = df[df['is_st'] == False]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"print(df.info())"
],
"id": "dbe2fd8021b9417f",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1136157 entries, 0 to 1136156\n",
"Data columns (total 83 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 ts_code 1136157 non-null object \n",
" 1 trade_date 1136157 non-null datetime64[ns]\n",
" 2 open 1136157 non-null float64 \n",
" 3 close 1136157 non-null float64 \n",
" 4 high 1136157 non-null float64 \n",
" 5 low 1136157 non-null float64 \n",
" 6 vol 1136157 non-null float64 \n",
" 7 is_st 1136157 non-null object \n",
" 8 up_limit 1135878 non-null float64 \n",
" 9 down_limit 1135878 non-null float64 \n",
" 10 buy_sm_vol 1135663 non-null float64 \n",
" 11 sell_sm_vol 1135663 non-null float64 \n",
" 12 buy_lg_vol 1135663 non-null float64 \n",
" 13 sell_lg_vol 1135663 non-null float64 \n",
" 14 buy_elg_vol 1135663 non-null float64 \n",
" 15 sell_elg_vol 1135663 non-null float64 \n",
" 16 net_mf_vol 1135663 non-null float64 \n",
" 17 up 1136157 non-null float64 \n",
" 18 down 1136157 non-null float64 \n",
" 19 atr_14 1136157 non-null float64 \n",
" 20 atr_6 1136157 non-null float64 \n",
" 21 obv 1136157 non-null float64 \n",
" 22 maobv_6 1136157 non-null float64 \n",
" 23 obv-maobv_6 1136157 non-null float64 \n",
" 24 rsi_3 1136157 non-null float64 \n",
" 25 rsi_6 1136157 non-null float64 \n",
" 26 rsi_9 1136157 non-null float64 \n",
" 27 return_10 1136157 non-null float64 \n",
" 28 return_20 1136157 non-null float64 \n",
" 29 avg_close_5 1136157 non-null float64 \n",
" 30 std_return_5 1136157 non-null float64 \n",
" 31 std_return_15 1136157 non-null float64 \n",
" 32 std_return_25 1136157 non-null float64 \n",
" 33 std_return_90 1136131 non-null float64 \n",
" 34 std_return_90_2 1136129 non-null float64 \n",
" 35 std_return_5 / std_return_90 1136131 non-null float64 \n",
" 36 std_return_5 / std_return_25 1136157 non-null float64 \n",
" 37 std_return_90 - std_return_90_2 1136129 non-null float64 \n",
" 38 ema_5 1136157 non-null float64 \n",
" 39 ema_13 1136157 non-null float64 \n",
" 40 ema_20 1136157 non-null float64 \n",
" 41 ema_60 1136153 non-null float64 \n",
" 42 act_factor1 1136157 non-null float64 \n",
" 43 act_factor2 1136157 non-null float64 \n",
" 44 act_factor3 1136157 non-null float64 \n",
" 45 act_factor4 1136152 non-null float64 \n",
" 46 act_factor5 1136152 non-null float64 \n",
" 47 act_factor6 1136157 non-null float64 \n",
" 48 rank_act_factor1 1136157 non-null float64 \n",
" 49 rank_act_factor2 1136157 non-null float64 \n",
" 50 rank_act_factor3 1136157 non-null float64 \n",
" 51 active_buy_volume_large 1135659 non-null float64 \n",
" 52 active_buy_volume_big 1135636 non-null float64 \n",
" 53 active_buy_volume_small 1135663 non-null float64 \n",
" 54 buy_lg_vol - sell_lg_vol 1135660 non-null float64 \n",
" 55 buy_elg_vol - sell_elg_vol 1135640 non-null float64 \n",
" 56 future_return1 1136157 non-null float64 \n",
" 57 future_return2 1136157 non-null float64 \n",
" 58 future_return3 1136157 non-null float64 \n",
" 59 future_return4 1136157 non-null float64 \n",
" 60 future_return5 1136157 non-null float64 \n",
" 61 future_return6 1136157 non-null float64 \n",
" 62 future_return7 1136157 non-null float64 \n",
" 63 future_close1 1136157 non-null float64 \n",
" 64 future_close2 1136157 non-null float64 \n",
" 65 future_close3 1136157 non-null float64 \n",
" 66 future_close4 1136157 non-null float64 \n",
" 67 future_close5 1136157 non-null float64 \n",
" 68 future_af11 1136157 non-null float64 \n",
" 69 future_af12 1136157 non-null float64 \n",
" 70 future_af13 1136157 non-null float64 \n",
" 71 future_af14 1136157 non-null float64 \n",
" 72 future_af15 1136157 non-null float64 \n",
" 73 future_af21 1136157 non-null float64 \n",
" 74 future_af22 1136157 non-null float64 \n",
" 75 future_af23 1136157 non-null float64 \n",
" 76 future_af24 1136157 non-null float64 \n",
" 77 future_af25 1136157 non-null float64 \n",
" 78 future_af31 1136157 non-null float64 \n",
" 79 future_af32 1136157 non-null float64 \n",
" 80 future_af33 1136157 non-null float64 \n",
" 81 future_af34 1136157 non-null float64 \n",
" 82 future_af35 1136157 non-null float64 \n",
"dtypes: datetime64[ns](1), float64(80), object(2)\n",
"memory usage: 719.5+ MB\n",
"None\n"
]
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T15:00:45.828404Z",
"start_time": "2025-02-09T15:00:45.294830Z"
}
},
"cell_type": "code",
"source": [
"def remove_outliers_iqr(series, lower_quantile=0.05, upper_quantile=0.95, threshold=1.5):\n",
" Q1 = series.quantile(lower_quantile)\n",
" Q3 = series.quantile(upper_quantile)\n",
" IQR = Q3 - Q1\n",
" lower_bound = Q1 - threshold * IQR\n",
" upper_bound = Q3 + threshold * IQR\n",
" # 过滤掉低于下边界或高于上边界的极值\n",
" return (series >= lower_bound) & (series <= upper_bound)\n",
"\n",
"\n",
"def neutralize_labels(labels, features, feature_columns, z_threshold=3, method='regression'):\n",
" labels_no_outliers = remove_outliers_iqr(labels)\n",
" return labels_no_outliers\n",
"\n",
"\n",
"train_data = df[df['trade_date'] <= '2023-01-01']\n",
"test_data = df[df['trade_date'] >= '2023-01-01']\n",
"\n",
"feature_columns = [col for col in df.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"\n",
"# for column in [column for column in train_data.columns if 'future' in column]:\n",
"# label_index = neutralize_labels(train_data[column], train_data, feature_columns, z_threshold=3, method='regression')\n",
"# train_data = train_data[label_index]\n",
"# label_index = neutralize_labels(test_data[column], test_data, feature_columns, z_threshold=3, method='regression')\n",
"# test_data = test_data[label_index]\n",
"\n",
"print(len(train_data))\n",
"print(len(test_data))"
],
"id": "5f3d9aece75318cd",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol - sell_lg_vol', 'buy_elg_vol - sell_elg_vol']\n"
]
}
],
"execution_count": 19
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:56:05.319915Z",
"start_time": "2025-02-09T14:56:03.355725Z"
}
},
"cell_type": "code",
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"window = 5\n",
"quantiles = 20\n",
"\n",
"\n",
"def get_label(df):\n",
" labels = df['future_af13'] - df['act_factor1']\n",
" # labels = df['future_close3']\n",
" return labels\n",
"\n",
"\n",
"train_data['label'], test_data['label'] = get_label(train_data), get_label(test_data)\n",
"\n",
"train_data, test_data = train_data.dropna(subset=['label']), test_data.dropna(subset=['label'])\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan).dropna(), test_data.replace([np.inf, -np.inf],\n",
" np.nan).dropna()\n",
"train_data, test_data = train_data.reset_index(drop=True), test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(len(test_data))"
],
"id": "f4f16d63ad18d1bc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"875004\n",
"最小日期: 2017-01-03\n",
"最大日期: 2022-12-30\n",
"260581\n",
"最小日期: 2023-01-03\n",
"最大日期: 2025-01-27\n"
]
}
],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:56:05.480695Z",
"start_time": "2025-02-09T14:56:05.367238Z"
}
},
"cell_type": "code",
"source": [
"import lightgbm as lgb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import optuna\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.metrics import mean_absolute_error\n",
"import os\n",
"import json\n",
"import pickle\n",
"import hashlib\n",
"\n",
"\n",
"def objective(trial, X, y, num_boost_round, params):\n",
" # 参数网格\n",
" X, y = X.reset_index(drop=True), y.reset_index(drop=True)\n",
" param_grid = {\n",
" \"n_estimators\": trial.suggest_categorical(\"n_estimators\", [10000]),\n",
" \"learning_rate\": trial.suggest_float(\"learning_rate\", 0.01, 0.3),\n",
" \"num_leaves\": trial.suggest_int(\"num_leaves\", 20, 3000, step=25),\n",
" \"max_depth\": trial.suggest_int(\"max_depth\", 3, 16),\n",
" \"min_data_in_leaf\": trial.suggest_int(\"min_data_in_leaf\", 200, 10000, step=100),\n",
" \"lambda_l1\": trial.suggest_int(\"lambda_l1\", 0, 100, step=5),\n",
" \"lambda_l2\": trial.suggest_int(\"lambda_l2\", 0, 100, step=5),\n",
" \"min_gain_to_split\": trial.suggest_float(\"min_gain_to_split\", 0, 15),\n",
" \"bagging_fraction\": trial.suggest_float(\"bagging_fraction\", 0.2, 0.95, step=0.1),\n",
" \"bagging_freq\": trial.suggest_categorical(\"bagging_freq\", [1]),\n",
" \"feature_fraction\": trial.suggest_float(\"feature_fraction\", 0.2, 0.95, step=0.1),\n",
" \"random_state\": 1,\n",
" \"objective\": 'regression',\n",
" 'verbosity': -1\n",
" }\n",
" # 5折交叉验证\n",
" cv = KFold(n_splits=5, shuffle=False)\n",
"\n",
" cv_scores = np.empty(5)\n",
" for idx, (train_idx, test_idx) in enumerate(cv.split(X, y)):\n",
" X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]\n",
" y_train, y_test = y[train_idx], y[test_idx]\n",
"\n",
" # LGBM建模\n",
" model = lgb.LGBMRegressor(**param_grid, num_boost_round=num_boost_round)\n",
" model.fit(\n",
" X_train,\n",
" y_train,\n",
" eval_set=[(X_test, y_test)],\n",
" eval_metric=\"l2\",\n",
" callbacks=[\n",
" # LightGBMPruningCallback(trial, \"l2\"),\n",
" lgb.early_stopping(50, first_metric_only=True),\n",
" lgb.log_evaluation(period=-1)\n",
" ],\n",
" )\n",
" # 模型预测\n",
" preds = model.predict(X_test)\n",
" # 优化指标logloss最小\n",
" cv_scores[idx] = mean_absolute_error(y_test, preds)\n",
"\n",
" return np.mean(cv_scores)\n",
"\n",
"def generate_key(params, feature_columns, num_boost_round):\n",
" key_data = {\n",
" \"params\": params,\n",
" \"feature_columns\": feature_columns,\n",
" \"num_boost_round\": num_boost_round\n",
" }\n",
" # 转换成排序后的 JSON 字符串,再生成 md5 hash\n",
" key_str = json.dumps(key_data, sort_keys=True)\n",
" return hashlib.md5(key_str.encode('utf-8')).hexdigest()\n",
"\n",
"def train_light_model(df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" use_optuna=False):\n",
" cache_file = 'light_model.pkl'\n",
" cache_key = generate_key(params, feature_columns, num_boost_round)\n",
"\n",
" # 检查缓存文件是否存在\n",
" if os.path.exists(cache_file):\n",
" try:\n",
" with open(cache_file, 'rb') as f:\n",
" cache_data = pickle.load(f)\n",
" if cache_data.get('key') == cache_key:\n",
" print(\"加载缓存模型...\")\n",
" return cache_data.get('model')\n",
" else:\n",
" print(\"缓存模型的参数与当前参数不匹配,重新训练模型。\")\n",
" except Exception as e:\n",
" print(f\"加载缓存失败: {e},重新训练模型。\")\n",
" else:\n",
" print(\"未发现缓存模型,开始训练新模型。\")\n",
" # 确保数据按照 date 和 label 排序\n",
" df_sorted = df.sort_values(by=['trade_date', 'label'], ascending=[True, False]) # 按日期升序、标签降序排序\n",
" df_sorted = df_sorted.sort_values(by='trade_date')\n",
" unique_dates = df_sorted['trade_date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['trade_date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['trade_date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices]\n",
" val_df = df_sorted.iloc[val_indices]\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" train_data = lgb.Dataset(X_train, label=y_train)\n",
" val_data = lgb.Dataset(X_val, label=y_val)\n",
" if use_optuna:\n",
" # study = optuna.create_study(direction='minimize' if classify else 'maximize')\n",
" study = optuna.create_study(direction='minimize')\n",
" study.optimize(lambda trial: objective(trial, X_train, y_train, num_boost_round, params), n_trials=20)\n",
"\n",
" print(f\"Best parameters: {study.best_trial.params}\")\n",
" print(f\"Best score: {study.best_trial.value}\")\n",
"\n",
" params.update(study.best_trial.params)\n",
" model = lgb.train(\n",
" params, train_data, num_boost_round=num_boost_round,\n",
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_tree(model, figsize=(20, 8))\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
" # with open(cache_file, 'wb') as f:\n",
" # pickle.dump({'key': cache_key,\n",
" # 'model': model,\n",
" # 'feature_columns': feature_columns}, f)\n",
" # print(\"模型训练完成并已保存缓存。\")\n",
" return model\n",
"\n",
"\n",
"from catboost import CatBoostRegressor\n",
"import pandas as pd\n",
"\n",
"\n",
"def train_catboost(df, num_boost_round, params=None):\n",
" \"\"\"\n",
" 训练 CatBoost 排序模型\n",
" - df: 包含因子、date、instrument 和 label 的 DataFrame\n",
" - num_boost_round: 训练的轮数\n",
" - print_feature_importance: 是否打印特征重要性\n",
" - plot: 是否绘制特征重要性图\n",
" - split_date: 用于划分训练集和验证集的日期(比如 '2020-01-01'\n",
"\n",
" 返回训练好的模型\n",
" \"\"\"\n",
" df_sorted = df.sort_values(by=['date', 'label'], ascending=[True, False])\n",
"\n",
" # 提取特征和标签\n",
" feature_columns = [col for col in df.columns if col not in ['date',\n",
" 'instrument',\n",
" 'label']]\n",
" feature_columns = [col for col in feature_columns if 'future' not in col]\n",
" feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"\n",
" df_sorted = df_sorted.sort_values(by='date')\n",
" unique_dates = df_sorted['date'].unique()\n",
" val_date_count = int(len(unique_dates) * 0.1)\n",
" val_dates = unique_dates[-val_date_count:]\n",
" val_indices = df_sorted[df_sorted['date'].isin(val_dates)].index\n",
" train_indices = df_sorted[~df_sorted['date'].isin(val_dates)].index\n",
"\n",
" # 获取训练集和验证集的样本\n",
" train_df = df_sorted.iloc[train_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
" val_df = df_sorted.iloc[val_indices].sort_values(by=['date', 'label'], ascending=[True, False])\n",
"\n",
" X_train = train_df[feature_columns]\n",
" y_train = train_df['label']\n",
"\n",
" X_val = val_df[feature_columns]\n",
" y_val = val_df['label']\n",
"\n",
" model = CatBoostRegressor(**params, iterations=num_boost_round)\n",
" model.fit(X_train,\n",
" y_train,\n",
" eval_set=(X_val, y_val))\n",
"\n",
" return model"
],
"id": "8f134d435f71e9e2",
"outputs": [],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:56:05.576927Z",
"start_time": "2025-02-09T14:56:05.480695Z"
}
},
"cell_type": "code",
"source": [
"light_params = {\n",
" 'objective': 'regression',\n",
" 'metric': 'l2',\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 16,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
" 'nthread': 2,\n",
" 'feature_fraction': 0.7,\n",
" 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" 'lambda_l1': 80,\n",
" 'lambda_l2': 65,\n",
" 'verbosity': -1\n",
"}"
],
"id": "4a4542e1ed6afe7d",
"outputs": [],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:57:25.341222Z",
"start_time": "2025-02-09T14:56:05.640256Z"
}
},
"cell_type": "code",
"source": [
"print('train data size: ', len(train_data))\n",
"df = train_data\n",
"\n",
"evals = {}\n",
"light_model = train_light_model(train_data, light_params, feature_columns,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=1000, use_optuna=False,\n",
" print_feature_importance=False)"
],
"id": "beeb098799ecfa6a",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 875004\n",
"未发现缓存模型,开始训练新模型。\n",
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
"[378]\ttrain's l2: 0.435049\tvalid's l2: 0.589178\n",
"Evaluated only: l2\n"
]
}
],
"execution_count": 16
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:57:27.394697Z",
"start_time": "2025-02-09T14:57:25.373274Z"
}
},
"cell_type": "code",
"source": [
"test_data['score'] = light_model.predict(test_data[feature_columns])\n",
"predictions = test_data.loc[test_data.groupby('trade_date')['score'].idxmax()]"
],
"id": "5bb96ca8492e74d",
"outputs": [],
"execution_count": 17
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-09T14:57:27.489570Z",
"start_time": "2025-02-09T14:57:27.397368Z"
}
},
"cell_type": "code",
"source": "predictions[['trade_date', 'score', 'ts_code']].to_csv('predictions.csv', index=False)",
"id": "5d1522a7538db91b",
"outputs": [],
"execution_count": 18
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

BIN
code/train/best_model.pth Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

384
code/train/code.ipynb Normal file
View File

@@ -0,0 +1,384 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"\n",
"\n",
"def get_technical_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" df['return_skew'] = grouped['pct_chg'].rolling(window=5).skew().reset_index(0, drop=True)\n",
" df['return_kurtosis'] = grouped['pct_chg'].rolling(window=5).kurt().reset_index(0, drop=True)\n",
"\n",
" # 因子 1短期成交量变化率\n",
" df['volume_change_rate'] = (\n",
" grouped['vol'].rolling(window=2).mean() /\n",
" grouped['vol'].rolling(window=5).mean() - 1\n",
" ).reset_index(level=0, drop=True) # 确保索引对齐\n",
"\n",
" # 因子 2成交量突破信号\n",
" max_volume = grouped['vol'].rolling(window=5).max().reset_index(level=0, drop=True) # 确保索引对齐\n",
" df['cat_volume_breakout'] = (df['vol'] > max_volume)\n",
"\n",
" # 因子 3换手率均线偏离度\n",
" mean_turnover = grouped['turnover_rate'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
" std_turnover = grouped['turnover_rate'].rolling(window=3).std().reset_index(level=0, drop=True)\n",
" df['turnover_deviation'] = (df['turnover_rate'] - mean_turnover) / std_turnover\n",
"\n",
" # 因子 4换手率激增信号\n",
" df['cat_turnover_spike'] = (df['turnover_rate'] > mean_turnover + 2 * std_turnover)\n",
"\n",
" # 因子 5量比均值\n",
" df['avg_volume_ratio'] = grouped['volume_ratio'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
"\n",
" # 因子 6量比突破信号\n",
" max_volume_ratio = grouped['volume_ratio'].rolling(window=5).max().reset_index(level=0, drop=True)\n",
" df['cat_volume_ratio_breakout'] = (df['volume_ratio'] > max_volume_ratio)\n",
"\n",
" # 因子 7成交量与换手率的综合动量因子\n",
" alpha = 0.5\n",
" df['momentum_factor'] = df['volume_change_rate'] + alpha * df['turnover_deviation']\n",
"\n",
" # 因子 8量价共振因子\n",
" df['price_change_rate'] = grouped['close'].pct_change()\n",
" df['resonance_factor'] = df['volume_ratio'] * df['price_change_rate']\n",
"\n",
" # 计算 up 和 down\n",
" df['log_close'] = np.log(df['close'])\n",
"\n",
" df['vol_spike'] = grouped.apply(\n",
" lambda x: pd.Series(x['vol'].rolling(20).mean(), index=x.index)\n",
" )\n",
" df['cat_vol_spike'] = df['vol'] > 2 * df['vol_spike']\n",
" df['vol_std_5'] = df['vol'].pct_change().rolling(5).std()\n",
"\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),\n",
" index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),\n",
" index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算 RSI\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df, cat=True):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
" # 计算 EMA 指标\n",
" df['_ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['_ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['_ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['_ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['_ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['_ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['_ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['_ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" if cat:\n",
" df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_limit_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" # 分组处理\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 1. 今日是否涨停/跌停\n",
" df['cat_up_limit'] = (df['close'] == df['up_limit']).astype(int) # 是否涨停1表示涨停0表示未涨停\n",
" df['cat_down_limit'] = (df['close'] == df['down_limit']).astype(int) # 是否跌停1表示跌停0表示未跌停\n",
"\n",
" # 2. 最近涨跌停次数过去20个交易日\n",
" df['up_limit_count_10d'] = grouped['cat_up_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
" df['down_limit_count_10d'] = grouped['cat_down_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
"\n",
" # 3. 最近连续涨跌停天数\n",
" def calculate_consecutive_limits(series):\n",
" \"\"\"\n",
" 计算连续涨停/跌停天数。\n",
" \"\"\"\n",
" consecutive_up = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" consecutive_down = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" return consecutive_up, consecutive_down\n",
"\n",
" # 连续涨停天数\n",
" df['consecutive_up_limit'] = grouped['cat_up_limit'].apply(\n",
" lambda x: calculate_consecutive_limits(x)[0]\n",
" ).reset_index(level=0, drop=True)\n",
"\n",
" # 连续跌停天数\n",
" # df['consecutive_down_limit'] = grouped['cat_down_limit'].apply(\n",
" # lambda x: calculate_consecutive_limits(x)[1]\n",
" # ).reset_index(level=0, drop=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_cyp_perf_factor(df):\n",
" # 预处理:按股票代码和时间排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" # 按股票代码分组处理\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" df['ctrl_strength'] = (df['cost_85pct'] - df['cost_15pct']) / (df['his_high'] - df['his_low'])\n",
"\n",
" df['low_cost_dev'] = (df['close'] - df['cost_5pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
"\n",
" df['asymmetry'] = (df['cost_95pct'] - df['cost_50pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
"\n",
" df['lock_factor'] = df['turnover_rate'] * (\n",
" 1 - (df['cost_95pct'] - df['cost_5pct']) / (df['his_high'] - df['his_low']))\n",
"\n",
" df['vol_break'] = np.where((df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2), 1, 0)\n",
"\n",
" df['weight_roc5'] = grouped['weight_avg'].apply(lambda x: x.pct_change(5))\n",
"\n",
" def rolling_corr(group):\n",
" roc_close = group['close'].pct_change()\n",
" roc_weight = group['weight_avg'].pct_change()\n",
" return roc_close.rolling(10).corr(roc_weight)\n",
"\n",
" df['price_cost_divergence'] = grouped.apply(rolling_corr)\n",
"\n",
" def calc_atr(group):\n",
" high, low, close = group['high'], group['low'], group['close']\n",
" tr = np.maximum(high - low,\n",
" np.maximum(abs(high - close.shift()),\n",
" abs(low - close.shift())))\n",
" return tr.rolling(14).mean()\n",
"\n",
" df['atr_14'] = grouped.apply(calc_atr)\n",
" df['cost_atr_adj'] = (df['cost_95pct'] - df['cost_5pct']) / df['atr_14']\n",
"\n",
" # 12. 小盘股筹码集中度\n",
" df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])\n",
"\n",
" # 16. 筹码稳定性指数 (20日波动率)\n",
" df['weight_std20'] = grouped['weight_avg'].apply(lambda x: x.rolling(20).std())\n",
" df['cost_stability'] = df['weight_std20'] / grouped['weight_avg'].transform(lambda x: x.rolling(20).mean())\n",
"\n",
" # 17. 成本区间突破标记\n",
" df['high_cost_break_days'] = grouped.apply(lambda g: g['close'].gt(g['cost_95pct']).rolling(5).sum())\n",
"\n",
" # 18. 黄金筹码共振 (复合事件)\n",
" df['cat_golden_resonance'] = ((df['close'] > df['weight_avg']) &\n",
" (df['volume_ratio'] > 1.5) &\n",
" (df['winner_rate'] > 0.7))\n",
"\n",
" # 20. 筹码-流动性风险\n",
" df['liquidity_risk'] = (df['cost_95pct'] - df['cost_5pct']) * (\n",
" 1 / grouped['vol'].transform(lambda x: x.rolling(10).mean()))\n",
"\n",
" df.drop(columns=['weight_std20'], inplace=True, errors='ignore')\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_mv_factors(df):\n",
" \"\"\"\n",
" 计算多个因子并生成最终的综合因子。\n",
"\n",
" 参数:\n",
" df (pd.DataFrame): 包含 ts_code, trade_date, turnover_rate, pe_ttm, pb, ps, circ_mv, volume_ratio, vol 等列的数据框。\n",
"\n",
" 返回:\n",
" pd.DataFrame: 包含新增因子和最终综合因子的数据框。\n",
" \"\"\"\n",
" # 按 ts_code 和 trade_date 排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" # 按 ts_code 分组\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 1. 市值流动比因子\n",
" df['mv_turnover_ratio'] = df['turnover_rate'] / df['circ_mv']\n",
"\n",
" # 2. 市值调整成交量因子\n",
" df['mv_adjusted_volume'] = df['vol'] / df['circ_mv']\n",
"\n",
" # 3. 市值加权换手率因子\n",
" df['mv_weighted_turnover'] = df['turnover_rate'] * (1 / df['circ_mv'])\n",
"\n",
" # 4. 非线性市值成交量因子\n",
" df['nonlinear_mv_volume'] = df['vol'] / df['circ_mv']\n",
"\n",
" # 5. 市值量比因子\n",
" df['mv_volume_ratio'] = df['volume_ratio'] / df['circ_mv']\n",
"\n",
" # 6. 市值动量因子\n",
" df['mv_momentum'] = df['turnover_rate'] * df['volume_ratio'] / df['circ_mv']\n",
"\n",
" # 7. 市值波动率因子\n",
" df['turnover_std'] = grouped['turnover_rate'].rolling(window=20).std().reset_index(level=0, drop=True)\n",
" df['mv_volatility'] = grouped.apply(lambda x: x['turnover_std'] / x['circ_mv']).reset_index(level=0, drop=True)\n",
"\n",
" # 8. 市值成长性因子\n",
" df['volume_growth'] = grouped['vol'].pct_change(periods=20).reset_index(level=0, drop=True)\n",
" df['mv_growth'] = grouped.apply(lambda x: x['volume_growth'] / x['circ_mv']).reset_index(level=0, drop=True)\n",
"\n",
" # # 标准化因子\n",
" # factor_columns = [\n",
" # 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover',\n",
" # 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum',\n",
" # 'mv_volatility', 'mv_growth'\n",
" # ]\n",
" # scaler = StandardScaler()\n",
" # df[factor_columns] = scaler.fit_transform(df[factor_columns])\n",
" #\n",
" # # 加权合成因子\n",
" # weights = [0.2, 0.15, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1] # 各因子权重\n",
" # df['final_combined_factor'] = df[factor_columns].dot(weights)\n",
"\n",
" return df"
],
"id": "505e825945e4b8cf"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

1359
code/train/predictions.tsv Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,262 @@
trade_date,score,ts_code
2024-03-01,0.48944956369028353,600515.SH
2024-03-04,0.4910973129007265,601377.SH
2024-03-05,0.49755441935533035,601377.SH
2024-03-06,0.48530227310531543,000950.SZ
2024-03-07,0.4852382924633875,002252.SZ
2024-03-08,0.4295690224121115,000650.SZ
2024-03-11,0.4772666567296332,002539.SZ
2024-03-12,0.44912255948878027,002030.SZ
2024-03-13,0.48036470172022555,603366.SH
2024-03-14,0.47455902892485136,002616.SZ
2024-03-15,0.4898609892396381,002616.SZ
2024-03-18,0.4898609892396381,002616.SZ
2024-03-19,0.4585803401938102,600749.SH
2024-03-20,0.4717940942781549,601616.SH
2024-03-21,0.4898609892396381,002616.SZ
2024-03-22,0.4921881826907514,600268.SH
2024-03-25,0.4529948903273269,000910.SZ
2024-03-26,0.468535149074664,600163.SH
2024-03-27,0.45345883846621166,600268.SH
2024-03-28,0.4629678911934166,603689.SH
2024-03-29,0.4576699667200496,600479.SH
2024-04-01,0.4880682627555449,000407.SZ
2024-04-02,0.4658882567202438,002228.SZ
2024-04-03,0.46661866733947643,600279.SH
2024-04-08,0.4348058590086616,601199.SH
2024-04-09,0.440404827771838,600279.SH
2024-04-10,0.45436095343545857,600279.SH
2024-04-11,0.4860610115557819,002807.SZ
2024-04-12,0.49251811801038575,002807.SZ
2024-04-15,0.45282134310413785,002267.SZ
2024-04-16,0.44573083040693495,601200.SH
2024-04-17,0.4755082230960848,600681.SH
2024-04-18,0.46969708080904654,600681.SH
2024-04-19,0.4763960575203315,601963.SH
2024-04-22,0.4797611675632624,601006.SH
2024-04-23,0.4417060612494253,600905.SH
2024-04-24,0.467078149060861,600681.SH
2024-04-25,0.48593234463506524,601016.SH
2024-04-26,0.485704669384595,600248.SH
2024-04-29,0.4856232616952012,000088.SZ
2024-04-30,0.47658127432737996,600717.SH
2024-05-06,0.485599247202529,600057.SH
2024-05-07,0.48502517567793635,600928.SH
2024-05-08,0.4875800938455082,601997.SH
2024-05-09,0.4764814479757618,601990.SH
2024-05-10,0.4767299756060024,601187.SH
2024-05-13,0.48339009051626075,002252.SZ
2024-05-14,0.49804720312809037,601236.SH
2024-05-15,0.4754440788180303,601669.SH
2024-05-16,0.47188100886784723,002252.SZ
2024-05-17,0.4738812084576224,000505.SZ
2024-05-20,0.47680745811007763,600300.SH
2024-05-21,0.4914822821325402,000919.SZ
2024-05-22,0.43568399742751623,603777.SH
2024-05-23,0.48230237417866073,000919.SZ
2024-05-24,0.45217189085141946,600846.SH
2024-05-27,0.4822386303700534,000012.SZ
2024-05-28,0.4911043379651643,000012.SZ
2024-05-29,0.4593967921412618,000012.SZ
2024-05-30,0.46327623228814613,600033.SH
2024-05-31,0.48434217554317305,601006.SH
2024-06-03,0.49496311907072943,601158.SH
2024-06-04,0.49311491712360267,601158.SH
2024-06-05,0.4662101696226407,600681.SH
2024-06-06,0.44598437609141306,600507.SH
2024-06-07,0.46683069546117434,002936.SZ
2024-06-11,0.48392216170099567,002936.SZ
2024-06-12,0.48392216170099567,002936.SZ
2024-06-13,0.4856066420239958,601555.SH
2024-06-14,0.4626378515738043,600273.SH
2024-06-17,0.4930611701154713,601868.SH
2024-06-18,0.4850760483461876,600909.SH
2024-06-19,0.48789755402053436,600909.SH
2024-06-20,0.47168404865503394,600918.SH
2024-06-21,0.4780215106901423,002939.SZ
2024-06-24,0.4670168740500108,600507.SH
2024-06-25,0.4462181161419812,600507.SH
2024-06-26,0.47871585166988867,601108.SH
2024-06-27,0.4550530100208326,601108.SH
2024-06-28,0.4682668022638408,600056.SH
2024-07-01,0.4956846242134755,601555.SH
2024-07-02,0.4855061855012652,601375.SH
2024-07-03,0.48208040944932096,000166.SZ
2024-07-04,0.4741422882624783,002763.SZ
2024-07-05,0.4207747223564723,600279.SH
2024-07-08,0.4505916830616898,000778.SZ
2024-07-09,0.47552387696457243,600108.SH
2024-07-10,0.48433484852590997,601990.SH
2024-07-11,0.4670143394143526,600597.SH
2024-07-12,0.48598259773635294,601228.SH
2024-07-15,0.46747598383004435,600704.SH
2024-07-16,0.46517324677556354,601228.SH
2024-07-17,0.4669329317249588,600681.SH
2024-07-18,0.4342287048693295,603886.SH
2024-07-19,0.4688897638752658,002936.SZ
2024-07-22,0.4655336697859469,601008.SH
2024-07-23,0.4385299287889017,000959.SZ
2024-07-24,0.40980850428167526,600059.SH
2024-07-25,0.4354808457559223,600928.SH
2024-07-26,0.43863748866816843,002936.SZ
2024-07-29,0.4757869671100767,002936.SZ
2024-07-30,0.45055950278777573,600056.SH
2024-07-31,0.43441391812299834,000672.SZ
2024-08-01,0.46163952351711074,600597.SH
2024-08-02,0.45912773679288216,601528.SH
2024-08-05,0.42637336325271274,000012.SZ
2024-08-06,0.44417354831155165,000498.SZ
2024-08-07,0.46761174293945373,002029.SZ
2024-08-08,0.44084265961702834,600061.SH
2024-08-09,0.44756951600219536,600061.SH
2024-08-12,0.43865468410238356,600517.SH
2024-08-13,0.40119132713124916,600798.SH
2024-08-14,0.4147687048594177,600279.SH
2024-08-15,0.4285641100025751,600283.SH
2024-08-16,0.43481178678490995,600279.SH
2024-08-19,0.4334514206803078,000055.SZ
2024-08-20,0.4271544196243579,600925.SH
2024-08-21,0.44140030057394675,601718.SH
2024-08-22,0.4260043286627239,002266.SZ
2024-08-23,0.4272469992425721,002266.SZ
2024-08-26,0.43298605818458036,601718.SH
2024-08-27,0.42993335252901094,600308.SH
2024-08-28,0.44014909636073596,601216.SH
2024-08-29,0.4549086267823673,002108.SZ
2024-08-30,0.4432950051114155,603817.SH
2024-09-02,0.44924217940667366,603759.SH
2024-09-03,0.45929491873592476,603759.SH
2024-09-04,0.4555291720504659,603817.SH
2024-09-05,0.45486187663776934,000581.SZ
2024-09-06,0.458767486527876,600016.SH
2024-09-09,0.42622859802922114,000725.SZ
2024-09-10,0.47344109719180894,002239.SZ
2024-09-11,0.4602775423090333,600050.SH
2024-09-12,0.4581305095531178,603856.SH
2024-09-13,0.48196532955068866,002697.SZ
2024-09-18,0.48196532955068866,002697.SZ
2024-09-19,0.4448663482243619,600210.SH
2024-09-20,0.4894480847057984,600293.SH
2024-09-23,0.48392216170099567,002239.SZ
2024-09-24,0.43493260807615913,603182.SH
2024-09-25,0.4279590881784511,601369.SH
2024-09-26,0.457735727285402,600526.SH
2024-09-27,0.464687096497251,002818.SZ
2024-09-30,0.21397500413643297,600081.SH
2024-10-08,0.22763716829204592,605333.SH
2024-10-09,0.44498910797127905,600495.SH
2024-10-10,0.4635725472634731,600251.SH
2024-10-11,0.46083590103602623,000967.SZ
2024-10-14,0.47310129519500743,600251.SH
2024-10-15,0.4432757845922387,601727.SH
2024-10-16,0.49142274407028197,002267.SZ
2024-10-17,0.4969633584025033,600032.SH
2024-10-18,0.49050625194789943,002267.SZ
2024-10-21,0.4839204785725789,000709.SZ
2024-10-22,0.47843667953797847,002237.SZ
2024-10-23,0.4785259848937853,601577.SH
2024-10-24,0.48592022305116356,601577.SH
2024-10-25,0.4799806860019888,600820.SH
2024-10-28,0.4616976069244002,002135.SZ
2024-10-29,0.48907368099203313,600330.SH
2024-10-30,0.47786058189302977,600516.SH
2024-10-31,0.46572447890972274,600969.SH
2024-11-01,0.46269570933167936,600261.SH
2024-11-04,0.4819410916205754,600249.SH
2024-11-05,0.4866489103957071,603167.SH
2024-11-06,0.485704669384595,600423.SH
2024-11-07,0.4906974157098039,600249.SH
2024-11-08,0.47099968069944914,002328.SZ
2024-11-11,0.47924756292999116,600103.SH
2024-11-12,0.4864515748363791,600103.SH
2024-11-13,0.4682636704183581,600419.SH
2024-11-14,0.47314641913653194,000533.SZ
2024-11-15,0.48200633444426155,600103.SH
2024-11-18,0.4734312305146418,600300.SH
2024-11-19,0.481591587838489,601113.SH
2024-11-20,0.485704669384595,600493.SH
2024-11-21,0.48103812387602335,600515.SH
2024-11-22,0.48907368099203313,600284.SH
2024-11-25,0.49351892138415066,603111.SH
2024-11-26,0.49311491712360267,601187.SH
2024-11-27,0.49823564043257407,601187.SH
2024-11-28,0.49478522113574214,002091.SZ
2024-11-29,0.4744243783704977,002390.SZ
2024-12-02,0.47703304354620096,002390.SZ
2024-12-03,0.48674050526244056,002566.SZ
2024-12-04,0.4959275225246577,002753.SZ
2024-12-05,0.48839819807517926,002753.SZ
2024-12-06,0.4870618149295468,603128.SH
2024-12-09,0.49249410351771356,600425.SH
2024-12-10,0.4959275225246577,002772.SZ
2024-12-11,0.48502517567793635,600035.SH
2024-12-12,0.48907368099203313,600035.SH
2024-12-13,0.4901095168698787,600035.SH
2024-12-16,0.485704669384595,603577.SH
2024-12-17,0.42271811250681296,000533.SZ
2024-12-18,0.45253512664134277,000026.SZ
2024-12-19,0.4668403018194925,601000.SH
2024-12-20,0.47643624443298405,000026.SZ
2024-12-23,0.4405301681206755,600305.SH
2024-12-24,0.44052806294048863,000883.SZ
2024-12-25,0.448810439825342,000589.SZ
2024-12-26,0.47338854249265216,600582.SH
2024-12-27,0.4495112908067394,000830.SZ
2024-12-30,0.463164320301011,601006.SH
2024-12-31,0.4773178744459276,600004.SH
2025-01-02,0.47702414327290926,600572.SH
2025-01-03,0.4474782636368997,601163.SH
2025-01-06,0.4333771722554744,600821.SH
2025-01-07,0.44770531932040636,600004.SH
2025-01-08,0.47082512104142743,600116.SH
2025-01-09,0.4541624257750102,600004.SH
2025-01-10,0.4505883376349118,600905.SH
2025-01-13,0.4505883376349118,600905.SH
2025-01-14,0.4629277092655053,000088.SZ
2025-01-15,0.4849467618585074,601222.SH
2025-01-16,0.4904213884330184,600273.SH
2025-01-17,0.49427977331421474,002267.SZ
2025-01-20,0.4921881826907514,000088.SZ
2025-01-21,0.49980885843191936,002233.SZ
2025-01-22,0.49226959038014517,603817.SH
2025-01-23,0.4845623679956794,600731.SH
2025-01-24,0.4894480847057984,002443.SZ
2025-01-27,0.480828304199342,600475.SH
2025-02-05,0.49381191740852215,600475.SH
2025-02-06,0.48051107405830695,600219.SH
2025-02-07,0.4912337545022996,002365.SZ
2025-02-10,0.4891404807504195,002606.SZ
2025-02-11,0.48918316877240914,002454.SZ
2025-02-12,0.48841224882795287,605138.SH
2025-02-13,0.49427977331421474,603022.SH
2025-02-14,0.4961990011809636,002454.SZ
2025-02-17,0.49980885843191936,603172.SH
2025-02-18,0.4859761480409009,600526.SH
2025-02-19,0.47951904158629705,600526.SH
2025-02-20,0.4865640468808261,002972.SZ
2025-02-21,0.4878226668596109,002972.SZ
2025-02-24,0.49427977331421474,002972.SZ
2025-02-25,0.4959275225246577,002972.SZ
2025-02-26,0.49559758720502334,000850.SZ
2025-02-27,0.4962265008275085,600969.SH
2025-02-28,0.45857905703475876,000931.SZ
2025-03-03,0.475929514181944,600704.SH
2025-03-04,0.4855766924008325,603176.SH
2025-03-05,0.48947041607005387,600749.SH
2025-03-06,0.48947041607005387,600749.SH
2025-03-07,0.48852295338866114,002948.SZ
2025-03-10,0.48013938012027435,600749.SH
2025-03-11,0.4697442502219659,603916.SH
2025-03-12,0.46825376442882116,600969.SH
2025-03-13,0.485704669384595,601311.SH
2025-03-14,0.47290759442027386,600784.SH
2025-03-17,0.48236846890282975,002204.SZ
2025-03-18,0.4753161955607809,600784.SH
2025-03-19,0.4898609892396381,002627.SZ
2025-03-20,0.4898609892396381,002627.SZ
2025-03-21,0.47660773064492085,000589.SZ
2025-03-24,0.4753538392607698,000589.SZ
2025-03-25,0.4628733846203298,000589.SZ
2025-03-26,0.45948501496487415,603367.SH
2025-03-27,0.47591884404751766,600017.SH
2025-03-28,0.4765671044505851,600925.SH
1 trade_date,score,ts_code
2 2024-03-01,0.48944956369028353,600515.SH
3 2024-03-04,0.4910973129007265,601377.SH
4 2024-03-05,0.49755441935533035,601377.SH
5 2024-03-06,0.48530227310531543,000950.SZ
6 2024-03-07,0.4852382924633875,002252.SZ
7 2024-03-08,0.4295690224121115,000650.SZ
8 2024-03-11,0.4772666567296332,002539.SZ
9 2024-03-12,0.44912255948878027,002030.SZ
10 2024-03-13,0.48036470172022555,603366.SH
11 2024-03-14,0.47455902892485136,002616.SZ
12 2024-03-15,0.4898609892396381,002616.SZ
13 2024-03-18,0.4898609892396381,002616.SZ
14 2024-03-19,0.4585803401938102,600749.SH
15 2024-03-20,0.4717940942781549,601616.SH
16 2024-03-21,0.4898609892396381,002616.SZ
17 2024-03-22,0.4921881826907514,600268.SH
18 2024-03-25,0.4529948903273269,000910.SZ
19 2024-03-26,0.468535149074664,600163.SH
20 2024-03-27,0.45345883846621166,600268.SH
21 2024-03-28,0.4629678911934166,603689.SH
22 2024-03-29,0.4576699667200496,600479.SH
23 2024-04-01,0.4880682627555449,000407.SZ
24 2024-04-02,0.4658882567202438,002228.SZ
25 2024-04-03,0.46661866733947643,600279.SH
26 2024-04-08,0.4348058590086616,601199.SH
27 2024-04-09,0.440404827771838,600279.SH
28 2024-04-10,0.45436095343545857,600279.SH
29 2024-04-11,0.4860610115557819,002807.SZ
30 2024-04-12,0.49251811801038575,002807.SZ
31 2024-04-15,0.45282134310413785,002267.SZ
32 2024-04-16,0.44573083040693495,601200.SH
33 2024-04-17,0.4755082230960848,600681.SH
34 2024-04-18,0.46969708080904654,600681.SH
35 2024-04-19,0.4763960575203315,601963.SH
36 2024-04-22,0.4797611675632624,601006.SH
37 2024-04-23,0.4417060612494253,600905.SH
38 2024-04-24,0.467078149060861,600681.SH
39 2024-04-25,0.48593234463506524,601016.SH
40 2024-04-26,0.485704669384595,600248.SH
41 2024-04-29,0.4856232616952012,000088.SZ
42 2024-04-30,0.47658127432737996,600717.SH
43 2024-05-06,0.485599247202529,600057.SH
44 2024-05-07,0.48502517567793635,600928.SH
45 2024-05-08,0.4875800938455082,601997.SH
46 2024-05-09,0.4764814479757618,601990.SH
47 2024-05-10,0.4767299756060024,601187.SH
48 2024-05-13,0.48339009051626075,002252.SZ
49 2024-05-14,0.49804720312809037,601236.SH
50 2024-05-15,0.4754440788180303,601669.SH
51 2024-05-16,0.47188100886784723,002252.SZ
52 2024-05-17,0.4738812084576224,000505.SZ
53 2024-05-20,0.47680745811007763,600300.SH
54 2024-05-21,0.4914822821325402,000919.SZ
55 2024-05-22,0.43568399742751623,603777.SH
56 2024-05-23,0.48230237417866073,000919.SZ
57 2024-05-24,0.45217189085141946,600846.SH
58 2024-05-27,0.4822386303700534,000012.SZ
59 2024-05-28,0.4911043379651643,000012.SZ
60 2024-05-29,0.4593967921412618,000012.SZ
61 2024-05-30,0.46327623228814613,600033.SH
62 2024-05-31,0.48434217554317305,601006.SH
63 2024-06-03,0.49496311907072943,601158.SH
64 2024-06-04,0.49311491712360267,601158.SH
65 2024-06-05,0.4662101696226407,600681.SH
66 2024-06-06,0.44598437609141306,600507.SH
67 2024-06-07,0.46683069546117434,002936.SZ
68 2024-06-11,0.48392216170099567,002936.SZ
69 2024-06-12,0.48392216170099567,002936.SZ
70 2024-06-13,0.4856066420239958,601555.SH
71 2024-06-14,0.4626378515738043,600273.SH
72 2024-06-17,0.4930611701154713,601868.SH
73 2024-06-18,0.4850760483461876,600909.SH
74 2024-06-19,0.48789755402053436,600909.SH
75 2024-06-20,0.47168404865503394,600918.SH
76 2024-06-21,0.4780215106901423,002939.SZ
77 2024-06-24,0.4670168740500108,600507.SH
78 2024-06-25,0.4462181161419812,600507.SH
79 2024-06-26,0.47871585166988867,601108.SH
80 2024-06-27,0.4550530100208326,601108.SH
81 2024-06-28,0.4682668022638408,600056.SH
82 2024-07-01,0.4956846242134755,601555.SH
83 2024-07-02,0.4855061855012652,601375.SH
84 2024-07-03,0.48208040944932096,000166.SZ
85 2024-07-04,0.4741422882624783,002763.SZ
86 2024-07-05,0.4207747223564723,600279.SH
87 2024-07-08,0.4505916830616898,000778.SZ
88 2024-07-09,0.47552387696457243,600108.SH
89 2024-07-10,0.48433484852590997,601990.SH
90 2024-07-11,0.4670143394143526,600597.SH
91 2024-07-12,0.48598259773635294,601228.SH
92 2024-07-15,0.46747598383004435,600704.SH
93 2024-07-16,0.46517324677556354,601228.SH
94 2024-07-17,0.4669329317249588,600681.SH
95 2024-07-18,0.4342287048693295,603886.SH
96 2024-07-19,0.4688897638752658,002936.SZ
97 2024-07-22,0.4655336697859469,601008.SH
98 2024-07-23,0.4385299287889017,000959.SZ
99 2024-07-24,0.40980850428167526,600059.SH
100 2024-07-25,0.4354808457559223,600928.SH
101 2024-07-26,0.43863748866816843,002936.SZ
102 2024-07-29,0.4757869671100767,002936.SZ
103 2024-07-30,0.45055950278777573,600056.SH
104 2024-07-31,0.43441391812299834,000672.SZ
105 2024-08-01,0.46163952351711074,600597.SH
106 2024-08-02,0.45912773679288216,601528.SH
107 2024-08-05,0.42637336325271274,000012.SZ
108 2024-08-06,0.44417354831155165,000498.SZ
109 2024-08-07,0.46761174293945373,002029.SZ
110 2024-08-08,0.44084265961702834,600061.SH
111 2024-08-09,0.44756951600219536,600061.SH
112 2024-08-12,0.43865468410238356,600517.SH
113 2024-08-13,0.40119132713124916,600798.SH
114 2024-08-14,0.4147687048594177,600279.SH
115 2024-08-15,0.4285641100025751,600283.SH
116 2024-08-16,0.43481178678490995,600279.SH
117 2024-08-19,0.4334514206803078,000055.SZ
118 2024-08-20,0.4271544196243579,600925.SH
119 2024-08-21,0.44140030057394675,601718.SH
120 2024-08-22,0.4260043286627239,002266.SZ
121 2024-08-23,0.4272469992425721,002266.SZ
122 2024-08-26,0.43298605818458036,601718.SH
123 2024-08-27,0.42993335252901094,600308.SH
124 2024-08-28,0.44014909636073596,601216.SH
125 2024-08-29,0.4549086267823673,002108.SZ
126 2024-08-30,0.4432950051114155,603817.SH
127 2024-09-02,0.44924217940667366,603759.SH
128 2024-09-03,0.45929491873592476,603759.SH
129 2024-09-04,0.4555291720504659,603817.SH
130 2024-09-05,0.45486187663776934,000581.SZ
131 2024-09-06,0.458767486527876,600016.SH
132 2024-09-09,0.42622859802922114,000725.SZ
133 2024-09-10,0.47344109719180894,002239.SZ
134 2024-09-11,0.4602775423090333,600050.SH
135 2024-09-12,0.4581305095531178,603856.SH
136 2024-09-13,0.48196532955068866,002697.SZ
137 2024-09-18,0.48196532955068866,002697.SZ
138 2024-09-19,0.4448663482243619,600210.SH
139 2024-09-20,0.4894480847057984,600293.SH
140 2024-09-23,0.48392216170099567,002239.SZ
141 2024-09-24,0.43493260807615913,603182.SH
142 2024-09-25,0.4279590881784511,601369.SH
143 2024-09-26,0.457735727285402,600526.SH
144 2024-09-27,0.464687096497251,002818.SZ
145 2024-09-30,0.21397500413643297,600081.SH
146 2024-10-08,0.22763716829204592,605333.SH
147 2024-10-09,0.44498910797127905,600495.SH
148 2024-10-10,0.4635725472634731,600251.SH
149 2024-10-11,0.46083590103602623,000967.SZ
150 2024-10-14,0.47310129519500743,600251.SH
151 2024-10-15,0.4432757845922387,601727.SH
152 2024-10-16,0.49142274407028197,002267.SZ
153 2024-10-17,0.4969633584025033,600032.SH
154 2024-10-18,0.49050625194789943,002267.SZ
155 2024-10-21,0.4839204785725789,000709.SZ
156 2024-10-22,0.47843667953797847,002237.SZ
157 2024-10-23,0.4785259848937853,601577.SH
158 2024-10-24,0.48592022305116356,601577.SH
159 2024-10-25,0.4799806860019888,600820.SH
160 2024-10-28,0.4616976069244002,002135.SZ
161 2024-10-29,0.48907368099203313,600330.SH
162 2024-10-30,0.47786058189302977,600516.SH
163 2024-10-31,0.46572447890972274,600969.SH
164 2024-11-01,0.46269570933167936,600261.SH
165 2024-11-04,0.4819410916205754,600249.SH
166 2024-11-05,0.4866489103957071,603167.SH
167 2024-11-06,0.485704669384595,600423.SH
168 2024-11-07,0.4906974157098039,600249.SH
169 2024-11-08,0.47099968069944914,002328.SZ
170 2024-11-11,0.47924756292999116,600103.SH
171 2024-11-12,0.4864515748363791,600103.SH
172 2024-11-13,0.4682636704183581,600419.SH
173 2024-11-14,0.47314641913653194,000533.SZ
174 2024-11-15,0.48200633444426155,600103.SH
175 2024-11-18,0.4734312305146418,600300.SH
176 2024-11-19,0.481591587838489,601113.SH
177 2024-11-20,0.485704669384595,600493.SH
178 2024-11-21,0.48103812387602335,600515.SH
179 2024-11-22,0.48907368099203313,600284.SH
180 2024-11-25,0.49351892138415066,603111.SH
181 2024-11-26,0.49311491712360267,601187.SH
182 2024-11-27,0.49823564043257407,601187.SH
183 2024-11-28,0.49478522113574214,002091.SZ
184 2024-11-29,0.4744243783704977,002390.SZ
185 2024-12-02,0.47703304354620096,002390.SZ
186 2024-12-03,0.48674050526244056,002566.SZ
187 2024-12-04,0.4959275225246577,002753.SZ
188 2024-12-05,0.48839819807517926,002753.SZ
189 2024-12-06,0.4870618149295468,603128.SH
190 2024-12-09,0.49249410351771356,600425.SH
191 2024-12-10,0.4959275225246577,002772.SZ
192 2024-12-11,0.48502517567793635,600035.SH
193 2024-12-12,0.48907368099203313,600035.SH
194 2024-12-13,0.4901095168698787,600035.SH
195 2024-12-16,0.485704669384595,603577.SH
196 2024-12-17,0.42271811250681296,000533.SZ
197 2024-12-18,0.45253512664134277,000026.SZ
198 2024-12-19,0.4668403018194925,601000.SH
199 2024-12-20,0.47643624443298405,000026.SZ
200 2024-12-23,0.4405301681206755,600305.SH
201 2024-12-24,0.44052806294048863,000883.SZ
202 2024-12-25,0.448810439825342,000589.SZ
203 2024-12-26,0.47338854249265216,600582.SH
204 2024-12-27,0.4495112908067394,000830.SZ
205 2024-12-30,0.463164320301011,601006.SH
206 2024-12-31,0.4773178744459276,600004.SH
207 2025-01-02,0.47702414327290926,600572.SH
208 2025-01-03,0.4474782636368997,601163.SH
209 2025-01-06,0.4333771722554744,600821.SH
210 2025-01-07,0.44770531932040636,600004.SH
211 2025-01-08,0.47082512104142743,600116.SH
212 2025-01-09,0.4541624257750102,600004.SH
213 2025-01-10,0.4505883376349118,600905.SH
214 2025-01-13,0.4505883376349118,600905.SH
215 2025-01-14,0.4629277092655053,000088.SZ
216 2025-01-15,0.4849467618585074,601222.SH
217 2025-01-16,0.4904213884330184,600273.SH
218 2025-01-17,0.49427977331421474,002267.SZ
219 2025-01-20,0.4921881826907514,000088.SZ
220 2025-01-21,0.49980885843191936,002233.SZ
221 2025-01-22,0.49226959038014517,603817.SH
222 2025-01-23,0.4845623679956794,600731.SH
223 2025-01-24,0.4894480847057984,002443.SZ
224 2025-01-27,0.480828304199342,600475.SH
225 2025-02-05,0.49381191740852215,600475.SH
226 2025-02-06,0.48051107405830695,600219.SH
227 2025-02-07,0.4912337545022996,002365.SZ
228 2025-02-10,0.4891404807504195,002606.SZ
229 2025-02-11,0.48918316877240914,002454.SZ
230 2025-02-12,0.48841224882795287,605138.SH
231 2025-02-13,0.49427977331421474,603022.SH
232 2025-02-14,0.4961990011809636,002454.SZ
233 2025-02-17,0.49980885843191936,603172.SH
234 2025-02-18,0.4859761480409009,600526.SH
235 2025-02-19,0.47951904158629705,600526.SH
236 2025-02-20,0.4865640468808261,002972.SZ
237 2025-02-21,0.4878226668596109,002972.SZ
238 2025-02-24,0.49427977331421474,002972.SZ
239 2025-02-25,0.4959275225246577,002972.SZ
240 2025-02-26,0.49559758720502334,000850.SZ
241 2025-02-27,0.4962265008275085,600969.SH
242 2025-02-28,0.45857905703475876,000931.SZ
243 2025-03-03,0.475929514181944,600704.SH
244 2025-03-04,0.4855766924008325,603176.SH
245 2025-03-05,0.48947041607005387,600749.SH
246 2025-03-06,0.48947041607005387,600749.SH
247 2025-03-07,0.48852295338866114,002948.SZ
248 2025-03-10,0.48013938012027435,600749.SH
249 2025-03-11,0.4697442502219659,603916.SH
250 2025-03-12,0.46825376442882116,600969.SH
251 2025-03-13,0.485704669384595,601311.SH
252 2025-03-14,0.47290759442027386,600784.SH
253 2025-03-17,0.48236846890282975,002204.SZ
254 2025-03-18,0.4753161955607809,600784.SH
255 2025-03-19,0.4898609892396381,002627.SZ
256 2025-03-20,0.4898609892396381,002627.SZ
257 2025-03-21,0.47660773064492085,000589.SZ
258 2025-03-24,0.4753538392607698,000589.SZ
259 2025-03-25,0.4628733846203298,000589.SZ
260 2025-03-26,0.45948501496487415,603367.SH
261 2025-03-27,0.47591884404751766,600017.SH
262 2025-03-28,0.4765671044505851,600925.SH

File diff suppressed because it is too large Load Diff

738
code/train/utils/factor.py Normal file
View File

@@ -0,0 +1,738 @@
import numpy as np
import talib
import pandas as pd
def get_technical_factor(df):
# 按股票和日期排序
df = df.sort_values(by=['ts_code', 'trade_date'])
grouped = df.groupby('ts_code', group_keys=False)
df['return_skew'] = grouped['pct_chg'].rolling(window=5).skew().reset_index(0, drop=True)
df['return_kurtosis'] = grouped['pct_chg'].rolling(window=5).kurt().reset_index(0, drop=True)
# 因子 1短期成交量变化率
df['volume_change_rate'] = (
grouped['vol'].rolling(window=2).mean() /
grouped['vol'].rolling(window=5).mean() - 1
).reset_index(level=0, drop=True) # 确保索引对齐
# 因子 2成交量突破信号
max_volume = grouped['vol'].rolling(window=5).max().reset_index(level=0, drop=True) # 确保索引对齐
df['cat_volume_breakout'] = (df['vol'] > max_volume)
# 因子 3换手率均线偏离度
mean_turnover = grouped['turnover_rate'].rolling(window=3).mean().reset_index(level=0, drop=True)
std_turnover = grouped['turnover_rate'].rolling(window=3).std().reset_index(level=0, drop=True)
df['turnover_deviation'] = (df['turnover_rate'] - mean_turnover) / std_turnover
# 因子 4换手率激增信号
df['cat_turnover_spike'] = (df['turnover_rate'] > mean_turnover + 2 * std_turnover)
# 因子 5量比均值
df['avg_volume_ratio'] = grouped['volume_ratio'].rolling(window=3).mean().reset_index(level=0, drop=True)
# 因子 6量比突破信号
max_volume_ratio = grouped['volume_ratio'].rolling(window=5).max().reset_index(level=0, drop=True)
df['cat_volume_ratio_breakout'] = (df['volume_ratio'] > max_volume_ratio)
# 因子 7成交量与换手率的综合动量因子
alpha = 0.5
df['momentum_factor'] = df['volume_change_rate'] + alpha * df['turnover_deviation']
# 因子 8量价共振因子
df['price_change_rate'] = grouped['close'].pct_change()
df['resonance_factor'] = df['volume_ratio'] * df['price_change_rate']
# 计算 up 和 down
df['log_close'] = np.log(df['close'])
df['vol_spike'] = grouped.apply(
lambda x: pd.Series(x['vol'].rolling(20).mean(), index=x.index)
)
df['cat_vol_spike'] = df['vol'] > 2 * df['vol_spike']
df['vol_std_5'] = df['vol'].pct_change().rolling(5).std()
df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']
df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']
# 计算 ATR
df['atr_14'] = grouped.apply(
lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),
index=x.index)
)
df['atr_6'] = grouped.apply(
lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),
index=x.index)
)
# 计算 OBV 及其均线
df['obv'] = grouped.apply(
lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)
)
df['maobv_6'] = grouped.apply(
lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)
)
df['obv-maobv_6'] = df['obv'] - df['maobv_6']
# 计算 RSI
df['rsi_3'] = grouped.apply(
lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)
)
df['rsi_6'] = grouped.apply(
lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)
)
df['rsi_9'] = grouped.apply(
lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)
)
# 计算 return_10 和 return_20
df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)
df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)
df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)
# df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)
# 计算标准差指标
df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())
df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())
df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())
df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())
df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())
# 计算比值指标
df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']
df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']
# 计算标准差差值
df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']
return df
def get_act_factor(df, cat=True):
# 按股票和日期排序
df = df.sort_values(by=['ts_code', 'trade_date'])
grouped = df.groupby('ts_code', group_keys=False)
# 计算 EMA 指标
df['_ema_5'] = grouped['close'].apply(
lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)
)
df['_ema_13'] = grouped['close'].apply(
lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)
)
df['_ema_20'] = grouped['close'].apply(
lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)
)
df['_ema_60'] = grouped['close'].apply(
lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)
)
# 计算 act_factor1, act_factor2, act_factor3, act_factor4
df['act_factor1'] = grouped['_ema_5'].apply(
lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50
)
df['act_factor2'] = grouped['_ema_13'].apply(
lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40
)
df['act_factor3'] = grouped['_ema_20'].apply(
lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21
)
df['act_factor4'] = grouped['_ema_60'].apply(
lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10
)
if cat:
df['cat_af1'] = df['act_factor1'] > 0
df['cat_af2'] = df['act_factor2'] > df['act_factor1']
df['cat_af3'] = df['act_factor3'] > df['act_factor2']
df['cat_af4'] = df['act_factor4'] > df['act_factor3']
# 计算 act_factor5 和 act_factor6
df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']
df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(
df['act_factor1'] ** 2 + df['act_factor2'] ** 2)
# 根据 trade_date 截面计算排名
df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)
df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)
df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)
return df
def get_money_flow_factor(df):
# 计算资金流相关因子(字段名称见 tushare 数据说明)
df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']
df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']
df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']
df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']
df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']
df['log(circ_mv)'] = np.log(df['circ_mv'])
return df
def get_alpha_factor(df):
df = df.sort_values(by=['ts_code', 'trade_date'])
grouped = df.groupby('ts_code')
# alpha_022: 当前 close 与 5 日前 close 差值
# df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))
def rolling_covariance(x, y, window):
return x.rolling(window).cov(y)
def delta(series, period):
return series.diff(period)
def rank(series):
return series.rank(pct=True)
def stddev(series, window):
return series.rolling(window).std()
# 计算改进后的 Alpha 22 因子
window_high_volume = 5
window_close_stddev = 20
period_delta = 5
df['cov'] = rolling_covariance(df['high'], df['volume'], window_high_volume)
df['delta_cov'] = delta(df['cov'], period_delta)
df['_rank_stddev'] = rank(stddev(df['close'], window_close_stddev))
df['alpha_22_improved'] = -1 * df['delta_cov'] * df['_rank_stddev']
# alpha_003: (close - open) / (high - low)
df['alpha_003'] = np.where(df['high'] != df['low'],
(df['close'] - df['open']) / (df['high'] - df['low']),
0)
# alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名
df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)
df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)
# alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名
df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())
df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)
return df
def get_limit_factor(df):
# 按股票和日期排序
df = df.sort_values(by=['ts_code', 'trade_date'])
# 分组处理
grouped = df.groupby('ts_code', group_keys=False)
# 1. 今日是否涨停/跌停
df['cat_up_limit'] = (df['close'] == df['up_limit']).astype(int) # 是否涨停1表示涨停0表示未涨停
df['cat_down_limit'] = (df['close'] == df['down_limit']).astype(int) # 是否跌停1表示跌停0表示未跌停
# 2. 最近涨跌停次数过去20个交易日
df['up_limit_count_10d'] = grouped['cat_up_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,
drop=True)
df['down_limit_count_10d'] = grouped['cat_down_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,
drop=True)
# 3. 最近连续涨跌停天数
def calculate_consecutive_limits(series):
"""
计算连续涨停/跌停天数。
"""
consecutive_up = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)
consecutive_down = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)
return consecutive_up, consecutive_down
# 连续涨停天数
df['consecutive_up_limit'] = grouped['cat_up_limit'].apply(
lambda x: calculate_consecutive_limits(x)[0]
).reset_index(level=0, drop=True)
# 连续跌停天数
# df['consecutive_down_limit'] = grouped['cat_down_limit'].apply(
# lambda x: calculate_consecutive_limits(x)[1]
# ).reset_index(level=0, drop=True)
return df
def get_cyp_perf_factor(df):
# 预处理:按股票代码和时间排序
df = df.sort_values(by=['ts_code', 'trade_date'])
# 按股票代码分组处理
grouped = df.groupby('ts_code', group_keys=False)
df['ctrl_strength'] = (df['cost_85pct'] - df['cost_15pct']) / (df['his_high'] - df['his_low'])
df['low_cost_dev'] = (df['close'] - df['cost_5pct']) / (df['cost_50pct'] - df['cost_5pct'])
df['asymmetry'] = (df['cost_95pct'] - df['cost_50pct']) / (df['cost_50pct'] - df['cost_5pct'])
df['lock_factor'] = df['turnover_rate'] * (
1 - (df['cost_95pct'] - df['cost_5pct']) / (df['his_high'] - df['his_low']))
df['vol_break'] = np.where((df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2), 1, 0)
df['weight_roc5'] = grouped['weight_avg'].apply(lambda x: x.pct_change(5))
def rolling_corr(group):
roc_close = group['close'].pct_change()
roc_weight = group['weight_avg'].pct_change()
return roc_close.rolling(10).corr(roc_weight)
df['price_cost_divergence'] = grouped.apply(rolling_corr)
def calc_atr(group):
high, low, close = group['high'], group['low'], group['close']
tr = np.maximum(high - low,
np.maximum(abs(high - close.shift()),
abs(low - close.shift())))
return tr.rolling(14).mean()
df['atr_14'] = grouped.apply(calc_atr)
df['cost_atr_adj'] = (df['cost_95pct'] - df['cost_5pct']) / df['atr_14']
# 12. 小盘股筹码集中度
df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])
# 16. 筹码稳定性指数 (20日波动率)
df['weight_std20'] = grouped['weight_avg'].apply(lambda x: x.rolling(20).std())
df['cost_stability'] = df['weight_std20'] / grouped['weight_avg'].transform(lambda x: x.rolling(20).mean())
# 17. 成本区间突破标记
df['high_cost_break_days'] = grouped.apply(lambda g: g['close'].gt(g['cost_95pct']).rolling(5).sum())
# 18. 黄金筹码共振 (复合事件)
df['cat_golden_resonance'] = ((df['close'] > df['weight_avg']) &
(df['volume_ratio'] > 1.5) &
(df['winner_rate'] > 0.7))
# 20. 筹码-流动性风险
df['liquidity_risk'] = (df['cost_95pct'] - df['cost_5pct']) * (
1 / grouped['vol'].transform(lambda x: x.rolling(10).mean()))
df.drop(columns=['weight_std20'], inplace=True, errors='ignore')
return df
def get_mv_factors(df):
"""
计算多个因子并生成最终的综合因子。
参数:
df (pd.DataFrame): 包含 ts_code, trade_date, turnover_rate, pe_ttm, pb, ps, circ_mv, volume_ratio, vol 等列的数据框。
返回:
pd.DataFrame: 包含新增因子和最终综合因子的数据框。
"""
# 按 ts_code 和 trade_date 排序
df = df.sort_values(by=['ts_code', 'trade_date'])
# 按 ts_code 分组
grouped = df.groupby('ts_code', group_keys=False)
# 1. 市值流动比因子
df['mv_turnover_ratio'] = df['turnover_rate'] / df['circ_mv']
# 2. 市值调整成交量因子
df['mv_adjusted_volume'] = df['vol'] / df['circ_mv']
# 3. 市值加权换手率因子
df['mv_weighted_turnover'] = df['turnover_rate'] * (1 / df['circ_mv'])
# 4. 非线性市值成交量因子
df['nonlinear_mv_volume'] = df['vol'] / df['circ_mv']
# 5. 市值量比因子
df['mv_volume_ratio'] = df['volume_ratio'] / df['circ_mv']
# 6. 市值动量因子
df['mv_momentum'] = df['turnover_rate'] * df['volume_ratio'] / df['circ_mv']
# 7. 市值波动率因子
df['turnover_std'] = grouped['turnover_rate'].rolling(window=20).std().reset_index(level=0, drop=True)
df['mv_volatility'] = grouped.apply(lambda x: x['turnover_std'] / x['circ_mv']).reset_index(level=0, drop=True)
# 8. 市值成长性因子
df['volume_growth'] = grouped['vol'].pct_change(periods=20).reset_index(level=0, drop=True)
df['mv_growth'] = grouped.apply(lambda x: x['volume_growth'] / x['circ_mv']).reset_index(level=0, drop=True)
# # 标准化因子
# factor_columns = [
# 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover',
# 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum',
# 'mv_volatility', 'mv_growth'
# ]
# scaler = StandardScaler()
# df[factor_columns] = scaler.fit_transform(df[factor_columns])
#
# # 加权合成因子
# weights = [0.2, 0.15, 0.15, 0.1, 0.1, 0.1, 0.1, 0.1] # 各因子权重
# df['final_combined_factor'] = df[factor_columns].dot(weights)
return df
import numpy as np
import talib
def get_rolling_factor(df):
old_columns = df.columns.tolist()[:]
# 按股票和日期排序
df = df.sort_values(by=['ts_code', 'trade_date'])
grouped = df.groupby('ts_code', group_keys=False)
df["gap_next_open"] = (df["open"].shift(-1) - df["close"]) / df["close"]
df['return_skew'] = grouped['pct_chg'].rolling(window=5).skew().reset_index(0, drop=True)
df['return_kurtosis'] = grouped['pct_chg'].rolling(window=5).kurt().reset_index(0, drop=True)
# 因子 1短期成交量变化率
df['volume_change_rate'] = (
grouped['vol'].rolling(window=2).mean() /
grouped['vol'].rolling(window=10).mean() - 1
).reset_index(level=0, drop=True) # 确保索引对齐
# 因子 2成交量突破信号
max_volume = grouped['vol'].rolling(window=5).max().reset_index(level=0, drop=True) # 确保索引对齐
df['cat_volume_breakout'] = (df['vol'] > max_volume)
# 因子 3换手率均线偏离度
mean_turnover = grouped['turnover_rate'].rolling(window=3).mean().reset_index(level=0, drop=True)
std_turnover = grouped['turnover_rate'].rolling(window=3).std().reset_index(level=0, drop=True)
df['turnover_deviation'] = (df['turnover_rate'] - mean_turnover) / std_turnover
# 因子 4换手率激增信号
df['cat_turnover_spike'] = (df['turnover_rate'] > mean_turnover + 2 * std_turnover)
# 因子 5量比均值
df['avg_volume_ratio'] = grouped['volume_ratio'].rolling(window=3).mean().reset_index(level=0, drop=True)
# 因子 6量比突破信号
max_volume_ratio = grouped['volume_ratio'].rolling(window=5).max().reset_index(level=0, drop=True)
df['cat_volume_ratio_breakout'] = (df['volume_ratio'] > max_volume_ratio)
df['vol_spike'] = grouped.apply(
lambda x: pd.Series(x['vol'].rolling(20).mean(), index=x.index)
)
df['vol_std_5'] = df['vol'].pct_change().rolling(5).std()
# 计算 ATR
df['atr_14'] = grouped.apply(
lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),
index=x.index)
)
df['atr_6'] = grouped.apply(
lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),
index=x.index)
)
# 计算 OBV 及其均线
df['obv'] = grouped.apply(
lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)
)
df['maobv_6'] = grouped.apply(
lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)
)
df['rsi_3'] = grouped.apply(
lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)
)
df['rsi_6'] = grouped.apply(
lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)
)
df['rsi_9'] = grouped.apply(
lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)
)
# 计算 return_10 和 return_20
df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)
df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)
df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)
# df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)
# 计算标准差指标
df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())
df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())
df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())
df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())
df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())
# 计算 EMA 指标
df['_ema_5'] = grouped['close'].apply(
lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)
)
df['_ema_13'] = grouped['close'].apply(
lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)
)
df['_ema_20'] = grouped['close'].apply(
lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)
)
df['_ema_60'] = grouped['close'].apply(
lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)
)
# 计算 act_factor1, act_factor2, act_factor3, act_factor4
df['act_factor1'] = grouped['_ema_5'].apply(
lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50
)
df['act_factor2'] = grouped['_ema_13'].apply(
lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40
)
df['act_factor3'] = grouped['_ema_20'].apply(
lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21
)
df['act_factor4'] = grouped['_ema_60'].apply(
lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10
)
# 根据 trade_date 截面计算排名
df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)
df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)
df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)
df['log(circ_mv)'] = np.log(df['circ_mv'])
def rolling_covariance(x, y, window):
return x.rolling(window).cov(y)
def delta(series, period):
return series.diff(period)
def rank(series):
return series.rank(pct=True)
def stddev(series, window):
return series.rolling(window).std()
window_high_volume = 5
window_close_stddev = 20
period_delta = 5
df['cov'] = rolling_covariance(df['high'], df['vol'], window_high_volume)
df['delta_cov'] = delta(df['cov'], period_delta)
df['_rank_stddev'] = rank(stddev(df['close'], window_close_stddev))
df['alpha_22_improved'] = -1 * df['delta_cov'] * df['_rank_stddev']
df['alpha_003'] = np.where(df['high'] != df['low'],
(df['close'] - df['open']) / (df['high'] - df['low']),
0)
df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)
df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)
df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())
df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)
df['cat_up_limit'] = (df['close'] == df['up_limit']) # 是否涨停1表示涨停0表示未涨停
df['cat_down_limit'] = (df['close'] == df['down_limit']) # 是否跌停1表示跌停0表示未跌停
df['up_limit_count_10d'] = grouped['cat_up_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,
drop=True)
df['down_limit_count_10d'] = grouped['cat_down_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,
drop=True)
# 3. 最近连续涨跌停天数
def calculate_consecutive_limits(series):
"""
计算连续涨停/跌停天数。
"""
consecutive_up = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)
consecutive_down = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)
return consecutive_up, consecutive_down
# 连续涨停天数
df['consecutive_up_limit'] = grouped['cat_up_limit'].apply(
lambda x: calculate_consecutive_limits(x)[0]
).reset_index(level=0, drop=True)
df['vol_break'] = np.where((df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2), 1, 0)
df['weight_roc5'] = grouped['weight_avg'].apply(lambda x: x.pct_change(5))
def rolling_corr(group):
roc_close = group['close'].pct_change()
roc_weight = group['weight_avg'].pct_change()
return roc_close.rolling(10).corr(roc_weight)
df['price_cost_divergence'] = grouped.apply(rolling_corr)
df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])
# 16. 筹码稳定性指数 (20日波动率)
df['weight_std20'] = grouped['weight_avg'].apply(lambda x: x.rolling(20).std())
df['cost_stability'] = df['weight_std20'] / grouped['weight_avg'].transform(lambda x: x.rolling(20).mean())
# 17. 成本区间突破标记
df['high_cost_break_days'] = grouped.apply(lambda g: g['close'].gt(g['cost_95pct']).rolling(5).sum())
# 20. 筹码-流动性风险
df['liquidity_risk'] = (df['cost_95pct'] - df['cost_5pct']) * (
1 / grouped['vol'].transform(lambda x: x.rolling(10).mean()))
# 7. 市值波动率因子
df['turnover_std'] = grouped['turnover_rate'].rolling(window=20).std().reset_index(level=0, drop=True)
df['mv_volatility'] = grouped.apply(lambda x: x['turnover_std'] / x['circ_mv']).reset_index(level=0, drop=True)
# 8. 市值成长性因子
df['volume_growth'] = grouped['vol'].pct_change(periods=20).reset_index(level=0, drop=True)
df['mv_growth'] = grouped.apply(lambda x: x['volume_growth'] / x['circ_mv']).reset_index(level=0, drop=True)
df.drop(columns=['weight_std20'], inplace=True, errors='ignore')
new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]
return df, new_columns
def get_simple_factor(df):
old_columns = df.columns.tolist()[:]
df = df.sort_values(by=['ts_code', 'trade_date'])
alpha = 0.5
df['momentum_factor'] = df['volume_change_rate'] + alpha * df['turnover_deviation']
df['resonance_factor'] = df['volume_ratio'] * df['pct_chg']
df['log_close'] = np.log(df['close'])
df['cat_vol_spike'] = df['vol'] > 2 * df['vol_spike']
df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']
df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']
df['obv-maobv_6'] = df['obv'] - df['maobv_6']
# 计算比值指标
df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']
df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']
# 计算标准差差值
df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']
df['cat_af1'] = df['act_factor1'] > 0
df['cat_af2'] = df['act_factor2'] > df['act_factor1']
df['cat_af3'] = df['act_factor3'] > df['act_factor2']
df['cat_af4'] = df['act_factor4'] > df['act_factor3']
# 计算 act_factor5 和 act_factor6
df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']
df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(
df['act_factor1'] ** 2 + df['act_factor2'] ** 2)
df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']
df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']
df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']
df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']
df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']
df['log(circ_mv)'] = np.log(df['circ_mv'])
df['ctrl_strength'] = (df['cost_85pct'] - df['cost_15pct']) / (df['his_high'] - df['his_low'])
df['low_cost_dev'] = (df['close'] - df['cost_5pct']) / (df['cost_50pct'] - df['cost_5pct'])
df['asymmetry'] = (df['cost_95pct'] - df['cost_50pct']) / (df['cost_50pct'] - df['cost_5pct'])
df['lock_factor'] = df['turnover_rate'] * (
1 - (df['cost_95pct'] - df['cost_5pct']) / (df['his_high'] - df['his_low']))
df['cat_vol_break'] = (df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2)
df['cost_atr_adj'] = (df['cost_95pct'] - df['cost_5pct']) / df['atr_14']
# 12. 小盘股筹码集中度
df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])
df['cat_golden_resonance'] = ((df['close'] > df['weight_avg']) &
(df['volume_ratio'] > 1.5) &
(df['winner_rate'] > 0.7))
df['mv_turnover_ratio'] = df['turnover_rate'] / df['circ_mv']
df['mv_adjusted_volume'] = df['vol'] / df['circ_mv']
df['mv_weighted_turnover'] = df['turnover_rate'] * (1 / df['circ_mv'])
df['nonlinear_mv_volume'] = df['vol'] / df['circ_mv']
df['mv_volume_ratio'] = df['volume_ratio'] / df['circ_mv']
df['mv_momentum'] = df['turnover_rate'] * df['volume_ratio'] / df['circ_mv']
drop_columns = [col for col in df.columns if col.startswith('_')]
df.drop(columns=drop_columns, inplace=True, errors='ignore')
new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]
return df, new_columns
def calculate_indicators(df):
"""
计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。
"""
df = df.sort_values('trade_date')
df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100
# df['5_day_ma'] = df['close'].rolling(window=5).mean()
delta = df['close'].diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(window=14).mean()
avg_loss = loss.rolling(window=14).mean()
rs = avg_gain / avg_loss
df['RSI'] = 100 - (100 / (1 + rs))
# 计算MACD
ema12 = df['close'].ewm(span=12, adjust=False).mean()
ema26 = df['close'].ewm(span=26, adjust=False).mean()
df['MACD'] = ema12 - ema26
df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()
df['MACD_hist'] = df['MACD'] - df['Signal_line']
# 4. 情绪因子1市场上涨比例Up Ratio
df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)
df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例
# 5. 情绪因子2成交量变化率Volume Change Rate
df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量
df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率
# 6. 情绪因子3波动率Volatility
df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差
# 7. 情绪因子4成交额变化率Amount Change Rate
df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额
df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率
return df
def generate_index_indicators(h5_filename):
df = pd.read_hdf(h5_filename, key='index_data')
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
df = df.sort_values('trade_date')
# 计算每个ts_code的相关指标
df_indicators = []
for ts_code in df['ts_code'].unique():
df_index = df[df['ts_code'] == ts_code].copy()
df_index = calculate_indicators(df_index)
df_indicators.append(df_index)
# 合并所有指数的结果
df_all_indicators = pd.concat(df_indicators, ignore_index=True)
# 保留trade_date列并将同一天的数据按ts_code合并成一行
df_final = df_all_indicators.pivot_table(
index='trade_date',
columns='ts_code',
values=['daily_return', 'RSI', 'MACD', 'Signal_line',
'MACD_hist', 'up_ratio_20d', 'volume_change_rate', 'volatility',
'amount_change_rate', 'amount_mean'],
aggfunc='last'
)
df_final.columns = [f"{col[1]}_{col[0]}" for col in df_final.columns]
df_final = df_final.reset_index()
return df_final