2472 lines
470 KiB
Plaintext
2472 lines
470 KiB
Plaintext
|
|
{
|
|||
|
|
"cells": [
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 1,
|
|||
|
|
"id": "79a7758178bafdd3",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:46:06.987506Z",
|
|||
|
|
"start_time": "2025-04-03T12:46:06.259551Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"e:\\PyProject\\NewStock\\main\\train\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"%load_ext autoreload\n",
|
|||
|
|
"%autoreload 2\n",
|
|||
|
|
"\n",
|
|||
|
|
"import gc\n",
|
|||
|
|
"import os\n",
|
|||
|
|
"import sys\n",
|
|||
|
|
"sys.path.append('../../')\n",
|
|||
|
|
"print(os.getcwd())\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"from main.factor.factor import get_rolling_factor, get_simple_factor\n",
|
|||
|
|
"from main.utils.factor import read_industry_data\n",
|
|||
|
|
"from main.utils.factor_processor import calculate_score\n",
|
|||
|
|
"from main.utils.utils import read_and_merge_h5_data, merge_with_industry_data\n",
|
|||
|
|
"\n",
|
|||
|
|
"import warnings\n",
|
|||
|
|
"\n",
|
|||
|
|
"warnings.filterwarnings(\"ignore\")"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 2,
|
|||
|
|
"id": "a79cafb06a7e0e43",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:00.212859Z",
|
|||
|
|
"start_time": "2025-04-03T12:46:06.998047Z"
|
|||
|
|
},
|
|||
|
|
"scrolled": true
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"daily data\n",
|
|||
|
|
"daily basic\n",
|
|||
|
|
"inner merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"stk limit\n",
|
|||
|
|
"left merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"money flow\n",
|
|||
|
|
"left merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"cyq perf\n",
|
|||
|
|
"left merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|||
|
|
"RangeIndex: 8606490 entries, 0 to 8606489\n",
|
|||
|
|
"Data columns (total 32 columns):\n",
|
|||
|
|
" # Column Dtype \n",
|
|||
|
|
"--- ------ ----- \n",
|
|||
|
|
" 0 ts_code object \n",
|
|||
|
|
" 1 trade_date datetime64[ns]\n",
|
|||
|
|
" 2 open float64 \n",
|
|||
|
|
" 3 close float64 \n",
|
|||
|
|
" 4 high float64 \n",
|
|||
|
|
" 5 low float64 \n",
|
|||
|
|
" 6 vol float64 \n",
|
|||
|
|
" 7 pct_chg float64 \n",
|
|||
|
|
" 8 turnover_rate float64 \n",
|
|||
|
|
" 9 pe_ttm float64 \n",
|
|||
|
|
" 10 circ_mv float64 \n",
|
|||
|
|
" 11 total_mv float64 \n",
|
|||
|
|
" 12 volume_ratio float64 \n",
|
|||
|
|
" 13 is_st bool \n",
|
|||
|
|
" 14 up_limit float64 \n",
|
|||
|
|
" 15 down_limit float64 \n",
|
|||
|
|
" 16 buy_sm_vol float64 \n",
|
|||
|
|
" 17 sell_sm_vol float64 \n",
|
|||
|
|
" 18 buy_lg_vol float64 \n",
|
|||
|
|
" 19 sell_lg_vol float64 \n",
|
|||
|
|
" 20 buy_elg_vol float64 \n",
|
|||
|
|
" 21 sell_elg_vol float64 \n",
|
|||
|
|
" 22 net_mf_vol float64 \n",
|
|||
|
|
" 23 his_low float64 \n",
|
|||
|
|
" 24 his_high float64 \n",
|
|||
|
|
" 25 cost_5pct float64 \n",
|
|||
|
|
" 26 cost_15pct float64 \n",
|
|||
|
|
" 27 cost_50pct float64 \n",
|
|||
|
|
" 28 cost_85pct float64 \n",
|
|||
|
|
" 29 cost_95pct float64 \n",
|
|||
|
|
" 30 weight_avg float64 \n",
|
|||
|
|
" 31 winner_rate float64 \n",
|
|||
|
|
"dtypes: bool(1), datetime64[ns](1), float64(29), object(1)\n",
|
|||
|
|
"memory usage: 2.0+ GB\n",
|
|||
|
|
"None\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"from main.utils.utils import read_and_merge_h5_data\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('daily data')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/daily_data.h5', key='daily_data',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('daily basic')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/daily_basic.h5', key='daily_basic',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio',\n",
|
|||
|
|
" 'is_st'], df=df, join='inner')\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('stk limit')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/stk_limit.h5', key='stk_limit',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
|
|||
|
|
" df=df)\n",
|
|||
|
|
"print('money flow')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/money_flow.h5', key='money_flow',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
|
|||
|
|
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
|
|||
|
|
" df=df)\n",
|
|||
|
|
"print('cyq perf')\n",
|
|||
|
|
"df = read_and_merge_h5_data('../../data/cyq_perf.h5', key='cyq_perf',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
|
|||
|
|
" 'cost_50pct',\n",
|
|||
|
|
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
|
|||
|
|
" df=df)\n",
|
|||
|
|
"print(df.info())"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 3,
|
|||
|
|
"id": "cac01788dac10678",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:10.527104Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:00.488715Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"industry\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print('industry')\n",
|
|||
|
|
"industry_df = read_and_merge_h5_data('../../data/industry_data.h5', key='industry_data',\n",
|
|||
|
|
" columns=['ts_code', 'l2_code', 'in_date'],\n",
|
|||
|
|
" df=None, on=['ts_code'], join='left')\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def merge_with_industry_data(df, industry_df):\n",
|
|||
|
|
" # 确保日期字段是 datetime 类型\n",
|
|||
|
|
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
|
|||
|
|
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
|
|||
|
|
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
|
|||
|
|
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 使用 merge_asof 进行向后合并\n",
|
|||
|
|
" merged = pd.merge_asof(\n",
|
|||
|
|
" df_sorted,\n",
|
|||
|
|
" industry_df_sorted,\n",
|
|||
|
|
" by='ts_code', # 按 ts_code 分组\n",
|
|||
|
|
" left_on='trade_date',\n",
|
|||
|
|
" right_on='in_date',\n",
|
|||
|
|
" direction='backward'\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 获取每个 ts_code 的最早 in_date 记录\n",
|
|||
|
|
" min_in_date_per_ts = (industry_df_sorted\n",
|
|||
|
|
" .groupby('ts_code')\n",
|
|||
|
|
" .first()\n",
|
|||
|
|
" .reset_index()[['ts_code', 'l2_code']])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 填充未匹配到的记录(trade_date 早于所有 in_date 的情况)\n",
|
|||
|
|
" merged['l2_code'] = merged['l2_code'].fillna(\n",
|
|||
|
|
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 保留需要的列并重置索引\n",
|
|||
|
|
" result = merged.reset_index(drop=True)\n",
|
|||
|
|
" return result\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 使用示例\n",
|
|||
|
|
"df = merge_with_industry_data(df, industry_df)\n",
|
|||
|
|
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 4,
|
|||
|
|
"id": "c4e9e1d31da6dba6",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:10.719252Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:10.541247Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"def calculate_indicators(df):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 计算四个指标:当日涨跌幅、5日移动平均、RSI、MACD。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" df = df.sort_values('trade_date')\n",
|
|||
|
|
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
|
|||
|
|
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
|
|||
|
|
" delta = df['close'].diff()\n",
|
|||
|
|
" gain = delta.where(delta > 0, 0)\n",
|
|||
|
|
" loss = -delta.where(delta < 0, 0)\n",
|
|||
|
|
" avg_gain = gain.rolling(window=14).mean()\n",
|
|||
|
|
" avg_loss = loss.rolling(window=14).mean()\n",
|
|||
|
|
" rs = avg_gain / avg_loss\n",
|
|||
|
|
" df['RSI'] = 100 - (100 / (1 + rs))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算MACD\n",
|
|||
|
|
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
|
|||
|
|
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
|
|||
|
|
" df['MACD'] = ema12 - ema26\n",
|
|||
|
|
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
|
|||
|
|
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 4. 情绪因子1:市场上涨比例(Up Ratio)\n",
|
|||
|
|
" df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)\n",
|
|||
|
|
" df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 5. 情绪因子2:成交量变化率(Volume Change Rate)\n",
|
|||
|
|
" df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量\n",
|
|||
|
|
" df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 6. 情绪因子3:波动率(Volatility)\n",
|
|||
|
|
" df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 7. 情绪因子4:成交额变化率(Amount Change Rate)\n",
|
|||
|
|
" df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额\n",
|
|||
|
|
" df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def generate_index_indicators(h5_filename):\n",
|
|||
|
|
" df = pd.read_hdf(h5_filename, key='index_data')\n",
|
|||
|
|
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
|
|||
|
|
" df = df.sort_values('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算每个ts_code的相关指标\n",
|
|||
|
|
" df_indicators = []\n",
|
|||
|
|
" for ts_code in df['ts_code'].unique():\n",
|
|||
|
|
" df_index = df[df['ts_code'] == ts_code].copy()\n",
|
|||
|
|
" df_index = calculate_indicators(df_index)\n",
|
|||
|
|
" df_indicators.append(df_index)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 合并所有指数的结果\n",
|
|||
|
|
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 保留trade_date列,并将同一天的数据按ts_code合并成一行\n",
|
|||
|
|
" df_final = df_all_indicators.pivot_table(\n",
|
|||
|
|
" index='trade_date',\n",
|
|||
|
|
" columns='ts_code',\n",
|
|||
|
|
" values=['daily_return', 'RSI', 'MACD', 'Signal_line',\n",
|
|||
|
|
" 'MACD_hist', 'up_ratio_20d', 'volume_change_rate', 'volatility',\n",
|
|||
|
|
" 'amount_change_rate', 'amount_mean'],\n",
|
|||
|
|
" aggfunc='last'\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
|
|||
|
|
" df_final = df_final.reset_index()\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df_final\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 使用函数\n",
|
|||
|
|
"h5_filename = '../../data/index_data.h5'\n",
|
|||
|
|
"index_data = generate_index_indicators(h5_filename)\n",
|
|||
|
|
"index_data = index_data.dropna()\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 5,
|
|||
|
|
"id": "a735bc02ceb4d872",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:10.821169Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:10.751831Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import talib\n",
|
|||
|
|
"import numpy as np"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 6,
|
|||
|
|
"id": "53f86ddc0677a6d7",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:15.944254Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:10.826179Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
},
|
|||
|
|
"scrolled": true
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from main.utils.factor import get_act_factor\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def read_industry_data(h5_filename):\n",
|
|||
|
|
" # 读取 H5 文件中所有的行业数据\n",
|
|||
|
|
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
|
|||
|
|
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
|
|||
|
|
" ]) # 假设 H5 文件的键是 'industry_data'\n",
|
|||
|
|
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
" industry_data = industry_data.reindex()\n",
|
|||
|
|
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
|
|||
|
|
" industry_data['obv'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
|
|||
|
|
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = get_act_factor(industry_data, cat=False)\n",
|
|||
|
|
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
|
|||
|
|
" # factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
|
|||
|
|
" # \n",
|
|||
|
|
" # for factor in factor_columns:\n",
|
|||
|
|
" # if factor in industry_data.columns:\n",
|
|||
|
|
" # # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
|
|||
|
|
" # industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
|
|||
|
|
" # lambda x: x - x.mean())\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
|
|||
|
|
" lambda x: x.rank(pct=True))\n",
|
|||
|
|
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
|
|||
|
|
" lambda x: x.rank(pct=True))\n",
|
|||
|
|
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = industry_data.rename(\n",
|
|||
|
|
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
|
|||
|
|
" return industry_data\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"industry_df = read_industry_data('../../data/sw_daily.h5')\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 7,
|
|||
|
|
"id": "dbe2fd8021b9417f",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:15.969344Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:15.963327Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"['ts_code', 'open', 'close', 'high', 'low', 'circ_mv', 'total_mv', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'in_date']\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"origin_columns = df.columns.tolist()\n",
|
|||
|
|
"origin_columns = [col for col in origin_columns if\n",
|
|||
|
|
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
|
|||
|
|
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n",
|
|||
|
|
"origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
|
|||
|
|
"print(origin_columns)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 8,
|
|||
|
|
"id": "85c3e3d0235ffffa",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:16.089879Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:15.990101Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"fina_indicator_df = read_and_merge_h5_data('../../data/fina_indicator.h5', key='fina_indicator',\n",
|
|||
|
|
" columns=['ts_code', 'ann_date', 'undist_profit_ps', 'ocfps', 'bps'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"cashflow_df = read_and_merge_h5_data('../../data/cashflow.h5', key='cashflow',\n",
|
|||
|
|
" columns=['ts_code', 'ann_date', 'n_cashflow_act'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"balancesheet_df = read_and_merge_h5_data('../../data/balancesheet.h5', key='balancesheet',\n",
|
|||
|
|
" columns=['ts_code', 'ann_date', 'money_cap', 'total_liab'],\n",
|
|||
|
|
" df=None)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 9,
|
|||
|
|
"id": "92d84ce15a562ec6",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T13:08:01.612695Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:16.121802Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
|
|||
|
|
"开始计算因子: AR, BR (原地修改)...\n",
|
|||
|
|
"因子 AR, BR 计算成功。\n",
|
|||
|
|
"因子 AR, BR 计算流程结束。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
|
|||
|
|
"计算 BBI...\n",
|
|||
|
|
"--- 计算日级别偏离度 (使用 pct_chg) ---\n",
|
|||
|
|
"--- 计算日级别动量基准 (使用 pct_chg) ---\n",
|
|||
|
|
"日级别动量基准计算完成 (使用 pct_chg)。\n",
|
|||
|
|
"日级别偏离度计算完成 (使用 pct_chg)。\n",
|
|||
|
|
"--- 计算日级别行业偏离度 (使用 pct_chg 和行业基准) ---\n",
|
|||
|
|
"--- 计算日级别行业动量基准 (使用 pct_chg 和 cat_l2_code) ---\n",
|
|||
|
|
"错误: 计算日级别行业动量基准需要以下列: ['pct_chg', 'cat_l2_code', 'trade_date', 'ts_code']。\n",
|
|||
|
|
"错误: 计算日级别行业偏离度需要以下列: ['pct_chg', 'daily_industry_positive_benchmark', 'daily_industry_negative_benchmark']。请先运行 daily_industry_momentum_benchmark(df)。\n",
|
|||
|
|
"Index(['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol',\n",
|
|||
|
|
" 'pct_chg', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv',\n",
|
|||
|
|
" 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol',\n",
|
|||
|
|
" 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol',\n",
|
|||
|
|
" 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct',\n",
|
|||
|
|
" 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg',\n",
|
|||
|
|
" 'winner_rate', 'l2_code', 'undist_profit_ps', 'ocfps', 'AR', 'BR',\n",
|
|||
|
|
" 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio',\n",
|
|||
|
|
" 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor',\n",
|
|||
|
|
" 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity',\n",
|
|||
|
|
" 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio',\n",
|
|||
|
|
" 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change',\n",
|
|||
|
|
" 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel',\n",
|
|||
|
|
" 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy',\n",
|
|||
|
|
" 'cost_support_15pct_change', 'cat_winner_price_zone',\n",
|
|||
|
|
" 'flow_chip_consistency', 'profit_taking_vs_absorb', '_is_positive',\n",
|
|||
|
|
" '_is_negative', 'cat_is_positive', '_pos_returns', '_neg_returns',\n",
|
|||
|
|
" '_pos_returns_sq', '_neg_returns_sq', 'upside_vol', 'downside_vol',\n",
|
|||
|
|
" 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate',\n",
|
|||
|
|
" 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike',\n",
|
|||
|
|
" 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike',\n",
|
|||
|
|
" 'vol_std_5', 'atr_14', 'atr_6', 'obv'],\n",
|
|||
|
|
" dtype='object')\n",
|
|||
|
|
"Calculating lg_flow_mom_corr_20_60...\n",
|
|||
|
|
"Finished lg_flow_mom_corr_20_60.\n",
|
|||
|
|
"Calculating lg_flow_accel...\n",
|
|||
|
|
"Finished lg_flow_accel.\n",
|
|||
|
|
"Calculating profit_pressure...\n",
|
|||
|
|
"Finished profit_pressure.\n",
|
|||
|
|
"Calculating underwater_resistance...\n",
|
|||
|
|
"Finished underwater_resistance.\n",
|
|||
|
|
"Calculating cost_conc_std_20...\n",
|
|||
|
|
"Finished cost_conc_std_20.\n",
|
|||
|
|
"Calculating profit_decay_20...\n",
|
|||
|
|
"Finished profit_decay_20.\n",
|
|||
|
|
"Calculating vol_amp_loss_20...\n",
|
|||
|
|
"Finished vol_amp_loss_20.\n",
|
|||
|
|
"Calculating vol_drop_profit_cnt_5...\n",
|
|||
|
|
"Finished vol_drop_profit_cnt_5.\n",
|
|||
|
|
"Calculating lg_flow_vol_interact_20...\n",
|
|||
|
|
"Finished lg_flow_vol_interact_20.\n",
|
|||
|
|
"Calculating cost_break_confirm_cnt_5...\n",
|
|||
|
|
"Finished cost_break_confirm_cnt_5.\n",
|
|||
|
|
"Calculating atr_norm_channel_pos_14...\n",
|
|||
|
|
"Finished atr_norm_channel_pos_14.\n",
|
|||
|
|
"Calculating turnover_diff_skew_20...\n",
|
|||
|
|
"Finished turnover_diff_skew_20.\n",
|
|||
|
|
"Calculating lg_sm_flow_diverge_20...\n",
|
|||
|
|
"Finished lg_sm_flow_diverge_20.\n",
|
|||
|
|
"Calculating pullback_strong_20_20...\n",
|
|||
|
|
"Finished pullback_strong_20_20.\n",
|
|||
|
|
"Calculating vol_wgt_hist_pos_20...\n",
|
|||
|
|
"Finished vol_wgt_hist_pos_20.\n",
|
|||
|
|
"Calculating vol_adj_roc_20...\n",
|
|||
|
|
"Finished vol_adj_roc_20.\n",
|
|||
|
|
"Calculating cs_rank_net_lg_flow_val...\n",
|
|||
|
|
"Finished cs_rank_net_lg_flow_val.\n",
|
|||
|
|
"Calculating cs_rank_flow_divergence...\n",
|
|||
|
|
"Finished cs_rank_flow_divergence.\n",
|
|||
|
|
"Calculating cs_rank_ind_adj_lg_flow...\n",
|
|||
|
|
"Error calculating cs_rank_ind_adj_lg_flow: Missing 'cat_l2_code' column. Assigning NaN.\n",
|
|||
|
|
"Calculating cs_rank_elg_buy_ratio...\n",
|
|||
|
|
"Finished cs_rank_elg_buy_ratio.\n",
|
|||
|
|
"Calculating cs_rank_rel_profit_margin...\n",
|
|||
|
|
"Finished cs_rank_rel_profit_margin.\n",
|
|||
|
|
"Calculating cs_rank_cost_breadth...\n",
|
|||
|
|
"Finished cs_rank_cost_breadth.\n",
|
|||
|
|
"Calculating cs_rank_dist_to_upper_cost...\n",
|
|||
|
|
"Finished cs_rank_dist_to_upper_cost.\n",
|
|||
|
|
"Calculating cs_rank_winner_rate...\n",
|
|||
|
|
"Finished cs_rank_winner_rate.\n",
|
|||
|
|
"Calculating cs_rank_intraday_range...\n",
|
|||
|
|
"Finished cs_rank_intraday_range.\n",
|
|||
|
|
"Calculating cs_rank_close_pos_in_range...\n",
|
|||
|
|
"Finished cs_rank_close_pos_in_range.\n",
|
|||
|
|
"Calculating cs_rank_opening_gap...\n",
|
|||
|
|
"Error calculating cs_rank_opening_gap: Missing 'pre_close' column. Assigning NaN.\n",
|
|||
|
|
"Calculating cs_rank_pos_in_hist_range...\n",
|
|||
|
|
"Finished cs_rank_pos_in_hist_range.\n",
|
|||
|
|
"Calculating cs_rank_vol_x_profit_margin...\n",
|
|||
|
|
"Finished cs_rank_vol_x_profit_margin.\n",
|
|||
|
|
"Calculating cs_rank_lg_flow_price_concordance...\n",
|
|||
|
|
"Finished cs_rank_lg_flow_price_concordance.\n",
|
|||
|
|
"Calculating cs_rank_turnover_per_winner...\n",
|
|||
|
|
"Finished cs_rank_turnover_per_winner.\n",
|
|||
|
|
"Calculating cs_rank_ind_cap_neutral_pe (Placeholder - requires statsmodels)...\n",
|
|||
|
|
"Finished cs_rank_ind_cap_neutral_pe (Placeholder).\n",
|
|||
|
|
"Calculating cs_rank_volume_ratio...\n",
|
|||
|
|
"Finished cs_rank_volume_ratio.\n",
|
|||
|
|
"Calculating cs_rank_elg_buy_sell_sm_ratio...\n",
|
|||
|
|
"Finished cs_rank_elg_buy_sell_sm_ratio.\n",
|
|||
|
|
"Calculating cs_rank_cost_dist_vol_ratio...\n",
|
|||
|
|
"Finished cs_rank_cost_dist_vol_ratio.\n",
|
|||
|
|
"Calculating cs_rank_size...\n",
|
|||
|
|
"Finished cs_rank_size.\n",
|
|||
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|||
|
|
"RangeIndex: 4506576 entries, 0 to 4506575\n",
|
|||
|
|
"Columns: 178 entries, ts_code to cs_rank_size\n",
|
|||
|
|
"dtypes: bool(10), datetime64[ns](1), float64(162), int32(3), object(2)\n",
|
|||
|
|
"memory usage: 5.6+ GB\n",
|
|||
|
|
"None\n",
|
|||
|
|
"['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate', 'cat_l2_code', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'price_cost_divergence', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_flow_divergence', 'cs_rank_ind_adj_lg_flow', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_opening_gap', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_ind_cap_neutral_pe', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size']\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"from main.factor.factor import *\n",
|
|||
|
|
"\n",
|
|||
|
|
"def filter_data(df):\n",
|
|||
|
|
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
|
|||
|
|
" df = df[~df['is_st']]\n",
|
|||
|
|
" df = df[~df['ts_code'].str.endswith('BJ')]\n",
|
|||
|
|
" df = df[~df['ts_code'].str.startswith('30')]\n",
|
|||
|
|
" df = df[~df['ts_code'].str.startswith('68')]\n",
|
|||
|
|
" df = df[~df['ts_code'].str.startswith('8')]\n",
|
|||
|
|
" df = df[df['trade_date'] >= '2019-01-01']\n",
|
|||
|
|
" if 'in_date' in df.columns:\n",
|
|||
|
|
" df = df.drop(columns=['in_date'])\n",
|
|||
|
|
" df = df.reset_index(drop=True)\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"gc.collect()\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = filter_data(df)\n",
|
|||
|
|
"df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='undist_profit_ps')\n",
|
|||
|
|
"df = add_financial_factor(df, fina_indicator_df, factor_value_col='ocfps')\n",
|
|||
|
|
"calculate_arbr(df, N=26)\n",
|
|||
|
|
"df['log_circ_mv'] = np.log(df['circ_mv'])\n",
|
|||
|
|
"df = calculate_cashflow_to_ev_factor(df, cashflow_df, balancesheet_df)\n",
|
|||
|
|
"df = caculate_book_to_price_ratio(df, fina_indicator_df)\n",
|
|||
|
|
"df = turnover_rate_n(df, n=5)\n",
|
|||
|
|
"df = variance_n(df, n=20)\n",
|
|||
|
|
"df = bbi_ratio_factor(df)\n",
|
|||
|
|
"df = daily_deviation(df)\n",
|
|||
|
|
"df = daily_industry_deviation(df)\n",
|
|||
|
|
"df, _ = get_rolling_factor(df)\n",
|
|||
|
|
"df, _ = get_simple_factor(df)\n",
|
|||
|
|
"\n",
|
|||
|
|
"lg_flow_mom_corr(df, N=20, M=60)\n",
|
|||
|
|
"lg_flow_accel(df)\n",
|
|||
|
|
"profit_pressure(df)\n",
|
|||
|
|
"underwater_resistance(df)\n",
|
|||
|
|
"cost_conc_std(df, N=20)\n",
|
|||
|
|
"profit_decay(df, N=20)\n",
|
|||
|
|
"vol_amp_loss(df, N=20)\n",
|
|||
|
|
"vol_drop_profit_cnt(df, N=20, M=5)\n",
|
|||
|
|
"lg_flow_vol_interact(df, N=20)\n",
|
|||
|
|
"cost_break_confirm_cnt(df, M=5)\n",
|
|||
|
|
"atr_norm_channel_pos(df, N=14)\n",
|
|||
|
|
"turnover_diff_skew(df, N=20)\n",
|
|||
|
|
"lg_sm_flow_diverge(df, N=20)\n",
|
|||
|
|
"pullback_strong(df, N=20, M=20)\n",
|
|||
|
|
"vol_wgt_hist_pos(df, N=20)\n",
|
|||
|
|
"vol_adj_roc(df, N=20)\n",
|
|||
|
|
"\n",
|
|||
|
|
"cs_rank_net_lg_flow_val(df)\n",
|
|||
|
|
"cs_rank_flow_divergence(df)\n",
|
|||
|
|
"cs_rank_industry_adj_lg_flow(df) # Needs cat_l2_code\n",
|
|||
|
|
"cs_rank_elg_buy_ratio(df)\n",
|
|||
|
|
"cs_rank_rel_profit_margin(df)\n",
|
|||
|
|
"cs_rank_cost_breadth(df)\n",
|
|||
|
|
"cs_rank_dist_to_upper_cost(df)\n",
|
|||
|
|
"cs_rank_winner_rate(df)\n",
|
|||
|
|
"cs_rank_intraday_range(df)\n",
|
|||
|
|
"cs_rank_close_pos_in_range(df)\n",
|
|||
|
|
"cs_rank_opening_gap(df) # Needs pre_close\n",
|
|||
|
|
"cs_rank_pos_in_hist_range(df) # Needs his_low, his_high\n",
|
|||
|
|
"cs_rank_vol_x_profit_margin(df)\n",
|
|||
|
|
"cs_rank_lg_flow_price_concordance(df)\n",
|
|||
|
|
"cs_rank_turnover_per_winner(df)\n",
|
|||
|
|
"cs_rank_ind_cap_neutral_pe(df) # Placeholder - needs external libraries\n",
|
|||
|
|
"cs_rank_volume_ratio(df) # Needs volume_ratio\n",
|
|||
|
|
"cs_rank_elg_buy_sell_sm_ratio(df)\n",
|
|||
|
|
"cs_rank_cost_dist_vol_ratio(df) # Needs volume_ratio\n",
|
|||
|
|
"cs_rank_size(df) # Needs circ_mv\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = df.rename(columns={'l1_code': 'cat_l1_code'})\n",
|
|||
|
|
"df = df.rename(columns={'l2_code': 'cat_l2_code'})\n",
|
|||
|
|
"\n",
|
|||
|
|
"# df = df.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(df.info())\n",
|
|||
|
|
"print(df.columns.tolist())"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 10,
|
|||
|
|
"id": "b87b938028afa206",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T13:08:03.658725Z",
|
|||
|
|
"start_time": "2025-04-03T13:08:02.469611Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from scipy.stats import ks_2samp, wasserstein_distance\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def remove_shifted_features(train_data, test_data, feature_columns, ks_threshold=0.05, wasserstein_threshold=0.1,\n",
|
|||
|
|
" importance_threshold=0.05):\n",
|
|||
|
|
" dropped_features = []\n",
|
|||
|
|
"\n",
|
|||
|
|
" # **统计数据漂移**\n",
|
|||
|
|
" numeric_columns = train_data.select_dtypes(include=['float64', 'int64']).columns\n",
|
|||
|
|
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
|
|||
|
|
" for feature in numeric_columns:\n",
|
|||
|
|
" ks_stat, p_value = ks_2samp(train_data[feature], test_data[feature])\n",
|
|||
|
|
" wasserstein_dist = wasserstein_distance(train_data[feature], test_data[feature])\n",
|
|||
|
|
"\n",
|
|||
|
|
" if p_value < ks_threshold or wasserstein_dist > wasserstein_threshold:\n",
|
|||
|
|
" dropped_features.append(feature)\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(f\"检测到 {len(dropped_features)} 个可能漂移的特征: {dropped_features}\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # **应用阈值进行最终筛选**\n",
|
|||
|
|
" filtered_features = [f for f in feature_columns if f not in dropped_features]\n",
|
|||
|
|
"\n",
|
|||
|
|
" return filtered_features, dropped_features\n",
|
|||
|
|
"\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 11,
|
|||
|
|
"id": "f4f16d63ad18d1bc",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T13:08:03.670700Z",
|
|||
|
|
"start_time": "2025-04-03T13:08:03.665739Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import statsmodels.api as sm # 用于中性化回归\n",
|
|||
|
|
"from tqdm import tqdm # 可选,用于显示进度条\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 常量 ---\n",
|
|||
|
|
"epsilon = 1e-10 # 防止除零\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 1. 中位数去极值 (MAD) ---\n",
|
|||
|
|
"\n",
|
|||
|
|
"def cs_mad_filter(df: pd.DataFrame,\n",
|
|||
|
|
" features: list,\n",
|
|||
|
|
" k: float = 3.0,\n",
|
|||
|
|
" scale_factor: float = 1.4826):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 对指定特征列进行截面 MAD 去极值处理 (原地修改)。\n",
|
|||
|
|
"\n",
|
|||
|
|
" 方法: 对每日截面数据,计算 median 和 MAD,\n",
|
|||
|
|
" 将超出 [median - k * scale * MAD, median + k * scale * MAD] 范围的值\n",
|
|||
|
|
" 替换为边界值 (Winsorization)。\n",
|
|||
|
|
" scale_factor=1.4826 使得 MAD 约等于正态分布的标准差。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df (pd.DataFrame): 输入 DataFrame,需包含 'trade_date' 和 features 列。\n",
|
|||
|
|
" features (list): 需要处理的特征列名列表。\n",
|
|||
|
|
" k (float): MAD 的倍数,用于确定边界。默认为 3.0。\n",
|
|||
|
|
" scale_factor (float): MAD 的缩放因子。默认为 1.4826。\n",
|
|||
|
|
"\n",
|
|||
|
|
" WARNING: 此函数会原地修改输入的 DataFrame 'df'。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" print(f\"开始截面 MAD 去极值处理 (k={k})...\")\n",
|
|||
|
|
" if not all(col in df.columns for col in features):\n",
|
|||
|
|
" missing = [col for col in features if col not in df.columns]\n",
|
|||
|
|
" print(f\"错误: DataFrame 中缺少以下特征列: {missing}。跳过去极值处理。\")\n",
|
|||
|
|
" return\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped = df.groupby('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" for col in tqdm(features, desc=\"MAD Filtering\"):\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" # 计算截面中位数\n",
|
|||
|
|
" median = grouped[col].transform('median')\n",
|
|||
|
|
" # 计算截面 MAD (Median Absolute Deviation from Median)\n",
|
|||
|
|
" mad = (df[col] - median).abs().groupby(df['trade_date']).transform('median')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算上下边界\n",
|
|||
|
|
" lower_bound = median - k * scale_factor * mad\n",
|
|||
|
|
" upper_bound = median + k * scale_factor * mad\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 原地应用 clip\n",
|
|||
|
|
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
|
|||
|
|
"\n",
|
|||
|
|
" except KeyError:\n",
|
|||
|
|
" print(f\"警告: 列 '{col}' 可能不存在或在分组中出错,跳过此列的 MAD 处理。\")\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" print(f\"警告: 处理列 '{col}' 时发生错误: {e},跳过此列的 MAD 处理。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(\"截面 MAD 去极值处理完成。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 2. 行业市值中性化 ---\n",
|
|||
|
|
"\n",
|
|||
|
|
"def cs_neutralize_industry_cap(df: pd.DataFrame,\n",
|
|||
|
|
" features: list,\n",
|
|||
|
|
" industry_col: str = 'cat_l2_code',\n",
|
|||
|
|
" market_cap_col: str = 'circ_mv'):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 对指定特征列进行截面行业和对数市值中性化 (原地修改)。\n",
|
|||
|
|
" 使用 OLS 回归: feature ~ 1 + log(market_cap) + C(industry)\n",
|
|||
|
|
" 将回归残差写回原特征列。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df (pd.DataFrame): 输入 DataFrame,需包含 'trade_date', features 列,\n",
|
|||
|
|
" industry_col, market_cap_col。\n",
|
|||
|
|
" features (list): 需要处理的特征列名列表。\n",
|
|||
|
|
" industry_col (str): 行业分类列名。\n",
|
|||
|
|
" market_cap_col (str): 流通市值列名。\n",
|
|||
|
|
"\n",
|
|||
|
|
" WARNING: 此函数会原地修改输入的 DataFrame 'df' 的 features 列。\n",
|
|||
|
|
" 计算量较大,可能耗时较长。\n",
|
|||
|
|
" 需要安装 statsmodels 库 (pip install statsmodels)。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" print(\"开始截面行业市值中性化...\")\n",
|
|||
|
|
" required_cols = features + ['trade_date', industry_col, market_cap_col]\n",
|
|||
|
|
" if not all(col in df.columns for col in required_cols):\n",
|
|||
|
|
" missing = [col for col in required_cols if col not in df.columns]\n",
|
|||
|
|
" print(f\"错误: DataFrame 中缺少必需列: {missing}。无法进行中性化。\")\n",
|
|||
|
|
" return\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 预处理:计算 log 市值,处理 industry code 可能的 NaN\n",
|
|||
|
|
" log_cap_col = '_log_market_cap'\n",
|
|||
|
|
" df[log_cap_col] = np.log1p(df[market_cap_col]) # log1p 处理 0 值\n",
|
|||
|
|
" # df[industry_col] = df[industry_col].cat.add_categories('UnknownIndustry')\n",
|
|||
|
|
" # df[industry_col] = df[industry_col].fillna('UnknownIndustry') # 填充行业 NaN\n",
|
|||
|
|
" # df[industry_col] = df[industry_col].astype('category') # 转为类别,ols 会自动处理\n",
|
|||
|
|
"\n",
|
|||
|
|
" dates = df['trade_date'].unique()\n",
|
|||
|
|
" all_residuals = [] # 用于收集所有日期的残差\n",
|
|||
|
|
"\n",
|
|||
|
|
" for date in tqdm(dates, desc=\"Neutralizing\"):\n",
|
|||
|
|
" daily_data = df.loc[df['trade_date'] == date, features + [log_cap_col, industry_col]].copy() # 使用 .loc 获取副本\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 准备自变量 X (常数项 + log市值 + 行业哑变量)\n",
|
|||
|
|
" X = daily_data[[log_cap_col]]\n",
|
|||
|
|
" X = sm.add_constant(X, prepend=True) # 添加常数项\n",
|
|||
|
|
" # 创建行业哑变量 (drop_first=True 避免共线性)\n",
|
|||
|
|
" industry_dummies = pd.get_dummies(daily_data[industry_col], prefix=industry_col, drop_first=True)\n",
|
|||
|
|
" industry_dummies = industry_dummies.astype(int)\n",
|
|||
|
|
" X = pd.concat([X, industry_dummies], axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" daily_residuals = daily_data[[col for col in features]].copy() # 创建用于存储残差的df\n",
|
|||
|
|
"\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" Y = daily_data[col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 处理 NaN 值,确保 X 和 Y 在相同位置有有效值\n",
|
|||
|
|
" valid_mask = Y.notna() & X.notna().all(axis=1)\n",
|
|||
|
|
" if valid_mask.sum() < (X.shape[1] + 1): # 数据点不足以估计模型\n",
|
|||
|
|
" print(f\"警告: 日期 {date}, 特征 {col} 有效数据不足 ({valid_mask.sum()}个),无法中性化,填充 NaN。\")\n",
|
|||
|
|
" daily_residuals[col] = np.nan\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" Y_valid = Y[valid_mask]\n",
|
|||
|
|
" X_valid = X[valid_mask]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 执行 OLS 回归\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" model = sm.OLS(Y_valid.to_numpy(), X_valid.to_numpy())\n",
|
|||
|
|
" results = model.fit()\n",
|
|||
|
|
" # 将残差填回对应位置\n",
|
|||
|
|
" daily_residuals.loc[valid_mask, col] = results.resid\n",
|
|||
|
|
" daily_residuals.loc[~valid_mask, col] = np.nan # 原本无效的位置填充 NaN\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" print(f\"警告: 日期 {date}, 特征 {col} 回归失败: {e},填充 NaN。\")\n",
|
|||
|
|
" daily_residuals[col] = np.nan\n",
|
|||
|
|
" break\n",
|
|||
|
|
"\n",
|
|||
|
|
" all_residuals.append(daily_residuals)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 合并所有日期的残差结果\n",
|
|||
|
|
" if all_residuals:\n",
|
|||
|
|
" residuals_df = pd.concat(all_residuals)\n",
|
|||
|
|
" # 将残差结果更新回原始 df (原地修改)\n",
|
|||
|
|
" # 使用 update 比 merge 更适合基于索引的原地更新\n",
|
|||
|
|
" # 确保 residuals_df 的索引与 df 中对应部分一致\n",
|
|||
|
|
" df.update(residuals_df)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print(\"没有有效的残差结果可以合并。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 清理临时列\n",
|
|||
|
|
" df.drop(columns=[log_cap_col], inplace=True)\n",
|
|||
|
|
" print(\"截面行业市值中性化完成。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 3. Z-Score 标准化 ---\n",
|
|||
|
|
"\n",
|
|||
|
|
"def cs_zscore_standardize(df: pd.DataFrame, features: list, epsilon: float = 1e-10):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 对指定特征列进行截面 Z-Score 标准化 (原地修改)。\n",
|
|||
|
|
" 方法: Z = (value - cross_sectional_mean) / (cross_sectional_std + epsilon)\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df (pd.DataFrame): 输入 DataFrame,需包含 'trade_date' 和 features 列。\n",
|
|||
|
|
" features (list): 需要处理的特征列名列表。\n",
|
|||
|
|
" epsilon (float): 防止除以零的小常数。\n",
|
|||
|
|
"\n",
|
|||
|
|
" WARNING: 此函数会原地修改输入的 DataFrame 'df'。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" print(\"开始截面 Z-Score 标准化...\")\n",
|
|||
|
|
" if not all(col in df.columns for col in features):\n",
|
|||
|
|
" missing = [col for col in features if col not in df.columns]\n",
|
|||
|
|
" print(f\"错误: DataFrame 中缺少以下特征列: {missing}。跳过标准化处理。\")\n",
|
|||
|
|
" return\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped = df.groupby('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" for col in tqdm(features, desc=\"Standardizing\"):\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" # 使用 transform 计算截面均值和标准差\n",
|
|||
|
|
" mean = grouped[col].transform('mean')\n",
|
|||
|
|
" std = grouped[col].transform('std')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 Z-Score 并原地赋值\n",
|
|||
|
|
" df[col] = (df[col] - mean) / (std + epsilon)\n",
|
|||
|
|
"\n",
|
|||
|
|
" except KeyError:\n",
|
|||
|
|
" print(f\"警告: 列 '{col}' 可能不存在或在分组中出错,跳过此列的标准化处理。\")\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" print(f\"警告: 处理列 '{col}' 时发生错误: {e},跳过此列的标准化处理。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(\"截面 Z-Score 标准化完成。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"def fill_nan_with_daily_median(df: pd.DataFrame, feature_columns: list[str]) -> pd.DataFrame:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 对指定特征列进行每日截面中位数填充缺失值 (NaN)。\n",
|
|||
|
|
"\n",
|
|||
|
|
" 参数:\n",
|
|||
|
|
" df (pd.DataFrame): 包含多日数据的DataFrame,需要包含 'trade_date' 和 feature_columns 中的列。\n",
|
|||
|
|
" feature_columns (list[str]): 需要进行缺失值填充的特征列名称列表。\n",
|
|||
|
|
"\n",
|
|||
|
|
" 返回:\n",
|
|||
|
|
" pd.DataFrame: 包含缺失值填充后特征列的DataFrame。在输入DataFrame的副本上操作。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" processed_df = df.copy() # 在副本上操作,保留原始数据\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 确保 trade_date 是 datetime 类型以便正确分组\n",
|
|||
|
|
" processed_df['trade_date'] = pd.to_datetime(processed_df['trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" def _fill_daily_nan(group):\n",
|
|||
|
|
" # group 是某一个交易日的 DataFrame\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 遍历指定的特征列\n",
|
|||
|
|
" for feature_col in feature_columns:\n",
|
|||
|
|
" # 检查列是否存在于当前分组中\n",
|
|||
|
|
" if feature_col in group.columns:\n",
|
|||
|
|
" # 计算当日该特征的中位数\n",
|
|||
|
|
" median_val = group[feature_col].median()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 使用当日中位数填充该特征列的 NaN 值\n",
|
|||
|
|
" # inplace=True 会直接修改 group DataFrame\n",
|
|||
|
|
" group[feature_col].fillna(median_val, inplace=True)\n",
|
|||
|
|
" # else:\n",
|
|||
|
|
" # print(f\"Warning: Feature column '{feature_col}' not found in daily group for {group['trade_date'].iloc[0]}. Skipping.\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" return group\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 按交易日期分组,并应用每日填充函数\n",
|
|||
|
|
" # group_keys=False 避免将分组键添加到结果索引中\n",
|
|||
|
|
" filled_df = processed_df.groupby('trade_date', group_keys=False).apply(_fill_daily_nan)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return filled_df"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 12,
|
|||
|
|
"id": "40e6b68a91b30c79",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T13:08:04.694262Z",
|
|||
|
|
"start_time": "2025-04-03T13:08:03.694904Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
|
|||
|
|
" log=True):\n",
|
|||
|
|
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
|
|||
|
|
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Calculate lower and upper bounds based on percentiles\n",
|
|||
|
|
" lower_bound = label.quantile(lower_percentile)\n",
|
|||
|
|
" upper_bound = label.quantile(upper_percentile)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Filter out values outside the bounds\n",
|
|||
|
|
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Print the number of removed outliers\n",
|
|||
|
|
" if log:\n",
|
|||
|
|
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
|
|||
|
|
" return filtered_label\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def calculate_risk_adjusted_target(df, days=5):\n",
|
|||
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
|
|||
|
|
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
|
|||
|
|
" df['future_return'] = (df['future_close'] - df['future_open']) / df['future_open']\n",
|
|||
|
|
"\n",
|
|||
|
|
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
|
|||
|
|
" level=0, drop=True)\n",
|
|||
|
|
" sharpe_ratio = df['future_return'] * df['future_volatility']\n",
|
|||
|
|
" sharpe_ratio.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return sharpe_ratio\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def calculate_score(df, days=5, lambda_param=1.0):\n",
|
|||
|
|
" def calculate_max_drawdown(prices):\n",
|
|||
|
|
" peak = prices.iloc[0] # 初始化峰值\n",
|
|||
|
|
" max_drawdown = 0 # 初始化最大回撤\n",
|
|||
|
|
"\n",
|
|||
|
|
" for price in prices:\n",
|
|||
|
|
" if price > peak:\n",
|
|||
|
|
" peak = price # 更新峰值\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" drawdown = (peak - price) / peak # 计算当前回撤\n",
|
|||
|
|
" max_drawdown = max(max_drawdown, drawdown) # 更新最大回撤\n",
|
|||
|
|
"\n",
|
|||
|
|
" return max_drawdown\n",
|
|||
|
|
"\n",
|
|||
|
|
" def compute_stock_score(stock_df):\n",
|
|||
|
|
" stock_df = stock_df.sort_values(by=['trade_date'])\n",
|
|||
|
|
" future_return = stock_df['future_return']\n",
|
|||
|
|
" # 使用已有的 pct_chg 字段计算波动率\n",
|
|||
|
|
" volatility = stock_df['pct_chg'].rolling(days).std().shift(-days)\n",
|
|||
|
|
" max_drawdown = stock_df['close'].rolling(days).apply(calculate_max_drawdown, raw=False).shift(-days)\n",
|
|||
|
|
" score = future_return - lambda_param * max_drawdown\n",
|
|||
|
|
" return score\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # 确保 DataFrame 按照股票代码和交易日期排序\n",
|
|||
|
|
" # df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对每个股票分别计算 score\n",
|
|||
|
|
" df['score'] = df.groupby('ts_code').apply(compute_stock_score).reset_index(level=0, drop=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df['score']\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def remove_highly_correlated_features(df, feature_columns, threshold=0.9):\n",
|
|||
|
|
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
|
|||
|
|
" if not numeric_features:\n",
|
|||
|
|
" raise ValueError(\"No numeric features found in the provided data.\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" corr_matrix = df[numeric_features].corr().abs()\n",
|
|||
|
|
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
|
|||
|
|
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
|
|||
|
|
" remaining_features = [col for col in feature_columns if col not in to_drop\n",
|
|||
|
|
" or 'act' in col or 'af' in col]\n",
|
|||
|
|
" return remaining_features\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def cross_sectional_standardization(df, features):\n",
|
|||
|
|
" df_sorted = df.sort_values(by='trade_date') # 按时间排序\n",
|
|||
|
|
" df_standardized = df_sorted.copy()\n",
|
|||
|
|
"\n",
|
|||
|
|
" for date in df_sorted['trade_date'].unique():\n",
|
|||
|
|
" # 获取当前时间点的数据\n",
|
|||
|
|
" current_data = df_standardized[df_standardized['trade_date'] == date]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 只对指定特征进行标准化\n",
|
|||
|
|
" scaler = StandardScaler()\n",
|
|||
|
|
" standardized_values = scaler.fit_transform(current_data[features])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将标准化结果重新赋值回去\n",
|
|||
|
|
" df_standardized.loc[df_standardized['trade_date'] == date, features] = standardized_values\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df_standardized\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def neutralize_manual_revised(df: pd.DataFrame, features: list, industry_col: str, mkt_cap_col: str) -> pd.DataFrame:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 手动实现简单回归以提升速度,通过构建 Series 确保索引对齐。\n",
|
|||
|
|
" 对特征在行业内部进行市值中性化。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df: 输入的 DataFrame,包含特征、行业分类和市值列。\n",
|
|||
|
|
" features: 需要进行中性化的特征列名列表。\n",
|
|||
|
|
" industry_col: 行业分类列的列名。\n",
|
|||
|
|
" mkt_cap_col: 市值列的列名。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" 中性化后的 DataFrame。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
"\n",
|
|||
|
|
" df[mkt_cap_col] = pd.to_numeric(df[mkt_cap_col], errors='coerce')\n",
|
|||
|
|
" df_cleaned = df.dropna(subset=[mkt_cap_col]).copy()\n",
|
|||
|
|
" df_cleaned = df_cleaned[df_cleaned[mkt_cap_col] > 0].copy()\n",
|
|||
|
|
"\n",
|
|||
|
|
" if df_cleaned.empty:\n",
|
|||
|
|
" print(\"警告: 清理市值异常值后 DataFrame 为空。\")\n",
|
|||
|
|
" return df # 返回原始或空df,取决于清理前的状态\n",
|
|||
|
|
"\n",
|
|||
|
|
" processed_df = df\n",
|
|||
|
|
"\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" if col not in df_cleaned.columns:\n",
|
|||
|
|
" print(f\"警告: 特征列 '{col}' 不存在于清理后的 DataFrame 中,已跳过。\")\n",
|
|||
|
|
" # 对于原始 df 中该列不存在的,在结果 df 中也保持原样(可能全是NaN)\n",
|
|||
|
|
" processed_df[col] = df[col] if col in df.columns else np.nan\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 跳过对控制变量本身进行中性化\n",
|
|||
|
|
" if col == mkt_cap_col or col == industry_col:\n",
|
|||
|
|
" print(f\"警告: 特征列 '{col}' 是控制变量或内部使用的列,跳过中性化。\")\n",
|
|||
|
|
" # 在结果 df 中也保持原样\n",
|
|||
|
|
" processed_df[col] = df[col] if col in df.columns else np.nan\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" residual_series = pd.Series(index=df_cleaned.index, dtype=float)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 在分组前处理特征列的 NaN,只对有因子值的行进行回归计算\n",
|
|||
|
|
" df_subset_factor = df_cleaned.dropna(subset=[col]).copy()\n",
|
|||
|
|
"\n",
|
|||
|
|
" if not df_subset_factor.empty:\n",
|
|||
|
|
" for industry, group in df_subset_factor.groupby(industry_col):\n",
|
|||
|
|
" x = group[mkt_cap_col] # 市值对数\n",
|
|||
|
|
" y = group[col] # 因子值\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 确保有足够的数据点 (>1) 且市值对数有方差 (>0) 进行回归计算\n",
|
|||
|
|
" # 检查 np.var > 一个很小的正数,避免浮点数误差导致的零方差判断问题\n",
|
|||
|
|
" if len(group) > 1 and np.var(x) > 1e-9:\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" beta = np.cov(y, x)[0, 1] / np.var(x)\n",
|
|||
|
|
" alpha = np.mean(y) - beta * np.mean(x)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算残差\n",
|
|||
|
|
" resid = y - (alpha + beta * x)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将计算出的残差存储到 residual_series 中,通过索引自动对齐\n",
|
|||
|
|
" residual_series.loc[resid.index] = resid\n",
|
|||
|
|
"\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" # 捕获可能的计算异常,例如np.cov或np.var因为极端数据报错\n",
|
|||
|
|
" print(f\"警告: 在行业 {industry} 计算回归时发生错误: {e}。该组残差将设为原始值或 NaN。\")\n",
|
|||
|
|
" # 此时该组的残差会保持 residual_series 初始化时的 NaN 或后续处理\n",
|
|||
|
|
" # 也可以选择保留原始值:residual_series.loc[group.index] = group[col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" residual_series.loc[group.index] = group[col] # 保留原始因子值\n",
|
|||
|
|
" processed_df.loc[residual_series.index, col] = residual_series\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" processed_df[col] = np.nan # 或 df[col] if col in df.columns else np.nan\n",
|
|||
|
|
"\n",
|
|||
|
|
" return processed_df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"import gc\n",
|
|||
|
|
"\n",
|
|||
|
|
"gc.collect()\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def mad_filter(df, features, n=3):\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" median = df[col].median()\n",
|
|||
|
|
" mad = np.median(np.abs(df[col] - median))\n",
|
|||
|
|
" upper = median + n * mad\n",
|
|||
|
|
" lower = median - n * mad\n",
|
|||
|
|
" df[col] = np.clip(df[col], lower, upper) # 截断极值\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def percentile_filter(df, features, lower_percentile=0.01, upper_percentile=0.99):\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" # 按日期分组计算上下百分位数\n",
|
|||
|
|
" lower_bound = df.groupby('trade_date')[col].transform(\n",
|
|||
|
|
" lambda x: x.quantile(lower_percentile)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" upper_bound = df.groupby('trade_date')[col].transform(\n",
|
|||
|
|
" lambda x: x.quantile(upper_percentile)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" # 截断超出范围的值\n",
|
|||
|
|
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"from scipy.stats import iqr\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def iqr_filter(df, features):\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" df[col] = df.groupby('trade_date')[col].transform(\n",
|
|||
|
|
" lambda x: (x - x.median()) / iqr(x) if iqr(x) != 0 else x\n",
|
|||
|
|
" )\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99, window=60):\n",
|
|||
|
|
" df = df.copy()\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" # 计算 rolling 统计量,需要按日期进行 groupby\n",
|
|||
|
|
" rolling_lower = df.groupby('trade_date')[col].transform(lambda x: x.rolling(window=min(len(x), window)).quantile(lower_quantile))\n",
|
|||
|
|
" rolling_upper = df.groupby('trade_date')[col].transform(lambda x: x.rolling(window=min(len(x), window)).quantile(upper_quantile))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对数据进行裁剪\n",
|
|||
|
|
" df[col] = np.clip(df[col], rolling_lower, rolling_upper)\n",
|
|||
|
|
" \n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"def select_top_features_by_rankic(df: pd.DataFrame, feature_columns: list, n: int, target_column: str = 'future_return') -> list:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 计算给定特征与目标列的 RankIC,并返回 RankIC 绝对值最高的 n 个特征。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df: 包含特征列和目标列的 Pandas DataFrame。\n",
|
|||
|
|
" feature_columns: 包含所有待评估特征列名的列表。\n",
|
|||
|
|
" n: 希望选取的 RankIC 绝对值最高的特征数量。\n",
|
|||
|
|
" target_column: 目标列的名称,用于计算 RankIC。默认为 'future_return'。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" 包含 RankIC 绝对值最高的 n 个特征列名的列表。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns\n",
|
|||
|
|
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
|
|||
|
|
" if target_column not in df.columns:\n",
|
|||
|
|
" raise ValueError(f\"目标列 '{target_column}' 不存在于 DataFrame 中。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" rankic_scores = {}\n",
|
|||
|
|
" for feature in numeric_columns:\n",
|
|||
|
|
" if feature not in df.columns:\n",
|
|||
|
|
" print(f\"警告: 特征列 '{feature}' 不存在于 DataFrame 中,已跳过。\")\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算特征与目标列的 RankIC (斯皮尔曼相关系数)\n",
|
|||
|
|
" # dropna() 是为了处理缺失值,确保相关性计算不失败\n",
|
|||
|
|
" valid_data = df[[feature, target_column]].dropna()\n",
|
|||
|
|
" if len(valid_data) > 1: # 确保有足够的数据点进行相关性计算\n",
|
|||
|
|
" # 计算斯皮尔曼相关性\n",
|
|||
|
|
" correlation = valid_data[feature].corr(valid_data[target_column], method='spearman')\n",
|
|||
|
|
" rankic_scores[feature] = abs(correlation) # 使用绝对值来衡量相关性强度\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" rankic_scores[feature] = 0 # 数据不足,RankIC设为0或跳过\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将 RankIC 分数转换为 Series 便于排序\n",
|
|||
|
|
" rankic_series = pd.Series(rankic_scores)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 按 RankIC 绝对值降序排序,选取前 n 个特征\n",
|
|||
|
|
" # handle case where n might be larger than available features\n",
|
|||
|
|
" n_actual = min(n, len(rankic_series))\n",
|
|||
|
|
" top_features = rankic_series.sort_values(ascending=False).head(n_actual).index.tolist()\n",
|
|||
|
|
" top_features = [col for col in feature_columns if col in top_features or col not in numeric_columns]\n",
|
|||
|
|
" return top_features\n",
|
|||
|
|
"\n",
|
|||
|
|
"def create_deviation_within_dates(df, feature_columns):\n",
|
|||
|
|
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
|
|||
|
|
" new_columns = {}\n",
|
|||
|
|
" ret_feature_columns = feature_columns[:]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 自动选择所有数值型特征\n",
|
|||
|
|
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # num_features = ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'cat_vol_spike', 'obv', 'maobv_6', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013']\n",
|
|||
|
|
" num_features = [col for col in num_features if 'cat' not in col and 'industry' not in col]\n",
|
|||
|
|
" num_features = [col for col in num_features if 'limit' not in col]\n",
|
|||
|
|
" num_features = [col for col in num_features if 'cyq' not in col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 遍历所有数值型特征\n",
|
|||
|
|
" for feature in num_features:\n",
|
|||
|
|
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" # grouped_mean = df.groupby(['trade_date'])[feature].transform('mean')\n",
|
|||
|
|
" # deviation_col_name = f'deviation_mean_{feature}'\n",
|
|||
|
|
" # new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
|
|||
|
|
" # ret_feature_columns.append(deviation_col_name)\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
|
|||
|
|
" deviation_col_name = f'deviation_mean_{feature}'\n",
|
|||
|
|
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
|
|||
|
|
" ret_feature_columns.append(deviation_col_name)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
|
|||
|
|
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
|
|||
|
|
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df, ret_feature_columns\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 13,
|
|||
|
|
"id": "47c12bb34062ae7a",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T14:57:50.841165Z",
|
|||
|
|
"start_time": "2025-04-03T14:49:25.889057Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"days = 5\n",
|
|||
|
|
"validation_days = 120\n",
|
|||
|
|
"\n",
|
|||
|
|
"import gc\n",
|
|||
|
|
"\n",
|
|||
|
|
"gc.collect()\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
"df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
|
|||
|
|
"# df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / \\\n",
|
|||
|
|
"# df.groupby('ts_code')['open'].shift(-1)\n",
|
|||
|
|
"\n",
|
|||
|
|
"df['cat_up_limit'] = df['pct_chg'] > 5\n",
|
|||
|
|
"df['label'] = df.groupby('ts_code')['cat_up_limit'].rolling(window=5, min_periods=1).max().shift(-5).fillna(0).astype(int).reset_index(level=0, drop=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
"filter_index = df['future_return'].between(df['future_return'].quantile(0.01), df['future_return'].quantile(0.99))\n",
|
|||
|
|
"\n",
|
|||
|
|
"# for col in [col for col in df.columns]:\n",
|
|||
|
|
"# train_data[col] = train_data[col].astype('str')\n",
|
|||
|
|
"# test_data[col] = test_data[col].astype('str')"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 14,
|
|||
|
|
"id": "29221dde",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"feature_columns = [col for col in df.head(10).merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left').merge(index_data, on='trade_date', how='left').columns]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if col not in ['trade_date',\n",
|
|||
|
|
" 'ts_code',\n",
|
|||
|
|
" 'label']]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'label' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'gen' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'is_st' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]\n",
|
|||
|
|
"# feature_columns = [col for col in feature_columns if 'volatility' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'circ_mv' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'code' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
|
|||
|
|
"# feature_columns = [col for col in feature_columns if col not in ['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if col not in ['intraday_lg_flow_corr_20', \n",
|
|||
|
|
" 'cap_neutral_cost_metric', \n",
|
|||
|
|
" 'hurst_net_mf_vol_60', \n",
|
|||
|
|
" 'complex_factor_deap_1', \n",
|
|||
|
|
" 'lg_buy_consolidation_20',\n",
|
|||
|
|
" 'cs_rank_ind_cap_neutral_pe',\n",
|
|||
|
|
" 'cs_rank_opening_gap',\n",
|
|||
|
|
" 'cs_rank_ind_adj_lg_flow']]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# df = fill_nan_with_daily_median(df, feature_columns)\n",
|
|||
|
|
"for feature_col in [col for col in feature_columns if col in df.columns]:\n",
|
|||
|
|
" median_val = df[feature_col].median()\n",
|
|||
|
|
" df[feature_col].fillna(0, inplace=True)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 15,
|
|||
|
|
"id": "b76ea08a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
" ts_code trade_date log_circ_mv\n",
|
|||
|
|
"0 000001.SZ 2019-01-02 16.574219\n",
|
|||
|
|
"1 000001.SZ 2019-01-03 16.583965\n",
|
|||
|
|
"2 000001.SZ 2019-01-04 16.633371\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'cat_up_limit', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_ratio_20d', '000852.SH_volatility', '000905.SH_volatility', '399006.SZ_volatility', '000852.SH_volume_change_rate', '000905.SH_volume_change_rate', '399006.SZ_volume_change_rate']\n",
|
|||
|
|
"去除极值\n",
|
|||
|
|
"开始截面 MAD 去极值处理 (k=3.0)...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"MAD Filtering: 100%|██████████| 131/131 [00:28<00:00, 4.67it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 MAD 去极值处理完成。\n",
|
|||
|
|
"开始截面 MAD 去极值处理 (k=3.0)...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"MAD Filtering: 100%|██████████| 131/131 [00:23<00:00, 5.67it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 MAD 去极值处理完成。\n",
|
|||
|
|
"开始截面 MAD 去极值处理 (k=3.0)...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"MAD Filtering: 0it [00:00, ?it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 MAD 去极值处理完成。\n",
|
|||
|
|
"开始截面 MAD 去极值处理 (k=3.0)...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"MAD Filtering: 0it [00:00, ?it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 MAD 去极值处理完成。\n",
|
|||
|
|
"feature_columns: ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'cat_up_limit', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_ratio_20d', '000852.SH_volatility', '000905.SH_volatility', '399006.SZ_volatility', '000852.SH_volume_change_rate', '000905.SH_volume_change_rate', '399006.SZ_volume_change_rate']\n",
|
|||
|
|
"df最小日期: 2019-01-02\n",
|
|||
|
|
"df最大日期: 2025-05-08\n",
|
|||
|
|
"2057680\n",
|
|||
|
|
"train_data最小日期: 2020-01-02\n",
|
|||
|
|
"train_data最大日期: 2022-12-30\n",
|
|||
|
|
"1733637\n",
|
|||
|
|
"test_data最小日期: 2023-01-03\n",
|
|||
|
|
"test_data最大日期: 2025-05-08\n",
|
|||
|
|
" ts_code trade_date log_circ_mv\n",
|
|||
|
|
"0 000001.SZ 2019-01-02 16.574219\n",
|
|||
|
|
"1 000001.SZ 2019-01-03 16.583965\n",
|
|||
|
|
"2 000001.SZ 2019-01-04 16.633371\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"train_data = df[filter_index & (df['trade_date'] <= '2023-01-01') & (df['trade_date'] >= '2020-01-01')]\n",
|
|||
|
|
"test_data = df[(df['trade_date'] >= '2023-01-01')]\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(df[['ts_code', 'trade_date', 'log_circ_mv']].head(3))\n",
|
|||
|
|
"\n",
|
|||
|
|
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
|
|||
|
|
"index_data = index_data.sort_values(by=['trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
"# train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
"# train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"# test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
"# test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# feature_columns_new = feature_columns[:]\n",
|
|||
|
|
"# train_data, _ = create_deviation_within_dates(train_data, [col for col in feature_columns if col in train_data.columns])\n",
|
|||
|
|
"# test_data, _ = create_deviation_within_dates(test_data, [col for col in feature_columns if col in train_data.columns])\n",
|
|||
|
|
"\n",
|
|||
|
|
"# feature_columns = [\n",
|
|||
|
|
"# 'undist_profit_ps', \n",
|
|||
|
|
"# 'AR_BR',\n",
|
|||
|
|
"# 'pe_ttm',\n",
|
|||
|
|
"# 'alpha_22_improved', \n",
|
|||
|
|
"# 'alpha_003', \n",
|
|||
|
|
"# 'alpha_007', \n",
|
|||
|
|
"# 'alpha_013', \n",
|
|||
|
|
"# 'cat_up_limit', \n",
|
|||
|
|
"# 'cat_down_limit', \n",
|
|||
|
|
"# 'up_limit_count_10d', \n",
|
|||
|
|
"# 'down_limit_count_10d', \n",
|
|||
|
|
"# 'consecutive_up_limit', \n",
|
|||
|
|
"# 'vol_break', \n",
|
|||
|
|
"# 'weight_roc5', \n",
|
|||
|
|
"# 'price_cost_divergence', \n",
|
|||
|
|
"# 'smallcap_concentration', \n",
|
|||
|
|
"# 'cost_stability', \n",
|
|||
|
|
"# 'high_cost_break_days', \n",
|
|||
|
|
"# 'liquidity_risk', \n",
|
|||
|
|
"# 'turnover_std', \n",
|
|||
|
|
"# 'mv_volatility', \n",
|
|||
|
|
"# 'volume_growth', \n",
|
|||
|
|
"# 'mv_growth', \n",
|
|||
|
|
"# 'lg_flow_mom_corr_20_60', \n",
|
|||
|
|
"# 'lg_flow_accel', \n",
|
|||
|
|
"# 'profit_pressure', \n",
|
|||
|
|
"# 'underwater_resistance', \n",
|
|||
|
|
"# 'cost_conc_std_20', \n",
|
|||
|
|
"# 'profit_decay_20', \n",
|
|||
|
|
"# 'vol_amp_loss_20', \n",
|
|||
|
|
"# 'vol_drop_profit_cnt_5', \n",
|
|||
|
|
"# 'lg_flow_vol_interact_20', \n",
|
|||
|
|
"# 'cost_break_confirm_cnt_5', \n",
|
|||
|
|
"# 'atr_norm_channel_pos_14', \n",
|
|||
|
|
"# 'turnover_diff_skew_20', \n",
|
|||
|
|
"# 'lg_sm_flow_diverge_20', \n",
|
|||
|
|
"# 'pullback_strong_20_20', \n",
|
|||
|
|
"# 'vol_wgt_hist_pos_20', \n",
|
|||
|
|
"# 'vol_adj_roc_20',\n",
|
|||
|
|
"# 'cashflow_to_ev_factor',\n",
|
|||
|
|
"# 'ocfps',\n",
|
|||
|
|
"# 'book_to_price_ratio',\n",
|
|||
|
|
"# 'turnover_rate_mean_5',\n",
|
|||
|
|
"# 'variance_20',\n",
|
|||
|
|
"# 'bbi_ratio_factor'\n",
|
|||
|
|
"# ]\n",
|
|||
|
|
"# feature_columns = [col for col in feature_columns if col in train_data.columns]\n",
|
|||
|
|
"# feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
|
|||
|
|
"\n",
|
|||
|
|
"numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns\n",
|
|||
|
|
"numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
|
|||
|
|
"# feature_columns = select_top_features_by_rankic(df, numeric_columns, n=10)\n",
|
|||
|
|
"print(feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# train_data = fill_nan_with_daily_median(train_data, feature_columns)\n",
|
|||
|
|
"# test_data = fill_nan_with_daily_median(test_data, feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_data = train_data.dropna(subset=[col for col in feature_columns if col in train_data.columns])\n",
|
|||
|
|
"train_data = train_data.dropna(subset=['label'])\n",
|
|||
|
|
"train_data = train_data.reset_index(drop=True)\n",
|
|||
|
|
"# print(test_data.tail())\n",
|
|||
|
|
"test_data = test_data.dropna(subset=[col for col in feature_columns if col in train_data.columns])\n",
|
|||
|
|
"# test_data = test_data.dropna(subset=['label'])\n",
|
|||
|
|
"test_data = test_data.reset_index(drop=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
"transform_feature_columns = feature_columns\n",
|
|||
|
|
"transform_feature_columns = [col for col in transform_feature_columns if col in feature_columns and not col.startswith('cat') and col in train_data.columns]\n",
|
|||
|
|
"# transform_feature_columns.remove('undist_profit_ps')\n",
|
|||
|
|
"print('去除极值')\n",
|
|||
|
|
"cs_mad_filter(train_data, transform_feature_columns)\n",
|
|||
|
|
"# print('中性化')\n",
|
|||
|
|
"# cs_neutralize_industry_cap(train_data, transform_feature_columns)\n",
|
|||
|
|
"# print('标准化')\n",
|
|||
|
|
"# cs_zscore_standardize(train_data, transform_feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"cs_mad_filter(test_data, transform_feature_columns)\n",
|
|||
|
|
"# cs_neutralize_industry_cap(test_data, transform_feature_columns)\n",
|
|||
|
|
"# cs_zscore_standardize(test_data, transform_feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"mad_filter_feature_columns = [col for col in feature_columns if col not in transform_feature_columns and not col.startswith('cat') and col in train_data.columns]\n",
|
|||
|
|
"cs_mad_filter(train_data, mad_filter_feature_columns)\n",
|
|||
|
|
"cs_mad_filter(test_data, mad_filter_feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(f'feature_columns: {feature_columns}')\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(f\"df最小日期: {df['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(f\"df最大日期: {df['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(len(train_data))\n",
|
|||
|
|
"print(f\"train_data最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(f\"train_data最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(len(test_data))\n",
|
|||
|
|
"print(f\"test_data最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(f\"test_data最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"cat_columns = [col for col in feature_columns if col.startswith('cat')]\n",
|
|||
|
|
"for col in cat_columns:\n",
|
|||
|
|
" train_data[col] = train_data[col].astype('category')\n",
|
|||
|
|
" test_data[col] = test_data[col].astype('category')\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(df[['ts_code', 'trade_date', 'log_circ_mv']].head(3))\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 16,
|
|||
|
|
"id": "3ff2d1c5",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
|
"from sklearn.linear_model import LogisticRegression\n",
|
|||
|
|
"import matplotlib.pyplot as plt # 保持 matplotlib 导入,尽管LightGBM的绘图功能已移除\n",
|
|||
|
|
"from sklearn.decomposition import PCA\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import datetime # 用于日期计算\n",
|
|||
|
|
"from catboost import CatBoostClassifier\n",
|
|||
|
|
"from catboost import Pool\n",
|
|||
|
|
"import lightgbm as lgb\n",
|
|||
|
|
"\n",
|
|||
|
|
"def train_model(train_data_df, feature_columns,\n",
|
|||
|
|
" print_info=True, # 调整参数名,更通用\n",
|
|||
|
|
" validation_days=180, use_pca=False, split_date=None,\n",
|
|||
|
|
" target_column='label', type='light'): # 增加目标列参数\n",
|
|||
|
|
"\n",
|
|||
|
|
" print('train data size: ', len(train_data_df))\n",
|
|||
|
|
" print(train_data_df[['ts_code', 'trade_date', 'log_circ_mv']])\n",
|
|||
|
|
" # 确保数据按时间排序\n",
|
|||
|
|
" train_data_df = train_data_df.sort_values(by='trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 识别数值型特征列\n",
|
|||
|
|
" numeric_feature_columns = train_data_df[feature_columns].select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 去除标签为空的样本\n",
|
|||
|
|
" initial_len = len(train_data_df)\n",
|
|||
|
|
" train_data_df = train_data_df.dropna(subset=[target_column])\n",
|
|||
|
|
"\n",
|
|||
|
|
" if print_info:\n",
|
|||
|
|
" print(f'原始样本数: {initial_len}, 去除标签为空后样本数: {len(train_data_df)}')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 提取特征和标签,只取数值型特征用于线性回归\n",
|
|||
|
|
" \n",
|
|||
|
|
" if split_date is None:\n",
|
|||
|
|
" all_dates = train_data_df['trade_date'].unique() # 获取所有唯一的 trade_date\n",
|
|||
|
|
" split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
|
|||
|
|
" train_data_split = train_data_df[train_data_df['trade_date'] < split_date] # 训练集\n",
|
|||
|
|
" val_data_split = train_data_df[train_data_df['trade_date'] >= split_date] # 验证集\n",
|
|||
|
|
" \n",
|
|||
|
|
" X_train = train_data_split[feature_columns]\n",
|
|||
|
|
" y_train = train_data_split[target_column]\n",
|
|||
|
|
" \n",
|
|||
|
|
" X_val = val_data_split[feature_columns]\n",
|
|||
|
|
" y_val = val_data_split['label']\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # 标准化数值特征 (使用 StandardScaler 对训练集fit并transform, 对验证集只transform)\n",
|
|||
|
|
" scaler = StandardScaler()\n",
|
|||
|
|
" # X_train = scaler.fit_transform(X_train)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 训练线性回归模型\n",
|
|||
|
|
" # model = LogisticRegression(random_state=42)\n",
|
|||
|
|
" \n",
|
|||
|
|
" # # 使用处理后的特征和样本权重进行训练\n",
|
|||
|
|
" # model.fit(X_train, y_train)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" if type == 'cat':\n",
|
|||
|
|
" params = {\n",
|
|||
|
|
" 'loss_function': 'Logloss', # 适用于二分类\n",
|
|||
|
|
" 'eval_metric': 'Logloss', # 评估指标\n",
|
|||
|
|
" 'iterations': 1500,\n",
|
|||
|
|
" 'learning_rate': 0.01,\n",
|
|||
|
|
" 'depth': 10, # 控制模型复杂度\n",
|
|||
|
|
" 'l2_leaf_reg': 50, # L2 正则化\n",
|
|||
|
|
" 'verbose': 5000,\n",
|
|||
|
|
" 'early_stopping_rounds': 3000,\n",
|
|||
|
|
" 'one_hot_max_size': 50,\n",
|
|||
|
|
" 'class_weights': [0.6, 1.2],\n",
|
|||
|
|
" 'task_type': 'GPU',\n",
|
|||
|
|
" 'has_time': True,\n",
|
|||
|
|
" 'random_seed': 7\n",
|
|||
|
|
" }\n",
|
|||
|
|
" cat_features = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
|
|||
|
|
" train_pool = Pool(data=X_train, label=y_train, cat_features=cat_features)\n",
|
|||
|
|
" val_pool = Pool(data=X_val, label=y_val, cat_features=cat_features)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" model = CatBoostClassifier(**params)\n",
|
|||
|
|
" model.fit(train_pool,\n",
|
|||
|
|
" eval_set=val_pool, \n",
|
|||
|
|
" plot=True, \n",
|
|||
|
|
" use_best_model=True\n",
|
|||
|
|
" )\n",
|
|||
|
|
" elif type == 'light':\n",
|
|||
|
|
" params = {\n",
|
|||
|
|
" 'objective': 'binary',\n",
|
|||
|
|
" 'metric': 'average_precision',\n",
|
|||
|
|
" 'learning_rate': 0.01,\n",
|
|||
|
|
" 'is_unbalance': True,\n",
|
|||
|
|
" 'num_leaves': 2048,\n",
|
|||
|
|
" 'min_data_in_leaf': 1024,\n",
|
|||
|
|
" 'max_depth': 32,\n",
|
|||
|
|
" 'max_bin': 1024,\n",
|
|||
|
|
" 'feature_fraction': 0.5,\n",
|
|||
|
|
" 'bagging_fraction': 0.5,\n",
|
|||
|
|
" 'bagging_freq': 1,\n",
|
|||
|
|
" 'lambda_l1': 50,\n",
|
|||
|
|
" 'lambda_l2': 50,\n",
|
|||
|
|
" 'verbosity': -1,\n",
|
|||
|
|
" 'num_threads' : 8\n",
|
|||
|
|
" }\n",
|
|||
|
|
" categorical_feature = [col for col in feature_columns if 'cat' in col]\n",
|
|||
|
|
" train_dataset = lgb.Dataset(\n",
|
|||
|
|
" X_train, label=y_train,\n",
|
|||
|
|
" categorical_feature=categorical_feature\n",
|
|||
|
|
" )\n",
|
|||
|
|
" val_dataset = lgb.Dataset(\n",
|
|||
|
|
" X_val, label=y_val,\n",
|
|||
|
|
" categorical_feature=categorical_feature\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" evals = {}\n",
|
|||
|
|
" callbacks = [lgb.log_evaluation(period=1000),\n",
|
|||
|
|
" lgb.callback.record_evaluation(evals),\n",
|
|||
|
|
" lgb.early_stopping(100, first_metric_only=True)\n",
|
|||
|
|
" ]\n",
|
|||
|
|
" # 训练模型\n",
|
|||
|
|
" model = lgb.train(\n",
|
|||
|
|
" params, train_dataset, num_boost_round=1000,\n",
|
|||
|
|
" valid_sets=[train_dataset, val_dataset], valid_names=['train', 'valid'],\n",
|
|||
|
|
" callbacks=callbacks\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 打印特征重要性(如果需要)\n",
|
|||
|
|
" if True:\n",
|
|||
|
|
" lgb.plot_metric(evals)\n",
|
|||
|
|
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
|
|||
|
|
" plt.show()\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" return model, scaler, None # 返回训练好的模型、scaler 和 pca 对象"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 47,
|
|||
|
|
"id": "c6eb5cd4-e714-420a-ac48-39af3e11ee81",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T15:03:18.426481Z",
|
|||
|
|
"start_time": "2025-04-03T15:02:19.926352Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"train data size: 218400\n",
|
|||
|
|
" ts_code trade_date log_circ_mv\n",
|
|||
|
|
"0 600306.SH 2020-01-02 11.552040\n",
|
|||
|
|
"1 603269.SH 2020-01-02 11.324801\n",
|
|||
|
|
"2 002633.SZ 2020-01-02 11.759023\n",
|
|||
|
|
"3 603991.SH 2020-01-02 11.181150\n",
|
|||
|
|
"4 000691.SZ 2020-01-02 11.677910\n",
|
|||
|
|
"... ... ... ...\n",
|
|||
|
|
"218395 001207.SZ 2022-12-30 11.385045\n",
|
|||
|
|
"218396 002377.SZ 2022-12-30 12.425814\n",
|
|||
|
|
"218397 600714.SH 2022-12-30 12.427457\n",
|
|||
|
|
"218398 002521.SZ 2022-12-30 12.223073\n",
|
|||
|
|
"218399 600322.SH 2022-12-30 12.428769\n",
|
|||
|
|
"\n",
|
|||
|
|
"[218400 rows x 3 columns]\n",
|
|||
|
|
"原始样本数: 218400, 去除标签为空后样本数: 218400\n",
|
|||
|
|
"Training until validation scores don't improve for 100 rounds\n",
|
|||
|
|
"Early stopping, best iteration is:\n",
|
|||
|
|
"[620]\ttrain's average_precision: 0.379234\tvalid's average_precision: 0.304848\n",
|
|||
|
|
"Evaluated only: average_precision\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkMAAAHGCAYAAAB3rI9tAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAdG5JREFUeJzt3Xd8U+X+B/BPdroX3aVAd4GWlmnZSxkiIF4FRVygIjJ+qIjXxVWZXkXvBVRwgluRoXAVZYgDhAKFltIWCi1QunfaNM06vz8CgRhGmwZSms/79eIFOSvP+RKaD+c853lEgiAIICIiInJSYkc3gIiIiMiRGIaIiIjIqTEMERERkVNjGCIiIiKnxjBERERETo1hiIiIiJwawxARERE5NYYhIiIicmoMQ0REROTUGIaIWqENGzYgNjYW06dPNy/76KOPEBsbi+eee+6Gt2ffvn2IjY29Ye83ZcoUrFixwu7H3bBhA4YOHWr347Y2BQUFiI2NRUFBgc3HeO655xzyWSNyBKmjG0BEV5aTk2P+c3Z2tk3H2L59OwBg+PDhNrejS5cuWL9+vc37txZDhgxBly5dHNqGTz75BH369EF8fPx1e4+AgACsX78eAQEBNh9j5syZdmwRUevGMETUihUWFqK2thaenp4Wwag57BGG3N3dkZCQYPP+rYWPjw98fHwc2oZ169bB09PzuoYhuVze4r+vsLAwO7WGqPXjbTKiVio0NBT+/v7IycmBTqfDyZMnkZiY6OhmERG1OQxDRK1YbGwssrOzcerUKYjFYnTs2NFi/bZt23DHHXcgMTER48ePx969e83rhg4ditjYWGzcuBEbN25EbGwsYmNjsW/fPvM2F/oCGY1GvP/++7jtttvwzjvvWLXjan2GcnNz8fDDDyMxMRGDBw/G22+/Db1e36zzXLNmDfr164fu3bvj9ddfhyAI5nUrVqzAlClTLLb/e9+fC6+1Wi3eeOMNDB48GJs2bbJ6nyv1GbrQP2bnzp0YNWoUkpKSMG3aNFRWVpq3KSwsxNSpU5GcnIy7774bq1atwuDBg/Hpp59e8/wu9OGJjY3FuXPn8M9//tP8+lIXzrW2thYvv/wy+vXrh/3795vXGwwGLF++HAMHDkRycjImTZqEzMzMK77f3/sMXTj/U6dOYfLkyejWrRvuuOMOZGRkXLEmfzd06FBs2LABq1evRr9+/dCrVy+89tprFn9n+/fvxx133IHu3btj9uzZWLhwIfr06WPzrV6i641hiKgVi4mJQU5ODrKzsxEZGQmJRGJet2/fPsyZMwfDhw/Hhx9+iISEBDz66KM4efIkAODdd9/F+vXrMWTIEAwZMgTr16/H+vXrL9tn5rXXXsOWLVswadIk9O/fv8ntKykpweTJkyEIAt577z3MmTMHa9euxbvvvtvkY2zduhXLly/HpEmTsGLFCmRkZCAtLa3J+19q1qxZSEtLw8MPP9zs20RHjx7Fv/71L0yfPh2LFy9GWloa3n//ffP6F154ASKRCKtXr0ZERATWrl2LFStWNKlD9oU+POvXr4e/vz9mzpxpfv13Op0ODz74IEpLS/HEE0+gQ4cO5nVr1qzBunXr8NRTT2HNmjXw9/fHnDlzmnWe9fX1mDZtGgYMGGAOvgsWLGjWMdauXYsdO3Zg0aJFmDZtGj777DP8+uuvAIC6ujrMnDkTAwcOxDvvvIOCggLk5OTg/fff5603arXYZ4ioFYuJicHnn38OT09PxMTEWKxbuXIlhgwZYv4y7NGjB3755Rds3boVs2fPNl918Pb2BoCrhoOsrCx8/fXXUCqVzWrfF198AUEQsHLlSri7uwMAVCoVysrKmnyMtWvXYuDAgZg1axYAIC4uDoMHD25WOwDg3Llz6NKlCz799FOIxc3/f15ubi6+/fZbc51SU1MtrmQcPnwYb7/9Nnr37g0/Pz9s2rQJgYGBTeqkfGkfHrlcjtDQ0Cv+faSlpeGBBx7ACy+8YLWuc+fOWLFiBQYMGAAAKC8vx88//4yKigr4+fk16Tyrq6sxffp0PPzwwwCAGTNm4JlnnmnSvhdUVFRg27ZtcHNzw+DBg/H9998jOzsbQ4YMQV5eHmpqajBr1iwolUr84x//wNq1a3mLl1o1hiGiViwmJga5ubnw8PDAgAEDcPz4cfO648ePo7q62upWy+nTp5v9Ps8//3yzgxAAHDt2DLGxseYgBAAPPPBAs46Rn59vsY+fnx8iIiKuuo/RaLRaJpPJ8Pzzz9sUhAAgKSnJIqD4+voiNzfX/LpTp0747bff0Lt3b+zatQve3t5o166dTe91NT4+Ppg7d+5l1w0YMABbt27FvHnzkJaWhnPnzgEAGhoamnx8sViMe++91/za19e32bc1J0yYADc3t8seIywsDHK5HDt37sSgQYPw559/IioqqlnHJ7rRGIaIWrGoqChotVrs2bMHU6dOtQhDAHDffffhnnvusVjm4eHR7Pex9X/tl/YTuaC8vBz5+fno3r17k4KJ0Wi0uP0HwOr13xUXF1stCwgIQHBw8DXf70rat29/1fXx8fH47rvv8Nlnn8HDwwP//ve/bQ5eVxMdHQ1XV9fLrvu///s/HDhwAJMmTcKoUaMQHByM8ePHN+v4AQEBNgXfS12tVi4uLoiKisKzzz4LnU6HTp06Yc2aNS16P6LrjWGIqBVTKBQIDw9HXl6e1W2y6OholJWVWTyivXLlSvj4+GDy5MnmZXK5HGq1+rq0r3Pnzvjqq69QV1dnvjr0xRdf4PPPP7foqH014eHhOHr0qPl1dXU1Tp06hSFDhgAwBaNLr3wYjUZs27bNjmcB8/tcSWZmJnbu3Il9+/ahsLAQISEhUCgUNr2PXC6HwWBo9n4qlQrbtm3D4sWLcddddwEAfvvtt2Yf51pBs6XH+PLLL9G+fXusW7cOVVVVCAsLuy6hkcie+AklauViYmLg4+Nj1TflySefxPbt2/HWW28hNTUVa9aswapVq+Dv72+xXbdu3bB371789ttv2Lt3L7755hu7te2+++4DYBqg788//8TGjRvx2WefYeLEiU0+xuTJk7Fz5068++672Lt3L55++mloNBrz+ri4OGRnZyM9PR1qtRoLFy5ERUWF3c6hKSQSCWpqavDll1+isrISeXl5NrehW7du2Lx5Mw4cOICdO3eaOx5fi0KhgIuLC7Zv344DBw7g448/xlNPPQUANoWr60UsFuPQoUP47bffUFVVhZMnT6Kurs7RzSK6KoYholYuJibG6qoQAKSkpGD58uXYsWMHHnnkEWzcuBGLFi3CbbfdZrHdnXfeiZEjR+KZZ57BY489hvT0dLu1LSgoCJ999hkA4IknnsCKFSvwwAMPmDtDN8WECRMwb948fPvtt5gxYwbCwsLQvXt38/ohQ4Zg4sSJmDp1Km699VZIJBLMnj3bbufQFFFRUejatSvee+89PPzwwxg3bhz69u2LBx54wCK4NcUzzzwDV1dXTJ06Fc8++2yTO5vL5XK8+eabyMvLw0MPPYTNmzfjlVdegVQqxcGDB205reti5MiR0Gg0WLRoESZPnowxY8agV69eeP311x3dNKIrEgmXu+lPRERmb7/9Nv744w88/fTTcHNzg06nw4EDB7B8+XJs3rwZcXFxjm5iq3HfffehQ4cOuOuuuyCXy1FfX4/169fj999/txgziag1YRgiIrqGU6dO4fXXX0d6ejpqa2shk8kQERGB8ePHWw0I6ex27dqF1atXIzc3F2q1Gh4eHujSpQumTp2Kfv36Obp5RJfFMEREREROjX2GiIiIyKkxDBEREZFTYxgiIiIip8YwRERERE6NI1A3QVpaGgRBgEwmc3RTiIiIqIl0Oh1EIhGSk5Ovuh2vDDWBIAjmX2Sqh1arZT0uwZpYY02ssSbWWBNrrIk1W2vS1O9uXhlqAplMBq1Wi6ioqCtOoOhM1Go1srKyWI9LsCbWWBNrrIk11sQaa2LN1ppkZGQ0aTteGSIiIiKnxjBERERETo1hiIiIiJwawxARERE5NXagJiIisiODwQCdTmfz/o2NjebfxWJeswAuXxOZTAaJRGKX4zM
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwMAAAHGCAYAAAAomFcNAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Xlcjen/+PFXh5YTIiORtWISChNmrEP2Lfsy9rGGMBFjXyvLWKJs2aexk8QYDGYYn0G2IWSXpci+pKOTOr8/+nV/nYlxyiHL+/l4nIdzrvu+r/u63+V0v+/7uq7bRKfT6RBCCCGEEEJ8dlRZ3QAhhBBCCCFE1pBkQAghhBBCiM+UJANCCCGEEEJ8piQZEEIIIYQQ4jMlyYAQQgghhBCfKUkGhBBCCCGE+ExJMiCEEEIIIcRnSpIBIYQQQgghPlOSDAghhBBCCPGZkmRACCHEJ+Xw4cM4OTm98hUYGJjVzcsSN2/exMnJiZs3b2Z1U4QQH5jsWd0AIYQQ4l2YMWMGxYsX1yvLnz+/UfcRFRXF4cOH6d69u1HrNbb8+fOzceNGox+/MT158oSVK1fSrVs3rKyssro5Qnw25M6AEEKIT1KJEiVwcXHRe9na2hp1H1FRUfz8889GrfNdMDMzw8XFBTMzs6xuyms9efKEoKAgnjx5ktVNEeKzIsmAEEIIIYQQnylJBoQQQnyWVq9eTb169ShfvjzfffcdZ8+e1Vv+119/0apVK8qVK4e7uzsrVqxQlo0YMQInJydGjhxJTEzMK8ckuLu7Exoaqldnly5d9NZJ+3zx4kU8PT2pWLGi3vqPHj3ixx9/pFKlSlSpUoVx48aRkJCQ4WN93ZgBJycnFixYQO3atalZsyb79u2jadOmVK5cmb179yrbLVu2jAYNGvDVV1/Rr18/4uLi9OrZt28fzZo1o2zZsnh4eLB//3695YGBgXTp0oUnT54wbtw4qlWrRkREBAChoaE4OTlRp04dAOrUqYOTkxNdunTRq+O/fh5p40Tu3r2Lp6cn5cuXp169enrt0Ol0LFq0iG+//ZYKFSrQuXNnTp8+rbePiIgI2rZti6urK40aNWLbtm0ZjrUQHxtJBoQQQnx2Nm/ejL+/P506dSI4OBhra2u6devGw4cPAbhx4wb9+/enTJkyLF26lF69ejFt2jSOHj0KgJeXFxs3bsTLywsbGxs2btzIxo0badeuXYbbcv36dbp27UrBggX54Ycf9JYNHDiQM2fO8NNPPzFhwgT27t3L+PHj3/r4X7Zt2zYmT57Mixcv+OGHH5TjXrt2rbJOUFAQvXr1YsaMGVy9epV+/fqh0+kAOHjwIJ6enlSoUIHFixdTvnx5PD09OXz4sN5+kpKS6NatG3fu3KFfv34UK1YMgNq1a7Nx40YWLFgAwIIFC9i4cSMTJ05Utn3TzyNN7969sbe3Z8GCBdjZ2fHjjz+SkpICwKxZs1iwYAE9e/Zk4cKF5M6dm549eyo/86tXr9KjRw+cnZ1ZunQpDRo0YOjQoRw8eNCo8RbiQyMDiIUQQnySWrRoofd548aNuLi4AKlXqjt06KAM/HVycqJKlSrs3buX1q1bk5KSwrhx4/Dw8MDc3JxSpUoxf/58Tpw4QcWKFSlcuDCFCxfm4sWLSn/8zAoPD2fevHnUrVtXrzwiIoKIiAg2b95M6dKlAYiLi2P69On4+fkZrf+/p6cn1atXx8HBAXt7exo3bsylS5c4cuSIsk7Pnj1p27YtAJaWlnTr1o3jx4/j5ubGvHnz+Oqrr5g0aRIAVapU4cqVKwQFBfH1118rdZw4cYKuXbsyevRovf1bW1tjbW2t3LX48ssvKVy4sN46b/p5pKlYsSI//vgjADly5KBt27bcvXuXnDlzsmLFCvr160fXrl0BcHBwYPLkydy6dQtra2uCg4NxdHRUjqNSpUrs27ePsLAwqlSpYpRYC/EhkmRACCHEJ2n27NnY29srn9Pex8fHExMTQ0hICCEhIXrbXLt2DYBixYrx7Nkz5syZw9GjRzl79izJyck8f/78rdqUdpX6ZTVr1kyXCACcP38egJYtW6ZbFhMTo3dsbyNthiETExO99y/76quvlPdpic/169dxc3Pj9OnT9OzZU2/9KlWqsGTJEr0ya2trvL29M9VGQ38enTp1Ut7nzZsXSL0jcfnyZbRaLW5ubspyGxsb5s6dq3w+f/48586dw8nJSa9Oc3PzTLVZiI+FJANCCCE+Sfb29jg7O792+Q8//ECtWrX0ytJOIHfv3s2gQYNo0qQJHTt2xNXV1Sjdc27dupWuzNXV9bXrZ8uWjY0bN6Y7Obezs3vrtmRWWvegtDalfX7demlKliyJpaVlpvZp6M+jaNGiBrUlzalTp8ibN69yJ6Ju3bp4eXnprWNhYZGpNgvxsZAxA0IIIT4rOXPmxM7OjkePHuHs7Ky8NmzYwLFjx4DUMQVfffUVP/30Ey1atKBQoULExMSkq8vc3JwXL168cj/ZsmVDo9Eon0+dOvXKOl6nZMmSJCcno1KplDaqVCqWLl3K48ePM3jUbydtsC/AyZMnAZRnOLi4uKQbH3Do0KEMd51KuwKfnJycbpmhP49s2bK9sm5HR0dMTU31xhgkJibSqVMn/vjjDyA13rdu3dL7nTh+/Di//vprho5DiI+N3BkQQgjx2RkwYAATJ07ExsaG8uXL8/vvv7Nu3TqaN28OpHZpOXbsGPv27SMxMZElS5YQExOT7sS/TJky3L9/nw0bNuDg4MDx48fp2bMnKpWKUqVKsW3bNlq2bMnt27cZOXIkefLkMbiN33zzDZUqVcLHx4fBgwdjYWHB7NmzefHiBTY2NsYMxxutXLmSggUL8sUXXzBt2jRcXV0pX748kBrLHj16MGHCBBo0aMDOnTs5duwYy5cvz9A+bGxssLOz4+eff6Zp06ZER0fj4uJCiRIlDP55vE7OnDnp2rUrixcvJleuXHz55ZesWbMGS0tLGjRoAECfPn1o3rw5Y8eOpWnTply9epVp06ZlumuTEB8LSQaEEEJ8dtq0acPz589ZsWIFc+fOxdHRkXnz5lGuXDkABg8eTFxcHIMHDyZnzpy0aNECGxsbjh8/rldP8eLFmThxIkFBQdy9exd7e3ul//ywYcMYPnw41apVo2DBggwZMoSVK1dmqJ1z585lypQpjBw5EhMTE6pVq6a8f5+GDBnC8uXLuXnzJpUrV1YG2ULq+ICFCxcyY8YMNm7ciL29PQsXLtQbPGyo2bNnM2HCBNatW0fevHlZtGgRYPjP47/4+PiQK1cuFi9eTHx8PK6urqxYsUIZJ+Ho6MjixYuZOXMmmzdvJn/+/AwcOJDvv/8+w8chxMfERPe6jnRCCCGE+KzdvHmTOnXq6M3EJIT4tMiYASGEEEIIIT5TcmdACCGEEEKIz5TcGRBCCCGEEOIzJcmAEEIIIYQQnylJBoQQQgghhPhMSTIghBBCCCHEZ0qeMyCEeK0TJ06g0+kwNTXN6qYIIYQQwkBJSUmYmJhQoUKFN64rdwaEEK+l0+mUl3h7Op0OrVYr8TQSiadxSTyNT2JqXBJPw2Xkb7fcGRBCvJapqSlarZYSJUpgaWmZ1c356CUkJBAVFSXxNBKJp3FJPI1PYmpcEk/DRUZGGryu3BkQQgghhBDiMyXJgBBCCCGEEJ8pSQaEEEIIIYT4TEkyIIQQQgghxGdKkgEhhBBCCCE+U5IMCCGEEEII8ZmSZEAIIYQQQojPlCQDQgghhBBCfKYkGRBCCCGEEOIzJU8gFkIIIYQQwgCHDh2iW7du6crPnDlD9uzZuXv3Lr/88gtXr16lUKFCdOrUicKFC6db//z587Rq1YqZM2fSsGHD99H015JkQAghhBBCCAOcOXOGsmXLMmHCBL3y7Nmz8+DBA9q0aUPRokX56quv+PPPPwkNDWX79u188cUXyro6nY7x48dTsWLFLE8EQLoJiXfg8OHDODk5pSvv0qULgYGBWdCid+fevXv07NmT8uXL89VXXxESEqK3PDAwkC5dumRR64QQQghhTGfPnsXV1RUXFxe9F8C8efMoWrQoK1euxNvbmzV
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"train data size: 218400\n",
|
|||
|
|
" ts_code trade_date log_circ_mv\n",
|
|||
|
|
"0 600306.SH 2020-01-02 11.552040\n",
|
|||
|
|
"1 603269.SH 2020-01-02 11.324801\n",
|
|||
|
|
"2 002633.SZ 2020-01-02 11.759023\n",
|
|||
|
|
"3 603991.SH 2020-01-02 11.181150\n",
|
|||
|
|
"4 000691.SZ 2020-01-02 11.677910\n",
|
|||
|
|
"... ... ... ...\n",
|
|||
|
|
"218395 001207.SZ 2022-12-30 11.385045\n",
|
|||
|
|
"218396 002377.SZ 2022-12-30 12.425814\n",
|
|||
|
|
"218397 600714.SH 2022-12-30 12.427457\n",
|
|||
|
|
"218398 002521.SZ 2022-12-30 12.223073\n",
|
|||
|
|
"218399 600322.SH 2022-12-30 12.428769\n",
|
|||
|
|
"\n",
|
|||
|
|
"[218400 rows x 3 columns]\n",
|
|||
|
|
"原始样本数: 218400, 去除标签为空后样本数: 218400\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"application/vnd.jupyter.widget-view+json": {
|
|||
|
|
"model_id": "cc452ab5227045d1944aa8d8ed650f4a",
|
|||
|
|
"version_major": 2,
|
|||
|
|
"version_minor": 0
|
|||
|
|
},
|
|||
|
|
"text/plain": [
|
|||
|
|
"MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"0:\tlearn: 0.6890710\ttest: 0.6898147\tbest: 0.6898147 (0)\ttotal: 54.1ms\tremaining: 1m 21s\n",
|
|||
|
|
"1499:\tlearn: 0.3800534\ttest: 0.5434508\tbest: 0.5419956 (767)\ttotal: 57.4s\tremaining: 0us\n",
|
|||
|
|
"bestTest = 0.5419955563\n",
|
|||
|
|
"bestIteration = 767\n",
|
|||
|
|
"Shrink model to first 768 iterations.\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"\n",
|
|||
|
|
"gc.collect()\n",
|
|||
|
|
"feature_columns.remove('score1')\n",
|
|||
|
|
"feature_columns.remove('score2')\n",
|
|||
|
|
"\n",
|
|||
|
|
"use_pca = False\n",
|
|||
|
|
"# feature_contri = [2 if feat.startswith('act_factor') or 'buy' in feat or 'sell' in feat else 1 for feat in feature_columns]\n",
|
|||
|
|
"# light_params['feature_contri'] = feature_contri\n",
|
|||
|
|
"# print(f'feature_contri: {feature_contri}')\n",
|
|||
|
|
"model1, scaler, pca = train_model(train_data\n",
|
|||
|
|
" .dropna(subset=['label']).groupby('trade_date', group_keys=False)\n",
|
|||
|
|
" .apply(lambda x: x.nsmallest(300, 'total_mv'))\n",
|
|||
|
|
" .merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
" .merge(index_data, on='trade_date', how='left'), feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"model2, scaler, pca = train_model(train_data\n",
|
|||
|
|
" .dropna(subset=['label']).groupby('trade_date', group_keys=False)\n",
|
|||
|
|
" .apply(lambda x: x.nsmallest(300, 'total_mv'))\n",
|
|||
|
|
" .merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
" .merge(index_data, on='trade_date', how='left'), feature_columns, type='cat')\n",
|
|||
|
|
"\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 52,
|
|||
|
|
"id": "e82213f0",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"train data size: 218400\n",
|
|||
|
|
" ts_code trade_date log_circ_mv\n",
|
|||
|
|
"0 600306.SH 2020-01-02 11.552040\n",
|
|||
|
|
"1 603269.SH 2020-01-02 11.324801\n",
|
|||
|
|
"2 002633.SZ 2020-01-02 11.759023\n",
|
|||
|
|
"3 603991.SH 2020-01-02 11.181150\n",
|
|||
|
|
"4 000691.SZ 2020-01-02 11.677910\n",
|
|||
|
|
"... ... ... ...\n",
|
|||
|
|
"218395 001207.SZ 2022-12-30 11.385045\n",
|
|||
|
|
"218396 002377.SZ 2022-12-30 12.425814\n",
|
|||
|
|
"218397 600714.SH 2022-12-30 12.427457\n",
|
|||
|
|
"218398 002521.SZ 2022-12-30 12.223073\n",
|
|||
|
|
"218399 600322.SH 2022-12-30 12.428769\n",
|
|||
|
|
"\n",
|
|||
|
|
"[218400 rows x 3 columns]\n",
|
|||
|
|
"原始样本数: 218400, 去除标签为空后样本数: 218400\n",
|
|||
|
|
"Training until validation scores don't improve for 100 rounds\n",
|
|||
|
|
"Early stopping, best iteration is:\n",
|
|||
|
|
"[125]\ttrain's average_precision: 0.5023\tvalid's average_precision: 0.291157\n",
|
|||
|
|
"Evaluated only: average_precision\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkcAAAHGCAYAAAB+Ry8XAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAcGZJREFUeJzt3XlcVOX+B/DPmX3YF0FBXEFwF1xTc2/TzK3NMsvSyszllpXeLNuz+nXNm5pJq5otVzL16i3LpVVS3BUBUURFUNmZYfaZ8/tjYBrCBYaBYeDzfr14wZyZc+b7cBjmM895znMEURRFEBEREREAQOLpAoiIiIgaE4YjIiIiIicMR0REREROGI6IiIiInDAcERERETlhOCIiIiJywnBERERE5IThiIiIiMgJwxERERGRE4YjIg/ZuHEj4uLiMHPmTMeyTz/9FHFxcVi4cGGD17N3717ExcU12PNNnToVy5cvd/t2N27ciJEjR7p9u41NTk4O4uLikJOT4/I2Fi5c6JG/NaLGTubpAoiau4yMDMfP6enpLm1jx44dAICbbrrJ5Tq6deuGpKQkl9dvLEaMGIFu3bp5tIbPP/8cAwYMQJcuXertOcLDw5GUlITw8HCXtzF79mw3VkTUdDAcEXlYbm4uysrKEBAQUCUo1YY7wpGfnx969Ojh8vqNRXBwMIKDgz1aw9q1axEQEFCv4UihUNR5f0VFRbmpGqKmhYfViDyodevWCAsLQ0ZGBsxmM06fPo2ePXt6uiwiomaN4YjIw+Li4pCeno6srCxIJBK0b9++yv3bt2/HHXfcgZ49e2LChAlITk523Ddy5EjExcXhu+++w3fffYe4uDjExcVh7969jsdUjiWy2Wz46KOPcMstt+CDDz6oVse1xhydOnUKDz/8MHr27Inhw4dj2bJlsFgstWpnYmIiBg8ejN69e+Odd96BKIqO+5YvX46pU6dWefzfxw5V3jaZTHj33XcxfPhwbNq0qdrzXG3MUeX4ml27dmH06NGIj4/HjBkzUFRU5HhMbm4upk+fjoSEBNx9991YuXIlhg8fjnXr1l23fZVjgOLi4nDhwgX885//dNx2VtnWsrIyLF68GIMHD8a+ffsc91utVixduhRDhw5FQkICJk+ejNTU1Ks+39/HHFW2PysrC1OmTEGvXr1wxx134NixY1f9nfzdyJEjsXHjRqxevRqDBw9Gv3798Nprr1XZZ/v27cMdd9yB3r17Y+7cuXj99dcxYMAAlw8NEzUmDEdEHhYbG4uMjAykp6cjOjoaUqnUcd/evXsxb9483HTTTfjkk0/Qo0cPPProozh9+jQAYNWqVUhKSsKIESMwYsQIJCUlISkp6Ypjbl577TVs3boVkydPxo033ljj+i5duoQpU6ZAFEV8+OGHmDdvHtasWYNVq1bVeBvbtm3D0qVLMXnyZCxfvhzHjh3DoUOHary+szlz5uDQoUN4+OGHa31Y6fjx43j55Zcxc+ZMvPnmmzh06BA++ugjx/2LFi2CIAhYvXo1OnbsiDVr1mD58uU1GuBdOQYoKSkJYWFhmD17tuP235nNZjz00EO4fPkynnjiCbRr185xX2JiItauXYunn34aiYmJCAsLw7x582rVzvLycsyYMQNDhgxxBOGXXnqpVttYs2YNdu7ciTfeeAMzZszAF198gZ9//hkAoNVqMXv2bAwdOhQffPABcnJykJGRgY8++oiH6qhJ4JgjIg+LjY3F+vXrERAQgNjY2Cr3rVixAiNGjHC8Ofbp0wc//fQTtm3bhrlz5zp6JYKCggDgmmEhLS0N33zzDVQqVa3q+/LLLyGKIlasWAE/Pz8AgEajQX5+fo23sWbNGgwdOhRz5swBAHTu3BnDhw+vVR0AcOHCBXTr1g3r1q2DRFL7z3anTp3Chg0bHL+nlJSUKj0dhw8fxrJly9C/f3+EhoZi06ZNaNmyZY0GPTuPAVIoFGjduvVV98ehQ4fw4IMPYtGiRdXu69q1K5YvX44hQ4YAAAoKCvDjjz+isLAQoaGhNWpnSUkJZs6ciYcffhgAMGvWLDzzzDM1WrdSYWEhtm/fDl9fXwwfPhxbtmxBeno6RowYgTNnzqC0tBRz5syBSqXCXXfdhTVr1vCQMDUZDEdEHhYbG4tTp07B398fQ4YMwcmTJx33nTx5EiUlJdUOzZw9e7bWz/P888/XOhgBwIkTJxAXF+cIRgDw4IMP1mob2dnZVdYJDQ1Fx44dr7mOzWartkwul+P55593KRgBQHx8fJXAEhISglOnTjlud+jQAb/++iv69++P3bt3IygoCC1atHDpua4lODgYTz311BXvGzJkCLZt24Znn30Whw4dwoULFwAAer2+xtuXSCS47777HLdDQkJqfRh00qRJ8PX1veI2oqKioFAosGvXLgwbNgx//PEHYmJiarV9osaM4YjIw2JiYmAymbBnzx5Mnz69SjgCgPvvvx/33HNPlWX+/v61fh5XP9U7jzOpVFBQgOzsbPTu3btGQcVms1U5XAig2u2/u3jxYrVl4eHhiIiIuO7zXU2bNm2ueX+XLl3w7bff4osvvoC/vz/+7//+z+Ugdi2dOnWCj4/PFe/7xz/+gf3792Py5MkYPXo0IiIiMGHChFptPzw83KUg7Oxavyu1Wo2YmBg899xzMJvN6NChAxITE+v0fESNCcMRkYcplUq0bdsWZ86cqXZYrVOnTsjPz69ySviKFSsQHByMKVOmOJYpFArodLp6qa9r1674+uuvodVqHb1HX375JdavX19l4Pe1tG3bFsePH3fcLikpQVZWFkaMGAHAHpSce0ZsNhu2b9/uxlbA8TxXk5qail27dmHv3r3Izc1FZGQklEqlS8+jUChgtVprvZ5Go8H27dvx5ptv4s477wQA/Prrr7XezvWCZ1238dVXX6FNmzZYu3YtiouLERUVVS8hkshT+NdM1AjExsYiODi42tiWJ598Ejt27MB7772HlJQUJCYmYuXKlQgLC6vyuF69eiE5ORm//vorkpOT8Z///Mdttd1///0A7BMG/vHHH/juu+/wxRdf4N57763xNqZMmYJdu3Zh1apVSE5Oxvz582EwGBz3d+7cGenp6Th69Ch0Oh1ef/11FBYWuq0NNSGVSlFaWoqvvvoKRUVFOHPmjMs19OrVC5s3b8b+/fuxa9cux0Dm61EqlVCr1dixYwf279+Pzz77DE8//TQAuBS26otEIsHBgwfx66+/ori4GKdPn4ZWq/V0WURuw3BE1AjExsZW6zUCgIEDB2Lp0qXYuXMnHnnkEXz33Xd44403cMstt1R53MSJE3HbbbfhmWeewWOPPYajR4+6rbZWrVrhiy++AAA88cQTWL58OR588EHH4OqamDRpEp599lls2LABs2bNQlRUFHr37u24f8SIEbj33nsxffp03HzzzZBKpZg7d67b2lATMTEx6N69Oz788EM8/PDDGD9+PAYNGoQHH3ywSpCriWeeeQY+Pj6YPn06nnvuuRoPXlcoFPjXv/6FM2fOYNq0adi8eTNeeeUVyGQyHDhwwJVm1YvbbrsNBoMBb7zxBqZMmYKxY8eiX79+eOeddzxdGpFbCOKVBhQQETUzy5Ytw++//4758+fD19cXZrMZ+/fvx9KlS7F582Z07tzZ0yU2Gvfffz/atWuHO++8EwqFAuXl5UhKSsJvv/1WZc4mIm/FcEREBCArKwvvvPMOjh49irKyMsjlcnTs2BETJkyoNkFlc7d7926sXr0ap06dgk6ng7+/P7p164bp06dj8ODBni6PqM4YjoiIiIiccMwRERERkROGIyIiIiInDEdEREREThiOiIiIiJxwhmzYLwIpiiLkcrmnSyEiIqIaMpvNEAQBCQkJbt0ue45gv3ZU5VdTI4oiTCYT2+Zl2Dbv1JTbBjTt9rFt3qm+3rvZcwT7lb5NJhNiYmKuejFIb6XT6ZCWlsa2eRm2zTs15bYBTbt9bJt3Onr0KARBcPt22XNERERE5IThiIiIiMgJwxERERGRE4YjIiIiIicckE1ERFRLVqsVZrPZ02XUiNFodHyXSLynT0Qul0MqlXrkuRm
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtQAAAHGCAYAAABU2xz9AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAA5bdJREFUeJzs3Xd8jff7+PHXOZmH2Ds0RJBBrARFqb0l9ohSakZtMWJH7E1IUWrVVmK0aI1SX2pEVRAjJGQYsUWOJJLz+yO/3B+nCSLjRLiej0ceSe7xvq/7OhnXed/v+32rdDqdDiGEEEIIIUSaqLM6ACGEEEIIIbIzKaiFEEIIIYRIBymohRBCCCGESAcpqIUQQgghhEgHKaiFEEIIIYRIBymohRBCCCGESAcpqIUQQgghhEgHKaiFEEIIIYRIBymohRBCCCGESAcpqIUQQqTo9OnT2Nrapvjh4+OT1eFlibCwMGxtbQkLC8vqUIQQHxHjrA5ACCHEx23evHmUKlVKb1nhwoUz9BiBgYGcPn2anj17Zmi7Ga1w4cLs2LEjw88/Iz1//px169bx7bffkjt37qwOR4jPgvRQCyGEeKcyZcrg6Oio91GkSJEMPUZgYCDr16/P0DYzg6mpKY6OjpiammZ1KG/1/Plzli5dyvPnz7M6FCE+G1JQCyGEEEIIkQ5SUAshhEiXTZs20bhxYypXrkzXrl25cuWK3vq//vqLdu3aUalSJRo0aMDatWuVdWPHjsXW1hZPT0/Cw8NTHKPdoEEDdu7cqddm9+7d9bZJ+v7GjRsMGDAAZ2dnve2fPn3KmDFjqFatGjVr1mTSpElER0d/8Lm+bQy1ra0tP/zwA/Xr16du3bocO3aMVq1aUb16dY4cOaLs99NPP9G0aVOqVq2Ku7s79+/f12vn2LFjtG7dmgoVKuDi4sLx48f11vv4+NC9e3eeP3/OpEmTqF27NmfOnAFg586d2Nra0rBhQwAaNmyIra0t3bt312vjXa9H0rj5yMhIBgwYQOXKlWncuLFeHDqdjhUrVvD1119TpUoVvvnmGy5duqR3jDNnztCxY0cqVqxI8+bN2bdv3wfnWojsRApqIYQQabZr1y5mzJhBt27dWLlyJfny5ePbb7/lyZMnAISGhjJw4EDKly/P6tWr6dOnD7Nnz+bcuXMADBo0iB07djBo0CAKFSrEjh072LFjB506dfrgWO7cuUOPHj0oVqwYw4YN01s3ePBgLl++zNy5c5kyZQpHjhxh8uTJ6T7/N+3btw9vb29ev37NsGHDlPPesmWLss3SpUvp06cP8+bNIzg4GHd3d3Q6HQCnTp1iwIABVKlShR9//JHKlSszYMAATp8+rXecuLg4vv32Wx48eIC7uzslS5YEoH79+uzYsYMffvgBgB9++IEdO3bg5eWl7Pu+1yNJ3759sba25ocffsDS0pIxY8aQkJAAwIIFC/jhhx/o3bs3y5cvJ0+ePPTu3Vt5zYODg/nuu++wt7dn9erVNG3alJEjR3Lq1KkMzbcQHxO5KVEIIcQ7tWnTRu/7HTt24OjoCCT2mHbp0kW5mdDW1paaNWty5MgR2rdvT0JCApMmTcLFxQUzMzPs7Ozw9fXln3/+wdnZmRIlSlCiRAlu3LihjE9Oqz179rBs2TIaNWqkt/zMmTOcOXOGXbt24eDgAMD9+/eZM2cO06dPz7Dx0AMGDOCrr76idOnSWFtb06JFC4KCgjh79qyyTe/evenYsSMAOXLk4Ntvv+X8+fM4OTmxbNkyqlatytSpUwGoWbMmt27dYunSpdSoUUNp459//qFHjx6MHz9e7/j58uUjX758Su95uXLlKFGihN4273s9kjg7OzNmzBgAcubMSceOHYmMjMTCwoK1a9fi7u5Ojx49AChdujTe3t7cvXuXfPnysXLlSmxsbJTzqFatGseOHcPPz4+aNWtmSK6F+NhIQS2EEOKdFi5ciLW1tfJ90tdRUVGEh4ezYcMGNmzYoLfP7du3AShZsiQvX75k8eLFnDt3jitXrhAfH8+rV6/SFVNSb+mb6tatm6yYBrh27RoAbdu2TbYuPDxc79zSI2nmD5VKpff1m6pWrap8nfTm4c6dOzg5OXHp0iV69+6tt33NmjVZtWqV3rJ8+fIxfPjwNMWY2tejW7duytf58+cHEnvGb968SWxsLE5OTsr6QoUKsWTJEuX7a9eucfXqVWxtbfXaNDMzS1PMQmQHUlALIYR4J2tra+zt7d+6ftiwYdSrV09vWVIRdujQIYYMGULLli1xc3OjYsWKGTLU4u7du8mWVaxY8a3bGxkZsWPHjmQFrqWlZbpjSaukoR5JMSV9/7btkpQtW5YcOXKk6ZipfT2srKxSFUuSixcvkj9/fqVHvFGjRgwaNEhvG3Nz8zTFLER2IGOohRBCpImFhQWWlpY8ffoUe3t75WP79u34+/sDiWOsq1atyty5c2nTpg3FixcnPDw8WVtmZma8fv06xeMYGRmh1WqV7y9evJhiG29TtmxZ4uPjUavVSoxqtZrVq1fz7NmzDzzr9Em6gRDg33//BVDm+HZ0dEw2Xvrvv//+4GEwST3B8fHxydal9vUwMjJKsW0bGxtMTEz0xlzHxMTQrVs3jh49CiTm++7du3o/E+fPn+fXX3/9oPMQIjuRHmohhBBp9v333+Pl5UWhQoWoXLkyf/zxB1u3bsXV1RVIHJ7g7+/PsWPHiImJYdWqVYSHhycrnsuXL8+jR4/Yvn07pUuX5vz58/Tu3Ru1Wo2dnR379u2jbdu23Lt3D09PT/LmzZvqGL/88kuqVauGh4cHQ4cOxdzcnIULF/L69WsKFSqUkel4r3Xr1lGsWDEKFCjA7NmzqVixIpUrVwYSc/ndd98xZcoUmjZtysGDB/H392fNmjUfdIxChQphaWnJ+vXradWqFSEhITg6OlKmTJlUvx5vY2FhQY8ePfjxxx/JlSsX5cqVY/PmzeTIkYOmTZsC0K9fP1xdXZk4cSKtWrUiODiY2bNnp3mYihDZgRTUQggh0qxDhw68evWKtWvXsmTJEmxsbFi2bBmVKlUCYOjQody/f5+hQ4diYWFBmzZtKFSoEOfPn9drp1SpUnh5ebF06VIiIyOxtrZWxhOPGjWK0aNHU7t2bYoVK8aIESNYt27dB8W5ZMkSZs6ciaenJyqVitq1aytfG9KIESNYs2YNYWFhVK9eXblxDxLHSy9fvpx58+axY8cOrK2tWb58ud4Niam1cOFCpkyZwtatW8mfPz8rVqwAUv96vIuHhwe5cuXixx9/JCoqiooVK7J27Vpl3LiNjQ0//vgj8+fPZ9euXRQuXJjBgwfTq1evDz4PIbILle5tA6KEEEIIkSHCwsJo2LCh3gwpQohPh4yhFkIIIYQQIh2kh1oIIYQQQoh0kB5qIYQQQggh0kEKaiGEEEIIIdJBCmohhBBCCCHSQQpqIYQQQggh0kHmoRbCAP755x90Oh0mJiZZHYoQQgghUikuLg6VSkWVKlXeuZ30UAthADqdTvkQmUen0xEbGyt5zmSSZ8OQPBuG5NkwsmueU/u/W3qohTAAExMTYmNjKVOmDDly5MjqcD5Z0dHRBAYGSp4zmeTZMCTPhiF5NozsmueAgIBUbSc91EIIIYQQQqSDFNRCCCGEEEKkgxTUQgghhBBCpIMU1EIIIYQQQqSDFNRCCCGEEEKkgxTUQgghhBBCpIMU1EIIIYQQQqSDFNRCCCGEEEKkgxTUQgghhBBCpIMU1EIIIYQQwqD+/vtvbG1tk328fv1aWd+mTRuqVq3KkCFDePbsWYrteHh44OPjY8jQUyQFtRBCCCGEMKjLly9ToUIFduzYofdhbGzMjRs36N+/P6VKlWLJkiXExMTg4eGRrI2VK1eyd+/eLIg+OeOsDkCIj1VYWBiTJk3i3LlzqNVqmjdvzuTJkzE3N8/q0IQQQohs7cqVK1SsWBFHR8dk65YvX46VlRULFixArVZTqVIl6taty8WLF6lYsSIACxcuZOvWrRQvXtzQoadIeqiFSEFCQgLu7u4UKVKEAwcOsH79eo4dO8bKlSvT1a5KpcqgCEVKVCo
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 640x480 with 1 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"score_df = train_data.dropna(subset=['label']).groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(300, 'total_mv')).merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left').merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"score_df['score1'] = model1.predict(score_df[feature_columns])\n",
|
|||
|
|
"score_df['score2'] = model2.predict_proba(score_df[feature_columns])[:, 1]\n",
|
|||
|
|
"\n",
|
|||
|
|
"if 'score1' not in feature_columns:\n",
|
|||
|
|
" feature_columns.append('score1')\n",
|
|||
|
|
"if 'score2' not in feature_columns:\n",
|
|||
|
|
" feature_columns.append('score2')\n",
|
|||
|
|
"\n",
|
|||
|
|
"model3, scaler, pca = train_model(score_df, feature_columns)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 54,
|
|||
|
|
"id": "5d1522a7538db91b",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T15:04:39.656944Z",
|
|||
|
|
"start_time": "2025-04-03T15:04:39.298483Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"# train_data = train_data.sort_values(by='trade_date')\n",
|
|||
|
|
"# all_dates = train_data['trade_date'].unique() # 获取所有唯一的 trade_date\n",
|
|||
|
|
"# split_date = all_dates[-120] # 划分点为倒数第 validation_days 天\n",
|
|||
|
|
"# print(split_date)\n",
|
|||
|
|
"# print(all_dates)\n",
|
|||
|
|
"# val_data_split = train_data[train_data['trade_date'] >= split_date] # 验证集\n",
|
|||
|
|
"\n",
|
|||
|
|
"feature_columns.remove('score1')\n",
|
|||
|
|
"feature_columns.remove('score2')\n",
|
|||
|
|
"\n",
|
|||
|
|
"score_df = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(500, 'total_mv'))\n",
|
|||
|
|
"# score_df = score_df[score_df['pe_ttm'] > 0]\n",
|
|||
|
|
"score_df = score_df.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
"score_df = score_df.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"numeric_columns = score_df.select_dtypes(include=['float64', 'int64']).columns\n",
|
|||
|
|
"numeric_columns = [col for col in feature_columns if col in numeric_columns]\n",
|
|||
|
|
"score_df['score1'] = model1.predict(score_df[feature_columns])\n",
|
|||
|
|
"score_df['score2'] = model2.predict_proba(score_df[feature_columns])[:, 1]\n",
|
|||
|
|
"\n",
|
|||
|
|
"# score_df['score1'] = score_df.groupby('trade_date', group_keys=False)['score1'].rank(\n",
|
|||
|
|
"# ascending=True,\n",
|
|||
|
|
"# na_option='keep'\n",
|
|||
|
|
"# )\n",
|
|||
|
|
"# score_df['score2'] = score_df.groupby('trade_date', group_keys=False)['score2'].rank(\n",
|
|||
|
|
"# ascending=True,\n",
|
|||
|
|
"# na_option='keep'\n",
|
|||
|
|
"# )\n",
|
|||
|
|
"# score_df['score'] = score_df['score1'] + score_df['score2'] * 1.15\n",
|
|||
|
|
"if 'score1' not in feature_columns:\n",
|
|||
|
|
" feature_columns.append('score1')\n",
|
|||
|
|
"if 'score2' not in feature_columns:\n",
|
|||
|
|
" feature_columns.append('score2')\n",
|
|||
|
|
"score_df['score'] = model3.predict(score_df[feature_columns])\n",
|
|||
|
|
"\n",
|
|||
|
|
"score_df = score_df.groupby('trade_date', group_keys=False).apply(\n",
|
|||
|
|
" lambda x: x[x['score'] >= x['score'].quantile(0.90)] # 计算90%分位数作为阈值,筛选分数>=阈值的行\n",
|
|||
|
|
").reset_index(drop=True) # drop=True 避免添加旧索引列\n",
|
|||
|
|
"# save_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(1, 'score')).reset_index()\n",
|
|||
|
|
"save_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(1, 'total_mv')).reset_index()\n",
|
|||
|
|
"save_df = save_df.sort_values(['trade_date', 'score'])\n",
|
|||
|
|
"save_df[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 55,
|
|||
|
|
"id": "09b1799e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"193\n",
|
|||
|
|
"['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'cat_up_limit', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_daily_return', '000905.SH_daily_return', '399006.SZ_daily_return', '000852.SH_up_ratio_20d', '000905.SH_up_ratio_20d', '399006.SZ_up_ratio_20d', '000852.SH_volatility', '000905.SH_volatility', '399006.SZ_volatility', '000852.SH_volume_change_rate', '000905.SH_volume_change_rate', '399006.SZ_volume_change_rate', 'score1', 'score2']\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print(len(feature_columns))\n",
|
|||
|
|
"print(feature_columns)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 56,
|
|||
|
|
"id": "7e9023cc",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"def analyze_factors(\n",
|
|||
|
|
" df: pd.DataFrame,\n",
|
|||
|
|
" feature_columns: list[str],\n",
|
|||
|
|
" target_column: str = 'target', # 假设目标列默认为 'target'\n",
|
|||
|
|
" trade_date_col: str = 'trade_date', # 假设日期列默认为 'trade_date'\n",
|
|||
|
|
" mcap_col: str = 'total_mv', # 新增: 市值列名称\n",
|
|||
|
|
" mcap_bins: int = 5 # 新增: 市值分位数的数量 (例如 5 表示五分位数)\n",
|
|||
|
|
") -> pd.DataFrame:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 分析DataFrame中指定特征列的各种指标,包括基本统计、相关性、日间IC、ICIR以及在不同市值分位数上的IC。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df (pd.DataFrame): 包含日期、目标列、特征列和市值列的DataFrame。\n",
|
|||
|
|
" 需要包含 trade_date_col, target_column, feature_columns 和 mcap_col 中的所有列。\n",
|
|||
|
|
" feature_columns (list[str]): 需要分析的特征列名称列表。\n",
|
|||
|
|
" target_column (str): 目标变量列的名称。\n",
|
|||
|
|
" trade_date_col (str): 交易日期列的名称。\n",
|
|||
|
|
" mcap_col (str): 市值列的名称。\n",
|
|||
|
|
" mcap_bins (int): 市值分位数的数量 (例如 5 表示五分位数)。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" pd.DataFrame: 包含各个因子分析指标的汇总DataFrame。\n",
|
|||
|
|
" 同时打印因子在不同市值分位数上的平均IC表格。\n",
|
|||
|
|
" 如果输入数据或列有问题,可能返回空或包含NaN的DataFrame。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 数据校验 ---\n",
|
|||
|
|
" required_cols = [trade_date_col, target_column, mcap_col] + feature_columns\n",
|
|||
|
|
" if not all(col in df.columns for col in required_cols):\n",
|
|||
|
|
" missing = [col for col in required_cols if col not in df.columns]\n",
|
|||
|
|
" print(f\"错误: 输入DataFrame缺少必需的列: {missing}\")\n",
|
|||
|
|
" return pd.DataFrame() # 返回空DataFrame\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 确保日期列是 datetime 类型\n",
|
|||
|
|
" df = df.copy() # 在副本上操作\n",
|
|||
|
|
" df[trade_date_col] = pd.to_datetime(df[trade_date_col], errors='coerce')\n",
|
|||
|
|
" df.dropna(subset=[trade_date_col], inplace=True) # 移除日期转换失败的行\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 过滤掉那些在 feature_columns, target_column, mcap_col 上有 NaN 的行,以确保后续计算是在完整数据上\n",
|
|||
|
|
" # 直接在 df 副本上进行清洗\n",
|
|||
|
|
" initial_rows_before_clean = len(df)\n",
|
|||
|
|
" df.dropna(subset=feature_columns + [target_column, mcap_col], inplace=True)\n",
|
|||
|
|
" rows_dropped_clean = initial_rows_before_clean - len(df)\n",
|
|||
|
|
" if rows_dropped_clean > 0:\n",
|
|||
|
|
" print(f\"警告: 移除了 {rows_dropped_clean} 行,因为其特征、目标或市值列存在空值。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" if df.empty:\n",
|
|||
|
|
" print(\"错误: 清理缺失值后数据为空,无法进行因子分析。\")\n",
|
|||
|
|
" return pd.DataFrame() # 返回空DataFrame\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(f\"开始分析 {len(feature_columns)} 个因子指标...\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 1. 基本因子统计量 ---\n",
|
|||
|
|
" basic_stats = df[feature_columns].describe().T\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(\"\\n--- 基本因子统计量 ---\")\n",
|
|||
|
|
" print(basic_stats)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 2. 因子与目标变量的整体相关性 ---\n",
|
|||
|
|
" overall_correlation = {}\n",
|
|||
|
|
" for feature in feature_columns:\n",
|
|||
|
|
" # 在清理后的 df 上计算相关性\n",
|
|||
|
|
" if df[[feature, target_column]].dropna().shape[0] > 1: # 确保至少有两个有效数据点\n",
|
|||
|
|
" overall_correlation[feature] = {\n",
|
|||
|
|
" 'Pearson_Correlation_with_Target': df[feature].corr(df[target_column], method='pearson'),\n",
|
|||
|
|
" 'Spearman_Correlation_with_Target': df[feature].corr(df[target_column], method='spearman')\n",
|
|||
|
|
" }\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" overall_correlation[feature] = {\n",
|
|||
|
|
" 'Pearson_Correlation_with_Target': np.nan,\n",
|
|||
|
|
" 'Spearman_Correlation_with_Target': np.nan\n",
|
|||
|
|
" }\n",
|
|||
|
|
" overall_corr_df = pd.DataFrame.from_dict(overall_correlation, orient='index')\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(\"\\n--- 因子与目标变量的整体相关性 ---\")\n",
|
|||
|
|
" print(overall_corr_df)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 3. 因子之间的相关性矩阵 ---\n",
|
|||
|
|
" # 在清理后的 df 上计算相关性\n",
|
|||
|
|
" factor_correlation_matrix = df[feature_columns].corr(method='spearman') # 改回 Spearman\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(\"\\n--- 因子之间的相关性矩阵 (Spearman) ---\") # 修正打印信息\n",
|
|||
|
|
" print(factor_correlation_matrix)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 4. 日间 IC 和 ICIR ---\n",
|
|||
|
|
" print(\"\\n--- 计算日间 IC (Spearman 相关性) 和 ICIR ---\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 直接在清理后的 df 上计算每日 IC\n",
|
|||
|
|
" if df.empty: # 理论上上面已经检查过,这里再检查一次更安全\n",
|
|||
|
|
" daily_ic_series = pd.Series(dtype=float) # 空 Series\n",
|
|||
|
|
" ic_stats = pd.DataFrame({\n",
|
|||
|
|
" 'Mean_IC (Spearman)': np.nan, 'Std_Dev_IC': np.nan, 'ICIR': np.nan\n",
|
|||
|
|
" }, index=feature_columns)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" daily_ic_series = df.groupby(trade_date_col).apply(\n",
|
|||
|
|
" lambda day_group: {\n",
|
|||
|
|
" feature: day_group[feature].corr(day_group[target_column], method='spearman')\n",
|
|||
|
|
" for feature in feature_columns if day_group.shape[0] > 1 # 确保每日数据点多于1才能计算相关性\n",
|
|||
|
|
" }\n",
|
|||
|
|
" ).apply(pd.Series) # 将字典结果转换为 DataFrame\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 IC 的统计量\n",
|
|||
|
|
" if not daily_ic_series.empty:\n",
|
|||
|
|
" ic_mean = daily_ic_series.mean()\n",
|
|||
|
|
" ic_std = daily_ic_series.std()\n",
|
|||
|
|
" # 避免除以零\n",
|
|||
|
|
" ic_ir = ic_mean / ic_std.replace(0, np.nan) # 使用 replace 0 为 NaN\n",
|
|||
|
|
"\n",
|
|||
|
|
" ic_stats = pd.DataFrame({\n",
|
|||
|
|
" 'Mean_IC (Spearman)': ic_mean,\n",
|
|||
|
|
" 'Std_Dev_IC': ic_std,\n",
|
|||
|
|
" 'ICIR': ic_ir\n",
|
|||
|
|
" })\n",
|
|||
|
|
" print(\"\\n--- 日间 IC 和 ICIR (Spearman) ---\")\n",
|
|||
|
|
" print(ic_stats)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" ic_stats = pd.DataFrame({\n",
|
|||
|
|
" 'Mean_IC (Spearman)': np.nan, 'Std_Dev_IC': np.nan, 'ICIR': np.nan\n",
|
|||
|
|
" }, index=feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 5. 因子在不同市值分位数上的平均 IC ---\n",
|
|||
|
|
" print(f\"\\n--- 计算因子在 {mcap_bins} 个市值分位数上的平均 IC (Spearman) ---\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 在清理后的 df 上计算每日市值分位数,直接添加到 df 中\n",
|
|||
|
|
" # 使用 transform() 和 qcut() 在每个日期分组内计算分位数\n",
|
|||
|
|
" # labels=False 返回整数 0 to mcap_bins-1\n",
|
|||
|
|
" # duplicates='drop' 处理在某些日期股票数量少于 bins 导致分位数边缘重复的情况,会返回 NaN\n",
|
|||
|
|
" # 添加一个临时列来存储分位数\n",
|
|||
|
|
" mcap_bin_col_name = f'_mcap_bin_{mcap_bins}'\n",
|
|||
|
|
" df[mcap_bin_col_name] = df.groupby(trade_date_col)[mcap_col].transform(\n",
|
|||
|
|
" lambda x: pd.qcut(x, q=mcap_bins, labels=False, duplicates='drop') if len(x) >= mcap_bins else np.nan # 确保股票数量足够进行分位数划分\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 过滤掉无法划分分位数 (NaN) 的行,进行分位数 IC 计算\n",
|
|||
|
|
" # 创建一个临时 DataFrame df_binned_analysis\n",
|
|||
|
|
" df_binned_analysis = df.dropna(subset=[mcap_bin_col_name]).copy()\n",
|
|||
|
|
"\n",
|
|||
|
|
" if df_binned_analysis.empty:\n",
|
|||
|
|
" print(\"错误: 划分市值分位数后数据为空,无法计算分位数上的 IC。\")\n",
|
|||
|
|
" avg_ic_by_bin = pd.DataFrame(index=range(mcap_bins), columns=feature_columns) # Placeholder\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" # 按日期和市值分位数分组,计算每个分组内的因子与目标变量的截面相关性 (分位数IC)\n",
|
|||
|
|
" binned_ic_by_day = df_binned_analysis.groupby([trade_date_col, mcap_bin_col_name]).apply(\n",
|
|||
|
|
" lambda group: {\n",
|
|||
|
|
" feature: group[feature].corr(group[target_column], method='spearman')\n",
|
|||
|
|
" for feature in feature_columns if group.shape[0] > 1 # 确保分位数组内数据点多于1\n",
|
|||
|
|
" }\n",
|
|||
|
|
" ).apply(pd.Series) # 将嵌套结果转为 DataFrame\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对每个分位数组的每日 IC 求平均\n",
|
|||
|
|
" # unstack(level=mcap_bin_col_name) 将 mcap_bin 作为列\n",
|
|||
|
|
" # mean(axis=0) 对日期索引求平均\n",
|
|||
|
|
" avg_ic_by_bin = binned_ic_by_day.unstack(level=mcap_bin_col_name).mean(axis=0).unstack()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 重命名索引和列,使表格更清晰\n",
|
|||
|
|
" if not avg_ic_by_bin.empty:\n",
|
|||
|
|
" # Index name will be the original column name used for grouping ('_mcap_bin_X')\n",
|
|||
|
|
" # Rename the index name explicitly\n",
|
|||
|
|
" avg_ic_by_bin.index.name = 'MarketCap_Bin'\n",
|
|||
|
|
" avg_ic_by_bin.columns.name = 'Feature'\n",
|
|||
|
|
" # 可以根据需要对分位数 bin 索引进行排序 (虽然 pd.qcut labels=False usually sorts)\n",
|
|||
|
|
" avg_ic_by_bin = avg_ic_by_bin.sort_index()\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(avg_ic_by_bin)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 6. 汇总所有指标 ---\n",
|
|||
|
|
" # 将基本统计、整体相关性、IC/ICIR 合并到一个 DataFrame\n",
|
|||
|
|
" # 注意:合并时需要根据索引进行对齐 (因子名称)\n",
|
|||
|
|
" summary_df = basic_stats\n",
|
|||
|
|
" summary_df = summary_df.merge(overall_corr_df, left_index=True, right_index=True, how='left')\n",
|
|||
|
|
" summary_df = summary_df.merge(ic_stats, left_index=True, right_index=True, how='left')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # print(\"\\n--- 因子分析汇总报告 ---\")\n",
|
|||
|
|
" # print(summary_df)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 清理临时列 'mcap_bin' ---\n",
|
|||
|
|
" # 修正:在函数结束时从我们一直在操作的 df 副本中删除临时列\n",
|
|||
|
|
" if mcap_bin_col_name in df.columns:\n",
|
|||
|
|
" df.drop(columns=[mcap_bin_col_name], inplace=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" return summary_df # 主要返回汇总报告,分位数IC单独打印\n",
|
|||
|
|
"\n",
|
|||
|
|
"# # 运行分析函数\n",
|
|||
|
|
"# factor_analysis_report = analyze_factors(test_data.copy(), feature_columns, 'future_return')\n",
|
|||
|
|
"\n",
|
|||
|
|
"# print(\"\\n--- 最终汇总报告 DataFrame ---\")\n",
|
|||
|
|
"# print(factor_analysis_report)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 57,
|
|||
|
|
"id": "a0000d75",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"开始分析 'score' 在 'circ_mv' 和 'future_return' 下的表现...\n",
|
|||
|
|
"准备数据,处理 NaN 值...\n",
|
|||
|
|
"原始数据 28312 行,移除 NaN 后剩余 27929 行用于分析。\n",
|
|||
|
|
"对 'circ_mv' 和 'future_return' 进行 100 分位数分箱...\n",
|
|||
|
|
"按二维分箱分组计算 Spearman Rank IC...\n",
|
|||
|
|
"整理结果用于绘图...\n",
|
|||
|
|
"circ_mv_bin 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 \\\n",
|
|||
|
|
"future_return_bin ... \n",
|
|||
|
|
"0 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"1 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"2 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"3 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"4 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"... .. .. .. .. .. .. .. .. .. .. ... .. .. .. \n",
|
|||
|
|
"95 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"96 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"97 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"98 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"99 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN ... NaN NaN NaN \n",
|
|||
|
|
"\n",
|
|||
|
|
"circ_mv_bin 93 94 95 96 97 98 99 \n",
|
|||
|
|
"future_return_bin \n",
|
|||
|
|
"0 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"1 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"2 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"3 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"4 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"... .. .. .. .. .. .. .. \n",
|
|||
|
|
"95 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"96 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"97 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"98 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"99 NaN NaN NaN NaN NaN NaN NaN \n",
|
|||
|
|
"\n",
|
|||
|
|
"[100 rows x 100 columns]\n",
|
|||
|
|
"生成热力图...\n",
|
|||
|
|
"分析完成。\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"data": {
|
|||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABdEAAASgCAYAAAAXXAHaAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Xd0VFX79vErvdC7UqQL0YggUkRBaSJRiiigYhAUpUtRFBBQUUEUAQFRBBEEeWjSVASkiApIFzIQioAYihRpIQNJmJz3D97Mj5BMSIbJnJnM97NW1vPjzLVn3yd3nHe9Ozv7+BmGYQgAAAAAAAAAAKTjb3YBAAAAAAAAAAB4KhbRAQAAAAAAAABwgEV0AAAAAAAAAAAcYBEdAAAAAAAAAAAHWEQHAAAAAAAAAMABFtEBAAAAAAAAAHCARXQAAAAAAAAAABxgER0AAAAAAAAAAAdYRAcAAF6lTZs2atasmVJSUswuxaELFy5o2bJlMgzDfm3Xrl3q0qWLfvnll2y/3+rVqzVw4EBt27Yty2MMw9CePXuyPRcAAAAAIK1AswsAAADIjv/++0/Jycny98/6XoDExEQlJycrPDw8zbirV6/qypUryps3r/r16yeLxXLT91qwYIEKFCiQaWbevHkaPXq09uzZo9dff12SVKBAAW3YsEGXLl3SI488kuXaJenw4cNatGiRoqKiMnw9OTlZp0+f1pEjR3To0CHt3LlTmzZt0r///qtPP/1Ujz32WLbmy4p///1Xc+fO1auvvio/Pz+Xvz9ghmnTpunBBx9UlSpVzC4FAAAAHoRFdAAA4FHi4uL0ww8/qECBAgoKClJAQECa15OSkpSSkqKFCxemG5uSkqLExEQVLFhQjz/+uP360qVLNWTIkAznK1q0qNavX6+LFy/q3Llz6tChQ4a5DRs2aNeuXQoKCsq0fqvVqhkzZig0NFTR0dH262XLltWTTz6pBQsWaNGiRXryySczfZ/rhYWFpfnfVGfPnlXz5s114cIF+653f39/lShRQmXLllX9+vUVFxcnm82W7vt4K+Lj49W5c2cFBATopZdeUt68eV323oBZrl69qrVr12ratGmaP3++br/9drNLAgAAgIdgER0AAHiU48ePa9y4cTfNDRo0yOFr9957b5pF9CpVqqhnz54KCgrSzz//rL/++kvdunVTUlKSgoODJUmBgYHKnz+/+vXrl+F7Wq1W7dq1SyEhIZnWNW3aNJ0+fVovv/yySpQokea1vn37avny5RoxYoTuv/9+lSlT5qb3Kcm+e/7GHd8BAQE6f/68ateurR49eqhUqVK67bbb7PeUUwYMGKCkpCTNnj3bvoB+487dwMBAFS1aVPXq1VPXrl1Vrly5HK1JkjZt2qSOHTuqV69e6t27d47N06hRIx07dkyrV69W6dKlHeaWLFmiGTNm6ODBg8qbN68aN26svn37qnDhwjlWm7tFR0dr8+bN2rdv3y2/V5UqVXTvvfdq3rx5LqgsY0ePHlXjxo315JNP6sMPP0zzWmBgoD7//HP7z9CcOXNu+kszAAAA+AYW0QEAgEepVq2afv31VxUoUEDBwcHpjm1p2rSpbDab1qxZk25sSkqKrFarEhIS0r1ntWrVJElHjhzRP//8ox49eqTJ3LhTe/78+QoKClLDhg3THN+S2Y7uffv26YsvvlDhwoXVtWvXdK8XK1ZMb7zxhoYNG6aePXvq22+/Vb58+Ry+V1JSkoKCgnTmzBlJ137BsH//fiUmJqps2bL2xfJSpUrpgQcecFiXKy1evFi//vqr5s2bl+6XBJLUrVs3SdK5c+e0fft2LVy4UCtWrNDs2bNVtWpVt9ToCcaMGaPJkyerZMmSateunY4dO6Z58+Zp8+bNmj9/vsO+w1x58+bVZ599pscff1xTp05V9+7dzS4JAAAAHoBFdAAA4FHCwsIUHx+v+fPnKzQ0NN2idUJCgmw2W7rjXGw2m5KSkhQWFqY2bdrcch2zZs3S3r17tXHjxizlk5OT9eabbyo5OVl9+/Z1uEjavn17rV+/XitWrFCXLl30+eefZ7gzedCgQdq9e3eaawMGDLD/31OmTFHdunWzcUe3LikpSZ988ok6deqkyMjIDDPX7+RPSUnRkCFD9N1332ncuHH64osv3FWqqbZs2aLJkyerfPnymjdvnvLnzy/pWs9Gjx6tzz77TAMHDjS5Sjhy++2367XXXtPo0aP1zDPPqFChQmaXBAAAAJOxiA4AADzO8ePH9f7772eacXScS+XKlV2yiG6z2VS4cOEsH70xbtw4xcbG6qGHHlK7du0yzY4aNUrnzp3T5s2b1bZtW40ZM0b33ntvmszgwYN1+fJlhYSE6JNPPtGff/6pd955R5UrV9bly5d11113uf2oiVWrVuncuXN6+eWXs5T39/dX9+7d9d1332nHjh05XJ3nmDJliiSpT58+9gV0SXruuec0YcIE/fjjjyyie7j27dtr4sSJWrhwoV566SWzywEAAIDJWEQHAAAeJzIyUn/++aeCg4PT7UR3dJyLYRhKTk5WcnKyS2qw2WwqVqxYlrLz5s3T1KlTVbRoUY0aNSrd2eU3CgsL05QpU9S3b1+tXbtWzz77rF544QV1797dvuh6//33S5IuXbqk2NhYSdd+QZB6PZW/v7/279+vL7/8MsO5rl69qitXrqh///5ZupfMrFmzRg899FC2duYWKVJEknTlypVbnt8bJCYmasOGDfL391f9+vXTvJYnTx6VKFFC//zzjy5cuJDmmCB4lsDAQEVFRWn16tUsogMAAIBFdAAAPMG///6rCRMmaOPGjTpz5oyKFi2qBg0aqHfv3vZFyFQ2m00zZ87UggULdOTIERUpUkTVqlVTnz59VLFixXTvvXTpUs2YMUMHDhxQWFiYHnjgAb366quqUKFCmtyND2VctmyZ5syZo3379mnGjBnpzrPes2ePJk2apC1btighIUHly5dXx44d1bZtW6e/DwcPHlRwcLACAwMdLkTbbDbZbDb9+++/Dl8/efKkkpKS7DXv2bNHc+bMUXBwsCwWixITE/XBBx8oKSlJderUUVRUVLr3SUlJydIi+rx58/T2228rKChIbdu21bvvvquQkBAFBATcdDG9atWqqlSpkqZMmaJp06YpJCREffv2TZNZvHixEhMT7f/+559/VKZMGft7G4ahgwcPasaMGfbMhQsXZLPZlD9/fiUmJiopKckli+gWi+Wmu+xvtGfPHklK9xDVhIQETZs2TT/++KNOnDihggULqkaNGurXr5/Kli1rzy1cuFCDBg3SyJEjVa1aNX388cfatm2bAgMD9eCDD+qtt97K9K8FUlJSNGjQIC1evFjdu3dP9/11tYMHDyo5OVklS5a0P3T1ekOGDNH58+ez/fBXm82mhx9+WH5+flq3bl26ZwU0atRIVqtVv/32m/0vFK5cuaJp06bphx9+0PHjx5UnTx7dc8896tOnjyIiIpy/yWzau3evJk2apB07dig+Pl4lS5ZUq1at1Llz5wy/D+vXr9fo0aP1119/qUiRImrdurV69OiRLpucnKyvv/5aS5cu1ZEjRxQeHq6HHnpI/fv3V6lSpW657lq1amnhwoUyDOOm/y0DAAAgd2MRHQAAk8XHx6tDhw46duyYGjVqpHLlyuno0aOaO3eudu3apQULFtgXzGw2m3r27Km1a9eqXLlyat++vc6ePauVK1fql19+0cyZM9McCzJixAjNmDFDxYsX15NPPqmzZ89qxYoVWrdunaZOnaqaNWtmWNOQIUM0f/58lS5dWmXLllVoaGia19etW6devXopPDxcTZs2VWhoqH755RcNGTJE//77r3r37u3U96Jjx472h2jezMMPP3zTzL59+yRdOx5m4cKFCgoKUnJysq5evaqFCxcqKSlJefLkyXAR/ezZs6pevXqm72+1WjVnzhz5+fnpo48+0oULF/T5559nqX5JeuSRRzR58mTdd999mjdvnl599dU0rxuGoZkzZ9r//dNPP+l///uf3n33XbVt21Y2m02GYahly5Z677337Ll27drpzJkzGT589VacPHk
|
|||
|
|
"text/plain": [
|
|||
|
|
"<Figure size 1600x1200 with 2 Axes>"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
"metadata": {},
|
|||
|
|
"output_type": "display_data"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import matplotlib.pyplot as plt\n",
|
|||
|
|
"import seaborn as sns\n",
|
|||
|
|
"from scipy.stats import spearmanr\n",
|
|||
|
|
"from tqdm import tqdm # 用于显示进度条 (可选)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 设置 Matplotlib/Seaborn 样式 (可选)\n",
|
|||
|
|
"sns.set_theme(style=\"whitegrid\")\n",
|
|||
|
|
"plt.rcParams['font.sans-serif'] = ['SimHei'] # 或者其他支持中文的字体\n",
|
|||
|
|
"plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题\n",
|
|||
|
|
"\n",
|
|||
|
|
"def analyze_score_performance_2d(score_df: pd.DataFrame,\n",
|
|||
|
|
" score_col: str = 'score',\n",
|
|||
|
|
" label_col: str = 'label',\n",
|
|||
|
|
" condition1_col: str = 'circ_mv',\n",
|
|||
|
|
" condition2_col: str = 'future_return',\n",
|
|||
|
|
" n_bins: int = 100,\n",
|
|||
|
|
" min_samples_per_bin: int = 30): # 每个格子最少样本数\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 分析 score 在两个条件下 (如市值、未来收益) 的二维分箱表现。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" score_df (pd.DataFrame): 包含分数、标签和条件列的 DataFrame。\n",
|
|||
|
|
" score_col (str): 预测分数所在的列名。\n",
|
|||
|
|
" label_col (str): 目标标签所在的列名 (应为数值或可排序类别)。\n",
|
|||
|
|
" condition1_col (str): 第一个条件列名 (例如 'circ_mv')。\n",
|
|||
|
|
" condition2_col (str): 第二个条件列名 (例如 'future_return')。\n",
|
|||
|
|
" n_bins (int): 每个条件划分的箱数 (分位数数量)。\n",
|
|||
|
|
" min_samples_per_bin (int): 计算指标所需的最小样本数,小于此数目的格子结果将被屏蔽。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" tuple: 包含 (performance_pivot, count_pivot, fig)\n",
|
|||
|
|
" performance_pivot: 以二维分箱为索引/列的 Spearman 相关系数矩阵。\n",
|
|||
|
|
" count_pivot: 每个二维分箱的样本数量矩阵。\n",
|
|||
|
|
" fig: 生成的热力图 Matplotlib Figure 对象。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" print(f\"开始分析 '{score_col}' 在 '{condition1_col}' 和 '{condition2_col}' 下的表现...\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" required_cols = [score_col, label_col, condition1_col, condition2_col]\n",
|
|||
|
|
" if not all(col in score_df.columns for col in required_cols):\n",
|
|||
|
|
" missing = [col for col in required_cols if col not in score_df.columns]\n",
|
|||
|
|
" raise ValueError(f\"输入 DataFrame 缺少必需列: {missing}\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 1. 数据准备和清洗 ---\n",
|
|||
|
|
" print(\"准备数据,处理 NaN 值...\")\n",
|
|||
|
|
" # 只保留需要的列,并移除包含 NaN 的行,避免影响分箱和计算\n",
|
|||
|
|
" analysis_df = score_df[required_cols].dropna().copy()\n",
|
|||
|
|
" n_original = len(score_df)\n",
|
|||
|
|
" n_after_drop = len(analysis_df)\n",
|
|||
|
|
" print(f\"原始数据 {n_original} 行,移除 NaN 后剩余 {n_after_drop} 行用于分析。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" if n_after_drop < min_samples_per_bin * n_bins: # 检查数据量是否过少\n",
|
|||
|
|
" print(f\"警告: 清理 NaN 后数据量 ({n_after_drop}) 可能不足以支持 {n_bins}x{n_bins} 的精细分箱分析。\")\n",
|
|||
|
|
" if n_after_drop < min_samples_per_bin:\n",
|
|||
|
|
" print(\"错误: 有效数据过少,无法进行分析。\")\n",
|
|||
|
|
" return None, None, None\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 2. 二维分箱 ---\n",
|
|||
|
|
" print(f\"对 '{condition1_col}' 和 '{condition2_col}' 进行 {n_bins} 分位数分箱...\")\n",
|
|||
|
|
" bin1_col = f'{condition1_col}_bin'\n",
|
|||
|
|
" bin2_col = f'{condition2_col}_bin'\n",
|
|||
|
|
"\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" # 使用 qcut 进行分位数分箱,labels=False 返回 0 到 n_bins-1 的整数标签\n",
|
|||
|
|
" # duplicates='drop' 会丢弃导致边界不唯一的重复值所在的箱子,可能导致某些箱号缺失\n",
|
|||
|
|
" # 对于可视化,这通常可以接受,但如果需要严格的等分,需先 rank\n",
|
|||
|
|
" analysis_df[bin1_col] = pd.qcut(analysis_df[condition1_col], q=n_bins, labels=False, duplicates='drop')\n",
|
|||
|
|
" analysis_df[bin2_col] = pd.qcut(analysis_df[condition2_col], q=n_bins, labels=False, duplicates='drop')\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" print(f\"错误: 分箱失败,请检查数据分布或减少 n_bins。错误信息: {e}\")\n",
|
|||
|
|
" # 可以尝试先 rank 再 qcut\n",
|
|||
|
|
" # analysis_df[bin1_col] = pd.qcut(analysis_df[condition1_col].rank(method='first'), q=n_bins, labels=False, duplicates='raise')\n",
|
|||
|
|
" # analysis_df[bin2_col] = pd.qcut(analysis_df[condition2_col].rank(method='first'), q=n_bins, labels=False, duplicates='raise')\n",
|
|||
|
|
" return None, None, None\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 3. 分组计算表现指标 (Spearman Rank IC) ---\n",
|
|||
|
|
" print(\"按二维分箱分组计算 Spearman Rank IC...\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" def safe_spearmanr(x, y):\n",
|
|||
|
|
" \"\"\"安全计算 Spearman 相关性,处理数据量过少的情况\"\"\"\n",
|
|||
|
|
" if len(x) < max(2, min_samples_per_bin): # 要求至少有 min_samples_per_bin 个点才计算\n",
|
|||
|
|
" return np.nan\n",
|
|||
|
|
" corr, p_value = spearmanr(x, y)\n",
|
|||
|
|
" return corr if not np.isnan(corr) else np.nan # 确保返回 NaN 而不是 None 或其他\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 按两个分箱列分组\n",
|
|||
|
|
" grouped = analysis_df.groupby([bin1_col, bin2_col])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算每个格子的 Spearman 相关系数\n",
|
|||
|
|
" # apply 可能较慢,但计算相关性通常需要 apply\n",
|
|||
|
|
" performance_series = grouped.apply(lambda sub: safe_spearmanr(sub[score_col], sub[label_col]))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算每个格子的样本数量\n",
|
|||
|
|
" count_series = grouped.size()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 4. 结果整理成 Pivot Table (用于绘图) ---\n",
|
|||
|
|
" print(\"整理结果用于绘图...\")\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" # 将 performance_series 转换成二维矩阵\n",
|
|||
|
|
" # index 为 condition1_bin, columns 为 condition2_bin\n",
|
|||
|
|
" performance_pivot = performance_series.unstack(level=0) # level=0 对应第一个 groupby key (bin1_col)\n",
|
|||
|
|
" count_pivot = count_series.unstack(level=0)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 可选:按列和索引排序,确保顺序正确\n",
|
|||
|
|
" performance_pivot = performance_pivot.sort_index(axis=0).sort_index(axis=1)\n",
|
|||
|
|
" count_pivot = count_pivot.sort_index(axis=0).sort_index(axis=1)\n",
|
|||
|
|
" \n",
|
|||
|
|
" print(performance_pivot)\n",
|
|||
|
|
"\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" print(f\"错误: 无法将结果转换为二维矩阵,可能因为分箱不均匀或数据问题: {e}\")\n",
|
|||
|
|
" return None, None, None\n",
|
|||
|
|
"\n",
|
|||
|
|
" # --- 5. 可视化:绘制热力图 ---\n",
|
|||
|
|
" print(\"生成热力图...\")\n",
|
|||
|
|
" fig, ax = plt.subplots(figsize=(16, 12)) # 调整图像大小\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 使用 count_pivot 创建一个 mask,屏蔽掉样本量过小的格子\n",
|
|||
|
|
" mask = count_pivot < min_samples_per_bin\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 绘制热力图\n",
|
|||
|
|
" sns.heatmap(performance_pivot,\n",
|
|||
|
|
" annot=False, # 100x100 个格子加注释会太密集\n",
|
|||
|
|
" fmt=\".2f\",\n",
|
|||
|
|
" cmap=\"viridis\", # 选择颜色映射, 'viridis', 'coolwarm', 'RdYlGn' 等都不错\n",
|
|||
|
|
" linewidths=.5,\n",
|
|||
|
|
" linecolor='lightgray',\n",
|
|||
|
|
" # mask=mask, # 应用 mask\n",
|
|||
|
|
" ax=ax,\n",
|
|||
|
|
" cbar_kws={'label': f'Spearman Rank IC ({score_col} vs {label_col})'}) # 颜色条标签\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 设置标题和轴标签\n",
|
|||
|
|
" ax.set_title(f'{score_col} 表现分析 (Rank IC vs {label_col})\\n基于 {condition1_col} 和 {condition2_col} {n_bins}x{n_bins} 分箱', fontsize=16)\n",
|
|||
|
|
" ax.set_xlabel(f'{condition1_col} 分位数 (0 -> 高)', fontsize=12)\n",
|
|||
|
|
" ax.set_ylabel(f'{condition2_col} 分位数 (0 -> 高)', fontsize=12)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 可选:调整刻度标签,避免显示所有 100 个刻度\n",
|
|||
|
|
" if n_bins > 20:\n",
|
|||
|
|
" tick_interval = n_bins // 10 # 大约显示 10 个刻度\n",
|
|||
|
|
" ax.set_xticks(np.arange(0, n_bins, tick_interval) + 0.5)\n",
|
|||
|
|
" ax.set_yticks(np.arange(0, n_bins, tick_interval) + 0.5)\n",
|
|||
|
|
" ax.set_xticklabels(np.arange(0, n_bins, tick_interval))\n",
|
|||
|
|
" ax.set_yticklabels(np.arange(0, n_bins, tick_interval))\n",
|
|||
|
|
"\n",
|
|||
|
|
" plt.xticks(rotation=45, ha='right')\n",
|
|||
|
|
" plt.yticks(rotation=0)\n",
|
|||
|
|
" plt.tight_layout() # 调整布局\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(\"分析完成。\")\n",
|
|||
|
|
" return performance_pivot, count_pivot, fig\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 如何使用 ---\n",
|
|||
|
|
"# 假设你的包含预测结果和所需列的 DataFrame 是 final_predictions_df\n",
|
|||
|
|
"# 确保它包含 'score', 'label', 'circ_mv', 'future_return'\n",
|
|||
|
|
"\n",
|
|||
|
|
"# # 示例调用 (你需要有实际的 score_df)\n",
|
|||
|
|
"try:\n",
|
|||
|
|
" # 确保数据类型正确\n",
|
|||
|
|
" cols_to_numeric = ['score', 'label', 'circ_mv', 'future_return']\n",
|
|||
|
|
" for col in cols_to_numeric:\n",
|
|||
|
|
" if col in score_df.columns:\n",
|
|||
|
|
" score_df[col] = pd.to_numeric(score_df[col], errors='coerce')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 调用分析函数\n",
|
|||
|
|
" performance_matrix, count_matrix, heatmap_figure = analyze_score_performance_2d(\n",
|
|||
|
|
" score_df,\n",
|
|||
|
|
" n_bins=100, # 你要求的100分箱\n",
|
|||
|
|
" min_samples_per_bin=50 # 每个格子至少需要50个样本才显示IC,可以调整\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 显示图像\n",
|
|||
|
|
" if heatmap_figure:\n",
|
|||
|
|
" plt.show()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 可以查看具体的 performance_matrix 和 count_matrix\n",
|
|||
|
|
" # print(\"\\nPerformance Matrix (Spearman IC):\")\n",
|
|||
|
|
" # print(performance_matrix)\n",
|
|||
|
|
" # print(\"\\nCount Matrix:\")\n",
|
|||
|
|
" # print(count_matrix)\n",
|
|||
|
|
"\n",
|
|||
|
|
"except ValueError as ve:\n",
|
|||
|
|
" print(f\"数据错误: {ve}\")\n",
|
|||
|
|
"except Exception as e:\n",
|
|||
|
|
" print(f\"发生未知错误: {e}\")"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 58,
|
|||
|
|
"id": "a436dba4",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"Empty DataFrame\n",
|
|||
|
|
"Columns: [ts_code, trade_date, is_st]\n",
|
|||
|
|
"Index: []\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print(df[(df['ts_code'] == '600242.SH') & (df['trade_date'] >= '2023-06-01')][['ts_code', 'trade_date', 'is_st']])"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"metadata": {
|
|||
|
|
"kernelspec": {
|
|||
|
|
"display_name": "new_trader",
|
|||
|
|
"language": "python",
|
|||
|
|
"name": "python3"
|
|||
|
|
},
|
|||
|
|
"language_info": {
|
|||
|
|
"codemirror_mode": {
|
|||
|
|
"name": "ipython",
|
|||
|
|
"version": 3
|
|||
|
|
},
|
|||
|
|
"file_extension": ".py",
|
|||
|
|
"mimetype": "text/x-python",
|
|||
|
|
"name": "python",
|
|||
|
|
"nbconvert_exporter": "python",
|
|||
|
|
"pygments_lexer": "ipython3",
|
|||
|
|
"version": "3.11.11"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"nbformat": 4,
|
|||
|
|
"nbformat_minor": 5
|
|||
|
|
}
|