1318 lines
194 KiB
Plaintext
1318 lines
194 KiB
Plaintext
|
|
{
|
|||
|
|
"cells": [
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 1,
|
|||
|
|
"id": "79a7758178bafdd3",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:19:35.036653Z",
|
|||
|
|
"start_time": "2025-02-23T14:19:34.973432Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"%load_ext autoreload\n",
|
|||
|
|
"%autoreload 2\n",
|
|||
|
|
"\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 2,
|
|||
|
|
"id": "a79cafb06a7e0e43",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:20:18.946211Z",
|
|||
|
|
"start_time": "2025-02-23T14:19:35.036653Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"daily data\n",
|
|||
|
|
"daily basic\n",
|
|||
|
|
"inner merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"stk limit\n",
|
|||
|
|
"left merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"money flow\n",
|
|||
|
|
"left merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|||
|
|
"RangeIndex: 8296325 entries, 0 to 8296324\n",
|
|||
|
|
"Data columns (total 21 columns):\n",
|
|||
|
|
" # Column Dtype \n",
|
|||
|
|
"--- ------ ----- \n",
|
|||
|
|
" 0 ts_code object \n",
|
|||
|
|
" 1 trade_date datetime64[ns]\n",
|
|||
|
|
" 2 open float64 \n",
|
|||
|
|
" 3 close float64 \n",
|
|||
|
|
" 4 high float64 \n",
|
|||
|
|
" 5 low float64 \n",
|
|||
|
|
" 6 vol float64 \n",
|
|||
|
|
" 7 turnover_rate float64 \n",
|
|||
|
|
" 8 pe_ttm float64 \n",
|
|||
|
|
" 9 circ_mv float64 \n",
|
|||
|
|
" 10 volume_ratio float64 \n",
|
|||
|
|
" 11 is_st bool \n",
|
|||
|
|
" 12 up_limit float64 \n",
|
|||
|
|
" 13 down_limit float64 \n",
|
|||
|
|
" 14 buy_sm_vol float64 \n",
|
|||
|
|
" 15 sell_sm_vol float64 \n",
|
|||
|
|
" 16 buy_lg_vol float64 \n",
|
|||
|
|
" 17 sell_lg_vol float64 \n",
|
|||
|
|
" 18 buy_elg_vol float64 \n",
|
|||
|
|
" 19 sell_elg_vol float64 \n",
|
|||
|
|
" 20 net_mf_vol float64 \n",
|
|||
|
|
"dtypes: bool(1), datetime64[ns](1), float64(18), object(1)\n",
|
|||
|
|
"memory usage: 1.2+ GB\n",
|
|||
|
|
"None\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"from utils.utils import read_and_merge_h5_data\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('daily data')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('daily basic')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
|
|||
|
|
" 'is_st'], df=df, join='inner')\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('stk limit')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
|
|||
|
|
" df=df)\n",
|
|||
|
|
"print('money flow')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
|
|||
|
|
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
|
|||
|
|
" df=df)\n",
|
|||
|
|
"print(df.info())"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 3,
|
|||
|
|
"id": "38879273d3574ae3",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:20:21.108154Z",
|
|||
|
|
"start_time": "2025-02-23T14:20:19.025121Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"industry\n",
|
|||
|
|
"left merge on ['ts_code']\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print('industry')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
|
|||
|
|
" columns=['ts_code', 'l2_code'],\n",
|
|||
|
|
" df=df, on=['ts_code'], join='left')\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 4,
|
|||
|
|
"id": "a4eec8c93f6a7cc3",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:20:21.712957Z",
|
|||
|
|
"start_time": "2025-02-23T14:20:21.602953Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def calculate_indicators(df):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 计算四个指标:当日涨跌幅、5日移动平均、RSI、MACD。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
|
|||
|
|
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
|
|||
|
|
" delta = df['close'].diff()\n",
|
|||
|
|
" gain = delta.where(delta > 0, 0)\n",
|
|||
|
|
" loss = -delta.where(delta < 0, 0)\n",
|
|||
|
|
" avg_gain = gain.rolling(window=14).mean()\n",
|
|||
|
|
" avg_loss = loss.rolling(window=14).mean()\n",
|
|||
|
|
" rs = avg_gain / avg_loss\n",
|
|||
|
|
" df['RSI'] = 100 - (100 / (1 + rs))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算MACD\n",
|
|||
|
|
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
|
|||
|
|
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
|
|||
|
|
" df['MACD'] = ema12 - ema26\n",
|
|||
|
|
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
|
|||
|
|
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def generate_index_indicators(h5_filename):\n",
|
|||
|
|
" df = pd.read_hdf(h5_filename, key='index_data')\n",
|
|||
|
|
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
|
|||
|
|
" df = df.sort_values('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算每个ts_code的相关指标\n",
|
|||
|
|
" df_indicators = []\n",
|
|||
|
|
" for ts_code in df['ts_code'].unique():\n",
|
|||
|
|
" df_index = df[df['ts_code'] == ts_code].copy()\n",
|
|||
|
|
" df_index = calculate_indicators(df_index)\n",
|
|||
|
|
" df_indicators.append(df_index)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 合并所有指数的结果\n",
|
|||
|
|
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 保留trade_date列,并将同一天的数据按ts_code合并成一行\n",
|
|||
|
|
" df_final = df_all_indicators.pivot_table(\n",
|
|||
|
|
" index='trade_date',\n",
|
|||
|
|
" columns='ts_code',\n",
|
|||
|
|
" values=['daily_return', 'RSI', 'MACD', 'Signal_line', 'MACD_hist'],\n",
|
|||
|
|
" aggfunc='last'\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
|
|||
|
|
" df_final = df_final.reset_index()\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df_final\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 使用函数\n",
|
|||
|
|
"h5_filename = '../../data/index_data.h5'\n",
|
|||
|
|
"index_data = generate_index_indicators(h5_filename)\n",
|
|||
|
|
"index_data = index_data.dropna()\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 5,
|
|||
|
|
"id": "c4e9e1d31da6dba6",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:20:21.860318Z",
|
|||
|
|
"start_time": "2025-02-23T14:20:21.735442Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import talib\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def get_technical_factor(df):\n",
|
|||
|
|
" # 按股票和日期排序\n",
|
|||
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
" grouped = df.groupby('ts_code', group_keys=False)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 up 和 down\n",
|
|||
|
|
" df['up'] = (df['high'] - df[['close', 'open']].max(axis=1)) / df['close']\n",
|
|||
|
|
" df['down'] = (df[['close', 'open']].min(axis=1) - df['low']) / df['close']\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 ATR\n",
|
|||
|
|
" df['atr_14'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=14),\n",
|
|||
|
|
" index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['atr_6'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.ATR(x['high'].values, x['low'].values, x['close'].values, timeperiod=6),\n",
|
|||
|
|
" index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 OBV 及其均线\n",
|
|||
|
|
" df['obv'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['maobv_6'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.SMA(x['obv'].values, timeperiod=6), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['obv-maobv_6'] = df['obv'] - df['maobv_6']\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 RSI\n",
|
|||
|
|
" df['rsi_3'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=3), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['rsi_6'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=6), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['rsi_9'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.RSI(x['close'].values, timeperiod=9), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 return_10 和 return_20\n",
|
|||
|
|
" df['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
|
|||
|
|
" df['return_10'] = grouped['close'].apply(lambda x: x / x.shift(10) - 1)\n",
|
|||
|
|
" df['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 avg_close_5\n",
|
|||
|
|
" df['avg_close_5'] = grouped['close'].apply(lambda x: x.rolling(window=5).mean() / x)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算标准差指标\n",
|
|||
|
|
" df['std_return_5'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=5).std())\n",
|
|||
|
|
" df['std_return_15'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=15).std())\n",
|
|||
|
|
" df['std_return_25'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=25).std())\n",
|
|||
|
|
" df['std_return_90'] = grouped['close'].apply(lambda x: x.pct_change().rolling(window=90).std())\n",
|
|||
|
|
" df['std_return_90_2'] = grouped['close'].apply(lambda x: x.shift(10).pct_change().rolling(window=90).std())\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算比值指标\n",
|
|||
|
|
" df['std_return_5 / std_return_90'] = df['std_return_5'] / df['std_return_90']\n",
|
|||
|
|
" df['std_return_5 / std_return_25'] = df['std_return_5'] / df['std_return_25']\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算标准差差值\n",
|
|||
|
|
" df['std_return_90 - std_return_90_2'] = df['std_return_90'] - df['std_return_90_2']\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def get_act_factor(df, cat=True):\n",
|
|||
|
|
" # 按股票和日期排序\n",
|
|||
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
" grouped = df.groupby('ts_code', group_keys=False)\n",
|
|||
|
|
" # 计算 EMA 指标\n",
|
|||
|
|
" df['ema_5'] = grouped['close'].apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=5), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['ema_13'] = grouped['close'].apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=13), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['ema_20'] = grouped['close'].apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=20), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['ema_60'] = grouped['close'].apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.EMA(x.values, timeperiod=60), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
|
|||
|
|
" df['act_factor1'] = grouped['ema_5'].apply(\n",
|
|||
|
|
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 50\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['act_factor2'] = grouped['ema_13'].apply(\n",
|
|||
|
|
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 40\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['act_factor3'] = grouped['ema_20'].apply(\n",
|
|||
|
|
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 21\n",
|
|||
|
|
" )\n",
|
|||
|
|
" df['act_factor4'] = grouped['ema_60'].apply(\n",
|
|||
|
|
" lambda x: np.arctan((x / x.shift(1) - 1) * 100) * 57.3 / 10\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" if cat:\n",
|
|||
|
|
" df['cat_af1'] = df['act_factor1'] > 0\n",
|
|||
|
|
" df['cat_af2'] = df['act_factor2'] > df['act_factor1']\n",
|
|||
|
|
" df['cat_af3'] = df['act_factor3'] > df['act_factor2']\n",
|
|||
|
|
" df['cat_af4'] = df['act_factor4'] > df['act_factor3']\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 act_factor5 和 act_factor6\n",
|
|||
|
|
" df['act_factor5'] = df['act_factor1'] + df['act_factor2'] + df['act_factor3'] + df['act_factor4']\n",
|
|||
|
|
" df['act_factor6'] = (df['act_factor1'] - df['act_factor2']) / np.sqrt(\n",
|
|||
|
|
" df['act_factor1'] ** 2 + df['act_factor2'] ** 2)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 根据 trade_date 截面计算排名\n",
|
|||
|
|
" df['rank_act_factor1'] = df.groupby('trade_date', group_keys=False)['act_factor1'].rank(ascending=False, pct=True)\n",
|
|||
|
|
" df['rank_act_factor2'] = df.groupby('trade_date', group_keys=False)['act_factor2'].rank(ascending=False, pct=True)\n",
|
|||
|
|
" df['rank_act_factor3'] = df.groupby('trade_date', group_keys=False)['act_factor3'].rank(ascending=False, pct=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def get_money_flow_factor(df):\n",
|
|||
|
|
" # 计算资金流相关因子(字段名称见 tushare 数据说明)\n",
|
|||
|
|
" df['active_buy_volume_large'] = df['buy_lg_vol'] / df['net_mf_vol']\n",
|
|||
|
|
" df['active_buy_volume_big'] = df['buy_elg_vol'] / df['net_mf_vol']\n",
|
|||
|
|
" df['active_buy_volume_small'] = df['buy_sm_vol'] / df['net_mf_vol']\n",
|
|||
|
|
"\n",
|
|||
|
|
" df['buy_lg_vol_minus_sell_lg_vol'] = (df['buy_lg_vol'] - df['sell_lg_vol']) / df['net_mf_vol']\n",
|
|||
|
|
" df['buy_elg_vol_minus_sell_elg_vol'] = (df['buy_elg_vol'] - df['sell_elg_vol']) / df['net_mf_vol']\n",
|
|||
|
|
"\n",
|
|||
|
|
" df['log(circ_mv)'] = np.log(df['circ_mv'])\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def get_alpha_factor(df):\n",
|
|||
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
" grouped = df.groupby('ts_code')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # alpha_022: 当前 close 与 5 日前 close 差值\n",
|
|||
|
|
" df['alpha_022'] = grouped['close'].transform(lambda x: x - x.shift(5))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # alpha_003: (close - open) / (high - low)\n",
|
|||
|
|
" df['alpha_003'] = np.where(df['high'] != df['low'],\n",
|
|||
|
|
" (df['close'] - df['open']) / (df['high'] - df['low']),\n",
|
|||
|
|
" 0)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # alpha_007: 计算过去5日 close 与 vol 的相关性,并按 trade_date 排名\n",
|
|||
|
|
" df['alpha_007'] = grouped.apply(lambda x: x['close'].rolling(5).corr(x['vol'])).reset_index(level=0, drop=True)\n",
|
|||
|
|
" df['alpha_007'] = df.groupby('trade_date', group_keys=False)['alpha_007'].rank(ascending=True, pct=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # alpha_013: 计算过去5日 close 之和 - 20日 close 之和,并按 trade_date 排名\n",
|
|||
|
|
" df['alpha_013'] = grouped['close'].transform(lambda x: x.rolling(5).sum() - x.rolling(20).sum())\n",
|
|||
|
|
" df['alpha_013'] = df.groupby('trade_date', group_keys=False)['alpha_013'].rank(ascending=True, pct=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 6,
|
|||
|
|
"id": "53f86ddc0677a6d7",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:20:25.061140Z",
|
|||
|
|
"start_time": "2025-02-23T14:20:21.880078Z"
|
|||
|
|
},
|
|||
|
|
"scrolled": true
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"def read_industry_data(h5_filename):\n",
|
|||
|
|
" # 读取 H5 文件中所有的行业数据\n",
|
|||
|
|
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
|
|||
|
|
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
|
|||
|
|
" ]) # 假设 H5 文件的键是 'industry_data'\n",
|
|||
|
|
" industry_data = industry_data.reindex()\n",
|
|||
|
|
" industry_data['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
|
|||
|
|
" industry_data['obv'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
|
|||
|
|
" # industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # industry_data = get_act_factor(industry_data, cat=False)\n",
|
|||
|
|
" # industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
|
|||
|
|
" factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
|
|||
|
|
"\n",
|
|||
|
|
" for factor in factor_columns:\n",
|
|||
|
|
" if factor in industry_data.columns:\n",
|
|||
|
|
" # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
|
|||
|
|
" industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
|
|||
|
|
" lambda x: x - x.median())\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
|
|||
|
|
" lambda x: x.rank(pct=True))\n",
|
|||
|
|
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = industry_data.rename(\n",
|
|||
|
|
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
|
|||
|
|
" return industry_data\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 7,
|
|||
|
|
"id": "dbe2fd8021b9417f",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:20:25.128062Z",
|
|||
|
|
"start_time": "2025-02-23T14:20:25.076707Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
},
|
|||
|
|
"scrolled": true
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"origin_columns = df.columns.tolist()\n",
|
|||
|
|
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'l2_code']]\n",
|
|||
|
|
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 8,
|
|||
|
|
"id": "5f3d9aece75318cd",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:21:22.858178Z",
|
|||
|
|
"start_time": "2025-02-23T14:20:25.145555Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|||
|
|
"Index: 5538535 entries, 1962 to 5538534\n",
|
|||
|
|
"Data columns (total 71 columns):\n",
|
|||
|
|
" # Column Dtype \n",
|
|||
|
|
"--- ------ ----- \n",
|
|||
|
|
" 0 ts_code object \n",
|
|||
|
|
" 1 trade_date datetime64[ns]\n",
|
|||
|
|
" 2 open float64 \n",
|
|||
|
|
" 3 close float64 \n",
|
|||
|
|
" 4 high float64 \n",
|
|||
|
|
" 5 low float64 \n",
|
|||
|
|
" 6 vol float64 \n",
|
|||
|
|
" 7 turnover_rate float64 \n",
|
|||
|
|
" 8 pe_ttm float64 \n",
|
|||
|
|
" 9 circ_mv float64 \n",
|
|||
|
|
" 10 volume_ratio float64 \n",
|
|||
|
|
" 11 is_st bool \n",
|
|||
|
|
" 12 up_limit float64 \n",
|
|||
|
|
" 13 down_limit float64 \n",
|
|||
|
|
" 14 buy_sm_vol float64 \n",
|
|||
|
|
" 15 sell_sm_vol float64 \n",
|
|||
|
|
" 16 buy_lg_vol float64 \n",
|
|||
|
|
" 17 sell_lg_vol float64 \n",
|
|||
|
|
" 18 buy_elg_vol float64 \n",
|
|||
|
|
" 19 sell_elg_vol float64 \n",
|
|||
|
|
" 20 net_mf_vol float64 \n",
|
|||
|
|
" 21 cat_l2_code object \n",
|
|||
|
|
" 22 up float64 \n",
|
|||
|
|
" 23 down float64 \n",
|
|||
|
|
" 24 atr_14 float64 \n",
|
|||
|
|
" 25 atr_6 float64 \n",
|
|||
|
|
" 26 obv float64 \n",
|
|||
|
|
" 27 maobv_6 float64 \n",
|
|||
|
|
" 28 obv-maobv_6 float64 \n",
|
|||
|
|
" 29 rsi_3 float64 \n",
|
|||
|
|
" 30 rsi_6 float64 \n",
|
|||
|
|
" 31 rsi_9 float64 \n",
|
|||
|
|
" 32 return_5 float64 \n",
|
|||
|
|
" 33 return_10 float64 \n",
|
|||
|
|
" 34 return_20 float64 \n",
|
|||
|
|
" 35 avg_close_5 float64 \n",
|
|||
|
|
" 36 std_return_5 float64 \n",
|
|||
|
|
" 37 std_return_15 float64 \n",
|
|||
|
|
" 38 std_return_25 float64 \n",
|
|||
|
|
" 39 std_return_90 float64 \n",
|
|||
|
|
" 40 std_return_90_2 float64 \n",
|
|||
|
|
" 41 std_return_5 / std_return_90 float64 \n",
|
|||
|
|
" 42 std_return_5 / std_return_25 float64 \n",
|
|||
|
|
" 43 std_return_90 - std_return_90_2 float64 \n",
|
|||
|
|
" 44 ema_5 float64 \n",
|
|||
|
|
" 45 ema_13 float64 \n",
|
|||
|
|
" 46 ema_20 float64 \n",
|
|||
|
|
" 47 ema_60 float64 \n",
|
|||
|
|
" 48 act_factor1 float64 \n",
|
|||
|
|
" 49 act_factor2 float64 \n",
|
|||
|
|
" 50 act_factor3 float64 \n",
|
|||
|
|
" 51 act_factor4 float64 \n",
|
|||
|
|
" 52 cat_af1 bool \n",
|
|||
|
|
" 53 cat_af2 bool \n",
|
|||
|
|
" 54 cat_af3 bool \n",
|
|||
|
|
" 55 cat_af4 bool \n",
|
|||
|
|
" 56 act_factor5 float64 \n",
|
|||
|
|
" 57 act_factor6 float64 \n",
|
|||
|
|
" 58 rank_act_factor1 float64 \n",
|
|||
|
|
" 59 rank_act_factor2 float64 \n",
|
|||
|
|
" 60 rank_act_factor3 float64 \n",
|
|||
|
|
" 61 active_buy_volume_large float64 \n",
|
|||
|
|
" 62 active_buy_volume_big float64 \n",
|
|||
|
|
" 63 active_buy_volume_small float64 \n",
|
|||
|
|
" 64 buy_lg_vol_minus_sell_lg_vol float64 \n",
|
|||
|
|
" 65 buy_elg_vol_minus_sell_elg_vol float64 \n",
|
|||
|
|
" 66 log(circ_mv) float64 \n",
|
|||
|
|
" 67 alpha_022 float64 \n",
|
|||
|
|
" 68 alpha_003 float64 \n",
|
|||
|
|
" 69 alpha_007 float64 \n",
|
|||
|
|
" 70 alpha_013 float64 \n",
|
|||
|
|
"dtypes: bool(5), datetime64[ns](1), float64(63), object(2)\n",
|
|||
|
|
"memory usage: 2.8+ GB\n",
|
|||
|
|
"None\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"def filter_data(df):\n",
|
|||
|
|
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
|
|||
|
|
" df = df[~df['is_st']]\n",
|
|||
|
|
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
|
|||
|
|
" df = df[~df['ts_code'].str.startswith('30')]\n",
|
|||
|
|
" df = df[~df['ts_code'].str.startswith('68')]\n",
|
|||
|
|
" df = df[~df['ts_code'].str.startswith('8')]\n",
|
|||
|
|
" df = df.reset_index(drop=True)\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = filter_data(df)\n",
|
|||
|
|
"df = get_technical_factor(df)\n",
|
|||
|
|
"df = get_act_factor(df)\n",
|
|||
|
|
"df = get_money_flow_factor(df)\n",
|
|||
|
|
"df = get_alpha_factor(df)\n",
|
|||
|
|
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
|
|||
|
|
"# df = df.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(df.info())"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 9,
|
|||
|
|
"id": "f4f16d63ad18d1bc",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:21:23.090304Z",
|
|||
|
|
"start_time": "2025-02-23T14:21:22.943303Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"['turnover_rate', 'pe_ttm', 'volume_ratio', 'cat_l2_code', 'up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_5', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20', 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'cat_af1', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'log(circ_mv)', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013']\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"feature_columns = [col for col in df.columns if col not in ['trade_date',\n",
|
|||
|
|
" 'ts_code',\n",
|
|||
|
|
" 'label']]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
|
|||
|
|
"print(feature_columns)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 10,
|
|||
|
|
"id": "0ebdfb92-d88b-4b5c-a715-675dab876fc0",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:21:23.168842Z",
|
|||
|
|
"start_time": "2025-02-23T14:21:23.122002Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"def create_deviation_within_dates(df, feature_columns):\n",
|
|||
|
|
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
|
|||
|
|
" new_columns = {}\n",
|
|||
|
|
" ret_feature_columns = feature_columns[:]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 自动选择所有数值型特征\n",
|
|||
|
|
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 遍历所有数值型特征\n",
|
|||
|
|
" for feature in num_features:\n",
|
|||
|
|
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped_median = df.groupby(['trade_date', groupby_col])[feature].transform('median')\n",
|
|||
|
|
" deviation_col_name = f'deviation_median_{feature}'\n",
|
|||
|
|
" new_columns[deviation_col_name] = df[feature] - grouped_median\n",
|
|||
|
|
" ret_feature_columns.append(deviation_col_name)\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped_mean = df.groupby(groupby_col)[feature].transform('mean')\n",
|
|||
|
|
" deviation_col_name = f'deviation_mean_{feature}'\n",
|
|||
|
|
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
|
|||
|
|
" ret_feature_columns.append(deviation_col_name)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
|
|||
|
|
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
|
|||
|
|
" # if feature in df.columns and f'industry_{feature}' in df:\n",
|
|||
|
|
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df, ret_feature_columns\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 11,
|
|||
|
|
"id": "fbb968383f8cf2c7",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:21:39.397925Z",
|
|||
|
|
"start_time": "2025-02-23T14:21:23.168842Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"Removed 123852 outliers.\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"def get_qcuts(series, quantiles):\n",
|
|||
|
|
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
|
|||
|
|
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99):\n",
|
|||
|
|
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
|
|||
|
|
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Calculate lower and upper bounds based on percentiles\n",
|
|||
|
|
" lower_bound = label.quantile(lower_percentile)\n",
|
|||
|
|
" upper_bound = label.quantile(upper_percentile)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Filter out values outside the bounds\n",
|
|||
|
|
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Print the number of removed outliers\n",
|
|||
|
|
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
|
|||
|
|
" return filtered_label\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def calculate_risk_adjusted_target(df, days=5):\n",
|
|||
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
|
|||
|
|
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
|
|||
|
|
"\n",
|
|||
|
|
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
|
|||
|
|
" level=0, drop=True)\n",
|
|||
|
|
" df['sharpe_ratio'] = df['future_return'] * df['future_volatility']\n",
|
|||
|
|
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df['sharpe_ratio']\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"future_close = df.groupby('ts_code')['close'].shift(-4)\n",
|
|||
|
|
"future_return = (future_close - df['close']) / df['close']\n",
|
|||
|
|
"labels = future_return >= 0.03\n",
|
|||
|
|
"df['label'] = labels\n",
|
|||
|
|
"df['lr_label'] = future_return\n",
|
|||
|
|
"df['lr_label'] = remove_outliers_label_percentile(df['lr_label'])\n",
|
|||
|
|
"\n",
|
|||
|
|
"# df = df.apply(lambda x: x.astype('float32') if x.dtype in ['float64', 'float32'] else x)\n",
|
|||
|
|
"df = df.sort_values(by=['trade_date', 'ts_code'])\n",
|
|||
|
|
"train_data = df[(df['trade_date'] <= '2023-01-01') & (df['trade_date'] >= '2016-01-01')]\n",
|
|||
|
|
"test_data = df[df['trade_date'] >= '2023-01-01']\n",
|
|||
|
|
"\n",
|
|||
|
|
"# train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
"# train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"# train_data = train_data.rename(columns={'l2_code': 'cat_l2_code'})\n",
|
|||
|
|
"# test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
"# test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"# test_data = test_data.rename(columns={'l2_code': 'cat_l2_code'})\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_data = train_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
|
|||
|
|
"test_data = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
|
|||
|
|
"\n",
|
|||
|
|
"# train_data = get_future_data(train_data)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# df = df[['ts_code', 'trade_date', 'open', 'close']]\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 12,
|
|||
|
|
"id": "35238cb4f45ce756",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:21:56.438803Z",
|
|||
|
|
"start_time": "2025-02-23T14:21:39.431593Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"feature_columns size: 149\n",
|
|||
|
|
"feature_columns size: 149\n",
|
|||
|
|
"1171702\n",
|
|||
|
|
"最小日期: 2017-03-21\n",
|
|||
|
|
"最大日期: 2022-12-30\n",
|
|||
|
|
"402634\n",
|
|||
|
|
"最小日期: 2023-01-03\n",
|
|||
|
|
"最大日期: 2025-02-12\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"# feature_columns_new = feature_columns[:]\n",
|
|||
|
|
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
|
|||
|
|
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
|
|||
|
|
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
|
|||
|
|
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_data = train_data.dropna(subset=feature_columns_new)\n",
|
|||
|
|
"train_data = train_data.dropna(subset=['label'])\n",
|
|||
|
|
"train_data = train_data.reset_index(drop=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
"test_data = test_data.dropna(subset=feature_columns_new)\n",
|
|||
|
|
"test_data = test_data.dropna(subset=['label'])\n",
|
|||
|
|
"test_data = test_data.reset_index(drop=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(len(train_data))\n",
|
|||
|
|
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(len(test_data))\n",
|
|||
|
|
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 13,
|
|||
|
|
"id": "f679ccccecd09a06",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:21:56.621062Z",
|
|||
|
|
"start_time": "2025-02-23T14:21:56.470090Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
|
|||
|
|
"for col in cat_columns:\n",
|
|||
|
|
" train_data[col] = train_data[col].astype('category')\n",
|
|||
|
|
" test_data[col] = test_data[col].astype('category')"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 14,
|
|||
|
|
"id": "8f134d435f71e9e2",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:21:58.157508Z",
|
|||
|
|
"start_time": "2025-02-23T14:21:56.638939Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from catboost import Pool\n",
|
|||
|
|
"import lightgbm as lgb\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import matplotlib.pyplot as plt\n",
|
|||
|
|
"import optuna\n",
|
|||
|
|
"from sklearn.model_selection import KFold\n",
|
|||
|
|
"from sklearn.metrics import mean_absolute_error\n",
|
|||
|
|
"import os\n",
|
|||
|
|
"import json\n",
|
|||
|
|
"import pickle\n",
|
|||
|
|
"import hashlib\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def train_light_model(train_data_df, test_data_df, params, feature_columns, callbacks, evals,\n",
|
|||
|
|
" print_feature_importance=True, num_boost_round=100,\n",
|
|||
|
|
" use_optuna=False):\n",
|
|||
|
|
" train_data_df, test_data_df = train_data_df.dropna(subset=['label']), test_data_df.dropna(subset=['label'])\n",
|
|||
|
|
" X_train = train_data_df[feature_columns]\n",
|
|||
|
|
" y_train = train_data_df['label']\n",
|
|||
|
|
"\n",
|
|||
|
|
" X_val = test_data_df[feature_columns]\n",
|
|||
|
|
" y_val = test_data_df['label']\n",
|
|||
|
|
"\n",
|
|||
|
|
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
|
|||
|
|
" print(f'categorical_feature: {categorical_feature}')\n",
|
|||
|
|
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
|
|||
|
|
" val_data = lgb.Dataset(X_val, label=y_val, categorical_feature=categorical_feature)\n",
|
|||
|
|
" model = lgb.train(\n",
|
|||
|
|
" params, train_data, num_boost_round=num_boost_round,\n",
|
|||
|
|
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
|
|||
|
|
" callbacks=callbacks\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" if print_feature_importance:\n",
|
|||
|
|
" lgb.plot_metric(evals)\n",
|
|||
|
|
" # lgb.plot_tree(model, figsize=(20, 8))\n",
|
|||
|
|
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
|
|||
|
|
" plt.show()\n",
|
|||
|
|
" return model\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"from catboost import CatBoostClassifier\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def train_catboost(train_data_df, test_data_df, feature_columns, params=None, plot=False):\n",
|
|||
|
|
" train_data_df, test_data_df = train_data_df.dropna(subset=['label']), test_data_df.dropna(subset=['label'])\n",
|
|||
|
|
" X_train = train_data_df[feature_columns]\n",
|
|||
|
|
" y_train = train_data_df['label']\n",
|
|||
|
|
"\n",
|
|||
|
|
" X_val = test_data_df[feature_columns]\n",
|
|||
|
|
" y_val = test_data_df['label']\n",
|
|||
|
|
"\n",
|
|||
|
|
" cat_features = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
|
|||
|
|
" print(f'cat_features: {cat_features}')\n",
|
|||
|
|
" train_pool = Pool(data=X_train, label=y_train, cat_features=cat_features)\n",
|
|||
|
|
" val_pool = Pool(data=X_val, label=y_val, cat_features=cat_features)\n",
|
|||
|
|
"\n",
|
|||
|
|
" model = CatBoostClassifier(**params)\n",
|
|||
|
|
" model.fit(train_pool,\n",
|
|||
|
|
" eval_set=val_pool, plot=plot)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return model"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 15,
|
|||
|
|
"id": "4a4542e1ed6afe7d",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:21:58.418136Z",
|
|||
|
|
"start_time": "2025-02-23T14:21:58.339405Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"light_params = {\n",
|
|||
|
|
" 'objective': 'binary',\n",
|
|||
|
|
" 'metric': 'average_precision',\n",
|
|||
|
|
" 'learning_rate': 0.05,\n",
|
|||
|
|
" 'is_unbalance': True,\n",
|
|||
|
|
" 'num_leaves': 2048,\n",
|
|||
|
|
" 'min_data_in_leaf': 1024,\n",
|
|||
|
|
" 'max_depth': 32,\n",
|
|||
|
|
" 'max_bin': 1024,\n",
|
|||
|
|
" 'feature_fraction': 0.7,\n",
|
|||
|
|
" 'bagging_fraction': 0.7,\n",
|
|||
|
|
" 'bagging_freq': 5,\n",
|
|||
|
|
" # 'lambda_l1': 80,\n",
|
|||
|
|
" # 'lambda_l2': 65,\n",
|
|||
|
|
" 'verbosity': -1,\n",
|
|||
|
|
" 'num_threads': 16\n",
|
|||
|
|
"}"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 16,
|
|||
|
|
"id": "beeb098799ecfa6a",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:23:12.347650Z",
|
|||
|
|
"start_time": "2025-02-23T14:21:58.449422Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"train data size: 1171702\n",
|
|||
|
|
"categorical_feature: [3, 34, 35, 36, 37]\n",
|
|||
|
|
"Training until validation scores don't improve for 50 rounds\n",
|
|||
|
|
"Early stopping, best iteration is:\n",
|
|||
|
|
"[89]\ttrain's average_precision: 0.473589\tvalid's average_precision: 0.284686\n",
|
|||
|
|
"Evaluated only: average_precision\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAHFCAYAAAAaD0bAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABtj0lEQVR4nO3deVzUdf4H8NfMwAw3KDfKpYLIoSJogHmHpmZmmaZ5lW65dnhsa7rZLzVLc1tFW7XcLdkuc1uvDksxEc88EPDA+wJhuJNTYI7v74+RgREGYQQGmNfz8ZiHzOf7me98vm+QefH5XiJBEAQQERERmRCxsQdARERE1NIYgIiIiMjkMAARERGRyWEAIiIiIpPDAEREREQmhwGIiIiITA4DEBEREZkcBiAiIiIyOQxAREREZHIYgIhamdjYWIhEIohEIhw8eLDWckEQ0K1bN4hEIgwePNig99i4cSNiY2Mb9ZqDBw/qHVNTaa73aImx65OamoqlS5fi1q1bzbL+pUuXQiQSGfRaY9aFyNgYgIhaKVtbW3z++ee12hMSEnD9+nXY2toavG5DAlCfPn1w/Phx9OnTx+D3NRZjjj01NRXLli1rtgA0a9YsHD9+3KDXtuXvKdGjYgAiaqUmTpyI7du3o6ioSKf9888/R2RkJLy8vFpkHAqFAkqlEnZ2doiIiICdnV2LvG9TaItjLysra1T/zp07IyIiwqD3akt1IWpqDEBErdSkSZMAAFu3btW2FRYWYvv27Xj55ZfrfE1lZSVWrFiBgIAAyGQyODs746WXXkJubq62j4+PDy5cuICEhATtrjYfHx8A1btEvvrqK/zlL39Bp06dIJPJcO3aNb27S06cOIExY8bA0dERFhYW6Nq1K+bNm/fQ7bt06RKefPJJWFlZwcnJCbNnz0ZxcXGtfj4+PpgxY0at9sGDB+vsAmzs2GfMmAEbGxtcu3YNo0aNgo2NDTw9PfGXv/wFFRUVOu91584djB8/Hra2tnBwcMCLL76IU6dOQSQS1TuTFhsbi+effx4AMGTIEG29q14zePBgBAcH49ChQ4iKioKVlZX2e7tt2zYMHz4c7u7usLS0RI8ePbBo0SKUlpbqvEddu8B8fHzw1FNP4ddff0WfPn1gaWmJgIAAfPHFFzr9jFUXotaAAYiolbKzs8P48eN1PrS2bt0KsViMiRMn1uqvVqsxduxYrFq1CpMnT8bPP/+MVatWIS4uDoMHD8a9e/cAADt37kSXLl0QGhqK48eP4/jx49i5c6fOuhYvXoy0tDR8+umn+PHHH+Hi4lLnGPfu3YsBAwYgLS0Na9aswS+//IIlS5YgOzu73m3Lzs7GoEGDcP78eWzcuBFfffUVSkpK8Prrrze2TLU0dOyAZobo6aefxrBhw7B79268/PLLWLt2LT766CNtn9LSUgwZMgTx8fH46KOP8N///heurq51fg8eNHr0aHz44YcAgA0bNmjrPXr0aG0fuVyOKVOmYPLkydizZw/mzJkDALh69SpGjRqFzz//HL/++ivmzZuH//73vxgzZkyD6pCSkoK//OUvmD9/Pnbv3o2ePXti5syZOHTo0ENf29x1IWoVBCJqVbZs2SIAEE6dOiXEx8cLAITz588LgiAIffv2FWbMmCEIgiAEBQUJgwYN0r5u69atAgBh+/btOus7deqUAEDYuHGjtu3B11aper+BAwfqXRYfH69t69q1q9C1a1fh3r17jdrGt99+WxCJREJycrJOe3R0dK338Pb2FqZPn15rHYMGDdLZhsaOffr06QIA4b///a9O31GjRgndu3fXPt+wYYMAQPjll190+r366qsCAGHLli31buv3339f671rbgMA4bfffqt3HWq1WlAoFEJCQoIAQEhJSdEue++994QHf5V7e3sLFhYWwu3bt7Vt9+7dEzp27Ci8+uqr2jZj1oXI2DgDRNSKDRo0CF27dsUXX3yBc+fO4dSpU3p3f/30009wcHDAmDFjoFQqtY/evXvDzc2tUWf6PPfccw/tc+XKFVy/fh0zZ86EhYVFg9cNAPHx8QgKCkKvXr102idPntyo9dSlIWOvIhKJas2o9OzZE7dv39Y+T0hIgK2tLZ588kmdflW7KB9Vhw4dMHTo0FrtN27cwOTJk+Hm5gaJRAJzc3MMGjQIAHDx4sWHrrd37946x4lZWFjA399fZ9v0aQ11IWpuZsYeABHpJxKJ8NJLL2H9+vUoLy+Hv78/BgwYUGff7Oxs3L17F1KptM7leXl5DX5fd3f3h/apOq6oc+fODV5vlfz8fPj6+tZqd3Nza/S6HtSQsVexsrKqFd5kMhnKy8u1z/Pz8+Hq6lrrtXW1GaKu8ZaUlGDAgAGwsLDAihUr4O/vDysrK6Snp+PZZ5/V7s6sj6OjY602mUzWoNe2hroQNTcGIKJWbsaMGfi///s/fPrpp/jggw/09nNycoKjoyN+/fXXOpc35rT5hlxXxtnZGYDmQNjGcnR0RFZWVq32utosLCxqHXwLaAKdk5NTrXZDr4mjj6OjI06ePFmrva6xGqKu8R44cACZmZk4ePCgdtYHAO7evdsk79kUmrsuRM2Nu8CIWrlOnTrhr3/9K8aMGYPp06fr7ffUU08hPz8fKpUK4eHhtR7du3fX9m3oTEB9/P39tbvn6goo9RkyZAguXLiAlJQUnfZvv/22Vl8fHx+cPXtWp+3KlSu4fPly4wdtgEGDBqG4uBi//PKLTvt3333XoNfLZDIAaFS9q0JR1WurfPbZZw1eR3N71LoQGRtngIjagFWrVj20zwsvvIBvvvkGo0aNwty5c9GvXz+Ym5vjzp07iI+Px9ixYzFu3DgAQEhICL777jts27YNXbp0gYWFBUJCQho9rg0bNmDMmDGIiIjA/Pnz4eXlhbS0NOzduxfffPON3tfNmzcPX3zxBUaPHo0VK1bA1dUV33zzDS5dulSr79SpUzFlyhTMmTMHzz33HG7fvo3Vq1drZ6Ca2/Tp07F27VpMmTIFK1asQLdu3fDLL79g7969AACxuP6/I4ODgwEAmzdvhq2tLSwsLODr61vnLqoqUVFR6NChA2bPno333nsP5ubm+Oabb2oFRmN61LoQGRt/QonaCYlEgh9++AF/+9vfsGPHDowbNw7PPPMMVq1aVSvgLFu2DIMGDcKf/vQn9OvXr8GnVj9oxIgROHToENzd3fHmm2/iySefxPLlyx96HIibmxsSEhIQGBiIP//5z5gyZQosLCzwz3/+s1bfyZMnY/Xq1di7dy+eeuopbNq0CZs2bYK/v79BY24sa2trHDhwAIMHD8bChQvx3HPPIS0tDRs3bgQAODg41Pt6X19fxMTEICUlBYMHD0bfvn3x448/1vsaR0dH/Pzzz7CyssKUKVPw8ssvw8bGBtu2bWuqzXpkj1oXImMTCYIgGHsQRERtzYcffoglS5YgLS3NoAPB2yvWhdoK7gIjInqIqpmpgIAAKBQKHDhwAOvXr8eUKVNM+kOedaG2jAGIiOghrKyssHbtWty6dQsVFRXw8vLC22+/jSVLlhh7aEbFulBbxl1gREREZHJ4EDQRERGZHAYgIiIiMjkMQERERGRyeBB0HdRqNTIzM2Fra9vkl9UnIiKi5iEIAoqLi+Hh4fHQi3EyANUhMzMTnp6exh4GERERGSA9Pf2hl2JgAKpD1U0jb968iY4dOxp5NK2PQqHAvn37MHz4cJibmxt7OK0Ka1M/1kc/1kY/1qZ+rE+1oqIieHp6NujmzwxAdaja7WVraws7Ozsjj6b1USgUsLKygp2dncn/Z3sQa1M/1kc/1kY/1qZ+rE9tDTl8hQdBExERkclhACIiIiKTwwBEREREJofHABEREbUglUoFhULRZOtTKBQwMzNDeXk5VCpVk623tZJKpQ89xb0hGICIiIhagCAIyMrKwt27d5t8vW5ubkhPTzeJa9eJxWL4+vpCKpU+0noYgIiIiFpAVfhxcXGBlZVVk4UVtVqNkpIS2NjYNMnMSGtWdaFiuVwOLy+vR6ohAxAREVEzU6lU2vDj6OjYpOtWq9WorKyEhYVFuw9AAODs7IzMzEwolcp
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4YAAAHFCAYAAACuIOfmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xN9//A8dfNutnIkoQQKzFiz1hJzIhVe0QJGk2tkiK22KvU6Bctbayg2qBGipREaxOzBBVij5pphMz7+8Mj5+dKkGgkuO/n43Ef7fl8PudzPu9zce/7fj7nHJVGo9EghBBCCCGEEEJn6eX3AIQQQgghhBBC5C9JDIUQQgghhBBCx0liKIQQQgghhBA6ThJDIYQQQgghhNBxkhgKIYQQQgghhI6TxFAIIYQQQgghdJwkhkIIIYQQQgih4yQxFEIIIYQQQggdJ4mhEEIIIYQQQug4SQyFEEII8cFbvnw5KpUqy9ewYcPeyTHPnj1LcHAwcXFx76T//yIuLg6VSsXy5cvzeyhvLTw8nODg4PwehhA6wyC/ByCEEEIIkVtCQkIoW7asVpmjo+M7OdbZs2eZOHEinp6eODs7v5NjvC0HBwcOHDhAqVKl8nsoby08PJz//e9/khwKkUckMRRCCCHER8PNzY0aNWrk9zD+k5SUFFQqFQYGb/81Ta1WU6dOnVwcVd5JTEzE1NQ0v4chhM6RpaRCCCGE0Bk//fQT7u7umJmZYW5uTvPmzTl+/LhWm6NHj9K1a1ecnZ0xMTHB2dmZbt26ceXKFaXN8uXL6dSpEwBeXl7KstWMpZvOzs74+fllOr6npyeenp7KdlRUFCqVilWrVvHVV19RpEgR1Go1Fy9eBOD333+ncePGWFpaYmpqSr169di1a9cb48xqKWlwcDAqlYpTp07RqVMnChQogJWVFYGBgaSmpnL+/Hm8vb2xsLDA2dmZWbNmafWZMdbVq1cTGBiIvb09JiYmeHh4ZDqHAJs3b8bd3R1TU1MsLCxo2rQpBw4c0GqTMaZjx47RsWNHChUqRKlSpfDz8+N///sfgNay4Ixlu//73/9o2LAhdnZ2mJmZUbFiRWbNmkVKSkqm8+3m5saRI0do0KABpqamlCxZkhkzZpCenq7V9tGjR3z11VeULFkStVqNnZ0dPj4+nDt3TmmTnJzMlClTKFu2LGq1GltbW3r37s0///zzxvdEiPedJIZCCCGE+GikpaWRmpqq9cowbdo0unXrRvny5Vm/fj2rVq3i33//pUGDBpw9e1ZpFxcXh6urK/PmzWPHjh3MnDmTW7duUbNmTe7duwdAy5YtmTZtGvA8STlw4AAHDhygZcuWbzXuUaNGcfXqVZYsWcKWLVuws7Nj9erVNGvWDEtLS1asWMH69euxsrKiefPm2UoOX6Vz585UrlyZsLAw/P39+eabbxg6dCiffPIJLVu2ZOPGjTRq1IigoCA2bNiQaf/Ro0dz6dIlli1bxrJly7h58yaenp5cunRJabNmzRratm2LpaUla9eu5YcffuDhw4d4enqyd+/eTH22b9+e0qVL8/PPP7NkyRLGjRtHx44dAZRze+DAARwcHACIjY2le/furFq1iq1bt9K3b19mz57N559/nqnv27dv4+vrS48ePdi8eTMtWrRg1KhRrF69Wmnz77//Ur9+fb777jt69+7Nli1bWLJkCS4uLty6dQuA9PR02rZty4wZM+jevTvbtm1jxowZRERE4OnpydOnT9/6PRHivaARQgghhPjAhYSEaIAsXykpKZqrV69qDAwMNIMGDdLa799//9XY29trOnfu/Mq+U1NTNQkJCRozMzPN/PnzlfKff/5ZA2giIyMz7VO8eHFNr169MpV7eHhoPDw8lO3IyEgNoGnYsKFWuydPnmisrKw0rVu31ipPS0vTVK5cWVOrVq3XnA2N5vLlyxpAExISopRNmDBBA2jmzJmj1bZKlSoaQLNhwwalLCUlRWNra6tp3759prFWq1ZNk56erpTHxcVpDA0NNZ999pkyRkdHR03FihU1aWlpSrt///1XY2dnp6lbt26mMY0fPz5TDAMGDNBk56tqWlqaJiUlRbNy5UqNvr6+5sGDB0qdh4eHBtAcOnRIa5/y5ctrmjdvrmxPmjRJA2giIiJeeZy1a9dqAE1YWJhW+ZEjRzSAZtGiRW8cqxDvM5kxFEIIIcRHY+XKlRw5ckTrZWBgwI4dO0hNTaVnz55as4nGxsZ4eHgQFRWl9JGQkEBQUBClS5fGwMAAAwMDzM3NefLkCTExMe9k3B06dNDa3r9/Pw8ePKBXr15a401PT8fb25sjR47w5MmTtzpWq1attLbLlSuHSqWiRYsWSpmBgQGlS5fWWj6boXv37qhUKmW7ePHi1K1bl8jISADOnz/PzZs3+fTTT9HT+/+vmubm5nTo0IGDBw+SmJj42vjf5Pjx47Rp0wZra2v09fUxNDSkZ8+epKWlceHCBa229vb21KpVS6usUqVKWrH99ttvuLi40KRJk1cec+vWrRQsWJDWrVtrvSdVqlTB3t5e68+QEB8iufmMEEIIIT4a5cqVy/LmM3fu3AGgZs2aWe73YgLTvXt3du3axbhx46hZsyaWlpaoVCp8fHze2XLBjCWSL483YzllVh48eICZmVmOj2VlZaW1bWRkhKmpKcbGxpnK4+PjM+1vb2+fZdnJkycBuH//PpA5Jnh+h9j09HQePnyodYOZrNq+ytWrV2nQoAGurq7Mnz8fZ2dnjI2NOXz4MAMGDMj0HllbW2fqQ61Wa7X7559/KFas2GuPe+fOHR49eoSRkVGW9RnLjIX4UEliKIQQQoiPno2NDQC//PILxYsXf2W7x48fs3XrViZMmMDIkSOV8qSkJB48eJDt4xkbG5OUlJSp/N69e8pYXvTiDNyL4124cOEr7y5auHDhbI8nN92+fTvLsowELOO/GdfmvejmzZvo6elRqFAhrfKX43+dTZs28eTJEzZs2KD1Xp44cSLbfbzM1taW69evv7aNjY0N1tbWbN++Pct6CwuLtz6+EO8DSQyFEEII8dFr3rw5BgYGxMbGvnbZokqlQqPRoFartcqXLVtGWlqaVllGm6xmEZ2dnTl16pRW2YULFzh//nyWieHL6tWrR8GCBTl79iwDBw58Y/u8tHbtWgIDA5Vk7sqVK+zfv5+ePXsC4OrqSpEiRVizZg3Dhg1T2j158oSwsDDlTqVv8uL5NTExUcoz+nvxPdJoNCxduvStY2rRogXjx49n9+7dNGrUKMs2rVq1Yt26daSlpVG7du23PpYQ7ytJDIUQQgjx0XN2dmbSpEmMGTOGS5cu4e3tTaFChbhz5w6HDx/GzMyMiRMnYmlpScOGDZk9ezY2NjY4OzuzZ88efvjhBwoWLKjVp5ubGwDff/89FhYWGBsbU6JECaytrfn000/p0aMH/fv3p0OHDly5coVZs2Zha2ubrfGam5uzcOFCevXqxYMHD+jYsSN2dnb8888/nDx5kn/++YfFixfn9mnKlrt379KuXTv8/f15/PgxEyZMwNjYmFGjRgHPl+XOmjULX19fWrVqxeeff05SUhKzZ8/m0aNHzJgxI1vHqVixIgAzZ86kRYsW6OvrU6lSJZo2bYqRkRHdunVjxIgRPHv2jMWLF/Pw4cO3jmnIkCH89NNPtG3blpEjR1KrVi2ePn3Knj17aNWqFV5eXnTt2pXQ0FB8fHz48ssvqVWrFoaGhly/fp3IyEjatm1Lu3bt3noMQuQ3ufmMEEIIIXTCqFGj+OWXX7hw4QK9evWiefPmjBgxgitXrtCwYUOl3Zo1a/Dy8mLEiBG0b9+eo0ePEhERQYECBbT6K1GiBPPmzePkyZN4enpSs2ZNtmzZAjy/TnHWrFns2LGDVq1asXjxYhYvXoyLi0u2x9ujRw8iIyNJSEjg888/p0mTJnz55ZccO3aMxo0b585JeQvTpk2jePHi9O7dmz59+uDg4EBkZCSlSpVS2nTv3p1NmzZx//59unTpQu/evbG0tCQyMpL69etn6zjdu3fns88+Y9GiRbi7u1OzZk1u3rxJ2bJlCQsL4+HDh7R
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print('train data size: ', len(train_data))\n",
|
|||
|
|
"\n",
|
|||
|
|
"evals = {}\n",
|
|||
|
|
"model = train_light_model(train_data, test_data, light_params, feature_columns_new,\n",
|
|||
|
|
" [lgb.log_evaluation(period=500),\n",
|
|||
|
|
" lgb.callback.record_evaluation(evals),\n",
|
|||
|
|
" lgb.early_stopping(50, first_metric_only=True)\n",
|
|||
|
|
" ], evals,\n",
|
|||
|
|
" num_boost_round=1000, use_optuna=False,\n",
|
|||
|
|
" print_feature_importance=True)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 17,
|
|||
|
|
"id": "445dff84-70b2-4fc9-a9b6-1251993324d6",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:23:12.546231Z",
|
|||
|
|
"start_time": "2025-02-23T14:23:12.397923Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"catboost_params = {\n",
|
|||
|
|
" 'loss_function': 'Logloss', # 适用于二分类\n",
|
|||
|
|
" 'eval_metric': 'AUC', # 评估指标\n",
|
|||
|
|
" 'iterations': 5000,\n",
|
|||
|
|
" 'learning_rate': 0.01,\n",
|
|||
|
|
" 'depth': 10, # 控制模型复杂度\n",
|
|||
|
|
" # 'l2_leaf_reg': 3, # L2 正则化\n",
|
|||
|
|
" 'verbose': 500,\n",
|
|||
|
|
" 'early_stopping_rounds': 100,\n",
|
|||
|
|
" # 'one_hot_max_size': 50,\n",
|
|||
|
|
" # 'class_weights': [0.6, 1.2]\n",
|
|||
|
|
" 'task_type': 'GPU'\n",
|
|||
|
|
"}\n",
|
|||
|
|
"\n",
|
|||
|
|
"# model = train_catboost(train_data, test_data, feature_columns_new, catboost_params, plot=True)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 18,
|
|||
|
|
"id": "c8783e36d104ec15",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:23:12.687156Z",
|
|||
|
|
"start_time": "2025-02-23T14:23:12.546231Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from sklearn.linear_model import SGDClassifier, LinearRegression, LogisticRegression\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def train_gbdt_lr(model, X_train: pd.DataFrame, y_train: pd.Series, model_type: str = 'lightgbm'):\n",
|
|||
|
|
" if model_type not in ['lightgbm', 'catboost']:\n",
|
|||
|
|
" raise ValueError(\"model_type must be either 'lightgbm' or 'catboost'\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Step 1: Use the pre-trained GBDT model to transform the data into leaf indices\n",
|
|||
|
|
" X_train_lr = predict(model, X_train, model_type=model_type)\n",
|
|||
|
|
" # Raw predictions\n",
|
|||
|
|
" # Convert raw predictions to leaf indices (CatBoost doesn't directly support leaf indices)\n",
|
|||
|
|
" # X_train_lr = np.array([np.argmax(row) for row in X_train_lr])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # Step 2: One-hot encode the leaf indices to create new features\n",
|
|||
|
|
" # X_train_lr = pd.DataFrame(X_train_lr).add_prefix('tree_') # Add prefix for clarity\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # One-hot encoding for leaf indices\n",
|
|||
|
|
" # X_train_lr = pd.get_dummies(X_train_lr, columns=X_train_lr.columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Step 3: Train a Logistic Regression model on the transformed features\n",
|
|||
|
|
" lr_model = SGDClassifier(loss='log_loss')\n",
|
|||
|
|
" lr_model.fit(X_train_lr, y_train)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return lr_model\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def predict(model, X: pd.DataFrame, model_type: str = 'lightgbm'):\n",
|
|||
|
|
" if model_type == 'lightgbm':\n",
|
|||
|
|
" X_train_lr = model.predict(X, pred_leaf=True) # Get leaf indices for training data\n",
|
|||
|
|
" elif model_type == 'catboost':\n",
|
|||
|
|
" X_train_lr = model.calc_leaf_indexes(X, ntree_start=int(model.tree_count_ * 0.3),\n",
|
|||
|
|
" ntree_end=int(model.tree_count_ * 0.7))\n",
|
|||
|
|
" return X_train_lr"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 19,
|
|||
|
|
"id": "1a169a7d962c7ba7",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:23:12.799750Z",
|
|||
|
|
"start_time": "2025-02-23T14:23:12.720458Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"\n",
|
|||
|
|
"# train_data = train_data.dropna(subset=['lr_label'])\n",
|
|||
|
|
"# lr_model = train_gbdt_lr(model, train_data[feature_columns_new], train_data['label'], model_type='lightgbm')"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 20,
|
|||
|
|
"id": "52a6ffb090e92f84",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:23:12.911717Z",
|
|||
|
|
"start_time": "2025-02-23T14:23:12.834986Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"-0.14485596707818926\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print(train_data['lr_label'].min())"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 35,
|
|||
|
|
"id": "63cd2f8d16f05b38",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:37:28.748351Z",
|
|||
|
|
"start_time": "2025-02-23T14:37:28.538732Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from tqdm import tqdm\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def incremental_training(test_data: pd.DataFrame,\n",
|
|||
|
|
" model, lr_model,\n",
|
|||
|
|
" days: int,\n",
|
|||
|
|
" back_days: int,\n",
|
|||
|
|
" feature_columns: list,\n",
|
|||
|
|
" params: dict,\n",
|
|||
|
|
" model_type: str = 'lightgbm'):\n",
|
|||
|
|
"\n",
|
|||
|
|
" test_data = test_data.sort_values(by='trade_date')\n",
|
|||
|
|
" scores = []\n",
|
|||
|
|
" unique_trade_dates = sorted(test_data['trade_date'].unique())\n",
|
|||
|
|
"\n",
|
|||
|
|
" for i in tqdm(range(0, len(unique_trade_dates), days)):\n",
|
|||
|
|
" # Get the current window of trade dates\n",
|
|||
|
|
" current_dates = unique_trade_dates[i:i + days]\n",
|
|||
|
|
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
|
|||
|
|
" X = window_data[feature_columns]\n",
|
|||
|
|
"\n",
|
|||
|
|
" if lr_model is not None:\n",
|
|||
|
|
" window_scores = lr_model.predict_proba(model.predict(X, pred_leaf=True, num_iteration=1000))[:, -1]\n",
|
|||
|
|
" scores.extend(window_scores)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Prepare data for incremental training\n",
|
|||
|
|
" current_dates = unique_trade_dates[max(0, i - back_days):i + days]\n",
|
|||
|
|
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
|
|||
|
|
" X_train = window_data[feature_columns]\n",
|
|||
|
|
" y_train = window_data['label'] # Assuming 'label' is what you're predicting\n",
|
|||
|
|
" # Incrementally train the model\n",
|
|||
|
|
" if len(y_train.unique()) > 1:\n",
|
|||
|
|
" if model_type == 'lightgbm':\n",
|
|||
|
|
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
|
|||
|
|
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
|
|||
|
|
" model = lgb.train(params,\n",
|
|||
|
|
" train_set=train_data,\n",
|
|||
|
|
" num_boost_round=100,\n",
|
|||
|
|
" init_model=model,\n",
|
|||
|
|
" keep_training_booster=True)\n",
|
|||
|
|
" elif model_type == 'catboost':\n",
|
|||
|
|
" from catboost import Pool\n",
|
|||
|
|
" train_data = Pool(data=X_train, label=y_train,\n",
|
|||
|
|
" cat_features=[col for col in feature_columns if col.startswith('cat')])\n",
|
|||
|
|
" # model.set_params(**params)\n",
|
|||
|
|
" model.fit(train_data, init_model=model)\n",
|
|||
|
|
" lr_model = LogisticRegression(max_iter=10000)\n",
|
|||
|
|
" lr_model.fit(model.predict(X_train, pred_leaf=True, num_iteration=1000), y_train)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print(current_dates)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Add the scores as a new 'score' column to the test_data\n",
|
|||
|
|
" test_data['score'] = scores\n",
|
|||
|
|
" return test_data"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "751a6df9-d90b-4053-8769-c6c3b6654406",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"start_time": "2025-02-23T14:37:28.771726Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
" 9%|▉ | 9/102 [00:56<13:10, 8.50s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 10%|▉ | 10/102 [01:10<15:47, 10.29s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 12%|█▏ | 12/102 [01:37<17:54, 11.94s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 14%|█▎ | 14/102 [02:06<19:15, 13.13s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 15%|█▍ | 15/102 [02:21<20:00, 13.80s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 18%|█▊ | 18/102 [03:06<20:25, 14.59s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 21%|██ | 21/102 [03:48<18:57, 14.05s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 24%|██▎ | 24/102 [05:13<30:33, 23.51s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 25%|██▍ | 25/102 [05:50<35:04, 27.33s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 26%|██▋ | 27/102 [06:19<25:45, 20.60s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 27%|██▋ | 28/102 [06:33<22:54, 18.58s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 28%|██▊ | 29/102 [06:46<20:50, 17.13s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 29%|██▉ | 30/102 [06:59<18:49, 15.68s/it]E:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\sklearn\\linear_model\\_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
|
|||
|
|
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
|
|||
|
|
"\n",
|
|||
|
|
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
|
|||
|
|
"Please also refer to the documentation for alternative solver options:\n",
|
|||
|
|
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
|
|||
|
|
" n_iter_i = _check_optimize_result(\n",
|
|||
|
|
" 30%|███ | 31/102 [07:26<17:01, 14.39s/it]\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"predictions_test = incremental_training(test_data, model, None, 5, 0, feature_columns_new, light_params, model_type='lightgbm')\n",
|
|||
|
|
"predictions_test = predictions_test.loc[predictions_test.groupby('trade_date')['score'].idxmax()]\n",
|
|||
|
|
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "6329ae6f358dabe7",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-02-23T14:44:59.546639600Z",
|
|||
|
|
"start_time": "2025-02-22T21:50:47.121759Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from lightgbm import LGBMClassifier"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"metadata": {
|
|||
|
|
"kernelspec": {
|
|||
|
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
|
"language": "python",
|
|||
|
|
"name": "python3"
|
|||
|
|
},
|
|||
|
|
"language_info": {
|
|||
|
|
"codemirror_mode": {
|
|||
|
|
"name": "ipython",
|
|||
|
|
"version": 3
|
|||
|
|
},
|
|||
|
|
"file_extension": ".py",
|
|||
|
|
"mimetype": "text/x-python",
|
|||
|
|
"name": "python",
|
|||
|
|
"nbconvert_exporter": "python",
|
|||
|
|
"pygments_lexer": "ipython3",
|
|||
|
|
"version": "3.11.11"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"nbformat": 4,
|
|||
|
|
"nbformat_minor": 5
|
|||
|
|
}
|