Files
ProStock/src/experiment/common.py

449 lines
11 KiB
Python
Raw Normal View History

"""实验脚本的共用配置和辅助函数。
此模块包含 regression.py learn_to_rank.py 共用的代码
避免重复维护两份相同的配置和函数
"""
from datetime import datetime
from typing import List
import polars as pl
from src.factors import FactorEngine
# =============================================================================
# 日期范围配置(正确的 train/val/test 三分法)
# =============================================================================
TRAIN_START = "20200101"
TRAIN_END = "20231231"
VAL_START = "20240101"
VAL_END = "20241231"
TEST_START = "20250101"
TEST_END = "20261231"
# =============================================================================
# 因子配置
# =============================================================================
# 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子)
SELECTED_FACTORS = [
"ma_5",
"ma_20",
"ma_ratio_5_20",
"bias_10",
"high_low_ratio",
"bbi_ratio",
"return_5",
"return_20",
"kaufman_ER_20",
"mom_acceleration_10_20",
"drawdown_from_high_60",
"up_days_ratio_20",
"volatility_5",
"volatility_20",
"volatility_ratio",
"std_return_20",
"sharpe_ratio_20",
"min_ret_20",
"volatility_squeeze_5_60",
"overnight_intraday_diff",
"upper_shadow_ratio",
"capital_retention_20",
"max_ret_20",
"volume_ratio_5_20",
"turnover_rate_mean_5",
"turnover_deviation",
"amihud_illiq_20",
"turnover_cv_20",
"pv_corr_20",
"close_vwap_deviation",
"roe",
"roa",
"profit_margin",
"debt_to_equity",
"current_ratio",
"net_profit_yoy",
"revenue_yoy",
"healthy_expansion_velocity",
"EP",
"BP",
"CP",
"market_cap_rank",
"turnover_rank",
"return_5_rank",
"EP_rank",
"pe_expansion_trend",
"value_price_divergence",
"active_market_cap",
"ebit_rank",
"GTJA_alpha001",
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
"GTJA_alpha010",
"GTJA_alpha011",
"GTJA_alpha012",
"GTJA_alpha013",
"GTJA_alpha014",
"GTJA_alpha015",
"GTJA_alpha016",
"GTJA_alpha017",
"GTJA_alpha018",
"GTJA_alpha019",
"GTJA_alpha020",
"GTJA_alpha022",
"GTJA_alpha023",
"GTJA_alpha024",
"GTJA_alpha025",
"GTJA_alpha026",
"GTJA_alpha027",
"GTJA_alpha028",
"GTJA_alpha029",
"GTJA_alpha031",
"GTJA_alpha032",
"GTJA_alpha033",
"GTJA_alpha034",
"GTJA_alpha035",
"GTJA_alpha036",
"GTJA_alpha037",
# "GTJA_alpha038",
"GTJA_alpha039",
"GTJA_alpha040",
"GTJA_alpha041",
"GTJA_alpha042",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha045",
"GTJA_alpha046",
"GTJA_alpha047",
"GTJA_alpha048",
"GTJA_alpha049",
"GTJA_alpha050",
"GTJA_alpha051",
"GTJA_alpha052",
"GTJA_alpha053",
"GTJA_alpha054",
"GTJA_alpha056",
"GTJA_alpha057",
"GTJA_alpha058",
"GTJA_alpha059",
"GTJA_alpha060",
"GTJA_alpha061",
"GTJA_alpha062",
"GTJA_alpha063",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha067",
"GTJA_alpha068",
"GTJA_alpha070",
"GTJA_alpha071",
"GTJA_alpha072",
"GTJA_alpha073",
"GTJA_alpha074",
"GTJA_alpha076",
"GTJA_alpha077",
"GTJA_alpha078",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha081",
"GTJA_alpha082",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha085",
"GTJA_alpha086",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha089",
"GTJA_alpha090",
"GTJA_alpha091",
"GTJA_alpha092",
"GTJA_alpha093",
"GTJA_alpha094",
"GTJA_alpha095",
"GTJA_alpha096",
"GTJA_alpha097",
"GTJA_alpha098",
"GTJA_alpha099",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha106",
"GTJA_alpha107",
"GTJA_alpha108",
"GTJA_alpha109",
"GTJA_alpha110",
"GTJA_alpha111",
"GTJA_alpha112",
"GTJA_alpha113",
"GTJA_alpha114",
"GTJA_alpha115",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha119",
"GTJA_alpha120",
"GTJA_alpha121",
"GTJA_alpha122",
"GTJA_alpha123",
"GTJA_alpha124",
"GTJA_alpha125",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha131",
"GTJA_alpha132",
"GTJA_alpha133",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
"GTJA_alpha138",
"GTJA_alpha139",
"GTJA_alpha140",
"GTJA_alpha141",
"GTJA_alpha142",
"GTJA_alpha145",
"GTJA_alpha146",
"GTJA_alpha148",
"GTJA_alpha150",
"GTJA_alpha151",
"GTJA_alpha152",
"GTJA_alpha153",
"GTJA_alpha154",
"GTJA_alpha155",
"GTJA_alpha156",
"GTJA_alpha157",
"GTJA_alpha158",
"GTJA_alpha159",
"GTJA_alpha160",
"GTJA_alpha161",
"GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha164",
"GTJA_alpha165",
"GTJA_alpha166",
"GTJA_alpha167",
"GTJA_alpha168",
"GTJA_alpha169",
"GTJA_alpha170",
"GTJA_alpha171",
"GTJA_alpha173",
"GTJA_alpha174",
"GTJA_alpha175",
"GTJA_alpha176",
"GTJA_alpha177",
"GTJA_alpha178",
"GTJA_alpha179",
"GTJA_alpha180",
"GTJA_alpha183",
"GTJA_alpha184",
"GTJA_alpha185",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
]
# 因子定义字典完整因子库用于存放尚未注册到metadata的因子
FACTOR_DEFINITIONS = {
}
def get_label_factor(label_name: str) -> dict:
"""获取Label因子定义字典。
Args:
label_name: label因子名称
Returns:
Label因子定义字典
"""
return {
label_name: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1",
}
# =============================================================================
# 辅助函数
# =============================================================================
def register_factors(
engine: FactorEngine,
selected_factors: List[str],
factor_definitions: dict,
label_factor: dict,
) -> List[str]:
"""注册因子。
selected_factors metadata 查询factor_definitions DSL 表达式注册
Args:
engine: FactorEngine实例
selected_factors: 从metadata中选择的因子名称列表
factor_definitions: 通过表达式定义的因子字典
label_factor: label因子定义字典
Returns:
特征列名称列表
"""
print("=" * 80)
print("注册因子")
print("=" * 80)
# 注册 SELECTED_FACTORS 中的因子(已在 metadata 中)
print("\n注册特征因子(从 metadata:")
for name in selected_factors:
engine.add_factor(name)
print(f" - {name}")
# 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中)
print("\n注册特征因子(表达式):")
for name, expr in factor_definitions.items():
engine.add_factor(name, expr)
print(f" - {name}: {expr}")
# 注册 label 因子(通过表达式)
print("\n注册 Label 因子(表达式):")
for name, expr in label_factor.items():
engine.add_factor(name, expr)
print(f" - {name}: {expr}")
# 特征列 = SELECTED_FACTORS + FACTOR_DEFINITIONS 的 keys
feature_cols = selected_factors + list(factor_definitions.keys())
print(f"\n特征因子数: {len(feature_cols)}")
print(f" - 来自 metadata: {len(selected_factors)}")
print(f" - 来自表达式: {len(factor_definitions)}")
print(f"Label: {list(label_factor.keys())[0]}")
print(f"已注册因子总数: {len(engine.list_registered())}")
return feature_cols
def prepare_data(
engine: FactorEngine,
feature_cols: List[str],
start_date: str,
end_date: str,
label_name: str,
) -> pl.DataFrame:
"""准备数据。
计算因子并返回包含特征和label的数据框
Args:
engine: FactorEngine实例
feature_cols: 特征列名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
label_name: label列名称
Returns:
包含因子计算结果的数据框
"""
print("\n" + "=" * 80)
print("准备数据")
print("=" * 80)
# 计算因子(全市场数据)
print(f"\n计算因子: {start_date} - {end_date}")
factor_names = feature_cols + [label_name] # 包含 label
data = engine.compute(
factor_names=factor_names,
start_date=start_date,
end_date=end_date,
)
print(f"数据形状: {data.shape}")
print(f"数据列: {data.columns}")
print(f"\n前5行预览:")
print(data.head())
return data
# =============================================================================
# 股票池筛选配置
# =============================================================================
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""股票池筛选函数(单日数据)。
筛选条件
1. 排除创业板代码以 300 开头
2. 排除科创板代码以 688 开头
3. 排除北交所代码以 89 4 开头
4. 选取当日市值最小的500只股票
Args:
df: 单日数据框
Returns:
布尔Series表示哪些股票被选中
"""
# 代码筛选(排除创业板、科创板、北交所)
code_filter = (
~df["ts_code"].str.starts_with("30") # 排除创业板
& ~df["ts_code"].str.starts_with("68") # 排除科创板
& ~df["ts_code"].str.starts_with("8") # 排除北交所
& ~df["ts_code"].str.starts_with("9") # 排除北交所
& ~df["ts_code"].str.starts_with("4") # 排除北交所
)
# 在已筛选的股票中选取市值最小的500只
valid_df = df.filter(code_filter)
n = min(500, len(valid_df))
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
# 返回布尔 Series是否在被选中的股票中
return df["ts_code"].is_in(small_cap_codes)
# 定义筛选所需的基础列
STOCK_FILTER_REQUIRED_COLUMNS = ["total_mv"]
# =============================================================================
# 输出配置
# =============================================================================
OUTPUT_DIR = "output"
SAVE_PREDICTIONS = True
PERSIST_MODEL = False
# Top N 配置:每日推荐股票数量
TOP_N = 5 # 可调整为 10, 20 等
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
"""生成输出文件路径。
Args:
model_type: 模型类型"regression" "rank"
test_start: 测试开始日期
test_end: 测试结束日期
Returns:
输出文件路径
"""
import os
# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 生成文件名
start_dt = datetime.strptime(test_start, "%Y%m%d")
end_dt = datetime.strptime(test_end, "%Y%m%d")
date_str = f"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}"
filename = f"{model_type}_output.csv"
return os.path.join(OUTPUT_DIR, filename)