Files
NewStock/main/train/UpdateClassify.ipynb

1407 lines
212 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:01:24.391606Z",
"start_time": "2025-03-01T10:01:24.069449Z"
}
},
"source": [
2025-03-01 18:10:27 +08:00
"# %load_ext autoreload\n",
"# %autoreload 2\n",
"\n",
"import pandas as pd\n",
"\n",
"pd.set_option('display.max_columns', None)\n"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:02:01.510095Z",
"start_time": "2025-03-01T10:01:24.392790Z"
}
},
"source": [
2025-04-28 11:02:52 +08:00
"from code.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print(df.info())"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
2025-03-01 18:10:27 +08:00
"RangeIndex: 8360947 entries, 0 to 8360946\n",
"Data columns (total 21 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(18), object(1)\n",
2025-03-01 18:10:27 +08:00
"memory usage: 1.3+ GB\n",
"None\n"
]
}
],
"execution_count": 2
},
{
"cell_type": "code",
"id": "a4eec8c93f6a7cc3",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:02:04.058688Z",
"start_time": "2025-03-01T10:02:01.684757Z"
}
},
"source": [
"print('industry')\n",
"df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code'],\n",
" df=df, on=['ts_code'], join='left')\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n",
"left merge on ['ts_code']\n"
]
}
],
"execution_count": 3
},
{
"cell_type": "code",
"id": "c4e9e1d31da6dba6",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:02:04.116043Z",
"start_time": "2025-03-01T10:02:04.067264Z"
}
},
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
2025-03-01 18:10:27 +08:00
"\n",
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" return df\n",
"\n",
2025-03-01 18:10:27 +08:00
"\n",
"def generate_index_indicators(h5_filename):\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" df = df.sort_values('trade_date')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
" df_index = calculate_indicators(df_index)\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', 'RSI', 'MACD', 'Signal_line', 'MACD_hist'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
"h5_filename = '../../data/index_data.h5'\n",
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n"
],
"outputs": [],
"execution_count": 4
},
{
"cell_type": "code",
"id": "99776e73-f310-47c0-953e-2cf73ff13310",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:02:04.160778Z",
"start_time": "2025-03-01T10:02:04.134780Z"
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
2025-03-01 18:10:27 +08:00
"\n",
"def get_technical_factor(df):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
"\n",
" # 计算 up 和 down\n",
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
"\n",
" # 计算 ATR\n",
" df['atr_14'] = grouped.apply(\n",
2025-03-01 18:10:27 +08:00
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),\n",
" index=x.index)\n",
" )\n",
" df['atr_6'] = grouped.apply(\n",
2025-03-01 18:10:27 +08:00
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),\n",
" index=x.index)\n",
" )\n",
"\n",
" # 计算 OBV 及其均线\n",
" df['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" df['maobv_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
"\n",
" # 计算 RSI\n",
" df['rsi_3'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
" )\n",
" df['rsi_6'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
" )\n",
" df['rsi_9'] = grouped.apply(\n",
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
" )\n",
"\n",
" # 计算 return_10 和 return_20\n",
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" # 计算 avg_close_5\n",
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
"\n",
" # 计算标准差指标\n",
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
"\n",
" # 计算比值指标\n",
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
"\n",
" # 计算标准差差值\n",
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_act_factor(df, cat=True):\n",
" # 按股票和日期排序\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code', group_keys=False)\n",
" # 计算 EMA 指标\n",
" df['ema_5'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
" )\n",
" df['ema_13'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
" )\n",
" df['ema_20'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
" )\n",
" df['ema_60'] = grouped['close'].apply(\n",
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
" )\n",
"\n",
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df['act_factor1'] = grouped['ema_5'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
" )\n",
" df['act_factor2'] = grouped['ema_13'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
" )\n",
" df['act_factor3'] = grouped['ema_20'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
" )\n",
" df['act_factor4'] = grouped['ema_60'].apply(\n",
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
" )\n",
"\n",
" if cat:\n",
" df['cat_af1'] = df['act_factor1'] > 0\n",
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
"\n",
" # 计算 act_factor5 和 act_factor6\n",
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
2025-03-01 18:10:27 +08:00
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
"\n",
" # 根据 trade_date 截面计算排名\n",
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
"\n",
" return df\n",
"\n",
"\n",
"def get_money_flow_factor(df):\n",
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
"\n",
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
"\n",
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
" return df\n",
"\n",
"\n",
"def get_alpha_factor(df):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
" grouped = df.groupby('ts_code')\n",
"\n",
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
"\n",
" # alpha_003: (close - open) / (high - low)\n",
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
" 0)\n",
"\n",
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
"\n",
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
"\n",
" return df\n",
"\n"
],
"outputs": [],
"execution_count": 5
},
{
"cell_type": "code",
"id": "c0ad4f64-a6c6-4e43-b7ba-3f2a36baaa64",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:02:11.090680Z",
"start_time": "2025-03-01T10:02:04.160778Z"
}
},
"source": [
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
2025-03-01 18:10:27 +08:00
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
2025-03-01 18:10:27 +08:00
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
2025-03-01 18:10:27 +08:00
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
2025-03-01 18:10:27 +08:00
" industry_data = get_act_factor(industry_data, cat=False)\n",
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
"\n",
" for factor in factor_columns:\n",
" if factor in industry_data.columns:\n",
" # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
2025-03-01 18:10:27 +08:00
" industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
" lambda x: x - x.median())\n",
"\n",
2025-03-01 18:10:27 +08:00
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
2025-03-01 18:10:27 +08:00
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
2025-03-01 18:10:27 +08:00
"\n",
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
],
2025-03-01 18:10:27 +08:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\4080813444.py:11: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" industry_data['obv'] = grouped.apply(\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n",
"E:\\Python\\anaconda\\envs\\new_trader\\Lib\\site-packages\\numpy\\lib\\nanfunctions.py:1215: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(a, axis, out=out, keepdims=keepdims)\n"
]
}
],
"execution_count": 6
},
{
"cell_type": "code",
"id": "53f86ddc0677a6d7",
"metadata": {
"scrolled": true,
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:02:11.106136Z",
"start_time": "2025-03-01T10:02:11.102732Z"
}
},
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'l2_code']]\n",
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n"
],
"outputs": [],
"execution_count": 7
},
{
"cell_type": "code",
"id": "5f3d9aece75318cd",
"metadata": {
2025-03-01 18:10:27 +08:00
"scrolled": true,
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:03:15.010037Z",
"start_time": "2025-03-01T10:02:11.119055Z"
}
},
"source": [
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
" df = df[~df['is_st']]\n",
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
" df = df[~df['ts_code'].str.startswith('30')]\n",
" df = df[~df['ts_code'].str.startswith('68')]\n",
" df = df[~df['ts_code'].str.startswith('8')]\n",
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
"\n",
"df = filter_data(df)\n",
"df = get_technical_factor(df)\n",
"df = get_act_factor(df)\n",
"df = get_money_flow_factor(df)\n",
"df = get_alpha_factor(df)\n",
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())"
],
"outputs": [
2025-03-01 18:10:27 +08:00
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\18893331.py:15: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" df['atr_14'] = grouped.apply(\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\18893331.py:19: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" df['atr_6'] = grouped.apply(\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\18893331.py:25: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" df['obv'] = grouped.apply(\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\18893331.py:28: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" df['maobv_6'] = grouped.apply(\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\18893331.py:34: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" df['rsi_3'] = grouped.apply(\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\18893331.py:37: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" df['rsi_6'] = grouped.apply(\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\18893331.py:40: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" df['rsi_9'] = grouped.apply(\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\18893331.py:146: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
2025-03-01 18:10:27 +08:00
"Index: 5573859 entries, 1962 to 5573857\n",
"Data columns (total 71 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 turnover_rate float64 \n",
" 8 pe_ttm float64 \n",
" 9 circ_mv float64 \n",
" 10 volume_ratio float64 \n",
" 11 is_st bool \n",
" 12 up_limit float64 \n",
" 13 down_limit float64 \n",
" 14 buy_sm_vol float64 \n",
" 15 sell_sm_vol float64 \n",
" 16 buy_lg_vol float64 \n",
" 17 sell_lg_vol float64 \n",
" 18 buy_elg_vol float64 \n",
" 19 sell_elg_vol float64 \n",
" 20 net_mf_vol float64 \n",
" 21 cat_l2_code object \n",
" 22 up float64 \n",
" 23 down float64 \n",
" 24 atr_14 float64 \n",
" 25 atr_6 float64 \n",
" 26 obv float64 \n",
" 27 maobv_6 float64 \n",
" 28 obv-maobv_6 float64 \n",
" 29 rsi_3 float64 \n",
" 30 rsi_6 float64 \n",
" 31 rsi_9 float64 \n",
" 32 return_5 float64 \n",
" 33 return_10 float64 \n",
" 34 return_20 float64 \n",
" 35 avg_close_5 float64 \n",
" 36 std_return_5 float64 \n",
" 37 std_return_15 float64 \n",
" 38 std_return_25 float64 \n",
" 39 std_return_90 float64 \n",
" 40 std_return_90_2 float64 \n",
" 41 std_return_5 / std_return_90 float64 \n",
" 42 std_return_5 / std_return_25 float64 \n",
" 43 std_return_90 - std_return_90_2 float64 \n",
" 44 ema_5 float64 \n",
" 45 ema_13 float64 \n",
" 46 ema_20 float64 \n",
" 47 ema_60 float64 \n",
" 48 act_factor1 float64 \n",
" 49 act_factor2 float64 \n",
" 50 act_factor3 float64 \n",
" 51 act_factor4 float64 \n",
" 52 cat_af1 bool \n",
" 53 cat_af2 bool \n",
" 54 cat_af3 bool \n",
" 55 cat_af4 bool \n",
" 56 act_factor5 float64 \n",
" 57 act_factor6 float64 \n",
" 58 rank_act_factor1 float64 \n",
" 59 rank_act_factor2 float64 \n",
" 60 rank_act_factor3 float64 \n",
" 61 active_buy_volume_large float64 \n",
" 62 active_buy_volume_big float64 \n",
" 63 active_buy_volume_small float64 \n",
" 64 buy_lg_vol_minus_sell_lg_vol float64 \n",
" 65 buy_elg_vol_minus_sell_elg_vol float64 \n",
" 66 log(circ_mv) float64 \n",
" 67 alpha_022 float64 \n",
" 68 alpha_003 float64 \n",
" 69 alpha_007 float64 \n",
" 70 alpha_013 float64 \n",
"dtypes: bool(5), datetime64[ns](1), float64(63), object(2)\n",
"memory usage: 2.8+ GB\n",
"None\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "code",
"id": "0ebdfb92-d88b-4b5c-a715-675dab876fc0",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:03:15.119034Z",
"start_time": "2025-03-01T10:03:15.095810Z"
}
},
"source": [
"def create_deviation_within_dates(df, feature_columns):\n",
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
" new_columns = {}\n",
" ret_feature_columns = feature_columns[:]\n",
"\n",
" # 自动选择所有数值型特征\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
2025-03-01 18:10:27 +08:00
" num_features = [col for col in feature_columns if 'cat' not in col and 'industry' not in col]\n",
"\n",
" # 遍历所有数值型特征\n",
" for feature in num_features:\n",
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
" continue\n",
"\n",
2025-03-01 18:10:27 +08:00
" # grouped_median = df.groupby(['trade_date', groupby_col])[feature].transform('median')\n",
" # deviation_col_name = f'deviation_median_{feature}'\n",
" # new_columns[deviation_col_name] = df[feature] - grouped_median\n",
" # ret_feature_columns.append(deviation_col_name)\n",
"\n",
2025-03-01 18:10:27 +08:00
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
" deviation_col_name = f'deviation_mean_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
"\n",
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
"\n",
" return df, ret_feature_columns\n"
],
"outputs": [],
2025-03-01 18:10:27 +08:00
"execution_count": 9
},
{
"cell_type": "code",
"id": "fbb968383f8cf2c7",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:03:33.404454Z",
"start_time": "2025-03-01T10:03:15.217954Z"
}
},
"source": [
2025-03-01 18:10:27 +08:00
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"window = 5\n",
"quantiles = 20\n",
"\n",
2025-03-01 18:10:27 +08:00
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
"\n",
2025-03-01 18:10:27 +08:00
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" df['sharpe_ratio'] = df['future_return'] * df['future_volatility']\n",
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return df['sharpe_ratio']\n",
"\n",
2025-03-01 18:10:27 +08:00
"\n",
"future_close = df.groupby('ts_code')['close'].shift(-4)\n",
"future_return = (future_close - df['close']) / df['close']\n",
"df['label'] = future_return\n",
"\n",
"# df = df.apply(lambda x: x.astype('float32') if x.dtype in ['float64', 'float32'] else x)\n",
"df = df.sort_values(by=['trade_date', 'ts_code'])\n",
"train_data = df[(df['trade_date'] <= '2023-01-01') & (df['trade_date'] >= '2016-01-01')]\n",
2025-03-01 18:10:27 +08:00
"test_data = df[(df['trade_date'] >= '2023-01-01') & (df['trade_date'] <= '2025-02-26')]\n",
"\n",
2025-03-01 18:10:27 +08:00
"train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nlargest(1000, 'return_20')\n",
")\n",
"test_data = test_data.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: x.nlargest(1000, 'return_20')\n",
")\n",
"\n",
2025-03-01 18:10:27 +08:00
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
"\n",
2025-03-01 18:10:27 +08:00
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)"
],
2025-03-01 18:10:27 +08:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\1029324748.py:36: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" train_data = train_data.groupby('trade_date', group_keys=False).apply(\n",
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_11956\\1029324748.py:39: DeprecationWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.\n",
" test_data = test_data.groupby('trade_date', group_keys=False).apply(\n"
]
}
],
"execution_count": 10
},
{
"cell_type": "code",
"id": "de8c2f6c770d2439",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:03:33.451800Z",
"start_time": "2025-03-01T10:03:33.446956Z"
}
},
"source": [
2025-03-01 18:10:27 +08:00
"feature_columns = [col for col in train_data.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"print(feature_columns)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-01 18:10:27 +08:00
"['turnover_rate', 'pe_ttm', 'volume_ratio', 'cat_l2_code', 'up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_5', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'cat_af1', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'log(circ_mv)', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry_ema_5', 'industry_ema_13', 'industry_ema_20', 'industry_ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_obv_deviation', 'industry_return_5_deviation', 'industry_return_20_deviation', 'industry_act_factor1_deviation', 'industry_act_factor2_deviation', 'industry_act_factor3_deviation', 'industry_act_factor4_deviation', 'industry_return_5_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return']\n"
]
}
],
2025-03-01 18:10:27 +08:00
"execution_count": 11
},
{
"cell_type": "code",
"id": "20ffa7229c9d2f86",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:03:47.388358Z",
"start_time": "2025-03-01T10:03:33.451800Z"
}
},
"source": [
"feature_columns_new = feature_columns[:]\n",
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"\n",
"train_data = train_data.dropna(subset=feature_columns_new)\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"\n",
"test_data = test_data.dropna(subset=feature_columns_new)\n",
2025-03-01 18:10:27 +08:00
"# test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-01 18:10:27 +08:00
"feature_columns size: 155\n",
"feature_columns size: 155\n",
"1166756\n",
"最小日期: 2017-04-06\n",
"最大日期: 2022-12-30\n",
2025-03-01 18:10:27 +08:00
"408601\n",
"最小日期: 2023-01-03\n",
2025-03-01 18:10:27 +08:00
"最大日期: 2025-02-26\n"
]
}
],
2025-03-01 18:10:27 +08:00
"execution_count": 12
},
{
"cell_type": "code",
"id": "35238cb4f45ce756",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:03:47.536921Z",
"start_time": "2025-03-01T10:03:47.432730Z"
}
},
"source": [
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')"
],
"outputs": [],
2025-03-01 18:10:27 +08:00
"execution_count": 13
},
{
"cell_type": "code",
"id": "8f134d435f71e9e2",
"metadata": {
"jupyter": {
"source_hidden": true
},
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:03:48.200997Z",
"start_time": "2025-03-01T10:03:47.581003Z"
}
},
"source": [
"from catboost import Pool\n",
"import lightgbm as lgb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
2025-03-01 18:10:27 +08:00
"\n",
"\n",
"def train_light_model(train_data_df, test_data_df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" use_optuna=False):\n",
" train_data_df, test_data_df = train_data_df.dropna(subset=['label']), test_data_df.dropna(subset=['label'])\n",
2025-03-01 18:10:27 +08:00
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" print(f'categorical_feature: {categorical_feature}')\n",
"\n",
" X_train = train_data_df[feature_columns]\n",
" y_train = train_data_df['label']\n",
"\n",
" X_val = test_data_df[feature_columns]\n",
" y_val = test_data_df['label']\n",
"\n",
2025-03-01 18:10:27 +08:00
" scaler = StandardScaler()\n",
" numeric_columns = X_train.select_dtypes(include=['float64', 'int64']).columns\n",
" X_train.loc[:, numeric_columns] = scaler.fit_transform(X_train[numeric_columns])\n",
" X_val.loc[:, numeric_columns] = scaler.transform(X_val[numeric_columns])\n",
"\n",
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
" val_data = lgb.Dataset(X_val, label=y_val, categorical_feature=categorical_feature)\n",
" model = lgb.train(\n",
" params, train_data, num_boost_round=num_boost_round,\n",
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" # lgb.plot_tree(model, figsize=(20, 8))\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
2025-03-01 18:10:27 +08:00
" return model, scaler\n",
"\n",
"\n",
"from catboost import CatBoostClassifier\n",
"import pandas as pd\n",
"\n",
"\n",
"def train_catboost(train_data_df, test_data_df, feature_columns, params=None, plot=False):\n",
" train_data_df, test_data_df = train_data_df.dropna(subset=['label']), test_data_df.dropna(subset=['label'])\n",
" X_train = train_data_df[feature_columns]\n",
" y_train = train_data_df['label']\n",
"\n",
" X_val = test_data_df[feature_columns]\n",
" y_val = test_data_df['label']\n",
"\n",
" cat_features = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" print(f'cat_features: {cat_features}')\n",
" train_pool = Pool(data=X_train, label=y_train, cat_features=cat_features)\n",
" val_pool = Pool(data=X_val, label=y_val, cat_features=cat_features)\n",
"\n",
" model = CatBoostClassifier(**params)\n",
" model.fit(train_pool,\n",
" eval_set=val_pool, plot=plot)\n",
"\n",
" return model"
],
"outputs": [],
2025-03-01 18:10:27 +08:00
"execution_count": 14
},
{
"cell_type": "code",
"id": "4a4542e1ed6afe7d",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:03:48.253523Z",
"start_time": "2025-03-01T10:03:48.248246Z"
}
},
"source": [
"light_params = {\n",
" 'objective': 'binary',\n",
" 'metric': 'average_precision',\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 1024,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
2025-03-01 18:10:27 +08:00
" # 'feature_fraction': 0.5,\n",
" # 'bagging_fraction': 0.5,\n",
" # 'bagging_freq': 5,\n",
" # 'lambda_l1': 80,\n",
" # 'lambda_l2': 65,\n",
" 'verbosity': -1,\n",
2025-03-01 18:10:27 +08:00
" 'num_threads': 16\n",
"}"
],
"outputs": [],
2025-03-01 18:10:27 +08:00
"execution_count": 15
},
{
"cell_type": "code",
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:05:12.645870Z",
"start_time": "2025-03-01T10:03:48.301057Z"
}
},
"source": [
"print('train data size: ', len(train_data))\n",
2025-03-01 18:10:27 +08:00
"\n",
"evals = {}\n",
2025-03-01 18:10:27 +08:00
"model, scaler = train_light_model(train_data, test_data, light_params, feature_columns_new,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=1000, use_optuna=False,\n",
" print_feature_importance=True)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-01 18:10:27 +08:00
"train data size: 1166756\n",
"categorical_feature: [3, 34, 35, 36, 37]\n",
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
2025-03-01 18:10:27 +08:00
"[92]\ttrain's average_precision: 0.816465\tvalid's average_precision: 0.50322\n",
"Evaluated only: average_precision\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
2025-03-01 18:10:27 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAj0AAAHHCAYAAABUcOnjAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAX21JREFUeJzt3Xl8E2X+B/BP7jRt05aetBTKJTcFyiGgIgoiKAquFyCXK65CV7CLCIqcK1VUrCIr6g/EdWVldRFdObRUDkEE5L5vaCm9S+/mnt8faUND0pKmadMyn/frFZs8mZk8+TY0H595ZkYiCIIAIiIiotuc1NsdICIiImoIDD1EREQkCgw9REREJAoMPURERCQKDD1EREQkCgw9REREJAoMPURERCQKDD1EREQkCgw9REREJAoMPUSENWvWQCKR4PLly/X2GgsWLIBEImky2/W2y5cvQyKRYM2aNW6tL5FIsGDBAo/2iaipY+ghakCV4UIikWDXrl0OzwuCgOjoaEgkEjz88MNuvcY//vEPt78oqXbWrl2LpKQkb3eDiFzE0EPkBWq1GmvXrnVo37FjB65evQqVSuX2tt0JPePHj0d5eTlatWrl9ut6y9y5c1FeXu6V167P0NOqVSuUl5dj/Pjxbq1fXl6OuXPnerhXRE0bQw+RF4wYMQLffPMNTCaTXfvatWsRFxeHiIiIBulHaWkpAEAmk0GtVjep3USVfZfL5VCr1V7uza3pdDpYLBaXl5dIJFCr1ZDJZG69nlqthlwud2tdotsVQw+RF4wZMwZ5eXlITk62tRkMBnz77bcYO3as03UsFguSkpLQpUsXqNVqhIeH4y9/+QuuX79uWyYmJgYnTpzAjh07bLvR7r33XgA3dq3t2LEDU6dORVhYGFq0aGH33M1zejZv3oxBgwbB398fWq0Wffr0cTpCdbNdu3ahT58+UKvVaNu2LT755BOHZWqas3LzfJTKeTsnT57E2LFjERQUhLvuusvuuZvXj4+Px4YNG9C1a1eoVCp06dIFW7ZscXit7du3o3fv3nZ9dWWe0L333ouNGzfiypUrtlrHxMTYtimRSPD1119j7ty5iIqKgkajQVFREfLz8zFz5kx069YNfn5+0Gq1GD58OI4cOXLL+kyaNAl+fn5IT0/HqFGj4Ofnh9DQUMycORNms9mlGp4/fx6TJk1CYGAgAgICMHnyZJSVldmtW15ejpdeegkhISHw9/fHI488gvT0dM4ToiaP/xtA5AUxMTHo378//v3vf2P48OEArAGjsLAQTz/9ND788EOHdf7yl79gzZo1mDx5Ml566SVcunQJH330EQ4dOoTdu3dDoVAgKSkJf/3rX+Hn54fXX38dABAeHm63nalTpyI0NBTz5s2zjZY4s2bNGjz77LPo0qUL5syZg8DAQBw6dAhbtmypNpgBwLFjx/DAAw8gNDQUCxYsgMlkwvz58x364Y4nnngC7du3x5IlSyAIQo3L7tq1C+vXr8fUqVPh7++PDz/8EH/605+QmpqK4OBgAMChQ4fw4IMPonnz5li4cCHMZjMWLVqE0NDQW/bl9ddfR2FhIa5evYr3338fAODn52e3zOLFi6FUKjFz5kzo9XoolUqcPHkSGzZswBNPPIHWrVsjKysLn3zyCQYNGoSTJ08iMjKyxtc1m80YNmwY+vXrh3fffRdbt27Fe++9h7Zt2+LFF1+8Zb+ffPJJtG7dGomJiTh48CD+7//+D2FhYXj77bdty0yaNAn/+c9/MH78eNx5553YsWMHHnrooVtum6jRE4iowXz++ecCAGH//v3CRx99JPj7+wtlZWWCIAjCE088IQwePFgQBEFo1aqV8NBDD9nW+/XXXwUAwldffWW3vS1btji0d+nSRRg0aFC1r33XXXcJJpPJ6XOXLl0SBEEQCgoKBH9/f6Ffv35CeXm53bIWi6XG9zhq1ChBrVYLV65csbWdPHlSkMlkQtU/OZcuXRIACJ9//rnDNgAI8+fPtz2eP3++AEAYM2aMw7KVz928vlKpFM6fP29rO3LkiABAWL58ua1t5MiRgkajEdLT021t586dE+RyucM2nXnooYeEVq1aObRv27ZNACC0adPG9vutpNPpBLPZbNd26dIlQaVSCYsWLbJru7k+EydOFADYLScIgtCzZ08hLi7OoQbOavjss8/aLTd69GghODjY9vjAgQMCAGHGjBl2y02aNMlhm0RNDXdvEXnJk08+ifLycvz4448oLi7Gjz/+WO0IyjfffIOAgAAMHToUubm5tltcXBz8/Pywbds2l193ypQpt5wnkpycjOLiYsyePdthvkxNu33MZjN++uknjBo1Ci1btrS1d+rUCcOGDXO5j9V54YUXXF52yJAhaNu2re1x9+7dodVqcfHiRVtft27dilGjRtmNrrRr1842+lZXEydOhI+Pj12bSqWCVCq19SEvLw9+fn7o0KEDDh486NJ2b67D3XffbXtf7qybl5eHoqIiALDtApw6dardcn/9619d2j5RY8bdW0ReEhoaiiFDhmDt2rUoKyuD2WzG448/7nTZc+fOobCwEGFhYU6fz87Odvl1W7dufctlLly4AADo2rWry9sFgJycHJSXl6N9+/YOz3Xo0AGbNm2q1fZu5krfK1UNXZWCgoJsc6Cys7NRXl6Odu3aOSznrM0dzvprsVjwwQcf4B//+AcuXbpkNxencrdbTdRqtcPut6rv61ZurktQUBAA4Pr169Bqtbhy5QqkUqlD3z1VEyJvYugh8qKxY8diypQpyMzMxPDhwxEYGOh0OYvFgrCwMHz11VdOn3dlDkqlm0cevKW6EaObJ+RWVZu+VzeaJdxiLpAnOevvkiVL8MYbb+DZZ5/F4sWL0axZM0ilUsyYMcOlo7vcPZrrVus3ZF2IvIWhh8iLRo8ejb/85S/4/fffsW7dumqXa9u2LbZu3YqBAwfe8ovfE4edV+4WOn78eK3+Dz80NBQ+Pj44d+6cw3Nnzpyxe1w5wlBQUGDXfuXKlVr21j1hYWFQq9U4f/68w3PO2pxxp9bffvstBg8ejFWrVtm1FxQUICQkpNbb87RWrVrBYrHg0qVLdiN2rtaEqDHjnB4iL/Lz88PHH3+MBQsWYOTIkdUu9+STT8JsNmPx4sUOz5lMJrvg4Ovr6xAkauuBBx6Av78/EhMTodPp7J6raURAJpNh2LBh2LBhA1JTU23tp06dwk8//WS3rFarRUhICHbu3GnX/o9//KNOfXeVTCbDkCFDsGHDBly7ds3Wfv78eWzevNmlbfj6+qKwsLDWr3tzDb/55hukp6fXajv1pXLu1c2/h+XLl3ujO0QexZEeIi+bOHHiLZcZNGgQ/vKXvyAxMRGHDx/GAw88AIVCgXPnzuGbb77BBx98YJsPFBcXh48//hh///vf0a5dO4SFheG+++6rVZ+0Wi3ef/99PPfcc+jTp4/t3DhHjhxBWVkZvvjii2rXXbhwIbZs2YK7774bU6dOhclkwvLly9GlSxccPXrUbtnnnnsOb731Fp577jn07t0bO3fuxNmzZ2vV17pYsGABfv75ZwwcOBAvvvgizGYzPvroI3Tt2hWHDx++5fpxcXFYt24dEhIS0KdPH/j5+dUYXgHg4YcfxqJFizB58mQMGDAAx44dw1dffYU2bdp46F3VTVxcHP70pz8hKSkJeXl5tkPWK38vTekElkQ3Y+ghaiJWrlyJuLg4fPLJJ3jttdcgl8sRExODZ555BgMHDrQtN2/ePFy5cgVLly5FcXExBg0aVOvQAwB//vOfERYWhrfeeguLFy+GQqFAx44d8fLLL9e4Xvfu3fHTTz8hISEB8+bNQ4sWLbBw4UJkZGQ4hJ558+YhJycH3377Lf7zn/9g+PDh2Lx5c7UTtj0tLi4OmzdvxsyZM/HGG28gOjoaixYtwqlTp3D69Olbrj916lQcPnwYn3/+Od5//320atXqlqHntddeQ2lpKdauXYt169ahV69e2LhxI2bPnu2pt1Vn//znPxEREYF///vf+O677zBkyBCsW7cOHTp0aBJnvyaqjkTg7DUiIjujRo3CiRMnnM5NEqv
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
2025-03-01 18:10:27 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAscAAAHHCAYAAABTHvWzAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFcfewPHvoRcBpYkNsCMqiFgAC1gAldhi1GjsLaBE0QQRbNiNvcYSC96rXo3GlogFFTv23hVboqBiFFQiHGDePwz7egIoJBaQ+TzPeZLdmd2d2SE5vzM7O6MSQggkSZIkSZIkSULrYxdAkiRJkiRJkvILGRxLkiRJkiRJ0l9kcCxJkiRJkiRJf5HBsSRJkiRJkiT9RQbHkiRJkiRJkvQXGRxLkiRJkiRJ0l9kcCxJkiRJkiRJf5HBsSRJkiRJkiT9RQbHkiRJkiRJkvQXGRxLkiRJn6yIiAhUKhW3b9/+2EWRJKmAkMGxJEnSJyQzGMzuM2zYsPdyzcOHDxMeHs7Tp0/fy/kLs+TkZMLDw9m7d+/HLookFRo6H7sAkiRJ0rs3duxYypYtq7GvWrVq7+Vahw8fZsyYMfTo0YOiRYu+l2v8U127duXLL79EX1//YxflH0lOTmbMmDEAeHl5fdzCSFIhIYNjSZKkT1Dz5s2pVavWxy7Gv/LixQuMjY3/1Tm0tbXR1tZ+RyX6cDIyMkhNTf3YxZCkQkkOq5AkSSqEtm3bRoMGDTA2NsbExAQ/Pz8uXryokefcuXP06NGDcuXKYWBggI2NDb169eLx48dKnvDwcIKDgwEoW7asMoTj9u3b3L59G5VKRURERJbrq1QqwsPDNc6jUqm4dOkSnTt3plixYtSvX19JX7lyJa6urhgaGmJubs6XX37Jb7/99tZ6Zjfm2N7ens8++4y9e/dSq1YtDA0NqV69ujJ0YcOGDVSvXh0DAwNcXV05ffq0xjl79OhBkSJFuHnzJr6+vhgbG1OyZEnGjh2LEEIj74sXL/j2228pU6YM+vr6VK5cmWnTpmXJp1KpCAwMZNWqVVStWhV9fX0WLlyIlZUVAGPGjFHubeZ9y037vH5vb9y4ofTum5mZ0bNnT5KTk7Pcs5UrV1KnTh2MjIwoVqwYDRs2ZOfOnRp5cvP3I0kFlew5liRJ+gQlJiaSkJCgsc/S0hKA//73v3Tv3h1fX1++//57kpOTWbBgAfXr1+f06dPY29sDEBUVxc2bN+nZsyc2NjZcvHiRxYsXc/HiRY4cOYJKpeLzzz/n2rVr/O9//2PmzJnKNaysrHj06FGey92+fXsqVqzIxIkTlQBywoQJjBw5kg4dOtCnTx8ePXrE3LlzadiwIadPn/5HQzlu3LhB586d+frrr+nSpQvTpk2jZcuWLFy4kLCwMPr37w/ApEmT6NChA1evXkVL6//7k9LT02nWrBlubm5MmTKF7du3M3r0aNLS0hg7diwAQghatWpFdHQ0vXv3pkaNGuzYsYPg4GDu3bvHzJkzNcq0Z88efvrpJwIDA7G0tMTZ2ZkFCxYQEBBA27Zt+fzzzwFwcnICctc+r+vQoQNly5Zl0qRJnDp1iiVLlmBtbc3333+v5BkzZgzh4eF4eHgwduxY9PT0OHr0KHv27MHHxwfI/d+PJBVYQpIkSfpkLF++XADZfoQQ4tmzZ6Jo0aKib9++GsfFx8cLMzMzjf3JyclZzv+///1PAGL//v3KvqlTpwpA3Lp1SyPvrVu3BCCWL1+e5TyAGD16tLI9evRoAYhOnTpp5Lt9+7bQ1tYWEyZM0Nh//vx5oaOjk2V/Tvfj9bLZ2dkJQBw+fFjZt2PHDgEIQ0NDcefOHWX/okWLBCCio6OVfd27dxeA+Oabb5R9GRkZws/PT+jp6YlHjx4JIYTYtGmTAMT48eM1yvTFF18IlUolbty4oXE/tLS0xMWLFzXyPnr0KMu9ypTb9sm8t7169dLI27ZtW2FhYaFsX79+XWhpaYm2bduK9PR0jbwZGRlCiLz9/UhSQSWHVUiSJH2C5s+fT1RUlMYHXvU2Pn36lE6dOpGQkKB8tLW1qVu3LtHR0co5DA0NlX9/+fIlCQkJuLm5AXDq1Kn3Um5/f3+N7Q0bNpCRkUGHDh00ymtjY0PFihU1ypsXjo6OuLu7K9t169YFoHHjxtja2mbZf/PmzSznCAwMVP49c1hEamoqu3btAiAyMhJtbW0GDhyocdy3336LEIJt27Zp7Pf09MTR0THXdchr+/z93jZo0IDHjx+TlJQEwKZNm8jIyGDUqFEaveSZ9YO8/f1IUkElh1VIkiR9gurUqZPtC3nXr18HXgWB2TE1NVX+/Y8//mDMmDGsWbOGhw8fauRLTEx8h6X9f3+fYeP69esIIahYsWK2+XV1df/RdV4PgAHMzMwAKFOmTLb7nzx5orFfS0uLcuXKaeyrVKkSgDK++c6dO5QsWRITExONfFWqVFHSX/f3ur9NXtvn73UuVqwY8KpupqamxMbGoqWl9cYAPS9/P5JUUMngWJIkqRDJyMgAXo0btbGxyZKuo/P/XwsdOnTg8OHDBAcHU6NGDYoUKUJGRgbNmjVTzvMmfx/zmik9PT3HY17vDc0sr0qlYtu2bdnOOlGkSJG3liM7Oc1gkdN+8bcX6N6Hv9f9bfLaPu+ibnn5+5Gkgkr+FUuSJBUi5cuXB8Da2pqmTZvmmO/Jkyfs3r2bMWPGMGrUKGV/Zs/h63IKgjN7Jv++OMjfe0zfVl4hBGXLllV6ZvODjIwMbt68qVGma9euASgvpNnZ2bFr1y6ePXum0Xt85coVJf1tcrq3eWmf3CpfvjwZGRlcunSJGjVq5JgH3v73I0kFmRxzLEmSVIj4+vpiamrKxIkTUavVWdIzZ5jI7GX8e6/irFmzshyTORfx34NgU1NTLC0t2b9/v8b+H374Idfl/fzzz9HW1mbMmDFZyiKEyDJt2Yc0b948jbLMmzcPXV1dmjRpAkCLFi1IT0/XyAcwc+ZMVCoVzZs3f+s1jIyMgKz3Ni/tk1tt2rRBS0uLsWPHZul5zrxObv9+JKkgkz3HkiRJhYipqSkLFiyga9eu1KxZky+//BIrKyvu3r3L1q1bqVevHvPmzcPU1JSGDRsyZcoU1Go1pUqVYufOndy6dSvLOV1dXQEYPnw4X375Jbq6urRs2RJjY2P69OnD5MmT6dOnD7Vq1WL//v1KD2tulC9fnvHjxxMaGsrt27dp06YNJiYm3Lp1i40bN9KvXz++++67d3Z/csvAwIDt27fTvXt36taty7Zt29i6dSthYWHK3MQtW7akUaNGDB8+nNu3b+Ps7MzOnTvZvHkzQUFBSi/smxgaGuLo6MjatWupVKkS5ubmVKtWjWrVquW6fXKrQoUKDB8+nHHjxtGgQQM+//xz9PX1OX78OCVLlmTSpEm5/vuRpALtI82SIUmSJL0HmVOXHT9+/I35oqOjha+vrzAzMxMGBgaifPnyokePHuLEiRNKnt9//120bdtWFC1aVJiZmYn27duL+/fvZzu12Lhx40SpUqWElpaWxtRpycnJonfv3sLMzEyYmJiIDh06iIcPH+Y4lVvmNGh/9/PPP4v69esLY2NjYWxsLBwcHMSAAQPE1atXc3U//j6Vm5+fX5a8gBgwYIDGvszp6KZOnars6969uzA2NhaxsbHCx8dHGBkZieLFi4vRo0dnmQLt2bNnYvDgwaJkyZJCV1dXVKxYUUydOlWZGu1N1850+PBh4erqKvT09DTuW27bJ6d7m929EUKIZcuWCRcXF6Gvry+KFSsmPD09RVRUlEae3Pz9SFJBpRLiA7xlIEmSJEmfiB49erB+/XqeP3/+sYsiSdJ7IMccS5IkSZIkSdJfZHAsSZIkSZIkSX+RwbEkSZIkSZIk/UWOOZYkSZIkSZKkv8ieY0mSJEmSJEn6iwyOJUmSJEmSJOkvchEQScqDjIwM7t+/j4mJSY7LukqSJEmSlL8IIXj27BklS5ZES+vNfcMyOJakPLh//z5lypT52MWQJEmSJOkf+O233yhduvQb88jgWJLywMTEBIBbt25hbm7+kUsjvY1arWb
},
"metadata": {},
"output_type": "display_data"
}
],
2025-03-01 18:10:27 +08:00
"execution_count": 16
},
{
"cell_type": "code",
"id": "445dff84-70b2-4fc9-a9b6-1251993324d6",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:05:12.793641Z",
"start_time": "2025-03-01T10:05:12.787565Z"
}
},
"source": [
"# catboost_params = {\n",
"# 'loss_function': 'CrossEntropy', # 适用于二分类\n",
"# 'eval_metric': 'AUC', # 评估指标\n",
"# 'iterations': 1000,\n",
"# 'learning_rate': 0.01,\n",
"# 'depth': , # 控制模型复杂度\n",
"# # 'l2_leaf_reg': 3, # L2 正则化\n",
"# 'verbose': 500,\n",
"# 'early_stopping_rounds': 100,\n",
"# # 'one_hot_max_size': 50,\n",
"# # 'class_weights': [0.6, 1.2]\n",
"# # 'task_type': 'GPU'\n",
"# }\n",
"\n",
"# model = train_catboost(train_data, test_data, feature_columns_new, catboost_params, plot=True)"
],
"outputs": [],
2025-03-01 18:10:27 +08:00
"execution_count": 17
},
{
"cell_type": "code",
2025-03-01 18:10:27 +08:00
"id": "7bc246ddd6b2cdd1",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:05:12.823810Z",
"start_time": "2025-03-01T10:05:12.801658Z"
}
},
"source": [
"from tqdm import tqdm\n",
"\n",
"\n",
2025-03-01 18:10:27 +08:00
"def incremental_training(test_data: pd.DataFrame,\n",
" model,\n",
" scaler,\n",
" days: int,\n",
" back_days: int,\n",
" feature_columns: list,\n",
" params: dict,\n",
" model_type: str = 'lightgbm',\n",
" ):\n",
" if model_type not in ['lightgbm', 'catboost']:\n",
" raise ValueError(\"model_type must be either 'lightgbm' or 'catboost'\")\n",
"\n",
2025-03-01 18:10:27 +08:00
" test_data = test_data.sort_values(by='trade_date')\n",
" scores = []\n",
" unique_trade_dates = sorted(test_data['trade_date'].unique())\n",
2025-03-01 18:10:27 +08:00
"\n",
" new_model = None\n",
" for i in tqdm(range(0, len(unique_trade_dates))):\n",
" # Get the current window of trade dates\n",
2025-03-01 18:10:27 +08:00
" current_dates = [unique_trade_dates[i]]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" X = window_data[feature_columns]\n",
2025-03-01 18:10:27 +08:00
" numeric_columns = X.select_dtypes(include=['float64', 'int64']).columns\n",
" X.loc[:, numeric_columns] = scaler.transform(X[numeric_columns])\n",
"\n",
2025-03-01 18:10:27 +08:00
" if new_model is not None:\n",
" window_scores = new_model.predict(X, prediction_type='RawFormulaVal')\n",
" else:\n",
" window_scores = model.predict(X, prediction_type='RawFormulaVal')\n",
" scores.extend(window_scores)\n",
"\n",
2025-03-01 18:10:27 +08:00
" # # Prepare data for incremental training\n",
" # current_dates = unique_trade_dates[max(0, i - back_days):i + days]\n",
" # window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" # X_train = window_data[feature_columns]\n",
" current_dates = unique_trade_dates[max(0, i - days):i + 1]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" X_train = window_data[feature_columns]\n",
2025-03-01 18:10:27 +08:00
" y_train = window_data['label'] # Assuming 'label' is what you're predicting\n",
" # Incrementally train the model\n",
" if len(y_train.unique()) > 1:\n",
" numeric_columns = X.select_dtypes(include=['float64', 'int64']).columns\n",
" X_train.loc[:, numeric_columns] = scaler.transform(X_train[numeric_columns])\n",
" if model_type == 'lightgbm':\n",
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
" new_model = lgb.train(params,\n",
" train_set=train_data,\n",
" num_boost_round=100,\n",
" init_model=model,\n",
" keep_training_booster=True)\n",
" # print(f\"Number of trees: {model.num_trees()}\")\n",
" elif model_type == 'catboost':\n",
" from catboost import Pool\n",
" train_data = Pool(data=X_train, label=y_train,\n",
" cat_features=[col for col in feature_columns if col.startswith('cat')])\n",
" # model.set_params(**params)\n",
" model.fit(train_data, init_model=model)\n",
" else:\n",
" print(current_dates)\n",
"\n",
" # Add the scores as a new 'score' column to the test_data\n",
" test_data['score'] = scores\n",
2025-03-01 18:10:27 +08:00
" return test_data"
],
"outputs": [],
2025-03-01 18:10:27 +08:00
"execution_count": 18
},
{
"cell_type": "code",
2025-03-01 18:10:27 +08:00
"id": "34698ca4f5fb933",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:09:12.009268Z",
"start_time": "2025-03-01T10:05:12.875190Z"
}
},
"source": [
2025-03-01 18:10:27 +08:00
"predictions_test = incremental_training(test_data, model, scaler, 5, 0, feature_columns_new, light_params,\n",
" model_type='lightgbm')\n"
],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-03-01 18:10:27 +08:00
"100%|██████████| 517/517 [03:58<00:00, 2.17it/s]\n"
]
}
],
2025-03-01 18:10:27 +08:00
"execution_count": 19
},
{
"cell_type": "code",
"id": "36ccaa730ab46718",
"metadata": {
"ExecuteTime": {
2025-03-01 18:10:27 +08:00
"end_time": "2025-03-01T10:09:12.265271Z",
"start_time": "2025-03-01T10:09:12.167363Z"
}
},
"source": [
"predictions_test = predictions_test.loc[predictions_test.groupby('trade_date')['score'].idxmax()]\n",
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
],
"outputs": [],
2025-03-01 18:10:27 +08:00
"execution_count": 20
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}