1388 lines
59 KiB
Plaintext
1388 lines
59 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"id": "79a7758178bafdd3",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:44:04.673079Z",
|
||
"start_time": "2025-03-02T12:44:04.247257Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"# %load_ext autoreload\n",
|
||
"# %autoreload 2\n",
|
||
"\n",
|
||
"import pandas as pd\n",
|
||
"import warnings\n",
|
||
"\n",
|
||
"warnings.filterwarnings(\"ignore\")\n",
|
||
"\n",
|
||
"pd.set_option('display.max_columns', None)\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 1
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "a79cafb06a7e0e43",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:44:43.595370Z",
|
||
"start_time": "2025-03-02T12:44:04.688084Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"from code.utils.utils import read_and_merge_h5_data\n",
|
||
"\n",
|
||
"print('daily data')\n",
|
||
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
|
||
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
|
||
" df=None)\n",
|
||
"\n",
|
||
"print('daily basic')\n",
|
||
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
|
||
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
|
||
" 'is_st'], df=df, join='inner')\n",
|
||
"\n",
|
||
"print('stk limit')\n",
|
||
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
|
||
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
|
||
" df=df)\n",
|
||
"print('money flow')\n",
|
||
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
|
||
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
|
||
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
|
||
" df=df)\n",
|
||
"print(df.info())"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"daily data\n",
|
||
"daily basic\n",
|
||
"inner merge on ['ts_code', 'trade_date']\n",
|
||
"stk limit\n",
|
||
"left merge on ['ts_code', 'trade_date']\n",
|
||
"money flow\n",
|
||
"left merge on ['ts_code', 'trade_date']\n",
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"RangeIndex: 8369855 entries, 0 to 8369854\n",
|
||
"Data columns (total 21 columns):\n",
|
||
" # Column Dtype \n",
|
||
"--- ------ ----- \n",
|
||
" 0 ts_code object \n",
|
||
" 1 trade_date datetime64[ns]\n",
|
||
" 2 open float64 \n",
|
||
" 3 close float64 \n",
|
||
" 4 high float64 \n",
|
||
" 5 low float64 \n",
|
||
" 6 vol float64 \n",
|
||
" 7 turnover_rate float64 \n",
|
||
" 8 pe_ttm float64 \n",
|
||
" 9 circ_mv float64 \n",
|
||
" 10 volume_ratio float64 \n",
|
||
" 11 is_st bool \n",
|
||
" 12 up_limit float64 \n",
|
||
" 13 down_limit float64 \n",
|
||
" 14 buy_sm_vol float64 \n",
|
||
" 15 sell_sm_vol float64 \n",
|
||
" 16 buy_lg_vol float64 \n",
|
||
" 17 sell_lg_vol float64 \n",
|
||
" 18 buy_elg_vol float64 \n",
|
||
" 19 sell_elg_vol float64 \n",
|
||
" 20 net_mf_vol float64 \n",
|
||
"dtypes: bool(1), datetime64[ns](1), float64(18), object(1)\n",
|
||
"memory usage: 1.3+ GB\n",
|
||
"None\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 2
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "f7a55c19-b7dc-4d2f-a478-cffab11690df",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:44:46.323145Z",
|
||
"start_time": "2025-03-02T12:44:43.776850Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"print('industry')\n",
|
||
"df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
|
||
" columns=['ts_code', 'l2_code'],\n",
|
||
" df=df, on=['ts_code'], join='left')\n"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"industry\n",
|
||
"left merge on ['ts_code']\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 3
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "4077d4449d406c86",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:44:46.389069Z",
|
||
"start_time": "2025-03-02T12:44:46.332410Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"def calculate_indicators(df):\n",
|
||
" \"\"\"\n",
|
||
" 计算四个指标:当日涨跌幅、5日移动平均、RSI、MACD。\n",
|
||
" \"\"\"\n",
|
||
" df = df.sort_values('trade_date')\n",
|
||
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
|
||
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
|
||
" delta = df['close'].diff()\n",
|
||
" gain = delta.where(delta > 0, 0)\n",
|
||
" loss = -delta.where(delta < 0, 0)\n",
|
||
" avg_gain = gain.rolling(window=14).mean()\n",
|
||
" avg_loss = loss.rolling(window=14).mean()\n",
|
||
" rs = avg_gain / avg_loss\n",
|
||
" df['RSI'] = 100 - (100 / (1 + rs))\n",
|
||
"\n",
|
||
" # 计算MACD\n",
|
||
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
|
||
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
|
||
" df['MACD'] = ema12 - ema26\n",
|
||
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
|
||
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
|
||
"\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def generate_index_indicators(h5_filename):\n",
|
||
" df = pd.read_hdf(h5_filename, key='index_data')\n",
|
||
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
|
||
" df = df.sort_values('trade_date')\n",
|
||
"\n",
|
||
" # 计算每个ts_code的相关指标\n",
|
||
" df_indicators = []\n",
|
||
" for ts_code in df['ts_code'].unique():\n",
|
||
" df_index = df[df['ts_code'] == ts_code].copy()\n",
|
||
" df_index = calculate_indicators(df_index)\n",
|
||
" df_indicators.append(df_index)\n",
|
||
"\n",
|
||
" # 合并所有指数的结果\n",
|
||
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
|
||
"\n",
|
||
" # 保留trade_date列,并将同一天的数据按ts_code合并成一行\n",
|
||
" df_final = df_all_indicators.pivot_table(\n",
|
||
" index='trade_date',\n",
|
||
" columns='ts_code',\n",
|
||
" values=['daily_return', 'RSI', 'MACD', 'Signal_line', 'MACD_hist'],\n",
|
||
" aggfunc='last'\n",
|
||
" )\n",
|
||
"\n",
|
||
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
|
||
" df_final = df_final.reset_index()\n",
|
||
"\n",
|
||
" return df_final\n",
|
||
"\n",
|
||
"\n",
|
||
"# 使用函数\n",
|
||
"h5_filename = '../../data/index_data.h5'\n",
|
||
"index_data = generate_index_indicators(h5_filename)\n",
|
||
"index_data = index_data.dropna()\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 4
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "c4e9e1d31da6dba6",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:44:46.438183Z",
|
||
"start_time": "2025-03-02T12:44:46.409533Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import talib\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_technical_factor(df):\n",
|
||
" # 按股票和日期排序\n",
|
||
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
" grouped = df.groupby('ts_code', group_keys=False)\n",
|
||
"\n",
|
||
" # 计算 up 和 down\n",
|
||
" df['log_close'] = np.log(df['close'])\n",
|
||
"\n",
|
||
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
|
||
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
|
||
"\n",
|
||
" # 计算 ATR\n",
|
||
" df['atr_14'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),\n",
|
||
" index=x.index)\n",
|
||
" )\n",
|
||
" df['atr_6'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),\n",
|
||
" index=x.index)\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 计算 OBV 及其均线\n",
|
||
" df['obv'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
||
" )\n",
|
||
" df['maobv_6'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
|
||
" )\n",
|
||
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
|
||
"\n",
|
||
" # 计算 RSI\n",
|
||
" df['rsi_3'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
|
||
" )\n",
|
||
" df['rsi_6'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
|
||
" )\n",
|
||
" df['rsi_9'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 计算 return_10 和 return_20\n",
|
||
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
|
||
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
|
||
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
||
"\n",
|
||
" # 计算 avg_close_5\n",
|
||
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
|
||
"\n",
|
||
" # 计算标准差指标\n",
|
||
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
|
||
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
|
||
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
|
||
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
|
||
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
|
||
"\n",
|
||
" # 计算比值指标\n",
|
||
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
|
||
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
|
||
"\n",
|
||
" # 计算标准差差值\n",
|
||
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
|
||
"\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_act_factor(df, cat=True):\n",
|
||
" # 按股票和日期排序\n",
|
||
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
" grouped = df.groupby('ts_code', group_keys=False)\n",
|
||
" # 计算 EMA 指标\n",
|
||
" df['ema_5'] = grouped['close'].apply(\n",
|
||
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
|
||
" )\n",
|
||
" df['ema_13'] = grouped['close'].apply(\n",
|
||
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
|
||
" )\n",
|
||
" df['ema_20'] = grouped['close'].apply(\n",
|
||
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
|
||
" )\n",
|
||
" df['ema_60'] = grouped['close'].apply(\n",
|
||
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
|
||
" df['act_factor1'] = grouped['ema_5'].apply(\n",
|
||
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
|
||
" )\n",
|
||
" df['act_factor2'] = grouped['ema_13'].apply(\n",
|
||
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
|
||
" )\n",
|
||
" df['act_factor3'] = grouped['ema_20'].apply(\n",
|
||
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
|
||
" )\n",
|
||
" df['act_factor4'] = grouped['ema_60'].apply(\n",
|
||
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
|
||
" )\n",
|
||
"\n",
|
||
" if cat:\n",
|
||
" df['cat_af1'] = df['act_factor1'] > 0\n",
|
||
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
|
||
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
|
||
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
|
||
"\n",
|
||
" # 计算 act_factor5 和 act_factor6\n",
|
||
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
|
||
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
|
||
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
|
||
"\n",
|
||
" # 根据 trade_date 截面计算排名\n",
|
||
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
|
||
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
|
||
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
|
||
"\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_money_flow_factor(df):\n",
|
||
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
|
||
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
|
||
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
|
||
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
|
||
"\n",
|
||
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
|
||
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
|
||
"\n",
|
||
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_alpha_factor(df):\n",
|
||
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
" grouped = df.groupby('ts_code')\n",
|
||
"\n",
|
||
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
|
||
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
|
||
"\n",
|
||
" # alpha_003: (close - open) / (high - low)\n",
|
||
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
|
||
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
|
||
" 0)\n",
|
||
"\n",
|
||
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
|
||
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
|
||
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
|
||
"\n",
|
||
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
|
||
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
|
||
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
|
||
"\n",
|
||
" return df\n",
|
||
"\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 5
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "a735bc02ceb4d872",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:44:54.093568Z",
|
||
"start_time": "2025-03-02T12:44:46.451967Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"def read_industry_data(h5_filename):\n",
|
||
" # 读取 H5 文件中所有的行业数据\n",
|
||
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
|
||
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
|
||
" ]) # 假设 H5 文件的键是 'industry_data'\n",
|
||
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
" industry_data = industry_data.reindex()\n",
|
||
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
|
||
"\n",
|
||
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
|
||
" industry_data['obv'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
||
" )\n",
|
||
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
|
||
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
||
"\n",
|
||
" industry_data = get_act_factor(industry_data, cat=False)\n",
|
||
" # industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
|
||
"\n",
|
||
" # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
|
||
" factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
|
||
"\n",
|
||
" for factor in factor_columns:\n",
|
||
" if factor in industry_data.columns:\n",
|
||
" # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
|
||
" industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
|
||
" lambda x: x - x.median())\n",
|
||
"\n",
|
||
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
|
||
" lambda x: x.rank(pct=True))\n",
|
||
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
|
||
"\n",
|
||
" industry_data = industry_data.rename(\n",
|
||
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
|
||
"\n",
|
||
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
|
||
" return industry_data\n",
|
||
"\n",
|
||
"\n",
|
||
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 6
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "53f86ddc0677a6d7",
|
||
"metadata": {
|
||
"scrolled": true,
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:44:54.102298Z",
|
||
"start_time": "2025-03-02T12:44:54.093568Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"origin_columns = df.columns.tolist()\n",
|
||
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'l2_code']]\n",
|
||
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 7
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "dbe2fd8021b9417f",
|
||
"metadata": {
|
||
"scrolled": true,
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:46:07.996377Z",
|
||
"start_time": "2025-03-02T12:44:54.115006Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"def filter_data(df):\n",
|
||
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
|
||
" df = df[~df['is_st']]\n",
|
||
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
|
||
" df = df[~df['ts_code'].str.startswith('30')]\n",
|
||
" df = df[~df['ts_code'].str.startswith('68')]\n",
|
||
" df = df[~df['ts_code'].str.startswith('8')]\n",
|
||
" df = df.reset_index(drop=True)\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"df = filter_data(df)\n",
|
||
"df = get_technical_factor(df)\n",
|
||
"df = get_act_factor(df)\n",
|
||
"df = get_money_flow_factor(df)\n",
|
||
"df = get_alpha_factor(df)\n",
|
||
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
|
||
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
|
||
"# df = df.merge(index_data, on='trade_date', how='left')\n",
|
||
"\n",
|
||
"print(df.info())"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"Index: 5732462 entries, 1964 to 5732460\n",
|
||
"Data columns (total 72 columns):\n",
|
||
" # Column Dtype \n",
|
||
"--- ------ ----- \n",
|
||
" 0 ts_code object \n",
|
||
" 1 trade_date datetime64[ns]\n",
|
||
" 2 open float64 \n",
|
||
" 3 close float64 \n",
|
||
" 4 high float64 \n",
|
||
" 5 low float64 \n",
|
||
" 6 vol float64 \n",
|
||
" 7 turnover_rate float64 \n",
|
||
" 8 pe_ttm float64 \n",
|
||
" 9 circ_mv float64 \n",
|
||
" 10 volume_ratio float64 \n",
|
||
" 11 is_st bool \n",
|
||
" 12 up_limit float64 \n",
|
||
" 13 down_limit float64 \n",
|
||
" 14 buy_sm_vol float64 \n",
|
||
" 15 sell_sm_vol float64 \n",
|
||
" 16 buy_lg_vol float64 \n",
|
||
" 17 sell_lg_vol float64 \n",
|
||
" 18 buy_elg_vol float64 \n",
|
||
" 19 sell_elg_vol float64 \n",
|
||
" 20 net_mf_vol float64 \n",
|
||
" 21 cat_l2_code object \n",
|
||
" 22 log_close float64 \n",
|
||
" 23 up float64 \n",
|
||
" 24 down float64 \n",
|
||
" 25 atr_14 float64 \n",
|
||
" 26 atr_6 float64 \n",
|
||
" 27 obv float64 \n",
|
||
" 28 maobv_6 float64 \n",
|
||
" 29 obv-maobv_6 float64 \n",
|
||
" 30 rsi_3 float64 \n",
|
||
" 31 rsi_6 float64 \n",
|
||
" 32 rsi_9 float64 \n",
|
||
" 33 return_5 float64 \n",
|
||
" 34 return_10 float64 \n",
|
||
" 35 return_20 float64 \n",
|
||
" 36 avg_close_5 float64 \n",
|
||
" 37 std_return_5 float64 \n",
|
||
" 38 std_return_15 float64 \n",
|
||
" 39 std_return_25 float64 \n",
|
||
" 40 std_return_90 float64 \n",
|
||
" 41 std_return_90_2 float64 \n",
|
||
" 42 std_return_5 / std_return_90 float64 \n",
|
||
" 43 std_return_5 / std_return_25 float64 \n",
|
||
" 44 std_return_90 - std_return_90_2 float64 \n",
|
||
" 45 ema_5 float64 \n",
|
||
" 46 ema_13 float64 \n",
|
||
" 47 ema_20 float64 \n",
|
||
" 48 ema_60 float64 \n",
|
||
" 49 act_factor1 float64 \n",
|
||
" 50 act_factor2 float64 \n",
|
||
" 51 act_factor3 float64 \n",
|
||
" 52 act_factor4 float64 \n",
|
||
" 53 cat_af1 bool \n",
|
||
" 54 cat_af2 bool \n",
|
||
" 55 cat_af3 bool \n",
|
||
" 56 cat_af4 bool \n",
|
||
" 57 act_factor5 float64 \n",
|
||
" 58 act_factor6 float64 \n",
|
||
" 59 rank_act_factor1 float64 \n",
|
||
" 60 rank_act_factor2 float64 \n",
|
||
" 61 rank_act_factor3 float64 \n",
|
||
" 62 active_buy_volume_large float64 \n",
|
||
" 63 active_buy_volume_big float64 \n",
|
||
" 64 active_buy_volume_small float64 \n",
|
||
" 65 buy_lg_vol_minus_sell_lg_vol float64 \n",
|
||
" 66 buy_elg_vol_minus_sell_elg_vol float64 \n",
|
||
" 67 log(circ_mv) float64 \n",
|
||
" 68 alpha_022 float64 \n",
|
||
" 69 alpha_003 float64 \n",
|
||
" 70 alpha_007 float64 \n",
|
||
" 71 alpha_013 float64 \n",
|
||
"dtypes: bool(5), datetime64[ns](1), float64(64), object(2)\n",
|
||
"memory usage: 2.9+ GB\n",
|
||
"None\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 8
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "d345bcc43b15579e",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:46:08.743665Z",
|
||
"start_time": "2025-03-02T12:46:08.728004Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"def create_deviation_within_dates(df, feature_columns):\n",
|
||
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
|
||
" new_columns = {}\n",
|
||
" ret_feature_columns = feature_columns[:]\n",
|
||
"\n",
|
||
" # 自动选择所有数值型特征\n",
|
||
" # num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
|
||
" num_features = [col for col in feature_columns if 'cat' not in col and 'industry' not in col]\n",
|
||
"\n",
|
||
" # 遍历所有数值型特征\n",
|
||
" for feature in num_features:\n",
|
||
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
|
||
" continue\n",
|
||
"\n",
|
||
" # grouped_median = df.groupby(['trade_date', groupby_col])[feature].transform('median')\n",
|
||
" # deviation_col_name = f'deviation_median_{feature}'\n",
|
||
" # new_columns[deviation_col_name] = df[feature] - grouped_median\n",
|
||
" # ret_feature_columns.append(deviation_col_name)\n",
|
||
"\n",
|
||
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
|
||
" deviation_col_name = f'deviation_mean_{feature}'\n",
|
||
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
|
||
" ret_feature_columns.append(deviation_col_name)\n",
|
||
"\n",
|
||
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
|
||
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
|
||
"\n",
|
||
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
|
||
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
|
||
"\n",
|
||
" return df, ret_feature_columns\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 9
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "5f3d9aece75318cd",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T12:46:08.949623Z",
|
||
"start_time": "2025-03-02T12:46:08.931989Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"def get_qcuts(series, quantiles):\n",
|
||
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
|
||
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
|
||
"\n",
|
||
"\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"\n",
|
||
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
|
||
" log=True):\n",
|
||
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
|
||
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
|
||
"\n",
|
||
" # Calculate lower and upper bounds based on percentiles\n",
|
||
" lower_bound = label.quantile(lower_percentile)\n",
|
||
" upper_bound = label.quantile(upper_percentile)\n",
|
||
"\n",
|
||
" # Filter out values outside the bounds\n",
|
||
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
|
||
"\n",
|
||
" # Print the number of removed outliers\n",
|
||
" if log:\n",
|
||
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
|
||
" return filtered_label\n",
|
||
"\n",
|
||
"\n",
|
||
"def calculate_risk_adjusted_target(df, days=5):\n",
|
||
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
"\n",
|
||
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
|
||
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
|
||
"\n",
|
||
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
|
||
" level=0, drop=True)\n",
|
||
" df['sharpe_ratio'] = df['future_return'] * df['future_volatility']\n",
|
||
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
|
||
"\n",
|
||
" return df['sharpe_ratio']\n",
|
||
"\n",
|
||
"\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 10
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T13:16:09.082156Z",
|
||
"start_time": "2025-03-02T13:14:58.041672Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"days = 3\n",
|
||
"future_close = df.groupby('ts_code')['close'].shift(-days)\n",
|
||
"future_return = (future_close - df['close']) / df['close']\n",
|
||
"df['label'] = future_return\n",
|
||
"\n",
|
||
"# df['label'] = remove_outliers_label_percentile(df['label'])\n",
|
||
"\n",
|
||
"# df = df.apply(lambda x: x.astype('float32') if x.dtype in ['float64', 'float32'] else x)\n",
|
||
"df = df.sort_values(by=['trade_date', 'ts_code'])\n",
|
||
"train_data = df[(df['trade_date'] <= '2023-01-01') & (df['trade_date'] >= '2016-01-01')]\n",
|
||
"test_data = df[(df['trade_date'] >= '2023-01-01')]\n",
|
||
"\n",
|
||
"train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
|
||
" lambda x: x.nsmallest(3000, 'log(circ_mv)')\n",
|
||
")\n",
|
||
"test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
|
||
" lambda x: x.nsmallest(3000, 'log(circ_mv)')\n",
|
||
")\n",
|
||
"train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
|
||
" lambda x: x.nlargest(1000, 'return_20')\n",
|
||
")\n",
|
||
"test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
|
||
" lambda x: x.nlargest(1000, 'return_20')\n",
|
||
")\n",
|
||
"\n",
|
||
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
|
||
"index_data = index_data.sort_values(by=['trade_date'])\n",
|
||
"\n",
|
||
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
||
"train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
|
||
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
||
"test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
|
||
"\n",
|
||
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
|
||
"\n",
|
||
"feature_columns = [col for col in train_data.columns if col not in ['trade_date',\n",
|
||
" 'ts_code',\n",
|
||
" 'label']]\n",
|
||
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
|
||
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
|
||
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
|
||
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
|
||
"print(feature_columns)\n",
|
||
"\n",
|
||
"feature_columns_new = feature_columns[:]\n",
|
||
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
|
||
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
|
||
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
|
||
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
|
||
"\n",
|
||
"train_data = train_data.dropna(subset=feature_columns_new)\n",
|
||
"train_data = train_data.dropna(subset=['label'])\n",
|
||
"train_data['label'] = remove_outliers_label_percentile(train_data['label'])\n",
|
||
"train_data = train_data.dropna(subset=['label'])\n",
|
||
"train_data = train_data.reset_index(drop=True)\n",
|
||
"\n",
|
||
"# print(test_data.tail())\n",
|
||
"test_data = test_data.dropna(subset=feature_columns_new)\n",
|
||
"# test_data = test_data.dropna(subset=['label'])\n",
|
||
"test_data = test_data.reset_index(drop=True)\n",
|
||
"\n",
|
||
"print(len(train_data))\n",
|
||
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
||
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
||
"print(len(test_data))\n",
|
||
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
||
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
||
"\n",
|
||
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
|
||
"for col in cat_columns:\n",
|
||
" train_data[col] = train_data[col].astype('category')\n",
|
||
" test_data[col] = test_data[col].astype('category')"
|
||
],
|
||
"id": "cf7de0b77db39655",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"['turnover_rate', 'pe_ttm', 'volume_ratio', 'cat_l2_code', 'log_close', 'up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_5', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'cat_af1', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'log(circ_mv)', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry_ema_5', 'industry_ema_13', 'industry_ema_20', 'industry_ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_obv_deviation', 'industry_return_5_deviation', 'industry_return_20_deviation', 'industry_act_factor1_deviation', 'industry_act_factor2_deviation', 'industry_act_factor3_deviation', 'industry_act_factor4_deviation', 'industry_return_5_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return']\n",
|
||
"feature_columns size: 157\n",
|
||
"feature_columns size: 157\n",
|
||
"Removed 23208 outliers.\n",
|
||
"1137173\n",
|
||
"最小日期: 2017-04-06\n",
|
||
"最大日期: 2022-12-30\n",
|
||
"396093\n",
|
||
"最小日期: 2023-01-03\n",
|
||
"最大日期: 2025-02-28\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 68
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "8f134d435f71e9e2",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T13:23:29.071382Z",
|
||
"start_time": "2025-03-02T13:23:29.059163Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"import pandas as pd\n",
|
||
"from sklearn.linear_model import LinearRegression\n",
|
||
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
||
"\n",
|
||
"\n",
|
||
"def train_sgd_model(train_data: pd.DataFrame,\n",
|
||
" feature_columns: list,\n",
|
||
" params: dict, print_feature_importance=True):\n",
|
||
" # Initialize scaler and encoder\n",
|
||
" scaler = StandardScaler()\n",
|
||
" encoder = OneHotEncoder(handle_unknown='ignore')\n",
|
||
"\n",
|
||
" # Extract features and labels\n",
|
||
" X_train = train_data[feature_columns]\n",
|
||
" y_train = train_data['label']\n",
|
||
"\n",
|
||
" numeric_columns = X_train.select_dtypes(include=['float64', 'int64']).columns\n",
|
||
" categorical_columns = [col for col in feature_columns if col.startswith('cat')]\n",
|
||
"\n",
|
||
" X_train.loc[:, numeric_columns] = scaler.fit_transform(X_train[numeric_columns])\n",
|
||
" X_train_categorical = encoder.fit_transform(X_train[categorical_columns]).toarray()\n",
|
||
"\n",
|
||
" # Combine numeric and categorical features\n",
|
||
" X_train_processed = pd.concat([\n",
|
||
" pd.DataFrame(X_train[numeric_columns], columns=numeric_columns, index=X_train.index),\n",
|
||
" pd.DataFrame(X_train_categorical, columns=encoder.get_feature_names_out(categorical_columns),\n",
|
||
" index=X_train.index)\n",
|
||
" ], axis=1)\n",
|
||
"\n",
|
||
" # Train the model\n",
|
||
" # model = SGDRegressor(**params)\n",
|
||
" model = LinearRegression()\n",
|
||
" model.fit(X_train_processed, y_train)\n",
|
||
"\n",
|
||
" # 特征重要性可视化\n",
|
||
" if print_feature_importance:\n",
|
||
" coefficients = model.coef_\n",
|
||
"\n",
|
||
" # 创建一个字典,存储特征名称和对应的系数\n",
|
||
" feature_importance = dict(zip(X_train_processed.columns, coefficients))\n",
|
||
"\n",
|
||
" # 按系数绝对值排序\n",
|
||
" sorted_importance = sorted(feature_importance.items(), key=lambda x: abs(x[1]), reverse=True)\n",
|
||
"\n",
|
||
" # 打印特征重要性\n",
|
||
" print(\"Feature Importance:\")\n",
|
||
" for feature, importance in sorted_importance:\n",
|
||
" print(f\"{feature}: {importance:.4f}\")\n",
|
||
"\n",
|
||
" return model, scaler, encoder"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 79
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "beeb098799ecfa6a",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T13:23:56.411933Z",
|
||
"start_time": "2025-03-02T13:23:29.076456Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"print('train data size: ', len(train_data))\n",
|
||
"import gc\n",
|
||
"\n",
|
||
"gc.collect()\n",
|
||
"params = {\n",
|
||
" 'alpha': 0.0001, # 正则化强度\n",
|
||
" 'max_iter': 1000, # 最大迭代次数\n",
|
||
" 'tol': 1e-3, # 收敛阈值\n",
|
||
" 'eta0': 0.01, # 初始学习率\n",
|
||
" 'learning_rate': 'constant'\n",
|
||
"}\n",
|
||
"\n",
|
||
"model, scaler, encoder = train_sgd_model(train_data, feature_columns_new, params)"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"train data size: 1137173\n",
|
||
"Feature Importance:\n",
|
||
"deviation_mean_ema_13: 0.0841\n",
|
||
"deviation_mean_ema_20: -0.0679\n",
|
||
"ema_13: -0.0630\n",
|
||
"ema_20: 0.0550\n",
|
||
"industry_ema_20: 0.0238\n",
|
||
"industry_ema_13: -0.0227\n",
|
||
"act_factor3: -0.0217\n",
|
||
"deviation_mean_ema_5: -0.0200\n",
|
||
"industry_obv: 0.0178\n",
|
||
"industry_obv_deviation: -0.0176\n",
|
||
"rsi_6: 0.0155\n",
|
||
"act_factor2: 0.0152\n",
|
||
"ema_5: 0.0118\n",
|
||
"industry_act_factor4: -0.0092\n",
|
||
"rsi_3: -0.0091\n",
|
||
"deviation_mean_rsi_6: -0.0088\n",
|
||
"cat_l2_code_801056.SI: 0.0087\n",
|
||
"cat_l2_code_801125.SI: 0.0087\n",
|
||
"deviation_mean_act_factor2: -0.0074\n",
|
||
"deviation_mean_act_factor3: 0.0073\n",
|
||
"industry_act_factor1: 0.0073\n",
|
||
"alpha_013: -0.0073\n",
|
||
"rank_act_factor2: 0.0072\n",
|
||
"cat_l2_code_801039.SI: 0.0070\n",
|
||
"cat_l2_code_801194.SI: -0.0069\n",
|
||
"cat_l2_code_801782.SI: 0.0067\n",
|
||
"industry_act_factor4_deviation: 0.0066\n",
|
||
"industry_act_factor1_deviation: -0.0064\n",
|
||
"rank_act_factor3: -0.0062\n",
|
||
"deviation_mean_alpha_013: 0.0061\n",
|
||
"cat_l2_code_801737.SI: 0.0061\n",
|
||
"deviation_mean_rank_act_factor2: -0.0058\n",
|
||
"industry_act_factor2_deviation: 0.0057\n",
|
||
"log_close: -0.0057\n",
|
||
"cat_l2_code_801972.SI: -0.0056\n",
|
||
"cat_l2_code_801735.SI: 0.0055\n",
|
||
"cat_l2_code_801044.SI: 0.0055\n",
|
||
"atr_14: -0.0052\n",
|
||
"deviation_mean_log_close: 0.0050\n",
|
||
"turnover_rate: -0.0050\n",
|
||
"cat_l2_code_801995.SI: -0.0049\n",
|
||
"cat_l2_code_801081.SI: 0.0049\n",
|
||
"399006.SZ_RSI: 0.0048\n",
|
||
"cat_l2_code_801191.SI: -0.0048\n",
|
||
"deviation_mean_act_factor5: 0.0048\n",
|
||
"cat_l2_code_801054.SI: 0.0047\n",
|
||
"deviation_mean_atr_14: 0.0047\n",
|
||
"deviation_mean_rank_act_factor3: 0.0046\n",
|
||
"cat_l2_code_801231.SI: -0.0044\n",
|
||
"deviation_mean_rsi_3: 0.0044\n",
|
||
"industry_act_factor2: -0.0044\n",
|
||
"cat_l2_code_801085.SI: 0.0043\n",
|
||
"cat_l2_code_801084.SI: 0.0042\n",
|
||
"cat_l2_code_801154.SI: -0.0041\n",
|
||
"cat_l2_code_801111.SI: 0.0041\n",
|
||
"000905.SH_MACD_hist: 0.0039\n",
|
||
"399006.SZ_MACD_hist: -0.0038\n",
|
||
"cat_l2_code_801721.SI: -0.0036\n",
|
||
"cat_l2_code_801181.SI: -0.0036\n",
|
||
"cat_l2_code_801011.SI: 0.0035\n",
|
||
"cat_l2_code_801994.SI: -0.0035\n",
|
||
"return_10: 0.0034\n",
|
||
"cat_l2_code_801076.SI: -0.0034\n",
|
||
"cat_l2_code_801017.SI: 0.0033\n",
|
||
"cat_l2_code_801016.SI: 0.0033\n",
|
||
"deviation_mean_rsi_9: 0.0033\n",
|
||
"industry_ema_60: -0.0032\n",
|
||
"cat_l2_code_801018.SI: -0.0032\n",
|
||
"cat_l2_code_801202.SI: -0.0031\n",
|
||
"cat_l2_code_801114.SI: -0.0031\n",
|
||
"cat_l2_code_801203.SI: -0.0031\n",
|
||
"std_return_5 / std_return_25: 0.0031\n",
|
||
"cat_l2_code_801736.SI: 0.0031\n",
|
||
"cat_l2_code_801993.SI: -0.0030\n",
|
||
"act_factor5: -0.0030\n",
|
||
"act_factor6: 0.0030\n",
|
||
"industry_act_factor5: -0.0029\n",
|
||
"cat_l2_code_801012.SI: -0.0028\n",
|
||
"atr_6: 0.0028\n",
|
||
"cat_l2_code_801092.SI: -0.0028\n",
|
||
"000905.SH_MACD: 0.0028\n",
|
||
"cat_l2_code_801077.SI: 0.0028\n",
|
||
"cat_l2_code_801112.SI: -0.0028\n",
|
||
"cat_l2_code_801744.SI: -0.0028\n",
|
||
"cat_l2_code_801055.SI: 0.0027\n",
|
||
"cat_l2_code_801034.SI: 0.0027\n",
|
||
"deviation_mean_atr_6: -0.0027\n",
|
||
"cat_l2_code_801183.SI: -0.0026\n",
|
||
"cat_l2_code_801769.SI: -0.0026\n",
|
||
"deviation_mean_act_factor6: -0.0025\n",
|
||
"industry_return_20: 0.0024\n",
|
||
"deviation_mean_turnover_rate: 0.0024\n",
|
||
"cat_l2_code_801952.SI: 0.0024\n",
|
||
"industry_return_20_deviation: -0.0024\n",
|
||
"000905.SH_RSI: -0.0024\n",
|
||
"cat_l2_code_801971.SI: -0.0023\n",
|
||
"deviation_mean_std_return_5 / std_return_25: -0.0023\n",
|
||
"cat_l2_code_801095.SI: 0.0023\n",
|
||
"act_factor4: 0.0023\n",
|
||
"cat_l2_code_801178.SI: -0.0023\n",
|
||
"std_return_15: 0.0023\n",
|
||
"rsi_9: -0.0023\n",
|
||
"deviation_mean_return_10: -0.0022\n",
|
||
"cat_l2_code_801712.SI: 0.0022\n",
|
||
"cat_l2_code_801731.SI: 0.0022\n",
|
||
"up: 0.0022\n",
|
||
"cat_l2_code_801083.SI: 0.0021\n",
|
||
"cat_l2_code_801204.SI: -0.0021\n",
|
||
"cat_l2_code_801723.SI: -0.0021\n",
|
||
"cat_l2_code_801992.SI: -0.0021\n",
|
||
"cat_l2_code_801155.SI: -0.0021\n",
|
||
"cat_l2_code_801179.SI: -0.0021\n",
|
||
"alpha_007: -0.0020\n",
|
||
"std_return_25: 0.0020\n",
|
||
"cat_l2_code_801742.SI: -0.0020\n",
|
||
"deviation_mean_avg_close_5: -0.0020\n",
|
||
"deviation_mean_std_return_25: -0.0019\n",
|
||
"cat_l2_code_801767.SI: -0.0019\n",
|
||
"cat_l2_code_801086.SI: 0.0019\n",
|
||
"deviation_mean_up: -0.0019\n",
|
||
"act_factor1: -0.0019\n",
|
||
"industry_act_factor3_deviation: 0.0019\n",
|
||
"std_return_5: 0.0018\n",
|
||
"deviation_mean_std_return_90 - std_return_90_2: 0.0018\n",
|
||
"000852.SH_Signal_line: 0.0018\n",
|
||
"deviation_mean_std_return_90_2: 0.0018\n",
|
||
"cat_l2_code_801218.SI: -0.0018\n",
|
||
"rank_act_factor1: -0.0017\n",
|
||
"deviation_mean_std_return_15: -0.0017\n",
|
||
"000905.SH_Signal_line: 0.0017\n",
|
||
"cat_l2_code_801783.SI: -0.0017\n",
|
||
"deviation_mean_act_factor4: -0.0016\n",
|
||
"std_return_90 - std_return_90_2: -0.0016\n",
|
||
"industry_return_5_deviation: -0.0016\n",
|
||
"cat_l2_code_801045.SI: 0.0016\n",
|
||
"cat_l2_code_801765.SI: -0.0016\n",
|
||
"industry_rank_act_factor3: -0.0016\n",
|
||
"cat_l2_code_801784.SI: -0.0015\n",
|
||
"industry_rank_act_factor2: 0.0015\n",
|
||
"cat_l2_code_801082.SI: 0.0015\n",
|
||
"cat_l2_code_801014.SI: 0.0015\n",
|
||
"cat_l2_code_801141.SI: -0.0015\n",
|
||
"cat_l2_code_801124.SI: 0.0015\n",
|
||
"cat_af2_1.0: -0.0015\n",
|
||
"cat_af2_0.0: 0.0015\n",
|
||
"cat_l2_code_801113.SI: 0.0014\n",
|
||
"cat_l2_code_801142.SI: -0.0014\n",
|
||
"cat_l2_code_801015.SI: -0.0014\n",
|
||
"cat_l2_code_801104.SI: 0.0014\n",
|
||
"399006.SZ_daily_return: 0.0014\n",
|
||
"000852.SH_MACD: 0.0013\n",
|
||
"alpha_003: -0.0013\n",
|
||
"cat_l2_code_801219.SI: -0.0013\n",
|
||
"deviation_mean_ema_60: 0.0013\n",
|
||
"industry_ema_5: 0.0013\n",
|
||
"cat_l2_code_801078.SI: 0.0013\n",
|
||
"000852.SH_MACD_hist: -0.0012\n",
|
||
"industry_rank_act_factor1: -0.0012\n",
|
||
"cat_l2_code_801101.SI: 0.0012\n",
|
||
"cat_l2_code_801033.SI: 0.0012\n",
|
||
"cat_l2_code_801152.SI: -0.0012\n",
|
||
"cat_l2_code_801128.SI: -0.0012\n",
|
||
"return_5: -0.0012\n",
|
||
"deviation_mean_std_return_90: -0.0012\n",
|
||
"down: -0.0012\n",
|
||
"deviation_mean_rank_act_factor1: 0.0012\n",
|
||
"cat_l2_code_801096.SI: -0.0011\n",
|
||
"deviation_mean_log(circ_mv): -0.0011\n",
|
||
"deviation_mean_alpha_007: 0.0011\n",
|
||
"maobv_6: -0.0011\n",
|
||
"ema_60: -0.0011\n",
|
||
"cat_l2_code_801963.SI: -0.0011\n",
|
||
"cat_l2_code_801738.SI: -0.0011\n",
|
||
"cat_l2_code_801093.SI: 0.0011\n",
|
||
"obv: -0.0011\n",
|
||
"cat_l2_code_801074.SI: 0.0011\n",
|
||
"cat_l2_code_801193.SI: 0.0010\n",
|
||
"log(circ_mv): 0.0010\n",
|
||
"cat_l2_code_801785.SI: -0.0010\n",
|
||
"cat_l2_code_801115.SI: -0.0010\n",
|
||
"cat_l2_code_801206.SI: 0.0010\n",
|
||
"cat_l2_code_801724.SI: -0.0010\n",
|
||
"deviation_mean_maobv_6: 0.0010\n",
|
||
"cat_l2_code_801991.SI: -0.0009\n",
|
||
"cat_l2_code_801161.SI: -0.0009\n",
|
||
"deviation_mean_obv: 0.0009\n",
|
||
"cat_l2_code_801143.SI: 0.0009\n",
|
||
"cat_l2_code_801053.SI: 0.0008\n",
|
||
"cat_l2_code_801156.SI: -0.0008\n",
|
||
"volume_ratio: -0.0008\n",
|
||
"cat_l2_code_801962.SI: -0.0008\n",
|
||
"std_return_90: -0.0008\n",
|
||
"cat_l2_code_801711.SI: 0.0008\n",
|
||
"buy_elg_vol_minus_sell_elg_vol: -0.0008\n",
|
||
"cat_l2_code_801038.SI: -0.0008\n",
|
||
"cat_l2_code_801072.SI: 0.0008\n",
|
||
"cat_l2_code_801037.SI: 0.0007\n",
|
||
"deviation_mean_active_buy_volume_large: -0.0007\n",
|
||
"cat_l2_code_801981.SI: -0.0007\n",
|
||
"active_buy_volume_large: 0.0007\n",
|
||
"cat_l2_code_801032.SI: -0.0007\n",
|
||
"deviation_mean_buy_elg_vol_minus_sell_elg_vol: 0.0007\n",
|
||
"399006.SZ_MACD: -0.0007\n",
|
||
"obv-maobv_6: 0.0007\n",
|
||
"cat_l2_code_801733.SI: 0.0006\n",
|
||
"cat_l2_code_801103.SI: -0.0006\n",
|
||
"deviation_mean_obv-maobv_6: -0.0006\n",
|
||
"std_return_5 / std_return_90: -0.0006\n",
|
||
"cat_l2_code_801163.SI: -0.0006\n",
|
||
"000852.SH_daily_return: -0.0006\n",
|
||
"cat_l2_code_801043.SI: 0.0006\n",
|
||
"cat_l2_code_801131.SI: -0.0005\n",
|
||
"cat_af1_1.0: -0.0005\n",
|
||
"cat_af1_0.0: 0.0005\n",
|
||
"cat_l2_code_801764.SI: -0.0005\n",
|
||
"cat_l2_code_801153.SI: -0.0005\n",
|
||
"deviation_mean_act_factor1: -0.0005\n",
|
||
"cat_l2_code_801951.SI: -0.0005\n",
|
||
"cat_l2_code_801726.SI: -0.0005\n",
|
||
"deviation_mean_std_return_5: -0.0005\n",
|
||
"cat_l2_code_801036.SI: 0.0005\n",
|
||
"industry_act_factor6: 0.0005\n",
|
||
"industry_return_5_percentile: 0.0005\n",
|
||
"cat_af4_1.0: 0.0005\n",
|
||
"cat_af4_0.0: -0.0005\n",
|
||
"cat_l2_code_801223.SI: -0.0005\n",
|
||
"cat_l2_code_801881.SI: -0.0004\n",
|
||
"399006.SZ_Signal_line: 0.0004\n",
|
||
"cat_l2_code_801743.SI: -0.0004\n",
|
||
"deviation_mean_active_buy_volume_small: 0.0004\n",
|
||
"return_20: -0.0004\n",
|
||
"000905.SH_daily_return: -0.0003\n",
|
||
"cat_l2_code_801982.SI: -0.0003\n",
|
||
"pe_ttm: -0.0003\n",
|
||
"cat_l2_code_801713.SI: 0.0003\n",
|
||
"cat_l2_code_801129.SI: 0.0003\n",
|
||
"active_buy_volume_small: -0.0003\n",
|
||
"deviation_mean_return_5: -0.0003\n",
|
||
"deviation_mean_399006.SZ_Signal_line: -0.0003\n",
|
||
"cat_l2_code_801126.SI: 0.0003\n",
|
||
"deviation_mean_pe_ttm: 0.0002\n",
|
||
"deviation_mean_alpha_003: 0.0002\n",
|
||
"cat_l2_code_801116.SI: -0.0002\n",
|
||
"cat_af3_0.0: 0.0002\n",
|
||
"cat_af3_1.0: -0.0002\n",
|
||
"cat_l2_code_801132.SI: -0.0002\n",
|
||
"deviation_mean_399006.SZ_daily_return: 0.0002\n",
|
||
"std_return_90_2: -0.0002\n",
|
||
"deviation_mean_alpha_022: 0.0002\n",
|
||
"deviation_mean_std_return_5 / std_return_90: -0.0002\n",
|
||
"deviation_mean_000852.SH_RSI: -0.0002\n",
|
||
"deviation_mean_return_20: 0.0002\n",
|
||
"deviation_mean_down: -0.0002\n",
|
||
"active_buy_volume_big: 0.0002\n",
|
||
"deviation_mean_volume_ratio: -0.0002\n",
|
||
"cat_l2_code_801133.SI: 0.0002\n",
|
||
"industry_act_factor3: 0.0002\n",
|
||
"alpha_022: -0.0001\n",
|
||
"deviation_mean_399006.SZ_MACD: -0.0001\n",
|
||
"avg_close_5: -0.0001\n",
|
||
"deviation_mean_000852.SH_MACD: 0.0001\n",
|
||
"cat_l2_code_801722.SI: -0.0001\n",
|
||
"cat_l2_code_801766.SI: 0.0001\n",
|
||
"deviation_mean_active_buy_volume_big: -0.0001\n",
|
||
"cat_l2_code_801127.SI: -0.0001\n",
|
||
"deviation_mean_399006.SZ_MACD_hist: 0.0001\n",
|
||
"deviation_mean_000905.SH_Signal_line: -0.0001\n",
|
||
"cat_l2_code_801745.SI: -0.0001\n",
|
||
"deviation_mean_000905.SH_RSI: 0.0001\n",
|
||
"deviation_mean_000905.SH_MACD_hist: -0.0001\n",
|
||
"deviation_mean_399006.SZ_RSI: -0.0001\n",
|
||
"industry_return_5: 0.0001\n",
|
||
"cat_l2_code_801102.SI: -0.0001\n",
|
||
"cat_l2_code_801741.SI: 0.0001\n",
|
||
"000852.SH_RSI: 0.0001\n",
|
||
"deviation_mean_000852.SH_daily_return: -0.0001\n",
|
||
"deviation_mean_000852.SH_Signal_line: -0.0001\n",
|
||
"deviation_mean_000852.SH_MACD_hist: -0.0001\n",
|
||
"cat_l2_code_801145.SI: -0.0000\n",
|
||
"deviation_mean_buy_lg_vol_minus_sell_lg_vol: 0.0000\n",
|
||
"cat_l2_code_801051.SI: -0.0000\n",
|
||
"cat_l2_code_801151.SI: -0.0000\n",
|
||
"deviation_mean_000905.SH_daily_return: -0.0000\n",
|
||
"deviation_mean_000905.SH_MACD: -0.0000\n",
|
||
"buy_lg_vol_minus_sell_lg_vol: 0.0000\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 80
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "465944b1d463e4b1",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T13:23:57.174623Z",
|
||
"start_time": "2025-03-02T13:23:57.148826Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"from tqdm import tqdm\n",
|
||
"\n",
|
||
"\n",
|
||
"def incremental_training(test_data: pd.DataFrame,\n",
|
||
" model,\n",
|
||
" scaler,\n",
|
||
" encoder,\n",
|
||
" days: int,\n",
|
||
" back_days: int,\n",
|
||
" feature_columns: list,\n",
|
||
" params: dict\n",
|
||
" ):\n",
|
||
" test_data = test_data.sort_values(by='trade_date')\n",
|
||
" scores = []\n",
|
||
" unique_trade_dates = sorted(test_data['trade_date'].unique())\n",
|
||
"\n",
|
||
" new_model = None\n",
|
||
" for i in tqdm(range(0, len(unique_trade_dates))):\n",
|
||
" # Get the current window of trade dates\n",
|
||
" current_dates = [unique_trade_dates[i]]\n",
|
||
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
|
||
" X = window_data[feature_columns]\n",
|
||
" numeric_columns = X.select_dtypes(include=['float64', 'int64']).columns\n",
|
||
" categorical_columns = [col for col in feature_columns if col.startswith('cat')]\n",
|
||
" X.loc[:, numeric_columns] = scaler.transform(X[numeric_columns])\n",
|
||
" X_categorical = encoder.transform(X[categorical_columns]).toarray()\n",
|
||
"\n",
|
||
" # Combine numeric and categorical features\n",
|
||
" X_processed = pd.concat([\n",
|
||
" pd.DataFrame(X[numeric_columns], columns=numeric_columns, index=X.index),\n",
|
||
" pd.DataFrame(X_categorical, columns=encoder.get_feature_names_out(categorical_columns), index=X.index)\n",
|
||
" ], axis=1)\n",
|
||
" X_processed = X_processed.fillna(0)\n",
|
||
"\n",
|
||
" if new_model is not None:\n",
|
||
" window_scores = new_model.predict(X_processed)\n",
|
||
" else:\n",
|
||
" window_scores = model.predict(X_processed)\n",
|
||
" scores.extend(window_scores)\n",
|
||
"\n",
|
||
" # # Prepare data for incremental training\n",
|
||
" # current_dates = unique_trade_dates[max(0, i - back_days):i + days]\n",
|
||
" # window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
|
||
" # X_train = window_data[feature_columns]\n",
|
||
" current_dates = unique_trade_dates[max(0, i - days - back_days):i + 1 - back_days]\n",
|
||
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
|
||
" window_data['label'] = remove_outliers_label_percentile(window_data['label'], log=False)\n",
|
||
" window_data = window_data.dropna(subset=feature_columns)\n",
|
||
" window_data = window_data.dropna(subset=['label'])\n",
|
||
" X_train = window_data[feature_columns]\n",
|
||
" y_train = window_data['label']\n",
|
||
" # Incrementally train the model\n",
|
||
" if len(y_train.unique()) > 1:\n",
|
||
" X_train.loc[:, numeric_columns] = scaler.transform(X_train[numeric_columns])\n",
|
||
" X_train_categorical = encoder.transform(X_train[categorical_columns]).toarray()\n",
|
||
" X_train_processed = pd.concat([\n",
|
||
" pd.DataFrame(X_train[numeric_columns], columns=numeric_columns, index=X_train.index),\n",
|
||
" pd.DataFrame(X_train_categorical, columns=encoder.get_feature_names_out(categorical_columns),\n",
|
||
" index=X_train.index)\n",
|
||
" ], axis=1)\n",
|
||
" X_train_processed = X_train_processed.fillna(0)\n",
|
||
" model = model.partial_fit(X_train_processed, y_train)\n",
|
||
" else:\n",
|
||
" print(current_dates)\n",
|
||
"\n",
|
||
" # Add the scores as a new 'score' column to the test_data\n",
|
||
" test_data['score'] = scores\n",
|
||
" return test_data"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 81
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "e3ac761d8f0b5d31",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T13:25:27.698736Z",
|
||
"start_time": "2025-03-02T13:25:25.801431Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"import gc\n",
|
||
"\n",
|
||
"gc.collect()\n",
|
||
"\n",
|
||
"# predictions_test = incremental_training(test_data, model, scaler, encoder, 10, days, feature_columns_new, params)\n",
|
||
"X_test = test_data[feature_columns_new]\n",
|
||
"numeric_columns = X_test.select_dtypes(include=['float64', 'int64']).columns\n",
|
||
"categorical_columns = [col for col in feature_columns if col.startswith('cat')]\n",
|
||
"\n",
|
||
"X_test.loc[:, numeric_columns] = scaler.transform(X_test[numeric_columns])\n",
|
||
"X_test_categorical = encoder.transform(X_test[categorical_columns]).toarray()\n",
|
||
"\n",
|
||
"# Combine numeric and categorical features\n",
|
||
"X_test_processed = pd.concat([\n",
|
||
" pd.DataFrame(X_test[numeric_columns], columns=numeric_columns, index=X_test.index),\n",
|
||
" pd.DataFrame(X_test_categorical, columns=encoder.get_feature_names_out(categorical_columns), index=X_test.index)\n",
|
||
"], axis=1)\n",
|
||
"predictions_test = test_data[['ts_code', 'trade_date']]\n",
|
||
"predictions_test['score'] = model.predict(X_test_processed)\n",
|
||
"predictions_test = predictions_test.loc[predictions_test.groupby('trade_date')['score'].idxmax()]\n",
|
||
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 84
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "b427ce41-9739-4e9e-bea8-5f2551fec5d7",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T13:24:01.404626100Z",
|
||
"start_time": "2025-03-02T13:17:17.216317Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"import joblib\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"\n",
|
||
"# 假设你已经训练好了一个 LightGBM 模型\n",
|
||
"# model = lgb.train(params, train_data, ...)\n",
|
||
"\n",
|
||
"def save_model_with_info(model, params, feature_columns, train_data, info, save_path):\n",
|
||
" \"\"\"\n",
|
||
" 保存 LightGBM 模型及其相关信息。\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" model: 训练好的 LightGBM 模型 (lgb.Booster)。\n",
|
||
" params: 模型的参数 (dict)。\n",
|
||
" feature_columns: 特征列名列表 (list)。\n",
|
||
" train_data: 训练数据 (pd.DataFrame),包含 'trade_date' 列。\n",
|
||
" info: 额外信息 (str 或 dict)。\n",
|
||
" save_path: 保存路径 (str)。\n",
|
||
" \"\"\"\n",
|
||
" # 提取训练数据的 trade_date 的最大值和最小值\n",
|
||
" if 'trade_date' not in train_data.columns:\n",
|
||
" raise ValueError(\"训练数据中必须包含 'trade_date' 列。\")\n",
|
||
"\n",
|
||
" trade_date_min = train_data['trade_date'].min()\n",
|
||
" trade_date_max = train_data['trade_date'].max()\n",
|
||
"\n",
|
||
" # 构建保存的信息字典\n",
|
||
" model_info = {\n",
|
||
" 'model': model, # 保存模型本身\n",
|
||
" 'params': params, # 模型参数\n",
|
||
" 'feature_columns': feature_columns, # 特征列名\n",
|
||
" 'trade_date_range': {\n",
|
||
" 'min': trade_date_min,\n",
|
||
" 'max': trade_date_max\n",
|
||
" }, # trade_date 的范围\n",
|
||
" 'info': info # 额外信息\n",
|
||
" }\n",
|
||
"\n",
|
||
" # 使用 joblib 保存模型及相关信息\n",
|
||
" joblib.dump(model_info, save_path)\n",
|
||
" print(f\"模型及相关信息已成功保存到 {save_path}\")\n",
|
||
"\n",
|
||
"# info = \"Update Regression + 滚动new model + 5days\"\n",
|
||
"\n",
|
||
"# # 保存模型及相关信息\n",
|
||
"# save_path = \"../model/lightgbm_model_UpdateRegression_2025-2-25.pkl\"\n",
|
||
"# save_model_with_info(model, light_params, feature_columns, train_data, info, save_path)"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 73
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "8f9a2b7b-11fe-4eb5-aa11-c4066fe418a1",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-02T13:24:01.407058800Z",
|
||
"start_time": "2025-03-02T13:17:17.289631Z"
|
||
}
|
||
},
|
||
"source": [],
|
||
"outputs": [],
|
||
"execution_count": null
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.19"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|