Files
NewStock/main/train/Classify.ipynb
2025-06-01 15:59:29 +08:00

1246 lines
58 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-09T09:24:33.048709Z",
"start_time": "2025-03-09T09:24:32.439746Z"
}
},
"source": [
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import pandas as pd\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"pd.set_option('display.max_columns', None)\n"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:25:28.191722Z",
"start_time": "2025-03-09T09:24:33.069363Z"
}
},
"source": [
"from code.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8391351 entries, 0 to 8391350\n",
"Data columns (total 21 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(18), object(1)\n",
"memory usage: 1.3+ GB\n",
"None\n"
]
}
],
"execution_count": 2
},
{
"cell_type": "code",
"id": "38879273d3574ae3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:25:35.747048Z",
"start_time": "2025-03-09T09:25:28.639248Z"
}
},
"source": [
"print('industry')\n",
"df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code'],\n",
" df=df, on=['ts_code'], join='left')\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n",
"left merge on ['ts_code']\n"
]
}
],
"execution_count": 3
},
{
"cell_type": "code",
"id": "a4eec8c93f6a7cc3",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-09T09:25:35.910003Z",
"start_time": "2025-03-09T09:25:35.778298Z"
}
},
"source": [
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" df = df.sort_values('trade_date')\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" # 4. 情绪因子1市场上涨比例Up Ratio\n",
" df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)\n",
" df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例\n",
"\n",
" # 5. 情绪因子2成交量变化率Volume Change Rate\n",
" df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量\n",
" df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率\n",
"\n",
" # 6. 情绪因子3波动率Volatility\n",
" df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差\n",
"\n",
" # 7. 情绪因子4成交额变化率Amount Change Rate\n",
" df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额\n",
" df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率\n",
"\n",
" return df\n",
"\n",
"\n",
"def generate_index_indicators(h5_filename):\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" df = df.sort_values('trade_date')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
" df_index = calculate_indicators(df_index)\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', 'RSI', 'MACD', 'Signal_line',\n",
" 'MACD_hist', 'up_ratio_20d', 'volume_change_rate', 'volatility',\n",
" 'amount_change_rate', 'amount_mean'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
"h5_filename = '../../data/index_data.h5'\n",
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n"
],
"outputs": [],
"execution_count": 4
},
{
"cell_type": "code",
"id": "c4e9e1d31da6dba6",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-09T09:25:35.982553Z",
"start_time": "2025-03-09T09:25:35.933623Z"
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"def get_technical_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 计算 up 和 down\n",
" df['log_close'] = np.log(df['close'])\n",
"\n",
" df['cat_vol_spike'] = df['vol'] > 2 * df['vol'].rolling(20).mean()\n",
"\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),\n",
" index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),\n",
" index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算 RSI\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df, cat=True):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
" # 计算 EMA 指标\n",
" df['ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" if cat:\n",
" df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_limit_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" # 分组处理\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 1. 今日是否涨停/跌停\n",
" df['cat_up_limit'] = (df['close'] == df['up_limit']).astype(int) # 是否涨停1表示涨停0表示未涨停\n",
" df['cat_down_limit'] = (df['close'] == df['down_limit']).astype(int) # 是否跌停1表示跌停0表示未跌停\n",
"\n",
" # 2. 最近涨跌停次数过去20个交易日\n",
" df['up_limit_count_10d'] = grouped['cat_up_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
" df['down_limit_count_10d'] = grouped['cat_down_limit'].rolling(window=10, min_periods=1).sum().reset_index(level=0,\n",
" drop=True)\n",
"\n",
" # 3. 最近连续涨跌停天数\n",
" def calculate_consecutive_limits(series):\n",
" \"\"\"\n",
" 计算连续涨停/跌停天数。\n",
" \"\"\"\n",
" consecutive_up = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" consecutive_down = series * (series.groupby((series != series.shift()).cumsum()).cumcount() + 1)\n",
" return consecutive_up, consecutive_down\n",
"\n",
" # 连续涨停天数\n",
" df['consecutive_up_limit'] = grouped['cat_up_limit'].apply(\n",
" lambda x: calculate_consecutive_limits(x)[0]\n",
" ).reset_index(level=0, drop=True)\n",
"\n",
" # 连续跌停天数\n",
" # df['consecutive_down_limit'] = grouped['cat_down_limit'].apply(\n",
" # lambda x: calculate_consecutive_limits(x)[1]\n",
" # ).reset_index(level=0, drop=True)\n",
"\n",
" return df"
],
"outputs": [],
"execution_count": 5
},
{
"cell_type": "code",
"id": "53f86ddc0677a6d7",
"metadata": {
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-03-09T09:25:46.591961Z",
"start_time": "2025-03-09T09:25:35.982553Z"
}
},
"source": [
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" industry_data = get_act_factor(industry_data, cat=False)\n",
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
"\n",
" for factor in factor_columns:\n",
" if factor in industry_data.columns:\n",
" # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
" industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
" lambda x: x - x.mean())\n",
"\n",
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
"\n",
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
],
"outputs": [],
"execution_count": 6
},
{
"cell_type": "code",
"id": "dbe2fd8021b9417f",
"metadata": {
"jupyter": {
"source_hidden": true
},
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-03-09T09:25:46.623555Z",
"start_time": "2025-03-09T09:25:46.617044Z"
}
},
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'l2_code', 'vol']]\n",
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n"
],
"outputs": [],
"execution_count": 7
},
{
"cell_type": "code",
"id": "5f3d9aece75318cd",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-09T09:27:43.441075Z",
"start_time": "2025-03-09T09:25:46.654855Z"
}
},
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_alpha_factor(df)\n",
"df = get_limit_factor(df)\n",
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 5745009 entries, 1964 to 5745008\n",
"Data columns (total 78 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
" 21 cat_l2_code object \n",
" 22 log_close float64 \n",
" 23 cat_vol_spike bool \n",
" 24 up float64 \n",
" 25 down float64 \n",
" 26 atr_14 float64 \n",
" 27 atr_6 float64 \n",
" 28 obv float64 \n",
" 29 maobv_6 float64 \n",
" 30 obv-maobv_6 float64 \n",
" 31 rsi_3 float64 \n",
" 32 rsi_6 float64 \n",
" 33 rsi_9 float64 \n",
" 34 return_5 float64 \n",
" 35 return_10 float64 \n",
" 36 return_20 float64 \n",
" 37 avg_close_5 float64 \n",
" 38 std_return_5 float64 \n",
" 39 std_return_15 float64 \n",
" 40 std_return_25 float64 \n",
" 41 std_return_90 float64 \n",
" 42 std_return_90_2 float64 \n",
" 43 std_return_5 / std_return_90 float64 \n",
" 44 std_return_5 / std_return_25 float64 \n",
" 45 std_return_90 - std_return_90_2 float64 \n",
" 46 ema_5 float64 \n",
" 47 ema_13 float64 \n",
" 48 ema_20 float64 \n",
" 49 ema_60 float64 \n",
" 50 act_factor1 float64 \n",
" 51 act_factor2 float64 \n",
" 52 act_factor3 float64 \n",
" 53 act_factor4 float64 \n",
" 54 cat_af1 bool \n",
" 55 cat_af2 bool \n",
" 56 cat_af3 bool \n",
" 57 cat_af4 bool \n",
" 58 act_factor5 float64 \n",
" 59 act_factor6 float64 \n",
" 60 rank_act_factor1 float64 \n",
" 61 rank_act_factor2 float64 \n",
" 62 rank_act_factor3 float64 \n",
" 63 active_buy_volume_large float64 \n",
" 64 active_buy_volume_big float64 \n",
" 65 active_buy_volume_small float64 \n",
" 66 buy_lg_vol_minus_sell_lg_vol float64 \n",
" 67 buy_elg_vol_minus_sell_elg_vol float64 \n",
" 68 log(circ_mv) float64 \n",
" 69 alpha_022 float64 \n",
" 70 alpha_003 float64 \n",
" 71 alpha_007 float64 \n",
" 72 alpha_013 float64 \n",
" 73 cat_up_limit int32 \n",
" 74 cat_down_limit int32 \n",
" 75 up_limit_count_10d float64 \n",
" 76 down_limit_count_10d float64 \n",
" 77 consecutive_up_limit int64 \n",
"dtypes: bool(6), datetime64[ns](1), float64(66), int32(2), int64(1), object(2)\n",
"memory usage: 3.1+ GB\n",
"None\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "code",
"id": "f4f16d63ad18d1bc",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:27:44.357846Z",
"start_time": "2025-03-09T09:27:44.334722Z"
}
},
"source": [
"def create_deviation_within_dates(df, feature_columns):\n",
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
" new_columns = {}\n",
" ret_feature_columns = feature_columns[:]\n",
"\n",
" # 自动选择所有数值型特征\n",
" # num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'industry' not in col]\n",
"\n",
" # 遍历所有数值型特征\n",
" for feature in num_features:\n",
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
" continue\n",
"\n",
" # grouped_median = df.groupby(['trade_date', groupby_col])[feature].transform('median')\n",
" # deviation_col_name = f'deviation_median_{feature}'\n",
" # new_columns[deviation_col_name] = df[feature] - grouped_median\n",
" # ret_feature_columns.append(deviation_col_name)\n",
"\n",
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
" deviation_col_name = f'deviation_mean_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
"\n",
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
"\n",
" return df, ret_feature_columns\n"
],
"outputs": [],
"execution_count": 9
},
{
"cell_type": "code",
"id": "0ebdfb92-d88b-4b5c-a715-675dab876fc0",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:27:44.532047Z",
"start_time": "2025-03-09T09:27:44.519506Z"
}
},
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"import pandas as pd\n",
"\n",
"\n",
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.001, upper_percentile: float = 0.999,\n",
" log=True):\n",
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
"\n",
" # Calculate lower and upper bounds based on percentiles\n",
" lower_bound = label.quantile(lower_percentile)\n",
" upper_bound = label.quantile(upper_percentile)\n",
"\n",
" # Filter out values outside the bounds\n",
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
"\n",
" # Print the number of removed outliers\n",
" if log:\n",
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
" return filtered_label\n",
"\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" df['sharpe_ratio'] = df['future_return'] * df['future_volatility']\n",
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return df['sharpe_ratio']\n",
"\n",
"\n"
],
"outputs": [],
"execution_count": 10
},
{
"cell_type": "code",
"id": "fbb968383f8cf2c7",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:37:55.728757Z",
"start_time": "2025-03-09T09:36:53.140751Z"
}
},
"source": [
"days = 5\n",
"df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"# df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / df.groupby('ts_code')['open'].shift(-1)\n",
"# df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
"# df['future_return'] = calculate_risk_adjusted_target(df, days=days)\n",
"# df['future_return'] = df.groupby('ts_code', group_keys=False)['future_return'].apply(\n",
"# lambda x: pd.Series(talib.SMA(x.values, timeperiod=10), index=x.index)\n",
"# )\n",
"# df['future_return'] = remove_outliers_label_percentile(df['future_return'])\n",
"# df['label'] = df['future_return'].transform(\n",
"# lambda x: pd.qcut(x, q=5, labels=False, duplicates='drop')\n",
"# )\n",
"# df['label'] = (df['label'] == 4)\n",
"df['label'] = df['future_return'] > 0.01\n",
"\n",
"\n",
"# df = df.apply(lambda x: x.astype('float32') if x.dtype in ['float64', 'float32'] else x)\n",
"# df = df.sort_values(by=['trade_date', 'ts_code'])\n",
"train_data = df[(df['trade_date'] <= '2025-01-01') & (df['trade_date'] >= '2000-01-01')]\n",
"test_data = df[(df['trade_date'] >= '2025-01-01')]\n",
"\n",
"train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nsmallest(3000, 'log(circ_mv)')\n",
")\n",
"test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nsmallest(3000, 'log(circ_mv)')\n",
")\n",
"train_data = train_data[train_data['cat_vol_spike']]\n",
"test_data = test_data[test_data['cat_vol_spike']]\n",
"train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nlargest(1000, 'return_20')\n",
")\n",
"test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nlargest(1000, 'return_20')\n",
")\n",
"\n",
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
"index_data = index_data.sort_values(by=['trade_date'])\n",
"\n",
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"\n",
"feature_columns = [col for col in train_data.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"\n",
"feature_columns_new = feature_columns[:]\n",
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"print(feature_columns_new)\n",
"\n",
"train_data = train_data.dropna(subset=feature_columns_new)\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"\n",
"# print(test_data.tail())\n",
"test_data = test_data.dropna(subset=feature_columns_new)\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"\n",
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')\n",
"\n",
"import gc\n",
"gc.collect()"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feature_columns size: 199\n",
"feature_columns size: 199\n",
"['vol', 'turnover_rate', 'pe_ttm', 'volume_ratio', 'cat_l2_code', 'log_close', 'cat_vol_spike', 'up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_5', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'cat_af1', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'log(circ_mv)', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013', 'cat_up_limit', 'cat_down_limit', 'up_limit_count_10d', 'down_limit_count_10d', 'consecutive_up_limit', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry_ema_5', 'industry_ema_13', 'industry_ema_20', 'industry_ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_obv_deviation', 'industry_return_5_deviation', 'industry_return_20_deviation', 'industry_act_factor1_deviation', 'industry_act_factor2_deviation', 'industry_act_factor3_deviation', 'industry_act_factor4_deviation', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_ratio_20d', '000852.SH_volatility', '000905.SH_volatility', '399006.SZ_volatility', '000852.SH_volume_change_rate', '000905.SH_volume_change_rate', '399006.SZ_volume_change_rate', 'deviation_mean_vol', 'deviation_mean_turnover_rate', 'deviation_mean_pe_ttm', 'deviation_mean_volume_ratio', 'deviation_mean_log_close', 'deviation_mean_up', 'deviation_mean_down', 'deviation_mean_atr_14', 'deviation_mean_atr_6', 'deviation_mean_obv', 'deviation_mean_maobv_6', 'deviation_mean_obv-maobv_6', 'deviation_mean_rsi_3', 'deviation_mean_rsi_6', 'deviation_mean_rsi_9', 'deviation_mean_return_5', 'deviation_mean_return_10', 'deviation_mean_return_20', 'deviation_mean_avg_close_5', 'deviation_mean_std_return_5', 'deviation_mean_std_return_15', 'deviation_mean_std_return_25', 'deviation_mean_std_return_90', 'deviation_mean_std_return_90_2', 'deviation_mean_std_return_5 / std_return_90', 'deviation_mean_std_return_5 / std_return_25', 'deviation_mean_std_return_90 - std_return_90_2', 'deviation_mean_ema_5', 'deviation_mean_ema_13', 'deviation_mean_ema_20', 'deviation_mean_ema_60', 'deviation_mean_act_factor1', 'deviation_mean_act_factor2', 'deviation_mean_act_factor3', 'deviation_mean_act_factor4', 'deviation_mean_act_factor5', 'deviation_mean_act_factor6', 'deviation_mean_rank_act_factor1', 'deviation_mean_rank_act_factor2', 'deviation_mean_rank_act_factor3', 'deviation_mean_active_buy_volume_large', 'deviation_mean_active_buy_volume_big', 'deviation_mean_active_buy_volume_small', 'deviation_mean_buy_lg_vol_minus_sell_lg_vol', 'deviation_mean_buy_elg_vol_minus_sell_elg_vol', 'deviation_mean_log(circ_mv)', 'deviation_mean_alpha_022', 'deviation_mean_alpha_003', 'deviation_mean_alpha_007', 'deviation_mean_alpha_013', 'deviation_mean_up_limit_count_10d', 'deviation_mean_down_limit_count_10d', 'deviation_mean_consecutive_up_limit', 'deviation_mean_000852.SH_MACD', 'deviation_mean_000905.SH_MACD', 'deviation_mean_399006.SZ_MACD', 'deviation_mean_000852.SH_MACD_hist', 'deviation_mean_000905.SH_MACD_hist', 'deviation_mean_399006.SZ_MACD_hist', 'deviation_mean_000852.SH_RSI', 'deviation_mean_000905.SH_RSI', 'deviation_mean_399006.SZ_RSI', 'deviation_mean_000852.SH_Signal_line', 'deviation_mean_000905.SH_Signal_line', 'deviation_mean_399006.SZ_Signal_line', 'deviation_mean_000852.SH_amount_change_rate', 'deviation_mean_000905.SH_amount_change_rate', 'deviation_mean_399006.SZ_amount_change_rate', 'deviation_mean_000852.SH_amount_mean', 'deviation_mean_000905.SH_amount_mean', 'deviation_mean_399006.SZ_amount_mean', 'deviation_mean_000852.SH_daily_return', 'deviation_mean_000905.SH_daily_return', 'deviation_mean_399006.SZ_daily_return', 'deviation_mean_000852.SH_up_ratio_20d', 'deviation_mean_000905.SH_up_ratio_20d', 'deviation_mean_399006.SZ_up_ratio_20d', 'deviation_mean_000852.SH_volatility', 'deviation_mean_000905.SH_volatility', 'deviation_mean_399006.SZ_volatility', 'deviation_mean_000852.SH_volume_change_rate', 'deviation_mean_000905.SH_volume_change_rate', 'deviation_mean_399006.SZ_volume_change_rate']\n",
"267444\n",
"最小日期: 2017-04-06\n",
"最大日期: 2024-12-31\n",
"4790\n",
"最小日期: 2025-01-02\n",
"最大日期: 2025-03-06\n"
]
},
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 31
},
{
"cell_type": "code",
"id": "26c4c873b83ecc38",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:37:56.091041Z",
"start_time": "2025-03-09T09:37:56.072555Z"
}
},
"source": [
"def remove_highly_correlated_features(df, feature_columns, threshold=0.8):\n",
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
" if not numeric_features:\n",
" raise ValueError(\"No numeric features found in the provided data.\")\n",
"\n",
" corr_matrix = df[numeric_features].corr().abs()\n",
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
" remaining_features = [col for col in feature_columns if col not in to_drop]\n",
" return remaining_features\n",
"\n",
"# feature_columns_new = remove_highly_correlated_features(train_data, feature_columns_new)\n",
"# keep_columns = [col for col in train_data.columns if col in feature_columns_new or col in ['ts_code', 'trade_date', 'label']]\n",
"# train_data = train_data[keep_columns]\n",
"# test_data = test_data[keep_columns]\n",
"# print(f'feature_columns size: {len(feature_columns_new)}')"
],
"outputs": [],
"execution_count": 32
},
{
"cell_type": "code",
"id": "8f134d435f71e9e2",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-09T09:37:56.626336Z",
"start_time": "2025-03-09T09:37:56.606252Z"
}
},
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"from catboost import Pool\n",
"import lightgbm as lgb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def train_light_model(train_data_df, test_data_df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" use_optuna=False):\n",
" train_data_df, test_data_df = train_data_df.dropna(subset=['label']), test_data_df.dropna(subset=['label'])\n",
" X_train = train_data_df[feature_columns]\n",
" y_train = train_data_df['label']\n",
"\n",
" X_val = test_data_df[feature_columns]\n",
" y_val = test_data_df['label']\n",
"\n",
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" print(f'categorical_feature: {categorical_feature}')\n",
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
" val_data = lgb.Dataset(X_val, label=y_val, categorical_feature=categorical_feature)\n",
" model = lgb.train(\n",
" params, train_data, num_boost_round=num_boost_round,\n",
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" # lgb.plot_tree(model, figsize=(20, 8))\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
" return model\n",
"\n",
"\n",
"from catboost import CatBoostClassifier\n",
"import pandas as pd\n",
"\n",
"\n",
"def train_catboost(train_data_df, test_data_df, feature_columns, params=None, plot=False):\n",
"\n",
" train_data_df, test_data_df = train_data_df.dropna(subset=['label']), test_data_df.dropna(subset=['label'])\n",
" X_train = train_data_df[feature_columns]\n",
" y_train = train_data_df['label']\n",
"\n",
" X_val = test_data_df[feature_columns]\n",
" y_val = test_data_df['label']\n",
"\n",
" scaler = StandardScaler()\n",
" numeric_columns = X_train.select_dtypes(include=['float64', 'int64']).columns\n",
" X_train.loc[:, numeric_columns] = scaler.fit_transform(X_train[numeric_columns])\n",
" X_val.loc[:, numeric_columns] = scaler.transform(X_val[numeric_columns])\n",
"\n",
" cat_features = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" print(f'cat_features: {cat_features}')\n",
" train_pool = Pool(data=X_train, label=y_train, cat_features=cat_features)\n",
" val_pool = Pool(data=X_val, label=y_val, cat_features=cat_features)\n",
"\n",
"\n",
" model = CatBoostClassifier(**params)\n",
" model.fit(train_pool,\n",
" eval_set=val_pool, plot=plot, use_best_model=True)\n",
"\n",
" return model, scaler"
],
"outputs": [],
"execution_count": 33
},
{
"cell_type": "code",
"id": "4a4542e1ed6afe7d",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:37:56.652596Z",
"start_time": "2025-03-09T09:37:56.644890Z"
}
},
"source": [
"light_params = {\n",
" 'objective': 'binary',\n",
" 'metric': 'average_precision',\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 1024,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
" 'feature_fraction': 0.7,\n",
" 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" # 'lambda_l1': 80,\n",
" # 'lambda_l2': 65,\n",
" 'verbosity': -1,\n",
" 'num_threads' : 16\n",
"}"
],
"outputs": [],
"execution_count": 34
},
{
"cell_type": "code",
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:37:56.965558Z",
"start_time": "2025-03-09T09:37:56.895916Z"
}
},
"source": [
"print('train data size: ', len(train_data))\n",
" \n",
"evals = {}\n",
"\n",
"gc.collect()\n",
"\n",
"# model, scaler = train_light_model(train_data, test_data, light_params, feature_columns_new,\n",
"# [lgb.log_evaluation(period=500),\n",
"# lgb.callback.record_evaluation(evals),\n",
"# lgb.early_stopping(50, first_metric_only=True)\n",
"# ], evals,\n",
"# num_boost_round=1000, use_optuna=False,\n",
"# print_feature_importance=True)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 267444\n"
]
},
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 35
},
{
"cell_type": "code",
"id": "445dff84-70b2-4fc9-a9b6-1251993324d6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:38:20.416159Z",
"start_time": "2025-03-09T09:37:57.335011Z"
}
},
"source": [
"catboost_params = {\n",
" 'loss_function': 'CrossEntropy', # 适用于二分类\n",
" 'eval_metric': 'CrossEntropy', # 评估指标\n",
" 'iterations': 1000,\n",
" 'learning_rate': 0.05,\n",
" 'depth': 10, # 控制模型复杂度\n",
" 'l2_leaf_reg': 3, # L2 正则化\n",
" 'verbose': 500,\n",
" 'early_stopping_rounds': 100,\n",
" # 'one_hot_max_size': 50,\n",
" # 'class_weights': [0.6, 1.2]\n",
" 'task_type': 'GPU'\n",
"}\n",
"feature_weights = {col: 2.0 if 'act_factor' in col or 'af' in col else 1.0 for col in feature_columns_new}\n",
"catboost_params['feature_weights'] = feature_weights\n",
"\n",
"gc.collect()\n",
"\n",
"model, scaler = train_catboost(train_data, test_data, feature_columns_new, catboost_params, plot=True)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cat_features: [4, 6, 37, 38, 39, 40, 56, 57]\n"
]
},
{
"data": {
"text/plain": [
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "53df0a673f9d460ea35611b7b2dc649f"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0:\tlearn: 0.6773014\ttest: 0.6802322\tbest: 0.6802322 (0)\ttotal: 81.9ms\tremaining: 1m 21s\n",
"bestTest = 0.5604671088\n",
"bestIteration = 82\n",
"Shrink model to first 83 iterations.\n"
]
}
],
"execution_count": 36
},
{
"cell_type": "code",
"id": "5d1522a7538db91b",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:38:20.430959Z",
"start_time": "2025-03-09T09:38:20.426274Z"
}
},
"source": [
"# score_df = train_data\n",
"# score_df['score'] = model.predict_proba(score_df[feature_columns_new])[:, -1]\n",
"# # score_df['score'] = model.predict(score_df[feature_columns_new])\n",
"# predictions_train = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
"# # predictions_train = predictions_train[predictions_train['score'] > 0.]\n",
"# predictions_train[['trade_date', 'score', 'ts_code']].to_csv('predictions_train.tsv', index=False)"
],
"outputs": [],
"execution_count": 37
},
{
"cell_type": "code",
"id": "a3d7e881-c9b7-48a9-ba7f-20cd28c33f37",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:38:20.881701Z",
"start_time": "2025-03-09T09:38:20.658046Z"
}
},
"source": [
"score_df = test_data.copy()\n",
"numeric_columns = score_df[feature_columns_new].select_dtypes(include=['float64', 'int64']).columns\n",
"score_df.loc[:, numeric_columns] = scaler.transform(score_df[numeric_columns])\n",
"score_df['score'] = model.predict_proba(score_df[feature_columns_new])[:, -1]\n",
"# score_df['score'] = model.predict(score_df[feature_columns_new])\n",
"predictions_test = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
"# predictions_test = predictions_test[predictions_test['score'] > 0.5]\n",
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)"
],
"outputs": [],
"execution_count": 38
},
{
"cell_type": "code",
"id": "ebae809e26a7b594",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:38:21.084940Z",
"start_time": "2025-03-09T09:38:21.067825Z"
}
},
"source": "print(predictions_test[predictions_test['trade_date'] > '2018-01-01'][['label', 'score', 'future_return', 'ts_code', 'trade_date']].head(10))",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" label score future_return ts_code trade_date\n",
"8 False 0.373536 -0.058994 603881.SH 2025-01-02\n",
"117 True 0.394262 0.067581 002577.SZ 2025-01-03\n",
"232 True 0.450861 0.104442 603308.SH 2025-01-06\n",
"275 True 0.333439 0.052550 000573.SZ 2025-01-07\n",
"409 False 0.367638 -0.021196 600446.SH 2025-01-08\n",
"491 True 0.342494 0.056258 601208.SH 2025-01-09\n",
"562 True 0.376280 0.059358 002437.SZ 2025-01-10\n",
"607 True 0.384604 0.156388 002811.SZ 2025-01-13\n",
"719 False 0.303999 0.008529 603348.SH 2025-01-14\n",
"730 False 0.298599 0.015870 002552.SZ 2025-01-15\n"
]
}
],
"execution_count": 39
},
{
"cell_type": "code",
"id": "751a6df9-d90b-4053-8769-c6c3b6654406",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-09T09:38:21.371026Z",
"start_time": "2025-03-09T09:38:21.360910Z"
}
},
"source": [
"# df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"# # df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / df.groupby('ts_code')['open'].shift(-1)\n",
"# # df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"#\n",
"# # df['future_return'] = calculate_risk_adjusted_target(df, days=days)\n",
"# # df['future_return'] = df.groupby('ts_code', group_keys=False)['future_return'].apply(\n",
"# # lambda x: pd.Series(talib.SMA(x.values, timeperiod=10), index=x.index)\n",
"# # )\n",
"# df['future_return'] = remove_outliers_label_percentile(df['future_return'], lower_percentile=0.001, upper_percentile=0.999)\n",
"# print(df[(df['ts_code'] == '002095.SZ') & (df['trade_date'] >= '2023-01-03')].head()[['ts_code', 'trade_date', 'close', 'future_return']])\n"
],
"outputs": [],
"execution_count": 40
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}