2474 lines
149 KiB
Plaintext
2474 lines
149 KiB
Plaintext
|
|
{
|
|||
|
|
"cells": [
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 1,
|
|||
|
|
"id": "79a7758178bafdd3",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:46:06.987506Z",
|
|||
|
|
"start_time": "2025-04-03T12:46:06.259551Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"/mnt/d/PyProject/NewStock\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"%load_ext autoreload\n",
|
|||
|
|
"%autoreload 2\n",
|
|||
|
|
"\n",
|
|||
|
|
"import gc\n",
|
|||
|
|
"import os\n",
|
|||
|
|
"import sys\n",
|
|||
|
|
"sys.path.append('/mnt/d/PyProject/NewStock/')\n",
|
|||
|
|
"print(os.getcwd())\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"from main.factor.factor import get_rolling_factor, get_simple_factor\n",
|
|||
|
|
"from main.utils.factor import read_industry_data\n",
|
|||
|
|
"from main.utils.factor_processor import calculate_score\n",
|
|||
|
|
"from main.utils.utils import read_and_merge_h5_data, merge_with_industry_data\n",
|
|||
|
|
"\n",
|
|||
|
|
"import warnings\n",
|
|||
|
|
"\n",
|
|||
|
|
"warnings.filterwarnings(\"ignore\")"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 2,
|
|||
|
|
"id": "a79cafb06a7e0e43",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:00.212859Z",
|
|||
|
|
"start_time": "2025-04-03T12:46:06.998047Z"
|
|||
|
|
},
|
|||
|
|
"scrolled": true
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"daily data\n",
|
|||
|
|
"daily basic\n",
|
|||
|
|
"inner merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"stk limit\n",
|
|||
|
|
"left merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"money flow\n",
|
|||
|
|
"left merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"cyq perf\n",
|
|||
|
|
"left merge on ['ts_code', 'trade_date']\n",
|
|||
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|||
|
|
"RangeIndex: 8692146 entries, 0 to 8692145\n",
|
|||
|
|
"Data columns (total 33 columns):\n",
|
|||
|
|
" # Column Dtype \n",
|
|||
|
|
"--- ------ ----- \n",
|
|||
|
|
" 0 ts_code object \n",
|
|||
|
|
" 1 trade_date datetime64[ns]\n",
|
|||
|
|
" 2 open float64 \n",
|
|||
|
|
" 3 close float64 \n",
|
|||
|
|
" 4 high float64 \n",
|
|||
|
|
" 5 low float64 \n",
|
|||
|
|
" 6 vol float64 \n",
|
|||
|
|
" 7 pct_chg float64 \n",
|
|||
|
|
" 8 amount float64 \n",
|
|||
|
|
" 9 turnover_rate float64 \n",
|
|||
|
|
" 10 pe_ttm float64 \n",
|
|||
|
|
" 11 circ_mv float64 \n",
|
|||
|
|
" 12 total_mv float64 \n",
|
|||
|
|
" 13 volume_ratio float64 \n",
|
|||
|
|
" 14 is_st bool \n",
|
|||
|
|
" 15 up_limit float64 \n",
|
|||
|
|
" 16 down_limit float64 \n",
|
|||
|
|
" 17 buy_sm_vol float64 \n",
|
|||
|
|
" 18 sell_sm_vol float64 \n",
|
|||
|
|
" 19 buy_lg_vol float64 \n",
|
|||
|
|
" 20 sell_lg_vol float64 \n",
|
|||
|
|
" 21 buy_elg_vol float64 \n",
|
|||
|
|
" 22 sell_elg_vol float64 \n",
|
|||
|
|
" 23 net_mf_vol float64 \n",
|
|||
|
|
" 24 his_low float64 \n",
|
|||
|
|
" 25 his_high float64 \n",
|
|||
|
|
" 26 cost_5pct float64 \n",
|
|||
|
|
" 27 cost_15pct float64 \n",
|
|||
|
|
" 28 cost_50pct float64 \n",
|
|||
|
|
" 29 cost_85pct float64 \n",
|
|||
|
|
" 30 cost_95pct float64 \n",
|
|||
|
|
" 31 weight_avg float64 \n",
|
|||
|
|
" 32 winner_rate float64 \n",
|
|||
|
|
"dtypes: bool(1), datetime64[ns](1), float64(30), object(1)\n",
|
|||
|
|
"memory usage: 2.1+ GB\n",
|
|||
|
|
"None\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"from main.utils.utils import read_and_merge_h5_data\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('daily data')\n",
|
|||
|
|
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_data.h5', key='daily_data',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg', 'amount'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('daily basic')\n",
|
|||
|
|
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_basic.h5', key='daily_basic',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio',\n",
|
|||
|
|
" 'is_st'], df=df, join='inner')\n",
|
|||
|
|
"\n",
|
|||
|
|
"print('stk limit')\n",
|
|||
|
|
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/stk_limit.h5', key='stk_limit',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
|
|||
|
|
" df=df)\n",
|
|||
|
|
"print('money flow')\n",
|
|||
|
|
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/money_flow.h5', key='money_flow',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
|
|||
|
|
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
|
|||
|
|
" df=df)\n",
|
|||
|
|
"print('cyq perf')\n",
|
|||
|
|
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/cyq_perf.h5', key='cyq_perf',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
|
|||
|
|
" 'cost_50pct',\n",
|
|||
|
|
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
|
|||
|
|
" df=df)\n",
|
|||
|
|
"print(df.info())"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 3,
|
|||
|
|
"id": "cac01788dac10678",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:10.527104Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:00.488715Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"industry\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print('industry')\n",
|
|||
|
|
"industry_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/industry_data.h5', key='industry_data',\n",
|
|||
|
|
" columns=['ts_code', 'l2_code', 'in_date'],\n",
|
|||
|
|
" df=None, on=['ts_code'], join='left')\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def merge_with_industry_data(df, industry_df):\n",
|
|||
|
|
" # 确保日期字段是 datetime 类型\n",
|
|||
|
|
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
|
|||
|
|
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
|
|||
|
|
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
|
|||
|
|
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 使用 merge_asof 进行向后合并\n",
|
|||
|
|
" merged = pd.merge_asof(\n",
|
|||
|
|
" df_sorted,\n",
|
|||
|
|
" industry_df_sorted,\n",
|
|||
|
|
" by='ts_code', # 按 ts_code 分组\n",
|
|||
|
|
" left_on='trade_date',\n",
|
|||
|
|
" right_on='in_date',\n",
|
|||
|
|
" direction='backward'\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 获取每个 ts_code 的最早 in_date 记录\n",
|
|||
|
|
" min_in_date_per_ts = (industry_df_sorted\n",
|
|||
|
|
" .groupby('ts_code')\n",
|
|||
|
|
" .first()\n",
|
|||
|
|
" .reset_index()[['ts_code', 'l2_code']])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 填充未匹配到的记录(trade_date 早于所有 in_date 的情况)\n",
|
|||
|
|
" merged['l2_code'] = merged['l2_code'].fillna(\n",
|
|||
|
|
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 保留需要的列并重置索引\n",
|
|||
|
|
" result = merged.reset_index(drop=True)\n",
|
|||
|
|
" return result\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 使用示例\n",
|
|||
|
|
"df = merge_with_industry_data(df, industry_df)\n",
|
|||
|
|
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 4,
|
|||
|
|
"id": "c4e9e1d31da6dba6",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:10.719252Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:10.541247Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from main.factor.factor import *\n",
|
|||
|
|
"\n",
|
|||
|
|
"def calculate_indicators(df):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 计算四个指标:当日涨跌幅、5日移动平均、RSI、MACD。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" df = df.sort_values('trade_date')\n",
|
|||
|
|
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
|
|||
|
|
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
|
|||
|
|
" delta = df['close'].diff()\n",
|
|||
|
|
" gain = delta.where(delta > 0, 0)\n",
|
|||
|
|
" loss = -delta.where(delta < 0, 0)\n",
|
|||
|
|
" avg_gain = gain.rolling(window=14).mean()\n",
|
|||
|
|
" avg_loss = loss.rolling(window=14).mean()\n",
|
|||
|
|
" rs = avg_gain / avg_loss\n",
|
|||
|
|
" df['RSI'] = 100 - (100 / (1 + rs))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算MACD\n",
|
|||
|
|
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
|
|||
|
|
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
|
|||
|
|
" df['MACD'] = ema12 - ema26\n",
|
|||
|
|
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
|
|||
|
|
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 4. 情绪因子1:市场上涨比例(Up Ratio)\n",
|
|||
|
|
" df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)\n",
|
|||
|
|
" df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 5. 情绪因子2:成交量变化率(Volume Change Rate)\n",
|
|||
|
|
" df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量\n",
|
|||
|
|
" df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 6. 情绪因子3:波动率(Volatility)\n",
|
|||
|
|
" df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 7. 情绪因子4:成交额变化率(Amount Change Rate)\n",
|
|||
|
|
" df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额\n",
|
|||
|
|
" df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率\n",
|
|||
|
|
"\n",
|
|||
|
|
" # df = sentiment_panic_greed_index(df)\n",
|
|||
|
|
" # df = sentiment_market_breadth_proxy(df)\n",
|
|||
|
|
" # df = sentiment_reversal_indicator(df)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def generate_index_indicators(h5_filename):\n",
|
|||
|
|
" df = pd.read_hdf(h5_filename, key='index_data')\n",
|
|||
|
|
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
|
|||
|
|
" df = df.sort_values('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算每个ts_code的相关指标\n",
|
|||
|
|
" df_indicators = []\n",
|
|||
|
|
" for ts_code in df['ts_code'].unique():\n",
|
|||
|
|
" df_index = df[df['ts_code'] == ts_code].copy()\n",
|
|||
|
|
" df_index = calculate_indicators(df_index)\n",
|
|||
|
|
" df_indicators.append(df_index)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 合并所有指数的结果\n",
|
|||
|
|
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 保留trade_date列,并将同一天的数据按ts_code合并成一行\n",
|
|||
|
|
" df_final = df_all_indicators.pivot_table(\n",
|
|||
|
|
" index='trade_date',\n",
|
|||
|
|
" columns='ts_code',\n",
|
|||
|
|
" values=['daily_return', \n",
|
|||
|
|
" 'RSI', 'MACD', 'Signal_line', 'MACD_hist', \n",
|
|||
|
|
" # 'sentiment_panic_greed_index',\n",
|
|||
|
|
" 'up_ratio_20d', 'volume_change_rate', 'volatility',\n",
|
|||
|
|
" 'amount_change_rate', 'amount_mean'],\n",
|
|||
|
|
" aggfunc='last'\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
|
|||
|
|
" df_final = df_final.reset_index()\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df_final\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 使用函数\n",
|
|||
|
|
"h5_filename = '/mnt/d/PyProject/NewStock/data/index_data.h5'\n",
|
|||
|
|
"index_data = generate_index_indicators(h5_filename)\n",
|
|||
|
|
"index_data = index_data.dropna()\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 5,
|
|||
|
|
"id": "a735bc02ceb4d872",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:10.821169Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:10.751831Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import talib\n",
|
|||
|
|
"import numpy as np"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 6,
|
|||
|
|
"id": "53f86ddc0677a6d7",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:15.944254Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:10.826179Z"
|
|||
|
|
},
|
|||
|
|
"jupyter": {
|
|||
|
|
"source_hidden": true
|
|||
|
|
},
|
|||
|
|
"scrolled": true
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from main.utils.factor import get_act_factor\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def read_industry_data(h5_filename):\n",
|
|||
|
|
" # 读取 H5 文件中所有的行业数据\n",
|
|||
|
|
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
|
|||
|
|
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
|
|||
|
|
" ]) # 假设 H5 文件的键是 'industry_data'\n",
|
|||
|
|
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
" industry_data = industry_data.reindex()\n",
|
|||
|
|
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
|
|||
|
|
" industry_data['obv'] = grouped.apply(\n",
|
|||
|
|
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
|
|||
|
|
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = get_act_factor(industry_data, cat=False)\n",
|
|||
|
|
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
|
|||
|
|
" # factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
|
|||
|
|
" # \n",
|
|||
|
|
" # for factor in factor_columns:\n",
|
|||
|
|
" # if factor in industry_data.columns:\n",
|
|||
|
|
" # # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
|
|||
|
|
" # industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
|
|||
|
|
" # lambda x: x - x.mean())\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
|
|||
|
|
" lambda x: x.rank(pct=True))\n",
|
|||
|
|
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
|
|||
|
|
" lambda x: x.rank(pct=True))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # cs_rank_intraday_range(industry_data)\n",
|
|||
|
|
" # cs_rank_close_pos_in_range(industry_data)\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = industry_data.rename(\n",
|
|||
|
|
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
|
|||
|
|
"\n",
|
|||
|
|
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
|
|||
|
|
" return industry_data\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"industry_df = read_industry_data('/mnt/d/PyProject/NewStock/data/sw_daily.h5')\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 7,
|
|||
|
|
"id": "dbe2fd8021b9417f",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:15.969344Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:15.963327Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"['ts_code', 'open', 'close', 'high', 'low', 'amount', 'circ_mv', 'total_mv', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'in_date']\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"origin_columns = df.columns.tolist()\n",
|
|||
|
|
"origin_columns = [col for col in origin_columns if\n",
|
|||
|
|
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
|
|||
|
|
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n",
|
|||
|
|
"origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
|
|||
|
|
"print(origin_columns)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 8,
|
|||
|
|
"id": "85c3e3d0235ffffa",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T12:47:16.089879Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:15.990101Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"fina_indicator_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/fina_indicator.h5', key='fina_indicator',\n",
|
|||
|
|
" columns=['ts_code', 'ann_date', 'undist_profit_ps', 'ocfps', 'bps'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"cashflow_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/cashflow.h5', key='cashflow',\n",
|
|||
|
|
" columns=['ts_code', 'ann_date', 'n_cashflow_act'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"balancesheet_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/balancesheet.h5', key='balancesheet',\n",
|
|||
|
|
" columns=['ts_code', 'ann_date', 'money_cap', 'total_liab'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"top_list_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/top_list.h5', key='top_list',\n",
|
|||
|
|
" columns=['ts_code', 'trade_date', 'reason'],\n",
|
|||
|
|
" df=None)\n",
|
|||
|
|
"\n",
|
|||
|
|
"top_list_df = top_list_df.sort_values(by='trade_date', ascending=False).drop_duplicates(subset=['ts_code', 'trade_date'], keep='first').sort_values(by='trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
"stk_holdertrade_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/stk_holdertrade.h5', key='stk_holdertrade',\n",
|
|||
|
|
" columns=['ts_code', 'ann_date', 'in_de', 'change_ratio'],\n",
|
|||
|
|
" df=None)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 9,
|
|||
|
|
"id": "92d84ce15a562ec6",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T13:08:01.612695Z",
|
|||
|
|
"start_time": "2025-04-03T12:47:16.121802Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"开始计算股东增减持因子...\n",
|
|||
|
|
"警告: 'in_de' 列中存在未映射的值,可能导致 _direction 列出现NaN。\n",
|
|||
|
|
"股东增减持因子计算完成。\n",
|
|||
|
|
"Calculating cat_senti_mom_vol_spike...\n",
|
|||
|
|
"Finished cat_senti_mom_vol_spike.\n",
|
|||
|
|
"Calculating cat_senti_pre_breakout...\n",
|
|||
|
|
"Calculating atr_10 as it's missing...\n",
|
|||
|
|
"Calculating atr_40 as it's missing...\n",
|
|||
|
|
"Finished cat_senti_pre_breakout.\n",
|
|||
|
|
"计算因子 ts_turnover_rate_acceleration_5_20\n",
|
|||
|
|
"计算因子 ts_vol_sustain_10_30\n",
|
|||
|
|
"计算因子 cs_amount_outlier_10\n",
|
|||
|
|
"计算因子 ts_ff_to_total_turnover_ratio\n",
|
|||
|
|
"计算因子 ts_price_volume_trend_coherence_5_20\n",
|
|||
|
|
"计算因子 ts_ff_turnover_rate_surge_10\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
|
|||
|
|
"开始计算因子: AR, BR (原地修改)...\n",
|
|||
|
|
"因子 AR, BR 计算成功。\n",
|
|||
|
|
"因子 AR, BR 计算流程结束。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"使用 'ann_date' 作为财务数据生效日期。\n",
|
|||
|
|
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
|
|||
|
|
"计算 BBI...\n",
|
|||
|
|
"--- 计算日级别偏离度 (使用 pct_chg) ---\n",
|
|||
|
|
"--- 计算日级别动量基准 (使用 pct_chg) ---\n",
|
|||
|
|
"日级别动量基准计算完成 (使用 pct_chg)。\n",
|
|||
|
|
"日级别偏离度计算完成 (使用 pct_chg)。\n",
|
|||
|
|
"--- 计算日级别行业偏离度 (使用 pct_chg 和行业基准) ---\n",
|
|||
|
|
"--- 计算日级别行业动量基准 (使用 pct_chg 和 cat_l2_code) ---\n",
|
|||
|
|
"错误: 计算日级别行业动量基准需要以下列: ['pct_chg', 'cat_l2_code', 'trade_date', 'ts_code']。\n",
|
|||
|
|
"错误: 计算日级别行业偏离度需要以下列: ['pct_chg', 'daily_industry_positive_benchmark', 'daily_industry_negative_benchmark']。请先运行 daily_industry_momentum_benchmark(df)。\n",
|
|||
|
|
"Index(['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol',\n",
|
|||
|
|
" 'pct_chg', 'amount', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv',\n",
|
|||
|
|
" 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol',\n",
|
|||
|
|
" 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol',\n",
|
|||
|
|
" 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct',\n",
|
|||
|
|
" 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg',\n",
|
|||
|
|
" 'winner_rate', 'l2_code', 'holder_net_change_sum_10d',\n",
|
|||
|
|
" 'holder_increase_days_10d', 'holder_decrease_days_10d',\n",
|
|||
|
|
" 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d',\n",
|
|||
|
|
" 'holder_direction_score_10d', 'cat_senti_mom_vol_spike',\n",
|
|||
|
|
" 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20',\n",
|
|||
|
|
" 'ts_vol_sustain_10_30', 'cs_amount_outlier_10',\n",
|
|||
|
|
" 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20',\n",
|
|||
|
|
" 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR',\n",
|
|||
|
|
" 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio',\n",
|
|||
|
|
" 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor',\n",
|
|||
|
|
" 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity',\n",
|
|||
|
|
" 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio',\n",
|
|||
|
|
" 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change',\n",
|
|||
|
|
" 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel',\n",
|
|||
|
|
" 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy',\n",
|
|||
|
|
" 'cost_support_15pct_change', 'cat_winner_price_zone',\n",
|
|||
|
|
" 'flow_chip_consistency', 'profit_taking_vs_absorb', '_is_positive',\n",
|
|||
|
|
" '_is_negative', 'cat_is_positive', '_pos_returns', '_neg_returns',\n",
|
|||
|
|
" '_pos_returns_sq', '_neg_returns_sq', 'upside_vol', 'downside_vol',\n",
|
|||
|
|
" 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate',\n",
|
|||
|
|
" 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike',\n",
|
|||
|
|
" 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike',\n",
|
|||
|
|
" 'vol_std_5', 'atr_24', 'atr_6', 'obv'],\n",
|
|||
|
|
" dtype='object')\n",
|
|||
|
|
"Calculating senti_strong_inflow...\n",
|
|||
|
|
"Finished senti_strong_inflow.\n",
|
|||
|
|
"Calculating lg_flow_mom_corr_20_60...\n",
|
|||
|
|
"Finished lg_flow_mom_corr_20_60.\n",
|
|||
|
|
"Calculating lg_flow_accel...\n",
|
|||
|
|
"Finished lg_flow_accel.\n",
|
|||
|
|
"Calculating profit_pressure...\n",
|
|||
|
|
"Finished profit_pressure.\n",
|
|||
|
|
"Calculating underwater_resistance...\n",
|
|||
|
|
"Finished underwater_resistance.\n",
|
|||
|
|
"Calculating cost_conc_std_20...\n",
|
|||
|
|
"Finished cost_conc_std_20.\n",
|
|||
|
|
"Calculating profit_decay_20...\n",
|
|||
|
|
"Finished profit_decay_20.\n",
|
|||
|
|
"Calculating vol_amp_loss_20...\n",
|
|||
|
|
"Finished vol_amp_loss_20.\n",
|
|||
|
|
"Calculating vol_drop_profit_cnt_5...\n",
|
|||
|
|
"Finished vol_drop_profit_cnt_5.\n",
|
|||
|
|
"Calculating lg_flow_vol_interact_20...\n",
|
|||
|
|
"Finished lg_flow_vol_interact_20.\n",
|
|||
|
|
"Calculating cost_break_confirm_cnt_5...\n",
|
|||
|
|
"Finished cost_break_confirm_cnt_5.\n",
|
|||
|
|
"Calculating atr_norm_channel_pos_14...\n",
|
|||
|
|
"Finished atr_norm_channel_pos_14.\n",
|
|||
|
|
"Calculating turnover_diff_skew_20...\n",
|
|||
|
|
"Finished turnover_diff_skew_20.\n",
|
|||
|
|
"Calculating lg_sm_flow_diverge_20...\n",
|
|||
|
|
"Finished lg_sm_flow_diverge_20.\n",
|
|||
|
|
"Calculating pullback_strong_20_20...\n",
|
|||
|
|
"Finished pullback_strong_20_20.\n",
|
|||
|
|
"Calculating vol_wgt_hist_pos_20...\n",
|
|||
|
|
"Finished vol_wgt_hist_pos_20.\n",
|
|||
|
|
"Calculating vol_adj_roc_20...\n",
|
|||
|
|
"Finished vol_adj_roc_20.\n",
|
|||
|
|
"Calculating cs_rank_net_lg_flow_val...\n",
|
|||
|
|
"Finished cs_rank_net_lg_flow_val.\n",
|
|||
|
|
"Calculating cs_rank_flow_divergence...\n",
|
|||
|
|
"Finished cs_rank_flow_divergence.\n",
|
|||
|
|
"Calculating cs_rank_ind_adj_lg_flow...\n",
|
|||
|
|
"Finished cs_rank_ind_adj_lg_flow.\n",
|
|||
|
|
"Calculating cs_rank_elg_buy_ratio...\n",
|
|||
|
|
"Finished cs_rank_elg_buy_ratio.\n",
|
|||
|
|
"Calculating cs_rank_rel_profit_margin...\n",
|
|||
|
|
"Finished cs_rank_rel_profit_margin.\n",
|
|||
|
|
"Calculating cs_rank_cost_breadth...\n",
|
|||
|
|
"Finished cs_rank_cost_breadth.\n",
|
|||
|
|
"Calculating cs_rank_dist_to_upper_cost...\n",
|
|||
|
|
"Finished cs_rank_dist_to_upper_cost.\n",
|
|||
|
|
"Calculating cs_rank_winner_rate...\n",
|
|||
|
|
"Finished cs_rank_winner_rate.\n",
|
|||
|
|
"Calculating cs_rank_intraday_range...\n",
|
|||
|
|
"Finished cs_rank_intraday_range.\n",
|
|||
|
|
"Calculating cs_rank_close_pos_in_range...\n",
|
|||
|
|
"Finished cs_rank_close_pos_in_range.\n",
|
|||
|
|
"Calculating cs_rank_opening_gap...\n",
|
|||
|
|
"Error calculating cs_rank_opening_gap: Missing 'pre_close' column. Assigning NaN.\n",
|
|||
|
|
"Calculating cs_rank_pos_in_hist_range...\n",
|
|||
|
|
"Finished cs_rank_pos_in_hist_range.\n",
|
|||
|
|
"Calculating cs_rank_vol_x_profit_margin...\n",
|
|||
|
|
"Finished cs_rank_vol_x_profit_margin.\n",
|
|||
|
|
"Calculating cs_rank_lg_flow_price_concordance...\n",
|
|||
|
|
"Finished cs_rank_lg_flow_price_concordance.\n",
|
|||
|
|
"Calculating cs_rank_turnover_per_winner...\n",
|
|||
|
|
"Finished cs_rank_turnover_per_winner.\n",
|
|||
|
|
"Calculating cs_rank_ind_cap_neutral_pe (Placeholder - requires statsmodels)...\n",
|
|||
|
|
"Finished cs_rank_ind_cap_neutral_pe (Placeholder).\n",
|
|||
|
|
"Calculating cs_rank_volume_ratio...\n",
|
|||
|
|
"Finished cs_rank_volume_ratio.\n",
|
|||
|
|
"Calculating cs_rank_elg_buy_sell_sm_ratio...\n",
|
|||
|
|
"Finished cs_rank_elg_buy_sell_sm_ratio.\n",
|
|||
|
|
"Calculating cs_rank_cost_dist_vol_ratio...\n",
|
|||
|
|
"Finished cs_rank_cost_dist_vol_ratio.\n",
|
|||
|
|
"Calculating cs_rank_size...\n",
|
|||
|
|
"Finished cs_rank_size.\n",
|
|||
|
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
|||
|
|
"RangeIndex: 4554725 entries, 0 to 4554724\n",
|
|||
|
|
"Columns: 194 entries, ts_code to cs_rank_size\n",
|
|||
|
|
"dtypes: bool(10), datetime64[ns](1), float64(173), int64(6), object(4)\n",
|
|||
|
|
"memory usage: 6.3+ GB\n",
|
|||
|
|
"None\n",
|
|||
|
|
"['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg', 'amount', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate', 'cat_l2_code', 'holder_net_change_sum_10d', 'holder_increase_days_10d', 'holder_decrease_days_10d', 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d', 'holder_direction_score_10d', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_24', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'price_cost_divergence', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_flow_divergence', 'cs_rank_ind_adj_lg_flow', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_opening_gap', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_ind_cap_neutral_pe', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size']\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"from main.factor.factor import *\n",
|
|||
|
|
"from main.factor.money_factor import * \n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def filter_data(df):\n",
|
|||
|
|
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
|
|||
|
|
" df = df[~df[\"is_st\"]]\n",
|
|||
|
|
" df = df[~df[\"ts_code\"].str.endswith(\"BJ\")]\n",
|
|||
|
|
" df = df[~df[\"ts_code\"].str.startswith(\"30\")]\n",
|
|||
|
|
" df = df[~df[\"ts_code\"].str.startswith(\"68\")]\n",
|
|||
|
|
" df = df[~df[\"ts_code\"].str.startswith(\"8\")]\n",
|
|||
|
|
" df = df[df[\"trade_date\"] >= \"2019-01-01\"]\n",
|
|||
|
|
" if \"in_date\" in df.columns:\n",
|
|||
|
|
" df = df.drop(columns=[\"in_date\"])\n",
|
|||
|
|
" df = df.reset_index(drop=True)\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"gc.collect()\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = filter_data(df)\n",
|
|||
|
|
"df = df.sort_values(by=[\"ts_code\", \"trade_date\"])\n",
|
|||
|
|
"\n",
|
|||
|
|
"# df = price_minus_deduction_price(df, n=120)\n",
|
|||
|
|
"# df = price_deduction_price_diff_ratio_to_sma(df, n=120)\n",
|
|||
|
|
"# df = cat_price_vs_sma_vs_deduction_price(df, n=120)\n",
|
|||
|
|
"# df = cat_reason(df, top_list_df)\n",
|
|||
|
|
"# df = cat_is_on_top_list(df, top_list_df)\n",
|
|||
|
|
"df = holder_trade_factors(df, stk_holdertrade_df)\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = cat_senti_mom_vol_spike(\n",
|
|||
|
|
" df,\n",
|
|||
|
|
" return_period=3,\n",
|
|||
|
|
" return_threshold=0.03, # 近3日涨幅超3%\n",
|
|||
|
|
" volume_ratio_threshold=1.3,\n",
|
|||
|
|
" current_pct_chg_min=0.0, # 当日必须收红\n",
|
|||
|
|
" current_pct_chg_max=0.05,\n",
|
|||
|
|
") # 当日涨幅不宜过大\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = cat_senti_pre_breakout(\n",
|
|||
|
|
" df,\n",
|
|||
|
|
" atr_short_N=10,\n",
|
|||
|
|
" atr_long_M=40,\n",
|
|||
|
|
" vol_atrophy_N=10,\n",
|
|||
|
|
" vol_atrophy_M=40,\n",
|
|||
|
|
" price_stab_N=5,\n",
|
|||
|
|
" price_stab_threshold=0.06,\n",
|
|||
|
|
" current_pct_chg_min_signal=0.002,\n",
|
|||
|
|
" current_pct_chg_max_signal=0.05,\n",
|
|||
|
|
" volume_ratio_signal_threshold=1.1,\n",
|
|||
|
|
")\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = ts_turnover_rate_acceleration_5_20(df)\n",
|
|||
|
|
"df = ts_vol_sustain_10_30(df)\n",
|
|||
|
|
"# df = cs_turnover_rate_relative_strength_20(df)\n",
|
|||
|
|
"df = cs_amount_outlier_10(df)\n",
|
|||
|
|
"df = ts_ff_to_total_turnover_ratio(df)\n",
|
|||
|
|
"df = ts_price_volume_trend_coherence_5_20(df)\n",
|
|||
|
|
"# df = ts_turnover_rate_trend_strength_5(df)\n",
|
|||
|
|
"df = ts_ff_turnover_rate_surge_10(df)\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = add_financial_factor(df, fina_indicator_df, factor_value_col=\"undist_profit_ps\")\n",
|
|||
|
|
"df = add_financial_factor(df, fina_indicator_df, factor_value_col=\"ocfps\")\n",
|
|||
|
|
"calculate_arbr(df, N=26)\n",
|
|||
|
|
"df[\"log_circ_mv\"] = np.log(df[\"circ_mv\"])\n",
|
|||
|
|
"df = calculate_cashflow_to_ev_factor(df, cashflow_df, balancesheet_df)\n",
|
|||
|
|
"df = caculate_book_to_price_ratio(df, fina_indicator_df)\n",
|
|||
|
|
"df = turnover_rate_n(df, n=5)\n",
|
|||
|
|
"df = variance_n(df, n=20)\n",
|
|||
|
|
"df = bbi_ratio_factor(df)\n",
|
|||
|
|
"df = daily_deviation(df)\n",
|
|||
|
|
"df = daily_industry_deviation(df)\n",
|
|||
|
|
"df, _ = get_rolling_factor(df)\n",
|
|||
|
|
"df, _ = get_simple_factor(df)\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = calculate_strong_inflow_signal(df)\n",
|
|||
|
|
"\n",
|
|||
|
|
"df = df.rename(columns={\"l1_code\": \"cat_l1_code\"})\n",
|
|||
|
|
"df = df.rename(columns={\"l2_code\": \"cat_l2_code\"})\n",
|
|||
|
|
"\n",
|
|||
|
|
"lg_flow_mom_corr(df, N=20, M=60)\n",
|
|||
|
|
"lg_flow_accel(df)\n",
|
|||
|
|
"profit_pressure(df)\n",
|
|||
|
|
"underwater_resistance(df)\n",
|
|||
|
|
"cost_conc_std(df, N=20)\n",
|
|||
|
|
"profit_decay(df, N=20)\n",
|
|||
|
|
"vol_amp_loss(df, N=20)\n",
|
|||
|
|
"vol_drop_profit_cnt(df, N=20, M=5)\n",
|
|||
|
|
"lg_flow_vol_interact(df, N=20)\n",
|
|||
|
|
"cost_break_confirm_cnt(df, M=5)\n",
|
|||
|
|
"atr_norm_channel_pos(df, N=14)\n",
|
|||
|
|
"turnover_diff_skew(df, N=20)\n",
|
|||
|
|
"lg_sm_flow_diverge(df, N=20)\n",
|
|||
|
|
"pullback_strong(df, N=20, M=20)\n",
|
|||
|
|
"vol_wgt_hist_pos(df, N=20)\n",
|
|||
|
|
"vol_adj_roc(df, N=20)\n",
|
|||
|
|
"\n",
|
|||
|
|
"cs_rank_net_lg_flow_val(df)\n",
|
|||
|
|
"cs_rank_flow_divergence(df)\n",
|
|||
|
|
"cs_rank_industry_adj_lg_flow(df) # Needs cat_l2_code\n",
|
|||
|
|
"cs_rank_elg_buy_ratio(df)\n",
|
|||
|
|
"cs_rank_rel_profit_margin(df)\n",
|
|||
|
|
"cs_rank_cost_breadth(df)\n",
|
|||
|
|
"cs_rank_dist_to_upper_cost(df)\n",
|
|||
|
|
"cs_rank_winner_rate(df)\n",
|
|||
|
|
"cs_rank_intraday_range(df)\n",
|
|||
|
|
"cs_rank_close_pos_in_range(df)\n",
|
|||
|
|
"cs_rank_opening_gap(df) # Needs pre_close\n",
|
|||
|
|
"cs_rank_pos_in_hist_range(df) # Needs his_low, his_high\n",
|
|||
|
|
"cs_rank_vol_x_profit_margin(df)\n",
|
|||
|
|
"cs_rank_lg_flow_price_concordance(df)\n",
|
|||
|
|
"cs_rank_turnover_per_winner(df)\n",
|
|||
|
|
"cs_rank_ind_cap_neutral_pe(df) # Placeholder - needs external libraries\n",
|
|||
|
|
"cs_rank_volume_ratio(df) # Needs volume_ratio\n",
|
|||
|
|
"cs_rank_elg_buy_sell_sm_ratio(df)\n",
|
|||
|
|
"cs_rank_cost_dist_vol_ratio(df) # Needs volume_ratio\n",
|
|||
|
|
"cs_rank_size(df) # Needs circ_mv\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# df = df.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(df.info())\n",
|
|||
|
|
"print(df.columns.tolist())"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 10,
|
|||
|
|
"id": "b87b938028afa206",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T13:08:03.658725Z",
|
|||
|
|
"start_time": "2025-04-03T13:08:02.469611Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from scipy.stats import ks_2samp, wasserstein_distance\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def remove_shifted_features(train_data, test_data, feature_columns, ks_threshold=0.05, wasserstein_threshold=0.1,\n",
|
|||
|
|
" importance_threshold=0.05):\n",
|
|||
|
|
" dropped_features = []\n",
|
|||
|
|
"\n",
|
|||
|
|
" # **统计数据漂移**\n",
|
|||
|
|
" numeric_columns = train_data.select_dtypes(include=['float64', 'int64']).columns\n",
|
|||
|
|
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
|
|||
|
|
" for feature in numeric_columns:\n",
|
|||
|
|
" ks_stat, p_value = ks_2samp(train_data[feature], test_data[feature])\n",
|
|||
|
|
" wasserstein_dist = wasserstein_distance(train_data[feature], test_data[feature])\n",
|
|||
|
|
"\n",
|
|||
|
|
" if p_value < ks_threshold or wasserstein_dist > wasserstein_threshold:\n",
|
|||
|
|
" dropped_features.append(feature)\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(f\"检测到 {len(dropped_features)} 个可能漂移的特征: {dropped_features}\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # **应用阈值进行最终筛选**\n",
|
|||
|
|
" filtered_features = [f for f in feature_columns if f not in dropped_features]\n",
|
|||
|
|
"\n",
|
|||
|
|
" return filtered_features, dropped_features\n",
|
|||
|
|
"\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 11,
|
|||
|
|
"id": "f4f16d63ad18d1bc",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T13:08:03.670700Z",
|
|||
|
|
"start_time": "2025-04-03T13:08:03.665739Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import polars as pl\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"# from tqdm import tqdm # Polars 通常处理速度快,可能不需要 tqdm,但如果需要可以保留\n",
|
|||
|
|
"\n",
|
|||
|
|
"def cs_mad_filter_polars(df: pl.DataFrame,\n",
|
|||
|
|
" features: list[str],\n",
|
|||
|
|
" k: float = 3.0,\n",
|
|||
|
|
" scale_factor: float = 1.4826) -> pl.DataFrame:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 对指定特征列进行截面 MAD 去极值处理 (Polars 版本)。\n",
|
|||
|
|
"\n",
|
|||
|
|
" 方法: 对每日截面数据,计算 median 和 MAD,\n",
|
|||
|
|
" 将超出 [median - k * scale * MAD, median + k * scale * MAD] 范围的值\n",
|
|||
|
|
" 替换为边界值 (Winsorization)。\n",
|
|||
|
|
" scale_factor=1.4826 使得 MAD 约等于正态分布的标准差。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df (pl.DataFrame): 输入 Polars DataFrame,需包含 'trade_date' 和 features 列。\n",
|
|||
|
|
" features (list): 需要处理的特征列名列表。\n",
|
|||
|
|
" k (float): MAD 的倍数,用于确定边界。默认为 3.0。\n",
|
|||
|
|
" scale_factor (float): MAD 的缩放因子。默认为 1.4826。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" pl.DataFrame: 处理后的 Polars DataFrame (返回新 DataFrame,原 DataFrame 不变)。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" print(f\"开始截面 MAD 去极值处理 (k={k})...\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 检查特征列是否存在\n",
|
|||
|
|
" existing_features = [col for col in features if col in df.columns]\n",
|
|||
|
|
" missing_features = [col for col in features if col not in df.columns]\n",
|
|||
|
|
"\n",
|
|||
|
|
" if missing_features:\n",
|
|||
|
|
" print(f\"警告: DataFrame 中缺少以下特征列: {missing_features}。这些列将跳过去极值处理。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" if not existing_features:\n",
|
|||
|
|
" print(\"没有找到需要处理的特征列。跳过去极值处理。\")\n",
|
|||
|
|
" return df # 返回原始 DataFrame\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 构建一个表达式列表,用于一次性处理所有特征列\n",
|
|||
|
|
" expressions = []\n",
|
|||
|
|
" for col in existing_features:\n",
|
|||
|
|
" col_expr = pl.col(col)\n",
|
|||
|
|
"\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" # 计算截面中位数 (median) - 使用 over()\n",
|
|||
|
|
" median_val = col_expr.median().over('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算截面 MAD (Median Absolute Deviation from Median) - 使用 over()\n",
|
|||
|
|
" mad_val = (col_expr - median_val).abs().median().over('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算上下边界\n",
|
|||
|
|
" # 确保 mad_val 不为 null 或 0,否则边界会是 NaN\n",
|
|||
|
|
" lower_bound = median_val - k * scale_factor * mad_val\n",
|
|||
|
|
" upper_bound = median_val + k * scale_factor * mad_val\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 应用 Winsorization (裁剪值到边界内)\n",
|
|||
|
|
" # 使用 when/then/otherwise 来确保边界值为 NaN 或 MAD 为 0 时不进行过滤,保留原始值\n",
|
|||
|
|
" clipped_col_expr = pl.when(\n",
|
|||
|
|
" lower_bound.is_not_null() &\n",
|
|||
|
|
" upper_bound.is_not_null() &\n",
|
|||
|
|
" mad_val.is_not_null() &\n",
|
|||
|
|
" (mad_val != 0) # MAD 为 0 意味着所有值都相同,此时不应裁剪\n",
|
|||
|
|
" ).then(col_expr.clip(lower_bound, upper_bound)).otherwise(col_expr).alias(col) # 保持列名不变\n",
|
|||
|
|
"\n",
|
|||
|
|
" expressions.append(clipped_col_expr)\n",
|
|||
|
|
"\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" print(f\"警告: 处理列 '{col}' 时发生错误: {e},跳过此列的 MAD 处理。保留原始列。\")\n",
|
|||
|
|
" expressions.append(col_expr) # 发生错误时保留原始列\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 使用 with_columns 一次性应用所有表达式,创建新的 DataFrame\n",
|
|||
|
|
" result_df = df.with_columns(expressions)\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(\"截面 MAD 去极值处理完成。\")\n",
|
|||
|
|
" return result_df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 您的 cs_zscore_standardize_polars 函数 (确保其也是全 Polars 实现) ---\n",
|
|||
|
|
"# (如果您的 cs_zscore_standardize_polars 也有类似问题,也需要按此模式修改)\n",
|
|||
|
|
"def cs_zscore_standardize_polars(df: pl.DataFrame, features: list[str], epsilon: float = 1e-10) -> pl.DataFrame:\n",
|
|||
|
|
" print(\"开始截面 Z-Score 标准化...\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" existing_features = [col for col in features if col in df.columns]\n",
|
|||
|
|
" missing_features = [col for col in features if col not in df.columns]\n",
|
|||
|
|
"\n",
|
|||
|
|
" if missing_features:\n",
|
|||
|
|
" print(f\"警告: DataFrame 中缺少以下特征列: {missing_features}。这些列将跳过标准化处理。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" if not existing_features:\n",
|
|||
|
|
" print(\"没有找到需要标准化的特征列。跳过标准化处理。\")\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
" expressions = []\n",
|
|||
|
|
" for col in existing_features:\n",
|
|||
|
|
" col_expr = pl.col(col)\n",
|
|||
|
|
"\n",
|
|||
|
|
" if not col_expr.dtype.is_numeric():\n",
|
|||
|
|
" print(f\"警告: 列 '{col}' 不是数值类型 ({col_expr.dtype}),跳过此列的标准化处理。\")\n",
|
|||
|
|
" expressions.append(col_expr)\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" mean_val = col_expr.mean().over('trade_date')\n",
|
|||
|
|
" std_val = col_expr.std().over('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 处理标准差为 0 或 NaN 的情况,防止 Inf/NaN 扩散\n",
|
|||
|
|
" z_score_expr = pl.when(std_val.is_null() | (std_val == 0)).then(0) # 如果标准差为 null 或 0,设为 0 (或 np.nan,取决于业务逻辑).otherwise((col_expr - mean_val) / (std_val + epsilon)).alias(col)\n",
|
|||
|
|
" expressions.append(z_score_expr)\n",
|
|||
|
|
"\n",
|
|||
|
|
" result_df = df.with_columns(expressions)\n",
|
|||
|
|
" print(\"截面 Z-Score 标准化完成。\")\n",
|
|||
|
|
" return result_df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 2. 行业市值中性化 ---\n",
|
|||
|
|
"def cs_neutralize_industry_cap(df: pd.DataFrame,\n",
|
|||
|
|
" features: list,\n",
|
|||
|
|
" industry_col: str = 'cat_l2_code',\n",
|
|||
|
|
" market_cap_col: str = 'circ_mv'):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 对指定特征列进行截面行业和对数市值中性化 (原地修改)。\n",
|
|||
|
|
" 使用 OLS 回归: feature ~ 1 + log(market_cap) + C(industry)\n",
|
|||
|
|
" 将回归残差写回原特征列。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df (pd.DataFrame): 输入 DataFrame,需包含 'trade_date', features 列,\n",
|
|||
|
|
" industry_col, market_cap_col。\n",
|
|||
|
|
" features (list): 需要处理的特征列名列表。\n",
|
|||
|
|
" industry_col (str): 行业分类列名。\n",
|
|||
|
|
" market_cap_col (str): 流通市值列名。\n",
|
|||
|
|
"\n",
|
|||
|
|
" WARNING: 此函数会原地修改输入的 DataFrame 'df' 的 features 列。\n",
|
|||
|
|
" 计算量较大,可能耗时较长。\n",
|
|||
|
|
" 需要安装 statsmodels 库 (pip install statsmodels)。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" print(\"开始截面行业市值中性化...\")\n",
|
|||
|
|
" required_cols = features + ['trade_date', industry_col, market_cap_col]\n",
|
|||
|
|
" if not all(col in df.columns for col in required_cols):\n",
|
|||
|
|
" missing = [col for col in required_cols if col not in df.columns]\n",
|
|||
|
|
" print(f\"错误: DataFrame 中缺少必需列: {missing}。无法进行中性化。\")\n",
|
|||
|
|
" return\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 预处理:计算 log 市值,处理 industry code 可能的 NaN\n",
|
|||
|
|
" log_cap_col = '_log_market_cap'\n",
|
|||
|
|
" df[log_cap_col] = np.log1p(df[market_cap_col]) # log1p 处理 0 值\n",
|
|||
|
|
" # df[industry_col] = df[industry_col].cat.add_categories('UnknownIndustry')\n",
|
|||
|
|
" # df[industry_col] = df[industry_col].fillna('UnknownIndustry') # 填充行业 NaN\n",
|
|||
|
|
" # df[industry_col] = df[industry_col].astype('category') # 转为类别,ols 会自动处理\n",
|
|||
|
|
"\n",
|
|||
|
|
" dates = df['trade_date'].unique()\n",
|
|||
|
|
" all_residuals = [] # 用于收集所有日期的残差\n",
|
|||
|
|
"\n",
|
|||
|
|
" for date in tqdm(dates, desc=\"Neutralizing\"):\n",
|
|||
|
|
" daily_data = df.loc[df['trade_date'] == date, features + [log_cap_col, industry_col]].copy() # 使用 .loc 获取副本\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 准备自变量 X (常数项 + log市值 + 行业哑变量)\n",
|
|||
|
|
" X = daily_data[[log_cap_col]]\n",
|
|||
|
|
" X = sm.add_constant(X, prepend=True) # 添加常数项\n",
|
|||
|
|
" # 创建行业哑变量 (drop_first=True 避免共线性)\n",
|
|||
|
|
" industry_dummies = pd.get_dummies(daily_data[industry_col], prefix=industry_col, drop_first=True)\n",
|
|||
|
|
" industry_dummies = industry_dummies.astype(int)\n",
|
|||
|
|
" X = pd.concat([X, industry_dummies], axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" daily_residuals = daily_data[[col for col in features]].copy() # 创建用于存储残差的df\n",
|
|||
|
|
"\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" Y = daily_data[col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 处理 NaN 值,确保 X 和 Y 在相同位置有有效值\n",
|
|||
|
|
" valid_mask = Y.notna() & X.notna().all(axis=1)\n",
|
|||
|
|
" if valid_mask.sum() < (X.shape[1] + 1): # 数据点不足以估计模型\n",
|
|||
|
|
" print(f\"警告: 日期 {date}, 特征 {col} 有效数据不足 ({valid_mask.sum()}个),无法中性化,填充 NaN。\")\n",
|
|||
|
|
" daily_residuals[col] = np.nan\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" Y_valid = Y[valid_mask]\n",
|
|||
|
|
" X_valid = X[valid_mask]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 执行 OLS 回归\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" model = sm.OLS(Y_valid.to_numpy(), X_valid.to_numpy())\n",
|
|||
|
|
" results = model.fit()\n",
|
|||
|
|
" # 将残差填回对应位置\n",
|
|||
|
|
" daily_residuals.loc[valid_mask, col] = results.resid\n",
|
|||
|
|
" daily_residuals.loc[~valid_mask, col] = np.nan # 原本无效的位置填充 NaN\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" print(f\"警告: 日期 {date}, 特征 {col} 回归失败: {e},填充 NaN。\")\n",
|
|||
|
|
" daily_residuals[col] = np.nan\n",
|
|||
|
|
" break\n",
|
|||
|
|
"\n",
|
|||
|
|
" all_residuals.append(daily_residuals)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 合并所有日期的残差结果\n",
|
|||
|
|
" if all_residuals:\n",
|
|||
|
|
" residuals_df = pd.concat(all_residuals)\n",
|
|||
|
|
" # 将残差结果更新回原始 df (原地修改)\n",
|
|||
|
|
" # 使用 update 比 merge 更适合基于索引的原地更新\n",
|
|||
|
|
" # 确保 residuals_df 的索引与 df 中对应部分一致\n",
|
|||
|
|
" df.update(residuals_df)\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print(\"没有有效的残差结果可以合并。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 清理临时列\n",
|
|||
|
|
" df.drop(columns=[log_cap_col], inplace=True)\n",
|
|||
|
|
" print(\"截面行业市值中性化完成。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 3. Z-Score 标准化 ---\n",
|
|||
|
|
"\n",
|
|||
|
|
"import polars as pl\n",
|
|||
|
|
"import numpy as np # 用于 np.float64, np.nan 等\n",
|
|||
|
|
"\n",
|
|||
|
|
"def cs_zscore_standardize_polars(df: pl.DataFrame,\n",
|
|||
|
|
" features: list,\n",
|
|||
|
|
" epsilon: float = 1e-10) -> pl.DataFrame:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 对指定特征列进行截面 Z-Score 标准化 (Polars 版本)。\n",
|
|||
|
|
" 方法: Z = (value - cross_sectional_mean) / (cross_sectional_std + epsilon)\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df (pl.DataFrame): 输入 Polars DataFrame,需包含 'trade_date' 和 features 列。\n",
|
|||
|
|
" features (list): 需要处理的特征列名列表。\n",
|
|||
|
|
" epsilon (float): 防止除以零的小常数。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" pl.DataFrame: 处理后的 Polars DataFrame (返回新 DataFrame,原 DataFrame 不变)。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" print(\"开始截面 Z-Score 标准化...\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 检查特征列是否存在\n",
|
|||
|
|
" existing_features = [col for col in features if col in df.columns]\n",
|
|||
|
|
" missing_features = [col for col in features if col not in df.columns]\n",
|
|||
|
|
"\n",
|
|||
|
|
" if missing_features:\n",
|
|||
|
|
" print(f\"警告: DataFrame 中缺少以下特征列: {missing_features}。这些列将跳过标准化处理。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" if not existing_features:\n",
|
|||
|
|
" print(\"没有找到需要标准化的特征列。跳过标准化处理。\")\n",
|
|||
|
|
" return df # 返回原始 DataFrame\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 构建一个表达式列表,用于一次性处理所有特征列\n",
|
|||
|
|
" expressions = []\n",
|
|||
|
|
" for col in existing_features:\n",
|
|||
|
|
" col_expr = pl.col(col)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 确保列是数值类型,否则跳过\n",
|
|||
|
|
" if not col_expr.dtype.is_numeric():\n",
|
|||
|
|
" print(f\"警告: 列 '{col}' 不是数值类型 ({col_expr.dtype}),跳过此列的标准化处理。\")\n",
|
|||
|
|
" expressions.append(col_expr) # 保留原始列\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" # 计算截面均值和标准差\n",
|
|||
|
|
" mean_val = col_expr.mean().over('trade_date')\n",
|
|||
|
|
" std_val = col_expr.std().over('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算 Z-Score\n",
|
|||
|
|
" # 这里需要处理 std_val 可能是 0 的情况,即使有 epsilon,也可能因浮点精度问题导致问题\n",
|
|||
|
|
" # 或当 trade_date 组内只有少数几个相同的值时 std_val 可能会是 NaN 或 0\n",
|
|||
|
|
" # 使用 when/then/otherwise 来处理 std_val 为 0 或 NaN 的情况,防止 Inf/NaN 扩散\n",
|
|||
|
|
" z_score_expr = pl.when(std_val.is_null() | (std_val == 0)).then(0).otherwise((col_expr - mean_val) / (std_val + epsilon)).alias(col)\n",
|
|||
|
|
"\n",
|
|||
|
|
" expressions.append(z_score_expr)\n",
|
|||
|
|
"\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" print(f\"警告: 处理列 '{col}' 时发生错误: {e},跳过此列的标准化处理。保留原始列。\")\n",
|
|||
|
|
" expressions.append(col_expr) # 发生错误时保留原始列\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 使用 with_columns 一次性应用所有表达式,创建新的 DataFrame\n",
|
|||
|
|
" result_df = df.with_columns(expressions)\n",
|
|||
|
|
"\n",
|
|||
|
|
" print(\"截面 Z-Score 标准化完成。\")\n",
|
|||
|
|
" return result_df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def fill_nan_with_daily_median(df: pd.DataFrame, feature_columns: list[str]) -> pd.DataFrame:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 对指定特征列进行每日截面中位数填充缺失值 (NaN)。\n",
|
|||
|
|
"\n",
|
|||
|
|
" 参数:\n",
|
|||
|
|
" df (pd.DataFrame): 包含多日数据的DataFrame,需要包含 'trade_date' 和 feature_columns 中的列。\n",
|
|||
|
|
" feature_columns (list[str]): 需要进行缺失值填充的特征列名称列表。\n",
|
|||
|
|
"\n",
|
|||
|
|
" 返回:\n",
|
|||
|
|
" pd.DataFrame: 包含缺失值填充后特征列的DataFrame。在输入DataFrame的副本上操作。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" processed_df = df.copy() # 在副本上操作,保留原始数据\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 确保 trade_date 是 datetime 类型以便正确分组\n",
|
|||
|
|
" processed_df['trade_date'] = pd.to_datetime(processed_df['trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" def _fill_daily_nan(group):\n",
|
|||
|
|
" # group 是某一个交易日的 DataFrame\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 遍历指定的特征列\n",
|
|||
|
|
" for feature_col in feature_columns:\n",
|
|||
|
|
" # 检查列是否存在于当前分组中\n",
|
|||
|
|
" if feature_col in group.columns:\n",
|
|||
|
|
" # 计算当日该特征的中位数\n",
|
|||
|
|
" median_val = group[feature_col].median()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 使用当日中位数填充该特征列的 NaN 值\n",
|
|||
|
|
" # inplace=True 会直接修改 group DataFrame\n",
|
|||
|
|
" group[feature_col].fillna(median_val, inplace=True)\n",
|
|||
|
|
" # else:\n",
|
|||
|
|
" # print(f\"Warning: Feature column '{feature_col}' not found in daily group for {group['trade_date'].iloc[0]}. Skipping.\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" return group\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 按交易日期分组,并应用每日填充函数\n",
|
|||
|
|
" # group_keys=False 避免将分组键添加到结果索引中\n",
|
|||
|
|
" filled_df = processed_df.groupby('trade_date', group_keys=False).apply(_fill_daily_nan)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return filled_df"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 12,
|
|||
|
|
"id": "40e6b68a91b30c79",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T13:08:04.694262Z",
|
|||
|
|
"start_time": "2025-04-03T13:08:03.694904Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def remove_outliers_label_percentile(label: pd.Series, lower_percentile: float = 0.01, upper_percentile: float = 0.99,\n",
|
|||
|
|
" log=True):\n",
|
|||
|
|
" if not (0 <= lower_percentile < upper_percentile <= 1):\n",
|
|||
|
|
" raise ValueError(\"Percentile values must satisfy 0 <= lower_percentile < upper_percentile <= 1.\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Calculate lower and upper bounds based on percentiles\n",
|
|||
|
|
" lower_bound = label.quantile(lower_percentile)\n",
|
|||
|
|
" upper_bound = label.quantile(upper_percentile)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Filter out values outside the bounds\n",
|
|||
|
|
" filtered_label = label[(label >= lower_bound) & (label <= upper_bound)]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # Print the number of removed outliers\n",
|
|||
|
|
" if log:\n",
|
|||
|
|
" print(f\"Removed {len(label) - len(filtered_label)} outliers.\")\n",
|
|||
|
|
" return filtered_label\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def calculate_risk_adjusted_target(df, days=5):\n",
|
|||
|
|
" df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" df['future_close'] = df.groupby('ts_code')['close'].shift(-days)\n",
|
|||
|
|
" df['future_open'] = df.groupby('ts_code')['open'].shift(-1)\n",
|
|||
|
|
" df['future_return'] = (df['future_close'] - df['future_open']) / df['future_open']\n",
|
|||
|
|
"\n",
|
|||
|
|
" df['future_volatility'] = df.groupby('ts_code')['future_return'].rolling(days, min_periods=1).std().reset_index(\n",
|
|||
|
|
" level=0, drop=True)\n",
|
|||
|
|
" sharpe_ratio = df['future_return'] * df['future_volatility']\n",
|
|||
|
|
" sharpe_ratio.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return sharpe_ratio\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def calculate_score(df, days=5, lambda_param=1.0):\n",
|
|||
|
|
" def calculate_max_drawdown(prices):\n",
|
|||
|
|
" peak = prices.iloc[0] # 初始化峰值\n",
|
|||
|
|
" max_drawdown = 0 # 初始化最大回撤\n",
|
|||
|
|
"\n",
|
|||
|
|
" for price in prices:\n",
|
|||
|
|
" if price > peak:\n",
|
|||
|
|
" peak = price # 更新峰值\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" drawdown = (peak - price) / peak # 计算当前回撤\n",
|
|||
|
|
" max_drawdown = max(max_drawdown, drawdown) # 更新最大回撤\n",
|
|||
|
|
"\n",
|
|||
|
|
" return max_drawdown\n",
|
|||
|
|
"\n",
|
|||
|
|
" def compute_stock_score(stock_df):\n",
|
|||
|
|
" stock_df = stock_df.sort_values(by=['trade_date'])\n",
|
|||
|
|
" future_return = stock_df['future_return']\n",
|
|||
|
|
" # 使用已有的 pct_chg 字段计算波动率\n",
|
|||
|
|
" volatility = stock_df['pct_chg'].rolling(days).std().shift(-days)\n",
|
|||
|
|
" max_drawdown = stock_df['close'].rolling(days).apply(calculate_max_drawdown, raw=False).shift(-days)\n",
|
|||
|
|
" score = future_return - lambda_param * max_drawdown\n",
|
|||
|
|
" return score\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # 确保 DataFrame 按照股票代码和交易日期排序\n",
|
|||
|
|
" # df = df.sort_values(by=['ts_code', 'trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对每个股票分别计算 score\n",
|
|||
|
|
" df['score'] = df.groupby('ts_code').apply(compute_stock_score).reset_index(level=0, drop=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df['score']\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def remove_highly_correlated_features(df, feature_columns, threshold=0.9):\n",
|
|||
|
|
" numeric_features = df[feature_columns].select_dtypes(include=[np.number]).columns.tolist()\n",
|
|||
|
|
" if not numeric_features:\n",
|
|||
|
|
" raise ValueError(\"No numeric features found in the provided data.\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" corr_matrix = df[numeric_features].corr().abs()\n",
|
|||
|
|
" upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n",
|
|||
|
|
" to_drop = [column for column in upper.columns if any(upper[column] > threshold)]\n",
|
|||
|
|
" remaining_features = [col for col in feature_columns if col not in to_drop\n",
|
|||
|
|
" or 'act' in col or 'af' in col]\n",
|
|||
|
|
" return remaining_features\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def cross_sectional_standardization(df, features):\n",
|
|||
|
|
" df_sorted = df.sort_values(by='trade_date') # 按时间排序\n",
|
|||
|
|
" df_standardized = df_sorted.copy()\n",
|
|||
|
|
"\n",
|
|||
|
|
" for date in df_sorted['trade_date'].unique():\n",
|
|||
|
|
" # 获取当前时间点的数据\n",
|
|||
|
|
" current_data = df_standardized[df_standardized['trade_date'] == date]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 只对指定特征进行标准化\n",
|
|||
|
|
" scaler = StandardScaler()\n",
|
|||
|
|
" standardized_values = scaler.fit_transform(current_data[features])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将标准化结果重新赋值回去\n",
|
|||
|
|
" df_standardized.loc[df_standardized['trade_date'] == date, features] = standardized_values\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df_standardized\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def neutralize_manual_revised(df: pd.DataFrame, features: list, industry_col: str, mkt_cap_col: str) -> pd.DataFrame:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 手动实现简单回归以提升速度,通过构建 Series 确保索引对齐。\n",
|
|||
|
|
" 对特征在行业内部进行市值中性化。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df: 输入的 DataFrame,包含特征、行业分类和市值列。\n",
|
|||
|
|
" features: 需要进行中性化的特征列名列表。\n",
|
|||
|
|
" industry_col: 行业分类列的列名。\n",
|
|||
|
|
" mkt_cap_col: 市值列的列名。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" 中性化后的 DataFrame。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
"\n",
|
|||
|
|
" df[mkt_cap_col] = pd.to_numeric(df[mkt_cap_col], errors='coerce')\n",
|
|||
|
|
" df_cleaned = df.dropna(subset=[mkt_cap_col]).copy()\n",
|
|||
|
|
" df_cleaned = df_cleaned[df_cleaned[mkt_cap_col] > 0].copy()\n",
|
|||
|
|
"\n",
|
|||
|
|
" if df_cleaned.empty:\n",
|
|||
|
|
" print(\"警告: 清理市值异常值后 DataFrame 为空。\")\n",
|
|||
|
|
" return df # 返回原始或空df,取决于清理前的状态\n",
|
|||
|
|
"\n",
|
|||
|
|
" processed_df = df\n",
|
|||
|
|
"\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" if col not in df_cleaned.columns:\n",
|
|||
|
|
" print(f\"警告: 特征列 '{col}' 不存在于清理后的 DataFrame 中,已跳过。\")\n",
|
|||
|
|
" # 对于原始 df 中该列不存在的,在结果 df 中也保持原样(可能全是NaN)\n",
|
|||
|
|
" processed_df[col] = df[col] if col in df.columns else np.nan\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 跳过对控制变量本身进行中性化\n",
|
|||
|
|
" if col == mkt_cap_col or col == industry_col:\n",
|
|||
|
|
" print(f\"警告: 特征列 '{col}' 是控制变量或内部使用的列,跳过中性化。\")\n",
|
|||
|
|
" # 在结果 df 中也保持原样\n",
|
|||
|
|
" processed_df[col] = df[col] if col in df.columns else np.nan\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" residual_series = pd.Series(index=df_cleaned.index, dtype=float)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 在分组前处理特征列的 NaN,只对有因子值的行进行回归计算\n",
|
|||
|
|
" df_subset_factor = df_cleaned.dropna(subset=[col]).copy()\n",
|
|||
|
|
"\n",
|
|||
|
|
" if not df_subset_factor.empty:\n",
|
|||
|
|
" for industry, group in df_subset_factor.groupby(industry_col):\n",
|
|||
|
|
" x = group[mkt_cap_col] # 市值对数\n",
|
|||
|
|
" y = group[col] # 因子值\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 确保有足够的数据点 (>1) 且市值对数有方差 (>0) 进行回归计算\n",
|
|||
|
|
" # 检查 np.var > 一个很小的正数,避免浮点数误差导致的零方差判断问题\n",
|
|||
|
|
" if len(group) > 1 and np.var(x) > 1e-9:\n",
|
|||
|
|
" try:\n",
|
|||
|
|
" beta = np.cov(y, x)[0, 1] / np.var(x)\n",
|
|||
|
|
" alpha = np.mean(y) - beta * np.mean(x)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算残差\n",
|
|||
|
|
" resid = y - (alpha + beta * x)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将计算出的残差存储到 residual_series 中,通过索引自动对齐\n",
|
|||
|
|
" residual_series.loc[resid.index] = resid\n",
|
|||
|
|
"\n",
|
|||
|
|
" except Exception as e:\n",
|
|||
|
|
" # 捕获可能的计算异常,例如np.cov或np.var因为极端数据报错\n",
|
|||
|
|
" print(f\"警告: 在行业 {industry} 计算回归时发生错误: {e}。该组残差将设为原始值或 NaN。\")\n",
|
|||
|
|
" # 此时该组的残差会保持 residual_series 初始化时的 NaN 或后续处理\n",
|
|||
|
|
" # 也可以选择保留原始值:residual_series.loc[group.index] = group[col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" residual_series.loc[group.index] = group[col] # 保留原始因子值\n",
|
|||
|
|
" processed_df.loc[residual_series.index, col] = residual_series\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" processed_df[col] = np.nan # 或 df[col] if col in df.columns else np.nan\n",
|
|||
|
|
"\n",
|
|||
|
|
" return processed_df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"import gc\n",
|
|||
|
|
"\n",
|
|||
|
|
"gc.collect()\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def mad_filter(df, features, n=3):\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" median = df[col].median()\n",
|
|||
|
|
" mad = np.median(np.abs(df[col] - median))\n",
|
|||
|
|
" upper = median + n * mad\n",
|
|||
|
|
" lower = median - n * mad\n",
|
|||
|
|
" df[col] = np.clip(df[col], lower, upper) # 截断极值\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def percentile_filter(df, features, lower_percentile=0.01, upper_percentile=0.99):\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" # 按日期分组计算上下百分位数\n",
|
|||
|
|
" lower_bound = df.groupby('trade_date')[col].transform(\n",
|
|||
|
|
" lambda x: x.quantile(lower_percentile)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" upper_bound = df.groupby('trade_date')[col].transform(\n",
|
|||
|
|
" lambda x: x.quantile(upper_percentile)\n",
|
|||
|
|
" )\n",
|
|||
|
|
" # 截断超出范围的值\n",
|
|||
|
|
" df[col] = np.clip(df[col], lower_bound, upper_bound)\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"from scipy.stats import iqr\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def iqr_filter(df, features):\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" df[col] = df.groupby('trade_date')[col].transform(\n",
|
|||
|
|
" lambda x: (x - x.median()) / iqr(x) if iqr(x) != 0 else x\n",
|
|||
|
|
" )\n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"def quantile_filter(df, features, lower_quantile=0.01, upper_quantile=0.99, window=60):\n",
|
|||
|
|
" df = df.copy()\n",
|
|||
|
|
" for col in features:\n",
|
|||
|
|
" # 计算 rolling 统计量,需要按日期进行 groupby\n",
|
|||
|
|
" rolling_lower = df.groupby('trade_date')[col].transform(lambda x: x.rolling(window=min(len(x), window)).quantile(lower_quantile))\n",
|
|||
|
|
" rolling_upper = df.groupby('trade_date')[col].transform(lambda x: x.rolling(window=min(len(x), window)).quantile(upper_quantile))\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 对数据进行裁剪\n",
|
|||
|
|
" df[col] = np.clip(df[col], rolling_lower, rolling_upper)\n",
|
|||
|
|
" \n",
|
|||
|
|
" return df\n",
|
|||
|
|
"\n",
|
|||
|
|
"def select_top_features_by_rankic(df: pd.DataFrame, feature_columns: list, n: int, target_column: str = 'future_return') -> list:\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 计算给定特征与目标列的 RankIC,并返回 RankIC 绝对值最高的 n 个特征。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df: 包含特征列和目标列的 Pandas DataFrame。\n",
|
|||
|
|
" feature_columns: 包含所有待评估特征列名的列表。\n",
|
|||
|
|
" n: 希望选取的 RankIC 绝对值最高的特征数量。\n",
|
|||
|
|
" target_column: 目标列的名称,用于计算 RankIC。默认为 'future_return'。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" 包含 RankIC 绝对值最高的 n 个特征列名的列表。\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns\n",
|
|||
|
|
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
|
|||
|
|
" if target_column not in df.columns:\n",
|
|||
|
|
" raise ValueError(f\"目标列 '{target_column}' 不存在于 DataFrame 中。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" rankic_scores = {}\n",
|
|||
|
|
" for feature in numeric_columns:\n",
|
|||
|
|
" if feature not in df.columns:\n",
|
|||
|
|
" print(f\"警告: 特征列 '{feature}' 不存在于 DataFrame 中,已跳过。\")\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 计算特征与目标列的 RankIC (斯皮尔曼相关系数)\n",
|
|||
|
|
" # dropna() 是为了处理缺失值,确保相关性计算不失败\n",
|
|||
|
|
" valid_data = df[[feature, target_column]].dropna()\n",
|
|||
|
|
" if len(valid_data) > 1: # 确保有足够的数据点进行相关性计算\n",
|
|||
|
|
" # 计算斯皮尔曼相关性\n",
|
|||
|
|
" correlation = valid_data[feature].corr(valid_data[target_column], method='spearman')\n",
|
|||
|
|
" rankic_scores[feature] = abs(correlation) # 使用绝对值来衡量相关性强度\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" rankic_scores[feature] = 0 # 数据不足,RankIC设为0或跳过\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将 RankIC 分数转换为 Series 便于排序\n",
|
|||
|
|
" rankic_series = pd.Series(rankic_scores)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 按 RankIC 绝对值降序排序,选取前 n 个特征\n",
|
|||
|
|
" # handle case where n might be larger than available features\n",
|
|||
|
|
" n_actual = min(n, len(rankic_series))\n",
|
|||
|
|
" top_features = rankic_series.sort_values(ascending=False).head(n_actual).index.tolist()\n",
|
|||
|
|
" top_features = [col for col in feature_columns if col in top_features or col not in numeric_columns]\n",
|
|||
|
|
" return top_features\n",
|
|||
|
|
"\n",
|
|||
|
|
"def create_deviation_within_dates(df, feature_columns):\n",
|
|||
|
|
" groupby_col = 'cat_l2_code' # 使用 trade_date 进行分组\n",
|
|||
|
|
" new_columns = {}\n",
|
|||
|
|
" ret_feature_columns = feature_columns[:]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 自动选择所有数值型特征\n",
|
|||
|
|
" num_features = [col for col in feature_columns if 'cat' not in col and 'index' not in col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # num_features = ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'cat_vol_spike', 'obv', 'maobv_6', 'return_5', 'return_10', 'return_20', 'std_return_5', 'std_return_15', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'act_factor5', 'act_factor6', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'alpha_022', 'alpha_003', 'alpha_007', 'alpha_013']\n",
|
|||
|
|
" num_features = [col for col in num_features if 'cat' not in col and 'industry' not in col]\n",
|
|||
|
|
" num_features = [col for col in num_features if 'limit' not in col]\n",
|
|||
|
|
" num_features = [col for col in num_features if 'cyq' not in col]\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 遍历所有数值型特征\n",
|
|||
|
|
" for feature in num_features:\n",
|
|||
|
|
" if feature == 'trade_date': # 不需要对 'trade_date' 计算偏差\n",
|
|||
|
|
" continue\n",
|
|||
|
|
"\n",
|
|||
|
|
" # grouped_mean = df.groupby(['trade_date'])[feature].transform('mean')\n",
|
|||
|
|
" # deviation_col_name = f'deviation_mean_{feature}'\n",
|
|||
|
|
" # new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
|
|||
|
|
" # ret_feature_columns.append(deviation_col_name)\n",
|
|||
|
|
"\n",
|
|||
|
|
" grouped_mean = df.groupby(['trade_date', groupby_col])[feature].transform('mean')\n",
|
|||
|
|
" deviation_col_name = f'deviation_mean_{feature}'\n",
|
|||
|
|
" new_columns[deviation_col_name] = df[feature] - grouped_mean\n",
|
|||
|
|
" ret_feature_columns.append(deviation_col_name)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 将新计算的偏差特征与原始 DataFrame 合并\n",
|
|||
|
|
" df = pd.concat([df, pd.DataFrame(new_columns)], axis=1)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # for feature in ['obv', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4']:\n",
|
|||
|
|
" # df[f'deviation_industry_{feature}'] = df[feature] - df[f'industry_{feature}']\n",
|
|||
|
|
"\n",
|
|||
|
|
" return df, ret_feature_columns\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 13,
|
|||
|
|
"id": "47c12bb34062ae7a",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T14:57:50.841165Z",
|
|||
|
|
"start_time": "2025-04-03T14:49:25.889057Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"import polars as pl\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import gc\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 配置参数 ---\n",
|
|||
|
|
"days = 5\n",
|
|||
|
|
"validation_days = 120 # 此处未使用,保留用于上下文\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 核心处理逻辑: Pandas 转 Polars,处理,再转回 Pandas ---\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 1. Pandas DataFrame 转 Polars DataFrame\n",
|
|||
|
|
"gc.collect() # 清理内存,为 Polars 腾出空间\n",
|
|||
|
|
"df = pl.DataFrame(df)\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 2. Polars 处理\n",
|
|||
|
|
"# 确保数据已排序,对于 shift 操作至关重要\n",
|
|||
|
|
"df = df.sort(['ts_code', 'trade_date'])\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 计算 future_return\n",
|
|||
|
|
"df = df.with_columns(\n",
|
|||
|
|
" (\n",
|
|||
|
|
" (pl.col('close').shift(-days).over('ts_code') / pl.col('close')) - 1\n",
|
|||
|
|
" ).alias('future_return_1'),\n",
|
|||
|
|
" (\n",
|
|||
|
|
" (pl.col('close').shift(-2 * days).over('ts_code') / pl.col('close')) - 1\n",
|
|||
|
|
" ).alias('future_return_2')\n",
|
|||
|
|
").with_columns(\n",
|
|||
|
|
" (pl.col('future_return_1') + pl.col('future_return_2')).alias('future_return')\n",
|
|||
|
|
").drop(['future_return_1', 'future_return_2']) # 删除中间列\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 修正:安全地计算 label ---\n",
|
|||
|
|
"# 使用 pl.when().then().otherwise() 来处理 future_return 中的 NaN\n",
|
|||
|
|
"# 如果 future_return 为空 (NaN),则 label 也设为 None (Polars 中整数列的空值)\n",
|
|||
|
|
"# 否则,计算 qcut 并转换为整数\n",
|
|||
|
|
"df = df.with_columns(\n",
|
|||
|
|
" pl.col(\"future_return\").qcut(50, allow_duplicates=True).over(\"trade_date\").alias(\"label\")\n",
|
|||
|
|
")\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 根据 future_return 的全局分位数进行过滤\n",
|
|||
|
|
"lower_bound_quantile = df['future_return'].quantile(0.001)\n",
|
|||
|
|
"upper_bound_fixed = 0.6 # 固定上限值\n",
|
|||
|
|
"\n",
|
|||
|
|
"# 3. Polars DataFrame 转回 Pandas DataFrame\n",
|
|||
|
|
"# df_final_pd = df.to_pandas()"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 14,
|
|||
|
|
"id": "29221dde",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"205\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"industry_df = pl.DataFrame(industry_df).sort('trade_date')\n",
|
|||
|
|
"index_data = pl.DataFrame(index_data).sort('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
"feature_columns = (\n",
|
|||
|
|
" df.head(10)\n",
|
|||
|
|
" .join(industry_df, on=[\"cat_l2_code\", \"trade_date\"], how=\"left\")\n",
|
|||
|
|
" .join(index_data, on=\"trade_date\", how=\"left\")\n",
|
|||
|
|
" .columns\n",
|
|||
|
|
")\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if col not in ['trade_date',\n",
|
|||
|
|
" 'ts_code',\n",
|
|||
|
|
" 'label']]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'label' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'gen' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'is_st' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]\n",
|
|||
|
|
"# feature_columns = [col for col in feature_columns if 'volatility' not in col]\n",
|
|||
|
|
"# feature_columns = [col for col in feature_columns if 'circ_mv' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if 'code' not in col]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
|
|||
|
|
"# feature_columns = [col for col in feature_columns if col not in ['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if col not in ['intraday_lg_flow_corr_20', \n",
|
|||
|
|
" 'cap_neutral_cost_metric', \n",
|
|||
|
|
" 'hurst_net_mf_vol_60', \n",
|
|||
|
|
" 'complex_factor_deap_1', \n",
|
|||
|
|
" 'lg_buy_consolidation_20',\n",
|
|||
|
|
" 'cs_rank_ind_cap_neutral_pe',\n",
|
|||
|
|
" 'cs_rank_opening_gap',\n",
|
|||
|
|
" 'cs_rank_ind_adj_lg_flow']]\n",
|
|||
|
|
"feature_columns = [col for col in feature_columns if col not in ['cat_reason', 'cat_is_on_top_list']]\n",
|
|||
|
|
"print(len(feature_columns))"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 15,
|
|||
|
|
"id": "03ee5daf",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"# df = fill_nan_with_daily_median(df, feature_columns)\n",
|
|||
|
|
"for feature_col in [col for col in feature_columns if col in df.columns]:\n",
|
|||
|
|
" pl.col(feature_col).fill_null(0)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 16,
|
|||
|
|
"id": "b76ea08a",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"shape: (3, 3)\n",
|
|||
|
|
"┌───────────┬─────────────────────┬─────────────┐\n",
|
|||
|
|
"│ ts_code ┆ trade_date ┆ log_circ_mv │\n",
|
|||
|
|
"│ --- ┆ --- ┆ --- │\n",
|
|||
|
|
"│ str ┆ datetime[ns] ┆ f64 │\n",
|
|||
|
|
"╞═══════════╪═════════════════════╪═════════════╡\n",
|
|||
|
|
"│ 000001.SZ ┆ 2019-01-02 00:00:00 ┆ 16.574219 │\n",
|
|||
|
|
"│ 000001.SZ ┆ 2019-01-03 00:00:00 ┆ 16.583965 │\n",
|
|||
|
|
"│ 000001.SZ ┆ 2019-01-04 00:00:00 ┆ 16.633371 │\n",
|
|||
|
|
"└───────────┴─────────────────────┴─────────────┘\n",
|
|||
|
|
"['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'holder_net_change_sum_10d', 'holder_increase_days_10d', 'holder_decrease_days_10d', 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_24', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_da
|
|||
|
|
"去除极值\n",
|
|||
|
|
"开始截面 MAD 去极值处理 (k=3.0)...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"MAD Filtering: 100%|██████████| 144/144 [00:00<00:00, 170.19it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 MAD 去极值处理完成。\n",
|
|||
|
|
"标准化\n",
|
|||
|
|
"开始截面 Z-Score 标准化...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"Standardizing: 100%|██████████| 144/144 [00:00<00:00, 767.01it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 Z-Score 标准化完成。\n",
|
|||
|
|
"开始截面 MAD 去极值处理 (k=3.0)...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"MAD Filtering: 100%|██████████| 144/144 [00:26<00:00, 5.49it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 MAD 去极值处理完成。\n",
|
|||
|
|
"开始截面 Z-Score 标准化...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"Standardizing: 100%|██████████| 144/144 [00:03<00:00, 37.17it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 Z-Score 标准化完成。\n",
|
|||
|
|
"开始截面 MAD 去极值处理 (k=3.0)...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"MAD Filtering: 0it [00:00, ?it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 MAD 去极值处理完成。\n",
|
|||
|
|
"开始截面 MAD 去极值处理 (k=3.0)...\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stderr",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"MAD Filtering: 0it [00:00, ?it/s]\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"截面 MAD 去极值处理完成。\n",
|
|||
|
|
"feature_columns: ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'holder_net_change_sum_10d', 'holder_increase_days_10d', 'holder_decrease_days_10d', 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_24', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mea
|
|||
|
|
"df最小日期: 2019-01-02\n",
|
|||
|
|
"df最大日期: 2025-05-30\n",
|
|||
|
|
"148609\n",
|
|||
|
|
"train_data最小日期: 2020-01-02\n",
|
|||
|
|
"train_data最大日期: 2020-03-31\n",
|
|||
|
|
"3591195\n",
|
|||
|
|
"test_data最小日期: 2020-03-31\n",
|
|||
|
|
"test_data最大日期: 2025-05-30\n",
|
|||
|
|
"shape: (3, 3)\n",
|
|||
|
|
"┌───────────┬─────────────────────┬─────────────┐\n",
|
|||
|
|
"│ ts_code ┆ trade_date ┆ log_circ_mv │\n",
|
|||
|
|
"│ --- ┆ --- ┆ --- │\n",
|
|||
|
|
"│ str ┆ datetime[ns] ┆ f64 │\n",
|
|||
|
|
"╞═══════════╪═════════════════════╪═════════════╡\n",
|
|||
|
|
"│ 000001.SZ ┆ 2019-01-02 00:00:00 ┆ 16.574219 │\n",
|
|||
|
|
"│ 000001.SZ ┆ 2019-01-03 00:00:00 ┆ 16.583965 │\n",
|
|||
|
|
"│ 000001.SZ ┆ 2019-01-04 00:00:00 ┆ 16.633371 │\n",
|
|||
|
|
"└───────────┴─────────────────────┴─────────────┘\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"from main.utils.data_process import * \n",
|
|||
|
|
"\n",
|
|||
|
|
"split_date = '2023-01-01'\n",
|
|||
|
|
"split_date = pd.to_datetime('2020-03-31') # 将你的分割日期转换为 Pandas Timestamp\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 训练数据处理 ---\n",
|
|||
|
|
"train_data = df.filter(\n",
|
|||
|
|
" (pl.col('trade_date') <= split_date) &\n",
|
|||
|
|
" (pl.col('trade_date') >= pd.to_datetime('2020-01-01')) &\n",
|
|||
|
|
" (pl.col('future_return') >= lower_bound_quantile) &\n",
|
|||
|
|
" (pl.col('future_return') <= upper_bound_fixed)\n",
|
|||
|
|
")\n",
|
|||
|
|
"\n",
|
|||
|
|
"# --- 测试数据处理 ---\n",
|
|||
|
|
"test_data = (\n",
|
|||
|
|
" df.filter(\n",
|
|||
|
|
" pl.col('trade_date') >= split_date\n",
|
|||
|
|
" )\n",
|
|||
|
|
")\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(df[['ts_code', 'trade_date', 'log_circ_mv']].head(3))\n",
|
|||
|
|
"\n",
|
|||
|
|
"numeric_columns = [\n",
|
|||
|
|
" col for col, dtype in zip(df.columns, df.dtypes)\n",
|
|||
|
|
" if dtype in [pl.Float64, pl.Float32, pl.Int64, pl.Int32]\n",
|
|||
|
|
"]\n",
|
|||
|
|
"numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
|
|||
|
|
"print(feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_data = train_data.with_columns([\n",
|
|||
|
|
" pl.when(pl.col(col).is_infinite())\n",
|
|||
|
|
" .then(np.nan)\n",
|
|||
|
|
" .otherwise(pl.col(col))\n",
|
|||
|
|
" .alias(col)\n",
|
|||
|
|
" for col in numeric_columns\n",
|
|||
|
|
"])\n",
|
|||
|
|
"test_data = test_data.with_columns([\n",
|
|||
|
|
" pl.when(pl.col(col).is_infinite())\n",
|
|||
|
|
" .then(np.nan)\n",
|
|||
|
|
" .otherwise(pl.col(col))\n",
|
|||
|
|
" .alias(col)\n",
|
|||
|
|
" for col in numeric_columns\n",
|
|||
|
|
"])\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_data = train_data.drop_nulls(subset=[col for col in feature_columns if col in train_data.columns])\n",
|
|||
|
|
"test_data = test_data.drop_nulls(subset=[col for col in feature_columns if col in test_data.columns])\n",
|
|||
|
|
"\n",
|
|||
|
|
"transform_feature_columns = feature_columns\n",
|
|||
|
|
"transform_feature_columns = [col for col in transform_feature_columns if col in feature_columns and not col.startswith('cat') and col in train_data.columns]\n",
|
|||
|
|
"# transform_feature_columns.remove('undist_profit_ps')\n",
|
|||
|
|
"print('去除极值')\n",
|
|||
|
|
"train_data = train_data.to_pandas()\n",
|
|||
|
|
"test_data = test_data.to_pandas()\n",
|
|||
|
|
"gc.collect()\n",
|
|||
|
|
"cs_mad_filter(train_data, transform_feature_columns)\n",
|
|||
|
|
"print('标准化')\n",
|
|||
|
|
"cs_zscore_standardize(train_data, transform_feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"cs_mad_filter(test_data, transform_feature_columns)\n",
|
|||
|
|
"cs_zscore_standardize(test_data, transform_feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"mad_filter_feature_columns = [col for col in feature_columns if col not in transform_feature_columns and not col.startswith('cat') and col in train_data.columns]\n",
|
|||
|
|
"cs_mad_filter(train_data, mad_filter_feature_columns)\n",
|
|||
|
|
"cs_mad_filter(test_data, mad_filter_feature_columns)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(f'feature_columns: {feature_columns}')\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(f\"df最小日期: {df['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(f\"df最大日期: {df['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(len(train_data))\n",
|
|||
|
|
"print(f\"train_data最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(f\"train_data最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(len(test_data))\n",
|
|||
|
|
"print(f\"test_data最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"print(f\"test_data最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"cat_columns = [col for col in feature_columns if col.startswith('cat')]\n",
|
|||
|
|
"for col in cat_columns:\n",
|
|||
|
|
" train_data[col] = train_data[col].astype('category')\n",
|
|||
|
|
" test_data[col] = test_data[col].astype('category')\n",
|
|||
|
|
"\n",
|
|||
|
|
"print(df[['ts_code', 'trade_date', 'log_circ_mv']].head(3))\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 17,
|
|||
|
|
"id": "3ff2d1c5",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
|
"from sklearn.linear_model import LogisticRegression\n",
|
|||
|
|
"import matplotlib.pyplot as plt # 保持 matplotlib 导入,尽管LightGBM的绘图功能已移除\n",
|
|||
|
|
"from sklearn.decomposition import PCA\n",
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"import datetime # 用于日期计算\n",
|
|||
|
|
"from catboost import CatBoostClassifier, CatBoostRanker, CatBoostRegressor\n",
|
|||
|
|
"from catboost import Pool\n",
|
|||
|
|
"import lightgbm as lgb\n",
|
|||
|
|
"from lightgbm import LGBMRanker, LGBMRegressor\n",
|
|||
|
|
"\n",
|
|||
|
|
"def train_model(train_data_df, feature_columns,\n",
|
|||
|
|
" print_info=True, # 调整参数名,更通用\n",
|
|||
|
|
" validation_days=180, use_pca=False, split_date=None,\n",
|
|||
|
|
" target_column='label', type='light'): # 增加目标列参数\n",
|
|||
|
|
"\n",
|
|||
|
|
" print('train data size: ', len(train_data_df))\n",
|
|||
|
|
" print(train_data_df[['ts_code', 'trade_date', 'log_circ_mv']])\n",
|
|||
|
|
" # 确保数据按时间排序\n",
|
|||
|
|
" train_data_df = train_data_df.sort_values(by='trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 去除标签为空的样本\n",
|
|||
|
|
" initial_len = len(train_data_df)\n",
|
|||
|
|
" train_data_df = train_data_df.dropna(subset=[target_column])\n",
|
|||
|
|
"\n",
|
|||
|
|
" if print_info:\n",
|
|||
|
|
" print(f'原始样本数: {initial_len}, 去除标签为空后样本数: {len(train_data_df)}')\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 提取特征和标签,只取数值型特征用于线性回归\n",
|
|||
|
|
" \n",
|
|||
|
|
" if split_date is None:\n",
|
|||
|
|
" all_dates = train_data_df['trade_date'].unique() # 获取所有唯一的 trade_date\n",
|
|||
|
|
" split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
|
|||
|
|
" train_data_split = train_data_df[train_data_df['trade_date'] < split_date] # 训练集\n",
|
|||
|
|
" val_data_split = train_data_df[train_data_df['trade_date'] >= split_date] # 验证集\n",
|
|||
|
|
"\n",
|
|||
|
|
" train_data_split = train_data_split.sort_values('trade_date')\n",
|
|||
|
|
" val_data_split = val_data_split.sort_values('trade_date')\n",
|
|||
|
|
"\n",
|
|||
|
|
" \n",
|
|||
|
|
" X_train = train_data_split[feature_columns]\n",
|
|||
|
|
" y_train = train_data_split[target_column]\n",
|
|||
|
|
" \n",
|
|||
|
|
" X_val = val_data_split[feature_columns]\n",
|
|||
|
|
" y_val = val_data_split[target_column]\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # 标准化数值特征 (使用 StandardScaler 对训练集fit并transform, 对验证集只transform)\n",
|
|||
|
|
" scaler = StandardScaler()\n",
|
|||
|
|
" # X_train = scaler.fit_transform(X_train)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 训练线性回归模型\n",
|
|||
|
|
" # model = LogisticRegression(random_state=42)\n",
|
|||
|
|
" \n",
|
|||
|
|
" # # 使用处理后的特征和样本权重进行训练\n",
|
|||
|
|
" # model.fit(X_train, y_train)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" if type == 'cat':\n",
|
|||
|
|
" params = {\n",
|
|||
|
|
" 'loss_function': 'QueryRMSE', # 适用于二分类\n",
|
|||
|
|
" 'eval_metric': 'NDCG', # 评估指标\n",
|
|||
|
|
" 'iterations': 1500,\n",
|
|||
|
|
" 'learning_rate': 0.03,\n",
|
|||
|
|
" 'depth': 8, # 控制模型复杂度\n",
|
|||
|
|
" 'l2_leaf_reg': 1, # L2 正则化\n",
|
|||
|
|
" 'verbose': 5000,\n",
|
|||
|
|
" 'early_stopping_rounds': 300,\n",
|
|||
|
|
" 'one_hot_max_size': 50,\n",
|
|||
|
|
" # 'class_weights': [0.6, 1.2],\n",
|
|||
|
|
" 'task_type': 'GPU',\n",
|
|||
|
|
" 'has_time': True,\n",
|
|||
|
|
" 'random_seed': 7\n",
|
|||
|
|
" }\n",
|
|||
|
|
" cat_features = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
|
|||
|
|
" group_train = train_data_split['trade_date'].factorize()[0]\n",
|
|||
|
|
" group_val = val_data_split['trade_date'].factorize()[0]\n",
|
|||
|
|
" train_pool = Pool(\n",
|
|||
|
|
" data=X_train,\n",
|
|||
|
|
" label=y_train,\n",
|
|||
|
|
" group_id=group_train,\n",
|
|||
|
|
" cat_features=cat_features\n",
|
|||
|
|
" )\n",
|
|||
|
|
" val_pool = Pool(\n",
|
|||
|
|
" data=X_val,\n",
|
|||
|
|
" label=y_val,\n",
|
|||
|
|
" group_id=group_val,\n",
|
|||
|
|
" cat_features=cat_features\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" model = CatBoostRanker(**params)\n",
|
|||
|
|
" model.fit(train_pool,\n",
|
|||
|
|
" eval_set=val_pool, \n",
|
|||
|
|
" plot=True, \n",
|
|||
|
|
" use_best_model=True\n",
|
|||
|
|
" )\n",
|
|||
|
|
" elif type == 'light':\n",
|
|||
|
|
" label_gain = list(range(len(train_data_split[target_column].unique())))\n",
|
|||
|
|
" \n",
|
|||
|
|
" params = {\n",
|
|||
|
|
" 'label_gain': [gain * gain for gain in label_gain],\n",
|
|||
|
|
" 'objective': 'lambdarank',\n",
|
|||
|
|
" 'metric': 'ndcg',\n",
|
|||
|
|
" 'learning_rate': 0.01,\n",
|
|||
|
|
" # 'num_leaves': 1024,\n",
|
|||
|
|
" # 'min_data_in_leaf': 256,\n",
|
|||
|
|
" # 'max_depth': 10,\n",
|
|||
|
|
" # 'max_bin': 1024,\n",
|
|||
|
|
" 'feature_fraction': 0.5,\n",
|
|||
|
|
" 'bagging_fraction': 0.5,\n",
|
|||
|
|
" 'bagging_freq': 5,\n",
|
|||
|
|
" # 'lambda_l1': 1,\n",
|
|||
|
|
" 'lambda_l2': 50,\n",
|
|||
|
|
" 'boosting': 'gbdt',\n",
|
|||
|
|
" 'verbosity': -1,\n",
|
|||
|
|
" 'extra_trees': True,\n",
|
|||
|
|
" # 'max_position': 5,\n",
|
|||
|
|
" 'ndcg_at': '5',\n",
|
|||
|
|
" 'quant_train_renew_leaf': True,\n",
|
|||
|
|
" 'lambdarank_truncation_level': 10,\n",
|
|||
|
|
" # 'lambdarank_position_bias_regularization': 1,\n",
|
|||
|
|
" 'seed': 7\n",
|
|||
|
|
" }\n",
|
|||
|
|
" # feature_contri = [2 if feat.startswith('act_factor') or 'buy' in feat or 'sell' in feat else 1 for feat in feature_columns]\n",
|
|||
|
|
" # params['feature_contri'] = feature_contri\n",
|
|||
|
|
"\n",
|
|||
|
|
" train_groups = train_data_split.groupby('trade_date').size().tolist()\n",
|
|||
|
|
" val_groups = val_data_split.groupby('trade_date').size().tolist()\n",
|
|||
|
|
"\n",
|
|||
|
|
" categorical_feature = [col for col in feature_columns if 'cat' in col]\n",
|
|||
|
|
" train_dataset = lgb.Dataset(\n",
|
|||
|
|
" X_train, label=y_train, \n",
|
|||
|
|
" group=train_groups,\n",
|
|||
|
|
" categorical_feature=categorical_feature\n",
|
|||
|
|
" )\n",
|
|||
|
|
" val_dataset = lgb.Dataset(\n",
|
|||
|
|
" X_val, label=y_val, \n",
|
|||
|
|
" group=val_groups,\n",
|
|||
|
|
" categorical_feature=categorical_feature\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" evals = {}\n",
|
|||
|
|
" callbacks = [lgb.log_evaluation(period=1000),\n",
|
|||
|
|
" lgb.callback.record_evaluation(evals),\n",
|
|||
|
|
" lgb.early_stopping(300, first_metric_only=False)\n",
|
|||
|
|
" ]\n",
|
|||
|
|
" # 训练模型\n",
|
|||
|
|
" model = lgb.train(\n",
|
|||
|
|
" params, train_dataset, num_boost_round=1000,\n",
|
|||
|
|
" valid_sets=[train_dataset, val_dataset], valid_names=['train', 'valid'],\n",
|
|||
|
|
" callbacks=callbacks\n",
|
|||
|
|
" )\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 打印特征重要性(如果需要)\n",
|
|||
|
|
" if True:\n",
|
|||
|
|
" lgb.plot_metric(evals)\n",
|
|||
|
|
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
|
|||
|
|
" plt.show()\n",
|
|||
|
|
"\n",
|
|||
|
|
" # from flaml import AutoML\n",
|
|||
|
|
" # from sklearn.datasets import fetch_california_housing\n",
|
|||
|
|
"\n",
|
|||
|
|
" # # Initialize an AutoML instance\n",
|
|||
|
|
" # model = AutoML()\n",
|
|||
|
|
" # # Specify automl goal and constraint\n",
|
|||
|
|
" # automl_settings = {\n",
|
|||
|
|
" # \"time_budget\": 600, # in seconds\n",
|
|||
|
|
" # \"metric\": \"ndcg@1\",\n",
|
|||
|
|
" # \"task\": \"rank\",\n",
|
|||
|
|
" # \"estimator_list\": [\n",
|
|||
|
|
" # \"catboost\",\n",
|
|||
|
|
" # \"lgbm\",\n",
|
|||
|
|
" # \"xgboost\"\n",
|
|||
|
|
" # ], \n",
|
|||
|
|
" # \"ensemble\": {\n",
|
|||
|
|
" # \"final_estimator\": LGBMRanker(),\n",
|
|||
|
|
" # \"passthrough\": False,\n",
|
|||
|
|
" # },\n",
|
|||
|
|
" # }\n",
|
|||
|
|
" # model.fit(X_train=X_train, y_train=y_train, groups=train_groups,\n",
|
|||
|
|
" # X_val=X_val, y_val=y_val,groups_val=val_groups,\n",
|
|||
|
|
" # mlflow_logging=False, **automl_settings)\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
" return model, scaler, None # 返回训练好的模型、scaler 和 pca 对象"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 18,
|
|||
|
|
"id": "c6eb5cd4-e714-420a-ac48-39af3e11ee81",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T15:03:18.426481Z",
|
|||
|
|
"start_time": "2025-04-03T15:02:19.926352Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"ename": "KeyboardInterrupt",
|
|||
|
|
"evalue": "",
|
|||
|
|
"output_type": "error",
|
|||
|
|
"traceback": [
|
|||
|
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|||
|
|
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
|
|||
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[18]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28mtype\u001b[39m = \u001b[33m'\u001b[39m\u001b[33mlight\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 6\u001b[39m train_data[\u001b[33m'\u001b[39m\u001b[33mlabel2\u001b[39m\u001b[33m'\u001b[39m] = train_data.groupby(\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m, group_keys=\u001b[38;5;28;01mFalse\u001b[39;00m).apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x.nsmallest(\u001b[32m1000\u001b[39m, \u001b[33m'\u001b[39m\u001b[33mtotal_mv\u001b[39m\u001b[33m'\u001b[39m))[\u001b[33m'\u001b[39m\u001b[33mfuture_return\u001b[39m\u001b[33m'\u001b[39m].transform(\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x: pd.qcut(x, q=\u001b[32m50\u001b[39m, labels=\u001b[38;5;28;01mFalse\u001b[39;00m, duplicates=\u001b[33m'\u001b[39m\u001b[33mdrop\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 8\u001b[39m )\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m test_data[\u001b[33m'\u001b[39m\u001b[33mlabel2\u001b[39m\u001b[33m'\u001b[39m] = \u001b[43mtest_data\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgroupby\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mtrade_date\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgroup_keys\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mnsmallest\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m1000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mtotal_mv\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m[\u001b[33m'\u001b[39m\u001b[33mfuture_return\u001b[39m\u001b[33m'\u001b[39m].transform(\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x: pd.qcut(x, q=\u001b[32m50\u001b[39m, labels=\u001b[38;5;28;01mFalse\u001b[39;00m, duplicates=\u001b[33m'\u001b[39m\u001b[33mdrop\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 11\u001b[39m )\n\u001b[32m 13\u001b[39m \u001b[38;5;66;03m# feature_contri = [2 if feat.startswith('act_factor') or 'buy' in feat or 'sell' in feat else 1 for feat in feature_columns]\u001b[39;00m\n\u001b[32m 14\u001b[39m \u001b[38;5;66;03m# light_params['feature_contri'] = feature_contri\u001b[39;00m\n\u001b[32m 15\u001b[39m \u001b[38;5;66;03m# print(f'feature_contri: {feature_contri}')\u001b[39;00m\n\u001b[32m 16\u001b[39m model, scaler, pca = train_model(train_data.groupby(\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m, group_keys=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 17\u001b[39m .apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x.nsmallest(\u001b[32m1000\u001b[39m, \u001b[33m'\u001b[39m\u001b[33mtotal_mv\u001b[39m\u001b[33m'\u001b[39m))\n\u001b[32m 18\u001b[39m .merge(industry_df, on=[\u001b[33m'\u001b[39m\u001b[33mcat_l2_code\u001b[39m\u001b[33m'\u001b[39m, \u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m], how=\u001b[33m'\u001b[39m\u001b[33mleft\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 19\u001b[39m .merge(index_data, on=\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m, how=\u001b[33m'\u001b[39m\u001b[33mleft\u001b[39m\u001b[33m'\u001b[39m), \n\u001b[32m 20\u001b[39m feature_columns, \u001b[38;5;28mtype\u001b[39m=\u001b[38;5;28mtype\u001b[39m, target_column=\u001b[33m'\u001b[39m\u001b[33mlabel2\u001b[39m\u001b[33m'\u001b[39m)\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/groupby/groupby.py:1824\u001b[39m, in \u001b[36mGroupBy.apply\u001b[39m\u001b[34m(self, func, include_groups, *args, **kwargs)\u001b[39m\n\u001b[32m 1822\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m option_context(\u001b[33m\"\u001b[39m\u001b[33mmode.chained_assignment\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 1823\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1824\u001b[39m result = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_python_apply_general\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_selected_obj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1825\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 1826\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m.obj, Series)\n\u001b[32m 1827\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._selection \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1828\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._selected_obj.shape != \u001b[38;5;28mself\u001b[39m._obj_with_exclusions.shape\n\u001b[32m 1829\u001b[39m ):\n\u001b[32m 1830\u001b[39m warnings.warn(\n\u001b[32m 1831\u001b[39m message=_apply_groupings_depr.format(\n\u001b[32m 1832\u001b[39m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m).\u001b[34m__name__\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mapply\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m (...)\u001b[39m\u001b[32m 1835\u001b[39m stacklevel=find_stack_level(),\n\u001b[32m 1836\u001b[39m )\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/groupby/groupby.py:1885\u001b[39m, in \u001b[36mGroupBy._python_apply_general\u001b[39m\u001b[34m(self, f, data, not_indexed_same, is_transform, is_agg)\u001b[39m\n\u001b[32m 1850\u001b[39m \u001b[38;5;129m@final\u001b[39m\n\u001b[32m 1851\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_python_apply_general\u001b[39m(\n\u001b[32m 1852\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m (...)\u001b[39m\u001b[32m 1857\u001b[39m is_agg: \u001b[38;5;28mbool\u001b[39m = \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[32m 1858\u001b[39m ) -> NDFrameT:\n\u001b[32m 1859\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 1860\u001b[39m \u001b[33;03m Apply function f in python space\u001b[39;00m\n\u001b[32m 1861\u001b[39m \n\u001b[32m (...)\u001b[39m\u001b[32m 1883\u001b[39m \u001b[33;03m data after applying f\u001b[39;00m\n\u001b[32m 1884\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1885\u001b[39m values, mutated = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_grouper\u001b[49m\u001b[43m.\u001b[49m\u001b[43mapply_groupwise\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1886\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m not_indexed_same \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 1887\u001b[39m not_indexed_same = mutated\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/groupby/ops.py:919\u001b[39m, in \u001b[36mBaseGrouper.apply_groupwise\u001b[39m\u001b[34m(self, f, data, axis)\u001b[39m\n\u001b[32m 917\u001b[39m \u001b[38;5;66;03m# group might be modified\u001b[39;00m\n\u001b[32m 918\u001b[39m group_axes = group.axes\n\u001b[32m--> \u001b[39m\u001b[32m919\u001b[39m res = \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgroup\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 920\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m mutated \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_indexed_like(res, group_axes, axis):\n\u001b[32m 921\u001b[39m mutated = \u001b[38;5;28;01mTrue\u001b[39;00m\n",
|
|||
|
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[18]\u001b[39m\u001b[32m, line 9\u001b[39m, in \u001b[36m<lambda>\u001b[39m\u001b[34m(x)\u001b[39m\n\u001b[32m 4\u001b[39m \u001b[38;5;28mtype\u001b[39m = \u001b[33m'\u001b[39m\u001b[33mlight\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 6\u001b[39m train_data[\u001b[33m'\u001b[39m\u001b[33mlabel2\u001b[39m\u001b[33m'\u001b[39m] = train_data.groupby(\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m, group_keys=\u001b[38;5;28;01mFalse\u001b[39;00m).apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x.nsmallest(\u001b[32m1000\u001b[39m, \u001b[33m'\u001b[39m\u001b[33mtotal_mv\u001b[39m\u001b[33m'\u001b[39m))[\u001b[33m'\u001b[39m\u001b[33mfuture_return\u001b[39m\u001b[33m'\u001b[39m].transform(\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x: pd.qcut(x, q=\u001b[32m50\u001b[39m, labels=\u001b[38;5;28;01mFalse\u001b[39;00m, duplicates=\u001b[33m'\u001b[39m\u001b[33mdrop\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 8\u001b[39m )\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m test_data[\u001b[33m'\u001b[39m\u001b[33mlabel2\u001b[39m\u001b[33m'\u001b[39m] = test_data.groupby(\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m, group_keys=\u001b[38;5;28;01mFalse\u001b[39;00m).apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[43mx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mnsmallest\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m1000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mtotal_mv\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m)[\u001b[33m'\u001b[39m\u001b[33mfuture_return\u001b[39m\u001b[33m'\u001b[39m].transform(\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x: pd.qcut(x, q=\u001b[32m50\u001b[39m, labels=\u001b[38;5;28;01mFalse\u001b[39;00m, duplicates=\u001b[33m'\u001b[39m\u001b[33mdrop\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 11\u001b[39m )\n\u001b[32m 13\u001b[39m \u001b[38;5;66;03m# feature_contri = [2 if feat.startswith('act_factor') or 'buy' in feat or 'sell' in feat else 1 for feat in feature_columns]\u001b[39;00m\n\u001b[32m 14\u001b[39m \u001b[38;5;66;03m# light_params['feature_contri'] = feature_contri\u001b[39;00m\n\u001b[32m 15\u001b[39m \u001b[38;5;66;03m# print(f'feature_contri: {feature_contri}')\u001b[39;00m\n\u001b[32m 16\u001b[39m model, scaler, pca = train_model(train_data.groupby(\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m, group_keys=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 17\u001b[39m .apply(\u001b[38;5;28;01mlambda\u001b[39;00m x: x.nsmallest(\u001b[32m1000\u001b[39m, \u001b[33m'\u001b[39m\u001b[33mtotal_mv\u001b[39m\u001b[33m'\u001b[39m))\n\u001b[32m 18\u001b[39m .merge(industry_df, on=[\u001b[33m'\u001b[39m\u001b[33mcat_l2_code\u001b[39m\u001b[33m'\u001b[39m, \u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m], how=\u001b[33m'\u001b[39m\u001b[33mleft\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m 19\u001b[39m .merge(index_data, on=\u001b[33m'\u001b[39m\u001b[33mtrade_date\u001b[39m\u001b[33m'\u001b[39m, how=\u001b[33m'\u001b[39m\u001b[33mleft\u001b[39m\u001b[33m'\u001b[39m), \n\u001b[32m 20\u001b[39m feature_columns, \u001b[38;5;28mtype\u001b[39m=\u001b[38;5;28mtype\u001b[39m, target_column=\u001b[33m'\u001b[39m\u001b[33mlabel2\u001b[39m\u001b[33m'\u001b[39m)\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/frame.py:7756\u001b[39m, in \u001b[36mDataFrame.nsmallest\u001b[39m\u001b[34m(self, n, columns, keep)\u001b[39m\n\u001b[32m 7646\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mnsmallest\u001b[39m(\n\u001b[32m 7647\u001b[39m \u001b[38;5;28mself\u001b[39m, n: \u001b[38;5;28mint\u001b[39m, columns: IndexLabel, keep: NsmallestNlargestKeep = \u001b[33m\"\u001b[39m\u001b[33mfirst\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 7648\u001b[39m ) -> DataFrame:\n\u001b[32m 7649\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 7650\u001b[39m \u001b[33;03m Return the first `n` rows ordered by `columns` in ascending order.\u001b[39;00m\n\u001b[32m 7651\u001b[39m \n\u001b[32m (...)\u001b[39m\u001b[32m 7754\u001b[39m \u001b[33;03m Nauru 337000 182 NR\u001b[39;00m\n\u001b[32m 7755\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m7756\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mselectn\u001b[49m\u001b[43m.\u001b[49m\u001b[43mSelectNFrame\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep\u001b[49m\u001b[43m=\u001b[49m\u001b[43mkeep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolumns\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcolumns\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mnsmallest\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/methods/selectn.py:61\u001b[39m, in \u001b[36mSelectN.nsmallest\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 59\u001b[39m \u001b[38;5;129m@final\u001b[39m\n\u001b[32m 60\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mnsmallest\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m---> \u001b[39m\u001b[32m61\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mnsmallest\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/methods/selectn.py:218\u001b[39m, in \u001b[36mSelectNFrame.compute\u001b[39m\u001b[34m(self, method)\u001b[39m\n\u001b[32m 216\u001b[39m \u001b[38;5;66;03m# Below we save and reset the index in case index contains duplicates\u001b[39;00m\n\u001b[32m 217\u001b[39m original_index = frame.index\n\u001b[32m--> \u001b[39m\u001b[32m218\u001b[39m cur_frame = frame = \u001b[43mframe\u001b[49m\u001b[43m.\u001b[49m\u001b[43mreset_index\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdrop\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 219\u001b[39m cur_n = n\n\u001b[32m 220\u001b[39m indexer = Index([], dtype=np.int64)\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/frame.py:6417\u001b[39m, in \u001b[36mDataFrame.reset_index\u001b[39m\u001b[34m(self, level, drop, inplace, col_level, col_fill, allow_duplicates, names)\u001b[39m\n\u001b[32m 6415\u001b[39m new_obj = \u001b[38;5;28mself\u001b[39m\n\u001b[32m 6416\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m6417\u001b[39m new_obj = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdeep\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 6418\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m allow_duplicates \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m lib.no_default:\n\u001b[32m 6419\u001b[39m allow_duplicates = validate_bool_kwarg(allow_duplicates, \u001b[33m\"\u001b[39m\u001b[33mallow_duplicates\u001b[39m\u001b[33m\"\u001b[39m)\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/generic.py:6811\u001b[39m, in \u001b[36mNDFrame.copy\u001b[39m\u001b[34m(self, deep)\u001b[39m\n\u001b[32m 6662\u001b[39m \u001b[38;5;129m@final\u001b[39m\n\u001b[32m 6663\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcopy\u001b[39m(\u001b[38;5;28mself\u001b[39m, deep: bool_t | \u001b[38;5;28;01mNone\u001b[39;00m = \u001b[38;5;28;01mTrue\u001b[39;00m) -> Self:\n\u001b[32m 6664\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 6665\u001b[39m \u001b[33;03m Make a copy of this object's indices and data.\u001b[39;00m\n\u001b[32m 6666\u001b[39m \n\u001b[32m (...)\u001b[39m\u001b[32m 6809\u001b[39m \u001b[33;03m dtype: int64\u001b[39;00m\n\u001b[32m 6810\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m6811\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_mgr\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdeep\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdeep\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 6812\u001b[39m \u001b[38;5;28mself\u001b[39m._clear_item_cache()\n\u001b[32m 6813\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._constructor_from_mgr(data, axes=data.axes).__finalize__(\n\u001b[32m 6814\u001b[39m \u001b[38;5;28mself\u001b[39m, method=\u001b[33m\"\u001b[39m\u001b[33mcopy\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 6815\u001b[39m )\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/internals/managers.py:593\u001b[39m, in \u001b[36mBaseBlockManager.copy\u001b[39m\u001b[34m(self, deep)\u001b[39m\n\u001b[32m 590\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 591\u001b[39m new_axes = \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mself\u001b[39m.axes)\n\u001b[32m--> \u001b[39m\u001b[32m593\u001b[39m res = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcopy\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeep\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdeep\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 594\u001b[39m res.axes = new_axes\n\u001b[32m 596\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.ndim > \u001b[32m1\u001b[39m:\n\u001b[32m 597\u001b[39m \u001b[38;5;66;03m# Avoid needing to re-compute these\u001b[39;00m\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/internals/managers.py:363\u001b[39m, in \u001b[36mBaseBlockManager.apply\u001b[39m\u001b[34m(self, f, align_keys, **kwargs)\u001b[39m\n\u001b[32m 361\u001b[39m applied = b.apply(f, **kwargs)\n\u001b[32m 362\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m363\u001b[39m applied = \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 364\u001b[39m result_blocks = extend_blocks(applied, result_blocks)\n\u001b[32m 366\u001b[39m out = \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m).from_blocks(result_blocks, \u001b[38;5;28mself\u001b[39m.axes)\n",
|
|||
|
|
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/stock/lib/python3.13/site-packages/pandas/core/internals/blocks.py:796\u001b[39m, in \u001b[36mBlock.copy\u001b[39m\u001b[34m(self, deep)\u001b[39m\n\u001b[32m 794\u001b[39m refs: BlockValuesRefs | \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 795\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m deep:\n\u001b[32m--> \u001b[39m\u001b[32m796\u001b[39m values = \u001b[43mvalues\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 797\u001b[39m refs = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 798\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
|||
|
|
"\u001b[31mKeyboardInterrupt\u001b[39m: "
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"\n",
|
|||
|
|
"gc.collect()\n",
|
|||
|
|
"\n",
|
|||
|
|
"use_pca = False\n",
|
|||
|
|
"type = 'light'\n",
|
|||
|
|
"\n",
|
|||
|
|
"train_data['label2'] = train_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(1000, 'total_mv'))['future_return'].transform(\n",
|
|||
|
|
" lambda x: pd.qcut(x, q=50, labels=False, duplicates='drop')\n",
|
|||
|
|
")\n",
|
|||
|
|
"test_data['label2'] = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(1000, 'total_mv'))['future_return'].transform(\n",
|
|||
|
|
" lambda x: pd.qcut(x, q=50, labels=False, duplicates='drop')\n",
|
|||
|
|
")\n",
|
|||
|
|
"\n",
|
|||
|
|
"# feature_contri = [2 if feat.startswith('act_factor') or 'buy' in feat or 'sell' in feat else 1 for feat in feature_columns]\n",
|
|||
|
|
"# light_params['feature_contri'] = feature_contri\n",
|
|||
|
|
"# print(f'feature_contri: {feature_contri}')\n",
|
|||
|
|
"model, scaler, pca = train_model(train_data.groupby('trade_date', group_keys=False)\n",
|
|||
|
|
" .apply(lambda x: x.nsmallest(1000, 'total_mv'))\n",
|
|||
|
|
" .merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
" .merge(index_data, on='trade_date', how='left'), \n",
|
|||
|
|
" feature_columns, type=type, target_column='label2')\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "5d1522a7538db91b",
|
|||
|
|
"metadata": {
|
|||
|
|
"ExecuteTime": {
|
|||
|
|
"end_time": "2025-04-03T15:04:39.656944Z",
|
|||
|
|
"start_time": "2025-04-03T15:04:39.298483Z"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"score_df = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(1000, 'total_mv'))\n",
|
|||
|
|
"# score_df = fill_nan_with_daily_median(score_df, ['pe_ttm'])\n",
|
|||
|
|
"# score_df = score_df[score_df['pe_ttm'] > 0]\n",
|
|||
|
|
"score_df = score_df.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
|
|||
|
|
"score_df = score_df.merge(index_data, on='trade_date', how='left')\n",
|
|||
|
|
"# score_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(50, 'total_mv')).reset_index()\n",
|
|||
|
|
"numeric_columns = score_df.select_dtypes(include=['float64', 'int64']).columns\n",
|
|||
|
|
"numeric_columns = [col for col in feature_columns if col in numeric_columns]\n",
|
|||
|
|
"\n",
|
|||
|
|
"if type == 'cat':\n",
|
|||
|
|
" score_df['score'] = model.predict(score_df[feature_columns])\n",
|
|||
|
|
"elif type == 'light':\n",
|
|||
|
|
" score_df['score'] = model.predict(score_df[feature_columns])\n",
|
|||
|
|
"score_df['score_ranks'] = score_df.groupby('trade_date')['score'].rank(ascending=True)\n",
|
|||
|
|
"\n",
|
|||
|
|
"score_df = score_df.groupby('trade_date', group_keys=False).apply(\n",
|
|||
|
|
" lambda x: \n",
|
|||
|
|
" x[\n",
|
|||
|
|
" # (x['score'] <= x['score'].quantile(0.99)) & \n",
|
|||
|
|
" (x['score'] >= x['score'].quantile(0.90))\n",
|
|||
|
|
" ] # 计算90%分位数作为阈值,筛选分数>=阈值的行\n",
|
|||
|
|
").reset_index(drop=True) # drop=True 避免添加旧索引列\n",
|
|||
|
|
"# df_to_drop = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
|
|||
|
|
"# score_df = score_df.drop(df_to_drop.index)\n",
|
|||
|
|
"save_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(5, 'score')).reset_index()\n",
|
|||
|
|
"# save_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(2, 'total_mv')).reset_index()\n",
|
|||
|
|
"save_df = save_df.sort_values(['trade_date', 'score'])\n",
|
|||
|
|
"save_df[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "fed2d6c3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"2023-01-03 00:00:00\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print(test_data['trade_date'].min())"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "1f3c1331",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"\n",
|
|||
|
|
"成功连接到 Redis 服务器: 140.143.91.66:6389,数据库 0\n",
|
|||
|
|
"DataFrame 已使用 Pickle 序列化并写入 Redis,键为 'save_df'\n",
|
|||
|
|
"从 Redis 读取到的 Pickle 序列化数据 (前 20 字节):\n",
|
|||
|
|
"b'\\x80\\x04\\x95\\xbf\\x04\\x01\\x00\\x00\\x00\\x00\\x00\\x8c\\x11pandas.'\n",
|
|||
|
|
"\n",
|
|||
|
|
"从 Redis 加载的 DataFrame (使用 Pickle):\n",
|
|||
|
|
" index ts_code trade_date open close high low vol \\\n",
|
|||
|
|
"4 36 002247.SZ 2023-01-03 16.15 16.80 16.87 16.09 0.514578 \n",
|
|||
|
|
"3 26 002513.SZ 2023-01-03 15.41 15.77 15.95 15.41 -0.499029 \n",
|
|||
|
|
"2 14 002629.SZ 2023-01-03 15.17 15.54 15.58 15.13 0.631716 \n",
|
|||
|
|
"1 5 603030.SH 2023-01-03 8.08 8.32 8.32 8.05 -0.033641 \n",
|
|||
|
|
"0 3 000691.SZ 2023-01-03 13.87 14.24 14.36 13.63 -0.030740 \n",
|
|||
|
|
"... ... ... ... ... ... ... ... ... \n",
|
|||
|
|
"2904 58031 002524.SZ 2025-05-30 19.49 19.40 20.27 19.31 0.455783 \n",
|
|||
|
|
"2903 58034 002084.SZ 2025-05-30 11.88 11.68 11.88 11.58 0.610122 \n",
|
|||
|
|
"2902 58029 600159.SH 2025-05-30 13.66 13.55 13.86 13.50 -0.170606 \n",
|
|||
|
|
"2901 58035 002775.SZ 2025-05-30 17.84 17.55 17.84 17.41 -0.705858 \n",
|
|||
|
|
"2900 58019 600408.SH 2025-05-30 8.64 8.51 8.64 8.43 0.684604 \n",
|
|||
|
|
"\n",
|
|||
|
|
" pct_chg amount ... 000905.SH_up_ratio_20d \\\n",
|
|||
|
|
"4 0.715898 31834.487 ... 0.3 \n",
|
|||
|
|
"3 0.121763 25452.447 ... 0.3 \n",
|
|||
|
|
"2 0.299045 55379.071 ... 0.3 \n",
|
|||
|
|
"1 0.241548 22271.706 ... 0.3 \n",
|
|||
|
|
"0 0.279880 38602.205 ... 0.3 \n",
|
|||
|
|
"... ... ... ... ... \n",
|
|||
|
|
"2904 0.552905 76113.818 ... 0.6 \n",
|
|||
|
|
"2903 -0.070682 67595.352 ... 0.6 \n",
|
|||
|
|
"2902 0.005989 28573.069 ... 0.6 \n",
|
|||
|
|
"2901 0.241111 19025.635 ... 0.6 \n",
|
|||
|
|
"2900 0.788027 40827.689 ... 0.6 \n",
|
|||
|
|
"\n",
|
|||
|
|
" 399006.SZ_up_ratio_20d 000852.SH_volatility 000905.SH_volatility \\\n",
|
|||
|
|
"4 0.40 1.036997 0.828596 \n",
|
|||
|
|
"3 0.40 1.036997 0.828596 \n",
|
|||
|
|
"2 0.40 1.036997 0.828596 \n",
|
|||
|
|
"1 0.40 1.036997 0.828596 \n",
|
|||
|
|
"0 0.40 1.036997 0.828596 \n",
|
|||
|
|
"... ... ... ... \n",
|
|||
|
|
"2904 0.45 1.089861 0.850444 \n",
|
|||
|
|
"2903 0.45 1.089861 0.850444 \n",
|
|||
|
|
"2902 0.45 1.089861 0.850444 \n",
|
|||
|
|
"2901 0.45 1.089861 0.850444 \n",
|
|||
|
|
"2900 0.45 1.089861 0.850444 \n",
|
|||
|
|
"\n",
|
|||
|
|
" 399006.SZ_volatility 000852.SH_volume_change_rate \\\n",
|
|||
|
|
"4 0.935322 5.203088 \n",
|
|||
|
|
"3 0.935322 5.203088 \n",
|
|||
|
|
"2 0.935322 5.203088 \n",
|
|||
|
|
"1 0.935322 5.203088 \n",
|
|||
|
|
"0 0.935322 5.203088 \n",
|
|||
|
|
"... ... ... \n",
|
|||
|
|
"2904 1.195355 -2.039466 \n",
|
|||
|
|
"2903 1.195355 -2.039466 \n",
|
|||
|
|
"2902 1.195355 -2.039466 \n",
|
|||
|
|
"2901 1.195355 -2.039466 \n",
|
|||
|
|
"2900 1.195355 -2.039466 \n",
|
|||
|
|
"\n",
|
|||
|
|
" 000905.SH_volume_change_rate 399006.SZ_volume_change_rate score \\\n",
|
|||
|
|
"4 -0.750721 8.827360 0.042645 \n",
|
|||
|
|
"3 -0.750721 8.827360 0.047474 \n",
|
|||
|
|
"2 -0.750721 8.827360 0.055433 \n",
|
|||
|
|
"1 -0.750721 8.827360 0.061053 \n",
|
|||
|
|
"0 -0.750721 8.827360 0.068349 \n",
|
|||
|
|
"... ... ... ... \n",
|
|||
|
|
"2904 -12.002493 5.078672 0.064251 \n",
|
|||
|
|
"2903 -12.002493 5.078672 0.064517 \n",
|
|||
|
|
"2902 -12.002493 5.078672 0.065030 \n",
|
|||
|
|
"2901 -12.002493 5.078672 0.068848 \n",
|
|||
|
|
"2900 -12.002493 5.078672 0.074780 \n",
|
|||
|
|
"\n",
|
|||
|
|
" score_ranks \n",
|
|||
|
|
"4 996.0 \n",
|
|||
|
|
"3 997.0 \n",
|
|||
|
|
"2 998.0 \n",
|
|||
|
|
"1 999.0 \n",
|
|||
|
|
"0 1000.0 \n",
|
|||
|
|
"... ... \n",
|
|||
|
|
"2904 996.0 \n",
|
|||
|
|
"2903 997.0 \n",
|
|||
|
|
"2902 998.0 \n",
|
|||
|
|
"2901 999.0 \n",
|
|||
|
|
"2900 1000.0 \n",
|
|||
|
|
"\n",
|
|||
|
|
"[2905 rows x 248 columns]\n",
|
|||
|
|
"\n",
|
|||
|
|
"验证成功:原始 DataFrame 和从 Redis 加载的 DataFrame 一致。\n",
|
|||
|
|
"\n",
|
|||
|
|
"清理了 Redis 中键 'save_df' 的数据。\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import redis\n",
|
|||
|
|
"import pickle\n",
|
|||
|
|
"\n",
|
|||
|
|
"redis_host = '140.143.91.66'\n",
|
|||
|
|
"redis_port = 6389\n",
|
|||
|
|
"redis_db = 0\n",
|
|||
|
|
"redis_key = 'save_df'\n",
|
|||
|
|
"\n",
|
|||
|
|
"try:\n",
|
|||
|
|
" # 1. 连接到 Redis 服务器\n",
|
|||
|
|
" r = redis.Redis(host=redis_host, port=redis_port, db=redis_db, password='Redis520102')\n",
|
|||
|
|
" r.ping()\n",
|
|||
|
|
" print(f\"\\n成功连接到 Redis 服务器: {redis_host}:{redis_port},数据库 {redis_db}\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 2. 将 DataFrame 写入 Redis (使用 Pickle 序列化)\n",
|
|||
|
|
" df_serialized = pickle.dumps(save_df)\n",
|
|||
|
|
" r.set(redis_key, df_serialized)\n",
|
|||
|
|
" print(f\"DataFrame 已使用 Pickle 序列化并写入 Redis,键为 '{redis_key}'\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 3. 从 Redis 读取数据 (获取 Pickle 序列化的字节流)\n",
|
|||
|
|
" retrieved_serialized = r.get(redis_key)\n",
|
|||
|
|
"\n",
|
|||
|
|
" if retrieved_serialized:\n",
|
|||
|
|
" print(f\"从 Redis 读取到的 Pickle 序列化数据 (前 20 字节):\")\n",
|
|||
|
|
" print(retrieved_serialized[:20])\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 4. 使用 Pickle 反序列化回 Pandas DataFrame\n",
|
|||
|
|
" loaded_df = pickle.loads(retrieved_serialized)\n",
|
|||
|
|
" print(\"\\n从 Redis 加载的 DataFrame (使用 Pickle):\")\n",
|
|||
|
|
" print(loaded_df)\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 5. 验证原始 DataFrame 和加载的 DataFrame 是否一致\n",
|
|||
|
|
" if save_df.equals(loaded_df):\n",
|
|||
|
|
" print(\"\\n验证成功:原始 DataFrame 和从 Redis 加载的 DataFrame 一致。\")\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print(\"\\n验证失败:原始 DataFrame 和从 Redis 加载的 DataFrame 不一致!\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" print(f\"错误:无法从 Redis 获取键 '{redis_key}' 的值。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 6. 清理测试数据 (可选)\n",
|
|||
|
|
" r.delete(redis_key)\n",
|
|||
|
|
" print(f\"\\n清理了 Redis 中键 '{redis_key}' 的数据。\")\n",
|
|||
|
|
"\n",
|
|||
|
|
"except redis.exceptions.ConnectionError as e:\n",
|
|||
|
|
" print(f\"无法连接到 Redis 服务器: {e}\")\n",
|
|||
|
|
" print(\"请确保您的 Redis 服务器已启动并且主机和端口配置正确。\")\n",
|
|||
|
|
"except redis.exceptions.TimeoutError as e:\n",
|
|||
|
|
" print(f\"连接 Redis 服务器超时: {e}\")\n",
|
|||
|
|
" print(\"请检查您的网络连接和 Redis 服务器状态。\")\n",
|
|||
|
|
"except Exception as e:\n",
|
|||
|
|
" print(f\"测试 Redis 时发生未知错误: {e}\")\n",
|
|||
|
|
" print(f\"测试 Redis 时发生未知错误: {e}\")"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "09b1799e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"205\n",
|
|||
|
|
"['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'holder_net_change_sum_10d', 'holder_increase_days_10d', 'holder_decrease_days_10d', 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'cs_rank_size', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_amount_mean', '000905.SH_amount_mean', '399006.SZ_amount_mean', '000852.SH_da
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print(len(feature_columns))\n",
|
|||
|
|
"print(feature_columns)"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "bceabd1f",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
"警告: DataFrame 中没有 'group_id' 列。假设整个 DataFrame 是一个需要排序的组。\n",
|
|||
|
|
"\n",
|
|||
|
|
"NDCG 结果\n",
|
|||
|
|
"{'ndcg@1': np.float64(0.4489795918367347), 'ndcg@3': np.float64(0.40668217598446815), 'ndcg@5': np.float64(0.45584495629735)}\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"import pandas as pd\n",
|
|||
|
|
"import numpy as np\n",
|
|||
|
|
"\n",
|
|||
|
|
"def calculate_ndcg(df: pd.DataFrame, score_col: str, label_col: str, group_id: str = 'trade_date', k_values: list = [1, 3, 5, 10]):\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" 计算 DataFrame 中 score 列和 label 列的 NDCG 值。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Args:\n",
|
|||
|
|
" df (pd.DataFrame): 包含 score (排序学习预测分数) 和 label (相关性标签) 的 DataFrame。\n",
|
|||
|
|
" 假设每个需要排序的组(例如,每天的股票)在 DataFrame 中是连续的。\n",
|
|||
|
|
" score_col (str): 包含模型预测分数的列名。\n",
|
|||
|
|
" label_col (str): 包含相关性标签的列名。标签值越高表示相关性越高。\n",
|
|||
|
|
" k_values (list): 一个整数列表,表示计算 NDCG 的 top-k 值。\n",
|
|||
|
|
" 例如,[1, 3, 5] 将计算 NDCG@1, NDCG@3 和 NDCG@5。\n",
|
|||
|
|
"\n",
|
|||
|
|
" Returns:\n",
|
|||
|
|
" dict: 一个字典,包含每个 k 值对应的平均 NDCG 值。\n",
|
|||
|
|
" 例如: {'ndcg@1': 0.85, 'ndcg@3': 0.78, 'ndcg@5': 0.72, 'ndcg@10': 0.65}\n",
|
|||
|
|
" \"\"\"\n",
|
|||
|
|
" ndcg_scores = {f'ndcg@{k}': [] for k in k_values}\n",
|
|||
|
|
"\n",
|
|||
|
|
" def dcg_at_k(r, k):\n",
|
|||
|
|
" r = np.asarray(r)[:k] if len(r) > 0 else np.zeros(k)\n",
|
|||
|
|
" return np.sum(r / np.log2(np.arange(2, r.size + 2)))\n",
|
|||
|
|
"\n",
|
|||
|
|
" def ndcg_at_k(r, k):\n",
|
|||
|
|
" dcg_max = dcg_at_k(sorted(r, reverse=True), k)\n",
|
|||
|
|
" if not dcg_max:\n",
|
|||
|
|
" return 0.\n",
|
|||
|
|
" return dcg_at_k(r, k) / dcg_max\n",
|
|||
|
|
"\n",
|
|||
|
|
" # 假设 DataFrame 已经按照需要排序的组(例如,'trade_date')进行了分组,\n",
|
|||
|
|
" # 并且每个组内的顺序不重要,我们只需要计算每个组的 NDCG。\n",
|
|||
|
|
" # 如果需要按特定组计算 NDCG,请先对 DataFrame 进行分组。\n",
|
|||
|
|
" if group_id not in df.columns:\n",
|
|||
|
|
" print(\"警告: DataFrame 中没有 'group_id' 列。假设整个 DataFrame 是一个需要排序的组。\")\n",
|
|||
|
|
" group_df = df.sort_values(by=score_col, ascending=False)\n",
|
|||
|
|
" relevant_labels = group_df[label_col].values\n",
|
|||
|
|
" for k in k_values:\n",
|
|||
|
|
" ndcg_scores[f'ndcg@{k}'].append(ndcg_at_k(relevant_labels, k))\n",
|
|||
|
|
" else:\n",
|
|||
|
|
" for _, group_df in df.groupby(group_id):\n",
|
|||
|
|
" group_df_sorted = group_df.sort_values(by=score_col, ascending=False)\n",
|
|||
|
|
" relevant_labels = group_df_sorted[label_col].values\n",
|
|||
|
|
" for k in k_values:\n",
|
|||
|
|
" ndcg_scores[f'ndcg@{k}'].append(ndcg_at_k(relevant_labels, k))\n",
|
|||
|
|
"\n",
|
|||
|
|
" avg_ndcg = {k: np.mean(v) if v else np.nan for k, v in ndcg_scores.items()}\n",
|
|||
|
|
" return avg_ndcg\n",
|
|||
|
|
"\n",
|
|||
|
|
"\n",
|
|||
|
|
"ndcg_results_single_group = calculate_ndcg(score_df, score_col='score', label_col='label', k_values=[1, 3, 5], group_id=None)\n",
|
|||
|
|
"print(\"\\nNDCG 结果\")\n",
|
|||
|
|
"print(ndcg_results_single_group)\n"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "44f64679",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [
|
|||
|
|
{
|
|||
|
|
"name": "stdout",
|
|||
|
|
"output_type": "stream",
|
|||
|
|
"text": [
|
|||
|
|
" ts_code trade_date open close high low vol pct_chg \\\n",
|
|||
|
|
"1632028 002652.SZ 2019-01-02 19.59 19.64 19.89 19.28 20196.79 1.03 \n",
|
|||
|
|
"1632029 002652.SZ 2019-01-03 19.74 19.44 19.84 19.33 15731.99 -1.02 \n",
|
|||
|
|
"1632030 002652.SZ 2019-01-04 19.33 19.94 19.99 19.08 21099.93 2.57 \n",
|
|||
|
|
"1632031 002652.SZ 2019-01-07 20.04 21.95 21.95 20.04 83534.19 10.08 \n",
|
|||
|
|
"1632032 002652.SZ 2019-01-08 23.21 21.65 23.87 21.65 149377.97 -1.37 \n",
|
|||
|
|
"... ... ... ... ... ... ... ... ... \n",
|
|||
|
|
"1633576 002652.SZ 2025-05-26 14.75 14.85 15.11 14.55 99560.80 1.02 \n",
|
|||
|
|
"1633577 002652.SZ 2025-05-27 14.90 15.00 15.11 14.70 101184.00 1.01 \n",
|
|||
|
|
"1633578 002652.SZ 2025-05-28 15.11 14.85 15.16 14.80 75859.20 -1.00 \n",
|
|||
|
|
"1633579 002652.SZ 2025-05-29 15.00 15.36 15.36 14.85 126044.40 3.43 \n",
|
|||
|
|
"1633580 002652.SZ 2025-05-30 15.36 15.11 15.41 14.95 107732.00 -1.63 \n",
|
|||
|
|
"\n",
|
|||
|
|
" amount turnover_rate ... cs_rank_vol_x_profit_margin \\\n",
|
|||
|
|
"1632028 7867.047 0.3964 ... 0.608839 \n",
|
|||
|
|
"1632029 6121.460 0.3088 ... 0.586710 \n",
|
|||
|
|
"1632030 8245.083 0.4141 ... 0.682847 \n",
|
|||
|
|
"1632031 35514.117 1.6394 ... 0.987591 \n",
|
|||
|
|
"1632032 67160.354 2.9317 ... 0.765693 \n",
|
|||
|
|
"... ... ... ... ... \n",
|
|||
|
|
"1633576 29428.560 1.9443 ... 0.652159 \n",
|
|||
|
|
"1633577 30112.801 1.9760 ... 0.657694 \n",
|
|||
|
|
"1633578 22507.876 1.4814 ... 0.664673 \n",
|
|||
|
|
"1633579 38068.857 2.4615 ... 0.921236 \n",
|
|||
|
|
"1633580 32385.927 2.1039 ... 0.702990 \n",
|
|||
|
|
"\n",
|
|||
|
|
" cs_rank_lg_flow_price_concordance cs_rank_turnover_per_winner \\\n",
|
|||
|
|
"1632028 0.203142 0.864865 \n",
|
|||
|
|
"1632029 0.156684 0.763417 \n",
|
|||
|
|
"1632030 0.184009 0.660949 \n",
|
|||
|
|
"1632031 0.734940 0.700000 \n",
|
|||
|
|
"1632032 0.874042 0.914234 \n",
|
|||
|
|
"... ... ... \n",
|
|||
|
|
"1633576 0.122259 0.394684 \n",
|
|||
|
|
"1633577 0.092722 0.414756 \n",
|
|||
|
|
"1633578 0.684945 0.323363 \n",
|
|||
|
|
"1633579 0.295779 0.390828 \n",
|
|||
|
|
"1633580 0.705316 0.419934 \n",
|
|||
|
|
"\n",
|
|||
|
|
" cs_rank_ind_cap_neutral_pe cs_rank_volume_ratio \\\n",
|
|||
|
|
"1632028 NaN 0.646930 \n",
|
|||
|
|
"1632029 NaN 0.251279 \n",
|
|||
|
|
"1632030 NaN 0.311724 \n",
|
|||
|
|
"1632031 NaN 0.988313 \n",
|
|||
|
|
"1632032 NaN 0.990142 \n",
|
|||
|
|
"... ... ... \n",
|
|||
|
|
"1633576 NaN 0.400997 \n",
|
|||
|
|
"1633577 NaN 0.450150 \n",
|
|||
|
|
"1633578 NaN 0.199236 \n",
|
|||
|
|
"1633579 NaN 0.640744 \n",
|
|||
|
|
"1633580 NaN 0.537542 \n",
|
|||
|
|
"\n",
|
|||
|
|
" cs_rank_elg_buy_sell_sm_ratio cs_rank_cost_dist_vol_ratio \\\n",
|
|||
|
|
"1632028 0.341855 0.678941 \n",
|
|||
|
|
"1632029 0.318912 0.402916 \n",
|
|||
|
|
"1632030 0.260036 0.460713 \n",
|
|||
|
|
"1632031 0.796350 0.988501 \n",
|
|||
|
|
"1632032 0.598905 0.991571 \n",
|
|||
|
|
"... ... ... \n",
|
|||
|
|
"1633576 0.153987 0.620930 \n",
|
|||
|
|
"1633577 0.156198 0.643403 \n",
|
|||
|
|
"1633578 0.153373 0.484546 \n",
|
|||
|
|
"1633579 0.623795 0.764374 \n",
|
|||
|
|
"1633580 0.133056 0.703987 \n",
|
|||
|
|
"\n",
|
|||
|
|
" cs_rank_size future_return label \n",
|
|||
|
|
"1632028 0.258948 0.158859 40.0 \n",
|
|||
|
|
"1632029 0.258123 0.136831 37.0 \n",
|
|||
|
|
"1632030 0.257664 0.106319 39.0 \n",
|
|||
|
|
"1632031 0.290146 -0.072893 4.0 \n",
|
|||
|
|
"1632032 0.282482 -0.057737 5.0 \n",
|
|||
|
|
"... ... ... ... \n",
|
|||
|
|
"1633576 0.032226 NaN NaN \n",
|
|||
|
|
"1633577 0.032901 NaN NaN \n",
|
|||
|
|
"1633578 0.032237 NaN NaN \n",
|
|||
|
|
"1633579 0.034231 NaN NaN \n",
|
|||
|
|
"1633580 0.033887 NaN NaN \n",
|
|||
|
|
"\n",
|
|||
|
|
"[1553 rows x 196 columns]\n"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"source": [
|
|||
|
|
"print(df[df['ts_code'] == '002652.SZ'])"
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"metadata": {
|
|||
|
|
"kernelspec": {
|
|||
|
|
"display_name": "stock",
|
|||
|
|
"language": "python",
|
|||
|
|
"name": "python3"
|
|||
|
|
},
|
|||
|
|
"language_info": {
|
|||
|
|
"codemirror_mode": {
|
|||
|
|
"name": "ipython",
|
|||
|
|
"version": 3
|
|||
|
|
},
|
|||
|
|
"file_extension": ".py",
|
|||
|
|
"mimetype": "text/x-python",
|
|||
|
|
"name": "python",
|
|||
|
|
"nbconvert_exporter": "python",
|
|||
|
|
"pygments_lexer": "ipython3",
|
|||
|
|
"version": "3.13.2"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"nbformat": 4,
|
|||
|
|
"nbformat_minor": 5
|
|||
|
|
}
|