Files
NewStock/main/train/Rank2.ipynb

2245 lines
247 KiB
Plaintext
Raw Normal View History

2025-05-26 21:34:36 +08:00
{
"cells": [
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 9,
2025-05-26 21:34:36 +08:00
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:46:06.987506Z",
"start_time": "2025-04-03T12:46:06.259551Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-04 13:50:02 +08:00
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n",
"/mnt/d/PyProject/NewStock\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import gc\n",
"import os\n",
"import sys\n",
2025-06-04 13:50:02 +08:00
"sys.path.append('/mnt/d/PyProject/NewStock/')\n",
2025-05-26 21:34:36 +08:00
"print(os.getcwd())\n",
"import pandas as pd\n",
"from main.factor.factor import get_rolling_factor, get_simple_factor\n",
"from main.utils.factor import read_industry_data\n",
"from main.utils.factor_processor import calculate_score\n",
"from main.utils.utils import read_and_merge_h5_data, merge_with_industry_data\n",
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 10,
2025-05-26 21:34:36 +08:00
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:00.212859Z",
"start_time": "2025-04-03T12:46:06.998047Z"
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"cyq perf\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
2025-06-04 13:50:02 +08:00
"RangeIndex: 8692146 entries, 0 to 8692145\n",
2025-05-29 20:41:18 +08:00
"Data columns (total 33 columns):\n",
2025-05-26 21:34:36 +08:00
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 pct_chg float64 \n",
2025-05-29 20:41:18 +08:00
" 8 amount float64 \n",
" 9 turnover_rate float64 \n",
" 10 pe_ttm float64 \n",
" 11 circ_mv float64 \n",
" 12 total_mv float64 \n",
" 13 volume_ratio float64 \n",
" 14 is_st bool \n",
" 15 up_limit float64 \n",
" 16 down_limit float64 \n",
" 17 buy_sm_vol float64 \n",
" 18 sell_sm_vol float64 \n",
" 19 buy_lg_vol float64 \n",
" 20 sell_lg_vol float64 \n",
" 21 buy_elg_vol float64 \n",
" 22 sell_elg_vol float64 \n",
" 23 net_mf_vol float64 \n",
" 24 his_low float64 \n",
" 25 his_high float64 \n",
" 26 cost_5pct float64 \n",
" 27 cost_15pct float64 \n",
" 28 cost_50pct float64 \n",
" 29 cost_85pct float64 \n",
" 30 cost_95pct float64 \n",
" 31 weight_avg float64 \n",
" 32 winner_rate float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(30), object(1)\n",
"memory usage: 2.1+ GB\n",
2025-05-26 21:34:36 +08:00
"None\n"
]
}
],
"source": [
"from main.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_data.h5', key='daily_data',\n",
2025-05-29 20:41:18 +08:00
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg', 'amount'],\n",
2025-05-26 21:34:36 +08:00
" df=None)\n",
"\n",
"print('daily basic')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_basic.h5', key='daily_basic',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/stk_limit.h5', key='stk_limit',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/money_flow.h5', key='money_flow',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print('cyq perf')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/cyq_perf.h5', key='cyq_perf',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct',\n",
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
" df=df)\n",
"print(df.info())"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 11,
2025-05-26 21:34:36 +08:00
"id": "cac01788dac10678",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:10.527104Z",
"start_time": "2025-04-03T12:47:00.488715Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n"
]
}
],
"source": [
"print('industry')\n",
2025-06-04 13:50:02 +08:00
"industry_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/industry_data.h5', key='industry_data',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'l2_code', 'in_date'],\n",
" df=None, on=['ts_code'], join='left')\n",
"\n",
"\n",
"def merge_with_industry_data(df, industry_df):\n",
" # 确保日期字段是 datetime 类型\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
"\n",
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
"\n",
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
"\n",
" # 使用 merge_asof 进行向后合并\n",
" merged = pd.merge_asof(\n",
" df_sorted,\n",
" industry_df_sorted,\n",
" by='ts_code', # 按 ts_code 分组\n",
" left_on='trade_date',\n",
" right_on='in_date',\n",
" direction='backward'\n",
" )\n",
"\n",
" # 获取每个 ts_code 的最早 in_date 记录\n",
" min_in_date_per_ts = (industry_df_sorted\n",
" .groupby('ts_code')\n",
" .first()\n",
" .reset_index()[['ts_code', 'l2_code']])\n",
"\n",
" # 填充未匹配到的记录trade_date 早于所有 in_date 的情况)\n",
" merged['l2_code'] = merged['l2_code'].fillna(\n",
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
" )\n",
"\n",
" # 保留需要的列并重置索引\n",
" result = merged.reset_index(drop=True)\n",
" return result\n",
"\n",
"\n",
"# 使用示例\n",
"df = merge_with_industry_data(df, industry_df)\n",
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 12,
2025-05-26 21:34:36 +08:00
"id": "c4e9e1d31da6dba6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:10.719252Z",
"start_time": "2025-04-03T12:47:10.541247Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from main.factor.factor import *\n",
"\n",
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" df = df.sort_values('trade_date')\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" # 4. 情绪因子1市场上涨比例Up Ratio\n",
" df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)\n",
" df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例\n",
"\n",
" # 5. 情绪因子2成交量变化率Volume Change Rate\n",
" df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量\n",
" df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率\n",
"\n",
" # 6. 情绪因子3波动率Volatility\n",
" df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差\n",
"\n",
" # 7. 情绪因子4成交额变化率Amount Change Rate\n",
" df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额\n",
" df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率\n",
"\n",
" # df = sentiment_panic_greed_index(df)\n",
" # df = sentiment_market_breadth_proxy(df)\n",
" # df = sentiment_reversal_indicator(df)\n",
"\n",
" return df\n",
"\n",
"\n",
"def generate_index_indicators(h5_filename):\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" df = df.sort_values('trade_date')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
" df_index = calculate_indicators(df_index)\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', \n",
" 'RSI', 'MACD', 'Signal_line', 'MACD_hist', \n",
" # 'sentiment_panic_greed_index',\n",
" 'up_ratio_20d', 'volume_change_rate', 'volatility',\n",
" 'amount_change_rate', 'amount_mean'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
2025-06-04 13:50:02 +08:00
"h5_filename = '/mnt/d/PyProject/NewStock/data/index_data.h5'\n",
2025-05-26 21:34:36 +08:00
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 13,
2025-05-26 21:34:36 +08:00
"id": "a735bc02ceb4d872",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:10.821169Z",
"start_time": "2025-04-03T12:47:10.751831Z"
}
},
"outputs": [],
"source": [
"import talib\n",
"import numpy as np"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 14,
2025-05-26 21:34:36 +08:00
"id": "53f86ddc0677a6d7",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:15.944254Z",
"start_time": "2025-04-03T12:47:10.826179Z"
},
"jupyter": {
"source_hidden": true
},
"scrolled": true
},
"outputs": [],
"source": [
"from main.utils.factor import get_act_factor\n",
"\n",
"\n",
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" industry_data = get_act_factor(industry_data, cat=False)\n",
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" # factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
" # \n",
" # for factor in factor_columns:\n",
" # if factor in industry_data.columns:\n",
" # # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
" # industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
" # lambda x: x - x.mean())\n",
"\n",
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
" lambda x: x.rank(pct=True))\n",
"\n",
" # cs_rank_intraday_range(industry_data)\n",
" # cs_rank_close_pos_in_range(industry_data)\n",
"\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
"\n",
2025-06-04 13:50:02 +08:00
"industry_df = read_industry_data('/mnt/d/PyProject/NewStock/data/sw_daily.h5')\n"
2025-05-26 21:34:36 +08:00
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 15,
2025-05-26 21:34:36 +08:00
"id": "dbe2fd8021b9417f",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:15.969344Z",
"start_time": "2025-04-03T12:47:15.963327Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-05-29 20:41:18 +08:00
"['ts_code', 'open', 'close', 'high', 'low', 'amount', 'circ_mv', 'total_mv', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'in_date']\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if\n",
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n",
"origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
"print(origin_columns)"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 16,
2025-05-26 21:34:36 +08:00
"id": "85c3e3d0235ffffa",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:16.089879Z",
"start_time": "2025-04-03T12:47:15.990101Z"
}
},
"outputs": [],
"source": [
2025-06-04 13:50:02 +08:00
"fina_indicator_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/fina_indicator.h5', key='fina_indicator',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'ann_date', 'undist_profit_ps', 'ocfps', 'bps'],\n",
" df=None)\n",
2025-06-04 13:50:02 +08:00
"cashflow_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/cashflow.h5', key='cashflow',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'ann_date', 'n_cashflow_act'],\n",
" df=None)\n",
2025-06-04 13:50:02 +08:00
"balancesheet_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/balancesheet.h5', key='balancesheet',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'ann_date', 'money_cap', 'total_liab'],\n",
" df=None)\n",
2025-06-04 13:50:02 +08:00
"top_list_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/top_list.h5', key='top_list',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'reason'],\n",
" df=None)\n",
"\n",
2025-06-04 13:50:02 +08:00
"top_list_df = top_list_df.sort_values(by='trade_date', ascending=False).drop_duplicates(subset=['ts_code', 'trade_date'], keep='first').sort_values(by='trade_date')\n",
"\n",
"stk_holdertrade_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/stk_holdertrade.h5', key='stk_holdertrade',\n",
" columns=['ts_code', 'ann_date', 'in_de', 'change_ratio'],\n",
" df=None)"
2025-05-26 21:34:36 +08:00
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 18,
2025-05-26 21:34:36 +08:00
"id": "92d84ce15a562ec6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T13:08:01.612695Z",
"start_time": "2025-04-03T12:47:16.121802Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-04 13:50:02 +08:00
"正在计算股东增减持因子(优化版)...\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[18]\u001b[39m\u001b[32m, line 30\u001b[39m\n\u001b[32m 23\u001b[39m df = df.sort_values(by=[\u001b[33m\"\u001b[39m\u001b[33mts_code\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m\"\u001b[39m])\n\u001b[32m 25\u001b[39m \u001b[38;5;66;03m# df = price_minus_deduction_price(df, n=120)\u001b[39;00m\n\u001b[32m 26\u001b[39m \u001b[38;5;66;03m# df = price_deduction_price_diff_ratio_to_sma(df, n=120)\u001b[39;00m\n\u001b[32m 27\u001b[39m \u001b[38;5;66;03m# df = cat_price_vs_sma_vs_deduction_price(df, n=120)\u001b[39;00m\n\u001b[32m 28\u001b[39m \u001b[38;5;66;03m# df = cat_reason(df, top_list_df)\u001b[39;00m\n\u001b[32m 29\u001b[39m \u001b[38;5;66;03m# df = cat_is_on_top_list(df, top_list_df)\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m30\u001b[39m df = \u001b[43mholder_trade_factors\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstk_holdertrade_df\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 32\u001b[39m df = cat_senti_mom_vol_spike(\n\u001b[32m 33\u001b[39m df,\n\u001b[32m 34\u001b[39m return_period=\u001b[32m3\u001b[39m,\n\u001b[32m (...)\u001b[39m\u001b[32m 38\u001b[39m current_pct_chg_max=\u001b[32m0.05\u001b[39m,\n\u001b[32m 39\u001b[39m ) \u001b[38;5;66;03m# 当日涨幅不宜过大\u001b[39;00m\n\u001b[32m 41\u001b[39m df = cat_senti_pre_breakout(\n\u001b[32m 42\u001b[39m df,\n\u001b[32m 43\u001b[39m atr_short_N=\u001b[32m10\u001b[39m,\n\u001b[32m (...)\u001b[39m\u001b[32m 51\u001b[39m volume_ratio_signal_threshold=\u001b[32m1.1\u001b[39m,\n\u001b[32m 52\u001b[39m )\n",
"\u001b[36mFile \u001b[39m\u001b[32m/mnt/d/PyProject/NewStock/main/factor/money_factor.py:50\u001b[39m, in \u001b[36mholder_trade_factors\u001b[39m\u001b[34m(all_data_df, stk_holdertrade_df)\u001b[39m\n\u001b[32m 43\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m date_in_window \u001b[38;5;129;01min\u001b[39;00m future_dates:\n\u001b[32m 44\u001b[39m \u001b[38;5;66;03m# 只有当日期是实际交易日时才添加\u001b[39;00m\n\u001b[32m 45\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m date_in_window \u001b[38;5;129;01min\u001b[39;00m all_trade_dates_set:\n\u001b[32m 46\u001b[39m expanded_holder_events.append({\n\u001b[32m 47\u001b[39m \u001b[33m'\u001b[39m\u001b[33mts_code\u001b[39m\u001b[33m'\u001b[39m: ts_code,\n\u001b[32m 48\u001b[39m \u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m: date_in_window,\n\u001b[32m 49\u001b[39m \u001b[33m'\u001b[39m\u001b[33min_de_numeric\u001b[39m\u001b[33m'\u001b[39m: row[\u001b[33m'\u001b[39m\u001b[33min_de_numeric\u001b[39m\u001b[33m'\u001b[39m],\n\u001b[32m---> \u001b[39m\u001b[32m50\u001b[39m \u001b[33m'\u001b[39m\u001b[33mchange_ratio_total_agg\u001b[39m\u001b[33m'\u001b[39m: \u001b[43mrow\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mchange_ratio_total_agg\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m]\u001b[49m,\n\u001b[32m 51\u001b[39m \u001b[33m'\u001b[39m\u001b[33mchange_ratio_in_agg\u001b[39m\u001b[33m'\u001b[39m: row[\u001b[33m'\u001b[39m\u001b[33mchange_ratio_in_agg\u001b[39m\u001b[33m'\u001b[39m],\n\u001b[32m 52\u001b[39m \u001b[33m'\u001b[39m\u001b[33mchange_ratio_de_agg\u001b[39m\u001b[33m'\u001b[39m: row[\u001b[33m'\u001b[39m\u001b[33mchange_ratio_de_agg\u001b[39m\u001b[33m'\u001b[39m]\n\u001b[32m 53\u001b[39m })\n\u001b[32m 55\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m expanded_holder_events: \u001b[38;5;66;03m# 如果没有事件,直接返回原始 df\u001b[39;00m\n\u001b[32m 56\u001b[39m \u001b[38;5;66;03m# 确保返回的DataFrame与原始df具有相同的列和顺序\u001b[39;00m\n\u001b[32m 57\u001b[39m \u001b[38;5;66;03m# 并填充为默认值\u001b[39;00m\n\u001b[32m 58\u001b[39m default_factors = pd.DataFrame({\n\u001b[32m 59\u001b[39m \u001b[33m'\u001b[39m\u001b[33mholder_trade_type_10d\u001b[39m\u001b[33m'\u001b[39m: \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m 60\u001b[39m \u001b[33m'\u001b[39m\u001b[33mholder_change_ratio_sum_10d\u001b[39m\u001b[33m'\u001b[39m: \u001b[32m0.0\u001b[39m,\n\u001b[32m 61\u001b[39m \u001b[33m'\u001b[39m\u001b[33mholder_in_change_ratio_sum_10d\u001b[39m\u001b[33m'\u001b[39m: \u001b[32m0.0\u001b[39m,\n\u001b[32m 62\u001b[39m \u001b[33m'\u001b[39m\u001b[33mholder_de_change_ratio_sum_10d\u001b[39m\u001b[33m'\u001b[39m: \u001b[32m0.0\u001b[39m\n\u001b[32m 63\u001b[39m }, index=all_data_df.index)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/series.py:1121\u001b[39m, in \u001b[36mSeries.__getitem__\u001b[39m\u001b[34m(self, key)\u001b[39m\n\u001b[32m 1118\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._values[key]\n\u001b[32m 1120\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m key_is_scalar:\n\u001b[32m-> \u001b[39m\u001b[32m1121\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1123\u001b[39m \u001b[38;5;66;03m# Convert generator to list before going through hashable part\u001b[39;00m\n\u001b[32m 1124\u001b[39m \u001b[38;5;66;03m# (We will iterate through the generator there to check for slices)\u001b[39;00m\n\u001b[32m 1125\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m is_iterator(key):\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/series.py:1239\u001b[39m, in \u001b[36mSeries._get_value\u001b[39m\u001b[34m(self, label, takeable)\u001b[39m\n\u001b[32m 1236\u001b[39m \u001b[38;5;66;03m# Similar to Index.get_value, but we do not fall back to positional\u001b[39;00m\n\u001b[32m 1237\u001b[39m loc = \u001b[38;5;28mself\u001b[39m.index.get_loc(label)\n\u001b[32m-> \u001b[39m\u001b[32m1239\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mis_integer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloc\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 1240\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._values[loc]\n\u001b[32m 1242\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m.index, MultiIndex):\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"import numpy as np\n",
"from main.factor.factor import *\n",
2025-06-04 13:50:02 +08:00
"from main.factor.money_factor import * \n",
2025-05-26 21:34:36 +08:00
"\n",
2025-05-29 20:41:18 +08:00
"\n",
2025-05-26 21:34:36 +08:00
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
2025-05-29 20:41:18 +08:00
" df = df[~df[\"is_st\"]]\n",
" df = df[~df[\"ts_code\"].str.endswith(\"BJ\")]\n",
" df = df[~df[\"ts_code\"].str.startswith(\"30\")]\n",
" df = df[~df[\"ts_code\"].str.startswith(\"68\")]\n",
" df = df[~df[\"ts_code\"].str.startswith(\"8\")]\n",
" df = df[df[\"trade_date\"] >= \"2019-01-01\"]\n",
" if \"in_date\" in df.columns:\n",
" df = df.drop(columns=[\"in_date\"])\n",
2025-05-26 21:34:36 +08:00
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
2025-05-29 20:41:18 +08:00
"\n",
2025-05-26 21:34:36 +08:00
"gc.collect()\n",
"\n",
"df = filter_data(df)\n",
2025-05-29 20:41:18 +08:00
"df = df.sort_values(by=[\"ts_code\", \"trade_date\"])\n",
2025-05-26 21:34:36 +08:00
"\n",
"# df = price_minus_deduction_price(df, n=120)\n",
"# df = price_deduction_price_diff_ratio_to_sma(df, n=120)\n",
"# df = cat_price_vs_sma_vs_deduction_price(df, n=120)\n",
"# df = cat_reason(df, top_list_df)\n",
"# df = cat_is_on_top_list(df, top_list_df)\n",
2025-06-04 13:50:02 +08:00
"df = holder_trade_factors(df, stk_holdertrade_df)\n",
2025-05-26 21:34:36 +08:00
"\n",
2025-05-29 20:41:18 +08:00
"df = cat_senti_mom_vol_spike(\n",
" df,\n",
" return_period=3,\n",
" return_threshold=0.03, # 近3日涨幅超3%\n",
" volume_ratio_threshold=1.3,\n",
" current_pct_chg_min=0.0, # 当日必须收红\n",
2025-06-01 15:59:29 +08:00
" current_pct_chg_max=0.05,\n",
2025-05-29 20:41:18 +08:00
") # 当日涨幅不宜过大\n",
"\n",
"df = cat_senti_pre_breakout(\n",
" df,\n",
" atr_short_N=10,\n",
" atr_long_M=40,\n",
" vol_atrophy_N=10,\n",
" vol_atrophy_M=40,\n",
" price_stab_N=5,\n",
" price_stab_threshold=0.06,\n",
" current_pct_chg_min_signal=0.002,\n",
" current_pct_chg_max_signal=0.05,\n",
" volume_ratio_signal_threshold=1.1,\n",
")\n",
"\n",
2025-05-28 14:16:04 +08:00
"df = ts_turnover_rate_acceleration_5_20(df)\n",
"df = ts_vol_sustain_10_30(df)\n",
2025-05-29 20:41:18 +08:00
"# df = cs_turnover_rate_relative_strength_20(df)\n",
2025-05-28 14:16:04 +08:00
"df = cs_amount_outlier_10(df)\n",
"df = ts_ff_to_total_turnover_ratio(df)\n",
"df = ts_price_volume_trend_coherence_5_20(df)\n",
2025-05-29 20:41:18 +08:00
"# df = ts_turnover_rate_trend_strength_5(df)\n",
2025-05-28 14:16:04 +08:00
"df = ts_ff_turnover_rate_surge_10(df)\n",
"\n",
2025-05-29 20:41:18 +08:00
"df = add_financial_factor(df, fina_indicator_df, factor_value_col=\"undist_profit_ps\")\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col=\"ocfps\")\n",
2025-05-26 21:34:36 +08:00
"calculate_arbr(df, N=26)\n",
2025-05-29 20:41:18 +08:00
"df[\"log_circ_mv\"] = np.log(df[\"circ_mv\"])\n",
2025-05-26 21:34:36 +08:00
"df = calculate_cashflow_to_ev_factor(df, cashflow_df, balancesheet_df)\n",
"df = caculate_book_to_price_ratio(df, fina_indicator_df)\n",
"df = turnover_rate_n(df, n=5)\n",
"df = variance_n(df, n=20)\n",
"df = bbi_ratio_factor(df)\n",
"df = daily_deviation(df)\n",
"df = daily_industry_deviation(df)\n",
"df, _ = get_rolling_factor(df)\n",
"df, _ = get_simple_factor(df)\n",
"\n",
2025-06-01 15:59:29 +08:00
"df = calculate_strong_inflow_signal(df)\n",
"\n",
2025-05-29 20:41:18 +08:00
"df = df.rename(columns={\"l1_code\": \"cat_l1_code\"})\n",
"df = df.rename(columns={\"l2_code\": \"cat_l2_code\"})\n",
2025-05-26 21:34:36 +08:00
"\n",
"lg_flow_mom_corr(df, N=20, M=60)\n",
"lg_flow_accel(df)\n",
"profit_pressure(df)\n",
"underwater_resistance(df)\n",
"cost_conc_std(df, N=20)\n",
"profit_decay(df, N=20)\n",
"vol_amp_loss(df, N=20)\n",
"vol_drop_profit_cnt(df, N=20, M=5)\n",
"lg_flow_vol_interact(df, N=20)\n",
"cost_break_confirm_cnt(df, M=5)\n",
"atr_norm_channel_pos(df, N=14)\n",
"turnover_diff_skew(df, N=20)\n",
"lg_sm_flow_diverge(df, N=20)\n",
"pullback_strong(df, N=20, M=20)\n",
"vol_wgt_hist_pos(df, N=20)\n",
"vol_adj_roc(df, N=20)\n",
"\n",
"cs_rank_net_lg_flow_val(df)\n",
"cs_rank_flow_divergence(df)\n",
2025-05-29 20:41:18 +08:00
"cs_rank_industry_adj_lg_flow(df) # Needs cat_l2_code\n",
2025-05-26 21:34:36 +08:00
"cs_rank_elg_buy_ratio(df)\n",
"cs_rank_rel_profit_margin(df)\n",
"cs_rank_cost_breadth(df)\n",
"cs_rank_dist_to_upper_cost(df)\n",
"cs_rank_winner_rate(df)\n",
"cs_rank_intraday_range(df)\n",
"cs_rank_close_pos_in_range(df)\n",
2025-05-29 20:41:18 +08:00
"cs_rank_opening_gap(df) # Needs pre_close\n",
"cs_rank_pos_in_hist_range(df) # Needs his_low, his_high\n",
2025-05-26 21:34:36 +08:00
"cs_rank_vol_x_profit_margin(df)\n",
"cs_rank_lg_flow_price_concordance(df)\n",
"cs_rank_turnover_per_winner(df)\n",
2025-05-29 20:41:18 +08:00
"cs_rank_ind_cap_neutral_pe(df) # Placeholder - needs external libraries\n",
"cs_rank_volume_ratio(df) # Needs volume_ratio\n",
2025-05-26 21:34:36 +08:00
"cs_rank_elg_buy_sell_sm_ratio(df)\n",
2025-05-29 20:41:18 +08:00
"cs_rank_cost_dist_vol_ratio(df) # Needs volume_ratio\n",
"cs_rank_size(df) # Needs circ_mv\n",
2025-05-26 21:34:36 +08:00
"\n",
"\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())\n",
"print(df.columns.tolist())"
]
},
{
"cell_type": "code",
2025-06-01 15:59:29 +08:00
"execution_count": 10,
2025-05-26 21:34:36 +08:00
"id": "b87b938028afa206",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T13:08:03.658725Z",
"start_time": "2025-04-03T13:08:02.469611Z"
}
},
"outputs": [],
"source": [
"from scipy.stats import ks_2samp, wasserstein_distance\n",
"\n",
"\n",
"def remove_shifted_features(train_data, test_data, feature_columns, ks_threshold=0.05, wasserstein_threshold=0.1,\n",
" importance_threshold=0.05):\n",
" dropped_features = []\n",
"\n",
" # **统计数据漂移**\n",
" numeric_columns = train_data.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" for feature in numeric_columns:\n",
" ks_stat, p_value = ks_2samp(train_data[feature], test_data[feature])\n",
" wasserstein_dist = wasserstein_distance(train_data[feature], test_data[feature])\n",
"\n",
" if p_value < ks_threshold or wasserstein_dist > wasserstein_threshold:\n",
" dropped_features.append(feature)\n",
"\n",
" print(f\"检测到 {len(dropped_features)} 个可能漂移的特征: {dropped_features}\")\n",
"\n",
" # **应用阈值进行最终筛选**\n",
" filtered_features = [f for f in feature_columns if f not in dropped_features]\n",
"\n",
" return filtered_features, dropped_features\n",
"\n"
]
},
{
"cell_type": "code",
2025-06-01 15:59:29 +08:00
"execution_count": 11,
2025-05-26 21:34:36 +08:00
"id": "f4f16d63ad18d1bc",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T13:08:03.670700Z",
"start_time": "2025-04-03T13:08:03.665739Z"
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import statsmodels.api as sm # 用于中性化回归\n",
"from tqdm import tqdm # 可选,用于显示进度条\n",
"\n",
"# --- 常量 ---\n",
"epsilon = 1e-10 # 防止除零\n",
"\n",
"# --- 1. 中位数去极值 (MAD) ---\n",
"\n",
"def cs_mad_filter(df: pd.DataFrame,\n",
" features: list,\n",
" k: float = 3.0,\n",
" scale_factor: float = 1.4826):\n",
" \"\"\"\n",
" 对指定特征列进行截面 MAD 去极值处理 (原地修改)。\n",
"\n",
" 方法: 对每日截面数据,计算 median 和 MAD\n",
" 将超出 [median - k * scale * MAD, median + k * scale * MAD] 范围的值\n",
" 替换为边界值 (Winsorization)。\n",
" scale_factor=1.4826 使得 MAD 约等于正态分布的标准差。\n",
"\n",
" Args:\n",
" df (pd.DataFrame): 输入 DataFrame需包含 'trade_date' 和 features 列。\n",
" features (list): 需要处理的特征列名列表。\n",
" k (float): MAD 的倍数,用于确定边界。默认为 3.0。\n",
" scale_factor (float): MAD 的缩放因子。默认为 1.4826。\n",
"\n",
" WARNING: 此函数会原地修改输入的 DataFrame 'df'。\n",
" \"\"\"\n",
" print(f\"开始截面 MAD 去极值处理 (k={k})...\")\n",
" if not all(col in df.columns for col in features):\n",
" missing = [col for col in features if col not in df.columns]\n",
" print(f\"错误: DataFrame 中缺少以下特征列: {missing}。跳过去极值处理。\")\n",
" return\n",
"\n",
" grouped = df.groupby('trade_date')\n",
"\n",
" for col in tqdm(features, desc=\"MAD Filtering\"):\n",
" try:\n",
" # 计算截面中位数\n",
" median = grouped[col].transform('median')\n",
" # 计算截面 MAD (Median Absolute Deviation from Median)\n",
" mad = (df[col] - median).abs().groupby(df['trade_date']).transform('median')\n",
"\n",
" # 计算上下边界\n",
" lower_bound = median - k * scale_factor * mad\n",
" upper_bound = median + k * scale_factor * mad\n",
"\n",
" # 原地应用 clip\n",
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
"\n",
" except KeyError:\n",
" print(f\"警告: 列 '{col}' 可能不存在或在分组中出错,跳过此列的 MAD 处理。\")\n",
" except Exception as e:\n",
" print(f\"警告: 处理列 '{col}' 时发生错误: {e},跳过此列的 MAD 处理。\")\n",
"\n",
" print(\"截面 MAD 去极值处理完成。\")\n",
"\n",
"\n",
"# --- 2. 行业市值中性化 ---\n",
"\n",
"def cs_neutralize_industry_cap(df: pd.DataFrame,\n",
" features: list,\n",
" industry_col: str = 'cat_l2_code',\n",
" market_cap_col: str = 'circ_mv'):\n",
" \"\"\"\n",
" 对指定特征列进行截面行业和对数市值中性化 (原地修改)。\n",
" 使用 OLS 回归: feature ~ 1 + log(market_cap) + C(industry)\n",
" 将回归残差写回原特征列。\n",
"\n",
" Args:\n",
" df (pd.DataFrame): 输入 DataFrame需包含 'trade_date', features 列,\n",
" industry_col, market_cap_col。\n",
" features (list): 需要处理的特征列名列表。\n",
" industry_col (str): 行业分类列名。\n",
" market_cap_col (str): 流通市值列名。\n",
"\n",
" WARNING: 此函数会原地修改输入的 DataFrame 'df' 的 features 列。\n",
" 计算量较大,可能耗时较长。\n",
" 需要安装 statsmodels 库 (pip install statsmodels)。\n",
" \"\"\"\n",
" print(\"开始截面行业市值中性化...\")\n",
" required_cols = features + ['trade_date', industry_col, market_cap_col]\n",
" if not all(col in df.columns for col in required_cols):\n",
" missing = [col for col in required_cols if col not in df.columns]\n",
" print(f\"错误: DataFrame 中缺少必需列: {missing}。无法进行中性化。\")\n",
" return\n",
"\n",
" # 预处理:计算 log 市值,处理 industry code 可能的 NaN\n",
" log_cap_col = '_log_market_cap'\n",
" df[log_cap_col] = np.log1p(df[market_cap_col]) # log1p 处理 0 值\n",
" # df[industry_col] = df[industry_col].cat.add_categories('UnknownIndustry')\n",
" # df[industry_col] = df[industry_col].fillna('UnknownIndustry') # 填充行业 NaN\n",
" # df[industry_col] = df[industry_col].astype('category') # 转为类别ols 会自动处理\n",
"\n",
" dates = df['trade_date'].unique()\n",
" all_residuals = [] # 用于收集所有日期的残差\n",
"\n",
" for date in tqdm(dates, desc=\"Neutralizing\"):\n",
" daily_data = df.loc[df['trade_date'] == date, features + [log_cap_col, industry_col]].copy() # 使用 .loc 获取副本\n",
"\n",
" # 准备自变量 X (常数项 + log市值 + 行业哑变量)\n",
" X = daily_data[[log_cap_col]]\n",
" X = sm.add_constant(X, prepend=True) # 添加常数项\n",
" # 创建行业哑变量 (drop_first=True 避免共线性)\n",
" industry_dummies = pd.get_dummies(daily_data[industry_col], prefix=industry_col, drop_first=True)\n",
" industry_dummies = industry_dummies.astype(int)\n",
" X = pd.concat([X, industry_dummies], axis=1)\n",
"\n",
" daily_residuals = daily_data[[col for col in features]].copy() # 创建用于存储残差的df\n",
"\n",
" for col in features:\n",
" Y = daily_data[col]\n",
"\n",
" # 处理 NaN 值,确保 X 和 Y 在相同位置有有效值\n",
" valid_mask = Y.notna() & X.notna().all(axis=1)\n",
" if valid_mask.sum() < (X.shape[1] + 1): # 数据点不足以估计模型\n",
" print(f\"警告: 日期 {date}, 特征 {col} 有效数据不足 ({valid_mask.sum()}个),无法中性化,填充 NaN。\")\n",
" daily_residuals[col] = np.nan\n",
" continue\n",
"\n",
" Y_valid = Y[valid_mask]\n",
" X_valid = X[valid_mask]\n",
"\n",
" # 执行 OLS 回归\n",
" try:\n",
" model = sm.OLS(Y_valid.to_numpy(), X_valid.to_numpy())\n",
" results = model.fit()\n",
" # 将残差填回对应位置\n",
" daily_residuals.loc[valid_mask, col] = results.resid\n",
" daily_residuals.loc[~valid_mask, col] = np.nan # 原本无效的位置填充 NaN\n",
" except Exception as e:\n",
" print(f\"警告: 日期 {date}, 特征 {col} 回归失败: {e},填充 NaN。\")\n",
" daily_residuals[col] = np.nan\n",
" break\n",
"\n",
" all_residuals.append(daily_residuals)\n",
"\n",
" # 合并所有日期的残差结果\n",
" if all_residuals:\n",
" residuals_df = pd.concat(all_residuals)\n",
" # 将残差结果更新回原始 df (原地修改)\n",
" # 使用 update 比 merge 更适合基于索引的原地更新\n",
" # 确保 residuals_df 的索引与 df 中对应部分一致\n",
" df.update(residuals_df)\n",
" else:\n",
" print(\"没有有效的残差结果可以合并。\")\n",
"\n",
"\n",
" # 清理临时列\n",
" df.drop(columns=[log_cap_col], inplace=True)\n",
" print(\"截面行业市值中性化完成。\")\n",
"\n",
"\n",
"# --- 3. Z-Score 标准化 ---\n",
"\n",
"def cs_zscore_standardize(df: pd.DataFrame, features: list, epsilon: float = 1e-10):\n",
" \"\"\"\n",
" 对指定特征列进行截面 Z-Score 标准化 (原地修改)。\n",
" 方法: Z = (value - cross_sectional_mean) / (cross_sectional_std + epsilon)\n",
"\n",
" Args:\n",
" df (pd.DataFrame): 输入 DataFrame需包含 'trade_date' 和 features 列。\n",
" features (list): 需要处理的特征列名列表。\n",
" epsilon (float): 防止除以零的小常数。\n",
"\n",
" WARNING: 此函数会原地修改输入的 DataFrame 'df'。\n",
" \"\"\"\n",
" print(\"开始截面 Z-Score 标准化...\")\n",
" if not all(col in df.columns for col in features):\n",
" missing = [col for col in features if col not in df.columns]\n",
" print(f\"错误: DataFrame 中缺少以下特征列: {missing}。跳过标准化处理。\")\n",
" return\n",
"\n",
" grouped = df.groupby('trade_date')\n",
"\n",
" for col in tqdm(features, desc=\"Standardizing\"):\n",
" try:\n",
" # 使用 transform 计算截面均值和标准差\n",
" mean = grouped[col].transform('mean')\n",
" std = grouped[col].transform('std')\n",
"\n",
" # 计算 Z-Score 并原地赋值\n",
" df[col] = (df[col] - mean) / (std + epsilon)\n",
"\n",
" except KeyError:\n",
" print(f\"警告: 列 '{col}' 可能不存在或在分组中出错,跳过此列的标准化处理。\")\n",
" except Exception as e:\n",
" print(f\"警告: 处理列 '{col}' 时发生错误: {e},跳过此列的标准化处理。\")\n",
"\n",
" print(\"截面 Z-Score 标准化完成。\")\n",
"\n",
"def fill_nan_with_daily_median(df: pd.DataFrame, feature_columns: list[str]) -> pd.DataFrame:\n",
" \"\"\"\n",
" 对指定特征列进行每日截面中位数填充缺失值 (NaN)。\n",
"\n",
" 参数:\n",
" df (pd.DataFrame): 包含多日数据的DataFrame需要包含 'trade_date' 和 feature_columns 中的列。\n",
" feature_columns (list[str]): 需要进行缺失值填充的特征列名称列表。\n",
"\n",
" 返回:\n",
" pd.DataFrame: 包含缺失值填充后特征列的DataFrame。在输入DataFrame的副本上操作。\n",
" \"\"\"\n",
" processed_df = df.copy() # 在副本上操作,保留原始数据\n",
"\n",
" # 确保 trade_date 是 datetime 类型以便正确分组\n",
" processed_df['trade_date'] = pd.to_datetime(processed_df['trade_date'])\n",
"\n",
" def _fill_daily_nan(group):\n",
" # group 是某一个交易日的 DataFrame\n",
"\n",
" # 遍历指定的特征列\n",
" for feature_col in feature_columns:\n",
" # 检查列是否存在于当前分组中\n",
" if feature_col in group.columns:\n",
" # 计算当日该特征的中位数\n",
" median_val = group[feature_col].median()\n",
"\n",
" # 使用当日中位数填充该特征列的 NaN 值\n",
" # inplace=True 会直接修改 group DataFrame\n",
" group[feature_col].fillna(median_val, inplace=True)\n",
" # else:\n",
" # print(f\"Warning: Feature column '{feature_col}' not found in daily group for {group['trade_date'].iloc[0]}. Skipping.\")\n",
"\n",
" return group\n",
"\n",
" # 按交易日期分组,并应用每日填充函数\n",
" # group_keys=False 避免将分组键添加到结果索引中\n",
" filled_df = processed_df.groupby('trade_date', group_keys=False).apply(_fill_daily_nan)\n",
"\n",
" return filled_df"
]
},
{
"cell_type": "code",
2025-06-01 15:59:29 +08:00
"execution_count": 12,
2025-05-26 21:34:36 +08:00
"id": "40e6b68a91b30c79",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T13:08:04.694262Z",
"start_time": "2025-04-03T13:08:03.694904Z"
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
" log=True):\n",
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
"\n",
" # Calculate lower and upper bounds based on percentiles\n",
" lower_bound = label.quantile(lower_percentile)\n",
" upper_bound = label.quantile(upper_percentile)\n",
"\n",
" # Filter out values outside the bounds\n",
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
"\n",
" # Print the number of removed outliers\n",
" if log:\n",
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
" return filtered_label\n",
"\n",
"\n",
"def calculate_risk_adjusted_target(df, days=5):\n",
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
" df['future_return'] = (df['future_close'] - df['future_open']) / df['future_open']\n",
"\n",
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
" level=0, drop=True)\n",
" sharpe_ratio = df['future_return'] * df['future_volatility']\n",
" sharpe_ratio.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
"\n",
" return sharpe_ratio\n",
"\n",
"\n",
"def calculate_score(df, days=5, lambda_param=1.0):\n",
" def calculate_max_drawdown(prices):\n",
" peak = prices.iloc[0] # 初始化峰值\n",
" max_drawdown = 0 # 初始化最大回撤\n",
"\n",
" for price in prices:\n",
" if price > peak:\n",
" peak = price # 更新峰值\n",
" else:\n",
" drawdown = (peak - price) / peak # 计算当前回撤\n",
" max_drawdown = max(max_drawdown, drawdown) # 更新最大回撤\n",
"\n",
" return max_drawdown\n",
"\n",
" def compute_stock_score(stock_df):\n",
" stock_df = stock_df.sort_values(by=['trade_date'])\n",
" future_return = stock_df['future_return']\n",
" # 使用已有的 pct_chg 字段计算波动率\n",
" volatility = stock_df['pct_chg'].rolling(days).std().shift(-days)\n",
" max_drawdown = stock_df['close'].rolling(days).apply(calculate_max_drawdown, raw=False).shift(-days)\n",
" score = future_return - lambda_param * max_drawdown\n",
" return score\n",
"\n",
" # # 确保 DataFrame 按照股票代码和交易日期排序\n",
" # df = df.sort_values(by=['ts_code', 'trade_date'])\n",
"\n",
" # 对每个股票分别计算 score\n",
" df['score'] = df.groupby('ts_code').apply(compute_stock_score).reset_index(level=0, drop=True)\n",
"\n",
" return df['score']\n",
"\n",
"\n",
"def remove_highly_correlated_features(df, feature_columns, threshold=0.9):\n",
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
" if not numeric_features:\n",
" raise ValueError(\"No numeric features found in the provided data.\")\n",
"\n",
" corr_matrix = df[numeric_features].corr().abs()\n",
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
" remaining_features = [col for col in feature_columns if col not in to_drop\n",
" or 'act' in col or 'af' in col]\n",
" return remaining_features\n",
"\n",
"\n",
"def cross_sectional_standardization(df, features):\n",
" df_sorted = df.sort_values(by='trade_date') # 按时间排序\n",
" df_standardized = df_sorted.copy()\n",
"\n",
" for date in df_sorted['trade_date'].unique():\n",
" # 获取当前时间点的数据\n",
" current_data = df_standardized[df_standardized['trade_date'] == date]\n",
"\n",
" # 只对指定特征进行标准化\n",
" scaler = StandardScaler()\n",
" standardized_values = scaler.fit_transform(current_data[features])\n",
"\n",
" # 将标准化结果重新赋值回去\n",
" df_standardized.loc[df_standardized['trade_date'] == date, features] = standardized_values\n",
"\n",
" return df_standardized\n",
"\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"\n",
"def neutralize_manual_revised(df: pd.DataFrame, features: list, industry_col: str, mkt_cap_col: str) -> pd.DataFrame:\n",
" \"\"\"\n",
" 手动实现简单回归以提升速度,通过构建 Series 确保索引对齐。\n",
" 对特征在行业内部进行市值中性化。\n",
"\n",
" Args:\n",
" df: 输入的 DataFrame包含特征、行业分类和市值列。\n",
" features: 需要进行中性化的特征列名列表。\n",
" industry_col: 行业分类列的列名。\n",
" mkt_cap_col: 市值列的列名。\n",
"\n",
" Returns:\n",
" 中性化后的 DataFrame。\n",
" \"\"\"\n",
"\n",
" df[mkt_cap_col] = pd.to_numeric(df[mkt_cap_col], errors='coerce')\n",
" df_cleaned = df.dropna(subset=[mkt_cap_col]).copy()\n",
" df_cleaned = df_cleaned[df_cleaned[mkt_cap_col] > 0].copy()\n",
"\n",
" if df_cleaned.empty:\n",
" print(\"警告: 清理市值异常值后 DataFrame 为空。\")\n",
" return df # 返回原始或空df取决于清理前的状态\n",
"\n",
" processed_df = df\n",
"\n",
" for col in features:\n",
" if col not in df_cleaned.columns:\n",
" print(f\"警告: 特征列 '{col}' 不存在于清理后的 DataFrame 中,已跳过。\")\n",
" # 对于原始 df 中该列不存在的,在结果 df 中也保持原样可能全是NaN\n",
" processed_df[col] = df[col] if col in df.columns else np.nan\n",
" continue\n",
"\n",
" # 跳过对控制变量本身进行中性化\n",
" if col == mkt_cap_col or col == industry_col:\n",
" print(f\"警告: 特征列 '{col}' 是控制变量或内部使用的列,跳过中性化。\")\n",
" # 在结果 df 中也保持原样\n",
" processed_df[col] = df[col] if col in df.columns else np.nan\n",
" continue\n",
"\n",
" residual_series = pd.Series(index=df_cleaned.index, dtype=float)\n",
"\n",
" # 在分组前处理特征列的 NaN只对有因子值的行进行回归计算\n",
" df_subset_factor = df_cleaned.dropna(subset=[col]).copy()\n",
"\n",
" if not df_subset_factor.empty:\n",
" for industry, group in df_subset_factor.groupby(industry_col):\n",
" x = group[mkt_cap_col] # 市值对数\n",
" y = group[col] # 因子值\n",
"\n",
" # 确保有足够的数据点 (>1) 且市值对数有方差 (>0) 进行回归计算\n",
" # 检查 np.var > 一个很小的正数,避免浮点数误差导致的零方差判断问题\n",
" if len(group) > 1 and np.var(x) > 1e-9:\n",
" try:\n",
" beta = np.cov(y, x)[0, 1] / np.var(x)\n",
" alpha = np.mean(y) - beta * np.mean(x)\n",
"\n",
" # 计算残差\n",
" resid = y - (alpha + beta * x)\n",
"\n",
" # 将计算出的残差存储到 residual_series 中,通过索引自动对齐\n",
" residual_series.loc[resid.index] = resid\n",
"\n",
" except Exception as e:\n",
" # 捕获可能的计算异常例如np.cov或np.var因为极端数据报错\n",
" print(f\"警告: 在行业 {industry} 计算回归时发生错误: {e}。该组残差将设为原始值或 NaN。\")\n",
" # 此时该组的残差会保持 residual_series 初始化时的 NaN 或后续处理\n",
" # 也可以选择保留原始值residual_series.loc[group.index] = group[col]\n",
"\n",
" else:\n",
" residual_series.loc[group.index] = group[col] # 保留原始因子值\n",
" processed_df.loc[residual_series.index, col] = residual_series\n",
"\n",
"\n",
" else:\n",
" processed_df[col] = np.nan # 或 df[col] if col in df.columns else np.nan\n",
"\n",
" return processed_df\n",
"\n",
"\n",
"import gc\n",
"\n",
"gc.collect()\n",
"\n",
"\n",
"def mad_filter(df, features, n=3):\n",
" for col in features:\n",
" median = df[col].median()\n",
" mad = np.median(np.abs(df[col] - median))\n",
" upper = median + n * mad\n",
" lower = median - n * mad\n",
" df[col] = np.clip(df[col], lower, upper) # 截断极值\n",
" return df\n",
"\n",
"\n",
"def percentile_filter(df, features, lower_percentile=0.01, upper_percentile=0.99):\n",
" for col in features:\n",
" # 按日期分组计算上下百分位数\n",
" lower_bound = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.quantile(lower_percentile)\n",
" )\n",
" upper_bound = df.groupby('trade_date')[col].transform(\n",
" lambda x: x.quantile(upper_percentile)\n",
" )\n",
" # 截断超出范围的值\n",
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
" return df\n",
"\n",
"\n",
"from scipy.stats import iqr\n",
"\n",
"\n",
"def iqr_filter(df, features):\n",
" for col in features:\n",
" df[col] = df.groupby('trade_date')[col].transform(\n",
" lambda x: (x - x.median()) / iqr(x) if iqr(x) != 0 else x\n",
" )\n",
" return df\n",
"\n",
"\n",
"def quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99, window=60):\n",
" df = df.copy()\n",
" for col in features:\n",
" # 计算 rolling 统计量,需要按日期进行 groupby\n",
" rolling_lower = df.groupby('trade_date')[col].transform(lambda x: x.rolling(window=min(len(x), window)).quantile(lower_quantile))\n",
" rolling_upper = df.groupby('trade_date')[col].transform(lambda x: x.rolling(window=min(len(x), window)).quantile(upper_quantile))\n",
"\n",
" # 对数据进行裁剪\n",
" df[col] = np.clip(df[col], rolling_lower, rolling_upper)\n",
" \n",
" return df\n",
"\n",
"def select_top_features_by_rankic(df: pd.DataFrame, feature_columns: list, n: int, target_column: str = 'future_return') -> list:\n",
" \"\"\"\n",
" 计算给定特征与目标列的 RankIC并返回 RankIC 绝对值最高的 n 个特征。\n",
"\n",
" Args:\n",
" df: 包含特征列和目标列的 Pandas DataFrame。\n",
" feature_columns: 包含所有待评估特征列名的列表。\n",
" n: 希望选取的 RankIC 绝对值最高的特征数量。\n",
" target_column: 目标列的名称,用于计算 RankIC。默认为 'future_return'。\n",
"\n",
" Returns:\n",
" 包含 RankIC 绝对值最高的 n 个特征列名的列表。\n",
" \"\"\"\n",
" numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" if target_column not in df.columns:\n",
" raise ValueError(f\"目标列 '{target_column}' 不存在于 DataFrame 中。\")\n",
"\n",
" rankic_scores = {}\n",
" for feature in numeric_columns:\n",
" if feature not in df.columns:\n",
" print(f\"警告: 特征列 '{feature}' 不存在于 DataFrame 中,已跳过。\")\n",
" continue\n",
"\n",
" # 计算特征与目标列的 RankIC (斯皮尔曼相关系数)\n",
" # dropna() 是为了处理缺失值,确保相关性计算不失败\n",
" valid_data = df[[feature, target_column]].dropna()\n",
" if len(valid_data) > 1: # 确保有足够的数据点进行相关性计算\n",
" # 计算斯皮尔曼相关性\n",
" correlation = valid_data[feature].corr(valid_data[target_column], method='spearman')\n",
" rankic_scores[feature] = abs(correlation) # 使用绝对值来衡量相关性强度\n",
" else:\n",
" rankic_scores[feature] = 0 # 数据不足RankIC设为0或跳过\n",
"\n",
" # 将 RankIC 分数转换为 Series 便于排序\n",
" rankic_series = pd.Series(rankic_scores)\n",
"\n",
" # 按 RankIC 绝对值降序排序,选取前 n 个特征\n",
" # handle case where n might be larger than available features\n",
" n_actual = min(n, len(rankic_series))\n",
" top_features = rankic_series.sort_values(ascending=False).head(n_actual).index.tolist()\n",
" top_features = [col for col in feature_columns if col in top_features or col not in numeric_columns]\n",
" return top_features\n",
"\n",
"def create_deviation_within_dates(df, feature_columns):\n",
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
" new_columns = {}\n",
" ret_feature_columns = feature_columns[:]\n",
"\n",
" # 自动选择所有数值型特征\n",
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
"\n",
" # num_features = ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'cat_vol_spike', 'obv', 'maobv_6', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013']\n",
" num_features = [col for col in num_features if 'cat' not in col and 'industry' not in col]\n",
" num_features = [col for col in num_features if 'limit' not in col]\n",
" num_features = [col for col in num_features if 'cyq' not in col]\n",
"\n",
" # 遍历所有数值型特征\n",
" for feature in num_features:\n",
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
" continue\n",
"\n",
" # grouped_mean = df.groupby(['trade_date'])[feature].transform('mean')\n",
" # deviation_col_name = f'deviation_mean_{feature}'\n",
" # new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" # ret_feature_columns.append(deviation_col_name)\n",
"\n",
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
" deviation_col_name = f'deviation_mean_{feature}'\n",
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
" ret_feature_columns.append(deviation_col_name)\n",
"\n",
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
"\n",
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
"\n",
" return df, ret_feature_columns\n"
]
},
{
"cell_type": "code",
2025-06-01 15:59:29 +08:00
"execution_count": 13,
2025-05-26 21:34:36 +08:00
"id": "47c12bb34062ae7a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T14:57:50.841165Z",
"start_time": "2025-04-03T14:49:25.889057Z"
}
},
"outputs": [],
"source": [
"days = 5\n",
"validation_days = 120\n",
"\n",
"import gc\n",
"\n",
"gc.collect()\n",
"\n",
"df = df.sort_values(by=['ts_code', 'trade_date'])\n",
2025-06-04 13:50:02 +08:00
"# df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1) + df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-2 * days) / x - 1)\n",
"\n",
2025-05-26 21:34:36 +08:00
"# df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / \\\n",
"# df.groupby('ts_code')['open'].shift(-1)\n",
"\n",
2025-05-28 14:16:04 +08:00
"# df['cat_up_limit'] = df['pct_chg'] > 5\n",
2025-05-26 21:34:36 +08:00
"# df['label'] = (df.groupby('ts_code')['cat_up_limit']\n",
"# .rolling(window=5, min_periods=1).sum()\n",
"# .groupby('ts_code') # 再次按 ts_code 分组\n",
"# .shift(-5)\n",
"# .fillna(0) # 填充每个股票组最后的 NaN\n",
"# .astype(int)\n",
"# .reset_index(level=0, drop=True))\n",
2025-06-04 13:50:02 +08:00
"# df['label'] = df.groupby('trade_date', group_keys=False)['future_return'].transform(\n",
"# lambda x: pd.qcut(x, q=50, labels=False, duplicates='drop')\n",
"# )\n",
2025-05-29 20:41:18 +08:00
"# filter_index = df['future_return'].between(df['future_return'].quantile(0.01), df['future_return'].quantile(0.99))\n",
"filter_index = df['future_return'].between(df['future_return'].quantile(0.001), 0.6)\n",
2025-05-26 21:34:36 +08:00
"\n",
"# for col in [col for col in df.columns]:\n",
"# train_data[col] = train_data[col].astype('str')\n",
"# test_data[col] = test_data[col].astype('str')"
]
},
{
"cell_type": "code",
2025-06-01 15:59:29 +08:00
"execution_count": 14,
2025-05-26 21:34:36 +08:00
"id": "29221dde",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-01 15:59:29 +08:00
"200\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"feature_columns = [col for col in df.head(10)\n",
" .merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
" .merge(index_data, on='trade_date', how='left')\n",
" .columns\n",
" ]\n",
"feature_columns = [col for col in feature_columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'label' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if 'gen' not in col]\n",
"feature_columns = [col for col in feature_columns if 'is_st' not in col]\n",
"feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]\n",
"# feature_columns = [col for col in feature_columns if 'volatility' not in col]\n",
"# feature_columns = [col for col in feature_columns if 'circ_mv' not in col]\n",
"feature_columns = [col for col in feature_columns if 'code' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"# feature_columns = [col for col in feature_columns if col not in ['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']]\n",
"feature_columns = [col for col in feature_columns if col not in ['intraday_lg_flow_corr_20', \n",
" 'cap_neutral_cost_metric', \n",
" 'hurst_net_mf_vol_60', \n",
" 'complex_factor_deap_1', \n",
" 'lg_buy_consolidation_20',\n",
" 'cs_rank_ind_cap_neutral_pe',\n",
" 'cs_rank_opening_gap',\n",
" 'cs_rank_ind_adj_lg_flow']]\n",
"feature_columns = [col for col in feature_columns if col not in ['cat_reason', 'cat_is_on_top_list']]\n",
"print(len(feature_columns))"
]
},
{
"cell_type": "code",
2025-06-01 15:59:29 +08:00
"execution_count": 15,
2025-05-26 21:34:36 +08:00
"id": "03ee5daf",
"metadata": {},
"outputs": [],
"source": [
"# df = fill_nan_with_daily_median(df, feature_columns)\n",
"for feature_col in [col for col in feature_columns if col in df.columns]:\n",
" # median_val = df[feature_col].median()\n",
" df[feature_col].fillna(0, inplace=True)"
]
},
{
"cell_type": "code",
2025-06-01 15:59:29 +08:00
"execution_count": 16,
2025-05-26 21:34:36 +08:00
"id": "b76ea08a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ts_code trade_date log_circ_mv\n",
"0 000001.SZ 2019-01-02 16.574219\n",
"1 000001.SZ 2019-01-03 16.583965\n",
2025-06-04 13:50:02 +08:00
"2 000001.SZ 2019-01-04 16.633371\n",
2025-06-01 15:59:29 +08:00
"['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_ratio_20d', '00085
2025-05-26 21:34:36 +08:00
"去除极值\n",
"开始截面 MAD 去极值处理 (k=3.0)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-06-04 13:50:02 +08:00
"MAD Filtering: 100%|██████████| 139/139 [00:06<00:00, 21.25it/s]\n"
2025-05-26 21:34:36 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 MAD 去极值处理完成。\n",
"开始截面 MAD 去极值处理 (k=3.0)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-06-04 13:50:02 +08:00
"MAD Filtering: 100%|██████████| 139/139 [00:05<00:00, 26.68it/s]\n"
2025-05-26 21:34:36 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 MAD 去极值处理完成。\n",
"开始截面 MAD 去极值处理 (k=3.0)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"MAD Filtering: 0it [00:00, ?it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 MAD 去极值处理完成。\n",
"开始截面 MAD 去极值处理 (k=3.0)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"MAD Filtering: 0it [00:00, ?it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 MAD 去极值处理完成。\n",
2025-06-01 15:59:29 +08:00
"feature_columns: ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_r
2025-05-26 21:34:36 +08:00
"df最小日期: 2019-01-02\n",
2025-06-04 13:50:02 +08:00
"df最大日期: 2025-05-30\n",
"1091062\n",
2025-05-26 21:34:36 +08:00
"train_data最小日期: 2020-01-02\n",
"train_data最大日期: 2022-12-30\n",
2025-06-04 13:50:02 +08:00
"869968\n",
2025-05-26 21:34:36 +08:00
"test_data最小日期: 2023-01-03\n",
2025-06-04 13:50:02 +08:00
"test_data最大日期: 2025-05-30\n",
2025-05-26 21:34:36 +08:00
" ts_code trade_date log_circ_mv\n",
"0 000001.SZ 2019-01-02 16.574219\n",
"1 000001.SZ 2019-01-03 16.583965\n",
"2 000001.SZ 2019-01-04 16.633371\n"
]
}
],
"source": [
"split_date = '2023-01-01'\n",
2025-06-04 13:50:02 +08:00
"train_data = df[filter_index & (df['trade_date'] <= split_date) & (df['trade_date'] >= '2020-01-01')].groupby(\n",
" 'trade_date', group_keys=False).apply(lambda x: x.nsmallest(1500, 'total_mv'))\n",
"test_data = df[(df['trade_date'] >= split_date)].groupby(\n",
" 'trade_date', group_keys=False).apply(lambda x: x.nsmallest(1500, 'total_mv'))\n",
2025-05-26 21:34:36 +08:00
"\n",
"print(df[['ts_code', 'trade_date', 'log_circ_mv']].head(3))\n",
"\n",
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
"index_data = index_data.sort_values(by=['trade_date'])\n",
"\n",
"# train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
"# test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"\n",
2025-06-04 13:50:02 +08:00
"train_data['label'] = train_data.groupby('trade_date', group_keys=False)['future_return'].transform(\n",
" lambda x: pd.qcut(x, q=100, labels=False, duplicates='drop')\n",
")\n",
"test_data['label'] = test_data.groupby('trade_date', group_keys=False)['future_return'].transform(\n",
" lambda x: pd.qcut(x, q=100, labels=False, duplicates='drop')\n",
")\n",
"\n",
2025-05-26 21:34:36 +08:00
"# feature_columns_new = feature_columns[:]\n",
"# train_data, _ = create_deviation_within_dates(train_data, [col for col in feature_columns if col in train_data.columns])\n",
"# test_data, _ = create_deviation_within_dates(test_data, [col for col in feature_columns if col in train_data.columns])\n",
"\n",
"# feature_columns = [\n",
"# 'undist_profit_ps', \n",
"# 'AR_BR',\n",
"# 'pe_ttm',\n",
"# 'alpha_22_improved', \n",
"# 'alpha_003', \n",
"# 'alpha_007', \n",
"# 'alpha_013', \n",
"# 'cat_up_limit', \n",
"# 'cat_down_limit', \n",
"# 'up_limit_count_10d', \n",
"# 'down_limit_count_10d', \n",
"# 'consecutive_up_limit', \n",
"# 'vol_break', \n",
"# 'weight_roc5', \n",
"# 'price_cost_divergence', \n",
"# 'smallcap_concentration', \n",
"# 'cost_stability', \n",
"# 'high_cost_break_days', \n",
"# 'liquidity_risk', \n",
"# 'turnover_std', \n",
"# 'mv_volatility', \n",
"# 'volume_growth', \n",
"# 'mv_growth', \n",
"# 'lg_flow_mom_corr_20_60', \n",
"# 'lg_flow_accel', \n",
"# 'profit_pressure', \n",
"# 'underwater_resistance', \n",
"# 'cost_conc_std_20', \n",
"# 'profit_decay_20', \n",
"# 'vol_amp_loss_20', \n",
"# 'vol_drop_profit_cnt_5', \n",
"# 'lg_flow_vol_interact_20', \n",
"# 'cost_break_confirm_cnt_5', \n",
"# 'atr_norm_channel_pos_14', \n",
"# 'turnover_diff_skew_20', \n",
"# 'lg_sm_flow_diverge_20', \n",
"# 'pullback_strong_20_20', \n",
"# 'vol_wgt_hist_pos_20', \n",
"# 'vol_adj_roc_20',\n",
"# 'cashflow_to_ev_factor',\n",
"# 'ocfps',\n",
"# 'book_to_price_ratio',\n",
"# 'turnover_rate_mean_5',\n",
"# 'variance_20',\n",
"# 'bbi_ratio_factor'\n",
"# ]\n",
"# feature_columns = [col for col in feature_columns if col in train_data.columns]\n",
"# feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"\n",
"numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns\n",
"numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
"# feature_columns = select_top_features_by_rankic(df, numeric_columns, n=10)\n",
"print(feature_columns)\n",
"\n",
"# train_data = fill_nan_with_daily_median(train_data, feature_columns)\n",
"# test_data = fill_nan_with_daily_median(test_data, feature_columns)\n",
"\n",
"train_data = train_data.dropna(subset=[col for col in feature_columns if col in train_data.columns])\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"# print(test_data.tail())\n",
"test_data = test_data.dropna(subset=[col for col in feature_columns if col in train_data.columns])\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"transform_feature_columns = feature_columns\n",
"transform_feature_columns = [col for col in transform_feature_columns if col in feature_columns and not col.startswith('cat') and col in train_data.columns]\n",
"# transform_feature_columns.remove('undist_profit_ps')\n",
"print('去除极值')\n",
"cs_mad_filter(train_data, transform_feature_columns)\n",
"# print('中性化')\n",
"# cs_neutralize_industry_cap(train_data, transform_feature_columns)\n",
"# print('标准化')\n",
"# cs_zscore_standardize(train_data, transform_feature_columns)\n",
"\n",
"cs_mad_filter(test_data, transform_feature_columns)\n",
"# cs_neutralize_industry_cap(test_data, transform_feature_columns)\n",
"# cs_zscore_standardize(test_data, transform_feature_columns)\n",
"\n",
"mad_filter_feature_columns = [col for col in feature_columns if col not in transform_feature_columns and not col.startswith('cat') and col in train_data.columns]\n",
"cs_mad_filter(train_data, mad_filter_feature_columns)\n",
"cs_mad_filter(test_data, mad_filter_feature_columns)\n",
"\n",
"\n",
"print(f'feature_columns: {feature_columns}')\n",
"\n",
"\n",
"print(f\"df最小日期: {df['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"df最大日期: {df['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(train_data))\n",
"print(f\"train_data最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"train_data最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"test_data最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"test_data最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"\n",
"cat_columns = [col for col in feature_columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')\n",
"\n",
"print(df[['ts_code', 'trade_date', 'log_circ_mv']].head(3))\n"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 106,
2025-05-26 21:34:36 +08:00
"id": "3ff2d1c5",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.linear_model import LogisticRegression\n",
"import matplotlib.pyplot as plt # 保持 matplotlib 导入尽管LightGBM的绘图功能已移除\n",
"from sklearn.decomposition import PCA\n",
"import pandas as pd\n",
"import numpy as np\n",
"import datetime # 用于日期计算\n",
"from catboost import CatBoostClassifier, CatBoostRanker, CatBoostRegressor\n",
"from catboost import Pool\n",
"import lightgbm as lgb\n",
"from lightgbm import LGBMRanker, LGBMRegressor\n",
"\n",
"def train_model(train_data_df, feature_columns,\n",
" print_info=True, # 调整参数名,更通用\n",
" validation_days=180, use_pca=False, split_date=None,\n",
" target_column='label', type='light'): # 增加目标列参数\n",
"\n",
" print('train data size: ', len(train_data_df))\n",
" print(train_data_df[['ts_code', 'trade_date', 'log_circ_mv']])\n",
" # 确保数据按时间排序\n",
" train_data_df = train_data_df.sort_values(by='trade_date')\n",
"\n",
" # 识别数值型特征列\n",
" numeric_feature_columns = train_data_df[feature_columns].select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
" # 去除标签为空的样本\n",
" initial_len = len(train_data_df)\n",
" train_data_df = train_data_df.dropna(subset=[target_column])\n",
"\n",
" if print_info:\n",
" print(f'原始样本数: {initial_len}, 去除标签为空后样本数: {len(train_data_df)}')\n",
"\n",
" # 提取特征和标签,只取数值型特征用于线性回归\n",
" \n",
" if split_date is None:\n",
" all_dates = train_data_df['trade_date'].unique() # 获取所有唯一的 trade_date\n",
" split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
" train_data_split = train_data_df[train_data_df['trade_date'] < split_date] # 训练集\n",
" val_data_split = train_data_df[train_data_df['trade_date'] >= split_date] # 验证集\n",
"\n",
" train_data_split = train_data_split.sort_values('trade_date')\n",
" val_data_split = val_data_split.sort_values('trade_date')\n",
"\n",
" \n",
" X_train = train_data_split[feature_columns]\n",
" y_train = train_data_split[target_column]\n",
" \n",
" X_val = val_data_split[feature_columns]\n",
" y_val = val_data_split[target_column]\n",
"\n",
"\n",
" # # 标准化数值特征 (使用 StandardScaler 对训练集fit并transform, 对验证集只transform)\n",
" scaler = StandardScaler()\n",
" # X_train = scaler.fit_transform(X_train)\n",
"\n",
" # 训练线性回归模型\n",
" # model = LogisticRegression(random_state=42)\n",
" \n",
" # # 使用处理后的特征和样本权重进行训练\n",
" # model.fit(X_train, y_train)\n",
"\n",
"\n",
" if type == 'cat':\n",
" params = {\n",
2025-06-04 13:50:02 +08:00
" 'loss_function': 'QueryRMSE', # 适用于二分类\n",
2025-05-26 21:34:36 +08:00
" 'eval_metric': 'NDCG', # 评估指标\n",
" 'iterations': 1500,\n",
2025-05-29 20:41:18 +08:00
" 'learning_rate': 0.03,\n",
2025-06-01 15:59:29 +08:00
" 'depth': 8, # 控制模型复杂度\n",
" 'l2_leaf_reg': 1, # L2 正则化\n",
2025-05-26 21:34:36 +08:00
" 'verbose': 5000,\n",
" 'early_stopping_rounds': 300,\n",
" 'one_hot_max_size': 50,\n",
" # 'class_weights': [0.6, 1.2],\n",
" 'task_type': 'GPU',\n",
" 'has_time': True,\n",
" 'random_seed': 7\n",
" }\n",
" cat_features = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" group_train = train_data_split['trade_date'].factorize()[0]\n",
" group_val = val_data_split['trade_date'].factorize()[0]\n",
" train_pool = Pool(\n",
" data=X_train,\n",
" label=y_train,\n",
" group_id=group_train,\n",
" cat_features=cat_features\n",
" )\n",
" val_pool = Pool(\n",
" data=X_val,\n",
" label=y_val,\n",
" group_id=group_val,\n",
" cat_features=cat_features\n",
" )\n",
"\n",
"\n",
" model = CatBoostRanker(**params)\n",
" model.fit(train_pool,\n",
" eval_set=val_pool, \n",
" plot=True, \n",
" use_best_model=True\n",
" )\n",
" elif type == 'light':\n",
" label_gain = list(range(len(train_data_split['label'].unique())))\n",
2025-06-04 13:50:02 +08:00
" \n",
2025-05-26 21:34:36 +08:00
" params = {\n",
" 'label_gain': [gain * gain for gain in label_gain],\n",
" 'objective': 'lambdarank',\n",
2025-06-04 13:50:02 +08:00
" 'metric': 'ndcg',\n",
2025-05-29 20:41:18 +08:00
" 'learning_rate': 0.05,\n",
2025-06-04 13:50:02 +08:00
" 'num_leaves': 1024,\n",
" 'min_data_in_leaf': 256,\n",
" # 'max_depth': 10,\n",
" 'max_bin': 1024,\n",
" 'feature_fraction': 0.5,\n",
" 'bagging_fraction': 0.5,\n",
2025-05-26 21:34:36 +08:00
" 'bagging_freq': 5,\n",
2025-06-04 13:50:02 +08:00
" 'lambda_l1': 5,\n",
" 'lambda_l2': 50,\n",
2025-05-26 21:34:36 +08:00
" 'boosting': 'gbdt',\n",
" 'verbosity': -1,\n",
" 'extra_trees': True,\n",
" # 'max_position': 5,\n",
2025-06-04 13:50:02 +08:00
" 'ndcg_at': '5',\n",
2025-05-26 21:34:36 +08:00
" 'quant_train_renew_leaf': True,\n",
2025-06-04 13:50:02 +08:00
" 'lambdarank_truncation_level': 10,\n",
2025-05-29 20:41:18 +08:00
" 'lambdarank_position_bias_regularization': 1,\n",
2025-05-26 21:34:36 +08:00
" 'seed': 7\n",
" }\n",
2025-06-04 13:50:02 +08:00
" feature_contri = [2 if feat.startswith('act_factor') or 'buy' in feat or 'sell' in feat else 1 for feat in feature_columns]\n",
" params['feature_contri'] = feature_contri\n",
"\n",
2025-05-26 21:34:36 +08:00
" train_groups = train_data_split.groupby('trade_date').size().tolist()\n",
" val_groups = val_data_split.groupby('trade_date').size().tolist()\n",
"\n",
" categorical_feature = [col for col in feature_columns if 'cat' in col]\n",
" train_dataset = lgb.Dataset(\n",
" X_train, label=y_train, \n",
" group=train_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
" val_dataset = lgb.Dataset(\n",
" X_val, label=y_val, \n",
" group=val_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
"\n",
" evals = {}\n",
" callbacks = [lgb.log_evaluation(period=1000),\n",
" lgb.callback.record_evaluation(evals),\n",
2025-06-04 13:50:02 +08:00
" # lgb.early_stopping(300, first_metric_only=False)\n",
2025-05-26 21:34:36 +08:00
" ]\n",
2025-05-28 14:16:04 +08:00
" # 训练模型\n",
" model = lgb.train(\n",
2025-06-04 13:50:02 +08:00
" params, train_dataset, num_boost_round=1000,\n",
2025-05-28 14:16:04 +08:00
" valid_sets=[train_dataset, val_dataset], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if True:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
"\n",
" # from flaml import AutoML\n",
" # from sklearn.datasets import fetch_california_housing\n",
"\n",
" # # Initialize an AutoML instance\n",
" # model = AutoML()\n",
" # # Specify automl goal and constraint\n",
" # automl_settings = {\n",
" # \"time_budget\": 600, # in seconds\n",
" # \"metric\": \"ndcg@1\",\n",
" # \"task\": \"rank\",\n",
" # \"estimator_list\": [\n",
" # \"catboost\",\n",
" # \"lgbm\",\n",
" # \"xgboost\"\n",
" # ], \n",
" # \"ensemble\": {\n",
" # \"final_estimator\": LGBMRanker(),\n",
" # \"passthrough\": False,\n",
" # },\n",
" # }\n",
" # model.fit(X_train=X_train, y_train=y_train, groups=train_groups,\n",
" # X_val=X_val, y_val=y_val,groups_val=val_groups,\n",
" # mlflow_logging=False, **automl_settings)\n",
2025-05-26 21:34:36 +08:00
"\n",
"\n",
" return model, scaler, None # 返回训练好的模型、scaler 和 pca 对象"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": 107,
2025-05-26 21:34:36 +08:00
"id": "c6eb5cd4-e714-420a-ac48-39af3e11ee81",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T15:03:18.426481Z",
"start_time": "2025-04-03T15:02:19.926352Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-04 13:50:02 +08:00
"train data size: 1091062\n",
2025-06-01 15:59:29 +08:00
" ts_code trade_date log_circ_mv\n",
"0 600306.SH 2020-01-02 11.552040\n",
"1 603269.SH 2020-01-02 11.324801\n",
"2 002633.SZ 2020-01-02 11.759023\n",
"3 603991.SH 2020-01-02 11.181150\n",
"4 000691.SZ 2020-01-02 11.677910\n",
"... ... ... ...\n",
2025-06-04 13:50:02 +08:00
"1091057 603533.SH 2022-12-30 13.362893\n",
"1091058 603416.SH 2022-12-30 13.364553\n",
"1091059 002277.SZ 2022-12-30 13.364740\n",
"1091060 002140.SZ 2022-12-30 13.086924\n",
"1091061 002374.SZ 2022-12-30 13.347147\n",
2025-05-26 21:34:36 +08:00
"\n",
2025-06-04 13:50:02 +08:00
"[1091062 rows x 3 columns]\n",
"原始样本数: 1091062, 去除标签为空后样本数: 1091062\n",
"[1000]\ttrain's ndcg@5: 0.667175\tvalid's ndcg@5: 0.357008\n"
2025-05-26 21:34:36 +08:00
]
2025-05-28 14:16:04 +08:00
},
{
"data": {
2025-06-04 13:50:02 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAHHCAYAAABEEKc/AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAb2xJREFUeJzt3Xd4U9X/B/B3kmZ070XpgrI3LRsZUmSJDAcCslQcgIqIA1EQUHB9EVzgQvwpCKKIyrSUJRsKZZUNHUB36V5pcn9/HJo2HdCWtknb9+t5eGjuSD7JaZN3zj3nXpkkSRKIiIiIyEBu6gKIiIiIzA0DEhEREVEJDEhEREREJTAgEREREZXAgERERERUAgMSERERUQkMSEREREQlMCARERERlcCARERERFQCAxIRYfXq1ZDJZIiMjKyxx3jvvfcgk8nqzP2aWmRkJGQyGVavXl2l/WUyGd57771qrYmoIWFAIqpFhUFEJpNh//79pdZLkgRvb2/IZDI8/PDDVXqMr7/+usofqlQ5a9euxbJly0xdBhHVAAYkIhPQaDRYu3ZtqeV79+7FjRs3oFarq3zfVQlIEyZMQE5ODnx9fav8uKbyzjvvICcnxySPXZMBydfXFzk5OZgwYUKV9s/JycE777xTzVURNRwMSEQmMHToUGzYsAEFBQVGy9euXYvAwEB4eHjUSh1ZWVkAAIVCAY1GU6cOVRXWbmFhAY1GY+Jq7i03Nxd6vb7C28tkMmg0GigUiio9nkajgYWFRZX2JSIGJCKTGDt2LJKTkxESEmJYlp+fj99//x3jxo0rcx+9Xo9ly5ahTZs20Gg0cHd3x/PPP4/bt28btvHz88O5c+ewd+9ew6G8fv36ASg6vLd3715MmzYNbm5uaNy4sdG6kmOQtm3bhr59+8LW1hZ2dnbo0qVLmT1fJe3fvx9dunSBRqNB06ZN8c0335Ta5m5jbEqOnykcZxQREYFx48bB0dERvXv3NlpXcv8ZM2Zg06ZNaNu2LdRqNdq0aYPt27eXeqw9e/YgKCjIqNaKjGvq168ftmzZgqioKMNr7efnZ7hPmUyGdevW4Z133oGXlxesrKyQnp6OlJQUzJ49G+3atYONjQ3s7OwwZMgQnDp16p6vz+TJk2FjY4ObN29i5MiRsLGxgaurK2bPng2dTleh1/DKlSuYPHkyHBwcYG9vjylTpiA7O9to35ycHLz88stwcXGBra0tHnnkEdy8eZPjmqhB4dcLIhPw8/NDjx498Ouvv2LIkCEARBhJS0vDk08+ic8//7zUPs8//zxWr16NKVOm4OWXX8b169fx5Zdf4uTJkzhw4ACUSiWWLVuGl156CTY2Npg7dy4AwN3d3eh+pk2bBldXV8ybN8/QC1OW1atX4+mnn0abNm0wZ84cODg44OTJk9i+fXu5IQ4Azpw5g4ceegiurq547733UFBQgPnz55eqoyoef/xxNGvWDIsXL4YkSXfddv/+/di4cSOmTZsGW1tbfP7553j00UcRHR0NZ2dnAMDJkycxePBgeHp6YsGCBdDpdFi4cCFcXV3vWcvcuXORlpaGGzdu4LPPPgMA2NjYGG2zaNEiqFQqzJ49G3l5eVCpVIiIiMCmTZvw+OOPw9/fH/Hx8fjmm2/Qt29fREREoFGjRnd9XJ1Oh0GDBqFbt2749NNPsXPnTvzvf/9D06ZN8eKLL96z7ieeeAL+/v5YsmQJTpw4ge+//x5ubm746KOPDNtMnjwZv/32GyZMmIDu3btj7969GDZs2D3vm6hekYio1vz4448SAOnYsWPSl19+Kdna2krZ2dmSJEnS448/LvXv31+SJEny9fWVhg0bZtjvv//+kwBIa9asMbq/7du3l1repk0bqW/fvuU+du/evaWCgoIy112/fl2SJElKTU2VbG1tpW7dukk5OTlG2+r1+rs+x5EjR0oajUaKiooyLIuIiJAUCoVU/C3n+vXrEgDpxx9/LHUfAKT58+cbbs+fP18CII0dO7bUtoXrSu6vUqmkK1euGJadOnVKAiB98cUXhmXDhw+XrKyspJs3bxqWXb58WbKwsCh1n2UZNmyY5OvrW2r57t27JQBSkyZNDO1bKDc3V9LpdEbLrl+/LqnVamnhwoVGy0q+PpMmTZIAGG0nSZLUqVMnKTAwsNRrUNZr+PTTTxttN2rUKMnZ2dlwOywsTAIgzZw502i7yZMnl7pPovqMh9iITOSJJ55ATk4ONm/ejIyMDGzevLncnpkNGzbA3t4eAwcORFJSkuFfYGAgbGxssHv37go/7tSpU+85riUkJAQZGRl46623So3vuduhJ51Ohx07dmDkyJHw8fExLG/VqhUGDRpU4RrL88ILL1R42+DgYDRt2tRwu3379rCzs8O1a9cMte7cuRMjR4406rUJCAgw9Ordr0mTJsHS0tJomVqthlwuN9SQnJwMGxsbtGjRAidOnKjQ/ZZ8HR544AHD86rKvsnJyUhPTwcAw2HIadOmGW330ksvVej+ieoLHmIjMhFXV1cEBwdj7dq1yM7Ohk6nw2OPPVbmtpcvX0ZaWhrc3NzKXJ+QkFDhx/X397/nNlevXgUAtG3btsL3CwCJiYnIyclBs2bNSq1r0aIFtm7dWqn7K6kitRcqHtAKOTo6GsZsJSQkICcnBwEBAaW2K2tZVZRVr16vx/Lly/H111/j+vXrRmOHCg/93Y1Goyl1CLD487qXkq+Lo6MjAOD27duws7NDVFQU5HJ5qdqr6zUhqisYkIhMaNy4cZg6dSri4uIwZMgQODg4lLmdXq+Hm5sb1qxZU+b6ioyZKVSyR8NUyuuJKjnYuLjK1F5eL5l0j7FL1amsehcvXox3330XTz/9NBYtWgQnJyfI5XLMnDmzQrPcqjqr7V771+brQlQXMCARmdCoUaPw/PPP4/Dhw1i/fn252zVt2hQ7d+5Er1697hkSqmOqfuGhqbNnz1aq58DV1RWWlpa4fPlyqXUXL140ul3Yc5Gammq0PCoqqpLVVo2bmxs0Gg2uXLlSal1Zy8pSldf6999/R//+/fHDDz8YLU9NTYWLi0ul76+6+fr6Qq/X4/r160Y9gRV9TYjqC45BIjIhGxsbrFixAu+99x6GDx9e7nZPPPEEdDodFi1aVGpdQUGBUciwtrYuFToq66GHHoKtrS2WLFmC3Nxco3V362lQKBQYNGgQNm3ahOjoaMPy8+fPY8eOHUbb2tnZwcXFBfv27TNa/vXXX99X7RWlUCgQHByMTZs24datW4blV65cwbZt2yp0H9bW1khLS6v045Z8DTds2ICbN29W6n5qSuFYsZLt8MUXX5iiHCKTYQ8SkYlNmjTpntv07dsXzz//PJYsWYLw8HA89NBDUCqVuHz5MjZs2IDly5cbxi8FBgZixYoVeP/99xEQEAA3Nzc8+OCDlarJzs4On332GZ599ll06dLFcO6hU6dOITs7Gz/99FO5+y5YsADbt2/HAw88gGnTpqGgoABffPEF2rRpg9OnTxtt++yzz+LDDz/Es88+i6CgIOzbtw+XLl2qVK3347333sO///6LXr164cUXX4ROp8OXX36Jtm3bIjw8/J77BwYGYv369Zg1axa6dOkCGxubuwZdAHj44YexcOFCTJkyBT179sSZM2ewZs0aNGnSpJqe1f0JDAzEo48+imXLliE5Odkwzb+wXerSyUSJ7gcDElEdsXLlSgQGBuKbb77B22+/DQsLC/j5+eGpp55Cr169DNvNmzcPUVFR+Pjjj5GRkYG+fftWOiABwDPPPAM3Nzd8+OGHWLRoEZRKJVq2bIlXX331rvu1b98eO3bswKxZszBv3jw0btwYCxYsQGxsbKmANG/ePCQmJuL333/Hb7/9hiFDhmDbtm3lDkavboGBgdi2bRtmz56Nd999F97e3li4cCHOnz+PCxcu3HP/adOmITw8HD/++CM+++wz+Pr63jMgvf3228jKysLatWuxfv16dO7cGVu2bMFbb71VXU/rvv3f//0fPDw88Ouvv+LPP/9EcHAw1q9fjxYtWtSJs5YTVQeZxJF5RERGRo4ciXPnzpU5lqqhCg8PR6dOnfD
2025-05-28 14:16:04 +08:00
"text/plain": [
2025-06-04 13:50:02 +08:00
"<Figure size 640x480 with 1 Axes>"
2025-05-28 14:16:04 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
2025-06-04 13:50:02 +08:00
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAvAAAAHHCAYAAADZMWzyAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAA+ktJREFUeJzs3XlcTun/+PHX3b5TYSoiS3YS2bcQKjKWGesgYxlLRozEGFRj3xnLzBhkD2OEkSXGMmPPYOyGkRgZy2dIMmm5f3/063zdWlRK7tv7+XjcD53rXOc67/dd6rqvc53rqNRqtRohhBBCCCGEVtAr7ACEEEIIIYQQOScdeCGEEEIIIbSIdOCFEEIIIYTQItKBF0IIIYQQQotIB14IIYQQQggtIh14IYQQQgghtIh04IUQQgghhNAi0oEXQgghhBBCi0gHXgghhBBCCC0iHXghhBCiEIWGhqJSqYiOji7sUIQQWkI68EIIId6q9A5rZq+xY8cWyDmPHj1KUFAQjx8/LpD232cJCQkEBQVx8ODBwg5FiPeGQWEHIIQQ4v0UEhJC2bJlNcqqV69eIOc6evQowcHB+Pr6UrRo0QI5R1717t2b7t27Y2xsXNih5ElCQgLBwcEAuLu7F24wQrwnpAMvhBCiUHh5eeHm5lbYYbyRZ8+eYW5u/kZt6Ovro6+vn08RvT2pqam8ePGisMMQ4r0kU2iEEEK8k3bt2kXTpk0xNzfH0tKSdu3acfHiRY06f/zxB76+vpQrVw4TExPs7Oz49NNPefTokVInKCiIgIAAAMqWLatM14mOjiY6OhqVSkVoaGiG86tUKoKCgjTaUalUXLp0iZ49e2JtbU2TJk2U/WvXrqVOnTqYmppiY2ND9+7duX379mvzzGwOvJOTE+3bt+fgwYO4ublhampKjRo1lGkqP/30EzVq1MDExIQ6depw5swZjTZ9fX2xsLDgr7/+om3btpibm+Pg4EBISAhqtVqj7rNnz/jiiy9wdHTE2NiYSpUqMXv27Az1VCoVfn5+rFu3jmrVqmFsbMy3335L8eLFAQgODlbe2/T3LSffn5ff2+vXrytXSYoUKUK/fv1ISEjI8J6tXbuWevXqYWZmhrW1Nc2aNWPv3r0adXLy8yOEtpIReCGEEIXiyZMnPHz4UKOsWLFiAKxZs4a+ffvStm1bZsyYQUJCAkuXLqVJkyacOXMGJycnACIjI/nrr7/o168fdnZ2XLx4ke+//56LFy9y/PhxVCoVnTt35tq1a2zYsIF58+Yp5yhevDgPHjzIddwff/wxzs7OTJ06VenkTpkyhQkTJtC1a1cGDBjAgwcP+Oabb2jWrBlnzpzJ07Sd69ev07NnTz777DM++eQTZs+ejY+PD99++y1ffvklQ4cOBWDatGl07dqVq1evoqf3f+NyKSkpeHp60qBBA2bOnMnu3buZNGkSycnJhISEAKBWq+nQoQMHDhygf//+1KpViz179hAQEMDff//NvHnzNGL65Zdf2LRpE35+fhQrVgwXFxeWLl3KkCFD6NSpE507dwagZs2aQM6+Py/r2rUrZcuWZdq0afz+++/88MMPlChRghkzZih1goODCQoKolGjRoSEhGBkZMSJEyf45ZdfaNOmDZDznx8htJZaCCGEeItWrlypBjJ9qdVq9dOnT9VFixZVDxw4UOO4e/fuqYsUKaJRnpCQkKH9DRs2qAH14cOHlbJZs2apAfXNmzc16t68eVMNqFeuXJmhHUA9adIkZXvSpElqQN2jRw+NetHR0Wp9fX31lClTNMrPnz+vNjAwyFCe1fvxcmxlypRRA+qjR48qZXv27FEDalNTU/WtW7eU8u+++04NqA8cOKCU9e3bVw2ohw8frpSlpqaq27VrpzYyMlI/ePBArVar1eHh4WpAPXnyZI2YPvroI7VKpVJfv35d4/3Q09NTX7x4UaPugwcPMrxX6XL6/Ul/bz/99FONup06dVLb2toq23/++adaT09P3alTJ3VKSopG3dTUVLVanbufHyG0lUyhEUIIUSgWL15MZGSkxgvSRm0fP35Mjx49ePjwofLS19enfv36HDhwQGnD1NRU+fq///7j4cOHNGjQAIDff/+9QOIePHiwxvZPP/1EamoqXbt21YjXzs4OZ2dnjXhzo2rVqjRs2FDZrl+/PgAtW7akdOnSGcr/+uuvDG34+fkpX6dPgXnx4gX79u0DICIiAn19fT7//HON47744gvUajW7du3SKG/evDlVq1bNcQ65/f68+t42bdqUR48eERcXB0B4eDipqalMnDhR42pDen6Qu58fIbSVTKERQghRKOrVq5fpTax//vknkNZRzYyVlZXy9f/+9z+Cg4MJCwvj/v37GvWePHmSj9H+n1dXzvnzzz9Rq9U4OztnWt/Q0DBP53m5kw5QpEgRABwdHTMt//fffzXK9fT0KFeunEZZxYoVAZT59rdu3cLBwQFLS0uNelWqVFH2v+zV3F8nt9+fV3O2trYG0nKzsrLixo0b6OnpZfshIjc/P0JoK+nACyGEeKekpqYCafOY7ezsMuw3MPi/P11du3bl6NGjBAQEUKtWLSwsLEhNTcXT01NpJzuvzsFOl5KSkuUxL48qp8erUqnYtWtXpqvJWFhYvDaOzGS1Mk1W5epXbjotCK/m/jq5/f7kR265+fkRQlvJT7EQQoh3Svny5QEoUaIEHh4eWdb7999/2b9/P8HBwUycOFEpTx+BfVlWHfX0Ed5XH/D06sjz6+JVq9WULVtWGeF+F6SmpvLXX39pxHTt2jUA5SbOMmXKsG/fPp4+faoxCn/lyhVl/+tk9d7m5vuTU+XLlyc1NZVLly5Rq1atLOvA639+hNBmMgdeCCHEO6Vt27ZYWVkxdepUkpKSMuxPXzkmfbT21dHZ+fPnZzgmfa32VzvqVlZWFCtWjMOHD2uUL1myJMfxdu7cGX19fYKDgzPEolarMyyZ+DYtWrRII5ZFixZhaGhIq1atAPD29iYlJUWjHsC8efNQqVR4eXm99hxmZmZAxvc2N9+fnOrYsSN6enqEhIRkGMFPP09Of36E0GYyAi+EEOKdYmVlxdKlS+nduze1a9eme/fuFC9enJiYGHbu3Enjxo1ZtGgRVlZWNGvWjJkzZ5KUlETJkiXZu3cvN2/ezNBmnTp1ABg/fjzdu3fH0NAQHx8fzM3NGTBgANOnT2fAgAG4ublx+PBhZaQ6J8qXL8/kyZMZN24c0dHRdOzYEUtLS27evMnWrVsZNGgQo0ePzrf3J6dMTEzYvXs3ffv2pX79+uzatYudO3fy5ZdfKmu3+/j40KJFC8aPH090dDQuLi7s3buXbdu24e/vr4xmZ8fU1JSqVauyceNGKlasiI2NDdWrV6d69eo5/v7kVIUKFRg/fjxff/01TZs2pXPnzhgbG3Pq1CkcHByYNm1ajn9+hNBqhbT6jRBCiPdU+rKJp06dyrbegQMH1G3btlUXKVJEbWJioi5fvrza19dXHRUVpdS5c+eOulOnTuqiRYuqixQpov7444/Vd+/ezXRZw6+//lpdsmRJtZ6ensayjQkJCer+/furixQpora0tFR37dpVff/+/SyXkUxfgvFVW7ZsUTdp0kRtbm6uNjc3V1euXFk9bNgw9dWrV3P0fry6jGS7du0y1AXUw4YN0yhLXwpz1qxZSlnfvn3V5ubm6hs3bqjbtGmjNjMzU3/wwQfqSZMmZVh+8enTp+qRI0eqHRwc1IaGhmpnZ2f1rFmzlGUZszt3uqNHj6rr1KmjNjIy0njfcvr9yeq9zey9UavV6hUrVqhdXV3VxsbGamtra3Xz5s3VkZGRGnVy8vMjhLZSqdVv4a4XIYQQQrw1vr6+/Pjjj8THxxd2KEKIAiBz4IUQQgghhNAi0oEXQgghhBBCi0gHXgghhBBCCC0ic+CFEEIIIYTQIjICL4QQQgghhBaRDrwQQgghhBBaRB7kJIQOSk1N5e7du1haWmb5mHMhhBBCvFvUajVPnz7FwcEBPb2sx9mlAy+EDrp79y6Ojo6FHYYQQggh8uD27duUKlUqy/3SgRdCB1laWgJw8+ZNbGxsCjm
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2025-05-26 21:34:36 +08:00
}
],
"source": [
"\n",
"gc.collect()\n",
"\n",
"use_pca = False\n",
2025-06-04 13:50:02 +08:00
"type = 'light'\n",
2025-05-26 21:34:36 +08:00
"# feature_contri = [2 if feat.startswith('act_factor') or 'buy' in feat or 'sell' in feat else 1 for feat in feature_columns]\n",
"# light_params['feature_contri'] = feature_contri\n",
"# print(f'feature_contri: {feature_contri}')\n",
"model, scaler, pca = train_model(train_data\n",
" .dropna(subset=['label']).groupby('trade_date', group_keys=False)\n",
2025-06-04 13:50:02 +08:00
" .apply(lambda x: x.nsmallest(3000, 'total_mv'))\n",
2025-05-26 21:34:36 +08:00
" .merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
" .merge(index_data, on='trade_date', how='left'), \n",
" feature_columns, type=type, target_column='label')\n"
]
},
{
"cell_type": "code",
2025-06-01 15:59:29 +08:00
"execution_count": null,
2025-05-26 21:34:36 +08:00
"id": "5d1522a7538db91b",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T15:04:39.656944Z",
"start_time": "2025-04-03T15:04:39.298483Z"
}
},
"outputs": [],
"source": [
2025-06-04 13:50:02 +08:00
"score_df = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(3000, 'total_mv'))\n",
2025-05-26 21:34:36 +08:00
"# score_df = fill_nan_with_daily_median(score_df, ['pe_ttm'])\n",
"# score_df = score_df[score_df['pe_ttm'] > 0]\n",
"score_df = score_df.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"score_df = score_df.merge(index_data, on='trade_date', how='left')\n",
"# score_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(50, 'total_mv')).reset_index()\n",
"numeric_columns = score_df.select_dtypes(include=['float64', 'int64']).columns\n",
"numeric_columns = [col for col in feature_columns if col in numeric_columns]\n",
"\n",
"if type == 'cat':\n",
" score_df['score'] = model.predict(score_df[feature_columns])\n",
"elif type == 'light':\n",
" score_df['score'] = model.predict(score_df[feature_columns])\n",
"score_df['score_ranks'] = score_df.groupby('trade_date')['score'].rank(ascending=True)\n",
"\n",
"score_df = score_df.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: \n",
" x[\n",
" # (x['score'] <= x['score'].quantile(0.99)) & \n",
" (x['score'] >= x['score'].quantile(0.90))\n",
" ] # 计算90%分位数作为阈值,筛选分数>=阈值的行\n",
").reset_index(drop=True) # drop=True 避免添加旧索引列\n",
"# df_to_drop = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
"# score_df = score_df.drop(df_to_drop.index)\n",
2025-06-01 15:59:29 +08:00
"save_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(5, 'score')).reset_index()\n",
2025-05-26 21:34:36 +08:00
"# save_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(2, 'total_mv')).reset_index()\n",
"save_df = save_df.sort_values(['trade_date', 'score'])\n",
2025-06-04 13:50:02 +08:00
"save_df[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
2025-06-01 15:59:29 +08:00
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": null,
2025-06-01 15:59:29 +08:00
"id": "1f3c1331",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"成功连接到 Redis 服务器: 140.143.91.66:6389数据库 0\n",
"DataFrame 已使用 Pickle 序列化并写入 Redis键为 'save_df'\n",
"从 Redis 读取到的 Pickle 序列化数据 (前 20 字节):\n",
2025-06-04 13:50:02 +08:00
"b'\\x80\\x04\\x95\\xbf\\x04\\x01\\x00\\x00\\x00\\x00\\x00\\x8c\\x11pandas.'\n",
2025-06-01 15:59:29 +08:00
"\n",
"从 Redis 加载的 DataFrame (使用 Pickle):\n",
" index ts_code trade_date open close high low vol \\\n",
2025-06-04 13:50:02 +08:00
"4 25 002802.SZ 2023-01-03 21.53 22.14 22.35 21.31 15974.43 \n",
"3 20 600235.SH 2023-01-03 12.23 12.41 12.50 12.23 35791.00 \n",
"2 13 002691.SZ 2023-01-03 9.13 9.32 9.37 9.13 35566.80 \n",
"1 35 603779.SH 2023-01-03 9.74 9.88 10.16 9.69 78641.45 \n",
"0 131 600228.SH 2023-01-03 15.68 15.75 15.79 15.60 17076.00 \n",
2025-06-01 15:59:29 +08:00
"... ... ... ... ... ... ... ... ... \n",
2025-06-04 13:50:02 +08:00
"2904 87098 000931.SZ 2025-05-30 8.02 7.96 8.13 7.92 170963.38 \n",
"2903 87001 605122.SH 2025-05-30 16.86 16.73 17.11 16.68 20146.40 \n",
"2902 87041 002942.SZ 2025-05-30 24.10 23.95 24.34 23.84 10821.00 \n",
"2901 87122 603170.SH 2025-05-30 14.69 14.34 14.83 14.32 26536.80 \n",
"2900 86987 603177.SH 2025-05-30 8.72 8.55 8.77 8.49 25875.00 \n",
2025-06-01 15:59:29 +08:00
"\n",
" pct_chg amount ... 000905.SH_up_ratio_20d 399006.SZ_up_ratio_20d \\\n",
2025-06-04 13:50:02 +08:00
"4 3.12 23729.134 ... 0.3 0.40 \n",
"3 0.73 20011.747 ... 0.3 0.40 \n",
"2 2.08 19174.296 ... 0.3 0.40 \n",
"1 1.65 53043.254 ... 0.3 0.40 \n",
"0 1.42 12366.664 ... 0.3 0.40 \n",
2025-06-01 15:59:29 +08:00
"... ... ... ... ... ... \n",
2025-06-04 13:50:02 +08:00
"2904 -0.38 89186.130 ... 0.6 0.45 \n",
"2903 -1.18 23723.010 ... 0.6 0.45 \n",
"2902 -1.07 18204.020 ... 0.6 0.45 \n",
"2901 -2.45 36665.356 ... 0.6 0.45 \n",
"2900 -1.38 21916.553 ... 0.6 0.45 \n",
2025-06-01 15:59:29 +08:00
"\n",
" 000852.SH_volatility 000905.SH_volatility 399006.SZ_volatility \\\n",
"4 1.036997 0.828596 0.935322 \n",
"3 1.036997 0.828596 0.935322 \n",
"2 1.036997 0.828596 0.935322 \n",
"1 1.036997 0.828596 0.935322 \n",
"0 1.036997 0.828596 0.935322 \n",
"... ... ... ... \n",
2025-06-04 13:50:02 +08:00
"2904 1.089861 0.850444 1.195355 \n",
"2903 1.089861 0.850444 1.195355 \n",
"2902 1.089861 0.850444 1.195355 \n",
"2901 1.089861 0.850444 1.195355 \n",
"2900 1.089861 0.850444 1.195355 \n",
2025-06-01 15:59:29 +08:00
"\n",
" 000852.SH_volume_change_rate 000905.SH_volume_change_rate \\\n",
"4 5.203088 -0.750721 \n",
"3 5.203088 -0.750721 \n",
"2 5.203088 -0.750721 \n",
"1 5.203088 -0.750721 \n",
"0 5.203088 -0.750721 \n",
"... ... ... \n",
2025-06-04 13:50:02 +08:00
"2904 -2.039466 -12.002493 \n",
"2903 -2.039466 -12.002493 \n",
"2902 -2.039466 -12.002493 \n",
"2901 -2.039466 -12.002493 \n",
"2900 -2.039466 -12.002493 \n",
2025-06-01 15:59:29 +08:00
"\n",
" 399006.SZ_volume_change_rate score score_ranks \n",
2025-06-04 13:50:02 +08:00
"4 8.827360 0.391351 1496.0 \n",
"3 8.827360 0.410662 1497.0 \n",
"2 8.827360 0.470817 1498.0 \n",
"1 8.827360 0.567032 1499.0 \n",
"0 8.827360 0.596280 1500.0 \n",
2025-06-01 15:59:29 +08:00
"... ... ... ... \n",
2025-06-04 13:50:02 +08:00
"2904 5.078672 0.707379 1490.0 \n",
"2903 5.078672 0.769014 1491.0 \n",
"2902 5.078672 0.874366 1492.0 \n",
"2901 5.078672 0.975826 1493.0 \n",
"2900 5.078672 1.072646 1494.0 \n",
2025-06-01 15:59:29 +08:00
"\n",
2025-06-04 13:50:02 +08:00
"[2905 rows x 241 columns]\n",
2025-06-01 15:59:29 +08:00
"\n",
"验证成功:原始 DataFrame 和从 Redis 加载的 DataFrame 一致。\n",
"\n",
"清理了 Redis 中键 'save_df' 的数据。\n"
]
}
],
"source": [
"import redis\n",
"import pickle\n",
"\n",
"redis_host = '140.143.91.66'\n",
"redis_port = 6389\n",
"redis_db = 0\n",
"redis_key = 'save_df'\n",
"\n",
"try:\n",
" # 1. 连接到 Redis 服务器\n",
" r = redis.Redis(host=redis_host, port=redis_port, db=redis_db, password='Redis520102')\n",
" r.ping()\n",
" print(f\"\\n成功连接到 Redis 服务器: {redis_host}:{redis_port},数据库 {redis_db}\")\n",
"\n",
" # 2. 将 DataFrame 写入 Redis (使用 Pickle 序列化)\n",
" df_serialized = pickle.dumps(save_df)\n",
" r.set(redis_key, df_serialized)\n",
" print(f\"DataFrame 已使用 Pickle 序列化并写入 Redis键为 '{redis_key}'\")\n",
"\n",
" # 3. 从 Redis 读取数据 (获取 Pickle 序列化的字节流)\n",
" retrieved_serialized = r.get(redis_key)\n",
"\n",
" if retrieved_serialized:\n",
" print(f\"从 Redis 读取到的 Pickle 序列化数据 (前 20 字节):\")\n",
" print(retrieved_serialized[:20])\n",
"\n",
" # 4. 使用 Pickle 反序列化回 Pandas DataFrame\n",
" loaded_df = pickle.loads(retrieved_serialized)\n",
" print(\"\\n从 Redis 加载的 DataFrame (使用 Pickle):\")\n",
" print(loaded_df)\n",
"\n",
" # 5. 验证原始 DataFrame 和加载的 DataFrame 是否一致\n",
" if save_df.equals(loaded_df):\n",
" print(\"\\n验证成功原始 DataFrame 和从 Redis 加载的 DataFrame 一致。\")\n",
" else:\n",
" print(\"\\n验证失败原始 DataFrame 和从 Redis 加载的 DataFrame 不一致!\")\n",
"\n",
" else:\n",
" print(f\"错误:无法从 Redis 获取键 '{redis_key}' 的值。\")\n",
"\n",
" # 6. 清理测试数据 (可选)\n",
" r.delete(redis_key)\n",
" print(f\"\\n清理了 Redis 中键 '{redis_key}' 的数据。\")\n",
"\n",
"except redis.exceptions.ConnectionError as e:\n",
" print(f\"无法连接到 Redis 服务器: {e}\")\n",
" print(\"请确保您的 Redis 服务器已启动并且主机和端口配置正确。\")\n",
"except redis.exceptions.TimeoutError as e:\n",
" print(f\"连接 Redis 服务器超时: {e}\")\n",
" print(\"请检查您的网络连接和 Redis 服务器状态。\")\n",
"except Exception as e:\n",
" print(f\"测试 Redis 时发生未知错误: {e}\")\n",
" print(f\"测试 Redis 时发生未知错误: {e}\")"
2025-05-26 21:34:36 +08:00
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": null,
2025-05-26 21:34:36 +08:00
"id": "09b1799e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-01 15:59:29 +08:00
"200\n",
"['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_ratio_20d', '00085
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"print(len(feature_columns))\n",
"print(feature_columns)"
]
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": null,
2025-05-26 21:34:36 +08:00
"id": "bceabd1f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-04 13:50:02 +08:00
"警告: DataFrame 中没有 'group_id' 列。假设整个 DataFrame 是一个需要排序的组。\n"
]
},
{
"ename": "AttributeError",
"evalue": "`np.asfarray` was removed in the NumPy 2.0 release. Use `np.asarray` with a proper dtype instead.",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[105]\u001b[39m\u001b[32m, line 52\u001b[39m\n\u001b[32m 48\u001b[39m avg_ndcg = {k: np.mean(v) \u001b[38;5;28;01mif\u001b[39;00m v \u001b[38;5;28;01melse\u001b[39;00m np.nan \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m ndcg_scores.items()}\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m avg_ndcg\n\u001b[32m---> \u001b[39m\u001b[32m52\u001b[39m ndcg_results_single_group = \u001b[43mcalculate_ndcg\u001b[49m\u001b[43m(\u001b[49m\u001b[43mscore_df\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscore_col\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mscore\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_col\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mlabel\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk_values\u001b[49m\u001b[43m=\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m5\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroup_id\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 53\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mNDCG 结果\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 54\u001b[39m \u001b[38;5;28mprint\u001b[39m(ndcg_results_single_group)\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[105]\u001b[39m\u001b[32m, line 40\u001b[39m, in \u001b[36mcalculate_ndcg\u001b[39m\u001b[34m(df, score_col, label_col, group_id, k_values)\u001b[39m\n\u001b[32m 38\u001b[39m relevant_labels = group_df[label_col].values\n\u001b[32m 39\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m k_values:\n\u001b[32m---> \u001b[39m\u001b[32m40\u001b[39m ndcg_scores[\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mndcg@\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m].append(\u001b[43mndcg_at_k\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrelevant_labels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 41\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 42\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _, group_df \u001b[38;5;129;01min\u001b[39;00m df.groupby(group_id):\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[105]\u001b[39m\u001b[32m, line 27\u001b[39m, in \u001b[36mcalculate_ndcg.<locals>.ndcg_at_k\u001b[39m\u001b[34m(r, k)\u001b[39m\n\u001b[32m 26\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mndcg_at_k\u001b[39m(r, k):\n\u001b[32m---> \u001b[39m\u001b[32m27\u001b[39m dcg_max = \u001b[43mdcg_at_k\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43msorted\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreverse\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 28\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m dcg_max:\n\u001b[32m 29\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[32m0.\u001b[39m\n",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[105]\u001b[39m\u001b[32m, line 23\u001b[39m, in \u001b[36mcalculate_ndcg.<locals>.dcg_at_k\u001b[39m\u001b[34m(r, k)\u001b[39m\n\u001b[32m 22\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdcg_at_k\u001b[39m(r, k):\n\u001b[32m---> \u001b[39m\u001b[32m23\u001b[39m r = \u001b[43mnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43masfarray\u001b[49m(r)[:k] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(r) > \u001b[32m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m np.zeros(k)\n\u001b[32m 24\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m np.sum(r / np.log2(np.arange(\u001b[32m2\u001b[39m, r.size + \u001b[32m2\u001b[39m)))\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/numpy/__init__.py:400\u001b[39m, in \u001b[36m__getattr__\u001b[39m\u001b[34m(attr)\u001b[39m\n\u001b[32m 397\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(__former_attrs__[attr], name=\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[32m 399\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m attr \u001b[38;5;129;01min\u001b[39;00m __expired_attributes__:\n\u001b[32m--> \u001b[39m\u001b[32m400\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[32m 401\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m`np.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mattr\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m` was removed in the NumPy 2.0 release. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 402\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m__expired_attributes__[attr]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m,\n\u001b[32m 403\u001b[39m name=\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 404\u001b[39m )\n\u001b[32m 406\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m attr == \u001b[33m\"\u001b[39m\u001b[33mchararray\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 407\u001b[39m warnings.warn(\n\u001b[32m 408\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m`np.chararray` is deprecated and will be removed from \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 409\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mthe main namespace in the future. Use an array with a string \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 410\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mor bytes dtype instead.\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;167;01mDeprecationWarning\u001b[39;00m, stacklevel=\u001b[32m2\u001b[39m)\n",
"\u001b[31mAttributeError\u001b[39m: `np.asfarray` was removed in the NumPy 2.0 release. Use `np.asarray` with a proper dtype instead."
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"def calculate_ndcg(df: pd.DataFrame, score_col: str, label_col: str, group_id: str = 'trade_date', k_values: list = [1, 3, 5, 10]):\n",
" \"\"\"\n",
" 计算 DataFrame 中 score 列和 label 列的 NDCG 值。\n",
"\n",
" Args:\n",
" df (pd.DataFrame): 包含 score (排序学习预测分数) 和 label (相关性标签) 的 DataFrame。\n",
" 假设每个需要排序的组(例如,每天的股票)在 DataFrame 中是连续的。\n",
" score_col (str): 包含模型预测分数的列名。\n",
" label_col (str): 包含相关性标签的列名。标签值越高表示相关性越高。\n",
" k_values (list): 一个整数列表,表示计算 NDCG 的 top-k 值。\n",
" 例如,[1, 3, 5] 将计算 NDCG@1, NDCG@3 和 NDCG@5。\n",
"\n",
" Returns:\n",
" dict: 一个字典,包含每个 k 值对应的平均 NDCG 值。\n",
" 例如: {'ndcg@1': 0.85, 'ndcg@3': 0.78, 'ndcg@5': 0.72, 'ndcg@10': 0.65}\n",
" \"\"\"\n",
" ndcg_scores = {f'ndcg@{k}': [] for k in k_values}\n",
"\n",
" def dcg_at_k(r, k):\n",
" r = np.asfarray(r)[:k] if len(r) > 0 else np.zeros(k)\n",
" return np.sum(r / np.log2(np.arange(2, r.size + 2)))\n",
"\n",
" def ndcg_at_k(r, k):\n",
" dcg_max = dcg_at_k(sorted(r, reverse=True), k)\n",
" if not dcg_max:\n",
" return 0.\n",
" return dcg_at_k(r, k) / dcg_max\n",
"\n",
" # 假设 DataFrame 已经按照需要排序的组(例如,'trade_date')进行了分组,\n",
" # 并且每个组内的顺序不重要,我们只需要计算每个组的 NDCG。\n",
" # 如果需要按特定组计算 NDCG请先对 DataFrame 进行分组。\n",
" if group_id not in df.columns:\n",
" print(\"警告: DataFrame 中没有 'group_id' 列。假设整个 DataFrame 是一个需要排序的组。\")\n",
" group_df = df.sort_values(by=score_col, ascending=False)\n",
" relevant_labels = group_df[label_col].values\n",
" for k in k_values:\n",
" ndcg_scores[f'ndcg@{k}'].append(ndcg_at_k(relevant_labels, k))\n",
" else:\n",
" for _, group_df in df.groupby(group_id):\n",
" group_df_sorted = group_df.sort_values(by=score_col, ascending=False)\n",
" relevant_labels = group_df_sorted[label_col].values\n",
" for k in k_values:\n",
" ndcg_scores[f'ndcg@{k}'].append(ndcg_at_k(relevant_labels, k))\n",
"\n",
" avg_ndcg = {k: np.mean(v) if v else np.nan for k, v in ndcg_scores.items()}\n",
" return avg_ndcg\n",
"\n",
"\n",
"ndcg_results_single_group = calculate_ndcg(score_df, score_col='score', label_col='label', k_values=[1, 3, 5], group_id=None)\n",
"print(\"\\nNDCG 结果\")\n",
"print(ndcg_results_single_group)\n"
]
2025-05-29 20:41:18 +08:00
},
{
"cell_type": "code",
2025-06-04 13:50:02 +08:00
"execution_count": null,
2025-05-29 20:41:18 +08:00
"id": "44f64679",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ts_code trade_date open close high low vol pct_chg \\\n",
"1626778 002652.SZ 2019-01-02 19.59 19.64 19.89 19.28 20196.79 1.03 \n",
"1626779 002652.SZ 2019-01-03 19.74 19.44 19.84 19.33 15731.99 -1.02 \n",
"1626780 002652.SZ 2019-01-04 19.33 19.94 19.99 19.08 21099.93 2.57 \n",
"1626781 002652.SZ 2019-01-07 20.04 21.95 21.95 20.04 83534.19 10.08 \n",
"1626782 002652.SZ 2019-01-08 23.21 21.65 23.87 21.65 149377.97 -1.37 \n",
"... ... ... ... ... ... ... ... ... \n",
"1628321 002652.SZ 2025-05-19 15.05 15.05 15.21 14.80 142648.00 1.69 \n",
"1628322 002652.SZ 2025-05-20 15.11 15.31 15.36 14.85 131075.23 1.73 \n",
"1628323 002652.SZ 2025-05-21 15.51 15.26 15.56 15.11 147109.00 -0.33 \n",
"1628324 002652.SZ 2025-05-22 15.11 14.95 15.46 14.75 129187.20 -2.03 \n",
"1628325 002652.SZ 2025-05-23 14.95 14.70 15.11 14.70 139145.40 -1.67 \n",
"\n",
" amount turnover_rate ... cs_rank_vol_x_profit_margin \\\n",
"1626778 7867.047 0.3964 ... 0.608839 \n",
"1626779 6121.460 0.3088 ... 0.586710 \n",
"1626780 8245.083 0.4141 ... 0.682847 \n",
"1626781 35514.117 1.6394 ... 0.987591 \n",
"1626782 67160.354 2.9317 ... 0.765693 \n",
"... ... ... ... ... \n",
"1628321 42651.655 2.7857 ... 0.758644 \n",
"1628322 39438.290 2.5597 ... 0.834661 \n",
"1628323 44703.816 2.8729 ... 0.365327 \n",
"1628324 38679.608 2.5229 ... 0.810362 \n",
"1628325 41151.946 2.7173 ... 0.738293 \n",
"\n",
" cs_rank_lg_flow_price_concordance cs_rank_turnover_per_winner \\\n",
"1626778 0.203142 0.864865 \n",
"1626779 0.156684 0.763417 \n",
"1626780 0.184009 0.660949 \n",
"1626781 0.734940 0.700000 \n",
"1626782 0.874042 0.914234 \n",
"... ... ... \n",
"1628321 0.106051 0.544548 \n",
"1628322 0.202523 0.478420 \n",
"1628323 0.580870 0.520757 \n",
"1628324 0.808369 0.476918 \n",
"1628325 0.617735 0.404517 \n",
"\n",
" cs_rank_ind_cap_neutral_pe cs_rank_volume_ratio \\\n",
"1626778 NaN 0.646930 \n",
"1626779 NaN 0.251279 \n",
"1626780 NaN 0.311724 \n",
"1626781 NaN 0.988313 \n",
"1626782 NaN 0.990142 \n",
"... ... ... \n",
"1628321 NaN 0.695645 \n",
"1628322 NaN 0.542497 \n",
"1628323 NaN 0.678180 \n",
"1628324 NaN 0.524743 \n",
"1628325 NaN 0.585852 \n",
"\n",
" cs_rank_elg_buy_sell_sm_ratio cs_rank_cost_dist_vol_ratio \\\n",
"1626778 0.341855 0.678941 \n",
"1626779 0.318912 0.402916 \n",
"1626780 0.260036 0.460713 \n",
"1626781 0.796350 0.988501 \n",
"1626782 0.598905 0.991571 \n",
"... ... ... \n",
"1628321 0.287899 0.788896 \n",
"1628322 0.116534 0.705843 \n",
"1628323 0.492860 0.783793 \n",
"1628324 0.130521 0.696446 \n",
"1628325 0.134175 0.735636 \n",
"\n",
" cs_rank_size future_return label \n",
"1626778 0.258948 0.092159 45.0 \n",
"1626779 0.258123 0.075103 41.0 \n",
"1626780 0.257664 0.058175 41.0 \n",
"1626781 0.290146 -0.034169 4.0 \n",
"1626782 0.282482 -0.023095 4.0 \n",
"... ... ... ... \n",
"1628321 0.032912 NaN NaN \n",
"1628322 0.034861 NaN NaN \n",
"1628323 0.035204 NaN NaN \n",
"1628324 0.034208 NaN NaN \n",
"1628325 0.032547 NaN NaN \n",
"\n",
2025-06-01 15:59:29 +08:00
"[1548 rows x 190 columns]\n"
2025-05-29 20:41:18 +08:00
]
}
],
"source": [
"print(df[df['ts_code'] == '002652.SZ'])"
]
2025-05-26 21:34:36 +08:00
}
],
"metadata": {
"kernelspec": {
2025-06-04 13:50:02 +08:00
"display_name": "stock",
2025-05-26 21:34:36 +08:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2025-06-04 13:50:02 +08:00
"version": "3.13.2"
2025-05-26 21:34:36 +08:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}