Files
NewStock/main/train/Transformer.ipynb

1363 lines
196 KiB
Plaintext
Raw Normal View History

2025-04-03 00:45:07 +08:00
{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:54.782179Z",
"start_time": "2025-03-01T04:00:54.471051Z"
}
},
"source": [
"%load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import pandas as pd\n",
" \n",
"\n",
" \n",
"pd.set_option('display.max_columns', None)\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload module is not an IPython extension.\n"
]
}
],
"execution_count": 1
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-01T04:01:45.158165Z",
"start_time": "2025-03-01T04:00:54.784691Z"
}
},
"source": [
"\n",
2025-04-28 11:02:52 +08:00
"from code.utils.utils import read_and_merge_h5_data\n",
2025-04-03 00:45:07 +08:00
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8350183 entries, 0 to 8350182\n",
"Data columns (total 21 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(18), object(1)\n",
"memory usage: 1.3+ GB\n",
"None\n"
]
}
],
"execution_count": 2
},
{
"cell_type": "code",
"id": "f7a55c19-b7dc-4d2f-a478-cffab11690df",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:01:49.907736Z",
"start_time": "2025-03-01T04:01:45.533679Z"
}
},
"source": [
"print('industry')\n",
"df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code'],\n",
" df=df, on=['ts_code'], join='left')\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n",
"left merge on ['ts_code']\n"
]
}
],
"execution_count": 3
},
{
"cell_type": "code",
"id": "4077d4449d406c86",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:01:50.070194Z",
"start_time": "2025-03-01T04:01:49.936309Z"
}
},
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" return df\n",
"\n",
"def generate_index_indicators(h5_filename):\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" df = df.sort_values('trade_date')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
" df_index = calculate_indicators(df_index)\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', 'RSI', 'MACD', 'Signal_line', 'MACD_hist'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
"h5_filename = '../../data/index_data.h5'\n",
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n"
],
"outputs": [],
"execution_count": 4
},
{
"cell_type": "code",
"id": "c4e9e1d31da6dba6",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-01T04:12:27.210806Z",
"start_time": "2025-03-01T04:12:26.858559Z"
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"def get_technical_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 计算 up 和 down\n",
" df['log_close'] = np.log(df['close'])\n",
"\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14), index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算 RSI\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df, cat=True):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
" # 计算 EMA 指标\n",
" df['ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" if cat:\n",
" df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(df['act_factor1']**2 + df['act_factor2']**2)\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" return df\n",
"\n"
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a735bc02ceb4d872",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.633429700Z",
"start_time": "2025-02-27T16:36:47.161749Z"
}
},
"outputs": [],
"source": [
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" industry_data = get_act_factor(industry_data, cat=False)\n",
" # industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
"\n",
" for factor in factor_columns:\n",
" if factor in industry_data.columns:\n",
" # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
" industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(lambda x: x - x.median())\n",
"\n",
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(lambda x: x.rank(pct=True))\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "53f86ddc0677a6d7",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.633429700Z",
"start_time": "2025-02-27T16:36:56.276375Z"
},
"scrolled": true
},
"outputs": [],
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'l2_code']]\n",
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dbe2fd8021b9417f",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.633429700Z",
"start_time": "2025-02-27T16:36:56.417914Z"
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 5567976 entries, 1962 to 5567975\n",
"Data columns (total 72 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
" 21 cat_l2_code object \n",
" 22 log_close float64 \n",
" 23 up float64 \n",
" 24 down float64 \n",
" 25 atr_14 float64 \n",
" 26 atr_6 float64 \n",
" 27 obv float64 \n",
" 28 maobv_6 float64 \n",
" 29 obv-maobv_6 float64 \n",
" 30 rsi_3 float64 \n",
" 31 rsi_6 float64 \n",
" 32 rsi_9 float64 \n",
" 33 return_5 float64 \n",
" 34 return_10 float64 \n",
" 35 return_20 float64 \n",
" 36 avg_close_5 float64 \n",
" 37 std_return_5 float64 \n",
" 38 std_return_15 float64 \n",
" 39 std_return_25 float64 \n",
" 40 std_return_90 float64 \n",
" 41 std_return_90_2 float64 \n",
" 42 std_return_5 / std_return_90 float64 \n",
" 43 std_return_5 / std_return_25 float64 \n",
" 44 std_return_90 - std_return_90_2 float64 \n",
" 45 ema_5 float64 \n",
" 46 ema_13 float64 \n",
" 47 ema_20 float64 \n",
" 48 ema_60 float64 \n",
" 49 act_factor1 float64 \n",
" 50 act_factor2 float64 \n",
" 51 act_factor3 float64 \n",
" 52 act_factor4 float64 \n",
" 53 cat_af1 bool \n",
" 54 cat_af2 bool \n",
" 55 cat_af3 bool \n",
" 56 cat_af4 bool \n",
" 57 act_factor5 float64 \n",
" 58 act_factor6 float64 \n",
" 59 rank_act_factor1 float64 \n",
" 60 rank_act_factor2 float64 \n",
" 61 rank_act_factor3 float64 \n",
" 62 active_buy_volume_large float64 \n",
" 63 active_buy_volume_big float64 \n",
" 64 active_buy_volume_small float64 \n",
" 65 buy_lg_vol_minus_sell_lg_vol float64 \n",
" 66 buy_elg_vol_minus_sell_elg_vol float64 \n",
" 67 log(circ_mv) float64 \n",
" 68 alpha_022 float64 \n",
" 69 alpha_003 float64 \n",
" 70 alpha_007 float64 \n",
" 71 alpha_013 float64 \n",
"dtypes: bool(5), datetime64[ns](1), float64(64), object(2)\n",
"memory usage: 2.8+ GB\n",
"None\n"
]
}
],
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_alpha_factor(df)\n",
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d345bcc43b15579e",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.648343Z",
"start_time": "2025-02-27T16:38:24.536432Z"
}
},
"outputs": [],
"source": [
"def create_deviation_within_dates(df, feature_columns):\n",
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
" new_columns = {}\n",
" ret_feature_columns = feature_columns[:]\n",
"\n",
" # 自动选择所有数值型特征\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'industry' not in col]\n",
"\n",
" # 遍历所有数值型特征\n",
" for feature in num_features:\n",
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
" continue\n",
"\n",
" # grouped_median = df.groupby(['trade_date', groupby_col])[feature].transform('median')\n",
" # deviation_col_name = f'deviation_median_{feature}'\n",
" # new_columns[deviation_col_name] = df[feature] - grouped_median\n",
" # ret_feature_columns.append(deviation_col_name)\n",
"\n",
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
" deviation_col_name = f'deviation_mean_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
"\n",
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
"\n",
" return df, ret_feature_columns\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "5f3d9aece75318cd",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.648343Z",
"start_time": "2025-02-27T16:38:24.876179Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Removed 124442 outliers.\n"
]
}
],
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"import pandas as pd\n",
"\n",
"\n",
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99):\n",
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
"\n",
" # Calculate lower and upper bounds based on percentiles\n",
" lower_bound = label.quantile(lower_percentile)\n",
" upper_bound = label.quantile(upper_percentile)\n",
"\n",
" # Filter out values outside the bounds\n",
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
"\n",
" # Print the number of removed outliers\n",
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
" return filtered_label\n",
"\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" df['sharpe_ratio'] = df['future_return'] * df['future_volatility']\n",
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return df['sharpe_ratio']\n",
"\n",
"\n",
"future_close = df.groupby('ts_code')['close'].shift(-4)\n",
"future_return = (future_close - df['close']) / df['close']\n",
"df['label'] = future_return\n",
"df['label'] = remove_outliers_label_percentile(df['label'])\n",
"\n",
"# df = df.apply(lambda x: x.astype('float32') if x.dtype in ['float64', 'float32'] else x)\n",
"df = df.sort_values(by=['trade_date', 'ts_code'])\n",
"train_data = df[(df['trade_date'] <= '2023-01-01') & (df['trade_date'] >= '2016-01-01')]\n",
"test_data = df[(df['trade_date'] >= '2023-01-01') & (df['trade_date'] <= '2025-02-26')]\n",
"\n",
"train_data = train_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
"test_data = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
"\n",
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "93d47ef451968346",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.648343Z",
"start_time": "2025-02-27T16:38:53.216704Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['turnover_rate', 'pe_ttm', 'volume_ratio', 'cat_l2_code', 'log_close', 'up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_5', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'cat_af1', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'log(circ_mv)', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry_ema_5', 'industry_ema_13', 'industry_ema_20', 'industry_ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_obv_deviation', 'industry_return_5_deviation', 'industry_return_20_deviation', 'industry_act_factor1_deviation', 'industry_act_factor2_deviation', 'industry_act_factor3_deviation', 'industry_act_factor4_deviation', 'industry_return_5_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return']\n"
]
}
],
"source": [
"feature_columns = [col for col in train_data.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"print(feature_columns)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "572576eea818c865",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.648343Z",
"start_time": "2025-02-27T16:38:53.484283Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feature_columns size: 221\n",
"feature_columns size: 221\n",
"1136428\n",
"最小日期: 2017-04-06\n",
"最大日期: 2022-12-30\n",
"408601\n",
"最小日期: 2023-01-03\n",
"最大日期: 2025-02-26\n"
]
}
],
"source": [
"feature_columns_new = feature_columns[:]\n",
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"\n",
"train_data = train_data.dropna(subset=feature_columns_new)\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"\n",
"# print(test_data.tail())\n",
"test_data = test_data.dropna(subset=feature_columns_new)\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "2d7e37432f551aea",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.648343Z",
"start_time": "2025-02-27T16:39:44.144179Z"
}
},
"outputs": [],
"source": [
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class MultiTaskNNTransformerModel(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim_nn, hidden_dim_transformer, num_heads, num_layers,\n",
" output_dim, task_type=\"regression\", num_classes=None):\n",
" super(MultiTaskNNTransformerModel, self).__init__()\n",
"\n",
" # 参数检查\n",
" if task_type not in [\"regression\", \"classification\", \"ranking\"]:\n",
" raise ValueError(\"task_type must be 'regression', 'classification', or 'ranking'\")\n",
" if task_type == \"classification\" and num_classes is None:\n",
" raise ValueError(\"num_classes must be specified for classification tasks\")\n",
"\n",
" # 3层全连接神经网络\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim_nn)\n",
" self.fc2 = nn.Linear(hidden_dim_nn, hidden_dim_nn)\n",
" self.fc3 = nn.Linear(hidden_dim_nn, hidden_dim_transformer)\n",
" self.relu = nn.ReLU()\n",
"\n",
" # Transformer层\n",
" self.transformer = nn.Transformer(\n",
" d_model=hidden_dim_transformer, # Transformer的输入维度\n",
" nhead=num_heads, # 多头注意力机制的头数\n",
" num_encoder_layers=num_layers, # 编码器层数\n",
" num_decoder_layers=num_layers, # 解码器层数\n",
" dim_feedforward=hidden_dim_transformer * 2 # 前馈网络的隐藏层维度\n",
" )\n",
"\n",
" # 输出层\n",
" if task_type == \"classification\":\n",
" self.output_layer = nn.Linear(hidden_dim_transformer, num_classes) # 分类任务\n",
" else:\n",
" self.output_layer = nn.Linear(hidden_dim_transformer, output_dim) # 回归或排序任务\n",
"\n",
" # 激活函数用于分类任务的softmax或排序任务的sigmoid\n",
" self.task_type = task_type\n",
" if task_type == \"classification\":\n",
" self.activation = nn.Softmax(dim=1) # 分类任务使用Softmax\n",
" elif task_type == \"ranking\":\n",
" self.activation = nn.Sigmoid() # 排序任务使用Sigmoid\n",
" else:\n",
" self.activation = None # 回归任务不需要激活函数\n",
"\n",
" def forward(self, x):\n",
" \"\"\"\n",
" 前向传播函数。\n",
"\n",
" :param x: 输入数据 (batch_size, input_dim)\n",
" :return: 模型的输出 (batch_size, output_dim)\n",
" \"\"\"\n",
" # Step 1: 通过3层全连接神经网络提取特征\n",
" x = self.relu(self.fc1(x)) # 第一层\n",
" x = self.relu(self.fc2(x)) # 第二层\n",
" x = self.relu(self.fc3(x)) # 第三层\n",
"\n",
" # Step 2: 添加序列维度 (batch_size, seq_len=1, hidden_dim_transformer)\n",
" x = x.unsqueeze(1)\n",
"\n",
" # Step 3: 通过Transformer处理\n",
" transformer_output = self.transformer(x, x).squeeze(1) # 移除序列维度\n",
"\n",
" # Step 4: 通过输出层得到最终预测\n",
" output = self.output_layer(transformer_output)\n",
"\n",
" # Step 5: 根据任务类型应用激活函数\n",
" if self.activation is not None:\n",
" output = self.activation(output)\n",
"\n",
" return output"
],
"id": "d8a5eba1e446b79a"
},
{
"cell_type": "code",
"execution_count": 25,
"id": "8f134d435f71e9e2",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.668543Z",
"start_time": "2025-02-27T16:39:44.618725Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"# 自定义Dataset类\n",
"class TradeDateDataset(Dataset):\n",
" def __init__(self, data_df, feature_columns, label_column, scaler=None, one_hot_encoder=None):\n",
" self.data = []\n",
"\n",
" # 按trade_date分组\n",
" grouped = data_df.groupby('trade_date')\n",
" for trade_date, group in grouped:\n",
" X = group[feature_columns].values\n",
" y = group[label_column].values\n",
"\n",
" # 标准化数值型特征\n",
" if scaler is not None:\n",
" numeric_columns = group[feature_columns].select_dtypes(include=['float64', 'int64']).columns\n",
" X[:, [feature_columns.index(col) for col in numeric_columns]] = scaler.transform(\n",
" group[numeric_columns]\n",
" )\n",
"\n",
" # 对类别型特征进行One-Hot编码\n",
" if one_hot_encoder is not None:\n",
" categorical_columns = group[feature_columns].select_dtypes(include=['object', 'category']).columns\n",
" X_categorical = one_hot_encoder.transform(group[categorical_columns]).toarray()\n",
" X = np.hstack([X, X_categorical]) # 将One-Hot编码与数值型特征拼接\n",
"\n",
" self.data.append((X, y))\n",
"\n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n",
"\n",
"def train_transformer_model(train_data_df, test_data_df, feature_columns, label_column,\n",
" output_dim=1, task='regression',\n",
" hidden_dim=64, num_heads=4, num_layers=2,\n",
" dropout=0.1, learning_rate=0.001, num_epochs=10, batch_size=32):\n",
" # 数据预处理\n",
" train_data_df = train_data_df.dropna(subset=[label_column])\n",
" test_data_df = test_data_df.dropna(subset=[label_column])\n",
"\n",
" # 标准化数值型特征\n",
" scaler = StandardScaler()\n",
" numeric_columns = train_data_df[feature_columns].select_dtypes(include=['float64', 'int64']).columns\n",
" scaler.fit(train_data_df[numeric_columns])\n",
"\n",
" # 对类别型特征进行One-Hot编码\n",
" one_hot_encoder = OneHotEncoder(sparse=False, handle_unknown='ignore')\n",
" categorical_columns = train_data_df[feature_columns].select_dtypes(include=['object', 'category']).columns\n",
" one_hot_encoder.fit(train_data_df[categorical_columns])\n",
"\n",
" # 创建Dataset\n",
" train_dataset = TradeDateDataset(train_data_df, feature_columns, label_column, scaler, one_hot_encoder)\n",
" test_dataset = TradeDateDataset(test_data_df, feature_columns, label_column, scaler, one_hot_encoder)\n",
"\n",
" # 创建DataLoader\n",
" train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
" test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
"\n",
" # 初始化模型、损失函数和优化器\n",
" input_dim = len(numeric_columns) + len(one_hot_encoder.get_feature_names_out())\n",
" model = TransformerModel(input_dim, hidden_dim, output_dim, num_heads, num_layers, dropout)\n",
"\n",
" if task == 'classification':\n",
" criterion = nn.CrossEntropyLoss()\n",
" elif task == 'regression':\n",
" criterion = nn.MSELoss()\n",
" else:\n",
" raise ValueError(\"Unsupported task type. Choose 'classification' or 'regression'.\")\n",
"\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
"\n",
" # 训练循环\n",
" train_losses, val_losses = [], []\n",
" for epoch in range(num_epochs):\n",
" model.train()\n",
" running_loss = 0.0\n",
"\n",
" for X_batch, y_batch in train_loader:\n",
" X_batch = torch.tensor(X_batch, dtype=torch.float32)\n",
" if task == 'classification':\n",
" y_batch = torch.tensor(y_batch, dtype=torch.long)\n",
" elif task == 'regression':\n",
" y_batch = torch.tensor(y_batch, dtype=torch.float32).unsqueeze(1)\n",
"\n",
" # 前向传播\n",
" output = model(X_batch)\n",
"\n",
" # 计算损失\n",
" loss = criterion(output, y_batch)\n",
"\n",
" # 反向传播\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" running_loss += loss.item()\n",
"\n",
" train_losses.append(running_loss / len(train_loader))\n",
" print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {train_losses[-1]:.4f}\")\n",
"\n",
" # 测试阶段\n",
" model.eval()\n",
" val_loss = 0.0\n",
" with torch.no_grad():\n",
" for X_batch, y_batch in test_loader:\n",
" X_batch = torch.tensor(X_batch, dtype=torch.float32)\n",
" if task == 'classification':\n",
" y_batch = torch.tensor(y_batch, dtype=torch.long)\n",
" elif task == 'regression':\n",
" y_batch = torch.tensor(y_batch, dtype=torch.float32).unsqueeze(1)\n",
"\n",
" output = model(X_batch)\n",
" loss = criterion(output, y_batch)\n",
" val_loss += loss.item()\n",
"\n",
" val_losses.append(val_loss / len(test_loader))\n",
" print(f\"Validation Loss: {val_losses[-1]:.4f}\")\n",
"\n",
" # 可视化损失曲线\n",
" plt.plot(train_losses, label=\"Training Loss\")\n",
" plt.plot(val_losses, label=\"Validation Loss\")\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
" return model, scaler, one_hot_encoder"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "4a4542e1ed6afe7d",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.668543Z",
"start_time": "2025-02-27T16:39:44.838044Z"
}
},
"outputs": [],
"source": [
"light_params = {\n",
" # 'objective': 'regression',\n",
" # 'metric': 'l2',\n",
" 'objective': 'quantile', # 分位回归\n",
" 'metric': 'quantile', # 使用 quantile 作为评估指标\n",
" 'alpha': 0.75, # 90% 分位数\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 1024,\n",
" 'min_data_in_leaf': 128,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
" 'feature_fraction': 0.7,\n",
" 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" 'lambda_l1': 1,\n",
" 'lambda_l2': 1,\n",
" # 'boosting_type': 'dart',\n",
" 'verbosity': -1,\n",
" 'seed': 17\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.668543Z",
"start_time": "2025-02-27T16:39:45.091206Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 1136428\n",
"categorical_feature: [3, 35, 36, 37, 38]\n",
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
"[31]\ttrain's quantile: 0.014632\tvalid's quantile: 0.0152371\n",
"Evaluated only: quantile\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAksAAAHFCAYAAADi7703AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABvFElEQVR4nO3deVxU9f7H8dcwbLJvyqIgIO67oLlkaKam5W0xWy2t7HfNuqW2d5f2tJuVdUttcWnPulbXSlNMRStXFHNfUVxABERAFAY4vz9GRhEkQHRY3s/H4zxgznzPOZ/PDDAfvud7vsdkGIaBiIiIiJTLwd4BiIiIiNRmKpZEREREKqBiSURERKQCKpZEREREKqBiSURERKQCKpZEREREKqBiSURERKQCKpZEREREKqBiSURERKQCKpZE6qA5c+ZgMpkwmUwsX768zPOGYRAVFYXJZKJfv37VOsa0adOYM2dOlbZZvnz5BWOqKZfqGJcj9gvZtm0bzz//PPv3778k+3/++ecxmUzV2taer4tIbaFiSaQO8/T0ZObMmWXWx8fHs3fvXjw9Pau97+oUS926dWPVqlV069at2se1F3vGvm3bNl544YVLViyNGTOGVatWVWvbuvyeitQUFUsiddhtt93GvHnzyM7OLrV+5syZ9OrVi7CwsMsSh8ViobCwEC8vL3r27ImXl9dlOW5NqIux5+XlVal9s2bN6NmzZ7WOVZdeF5FLRcWSSB12xx13APDll1/a1p04cYJ58+Zx3333lbtNQUEBL7/8Mm3atMHFxYXGjRtz7733cuzYMVub8PBwtm7dSnx8vO10X3h4OHD2tMynn37KY489RtOmTXFxcWHPnj0XPGWzZs0ahg0bhr+/P66urrRo0YLx48f/aX47duzg2muvxc3NjYCAAMaOHUtOTk6ZduHh4YwePbrM+n79+pU6DVnV2EePHo2Hhwd79uxh6NCheHh4EBoaymOPPUZ+fn6pYx06dIhbbrkFT09PfHx8uOuuu1i3bh0mk6nCHro5c+YwYsQIAPr37297vUu26devHx06dGDFihX07t0bNzc323s7d+5cBg0aRHBwMI0aNaJt27Y8/fTTnDx5stQxyjsNFx4ezvXXX8/PP/9Mt27daNSoEW3atGHWrFml2tnrdRGpTVQsidRhXl5e3HLLLaU+4L788kscHBy47bbbyrQvLi7mhhtuYPLkydx555389NNPTJ48mbi4OPr168epU6cA+O6774iMjKRr166sWrWKVatW8d1335Xa1zPPPENycjIzZszghx9+oEmTJuXGuGjRIvr27UtycjJvvvkmCxcu5B//+AdHjx6tMLejR48SGxvLli1bmDZtGp9++im5ubk8/PDDVX2Zyqhs7GDtefrLX/7CgAED+N///sd9993HW2+9xWuvvWZrc/LkSfr378+yZct47bXX+PrrrwkMDCz3PTjfddddx6uvvgrAe++9Z3u9r7vuOlublJQURo4cyZ133smCBQsYN24cALt372bo0KHMnDmTn3/+mfHjx/P1118zbNiwSr0OmzZt4rHHHmPChAn873//o1OnTtx///2sWLHiT7e91K+LSK1iiEidM3v2bAMw1q1bZyxbtswAjC1bthiGYRjdu3c3Ro8ebRiGYbRv396IjY21bffll18agDFv3rxS+1u3bp0BGNOmTbOtO3/bEiXHu+qqqy743LJly2zrWrRoYbRo0cI4depUlXJ86qmnDJPJZCQmJpZaP3DgwDLHaN68uTFq1Kgy+4iNjS2VQ1VjHzVqlAEYX3/9dam2Q4cONVq3bm17/N577xmAsXDhwlLt/vrXvxqAMXv27Apz/eabb8oc+9wcAOOXX36pcB/FxcWGxWIx4uPjDcDYtGmT7bnnnnvOOP/PffPmzQ1XV1fjwIEDtnWnTp0y/Pz8jL/+9a+2dfZ8XURqC/UsidRxsbGxtGjRglmzZrF582bWrVt3wVNwP/74Iz4+PgwbNozCwkLb0qVLF4KCgqp0xdPw4cP/tM2uXbvYu3cv999/P66urpXeN8CyZcto3749nTt3LrX+zjvvrNJ+ylOZ2EuYTKYyPTWdOnXiwIEDtsfx8fF4enpy7bXXlmpXcpr0Yvn6+nL11VeXWb9v3z7uvPNOgoKCMJvNODk5ERsbC8D27dv/dL9dunQpNa7N1dWVVq1alcrtQmrD6yJyuTjaOwARuTgmk4l7772Xd955h9OnT9OqVSv69u1bbtujR4+SlZWFs7Nzuc+np6dX+rjBwcF/2qZkHFSzZs0qvd8SGRkZRERElFkfFBRU5X2drzKxl3BzcytT6Lm4uHD69Gnb44yMDAIDA8tsW9666igv3tzcXPr27Yurqysvv/wyrVq1ws3NjYMHD3LzzTfbTqlWxN/fv8w6FxeXSm1bG14XkctFxZJIPTB69Gj+9a9/MWPGDF555ZULtgsICMDf35+ff/653OerMtVAZebtady4MWAd5FtV/v7+pKamlllf3jpXV9cyA4vBWvwFBASUWV/dOYcuxN/fn7Vr15ZZX16s1VFevEuXLuXIkSMsX77c1psEkJWVVSPHrAmX+nURuVx0Gk6kHmjatClPPPEEw4YNY9SoURdsd/3115ORkUFRURExMTFlltatW9vaVraHoSKtWrWynSIsr5ipSP/+/dm6dSubNm0qtf6LL74o0zY8PJw//vij1Lpdu3axc+fOqgddDbGxseTk5LBw4cJS67/66qtKbe/i4gJQpde7pIAq2bbE+++/X+l9XGoX+7qI1BbqWRKpJyZPnvynbW6//XY+//xzhg4dyqOPPkqPHj1wcnLi0KFDLFu2jBtuuIGbbroJgI4dO/LVV18xd+5cIiMjcXV1pWPHjlWO67333mPYsGH07NmTCRMmEBYWRnJyMosWLeLzzz+/4Hbjx49n1qxZXHfddbz88ssEBgby+eefs2PHjjJt7777bkaOHMm4ceMYPnw4Bw4c4N///retZ+tSGzVqFG+99RYjR47k5ZdfJioqioULF7Jo0SIAHBwq/r+0Q4cOAHzwwQd4enri6upKREREuafJSvTu3RtfX1/Gjh3Lc889h5OTE59//nmZ4tKeLvZ1Eakt9JMq0oCYzWbmz5/Ps88+y7fffstNN93EjTfeyOTJk8sUQy+88AKxsbE88MAD9OjRo9KXo59v8ODBrFixguDgYB555BGuvfZaXnzxxT8dtxIUFER8fDzt2rXjwQcfZOTIkbi6uvLuu++WaXvnnXfy73//m0WLFnH99dczffp0pk+fTqtWraoVc1W5u7uzdOlS+vXrx5NPPsnw4cNJTk5m2rRpAPj4+FS4fUREBFOnTmXTpk3069eP7t2788MPP1S4jb+/Pz/99BNubm6MHDmS++67Dw8PD+bOnVtTaV20i31dRGoLk2EYhr2DEBGpj1599VX+8Y9/kJycXK1B7vWVXhepa3QaTkSkBpT0eLVp0waLxcLSpUt55513GDlyZIMuCPS6SH2gYklEpAa4ubnx1ltvsX//fvLz8wkLC+Opp57iH//4h71Dsyu9LlIf6DSciIiISAU0wFtERESkAiqWRERERCqgYklERESkAhrgXU3FxcUcOXIET0/PGr91goiIiFwahmGQk5NDSEhIpSdGVbFUTUeOHCE0NNTeYYiIiEg1HDx4sNLTV6hYqqaSG44mJSXh5+dn52guD4vFwuLFixk0aBBOTk72DueyaYh5K2flXJ81xLyV89mcs7OzCQ0NrdKNw1UsVVPJqTdPT0+8vLzsHM3lYbFYcHNzw8vLq8H8skHDzFs5K+f6rCHmrZzL5lyVITQa4C0iIiJSARVLIiIiIhVQsSQiIiJSAY1ZEhERqUWKi4spKCio0X1aLBYcHR05ffo0RUVFNbrv2sbJyQmz2Vyj+1SxJCIiUksUFBSQlJREcXFxje7XMAyCgoI4ePBgg5gb0MfHB39//xrbn4olERGRWsAwDFJSUjCbzYSGhlZ6wsTKKC4uJjc3Fw8Pjxrdb21jGAZ5eXmkpaXVaA+aiiUREZFaoLCwkLy8PEJCQnBzc6vRfZec2nN1da3XxRJAo0a
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsgAAAHFCAYAAADv3Q81AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1RUR/vA8e/SmyC9KAIq2Cs2NBELoIjYYokaxNjim2KPPWqMomKPlVhQEyN5DWpioigKithAhbyCLfYGKhFBUanz+8PD/bkCAkZRkvmcsyfZO3PvnXlY3IfZ2RmVEEIgSZIkSZIkSRIAGm+7AZIkSZIkSZL0LpEJsiRJkiRJkiQ9RybIkiRJkiRJkvQcmSBLkiRJkiRJ0nNkgixJkiRJkiRJz5EJsiRJkiRJkiQ9RybIkiRJkiRJkvQcmSBLkiRJkiRJ0nNkgixJkiRJkiRJz5EJsiRJ0j/Ihg0bUKlUhT7GjRv3Ru555swZZsyYwdWrV9/I9f+Oq1evolKp2LBhw9tuyivbtWsXM2bMeNvNkKR/Fa233QBJkiTp9QsODqZmzZpqx+zs7N7Ivc6cOcPXX39NmzZtcHR0fCP3eFW2trYcPXqUatWqve2mvLJdu3axYsUKmSRLUhmSCbIkSdI/UN26dWnSpMnbbsbfkp2djUqlQkvr1d+qdHV1adGixWtsVdl5/PgxBgYGb7sZkvSvJKdYSJIk/Qv99NNPuLm5YWhoiJGRER06dCAuLk6tzokTJ/jwww9xdHREX18fR0dH+vbty7Vr15Q6GzZsoFevXgC0bdtWmc6RP6XB0dGRgQMHFrh/mzZtaNOmjfL8wIEDqFQqvv/+e8aOHUulSpXQ1dXl4sWLAOzbt4/27dtjbGyMgYEBrVq1Yv/+/cX2s7ApFjNmzEClUvG///2PXr16YWJigpmZGWPGjCEnJ4fz58/TsWNHKlSogKOjI4GBgWrXzG/rDz/8wJgxY7CxsUFfXx93d/cCMQT49ddfcXNzw8DAgAoVKuDp6cnRo0fV6uS36dSpU/Ts2RNTU1OqVavGwIEDWbFiBYDadJn86SwrVqygdevWWFlZYWhoSL169QgMDCQ7O7tAvOvWrUtsbCzvv/8+BgYGVK1alblz55KXl6dW98GDB4wdO5aqVauiq6uLlZUVnTp14ty5c0qdrKwsZs2aRc2aNdHV1cXS0pKPP/6Ye/fuFfszkaTyQCbIkiRJ/0C5ubnk5OSoPfIFBATQt29fateuzX//+1++//57Hj58yPvvv8+ZM2eUelevXqVGjRosWbKEPXv2MG/ePJKSkmjatCkpKSkA+Pj4EBAQADxL1o4ePcrRo0fx8fF5pXZPmjSJ69evs3r1anbu3ImVlRU//PADXl5eGBsbs3HjRv773/9iZmZGhw4dSpQkF6V37940aNCA0NBQhg4dyuLFixk9ejTdunXDx8eH7du3065dOyZMmMC2bdsKnD958mQuX77M2rVrWbt2Lbdv36ZNmzZcvnxZqfPjjz/StWtXjI2N2bJlC+vWrSM1NZU2bdoQHR1d4Jo9evSgevXqbN26ldWrV/PVV1/Rs2dPACW2R48exdbWFoBLly7Rr18/vv/+e3777TcGDx7M/Pnz+eSTTwpcOzk5mf79+/PRRx/x66+/4u3tzaRJk/jhhx+UOg8fPuS9994jKCiIjz/+mJ07d7J69WpcXFxISkoCIC8vj65duzJ37lz69evH77//zty5cwkPD6dNmzY8efLklX8mkvTOEJIkSdI/RnBwsAAKfWRnZ4vr168LLS0t8cUXX6id9/DhQ2FjYyN69+5d5LVzcnLEo0ePhKGhoVi6dKlyfOvWrQIQkZGRBc5xcHAQ/v7+BY67u7sLd3d35XlkZKQAROvWrdXqZWRkCDMzM+Hr66t2PDc3VzRo0EA0a9bsJdEQ4sqVKwIQwcHByrHp06cLQCxcuFCtbsOGDQUgtm3bphzLzs4WlpaWokePHgXa2rhxY5GXl6ccv3r1qtDW1hZDhgxR2mhnZyfq1asncnNzlXoPHz4UVlZWomXLlgXaNG3atAJ9+Oyzz0RJ3q5zc3NFdna22LRpk9DU1BT3799Xytzd3QUgjh8/rnZO7dq1RYcOHZTnM2fOFIAIDw8v8j5btmwRgAgNDVU7HhsbKwCxcuXKYtsqSe86OYIsSZL0D7Rp0yZiY2PVHlpaWuzZs4ecnBwGDBigNrqsp6eHu7s7Bw4cUK7x6NEjJkyYQPXq1dHS0kJLSwsjIyMyMjI4e/bsG2n3Bx98oPb8yJEj3L9/H39/f7X25uXl0bFjR2JjY8nIyHile3Xu3Fntea1atVCpVHh7eyvHtLS0qF69utq0knz9+vVDpVIpzx0cHGjZsiWRkZEAnD9/ntu3b+Pn54eGxv+/3RoZGfHBBx9w7NgxHj9+/NL+FycuLo4uXbpgbm6OpqYm2traDBgwgNzcXC5cuKBW18bGhmbNmqkdq1+/vlrfdu/ejYuLCx4eHkXe87fffqNixYr4+vqq/UwaNmyIjY2N2mtIksor+SU9SZKkf6BatWoV+iW9O3fuANC0adNCz3s+kevXrx/79+/nq6++omnTphgbG6NSqejUqdMb+xg9f+rAi+3Nn2ZQmPv372NoaFjqe5mZmak919HRwcDAAD09vQLH09PTC5xvY2NT6LE//vgDgL/++gso2Cd4tqJIXl4eqampal/EK6xuUa5fv877779PjRo1WLp0KY6Ojujp6RETE8Nnn31W4Gdkbm5e4Bq6urpq9e7du0eVKlVeet87d+7w4MEDdHR0Ci3Pn34jSeWZTJAlSZL+RSwsLAD4+eefcXBwKLJeWloav/32G9OnT2fixInK8czMTO7fv1/i++np6ZGZmVngeEpKitKW5z0/Ivt8e5ctW1bkahTW1tYlbs/rlJycXOix/EQ0/7/5c3efd/v2bTQ0NDA1NVU7/mL/X2bHjh1kZGSwbds2tZ9lfHx8ia/xIktLS27evPnSOhYWFpibmxMWFlZoeYUKFV75/pL0rpAJsiRJ0r9Ihw4d0NLS4tKlSy/9OF+lUiGEQFdXV+342rVryc3NVTuWX6ewUWVHR0f+97//qR27cOEC58+fLzRBflGrVq2oWLEiZ86c4fPPPy+2flnasmULY8aMUZLaa9euceTIEQYMGABAjRo1qFSpEj/++CPjxo1T6mVkZBAaGqqsbFGc5+Orr6+vHM+/3vM/IyEEa9aseeU+eXt7M23aNCIiImjXrl2hdTp37kxISAi5ubk0b978le8lSe8ymSBLkiT9izg6OjJz5kymTJnC5cuX6dixI6ampty5c4eYmBgMDQ35+uuvMTY2pnXr1syfPx8LCwscHR05ePAg69ato2LFimrXrFu3LgDfffcdFSpUQE9PDycnJ8zNzfHz8+Ojjz7i008/5YMPPuDatWsEBgZiaWlZovYaGRmxbNky/P39uX//Pj179sTKyop79+7xxx9/cO/ePVatWvW6w1Qid+/epXv37gwdOpS0tDSmT5+Onp4ekyZNAp5NVwkMDKR///507tyZTz75hMzMTObPn8+DBw+YO3duie5Tr149AObNm4e3tzeamprUr18fT09PdHR06Nu3L+PHj+fp06esWrWK1NTUV+7TqFGj+Omnn+jatSsTJ06kWbNmPHnyhIMHD9K5c2fatm3Lhx9+yObNm+nUqRMjR46kWbNmaGtrc/PmTSIjI+natSvdu3d/5TZI0jvhbX9LUJIkSXp98lexiI2NfWm9HTt2iLZt2wpjY2Ohq6srHBwcRM+ePcW+ffuUOjdv3hQffPCBMDU1FRUqVBAdO3YUCQkJha5MsWTJEuHk5CQ0NTXVVo3Iy8sTgYGBomrVqkJPT080adJEREREFLmKxdatWwtt78GDB4WPj48wMzMT2traolKlSsLHx6fI+vletorFvXv31Or6+/sLQ0PDAtdwd3cXderUKdDW77//XowYMUJYWloKXV1d8f7774sTJ04UOH/Hjh2iefPmQk9PTxgaGor27duLw4cPq9Upqk1CCJGZmSmGDBkiLC0thUqlEoC4cuWKEEKInTt3igY
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print('train data size: ', len(train_data))\n",
"\n",
"evals = {}\n",
"model = train_light_model(train_data, test_data, light_params, feature_columns_new,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=500, use_optuna=False,\n",
" print_feature_importance=True)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "63235069-dc59-48fb-961a-e80373e41a61",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.668543Z",
"start_time": "2025-02-27T16:41:14.573661Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 1136428\n"
]
}
],
"source": [
"print('train data size: ', len(train_data))\n",
"\n",
"catboost_params = {\n",
" 'loss_function': 'MAE', # 90% 分位回归\n",
" 'iterations': 5000, # 训练轮数\n",
" 'learning_rate': 0.05, # 学习率,较低以防止过拟合\n",
" 'depth': 10, # 树的深度,防止过拟合\n",
" # 'l1_leaf_reg': 10.0, # l1 正则化,提高泛化能力\n",
" # 'bagging_temperature': 1, # 降低过拟合\n",
" # 'subsample': 0.8, # 每轮随机 80% 的样本,减少过拟合\n",
" 'colsample_bylevel': 0.8, # 每层 80% 特征子集,防止过拟合\n",
" 'random_seed': 42, # 固定随机种子,保证可复现\n",
" 'verbose': 500, # 每 100 轮打印一次信息\n",
" 'early_stopping_rounds': 100, # 早停,防止过拟合\n",
" # 'task_type': 'GPU'\n",
"}\n",
"\n",
"# model = train_catboost(train_data, test_data, feature_columns_new, catboost_params)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "465944b1d463e4b1",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.668543Z",
"start_time": "2025-02-27T16:41:14.863552Z"
}
},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"\n",
"def incremental_training(test_data: pd.DataFrame,\n",
" model,\n",
" days: int,\n",
" back_days: int,\n",
" feature_columns: list,\n",
" params: dict,\n",
" model_type: str = 'lightgbm'):\n",
" if model_type not in ['lightgbm', 'catboost']:\n",
" raise ValueError(\"model_type must be either 'lightgbm' or 'catboost'\")\n",
"\n",
" test_data = test_data.sort_values(by='trade_date')\n",
" scores = []\n",
" unique_trade_dates = sorted(test_data['trade_date'].unique())\n",
"\n",
" new_model = None\n",
" for i in tqdm(range(0, len(unique_trade_dates), days)):\n",
" # Get the current window of trade dates\n",
" current_dates = unique_trade_dates[i:i + days]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" window_data = window_data.sort_values(by=['ts_code', 'trade_date'])\n",
" X = window_data[feature_columns]\n",
"\n",
" if new_model is not None:\n",
" window_scores = new_model.predict(X, prediction_type='RawFormulaVal')\n",
" else:\n",
" window_scores = model.predict(X, prediction_type='RawFormulaVal')\n",
" scores.extend(window_scores)\n",
"\n",
" # # Prepare data for incremental training\n",
" # current_dates = unique_trade_dates[max(0, i - back_days):i + days]\n",
" # window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" # X_train = window_data[feature_columns]\n",
" window_data = window_data.dropna(subset=['label'])\n",
" X_train = window_data[feature_columns]\n",
" y_train = window_data['label'] # Assuming 'label' is what you're predicting\n",
" # Incrementally train the model\n",
" if len(y_train.unique()) > 1:\n",
" if model_type == 'lightgbm':\n",
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
" new_model = lgb.train(params,\n",
" train_set=train_data,\n",
" num_boost_round=100,\n",
" init_model=model,\n",
" keep_training_booster=True)\n",
" elif model_type == 'catboost':\n",
" from catboost import Pool\n",
" train_data = Pool(data=X_train, label=y_train, cat_features=[col for col in feature_columns if col.startswith('cat')])\n",
" # model.set_params(**params)\n",
" model.fit(train_data, init_model=model)\n",
" else:\n",
" print(current_dates)\n",
"\n",
" # Add the scores as a new 'score' column to the test_data\n",
" test_data['score'] = scores\n",
" return test_data"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "e3ac761d8f0b5d31",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.672656900Z",
"start_time": "2025-02-27T16:41:15.107891Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████| 104/104 [00:51<00:00, 2.03it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Timestamp('2025-02-25 00:00:00'), Timestamp('2025-02-26 00:00:00')]\n"
]
}
],
"source": [
"\n",
"predictions_test = incremental_training(test_data, model, 5, 0, feature_columns_new, light_params, model_type='lightgbm')\n",
"predictions_test = predictions_test.loc[predictions_test.groupby('trade_date')['score'].idxmax()]\n",
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "b427ce41-9739-4e9e-bea8-5f2551fec5d7",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.672656900Z",
"start_time": "2025-02-27T16:42:45.229233Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"import joblib\n",
"import lightgbm as lgb\n",
"import pandas as pd\n",
"\n",
"# 假设你已经训练好了一个 LightGBM 模型\n",
"# model = lgb.train(params, train_data, ...)\n",
"\n",
"def save_model_with_info(model, params, feature_columns, train_data, info, save_path):\n",
" \"\"\"\n",
" 保存 LightGBM 模型及其相关信息。\n",
" \n",
" 参数:\n",
" model: 训练好的 LightGBM 模型 (lgb.Booster)。\n",
" params: 模型的参数 (dict)。\n",
" feature_columns: 特征列名列表 (list)。\n",
" train_data: 训练数据 (pd.DataFrame),包含 'trade_date' 列。\n",
" info: 额外信息 (str 或 dict)。\n",
" save_path: 保存路径 (str)。\n",
" \"\"\"\n",
" # 提取训练数据的 trade_date 的最大值和最小值\n",
" if 'trade_date' not in train_data.columns:\n",
" raise ValueError(\"训练数据中必须包含 'trade_date' 列。\")\n",
" \n",
" trade_date_min = train_data['trade_date'].min()\n",
" trade_date_max = train_data['trade_date'].max()\n",
" \n",
" # 构建保存的信息字典\n",
" model_info = {\n",
" 'model': model, # 保存模型本身\n",
" 'params': params, # 模型参数\n",
" 'feature_columns': feature_columns, # 特征列名\n",
" 'trade_date_range': {\n",
" 'min': trade_date_min,\n",
" 'max': trade_date_max\n",
" }, # trade_date 的范围\n",
" 'info': info # 额外信息\n",
" }\n",
" \n",
" # 使用 joblib 保存模型及相关信息\n",
" joblib.dump(model_info, save_path)\n",
" print(f\"模型及相关信息已成功保存到 {save_path}\")\n",
"\n",
" \n",
"\n",
"# info = \"Update Regression + 滚动new model + 5days\"\n",
" \n",
"# # 保存模型及相关信息\n",
"# save_path = \"../model/lightgbm_model_UpdateRegression_2025-2-25.pkl\"\n",
"# save_model_with_info(model, light_params, feature_columns, train_data, info, save_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f9a2b7b-11fe-4eb5-aa11-c4066fe418a1",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-01T04:00:27.672656900Z",
"start_time": "2025-02-27T16:42:45.450227Z"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}