2025-05-26 21:34:36 +08:00
{
"cells": [
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 1,
2025-05-26 21:34:36 +08:00
"id": "79a7758178bafdd3",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:46:06.987506Z",
"start_time": "2025-04-03T12:46:06.259551Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-04 13:50:02 +08:00
"/mnt/d/PyProject/NewStock\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import gc\n",
"import os\n",
"import sys\n",
2025-06-04 13:50:02 +08:00
"sys.path.append('/mnt/d/PyProject/NewStock/')\n",
2025-05-26 21:34:36 +08:00
"print(os.getcwd())\n",
"import pandas as pd\n",
"from main.factor.factor import get_rolling_factor, get_simple_factor\n",
"from main.utils.factor import read_industry_data\n",
"from main.utils.factor_processor import calculate_score\n",
"from main.utils.utils import read_and_merge_h5_data, merge_with_industry_data\n",
"\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 2,
2025-05-26 21:34:36 +08:00
"id": "a79cafb06a7e0e43",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:00.212859Z",
"start_time": "2025-04-03T12:46:06.998047Z"
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"daily data\n",
"daily basic\n",
"inner merge on ['ts_code', 'trade_date']\n",
"stk limit\n",
"left merge on ['ts_code', 'trade_date']\n",
"money flow\n",
"left merge on ['ts_code', 'trade_date']\n",
"cyq perf\n",
"left merge on ['ts_code', 'trade_date']\n",
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 8713571 entries, 0 to 8713570\n",
"Data columns (total 33 columns):\n",
" # Column Dtype \n",
"--- ------ ----- \n",
" 0 ts_code object \n",
" 1 trade_date datetime64[ns]\n",
" 2 open float64 \n",
" 3 close float64 \n",
" 4 high float64 \n",
" 5 low float64 \n",
" 6 vol float64 \n",
" 7 pct_chg float64 \n",
" 8 amount float64 \n",
" 9 turnover_rate float64 \n",
" 10 pe_ttm float64 \n",
" 11 circ_mv float64 \n",
" 12 total_mv float64 \n",
" 13 volume_ratio float64 \n",
" 14 is_st bool \n",
" 15 up_limit float64 \n",
" 16 down_limit float64 \n",
" 17 buy_sm_vol float64 \n",
" 18 sell_sm_vol float64 \n",
" 19 buy_lg_vol float64 \n",
" 20 sell_lg_vol float64 \n",
" 21 buy_elg_vol float64 \n",
" 22 sell_elg_vol float64 \n",
" 23 net_mf_vol float64 \n",
" 24 his_low float64 \n",
" 25 his_high float64 \n",
" 26 cost_5pct float64 \n",
" 27 cost_15pct float64 \n",
" 28 cost_50pct float64 \n",
" 29 cost_85pct float64 \n",
" 30 cost_95pct float64 \n",
" 31 weight_avg float64 \n",
" 32 winner_rate float64 \n",
"dtypes: bool(1), datetime64[ns](1), float64(30), object(1)\n",
"memory usage: 2.1+ GB\n",
"None\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"from main.utils.utils import read_and_merge_h5_data\n",
"\n",
"print('daily data')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_data.h5', key='daily_data',\n",
2025-05-29 20:41:18 +08:00
" columns=['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg', 'amount'],\n",
2025-05-26 21:34:36 +08:00
" df=None)\n",
"\n",
"print('daily basic')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/daily_basic.h5', key='daily_basic',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio',\n",
" 'is_st'], df=df, join='inner')\n",
"\n",
"print('stk limit')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/stk_limit.h5', key='stk_limit',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'pre_close', 'up_limit', 'down_limit'],\n",
" df=df)\n",
"print('money flow')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/money_flow.h5', key='money_flow',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol',\n",
" 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol'],\n",
" df=df)\n",
"print('cyq perf')\n",
2025-06-04 13:50:02 +08:00
"df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/cyq_perf.h5', key='cyq_perf',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct',\n",
" 'cost_50pct',\n",
" 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate'],\n",
" df=df)\n",
"print(df.info())"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 3,
2025-05-26 21:34:36 +08:00
"id": "cac01788dac10678",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:10.527104Z",
"start_time": "2025-04-03T12:47:00.488715Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"industry\n"
]
}
],
"source": [
"print('industry')\n",
2025-06-04 13:50:02 +08:00
"industry_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/industry_data.h5', key='industry_data',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'l2_code', 'in_date'],\n",
" df=None, on=['ts_code'], join='left')\n",
"\n",
"\n",
"def merge_with_industry_data(df, industry_df):\n",
" # 确保日期字段是 datetime 类型\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'])\n",
" industry_df['in_date'] = pd.to_datetime(industry_df['in_date'])\n",
"\n",
" # 对 industry_df 按 ts_code 和 in_date 排序\n",
" industry_df_sorted = industry_df.sort_values(['in_date', 'ts_code'])\n",
"\n",
" # 对原始 df 按 ts_code 和 trade_date 排序\n",
" df_sorted = df.sort_values(['trade_date', 'ts_code'])\n",
"\n",
" # 使用 merge_asof 进行向后合并\n",
" merged = pd.merge_asof(\n",
" df_sorted,\n",
" industry_df_sorted,\n",
" by='ts_code', # 按 ts_code 分组\n",
" left_on='trade_date',\n",
" right_on='in_date',\n",
" direction='backward'\n",
" )\n",
"\n",
" # 获取每个 ts_code 的最早 in_date 记录\n",
" min_in_date_per_ts = (industry_df_sorted\n",
" .groupby('ts_code')\n",
" .first()\n",
" .reset_index()[['ts_code', 'l2_code']])\n",
"\n",
" # 填充未匹配到的记录( trade_date 早于所有 in_date 的情况)\n",
" merged['l2_code'] = merged['l2_code'].fillna(\n",
" merged['ts_code'].map(min_in_date_per_ts.set_index('ts_code')['l2_code'])\n",
" )\n",
"\n",
" # 保留需要的列并重置索引\n",
" result = merged.reset_index(drop=True)\n",
" return result\n",
"\n",
"\n",
"# 使用示例\n",
"df = merge_with_industry_data(df, industry_df)\n",
"# print(mdf[mdf['ts_code'] == '600751.SH'][['ts_code', 'trade_date', 'l2_code']])"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 4,
2025-05-26 21:34:36 +08:00
"id": "c4e9e1d31da6dba6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:10.719252Z",
"start_time": "2025-04-03T12:47:10.541247Z"
},
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"from main.factor.factor import *\n",
"\n",
"def calculate_indicators(df):\n",
" \"\"\"\n",
" 计算四个指标: 当日涨跌幅、5日移动平均、RSI、MACD。\n",
" \"\"\"\n",
" df = df.sort_values('trade_date')\n",
" df['daily_return'] = (df['close'] - df['pre_close']) / df['pre_close'] * 100\n",
" # df['5_day_ma'] = df['close'].rolling(window=5).mean()\n",
" delta = df['close'].diff()\n",
" gain = delta.where(delta > 0, 0)\n",
" loss = -delta.where(delta < 0, 0)\n",
" avg_gain = gain.rolling(window=14).mean()\n",
" avg_loss = loss.rolling(window=14).mean()\n",
" rs = avg_gain / avg_loss\n",
" df['RSI'] = 100 - (100 / (1 + rs))\n",
"\n",
" # 计算MACD\n",
" ema12 = df['close'].ewm(span=12, adjust=False).mean()\n",
" ema26 = df['close'].ewm(span=26, adjust=False).mean()\n",
" df['MACD'] = ema12 - ema26\n",
" df['Signal_line'] = df['MACD'].ewm(span=9, adjust=False).mean()\n",
" df['MACD_hist'] = df['MACD'] - df['Signal_line']\n",
"\n",
" # 4. 情绪因子1: 市场上涨比例( Up Ratio) \n",
" df['up_ratio'] = df['daily_return'].apply(lambda x: 1 if x > 0 else 0)\n",
" df['up_ratio_20d'] = df['up_ratio'].rolling(window=20).mean() # 过去20天上涨比例\n",
"\n",
" # 5. 情绪因子2: 成交量变化率( Volume Change Rate) \n",
" df['volume_mean'] = df['vol'].rolling(window=20).mean() # 过去20天的平均成交量\n",
" df['volume_change_rate'] = (df['vol'] - df['volume_mean']) / df['volume_mean'] * 100 # 成交量变化率\n",
"\n",
" # 6. 情绪因子3: 波动率( Volatility) \n",
" df['volatility'] = df['daily_return'].rolling(window=20).std() # 过去20天的日收益率标准差\n",
"\n",
" # 7. 情绪因子4: 成交额变化率( Amount Change Rate) \n",
" df['amount_mean'] = df['amount'].rolling(window=20).mean() # 过去20天的平均成交额\n",
" df['amount_change_rate'] = (df['amount'] - df['amount_mean']) / df['amount_mean'] * 100 # 成交额变化率\n",
"\n",
" # df = sentiment_panic_greed_index(df)\n",
" # df = sentiment_market_breadth_proxy(df)\n",
" # df = sentiment_reversal_indicator(df)\n",
"\n",
" return df\n",
"\n",
"\n",
"def generate_index_indicators(h5_filename):\n",
" df = pd.read_hdf(h5_filename, key='index_data')\n",
" df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')\n",
" df = df.sort_values('trade_date')\n",
"\n",
" # 计算每个ts_code的相关指标\n",
" df_indicators = []\n",
" for ts_code in df['ts_code'].unique():\n",
" df_index = df[df['ts_code'] == ts_code].copy()\n",
" df_index = calculate_indicators(df_index)\n",
" df_indicators.append(df_index)\n",
"\n",
" # 合并所有指数的结果\n",
" df_all_indicators = pd.concat(df_indicators, ignore_index=True)\n",
"\n",
" # 保留trade_date列, 并将同一天的数据按ts_code合并成一行\n",
" df_final = df_all_indicators.pivot_table(\n",
" index='trade_date',\n",
" columns='ts_code',\n",
" values=['daily_return', \n",
" 'RSI', 'MACD', 'Signal_line', 'MACD_hist', \n",
" # 'sentiment_panic_greed_index',\n",
" 'up_ratio_20d', 'volume_change_rate', 'volatility',\n",
" 'amount_change_rate', 'amount_mean'],\n",
" aggfunc='last'\n",
" )\n",
"\n",
" df_final.columns = [f\"{col[1]}_{col[0]}\" for col in df_final.columns]\n",
" df_final = df_final.reset_index()\n",
"\n",
" return df_final\n",
"\n",
"\n",
"# 使用函数\n",
2025-06-04 13:50:02 +08:00
"h5_filename = '/mnt/d/PyProject/NewStock/data/index_data.h5'\n",
2025-05-26 21:34:36 +08:00
"index_data = generate_index_indicators(h5_filename)\n",
"index_data = index_data.dropna()\n"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 5,
2025-05-26 21:34:36 +08:00
"id": "a735bc02ceb4d872",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:10.821169Z",
"start_time": "2025-04-03T12:47:10.751831Z"
}
},
"outputs": [],
"source": [
"import talib\n",
"import numpy as np"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 6,
2025-05-26 21:34:36 +08:00
"id": "53f86ddc0677a6d7",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:15.944254Z",
"start_time": "2025-04-03T12:47:10.826179Z"
},
"jupyter": {
"source_hidden": true
},
"scrolled": true
},
"outputs": [],
"source": [
"from main.utils.factor import get_act_factor\n",
"\n",
"\n",
"def read_industry_data(h5_filename):\n",
" # 读取 H5 文件中所有的行业数据\n",
" industry_data = pd.read_hdf(h5_filename, key='sw_daily', columns=[\n",
" 'ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'pe', 'pb', 'vol'\n",
" ]) # 假设 H5 文件的键是 'industry_data'\n",
" industry_data = industry_data.sort_values(by=['ts_code', 'trade_date'])\n",
" industry_data = industry_data.reindex()\n",
" industry_data['trade_date'] = pd.to_datetime(industry_data['trade_date'], format='%Y%m%d')\n",
"\n",
" grouped = industry_data.groupby('ts_code', group_keys=False)\n",
" industry_data['obv'] = grouped.apply(\n",
" lambda x: pd.Series(talib.OBV(x['close'].values, x['vol'].values), index=x.index)\n",
" )\n",
" industry_data['return_5'] = grouped['close'].apply(lambda x: x / x.shift(5) - 1)\n",
" industry_data['return_20'] = grouped['close'].apply(lambda x: x / x.shift(20) - 1)\n",
"\n",
" industry_data = get_act_factor(industry_data, cat=False)\n",
" industry_data = industry_data.sort_values(by=['trade_date', 'ts_code'])\n",
"\n",
" # # 计算每天每个 ts_code 的因子和当天所有 ts_code 的中位数的偏差\n",
" # factor_columns = ['obv', 'return_5', 'return_20', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4'] # 因子列\n",
" # \n",
" # for factor in factor_columns:\n",
" # if factor in industry_data.columns:\n",
" # # 计算每天每个 ts_code 的因子值与当天所有 ts_code 的中位数的偏差\n",
" # industry_data[f'{factor}_deviation'] = industry_data.groupby('trade_date')[factor].transform(\n",
" # lambda x: x - x.mean())\n",
"\n",
" industry_data['return_5_percentile'] = industry_data.groupby('trade_date')['return_5'].transform(\n",
" lambda x: x.rank(pct=True))\n",
" industry_data['return_20_percentile'] = industry_data.groupby('trade_date')['return_20'].transform(\n",
" lambda x: x.rank(pct=True))\n",
"\n",
" # cs_rank_intraday_range(industry_data)\n",
" # cs_rank_close_pos_in_range(industry_data)\n",
"\n",
" industry_data = industry_data.drop(columns=['open', 'close', 'high', 'low', 'pe', 'pb', 'vol'])\n",
"\n",
" industry_data = industry_data.rename(\n",
" columns={col: f'industry_{col}' for col in industry_data.columns if col not in ['ts_code', 'trade_date']})\n",
"\n",
" industry_data = industry_data.rename(columns={'ts_code': 'cat_l2_code'})\n",
" return industry_data\n",
"\n",
"\n",
2025-06-04 13:50:02 +08:00
"industry_df = read_industry_data('/mnt/d/PyProject/NewStock/data/sw_daily.h5')\n"
2025-05-26 21:34:36 +08:00
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 7,
2025-05-26 21:34:36 +08:00
"id": "dbe2fd8021b9417f",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:15.969344Z",
"start_time": "2025-04-03T12:47:15.963327Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-05-29 20:41:18 +08:00
"['ts_code', 'open', 'close', 'high', 'low', 'amount', 'circ_mv', 'total_mv', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'in_date']\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"origin_columns = df.columns.tolist()\n",
"origin_columns = [col for col in origin_columns if\n",
" col not in ['turnover_rate', 'pe_ttm', 'volume_ratio', 'vol', 'pct_chg', 'l2_code', 'winner_rate']]\n",
"origin_columns = [col for col in origin_columns if col not in index_data.columns]\n",
"origin_columns = [col for col in origin_columns if 'cyq' not in col]\n",
"print(origin_columns)"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 8,
2025-05-26 21:34:36 +08:00
"id": "85c3e3d0235ffffa",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T12:47:16.089879Z",
"start_time": "2025-04-03T12:47:15.990101Z"
}
},
"outputs": [],
"source": [
2025-06-04 13:50:02 +08:00
"fina_indicator_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/fina_indicator.h5', key='fina_indicator',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'ann_date', 'undist_profit_ps', 'ocfps', 'bps'],\n",
" df=None)\n",
2025-06-04 13:50:02 +08:00
"cashflow_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/cashflow.h5', key='cashflow',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'ann_date', 'n_cashflow_act'],\n",
" df=None)\n",
2025-06-04 13:50:02 +08:00
"balancesheet_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/balancesheet.h5', key='balancesheet',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'ann_date', 'money_cap', 'total_liab'],\n",
" df=None)\n",
2025-06-04 13:50:02 +08:00
"top_list_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/top_list.h5', key='top_list',\n",
2025-05-26 21:34:36 +08:00
" columns=['ts_code', 'trade_date', 'reason'],\n",
" df=None)\n",
"\n",
2025-06-04 13:50:02 +08:00
"top_list_df = top_list_df.sort_values(by='trade_date', ascending=False).drop_duplicates(subset=['ts_code', 'trade_date'], keep='first').sort_values(by='trade_date')\n",
"\n",
"stk_holdertrade_df = read_and_merge_h5_data('/mnt/d/PyProject/NewStock/data/stk_holdertrade.h5', key='stk_holdertrade',\n",
" columns=['ts_code', 'ann_date', 'in_de', 'change_ratio'],\n",
" df=None)"
2025-05-26 21:34:36 +08:00
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 9,
2025-06-06 17:04:01 +08:00
"id": "ba5935c8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ 成功从 Redis Hash 'concept_stocks_daily_lists_pickle' 读取 1794 条每日概念股票数据。\n"
]
}
],
"source": [
"import redis\n",
"import pickle\n",
"from datetime import date, datetime\n",
"\n",
"# --- 配置 Redis 连接 ---\n",
"REDIS_HOST = '140.143.91.66'\n",
"REDIS_PORT = 6389\n",
"REDIS_DB = 0\n",
"\n",
"# --- 定义 Redis 键名 ---\n",
"HASH_KEY = \"concept_stocks_daily_lists_pickle\" # 区分之前的 JSON 版本\n",
"MAX_DATE_KEY = \"concept_stocks_max_date_pickle\" # 区分之前的 JSON 版本\n",
"\n",
"concept_dict = {}\n",
"\n",
"# --- 连接 Redis ---\n",
"try:\n",
" r = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, password='Redis520102')\n",
"\n",
" all_data_from_redis = r.hgetall(HASH_KEY) # 返回的是字典,键是字节,值是字节\n",
" \n",
" if all_data_from_redis:\n",
" for date_bytes, stocks_bytes in all_data_from_redis.items(): # 将变量名改为 date_bytes 更清晰\n",
" try:\n",
" # *** 修正点:将日期字节解码为字符串 ***\n",
" date_str = date_bytes.decode('utf-8') \n",
" date_obj = datetime.strptime(date_str, '%Y%m%d').date()\n",
" \n",
" stocks_list = pickle.loads(stocks_bytes)\n",
" concept_dict[date_obj] = stocks_list\n",
" except (ValueError, pickle.UnpicklingError) as e:\n",
" print(f\"⚠️ 警告: 解析 Redis 数据时出错 (日期键: '{date_bytes.decode('utf-8', errors='ignore')}'),跳过此条数据: {e}\") # 打印警告时也解码一下\n",
" print(f\"✅ 成功从 Redis Hash '{HASH_KEY}' 读取 {len(concept_dict)} 条每日概念股票数据。\")\n",
" else:\n",
" print(f\"ℹ ️ Redis Hash '{HASH_KEY}' 中没有找到任何数据。\")\n",
"\n",
"except redis.exceptions.ConnectionError as e:\n",
" print(f\"❌ 错误: 无法连接到 Redis 服务器,请检查 Redis 是否正在运行或连接配置: {e}\")\n",
"except Exception as e:\n",
" print(f\"❌ 从 Redis 读取数据时发生未知错误: {e}\")"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 10,
2025-05-26 21:34:36 +08:00
"id": "92d84ce15a562ec6",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T13:08:01.612695Z",
"start_time": "2025-04-03T12:47:16.121802Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"4566757\n",
2025-06-06 17:04:01 +08:00
"开始生成概念相关因子...\n",
"开始计算概念内截面排序因子,基于: ['pct_chg', 'turnover_rate', 'volume_ratio']\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"Ranking Features in Concepts: 100%|██████████| 3/3 [00:00<00:00, 15.60it/s]\n"
2025-06-06 17:04:01 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"概念相关因子生成完毕。\n",
2025-06-10 15:22:25 +08:00
"4566757\n",
2025-06-04 20:34:17 +08:00
"开始计算股东增减持因子...\n",
"警告: 'in_de' 列中存在未映射的值,可能导致 _direction 列出现NaN。\n",
"股东增减持因子计算完成。\n",
"Calculating cat_senti_mom_vol_spike...\n",
"Finished cat_senti_mom_vol_spike.\n",
"Calculating cat_senti_pre_breakout...\n",
"Calculating atr_10 as it's missing...\n",
"Calculating atr_40 as it's missing...\n",
"Finished cat_senti_pre_breakout.\n",
"计算因子 ts_turnover_rate_acceleration_5_20\n",
"计算因子 ts_vol_sustain_10_30\n",
"计算因子 cs_amount_outlier_10\n",
"计算因子 ts_ff_to_total_turnover_ratio\n",
"计算因子 ts_price_volume_trend_coherence_5_20\n",
"计算因子 ts_ff_turnover_rate_surge_10\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"开始计算因子: AR, BR (原地修改)...\n",
"因子 AR, BR 计算成功。\n",
"因子 AR, BR 计算流程结束。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"使用 'ann_date' 作为财务数据生效日期。\n",
"警告: 从 financial_data_subset 中移除了 366 行,因为其 'ts_code' 或 'ann_date' 列存在空值。\n",
"计算 BBI...\n",
"--- 计算日级别偏离度 (使用 pct_chg) ---\n",
"--- 计算日级别动量基准 (使用 pct_chg) ---\n",
"日级别动量基准计算完成 (使用 pct_chg)。\n",
"日级别偏离度计算完成 (使用 pct_chg)。\n",
"--- 计算日级别行业偏离度 (使用 pct_chg 和行业基准) ---\n",
"--- 计算日级别行业动量基准 (使用 pct_chg 和 cat_l2_code) ---\n",
"错误: 计算日级别行业动量基准需要以下列: ['pct_chg', 'cat_l2_code', 'trade_date', 'ts_code']。\n",
"错误: 计算日级别行业偏离度需要以下列: ['pct_chg', 'daily_industry_positive_benchmark', 'daily_industry_negative_benchmark']。请先运行 daily_industry_momentum_benchmark(df)。\n",
"Index(['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol',\n",
2025-06-06 17:04:01 +08:00
" 'pct_chg', 'amount', 'turnover_rate',\n",
" ...\n",
2025-06-04 20:34:17 +08:00
" 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike',\n",
" 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike',\n",
" 'vol_std_5', 'atr_14', 'atr_6', 'obv'],\n",
2025-06-06 17:04:01 +08:00
" dtype='object', length=104)\n",
2025-06-04 20:34:17 +08:00
"Calculating senti_strong_inflow...\n",
"Finished senti_strong_inflow.\n",
"Calculating lg_flow_mom_corr_20_60...\n",
"Finished lg_flow_mom_corr_20_60.\n",
"Calculating lg_flow_accel...\n",
"Finished lg_flow_accel.\n",
"Calculating profit_pressure...\n",
"Finished profit_pressure.\n",
"Calculating underwater_resistance...\n",
"Finished underwater_resistance.\n",
"Calculating cost_conc_std_20...\n",
"Finished cost_conc_std_20.\n",
"Calculating profit_decay_20...\n",
"Finished profit_decay_20.\n",
"Calculating vol_amp_loss_20...\n",
"Finished vol_amp_loss_20.\n",
"Calculating vol_drop_profit_cnt_5...\n",
"Finished vol_drop_profit_cnt_5.\n",
"Calculating lg_flow_vol_interact_20...\n",
"Finished lg_flow_vol_interact_20.\n",
"Calculating cost_break_confirm_cnt_5...\n",
"Finished cost_break_confirm_cnt_5.\n",
"Calculating atr_norm_channel_pos_14...\n",
"Finished atr_norm_channel_pos_14.\n",
"Calculating turnover_diff_skew_20...\n",
"Finished turnover_diff_skew_20.\n",
"Calculating lg_sm_flow_diverge_20...\n",
"Finished lg_sm_flow_diverge_20.\n",
"Calculating pullback_strong_20_20...\n",
"Finished pullback_strong_20_20.\n",
"Calculating vol_wgt_hist_pos_20...\n",
"Finished vol_wgt_hist_pos_20.\n",
"Calculating vol_adj_roc_20...\n",
"Finished vol_adj_roc_20.\n",
"Calculating cs_rank_net_lg_flow_val...\n",
"Finished cs_rank_net_lg_flow_val.\n",
"Calculating cs_rank_flow_divergence...\n",
"Finished cs_rank_flow_divergence.\n",
"Calculating cs_rank_ind_adj_lg_flow...\n",
"Finished cs_rank_ind_adj_lg_flow.\n",
"Calculating cs_rank_elg_buy_ratio...\n",
"Finished cs_rank_elg_buy_ratio.\n",
"Calculating cs_rank_rel_profit_margin...\n",
"Finished cs_rank_rel_profit_margin.\n",
"Calculating cs_rank_cost_breadth...\n",
"Finished cs_rank_cost_breadth.\n",
"Calculating cs_rank_dist_to_upper_cost...\n",
"Finished cs_rank_dist_to_upper_cost.\n",
"Calculating cs_rank_winner_rate...\n",
"Finished cs_rank_winner_rate.\n",
"Calculating cs_rank_intraday_range...\n",
"Finished cs_rank_intraday_range.\n",
"Calculating cs_rank_close_pos_in_range...\n",
"Finished cs_rank_close_pos_in_range.\n",
"Calculating cs_rank_opening_gap...\n",
"Error calculating cs_rank_opening_gap: Missing 'pre_close' column. Assigning NaN.\n",
"Calculating cs_rank_pos_in_hist_range...\n",
"Finished cs_rank_pos_in_hist_range.\n",
"Calculating cs_rank_vol_x_profit_margin...\n",
"Finished cs_rank_vol_x_profit_margin.\n",
"Calculating cs_rank_lg_flow_price_concordance...\n",
"Finished cs_rank_lg_flow_price_concordance.\n",
"Calculating cs_rank_turnover_per_winner...\n",
"Finished cs_rank_turnover_per_winner.\n",
"Calculating cs_rank_ind_cap_neutral_pe (Placeholder - requires statsmodels)...\n",
"Finished cs_rank_ind_cap_neutral_pe (Placeholder).\n",
"Calculating cs_rank_volume_ratio...\n",
"Finished cs_rank_volume_ratio.\n",
"Calculating cs_rank_elg_buy_sell_sm_ratio...\n",
"Finished cs_rank_elg_buy_sell_sm_ratio.\n",
"Calculating cs_rank_cost_dist_vol_ratio...\n",
"Finished cs_rank_cost_dist_vol_ratio.\n",
"<class 'pandas.core.frame.DataFrame'>\n",
2025-06-10 15:22:25 +08:00
"RangeIndex: 4566757 entries, 0 to 4566756\n",
"Columns: 197 entries, ts_code to cs_rank_cost_dist_vol_ratio\n",
"dtypes: bool(10), datetime64[ns](1), float64(175), int64(6), int8(1), object(4)\n",
2025-06-06 17:04:01 +08:00
"memory usage: 6.4+ GB\n",
2025-06-04 20:34:17 +08:00
"None\n",
2025-06-10 15:22:25 +08:00
"['ts_code', 'trade_date', 'open', 'close', 'high', 'low', 'vol', 'pct_chg', 'amount', 'turnover_rate', 'pe_ttm', 'circ_mv', 'total_mv', 'volume_ratio', 'is_st', 'up_limit', 'down_limit', 'buy_sm_vol', 'sell_sm_vol', 'buy_lg_vol', 'sell_lg_vol', 'buy_elg_vol', 'sell_elg_vol', 'net_mf_vol', 'his_low', 'his_high', 'cost_5pct', 'cost_15pct', 'cost_50pct', 'cost_85pct', 'cost_95pct', 'weight_avg', 'winner_rate', 'cat_l2_code', 'cat_hot_concept_stock', 'concept_rank_pct_chg', 'concept_rank_turnover_rate', 'concept_rank_volume_ratio', 'holder_net_change_sum_10d', 'holder_increase_days_10d', 'holder_decrease_days_10d', 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d', 'holder_direction_score_10d', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'log_circ_mv', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'flow_divergence_diff', 'flow_divergence_ratio', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'price_cost_divergence', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_flow_divergence', 'cs_rank_ind_adj_lg_flow', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_opening_gap', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_ind_cap_neutral_pe', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio']\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"import numpy as np\n",
"from main.factor.factor import *\n",
2025-06-04 13:50:02 +08:00
"from main.factor.money_factor import * \n",
2025-06-06 17:04:01 +08:00
"from main.factor.concept_factor import * \n",
2025-05-26 21:34:36 +08:00
"\n",
2025-05-29 20:41:18 +08:00
"\n",
2025-05-26 21:34:36 +08:00
"def filter_data(df):\n",
" # df = df.groupby('trade_date').apply(lambda x: x.nlargest(1000, 'act_factor1'))\n",
2025-05-29 20:41:18 +08:00
" df = df[~df[\"is_st\"]]\n",
" df = df[~df[\"ts_code\"].str.endswith(\"BJ\")]\n",
" df = df[~df[\"ts_code\"].str.startswith(\"30\")]\n",
" df = df[~df[\"ts_code\"].str.startswith(\"68\")]\n",
" df = df[~df[\"ts_code\"].str.startswith(\"8\")]\n",
" df = df[df[\"trade_date\"] >= \"2019-01-01\"]\n",
" if \"in_date\" in df.columns:\n",
" df = df.drop(columns=[\"in_date\"])\n",
2025-05-26 21:34:36 +08:00
" df = df.reset_index(drop=True)\n",
" return df\n",
"\n",
2025-05-29 20:41:18 +08:00
"\n",
2025-05-26 21:34:36 +08:00
"gc.collect()\n",
"\n",
"df = filter_data(df)\n",
2025-05-29 20:41:18 +08:00
"df = df.sort_values(by=[\"ts_code\", \"trade_date\"])\n",
2025-05-26 21:34:36 +08:00
"\n",
"# df = price_minus_deduction_price(df, n=120)\n",
"# df = price_deduction_price_diff_ratio_to_sma(df, n=120)\n",
"# df = cat_price_vs_sma_vs_deduction_price(df, n=120)\n",
"# df = cat_reason(df, top_list_df)\n",
"# df = cat_is_on_top_list(df, top_list_df)\n",
2025-06-06 17:04:01 +08:00
"print(len(df))\n",
"df = generate_concept_factors(df, concept_dict)\n",
"print(len(df))\n",
"\n",
2025-06-04 13:50:02 +08:00
"df = holder_trade_factors(df, stk_holdertrade_df)\n",
2025-05-26 21:34:36 +08:00
"\n",
2025-05-29 20:41:18 +08:00
"df = cat_senti_mom_vol_spike(\n",
" df,\n",
" return_period=3,\n",
" return_threshold=0.03, # 近3日涨幅超3%\n",
" volume_ratio_threshold=1.3,\n",
" current_pct_chg_min=0.0, # 当日必须收红\n",
2025-06-01 15:59:29 +08:00
" current_pct_chg_max=0.05,\n",
2025-05-29 20:41:18 +08:00
") # 当日涨幅不宜过大\n",
"\n",
"df = cat_senti_pre_breakout(\n",
" df,\n",
" atr_short_N=10,\n",
" atr_long_M=40,\n",
" vol_atrophy_N=10,\n",
" vol_atrophy_M=40,\n",
" price_stab_N=5,\n",
" price_stab_threshold=0.06,\n",
" current_pct_chg_min_signal=0.002,\n",
" current_pct_chg_max_signal=0.05,\n",
" volume_ratio_signal_threshold=1.1,\n",
")\n",
"\n",
2025-05-28 14:16:04 +08:00
"df = ts_turnover_rate_acceleration_5_20(df)\n",
"df = ts_vol_sustain_10_30(df)\n",
2025-05-29 20:41:18 +08:00
"# df = cs_turnover_rate_relative_strength_20(df)\n",
2025-05-28 14:16:04 +08:00
"df = cs_amount_outlier_10(df)\n",
"df = ts_ff_to_total_turnover_ratio(df)\n",
"df = ts_price_volume_trend_coherence_5_20(df)\n",
2025-05-29 20:41:18 +08:00
"# df = ts_turnover_rate_trend_strength_5(df)\n",
2025-05-28 14:16:04 +08:00
"df = ts_ff_turnover_rate_surge_10(df)\n",
"\n",
2025-05-29 20:41:18 +08:00
"df = add_financial_factor(df, fina_indicator_df, factor_value_col=\"undist_profit_ps\")\n",
"df = add_financial_factor(df, fina_indicator_df, factor_value_col=\"ocfps\")\n",
2025-05-26 21:34:36 +08:00
"calculate_arbr(df, N=26)\n",
2025-05-29 20:41:18 +08:00
"df[\"log_circ_mv\"] = np.log(df[\"circ_mv\"])\n",
2025-05-26 21:34:36 +08:00
"df = calculate_cashflow_to_ev_factor(df, cashflow_df, balancesheet_df)\n",
"df = caculate_book_to_price_ratio(df, fina_indicator_df)\n",
"df = turnover_rate_n(df, n=5)\n",
"df = variance_n(df, n=20)\n",
"df = bbi_ratio_factor(df)\n",
"df = daily_deviation(df)\n",
"df = daily_industry_deviation(df)\n",
"df, _ = get_rolling_factor(df)\n",
"df, _ = get_simple_factor(df)\n",
"\n",
2025-06-01 15:59:29 +08:00
"df = calculate_strong_inflow_signal(df)\n",
"\n",
2025-05-29 20:41:18 +08:00
"df = df.rename(columns={\"l1_code\": \"cat_l1_code\"})\n",
"df = df.rename(columns={\"l2_code\": \"cat_l2_code\"})\n",
2025-05-26 21:34:36 +08:00
"\n",
"lg_flow_mom_corr(df, N=20, M=60)\n",
"lg_flow_accel(df)\n",
"profit_pressure(df)\n",
"underwater_resistance(df)\n",
"cost_conc_std(df, N=20)\n",
"profit_decay(df, N=20)\n",
"vol_amp_loss(df, N=20)\n",
"vol_drop_profit_cnt(df, N=20, M=5)\n",
"lg_flow_vol_interact(df, N=20)\n",
"cost_break_confirm_cnt(df, M=5)\n",
"atr_norm_channel_pos(df, N=14)\n",
"turnover_diff_skew(df, N=20)\n",
"lg_sm_flow_diverge(df, N=20)\n",
"pullback_strong(df, N=20, M=20)\n",
"vol_wgt_hist_pos(df, N=20)\n",
"vol_adj_roc(df, N=20)\n",
"\n",
"cs_rank_net_lg_flow_val(df)\n",
"cs_rank_flow_divergence(df)\n",
2025-05-29 20:41:18 +08:00
"cs_rank_industry_adj_lg_flow(df) # Needs cat_l2_code\n",
2025-05-26 21:34:36 +08:00
"cs_rank_elg_buy_ratio(df)\n",
"cs_rank_rel_profit_margin(df)\n",
"cs_rank_cost_breadth(df)\n",
"cs_rank_dist_to_upper_cost(df)\n",
"cs_rank_winner_rate(df)\n",
"cs_rank_intraday_range(df)\n",
"cs_rank_close_pos_in_range(df)\n",
2025-05-29 20:41:18 +08:00
"cs_rank_opening_gap(df) # Needs pre_close\n",
"cs_rank_pos_in_hist_range(df) # Needs his_low, his_high\n",
2025-05-26 21:34:36 +08:00
"cs_rank_vol_x_profit_margin(df)\n",
"cs_rank_lg_flow_price_concordance(df)\n",
"cs_rank_turnover_per_winner(df)\n",
2025-05-29 20:41:18 +08:00
"cs_rank_ind_cap_neutral_pe(df) # Placeholder - needs external libraries\n",
"cs_rank_volume_ratio(df) # Needs volume_ratio\n",
2025-05-26 21:34:36 +08:00
"cs_rank_elg_buy_sell_sm_ratio(df)\n",
2025-05-29 20:41:18 +08:00
"cs_rank_cost_dist_vol_ratio(df) # Needs volume_ratio\n",
2025-06-10 15:22:25 +08:00
"# cs_rank_size(df) # Needs circ_mv\n",
2025-05-26 21:34:36 +08:00
"\n",
"\n",
"# df = df.merge(index_data, on='trade_date', how='left')\n",
"\n",
"print(df.info())\n",
"print(df.columns.tolist())"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 11,
2025-05-26 21:34:36 +08:00
"id": "b87b938028afa206",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T13:08:03.658725Z",
"start_time": "2025-04-03T13:08:02.469611Z"
}
},
"outputs": [],
"source": [
"from scipy.stats import ks_2samp, wasserstein_distance\n",
"\n",
"\n",
"def remove_shifted_features(train_data, test_data, feature_columns, ks_threshold=0.05, wasserstein_threshold=0.1,\n",
" importance_threshold=0.05):\n",
" dropped_features = []\n",
"\n",
" # **统计数据漂移**\n",
" numeric_columns = train_data.select_dtypes(include=['float64', 'int64']).columns\n",
" numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
" for feature in numeric_columns:\n",
" ks_stat, p_value = ks_2samp(train_data[feature], test_data[feature])\n",
" wasserstein_dist = wasserstein_distance(train_data[feature], test_data[feature])\n",
"\n",
" if p_value < ks_threshold or wasserstein_dist > wasserstein_threshold:\n",
" dropped_features.append(feature)\n",
"\n",
" print(f\"检测到 {len(dropped_features)} 个可能漂移的特征: {dropped_features}\")\n",
"\n",
" # **应用阈值进行最终筛选**\n",
" filtered_features = [f for f in feature_columns if f not in dropped_features]\n",
"\n",
" return filtered_features, dropped_features\n",
"\n"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 12,
2025-05-26 21:34:36 +08:00
"id": "47c12bb34062ae7a",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T14:57:50.841165Z",
"start_time": "2025-04-03T14:49:25.889057Z"
}
},
"outputs": [],
"source": [
"days = 5\n",
"validation_days = 120\n",
"\n",
"import gc\n",
"\n",
"gc.collect()\n",
"\n",
"df = df.sort_values(by=['ts_code', 'trade_date'])\n",
2025-06-04 13:50:02 +08:00
"# df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1)\n",
"df['future_return'] = df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-days) / x - 1) + df.groupby('ts_code', group_keys=False)['close'].apply(lambda x: x.shift(-2 * days) / x - 1)\n",
"\n",
2025-05-26 21:34:36 +08:00
"# df['future_return'] = (df.groupby('ts_code')['close'].shift(-days) - df.groupby('ts_code')['open'].shift(-1)) / \\\n",
"# df.groupby('ts_code')['open'].shift(-1)\n",
"\n",
2025-05-28 14:16:04 +08:00
"# df['cat_up_limit'] = df['pct_chg'] > 5\n",
2025-05-26 21:34:36 +08:00
"# df['label'] = (df.groupby('ts_code')['cat_up_limit']\n",
"# .rolling(window=5, min_periods=1).sum()\n",
"# .groupby('ts_code') # 再次按 ts_code 分组\n",
"# .shift(-5)\n",
"# .fillna(0) # 填充每个股票组最后的 NaN\n",
"# .astype(int)\n",
"# .reset_index(level=0, drop=True))\n",
2025-06-04 20:34:17 +08:00
"df['label'] = df.groupby('trade_date', group_keys=False)['future_return'].transform(\n",
" lambda x: pd.qcut(x, q=50, labels=False, duplicates='drop')\n",
")\n",
2025-05-29 20:41:18 +08:00
"# filter_index = df['future_return'].between(df['future_return'].quantile(0.01), df['future_return'].quantile(0.99))\n",
"filter_index = df['future_return'].between(df['future_return'].quantile(0.001), 0.6)\n",
2025-05-26 21:34:36 +08:00
"\n",
"# for col in [col for col in df.columns]:\n",
"# train_data[col] = train_data[col].astype('str')\n",
"# test_data[col] = test_data[col].astype('str')"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 13,
2025-05-26 21:34:36 +08:00
"id": "29221dde",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"207\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"feature_columns = [col for col in df.head(10)\n",
" .merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
" .merge(index_data, on='trade_date', how='left')\n",
" .columns\n",
" ]\n",
"feature_columns = [col for col in feature_columns if col not in ['trade_date',\n",
" 'ts_code',\n",
" 'label']]\n",
"feature_columns = [col for col in feature_columns if 'future' not in col]\n",
"feature_columns = [col for col in feature_columns if 'label' not in col]\n",
"feature_columns = [col for col in feature_columns if 'score' not in col]\n",
"feature_columns = [col for col in feature_columns if 'gen' not in col]\n",
"feature_columns = [col for col in feature_columns if 'is_st' not in col]\n",
"feature_columns = [col for col in feature_columns if 'pe_ttm' not in col]\n",
"# feature_columns = [col for col in feature_columns if 'volatility' not in col]\n",
"# feature_columns = [col for col in feature_columns if 'circ_mv' not in col]\n",
2025-06-06 17:04:01 +08:00
"feature_columns = [col for col in feature_columns if 'circ_mv' not in col]\n",
"\n",
2025-05-26 21:34:36 +08:00
"feature_columns = [col for col in feature_columns if 'code' not in col]\n",
"feature_columns = [col for col in feature_columns if col not in origin_columns]\n",
"feature_columns = [col for col in feature_columns if not col.startswith('_')]\n",
"# feature_columns = [col for col in feature_columns if col not in ['ts_code', 'trade_date', 'vol_std_5', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_007', 'consecutive_up_limit', 'mv_volatility', 'volume_growth', 'mv_growth', 'arbr']]\n",
"feature_columns = [col for col in feature_columns if col not in ['intraday_lg_flow_corr_20', \n",
" 'cap_neutral_cost_metric', \n",
" 'hurst_net_mf_vol_60', \n",
" 'complex_factor_deap_1', \n",
" 'lg_buy_consolidation_20',\n",
" 'cs_rank_ind_cap_neutral_pe',\n",
" 'cs_rank_opening_gap',\n",
" 'cs_rank_ind_adj_lg_flow']]\n",
"feature_columns = [col for col in feature_columns if col not in ['cat_reason', 'cat_is_on_top_list']]\n",
"print(len(feature_columns))"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 14,
2025-05-26 21:34:36 +08:00
"id": "03ee5daf",
"metadata": {},
"outputs": [],
"source": [
"# df = fill_nan_with_daily_median(df, feature_columns)\n",
"for feature_col in [col for col in feature_columns if col in df.columns]:\n",
" # median_val = df[feature_col].median()\n",
" df[feature_col].fillna(0, inplace=True)"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 15,
2025-05-26 21:34:36 +08:00
"id": "b76ea08a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ts_code trade_date log_circ_mv\n",
"0 000001.SZ 2019-01-02 16.574219\n",
"1 000001.SZ 2019-01-03 16.583965\n",
2025-06-04 13:50:02 +08:00
"2 000001.SZ 2019-01-04 16.633371\n",
2025-06-10 15:22:25 +08:00
"['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'cat_hot_concept_stock', 'concept_rank_pct_chg', 'concept_rank_turnover_rate', 'concept_rank_volume_ratio', 'holder_net_change_sum_10d', 'holder_increase_days_10d', 'holder_decrease_days_10d', 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_
2025-05-26 21:34:36 +08:00
"去除极值\n",
"开始截面 MAD 去极值处理 (k=3.0)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"MAD Filtering: 100%|██████████| 145/145 [00:07<00:00, 19.05it/s]\n"
2025-05-26 21:34:36 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 MAD 去极值处理完成。\n",
2025-06-04 20:34:17 +08:00
"标准化\n",
"开始截面 Z-Score 标准化...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"Standardizing: 100%|██████████| 145/145 [00:02<00:00, 61.44it/s]\n"
2025-06-04 20:34:17 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 Z-Score 标准化完成。\n",
2025-05-26 21:34:36 +08:00
"开始截面 MAD 去极值处理 (k=3.0)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"MAD Filtering: 100%|██████████| 145/145 [00:05<00:00, 27.24it/s]\n"
2025-05-26 21:34:36 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 MAD 去极值处理完成。\n",
2025-06-04 20:34:17 +08:00
"开始截面 Z-Score 标准化...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"Standardizing: 100%|██████████| 145/145 [00:01<00:00, 85.11it/s]\n"
2025-06-04 20:34:17 +08:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 Z-Score 标准化完成。\n",
2025-05-26 21:34:36 +08:00
"开始截面 MAD 去极值处理 (k=3.0)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"MAD Filtering: 0it [00:00, ?it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 MAD 去极值处理完成。\n",
"开始截面 MAD 去极值处理 (k=3.0)...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"MAD Filtering: 0it [00:00, ?it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"截面 MAD 去极值处理完成。\n",
2025-06-10 15:22:25 +08:00
"feature_columns: ['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'cat_hot_concept_stock', 'concept_rank_pct_chg', 'concept_rank_turnover_rate', 'concept_rank_volume_ratio', 'holder_net_change_sum_10d', 'holder_increase_days_10d', 'holder_decrease_days_10d', 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_r
2025-05-26 21:34:36 +08:00
"df最小日期: 2019-01-02\n",
2025-06-10 15:22:25 +08:00
"df最大日期: 2025-06-06\n",
2025-06-04 13:50:02 +08:00
"1091062\n",
2025-05-26 21:34:36 +08:00
"train_data最小日期: 2020-01-02\n",
"train_data最大日期: 2022-12-30\n",
2025-06-10 15:22:25 +08:00
"875950\n",
2025-05-26 21:34:36 +08:00
"test_data最小日期: 2023-01-03\n",
2025-06-10 15:22:25 +08:00
"test_data最大日期: 2025-06-06\n",
2025-05-26 21:34:36 +08:00
" ts_code trade_date log_circ_mv\n",
"0 000001.SZ 2019-01-02 16.574219\n",
"1 000001.SZ 2019-01-03 16.583965\n",
"2 000001.SZ 2019-01-04 16.633371\n"
]
}
],
"source": [
2025-06-06 17:04:01 +08:00
"from main.utils.data_process import *\n",
"\n",
2025-05-26 21:34:36 +08:00
"split_date = '2023-01-01'\n",
2025-06-04 13:50:02 +08:00
"train_data = df[filter_index & (df['trade_date'] <= split_date) & (df['trade_date'] >= '2020-01-01')].groupby(\n",
" 'trade_date', group_keys=False).apply(lambda x: x.nsmallest(1500, 'total_mv'))\n",
"test_data = df[(df['trade_date'] >= split_date)].groupby(\n",
" 'trade_date', group_keys=False).apply(lambda x: x.nsmallest(1500, 'total_mv'))\n",
2025-05-26 21:34:36 +08:00
"\n",
"print(df[['ts_code', 'trade_date', 'log_circ_mv']].head(3))\n",
"\n",
"industry_df = industry_df.sort_values(by=['trade_date'])\n",
"index_data = index_data.sort_values(by=['trade_date'])\n",
"\n",
"# train_data = train_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# train_data = train_data.merge(index_data, on='trade_date', how='left')\n",
"# test_data = test_data.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"# test_data = test_data.merge(index_data, on='trade_date', how='left')\n",
"\n",
"train_data, test_data = train_data.replace([np.inf, -np.inf], np.nan), test_data.replace([np.inf, -np.inf], np.nan)\n",
"\n",
2025-06-04 20:34:17 +08:00
"# train_data['label'] = train_data.groupby('trade_date', group_keys=False)['future_return'].transform(\n",
"# lambda x: pd.qcut(x, q=100, labels=False, duplicates='drop')\n",
"# )\n",
"# test_data['label'] = test_data.groupby('trade_date', group_keys=False)['future_return'].transform(\n",
"# lambda x: pd.qcut(x, q=100, labels=False, duplicates='drop')\n",
"# )\n",
2025-06-04 13:50:02 +08:00
"\n",
2025-05-26 21:34:36 +08:00
"# feature_columns_new = feature_columns[:]\n",
"# train_data, _ = create_deviation_within_dates(train_data, [col for col in feature_columns if col in train_data.columns])\n",
"# test_data, _ = create_deviation_within_dates(test_data, [col for col in feature_columns if col in train_data.columns])\n",
"\n",
"numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns\n",
"numeric_columns = [col for col in numeric_columns if col in feature_columns]\n",
"# feature_columns = select_top_features_by_rankic(df, numeric_columns, n=10)\n",
"print(feature_columns)\n",
"\n",
"# train_data = fill_nan_with_daily_median(train_data, feature_columns)\n",
"# test_data = fill_nan_with_daily_median(test_data, feature_columns)\n",
"\n",
"train_data = train_data.dropna(subset=[col for col in feature_columns if col in train_data.columns])\n",
"train_data = train_data.dropna(subset=['label'])\n",
"train_data = train_data.reset_index(drop=True)\n",
"# print(test_data.tail())\n",
"test_data = test_data.dropna(subset=[col for col in feature_columns if col in train_data.columns])\n",
"# test_data = test_data.dropna(subset=['label'])\n",
"test_data = test_data.reset_index(drop=True)\n",
"\n",
"transform_feature_columns = feature_columns\n",
"transform_feature_columns = [col for col in transform_feature_columns if col in feature_columns and not col.startswith('cat') and col in train_data.columns]\n",
"# transform_feature_columns.remove('undist_profit_ps')\n",
"print('去除极值')\n",
"cs_mad_filter(train_data, transform_feature_columns)\n",
"# print('中性化')\n",
"# cs_neutralize_industry_cap(train_data, transform_feature_columns)\n",
2025-06-04 20:34:17 +08:00
"print('标准化')\n",
"cs_zscore_standardize(train_data, transform_feature_columns)\n",
2025-05-26 21:34:36 +08:00
"\n",
"cs_mad_filter(test_data, transform_feature_columns)\n",
"# cs_neutralize_industry_cap(test_data, transform_feature_columns)\n",
2025-06-04 20:34:17 +08:00
"cs_zscore_standardize(test_data, transform_feature_columns)\n",
2025-05-26 21:34:36 +08:00
"\n",
"mad_filter_feature_columns = [col for col in feature_columns if col not in transform_feature_columns and not col.startswith('cat') and col in train_data.columns]\n",
"cs_mad_filter(train_data, mad_filter_feature_columns)\n",
"cs_mad_filter(test_data, mad_filter_feature_columns)\n",
"\n",
"\n",
"print(f'feature_columns: {feature_columns}')\n",
"\n",
"\n",
"print(f\"df最小日期: {df['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"df最大日期: {df['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(train_data))\n",
"print(f\"train_data最小日期: {train_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"train_data最大日期: {train_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"print(len(test_data))\n",
"print(f\"test_data最小日期: {test_data['trade_date'].min().strftime('%Y-%m-%d')}\")\n",
"print(f\"test_data最大日期: {test_data['trade_date'].max().strftime('%Y-%m-%d')}\")\n",
"\n",
"cat_columns = [col for col in feature_columns if col.startswith('cat')]\n",
"for col in cat_columns:\n",
" train_data[col] = train_data[col].astype('category')\n",
" test_data[col] = test_data[col].astype('category')\n",
"\n",
"print(df[['ts_code', 'trade_date', 'log_circ_mv']].head(3))\n"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 126,
2025-05-26 21:34:36 +08:00
"id": "3ff2d1c5",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.linear_model import LogisticRegression\n",
"import matplotlib.pyplot as plt # 保持 matplotlib 导入, 尽管LightGBM的绘图功能已移除\n",
"from sklearn.decomposition import PCA\n",
"import pandas as pd\n",
"import numpy as np\n",
"import datetime # 用于日期计算\n",
"from catboost import CatBoostClassifier, CatBoostRanker, CatBoostRegressor\n",
"from catboost import Pool\n",
"import lightgbm as lgb\n",
"from lightgbm import LGBMRanker, LGBMRegressor\n",
"\n",
"def train_model(train_data_df, feature_columns,\n",
" print_info=True, # 调整参数名,更通用\n",
" validation_days=180, use_pca=False, split_date=None,\n",
" target_column='label', type='light'): # 增加目标列参数\n",
"\n",
" print('train data size: ', len(train_data_df))\n",
" print(train_data_df[['ts_code', 'trade_date', 'log_circ_mv']])\n",
" # 确保数据按时间排序\n",
" train_data_df = train_data_df.sort_values(by='trade_date')\n",
"\n",
" # 识别数值型特征列\n",
" numeric_feature_columns = train_data_df[feature_columns].select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
" # 去除标签为空的样本\n",
" initial_len = len(train_data_df)\n",
" train_data_df = train_data_df.dropna(subset=[target_column])\n",
"\n",
" if print_info:\n",
" print(f'原始样本数: {initial_len}, 去除标签为空后样本数: {len(train_data_df)}')\n",
"\n",
" # 提取特征和标签,只取数值型特征用于线性回归\n",
" \n",
" if split_date is None:\n",
" all_dates = train_data_df['trade_date'].unique() # 获取所有唯一的 trade_date\n",
" split_date = all_dates[-validation_days] # 划分点为倒数第 validation_days 天\n",
" train_data_split = train_data_df[train_data_df['trade_date'] < split_date] # 训练集\n",
" val_data_split = train_data_df[train_data_df['trade_date'] >= split_date] # 验证集\n",
"\n",
" train_data_split = train_data_split.sort_values('trade_date')\n",
" val_data_split = val_data_split.sort_values('trade_date')\n",
"\n",
" \n",
" X_train = train_data_split[feature_columns]\n",
" y_train = train_data_split[target_column]\n",
" \n",
" X_val = val_data_split[feature_columns]\n",
" y_val = val_data_split[target_column]\n",
"\n",
"\n",
" # # 标准化数值特征 (使用 StandardScaler 对训练集fit并transform, 对验证集只transform)\n",
" scaler = StandardScaler()\n",
" # X_train = scaler.fit_transform(X_train)\n",
"\n",
" # 训练线性回归模型\n",
" # model = LogisticRegression(random_state=42)\n",
" \n",
" # # 使用处理后的特征和样本权重进行训练\n",
" # model.fit(X_train, y_train)\n",
"\n",
"\n",
" if type == 'cat':\n",
" params = {\n",
2025-06-04 13:50:02 +08:00
" 'loss_function': 'QueryRMSE', # 适用于二分类\n",
2025-05-26 21:34:36 +08:00
" 'eval_metric': 'NDCG', # 评估指标\n",
" 'iterations': 1500,\n",
2025-05-29 20:41:18 +08:00
" 'learning_rate': 0.03,\n",
2025-06-01 15:59:29 +08:00
" 'depth': 8, # 控制模型复杂度\n",
" 'l2_leaf_reg': 1, # L2 正则化\n",
2025-05-26 21:34:36 +08:00
" 'verbose': 5000,\n",
" 'early_stopping_rounds': 300,\n",
" 'one_hot_max_size': 50,\n",
" # 'class_weights': [0.6, 1.2],\n",
" 'task_type': 'GPU',\n",
" 'has_time': True,\n",
" 'random_seed': 7\n",
" }\n",
" cat_features = [i for i, col in enumerate(feature_columns) if col.startswith('cat')]\n",
" group_train = train_data_split['trade_date'].factorize()[0]\n",
" group_val = val_data_split['trade_date'].factorize()[0]\n",
" train_pool = Pool(\n",
" data=X_train,\n",
" label=y_train,\n",
" group_id=group_train,\n",
" cat_features=cat_features\n",
" )\n",
" val_pool = Pool(\n",
" data=X_val,\n",
" label=y_val,\n",
" group_id=group_val,\n",
" cat_features=cat_features\n",
" )\n",
"\n",
"\n",
" model = CatBoostRanker(**params)\n",
" model.fit(train_pool,\n",
" eval_set=val_pool, \n",
" plot=True, \n",
" use_best_model=True\n",
" )\n",
" elif type == 'light':\n",
" label_gain = list(range(len(train_data_split['label'].unique())))\n",
2025-06-04 13:50:02 +08:00
" \n",
2025-05-26 21:34:36 +08:00
" params = {\n",
2025-06-10 15:22:25 +08:00
" 'label_gain': [gain for gain in label_gain],\n",
2025-05-26 21:34:36 +08:00
" 'objective': 'lambdarank',\n",
2025-06-04 13:50:02 +08:00
" 'metric': 'ndcg',\n",
2025-06-04 20:34:17 +08:00
" 'learning_rate': 0.01,\n",
2025-06-10 15:22:25 +08:00
" 'num_leaves': 1024,\n",
" 'min_data_in_leaf': 256,\n",
2025-06-04 13:50:02 +08:00
" # 'max_depth': 10,\n",
2025-06-10 15:22:25 +08:00
" 'max_bin': 1024,\n",
2025-06-04 13:50:02 +08:00
" 'feature_fraction': 0.5,\n",
" 'bagging_fraction': 0.5,\n",
2025-05-26 21:34:36 +08:00
" 'bagging_freq': 5,\n",
2025-06-10 15:22:25 +08:00
" 'lambda_l1': 0.1,\n",
" 'lambda_l2': 10,\n",
2025-05-26 21:34:36 +08:00
" 'boosting': 'gbdt',\n",
" 'verbosity': -1,\n",
" 'extra_trees': True,\n",
" # 'max_position': 5,\n",
2025-06-04 13:50:02 +08:00
" 'ndcg_at': '5',\n",
2025-05-26 21:34:36 +08:00
" 'quant_train_renew_leaf': True,\n",
2025-06-04 13:50:02 +08:00
" 'lambdarank_truncation_level': 10,\n",
2025-06-04 20:34:17 +08:00
" # 'lambdarank_position_bias_regularization': 1,\n",
2025-05-26 21:34:36 +08:00
" 'seed': 7\n",
" }\n",
2025-06-06 17:04:01 +08:00
" feature_contri = [2 if 'concept' in feat else 1 for feat in feature_columns]\n",
" params['feature_contri'] = feature_contri\n",
2025-06-04 13:50:02 +08:00
"\n",
2025-05-26 21:34:36 +08:00
" train_groups = train_data_split.groupby('trade_date').size().tolist()\n",
" val_groups = val_data_split.groupby('trade_date').size().tolist()\n",
"\n",
" categorical_feature = [col for col in feature_columns if 'cat' in col]\n",
" train_dataset = lgb.Dataset(\n",
" X_train, label=y_train, \n",
" group=train_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
" val_dataset = lgb.Dataset(\n",
" X_val, label=y_val, \n",
" group=val_groups,\n",
" categorical_feature=categorical_feature\n",
" )\n",
"\n",
" evals = {}\n",
2025-06-10 15:22:25 +08:00
" callbacks = [lgb.log_evaluation(period=3000),\n",
2025-05-26 21:34:36 +08:00
" lgb.callback.record_evaluation(evals),\n",
2025-06-10 15:22:25 +08:00
" lgb.early_stopping(300, first_metric_only=True)\n",
2025-05-26 21:34:36 +08:00
" ]\n",
2025-05-28 14:16:04 +08:00
" # 训练模型\n",
" model = lgb.train(\n",
2025-06-04 13:50:02 +08:00
" params, train_dataset, num_boost_round=1000,\n",
2025-05-28 14:16:04 +08:00
" valid_sets=[train_dataset, val_dataset], valid_names=['train', 'valid'],\n",
" callbacks=callbacks\n",
" )\n",
"\n",
" # 打印特征重要性(如果需要)\n",
" if True:\n",
" lgb.plot_metric(evals)\n",
" lgb.plot_importance(model, importance_type='split', max_num_features=20)\n",
" plt.show()\n",
"\n",
" # from flaml import AutoML\n",
" # from sklearn.datasets import fetch_california_housing\n",
"\n",
" # # Initialize an AutoML instance\n",
" # model = AutoML()\n",
" # # Specify automl goal and constraint\n",
" # automl_settings = {\n",
" # \"time_budget\": 600, # in seconds\n",
" # \"metric\": \"ndcg@1\",\n",
" # \"task\": \"rank\",\n",
" # \"estimator_list\": [\n",
" # \"catboost\",\n",
" # \"lgbm\",\n",
" # \"xgboost\"\n",
" # ], \n",
" # \"ensemble\": {\n",
" # \"final_estimator\": LGBMRanker(),\n",
" # \"passthrough\": False,\n",
" # },\n",
" # }\n",
" # model.fit(X_train=X_train, y_train=y_train, groups=train_groups,\n",
" # X_val=X_val, y_val=y_val,groups_val=val_groups,\n",
" # mlflow_logging=False, **automl_settings)\n",
2025-05-26 21:34:36 +08:00
"\n",
"\n",
" return model, scaler, None # 返回训练好的模型、scaler 和 pca 对象"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 127,
2025-05-26 21:34:36 +08:00
"id": "c6eb5cd4-e714-420a-ac48-39af3e11ee81",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T15:03:18.426481Z",
"start_time": "2025-04-03T15:02:19.926352Z"
}
},
"outputs": [
{
2025-06-10 15:22:25 +08:00
"name": "stdout",
"output_type": "stream",
"text": [
"train data size: 364000\n",
" ts_code trade_date log_circ_mv\n",
"0 600306.SH 2020-01-02 11.552040\n",
"1 603269.SH 2020-01-02 11.324801\n",
"2 002633.SZ 2020-01-02 11.759023\n",
"3 603991.SH 2020-01-02 11.181150\n",
"4 000691.SZ 2020-01-02 11.677910\n",
"... ... ... ...\n",
"363995 605218.SH 2022-12-30 11.710093\n",
"363996 603519.SH 2022-12-30 12.592329\n",
"363997 600293.SH 2022-12-30 12.593635\n",
"363998 603182.SH 2022-12-30 11.207510\n",
"363999 600749.SH 2022-12-30 12.594148\n",
"\n",
"[364000 rows x 3 columns]\n",
"原始样本数: 364000, 去除标签为空后样本数: 364000\n",
"Training until validation scores don't improve for 300 rounds\n",
"Did not meet early stopping. Best iteration is:\n",
"[728]\ttrain's ndcg@5: 0.764617\tvalid's ndcg@5: 0.577452\n",
"Evaluated only: ndcg@5\n"
2025-05-26 21:34:36 +08:00
]
2025-06-10 15:22:25 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlEAAAHHCAYAAACfqw0dAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAd81JREFUeJzt3Xd4U1UDBvA3STO696a0hZZRNgVq2UjZIuBkyPITFEHFfqggshX8HIADwQGCA0EcoLIsZShSKLJnoXQxukv3SpP7/XFpSuigDW2Ttu/vefrYnDtyco+Ql3POPVciCIIAIiIiIqoRqbErQERERNQQMUQRERERGYAhioiIiMgADFFEREREBmCIIiIiIjIAQxQRERGRARiiiIiIiAzAEEVERERkAIYoIiIiIgMwRBFRtWzcuBESiQRxcXF19h6LFy+GRCJpMOc1tri4OEgkEmzcuNGg4yUSCRYvXlyrdSJqShiiiExMaViRSCQ4fPhwue2CIMDLywsSiQSPPPKIQe/x2WefGfzFSzWzefNmrF692tjVIKI6wBBFZKJUKhU2b95crvzQoUO4ceMGlEqlwec2JERNnDgRBQUF8Pb2Nvh9jeWtt95CQUGBUd67LkOUt7c3CgoKMHHiRIOOLygowFtvvVXLtSJqOhiiiEzU8OHDsW3bNpSUlOiVb968GYGBgXBzc6uXeuTl5QEAZDIZVCpVgxoWK627mZkZVCqVkWtzf4WFhdBqtdXeXyKRQKVSQSaTGfR+KpUKZmZmBh1LRAxRRCZr3LhxSE9PR1hYmK6suLgYP/30E8aPH1/hMVqtFqtXr0a7du2gUqng6uqK559/Hrdv39bt4+PjgwsXLuDQoUO6YcP+/fsDKBtKPHToEF588UW4uLigWbNmetvunRO1e/du9OvXD9bW1rCxsUH37t0r7EG71+HDh9G9e3eoVCq0bNkSn3/+ebl9qprzc+98ntJ5TxcvXsT48eNhb2+P3r1762279/hZs2Zh+/btaN++PZRKJdq1a4c9e/aUe6+DBw+iW7duenWtzjyr/v37Y+fOnYiPj9ddax8fH905JRIJtmzZgrfeeguenp6wsLBAdnY2MjIyMGfOHHTo0AFWVlawsbHBsGHDcObMmftenylTpsDKygo3b97E6NGjYWVlBWdnZ8yZMwcajaZa1zA6OhpTpkyBnZ0dbG1tMXXqVOTn5+sdW1BQgJdffhlOTk6wtrbGo48+ips3b3KeFTUp/CcIkYny8fFBcHAwfvjhBwwbNgyAGFiysrIwduxYfPzxx+WOef7557Fx40ZMnToVL7/8MmJjY/Hpp5/i1KlT+OeffyCXy7F69Wq89NJLsLKywvz58wEArq6ueud58cUX4ezsjIULF+p6cyqyceNGPPvss2jXrh3mzZsHOzs7nDp1Cnv27Kk06AHAuXPnMHjwYDg7O2Px4sUoKSnBokWLytXDEE8++ST8/f2xfPlyCIJQ5b6HDx/GL7/8ghdffBHW1tb4+OOP8fjjjyMhIQGOjo4AgFOnTmHo0KFwd3fHkiVLoNFosHTpUjg7O9+3LvPnz0dWVhZu3LiBVatWAQCsrKz09lm2bBkUCgXmzJmDoqIiKBQKXLx4Edu3b8eTTz4JX19fJCcn4/PPP0e/fv1w8eJFeHh4VPm+Go0GQ4YMQVBQED744APs27cPH374IVq2bIkZM2bct95PPfUUfH19sWLFCpw8eRJfffUVXFxc8L///U+3z5QpU/Djjz9i4sSJeOihh3Do0CGMGDHivucmalQEIjIpX3/9tQBAOH78uPDpp58K1tbWQn5+viAIgvDkk08KAwYMEARBELy9vYURI0bojvv7778FAML333+vd749e/aUK2/Xrp3Qr1+/St+7d+/eQklJSYXbYmNjBUEQhMzMTMHa2loICgoSCgoK9PbVarVVfsbRo0cLKpVKiI+P15VdvHhRkMlkwt1/LcXGxgoAhK+//rrcOQAIixYt0r1etGiRAEAYN25cuX1Lt917vEKhEKKjo3VlZ86cEQAIn3zyia5s5MiRgoWFhXDz5k1d2dWrVwUzM7Ny56zIiBEjBG9v73LlBw4cEAAILVq00LVvqcLCQkGj0eiVxcbGCkqlUli6dKle2b3XZ/LkyQIAvf0EQRC6dOkiBAYGlrsGFV3DZ599Vm+/MWPGCI6OjrrXJ06cEAAIs2fP1ttvypQp5c5J1JhxOI/IhD311FMoKCjAH3/8gZycHPzxxx+V9vBs27YNtra2GDRoENLS0nQ/gYGBsLKywoEDB6r9vtOmTbvvPJuwsDDk5ORg7ty55eYbVTXMpdFosHfvXowePRrNmzfXlbdt2xZDhgypdh0r88ILL1R735CQELRs2VL3umPHjrCxsUFMTIyurvv27cPo0aP1en/8/Px0vYMPavLkyTA3N9crUyqVkEqlujqkp6fDysoKrVu3xsmTJ6t13nuvQ58+fXSfy5Bj09PTkZ2dDQC6Ic8XX3xRb7+XXnqpWucnaiw4nEdkwpydnRESEoLNmzcjPz8fGo0GTzzxRIX7Xr16FVlZWXBxcalwe0pKSrXf19fX9777XLt2DQDQvn37ap8XAFJTU1FQUAB/f/9y21q3bo1du3bV6Hz3qk7dS90d4krZ29vr5pClpKSgoKAAfn5+5farqMwQFdVXq9Xio48+wmeffYbY2Fi9uUylw4xVUalU5YYb7/5c93PvdbG3twcA3L59GzY2NoiPj4dUKi1X99q6JkQNBUMUkYkbP348pk2bhqSkJAwbNgx2dnYV7qfVauHi4oLvv/++wu3VmcNT6t6eEWOprEfr3gnSd6tJ3SvrbRPuM5eqNlVU3+XLl2PBggV49tlnsWzZMjg4OEAqlWL27NnVunvP0Lv17nd8fV4XooaAIYrIxI0ZMwbPP/88jh49iq1bt1a6X8uWLbFv3z706tXrvkGiNpYpKB0GO3/+fI16IJydnWFubo6rV6+W2xYVFaX3urQHJDMzU688Pj6+hrU1jIuLC1QqFaKjo8ttq6isIoZc659++gkDBgzA+vXr9cozMzPh5ORU4/PVNm9vb2i1WsTGxur1KFb3mhA1FpwTRWTirKyssHbtWixevBgjR46sdL+nnnoKGo0Gy5YtK7etpKREL4hYWlqWCyY1NXjwYFhbW2PFihUoLCzU21ZVj4VMJsOQIUOwfft2JCQk6MovXbqEvXv36u1rY2MDJycn/PXXX3rln3322QPVvbpkMhlCQkKwfft23Lp1S1ceHR2N3bt3V+sclpaWyMrKqvH73nsNt23bhps3b9boPHWldO7ave3wySefGKM6REbDniiiBmDy5Mn33adfv354/vnnsWLFCpw+fRqDBw+GXC7H1atXsW3bNnz00Ue6+VSBgYFYu3Yt3n77bfj5+cHFxQUPP/xwjepkY2ODVatW4bnnnkP37t11azOdOXMG+fn52LRpU6XHLlmyBHv27EGfPn3w4osvoqSkBJ988gnatWuHs2fP6u373HPP4d1338Vzzz2Hbt264a+//sKVK1dqVNcHsXjxYvz555/o1asXZsyYAY1Gg08//RTt27fH6dOn73t8YGAgtm7ditDQUHTv3h1WVlZVhmEAeOSRR7B06VJMnToVPXv2xLlz5/D999+jRYsWtfSpHkxgYCAef/xxrF69Gunp6bolDkrbpSEtyEr0IBiiiBqRdevWITAwEJ9//jnefPNNmJmZwcfHB8888wx69eql22/hwoWIj4/He++9h5ycHPTr16/GIQoA/vOf/8DFxQXvvvsuli1bBrlcjjZt2uDVV1+t8riOHTti7969CA0NxcKFC9GsWTMsWbIEiYmJ5ULUwoULkZqaip9++gk//vgjhg0bht27d1c6gb62BQYGYvfu3ZgzZw4WLFgALy8vLF26FJcuXcLly5fve/yLL76I06dP4+uvv8aqVavg7e193xD15ptvIi8vD5s3b8bWrVvRtWtX7Ny5E3Pnzq2tj/XAvvnmG7i5ueGHH37Ar7/+ipCQEGzduhWtW7duEKvDE9UGicCZgkRENTZ69GhcuHChwrldTdXp06fRpUsXfPfdd5g
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAuwAAAHHCAYAAADkow2UAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Xtczvf/+PHH1bl0FaWkiUKSY835zESRsLHatDnMmTbHJIeUwxCR42YONYctZ5uJKZNtjpkxZ3NI+WDYkGSpruv3h1/X16WjhnR53m+366br9X6936/n832lXr3er/frrVCr1WqEEEIIIYQQryW9kg5ACCGEEEIIkT/psAshhBBCCPEakw67EEIIIYQQrzHpsAshhBBCCPEakw67EEIIIYQQrzHpsAshhBBCCPEakw67EEIIIYQQrzHpsAshhBBCCPEakw67EEIIIYQQrzHpsAshhBCvUHR0NAqFgqSkpJIORQhRSkiHXQghxEuV00HN6zV+/PiX0uaBAwcIDQ3l3r17L+X4b7L09HRCQ0NJSEgo6VCEeGMYlHQAQggh3gxTp07FyclJq6xOnTovpa0DBw4QFhZG3759KVu27Etpo7g+/vhjPvjgA4yNjUs6lGJJT08nLCwMgLZt25ZsMEK8IaTDLoQQ4pXo1KkTDRs2LOkw/pOHDx9SpkyZ/3QMfX199PX1X1BEr45KpeLx48clHYYQbySZEiOEEOK1sHPnTlq1akWZMmVQKpV4e3tz+vRprTp//PEHffv2pWrVqpiYmGBnZ8cnn3zC33//rakTGhpKYGAgAE5OTprpN0lJSSQlJaFQKIiOjs7VvkKhIDQ0VOs4CoWCM2fO0KtXL8qVK0fLli0129euXUuDBg0wNTXFysqKDz74gJSUlELzzGsOu6OjI126dCEhIYGGDRtiampK3bp1NdNOtmzZQt26dTExMaFBgwb8/vvvWsfs27cv5ubmXL58GU9PT8qUKYO9vT1Tp05FrVZr1X348CFjxozBwcEBY2NjXFxcmDt3bq56CoWCgIAA1q1bR+3atTE2NubLL7/ExsYGgLCwMM25zTlvRfl8nj63Fy9e1FwFsbS0pF+/fqSnp+c6Z2vXrqVx48aYmZlRrlw5Wrduze7du7XqFOX7R4jSSkbYhRBCvBL379/nzp07WmXly5cHYM2aNfTp0wdPT09mz55Neno6X3zxBS1btuT333/H0dERgLi4OC5fvky/fv2ws7Pj9OnTfPXVV5w+fZpDhw6hUCh47733uHDhAt9++y3z58/XtGFjY8Pt27efO+73338fZ2dnPv/8c02ndsaMGUyePBlfX18GDBjA7du3WbRoEa1bt+b3338v1jScixcv0qtXLwYPHsxHH33E3Llz8fHx4csvv2TChAkMGzYMgJkzZ+Lr68v58+fR0/u/cbfs7Gy8vLxo2rQp4eHh7Nq1iylTppCVlcXUqVMBUKvVdO3alb1799K/f3/c3Nz48ccfCQwM5H//+x/z58/Xiumnn35iw4YNBAQEUL58eerXr88XX3zB0KFDeffdd3nvvfcAqFevHlC0z+dpvr6+ODk5MXPmTI4dO8aKFSuwtbVl9uzZmjphYWGEhobSvHlzpk6dipGREYcPH+ann36iY8eOQNG/f4QotdRCCCHESxQVFaUG8nyp1Wr1gwcP1GXLllUPHDhQa7+bN2+qLS0ttcrT09NzHf/bb79VA+qff/5ZUzZnzhw1oL5y5YpW3StXrqgBdVRUVK7jAOopU6Zo3k+ZMkUNqD/88EOteklJSWp9fX31jBkztMpPnjypNjAwyFWe3/l4OrYqVaqoAfWBAwc0ZT/++KMaUJuamqqvXr2qKV+2bJkaUO/du1dT1qdPHzWg/vTTTzVlKpVK7e3trTYyMlLfvn1brVar1du2bVMD6unTp2vF1LNnT7VCoVBfvHhR63zo6empT58+rVX39u3buc5VjqJ+Pjnn9pNPPtGq++6776qtra017//880+1np6e+t1331VnZ2dr1VWpVGq1+vm+f4QorWRKjBBCiFdiyZIlxMXFab3gyajsvXv3+PDDD7lz547mpa+vT5MmTdi7d6/mGKamppqv//33X+7cuUPTpk0BOHbs2EuJe8iQIVrvt2zZgkqlwtfXVyteOzs7nJ2dteJ9HrVq1aJZs2aa902aNAHgnXfeoXLlyrnKL1++nOsYAQEBmq9zprQ8fvyY+Ph4AGJjY9HX1+ezzz7T2m/MmDGo1Wp27typVd6mTRtq1apV5Bye9/N59ty2atWKv//+m9TUVAC2bduGSqUiJCRE62pCTn7wfN8/QpRWMiVGCCHEK9G4ceM8bzr9888/gScd07xYWFhovv7nn38ICwsjJiaGW7duadW7f//+C4z2/zy7ss2ff/6JWq3G2dk5z/qGhobFaufpTjmApaUlAA4ODnmW3717V6tcT0+PqlWrapXVqFEDQDNf/urVq9jb26NUKrXqubq6arY/7dncC/O8n8+zOZcrVw54kpuFhQWXLl1CT0+vwD8anuf7R4jSSjrsQgghSpRKpQKezEO2s7PLtd3A4P9+Vfn6+nLgwAECAwNxc3PD3NwclUqFl5eX5jgFeXYOdY7s7Ox893l61DgnXoVCwc6dO/Nc7cXc3LzQOPKS38ox+ZWrn7lJ9GV4NvfCPO/n8yJye57vHyFKK/kuFkIIUaKqVasGgK2tLR4eHvnWu3v3Lnv27CEsLIyQkBBNec4I69Py65jnjOA++0ClZ0eWC4tXrVbj5OSkGcF+HahUKi5fvqwV04ULFwA0N11WqVKF+Ph4Hjx4oDXKfu7cOc32wuR3bp/n8ymqatWqoVKpOHPmDG5ubvnWgcK/f4QozWQOuxBCiBLl6emJhYUFn3/+OZmZmbm256zskjMa++zoa2RkZK59ctZKf7ZjbmFhQfny5fn555+1ypcuXVrkeN977z309fUJCwvLFYtarc61hOGrtHjxYq1YFi9ejKGhIe3btwegc+fOZGdna9UDmD9/PgqFgk6dOhXahpmZGZD73D7P51NU3bt3R09Pj6lTp+Yaoc9pp6jfP0KUZjLCLoQQokRZWFjwxRdf8PHHH/P222/zwQcfYGNjQ3JyMjt27KBFixYsXrwYCwsLWrduTXh4OJmZmbz11lvs3r2bK1eu5DpmgwYNAJg4cSIffPABhoaG+Pj4UKZMGQYMGMCsWbMYMGAADRs25Oeff9aMRBdFtWrVmD59OsHBwSQlJdG9e3eUSiVXrlxh69atDBo0iLFjx76w81NUJiYm7Nq1iz59+tCkSRN27tzJjh07mDBhgmbtdB8fH9q1a8fEiRNJSkqifv367N69m++++46RI0dqRqsLYmpqSq1atVi/fj01atTAysqKOnXqUKdOnSJ/PkVVvXp1Jk6cyLRp02jVqhXvvfcexsbGJCYmYm9vz8yZM4v8/SNEqVZCq9MIIYR4Q+QsY5iYmFhgvb1796o9PT3VlpaWahMTE3W1atXUffv2VR89elRT59q1a+p3331XXbZsWbWlpaX6/fffV1+/fj3PZQanTZumfuutt9R6enpayyimp6er+/fvr7a0tFQrlUq1r6+v+tatW/ku65izJOKzNm/erG7ZsqW6TJky6jJlyqhr1qypHj58uPr8+fNFOh/PLuvo7e2dqy6gHj58uFZZztKUc+bM0ZT16dNHXaZMGfWlS5fUHTt2VJuZmakrVKignjJlSq7lEB88eKAeNWqU2t7eXm1oaKh2dnZWz5kzR7NMYkFt5zhw4IC6QYMGaiMjI63zVtTPJ79zm9e5UavV6lWrVqnd3d3VxsbG6nLlyqnbtGmjjouL06pTlO8fIUorhVr9Cu5aEUIIIcRL07dvXzZt2kRaWlpJhyKEeAlkDrsQQgghhBCvMemwCyGEEEII8RqTDrsQQgghhBCvMZnDLoQQQgghxGtMRtiFEEIIIYR4jUmHXQghhBBCiNeYPDhJCB2gUqm4fv06SqUy38eGCyGEEOL1olarefDgAfb29ujp5T+OLh12IXTA9evXcXBwKOkwhBBCCFEMKSkpVKpUKd/t0mEXQgcolUoArly
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2025-05-26 21:34:36 +08:00
}
],
"source": [
"\n",
"gc.collect()\n",
"\n",
"use_pca = False\n",
2025-06-04 13:50:02 +08:00
"type = 'light'\n",
2025-06-04 20:34:17 +08:00
"\n",
2025-06-10 15:22:25 +08:00
"train_data['label2'] = train_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(500, 'total_mv'))['future_return'].transform(\n",
" lambda x: pd.qcut(x, q=20, labels=False, duplicates='drop')\n",
2025-06-04 20:34:17 +08:00
")\n",
2025-06-10 15:22:25 +08:00
"test_data['label2'] = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(500, 'total_mv'))['future_return'].transform(\n",
" lambda x: pd.qcut(x, q=20, labels=False, duplicates='drop')\n",
2025-06-04 20:34:17 +08:00
")\n",
"\n",
2025-05-26 21:34:36 +08:00
"# feature_contri = [2 if feat.startswith('act_factor') or 'buy' in feat or 'sell' in feat else 1 for feat in feature_columns]\n",
"# light_params['feature_contri'] = feature_contri\n",
"# print(f'feature_contri: {feature_contri}')\n",
"model, scaler, pca = train_model(train_data\n",
2025-06-10 15:22:25 +08:00
" # .dropna(subset=['label'])\n",
" .groupby('trade_date', group_keys=False)\n",
" .apply(lambda x: x.nsmallest(500, 'total_mv'))\n",
2025-05-26 21:34:36 +08:00
" .merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
" .merge(index_data, on='trade_date', how='left'), \n",
2025-06-04 20:34:17 +08:00
" feature_columns, type=type, target_column='label2')\n"
2025-05-26 21:34:36 +08:00
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 128,
2025-05-26 21:34:36 +08:00
"id": "5d1522a7538db91b",
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-03T15:04:39.656944Z",
"start_time": "2025-04-03T15:04:39.298483Z"
}
},
"outputs": [],
"source": [
2025-06-10 15:22:25 +08:00
"score_df = test_data.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(500, 'total_mv'))\n",
2025-05-26 21:34:36 +08:00
"# score_df = fill_nan_with_daily_median(score_df, ['pe_ttm'])\n",
"# score_df = score_df[score_df['pe_ttm'] > 0]\n",
"score_df = score_df.merge(industry_df, on=['cat_l2_code', 'trade_date'], how='left')\n",
"score_df = score_df.merge(index_data, on='trade_date', how='left')\n",
"# score_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(50, 'total_mv')).reset_index()\n",
"numeric_columns = score_df.select_dtypes(include=['float64', 'int64']).columns\n",
"numeric_columns = [col for col in feature_columns if col in numeric_columns]\n",
"\n",
"if type == 'cat':\n",
" score_df['score'] = model.predict(score_df[feature_columns])\n",
"elif type == 'light':\n",
" score_df['score'] = model.predict(score_df[feature_columns])\n",
"score_df['score_ranks'] = score_df.groupby('trade_date')['score'].rank(ascending=True)\n",
"\n",
"score_df = score_df.groupby('trade_date', group_keys=False).apply(\n",
" lambda x: \n",
" x[\n",
" # (x['score'] <= x['score'].quantile(0.99)) & \n",
" (x['score'] >= x['score'].quantile(0.90))\n",
" ] # 计算90%分位数作为阈值,筛选分数>=阈值的行\n",
").reset_index(drop=True) # drop=True 避免添加旧索引列\n",
"# df_to_drop = score_df.loc[score_df.groupby('trade_date')['score'].idxmax()]\n",
"# score_df = score_df.drop(df_to_drop.index)\n",
2025-06-10 15:22:25 +08:00
"save_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nlargest(2, 'score')).reset_index()\n",
2025-05-26 21:34:36 +08:00
"# save_df = score_df.groupby('trade_date', group_keys=False).apply(lambda x: x.nsmallest(2, 'total_mv')).reset_index()\n",
"save_df = save_df.sort_values(['trade_date', 'score'])\n",
2025-06-04 13:50:02 +08:00
"save_df[['trade_date', 'score', 'ts_code']].to_csv('predictions_test.tsv', index=False)\n"
2025-06-01 15:59:29 +08:00
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 129,
2025-06-04 20:34:17 +08:00
"id": "fed2d6c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-01-03 00:00:00\n"
]
}
],
"source": [
"print(test_data['trade_date'].min())"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 130,
2025-06-01 15:59:29 +08:00
"id": "1f3c1331",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"成功连接到 Redis 服务器: 140.143.91.66:6389, 数据库 0\n",
"DataFrame 已使用 Pickle 序列化并写入 Redis, 键为 'save_df'\n",
"从 Redis 读取到的 Pickle 序列化数据 (前 20 字节):\n",
2025-06-10 15:22:25 +08:00
"b'\\x80\\x04\\x95\\x16u\\x00\\x00\\x00\\x00\\x00\\x00\\x8c\\x11pandas.'\n",
2025-06-01 15:59:29 +08:00
"\n",
"从 Redis 加载的 DataFrame (使用 Pickle):\n",
2025-06-04 20:34:17 +08:00
" index ts_code trade_date open close high low vol \\\n",
2025-06-10 15:22:25 +08:00
"1 35 603133.SH 2023-01-03 12.16 12.15 12.31 11.92 -0.734431 \n",
"0 30 603321.SH 2023-01-03 7.25 7.51 7.52 7.20 -0.849729 \n",
"3 99 603090.SH 2023-01-04 30.06 30.15 30.23 29.53 -1.002912 \n",
"2 79 603321.SH 2023-01-04 7.57 7.56 7.59 7.49 -0.889013 \n",
"5 125 002963.SZ 2023-01-05 14.91 14.65 14.91 14.53 -1.006545 \n",
2025-06-04 20:34:17 +08:00
"... ... ... ... ... ... ... ... ... \n",
2025-06-10 15:22:25 +08:00
"1164 29102 603177.SH 2025-06-04 8.61 8.67 8.69 8.54 -0.947619 \n",
"1167 29151 001211.SZ 2025-06-05 23.48 23.32 23.68 23.20 -1.106178 \n",
"1166 29152 603177.SH 2025-06-05 8.63 8.67 8.73 8.58 -0.941981 \n",
"1169 29202 603177.SH 2025-06-06 8.73 8.80 8.81 8.61 -0.895431 \n",
"1168 29207 605567.SH 2025-06-06 10.13 10.18 10.22 10.01 -0.825091 \n",
2025-06-04 20:34:17 +08:00
"\n",
" pct_chg amount ... 000905.SH_up_ratio_20d \\\n",
2025-06-10 15:22:25 +08:00
"1 -1.037757 36429.924 ... 0.30 \n",
"0 0.854849 15236.182 ... 0.30 \n",
"3 -0.214829 23633.991 ... 0.30 \n",
"2 -0.005288 12436.182 ... 0.30 \n",
"5 -0.614732 12390.384 ... 0.35 \n",
2025-06-04 20:34:17 +08:00
"... ... ... ... ... \n",
2025-06-10 15:22:25 +08:00
"1164 0.303152 20382.162 ... 0.60 \n",
"1167 -0.095098 16985.707 ... 0.60 \n",
"1166 0.143034 22574.106 ... 0.60 \n",
"1169 0.707494 25728.599 ... 0.55 \n",
"1168 0.407593 37040.344 ... 0.55 \n",
2025-06-01 15:59:29 +08:00
"\n",
2025-06-04 20:34:17 +08:00
" 399006.SZ_up_ratio_20d 000852.SH_volatility 000905.SH_volatility \\\n",
"1 0.40 1.036997 0.828596 \n",
"0 0.40 1.036997 0.828596 \n",
2025-06-10 15:22:25 +08:00
"3 0.35 1.037707 0.828639 \n",
"2 0.35 1.037707 0.828639 \n",
"5 0.35 1.071637 0.869955 \n",
2025-06-04 20:34:17 +08:00
"... ... ... ... \n",
2025-06-10 15:22:25 +08:00
"1164 0.45 0.942048 0.748797 \n",
"1167 0.45 0.954604 0.757642 \n",
"1166 0.45 0.954604 0.757642 \n",
"1169 0.40 0.941305 0.752701 \n",
"1168 0.40 0.941305 0.752701 \n",
2025-06-01 15:59:29 +08:00
"\n",
2025-06-04 20:34:17 +08:00
" 399006.SZ_volatility 000852.SH_volume_change_rate \\\n",
"1 0.935322 5.203088 \n",
"0 0.935322 5.203088 \n",
2025-06-10 15:22:25 +08:00
"3 0.938230 4.492401 \n",
"2 0.938230 4.492401 \n",
"5 1.120001 -1.639926 \n",
2025-06-04 20:34:17 +08:00
"... ... ... \n",
2025-06-10 15:22:25 +08:00
"1164 1.132207 -1.062074 \n",
"1167 1.154128 9.866900 \n",
"1166 1.154128 9.866900 \n",
"1169 1.103436 -4.268643 \n",
"1168 1.103436 -4.268643 \n",
2025-06-01 15:59:29 +08:00
"\n",
2025-06-04 20:34:17 +08:00
" 000905.SH_volume_change_rate 399006.SZ_volume_change_rate score \\\n",
2025-06-10 15:22:25 +08:00
"1 -0.750721 8.827360 0.527536 \n",
"0 -0.750721 8.827360 0.675714 \n",
"3 -0.552539 5.415982 0.524954 \n",
"2 -0.552539 5.415982 0.628866 \n",
"5 1.034360 1.155365 0.566081 \n",
2025-06-04 20:34:17 +08:00
"... ... ... ... \n",
2025-06-10 15:22:25 +08:00
"1164 -4.589323 1.251375 0.741102 \n",
"1167 4.793307 14.960862 0.776248 \n",
"1166 4.793307 14.960862 0.805704 \n",
"1169 -8.367502 -0.972266 0.705210 \n",
"1168 -8.367502 -0.972266 0.914980 \n",
2025-06-01 15:59:29 +08:00
"\n",
2025-06-04 20:34:17 +08:00
" score_ranks \n",
2025-06-10 15:22:25 +08:00
"1 499.0 \n",
"0 500.0 \n",
"3 499.0 \n",
"2 500.0 \n",
"5 499.0 \n",
2025-06-04 20:34:17 +08:00
"... ... \n",
2025-06-10 15:22:25 +08:00
"1164 500.0 \n",
"1167 499.0 \n",
"1166 500.0 \n",
"1169 499.0 \n",
"1168 500.0 \n",
2025-06-01 15:59:29 +08:00
"\n",
2025-06-10 15:22:25 +08:00
"[1170 rows x 251 columns]\n",
2025-06-01 15:59:29 +08:00
"\n",
"验证成功:原始 DataFrame 和从 Redis 加载的 DataFrame 一致。\n",
"\n",
"清理了 Redis 中键 'save_df' 的数据。\n"
]
}
],
"source": [
"import redis\n",
"import pickle\n",
"\n",
"redis_host = '140.143.91.66'\n",
"redis_port = 6389\n",
"redis_db = 0\n",
"redis_key = 'save_df'\n",
"\n",
"try:\n",
" # 1. 连接到 Redis 服务器\n",
" r = redis.Redis(host=redis_host, port=redis_port, db=redis_db, password='Redis520102')\n",
" r.ping()\n",
" print(f\"\\n成功连接到 Redis 服务器: {redis_host}:{redis_port},数据库 {redis_db}\")\n",
"\n",
" # 2. 将 DataFrame 写入 Redis (使用 Pickle 序列化)\n",
" df_serialized = pickle.dumps(save_df)\n",
" r.set(redis_key, df_serialized)\n",
" print(f\"DataFrame 已使用 Pickle 序列化并写入 Redis, 键为 '{redis_key}'\")\n",
"\n",
" # 3. 从 Redis 读取数据 (获取 Pickle 序列化的字节流)\n",
" retrieved_serialized = r.get(redis_key)\n",
"\n",
" if retrieved_serialized:\n",
" print(f\"从 Redis 读取到的 Pickle 序列化数据 (前 20 字节):\")\n",
" print(retrieved_serialized[:20])\n",
"\n",
" # 4. 使用 Pickle 反序列化回 Pandas DataFrame\n",
" loaded_df = pickle.loads(retrieved_serialized)\n",
" print(\"\\n从 Redis 加载的 DataFrame (使用 Pickle):\")\n",
" print(loaded_df)\n",
"\n",
" # 5. 验证原始 DataFrame 和加载的 DataFrame 是否一致\n",
" if save_df.equals(loaded_df):\n",
" print(\"\\n验证成功: 原始 DataFrame 和从 Redis 加载的 DataFrame 一致。\")\n",
" else:\n",
" print(\"\\n验证失败: 原始 DataFrame 和从 Redis 加载的 DataFrame 不一致!\")\n",
"\n",
" else:\n",
" print(f\"错误:无法从 Redis 获取键 '{redis_key}' 的值。\")\n",
"\n",
" # 6. 清理测试数据 (可选)\n",
" r.delete(redis_key)\n",
" print(f\"\\n清理了 Redis 中键 '{redis_key}' 的数据。\")\n",
"\n",
"except redis.exceptions.ConnectionError as e:\n",
" print(f\"无法连接到 Redis 服务器: {e}\")\n",
" print(\"请确保您的 Redis 服务器已启动并且主机和端口配置正确。\")\n",
"except redis.exceptions.TimeoutError as e:\n",
" print(f\"连接 Redis 服务器超时: {e}\")\n",
" print(\"请检查您的网络连接和 Redis 服务器状态。\")\n",
"except Exception as e:\n",
" print(f\"测试 Redis 时发生未知错误: {e}\")\n",
" print(f\"测试 Redis 时发生未知错误: {e}\")"
2025-05-26 21:34:36 +08:00
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 131,
2025-05-26 21:34:36 +08:00
"id": "09b1799e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-10 15:22:25 +08:00
"207\n",
"['vol', 'pct_chg', 'turnover_rate', 'volume_ratio', 'winner_rate', 'cat_hot_concept_stock', 'concept_rank_pct_chg', 'concept_rank_turnover_rate', 'concept_rank_volume_ratio', 'holder_net_change_sum_10d', 'holder_increase_days_10d', 'holder_decrease_days_10d', 'holder_any_increase_flag_10d', 'holder_any_decrease_flag_10d', 'cat_senti_mom_vol_spike', 'cat_senti_pre_breakout', 'ts_turnover_rate_acceleration_5_20', 'ts_vol_sustain_10_30', 'cs_amount_outlier_10', 'ts_ff_to_total_turnover_ratio', 'ts_price_volume_trend_coherence_5_20', 'ts_ff_turnover_rate_surge_10', 'undist_profit_ps', 'ocfps', 'AR', 'BR', 'AR_BR', 'cashflow_to_ev_factor', 'book_to_price_ratio', 'turnover_rate_mean_5', 'variance_20', 'bbi_ratio_factor', 'daily_deviation', 'lg_elg_net_buy_vol', 'flow_lg_elg_intensity', 'sm_net_buy_vol', 'total_buy_vol', 'lg_elg_buy_prop', 'flow_struct_buy_change', 'lg_elg_net_buy_vol_change', 'flow_lg_elg_accel', 'chip_concentration_range', 'chip_skewness', 'floating_chip_proxy', 'cost_support_15pct_change', 'cat_winner_price_zone', 'flow_chip_consistency', 'profit_taking_vs_absorb', 'cat_is_positive', 'upside_vol', 'downside_vol', 'vol_ratio', 'return_skew', 'return_kurtosis', 'volume_change_rate', 'cat_volume_breakout', 'turnover_deviation', 'cat_turnover_spike', 'avg_volume_ratio', 'cat_volume_ratio_breakout', 'vol_spike', 'vol_std_5', 'atr_14', 'atr_6', 'obv', 'maobv_6', 'rsi_3', 'return_5', 'return_20', 'std_return_5', 'std_return_90', 'std_return_90_2', 'act_factor1', 'act_factor2', 'act_factor3', 'act_factor4', 'rank_act_factor1', 'rank_act_factor2', 'rank_act_factor3', 'cov', 'delta_cov', 'alpha_22_improved', 'alpha_003', 'alpha_007', 'alpha_013', 'vol_break', 'weight_roc5', 'smallcap_concentration', 'cost_stability', 'high_cost_break_days', 'liquidity_risk', 'turnover_std', 'mv_volatility', 'volume_growth', 'mv_growth', 'momentum_factor', 'resonance_factor', 'log_close', 'cat_vol_spike', 'up', 'down', 'obv_maobv_6', 'std_return_5_over_std_return_90', 'std_return_90_minus_std_return_90_2', 'cat_af2', 'cat_af3', 'cat_af4', 'act_factor5', 'act_factor6', 'active_buy_volume_large', 'active_buy_volume_big', 'active_buy_volume_small', 'buy_lg_vol_minus_sell_lg_vol', 'buy_elg_vol_minus_sell_elg_vol', 'ctrl_strength', 'low_cost_dev', 'asymmetry', 'lock_factor', 'cat_vol_break', 'cost_atr_adj', 'cat_golden_resonance', 'mv_turnover_ratio', 'mv_adjusted_volume', 'mv_weighted_turnover', 'nonlinear_mv_volume', 'mv_volume_ratio', 'mv_momentum', 'senti_strong_inflow', 'lg_flow_mom_corr_20_60', 'lg_flow_accel', 'profit_pressure', 'underwater_resistance', 'cost_conc_std_20', 'profit_decay_20', 'vol_amp_loss_20', 'vol_drop_profit_cnt_5', 'lg_flow_vol_interact_20', 'cost_break_confirm_cnt_5', 'atr_norm_channel_pos_14', 'turnover_diff_skew_20', 'lg_sm_flow_diverge_20', 'pullback_strong_20_20', 'vol_wgt_hist_pos_20', 'vol_adj_roc_20', 'cs_rank_net_lg_flow_val', 'cs_rank_elg_buy_ratio', 'cs_rank_rel_profit_margin', 'cs_rank_cost_breadth', 'cs_rank_dist_to_upper_cost', 'cs_rank_winner_rate', 'cs_rank_intraday_range', 'cs_rank_close_pos_in_range', 'cs_rank_pos_in_hist_range', 'cs_rank_vol_x_profit_margin', 'cs_rank_lg_flow_price_concordance', 'cs_rank_turnover_per_winner', 'cs_rank_volume_ratio', 'cs_rank_elg_buy_sell_sm_ratio', 'cs_rank_cost_dist_vol_ratio', 'industry_obv', 'industry_return_5', 'industry_return_20', 'industry__ema_5', 'industry__ema_13', 'industry__ema_20', 'industry__ema_60', 'industry_act_factor1', 'industry_act_factor2', 'industry_act_factor3', 'industry_act_factor4', 'industry_act_factor5', 'industry_act_factor6', 'industry_rank_act_factor1', 'industry_rank_act_factor2', 'industry_rank_act_factor3', 'industry_return_5_percentile', 'industry_return_20_percentile', '000852.SH_MACD', '000905.SH_MACD', '399006.SZ_MACD', '000852.SH_MACD_hist', '000905.SH_MACD_hist', '399006.SZ_MACD_hist', '000852.SH_RSI', '000905.SH_RSI', '399006.SZ_RSI', '000852.SH_Signal_line', '000905.SH_Signal_line', '399006.SZ_Signal_line', '000852.SH_amount_change_rate', '000905.SH_amount_change_rate', '399006.SZ_amount_change_rate', '000852.SH_
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"print(len(feature_columns))\n",
"print(feature_columns)"
]
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 132,
2025-05-26 21:34:36 +08:00
"id": "bceabd1f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-06-04 20:34:17 +08:00
"警告: DataFrame 中没有 'group_id' 列。假设整个 DataFrame 是一个需要排序的组。\n",
"\n",
"NDCG 结果\n",
2025-06-10 15:22:25 +08:00
"{'ndcg@1': np.float64(0.5102040816326531), 'ndcg@3': np.float64(0.6258109632386283), 'ndcg@5': np.float64(0.6760105470779576)}\n"
2025-05-26 21:34:36 +08:00
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"def calculate_ndcg(df: pd.DataFrame, score_col: str, label_col: str, group_id: str = 'trade_date', k_values: list = [1, 3, 5, 10]):\n",
" \"\"\"\n",
" 计算 DataFrame 中 score 列和 label 列的 NDCG 值。\n",
"\n",
" Args:\n",
" df (pd.DataFrame): 包含 score (排序学习预测分数) 和 label (相关性标签) 的 DataFrame。\n",
" 假设每个需要排序的组(例如,每天的股票)在 DataFrame 中是连续的。\n",
" score_col (str): 包含模型预测分数的列名。\n",
" label_col (str): 包含相关性标签的列名。标签值越高表示相关性越高。\n",
" k_values (list): 一个整数列表,表示计算 NDCG 的 top-k 值。\n",
" 例如,[1, 3, 5] 将计算 NDCG@1, NDCG@3 和 NDCG@5。\n",
"\n",
" Returns:\n",
" dict: 一个字典,包含每个 k 值对应的平均 NDCG 值。\n",
" 例如: {'ndcg@1': 0.85, 'ndcg@3': 0.78, 'ndcg@5': 0.72, 'ndcg@10': 0.65}\n",
" \"\"\"\n",
" ndcg_scores = {f'ndcg@{k}': [] for k in k_values}\n",
"\n",
" def dcg_at_k(r, k):\n",
2025-06-04 20:34:17 +08:00
" r = np.asarray(r)[:k] if len(r) > 0 else np.zeros(k)\n",
2025-05-26 21:34:36 +08:00
" return np.sum(r / np.log2(np.arange(2, r.size + 2)))\n",
"\n",
" def ndcg_at_k(r, k):\n",
" dcg_max = dcg_at_k(sorted(r, reverse=True), k)\n",
" if not dcg_max:\n",
" return 0.\n",
" return dcg_at_k(r, k) / dcg_max\n",
"\n",
" # 假设 DataFrame 已经按照需要排序的组(例如,'trade_date')进行了分组,\n",
" # 并且每个组内的顺序不重要,我们只需要计算每个组的 NDCG。\n",
" # 如果需要按特定组计算 NDCG, 请先对 DataFrame 进行分组。\n",
" if group_id not in df.columns:\n",
" print(\"警告: DataFrame 中没有 'group_id' 列。假设整个 DataFrame 是一个需要排序的组。\")\n",
" group_df = df.sort_values(by=score_col, ascending=False)\n",
" relevant_labels = group_df[label_col].values\n",
" for k in k_values:\n",
" ndcg_scores[f'ndcg@{k}'].append(ndcg_at_k(relevant_labels, k))\n",
" else:\n",
" for _, group_df in df.groupby(group_id):\n",
" group_df_sorted = group_df.sort_values(by=score_col, ascending=False)\n",
" relevant_labels = group_df_sorted[label_col].values\n",
" for k in k_values:\n",
" ndcg_scores[f'ndcg@{k}'].append(ndcg_at_k(relevant_labels, k))\n",
"\n",
" avg_ndcg = {k: np.mean(v) if v else np.nan for k, v in ndcg_scores.items()}\n",
" return avg_ndcg\n",
"\n",
"\n",
"ndcg_results_single_group = calculate_ndcg(score_df, score_col='score', label_col='label', k_values=[1, 3, 5], group_id=None)\n",
"print(\"\\nNDCG 结果\")\n",
"print(ndcg_results_single_group)\n"
]
2025-05-29 20:41:18 +08:00
},
{
"cell_type": "code",
2025-06-10 15:22:25 +08:00
"execution_count": 133,
2025-05-29 20:41:18 +08:00
"id": "44f64679",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ts_code trade_date open close high low vol pct_chg \\\n",
2025-06-10 15:22:25 +08:00
"1636230 002652.SZ 2019-01-02 19.59 19.64 19.89 19.28 20196.79 1.03 \n",
"1636231 002652.SZ 2019-01-03 19.74 19.44 19.84 19.33 15731.99 -1.02 \n",
"1636232 002652.SZ 2019-01-04 19.33 19.94 19.99 19.08 21099.93 2.57 \n",
"1636233 002652.SZ 2019-01-07 20.04 21.95 21.95 20.04 83534.19 10.08 \n",
"1636234 002652.SZ 2019-01-08 23.21 21.65 23.87 21.65 149377.97 -1.37 \n",
2025-05-29 20:41:18 +08:00
"... ... ... ... ... ... ... ... ... \n",
2025-06-10 15:22:25 +08:00
"1637782 002652.SZ 2025-05-30 15.36 15.11 15.41 14.95 107732.00 -1.63 \n",
"1637783 002652.SZ 2025-06-03 15.11 15.41 15.71 14.90 163459.00 1.99 \n",
"1637784 002652.SZ 2025-06-04 15.41 15.71 15.71 15.36 140521.00 1.95 \n",
"1637785 002652.SZ 2025-06-05 15.76 15.61 16.51 15.51 246994.40 -0.64 \n",
"1637786 002652.SZ 2025-06-06 15.71 16.01 16.36 15.66 228370.40 2.56 \n",
2025-05-29 20:41:18 +08:00
"\n",
2025-06-10 15:22:25 +08:00
" amount turnover_rate ... cs_rank_pos_in_hist_range \\\n",
"1636230 7867.047 0.3964 ... 0.730643 \n",
"1636231 6121.460 0.3088 ... 0.732202 \n",
"1636232 8245.083 0.4141 ... 0.727920 \n",
"1636233 35514.117 1.6394 ... 0.725182 \n",
"1636234 67160.354 2.9317 ... 0.726095 \n",
"... ... ... ... ... \n",
"1637782 32385.927 2.1039 ... 0.657143 \n",
"1637783 50114.396 3.1922 ... 0.657133 \n",
"1637784 43515.970 2.7442 ... 0.653207 \n",
"1637785 77669.905 4.8235 ... 0.652427 \n",
"1637786 72598.629 4.4598 ... 0.653092 \n",
2025-05-29 20:41:18 +08:00
"\n",
2025-06-10 15:22:25 +08:00
" cs_rank_vol_x_profit_margin cs_rank_lg_flow_price_concordance \\\n",
"1636230 0.608839 0.203142 \n",
"1636231 0.586710 0.156684 \n",
"1636232 0.682847 0.184009 \n",
"1636233 0.987591 0.734940 \n",
"1636234 0.765693 0.874042 \n",
"... ... ... \n",
"1637782 0.702990 0.705316 \n",
"1637783 0.842368 0.333222 \n",
"1637784 0.851113 0.101695 \n",
"1637785 0.490691 0.137965 \n",
"1637786 0.916556 0.923205 \n",
2025-05-29 20:41:18 +08:00
"\n",
2025-06-10 15:22:25 +08:00
" cs_rank_turnover_per_winner cs_rank_ind_cap_neutral_pe \\\n",
"1636230 0.864865 NaN \n",
"1636231 0.763417 NaN \n",
"1636232 0.660949 NaN \n",
"1636233 0.700000 NaN \n",
"1636234 0.914234 NaN \n",
"... ... ... \n",
"1637782 0.419934 NaN \n",
"1637783 0.466578 NaN \n",
"1637784 0.440678 NaN \n",
"1637785 0.686170 NaN \n",
"1637786 0.648604 NaN \n",
2025-05-29 20:41:18 +08:00
"\n",
2025-06-10 15:22:25 +08:00
" cs_rank_volume_ratio cs_rank_elg_buy_sell_sm_ratio \\\n",
"1636230 0.646930 0.341855 \n",
"1636231 0.251279 0.318912 \n",
"1636232 0.311724 0.260036 \n",
"1636233 0.988313 0.796350 \n",
"1636234 0.990142 0.598905 \n",
"... ... ... \n",
"1637782 0.537542 0.133056 \n",
"1637783 0.852843 0.129697 \n",
"1637784 0.726653 0.740113 \n",
"1637785 0.932846 0.645279 \n",
"1637786 0.863531 0.724069 \n",
2025-05-29 20:41:18 +08:00
"\n",
2025-06-10 15:22:25 +08:00
" cs_rank_cost_dist_vol_ratio future_return label \n",
"1636230 0.678941 0.158859 40.0 \n",
"1636231 0.402916 0.136831 37.0 \n",
"1636232 0.460713 0.106319 39.0 \n",
"1636233 0.988501 -0.072893 4.0 \n",
"1636234 0.991571 -0.057737 5.0 \n",
"... ... ... ... \n",
"1637782 0.703987 NaN NaN \n",
"1637783 0.895910 NaN NaN \n",
"1637784 0.820871 NaN NaN \n",
"1637785 0.958112 NaN NaN \n",
"1637786 0.912566 NaN NaN \n",
2025-05-29 20:41:18 +08:00
"\n",
2025-06-10 15:22:25 +08:00
"[1557 rows x 199 columns]\n"
2025-05-29 20:41:18 +08:00
]
}
],
"source": [
"print(df[df['ts_code'] == '002652.SZ'])"
]
2025-05-26 21:34:36 +08:00
}
],
"metadata": {
"kernelspec": {
2025-06-04 13:50:02 +08:00
"display_name": "stock",
2025-05-26 21:34:36 +08:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2025-06-04 13:50:02 +08:00
"version": "3.13.2"
2025-05-26 21:34:36 +08:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}