Files
NewStock/code/train/PlUpdateClassify.ipynb

1017 lines
179 KiB
Plaintext
Raw Normal View History

2025-04-03 00:45:07 +08:00
{
"cells": [
{
"cell_type": "code",
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:09:12.588477Z",
"start_time": "2025-02-21T15:09:12.513776Z"
}
},
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"execution_count": 13
},
{
"cell_type": "code",
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:09:52.492559Z",
"start_time": "2025-02-21T15:09:12.603475Z"
}
},
"source": [
"from utils.utils import read_and_merge_h5_data_polars\n",
"\n",
"print('daily data')\n",
"df = read_and_merge_h5_data_polars('../../data/daily_data.h5', key='daily_data',\n",
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol'],\n",
" df=None)\n",
"\n",
"print('daily basic')\n",
"df = read_and_merge_h5_data_polars('../../data/daily_basic.h5', key='daily_basic',\n",
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
"df = read_and_merge_h5_data_polars('../../data/stk_limit.h5', key='stk_limit',\n",
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
"df = read_and_merge_h5_data_polars('../../data/money_flow.h5', key='money_flow',\n",
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n"
]
}
],
"execution_count": 14
},
{
"cell_type": "code",
"id": "a4eec8c93f6a7cc3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:09:53.009911Z",
"start_time": "2025-02-21T15:09:52.621225Z"
}
},
"source": [
"print('industry')\n",
"df = read_and_merge_h5_data_polars('../../data/industry_data.h5', key='industry_data',\n",
" columns=['ts_code', 'l2_code'],\n",
" df=df, on=['ts_code'], join='left')\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n",
"left merge on ['ts_code']\n"
]
}
],
"execution_count": 15
},
{
"cell_type": "code",
"id": "4243c838-1775-4c33-8ec5-38d3fae3e55c",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:09:53.087600Z",
"start_time": "2025-02-21T15:09:53.016442Z"
}
},
"source": [
"import polars as pl\n",
"\n",
"def print_df_info(df: pl.DataFrame):\n",
" # 获取行列信息\n",
" shape = df.shape\n",
" print(f\"Shape: {shape[0]} rows, {shape[1]} columns\")\n",
" \n",
" # 获取内存占用大小\n",
" mem_size = df.estimated_size() # 单位是字节\n",
" print(f\"Memory usage: {mem_size / (1024 ** 2):.2f} MB\") # 转换为 MB\n",
" \n",
" # 获取列名和每列的类型\n",
" print(\"\\nColumn types:\")\n",
" for col_name, dtype in zip(df.columns, df.dtypes):\n",
" print(f\"{col_name}: {dtype}\")\n",
"\n",
"print_df_info(df)\n",
"\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape: 8418162 rows, 22 columns\n",
"Memory usage: 1380.84 MB\n",
"\n",
"Column types:\n",
"ts_code: String\n",
"trade_date: String\n",
"open: Float64\n",
"close: Float64\n",
"high: Float64\n",
"low: Float64\n",
"vol: Float64\n",
"turnover_rate: Float64\n",
"pe_ttm: Float64\n",
"circ_mv: Float64\n",
"volume_ratio: Float64\n",
"is_st: Boolean\n",
"up_limit: Float64\n",
"down_limit: Float64\n",
"buy_sm_vol: Int64\n",
"sell_sm_vol: Int64\n",
"buy_lg_vol: Int64\n",
"sell_lg_vol: Int64\n",
"buy_elg_vol: Int64\n",
"sell_elg_vol: Int64\n",
"net_mf_vol: Int64\n",
"l2_code: String\n"
]
}
],
"execution_count": 16
},
{
"cell_type": "code",
"id": "99776e73-f310-47c0-953e-2cf73ff13310",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:09:53.229940Z",
"start_time": "2025-02-21T15:09:53.151622Z"
}
},
"source": [
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"def get_technical_factor(df: pl.LazyFrame) -> pl.LazyFrame:\n",
" # 按股票和日期排序\n",
" df = df.sort(['ts_code', 'trade_date'])\n",
" \n",
" # 计算 up 和 down\n",
" df = df.with_columns([\n",
" ((pl.col('high') - pl.max_horizontal(['close', 'open'])) / pl.col('close')).alias('up'),\n",
" ((pl.min_horizontal(['close', 'open']) - pl.col('low')) / pl.col('close')).alias('down')\n",
" ])\n",
" \n",
" # 定义一个函数来计算分组指标(如 ATR、OBV、RSI 等)\n",
" def calculate_talib_indicator(col_names, func, *args, **kwargs):\n",
" return (\n",
" df.group_by('ts_code', maintain_order=True)\n",
" .agg(\n",
" pl.map_batches(\n",
" col_names,\n",
" lambda cols: pl.Series(func(*[col.to_numpy() for col in cols], **kwargs))\n",
" ).alias(kwargs.get('output_col'))\n",
" )\n",
" .explode(kwargs.get('output_col'))\n",
" )\n",
"\n",
" # 计算 ATR\n",
" atr_14_df = calculate_talib_indicator(\n",
" ['high', 'low', 'close'], talib.ATR, timeperiod=14, output_col='atr_14'\n",
" )\n",
" atr_6_df = calculate_talib_indicator(\n",
" ['high', 'low', 'close'], talib.ATR, timeperiod=6, output_col='atr_6'\n",
" )\n",
"\n",
" # 合并 ATR 列\n",
" df = df.join(atr_14_df, on='ts_code', how='left').join(atr_6_df, on='ts_code', how='left')\n",
"\n",
" # 计算 OBV 及其均线\n",
" obv_df = calculate_talib_indicator(\n",
" ['close', 'vol'], talib.OBV, output_col='obv'\n",
" )\n",
" maobv_6_df = obv_df.with_columns(\n",
" pl.map_batches(\n",
" ['obv'],\n",
" lambda cols: pl.Series(talib.SMA(cols[0].to_numpy(), timeperiod=6))\n",
" ).alias('maobv_6')\n",
" )\n",
"\n",
" # 合并 OBV 和 MAOBV 列\n",
" df = df.join(obv_df, on='ts_code', how='left').join(maobv_6_df, on='ts_code', how='left')\n",
" df = df.with_columns((pl.col('obv') - pl.col('maobv_6')).alias('obv-maobv_6'))\n",
"\n",
" # 计算 RSI\n",
" rsi_3_df = calculate_talib_indicator(\n",
" ['close'], talib.RSI, timeperiod=3, output_col='rsi_3'\n",
" )\n",
" rsi_6_df = calculate_talib_indicator(\n",
" ['close'], talib.RSI, timeperiod=6, output_col='rsi_6'\n",
" )\n",
" rsi_9_df = calculate_talib_indicator(\n",
" ['close'], talib.RSI, timeperiod=9, output_col='rsi_9'\n",
" )\n",
"\n",
" # 合并 RSI 列\n",
" df = (\n",
" df.join(rsi_3_df, on='ts_code', how='left')\n",
" .join(rsi_6_df, on='ts_code', how='left')\n",
" .join(rsi_9_df, on='ts_code', how='left')\n",
" )\n",
"\n",
" # 计算 return_5, return_10, return_20\n",
" df = df.with_columns([\n",
" (pl.col('close') / pl.col('close').shift(5) - 1).over('ts_code').alias('return_5'),\n",
" (pl.col('close') / pl.col('close').shift(10) - 1).over('ts_code').alias('return_10'),\n",
" (pl.col('close') / pl.col('close').shift(20) - 1).over('ts_code').alias('return_20')\n",
" ])\n",
"\n",
" # 计算 avg_close_5\n",
" df = df.with_columns(\n",
" (pl.col('close').rolling_mean(window_size=5) / pl.col('close')).over('ts_code').alias('avg_close_5')\n",
" )\n",
"\n",
" # 计算标准差指标\n",
" df = df.with_columns([\n",
" pl.col('close').pct_change().rolling_std(window_size=5).over('ts_code').alias('std_return_5'),\n",
" pl.col('close').pct_change().rolling_std(window_size=15).over('ts_code').alias('std_return_15'),\n",
" pl.col('close').pct_change().rolling_std(window_size=25).over('ts_code').alias('std_return_25'),\n",
" pl.col('close').pct_change().rolling_std(window_size=90).over('ts_code').alias('std_return_90'),\n",
" pl.col('close').shift(10).pct_change().rolling_std(window_size=90).over('ts_code').alias('std_return_90_2')\n",
" ])\n",
"\n",
" # 计算比值指标\n",
" df = df.with_columns([\n",
" (pl.col('std_return_5') / pl.col('std_return_90')).alias('std_return_5 / std_return_90'),\n",
" (pl.col('std_return_5') / pl.col('std_return_25')).alias('std_return_5 / std_return_25')\n",
" ])\n",
"\n",
" # 计算标准差差值\n",
" df = df.with_columns(\n",
" (pl.col('std_return_90') - pl.col('std_return_90_2')).alias('std_return_90 - std_return_90_2')\n",
" )\n",
"\n",
" return df\n",
"\n",
"def get_act_factor(df: pl.LazyFrame, cat=True) -> pl.LazyFrame:\n",
" # 按股票和日期排序\n",
" df = df.sort(['ts_code', 'trade_date'])\n",
" \n",
" # 定义一个函数来计算分组 EMA\n",
" def calculate_ema(col_name, timeperiod):\n",
" return (\n",
" df.group_by('ts_code', maintain_order=True)\n",
" .agg(\n",
" pl.map_batches(\n",
" [col_name],\n",
" lambda cols: pl.Series(talib.EMA(cols[0].to_numpy(), timeperiod=timeperiod))\n",
" ).alias(f\"ema_{timeperiod}\")\n",
" )\n",
" .explode(f\"ema_{timeperiod}\")\n",
" )\n",
" \n",
" # 计算 EMA 指标\n",
" ema_columns = []\n",
" for period in [5, 13, 20, 60]:\n",
" ema_df = calculate_ema('close', period)\n",
" ema_columns.append(ema_df.select([f\"ts_code\", f\"ema_{period}\"]))\n",
" \n",
" # 将所有 EMA 列合并到原始 DataFrame 中\n",
" for ema_col in ema_columns:\n",
" df = df.join(ema_col, on='ts_code', how='left')\n",
" \n",
" # 使用 NumPy 的 arctan 和 sqrt 计算 act_factor1, act_factor2, act_factor3, act_factor4\n",
" df = df.with_columns([\n",
" (pl.map_batches(\n",
" ['ema_5'],\n",
" lambda cols: pl.Series(np.arctan((cols[0] / cols[0].shift(1) - 1) * 100) * 57.3 / 50)\n",
" )).alias('act_factor1'),\n",
" (pl.map_batches(\n",
" ['ema_13'],\n",
" lambda cols: pl.Series(np.arctan((cols[0] / cols[0].shift(1) - 1) * 100) * 57.3 / 40)\n",
" )).alias('act_factor2'),\n",
" (pl.map_batches(\n",
" ['ema_20'],\n",
" lambda cols: pl.Series(np.arctan((cols[0] / cols[0].shift(1) - 1) * 100) * 57.3 / 21)\n",
" )).alias('act_factor3'),\n",
" (pl.map_batches(\n",
" ['ema_60'],\n",
" lambda cols: pl.Series(np.arctan((cols[0] / cols[0].shift(1) - 1) * 100) * 57.3 / 10)\n",
" )).alias('act_factor4')\n",
" ])\n",
" \n",
" if cat:\n",
" df = df.with_columns([\n",
" (pl.col('act_factor1') > 0).alias('cat_af1'),\n",
" (pl.col('act_factor2') > pl.col('act_factor1')).alias('cat_af2'),\n",
" (pl.col('act_factor3') > pl.col('act_factor2')).alias('cat_af3'),\n",
" (pl.col('act_factor4') > pl.col('act_factor3')).alias('cat_af4')\n",
" ])\n",
" \n",
" # 计算 act_factor5 和 act_factor6\n",
" df = df.with_columns([\n",
" (pl.col('act_factor1') + pl.col('act_factor2') + pl.col('act_factor3') + pl.col('act_factor4')).alias('act_factor5'),\n",
" ((pl.col('act_factor1') - pl.col('act_factor2')) / \n",
" pl.map_batches(\n",
" ['act_factor1', 'act_factor2'],\n",
" lambda cols: pl.Series(np.sqrt(cols[0]**2 + cols[1]**2))\n",
" )).alias('act_factor6')\n",
" ])\n",
" \n",
" # 根据 trade_date 截面计算排名\n",
" df = df.with_columns([\n",
" pl.col('act_factor1').rank(method='average', descending=True).over('trade_date').alias('rank_act_factor1'),\n",
" pl.col('act_factor2').rank(method='average', descending=True).over('trade_date').alias('rank_act_factor2'),\n",
" pl.col('act_factor3').rank(method='average', descending=True).over('trade_date').alias('rank_act_factor3')\n",
" ])\n",
" \n",
" return df"
],
"outputs": [],
"execution_count": 17
},
{
"cell_type": "code",
"id": "53f86ddc0677a6d7",
"metadata": {
"scrolled": true,
"ExecuteTime": {
"end_time": "2025-02-21T15:09:53.292850Z",
"start_time": "2025-02-21T15:09:53.234877Z"
}
},
"source": [
"origin_columns = df.columns\n",
"origin_columns = [col for col in origin_columns if col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'l2_code']]\n",
"# origin_columns = [col for col in origin_columns if col not in index_data.columns]\n"
],
"outputs": [],
"execution_count": 18
},
{
"cell_type": "code",
"id": "5f3d9aece75318cd",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:09:55.552350Z",
"start_time": "2025-02-21T15:09:53.310664Z"
}
},
"source": [
"def filter_data(df):\n",
" # 转换为 lazy frame\n",
" df = df.lazy()\n",
"\n",
" df = df.filter(~pl.col('is_st')) # 去掉 is_st 为 True 的行\n",
" df = df.filter(~pl.col('ts_code').str.contains('BJ$')) # ts_code 以 'BJ' 结尾的行\n",
" df = df.filter(~pl.col('ts_code').str.contains('^30')) # ts_code 以 '30' 开头的行\n",
" df = df.filter(~pl.col('ts_code').str.contains('^68')) # ts_code 以 '68' 开头的行\n",
" df = df.filter(~pl.col('ts_code').str.contains('^8')) # ts_code 以 '8' 开头的行\n",
"\n",
" return df\n",
"\n",
"\n",
"\n",
"tdf = filter_data(df)\n",
"tdf = get_technical_factor(tdf)\n",
"tdf = get_act_factor(tdf)\n",
"# df = get_money_flow_factor(df)\n",
"# df = get_alpha_factor(df)\n",
"# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\n",
"# df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print_df_info(tdf.collect())"
],
"outputs": [
{
"ename": "PanicException",
"evalue": "python function failed: ATR() got an unexpected keyword argument 'output_col'",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mPanicException\u001B[0m Traceback (most recent call last)",
"Cell \u001B[1;32mIn[19], line 24\u001B[0m\n\u001B[0;32m 17\u001B[0m tdf \u001B[38;5;241m=\u001B[39m get_act_factor(tdf)\n\u001B[0;32m 18\u001B[0m \u001B[38;5;66;03m# df = get_money_flow_factor(df)\u001B[39;00m\n\u001B[0;32m 19\u001B[0m \u001B[38;5;66;03m# df = get_alpha_factor(df)\u001B[39;00m\n\u001B[0;32m 20\u001B[0m \u001B[38;5;66;03m# df = df.merge(industry_df, on=['l2_code', 'trade_date'], how='left')\u001B[39;00m\n\u001B[0;32m 21\u001B[0m \u001B[38;5;66;03m# df = df.rename(columns={'l2_code': 'cat_l2_code'})\u001B[39;00m\n\u001B[0;32m 22\u001B[0m \u001B[38;5;66;03m# df = df.merge(index_data, on='trade_date', how='left')\u001B[39;00m\n\u001B[1;32m---> 24\u001B[0m print_df_info(\u001B[43mtdf\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcollect\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m)\n",
"File \u001B[1;32mE:\\Python\\anaconda\\envs\\try_trader\\lib\\site-packages\\polars\\lazyframe\\frame.py:2053\u001B[0m, in \u001B[0;36mLazyFrame.collect\u001B[1;34m(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, collapse_joins, no_optimization, streaming, engine, background, _eager, **_kwargs)\u001B[0m\n\u001B[0;32m 2051\u001B[0m \u001B[38;5;66;03m# Only for testing purposes\u001B[39;00m\n\u001B[0;32m 2052\u001B[0m callback \u001B[38;5;241m=\u001B[39m _kwargs\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpost_opt_callback\u001B[39m\u001B[38;5;124m\"\u001B[39m, callback)\n\u001B[1;32m-> 2053\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m wrap_df(\u001B[43mldf\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mcollect\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcallback\u001B[49m\u001B[43m)\u001B[49m)\n",
"\u001B[1;31mPanicException\u001B[0m: python function failed: ATR() got an unexpected keyword argument 'output_col'"
]
}
],
"execution_count": 19
},
{
"cell_type": "code",
"id": "f4f16d63ad18d1bc",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:09:55.555352100Z",
"start_time": "2025-02-21T15:06:05.193979Z"
}
},
"source": [
"feature_columns = [col for col in tdf.columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"print(feature_columns)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['turnover_rate', 'pe_ttm', 'volume_ratio', 'l2_code', 'up', 'down', 'atr_14', 'atr_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_5', 'return_10', 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15', 'std_return_25', 'std_return_90', 'std_return_90_2', 'std_return_5 / std_return_90', 'std_return_5 / std_return_25', 'std_return_90 - std_return_90_2']\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_6468\\1049439272.py:1: PerformanceWarning: Determining the column names of a LazyFrame requires resolving its schema, which is a potentially expensive operation. Use `LazyFrame.collect_schema().names()` to get the column names without this warning.\n",
" feature_columns = [col for col in tdf.columns if col not in ['trade_date',\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "code",
"id": "0ebdfb92-d88b-4b5c-a715-675dab876fc0",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.360476Z",
"start_time": "2025-02-21T15:06:05.298420Z"
}
},
"source": [
"def create_deviation_within_dates(df, feature_columns):\n",
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
" new_columns = {}\n",
" ret_feature_columns = feature_columns[:]\n",
"\n",
" # 自动选择所有数值型特征\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
"\n",
" # 遍历所有数值型特征\n",
" for feature in num_features:\n",
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
" continue\n",
"\n",
" grouped_median = df.groupby(['trade_date', groupby_col])[feature].transform('median')\n",
" deviation_col_name = f'deviation_median_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_median\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" grouped_mean = df.groupby(groupby_col)[feature].transform('mean')\n",
" deviation_col_name = f'deviation_mean_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
"\n",
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
"\n",
" return df, ret_feature_columns\n"
],
"outputs": [],
"execution_count": 9
},
{
"cell_type": "code",
"id": "fbb968383f8cf2c7",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.485541Z",
"start_time": "2025-02-21T15:06:05.392972Z"
}
},
"source": [
"def get_qcuts(series, quantiles):\n",
" q = pd.qcut(series, q=quantiles, labels=False, duplicates='drop')\n",
" return q[-1] # 返回窗口最后一个元素的分位数标签\n",
"\n",
"\n",
"window = 5\n",
"quantiles = 20\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(level=0, drop=True)\n",
" df['sharpe_ratio'] = df['future_return'] * df['future_volatility']\n",
" df['sharpe_ratio'].replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return df['sharpe_ratio']\n",
"\n",
"def get_label(df):\n",
" # labels = df['future_af13'] - df['act_factor1']\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-4)\n",
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
" df['future_high'] = df.groupby('ts_code')['high'].shift(-1)\n",
" df['future_return'] = (df['future_close'] - df['close']) / df['close']\n",
" labels = df['future_return'] >= 0.03\n",
" # labels = df['future_af11']\n",
" # labels = df['ema_5'].shift(-1) - df['close']\n",
" # df['label'] = (df['future_af11'] - df['act_factor1']) / df['act_factor1']\n",
" # df['label'] = calculate_risk_adjusted_target(df, days=5)\n",
" # lower_percentile = df['label'].quantile(0.01) # 1%分位数\n",
" # upper_percentile = df['label'].quantile(0.99) # 99%分位数\n",
" # labels = df['label'].clip(lower=lower_percentile, upper=upper_percentile)\n",
" # labels = calculate_risk_adjusted_return(df, days=3, history_days=3, method='ratio')\n",
" return labels\n",
"\n",
"df['label'] = get_label(df)\n",
"# df = df.apply(lambda x: x.astype('float32') if x.dtype in ['float64', 'float32'] else x)\n",
"df = df.sort_values(by=['trade_date', 'ts_code'])\n",
"train_data = df[(df['trade_date'] <= '2023-01-01') & (df['trade_date'] >= '2016-01-01')]\n",
"test_data = df[df['trade_date'] >= '2023-01-01']\n",
"\n",
"train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# train_data = train_data.rename(columns={'l2_code': 'cat_l2_code'})\n",
"test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# test_data = test_data.rename(columns={'l2_code': 'cat_l2_code'})\n",
"\n",
"train_data = train_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
"test_data = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1000, 'return_20'))\n",
"\n",
"# train_data = get_future_data(train_data)\n",
"\n",
"# df = df[['ts_code', 'trade_date', 'open', 'close']]\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n"
],
"outputs": [
{
"ename": "AttributeError",
"evalue": "'DataFrame' object has no attribute 'groupby'",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mAttributeError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[1;32mIn[10], line 38\u001B[0m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;66;03m# labels = df['future_af11']\u001B[39;00m\n\u001B[0;32m 29\u001B[0m \u001B[38;5;66;03m# labels = df['ema_5'].shift(-1) - df['close']\u001B[39;00m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;66;03m# df['label'] = (df['future_af11'] - df['act_factor1']) / df['act_factor1']\u001B[39;00m\n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;66;03m# labels = df['label'].clip(lower=lower_percentile, upper=upper_percentile)\u001B[39;00m\n\u001B[0;32m 35\u001B[0m \u001B[38;5;66;03m# labels = calculate_risk_adjusted_return(df, days=3, history_days=3, method='ratio')\u001B[39;00m\n\u001B[0;32m 36\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m labels\n\u001B[1;32m---> 38\u001B[0m df[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlabel\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[43mget_label\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdf\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 39\u001B[0m \u001B[38;5;66;03m# df = df.apply(lambda x: x.astype('float32') if x.dtype in ['float64', 'float32'] else x)\u001B[39;00m\n\u001B[0;32m 40\u001B[0m df \u001B[38;5;241m=\u001B[39m df\u001B[38;5;241m.\u001B[39msort_values(by\u001B[38;5;241m=\u001B[39m[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mtrade_date\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mts_code\u001B[39m\u001B[38;5;124m'\u001B[39m])\n",
"Cell \u001B[1;32mIn[10], line 23\u001B[0m, in \u001B[0;36mget_label\u001B[1;34m(df)\u001B[0m\n\u001B[0;32m 21\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mget_label\u001B[39m(df):\n\u001B[0;32m 22\u001B[0m \u001B[38;5;66;03m# labels = df['future_af13'] - df['act_factor1']\u001B[39;00m\n\u001B[1;32m---> 23\u001B[0m df[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfuture_close\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m=\u001B[39m \u001B[43mdf\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgroupby\u001B[49m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mts_code\u001B[39m\u001B[38;5;124m'\u001B[39m)[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mclose\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;241m.\u001B[39mshift(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m4\u001B[39m)\n\u001B[0;32m 24\u001B[0m df[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfuture_open\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m=\u001B[39m df\u001B[38;5;241m.\u001B[39mgroupby(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mts_code\u001B[39m\u001B[38;5;124m'\u001B[39m)[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mopen\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;241m.\u001B[39mshift(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)\n\u001B[0;32m 25\u001B[0m df[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mfuture_high\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m=\u001B[39m df\u001B[38;5;241m.\u001B[39mgroupby(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mts_code\u001B[39m\u001B[38;5;124m'\u001B[39m)[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mhigh\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;241m.\u001B[39mshift(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)\n",
"\u001B[1;31mAttributeError\u001B[0m: 'DataFrame' object has no attribute 'groupby'"
]
}
],
"execution_count": 10
},
{
"cell_type": "code",
"execution_count": 12,
"id": "de8c2f6c770d2439",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.491525100Z",
"start_time": "2025-02-20T16:12:48.951414Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol',\n",
" 'turnover_rate', 'pe_ttm', 'circ_mv', 'volume_ratio', 'is_st',\n",
" 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol',\n",
" 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol',\n",
" 'cat_l2_code', 'up', 'down', 'atr_14', 'atr_6', 'obv', 'maobv_6',\n",
" 'obv-maobv_6', 'rsi_3', 'rsi_6', 'rsi_9', 'return_5', 'return_10',\n",
" 'return_20', 'avg_close_5', 'std_return_5', 'std_return_15',\n",
" 'std_return_25', 'std_return_90', 'std_return_90_2',\n",
" 'std_return_5 / std_return_90', 'std_return_5 / std_return_25',\n",
" 'std_return_90 - std_return_90_2', 'ema_5', 'ema_13', 'ema_20',\n",
" 'ema_60', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4',\n",
" 'cat_af1', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5',\n",
" 'act_factor6', 'rank_act_factor1', 'rank_act_factor2',\n",
" 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big',\n",
" 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol',\n",
" 'buy_elg_vol_minus_sell_elg_vol', 'log(circ_mv)', 'alpha_022',\n",
" 'alpha_003', 'alpha_007', 'alpha_013', 'future_close', 'future_open',\n",
" 'future_high', 'future_return', 'label', 'industry_obv',\n",
" 'industry_return_5', 'industry_obv_deviation',\n",
" 'industry_return_5_deviation', 'industry_return_5_percentile'],\n",
" dtype='object')\n"
]
}
],
"source": [
"print(train_data.columns)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "20ffa7229c9d2f86",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.492528800Z",
"start_time": "2025-02-20T16:12:49.145792Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feature_columns size: 149\n",
"feature_columns size: 149\n",
"1164839\n",
"最小日期: 2017-03-21\n",
"最大日期: 2022-12-30\n",
"397977\n",
"最小日期: 2023-01-03\n",
"最大日期: 2025-02-12\n"
]
}
],
"source": [
"feature_columns_new = feature_columns[:]\n",
"train_data, feature_columns_new = create_deviation_within_dates(train_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"test_data, feature_columns_new = create_deviation_within_dates(test_data, feature_columns)\n",
"print(f'feature_columns size: {len(feature_columns_new)}')\n",
"\n",
"train_data = train_data.dropna(subset=feature_columns_new)\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"\n",
"test_data = test_data.dropna(subset=feature_columns_new)\n",
"test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"print(len(train_data))\n",
"print(f\"最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "35238cb4f45ce756",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.507525300Z",
"start_time": "2025-02-20T16:13:07.227911Z"
}
},
"outputs": [],
"source": [
"cat_columns = [col for col in df.columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8f134d435f71e9e2",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.508526600Z",
"start_time": "2025-02-20T16:13:07.436171Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from catboost import Pool\n",
"import lightgbm as lgb\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import optuna\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.metrics import mean_absolute_error\n",
"import os\n",
"import json\n",
"import pickle\n",
"import hashlib\n",
"\n",
"def train_light_model(train_data_df, test_data_df, params, feature_columns, callbacks, evals,\n",
" print_feature_importance=True, num_boost_round=100,\n",
" use_optuna=False):\n",
" train_data_df, test_data_df = train_data_df.dropna(subset=['label']), test_data_df.dropna(subset=['label'])\n",
" X_train = train_data_df[feature_columns]\n",
" y_train = train_data_df['label']\n",
"\n",
" X_val = test_data_df[feature_columns]\n",
" y_val = test_data_df['label']\n",
"\n",
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" print(f'categorical_feature: {categorical_feature}')\n",
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
" val_data = lgb.Dataset(X_val, label=y_val, categorical_feature=categorical_feature)\n",
" model = lgb.train(\n",
" params, train_data, num_boost_round=num_boost_round,\n",
" valid_sets=[train_data, val_data], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" if print_feature_importance:\n",
" lgb.plot_metric(evals)\n",
" # lgb.plot_tree(model, figsize=(20, 8))\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
" return model\n",
"\n",
"\n",
"from catboost import CatBoostClassifier\n",
"import pandas as pd\n",
"\n",
"\n",
"def train_catboost(train_data_df, test_data_df, feature_columns, params=None, plot=False):\n",
"\n",
" train_data_df, test_data_df = train_data_df.dropna(subset=['label']), test_data_df.dropna(subset=['label'])\n",
" X_train = train_data_df[feature_columns]\n",
" y_train = train_data_df['label']\n",
"\n",
" X_val = test_data_df[feature_columns]\n",
" y_val = test_data_df['label']\n",
"\n",
" cat_features = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" print(f'cat_features: {cat_features}')\n",
" train_pool = Pool(data=X_train, label=y_train, cat_features=cat_features)\n",
" val_pool = Pool(data=X_val, label=y_val, cat_features=cat_features)\n",
"\n",
"\n",
" model = CatBoostClassifier(**params)\n",
" model.fit(train_pool,\n",
" eval_set=val_pool, plot=plot)\n",
"\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "4a4542e1ed6afe7d",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.509525700Z",
"start_time": "2025-02-20T16:13:09.199245Z"
}
},
"outputs": [],
"source": [
"light_params = {\n",
" 'objective': 'binary',\n",
" 'metric': 'average_precision',\n",
" 'learning_rate': 0.05,\n",
" 'is_unbalance': True,\n",
" 'num_leaves': 2048,\n",
" 'min_data_in_leaf': 1024,\n",
" 'max_depth': 32,\n",
" 'max_bin': 1024,\n",
" 'feature_fraction': 0.7,\n",
" 'bagging_fraction': 0.7,\n",
" 'bagging_freq': 5,\n",
" # 'lambda_l1': 80,\n",
" # 'lambda_l2': 65,\n",
" 'verbosity': -1,\n",
" 'num_threads' : 16\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "beeb098799ecfa6a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.520525800Z",
"start_time": "2025-02-20T16:13:09.280097Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 1164839\n",
"categorical_feature: [3, 34, 35, 36, 37]\n",
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
"[82]\ttrain's average_precision: 0.534659\tvalid's average_precision: 0.281957\n",
"Evaluated only: average_precision\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHFCAYAAAAOmtghAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABigElEQVR4nO3deVzUdf4H8NfMMAz3fd8oKALiAZZH3uJ9V5qk6apbRlZmZba5v9Qszd01t1213K3czMxKLTPSUBHPPEFUvC8Uhlu5YYaZ7++PkYER1HEcGBhez8eDh8zn+53vvOctMi8/30skCIIAIiIiIjMhNnUBRERERMbEcENERERmheGGiIiIzArDDREREZkVhhsiIiIyKww3REREZFYYboiIiMisMNwQERGRWWG4ISIiIrPCcEPUhNatWweRSASRSIS9e/fWWy4IAkJCQiASidCvXz+DXmP16tVYt27dIz1n7969963JWBrrNZqi9vtJT0/HwoULcf369UbZ/sKFCyESiQx6rin7QmRqDDdEJmBvb48vvvii3nhycjKuXLkCe3t7g7dtSLjp2rUrDh8+jK5duxr8uqZiytrT09OxaNGiRgs3M2fOxOHDhw16bkv+OyV6XAw3RCYwceJEbN68GcXFxTrjX3zxBXr06IGAgIAmqUOpVKK6uhoODg7o3r07HBwcmuR1jaEl1l5eXv5I6/v5+aF79+4GvVZL6guRsTHcEJnApEmTAAAbN27UjhUVFWHz5s2YPn16g89RKBRYsmQJwsLCIJPJ4O7ujj/96U/Iy8vTrhMUFISzZ88iOTlZu/srKCgIQO1uivXr1+PNN9+Er68vZDIZLl++fN9dGEeOHMGoUaPg6uoKKysrtG3bFnPmzHno+zt//jyGDh0KGxsbuLm5YdasWSgpKam3XlBQEKZNm1ZvvF+/fjq75R619mnTpsHOzg6XL1/G8OHDYWdnB39/f7z55puoqqrSea1bt27hmWeegb29PZycnPD888/j2LFjEIlED5wBW7duHZ599lkAQP/+/bX9rnlOv379EBkZiX379qFnz56wsbHR/t1u2rQJgwcPhre3N6ytrdGhQwfMnz8fZWVlOq/R0G6poKAgjBw5Ejt27EDXrl1hbW2NsLAwfPnllzrrmaovRM0Bww2RCTg4OOCZZ57R+UDauHEjxGIxJk6cWG99tVqNMWPGYNmyZYiLi8Ovv/6KZcuWITExEf369UNFRQUAYOvWrWjTpg26dOmCw4cP4/Dhw9i6davOtt59911kZGTgs88+wy+//AIPD48Ga9y5cyd69+6NjIwMrFixAr/99hsWLFiAnJycB763nJwc9O3bF2fOnMHq1auxfv16lJaWYvbs2Y/apnr0rR3QzOyMHj0aAwcOxM8//4zp06fjk08+wccff6xdp6ysDP3790dSUhI+/vhjfP/99/D09Gzw7+BeI0aMwEcffQQAWLVqlbbfI0aM0K4jl8sxefJkxMXFISEhAfHx8QCAS5cuYfjw4fjiiy+wY8cOzJkzB99//z1GjRqlVx9OnTqFN998E2+88QZ+/vlnREVFYcaMGdi3b99Dn9vYfSFqFgQiajJfffWVAEA4duyYkJSUJAAQzpw5IwiCIHTr1k2YNm2aIAiCEBERIfTt21f7vI0bNwoAhM2bN+ts79ixYwIAYfXq1dqxe59bo+b1+vTpc99lSUlJ2rG2bdsKbdu2FSoqKh7pPb7zzjuCSCQSUlNTdcZjY2PrvUZgYKAwderUetvo27evznt41NqnTp0qABC+//57nXWHDx8utG/fXvt41apVAgDht99+01nvpZdeEgAIX3311QPf6w8//FDvteu+BwDC7t27H7gNtVotKJVKITk5WQAgnDp1Srvs/fffF+79NR0YGChYWVkJN27c0I5VVFQILi4uwksvvaQdM2VfiEyNMzdEJtK3b1+0bdsWX375JU6fPo1jx47dd5fU9u3b4eTkhFGjRqG6ulr71blzZ3h5eT3SGTFPP/30Q9e5ePEirly5ghkzZsDKykrvbQNAUlISIiIi0KlTJ53xuLi4R9pOQ/SpvYZIJKo3ExIVFYUbN25oHycnJ8Pe3h5Dhw7VWa9mt+HjcnZ2xoABA+qNX716FXFxcfDy8oJEIoFUKkXfvn0BAOfOnXvodjt37qxzXJaVlRXatWun897upzn0haixWZi6AKLWSiQS4U9/+hM+/fRTVFZWol27dujdu3eD6+bk5ODOnTuwtLRscHl+fr7er+vt7f3QdWqO4/Hz89N7uzUKCgoQHBxcb9zLy+uRt3UvfWqvYWNjUy+YyWQyVFZWah8XFBTA09Oz3nMbGjNEQ/WWlpaid+/esLKywpIlS9CuXTvY2Njg5s2bGD9+vHYX44O4urrWG5PJZHo9tzn0haixMdwQmdC0adPwf//3f/jss8/w4Ycf3nc9Nzc3uLq6YseOHQ0uf5RTx/W5boq7uzsAzUGlj8rV1RXZ2dn1xhsas7KyqncgK6AJa25ubvXGDb3my/24urri6NGj9cYbqtUQDdW7Z88eZGVlYe/evdrZGgC4c+eOUV7TGBq7L0SNjbuliEzI19cXb7/9NkaNGoWpU6fed72RI0eioKAAKpUKMTEx9b7at2+vXVff/8E/SLt27bS7zBoKHw/Sv39/nD17FqdOndIZ//bbb+utGxQUhLS0NJ2xixcv4sKFC49etAH69u2LkpIS/Pbbbzrj3333nV7Pl8lkAPBI/a4JPDXPrfH555/rvY3G9rh9ITI1ztwQmdiyZcseus5zzz2HDRs2YPjw4Xj99dfxxBNPQCqV4tatW0hKSsKYMWMwbtw4AEDHjh3x3XffYdOmTWjTpg2srKzQsWPHR65r1apVGDVqFLp374433ngDAQEByMjIwM6dO7Fhw4b7Pm/OnDn48ssvMWLECCxZsgSenp7YsGEDzp8/X2/dKVOmYPLkyYiPj8fTTz+NGzduYPny5dqZo8Y2depUfPLJJ5g8eTKWLFmCkJAQ/Pbbb9i5cycAQCx+8P//IiMjAQBr166Fvb09rKysEBwc3OBuoxo9e/aEs7MzZs2ahffffx9SqRQbNmyoFwZN6XH7QmRq/AklagEkEgm2bduGv/zlL9iyZQvGjRuHsWPHYtmyZfXCy6JFi9C3b1/8+c9/xhNPPKH36cX3GjJkCPbt2wdvb2+89tprGDp0KBYvXvzQ4y68vLyQnJyM8PBwvPzyy5g8eTKsrKzw73//u966cXFxWL58OXbu3ImRI0dizZo1WLNmDdq1a2dQzY/K1tYWe/bsQb9+/TBv3jw8/fTTyMjIwOrVqwEATk5OD3x+cHAwVq5ciVOnTqFfv37o1q0bfvnllwc+x9XVFb/++itsbGwwefJkTJ8+HXZ2dti0aZOx3tZje9y+EJmaSBAEwdRFEBE1Jx999BEWLFiAjIwMgw6qNlfsC7UU3C1FRK1azYxSWFgYlEol9uzZg08//RSTJ09u1R/g7Au1ZAw3RNSq2djY4JNPPsH169dRVVWFgIAAvPPOO1iwYIGpSzMp9oVaMu6WIiIiIrPCA4qJiIjIrDDcEBERkVlhuCEiIiKz0ioPKFar1cjKyoK9vb3RL+dOREREjUMQBJSUlMDHx+eBF5NsleEmKysL/v7+pi6DiIiIDHDz5s0HXpKgVYabmpsMXrt2DS4uLiaupnlTKpX4/fffMXjwYEilUlOX06yxV/pjr/THXumPvdJfS+1VcXEx/P39H3qz4FYZbmp2Rdnb28PBwcHE1TRvSqUSNjY2cHBwaFH/AEyBvdIfe6U/9kp/7JX+WnqvHnZICQ8oJiIiIrPCcENERERmheGGiIiIzEqrPOaGiIioMahUKiiVSlOX8VBKpRIWFhaorKyESqUydTlaUqkUEonksbfDcENERPSYBEFAdnY27ty5Y+pS9CIIAry8vHDz5s1md703JycneHl5PVZdDDdERESPqSbYeHh4wMbGptkFhnup1WqUlpbCzs7ugRfDa0qCIKC8vBy5ubkAAG9vb4O3xXBDRET0GFQqlTbYuLq6mrocvajVaigUClhZWTWbcAMA1tbWAID
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA48AAAHFCAYAAABSNcwsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVyN6f/48ddpXy2lVES2stNYs2bsWceQyBAmg7EmkjX7NrIOxpYtjJnwsZsQY9cwaCwhQvZtNIrW8/vDr/vbcU7ERJj38/E4D+7rvu7rvu73ic77XNd93Sq1Wq1GCCGEEEIIIYR4Db3c7oAQQgghhBBCiI+fJI9CCCGEEEIIId5IkkchhBBCCCGEEG8kyaMQQgghhBBCiDeS5FEIIYQQQgghxBtJ8iiEEEIIIYQQ4o0keRRCCCGEEEII8UaSPAohhBBCCCGEeCNJHoUQQgghhBBCvJEkj0IIIYT45K1YsQKVSqXz5e/v/17Oef78eYKCgoiNjX0v7f8bsbGxqFQqVqxYkdtdeWc7duwgKCgot7shhMjEILc7IIQQQgiRU0JCQihdurRGmYODw3s51/nz5xk3bhzu7u44OTm9l3O8K3t7e44ePUqJEiVyuyvvbMeOHfz444+SQArxEZHkUQghhBCfjfLly1O1atXc7sa/kpKSgkqlwsDg3T+mGRsbU7NmzRzs1YeTmJiImZlZbndDCKGDTFsVQgghxH/Gzz//jJubG+bm5lhYWNC0aVP+/PNPjTp//PEHXl5eODk5YWpqipOTE506deL69etKnRUrVtChQwcAGjRooEyRzZgm6uTkhI+Pj9b53d3dcXd3V7b379+PSqVi9erVDBkyhEKFCmFsbMyVK1cA2LNnDw0bNiRPnjyYmZlRu3Zt9u7d+8br1DVtNSgoCJVKxdmzZ+nQoQN58+bFysoKPz8/UlNTiY6OplmzZlhaWuLk5MT06dM12szo65o1a/Dz88POzg5TU1Pq16+vFUOALVu24ObmhpmZGZaWljRu3JijR49q1Mno06lTp2jfvj358+enRIkS+Pj48OOPPwJoTEHOmCL8448/Uq9ePWxtbTE3N6dChQpMnz6dlJQUrXiXL1+eyMhI6tati5mZGcWLF2fq1Kmkp6dr1P37778ZMmQIxYsXx9jYGFtbWzw8PLh48aJSJzk5mYkTJ1K6dGmMjY2xsbGhe/fuPHjw4I3viRCfA0kehRBCCPHZSEtLIzU1VeOVYfLkyXTq1ImyZcuyYcMGVq9ezT///EPdunU5f/68Ui82NhYXFxdmz57N7t27mTZtGnfu3KFatWo8fPgQgBYtWjB58mTgZSJz9OhRjh49SosWLd6p34GBgdy4cYNFixaxdetWbG1tWbNmDU2aNCFPnjysXLmSDRs2YGVlRdOmTbOVQGbF09OTSpUqERYWhq+vL7NmzWLw4MG0bduWFi1asGnTJr788ksCAgLYuHGj1vEjRozg6tWrLF26lKVLl3L79m3c3d25evWqUmft2rW0adOGPHnysG7dOpYtW8aTJ09wd3fn0KFDWm22a9eOkiVL8ssvv7Bo0SJGjx5N+/btAZTYHj16FHt7ewBiYmLo3Lkzq1evZtu2bfTs2ZMZM2bw3XffabV99+5dvL296dKlC1u2bKF58+YEBgayZs0apc4///xDnTp1+Omnn+jevTtbt25l0aJFODs7c+fOHQDS09Np06YNU6dOpXPnzmzfvp2pU6cSHh6Ou7s7z58/f+f3RIhPhloIIYQQ4hMXEhKiBnS+UlJS1Ddu3FAbGBio+/fvr3HcP//8o7azs1N7enpm2XZqaqr62bNnanNzc/WcOXOU8l9++UUNqCMiIrSOKVq0qLpbt25a5fXr11fXr19f2Y6IiFAD6nr16mnUS0hIUFtZWalbtWqlUZ6WlqauVKmSunr16q+Jhlp97do1NaAOCQlRysaOHasG1DNnztSoW7lyZTWg3rhxo1KWkpKitrGxUbdr106rr1988YU6PT1dKY+NjVUbGhqqv/32W6WPDg4O6goVKqjT0tKUev/884/a1tZWXatWLa0+jRkzRusavv/+e3V2PqqmpaWpU1JS1KtWrVLr6+urHz9+rOyrX7++GlAfP35c45iyZcuqmzZtqmyPHz9eDajDw8OzPM+6devUgDosLEyjPDIyUg2oFyxY8Ma+CvGpk5FHIYQQQnw2Vq1aRWRkpMbLwMCA3bt3k5qaSteuXTVGJU1MTKhfvz779+9X2nj27BkBAQGULFkSAwMDDAwMsLCwICEhgQsXLryXfn/99dca20eOHOHx48d069ZNo7/p6ek0a9aMyMhIEhIS3ulcLVu21NguU6YMKpWK5s2bK2UGBgaULFlSY6puhs6dO6NSqZTtokWLUqtWLSIiIgCIjo7m9u3bfPPNN+jp/d9HTQsLC77++muOHTtGYmLia6//Tf78809at26NtbU1+vr6GBoa0rVrV9LS0rh06ZJGXTs7O6pXr65RVrFiRY1r27lzJ87OzjRq1CjLc27bto18+fLRqlUrjfekcuXK2NnZafwMCfG5kgVzhBBCCPHZKFOmjM4Fc+7duwdAtWrVdB6XOcnp3Lkze/fuZfTo0VSrVo08efKgUqnw8PB4b1MTM6ZjvtrfjKmbujx+/Bhzc/O3PpeVlZXGtpGREWZmZpiYmGiVx8fHax1vZ2ens+zMmTMAPHr0CNC+Jni58m16ejpPnjzRWBRHV92s3Lhxg7p16+Li4sKcOXNwcnLCxMSEEydO8P3332u9R9bW1lptGBsba9R78OABRYoUee157927x99//42RkZHO/RlTmoX4nEnyKIQQQojPXoECBQD49ddfKVq0aJb1nj59yrZt2xg7dizDhw9XypOSknj8+HG2z2diYkJSUpJW+cOHD5W+ZJZ5JC9zf+fNm5flqqkFCxbMdn9y0t27d3WWZSRpGX9m3CuY2e3bt9HT0yN//vwa5a9e/+ts3ryZhIQENm7cqPFenj59OtttvMrGxoa4uLjX1ilQoADW1tbs2rVL535LS8t3Pr8QnwpJHoUQQgjx2WvatCkGBgbExMS8doqkSqVCrVZjbGysUb506VLS0tI0yjLq6BqNdHJy4uzZsxplly5dIjo6Wmfy+KratWuTL18+zp8/T79+/d5Y/0Nat24dfn5+SsJ3/fp1jhw5QteuXQFwcXGhUKFCrF27Fn9/f6VeQkICYWFhygqsb5I5vqampkp5RnuZ3yO1Ws2SJUve+ZqaN2/OmDFj2LdvH19++aXOOi1btmT9+vWkpaVRo0aNdz6XEJ8ySR6FEEII8dlzcnJi/PjxjBw5kqtXr9KsWTPy58/PvXv3OHHiBObm5owbN448efJQr149ZsyYQYECBXBycuLAgQMsW7aMfPnyabRZvnx5ABYvXoylpSUmJiYUK1YMa2trvvnmG7p06ULfvn35+uuvuX79OtOnT8fGxiZb/bWwsGDevHl069aNx48f0759e2xtbXnw4AFnzpzhwYMHLFy4MKfDlC3379/nq6++wtfXl6dPnzJ27FhMTEwIDAwEXk4Bnj59Ot7e3rRs2ZLvvvuOpKQkZsyYwd9//83UqVOzdZ4KFSoAMG3aNJo3b46+vj4VK1akcePGGBkZ0alTJ4YNG8aLFy9YuHAhT548eedrGjRoED///DNt2rRh+PDhVK9enefPn3PgwAFatmxJgwYN8PLyIjQ0FA8PDwYOHEj16tUxNDQkLi6OiIgI2rRpw1dfffXOfRDiUyAL5gghhBDiPyEwMJBff/2VS5cu0a1bN5o2bcqwYcO4fv069erVU+qtXbuWBg0aMGzYMNq1a8cff/xBeHg4efPm1WivWLFizJ49mzNnzuDu7k61atXYunUr8PK+yenTp7N7925atmzJwoULWbhwIc7Oztnub5cuXYiIiODZs2d89913NGrUiIEDB3Lq1CkaNmyYM0F5B5MnT6Zo0aJ0796dHj16YG9vT0REBCVKlFDqdO7cmc2bN/Po0SM6duxI9+7dyZMnDxEREdSpUydb5+ncuTPffvstCxYswM3NjWrVqnH79m1Kly5NWFgYT54
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print('train data size: ', len(train_data))\n",
" \n",
"evals = {}\n",
"model = train_light_model(train_data, test_data, light_params, feature_columns_new,\n",
" [lgb.log_evaluation(period=500),\n",
" lgb.callback.record_evaluation(evals),\n",
" lgb.early_stopping(50, first_metric_only=True)\n",
" ], evals,\n",
" num_boost_round=1000, use_optuna=False,\n",
" print_feature_importance=True)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "445dff84-70b2-4fc9-a9b6-1251993324d6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.522549500Z",
"start_time": "2025-02-20T16:14:20.677936Z"
}
},
"outputs": [],
"source": [
"# catboost_params = {\n",
"# 'loss_function': 'CrossEntropy', # 适用于二分类\n",
"# 'eval_metric': 'AUC', # 评估指标\n",
"# 'iterations': 1000,\n",
"# 'learning_rate': 0.01,\n",
"# 'depth': , # 控制模型复杂度\n",
"# # 'l2_leaf_reg': 3, # L2 正则化\n",
"# 'verbose': 500,\n",
"# 'early_stopping_rounds': 100,\n",
"# # 'one_hot_max_size': 50,\n",
"# # 'class_weights': [0.6, 1.2]\n",
"# # 'task_type': 'GPU'\n",
"# }\n",
"\n",
"# model = train_catboost(train_data, test_data, feature_columns_new, catboost_params, plot=True)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "751a6df9-d90b-4053-8769-c6c3b6654406",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.523526100Z",
"start_time": "2025-02-20T16:14:20.855152Z"
}
},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"\n",
"\n",
"def incremental_training(test_data: pd.DataFrame, model: lgb.Booster, days: int, back_days: int, feature_columns: list, params: dict):\n",
" test_data = test_data.sort_values(by='trade_date')\n",
"\n",
" scores = []\n",
"\n",
" unique_trade_dates = sorted(test_data['trade_date'].unique())\n",
" for i in tqdm(range(0, len(unique_trade_dates), days)):\n",
" # Get the current window of trade dates\n",
" current_dates = unique_trade_dates[i:i + days]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
" X = window_data[feature_columns]\n",
"\n",
" window_scores = model.predict(X)\n",
" scores.extend(window_scores)\n",
"\n",
" current_dates = unique_trade_dates[max(0, i - back_days):i + days]\n",
" window_data = test_data[test_data['trade_date'].isin(current_dates)]\n",
"\n",
" # Incrementally train the model with the current window data\n",
" X_train = window_data[feature_columns]\n",
" y_train = window_data['label'] # Assuming 'score' is what you're predicting\n",
" categorical_feature = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" # print(f'categorical_feature: {categorical_feature}')\n",
" train_data = lgb.Dataset(X_train, label=y_train, categorical_feature=categorical_feature)\n",
"\n",
" model = lgb.train(params,\n",
" train_set=train_data,\n",
" num_boost_round=100,\n",
" init_model=model,\n",
" keep_training_booster=True)\n",
"\n",
" # Add the scores as a new 'score' column to the test_data\n",
" test_data['score'] = scores\n",
"\n",
" return test_data\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "660a24f74501f98f",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.533542700Z",
"start_time": "2025-02-20T16:14:20.978664Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████| 102/102 [02:16<00:00, 1.34s/it]\n"
]
}
],
"source": [
"predictions_test = incremental_training(test_data, model, 5, 0, feature_columns_new, light_params)\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "36ccaa730ab46718",
"metadata": {
"ExecuteTime": {
"end_time": "2025-02-21T15:06:05.535526300Z",
"start_time": "2025-02-20T16:17:08.116647Z"
}
},
"outputs": [],
"source": [
"predictions_test = predictions_test.loc[predictions_test.groupby('trade_date')['score'].idxmax()]\n",
"predictions_test[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}