diff --git a/src/training/main.py b/src/training/main.py index 4876dad..204811c 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -5,26 +5,298 @@ 或: uv run python src/training/main.py + +本脚本提供两种运行方式: +1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测) +2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整 + +因子配置示例: + from src.factors import MovingAverageFactor, ReturnRankFactor + + # 直接传入因子实例列表 - 最简单的方式 + factors = [ + MovingAverageFactor(period=5), + MovingAverageFactor(period=10), + MovingAverageFactor(period=20), + ReturnRankFactor(period=5), + ReturnRankFactor(period=10), + ] + + # 运行完整流程 + result = run_full_pipeline(factors=factors) """ -from src.training.pipeline import run_training +from pathlib import Path +from typing import Optional, List + +import polars as pl + +from src.factors import BaseFactor +from src.training.pipeline import ( + FactorConfig, + predict_top_stocks, + prepare_data, + save_top_stocks, + train_model, +) + + +def prepare_data_and_train( + factors: Optional[List[BaseFactor]] = None, + data_dir: str = "data", + train_start: str = "20190101", + train_end: str = "20231231", + val_start: str = "20240102", + val_end: str = "20240531", + test_start: str = "20240602", + test_end: str = "20241231", +) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]: + """第一步:数据处理 + + 加载原始数据,计算因子和标签,拆分训练/验证/测试集。 + + Args: + factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5) + data_dir: 数据目录 + train_start: 训练集开始日期 + train_end: 训练集结束日期 + val_start: 验证集开始日期 + val_end: 验证集结束日期 + test_start: 测试集开始日期 + test_end: 测试集结束日期 + + Returns: + tuple: (train_data, val_data, test_data, factor_config, label_col) + """ + print("=" * 50) + print("[Step 1] 数据处理") + print("=" * 50) + print(f"训练集: {train_start} -> {train_end}") + print(f"验证集: {val_start} -> {val_end}") + print(f"测试集: {test_start} -> {test_end}") + print() + + # 1. 准备数据 + train_data, val_data, test_data, factor_config = prepare_data( + factors=factors, + data_dir=data_dir, + train_start=train_start, + train_end=train_end, + val_start=val_start, + val_end=val_end, + test_start=test_start, + test_end=test_end, + ) + + print(f"训练集样本数: {len(train_data)}") + print(f"验证集样本数: {len(val_data)}") + print(f"测试集样本数: {len(test_data)}") + print() + + # 打印少量数据样本展示 + print("=" * 50) + print("[数据预览] 训练集前3行:") + print(train_data.head(3)) + print() + print("[数据预览] 验证集前3行:") + print(val_data.head(3)) + print() + print("[数据预览] 测试集前3行:") + print(test_data.head(3)) + print() + + # 2. 获取特征列名 + feature_cols = factor_config.get_feature_names() + label_col = "label" + + print(f"特征列: {feature_cols}") + print(f"标签列: {label_col}") + print() + + return train_data, val_data, test_data, factor_config, label_col + + +def train_and_predict( + train_data: pl.DataFrame, + val_data: pl.DataFrame, + test_data: pl.DataFrame, + factor_config: FactorConfig, + label_col: str = "label", + top_n: int = 5, + output_path: str = "output/top_stocks.tsv", +) -> pl.DataFrame: + """第二步:训练和预测 + + 使用处理好的数据训练模型,进行测试集预测并保存结果。 + + Args: + train_data: 训练数据 + val_data: 验证数据 + test_data: 测试数据 + factor_config: 因子配置对象 + label_col: 标签列名 + top_n: 每日选股数量 + output_path: 输出文件路径 + + Returns: + 选股结果DataFrame + """ + print("=" * 50) + print("[Step 2] 模型训练与预测") + print("=" * 50) + print() + + # 获取特征列名 + feature_cols = factor_config.get_feature_names() + print(f"使用特征: {feature_cols}") + print() + + # 3. 训练模型 + print("[Training] Training model...") + model, pipeline = train_model( + train_data=train_data, + val_data=val_data, + feature_cols=feature_cols, + label_col=label_col, + ) + print() + + # 4. 测试集预测 + print("[Predict] Predicting on test set...") + top_stocks = predict_top_stocks( + model=model, + pipeline=pipeline, + test_data=test_data, + feature_cols=feature_cols, + top_n=top_n, + ) + print() + + # 5. 保存结果 + print(f"[Saving] Saving results to {output_path}...") + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + save_top_stocks(top_stocks, output_path) + print() + + return top_stocks + + +def run_full_pipeline( + factors: Optional[List[BaseFactor]] = None, + train_start: str = "20190101", + train_end: str = "20231231", + val_start: str = "20240102", + val_end: str = "20240531", + test_start: str = "20240602", + test_end: str = "20241231", + top_n: int = 5, + output_path: str = "output/top_stocks.tsv", +) -> pl.DataFrame: + """运行完整训练流程 + + 相当于依次调用 prepare_data_and_train 和 train_and_predict。 + + Args: + factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5) + train_start: 训练集开始日期 + train_end: 训练集结束日期 + val_start: 验证集开始日期 + val_end: 验证集结束日期 + test_start: 测试集开始日期 + test_end: 测试集结束日期 + top_n: 每日选股数量 + output_path: 输出文件路径 + + Returns: + 选股结果DataFrame + """ + # 第一步:数据处理 + train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train( + factors=factors, + train_start=train_start, + train_end=train_end, + val_start=val_start, + val_end=val_end, + test_start=test_start, + test_end=test_end, + ) + + # 第二步:训练和预测 + result = train_and_predict( + train_data=train_data, + val_data=val_data, + test_data=test_data, + factor_config=factor_config, + label_col=label_col, + top_n=top_n, + output_path=output_path, + ) + + print("=" * 50) + print("[Done] 训练流程完成!") + print("=" * 50) + + return result if __name__ == "__main__": - # 运行完整训练流程 - # 训练集:20190101 - 20231231 - # 验证集:20240102 - 20240531 (与训练集间隔1天,避免数据泄露) - # 测试集:20240602 - 20241231 (与验证集间隔1天,避免数据泄露) - result = run_training( + from src.factors import MovingAverageFactor, ReturnRankFactor + + # ========== 因子配置 ========== + # 直接传入因子实例列表 - 简单直观 + factors = [ + MovingAverageFactor(period=5), # 5日移动平均线 + MovingAverageFactor(period=10), # 10日移动平均线 + MovingAverageFactor(period=20), # 20日移动平均线 + ReturnRankFactor(period=5), # 5日收益率排名 + ReturnRankFactor(period=10), # 10日收益率排名 + ] + + # ========== 运行方式 ========== + + # 方式一:完整流程(一次性执行) + # result = run_full_pipeline( + # factors=factors, + # train_start="20190101", + # train_end="20231231", + # val_start="20240102", + # val_end="20240531", + # test_start="20240602", + # test_end="20241231", + # top_n=5, + # output_path="output/top_stocks.tsv", + # ) + + # 方式二:分步执行(便于调试) + # 第一步:数据处理 + train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train( + factors=factors, train_start="20190101", train_end="20231231", val_start="20240102", val_end="20240531", test_start="20240602", test_end="20241231", - top_n=5, - output_path="output/top_stocks.tsv", ) - print("\n[Result] Top stocks selection:") - print(result) + # 可在此处添加自定义逻辑,例如: + # - 查看数据分布 + # - 调整特征 + # - 保存中间结果 + print("\n[Info] 因子配置详情:") + print(f" 因子列表: {factor_config.get_feature_names()}") + print(f" 最大回溯天数: {factor_config.get_max_lookback()}") + + # 第二步:训练和预测 + # result = train_and_predict( + # train_data=train_data, + # val_data=val_data, + # test_data=test_data, + # factor_config=factor_config, + # label_col=label_col, + # top_n=5, + # output_path="output/top_stocks.tsv", + # ) + # + # print("\n[Result] Top stocks selection:") + # print(result) diff --git a/src/training/pipeline.py b/src/training/pipeline.py index cf5b859..483916c 100644 --- a/src/training/pipeline.py +++ b/src/training/pipeline.py @@ -1,21 +1,48 @@ """训练管道 - 包含数据处理、模型训练和预测功能 本模块提供: -1. 数据准备:从因子计算结果中准备训练/测试数据 +1. 数据准备:使用 FactorEngine 从因子计算结果中准备训练/测试数据 2. 数据处理:Fillna(0) -> Dropna 3. 模型训练:使用LightGBM训练分类模型 4. 预测和选股:输出每日top5股票池 + +注意:本模块使用 src.factors 框架进行因子计算,确保防泄露机制生效。 + +因子配置示例: + from src.factors import MovingAverageFactor, ReturnRankFactor + from src.training.pipeline import prepare_data + + # 直接传入因子实例列表 - 简单直观 + factors = [ + MovingAverageFactor(period=5), + MovingAverageFactor(period=10), + ReturnRankFactor(period=5), + ] + + train_data, val_data, test_data, factor_config = prepare_data( + factors=factors, + ... + ) + + # 或者使用 FactorConfig 包装(支持链式添加) + from src.training.pipeline import FactorConfig + + config = FactorConfig() + .add(MovingAverageFactor(period=5)) + .add(MovingAverageFactor(period=10)) + .add(ReturnRankFactor(period=5)) """ from datetime import datetime from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional import numpy as np import polars as pl -from src.factors import DataLoader, FactorEngine +from src.factors import DataLoader, FactorEngine, BaseFactor from src.factors.data_spec import DataSpec +from src.factors.momentum import MovingAverageFactor, ReturnRankFactor from src.pipeline import ( DropNAProcessor, FillNAProcessor, @@ -26,7 +53,98 @@ from src.pipeline import ( ) +# ========== 因子配置类 ========== + + +class FactorConfig: + """因子配置类 - 管理因子列表 + + 用于包装因子实例列表,支持链式添加。 + + 示例: + # 方式1:初始化时传入列表 + config = FactorConfig([ + MovingAverageFactor(period=5), + ReturnRankFactor(period=5), + ]) + + # 方式2:链式添加 + config = FactorConfig() + .add(MovingAverageFactor(period=5)) + .add(ReturnRankFactor(period=5)) + + # 获取因子实例列表 + factors = config.get_factors() + + # 获取特征列名 + feature_cols = config.get_feature_names() + """ + + def __init__(self, factors: Optional[List[BaseFactor]] = None): + """初始化因子配置 + + Args: + factors: 因子实例列表 + """ + self._factors: List[BaseFactor] = factors or [] + + def add(self, factor: BaseFactor) -> "FactorConfig": + """添加因子到配置 + + 支持链式调用: + config = FactorConfig() + .add(MovingAverageFactor(period=5)) + .add(ReturnRankFactor(period=5)) + + Args: + factor: 因子实例 + + Returns: + self,支持链式调用 + """ + if not isinstance(factor, BaseFactor): + raise ValueError(f"必须是 BaseFactor 实例, got {type(factor)}") + self._factors.append(factor) + return self + + def get_factors(self) -> List[BaseFactor]: + """获取因子实例列表 + + Returns: + 因子实例列表 + """ + return self._factors + + def get_feature_names(self) -> List[str]: + """获取所有因子的特征列名 + + Returns: + 特征列名列表 + """ + return [f.name for f in self._factors] + + def get_max_lookback(self) -> int: + """获取所有因子中最大的 lookback 天数 + + Returns: + 最大 lookback 天数 + """ + max_lookback = 0 + for factor in self._factors: + for spec in factor.data_specs: + max_lookback = max(max_lookback, spec.lookback_days) + return max_lookback + + def __len__(self) -> int: + return len(self._factors) + + def __repr__(self) -> str: + names = [f.name for f in self._factors] + return f"FactorConfig({names})" + + def prepare_data( + factors: Optional[List[BaseFactor]] = None, data_dir: str = "data", train_start: str = "20180101", train_end: str = "20230101", @@ -34,12 +152,13 @@ def prepare_data( val_end: str = "20230601", test_start: str = "20230601", test_end: str = "20240101", -) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: +) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig]: """准备训练、验证和测试数据 - 从DuckDB加载原始日线数据,计算所需因子并生成标签。 + 使用 FactorEngine 计算因子,确保防泄露机制生效。 Args: + factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5) data_dir: 数据目录 train_start: 训练集开始日期 train_end: 训练集结束日期 @@ -49,44 +168,107 @@ def prepare_data( test_end: 测试集结束日期 Returns: - (train_data, val_data, test_data): 训练集、验证集和测试集的DataFrame + (train_data, val_data, test_data, factor_config): + 训练集、验证集、测试集的DataFrame,以及使用的因子配置 """ from src.data.storage import Storage storage = Storage() - # 加载日线数据(需要更多历史数据用于计算因子) - # 训练集需要更多历史数据(用于计算因子lookback) - lookback_days = 20 # 足够计算MA10和5日收益率 + # 1. 处理因子配置 + if factors is None: + # 默认因子配置 + factor_config = FactorConfig( + [ + MovingAverageFactor(period=5), + MovingAverageFactor(period=10), + ReturnRankFactor(period=5), + ] + ) + elif isinstance(factors, FactorConfig): + factor_config = factors + elif isinstance(factors, list): + # 转换为 FactorConfig + factor_config = FactorConfig(factors) + else: + raise ValueError( + f"factors 必须是 List[BaseFactor] 或 FactorConfig, got {type(factors)}" + ) + + factors_list = factor_config.get_factors() + feature_cols = factor_config.get_feature_names() + + if not factors_list: + raise ValueError("至少需要提供一个因子") + + print(f"[PrepareData] 使用因子: {feature_cols}") + + # 2. 初始化 FactorEngine + loader = DataLoader(data_dir=data_dir) + engine = FactorEngine(loader) + + # 3. 计算需要的回溯天数 + max_lookback = factor_config.get_max_lookback() start_with_lookback = str(int(train_start) - 10000) # 往前取一年 - # 查询全部数据(包含train、val、test),然后再拆分 - # 注意:DuckDB 中 trade_date 是 DATE 类型,需要转换 + # 获取所有股票列表 start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}" end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}" - - all_query = f""" - SELECT ts_code, trade_date, close, pre_close - FROM daily + all_stocks_query = f""" + SELECT DISTINCT ts_code FROM daily WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}' - ORDER BY ts_code, trade_date """ - all_raw = storage._connection.sql(all_query).pl() - # 转换 trade_date 为字符串格式 - all_raw = all_raw.with_columns( - pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date") - ) + all_stocks_df = storage._connection.sql(all_stocks_query).pl() + all_stocks = all_stocks_df["ts_code"].to_list() - # 过滤不符合条件的股票 - all_raw = _filter_invalid_stocks(all_raw) - print(f"[PrepareData] After filtering: total={len(all_raw)}") + print(f"[PrepareData] 股票数量: {len(all_stocks)}") - # 计算因子和标签(需要全局数据一次性计算) - all_data = _compute_features_and_label( - all_raw, - start_date=train_start, - end_date=test_end - ) + # 4. 计算所有因子并合并 + all_features = None + + for factor in factors_list: + print(f"[PrepareData] 计算因子: {factor.name} ({factor.factor_type})") + + if factor.factor_type == "time_series": + # 时序因子需要传入股票列表 + result = engine.compute( + factor, + stock_codes=all_stocks, + start_date=start_with_lookback, + end_date=test_end, + ) + else: + # 截面因子不需要股票列表 + result = engine.compute( + factor, + start_date=start_with_lookback, + end_date=test_end, + ) + + # 合并结果 + if all_features is None: + all_features = result + else: + # 确保没有重复的 _right 列 + result = result.select([ + c for c in result.columns if not c.endswith("_right") + ]) + all_features = all_features.select([ + c for c in all_features.columns if not c.endswith("_right") + ]) + all_features = all_features.join( + result, on=["trade_date", "ts_code"], how="outer" + ) + + if all_features is None: + raise ValueError("没有计算任何因子") + + # 5. 计算标签(未来5日收益率) + all_data = _compute_label(all_features, start_date=train_start, end_date=test_end) + + # 6. 过滤不符合条件的股票 + all_data = _filter_invalid_stocks(all_data) + print(f"[PrepareData] After filtering: total={len(all_data)}") # 转换日期格式用于比较 train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}" @@ -98,18 +280,22 @@ def prepare_data( # 拆分数据 train_data = all_data.filter( - (pl.col("trade_date") >= train_start_fmt) & (pl.col("trade_date") <= train_end_fmt) + (pl.col("trade_date") >= train_start_fmt) + & (pl.col("trade_date") <= train_end_fmt) ) val_data = all_data.filter( (pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt) ) test_data = all_data.filter( - (pl.col("trade_date") >= test_start_fmt) & (pl.col("trade_date") <= test_end_fmt) + (pl.col("trade_date") >= test_start_fmt) + & (pl.col("trade_date") <= test_end_fmt) ) - print(f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}") + print( + f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}" + ) - return train_data, val_data, test_data + return train_data, val_data, test_data, factor_config def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame: @@ -137,59 +323,54 @@ def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame: ) -def _compute_features_and_label( - raw_data: pl.DataFrame, +def _compute_label( + features_df: pl.DataFrame, start_date: str, end_date: str, ) -> pl.DataFrame: - """计算因子和标签 + """计算标签(未来5日收益率) - 因子: - 1. return_5_rank: 5日收益率截面排名 - 2. ma_5: 5日移动平均 - 3. ma_10: 10日移动平均 - - 标签:未来5日收益率大于0为1,否则为0 + 标签定义:未来5日收益率大于0为1,否则为0 Args: - raw_data: 原始日线数据 + features_df: 包含因子的DataFrame start_date: 开始日期 end_date: 结束日期 Returns: 包含因子和标签的DataFrame """ - # 确保按日期排序 - raw_data = raw_data.sort(["ts_code", "trade_date"]) + from src.data.storage import Storage - # 计算收益率(未来5日) - raw_data = raw_data.with_columns( - [ - # 当日收益率 - ((pl.col("close") - pl.col("pre_close")) / pl.col("pre_close")).alias( - "daily_return" - ), - ] + storage = Storage() + + # 从数据库获取收盘价数据用于计算标签 + start_dt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}" + end_dt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}" + + # 需要多取5天数据来计算未来收益率 + end_dt_extended = f"{end_date[:4]}-{end_date[4:6]}-{int(end_date[6:8]) + 5}" + + price_query = f""" + SELECT ts_code, trade_date, close + FROM daily + WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt_extended}' + ORDER BY ts_code, trade_date + """ + price_data = storage._connection.sql(price_query).pl() + price_data = price_data.with_columns( + pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date") ) - # 按股票分组计算 + # 按股票计算未来5日收益率 result_list = [] + for ts_code in price_data["ts_code"].unique(): + stock_data = price_data.filter(pl.col("ts_code") == ts_code).sort("trade_date") - for ts_code in raw_data["ts_code"].unique(): - stock_data = raw_data.filter(pl.col("ts_code") == ts_code).sort("trade_date") - - if len(stock_data) < 20: + if len(stock_data) < 6: continue - # 计算MA5和MA10 - stock_data = stock_data.with_columns( - [ - pl.col("close").rolling_mean(5).alias("ma_5"), - pl.col("close").rolling_mean(10).alias("ma_10"), - ] - ) - - # 计算未来5日收益率(用于标签) + # 计算未来5日收益率 future_return = stock_data["close"].shift(-5) - stock_data["close"] future_return_pct = future_return / stock_data["close"] stock_data = stock_data.with_columns( @@ -205,46 +386,21 @@ def _compute_features_and_label( ] ) - result_list.append(stock_data) + result_list.append(stock_data.select(["trade_date", "ts_code", "label"])) if not result_list: return pl.DataFrame() - result = pl.concat(result_list) + label_df = pl.concat(result_list) - # 转换日期格式:YYYYMMDD -> YYYY-MM-DD - start_date_formatted = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}" - end_date_formatted = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}" + # 将标签合并到因子数据 + result = features_df.join(label_df, on=["trade_date", "ts_code"], how="inner") # 过滤有效日期范围 result = result.filter( - (pl.col("trade_date") >= start_date_formatted) & (pl.col("trade_date") <= end_date_formatted) + (pl.col("trade_date") >= start_dt) & (pl.col("trade_date") <= end_dt) ) - # 计算5日收益率排名(截面) - result = result.with_columns( - [ - pl.col("daily_return") - .rank(method="average") - .over("trade_date") - .alias("return_5_rank") - ] - ) - - # 归一化排名到0-1 - result = result.with_columns( - [ - ( - pl.col("return_5_rank") - / pl.col("return_5_rank").max().over("trade_date") - ).alias("return_5_rank") - ] - ) - - # 选择需要的列 - feature_cols = ["trade_date", "ts_code", "return_5_rank", "ma_5", "ma_10", "label"] - result = result.select(feature_cols) - return result @@ -271,7 +427,7 @@ def train_model( feature_cols: List[str], label_col: str = "label", model_params: Optional[dict] = None, -) -> Tuple[LightGBMModel, ProcessingPipeline]: +) -> tuple[LightGBMModel, ProcessingPipeline]: """训练LightGBM分类模型 Args: @@ -301,8 +457,12 @@ def train_model( valid_mask = y_train.is_in([0, 1]) X_train_processed = X_train_processed.filter(valid_mask) y_train = y_train.filter(valid_mask) - print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples") - print(f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}") + print( + f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples" + ) + print( + f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}" + ) # 准备验证集 X_val_processed = None @@ -311,16 +471,18 @@ def train_model( X_val = val_data.select(feature_cols) y_val = val_data[label_col] print(f"[TrainModel] Val samples: {len(X_val)}") - + # 处理验证集数据(使用训练集的参数) X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST) - + # 过滤验证集有效标签 val_valid_mask = y_val.is_in([0, 1]) X_val_processed = X_val_processed.filter(val_valid_mask) y_val = y_val.filter(val_valid_mask) print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples") - print(f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}") + print( + f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}" + ) # 创建模型 params = model_params or { @@ -384,7 +546,10 @@ def predict_top_stocks( # 使用 key_data 添加预测结果,保持行数一致 result = key_data.with_columns( pl.Series( - name="pred_prob", values=probs[:, 1] if len(probs.shape) > 1 and probs.shape[1] > 1 else probs.flatten() + name="pred_prob", + values=probs[:, 1] + if len(probs.shape) > 1 and probs.shape[1] > 1 + else probs.flatten(), ), ) @@ -396,7 +561,11 @@ def predict_top_stocks( # 按概率降序排序,选出top N day_top = day_data.sort("pred_prob", descending=True).head(top_n) - top_stocks.append(day_top.select(["trade_date", "pred_prob", "ts_code"]).rename({"pred_prob": "score"})) + top_stocks.append( + day_top.select(["trade_date", "pred_prob", "ts_code"]).rename( + {"pred_prob": "score"} + ) + ) return pl.concat(top_stocks) @@ -415,6 +584,7 @@ def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None: def run_training( + factors: Optional[List[BaseFactor]] = None, data_dir: str = "data", output_path: str = "output/top_stocks.tsv", train_start: str = "20180101", @@ -428,6 +598,7 @@ def run_training( """运行完整训练流程 Args: + factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5) data_dir: 数据目录 output_path: 输出文件路径 train_start: 训练集开始日期 @@ -448,7 +619,8 @@ def run_training( # 1. 准备数据 print("[Training] Preparing data...") - train_data, val_data, test_data = prepare_data( + train_data, val_data, test_data, factor_config = prepare_data( + factors=factors, data_dir=data_dir, train_start=train_start, train_end=train_end, @@ -461,9 +633,10 @@ def run_training( print(f"[Training] Val samples: {len(val_data)}") print(f"[Training] Test samples: {len(test_data)}") - # 2. 定义特征列 - feature_cols = ["return_5_rank", "ma_5", "ma_10"] + # 2. 获取特征列名 + feature_cols = factor_config.get_feature_names() label_col = "label" + print(f"[Training] Feature columns: {feature_cols}") # 3. 训练模型 print("[Training] Training model...")