1718 lines
85 KiB
Plaintext
1718 lines
85 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"id": "79a7758178bafdd3",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:44:14.217154Z",
|
||
"start_time": "2025-03-19T14:44:13.935883Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"# %load_ext autoreload\n",
|
||
"# %autoreload 2\n",
|
||
"\n",
|
||
"import pandas as pd\n",
|
||
"import warnings\n",
|
||
"\n",
|
||
"warnings.filterwarnings(\"ignore\")\n",
|
||
"\n",
|
||
"pd.set_option('display.max_columns', None)\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 1
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "a79cafb06a7e0e43",
|
||
"metadata": {
|
||
"scrolled": true,
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:45:19.808772Z",
|
||
"start_time": "2025-03-19T14:44:14.221159Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"from code.utils.utils import read_and_merge_h5_data\n",
|
||
"\n",
|
||
"print('daily data')\n",
|
||
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
|
||
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],\n",
|
||
" df=None)\n",
|
||
"\n",
|
||
"print('daily basic')\n",
|
||
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
|
||
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
|
||
" 'is_st'], df=df, join='inner')\n",
|
||
"\n",
|
||
"print('stk limit')\n",
|
||
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
|
||
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
|
||
" df=df)\n",
|
||
"print('money flow')\n",
|
||
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
|
||
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
|
||
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
|
||
" df=df)\n",
|
||
"print('cyq perf')\n",
|
||
"df = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf',\n",
|
||
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
|
||
" 'cost_50pct',\n",
|
||
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
|
||
" df=df)\n",
|
||
"print(df.info())"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"daily data\n",
|
||
"daily basic\n",
|
||
"inner merge on ['ts_code', 'trade_date']\n",
|
||
"stk limit\n",
|
||
"left merge on ['ts_code', 'trade_date']\n",
|
||
"money flow\n",
|
||
"left merge on ['ts_code', 'trade_date']\n",
|
||
"cyq perf\n",
|
||
"left merge on ['ts_code', 'trade_date']\n",
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"RangeIndex: 8418207 entries, 0 to 8418206\n",
|
||
"Data columns (total 31 columns):\n",
|
||
" # Column Dtype \n",
|
||
"--- ------ ----- \n",
|
||
" 0 ts_code object \n",
|
||
" 1 trade_date datetime64[ns]\n",
|
||
" 2 open float64 \n",
|
||
" 3 close float64 \n",
|
||
" 4 high float64 \n",
|
||
" 5 low float64 \n",
|
||
" 6 vol float64 \n",
|
||
" 7 pct_chg float64 \n",
|
||
" 8 turnover_rate float64 \n",
|
||
" 9 pe_ttm float64 \n",
|
||
" 10 circ_mv float64 \n",
|
||
" 11 volume_ratio float64 \n",
|
||
" 12 is_st bool \n",
|
||
" 13 up_limit float64 \n",
|
||
" 14 down_limit float64 \n",
|
||
" 15 buy_sm_vol float64 \n",
|
||
" 16 sell_sm_vol float64 \n",
|
||
" 17 buy_lg_vol float64 \n",
|
||
" 18 sell_lg_vol float64 \n",
|
||
" 19 buy_elg_vol float64 \n",
|
||
" 20 sell_elg_vol float64 \n",
|
||
" 21 net_mf_vol float64 \n",
|
||
" 22 his_low float64 \n",
|
||
" 23 his_high float64 \n",
|
||
" 24 cost_5pct float64 \n",
|
||
" 25 cost_15pct float64 \n",
|
||
" 26 cost_50pct float64 \n",
|
||
" 27 cost_85pct float64 \n",
|
||
" 28 cost_95pct float64 \n",
|
||
" 29 weight_avg float64 \n",
|
||
" 30 winner_rate float64 \n",
|
||
"dtypes: bool(1), datetime64[ns](1), float64(28), object(1)\n",
|
||
"memory usage: 1.9+ GB\n",
|
||
"None\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 2
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "cac01788dac10678",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:45:35.459083Z",
|
||
"start_time": "2025-03-19T14:45:20.342424Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"print('industry')\n",
|
||
"industry_df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
|
||
" columns=['ts_code', 'l2_code', 'in_date'],\n",
|
||
" df=None, on=['ts_code'], join='left')\n",
|
||
"\n",
|
||
"\n",
|
||
"def merge_with_industry_data(df, industry_df):\n",
|
||
" # 确保日期字段是 datetime 类型\n",
|
||
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
|
||
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
|
||
"\n",
|
||
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
|
||
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
|
||
"\n",
|
||
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
|
||
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
|
||
"\n",
|
||
" # 使用 merge_asof 进行向后合并\n",
|
||
" merged = pd.merge_asof(\n",
|
||
" df_sorted,\n",
|
||
" industry_df_sorted,\n",
|
||
" by='ts_code', # 按 ts_code 分组\n",
|
||
" left_on='trade_date',\n",
|
||
" right_on='in_date',\n",
|
||
" direction='backward'\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 获取每个 ts_code 的最早 in_date 记录\n",
|
||
" min_in_date_per_ts = (industry_df_sorted\n",
|
||
" .groupby('ts_code')\n",
|
||
" .first()\n",
|
||
" .reset_index()[['ts_code', 'l2_code']])\n",
|
||
"\n",
|
||
" # 填充未匹配到的记录(trade_date 早于所有 in_date 的情况)\n",
|
||
" merged['l2_code'] = merged['l2_code'].fillna(\n",
|
||
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 保留需要的列并重置索引\n",
|
||
" result = merged.reset_index(drop=True)\n",
|
||
" return result\n",
|
||
"\n",
|
||
"\n",
|
||
"# 使用示例\n",
|
||
"df = merge_with_industry_data(df, industry_df)\n",
|
||
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"industry\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 3
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "c4e9e1d31da6dba6",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:45:35.759821Z",
|
||
"start_time": "2025-03-19T14:45:35.492414Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"def calculate_indicators(df):\n",
|
||
" \"\"\"\n",
|
||
" 计算四个指标:当日涨跌幅、5日移动平均、RSI、MACD。\n",
|
||
" \"\"\"\n",
|
||
" df = df.sort_values('trade_date')\n",
|
||
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
|
||
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
|
||
" delta = df['close'].diff()\n",
|
||
" gain = delta.where(delta > 0, 0)\n",
|
||
" loss = -delta.where(delta < 0, 0)\n",
|
||
" avg_gain = gain.rolling(window=14).mean()\n",
|
||
" avg_loss = loss.rolling(window=14).mean()\n",
|
||
" rs = avg_gain / avg_loss\n",
|
||
" df['RSI'] = 100 - (100 / (1 + rs))\n",
|
||
"\n",
|
||
" # 计算MACD\n",
|
||
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
|
||
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
|
||
" df['MACD'] = ema12 - ema26\n",
|
||
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
|
||
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
|
||
"\n",
|
||
" # 4. 情绪因子1:市场上涨比例(Up Ratio)\n",
|
||
" df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)\n",
|
||
" df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例\n",
|
||
"\n",
|
||
" # 5. 情绪因子2:成交量变化率(Volume Change Rate)\n",
|
||
" df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量\n",
|
||
" df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率\n",
|
||
"\n",
|
||
" # 6. 情绪因子3:波动率(Volatility)\n",
|
||
" df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差\n",
|
||
"\n",
|
||
" # 7. 情绪因子4:成交额变化率(Amount Change Rate)\n",
|
||
" df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额\n",
|
||
" df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率\n",
|
||
"\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"def generate_index_indicators(h5_filename):\n",
|
||
" df = pd.read_hdf(h5_filename, key='index_data')\n",
|
||
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
|
||
" df = df.sort_values('trade_date')\n",
|
||
"\n",
|
||
" # 计算每个ts_code的相关指标\n",
|
||
" df_indicators = []\n",
|
||
" for ts_code in df['ts_code'].unique():\n",
|
||
" df_index = df[df['ts_code'] == ts_code].copy()\n",
|
||
" df_index = calculate_indicators(df_index)\n",
|
||
" df_indicators.append(df_index)\n",
|
||
"\n",
|
||
" # 合并所有指数的结果\n",
|
||
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
|
||
"\n",
|
||
" # 保留trade_date列,并将同一天的数据按ts_code合并成一行\n",
|
||
" df_final = df_all_indicators.pivot_table(\n",
|
||
" index='trade_date',\n",
|
||
" columns='ts_code',\n",
|
||
" values=['daily_return', 'RSI', 'MACD', 'Signal_line',\n",
|
||
" 'MACD_hist', 'up_ratio_20d', 'volume_change_rate', 'volatility',\n",
|
||
" 'amount_change_rate', 'amount_mean'],\n",
|
||
" aggfunc='last'\n",
|
||
" )\n",
|
||
"\n",
|
||
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
|
||
" df_final = df_final.reset_index()\n",
|
||
"\n",
|
||
" return df_final\n",
|
||
"\n",
|
||
"\n",
|
||
"# 使用函数\n",
|
||
"h5_filename = '../../data/index_data.h5'\n",
|
||
"index_data = generate_index_indicators(h5_filename)\n",
|
||
"index_data = index_data.dropna()\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 4
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "a735bc02ceb4d872",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:45:35.904531Z",
|
||
"start_time": "2025-03-19T14:45:35.790296Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import talib\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_rolling_factor(df):\n",
|
||
" old_columns = df.columns.tolist()[:]\n",
|
||
" # 按股票和日期排序\n",
|
||
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
" grouped = df.groupby('ts_code', group_keys=False)\n",
|
||
"\n",
|
||
" df['return_skew'] = grouped['pct_chg'].rolling(window=5).skew().reset_index(0, drop=True)\n",
|
||
" df['return_kurtosis'] = grouped['pct_chg'].rolling(window=5).kurt().reset_index(0, drop=True)\n",
|
||
"\n",
|
||
" # 因子 1:短期成交量变化率\n",
|
||
" df['volume_change_rate'] = (\n",
|
||
" grouped['vol'].rolling(window=2).mean() /\n",
|
||
" grouped['vol'].rolling(window=10).mean() - 1\n",
|
||
" ).reset_index(level=0, drop=True) # 确保索引对齐\n",
|
||
"\n",
|
||
" # 因子 2:成交量突破信号\n",
|
||
" max_volume = grouped['vol'].rolling(window=5).max().reset_index(level=0, drop=True) # 确保索引对齐\n",
|
||
" df['cat_volume_breakout'] = (df['vol'] > max_volume)\n",
|
||
"\n",
|
||
" # 因子 3:换手率均线偏离度\n",
|
||
" mean_turnover = grouped['turnover_rate'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
|
||
" std_turnover = grouped['turnover_rate'].rolling(window=3).std().reset_index(level=0, drop=True)\n",
|
||
" df['turnover_deviation'] = (df['turnover_rate'] - mean_turnover) / std_turnover\n",
|
||
"\n",
|
||
" # 因子 4:换手率激增信号\n",
|
||
" df['cat_turnover_spike'] = (df['turnover_rate'] > mean_turnover + 2 * std_turnover)\n",
|
||
"\n",
|
||
" # 因子 5:量比均值\n",
|
||
" df['avg_volume_ratio'] = grouped['volume_ratio'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
|
||
"\n",
|
||
" # 因子 6:量比突破信号\n",
|
||
" max_volume_ratio = grouped['volume_ratio'].rolling(window=5).max().reset_index(level=0, drop=True)\n",
|
||
" df['cat_volume_ratio_breakout'] = (df['volume_ratio'] > max_volume_ratio)\n",
|
||
"\n",
|
||
" df['vol_spike'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(x['vol'].rolling(20).mean(), index=x.index)\n",
|
||
" )\n",
|
||
" df['vol_std_5'] = df['vol'].pct_change().rolling(5).std()\n",
|
||
"\n",
|
||
" # 计算 OBV 及其均线\n",
|
||
" df['obv'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
||
" )\n",
|
||
" df['maobv_6'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
|
||
" )\n",
|
||
"\n",
|
||
" # df['rsi_3'] = grouped.apply(\n",
|
||
" # lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
|
||
" # )\n",
|
||
" # df['rsi_6'] = grouped.apply(\n",
|
||
" # lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
|
||
" # )\n",
|
||
" # df['rsi_9'] = grouped.apply(\n",
|
||
" # lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
|
||
" # )\n",
|
||
"\n",
|
||
" # 计算 return_10 和 return_20\n",
|
||
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
|
||
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
|
||
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
||
"\n",
|
||
" # df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
|
||
"\n",
|
||
" # 计算标准差指标\n",
|
||
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
|
||
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
|
||
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
|
||
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
|
||
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
|
||
"\n",
|
||
" # 计算 EMA 指标\n",
|
||
" df['_ema_5'] = grouped['close'].apply(\n",
|
||
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
|
||
" )\n",
|
||
" df['_ema_13'] = grouped['close'].apply(\n",
|
||
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
|
||
" )\n",
|
||
" df['_ema_20'] = grouped['close'].apply(\n",
|
||
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
|
||
" )\n",
|
||
" df['_ema_60'] = grouped['close'].apply(\n",
|
||
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
|
||
" df['act_factor1'] = grouped['_ema_5'].apply(\n",
|
||
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
|
||
" )\n",
|
||
" df['act_factor2'] = grouped['_ema_13'].apply(\n",
|
||
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
|
||
" )\n",
|
||
" df['act_factor3'] = grouped['_ema_20'].apply(\n",
|
||
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
|
||
" )\n",
|
||
" df['act_factor4'] = grouped['_ema_60'].apply(\n",
|
||
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
|
||
" )\n",
|
||
"\n",
|
||
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
|
||
"\n",
|
||
" def rolling_covariance(x, y, window):\n",
|
||
" return x.rolling(window).cov(y)\n",
|
||
"\n",
|
||
" def delta(series, period):\n",
|
||
" return series.diff(period)\n",
|
||
"\n",
|
||
" def rank(series):\n",
|
||
" return series.rank(pct=True)\n",
|
||
"\n",
|
||
" def stddev(series, window):\n",
|
||
" return series.rolling(window).std()\n",
|
||
"\n",
|
||
" window_high_volume = 5\n",
|
||
" window_close_stddev = 20\n",
|
||
" period_delta = 5\n",
|
||
" df['cov'] = rolling_covariance(df['high'], df['vol'], window_high_volume)\n",
|
||
" df['delta_cov'] = delta(df['cov'], period_delta)\n",
|
||
" df['_rank_stddev'] = rank(stddev(df['close'], window_close_stddev))\n",
|
||
" df['alpha_22_improved'] = -1 * df['delta_cov'] * df['_rank_stddev']\n",
|
||
"\n",
|
||
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
|
||
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
|
||
" 0)\n",
|
||
"\n",
|
||
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
|
||
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
|
||
"\n",
|
||
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
|
||
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
|
||
"\n",
|
||
" df['cat_up_limit'] = (df['close'] == df['up_limit']) # 是否涨停(1表示涨停,0表示未涨停)\n",
|
||
" df['cat_down_limit'] = (df['close'] == df['down_limit']) # 是否跌停(1表示跌停,0表示未跌停)\n",
|
||
" df['up_limit_count_10d'] = grouped['cat_up_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
|
||
" drop=True)\n",
|
||
" df['down_limit_count_10d'] = grouped['cat_down_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
|
||
" drop=True)\n",
|
||
"\n",
|
||
" # 3. 最近连续涨跌停天数\n",
|
||
" def calculate_consecutive_limits(series):\n",
|
||
" \"\"\"\n",
|
||
" 计算连续涨停/跌停天数。\n",
|
||
" \"\"\"\n",
|
||
" consecutive_up = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
|
||
" consecutive_down = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
|
||
" return consecutive_up, consecutive_down\n",
|
||
"\n",
|
||
" # 连续涨停天数\n",
|
||
" df['consecutive_up_limit'] = grouped['cat_up_limit'].apply(\n",
|
||
" lambda x: calculate_consecutive_limits(x)[0]\n",
|
||
" ).reset_index(level=0, drop=True)\n",
|
||
"\n",
|
||
" df['vol_break'] = np.where((df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2), 1, 0)\n",
|
||
"\n",
|
||
" df['weight_roc5'] = grouped['weight_avg'].apply(lambda x: x.pct_change(5))\n",
|
||
"\n",
|
||
" def rolling_corr(group):\n",
|
||
" roc_close = group['close'].pct_change()\n",
|
||
" roc_weight = group['weight_avg'].pct_change()\n",
|
||
" return roc_close.rolling(10).corr(roc_weight)\n",
|
||
"\n",
|
||
" df['price_cost_divergence'] = grouped.apply(rolling_corr)\n",
|
||
"\n",
|
||
" df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])\n",
|
||
"\n",
|
||
" # 16. 筹码稳定性指数 (20日波动率)\n",
|
||
" df['weight_std20'] = grouped['weight_avg'].apply(lambda x: x.rolling(20).std())\n",
|
||
" df['cost_stability'] = df['weight_std20'] / grouped['weight_avg'].transform(lambda x: x.rolling(20).mean())\n",
|
||
"\n",
|
||
" # 17. 成本区间突破标记\n",
|
||
" df['high_cost_break_days'] = grouped.apply(lambda g: g['close'].gt(g['cost_95pct']).rolling(5).sum())\n",
|
||
"\n",
|
||
" # 20. 筹码-流动性风险\n",
|
||
" df['liquidity_risk'] = (df['cost_95pct'] - df['cost_5pct']) * (\n",
|
||
" 1 / grouped['vol'].transform(lambda x: x.rolling(10).mean()))\n",
|
||
"\n",
|
||
" # 7. 市值波动率因子\n",
|
||
" df['turnover_std'] = grouped['turnover_rate'].rolling(window=20).std().reset_index(level=0, drop=True)\n",
|
||
" df['mv_volatility'] = grouped.apply(lambda x: x['turnover_std'] / x['circ_mv']).reset_index(level=0, drop=True)\n",
|
||
"\n",
|
||
" # 8. 市值成长性因子\n",
|
||
" df['volume_growth'] = grouped['vol'].pct_change(periods=20).reset_index(level=0, drop=True)\n",
|
||
" df['mv_growth'] = grouped.apply(lambda x: x['volume_growth'] / x['circ_mv']).reset_index(level=0, drop=True)\n",
|
||
"\n",
|
||
" df.drop(columns=['weight_std20'], inplace=True, errors='ignore')\n",
|
||
" new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]\n",
|
||
"\n",
|
||
" return df, new_columns\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_simple_factor(df):\n",
|
||
" old_columns = df.columns.tolist()[:]\n",
|
||
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
"\n",
|
||
" alpha = 0.5\n",
|
||
" df['momentum_factor'] = df['volume_change_rate'] + alpha * df['turnover_deviation']\n",
|
||
" df['resonance_factor'] = df['volume_ratio'] * df['pct_chg']\n",
|
||
" df['log_close'] = np.log(df['close'])\n",
|
||
"\n",
|
||
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
|
||
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
|
||
"\n",
|
||
" # 计算比值指标\n",
|
||
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
|
||
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
|
||
"\n",
|
||
" # 计算标准差差值\n",
|
||
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
|
||
"\n",
|
||
" df['cat_af1'] = df['act_factor1'] > 0\n",
|
||
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
|
||
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
|
||
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
|
||
"\n",
|
||
" # 计算 act_factor5 和 act_factor6\n",
|
||
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
|
||
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
|
||
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
|
||
"\n",
|
||
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
|
||
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
|
||
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
|
||
"\n",
|
||
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
|
||
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
|
||
"\n",
|
||
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
|
||
"\n",
|
||
" df['ctrl_strength'] = (df['cost_85pct'] - df['cost_15pct']) / (df['his_high'] - df['his_low'])\n",
|
||
"\n",
|
||
" df['low_cost_dev'] = (df['close'] - df['cost_5pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
|
||
"\n",
|
||
" df['asymmetry'] = (df['cost_95pct'] - df['cost_50pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
|
||
"\n",
|
||
" df['lock_factor'] = df['turnover_rate'] * (\n",
|
||
" 1 - (df['cost_95pct'] - df['cost_5pct']) / (df['his_high'] - df['his_low']))\n",
|
||
"\n",
|
||
" df['cat_vol_break'] = (df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2)\n",
|
||
"\n",
|
||
" df['cost_atr_adj'] = (df['cost_95pct'] - df['cost_5pct']) / df['atr_14']\n",
|
||
"\n",
|
||
" # 12. 小盘股筹码集中度\n",
|
||
" df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])\n",
|
||
"\n",
|
||
" df['cat_golden_resonance'] = ((df['close'] > df['weight_avg']) &\n",
|
||
" (df['volume_ratio'] > 1.5) &\n",
|
||
" (df['winner_rate'] > 0.7))\n",
|
||
"\n",
|
||
" df['mv_turnover_ratio'] = df['turnover_rate'] / df['circ_mv']\n",
|
||
"\n",
|
||
" df['mv_adjusted_volume'] = df['vol'] / df['circ_mv']\n",
|
||
"\n",
|
||
" df['mv_weighted_turnover'] = df['turnover_rate'] * (1 / df['circ_mv'])\n",
|
||
"\n",
|
||
" df['nonlinear_mv_volume'] = df['vol'] / df['circ_mv']\n",
|
||
"\n",
|
||
" df['mv_volume_ratio'] = df['volume_ratio'] / df['circ_mv']\n",
|
||
"\n",
|
||
" df['mv_momentum'] = df['turnover_rate'] * df['volume_ratio'] / df['circ_mv']\n",
|
||
"\n",
|
||
" drop_columns = [col for col in df.columns if col.startswith('_')]\n",
|
||
" df.drop(columns=drop_columns, inplace=True, errors='ignore')\n",
|
||
"\n",
|
||
" new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]\n",
|
||
" return df, new_columns\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 5
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "53f86ddc0677a6d7",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"scrolled": true,
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:45:43.390734Z",
|
||
"start_time": "2025-03-19T14:45:35.933466Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"from code.utils.factor import get_act_factor\n",
|
||
"\n",
|
||
"\n",
|
||
"def read_industry_data(h5_filename):\n",
|
||
" # 读取 H5 文件中所有的行业数据\n",
|
||
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
|
||
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
|
||
" ]) # 假设 H5 文件的键是 'industry_data'\n",
|
||
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
" industry_data = industry_data.reindex()\n",
|
||
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
|
||
"\n",
|
||
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
|
||
" industry_data['obv'] = grouped.apply(\n",
|
||
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
||
" )\n",
|
||
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
|
||
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
||
"\n",
|
||
" industry_data = get_act_factor(industry_data, cat=False)\n",
|
||
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
|
||
"\n",
|
||
" # # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
|
||
" # factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
|
||
" # \n",
|
||
" # for factor in factor_columns:\n",
|
||
" # if factor in industry_data.columns:\n",
|
||
" # # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
|
||
" # industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
|
||
" # lambda x: x - x.mean())\n",
|
||
"\n",
|
||
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
|
||
" lambda x: x.rank(pct=True))\n",
|
||
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
|
||
" lambda x: x.rank(pct=True))\n",
|
||
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
|
||
"\n",
|
||
" industry_data = industry_data.rename(\n",
|
||
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
|
||
"\n",
|
||
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
|
||
" return industry_data\n",
|
||
"\n",
|
||
"\n",
|
||
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 6
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "dbe2fd8021b9417f",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:45:43.435718Z",
|
||
"start_time": "2025-03-19T14:45:43.429472Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"origin_columns = df.columns.tolist()\n",
|
||
"origin_columns = [col for col in origin_columns if\n",
|
||
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
|
||
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n",
|
||
"origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
|
||
"print(origin_columns)"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"['ts_code', 'open', 'close', 'high', 'low', 'circ_mv', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'in_date']\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 7
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "85c3e3d0235ffffa",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:45:43.684712Z",
|
||
"start_time": "2025-03-19T14:45:43.492971Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"print(df[df['is_st']][['ts_code', 'trade_date', 'is_st']])"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" ts_code trade_date is_st\n",
|
||
"29 000037.SZ 2017-01-03 True\n",
|
||
"72 000408.SZ 2017-01-03 True\n",
|
||
"95 000504.SZ 2017-01-03 True\n",
|
||
"96 000505.SZ 2017-01-03 True\n",
|
||
"101 000511.SZ 2017-01-03 True\n",
|
||
"... ... ... ...\n",
|
||
"8417201 603869.SH 2025-03-13 True\n",
|
||
"8417206 603879.SH 2025-03-13 True\n",
|
||
"8417253 603959.SH 2025-03-13 True\n",
|
||
"8417635 688282.SH 2025-03-13 True\n",
|
||
"8417639 688287.SH 2025-03-13 True\n",
|
||
"\n",
|
||
"[191519 rows x 3 columns]\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 8
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "92d84ce15a562ec6",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:47:42.395141Z",
|
||
"start_time": "2025-03-19T14:45:43.718152Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"def filter_data(df):\n",
|
||
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
|
||
" df = df[~df['is_st']]\n",
|
||
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
|
||
" df = df[~df['ts_code'].str.startswith('30')]\n",
|
||
" df = df[~df['ts_code'].str.startswith('68')]\n",
|
||
" df = df[~df['ts_code'].str.startswith('8')]\n",
|
||
" df = df[df['trade_date'] >= '20180101']\n",
|
||
" df = df.drop(columns=[col for col in df.columns if col.startswith('_')])\n",
|
||
" df = df.reset_index(drop=True)\n",
|
||
" return df\n",
|
||
"\n",
|
||
"\n",
|
||
"df = filter_data(df)\n",
|
||
"# df = get_technical_factor(df)\n",
|
||
"# df = get_act_factor(df)\n",
|
||
"# df = get_money_flow_factor(df)\n",
|
||
"# df = get_alpha_factor(df)\n",
|
||
"# df = get_limit_factor(df)\n",
|
||
"# df = get_cyp_perf_factor(df)\n",
|
||
"# df = get_mv_factors(df)\n",
|
||
"df, _ = get_rolling_factor(df)\n",
|
||
"\n",
|
||
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
|
||
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
|
||
"# df = df.merge(index_data, on='trade_date', how='left')\n",
|
||
"\n",
|
||
"print(df.info())"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"Index: 5084263 entries, 0 to 5084262\n",
|
||
"Data columns (total 85 columns):\n",
|
||
" # Column Dtype \n",
|
||
"--- ------ ----- \n",
|
||
" 0 ts_code object \n",
|
||
" 1 trade_date datetime64[ns]\n",
|
||
" 2 open float64 \n",
|
||
" 3 close float64 \n",
|
||
" 4 high float64 \n",
|
||
" 5 low float64 \n",
|
||
" 6 vol float64 \n",
|
||
" 7 pct_chg float64 \n",
|
||
" 8 turnover_rate float64 \n",
|
||
" 9 pe_ttm float64 \n",
|
||
" 10 circ_mv float64 \n",
|
||
" 11 volume_ratio float64 \n",
|
||
" 12 is_st bool \n",
|
||
" 13 up_limit float64 \n",
|
||
" 14 down_limit float64 \n",
|
||
" 15 buy_sm_vol float64 \n",
|
||
" 16 sell_sm_vol float64 \n",
|
||
" 17 buy_lg_vol float64 \n",
|
||
" 18 sell_lg_vol float64 \n",
|
||
" 19 buy_elg_vol float64 \n",
|
||
" 20 sell_elg_vol float64 \n",
|
||
" 21 net_mf_vol float64 \n",
|
||
" 22 his_low float64 \n",
|
||
" 23 his_high float64 \n",
|
||
" 24 cost_5pct float64 \n",
|
||
" 25 cost_15pct float64 \n",
|
||
" 26 cost_50pct float64 \n",
|
||
" 27 cost_85pct float64 \n",
|
||
" 28 cost_95pct float64 \n",
|
||
" 29 weight_avg float64 \n",
|
||
" 30 winner_rate float64 \n",
|
||
" 31 cat_l2_code object \n",
|
||
" 32 in_date datetime64[ns]\n",
|
||
" 33 return_skew float64 \n",
|
||
" 34 return_kurtosis float64 \n",
|
||
" 35 volume_change_rate float64 \n",
|
||
" 36 cat_volume_breakout bool \n",
|
||
" 37 turnover_deviation float64 \n",
|
||
" 38 cat_turnover_spike bool \n",
|
||
" 39 avg_volume_ratio float64 \n",
|
||
" 40 cat_volume_ratio_breakout bool \n",
|
||
" 41 vol_spike float64 \n",
|
||
" 42 vol_std_5 float64 \n",
|
||
" 43 obv float64 \n",
|
||
" 44 maobv_6 float64 \n",
|
||
" 45 return_5 float64 \n",
|
||
" 46 return_10 float64 \n",
|
||
" 47 return_20 float64 \n",
|
||
" 48 std_return_5 float64 \n",
|
||
" 49 std_return_15 float64 \n",
|
||
" 50 std_return_25 float64 \n",
|
||
" 51 std_return_90 float64 \n",
|
||
" 52 std_return_90_2 float64 \n",
|
||
" 53 _ema_5 float64 \n",
|
||
" 54 _ema_13 float64 \n",
|
||
" 55 _ema_20 float64 \n",
|
||
" 56 _ema_60 float64 \n",
|
||
" 57 act_factor1 float64 \n",
|
||
" 58 act_factor2 float64 \n",
|
||
" 59 act_factor3 float64 \n",
|
||
" 60 act_factor4 float64 \n",
|
||
" 61 log(circ_mv) float64 \n",
|
||
" 62 cov float64 \n",
|
||
" 63 delta_cov float64 \n",
|
||
" 64 _rank_stddev float64 \n",
|
||
" 65 alpha_22_improved float64 \n",
|
||
" 66 alpha_003 float64 \n",
|
||
" 67 alpha_007 float64 \n",
|
||
" 68 alpha_013 float64 \n",
|
||
" 69 cat_up_limit bool \n",
|
||
" 70 cat_down_limit bool \n",
|
||
" 71 up_limit_count_10d float64 \n",
|
||
" 72 down_limit_count_10d float64 \n",
|
||
" 73 consecutive_up_limit int64 \n",
|
||
" 74 vol_break int32 \n",
|
||
" 75 weight_roc5 float64 \n",
|
||
" 76 price_cost_divergence float64 \n",
|
||
" 77 smallcap_concentration float64 \n",
|
||
" 78 cost_stability float64 \n",
|
||
" 79 high_cost_break_days float64 \n",
|
||
" 80 liquidity_risk float64 \n",
|
||
" 81 turnover_std float64 \n",
|
||
" 82 mv_volatility float64 \n",
|
||
" 83 volume_growth float64 \n",
|
||
" 84 mv_growth float64 \n",
|
||
"dtypes: bool(6), datetime64[ns](2), float64(73), int32(1), int64(1), object(2)\n",
|
||
"memory usage: 3.0+ GB\n",
|
||
"None\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 9
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "f4f16d63ad18d1bc",
|
||
"metadata": {
|
||
"jupyter": {
|
||
"source_hidden": true
|
||
},
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:47:42.645061Z",
|
||
"start_time": "2025-03-19T14:47:42.637235Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"def create_deviation_within_dates(df, feature_columns):\n",
|
||
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
|
||
" new_columns = {}\n",
|
||
" ret_feature_columns = feature_columns[:]\n",
|
||
"\n",
|
||
" # 自动选择所有数值型特征\n",
|
||
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
|
||
"\n",
|
||
" # num_features = ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'cat_vol_spike', 'obv', 'maobv_6', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013']\n",
|
||
" num_features = [col for col in num_features if 'cat' not in col and 'industry' not in col]\n",
|
||
" num_features = [col for col in num_features if 'limit' not in col]\n",
|
||
" num_features = [col for col in num_features if 'cyq' not in col]\n",
|
||
"\n",
|
||
" # 遍历所有数值型特征\n",
|
||
" for feature in num_features:\n",
|
||
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
|
||
" continue\n",
|
||
"\n",
|
||
" # grouped_mean = df.groupby(['trade_date'])[feature].transform('mean')\n",
|
||
" # deviation_col_name = f'deviation_mean_{feature}'\n",
|
||
" # new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
|
||
" # ret_feature_columns.append(deviation_col_name)\n",
|
||
"\n",
|
||
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
|
||
" deviation_col_name = f'deviation_mean_{feature}'\n",
|
||
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
|
||
" ret_feature_columns.append(deviation_col_name)\n",
|
||
"\n",
|
||
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
|
||
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
|
||
"\n",
|
||
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
|
||
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
|
||
"\n",
|
||
" return df, ret_feature_columns\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 10
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "40e6b68a91b30c79",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T14:47:42.840210Z",
|
||
"start_time": "2025-03-19T14:47:42.740146Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"\n",
|
||
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
|
||
" log=True):\n",
|
||
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
|
||
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
|
||
"\n",
|
||
" # Calculate lower and upper bounds based on percentiles\n",
|
||
" lower_bound = label.quantile(lower_percentile)\n",
|
||
" upper_bound = label.quantile(upper_percentile)\n",
|
||
"\n",
|
||
" # Filter out values outside the bounds\n",
|
||
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
|
||
"\n",
|
||
" # Print the number of removed outliers\n",
|
||
" if log:\n",
|
||
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
|
||
" return filtered_label\n",
|
||
"\n",
|
||
"\n",
|
||
"def calculate_risk_adjusted_target(df, days=5):\n",
|
||
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
||
"\n",
|
||
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
|
||
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
|
||
" df['future_return'] = (df['future_close'] - df['future_open']) / df['future_open']\n",
|
||
"\n",
|
||
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
|
||
" level=0, drop=True)\n",
|
||
" sharpe_ratio = df['future_return'] * df['future_volatility']\n",
|
||
" sharpe_ratio.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
|
||
"\n",
|
||
" return sharpe_ratio\n",
|
||
"\n",
|
||
"\n",
|
||
"def calculate_score(df, days=5, lambda_param=1.0):\n",
|
||
" def calculate_max_drawdown(prices):\n",
|
||
" peak = prices.iloc[0] # 初始化峰值\n",
|
||
" max_drawdown = 0 # 初始化最大回撤\n",
|
||
"\n",
|
||
" for price in prices:\n",
|
||
" if price > peak:\n",
|
||
" peak = price # 更新峰值\n",
|
||
" else:\n",
|
||
" drawdown = (peak - price) / peak # 计算当前回撤\n",
|
||
" max_drawdown = max(max_drawdown, drawdown) # 更新最大回撤\n",
|
||
"\n",
|
||
" return max_drawdown\n",
|
||
"\n",
|
||
" def compute_stock_score(stock_df):\n",
|
||
" stock_df = stock_df.sort_values(by=['trade_date'])\n",
|
||
" future_return = stock_df['future_return']\n",
|
||
" volatility = stock_df['close'].pct_change().rolling(days).std().shift(-days)\n",
|
||
" max_drawdown = stock_df['close'].rolling(days).apply(calculate_max_drawdown, raw=False).shift(-days)\n",
|
||
" score = future_return - lambda_param * max_drawdown\n",
|
||
"\n",
|
||
" return score\n",
|
||
"\n",
|
||
" scores = df.groupby('ts_code').apply(lambda x: compute_stock_score(x))\n",
|
||
" scores = scores.reset_index(level=0, drop=True)\n",
|
||
"\n",
|
||
" return scores\n",
|
||
"\n",
|
||
"\n",
|
||
"import gc\n",
|
||
"\n",
|
||
"gc.collect()"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"0"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 11
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "47c12bb34062ae7a",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T15:54:02.848116Z",
|
||
"start_time": "2025-03-19T15:51:53.055184Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"days = 5\n",
|
||
"\n",
|
||
"import gc\n",
|
||
"\n",
|
||
"gc.collect()\n",
|
||
"\n",
|
||
"df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
|
||
"\n",
|
||
"\n",
|
||
"# df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / \\\n",
|
||
"# df.groupby('ts_code')['open'].shift(-1)\n",
|
||
"\n",
|
||
"\n",
|
||
"def symmetric_log_transform(values):\n",
|
||
" return np.sign(values) * np.log1p(np.abs(values))\n",
|
||
"\n",
|
||
"\n",
|
||
"# df['future_score'] = df['future_return']\n",
|
||
"df['future_score'] = calculate_score(df, days=days, lambda_param=0.5)\n",
|
||
"# df['future_score'] = df.groupby('ts_code')['future_return'].shift(-1).rolling(window=2).mean()\n",
|
||
"df['future_score'] = symmetric_log_transform(df['future_score'])\n",
|
||
"\n",
|
||
"# df['label'] = remove_outliers_label_percentile(df['label'])\n",
|
||
"train_data = df[(df['trade_date'] <= '2024-01-01') & (df['trade_date'] >= '2000-01-01')]\n",
|
||
"test_data = df[(df['trade_date'] >= '2024-01-01')]\n",
|
||
"\n",
|
||
"\n",
|
||
"def select_pre_zt_stocks_dynamic(\n",
|
||
" stock_df,\n",
|
||
" vol_spike_multiplier=1.5,\n",
|
||
" min_return=0.03, # 最小累计涨幅(例如 3%)\n",
|
||
" min_main_net_inflow=1e6, # 最小主力资金净流入(例如 100 万元)\n",
|
||
" window=30, # 计算历史均值的窗口大小\n",
|
||
" signal_days=1 # 异动信号需要连续出现的天数\n",
|
||
"):\n",
|
||
" # 排序数据\n",
|
||
" stock_df = stock_df.sort_values(by=['trade_date', 'ts_code'])\n",
|
||
" stock_df = stock_df.groupby('trade_date', group_keys=False).apply(\n",
|
||
" lambda x: x.nlargest(512, 'return_20')\n",
|
||
" )\n",
|
||
"\n",
|
||
" return stock_df\n",
|
||
"\n",
|
||
"\n",
|
||
"train_data = select_pre_zt_stocks_dynamic(train_data)\n",
|
||
"test_data = select_pre_zt_stocks_dynamic(test_data)\n",
|
||
"# # train_data = train_data[train_data['circ_mv'] < 2000000]\n",
|
||
"# # test_data = test_data[test_data['circ_mv'] < 2000000]\n",
|
||
"#\n",
|
||
"# train_data = train_data[train_data['vol'] > 1.5 * train_data['vol_spike']]\n",
|
||
"# test_data = test_data[test_data['vol'] > 1.5 * test_data['vol_spike']]\n",
|
||
"# train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
|
||
"# lambda x: x.nsmallest(500, 'return_20')\n",
|
||
"# )\n",
|
||
"#\n",
|
||
"# test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
|
||
"# lambda x: x.nsmallest(500, 'return_20')\n",
|
||
"# )\n",
|
||
"\n",
|
||
"# train_data, _ = get_simple_factor(train_data)\n",
|
||
"# test_data, _ = get_simple_factor(test_data)\n",
|
||
"\n",
|
||
"train_data['label'] = train_data.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
|
||
" lambda x: pd.qcut(x, q=5, labels=False, duplicates='drop')\n",
|
||
")\n",
|
||
"test_data['label'] = test_data.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
|
||
" lambda x: pd.qcut(x, q=5, labels=False, duplicates='drop')\n",
|
||
")\n",
|
||
"# train_data['label'] = train_data['label'] == 4\n",
|
||
"# test_data['label'] = test_data['label'] == 4\n",
|
||
"\n",
|
||
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
|
||
"index_data = index_data.sort_values(by=['trade_date'])\n",
|
||
"\n",
|
||
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
||
"train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
|
||
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
||
"test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
|
||
"\n",
|
||
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
|
||
"\n",
|
||
"feature_columns = [col for col in train_data.columns if col not in ['trade_date',\n",
|
||
" 'ts_code',\n",
|
||
" 'label']]\n",
|
||
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
|
||
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
|
||
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
|
||
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
|
||
"print(feature_columns)\n",
|
||
"\n",
|
||
"feature_columns_new = feature_columns[:]\n",
|
||
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
|
||
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
|
||
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
|
||
"\n",
|
||
"train_data = train_data.dropna(subset=feature_columns_new)\n",
|
||
"train_data = train_data.dropna(subset=['label'])\n",
|
||
"train_data = train_data.reset_index(drop=True)\n",
|
||
"\n",
|
||
"# print(test_data.tail())\n",
|
||
"test_data = test_data.dropna(subset=feature_columns_new)\n",
|
||
"# test_data = test_data.dropna(subset=['label'])\n",
|
||
"test_data = test_data.reset_index(drop=True)\n",
|
||
"\n",
|
||
"print(len(train_data))\n",
|
||
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
||
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
||
"print(len(test_data))\n",
|
||
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
||
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
||
"\n",
|
||
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
|
||
"for col in cat_columns:\n",
|
||
" train_data[col] = train_data[col].astype('category')\n",
|
||
" test_data[col] = test_data[col].astype('category')\n",
|
||
"\n",
|
||
"\n",
|
||
"# feature_columns_new.remove('cat_l2_code')"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"['vol', 'pct_chg', 'turnover_rate', 'pe_ttm', 'volume_ratio', 'winner_rate', 'cat_l2_code', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'obv', 'maobv_6', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'log(circ_mv)', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'cat_up_limit', 'cat_down_limit', 'up_limit_count_10d', 'down_limit_count_10d', 'consecutive_up_limit', 'vol_break', 'weight_roc5', 'price_cost_divergence', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_ratio_20d', '000852.SH_volatility', '000905.SH_volatility', '399006.SZ_volatility', '000852.SH_volume_change_rate', '000905.SH_volume_change_rate', '399006.SZ_volume_change_rate']\n",
|
||
"feature_columns size: 177\n",
|
||
"541133\n",
|
||
"最小日期: 2018-06-04\n",
|
||
"最大日期: 2023-12-29\n",
|
||
"106062\n",
|
||
"最小日期: 2024-01-02\n",
|
||
"最大日期: 2025-03-13\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 40
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "1c46817a-b5dd-4bec-8bb4-e6e80bfd9d66",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T15:54:03.016065Z",
|
||
"start_time": "2025-03-19T15:54:03.011956Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"# test_data = df[(df['trade_date'] >= '2024-04-01')]\n",
|
||
"# test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
||
"# test_data = test_data.merge(index_data, on='trade_date', how='left')"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 41
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"id": "da2bb202843d9275",
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T15:54:24.393873Z",
|
||
"start_time": "2025-03-19T15:54:03.049056Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"def remove_highly_correlated_features(df, feature_columns, threshold=0.9):\n",
|
||
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
|
||
" if not numeric_features:\n",
|
||
" raise ValueError(\"No numeric features found in the provided data.\")\n",
|
||
"\n",
|
||
" corr_matrix = df[numeric_features].corr().abs()\n",
|
||
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
|
||
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
|
||
" remaining_features = [col for col in feature_columns if col not in to_drop\n",
|
||
" or 'act' in col or 'af' in col]\n",
|
||
" return remaining_features\n",
|
||
"\n",
|
||
"\n",
|
||
"feature_columns_new = remove_highly_correlated_features(train_data, feature_columns_new)\n",
|
||
"keep_columns = [col for col in train_data.columns if\n",
|
||
" col in feature_columns_new or col in ['ts_code', 'trade_date', 'label', 'future_return',\n",
|
||
" 'future_score']]\n",
|
||
"train_data = train_data[keep_columns]\n",
|
||
"# test_data = test_data[keep_columns]\n",
|
||
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
|
||
"# print(feature_columns_new)\n",
|
||
"# 2. 按 trade_date 分组"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"feature_columns size: 154\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 42
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T15:54:24.499139Z",
|
||
"start_time": "2025-03-19T15:54:24.490686Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import torch\n",
|
||
"\n",
|
||
"num_heads = 4 # Transformer头数\n",
|
||
"transformer_input_dim = 64\n",
|
||
"num_layers = 2 # Transformer层数\n",
|
||
"hidden_dim = 64 # 隐藏层维度\n",
|
||
"num_epochs = 10 # 训练轮数\n",
|
||
"learning_rate = 0.001 # 学习率\n",
|
||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
"validation_days = 180\n",
|
||
"\n",
|
||
"print(device)"
|
||
],
|
||
"id": "8113bba62de693cc",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"cuda\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 43
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T15:57:09.407047Z",
|
||
"start_time": "2025-03-19T15:54:24.596522Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"from tqdm import tqdm\n",
|
||
"import torch.optim as optim\n",
|
||
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
|
||
"from sklearn.metrics import accuracy_score\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"\n",
|
||
"\n",
|
||
"class StockPredictionModel(nn.Module):\n",
|
||
" def __init__(self, input_dim, hidden_dim, num_classes, num_heads, num_layers, transformer_input_dim):\n",
|
||
" super(StockPredictionModel, self).__init__()\n",
|
||
"\n",
|
||
" # 输入全连接层:将原始输入特征映射到 Transformer 的输入维度\n",
|
||
" self.fc_input = nn.Sequential(\n",
|
||
" nn.Linear(input_dim, hidden_dim),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Dropout(0.3),\n",
|
||
" nn.Linear(hidden_dim, hidden_dim),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Dropout(0.3),\n",
|
||
" nn.Linear(hidden_dim, hidden_dim),\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Transformer Encoder\n",
|
||
" self.transformer = nn.Transformer(\n",
|
||
" d_model=transformer_input_dim,\n",
|
||
" nhead=num_heads,\n",
|
||
" num_encoder_layers=num_layers,\n",
|
||
" batch_first=True\n",
|
||
" )\n",
|
||
"\n",
|
||
" self.dropout = nn.Dropout(0.3)\n",
|
||
"\n",
|
||
" # 全连接层\n",
|
||
" self.fc_output = nn.Sequential(\n",
|
||
" nn.Linear(transformer_input_dim, hidden_dim),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Dropout(0.3),\n",
|
||
" nn.Linear(hidden_dim, hidden_dim),\n",
|
||
" nn.ReLU(),\n",
|
||
" nn.Dropout(0.3),\n",
|
||
" nn.Linear(hidden_dim, num_classes),\n",
|
||
" )\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" # x: (batch_size, num_stocks, input_dim)\n",
|
||
"\n",
|
||
" # 输入全连接层处理\n",
|
||
" x = (self.fc_input(x))\n",
|
||
"\n",
|
||
" # Transformer处理\n",
|
||
" transformer_out = self.transformer(x, x) # Self-attention\n",
|
||
"\n",
|
||
" # 全连接层处理\n",
|
||
" out = self.fc_output(transformer_out)\n",
|
||
"\n",
|
||
" return out\n",
|
||
"\n",
|
||
"\n",
|
||
"# 数据预处理函数\n",
|
||
"def preprocess_data(data, cat_columns, feature_columns, label_encoders=None, scaler=None):\n",
|
||
" \"\"\"\n",
|
||
" 预处理数据:\n",
|
||
" - 对分类特征进行 LabelEncoder 编码。\n",
|
||
" - 对数值特征进行标准化。\n",
|
||
" - 将数据按天分组并转换为 NumPy 数组。\n",
|
||
" \"\"\"\n",
|
||
" # 分离分类特征和数值特征\n",
|
||
" numeric_cols = [col for col in feature_columns if col not in cat_columns]\n",
|
||
"\n",
|
||
" # 初始化编码器和标准化器(如果是训练阶段)\n",
|
||
" if label_encoders is None:\n",
|
||
" label_encoders = {col: LabelEncoder() for col in cat_columns}\n",
|
||
" X_cat = np.array([label_encoders[col].fit_transform(data[col]) for col in cat_columns]).T\n",
|
||
" else:\n",
|
||
" X_cat = np.array([label_encoders[col].transform(data[col]) for col in cat_columns]).T\n",
|
||
" if scaler is None:\n",
|
||
" scaler = StandardScaler()\n",
|
||
" X_num = scaler.fit_transform(data[numeric_cols])\n",
|
||
" else:\n",
|
||
" X_num = scaler.transform(data[numeric_cols])\n",
|
||
"\n",
|
||
"\n",
|
||
" # 处理分类特征\n",
|
||
"\n",
|
||
" # 处理数值特征\n",
|
||
"\n",
|
||
" # 合并特征\n",
|
||
" X_processed = np.hstack([X_num, X_cat])\n",
|
||
"\n",
|
||
" # 将处理后的数据与日期对齐\n",
|
||
" processed_data = pd.DataFrame(X_processed, columns=numeric_cols + cat_columns)\n",
|
||
" processed_data['trade_date'] = data['trade_date'].values # 保留日期列\n",
|
||
" processed_data['label'] = data['label'].values # 保留标签列\n",
|
||
"\n",
|
||
" # 按天分组\n",
|
||
" grouped_X = []\n",
|
||
" grouped_y = []\n",
|
||
" for date, group in processed_data.groupby('trade_date'):\n",
|
||
" X_day = group[numeric_cols + cat_columns].values # 使用预处理后的特征\n",
|
||
" y_day = group['label'].values # 使用原始标签\n",
|
||
" grouped_X.append(X_day)\n",
|
||
" grouped_y.append(y_day)\n",
|
||
"\n",
|
||
" # 获取每个分类特征的最大类别数(用于嵌入层)\n",
|
||
" cat_dims = [len(label_encoders[col].classes_) for col in cat_columns]\n",
|
||
"\n",
|
||
" return grouped_X, grouped_y, label_encoders, scaler, cat_dims\n",
|
||
"\n",
|
||
"\n",
|
||
"# 训练函数(逐天训练)\n",
|
||
"def train_by_day(model, X_train, y_train, criterion, optimizer, device):\n",
|
||
" model.train()\n",
|
||
" total_loss = 0\n",
|
||
" all_preds, all_labels = [], []\n",
|
||
"\n",
|
||
" for day_idx in range(len(X_train)):\n",
|
||
" # 取出当天的数据\n",
|
||
" X_batch = torch.tensor(X_train[day_idx], dtype=torch.float32).to(device)\n",
|
||
" y_batch = torch.tensor(y_train[day_idx], dtype=torch.long).to(device)\n",
|
||
"\n",
|
||
" # 添加 batch 维度(从 2D 变为 3D)\n",
|
||
" if X_batch.dim() == 2:\n",
|
||
" X_batch = X_batch.unsqueeze(0) # 形状变为 (1, num_stocks, num_features)\n",
|
||
"\n",
|
||
" # 前向传播\n",
|
||
" outputs = model(X_batch) # (batch_size, num_stocks, num_classes)\n",
|
||
" loss = criterion(outputs.view(-1, outputs.size(-1)), y_batch.view(-1))\n",
|
||
"\n",
|
||
" # 反向传播\n",
|
||
" optimizer.zero_grad()\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" # 统计损失和预测结果\n",
|
||
" total_loss += loss.item()\n",
|
||
" preds = torch.argmax(outputs, dim=-1).cpu().numpy()\n",
|
||
" labels = y_batch.cpu().numpy()\n",
|
||
" all_preds.extend(preds.flatten())\n",
|
||
" all_labels.extend(labels.flatten())\n",
|
||
"\n",
|
||
" avg_loss = total_loss / len(X_train)\n",
|
||
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
||
" return avg_loss, accuracy\n",
|
||
"\n",
|
||
"\n",
|
||
"# 验证/测试函数(逐天验证)\n",
|
||
"def evaluate_by_day(model, X_val, y_val, criterion, device):\n",
|
||
" model.eval()\n",
|
||
" total_loss = 0\n",
|
||
" all_preds, all_labels = [], []\n",
|
||
"\n",
|
||
" with torch.no_grad():\n",
|
||
" for day_idx in range(len(X_val)):\n",
|
||
" # 取出当天的数据\n",
|
||
" X_batch = torch.tensor(X_val[day_idx], dtype=torch.float32).to(device)\n",
|
||
" y_batch = torch.tensor(y_val[day_idx], dtype=torch.long).to(device)\n",
|
||
" # 添加 batch 维度(从 2D 变为 3D)\n",
|
||
" if X_batch.dim() == 2:\n",
|
||
" X_batch = X_batch.unsqueeze(0) # 形状变为 (1, num_stocks, num_features)\n",
|
||
"\n",
|
||
" # 前向传播\n",
|
||
" outputs = model(X_batch)\n",
|
||
" loss = criterion(outputs.view(-1, outputs.size(-1)), y_batch.view(-1))\n",
|
||
"\n",
|
||
" # 统计损失和预测结果\n",
|
||
" total_loss += loss.item()\n",
|
||
" preds = torch.argmax(outputs, dim=-1).cpu().numpy()\n",
|
||
" labels = y_batch.cpu().numpy()\n",
|
||
" all_preds.extend(preds.flatten())\n",
|
||
" all_labels.extend(labels.flatten())\n",
|
||
"\n",
|
||
" avg_loss = total_loss / len(X_val)\n",
|
||
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
||
" return avg_loss, accuracy\n",
|
||
"\n",
|
||
"\n",
|
||
"# 主程序\n",
|
||
"gc.collect()\n",
|
||
"# 数据切分\n",
|
||
"all_dates = train_data['trade_date'].unique() # 获取所有唯一的 trade_date\n",
|
||
"split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
|
||
"train_data_split = train_data[train_data['trade_date'] < split_date] # 训练集\n",
|
||
"val_data_split = train_data[train_data['trade_date'] >= split_date] # 验证集\n",
|
||
"\n",
|
||
"# 找到分类特征列\n",
|
||
"cat_columns = [col for col in train_data.columns if col.startswith(\"cat\")]\n",
|
||
"\n",
|
||
"# 预处理数据\n",
|
||
"X_train_processed, y_train_processed, label_encoders, scaler, cat_dims = preprocess_data(train_data_split, cat_columns,\n",
|
||
" feature_columns_new)\n",
|
||
"X_val_processed, y_val_processed, _, _, _ = preprocess_data(val_data_split, cat_columns, feature_columns_new,\n",
|
||
" label_encoders, scaler)\n",
|
||
"X_test_processed, y_test_processed, _, _, _ = preprocess_data(test_data, cat_columns, feature_columns_new,\n",
|
||
" label_encoders, scaler)\n",
|
||
"\n",
|
||
"# 超参数\n",
|
||
"input_dim = X_train_processed[0].shape[1] # 直接从数据获取输入维度\n",
|
||
"print(input_dim)\n",
|
||
"num_classes = len(train_data[\"label\"].unique())\n",
|
||
"\n",
|
||
"\n",
|
||
"# 模型、损失函数和优化器\n",
|
||
"model = StockPredictionModel(\n",
|
||
" input_dim=input_dim,\n",
|
||
" hidden_dim=hidden_dim,\n",
|
||
" num_classes=num_classes,\n",
|
||
" num_heads=num_heads,\n",
|
||
" num_layers=num_layers,\n",
|
||
" transformer_input_dim=transformer_input_dim,\n",
|
||
").to(device)\n",
|
||
"criterion = nn.CrossEntropyLoss()\n",
|
||
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
|
||
"\n",
|
||
"# 训练和验证\n",
|
||
"best_val_accuracy = 0\n",
|
||
"for epoch in tqdm(range(num_epochs)):\n",
|
||
" train_loss, train_acc = train_by_day(model, X_train_processed, y_train_processed, criterion, optimizer, device)\n",
|
||
" val_loss, val_acc = evaluate_by_day(model, X_val_processed, y_val_processed, criterion, device)\n",
|
||
"\n",
|
||
" print(f\"Epoch {epoch + 1}/{num_epochs}: \"\n",
|
||
" f\"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, \"\n",
|
||
" f\"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\")\n",
|
||
"\n",
|
||
" # 保存最佳模型\n",
|
||
" # if val_acc > best_val_accuracy:\n",
|
||
" # best_val_accuracy = val_acc\n",
|
||
" # torch.save(model.state_dict(), \"best_model.pth\")\n",
|
||
"\n",
|
||
"# 测试\n",
|
||
"# model.load_state_dict(torch.load(\"best_model.pth\"))\n",
|
||
"# test_loss, test_acc = evaluate_by_day(model, X_test_processed, y_test_processed, criterion, device)\n",
|
||
"# print(f\"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}\")"
|
||
],
|
||
"id": "beeb098799ecfa6a",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"154\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 10%|█ | 1/10 [00:26<03:58, 26.47s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 1/10: Train Loss: 1.6094, Train Acc: 0.2037, Val Loss: 1.6091, Val Acc: 0.2088\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 20%|██ | 2/10 [00:53<03:33, 26.72s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 2/10: Train Loss: 1.6091, Train Acc: 0.2046, Val Loss: 1.6091, Val Acc: 0.2088\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 30%|███ | 3/10 [01:20<03:07, 26.80s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 3/10: Train Loss: 1.6091, Train Acc: 0.2038, Val Loss: 1.6091, Val Acc: 0.2088\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 40%|████ | 4/10 [01:44<02:35, 25.92s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 4/10: Train Loss: 1.6091, Train Acc: 0.2043, Val Loss: 1.6091, Val Acc: 0.2088\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 50%|█████ | 5/10 [02:08<02:05, 25.18s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 5/10: Train Loss: 1.6091, Train Acc: 0.2046, Val Loss: 1.6091, Val Acc: 0.2088\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 60%|██████ | 6/10 [02:35<01:42, 25.72s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Epoch 6/10: Train Loss: 1.6091, Train Acc: 0.2042, Val Loss: 1.6091, Val Acc: 0.2088\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 60%|██████ | 6/10 [02:41<01:47, 26.93s/it]\n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "KeyboardInterrupt",
|
||
"evalue": "",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||
"\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
|
||
"Cell \u001B[1;32mIn[44], line 223\u001B[0m\n\u001B[0;32m 221\u001B[0m best_val_accuracy \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m 222\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m epoch \u001B[38;5;129;01min\u001B[39;00m tqdm(\u001B[38;5;28mrange\u001B[39m(num_epochs)):\n\u001B[1;32m--> 223\u001B[0m train_loss, train_acc \u001B[38;5;241m=\u001B[39m train_by_day(model, X_train_processed, y_train_processed, criterion, optimizer, device)\n\u001B[0;32m 224\u001B[0m val_loss, val_acc \u001B[38;5;241m=\u001B[39m evaluate_by_day(model, X_val_processed, y_val_processed, criterion, device)\n\u001B[0;32m 226\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mEpoch \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mepoch\u001B[38;5;250m \u001B[39m\u001B[38;5;241m+\u001B[39m\u001B[38;5;250m \u001B[39m\u001B[38;5;241m1\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m/\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mnum_epochs\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m: \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 227\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mTrain Loss: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtrain_loss\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, Train Acc: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtrain_acc\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 228\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mVal Loss: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mval_loss\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, Val Acc: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mval_acc\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m)\n",
|
||
"Cell \u001B[1;32mIn[44], line 132\u001B[0m, in \u001B[0;36mtrain_by_day\u001B[1;34m(model, X_train, y_train, criterion, optimizer, device)\u001B[0m\n\u001B[0;32m 129\u001B[0m X_batch \u001B[38;5;241m=\u001B[39m X_batch\u001B[38;5;241m.\u001B[39munsqueeze(\u001B[38;5;241m0\u001B[39m) \u001B[38;5;66;03m# 形状变为 (1, num_stocks, num_features)\u001B[39;00m\n\u001B[0;32m 131\u001B[0m \u001B[38;5;66;03m# 前向传播\u001B[39;00m\n\u001B[1;32m--> 132\u001B[0m outputs \u001B[38;5;241m=\u001B[39m model(X_batch) \u001B[38;5;66;03m# (batch_size, num_stocks, num_classes)\u001B[39;00m\n\u001B[0;32m 133\u001B[0m loss \u001B[38;5;241m=\u001B[39m criterion(outputs\u001B[38;5;241m.\u001B[39mview(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m, outputs\u001B[38;5;241m.\u001B[39msize(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)), y_batch\u001B[38;5;241m.\u001B[39mview(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m))\n\u001B[0;32m 135\u001B[0m \u001B[38;5;66;03m# 反向传播\u001B[39;00m\n",
|
||
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1737\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1738\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1739\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
|
||
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1745\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1746\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1747\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1748\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1749\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1750\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1752\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m 1753\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
|
||
"Cell \u001B[1;32mIn[44], line 54\u001B[0m, in \u001B[0;36mStockPredictionModel.forward\u001B[1;34m(self, x)\u001B[0m\n\u001B[0;32m 50\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, x):\n\u001B[0;32m 51\u001B[0m \u001B[38;5;66;03m# x: (batch_size, num_stocks, input_dim)\u001B[39;00m\n\u001B[0;32m 52\u001B[0m \n\u001B[0;32m 53\u001B[0m \u001B[38;5;66;03m# 输入全连接层处理\u001B[39;00m\n\u001B[1;32m---> 54\u001B[0m x \u001B[38;5;241m=\u001B[39m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfc_input(x))\n\u001B[0;32m 56\u001B[0m \u001B[38;5;66;03m# Transformer处理\u001B[39;00m\n\u001B[0;32m 57\u001B[0m transformer_out \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtransformer(x, x) \u001B[38;5;66;03m# Self-attention\u001B[39;00m\n",
|
||
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1737\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1738\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1739\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
|
||
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1745\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1746\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1747\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1748\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1749\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1750\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1752\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m 1753\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
|
||
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\container.py:250\u001B[0m, in \u001B[0;36mSequential.forward\u001B[1;34m(self, input)\u001B[0m\n\u001B[0;32m 248\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m):\n\u001B[0;32m 249\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m module \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m:\n\u001B[1;32m--> 250\u001B[0m \u001B[38;5;28minput\u001B[39m \u001B[38;5;241m=\u001B[39m module(\u001B[38;5;28minput\u001B[39m)\n\u001B[0;32m 251\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28minput\u001B[39m\n",
|
||
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1737\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1738\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1739\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
|
||
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1745\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1746\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1747\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1748\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1749\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1750\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1752\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m 1753\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
|
||
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\linear.py:125\u001B[0m, in \u001B[0;36mLinear.forward\u001B[1;34m(self, input)\u001B[0m\n\u001B[0;32m 124\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m: Tensor) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Tensor:\n\u001B[1;32m--> 125\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m F\u001B[38;5;241m.\u001B[39mlinear(\u001B[38;5;28minput\u001B[39m, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mweight, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbias)\n",
|
||
"\u001B[1;31mKeyboardInterrupt\u001B[0m: "
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 44
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-03-19T15:57:09.420051500Z",
|
||
"start_time": "2025-03-19T15:47:37.553874Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def preprocess_test_data(data, cat_columns, feature_columns, label_encoders, scaler):\n",
|
||
" \"\"\"对测试数据进行预处理\"\"\"\n",
|
||
" numeric_cols = [col for col in feature_columns if col not in cat_columns]\n",
|
||
"\n",
|
||
" # 处理分类特征\n",
|
||
" X_cat = np.array([label_encoders[col].transform(data[col]) for col in cat_columns]).T\n",
|
||
"\n",
|
||
" # 处理数值特征\n",
|
||
" X_num = scaler.transform(data[numeric_cols])\n",
|
||
"\n",
|
||
" # 合并特征\n",
|
||
" X_processed = np.hstack([X_num, X_cat])\n",
|
||
"\n",
|
||
" # 保留原始 ts_code 和 trade_date\n",
|
||
" processed_data = data[['ts_code', 'trade_date']].copy()\n",
|
||
" processed_data['features'] = list(X_processed)\n",
|
||
"\n",
|
||
" return processed_data\n",
|
||
"\n",
|
||
"\n",
|
||
"def predict_and_save(model, test_data, cat_columns, feature_columns, label_encoders, scaler, output_path):\n",
|
||
" # 预处理测试数据\n",
|
||
" processed_data = preprocess_test_data(test_data, cat_columns, feature_columns, label_encoders, scaler)\n",
|
||
"\n",
|
||
" # 按天分组\n",
|
||
" grouped_data = processed_data.groupby('trade_date')\n",
|
||
"\n",
|
||
" # 存储预测结果\n",
|
||
" results = []\n",
|
||
"\n",
|
||
" model.eval()\n",
|
||
" with torch.no_grad():\n",
|
||
" for date, group in grouped_data:\n",
|
||
" # 准备输入数据\n",
|
||
" X_batch = np.stack(group['features'].values) # 形状 (num_stocks, num_features)\n",
|
||
"\n",
|
||
" X_batch = torch.tensor(X_batch, dtype=torch.float32).to(device)\n",
|
||
"\n",
|
||
" if X_batch.dim() == 2:\n",
|
||
" X_batch = X_batch.unsqueeze(0) # 形状变为 (1, num_stocks, num_features)\n",
|
||
"\n",
|
||
" # 预测\n",
|
||
" outputs = model(X_batch)\n",
|
||
" probs = torch.softmax(outputs, dim=-1).cpu().numpy() # 转换为概率\n",
|
||
"\n",
|
||
" # 取最后一列作为 score(假设最后一列是目标类别)\n",
|
||
" scores = probs[0, :, -1] # 形状 (num_stocks,)\n",
|
||
"\n",
|
||
" # 收集结果\n",
|
||
" for i in range(len(group)):\n",
|
||
" results.append({\n",
|
||
" 'trade_date': date,\n",
|
||
" 'score': scores[i],\n",
|
||
" 'ts_code': group['ts_code'].values[i],\n",
|
||
" })\n",
|
||
" # 保存为 DataFrame\n",
|
||
" result_df = pd.DataFrame(results)\n",
|
||
" result_df = result_df.loc[result_df.groupby('trade_date')['score'].idxmax()]\n",
|
||
" result_df.to_csv(output_path, index=False)\n",
|
||
" print(f\"预测结果已保存至 {output_path}\")\n",
|
||
"\n",
|
||
"\n",
|
||
"predict_and_save(\n",
|
||
" model=model,\n",
|
||
" test_data=test_data,\n",
|
||
" cat_columns=cat_columns,\n",
|
||
" feature_columns=feature_columns_new,\n",
|
||
" label_encoders=label_encoders,\n",
|
||
" scaler=scaler,\n",
|
||
" output_path=\"predictions_test.tsv\"\n",
|
||
")"
|
||
],
|
||
"id": "5d1522a7538db91b",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"预测结果已保存至 predictions_test.tsv\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 38
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.11"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|