Files
NewStock/main/train/UpdateSGD.ipynb
2025-04-28 11:02:52 +08:00

1388 lines
59 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T12:44:04.673079Z",
"start_time": "2025-03-02T12:44:04.247257Z"
}
},
"source": [
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import pandas as pd\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"pd.set_option('display.max_columns', None)\n"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-02T12:44:43.595370Z",
"start_time": "2025-03-02T12:44:04.688084Z"
}
},
"source": [
"from code.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8369855 entries, 0 to 8369854\n",
"Data columns (total 21 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(18), object(1)\n",
"memory usage: 1.3+ GB\n",
"None\n"
]
}
],
"execution_count": 2
},
{
"cell_type": "code",
"id": "f7a55c19-b7dc-4d2f-a478-cffab11690df",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T12:44:46.323145Z",
"start_time": "2025-03-02T12:44:43.776850Z"
}
},
"source": [
"print('industry')\n",
"df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code'],\n",
" df=df, on=['ts_code'], join='left')\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n",
"left merge on ['ts_code']\n"
]
}
],
"execution_count": 3
},
{
"cell_type": "code",
"id": "4077d4449d406c86",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T12:44:46.389069Z",
"start_time": "2025-03-02T12:44:46.332410Z"
}
},
"source": [
"\n",
"\n",
"\n",
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" df = df.sort_values('trade_date')\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" return df\n",
"\n",
"\n",
"def generate_index_indicators(h5_filename):\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" df = df.sort_values('trade_date')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
" df_index = calculate_indicators(df_index)\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', 'RSI', 'MACD', 'Signal_line', 'MACD_hist'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
"h5_filename = '../../data/index_data.h5'\n",
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n"
],
"outputs": [],
"execution_count": 4
},
{
"cell_type": "code",
"id": "c4e9e1d31da6dba6",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-02T12:44:46.438183Z",
"start_time": "2025-03-02T12:44:46.409533Z"
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"def get_technical_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 计算 up 和 down\n",
" df['log_close'] = np.log(df['close'])\n",
"\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),\n",
" index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),\n",
" index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算 RSI\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df, cat=True):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
" # 计算 EMA 指标\n",
" df['ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" if cat:\n",
" df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" return df\n",
"\n"
],
"outputs": [],
"execution_count": 5
},
{
"cell_type": "code",
"id": "a735bc02ceb4d872",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T12:44:54.093568Z",
"start_time": "2025-03-02T12:44:46.451967Z"
}
},
"source": [
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" industry_data = get_act_factor(industry_data, cat=False)\n",
" # industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
"\n",
" for factor in factor_columns:\n",
" if factor in industry_data.columns:\n",
" # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
" industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
" lambda x: x - x.median())\n",
"\n",
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
"\n",
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
],
"outputs": [],
"execution_count": 6
},
{
"cell_type": "code",
"id": "53f86ddc0677a6d7",
"metadata": {
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-03-02T12:44:54.102298Z",
"start_time": "2025-03-02T12:44:54.093568Z"
}
},
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'l2_code']]\n",
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n"
],
"outputs": [],
"execution_count": 7
},
{
"cell_type": "code",
"id": "dbe2fd8021b9417f",
"metadata": {
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-03-02T12:46:07.996377Z",
"start_time": "2025-03-02T12:44:54.115006Z"
}
},
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_alpha_factor(df)\n",
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 5732462 entries, 1964 to 5732460\n",
"Data columns (total 72 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
" 21 cat_l2_code object \n",
" 22 log_close float64 \n",
" 23 up float64 \n",
" 24 down float64 \n",
" 25 atr_14 float64 \n",
" 26 atr_6 float64 \n",
" 27 obv float64 \n",
" 28 maobv_6 float64 \n",
" 29 obv-maobv_6 float64 \n",
" 30 rsi_3 float64 \n",
" 31 rsi_6 float64 \n",
" 32 rsi_9 float64 \n",
" 33 return_5 float64 \n",
" 34 return_10 float64 \n",
" 35 return_20 float64 \n",
" 36 avg_close_5 float64 \n",
" 37 std_return_5 float64 \n",
" 38 std_return_15 float64 \n",
" 39 std_return_25 float64 \n",
" 40 std_return_90 float64 \n",
" 41 std_return_90_2 float64 \n",
" 42 std_return_5 / std_return_90 float64 \n",
" 43 std_return_5 / std_return_25 float64 \n",
" 44 std_return_90 - std_return_90_2 float64 \n",
" 45 ema_5 float64 \n",
" 46 ema_13 float64 \n",
" 47 ema_20 float64 \n",
" 48 ema_60 float64 \n",
" 49 act_factor1 float64 \n",
" 50 act_factor2 float64 \n",
" 51 act_factor3 float64 \n",
" 52 act_factor4 float64 \n",
" 53 cat_af1 bool \n",
" 54 cat_af2 bool \n",
" 55 cat_af3 bool \n",
" 56 cat_af4 bool \n",
" 57 act_factor5 float64 \n",
" 58 act_factor6 float64 \n",
" 59 rank_act_factor1 float64 \n",
" 60 rank_act_factor2 float64 \n",
" 61 rank_act_factor3 float64 \n",
" 62 active_buy_volume_large float64 \n",
" 63 active_buy_volume_big float64 \n",
" 64 active_buy_volume_small float64 \n",
" 65 buy_lg_vol_minus_sell_lg_vol float64 \n",
" 66 buy_elg_vol_minus_sell_elg_vol float64 \n",
" 67 log(circ_mv) float64 \n",
" 68 alpha_022 float64 \n",
" 69 alpha_003 float64 \n",
" 70 alpha_007 float64 \n",
" 71 alpha_013 float64 \n",
"dtypes: bool(5), datetime64[ns](1), float64(64), object(2)\n",
"memory usage: 2.9+ GB\n",
"None\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "code",
"id": "d345bcc43b15579e",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T12:46:08.743665Z",
"start_time": "2025-03-02T12:46:08.728004Z"
}
},
"source": [
"def create_deviation_within_dates(df, feature_columns):\n",
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
" new_columns = {}\n",
" ret_feature_columns = feature_columns[:]\n",
"\n",
" # 自动选择所有数值型特征\n",
" # num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'industry' not in col]\n",
"\n",
" # 遍历所有数值型特征\n",
" for feature in num_features:\n",
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
" continue\n",
"\n",
" # grouped_median = df.groupby(['trade_date', groupby_col])[feature].transform('median')\n",
" # deviation_col_name = f'deviation_median_{feature}'\n",
" # new_columns[deviation_col_name] = df[feature] - grouped_median\n",
" # ret_feature_columns.append(deviation_col_name)\n",
"\n",
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
" deviation_col_name = f'deviation_mean_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
"\n",
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
"\n",
" return df, ret_feature_columns\n"
],
"outputs": [],
"execution_count": 9
},
{
"cell_type": "code",
"id": "5f3d9aece75318cd",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T12:46:08.949623Z",
"start_time": "2025-03-02T12:46:08.931989Z"
}
},
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"import pandas as pd\n",
"\n",
"\n",
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
" log=True):\n",
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
"\n",
" # Calculate lower and upper bounds based on percentiles\n",
" lower_bound = label.quantile(lower_percentile)\n",
" upper_bound = label.quantile(upper_percentile)\n",
"\n",
" # Filter out values outside the bounds\n",
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
"\n",
" # Print the number of removed outliers\n",
" if log:\n",
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
" return filtered_label\n",
"\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" df['sharpe_ratio'] = df['future_return'] * df['future_volatility']\n",
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return df['sharpe_ratio']\n",
"\n",
"\n"
],
"outputs": [],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T13:16:09.082156Z",
"start_time": "2025-03-02T13:14:58.041672Z"
}
},
"cell_type": "code",
"source": [
"days = 3\n",
"future_close = df.groupby('ts_code')['close'].shift(-days)\n",
"future_return = (future_close - df['close']) / df['close']\n",
"df['label'] = future_return\n",
"\n",
"# df['label'] = remove_outliers_label_percentile(df['label'])\n",
"\n",
"# df = df.apply(lambda x: x.astype('float32') if x.dtype in ['float64', 'float32'] else x)\n",
"df = df.sort_values(by=['trade_date', 'ts_code'])\n",
"train_data = df[(df['trade_date'] <= '2023-01-01') & (df['trade_date'] >= '2016-01-01')]\n",
"test_data = df[(df['trade_date'] >= '2023-01-01')]\n",
"\n",
"train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nsmallest(3000, 'log(circ_mv)')\n",
")\n",
"test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nsmallest(3000, 'log(circ_mv)')\n",
")\n",
"train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nlargest(1000, 'return_20')\n",
")\n",
"test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nlargest(1000, 'return_20')\n",
")\n",
"\n",
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
"index_data = index_data.sort_values(by=['trade_date'])\n",
"\n",
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"\n",
"feature_columns = [col for col in train_data.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"print(feature_columns)\n",
"\n",
"feature_columns_new = feature_columns[:]\n",
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"\n",
"train_data = train_data.dropna(subset=feature_columns_new)\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data['label'] = remove_outliers_label_percentile(train_data['label'])\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"\n",
"# print(test_data.tail())\n",
"test_data = test_data.dropna(subset=feature_columns_new)\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"\n",
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')"
],
"id": "cf7de0b77db39655",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['turnover_rate', 'pe_ttm', 'volume_ratio', 'cat_l2_code', 'log_close', 'up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_5', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'cat_af1', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'log(circ_mv)', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry_ema_5', 'industry_ema_13', 'industry_ema_20', 'industry_ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_obv_deviation', 'industry_return_5_deviation', 'industry_return_20_deviation', 'industry_act_factor1_deviation', 'industry_act_factor2_deviation', 'industry_act_factor3_deviation', 'industry_act_factor4_deviation', 'industry_return_5_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return']\n",
"feature_columns size: 157\n",
"feature_columns size: 157\n",
"Removed 23208 outliers.\n",
"1137173\n",
"最小日期: 2017-04-06\n",
"最大日期: 2022-12-30\n",
"396093\n",
"最小日期: 2023-01-03\n",
"最大日期: 2025-02-28\n"
]
}
],
"execution_count": 68
},
{
"cell_type": "code",
"id": "8f134d435f71e9e2",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-02T13:23:29.071382Z",
"start_time": "2025-03-02T13:23:29.059163Z"
}
},
"source": [
"\n",
"import numpy as np\n",
"\n",
"import pandas as pd\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"\n",
"\n",
"def train_sgd_model(train_data: pd.DataFrame,\n",
" feature_columns: list,\n",
" params: dict, print_feature_importance=True):\n",
" # Initialize scaler and encoder\n",
" scaler = StandardScaler()\n",
" encoder = OneHotEncoder(handle_unknown='ignore')\n",
"\n",
" # Extract features and labels\n",
" X_train = train_data[feature_columns]\n",
" y_train = train_data['label']\n",
"\n",
" numeric_columns = X_train.select_dtypes(include=['float64', 'int64']).columns\n",
" categorical_columns = [col for col in feature_columns if col.startswith('cat')]\n",
"\n",
" X_train.loc[:, numeric_columns] = scaler.fit_transform(X_train[numeric_columns])\n",
" X_train_categorical = encoder.fit_transform(X_train[categorical_columns]).toarray()\n",
"\n",
" # Combine numeric and categorical features\n",
" X_train_processed = pd.concat([\n",
" pd.DataFrame(X_train[numeric_columns], columns=numeric_columns, index=X_train.index),\n",
" pd.DataFrame(X_train_categorical, columns=encoder.get_feature_names_out(categorical_columns),\n",
" index=X_train.index)\n",
" ], axis=1)\n",
"\n",
" # Train the model\n",
" # model = SGDRegressor(**params)\n",
" model = LinearRegression()\n",
" model.fit(X_train_processed, y_train)\n",
"\n",
" # 特征重要性可视化\n",
" if print_feature_importance:\n",
" coefficients = model.coef_\n",
"\n",
" # 创建一个字典,存储特征名称和对应的系数\n",
" feature_importance = dict(zip(X_train_processed.columns, coefficients))\n",
"\n",
" # 按系数绝对值排序\n",
" sorted_importance = sorted(feature_importance.items(), key=lambda x: abs(x[1]), reverse=True)\n",
"\n",
" # 打印特征重要性\n",
" print(\"Feature Importance:\")\n",
" for feature, importance in sorted_importance:\n",
" print(f\"{feature}: {importance:.4f}\")\n",
"\n",
" return model, scaler, encoder"
],
"outputs": [],
"execution_count": 79
},
{
"cell_type": "code",
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T13:23:56.411933Z",
"start_time": "2025-03-02T13:23:29.076456Z"
}
},
"source": [
"print('train data size: ', len(train_data))\n",
"import gc\n",
"\n",
"gc.collect()\n",
"params = {\n",
" 'alpha': 0.0001, # 正则化强度\n",
" 'max_iter': 1000, # 最大迭代次数\n",
" 'tol': 1e-3, # 收敛阈值\n",
" 'eta0': 0.01, # 初始学习率\n",
" 'learning_rate': 'constant'\n",
"}\n",
"\n",
"model, scaler, encoder = train_sgd_model(train_data, feature_columns_new, params)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 1137173\n",
"Feature Importance:\n",
"deviation_mean_ema_13: 0.0841\n",
"deviation_mean_ema_20: -0.0679\n",
"ema_13: -0.0630\n",
"ema_20: 0.0550\n",
"industry_ema_20: 0.0238\n",
"industry_ema_13: -0.0227\n",
"act_factor3: -0.0217\n",
"deviation_mean_ema_5: -0.0200\n",
"industry_obv: 0.0178\n",
"industry_obv_deviation: -0.0176\n",
"rsi_6: 0.0155\n",
"act_factor2: 0.0152\n",
"ema_5: 0.0118\n",
"industry_act_factor4: -0.0092\n",
"rsi_3: -0.0091\n",
"deviation_mean_rsi_6: -0.0088\n",
"cat_l2_code_801056.SI: 0.0087\n",
"cat_l2_code_801125.SI: 0.0087\n",
"deviation_mean_act_factor2: -0.0074\n",
"deviation_mean_act_factor3: 0.0073\n",
"industry_act_factor1: 0.0073\n",
"alpha_013: -0.0073\n",
"rank_act_factor2: 0.0072\n",
"cat_l2_code_801039.SI: 0.0070\n",
"cat_l2_code_801194.SI: -0.0069\n",
"cat_l2_code_801782.SI: 0.0067\n",
"industry_act_factor4_deviation: 0.0066\n",
"industry_act_factor1_deviation: -0.0064\n",
"rank_act_factor3: -0.0062\n",
"deviation_mean_alpha_013: 0.0061\n",
"cat_l2_code_801737.SI: 0.0061\n",
"deviation_mean_rank_act_factor2: -0.0058\n",
"industry_act_factor2_deviation: 0.0057\n",
"log_close: -0.0057\n",
"cat_l2_code_801972.SI: -0.0056\n",
"cat_l2_code_801735.SI: 0.0055\n",
"cat_l2_code_801044.SI: 0.0055\n",
"atr_14: -0.0052\n",
"deviation_mean_log_close: 0.0050\n",
"turnover_rate: -0.0050\n",
"cat_l2_code_801995.SI: -0.0049\n",
"cat_l2_code_801081.SI: 0.0049\n",
"399006.SZ_RSI: 0.0048\n",
"cat_l2_code_801191.SI: -0.0048\n",
"deviation_mean_act_factor5: 0.0048\n",
"cat_l2_code_801054.SI: 0.0047\n",
"deviation_mean_atr_14: 0.0047\n",
"deviation_mean_rank_act_factor3: 0.0046\n",
"cat_l2_code_801231.SI: -0.0044\n",
"deviation_mean_rsi_3: 0.0044\n",
"industry_act_factor2: -0.0044\n",
"cat_l2_code_801085.SI: 0.0043\n",
"cat_l2_code_801084.SI: 0.0042\n",
"cat_l2_code_801154.SI: -0.0041\n",
"cat_l2_code_801111.SI: 0.0041\n",
"000905.SH_MACD_hist: 0.0039\n",
"399006.SZ_MACD_hist: -0.0038\n",
"cat_l2_code_801721.SI: -0.0036\n",
"cat_l2_code_801181.SI: -0.0036\n",
"cat_l2_code_801011.SI: 0.0035\n",
"cat_l2_code_801994.SI: -0.0035\n",
"return_10: 0.0034\n",
"cat_l2_code_801076.SI: -0.0034\n",
"cat_l2_code_801017.SI: 0.0033\n",
"cat_l2_code_801016.SI: 0.0033\n",
"deviation_mean_rsi_9: 0.0033\n",
"industry_ema_60: -0.0032\n",
"cat_l2_code_801018.SI: -0.0032\n",
"cat_l2_code_801202.SI: -0.0031\n",
"cat_l2_code_801114.SI: -0.0031\n",
"cat_l2_code_801203.SI: -0.0031\n",
"std_return_5 / std_return_25: 0.0031\n",
"cat_l2_code_801736.SI: 0.0031\n",
"cat_l2_code_801993.SI: -0.0030\n",
"act_factor5: -0.0030\n",
"act_factor6: 0.0030\n",
"industry_act_factor5: -0.0029\n",
"cat_l2_code_801012.SI: -0.0028\n",
"atr_6: 0.0028\n",
"cat_l2_code_801092.SI: -0.0028\n",
"000905.SH_MACD: 0.0028\n",
"cat_l2_code_801077.SI: 0.0028\n",
"cat_l2_code_801112.SI: -0.0028\n",
"cat_l2_code_801744.SI: -0.0028\n",
"cat_l2_code_801055.SI: 0.0027\n",
"cat_l2_code_801034.SI: 0.0027\n",
"deviation_mean_atr_6: -0.0027\n",
"cat_l2_code_801183.SI: -0.0026\n",
"cat_l2_code_801769.SI: -0.0026\n",
"deviation_mean_act_factor6: -0.0025\n",
"industry_return_20: 0.0024\n",
"deviation_mean_turnover_rate: 0.0024\n",
"cat_l2_code_801952.SI: 0.0024\n",
"industry_return_20_deviation: -0.0024\n",
"000905.SH_RSI: -0.0024\n",
"cat_l2_code_801971.SI: -0.0023\n",
"deviation_mean_std_return_5 / std_return_25: -0.0023\n",
"cat_l2_code_801095.SI: 0.0023\n",
"act_factor4: 0.0023\n",
"cat_l2_code_801178.SI: -0.0023\n",
"std_return_15: 0.0023\n",
"rsi_9: -0.0023\n",
"deviation_mean_return_10: -0.0022\n",
"cat_l2_code_801712.SI: 0.0022\n",
"cat_l2_code_801731.SI: 0.0022\n",
"up: 0.0022\n",
"cat_l2_code_801083.SI: 0.0021\n",
"cat_l2_code_801204.SI: -0.0021\n",
"cat_l2_code_801723.SI: -0.0021\n",
"cat_l2_code_801992.SI: -0.0021\n",
"cat_l2_code_801155.SI: -0.0021\n",
"cat_l2_code_801179.SI: -0.0021\n",
"alpha_007: -0.0020\n",
"std_return_25: 0.0020\n",
"cat_l2_code_801742.SI: -0.0020\n",
"deviation_mean_avg_close_5: -0.0020\n",
"deviation_mean_std_return_25: -0.0019\n",
"cat_l2_code_801767.SI: -0.0019\n",
"cat_l2_code_801086.SI: 0.0019\n",
"deviation_mean_up: -0.0019\n",
"act_factor1: -0.0019\n",
"industry_act_factor3_deviation: 0.0019\n",
"std_return_5: 0.0018\n",
"deviation_mean_std_return_90 - std_return_90_2: 0.0018\n",
"000852.SH_Signal_line: 0.0018\n",
"deviation_mean_std_return_90_2: 0.0018\n",
"cat_l2_code_801218.SI: -0.0018\n",
"rank_act_factor1: -0.0017\n",
"deviation_mean_std_return_15: -0.0017\n",
"000905.SH_Signal_line: 0.0017\n",
"cat_l2_code_801783.SI: -0.0017\n",
"deviation_mean_act_factor4: -0.0016\n",
"std_return_90 - std_return_90_2: -0.0016\n",
"industry_return_5_deviation: -0.0016\n",
"cat_l2_code_801045.SI: 0.0016\n",
"cat_l2_code_801765.SI: -0.0016\n",
"industry_rank_act_factor3: -0.0016\n",
"cat_l2_code_801784.SI: -0.0015\n",
"industry_rank_act_factor2: 0.0015\n",
"cat_l2_code_801082.SI: 0.0015\n",
"cat_l2_code_801014.SI: 0.0015\n",
"cat_l2_code_801141.SI: -0.0015\n",
"cat_l2_code_801124.SI: 0.0015\n",
"cat_af2_1.0: -0.0015\n",
"cat_af2_0.0: 0.0015\n",
"cat_l2_code_801113.SI: 0.0014\n",
"cat_l2_code_801142.SI: -0.0014\n",
"cat_l2_code_801015.SI: -0.0014\n",
"cat_l2_code_801104.SI: 0.0014\n",
"399006.SZ_daily_return: 0.0014\n",
"000852.SH_MACD: 0.0013\n",
"alpha_003: -0.0013\n",
"cat_l2_code_801219.SI: -0.0013\n",
"deviation_mean_ema_60: 0.0013\n",
"industry_ema_5: 0.0013\n",
"cat_l2_code_801078.SI: 0.0013\n",
"000852.SH_MACD_hist: -0.0012\n",
"industry_rank_act_factor1: -0.0012\n",
"cat_l2_code_801101.SI: 0.0012\n",
"cat_l2_code_801033.SI: 0.0012\n",
"cat_l2_code_801152.SI: -0.0012\n",
"cat_l2_code_801128.SI: -0.0012\n",
"return_5: -0.0012\n",
"deviation_mean_std_return_90: -0.0012\n",
"down: -0.0012\n",
"deviation_mean_rank_act_factor1: 0.0012\n",
"cat_l2_code_801096.SI: -0.0011\n",
"deviation_mean_log(circ_mv): -0.0011\n",
"deviation_mean_alpha_007: 0.0011\n",
"maobv_6: -0.0011\n",
"ema_60: -0.0011\n",
"cat_l2_code_801963.SI: -0.0011\n",
"cat_l2_code_801738.SI: -0.0011\n",
"cat_l2_code_801093.SI: 0.0011\n",
"obv: -0.0011\n",
"cat_l2_code_801074.SI: 0.0011\n",
"cat_l2_code_801193.SI: 0.0010\n",
"log(circ_mv): 0.0010\n",
"cat_l2_code_801785.SI: -0.0010\n",
"cat_l2_code_801115.SI: -0.0010\n",
"cat_l2_code_801206.SI: 0.0010\n",
"cat_l2_code_801724.SI: -0.0010\n",
"deviation_mean_maobv_6: 0.0010\n",
"cat_l2_code_801991.SI: -0.0009\n",
"cat_l2_code_801161.SI: -0.0009\n",
"deviation_mean_obv: 0.0009\n",
"cat_l2_code_801143.SI: 0.0009\n",
"cat_l2_code_801053.SI: 0.0008\n",
"cat_l2_code_801156.SI: -0.0008\n",
"volume_ratio: -0.0008\n",
"cat_l2_code_801962.SI: -0.0008\n",
"std_return_90: -0.0008\n",
"cat_l2_code_801711.SI: 0.0008\n",
"buy_elg_vol_minus_sell_elg_vol: -0.0008\n",
"cat_l2_code_801038.SI: -0.0008\n",
"cat_l2_code_801072.SI: 0.0008\n",
"cat_l2_code_801037.SI: 0.0007\n",
"deviation_mean_active_buy_volume_large: -0.0007\n",
"cat_l2_code_801981.SI: -0.0007\n",
"active_buy_volume_large: 0.0007\n",
"cat_l2_code_801032.SI: -0.0007\n",
"deviation_mean_buy_elg_vol_minus_sell_elg_vol: 0.0007\n",
"399006.SZ_MACD: -0.0007\n",
"obv-maobv_6: 0.0007\n",
"cat_l2_code_801733.SI: 0.0006\n",
"cat_l2_code_801103.SI: -0.0006\n",
"deviation_mean_obv-maobv_6: -0.0006\n",
"std_return_5 / std_return_90: -0.0006\n",
"cat_l2_code_801163.SI: -0.0006\n",
"000852.SH_daily_return: -0.0006\n",
"cat_l2_code_801043.SI: 0.0006\n",
"cat_l2_code_801131.SI: -0.0005\n",
"cat_af1_1.0: -0.0005\n",
"cat_af1_0.0: 0.0005\n",
"cat_l2_code_801764.SI: -0.0005\n",
"cat_l2_code_801153.SI: -0.0005\n",
"deviation_mean_act_factor1: -0.0005\n",
"cat_l2_code_801951.SI: -0.0005\n",
"cat_l2_code_801726.SI: -0.0005\n",
"deviation_mean_std_return_5: -0.0005\n",
"cat_l2_code_801036.SI: 0.0005\n",
"industry_act_factor6: 0.0005\n",
"industry_return_5_percentile: 0.0005\n",
"cat_af4_1.0: 0.0005\n",
"cat_af4_0.0: -0.0005\n",
"cat_l2_code_801223.SI: -0.0005\n",
"cat_l2_code_801881.SI: -0.0004\n",
"399006.SZ_Signal_line: 0.0004\n",
"cat_l2_code_801743.SI: -0.0004\n",
"deviation_mean_active_buy_volume_small: 0.0004\n",
"return_20: -0.0004\n",
"000905.SH_daily_return: -0.0003\n",
"cat_l2_code_801982.SI: -0.0003\n",
"pe_ttm: -0.0003\n",
"cat_l2_code_801713.SI: 0.0003\n",
"cat_l2_code_801129.SI: 0.0003\n",
"active_buy_volume_small: -0.0003\n",
"deviation_mean_return_5: -0.0003\n",
"deviation_mean_399006.SZ_Signal_line: -0.0003\n",
"cat_l2_code_801126.SI: 0.0003\n",
"deviation_mean_pe_ttm: 0.0002\n",
"deviation_mean_alpha_003: 0.0002\n",
"cat_l2_code_801116.SI: -0.0002\n",
"cat_af3_0.0: 0.0002\n",
"cat_af3_1.0: -0.0002\n",
"cat_l2_code_801132.SI: -0.0002\n",
"deviation_mean_399006.SZ_daily_return: 0.0002\n",
"std_return_90_2: -0.0002\n",
"deviation_mean_alpha_022: 0.0002\n",
"deviation_mean_std_return_5 / std_return_90: -0.0002\n",
"deviation_mean_000852.SH_RSI: -0.0002\n",
"deviation_mean_return_20: 0.0002\n",
"deviation_mean_down: -0.0002\n",
"active_buy_volume_big: 0.0002\n",
"deviation_mean_volume_ratio: -0.0002\n",
"cat_l2_code_801133.SI: 0.0002\n",
"industry_act_factor3: 0.0002\n",
"alpha_022: -0.0001\n",
"deviation_mean_399006.SZ_MACD: -0.0001\n",
"avg_close_5: -0.0001\n",
"deviation_mean_000852.SH_MACD: 0.0001\n",
"cat_l2_code_801722.SI: -0.0001\n",
"cat_l2_code_801766.SI: 0.0001\n",
"deviation_mean_active_buy_volume_big: -0.0001\n",
"cat_l2_code_801127.SI: -0.0001\n",
"deviation_mean_399006.SZ_MACD_hist: 0.0001\n",
"deviation_mean_000905.SH_Signal_line: -0.0001\n",
"cat_l2_code_801745.SI: -0.0001\n",
"deviation_mean_000905.SH_RSI: 0.0001\n",
"deviation_mean_000905.SH_MACD_hist: -0.0001\n",
"deviation_mean_399006.SZ_RSI: -0.0001\n",
"industry_return_5: 0.0001\n",
"cat_l2_code_801102.SI: -0.0001\n",
"cat_l2_code_801741.SI: 0.0001\n",
"000852.SH_RSI: 0.0001\n",
"deviation_mean_000852.SH_daily_return: -0.0001\n",
"deviation_mean_000852.SH_Signal_line: -0.0001\n",
"deviation_mean_000852.SH_MACD_hist: -0.0001\n",
"cat_l2_code_801145.SI: -0.0000\n",
"deviation_mean_buy_lg_vol_minus_sell_lg_vol: 0.0000\n",
"cat_l2_code_801051.SI: -0.0000\n",
"cat_l2_code_801151.SI: -0.0000\n",
"deviation_mean_000905.SH_daily_return: -0.0000\n",
"deviation_mean_000905.SH_MACD: -0.0000\n",
"buy_lg_vol_minus_sell_lg_vol: 0.0000\n"
]
}
],
"execution_count": 80
},
{
"cell_type": "code",
"id": "465944b1d463e4b1",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T13:23:57.174623Z",
"start_time": "2025-03-02T13:23:57.148826Z"
}
},
"source": [
"from tqdm import tqdm\n",
"\n",
"\n",
"def incremental_training(test_data: pd.DataFrame,\n",
" model,\n",
" scaler,\n",
" encoder,\n",
" days: int,\n",
" back_days: int,\n",
" feature_columns: list,\n",
" params: dict\n",
" ):\n",
" test_data = test_data.sort_values(by='trade_date')\n",
" scores = []\n",
" unique_trade_dates = sorted(test_data['trade_date'].unique())\n",
"\n",
" new_model = None\n",
" for i in tqdm(range(0, len(unique_trade_dates))):\n",
" # Get the current window of trade dates\n",
" current_dates = [unique_trade_dates[i]]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" X = window_data[feature_columns]\n",
" numeric_columns = X.select_dtypes(include=['float64', 'int64']).columns\n",
" categorical_columns = [col for col in feature_columns if col.startswith('cat')]\n",
" X.loc[:, numeric_columns] = scaler.transform(X[numeric_columns])\n",
" X_categorical = encoder.transform(X[categorical_columns]).toarray()\n",
"\n",
" # Combine numeric and categorical features\n",
" X_processed = pd.concat([\n",
" pd.DataFrame(X[numeric_columns], columns=numeric_columns, index=X.index),\n",
" pd.DataFrame(X_categorical, columns=encoder.get_feature_names_out(categorical_columns), index=X.index)\n",
" ], axis=1)\n",
" X_processed = X_processed.fillna(0)\n",
"\n",
" if new_model is not None:\n",
" window_scores = new_model.predict(X_processed)\n",
" else:\n",
" window_scores = model.predict(X_processed)\n",
" scores.extend(window_scores)\n",
"\n",
" # # Prepare data for incremental training\n",
" # current_dates = unique_trade_dates[max(0, i - back_days):i + days]\n",
" # window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" # X_train = window_data[feature_columns]\n",
" current_dates = unique_trade_dates[max(0, i - days - back_days):i + 1 - back_days]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" window_data['label'] = remove_outliers_label_percentile(window_data['label'], log=False)\n",
" window_data = window_data.dropna(subset=feature_columns)\n",
" window_data = window_data.dropna(subset=['label'])\n",
" X_train = window_data[feature_columns]\n",
" y_train = window_data['label']\n",
" # Incrementally train the model\n",
" if len(y_train.unique()) > 1:\n",
" X_train.loc[:, numeric_columns] = scaler.transform(X_train[numeric_columns])\n",
" X_train_categorical = encoder.transform(X_train[categorical_columns]).toarray()\n",
" X_train_processed = pd.concat([\n",
" pd.DataFrame(X_train[numeric_columns], columns=numeric_columns, index=X_train.index),\n",
" pd.DataFrame(X_train_categorical, columns=encoder.get_feature_names_out(categorical_columns),\n",
" index=X_train.index)\n",
" ], axis=1)\n",
" X_train_processed = X_train_processed.fillna(0)\n",
" model = model.partial_fit(X_train_processed, y_train)\n",
" else:\n",
" print(current_dates)\n",
"\n",
" # Add the scores as a new 'score' column to the test_data\n",
" test_data['score'] = scores\n",
" return test_data"
],
"outputs": [],
"execution_count": 81
},
{
"cell_type": "code",
"id": "e3ac761d8f0b5d31",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T13:25:27.698736Z",
"start_time": "2025-03-02T13:25:25.801431Z"
}
},
"source": [
"import gc\n",
"\n",
"gc.collect()\n",
"\n",
"# predictions_test = incremental_training(test_data, model, scaler, encoder, 10, days, feature_columns_new, params)\n",
"X_test = test_data[feature_columns_new]\n",
"numeric_columns = X_test.select_dtypes(include=['float64', 'int64']).columns\n",
"categorical_columns = [col for col in feature_columns if col.startswith('cat')]\n",
"\n",
"X_test.loc[:, numeric_columns] = scaler.transform(X_test[numeric_columns])\n",
"X_test_categorical = encoder.transform(X_test[categorical_columns]).toarray()\n",
"\n",
"# Combine numeric and categorical features\n",
"X_test_processed = pd.concat([\n",
" pd.DataFrame(X_test[numeric_columns], columns=numeric_columns, index=X_test.index),\n",
" pd.DataFrame(X_test_categorical, columns=encoder.get_feature_names_out(categorical_columns), index=X_test.index)\n",
"], axis=1)\n",
"predictions_test = test_data[['ts_code', 'trade_date']]\n",
"predictions_test['score'] = model.predict(X_test_processed)\n",
"predictions_test = predictions_test.loc[predictions_test.groupby('trade_date')['score'].idxmax()]\n",
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
],
"outputs": [],
"execution_count": 84
},
{
"cell_type": "code",
"id": "b427ce41-9739-4e9e-bea8-5f2551fec5d7",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
"end_time": "2025-03-02T13:24:01.404626100Z",
"start_time": "2025-03-02T13:17:17.216317Z"
}
},
"source": [
"import joblib\n",
"import pandas as pd\n",
"\n",
"\n",
"# 假设你已经训练好了一个 LightGBM 模型\n",
"# model = lgb.train(params, train_data, ...)\n",
"\n",
"def save_model_with_info(model, params, feature_columns, train_data, info, save_path):\n",
" \"\"\"\n",
" 保存 LightGBM 模型及其相关信息。\n",
" \n",
" 参数:\n",
" model: 训练好的 LightGBM 模型 (lgb.Booster)。\n",
" params: 模型的参数 (dict)。\n",
" feature_columns: 特征列名列表 (list)。\n",
" train_data: 训练数据 (pd.DataFrame),包含 'trade_date' 列。\n",
" info: 额外信息 (str 或 dict)。\n",
" save_path: 保存路径 (str)。\n",
" \"\"\"\n",
" # 提取训练数据的 trade_date 的最大值和最小值\n",
" if 'trade_date' not in train_data.columns:\n",
" raise ValueError(\"训练数据中必须包含 'trade_date' 列。\")\n",
"\n",
" trade_date_min = train_data['trade_date'].min()\n",
" trade_date_max = train_data['trade_date'].max()\n",
"\n",
" # 构建保存的信息字典\n",
" model_info = {\n",
" 'model': model, # 保存模型本身\n",
" 'params': params, # 模型参数\n",
" 'feature_columns': feature_columns, # 特征列名\n",
" 'trade_date_range': {\n",
" 'min': trade_date_min,\n",
" 'max': trade_date_max\n",
" }, # trade_date 的范围\n",
" 'info': info # 额外信息\n",
" }\n",
"\n",
" # 使用 joblib 保存模型及相关信息\n",
" joblib.dump(model_info, save_path)\n",
" print(f\"模型及相关信息已成功保存到 {save_path}\")\n",
"\n",
"# info = \"Update Regression + 滚动new model + 5days\"\n",
"\n",
"# # 保存模型及相关信息\n",
"# save_path = \"../model/lightgbm_model_UpdateRegression_2025-2-25.pkl\"\n",
"# save_model_with_info(model, light_params, feature_columns, train_data, info, save_path)"
],
"outputs": [],
"execution_count": 73
},
{
"cell_type": "code",
"id": "8f9a2b7b-11fe-4eb5-aa11-c4066fe418a1",
"metadata": {
"ExecuteTime": {
"end_time": "2025-03-02T13:24:01.407058800Z",
"start_time": "2025-03-02T13:17:17.289631Z"
}
},
"source": [],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}