Files
NewStock/main/train/AnalyzeData.ipynb

2180 lines
111 KiB
Plaintext
Raw Normal View History

2025-04-28 11:02:52 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:39:30.609224Z",
"start_time": "2025-04-09T16:39:29.929606Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import pandas as pd\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"pd.set_option('display.max_columns', None)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2c66084a979c42dd",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:39:30.914968Z",
"start_time": "2025-04-09T16:39:30.858395Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"\n",
"import talib\n",
"\n",
"\n",
"def get_rolling_factor(df):\n",
" old_columns = df.columns.tolist()[:]\n",
"\n",
" # 按股票和日期排序(如果尚未排序)\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" window = 20\n",
" df['_is_positive'] = (df['pct_chg'] > 0).astype(int)\n",
" df['_is_negative'] = (df['pct_chg'] < 0).astype(int)\n",
" df['cat_is_positive'] = (df['pct_chg'] > 0).astype(int)\n",
"\n",
" # 分离正负收益率 (用于计算各自的均值和平方均值)\n",
" # 注意:这里我们保留原始收益率用于计算,而不是 clip 到 0\n",
" df['_pos_returns'] = df['pct_chg'].where(df['pct_chg'] > 0, 0) # 非正设为0便于求和\n",
" df['_neg_returns'] = df['pct_chg'].where(df['pct_chg'] < 0, 0) # 非负设为0便于求和\n",
"\n",
" # 计算收益率的平方 (用于计算 E[X^2])\n",
" df['_pos_returns_sq'] = np.square(df['_pos_returns'])\n",
" df['_neg_returns_sq'] = np.square(df['_neg_returns']) # 平方后负数变正\n",
"\n",
" # 4. 计算滚动统计量 (使用内置函数,速度较快)\n",
" # 计算正收益日的统计量\n",
" rolling_pos_count = grouped['_is_positive'].rolling(window, min_periods=max(1, window // 2)).sum()\n",
" rolling_pos_sum = grouped['_pos_returns'].rolling(window, min_periods=max(1, window // 2)).sum()\n",
" rolling_pos_sum_sq = grouped['_pos_returns_sq'].rolling(window, min_periods=max(1, window // 2)).sum()\n",
"\n",
" # 计算负收益日的统计量\n",
" rolling_neg_count = grouped['_is_negative'].rolling(window, min_periods=max(1, window // 2)).sum()\n",
" rolling_neg_sum = grouped['_neg_returns'].rolling(window, min_periods=max(1, window // 2)).sum()\n",
" rolling_neg_sum_sq = grouped['_neg_returns_sq'].rolling(window, min_periods=max(1, window // 2)).sum()\n",
"\n",
" # 5. 计算方差和标准差\n",
" pos_mean_sq = rolling_pos_sum_sq / rolling_pos_count\n",
" pos_mean = rolling_pos_sum / rolling_pos_count\n",
" pos_var = pos_mean_sq - np.square(pos_mean)\n",
" pos_var = pos_var.where(rolling_pos_count >= 2, np.nan).clip(lower=0)\n",
" upside_vol = np.sqrt(pos_var)\n",
"\n",
" neg_mean_sq = rolling_neg_sum_sq / rolling_neg_count\n",
" neg_mean = rolling_neg_sum / rolling_neg_count # 注意 neg_mean 是负数\n",
" neg_var = neg_mean_sq - np.square(neg_mean)\n",
" neg_var = neg_var.where(rolling_neg_count >= 2, np.nan).clip(lower=0)\n",
" downside_vol = np.sqrt(neg_var)\n",
"\n",
" # rolling 操作后结果带有 MultiIndex需要去除股票代码层级以便合并\n",
" df['upside_vol'] = upside_vol.reset_index(level=0, drop=True)\n",
" df['downside_vol'] = downside_vol.reset_index(level=0, drop=True)\n",
"\n",
" df['vol_ratio'] = df['upside_vol'] / df['downside_vol']\n",
" df['vol_ratio'] = df['vol_ratio'].replace([np.inf, -np.inf], np.nan).fillna(0) # 或 fillna(np.nan)\n",
"\n",
" df['return_skew'] = grouped['pct_chg'].rolling(window=5).skew().reset_index(0, drop=True)\n",
" df['return_kurtosis'] = grouped['pct_chg'].rolling(window=5).kurt().reset_index(0, drop=True)\n",
"\n",
" # 因子 1短期成交量变化率\n",
" df['volume_change_rate'] = (\n",
" grouped['vol'].rolling(window=2).mean() /\n",
" grouped['vol'].rolling(window=10).mean() - 1\n",
" ).reset_index(level=0, drop=True) # 确保索引对齐\n",
"\n",
" # 因子 2成交量突破信号\n",
" max_volume = grouped['vol'].rolling(window=5).max().reset_index(level=0, drop=True) # 确保索引对齐\n",
" df['cat_volume_breakout'] = (df['vol'] > max_volume)\n",
"\n",
" # 因子 3换手率均线偏离度\n",
" mean_turnover = grouped['turnover_rate'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
" std_turnover = grouped['turnover_rate'].rolling(window=3).std().reset_index(level=0, drop=True)\n",
" df['turnover_deviation'] = (df['turnover_rate'] - mean_turnover) / std_turnover\n",
"\n",
" # 因子 4换手率激增信号\n",
" df['cat_turnover_spike'] = (df['turnover_rate'] > mean_turnover + 2 * std_turnover)\n",
"\n",
" # 因子 5量比均值\n",
" df['avg_volume_ratio'] = grouped['volume_ratio'].rolling(window=3).mean().reset_index(level=0, drop=True)\n",
"\n",
" # 因子 6量比突破信号\n",
" max_volume_ratio = grouped['volume_ratio'].rolling(window=5).max().reset_index(level=0, drop=True)\n",
" df['cat_volume_ratio_breakout'] = (df['volume_ratio'] > max_volume_ratio)\n",
"\n",
" df['vol_spike'] = grouped.apply(\n",
" lambda x: pd.Series(x['vol'].rolling(20).mean(), index=x.index)\n",
" )\n",
" df['vol_std_5'] = grouped['vol'].pct_change().rolling(window=5).std()\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),\n",
" index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),\n",
" index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" print(df.columns)\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
"\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" # df['rsi_6'] = grouped.apply(\n",
" # lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" # )\n",
" # df['rsi_9'] = grouped.apply(\n",
" # lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" # )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" # df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" # df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" # df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算 EMA 指标\n",
" df['_ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['_ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['_ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['_ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['_ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['_ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['_ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['_ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
"\n",
" window_high_volume = 5\n",
" window_close_stddev = 20\n",
" period_delta = 5\n",
"\n",
" # 计算每只股票的滚动协方差\n",
" def calculate_rolling_cov(group):\n",
" return group['high'].rolling(window_high_volume).cov(group['vol'])\n",
"\n",
" df['cov'] = grouped.apply(calculate_rolling_cov)\n",
"\n",
" # 计算每只股票的协方差差分\n",
" def calculate_delta_cov(group):\n",
" return group['cov'].diff(period_delta)\n",
"\n",
" df['delta_cov'] = grouped.apply(calculate_delta_cov)\n",
"\n",
" # 计算每只股票的滚动标准差\n",
" def calculate_stddev_close(group):\n",
" return group['close'].rolling(window_close_stddev).std()\n",
"\n",
" df['_stddev_close'] = grouped.apply(calculate_stddev_close)\n",
" df['_rank_stddev'] = df.groupby('trade_date')['_stddev_close'].rank(pct=True)\n",
" df['alpha_22_improved'] = -1 * df['delta_cov'] * df['_rank_stddev']\n",
"\n",
"\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol']))\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" df['cat_up_limit'] = (df['close'] == df['up_limit']) # 是否涨停1表示涨停0表示未涨停\n",
" df['cat_down_limit'] = (df['close'] == df['down_limit']) # 是否跌停1表示跌停0表示未跌停\n",
" df['up_limit_count_10d'] = grouped['cat_up_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
" df['down_limit_count_10d'] = grouped['cat_down_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
"\n",
" # 3. 最近连续涨跌停天数\n",
" def calculate_consecutive_limits(series):\n",
" \"\"\"\n",
" 计算连续涨停/跌停天数。\n",
" \"\"\"\n",
" consecutive_up = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" consecutive_down = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" return consecutive_up, consecutive_down\n",
"\n",
" # 连续涨停天数\n",
" df['consecutive_up_limit'] = grouped['cat_up_limit'].apply(\n",
" lambda x: calculate_consecutive_limits(x)[0]\n",
" )\n",
"\n",
" df['vol_break'] = np.where((df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2), 1, 0)\n",
"\n",
" df['weight_roc5'] = grouped['weight_avg'].apply(lambda x: x.pct_change(5))\n",
"\n",
" def rolling_corr(group):\n",
" roc_close = group['close'].pct_change()\n",
" roc_weight = group['weight_avg'].pct_change()\n",
" return roc_close.rolling(10).corr(roc_weight)\n",
"\n",
" df['price_cost_divergence'] = grouped.apply(rolling_corr)\n",
"\n",
" df['smallcap_concentration'] = (1 / df['log(circ_mv)']) * (df['cost_85pct'] - df['cost_15pct'])\n",
"\n",
" # 16. 筹码稳定性指数 (20日波动率)\n",
" df['weight_std20'] = grouped['weight_avg'].apply(lambda x: x.rolling(20).std())\n",
" df['cost_stability'] = df['weight_std20'] / grouped['weight_avg'].transform(lambda x: x.rolling(20).mean())\n",
"\n",
" # 17. 成本区间突破标记\n",
" df['high_cost_break_days'] = grouped.apply(lambda g: g['close'].gt(g['cost_95pct']).rolling(5).sum())\n",
"\n",
" # 20. 筹码-流动性风险\n",
" df['liquidity_risk'] = (df['cost_95pct'] - df['cost_5pct']) * (\n",
" 1 / grouped['vol'].transform(lambda x: x.rolling(10).mean()))\n",
"\n",
" # # 7. 市值波动率因子\n",
" # df['turnover_std'] = df.groupby('ts_code')['turnover_rate'].transform(lambda x: x.rolling(window=20).std())\n",
" # df['mv_volatility'] = grouped.apply(lambda x: x['turnover_std'] / x['log(circ_mv)'])\n",
" #\n",
" # # 8. 市值成长性因子\n",
" # df['volume_growth'] = df.groupby('ts_code')['vol'].pct_change(periods=20)\n",
" # df['mv_growth'] = df['volume_growth'] / df['log(circ_mv)']\n",
" #\n",
" # df[\"ar\"] = df.groupby('ts_code').apply(lambda x: (x[\"high\"].div(x[\"open\"]).rolling(3).sum()) / (x[\"open\"].div(x[\"low\"]).rolling(3).sum()) * 100).reset_index(level='ts_code', drop=True)\n",
" # df[\"pre_close\"] = df.groupby('ts_code')[\"close\"].shift(1)\n",
" # df[\"br_up\"] = (df[\"high\"] - df[\"pre_close\"]).clip(lower=0)\n",
" # df[\"br_down\"] = (df[\"pre_close\"] - df[\"low\"]).clip(lower=0)\n",
" # df[\"br\"] = df.groupby('ts_code').apply(lambda x: (x[\"br_up\"].rolling(3).sum()) / (x[\"br_down\"].rolling(3).sum()) * 100).reset_index(level='ts_code', drop=True)\n",
" # df['arbr'] = df['ar'] - df['br']\n",
" # df.drop(columns=[\"pre_close\", \"br_up\", \"br_down\", 'ar', 'br'], inplace=True)\n",
"\n",
" # 7. 市值波动率因子 (使用 grouped)\n",
" df['turnover_std'] = grouped['turnover_rate'].transform(lambda x: x.rolling(window=20).std())\n",
" df['mv_volatility'] = grouped.apply(lambda x: x['turnover_std'] / x['log(circ_mv)'])\n",
"\n",
" # 8. 市值成长性因子\n",
" df['volume_growth'] = grouped['vol'].pct_change(periods=20)\n",
" df['mv_growth'] = df['volume_growth'] / df['log(circ_mv)']\n",
"\n",
" # AR 指标\n",
" df[\"ar\"] = grouped.apply(lambda x: (x[\"high\"].div(x[\"open\"]).rolling(3).sum()) / (x[\"open\"].div(x[\"low\"]).rolling(3).sum()) * 100)\n",
"\n",
" # BR 指标\n",
" df[\"pre_close\"] = grouped[\"close\"].shift(1)\n",
" df[\"br_up\"] = (df[\"high\"] - df[\"pre_close\"]).clip(lower=0)\n",
" df[\"br_down\"] = (df[\"pre_close\"] - df[\"low\"]).clip(lower=0)\n",
" df[\"br\"] = grouped.apply(lambda x: (x[\"br_up\"].rolling(3).sum()) / (x[\"br_down\"].rolling(3).sum()) * 100)\n",
"\n",
" # ARBR\n",
" df['arbr'] = df['ar'] - df['br']\n",
" df.drop(columns=[\"pre_close\", \"br_up\", \"br_down\", 'ar', 'br'], inplace=True)\n",
"\n",
" df.drop(columns=['weight_std20'], inplace=True, errors='ignore')\n",
" df.drop(\n",
" columns=['_is_positive', '_is_negative', '_pos_returns', '_neg_returns', '_pos_returns_sq', '_neg_returns_sq'],\n",
" inplace=True, errors='ignore')\n",
" new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]\n",
"\n",
" return df, new_columns\n",
"\n",
"\n",
"def get_simple_factor(df):\n",
" old_columns = df.columns.tolist()[:]\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" alpha = 0.5\n",
" df['momentum_factor'] = df['volume_change_rate'] + alpha * df['turnover_deviation']\n",
" df['resonance_factor'] = df['volume_ratio'] * df['pct_chg']\n",
" df['log_close'] = np.log(df['close'])\n",
"\n",
" df['cat_vol_spike'] = df['vol'] > 2 * df['vol_spike']\n",
"\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" # df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" # df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
"\n",
" df['ctrl_strength'] = (df['cost_85pct'] - df['cost_15pct']) / (df['his_high'] - df['his_low'])\n",
"\n",
" df['low_cost_dev'] = (df['close'] - df['cost_5pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
"\n",
" df['asymmetry'] = (df['cost_95pct'] - df['cost_50pct']) / (df['cost_50pct'] - df['cost_5pct'])\n",
"\n",
" df['lock_factor'] = df['turnover_rate'] * (\n",
" 1 - (df['cost_95pct'] - df['cost_5pct']) / (df['his_high'] - df['his_low']))\n",
"\n",
" df['cat_vol_break'] = (df['close'] > df['cost_85pct']) & (df['volume_ratio'] > 2)\n",
"\n",
" df['cost_atr_adj'] = (df['cost_95pct'] - df['cost_5pct']) / df['atr_14']\n",
"\n",
" # 12. 小盘股筹码集中度\n",
" df['smallcap_concentration'] = (1 / df['log(circ_mv)']) * (df['cost_85pct'] - df['cost_15pct'])\n",
"\n",
" df['cat_golden_resonance'] = ((df['close'] > df['weight_avg']) &\n",
" (df['volume_ratio'] > 1.5) &\n",
" (df['winner_rate'] > 0.7))\n",
"\n",
" df['mv_turnover_ratio'] = df['turnover_rate'] / df['log(circ_mv)']\n",
"\n",
" df['mv_adjusted_volume'] = df['vol'] / df['log(circ_mv)']\n",
"\n",
" df['mv_weighted_turnover'] = df['turnover_rate'] * (1 / df['log(circ_mv)'])\n",
"\n",
" df['nonlinear_mv_volume'] = df['vol'] / df['log(circ_mv)']\n",
"\n",
" df['mv_volume_ratio'] = df['volume_ratio'] / df['log(circ_mv)']\n",
"\n",
" df['mv_momentum'] = df['turnover_rate'] * df['volume_ratio'] / df['log(circ_mv)']\n",
"\n",
" drop_columns = [col for col in df.columns if col.startswith('_')]\n",
" df.drop(columns=drop_columns, inplace=True, errors='ignore')\n",
"\n",
" new_columns = [col for col in df.columns.tolist()[:] if col not in old_columns]\n",
" return df, new_columns\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:40:19.471361Z",
"start_time": "2025-04-09T16:39:30.917824Z"
},
"jupyter": {
"source_hidden": true
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"cyq perf\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 4051406 entries, 0 to 4051405\n",
"Data columns (total 31 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 pct_chg float64 \n",
" 8 turnover_rate float64 \n",
" 9 pe_ttm float64 \n",
" 10 circ_mv float64 \n",
" 11 volume_ratio float64 \n",
" 12 is_st bool \n",
" 13 up_limit float64 \n",
" 14 down_limit float64 \n",
" 15 buy_sm_vol float64 \n",
" 16 sell_sm_vol float64 \n",
" 17 buy_lg_vol float64 \n",
" 18 sell_lg_vol float64 \n",
" 19 buy_elg_vol float64 \n",
" 20 sell_elg_vol float64 \n",
" 21 net_mf_vol float64 \n",
" 22 his_low float64 \n",
" 23 his_high float64 \n",
" 24 cost_5pct float64 \n",
" 25 cost_15pct float64 \n",
" 26 cost_50pct float64 \n",
" 27 cost_85pct float64 \n",
" 28 cost_95pct float64 \n",
" 29 weight_avg float64 \n",
" 30 winner_rate float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(28), object(1)\n",
"memory usage: 931.2+ MB\n",
"None\n"
]
}
],
"source": [
"from code.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df1 = read_and_merge_h5_data('../../data-copy/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],\n",
" df=None)\n",
"df1 = df1[df1['trade_date'] >= '2022-01-01']\n",
"\n",
"print('daily basic')\n",
"df1 = read_and_merge_h5_data('../../data-copy/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df1, join='inner')\n",
"\n",
"print('stk limit')\n",
"df1 = read_and_merge_h5_data('../../data-copy/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df1)\n",
"print('money flow')\n",
"df1 = read_and_merge_h5_data('../../data-copy/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df1)\n",
"print('cyq perf')\n",
"df1 = read_and_merge_h5_data('../../data-copy/cyq_perf.h5', key='cyq_perf',\n",
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct',\n",
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
" df=df1)\n",
"print(df1.info())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cac01788dac10678",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:40:23.694912Z",
"start_time": "2025-04-09T16:40:19.488481Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n"
]
}
],
"source": [
"print('industry')\n",
"industry_df1 = read_and_merge_h5_data('../../data-copy/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code', 'in_date'],\n",
" df=None, on=['ts_code'], join='left')\n",
"\n",
"\n",
"def merge_with_industry_data(df, industry_df):\n",
" # 确保日期字段是 datetime 类型\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
"\n",
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
"\n",
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
"\n",
" # 使用 merge_asof 进行向后合并\n",
" merged = pd.merge_asof(\n",
" df_sorted,\n",
" industry_df_sorted,\n",
" by='ts_code', # 按 ts_code 分组\n",
" left_on='trade_date',\n",
" right_on='in_date',\n",
" direction='backward'\n",
" )\n",
"\n",
" # 获取每个 ts_code 的最早 in_date 记录\n",
" min_in_date_per_ts = (industry_df_sorted\n",
" .groupby('ts_code')\n",
" .first()\n",
" .reset_index()[['ts_code', 'l2_code']])\n",
"\n",
" # 填充未匹配到的记录trade_date 早于所有 in_date 的情况)\n",
" merged['l2_code'] = merged['l2_code'].fillna(\n",
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
" )\n",
"\n",
" # 保留需要的列并重置索引\n",
" result = merged.reset_index(drop=True)\n",
" return result\n",
"\n",
"\n",
"# 使用示例\n",
"df1 = merge_with_industry_data(df1, industry_df1)\n",
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5f7a8b42681606f6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:40:30.145830Z",
"start_time": "2025-04-09T16:40:23.712071Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from code.utils.factor import get_act_factor\n",
"\n",
"\n",
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" industry_data = get_act_factor(industry_data, cat=False)\n",
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" # factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
" #\n",
" # for factor in factor_columns:\n",
" # if factor in industry_data.columns:\n",
" # # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
" # industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
" # lambda x: x - x.mean())\n",
"\n",
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
"\n",
"industry_df1 = read_industry_data('../../data-copy/sw_daily.h5')\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "85c3e3d0235ffffa",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:41:39.580305Z",
"start_time": "2025-04-09T16:40:30.170820Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol',\n",
" 'pct_chg', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol',\n",
" 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol',\n",
" 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate',\n",
" 'l2_code', '_is_positive', '_is_negative', 'cat_is_positive',\n",
" '_pos_returns', '_neg_returns', '_pos_returns_sq', '_neg_returns_sq',\n",
" 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew',\n",
" 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout',\n",
" 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio',\n",
" 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14',\n",
" 'atr_6', 'obv'],\n",
" dtype='object')\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 2425287 entries, 0 to 2425286\n",
"Columns: 137 entries, ts_code to industry_return_20_percentile\n",
"dtypes: bool(12), datetime64[ns](1), float64(119), int32(2), int64(1), object(2)\n",
"memory usage: 2.3+ GB\n",
"None\n"
]
}
],
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" if 'in_date' in df.columns:\n",
" df = df.drop(columns=['in_date'])\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df1 = filter_data(df1)\n",
"df1, _ = get_rolling_factor(df1)\n",
"df1, _ = get_simple_factor(df1)\n",
"df1 = df1.rename(columns={'l2_code': 'cat_l2_code'})\n",
"df1 = df1.merge(industry_df1, on=['cat_l2_code', 'trade_date'], how='left')\n",
"\n",
"\n",
"print(df1.info())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5dabff1e7bdd48c0",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:42:29.604069Z",
"start_time": "2025-04-09T16:41:39.621703Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"cyq perf\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 4062142 entries, 0 to 4062141\n",
"Data columns (total 31 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 pct_chg float64 \n",
" 8 turnover_rate float64 \n",
" 9 pe_ttm float64 \n",
" 10 circ_mv float64 \n",
" 11 volume_ratio float64 \n",
" 12 is_st bool \n",
" 13 up_limit float64 \n",
" 14 down_limit float64 \n",
" 15 buy_sm_vol float64 \n",
" 16 sell_sm_vol float64 \n",
" 17 buy_lg_vol float64 \n",
" 18 sell_lg_vol float64 \n",
" 19 buy_elg_vol float64 \n",
" 20 sell_elg_vol float64 \n",
" 21 net_mf_vol float64 \n",
" 22 his_low float64 \n",
" 23 his_high float64 \n",
" 24 cost_5pct float64 \n",
" 25 cost_15pct float64 \n",
" 26 cost_50pct float64 \n",
" 27 cost_85pct float64 \n",
" 28 cost_95pct float64 \n",
" 29 weight_avg float64 \n",
" 30 winner_rate float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(28), object(1)\n",
"memory usage: 933.6+ MB\n",
"None\n"
]
}
],
"source": [
"from code.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df2 = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],\n",
" df=None)\n",
"df2 = df2[df2['trade_date'] >= '2022-01-01']\n",
"\n",
"print('daily basic')\n",
"df2 = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df2, join='inner')\n",
"\n",
"print('stk limit')\n",
"df2 = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df2)\n",
"print('money flow')\n",
"df2 = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df2)\n",
"print('cyq perf')\n",
"df2 = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf',\n",
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct',\n",
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
" df=df2)\n",
"print(df2.info())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7da9e79ee7f2eeb2",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:42:33.590224Z",
"start_time": "2025-04-09T16:42:29.605171Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n"
]
}
],
"source": [
"print('industry')\n",
"industry_df2 = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code', 'in_date'],\n",
" df=None, on=['ts_code'], join='left')\n",
"\n",
"\n",
"def merge_with_industry_data(df, industry_df):\n",
" # 确保日期字段是 datetime 类型\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
"\n",
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
"\n",
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
"\n",
" # 使用 merge_asof 进行向后合并\n",
" merged = pd.merge_asof(\n",
" df_sorted,\n",
" industry_df_sorted,\n",
" by='ts_code', # 按 ts_code 分组\n",
" left_on='trade_date',\n",
" right_on='in_date',\n",
" direction='backward'\n",
" )\n",
"\n",
" # 获取每个 ts_code 的最早 in_date 记录\n",
" min_in_date_per_ts = (industry_df_sorted\n",
" .groupby('ts_code')\n",
" .first()\n",
" .reset_index()[['ts_code', 'l2_code']])\n",
"\n",
" # 填充未匹配到的记录trade_date 早于所有 in_date 的情况)\n",
" merged['l2_code'] = merged['l2_code'].fillna(\n",
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
" )\n",
"\n",
" # 保留需要的列并重置索引\n",
" result = merged.reset_index(drop=True)\n",
" return result\n",
"\n",
"\n",
"# 使用示例\n",
"df2 = merge_with_industry_data(df2, industry_df2)\n",
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7f0830ced3ce1050",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:42:39.280494Z",
"start_time": "2025-04-09T16:42:33.613600Z"
}
},
"outputs": [],
"source": [
"industry_df2 = read_industry_data('../../data/sw_daily.h5')\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ee9d7511597a312b",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:43:50.865104Z",
"start_time": "2025-04-09T16:42:39.340589Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol',\n",
" 'pct_chg', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol',\n",
" 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol',\n",
" 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate',\n",
" 'l2_code', '_is_positive', '_is_negative', 'cat_is_positive',\n",
" '_pos_returns', '_neg_returns', '_pos_returns_sq', '_neg_returns_sq',\n",
" 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew',\n",
" 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout',\n",
" 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio',\n",
" 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14',\n",
" 'atr_6', 'obv'],\n",
" dtype='object')\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 2431461 entries, 0 to 2431460\n",
"Columns: 137 entries, ts_code to industry_return_20_percentile\n",
"dtypes: bool(12), datetime64[ns](1), float64(119), int32(2), int64(1), object(2)\n",
"memory usage: 2.3+ GB\n",
"None\n"
]
}
],
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" if 'in_date' in df.columns:\n",
" df = df.drop(columns=['in_date'])\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df2 = filter_data(df2)\n",
"df2, _ = get_rolling_factor(df2)\n",
"df2, _ = get_simple_factor(df2)\n",
"df2 = df2.rename(columns={'l2_code': 'cat_l2_code'})\n",
"df2 = df2.merge(industry_df2, on=['cat_l2_code', 'trade_date'], how='left')\n",
"\n",
"print(df2.info())"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "4ae711775caefbe5",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:43:53.621695Z",
"start_time": "2025-04-09T16:43:50.925481Z"
}
},
"outputs": [],
"source": [
"# print(df1[df1['trade_date'] == '2025-04-07'][['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']].tail())\n",
"# print(df2[df2['trade_date'] == '2025-04-07'][['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']].tail())\n",
"# print(df1[df1['trade_date'] == '2025-04-07'].equals(df2[df2['trade_date'] == '2025-04-07']))\n",
"days = 2\n",
"df1 = df1.sort_values(by=['ts_code', 'trade_date'])\n",
"# df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"df1['future_return'] = (df1.groupby('ts_code')['close'].shift(-days) - df1.groupby('ts_code')['open'].shift(-1)) / \\\n",
" df1.groupby('ts_code')['open'].shift(-1)\n",
"df1['future_score'] = calculate_score(df1, days=2, lambda_param=0.3)\n",
"df1['label'] = df1.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" lambda x: pd.qcut(x, q=20, labels=False, duplicates='drop')\n",
")\n",
"\n",
"df2 = df2.sort_values(by=['ts_code', 'trade_date'])\n",
"# df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"df2['future_return'] = (df2.groupby('ts_code')['close'].shift(-days) - df2.groupby('ts_code')['open'].shift(-1)) / \\\n",
" df2.groupby('ts_code')['open'].shift(-1)\n",
"df2['future_score'] = calculate_score(df2, days=2, lambda_param=0.3)\n",
"df2['label'] = df2.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" lambda x: pd.qcut(x, q=20, labels=False, duplicates='drop')\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "350bf91df8c3dfc2",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:43:53.723327Z",
"start_time": "2025-04-09T16:43:53.658090Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"日期: 2025-03-26\n",
"------------------------------\n",
"Slice 1 形状: (3086, 141)\n",
"Slice 2 形状: (3086, 141)\n",
"!!! 索引不同,尝试按 ts_code 对齐 !!!\n",
"------------------------------\n",
"使用 compare() 方法查找差异:\n",
"!!! 发现差异 (compare结果):\n",
"MultiIndex([('vol_std_5', 'self'),\n",
" ('vol_std_5', 'other')],\n",
" )\n",
" vol_std_5 \n",
" self other\n",
"ts_code \n",
"000004.SZ 1.076957 1.076957\n",
"000006.SZ 1.228637 1.228637\n",
"000007.SZ 0.533913 0.533913\n",
"000008.SZ 0.368086 0.368086\n",
"000009.SZ 0.393264 0.393264\n",
"... ... ...\n",
"605580.SH 1.164645 1.164645\n",
"605588.SH 0.314876 0.314876\n",
"605589.SH 0.562543 0.562543\n",
"605598.SH 1.057029 1.057029\n",
"605599.SH 0.193314 0.193314\n",
"\n",
"[3001 rows x 2 columns]\n",
"\n",
"存在差异的列: ['vol_std_5']\n"
]
}
],
"source": [
"# 假设 slice1 和 slice2 已经获取,并且索引和列已对齐\n",
"# (如果索引或列不同,需要先用 .sort_index() 或 .sort_index(axis=1) 对齐)\n",
"# 假设 pdf1 和 pdf2 已经是处理到最后一步的结果\n",
"date_to_compare = '2025-03-26'\n",
"\n",
"# 1. 获取两个 DataFrame 在该日期的切片\n",
"slice1 = df1[df1['trade_date'] == date_to_compare]\n",
"slice2 = df2[df2['trade_date'] == date_to_compare]\n",
"\n",
"def get_diff(slice1, slice2):\n",
" print(f\"日期: {date_to_compare}\")\n",
" print(\"-\" * 30)\n",
" print(f\"Slice 1 形状: {slice1.shape}\")\n",
" print(f\"Slice 2 形状: {slice2.shape}\")\n",
" if slice1.shape != slice2.shape:\n",
" print(\"!!! 形状不同 !!!\")\n",
"\n",
" if not slice1.index.equals(slice2.index):\n",
" print(\"!!! 索引不同,尝试按 ts_code 对齐 !!!\")\n",
" try:\n",
" slice1 = slice1.set_index('ts_code').sort_index()\n",
" slice2 = slice2.set_index('ts_code').sort_index()\n",
" except KeyError:\n",
" print(\"错误:无法按 ts_code 设置索引,请确保该列存在。\")\n",
" # 或者尝试其他对齐方式,例如 reset_index\n",
" # slice1 = slice1.reset_index(drop=True)\n",
" # slice2 = slice2.reset_index(drop=True)\n",
"\n",
" if not slice1.columns.equals(slice2.columns):\n",
" print(\"!!! 列名或顺序不同,尝试按列名排序对齐 !!!\")\n",
" slice1 = slice1.sort_index(axis=1)\n",
" slice2 = slice2.sort_index(axis=1)\n",
"\n",
" # 再次检查对齐情况\n",
" if slice1.index.equals(slice2.index) and slice1.columns.equals(slice2.columns):\n",
" print(\"-\" * 30)\n",
" print(\"使用 compare() 方法查找差异:\")\n",
" try:\n",
" # compare 会返回一个显示差异的 DataFrame\n",
" # self 列显示 slice1 的值other 列显示 slice2 的值\n",
" diff_compare = slice1.compare(slice2)\n",
"\n",
" if diff_compare.empty:\n",
" print(\"使用 compare() 未发现差异。\")\n",
" # 如果 compare 为空但 equals 仍为 False, 可能是非常细微的浮点差异或类型差异\n",
" # 可以再检查一下dtypes\n",
" if not slice1.dtypes.equals(slice2.dtypes):\n",
" print(\"!!! 发现数据类型 (dtypes) 不同 !!!\")\n",
" print(slice1.dtypes[slice1.dtypes != slice2.dtypes])\n",
" print(slice2.dtypes[slice1.dtypes != slice2.dtypes])\n",
"\n",
" else:\n",
" print(\"!!! 发现差异 (compare结果):\")\n",
" # 默认情况下compare 的列是 MultiIndex ('列名', 'self'/'other')\n",
" # 为了更清晰地显示,可以调整一下格式\n",
" # diff_compare.columns = ['_'.join(col) for col in diff_compare.columns]\n",
" print(diff_compare.columns)\n",
" print(diff_compare[diff_compare[('vol_std_5', 'self')] != diff_compare[('vol_std_5', 'other')]]) # 打印差异的头部\n",
"\n",
" # 找出哪些列存在差异\n",
" differing_columns = diff_compare.columns.get_level_values(0).unique().tolist()\n",
" print(f\"\\n存在差异的列: {differing_columns}\")\n",
"\n",
" except Exception as e:\n",
" print(f\"使用 compare() 时出错: {e}\")\n",
" else:\n",
" print(\"-\" * 30)\n",
" print(\"索引或列在对齐后仍然不匹配,无法使用 compare()。请检查对齐逻辑。\")\n",
"\n",
"get_diff(slice1, slice2)\n",
"# print(df1['trade_date'].unique().tolist()[-5:])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "9df2781fc6c7ae44",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:43:55.223316Z",
"start_time": "2025-04-09T16:43:53.868461Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from scipy.stats import ks_2samp, wasserstein_distance\n",
"\n",
"\n",
"def remove_shifted_features(train_data, feature_columns, ks_threshold=0.05, wasserstein_threshold=0.1, size=0.8,\n",
" log=True):\n",
" dropped_features = []\n",
"\n",
" all_dates = sorted(train_data['trade_date'].unique().tolist()) # 获取所有唯一的 trade_date\n",
" split_date = all_dates[int(len(all_dates) * size)] # 划分点为倒数第 validation_days 天\n",
" train_data_split = train_data[train_data['trade_date'] < split_date] # 训练集\n",
" val_data_split = train_data[train_data['trade_date'] >= split_date] # 验证集\n",
"\n",
" # **统计数据漂移**\n",
" numeric_columns = train_data_split.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" for feature in numeric_columns:\n",
" ks_stat, p_value = ks_2samp(train_data_split[feature], val_data_split[feature])\n",
" wasserstein_dist = wasserstein_distance(train_data_split[feature], val_data_split[feature])\n",
"\n",
" if p_value < ks_threshold or wasserstein_dist > wasserstein_threshold:\n",
" dropped_features.append(feature)\n",
" if log:\n",
" print(f\"检测到 {len(dropped_features)} 个可能漂移的特征: {dropped_features}\")\n",
"\n",
" # **应用阈值进行最终筛选**\n",
" filtered_features = [f for f in feature_columns if f not in dropped_features]\n",
"\n",
" return filtered_features, dropped_features\n",
"\n",
"\n",
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
" log=True):\n",
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
"\n",
" # Calculate lower and upper bounds based on percentiles\n",
" lower_bound = label.quantile(lower_percentile)\n",
" upper_bound = label.quantile(upper_percentile)\n",
"\n",
" # Filter out values outside the bounds\n",
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
"\n",
" # Print the number of removed outliers\n",
" if log:\n",
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
" return filtered_label\n",
"\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
" df['future_return'] = (df['future_close'] - df['future_open']) / df['future_open']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" sharpe_ratio = df['future_return'] * df['future_volatility']\n",
" sharpe_ratio.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return sharpe_ratio\n",
"\n",
"\n",
"def calculate_score(df, days=5, lambda_param=1.0):\n",
" def calculate_max_drawdown(prices):\n",
" peak = prices.iloc[0] # 初始化峰值\n",
" max_drawdown = 0 # 初始化最大回撤\n",
"\n",
" for price in prices:\n",
" if price > peak:\n",
" peak = price # 更新峰值\n",
" else:\n",
" drawdown = (peak - price) / peak # 计算当前回撤\n",
" max_drawdown = max(max_drawdown, drawdown) # 更新最大回撤\n",
"\n",
" return max_drawdown\n",
"\n",
" def compute_stock_score(stock_df):\n",
" stock_df = stock_df.sort_values(by=['trade_date'])\n",
" future_return = stock_df['future_return']\n",
" # 使用已有的 pct_chg 字段计算波动率\n",
" volatility = stock_df['pct_chg'].rolling(days).std().shift(-days)\n",
" max_drawdown = stock_df['close'].rolling(days).apply(calculate_max_drawdown, raw=False).shift(-days)\n",
" score = future_return - lambda_param * max_drawdown\n",
" return score\n",
"\n",
" # # 确保 DataFrame 按照股票代码和交易日期排序\n",
" # df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" # 对每个股票分别计算 score\n",
" df['score'] = df.groupby('ts_code').apply(compute_stock_score).reset_index(level=0, drop=True)\n",
"\n",
" return df['score']\n",
"\n",
"\n",
"def remove_highly_correlated_features(df, feature_columns, threshold=0.9):\n",
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
" if not numeric_features:\n",
" raise ValueError(\"No numeric features found in the provided data.\")\n",
"\n",
" corr_matrix = df[numeric_features].corr().abs()\n",
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
" remaining_features = [col for col in feature_columns if col not in to_drop\n",
" or 'act' in col or 'af' in col]\n",
" return remaining_features\n",
"\n",
"\n",
"def cross_sectional_standardization(df, features):\n",
" df_sorted = df.sort_values(by='trade_date') # 按时间排序\n",
" df_standardized = df_sorted.copy()\n",
"\n",
" for date in df_sorted['trade_date'].unique():\n",
" # 获取当前时间点的数据\n",
" current_data = df_standardized[df_standardized['trade_date'] == date]\n",
"\n",
" # 只对指定特征进行标准化\n",
" scaler = StandardScaler()\n",
" standardized_values = scaler.fit_transform(current_data[features])\n",
"\n",
" # 将标准化结果重新赋值回去\n",
" df_standardized.loc[df_standardized['trade_date'] == date, features] = standardized_values\n",
"\n",
" return df_standardized\n",
"\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"\n",
"def neutralize_manual(df, features, industry_col, mkt_cap_col):\n",
" \"\"\" 手动实现简单回归以提升速度 \"\"\"\n",
"\n",
" for col in features:\n",
" residuals = []\n",
" for _, group in df.groupby(industry_col):\n",
" if len(group) > 1:\n",
" x = np.log(group[mkt_cap_col]) # 市值对数\n",
" y = group[col] # 因子值\n",
" beta = np.cov(y, x)[0, 1] / np.var(x) # 计算斜率\n",
" alpha = np.mean(y) - beta * np.mean(x) # 计算截距\n",
" resid = y - (alpha + beta * x) # 计算残差\n",
" residuals.extend(resid)\n",
" else:\n",
" residuals.extend(group[col]) # 样本不足时保留原值\n",
"\n",
" df[col] = residuals\n",
"\n",
" return df\n",
"\n",
"\n",
"import gc\n",
"\n",
"gc.collect()\n",
"\n",
"\n",
"def mad_filter(df, features, n=3):\n",
" for col in features:\n",
" median = df[col].median()\n",
" mad = np.median(np.abs(df[col] - median))\n",
" upper = median + n * mad\n",
" lower = median - n * mad\n",
" df[col] = np.clip(df[col], lower, upper) # 截断极值\n",
" return df\n",
"\n",
"\n",
"def percentile_filter(df, features, lower_percentile=0.01, upper_percentile=0.99):\n",
" for col in features:\n",
" # 按日期分组计算上下百分位数\n",
" lower_bound = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.quantile(lower_percentile)\n",
" )\n",
" upper_bound = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.quantile(upper_percentile)\n",
" )\n",
" # 截断超出范围的值\n",
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
" return df\n",
"\n",
"\n",
"from scipy.stats import iqr\n",
"\n",
"\n",
"def iqr_filter(df, features):\n",
" for col in features:\n",
" df[col] = df.groupby('trade_date')[col].transform(\n",
" lambda x: (x - x.median()) / iqr(x) if iqr(x) != 0 else x\n",
" )\n",
" return df\n",
"\n",
"\n",
"def quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99, window=60):\n",
" df = df.copy()\n",
" for col in features:\n",
" # 计算 rolling 统计量,需要按日期进行 groupby\n",
" rolling_lower = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.rolling(window=min(len(x), window)).quantile(lower_quantile))\n",
" rolling_upper = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.rolling(window=min(len(x), window)).quantile(upper_quantile))\n",
"\n",
" # 对数据进行裁剪\n",
" df[col] = np.clip(df[col], rolling_lower, rolling_upper)\n",
"\n",
" return df\n",
"\n",
"def time_series_quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99, window=60):\n",
" df = df.copy()\n",
" # 确保按股票和时间排序\n",
" df = df.sort_values(['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
" for col in features:\n",
" # 对每个股票的时间序列计算滚动分位数\n",
" rolling_lower = grouped[col].rolling(window=window, min_periods=window // 2).quantile(lower_quantile)\n",
" rolling_upper = grouped[col].rolling(window=window, min_periods=window // 2).quantile(upper_quantile)\n",
" # rolling结果带有多重索引需要对齐\n",
" rolling_lower = rolling_lower.reset_index(level=0, drop=True)\n",
" rolling_upper = rolling_upper.reset_index(level=0, drop=True)\n",
" # 应用 clip\n",
" df[col] = np.clip(df[col], rolling_lower, rolling_upper)\n",
" return df\n",
"\n",
"def cross_sectional_quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99):\n",
" df = df.copy()\n",
" grouped = df.groupby('trade_date')\n",
" for col in features:\n",
" # 计算每日截面的分位数边界\n",
" lower_bound = grouped[col].transform(lambda x: x.quantile(lower_quantile))\n",
" upper_bound = grouped[col].transform(lambda x: x.quantile(upper_quantile))\n",
" # 应用 clip\n",
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "99f677aca6a286d0",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-09T16:54:07.250024Z",
"start_time": "2025-04-09T16:53:57.299050Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Timestamp('2025-04-01 00:00:00')] 19 [19. 0. 5. 2. 1. 6. 10. 18. 4. 12. 17. 16. 11. 8. 15. 7. 14. 9.\n",
" 13.]\n",
"[Timestamp('2025-04-02 00:00:00')] 19 [18. 19. 1. 0. 3. 7. 17. 10. 16. 5. 9. 15. 8. 6. 4. 13. 2. 11.\n",
" 14.]\n",
"[Timestamp('2025-04-03 00:00:00')] 0 [nan]\n",
"[Timestamp('2025-04-07 00:00:00')] 0 [nan]\n",
"2025-04-07 00:00:00\n",
"[Timestamp('2025-04-01 00:00:00')] 19 [19. 0. 5. 2. 1. 6. 10. 18. 4. 12. 17. 16. 11. 8. 15. 7. 14. 9.\n",
" 13.]\n",
"[Timestamp('2025-04-02 00:00:00')] 19 [18. 19. 1. 0. 3. 7. 17. 10. 16. 5. 9. 15. 8. 6. 4. 13. 2. 11.\n",
" 14.]\n",
"[Timestamp('2025-04-03 00:00:00')] 19 [ 2. 15. 19. 0. 1. 5. 18. 17. 4. 6. 16. 8. 13. 14. 9. 7. 12. 11.\n",
" 3.]\n",
"[Timestamp('2025-04-07 00:00:00')] 19 [ 0. 18. 4. 17. 1. 19. 9. 13. 7. 5. 2. 16. 15. 6. 12. 11. 3. 14.\n",
" 8.]\n",
"[Timestamp('2025-04-08 00:00:00')] 0 [nan]\n",
"[Timestamp('2025-04-09 00:00:00')] 0 [nan]\n",
"2025-04-09 00:00:00\n",
"日期: 2025-04-07\n",
"------------------------------\n",
"Slice 1 形状: (100, 159)\n",
"Slice 2 形状: (110, 159)\n",
"!!! 形状不同 !!!\n",
"!!! 索引不同,尝试按 ts_code 对齐 !!!\n",
"------------------------------\n",
"索引或列在对齐后仍然不匹配,无法使用 compare()。请检查对齐逻辑。\n"
]
}
],
"source": [
"def get_pdf(df, industry_df):\n",
" origin_columns = df.columns.tolist()\n",
" origin_columns = [col for col in origin_columns if\n",
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
" origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
"\n",
" days = 2\n",
" # df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" # # df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
" # df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / \\\n",
" # df.groupby('ts_code')['open'].shift(-1)\n",
" # df['future_score'] = calculate_score(df, days=2, lambda_param=0.3)\n",
" # df['label'] = df.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" # lambda x: pd.qcut(x, q=20, labels=False, duplicates='drop')\n",
" # )\n",
" # df['label'] = df.groupby('trade_date', group_keys=False)['future_score'].transform(\n",
" # lambda x: pd.qcut(x.rank(method='first'), q=20, labels=False, duplicates='raise')\n",
" # )\n",
" # df['future_score'] = (\n",
" # 0.7 * df['future_return']\n",
" # * 0.3 * df['future_volatility']\n",
" # )\n",
"\n",
" def select_pre_zt_stocks_dynamic(stock_df):\n",
" def select_stocks(group):\n",
" max_stocks = 150\n",
" initial_data = group.nlargest(100, 'return_20')\n",
" unique_labels = initial_data['label'].nunique()\n",
"\n",
" print(group['trade_date'].unique().tolist(), initial_data['label'].nunique(), initial_data['label'].unique())\n",
" if unique_labels >= 20 or unique_labels == 0: # 包含标签种类为0的情况\n",
" return initial_data\n",
"\n",
" for i in range(110, max_stocks + 1, 10):\n",
" data = group.nlargest(i, 'return_20')\n",
" unique_labels = data['label'].nunique()\n",
" if unique_labels >= 20:\n",
" return data\n",
"\n",
" return group.nlargest(max_stocks, 'return_20') # 如果循环结束仍未找到足够标签,则返回最大数量的股票\n",
"\n",
" stock_df = stock_df.groupby('trade_date', group_keys=False).apply(select_stocks)\n",
" return stock_df\n",
"\n",
"\n",
" pdf = select_pre_zt_stocks_dynamic(df[(df['trade_date'] >= '2022-01-01') & (df['trade_date'] <= '2029-04-07')])\n",
" print(pdf['trade_date'].max())\n",
"\n",
" pdf = pdf.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
" pdf = pdf.replace([np.inf, -np.inf], np.nan)\n",
"\n",
" feature_columns = [col for col in pdf.columns if col in pdf.columns]\n",
" feature_columns = [col for col in feature_columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
" feature_columns = [col for col in feature_columns if 'future' not in col]\n",
" feature_columns = [col for col in feature_columns if 'label' not in col]\n",
" feature_columns = [col for col in feature_columns if 'score' not in col]\n",
" feature_columns = [col for col in feature_columns if 'gen' not in col]\n",
" feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]\n",
" feature_columns = [col for col in feature_columns if 'volatility' not in col]\n",
" feature_columns = [col for col in feature_columns if 'cat_l2_code' not in col]\n",
" feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
" feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
" # feature_columns = [col for col in feature_columns if col not in ['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']]\n",
"\n",
" numeric_columns = pdf.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
"\n",
" pdf = cross_sectional_quantile_filter(pdf, numeric_columns)\n",
" # pdf = cross_sectional_standardization(pdf, numeric_columns)\n",
"\n",
" pdf = pdf.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" filter_index = pdf['future_return'].between(pdf['future_return'].quantile(0.01), pdf['future_return'].quantile(0.99))\n",
"\n",
" feature_columns = remove_highly_correlated_features(pdf, feature_columns)\n",
"\n",
" return pdf, feature_columns, filter_index\n",
"\n",
"pdf1, feature_columns1, filter_index1 = get_pdf(df1[df1['trade_date'] >= '2025-04-01'], industry_df1)\n",
"pdf2, feature_columns2, filter_index2 = get_pdf(df2[df2['trade_date'] >= '2025-04-01'], industry_df2)\n",
"\n",
"# date_to_compare = '2025-04-07'\n",
"slice1 = pdf1[pdf1['trade_date'] == date_to_compare]\n",
"slice2 = pdf2[pdf2['trade_date'] == date_to_compare]\n",
"get_diff(slice1, slice2)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "1b863e4115252d2d",
"metadata": {
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"import lightgbm as lgb\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def train_light_model(train_data_df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" validation_days=180, use_pca=False, split_date=None): # 新增参数validation_days\n",
" # 确保数据按时间排序\n",
" train_data_df = train_data_df.sort_values(by='trade_date')\n",
"\n",
" numeric_columns = train_data_df.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" # X_train.loc[:, numeric_columns] = scaler.fit_transform(X_train[numeric_columns])\n",
" # X_val.loc[:, numeric_columns] = scaler.transform(X_val[numeric_columns])\n",
" # train_data_df = cross_sectional_standardization(train_data_df, numeric_columns)\n",
"\n",
" # 去除标签为空的样本\n",
" train_data_df = train_data_df.dropna(subset=['label'])\n",
" # print('原始训练集大小: ', len(train_data_df))\n",
"\n",
" # 按时间顺序划分训练集和验证集\n",
" if split_date is None:\n",
" all_dates = train_data_df['trade_date'].unique() # 获取所有唯一的 trade_date\n",
" if validation_days == 0:\n",
" split_date = all_dates[-1]\n",
" else:\n",
" split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
" if validation_days == 0:\n",
" train_data_split = train_data_df\n",
" else:\n",
" train_data_split = train_data_df[train_data_df['trade_date'] < split_date] # 训练集\n",
" val_data_split = train_data_df[train_data_df['trade_date'] >= split_date] # 验证集\n",
"\n",
" # 打印划分结果\n",
" print(f\"划分后的训练集大小: {len(train_data_split)}, 验证集大小: {len(val_data_split)}\")\n",
"\n",
" # 提取特征和标签\n",
" X_train = train_data_split[feature_columns]\n",
" y_train = train_data_split['label']\n",
"\n",
" # 标准化数值特征\n",
" scaler = StandardScaler()\n",
"\n",
" # 计算每个 trade_date 内的样本数LTR 需要 group 信息)\n",
" train_groups = train_data_split.groupby('trade_date').size().tolist()\n",
" val_groups = val_data_split.groupby('trade_date').size().tolist()\n",
"\n",
" # 处理类别特征\n",
" categorical_feature = [col for col in feature_columns if 'cat' in col]\n",
"\n",
" # 计算权重(基于时间)\n",
" # trade_date = train_data_split['trade_date'] # 交易日期\n",
" # weights = (trade_date - trade_date.min()).dt.days / (trade_date.max() - trade_date.min()).days + 1\n",
" # weights = train_data_split.groupby('trade_date')['std_return_5'].transform(\n",
" # lambda x: x / x.mean()\n",
" # )\n",
" ud = sorted(train_data_split[\"trade_date\"].unique().tolist())\n",
" date_weights = {date: weight * weight for date, weight in zip(ud, np.linspace(1, 10, len(ud)))}\n",
" params['weight'] = train_data_split[\"trade_date\"].map(date_weights).tolist()\n",
"\n",
" train_dataset = lgb.Dataset(\n",
" X_train, label=y_train, group=train_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
"\n",
" if validation_days > 0:\n",
" X_val = val_data_split[feature_columns]\n",
" y_val = val_data_split['label']\n",
" val_groups = val_data_split.groupby('trade_date').size().tolist()\n",
" val_dataset = lgb.Dataset(\n",
" X_val, label=y_val, group=val_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
" # 训练模型\n",
" model = lgb.train(\n",
" params, train_dataset, num_boost_round=num_boost_round,\n",
" valid_sets=[train_dataset, val_dataset], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
" else:\n",
" model = lgb.train(\n",
" params, train_dataset, num_boost_round=num_boost_round, callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
"\n",
" return model, scaler, None\n",
"\n",
"def rolling_train_predict(df, train_days, test_days, feature_columns_origin, days=5, use_pca=False, validation_days=60,\n",
" filter_index=None, params=None):\n",
" # 1. 按照交易日期排序\n",
" unique_dates = df[df['trade_date'] >= '2020-01-01']['trade_date'].unique().tolist()\n",
" unique_dates = sorted(unique_dates)\n",
" n = len(unique_dates)\n",
"\n",
" # 2. 计算需要跳过的天数,使后续窗口对齐\n",
" extra_days = (n - train_days) % test_days\n",
" start_index = extra_days # 从此索引开始滚动\n",
"\n",
" predictions_list = []\n",
"\n",
" for start in range(start_index, n - train_days - test_days + 1, test_days):\n",
"\n",
" train_dates = unique_dates[start: start + train_days]\n",
" test_dates = unique_dates[start + train_days: start + train_days + test_days]\n",
"\n",
" # 根据日期筛选数据\n",
" # train_data = df[df['trade_date'].isin(train_dates)]\n",
" train_data = df[filter_index & df['trade_date'].isin(train_dates)]\n",
" test_data = df[df['trade_date'].isin(test_dates)]\n",
"\n",
" train_data = train_data.sort_values('trade_date')\n",
" test_data = test_data.sort_values('trade_date')\n",
"\n",
" feature_columns, _ = remove_shifted_features(train_data, feature_columns_origin, size=0.8, log=False)\n",
"\n",
" train_data = train_data.dropna(subset=feature_columns)\n",
" train_data = train_data.dropna(subset=['label'])\n",
" train_data = train_data.reset_index(drop=True)\n",
"\n",
" # print(test_data.tail())\n",
" test_data = test_data.dropna(subset=feature_columns)\n",
" # test_data = test_data.dropna(subset=['label'])\n",
" test_data = test_data.reset_index(drop=True)\n",
"\n",
" # print(len(train_data))\n",
" # print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
" print(f\"train_data最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
" # # print(len(test_data))\n",
" # print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
" print(f\"test_data最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"\n",
" cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
" for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')\n",
"\n",
" label_gain = list(range(len(train_data['label'].unique())))\n",
" label_gain = [(gain + 1) * (gain + 1) for gain in label_gain]\n",
" params['label_gain'] = label_gain\n",
"\n",
" # ud = train_data[\"trade_date\"].unique()\n",
" # date_weights = {date: weight for date, weight in zip(ud, np.linspace(1, 2, len(unique_dates)))}\n",
" # params['weight'] = train_data[\"trade_date\"].map(date_weights).tolist()\n",
"\n",
" # print(f'feature_columns: {feature_columns}')\n",
" # feature_contri = [2 if feat.startswith('act_factor') else 1 for feat in feature_columns]\n",
" # params['feature_contri'] = feature_contri\n",
" evals = {}\n",
" model, _, _ = train_light_model(train_data.dropna(subset=['label']),\n",
" params, feature_columns,\n",
" [lgb.log_evaluation(period=100),\n",
" lgb.callback.record_evaluation(evals),\n",
" # lgb.early_stopping(100, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=100, validation_days=validation_days,\n",
" print_feature_importance=False, use_pca=False)\n",
"\n",
" score_df = test_data.copy()\n",
" score_df['score'] = model.predict(score_df[feature_columns])\n",
" score_df = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
" score_df = score_df[['trade_date', 'score', 'ts_code']]\n",
" predictions_list.append(score_df)\n",
"\n",
" final_predictions = pd.concat(predictions_list, ignore_index=True)\n",
" return final_predictions"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "ddb5b67a9852e2",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_data最大日期: 2022-12-07\n",
"test_data最大日期: 2022-12-08\n",
"划分后的训练集大小: 525, 验证集大小: 106\n",
"train_data最大日期: 2022-12-08\n",
"test_data最大日期: 2022-12-09\n",
"划分后的训练集大小: 531, 验证集大小: 109\n",
"train_data最大日期: 2022-12-09\n",
"test_data最大日期: 2022-12-12\n",
"划分后的训练集大小: 516, 验证集大小: 100\n",
"train_data最大日期: 2022-12-12\n",
"test_data最大日期: 2022-12-13\n",
"划分后的训练集大小: 528, 验证集大小: 108\n",
"train_data最大日期: 2022-12-13\n",
"test_data最大日期: 2022-12-14\n",
"划分后的训练集大小: 571, 验证集大小: 148\n",
"train_data最大日期: 2022-12-14\n",
"test_data最大日期: 2022-12-15\n",
"划分后的训练集大小: 565, 验证集大小: 100\n",
"train_data最大日期: 2022-12-15\n",
"test_data最大日期: 2022-12-16\n",
"划分后的训练集大小: 600, 验证集大小: 144\n",
"train_data最大日期: 2022-12-16\n",
"test_data最大日期: 2022-12-19\n",
"划分后的训练集大小: 597, 验证集大小: 97\n",
"train_data最大日期: 2022-12-19\n",
"test_data最大日期: 2022-12-20\n",
"划分后的训练集大小: 633, 验证集大小: 144\n",
"train_data最大日期: 2022-12-20\n",
"test_data最大日期: 2022-12-21\n",
"划分后的训练集大小: 627, 验证集大小: 142\n",
"train_data最大日期: 2022-12-21\n",
"test_data最大日期: 2022-12-22\n",
"划分后的训练集大小: 624, 验证集大小: 97\n",
"train_data最大日期: 2022-12-22\n",
"test_data最大日期: 2022-12-23\n",
"划分后的训练集大小: 605, 验证集大小: 125\n",
"train_data最大日期: 2022-12-23\n",
"test_data最大日期: 2022-12-26\n",
"划分后的训练集大小: 603, 验证集大小: 95\n",
"train_data最大日期: 2022-12-26\n",
"test_data最大日期: 2022-12-27\n",
"划分后的训练集大小: 558, 验证集大小: 99\n",
"train_data最大日期: 2022-12-27\n",
"test_data最大日期: 2022-12-28\n",
"划分后的训练集大小: 510, 验证集大小: 94\n",
"train_data最大日期: 2022-12-28\n",
"test_data最大日期: 2022-12-29\n",
"划分后的训练集大小: 508, 验证集大小: 95\n",
"train_data最大日期: 2022-12-29\n",
"test_data最大日期: 2022-12-30\n",
"划分后的训练集大小: 528, 验证集大小: 145\n",
"train_data最大日期: 2022-12-30\n",
"test_data最大日期: 2023-01-03\n",
"划分后的训练集大小: 570, 验证集大小: 137\n",
"train_data最大日期: 2023-01-03\n",
"test_data最大日期: 2023-01-04\n",
"划分后的训练集大小: 618, 验证集大小: 147\n",
"train_data最大日期: 2023-01-04\n",
"test_data最大日期: 2023-01-05\n",
"划分后的训练集大小: 666, 验证集大小: 142\n",
"train_data最大日期: 2023-01-05\n",
"test_data最大日期: 2023-01-06\n",
"划分后的训练集大小: 717, 验证集大小: 146\n",
"train_data最大日期: 2023-01-06\n",
"test_data最大日期: 2023-01-09\n",
"划分后的训练集大小: 670, 验证集大小: 98\n",
"train_data最大日期: 2023-01-09\n",
"test_data最大日期: 2023-01-10\n",
"划分后的训练集大小: 630, 验证集大小: 97\n",
"train_data最大日期: 2023-01-10\n",
"test_data最大日期: 2023-01-11\n",
"划分后的训练集大小: 589, 验证集大小: 106\n",
"train_data最大日期: 2023-01-11\n",
"test_data最大日期: 2023-01-12\n",
"划分后的训练集大小: 543, 验证集大小: 96\n",
"train_data最大日期: 2023-01-12\n",
"test_data最大日期: 2023-01-13\n",
"划分后的训练集大小: 544, 验证集大小: 147\n",
"train_data最大日期: 2023-01-13\n",
"test_data最大日期: 2023-01-16\n",
"划分后的训练集大小: 553, 验证集大小: 107\n",
"train_data最大日期: 2023-01-16\n",
"test_data最大日期: 2023-01-17\n",
"划分后的训练集大小: 573, 验证集大小: 117\n",
"train_data最大日期: 2023-01-17\n",
"test_data最大日期: 2023-01-18\n",
"划分后的训练集大小: 604, 验证集大小: 137\n",
"train_data最大日期: 2023-01-18\n",
"test_data最大日期: 2023-01-19\n",
"划分后的训练集大小: 625, 验证集大小: 117\n",
"train_data最大日期: 2023-01-19\n",
"test_data最大日期: 2023-01-20\n",
"划分后的训练集大小: 616, 验证集大小: 138\n",
"train_data最大日期: 2023-01-20\n",
"test_data最大日期: 2023-01-30\n",
"划分后的训练集大小: 609, 验证集大小: 100\n",
"train_data最大日期: 2023-01-30\n",
"test_data最大日期: 2023-01-31\n",
"划分后的训练集大小: 621, 验证集大小: 129\n",
"train_data最大日期: 2023-01-31\n",
"test_data最大日期: 2023-02-01\n",
"划分后的训练集大小: 584, 验证集大小: 100\n",
"train_data最大日期: 2023-02-01\n",
"test_data最大日期: 2023-02-02\n",
"划分后的训练集大小: 583, 验证集大小: 116\n",
"train_data最大日期: 2023-02-02\n",
"test_data最大日期: 2023-02-03\n",
"划分后的训练集大小: 553, 验证集大小: 108\n",
"train_data最大日期: 2023-02-03\n",
"test_data最大日期: 2023-02-06\n",
"划分后的训练集大小: 581, 验证集大小: 128\n",
"train_data最大日期: 2023-02-06\n",
"test_data最大日期: 2023-02-07\n",
"划分后的训练集大小: 572, 验证集大小: 120\n",
"train_data最大日期: 2023-02-07\n",
"test_data最大日期: 2023-02-08\n",
"划分后的训练集大小: 622, 验证集大小: 150\n",
"train_data最大日期: 2023-02-08\n",
"test_data最大日期: 2023-02-09\n",
"划分后的训练集大小: 656, 验证集大小: 150\n",
"train_data最大日期: 2023-02-09\n",
"test_data最大日期: 2023-02-10\n",
"划分后的训练集大小: 697, 验证集大小: 149\n",
"train_data最大日期: 2023-02-10\n",
"test_data最大日期: 2023-02-13\n",
"划分后的训练集大小: 698, 验证集大小: 129\n",
"train_data最大日期: 2023-02-13\n",
"test_data最大日期: 2023-02-14\n",
"划分后的训练集大小: 717, 验证集大小: 139\n",
"train_data最大日期: 2023-02-14\n",
"test_data最大日期: 2023-02-15\n",
"划分后的训练集大小: 715, 验证集大小: 148\n",
"train_data最大日期: 2023-02-15\n",
"test_data最大日期: 2023-02-16\n",
"划分后的训练集大小: 714, 验证集大小: 149\n",
"train_data最大日期: 2023-02-16\n",
"test_data最大日期: 2023-02-17\n",
"划分后的训练集大小: 713, 验证集大小: 148\n",
"train_data最大日期: 2023-02-17\n",
"test_data最大日期: 2023-02-20\n",
"划分后的训练集大小: 682, 验证集大小: 98\n",
"train_data最大日期: 2023-02-20\n",
"test_data最大日期: 2023-02-21\n",
"划分后的训练集大小: 681, 验证集大小: 138\n",
"train_data最大日期: 2023-02-21\n",
"test_data最大日期: 2023-02-22\n",
"划分后的训练集大小: 632, 验证集大小: 99\n",
"train_data最大日期: 2023-02-22\n",
"test_data最大日期: 2023-02-23\n",
"划分后的训练集大小: 619, 验证集大小: 136\n",
"train_data最大日期: 2023-02-23\n",
"test_data最大日期: 2023-02-24\n",
"划分后的训练集大小: 571, 验证集大小: 100\n",
"train_data最大日期: 2023-02-24\n",
"test_data最大日期: 2023-02-27\n",
"划分后的训练集大小: 621, 验证集大小: 148\n",
"train_data最大日期: 2023-02-27\n",
"test_data最大日期: 2023-02-28\n",
"划分后的训练集大小: 632, 验证集大小: 149\n",
"train_data最大日期: 2023-02-28\n",
"test_data最大日期: 2023-03-01\n",
"划分后的训练集大小: 632, 验证集大小: 99\n",
"train_data最大日期: 2023-03-01\n",
"test_data最大日期: 2023-03-02\n",
"划分后的训练集大小: 596, 验证集大小: 100\n",
"train_data最大日期: 2023-03-02\n",
"test_data最大日期: 2023-03-03\n",
"划分后的训练集大小: 595, 验证集大小: 99\n",
"train_data最大日期: 2023-03-03\n",
"test_data最大日期: 2023-03-06\n",
"划分后的训练集大小: 596, 验证集大小: 149\n",
"train_data最大日期: 2023-03-06\n",
"test_data最大日期: 2023-03-07\n",
"划分后的训练集大小: 547, 验证集大小: 100\n",
"train_data最大日期: 2023-03-07\n",
"test_data最大日期: 2023-03-08\n",
"划分后的训练集大小: 567, 验证集大小: 119\n",
"train_data最大日期: 2023-03-08\n",
"test_data最大日期: 2023-03-09\n",
"划分后的训练集大小: 585, 验证集大小: 118\n",
"train_data最大日期: 2023-03-09\n",
"test_data最大日期: 2023-03-10\n",
"划分后的训练集大小: 634, 验证集大小: 148\n",
"train_data最大日期: 2023-03-10\n",
"test_data最大日期: 2023-03-13\n",
"划分后的训练集大小: 630, 验证集大小: 145\n",
"train_data最大日期: 2023-03-13\n",
"test_data最大日期: 2023-03-14\n",
"划分后的训练集大小: 638, 验证集大小: 108\n",
"train_data最大日期: 2023-03-14\n",
"test_data最大日期: 2023-03-15\n",
"划分后的训练集大小: 665, 验证集大小: 146\n",
"train_data最大日期: 2023-03-15\n",
"test_data最大日期: 2023-03-16\n",
"划分后的训练集大小: 677, 验证集大小: 130\n",
"train_data最大日期: 2023-03-16\n",
"test_data最大日期: 2023-03-17\n",
"划分后的训练集大小: 678, 验证集大小: 149\n",
"train_data最大日期: 2023-03-17\n",
"test_data最大日期: 2023-03-20\n",
"划分后的训练集大小: 642, 验证集大小: 109\n",
"train_data最大日期: 2023-03-20\n",
"test_data最大日期: 2023-03-21\n",
"划分后的训练集大小: 663, 验证集大小: 129\n",
"train_data最大日期: 2023-03-21\n",
"test_data最大日期: 2023-03-22\n",
"划分后的训练集大小: 615, 验证集大小: 98\n",
"train_data最大日期: 2023-03-22\n",
"test_data最大日期: 2023-03-23\n",
"划分后的训练集大小: 633, 验证集大小: 148\n",
"train_data最大日期: 2023-03-23\n",
"test_data最大日期: 2023-03-24\n",
"划分后的训练集大小: 627, 验证集大小: 143\n",
"train_data最大日期: 2023-03-24\n",
"test_data最大日期: 2023-03-27\n",
"划分后的训练集大小: 646, 验证集大小: 128\n",
"train_data最大日期: 2023-03-27\n",
"test_data最大日期: 2023-03-28\n",
"划分后的训练集大小: 615, 验证集大小: 98\n",
"train_data最大日期: 2023-03-28\n",
"test_data最大日期: 2023-03-29\n",
"划分后的训练集大小: 644, 验证集大小: 127\n",
"train_data最大日期: 2023-03-29\n",
"test_data最大日期: 2023-03-30\n",
"划分后的训练集大小: 623, 验证集大小: 127\n",
"train_data最大日期: 2023-03-30\n",
"test_data最大日期: 2023-03-31\n",
"划分后的训练集大小: 577, 验证集大小: 97\n",
"train_data最大日期: 2023-03-31\n",
"test_data最大日期: 2023-04-03\n",
"划分后的训练集大小: 595, 验证集大小: 146\n",
"train_data最大日期: 2023-04-03\n",
"test_data最大日期: 2023-04-04\n",
"划分后的训练集大小: 644, 验证集大小: 147\n",
"train_data最大日期: 2023-04-04\n",
"test_data最大日期: 2023-04-06\n",
"划分后的训练集大小: 632, 验证集大小: 115\n",
"train_data最大日期: 2023-04-06\n",
"test_data最大日期: 2023-04-07\n",
"划分后的训练集大小: 651, 验证集大小: 146\n",
"train_data最大日期: 2023-04-07\n",
"test_data最大日期: 2023-04-10\n",
"划分后的训练集大小: 702, 验证集大小: 148\n",
"train_data最大日期: 2023-04-10\n",
"test_data最大日期: 2023-04-11\n",
"划分后的训练集大小: 701, 验证集大小: 145\n",
"train_data最大日期: 2023-04-11\n",
"test_data最大日期: 2023-04-12\n",
"划分后的训练集大小: 672, 验证集大小: 118\n",
"train_data最大日期: 2023-04-12\n",
"test_data最大日期: 2023-04-13\n",
"划分后的训练集大小: 694, 验证集大小: 137\n",
"train_data最大日期: 2023-04-13\n",
"test_data最大日期: 2023-04-14\n",
"划分后的训练集大小: 695, 验证集大小: 147\n",
"train_data最大日期: 2023-04-14\n",
"test_data最大日期: 2023-04-17\n",
"划分后的训练集大小: 684, 验证集大小: 137\n",
"train_data最大日期: 2023-04-17\n",
"test_data最大日期: 2023-04-18\n",
"划分后的训练集大小: 638, 验证集大小: 99\n",
"train_data最大日期: 2023-04-18\n",
"test_data最大日期: 2023-04-19\n",
"划分后的训练集大小: 649, 验证集大小: 129\n",
"train_data最大日期: 2023-04-19\n",
"test_data最大日期: 2023-04-20\n",
"划分后的训练集大小: 610, 验证集大小: 98\n",
"train_data最大日期: 2023-04-20\n",
"test_data最大日期: 2023-04-21\n",
"划分后的训练集大小: 611, 验证集大小: 148\n",
"train_data最大日期: 2023-04-21\n",
"test_data最大日期: 2023-04-24\n",
"划分后的训练集大小: 610, 验证集大小: 136\n",
"train_data最大日期: 2023-04-24\n",
"test_data最大日期: 2023-04-25\n",
"划分后的训练集大小: 657, 验证集大小: 146\n",
"train_data最大日期: 2023-04-25\n",
"test_data最大日期: 2023-04-26\n",
"划分后的训练集大小: 675, 验证集大小: 147\n",
"train_data最大日期: 2023-04-26\n",
"test_data最大日期: 2023-04-27\n",
"划分后的训练集大小: 677, 验证集大小: 100\n",
"train_data最大日期: 2023-04-27\n",
"test_data最大日期: 2023-04-28\n",
"划分后的训练集大小: 653, 验证集大小: 124\n",
"train_data最大日期: 2023-04-28\n",
"test_data最大日期: 2023-05-04\n",
"划分后的训练集大小: 664, 验证集大小: 147\n",
"train_data最大日期: 2023-05-04\n",
"test_data最大日期: 2023-05-05\n",
"划分后的训练集大小: 636, 验证集大小: 118\n",
"train_data最大日期: 2023-05-05\n",
"test_data最大日期: 2023-05-08\n",
"划分后的训练集大小: 637, 验证集大小: 148\n",
"train_data最大日期: 2023-05-08\n",
"test_data最大日期: 2023-05-09\n",
"划分后的训练集大小: 685, 验证集大小: 148\n",
"train_data最大日期: 2023-05-09\n",
"test_data最大日期: 2023-05-10\n",
"划分后的训练集大小: 658, 验证集大小: 97\n",
"train_data最大日期: 2023-05-10\n",
"test_data最大日期: 2023-05-11\n",
"划分后的训练集大小: 638, 验证集大小: 127\n",
"train_data最大日期: 2023-05-11\n",
"test_data最大日期: 2023-05-12\n",
"划分后的训练集大小: 666, 验证集大小: 146\n",
"train_data最大日期: 2023-05-12\n",
"test_data最大日期: 2023-05-15\n",
"划分后的训练集大小: 664, 验证集大小: 146\n",
"train_data最大日期: 2023-05-15\n",
"test_data最大日期: 2023-05-16\n",
"划分后的训练集大小: 621, 验证集大小: 105\n",
"train_data最大日期: 2023-05-16\n",
"test_data最大日期: 2023-05-17\n",
"划分后的训练集大小: 623, 验证集大小: 99\n",
"train_data最大日期: 2023-05-17\n",
"test_data最大日期: 2023-05-18\n",
"划分后的训练集大小: 606, 验证集大小: 110\n",
"train_data最大日期: 2023-05-18\n",
"test_data最大日期: 2023-05-19\n",
"划分后的训练集大小: 578, 验证集大小: 118\n",
"train_data最大日期: 2023-05-19\n",
"test_data最大日期: 2023-05-22\n",
"划分后的训练集大小: 540, 验证集大小: 108\n",
"train_data最大日期: 2023-05-22\n",
"test_data最大日期: 2023-05-23\n",
"划分后的训练集大小: 532, 验证集大小: 97\n",
"train_data最大日期: 2023-05-23\n",
"test_data最大日期: 2023-05-24\n",
"划分后的训练集大小: 559, 验证集大小: 126\n",
"train_data最大日期: 2023-05-24\n",
"test_data最大日期: 2023-05-25\n",
"划分后的训练集大小: 548, 验证集大小: 99\n",
"train_data最大日期: 2023-05-25\n",
"test_data最大日期: 2023-05-26\n",
"划分后的训练集大小: 526, 验证集大小: 96\n",
"train_data最大日期: 2023-05-26\n",
"test_data最大日期: 2023-05-29\n",
"划分后的训练集大小: 516, 验证集大小: 98\n",
"train_data最大日期: 2023-05-29\n",
"test_data最大日期: 2023-05-30\n",
"划分后的训练集大小: 527, 验证集大小: 108\n",
"train_data最大日期: 2023-05-30\n",
"test_data最大日期: 2023-05-31\n",
"划分后的训练集大小: 546, 验证集大小: 145\n",
"train_data最大日期: 2023-05-31\n",
"test_data最大日期: 2023-06-01\n",
"划分后的训练集大小: 594, 验证集大小: 147\n",
"train_data最大日期: 2023-06-01\n",
"test_data最大日期: 2023-06-02\n",
"划分后的训练集大小: 616, 验证集大小: 118\n",
"train_data最大日期: 2023-06-02\n",
"test_data最大日期: 2023-06-05\n",
"划分后的训练集大小: 666, 验证集大小: 148\n",
"train_data最大日期: 2023-06-05\n",
"test_data最大日期: 2023-06-06\n",
"划分后的训练集大小: 676, 验证集大小: 118\n",
"train_data最大日期: 2023-06-06\n",
"test_data最大日期: 2023-06-07\n",
"划分后的训练集大小: 626, 验证集大小: 95\n",
"train_data最大日期: 2023-06-07\n",
"test_data最大日期: 2023-06-08\n",
"划分后的训练集大小: 626, 验证集大小: 147\n",
"train_data最大日期: 2023-06-08\n",
"test_data最大日期: 2023-06-09\n",
"划分后的训练集大小: 606, 验证集大小: 98\n",
"train_data最大日期: 2023-06-09\n",
"test_data最大日期: 2023-06-12\n",
"划分后的训练集大小: 558, 验证集大小: 100\n",
"train_data最大日期: 2023-06-12\n",
"test_data最大日期: 2023-06-13\n",
"划分后的训练集大小: 579, 验证集大小: 139\n",
"train_data最大日期: 2023-06-13\n",
"test_data最大日期: 2023-06-14\n",
"划分后的训练集大小: 630, 验证集大小: 146\n"
]
},
{
"ename": "LightGBMError",
"evalue": "Forced splits file includes feature index 0, but maximum feature index in dataset is -1",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mLightGBMError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[1;32mIn[36], line 38\u001B[0m\n\u001B[0;32m 34\u001B[0m final_predictions\u001B[38;5;241m.\u001B[39mto_csv(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mpredictions_test.tsv\u001B[39m\u001B[38;5;124m'\u001B[39m, index\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[0;32m 36\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m final_predictions\n\u001B[1;32m---> 38\u001B[0m final_predictions1 \u001B[38;5;241m=\u001B[39m train(pdf1, feature_columns1, filter_index1)\n\u001B[0;32m 39\u001B[0m final_predictions2 \u001B[38;5;241m=\u001B[39m train(pdf2, feature_columns2, filter_index2)\n",
"Cell \u001B[1;32mIn[36], line 31\u001B[0m, in \u001B[0;36mtrain\u001B[1;34m(pdf, feature_columns, filter_index)\u001B[0m\n\u001B[0;32m 4\u001B[0m light_params \u001B[38;5;241m=\u001B[39m {\n\u001B[0;32m 5\u001B[0m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlabel_gain\u001B[39m\u001B[38;5;124m'\u001B[39m: label_gain,\n\u001B[0;32m 6\u001B[0m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mobjective\u001B[39m\u001B[38;5;124m'\u001B[39m: \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlambdarank\u001B[39m\u001B[38;5;124m'\u001B[39m,\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mseed\u001B[39m\u001B[38;5;124m'\u001B[39m: \u001B[38;5;241m7\u001B[39m\n\u001B[0;32m 27\u001B[0m }\n\u001B[0;32m 29\u001B[0m gc\u001B[38;5;241m.\u001B[39mcollect()\n\u001B[1;32m---> 31\u001B[0m final_predictions \u001B[38;5;241m=\u001B[39m rolling_train_predict(\n\u001B[0;32m 32\u001B[0m pdf[(pdf[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtrade_date\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m>\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m2022-12-01\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;241m&\u001B[39m (pdf[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtrade_date\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m2029-03-26\u001B[39m\u001B[38;5;124m'\u001B[39m)], \u001B[38;5;241m5\u001B[39m, \u001B[38;5;241m1\u001B[39m, feature_columns,\n\u001B[0;32m 33\u001B[0m days\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m, validation_days\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m, filter_index\u001B[38;5;241m=\u001B[39mfilter_index, params\u001B[38;5;241m=\u001B[39mlight_params)\n\u001B[0;32m 34\u001B[0m final_predictions\u001B[38;5;241m.\u001B[39mto_csv(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mpredictions_test.tsv\u001B[39m\u001B[38;5;124m'\u001B[39m, index\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[0;32m 36\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m final_predictions\n",
"Cell \u001B[1;32mIn[33], line 154\u001B[0m, in \u001B[0;36mrolling_train_predict\u001B[1;34m(df, train_days, test_days, feature_columns_origin, days, use_pca, validation_days, filter_index, params)\u001B[0m\n\u001B[0;32m 146\u001B[0m \u001B[38;5;66;03m# ud = train_data[\"trade_date\"].unique()\u001B[39;00m\n\u001B[0;32m 147\u001B[0m \u001B[38;5;66;03m# date_weights = {date: weight for date, weight in zip(ud, np.linspace(1, 2, len(unique_dates)))}\u001B[39;00m\n\u001B[0;32m 148\u001B[0m \u001B[38;5;66;03m# params['weight'] = train_data[\"trade_date\"].map(date_weights).tolist()\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 151\u001B[0m \u001B[38;5;66;03m# feature_contri = [2 if feat.startswith('act_factor') else 1 for feat in feature_columns]\u001B[39;00m\n\u001B[0;32m 152\u001B[0m \u001B[38;5;66;03m# params['feature_contri'] = feature_contri\u001B[39;00m\n\u001B[0;32m 153\u001B[0m evals \u001B[38;5;241m=\u001B[39m {}\n\u001B[1;32m--> 154\u001B[0m model, _, _ \u001B[38;5;241m=\u001B[39m train_light_model(train_data\u001B[38;5;241m.\u001B[39mdropna(subset\u001B[38;5;241m=\u001B[39m[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlabel\u001B[39m\u001B[38;5;124m'\u001B[39m]),\n\u001B[0;32m 155\u001B[0m params, feature_columns,\n\u001B[0;32m 156\u001B[0m [lgb\u001B[38;5;241m.\u001B[39mlog_evaluation(period\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m100\u001B[39m),\n\u001B[0;32m 157\u001B[0m lgb\u001B[38;5;241m.\u001B[39mcallback\u001B[38;5;241m.\u001B[39mrecord_evaluation(evals),\n\u001B[0;32m 158\u001B[0m \u001B[38;5;66;03m# lgb.early_stopping(100, first_metric_only=True)\u001B[39;00m\n\u001B[0;32m 159\u001B[0m ], evals,\n\u001B[0;32m 160\u001B[0m num_boost_round\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m100\u001B[39m, validation_days\u001B[38;5;241m=\u001B[39mvalidation_days,\n\u001B[0;32m 161\u001B[0m print_feature_importance\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m, use_pca\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[0;32m 163\u001B[0m score_df \u001B[38;5;241m=\u001B[39m test_data\u001B[38;5;241m.\u001B[39mcopy()\n\u001B[0;32m 164\u001B[0m score_df[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mscore\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m=\u001B[39m model\u001B[38;5;241m.\u001B[39mpredict(score_df[feature_columns])\n",
"Cell \u001B[1;32mIn[33], line 81\u001B[0m, in \u001B[0;36mtrain_light_model\u001B[1;34m(train_data_df, params, feature_columns, callbacks, evals, print_feature_importance, num_boost_round, validation_days, use_pca, split_date)\u001B[0m\n\u001B[0;32m 75\u001B[0m model \u001B[38;5;241m=\u001B[39m lgb\u001B[38;5;241m.\u001B[39mtrain(\n\u001B[0;32m 76\u001B[0m params, train_dataset, num_boost_round\u001B[38;5;241m=\u001B[39mnum_boost_round,\n\u001B[0;32m 77\u001B[0m valid_sets\u001B[38;5;241m=\u001B[39m[train_dataset, val_dataset], valid_names\u001B[38;5;241m=\u001B[39m[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtrain\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mvalid\u001B[39m\u001B[38;5;124m'\u001B[39m],\n\u001B[0;32m 78\u001B[0m callbacks\u001B[38;5;241m=\u001B[39mcallbacks\n\u001B[0;32m 79\u001B[0m )\n\u001B[0;32m 80\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m---> 81\u001B[0m model \u001B[38;5;241m=\u001B[39m lgb\u001B[38;5;241m.\u001B[39mtrain(\n\u001B[0;32m 82\u001B[0m params, train_dataset, num_boost_round\u001B[38;5;241m=\u001B[39mnum_boost_round, callbacks\u001B[38;5;241m=\u001B[39mcallbacks\n\u001B[0;32m 83\u001B[0m )\n\u001B[0;32m 85\u001B[0m \u001B[38;5;66;03m# 打印特征重要性(如果需要)\u001B[39;00m\n\u001B[0;32m 86\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m print_feature_importance:\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\lightgbm\\engine.py:297\u001B[0m, in \u001B[0;36mtrain\u001B[1;34m(params, train_set, num_boost_round, valid_sets, valid_names, feval, init_model, keep_training_booster, callbacks)\u001B[0m\n\u001B[0;32m 295\u001B[0m \u001B[38;5;66;03m# construct booster\u001B[39;00m\n\u001B[0;32m 296\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m--> 297\u001B[0m booster \u001B[38;5;241m=\u001B[39m Booster(params\u001B[38;5;241m=\u001B[39mparams, train_set\u001B[38;5;241m=\u001B[39mtrain_set)\n\u001B[0;32m 298\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m is_valid_contain_train:\n\u001B[0;32m 299\u001B[0m booster\u001B[38;5;241m.\u001B[39mset_train_data_name(train_data_name)\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\lightgbm\\basic.py:3660\u001B[0m, in \u001B[0;36mBooster.__init__\u001B[1;34m(self, params, train_set, model_file, model_str)\u001B[0m\n\u001B[0;32m 3658\u001B[0m params\u001B[38;5;241m.\u001B[39mupdate(train_set\u001B[38;5;241m.\u001B[39mget_params())\n\u001B[0;32m 3659\u001B[0m params_str \u001B[38;5;241m=\u001B[39m _param_dict_to_str(params)\n\u001B[1;32m-> 3660\u001B[0m _safe_call(\n\u001B[0;32m 3661\u001B[0m _LIB\u001B[38;5;241m.\u001B[39mLGBM_BoosterCreate(\n\u001B[0;32m 3662\u001B[0m train_set\u001B[38;5;241m.\u001B[39m_handle,\n\u001B[0;32m 3663\u001B[0m _c_str(params_str),\n\u001B[0;32m 3664\u001B[0m ctypes\u001B[38;5;241m.\u001B[39mbyref(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_handle),\n\u001B[0;32m 3665\u001B[0m )\n\u001B[0;32m 3666\u001B[0m )\n\u001B[0;32m 3667\u001B[0m \u001B[38;5;66;03m# save reference to data\u001B[39;00m\n\u001B[0;32m 3668\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtrain_set \u001B[38;5;241m=\u001B[39m train_set\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\lightgbm\\basic.py:313\u001B[0m, in \u001B[0;36m_safe_call\u001B[1;34m(ret)\u001B[0m\n\u001B[0;32m 305\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"Check the return value from C API call.\u001B[39;00m\n\u001B[0;32m 306\u001B[0m \n\u001B[0;32m 307\u001B[0m \u001B[38;5;124;03mParameters\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 310\u001B[0m \u001B[38;5;124;03m The return value from C API calls.\u001B[39;00m\n\u001B[0;32m 311\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 312\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m ret \u001B[38;5;241m!=\u001B[39m \u001B[38;5;241m0\u001B[39m:\n\u001B[1;32m--> 313\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m LightGBMError(_LIB\u001B[38;5;241m.\u001B[39mLGBM_GetLastError()\u001B[38;5;241m.\u001B[39mdecode(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mutf-8\u001B[39m\u001B[38;5;124m\"\u001B[39m))\n",
"\u001B[1;31mLightGBMError\u001B[0m: Forced splits file includes feature index 0, but maximum feature index in dataset is -1"
]
}
],
"source": [
"\n",
"\n",
"def train(pdf, feature_columns, filter_index):\n",
" label_gain = list(range(len(pdf['label'].unique())))\n",
" label_gain = [gain * gain for gain in label_gain]\n",
" light_params = {\n",
" 'label_gain': label_gain,\n",
" 'objective': 'lambdarank',\n",
" 'metric': 'ndcg',\n",
" 'learning_rate': 0.03,\n",
" 'num_leaves': 32,\n",
" # 'min_data_in_leaf': 128,\n",
" 'max_depth': 8,\n",
" 'max_bin': 32,\n",
" 'feature_fraction': 0.7,\n",
" # 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" 'lambda_l1': 0.1,\n",
" 'lambda_l2': 0.1,\n",
" 'boosting': 'gbdt',\n",
" 'verbosity': -1,\n",
" 'extra_trees': True,\n",
" 'max_position': 5,\n",
" 'ndcg_at': 1,\n",
" 'quant_train_renew_leaf': True,\n",
" 'lambdarank_truncation_level': 3,\n",
" # 'lambdarank_position_bias_regularization': 1,\n",
" 'seed': 7\n",
" }\n",
"\n",
" gc.collect()\n",
"\n",
" final_predictions = rolling_train_predict(\n",
" pdf[(pdf['trade_date'] >= '2022-12-01') & (pdf['trade_date'] <= '2029-03-26')], 5, 1, feature_columns,\n",
" days=0, validation_days=0, filter_index=filter_index, params=light_params)\n",
" final_predictions.to_csv('predictions_test.tsv', index=False)\n",
"\n",
" return final_predictions\n",
"\n",
"final_predictions1 = train(pdf1, feature_columns1, filter_index1)\n",
"final_predictions2 = train(pdf2, feature_columns2, filter_index2)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7e470a2-e1e5-42e5-a2ee-a5fc80455d95",
"metadata": {},
"outputs": [],
"source": [
"\n",
"slice1 = final_predictions1[final_predictions1['trade_date'] == date_to_compare]\n",
"slice2 = final_predictions2[final_predictions2['trade_date'] == date_to_compare]\n",
"get_diff(slice1, slice2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}