Files
NewStock/main/train/TRank.ipynb
2025-04-28 11:02:52 +08:00

1718 lines
85 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-19T14:44:14.217154Z",
"start_time": "2025-03-19T14:44:13.935883Z"
}
},
"source": [
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import pandas as pd\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"pd.set_option('display.max_columns', None)\n"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-03-19T14:45:19.808772Z",
"start_time": "2025-03-19T14:44:14.221159Z"
}
},
"source": [
"from code.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print('cyq perf')\n",
"df = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf',\n",
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct',\n",
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
" df=df)\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"cyq perf\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8418207 entries, 0 to 8418206\n",
"Data columns (total 31 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 pct_chg float64 \n",
" 8 turnover_rate float64 \n",
" 9 pe_ttm float64 \n",
" 10 circ_mv float64 \n",
" 11 volume_ratio float64 \n",
" 12 is_st bool \n",
" 13 up_limit float64 \n",
" 14 down_limit float64 \n",
" 15 buy_sm_vol float64 \n",
" 16 sell_sm_vol float64 \n",
" 17 buy_lg_vol float64 \n",
" 18 sell_lg_vol float64 \n",
" 19 buy_elg_vol float64 \n",
" 20 sell_elg_vol float64 \n",
" 21 net_mf_vol float64 \n",
" 22 his_low float64 \n",
" 23 his_high float64 \n",
" 24 cost_5pct float64 \n",
" 25 cost_15pct float64 \n",
" 26 cost_50pct float64 \n",
" 27 cost_85pct float64 \n",
" 28 cost_95pct float64 \n",
" 29 weight_avg float64 \n",
" 30 winner_rate float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(28), object(1)\n",
"memory usage: 1.9+ GB\n",
"None\n"
]
}
],
"execution_count": 2
},
{
"cell_type": "code",
"id": "cac01788dac10678",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-19T14:45:35.459083Z",
"start_time": "2025-03-19T14:45:20.342424Z"
}
},
"source": [
"print('industry')\n",
"industry_df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code', 'in_date'],\n",
" df=None, on=['ts_code'], join='left')\n",
"\n",
"\n",
"def merge_with_industry_data(df, industry_df):\n",
" # 确保日期字段是 datetime 类型\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
"\n",
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
"\n",
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
"\n",
" # 使用 merge_asof 进行向后合并\n",
" merged = pd.merge_asof(\n",
" df_sorted,\n",
" industry_df_sorted,\n",
" by='ts_code', # 按 ts_code 分组\n",
" left_on='trade_date',\n",
" right_on='in_date',\n",
" direction='backward'\n",
" )\n",
"\n",
" # 获取每个 ts_code 的最早 in_date 记录\n",
" min_in_date_per_ts = (industry_df_sorted\n",
" .groupby('ts_code')\n",
" .first()\n",
" .reset_index()[['ts_code', 'l2_code']])\n",
"\n",
" # 填充未匹配到的记录trade_date 早于所有 in_date 的情况)\n",
" merged['l2_code'] = merged['l2_code'].fillna(\n",
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
" )\n",
"\n",
" # 保留需要的列并重置索引\n",
" result = merged.reset_index(drop=True)\n",
" return result\n",
"\n",
"\n",
"# 使用示例\n",
"df = merge_with_industry_data(df, industry_df)\n",
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n"
]
}
],
"execution_count": 3
},
{
"cell_type": "code",
"id": "c4e9e1d31da6dba6",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-19T14:45:35.759821Z",
"start_time": "2025-03-19T14:45:35.492414Z"
}
},
"source": [
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" df = df.sort_values('trade_date')\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" # 4. 情绪因子1市场上涨比例Up Ratio\n",
" df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)\n",
" df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例\n",
"\n",
" # 5. 情绪因子2成交量变化率Volume Change Rate\n",
" df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量\n",
" df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率\n",
"\n",
" # 6. 情绪因子3波动率Volatility\n",
" df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差\n",
"\n",
" # 7. 情绪因子4成交额变化率Amount Change Rate\n",
" df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额\n",
" df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率\n",
"\n",
" return df\n",
"\n",
"\n",
"def generate_index_indicators(h5_filename):\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" df = df.sort_values('trade_date')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
" df_index = calculate_indicators(df_index)\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', 'RSI', 'MACD', 'Signal_line',\n",
" 'MACD_hist', 'up_ratio_20d', 'volume_change_rate', 'volatility',\n",
" 'amount_change_rate', 'amount_mean'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
"h5_filename = '../../data/index_data.h5'\n",
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n"
],
"outputs": [],
"execution_count": 4
},
{
"cell_type": "code",
"id": "a735bc02ceb4d872",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-19T14:45:35.904531Z",
"start_time": "2025-03-19T14:45:35.790296Z"
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"def get_rolling_factor(df):\n",
" old_columns = df.columns.tolist()[:]\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" df['return_skew'] = grouped['pct_chg'].rolling(window=5).skew().reset_index(0, drop=True)\n",
" df['return_kurtosis'] = grouped['pct_chg'].rolling(window=5).kurt().reset_index(0, drop=True)\n",
"\n",
" # 因子 1短期成交量变化率\n",
" df['volume_change_rate'] = (\n",
" grouped['vol'].rolling(window=2).mean() /\n",
" grouped['vol'].rolling(window=10).mean() - 1\n",
" ).reset_index(level=0, drop=True) # 确保索引对齐\n",
"\n",
" # 因子 2成交量突破信号\n",
" max_volume = grouped['vol'].rolling(window=5).max().reset_index(level=0, drop=True) # 确保索引对齐\n",
" df['cat_volume_breakout'] = (df['vol'] > max_volume)\n",
"\n",
" # 因子 3换手率均线偏离度\n",
" mean_turnover = grouped['turnover_rate'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
" std_turnover = grouped['turnover_rate'].rolling(window=3).std().reset_index(level=0, drop=True)\n",
" df['turnover_deviation'] = (df['turnover_rate'] - mean_turnover) / std_turnover\n",
"\n",
" # 因子 4换手率激增信号\n",
" df['cat_turnover_spike'] = (df['turnover_rate'] > mean_turnover + 2 * std_turnover)\n",
"\n",
" # 因子 5量比均值\n",
" df['avg_volume_ratio'] = grouped['volume_ratio'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
"\n",
" # 因子 6量比突破信号\n",
" max_volume_ratio = grouped['volume_ratio'].rolling(window=5).max().reset_index(level=0, drop=True)\n",
" df['cat_volume_ratio_breakout'] = (df['volume_ratio'] > max_volume_ratio)\n",
"\n",
" df['vol_spike'] = grouped.apply(\n",
" lambda x: pd.Series(x['vol'].rolling(20).mean(), index=x.index)\n",
" )\n",
" df['vol_std_5'] = df['vol'].pct_change().rolling(5).std()\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
"\n",
" # df['rsi_3'] = grouped.apply(\n",
" # lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" # )\n",
" # df['rsi_6'] = grouped.apply(\n",
" # lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" # )\n",
" # df['rsi_9'] = grouped.apply(\n",
" # lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" # )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算 EMA 指标\n",
" df['_ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['_ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['_ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['_ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['_ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['_ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['_ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['_ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
"\n",
" def rolling_covariance(x, y, window):\n",
" return x.rolling(window).cov(y)\n",
"\n",
" def delta(series, period):\n",
" return series.diff(period)\n",
"\n",
" def rank(series):\n",
" return series.rank(pct=True)\n",
"\n",
" def stddev(series, window):\n",
" return series.rolling(window).std()\n",
"\n",
" window_high_volume = 5\n",
" window_close_stddev = 20\n",
" period_delta = 5\n",
" df['cov'] = rolling_covariance(df['high'], df['vol'], window_high_volume)\n",
" df['delta_cov'] = delta(df['cov'], period_delta)\n",
" df['_rank_stddev'] = rank(stddev(df['close'], window_close_stddev))\n",
" df['alpha_22_improved'] = -1 * df['delta_cov'] * df['_rank_stddev']\n",
"\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" df['cat_up_limit'] = (df['close'] == df['up_limit']) # 是否涨停1表示涨停0表示未涨停\n",
" df['cat_down_limit'] = (df['close'] == df['down_limit']) # 是否跌停1表示跌停0表示未跌停\n",
" df['up_limit_count_10d'] = grouped['cat_up_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
" df['down_limit_count_10d'] = grouped['cat_down_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
"\n",
" # 3. 最近连续涨跌停天数\n",
" def calculate_consecutive_limits(series):\n",
" \"\"\"\n",
" 计算连续涨停/跌停天数。\n",
" \"\"\"\n",
" consecutive_up = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" consecutive_down = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" return consecutive_up, consecutive_down\n",
"\n",
" # 连续涨停天数\n",
" df['consecutive_up_limit'] = grouped['cat_up_limit'].apply(\n",
" lambda x: calculate_consecutive_limits(x)[0]\n",
" ).reset_index(level=0, drop=True)\n",
"\n",
" df['vol_break'] = np.where((df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2), 1, 0)\n",
"\n",
" df['weight_roc5'] = grouped['weight_avg'].apply(lambda x: x.pct_change(5))\n",
"\n",
" def rolling_corr(group):\n",
" roc_close = group['close'].pct_change()\n",
" roc_weight = group['weight_avg'].pct_change()\n",
" return roc_close.rolling(10).corr(roc_weight)\n",
"\n",
" df['price_cost_divergence'] = grouped.apply(rolling_corr)\n",
"\n",
" df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])\n",
"\n",
" # 16. 筹码稳定性指数 (20日波动率)\n",
" df['weight_std20'] = grouped['weight_avg'].apply(lambda x: x.rolling(20).std())\n",
" df['cost_stability'] = df['weight_std20'] / grouped['weight_avg'].transform(lambda x: x.rolling(20).mean())\n",
"\n",
" # 17. 成本区间突破标记\n",
" df['high_cost_break_days'] = grouped.apply(lambda g: g['close'].gt(g['cost_95pct']).rolling(5).sum())\n",
"\n",
" # 20. 筹码-流动性风险\n",
" df['liquidity_risk'] = (df['cost_95pct'] - df['cost_5pct']) * (\n",
" 1 / grouped['vol'].transform(lambda x: x.rolling(10).mean()))\n",
"\n",
" # 7. 市值波动率因子\n",
" df['turnover_std'] = grouped['turnover_rate'].rolling(window=20).std().reset_index(level=0, drop=True)\n",
" df['mv_volatility'] = grouped.apply(lambda x: x['turnover_std'] / x['circ_mv']).reset_index(level=0, drop=True)\n",
"\n",
" # 8. 市值成长性因子\n",
" df['volume_growth'] = grouped['vol'].pct_change(periods=20).reset_index(level=0, drop=True)\n",
" df['mv_growth'] = grouped.apply(lambda x: x['volume_growth'] / x['circ_mv']).reset_index(level=0, drop=True)\n",
"\n",
" df.drop(columns=['weight_std20'], inplace=True, errors='ignore')\n",
" new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]\n",
"\n",
" return df, new_columns\n",
"\n",
"\n",
"def get_simple_factor(df):\n",
" old_columns = df.columns.tolist()[:]\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" alpha = 0.5\n",
" df['momentum_factor'] = df['volume_change_rate'] + alpha * df['turnover_deviation']\n",
" df['resonance_factor'] = df['volume_ratio'] * df['pct_chg']\n",
" df['log_close'] = np.log(df['close'])\n",
"\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
"\n",
" df['ctrl_strength'] = (df['cost_85pct'] - df['cost_15pct']) / (df['his_high'] - df['his_low'])\n",
"\n",
" df['low_cost_dev'] = (df['close'] - df['cost_5pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
"\n",
" df['asymmetry'] = (df['cost_95pct'] - df['cost_50pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
"\n",
" df['lock_factor'] = df['turnover_rate'] * (\n",
" 1 - (df['cost_95pct'] - df['cost_5pct']) / (df['his_high'] - df['his_low']))\n",
"\n",
" df['cat_vol_break'] = (df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2)\n",
"\n",
" df['cost_atr_adj'] = (df['cost_95pct'] - df['cost_5pct']) / df['atr_14']\n",
"\n",
" # 12. 小盘股筹码集中度\n",
" df['smallcap_concentration'] = (1 / df['circ_mv']) * (df['cost_85pct'] - df['cost_15pct'])\n",
"\n",
" df['cat_golden_resonance'] = ((df['close'] > df['weight_avg']) &\n",
" (df['volume_ratio'] > 1.5) &\n",
" (df['winner_rate'] > 0.7))\n",
"\n",
" df['mv_turnover_ratio'] = df['turnover_rate'] / df['circ_mv']\n",
"\n",
" df['mv_adjusted_volume'] = df['vol'] / df['circ_mv']\n",
"\n",
" df['mv_weighted_turnover'] = df['turnover_rate'] * (1 / df['circ_mv'])\n",
"\n",
" df['nonlinear_mv_volume'] = df['vol'] / df['circ_mv']\n",
"\n",
" df['mv_volume_ratio'] = df['volume_ratio'] / df['circ_mv']\n",
"\n",
" df['mv_momentum'] = df['turnover_rate'] * df['volume_ratio'] / df['circ_mv']\n",
"\n",
" drop_columns = [col for col in df.columns if col.startswith('_')]\n",
" df.drop(columns=drop_columns, inplace=True, errors='ignore')\n",
"\n",
" new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]\n",
" return df, new_columns\n"
],
"outputs": [],
"execution_count": 5
},
{
"cell_type": "code",
"id": "53f86ddc0677a6d7",
"metadata": {
"jupyter": {
"source_hidden": true
},
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-03-19T14:45:43.390734Z",
"start_time": "2025-03-19T14:45:35.933466Z"
}
},
"source": [
"from code.utils.factor import get_act_factor\n",
"\n",
"\n",
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" industry_data = get_act_factor(industry_data, cat=False)\n",
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" # factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
" # \n",
" # for factor in factor_columns:\n",
" # if factor in industry_data.columns:\n",
" # # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
" # industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
" # lambda x: x - x.mean())\n",
"\n",
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
"\n",
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
],
"outputs": [],
"execution_count": 6
},
{
"cell_type": "code",
"id": "dbe2fd8021b9417f",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T14:45:43.435718Z",
"start_time": "2025-03-19T14:45:43.429472Z"
}
},
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if\n",
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n",
"origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
"print(origin_columns)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ts_code', 'open', 'close', 'high', 'low', 'circ_mv', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'in_date']\n"
]
}
],
"execution_count": 7
},
{
"cell_type": "code",
"id": "85c3e3d0235ffffa",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T14:45:43.684712Z",
"start_time": "2025-03-19T14:45:43.492971Z"
}
},
"source": [
"print(df[df['is_st']][['ts_code', 'trade_date', 'is_st']])"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ts_code trade_date is_st\n",
"29 000037.SZ 2017-01-03 True\n",
"72 000408.SZ 2017-01-03 True\n",
"95 000504.SZ 2017-01-03 True\n",
"96 000505.SZ 2017-01-03 True\n",
"101 000511.SZ 2017-01-03 True\n",
"... ... ... ...\n",
"8417201 603869.SH 2025-03-13 True\n",
"8417206 603879.SH 2025-03-13 True\n",
"8417253 603959.SH 2025-03-13 True\n",
"8417635 688282.SH 2025-03-13 True\n",
"8417639 688287.SH 2025-03-13 True\n",
"\n",
"[191519 rows x 3 columns]\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "code",
"id": "92d84ce15a562ec6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T14:47:42.395141Z",
"start_time": "2025-03-19T14:45:43.718152Z"
}
},
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df[df['trade_date'] >= '20180101']\n",
" df = df.drop(columns=[col for col in df.columns if col.startswith('_')])\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"# df = get_technical_factor(df)\n",
"# df = get_act_factor(df)\n",
"# df = get_money_flow_factor(df)\n",
"# df = get_alpha_factor(df)\n",
"# df = get_limit_factor(df)\n",
"# df = get_cyp_perf_factor(df)\n",
"# df = get_mv_factors(df)\n",
"df, _ = get_rolling_factor(df)\n",
"\n",
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 5084263 entries, 0 to 5084262\n",
"Data columns (total 85 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 pct_chg float64 \n",
" 8 turnover_rate float64 \n",
" 9 pe_ttm float64 \n",
" 10 circ_mv float64 \n",
" 11 volume_ratio float64 \n",
" 12 is_st bool \n",
" 13 up_limit float64 \n",
" 14 down_limit float64 \n",
" 15 buy_sm_vol float64 \n",
" 16 sell_sm_vol float64 \n",
" 17 buy_lg_vol float64 \n",
" 18 sell_lg_vol float64 \n",
" 19 buy_elg_vol float64 \n",
" 20 sell_elg_vol float64 \n",
" 21 net_mf_vol float64 \n",
" 22 his_low float64 \n",
" 23 his_high float64 \n",
" 24 cost_5pct float64 \n",
" 25 cost_15pct float64 \n",
" 26 cost_50pct float64 \n",
" 27 cost_85pct float64 \n",
" 28 cost_95pct float64 \n",
" 29 weight_avg float64 \n",
" 30 winner_rate float64 \n",
" 31 cat_l2_code object \n",
" 32 in_date datetime64[ns]\n",
" 33 return_skew float64 \n",
" 34 return_kurtosis float64 \n",
" 35 volume_change_rate float64 \n",
" 36 cat_volume_breakout bool \n",
" 37 turnover_deviation float64 \n",
" 38 cat_turnover_spike bool \n",
" 39 avg_volume_ratio float64 \n",
" 40 cat_volume_ratio_breakout bool \n",
" 41 vol_spike float64 \n",
" 42 vol_std_5 float64 \n",
" 43 obv float64 \n",
" 44 maobv_6 float64 \n",
" 45 return_5 float64 \n",
" 46 return_10 float64 \n",
" 47 return_20 float64 \n",
" 48 std_return_5 float64 \n",
" 49 std_return_15 float64 \n",
" 50 std_return_25 float64 \n",
" 51 std_return_90 float64 \n",
" 52 std_return_90_2 float64 \n",
" 53 _ema_5 float64 \n",
" 54 _ema_13 float64 \n",
" 55 _ema_20 float64 \n",
" 56 _ema_60 float64 \n",
" 57 act_factor1 float64 \n",
" 58 act_factor2 float64 \n",
" 59 act_factor3 float64 \n",
" 60 act_factor4 float64 \n",
" 61 log(circ_mv) float64 \n",
" 62 cov float64 \n",
" 63 delta_cov float64 \n",
" 64 _rank_stddev float64 \n",
" 65 alpha_22_improved float64 \n",
" 66 alpha_003 float64 \n",
" 67 alpha_007 float64 \n",
" 68 alpha_013 float64 \n",
" 69 cat_up_limit bool \n",
" 70 cat_down_limit bool \n",
" 71 up_limit_count_10d float64 \n",
" 72 down_limit_count_10d float64 \n",
" 73 consecutive_up_limit int64 \n",
" 74 vol_break int32 \n",
" 75 weight_roc5 float64 \n",
" 76 price_cost_divergence float64 \n",
" 77 smallcap_concentration float64 \n",
" 78 cost_stability float64 \n",
" 79 high_cost_break_days float64 \n",
" 80 liquidity_risk float64 \n",
" 81 turnover_std float64 \n",
" 82 mv_volatility float64 \n",
" 83 volume_growth float64 \n",
" 84 mv_growth float64 \n",
"dtypes: bool(6), datetime64[ns](2), float64(73), int32(1), int64(1), object(2)\n",
"memory usage: 3.0+ GB\n",
"None\n"
]
}
],
"execution_count": 9
},
{
"cell_type": "code",
"id": "f4f16d63ad18d1bc",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-19T14:47:42.645061Z",
"start_time": "2025-03-19T14:47:42.637235Z"
}
},
"source": [
"def create_deviation_within_dates(df, feature_columns):\n",
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
" new_columns = {}\n",
" ret_feature_columns = feature_columns[:]\n",
"\n",
" # 自动选择所有数值型特征\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
"\n",
" # num_features = ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'cat_vol_spike', 'obv', 'maobv_6', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013']\n",
" num_features = [col for col in num_features if 'cat' not in col and 'industry' not in col]\n",
" num_features = [col for col in num_features if 'limit' not in col]\n",
" num_features = [col for col in num_features if 'cyq' not in col]\n",
"\n",
" # 遍历所有数值型特征\n",
" for feature in num_features:\n",
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
" continue\n",
"\n",
" # grouped_mean = df.groupby(['trade_date'])[feature].transform('mean')\n",
" # deviation_col_name = f'deviation_mean_{feature}'\n",
" # new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" # ret_feature_columns.append(deviation_col_name)\n",
"\n",
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
" deviation_col_name = f'deviation_mean_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
"\n",
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
"\n",
" return df, ret_feature_columns\n"
],
"outputs": [],
"execution_count": 10
},
{
"cell_type": "code",
"id": "40e6b68a91b30c79",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T14:47:42.840210Z",
"start_time": "2025-03-19T14:47:42.740146Z"
}
},
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
" log=True):\n",
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
"\n",
" # Calculate lower and upper bounds based on percentiles\n",
" lower_bound = label.quantile(lower_percentile)\n",
" upper_bound = label.quantile(upper_percentile)\n",
"\n",
" # Filter out values outside the bounds\n",
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
"\n",
" # Print the number of removed outliers\n",
" if log:\n",
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
" return filtered_label\n",
"\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
" df['future_return'] = (df['future_close'] - df['future_open']) / df['future_open']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" sharpe_ratio = df['future_return'] * df['future_volatility']\n",
" sharpe_ratio.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return sharpe_ratio\n",
"\n",
"\n",
"def calculate_score(df, days=5, lambda_param=1.0):\n",
" def calculate_max_drawdown(prices):\n",
" peak = prices.iloc[0] # 初始化峰值\n",
" max_drawdown = 0 # 初始化最大回撤\n",
"\n",
" for price in prices:\n",
" if price > peak:\n",
" peak = price # 更新峰值\n",
" else:\n",
" drawdown = (peak - price) / peak # 计算当前回撤\n",
" max_drawdown = max(max_drawdown, drawdown) # 更新最大回撤\n",
"\n",
" return max_drawdown\n",
"\n",
" def compute_stock_score(stock_df):\n",
" stock_df = stock_df.sort_values(by=['trade_date'])\n",
" future_return = stock_df['future_return']\n",
" volatility = stock_df['close'].pct_change().rolling(days).std().shift(-days)\n",
" max_drawdown = stock_df['close'].rolling(days).apply(calculate_max_drawdown, raw=False).shift(-days)\n",
" score = future_return - lambda_param * max_drawdown\n",
"\n",
" return score\n",
"\n",
" scores = df.groupby('ts_code').apply(lambda x: compute_stock_score(x))\n",
" scores = scores.reset_index(level=0, drop=True)\n",
"\n",
" return scores\n",
"\n",
"\n",
"import gc\n",
"\n",
"gc.collect()"
],
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 11
},
{
"cell_type": "code",
"id": "47c12bb34062ae7a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T15:54:02.848116Z",
"start_time": "2025-03-19T15:51:53.055184Z"
}
},
"source": [
"days = 5\n",
"\n",
"import gc\n",
"\n",
"gc.collect()\n",
"\n",
"df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"\n",
"\n",
"# df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / \\\n",
"# df.groupby('ts_code')['open'].shift(-1)\n",
"\n",
"\n",
"def symmetric_log_transform(values):\n",
" return np.sign(values) * np.log1p(np.abs(values))\n",
"\n",
"\n",
"# df['future_score'] = df['future_return']\n",
"df['future_score'] = calculate_score(df, days=days, lambda_param=0.5)\n",
"# df['future_score'] = df.groupby('ts_code')['future_return'].shift(-1).rolling(window=2).mean()\n",
"df['future_score'] = symmetric_log_transform(df['future_score'])\n",
"\n",
"# df['label'] = remove_outliers_label_percentile(df['label'])\n",
"train_data = df[(df['trade_date'] <= '2024-01-01') & (df['trade_date'] >= '2000-01-01')]\n",
"test_data = df[(df['trade_date'] >= '2024-01-01')]\n",
"\n",
"\n",
"def select_pre_zt_stocks_dynamic(\n",
" stock_df,\n",
" vol_spike_multiplier=1.5,\n",
" min_return=0.03, # 最小累计涨幅(例如 3%\n",
" min_main_net_inflow=1e6, # 最小主力资金净流入(例如 100 万元)\n",
" window=30, # 计算历史均值的窗口大小\n",
" signal_days=1 # 异动信号需要连续出现的天数\n",
"):\n",
" # 排序数据\n",
" stock_df = stock_df.sort_values(by=['trade_date', 'ts_code'])\n",
" stock_df = stock_df.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nlargest(512, 'return_20')\n",
" )\n",
"\n",
" return stock_df\n",
"\n",
"\n",
"train_data = select_pre_zt_stocks_dynamic(train_data)\n",
"test_data = select_pre_zt_stocks_dynamic(test_data)\n",
"# # train_data = train_data[train_data['circ_mv'] < 2000000]\n",
"# # test_data = test_data[test_data['circ_mv'] < 2000000]\n",
"#\n",
"# train_data = train_data[train_data['vol'] > 1.5 * train_data['vol_spike']]\n",
"# test_data = test_data[test_data['vol'] > 1.5 * test_data['vol_spike']]\n",
"# train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
"# lambda x: x.nsmallest(500, 'return_20')\n",
"# )\n",
"#\n",
"# test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
"# lambda x: x.nsmallest(500, 'return_20')\n",
"# )\n",
"\n",
"# train_data, _ = get_simple_factor(train_data)\n",
"# test_data, _ = get_simple_factor(test_data)\n",
"\n",
"train_data['label'] = train_data.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" lambda x: pd.qcut(x, q=5, labels=False, duplicates='drop')\n",
")\n",
"test_data['label'] = test_data.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" lambda x: pd.qcut(x, q=5, labels=False, duplicates='drop')\n",
")\n",
"# train_data['label'] = train_data['label'] == 4\n",
"# test_data['label'] = test_data['label'] == 4\n",
"\n",
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
"index_data = index_data.sort_values(by=['trade_date'])\n",
"\n",
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"\n",
"feature_columns = [col for col in train_data.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"print(feature_columns)\n",
"\n",
"feature_columns_new = feature_columns[:]\n",
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"\n",
"train_data = train_data.dropna(subset=feature_columns_new)\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"\n",
"# print(test_data.tail())\n",
"test_data = test_data.dropna(subset=feature_columns_new)\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"\n",
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')\n",
"\n",
"\n",
"# feature_columns_new.remove('cat_l2_code')"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['vol', 'pct_chg', 'turnover_rate', 'pe_ttm', 'volume_ratio', 'winner_rate', 'cat_l2_code', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'obv', 'maobv_6', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'log(circ_mv)', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'cat_up_limit', 'cat_down_limit', 'up_limit_count_10d', 'down_limit_count_10d', 'consecutive_up_limit', 'vol_break', 'weight_roc5', 'price_cost_divergence', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_ratio_20d', '000852.SH_volatility', '000905.SH_volatility', '399006.SZ_volatility', '000852.SH_volume_change_rate', '000905.SH_volume_change_rate', '399006.SZ_volume_change_rate']\n",
"feature_columns size: 177\n",
"541133\n",
"最小日期: 2018-06-04\n",
"最大日期: 2023-12-29\n",
"106062\n",
"最小日期: 2024-01-02\n",
"最大日期: 2025-03-13\n"
]
}
],
"execution_count": 40
},
{
"cell_type": "code",
"id": "1c46817a-b5dd-4bec-8bb4-e6e80bfd9d66",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T15:54:03.016065Z",
"start_time": "2025-03-19T15:54:03.011956Z"
}
},
"source": [
"# test_data = df[(df['trade_date'] >= '2024-04-01')]\n",
"# test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# test_data = test_data.merge(index_data, on='trade_date', how='left')"
],
"outputs": [],
"execution_count": 41
},
{
"cell_type": "code",
"id": "da2bb202843d9275",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T15:54:24.393873Z",
"start_time": "2025-03-19T15:54:03.049056Z"
}
},
"source": [
"def remove_highly_correlated_features(df, feature_columns, threshold=0.9):\n",
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
" if not numeric_features:\n",
" raise ValueError(\"No numeric features found in the provided data.\")\n",
"\n",
" corr_matrix = df[numeric_features].corr().abs()\n",
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
" remaining_features = [col for col in feature_columns if col not in to_drop\n",
" or 'act' in col or 'af' in col]\n",
" return remaining_features\n",
"\n",
"\n",
"feature_columns_new = remove_highly_correlated_features(train_data, feature_columns_new)\n",
"keep_columns = [col for col in train_data.columns if\n",
" col in feature_columns_new or col in ['ts_code', 'trade_date', 'label', 'future_return',\n",
" 'future_score']]\n",
"train_data = train_data[keep_columns]\n",
"# test_data = test_data[keep_columns]\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"# print(feature_columns_new)\n",
"# 2. 按 trade_date 分组"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feature_columns size: 154\n"
]
}
],
"execution_count": 42
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T15:54:24.499139Z",
"start_time": "2025-03-19T15:54:24.490686Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"\n",
"num_heads = 4 # Transformer头数\n",
"transformer_input_dim = 64\n",
"num_layers = 2 # Transformer层数\n",
"hidden_dim = 64 # 隐藏层维度\n",
"num_epochs = 10 # 训练轮数\n",
"learning_rate = 0.001 # 学习率\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"validation_days = 180\n",
"\n",
"print(device)"
],
"id": "8113bba62de693cc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n"
]
}
],
"execution_count": 43
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T15:57:09.407047Z",
"start_time": "2025-03-19T15:54:24.596522Z"
}
},
"cell_type": "code",
"source": [
"from tqdm import tqdm\n",
"import torch.optim as optim\n",
"from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
"from sklearn.metrics import accuracy_score\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"class StockPredictionModel(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim, num_classes, num_heads, num_layers, transformer_input_dim):\n",
" super(StockPredictionModel, self).__init__()\n",
"\n",
" # 输入全连接层:将原始输入特征映射到 Transformer 的输入维度\n",
" self.fc_input = nn.Sequential(\n",
" nn.Linear(input_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.3),\n",
" nn.Linear(hidden_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.3),\n",
" nn.Linear(hidden_dim, hidden_dim),\n",
" )\n",
"\n",
" # Transformer Encoder\n",
" self.transformer = nn.Transformer(\n",
" d_model=transformer_input_dim,\n",
" nhead=num_heads,\n",
" num_encoder_layers=num_layers,\n",
" batch_first=True\n",
" )\n",
"\n",
" self.dropout = nn.Dropout(0.3)\n",
"\n",
" # 全连接层\n",
" self.fc_output = nn.Sequential(\n",
" nn.Linear(transformer_input_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.3),\n",
" nn.Linear(hidden_dim, hidden_dim),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.3),\n",
" nn.Linear(hidden_dim, num_classes),\n",
" )\n",
"\n",
" def forward(self, x):\n",
" # x: (batch_size, num_stocks, input_dim)\n",
"\n",
" # 输入全连接层处理\n",
" x = (self.fc_input(x))\n",
"\n",
" # Transformer处理\n",
" transformer_out = self.transformer(x, x) # Self-attention\n",
"\n",
" # 全连接层处理\n",
" out = self.fc_output(transformer_out)\n",
"\n",
" return out\n",
"\n",
"\n",
"# 数据预处理函数\n",
"def preprocess_data(data, cat_columns, feature_columns, label_encoders=None, scaler=None):\n",
" \"\"\"\n",
" 预处理数据:\n",
" - 对分类特征进行 LabelEncoder 编码。\n",
" - 对数值特征进行标准化。\n",
" - 将数据按天分组并转换为 NumPy 数组。\n",
" \"\"\"\n",
" # 分离分类特征和数值特征\n",
" numeric_cols = [col for col in feature_columns if col not in cat_columns]\n",
"\n",
" # 初始化编码器和标准化器(如果是训练阶段)\n",
" if label_encoders is None:\n",
" label_encoders = {col: LabelEncoder() for col in cat_columns}\n",
" X_cat = np.array([label_encoders[col].fit_transform(data[col]) for col in cat_columns]).T\n",
" else:\n",
" X_cat = np.array([label_encoders[col].transform(data[col]) for col in cat_columns]).T\n",
" if scaler is None:\n",
" scaler = StandardScaler()\n",
" X_num = scaler.fit_transform(data[numeric_cols])\n",
" else:\n",
" X_num = scaler.transform(data[numeric_cols])\n",
"\n",
"\n",
" # 处理分类特征\n",
"\n",
" # 处理数值特征\n",
"\n",
" # 合并特征\n",
" X_processed = np.hstack([X_num, X_cat])\n",
"\n",
" # 将处理后的数据与日期对齐\n",
" processed_data = pd.DataFrame(X_processed, columns=numeric_cols + cat_columns)\n",
" processed_data['trade_date'] = data['trade_date'].values # 保留日期列\n",
" processed_data['label'] = data['label'].values # 保留标签列\n",
"\n",
" # 按天分组\n",
" grouped_X = []\n",
" grouped_y = []\n",
" for date, group in processed_data.groupby('trade_date'):\n",
" X_day = group[numeric_cols + cat_columns].values # 使用预处理后的特征\n",
" y_day = group['label'].values # 使用原始标签\n",
" grouped_X.append(X_day)\n",
" grouped_y.append(y_day)\n",
"\n",
" # 获取每个分类特征的最大类别数(用于嵌入层)\n",
" cat_dims = [len(label_encoders[col].classes_) for col in cat_columns]\n",
"\n",
" return grouped_X, grouped_y, label_encoders, scaler, cat_dims\n",
"\n",
"\n",
"# 训练函数(逐天训练)\n",
"def train_by_day(model, X_train, y_train, criterion, optimizer, device):\n",
" model.train()\n",
" total_loss = 0\n",
" all_preds, all_labels = [], []\n",
"\n",
" for day_idx in range(len(X_train)):\n",
" # 取出当天的数据\n",
" X_batch = torch.tensor(X_train[day_idx], dtype=torch.float32).to(device)\n",
" y_batch = torch.tensor(y_train[day_idx], dtype=torch.long).to(device)\n",
"\n",
" # 添加 batch 维度(从 2D 变为 3D\n",
" if X_batch.dim() == 2:\n",
" X_batch = X_batch.unsqueeze(0) # 形状变为 (1, num_stocks, num_features)\n",
"\n",
" # 前向传播\n",
" outputs = model(X_batch) # (batch_size, num_stocks, num_classes)\n",
" loss = criterion(outputs.view(-1, outputs.size(-1)), y_batch.view(-1))\n",
"\n",
" # 反向传播\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # 统计损失和预测结果\n",
" total_loss += loss.item()\n",
" preds = torch.argmax(outputs, dim=-1).cpu().numpy()\n",
" labels = y_batch.cpu().numpy()\n",
" all_preds.extend(preds.flatten())\n",
" all_labels.extend(labels.flatten())\n",
"\n",
" avg_loss = total_loss / len(X_train)\n",
" accuracy = accuracy_score(all_labels, all_preds)\n",
" return avg_loss, accuracy\n",
"\n",
"\n",
"# 验证/测试函数(逐天验证)\n",
"def evaluate_by_day(model, X_val, y_val, criterion, device):\n",
" model.eval()\n",
" total_loss = 0\n",
" all_preds, all_labels = [], []\n",
"\n",
" with torch.no_grad():\n",
" for day_idx in range(len(X_val)):\n",
" # 取出当天的数据\n",
" X_batch = torch.tensor(X_val[day_idx], dtype=torch.float32).to(device)\n",
" y_batch = torch.tensor(y_val[day_idx], dtype=torch.long).to(device)\n",
" # 添加 batch 维度(从 2D 变为 3D\n",
" if X_batch.dim() == 2:\n",
" X_batch = X_batch.unsqueeze(0) # 形状变为 (1, num_stocks, num_features)\n",
"\n",
" # 前向传播\n",
" outputs = model(X_batch)\n",
" loss = criterion(outputs.view(-1, outputs.size(-1)), y_batch.view(-1))\n",
"\n",
" # 统计损失和预测结果\n",
" total_loss += loss.item()\n",
" preds = torch.argmax(outputs, dim=-1).cpu().numpy()\n",
" labels = y_batch.cpu().numpy()\n",
" all_preds.extend(preds.flatten())\n",
" all_labels.extend(labels.flatten())\n",
"\n",
" avg_loss = total_loss / len(X_val)\n",
" accuracy = accuracy_score(all_labels, all_preds)\n",
" return avg_loss, accuracy\n",
"\n",
"\n",
"# 主程序\n",
"gc.collect()\n",
"# 数据切分\n",
"all_dates = train_data['trade_date'].unique() # 获取所有唯一的 trade_date\n",
"split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
"train_data_split = train_data[train_data['trade_date'] < split_date] # 训练集\n",
"val_data_split = train_data[train_data['trade_date'] >= split_date] # 验证集\n",
"\n",
"# 找到分类特征列\n",
"cat_columns = [col for col in train_data.columns if col.startswith(\"cat\")]\n",
"\n",
"# 预处理数据\n",
"X_train_processed, y_train_processed, label_encoders, scaler, cat_dims = preprocess_data(train_data_split, cat_columns,\n",
" feature_columns_new)\n",
"X_val_processed, y_val_processed, _, _, _ = preprocess_data(val_data_split, cat_columns, feature_columns_new,\n",
" label_encoders, scaler)\n",
"X_test_processed, y_test_processed, _, _, _ = preprocess_data(test_data, cat_columns, feature_columns_new,\n",
" label_encoders, scaler)\n",
"\n",
"# 超参数\n",
"input_dim = X_train_processed[0].shape[1] # 直接从数据获取输入维度\n",
"print(input_dim)\n",
"num_classes = len(train_data[\"label\"].unique())\n",
"\n",
"\n",
"# 模型、损失函数和优化器\n",
"model = StockPredictionModel(\n",
" input_dim=input_dim,\n",
" hidden_dim=hidden_dim,\n",
" num_classes=num_classes,\n",
" num_heads=num_heads,\n",
" num_layers=num_layers,\n",
" transformer_input_dim=transformer_input_dim,\n",
").to(device)\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
"\n",
"# 训练和验证\n",
"best_val_accuracy = 0\n",
"for epoch in tqdm(range(num_epochs)):\n",
" train_loss, train_acc = train_by_day(model, X_train_processed, y_train_processed, criterion, optimizer, device)\n",
" val_loss, val_acc = evaluate_by_day(model, X_val_processed, y_val_processed, criterion, device)\n",
"\n",
" print(f\"Epoch {epoch + 1}/{num_epochs}: \"\n",
" f\"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, \"\n",
" f\"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\")\n",
"\n",
" # 保存最佳模型\n",
" # if val_acc > best_val_accuracy:\n",
" # best_val_accuracy = val_acc\n",
" # torch.save(model.state_dict(), \"best_model.pth\")\n",
"\n",
"# 测试\n",
"# model.load_state_dict(torch.load(\"best_model.pth\"))\n",
"# test_loss, test_acc = evaluate_by_day(model, X_test_processed, y_test_processed, criterion, device)\n",
"# print(f\"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}\")"
],
"id": "beeb098799ecfa6a",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"154\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 10%|█ | 1/10 [00:26<03:58, 26.47s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10: Train Loss: 1.6094, Train Acc: 0.2037, Val Loss: 1.6091, Val Acc: 0.2088\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 20%|██ | 2/10 [00:53<03:33, 26.72s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/10: Train Loss: 1.6091, Train Acc: 0.2046, Val Loss: 1.6091, Val Acc: 0.2088\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 30%|███ | 3/10 [01:20<03:07, 26.80s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/10: Train Loss: 1.6091, Train Acc: 0.2038, Val Loss: 1.6091, Val Acc: 0.2088\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 40%|████ | 4/10 [01:44<02:35, 25.92s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/10: Train Loss: 1.6091, Train Acc: 0.2043, Val Loss: 1.6091, Val Acc: 0.2088\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 50%|█████ | 5/10 [02:08<02:05, 25.18s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/10: Train Loss: 1.6091, Train Acc: 0.2046, Val Loss: 1.6091, Val Acc: 0.2088\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 60%|██████ | 6/10 [02:35<01:42, 25.72s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6/10: Train Loss: 1.6091, Train Acc: 0.2042, Val Loss: 1.6091, Val Acc: 0.2088\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 60%|██████ | 6/10 [02:41<01:47, 26.93s/it]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
"Cell \u001B[1;32mIn[44], line 223\u001B[0m\n\u001B[0;32m 221\u001B[0m best_val_accuracy \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m0\u001B[39m\n\u001B[0;32m 222\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m epoch \u001B[38;5;129;01min\u001B[39;00m tqdm(\u001B[38;5;28mrange\u001B[39m(num_epochs)):\n\u001B[1;32m--> 223\u001B[0m train_loss, train_acc \u001B[38;5;241m=\u001B[39m train_by_day(model, X_train_processed, y_train_processed, criterion, optimizer, device)\n\u001B[0;32m 224\u001B[0m val_loss, val_acc \u001B[38;5;241m=\u001B[39m evaluate_by_day(model, X_val_processed, y_val_processed, criterion, device)\n\u001B[0;32m 226\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mEpoch \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mepoch\u001B[38;5;250m \u001B[39m\u001B[38;5;241m+\u001B[39m\u001B[38;5;250m \u001B[39m\u001B[38;5;241m1\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m/\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mnum_epochs\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m: \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 227\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mTrain Loss: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtrain_loss\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, Train Acc: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtrain_acc\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 228\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mVal Loss: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mval_loss\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m, Val Acc: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mval_acc\u001B[38;5;132;01m:\u001B[39;00m\u001B[38;5;124m.4f\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m)\n",
"Cell \u001B[1;32mIn[44], line 132\u001B[0m, in \u001B[0;36mtrain_by_day\u001B[1;34m(model, X_train, y_train, criterion, optimizer, device)\u001B[0m\n\u001B[0;32m 129\u001B[0m X_batch \u001B[38;5;241m=\u001B[39m X_batch\u001B[38;5;241m.\u001B[39munsqueeze(\u001B[38;5;241m0\u001B[39m) \u001B[38;5;66;03m# 形状变为 (1, num_stocks, num_features)\u001B[39;00m\n\u001B[0;32m 131\u001B[0m \u001B[38;5;66;03m# 前向传播\u001B[39;00m\n\u001B[1;32m--> 132\u001B[0m outputs \u001B[38;5;241m=\u001B[39m model(X_batch) \u001B[38;5;66;03m# (batch_size, num_stocks, num_classes)\u001B[39;00m\n\u001B[0;32m 133\u001B[0m loss \u001B[38;5;241m=\u001B[39m criterion(outputs\u001B[38;5;241m.\u001B[39mview(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m, outputs\u001B[38;5;241m.\u001B[39msize(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)), y_batch\u001B[38;5;241m.\u001B[39mview(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m))\n\u001B[0;32m 135\u001B[0m \u001B[38;5;66;03m# 反向传播\u001B[39;00m\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1737\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1738\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1739\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1745\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1746\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1747\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1748\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1749\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1750\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1752\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m 1753\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
"Cell \u001B[1;32mIn[44], line 54\u001B[0m, in \u001B[0;36mStockPredictionModel.forward\u001B[1;34m(self, x)\u001B[0m\n\u001B[0;32m 50\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, x):\n\u001B[0;32m 51\u001B[0m \u001B[38;5;66;03m# x: (batch_size, num_stocks, input_dim)\u001B[39;00m\n\u001B[0;32m 52\u001B[0m \n\u001B[0;32m 53\u001B[0m \u001B[38;5;66;03m# 输入全连接层处理\u001B[39;00m\n\u001B[1;32m---> 54\u001B[0m x \u001B[38;5;241m=\u001B[39m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfc_input(x))\n\u001B[0;32m 56\u001B[0m \u001B[38;5;66;03m# Transformer处理\u001B[39;00m\n\u001B[0;32m 57\u001B[0m transformer_out \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtransformer(x, x) \u001B[38;5;66;03m# Self-attention\u001B[39;00m\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1737\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1738\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1739\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1745\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1746\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1747\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1748\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1749\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1750\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1752\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m 1753\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\container.py:250\u001B[0m, in \u001B[0;36mSequential.forward\u001B[1;34m(self, input)\u001B[0m\n\u001B[0;32m 248\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m):\n\u001B[0;32m 249\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m module \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m:\n\u001B[1;32m--> 250\u001B[0m \u001B[38;5;28minput\u001B[39m \u001B[38;5;241m=\u001B[39m module(\u001B[38;5;28minput\u001B[39m)\n\u001B[0;32m 251\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28minput\u001B[39m\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1739\u001B[0m, in \u001B[0;36mModule._wrapped_call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1737\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_compiled_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs) \u001B[38;5;66;03m# type: ignore[misc]\u001B[39;00m\n\u001B[0;32m 1738\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m-> 1739\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_impl(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1750\u001B[0m, in \u001B[0;36mModule._call_impl\u001B[1;34m(self, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1745\u001B[0m \u001B[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001B[39;00m\n\u001B[0;32m 1746\u001B[0m \u001B[38;5;66;03m# this function, and just call forward.\u001B[39;00m\n\u001B[0;32m 1747\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_forward_pre_hooks\n\u001B[0;32m 1748\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_backward_pre_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_backward_hooks\n\u001B[0;32m 1749\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m _global_forward_hooks \u001B[38;5;129;01mor\u001B[39;00m _global_forward_pre_hooks):\n\u001B[1;32m-> 1750\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m forward_call(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1752\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m 1753\u001B[0m called_always_called_hooks \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mset\u001B[39m()\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\torch\\nn\\modules\\linear.py:125\u001B[0m, in \u001B[0;36mLinear.forward\u001B[1;34m(self, input)\u001B[0m\n\u001B[0;32m 124\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mforward\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m: Tensor) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m Tensor:\n\u001B[1;32m--> 125\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m F\u001B[38;5;241m.\u001B[39mlinear(\u001B[38;5;28minput\u001B[39m, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mweight, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbias)\n",
"\u001B[1;31mKeyboardInterrupt\u001B[0m: "
]
}
],
"execution_count": 44
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-19T15:57:09.420051500Z",
"start_time": "2025-03-19T15:47:37.553874Z"
}
},
"cell_type": "code",
"source": [
"def preprocess_test_data(data, cat_columns, feature_columns, label_encoders, scaler):\n",
" \"\"\"对测试数据进行预处理\"\"\"\n",
" numeric_cols = [col for col in feature_columns if col not in cat_columns]\n",
"\n",
" # 处理分类特征\n",
" X_cat = np.array([label_encoders[col].transform(data[col]) for col in cat_columns]).T\n",
"\n",
" # 处理数值特征\n",
" X_num = scaler.transform(data[numeric_cols])\n",
"\n",
" # 合并特征\n",
" X_processed = np.hstack([X_num, X_cat])\n",
"\n",
" # 保留原始 ts_code 和 trade_date\n",
" processed_data = data[['ts_code', 'trade_date']].copy()\n",
" processed_data['features'] = list(X_processed)\n",
"\n",
" return processed_data\n",
"\n",
"\n",
"def predict_and_save(model, test_data, cat_columns, feature_columns, label_encoders, scaler, output_path):\n",
" # 预处理测试数据\n",
" processed_data = preprocess_test_data(test_data, cat_columns, feature_columns, label_encoders, scaler)\n",
"\n",
" # 按天分组\n",
" grouped_data = processed_data.groupby('trade_date')\n",
"\n",
" # 存储预测结果\n",
" results = []\n",
"\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for date, group in grouped_data:\n",
" # 准备输入数据\n",
" X_batch = np.stack(group['features'].values) # 形状 (num_stocks, num_features)\n",
"\n",
" X_batch = torch.tensor(X_batch, dtype=torch.float32).to(device)\n",
"\n",
" if X_batch.dim() == 2:\n",
" X_batch = X_batch.unsqueeze(0) # 形状变为 (1, num_stocks, num_features)\n",
"\n",
" # 预测\n",
" outputs = model(X_batch)\n",
" probs = torch.softmax(outputs, dim=-1).cpu().numpy() # 转换为概率\n",
"\n",
" # 取最后一列作为 score假设最后一列是目标类别\n",
" scores = probs[0, :, -1] # 形状 (num_stocks,)\n",
"\n",
" # 收集结果\n",
" for i in range(len(group)):\n",
" results.append({\n",
" 'trade_date': date,\n",
" 'score': scores[i],\n",
" 'ts_code': group['ts_code'].values[i],\n",
" })\n",
" # 保存为 DataFrame\n",
" result_df = pd.DataFrame(results)\n",
" result_df = result_df.loc[result_df.groupby('trade_date')['score'].idxmax()]\n",
" result_df.to_csv(output_path, index=False)\n",
" print(f\"预测结果已保存至 {output_path}\")\n",
"\n",
"\n",
"predict_and_save(\n",
" model=model,\n",
" test_data=test_data,\n",
" cat_columns=cat_columns,\n",
" feature_columns=feature_columns_new,\n",
" label_encoders=label_encoders,\n",
" scaler=scaler,\n",
" output_path=\"predictions_test.tsv\"\n",
")"
],
"id": "5d1522a7538db91b",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"预测结果已保存至 predictions_test.tsv\n"
]
}
],
"execution_count": 38
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}