diff --git a/docs/plan/training_module_plan.md b/docs/plan/training_module_plan.md index 2f31aa7..3b807f8 100644 --- a/docs/plan/training_module_plan.md +++ b/docs/plan/training_module_plan.md @@ -40,9 +40,6 @@ src/ │ │ ├── processors/ # 数据处理器 │ │ │ ├── __init__.py │ │ │ └── transforms.py # 标准化(截面/时序)、缩尾 -│ │ └── metrics/ # 评估指标 -│ │ ├── __init__.py -│ │ └── metrics.py # IC, RankIC, MSE, MAE │ ├── config/ # 配置管理 │ │ ├── __init__.py │ │ └── config.py # TrainingConfig (pydantic) @@ -568,30 +565,6 @@ class Winsorizer(BaseProcessor): pass ``` -### 3.7 评估指标 (components/metrics/) - -```python -def ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: - """信息系数 (Pearson 相关系数)""" - from scipy.stats import pearsonr - corr, _ = pearsonr(y_true, y_pred) - return corr - -def rank_ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: - """秩相关系数 (Spearman 相关系数)""" - from scipy.stats import spearmanr - corr, _ = spearmanr(y_true, y_pred) - return corr - -def mse_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: - """均方误差""" - return np.mean((y_true - y_pred) ** 2) - -def mae_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: - """平均绝对误差""" - return np.mean(np.abs(y_true - y_pred)) -``` - ## 4. 训练流程设计 ### 4.1 Trainer 主类 (core/trainer.py) @@ -615,7 +588,6 @@ class Trainer: pool_manager: Optional[StockPoolManager] = None, processors: List[BaseProcessor] = None, splitter: DateSplitter = None, - metrics: List[str] = None, target_col: str = "target", feature_cols: List[str] = None, persist_model: bool = False, @@ -625,7 +597,6 @@ class Trainer: self.pool_manager = pool_manager self.processors = processors or [] self.splitter = splitter - self.metrics = metrics or ["ic", "rank_ic", "mse", "mae"] self.target_col = target_col self.feature_cols = feature_cols self.persist_model = persist_model @@ -634,7 +605,6 @@ class Trainer: # 存储训练后的处理器 self.fitted_processors: List[BaseProcessor] = [] self.results: pl.DataFrame = None - self.metrics_results: Dict[str, float] = {} def train(self, data: pl.DataFrame) -> "Trainer": """执行训练流程 @@ -686,38 +656,18 @@ class Trainer: X_test = test_data.select(self.feature_cols) predictions = self.model.predict(X_test) - # 7. 评估 - print("[评估] 计算指标...") - y_test = test_data.select(self.target_col).to_series().to_numpy() - self._evaluate(y_test, predictions) - - # 8. 保存结果 + # 7. 保存结果 self.results = test_data.with_columns([ pl.Series("prediction", predictions) ]) - # 9. 持久化模型 + # 8. 持久化模型 if self.persist_model and self.model_save_path: print(f"[保存] 保存模型到 {self.model_save_path}...") self.save_model(self.model_save_path) return self - def _evaluate(self, y_true: np.ndarray, y_pred: np.ndarray): - """计算评估指标""" - from .components.metrics import ic_score, rank_ic_score, mse_score, mae_score - - metric_funcs = { - "ic": ic_score, - "rank_ic": rank_ic_score, - "mse": mse_score, - "mae": mae_score, - } - - for metric in self.metrics: - if metric in metric_funcs: - self.metrics_results[metric] = metric_funcs[metric](y_true, y_pred) - def predict(self, data: pl.DataFrame) -> pl.DataFrame: """对新数据进行预测 @@ -738,10 +688,6 @@ class Trainer: """获取所有预测结果""" return self.results - def get_metrics(self) -> Dict[str, float]: - """获取评估指标""" - return self.metrics_results - def save_results(self, path: str): """保存预测结果到文件""" if self.results is not None: @@ -827,9 +773,6 @@ class TrainingConfig(BaseSettings): ] ) - # === 评估指标 === - metrics: List[str] = ["ic", "rank_ic", "mse", "mae"] - # === 持久化配置 === persist_model: bool = False # 默认不持久化 model_save_path: Optional[str] = None # 持久化路径 @@ -914,8 +857,6 @@ trainer.train(all_data) # 8. 获取结果 results = trainer.get_results() # 包含预测值 -metrics = trainer.get_metrics() # IC, RankIC, etc. -print("评估指标:", metrics) # 9. 保存结果 trainer.save_results("output/predictions.csv") @@ -989,7 +930,6 @@ trainer = Trainer( trainer.train(all_data) results = trainer.get_results() -metrics = trainer.get_metrics() ``` ## 6. 实现顺序 @@ -1019,22 +959,18 @@ metrics = trainer.get_metrics() - `training/components/models/__init__.py` - `training/components/models/lightgbm.py`(含 save_model/load_model) -### Commit 6: 评估指标 -- `training/components/metrics/__init__.py` -- `training/components/metrics/metrics.py`(IC, RankIC, MSE, MAE) - -### Commit 7: 股票池管理器 +### Commit 6: 股票池管理器 - `training/core/__init__.py` - `training/core/stock_pool_manager.py`(每日独立筛选) -### Commit 8: Trainer 训练器 +### Commit 7: Trainer 训练器 - `training/core/trainer.py` -### Commit 9: 配置管理 +### Commit 8: 配置管理 - `training/config/__init__.py` - `training/config/config.py`(TrainingConfig,含必填校验) -### Commit 10: 预留实验模块 +### Commit 9: 预留实验模块 - `experiment/__init__.py` ## 7. 注意事项 @@ -1080,7 +1016,7 @@ metrics = trainer.get_metrics() 1. **特征选择**(processors/selectors.py) 2. **滚动训练**(WalkForward, ExpandingWindow) 3. **结果分析工具**(复杂分析功能) -4. **validator.py, evaluator.py**(简化为 metrics.py) +4. **validator.py, evaluator.py**(已删除,不实现 metrics) ### 7.5 新增功能 diff --git a/src/experiment/__init__.py b/src/experiment/__init__.py new file mode 100644 index 0000000..c6b87fd --- /dev/null +++ b/src/experiment/__init__.py @@ -0,0 +1,7 @@ +"""实验管理模块(预留结构) + +此模块为预留结构,用于未来的实验管理功能。 +暂不提供具体实现。 +""" + +__all__ = [] diff --git a/src/training/__init__.py b/src/training/__init__.py index 48e78ac..ac2fa66 100644 --- a/src/training/__init__.py +++ b/src/training/__init__.py @@ -1,6 +1,6 @@ """训练模块 - ProStock 量化投资框架 -提供模型训练、数据处理和评估的完整流程。 +提供模型训练、数据处理和预测的完整流程。 """ # 基础抽象类 @@ -14,6 +14,31 @@ from src.training.registry import ( register_processor, ) +# 数据划分器 +from src.training.components.splitters import DateSplitter + +# 股票池选择器配置 +from src.training.components.selectors import ( + MarketCapSelectorConfig, + StockFilterConfig, +) + +# 数据处理器 +from src.training.components.processors import ( + CrossSectionalStandardScaler, + StandardScaler, + Winsorizer, +) + +# 模型 +from src.training.components.models import LightGBMModel + +# 训练核心 +from src.training.core import StockPoolManager, Trainer + +# 配置 +from src.training.config import TrainingConfig + __all__ = [ # 基础抽象类 "BaseModel", @@ -23,4 +48,20 @@ __all__ = [ "ProcessorRegistry", "register_model", "register_processor", + # 数据划分器 + "DateSplitter", + # 股票池选择器配置 + "StockFilterConfig", + "MarketCapSelectorConfig", + # 数据处理器 + "StandardScaler", + "CrossSectionalStandardScaler", + "Winsorizer", + # 模型 + "LightGBMModel", + # 训练核心 + "StockPoolManager", + "Trainer", + # 配置 + "TrainingConfig", ] diff --git a/src/training/config/__init__.py b/src/training/config/__init__.py new file mode 100644 index 0000000..2245022 --- /dev/null +++ b/src/training/config/__init__.py @@ -0,0 +1,18 @@ +"""训练配置管理 + +提供 TrainingConfig 配置类和相关配置数据类。 +""" + +from src.training.config.config import ( + MarketCapSelectorConfig, + ProcessorConfig, + StockFilterConfig, + TrainingConfig, +) + +__all__ = [ + "TrainingConfig", + "StockFilterConfig", + "MarketCapSelectorConfig", + "ProcessorConfig", +] diff --git a/src/training/config/config.py b/src/training/config/config.py new file mode 100644 index 0000000..8b4368b --- /dev/null +++ b/src/training/config/config.py @@ -0,0 +1,141 @@ +"""训练配置管理 + +提供 TrainingConfig 配置类,使用 pydantic 进行参数验证。 +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from pydantic import Field, validator +from pydantic_settings import BaseSettings + + +@dataclass +class StockFilterConfig: + """股票过滤器配置""" + + exclude_cyb: bool = True # 排除创业板 + exclude_kcb: bool = True # 排除科创板 + exclude_bj: bool = True # 排除北交所 + exclude_st: bool = True # 排除ST股票 + + +@dataclass +class MarketCapSelectorConfig: + """市值选择器配置""" + + enabled: bool = True # 是否启用 + n: int = 100 # 选择前 n 只 + ascending: bool = False # False=最大市值, True=最小市值 + market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic) + + +@dataclass +class ProcessorConfig: + """处理器配置""" + + name: str + params: Dict[str, Any] = Field(default_factory=dict) + + +class TrainingConfig(BaseSettings): + """训练配置类 + + 所有配置参数通过此类管理,支持 pydantic 验证。 + """ + + # === 数据配置(必填)=== + feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个 + target_col: str = "target" # 目标变量列名 + date_col: str = "trade_date" # 日期列名 + code_col: str = "ts_code" # 股票代码列名 + + # === 日期划分(必填)=== + train_start: str = Field(..., description="训练期开始 YYYYMMDD") + train_end: str = Field(..., description="训练期结束 YYYYMMDD") + test_start: str = Field(..., description="测试期开始 YYYYMMDD") + test_end: str = Field(..., description="测试期结束 YYYYMMDD") + + # === 股票池配置 === + stock_filter: StockFilterConfig = Field( + default_factory=lambda: StockFilterConfig( + exclude_cyb=True, + exclude_kcb=True, + exclude_bj=True, + exclude_st=True, + ) + ) + stock_selector: Optional[MarketCapSelectorConfig] = Field( + default_factory=lambda: MarketCapSelectorConfig( + enabled=True, + n=100, + ascending=False, + market_cap_col="total_mv", + ) + ) + # 注意:如果 stock_selector = None,则跳过市值选择 + + # === 模型配置 === + model_type: str = "lightgbm" + model_params: Dict[str, Any] = Field(default_factory=dict) + + # === 处理器配置 === + processors: List[ProcessorConfig] = Field( + default_factory=lambda: [ + ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}), + ProcessorConfig(name="cs_standard_scaler", params={}), + ] + ) + + # === 持久化配置 === + persist_model: bool = False # 默认不持久化 + model_save_path: Optional[str] = None # 持久化路径 + + # === 输出配置 === + output_dir: str = "output" + save_predictions: bool = True + + @validator("train_start", "train_end", "test_start", "test_end") + def validate_date_format(cls, v: str) -> str: + """验证日期格式为 YYYYMMDD""" + if not isinstance(v, str) or len(v) != 8: + raise ValueError(f"日期必须是格式为 'YYYYMMDD' 的8位字符串,得到: {v}") + try: + int(v) + except ValueError: + raise ValueError(f"日期必须是数字字符串,得到: {v}") + return v + + @validator("train_end") + def validate_train_dates(cls, v: str, values: Dict[str, Any]) -> str: + """验证训练日期范围""" + if "train_start" in values and values["train_start"] > v: + raise ValueError( + f"train_start ({values['train_start']}) 必须早于或等于 train_end ({v})" + ) + return v + + @validator("test_end") + def validate_test_dates(cls, v: str, values: Dict[str, Any]) -> str: + """验证测试日期范围""" + if "test_start" in values and values["test_start"] > v: + raise ValueError( + f"test_start ({values['test_start']}) 必须早于或等于 test_end ({v})" + ) + return v + + @validator("test_start") + def validate_no_overlap(cls, v: str, values: Dict[str, Any]) -> str: + """验证训练集和测试集不重叠""" + if "train_end" in values and v <= values["train_end"]: + raise ValueError( + f"测试集开始日期 ({v}) 必须晚于训练集结束日期 ({values['train_end']})," + "以确保训练集和测试集不重叠" + ) + return v + + class Config: + """Pydantic 配置""" + + env_prefix = "TRAINING_" # 环境变量前缀 + env_nested_delimiter = "__" diff --git a/src/training/core/__init__.py b/src/training/core/__init__.py new file mode 100644 index 0000000..4aa568d --- /dev/null +++ b/src/training/core/__init__.py @@ -0,0 +1,9 @@ +"""训练核心模块 + +包含 Trainer 主类和股票池管理器。 +""" + +from src.training.core.stock_pool_manager import StockPoolManager +from src.training.core.trainer import Trainer + +__all__ = ["StockPoolManager", "Trainer"] diff --git a/src/training/core/stock_pool_manager.py b/src/training/core/stock_pool_manager.py new file mode 100644 index 0000000..2f29b44 --- /dev/null +++ b/src/training/core/stock_pool_manager.py @@ -0,0 +1,171 @@ +"""股票池管理器 + +每日独立筛选股票池,市值数据从 daily_basic 表独立获取。 +""" + +from typing import TYPE_CHECKING, Dict, List, Optional + +import polars as pl + +from src.training.components.selectors import MarketCapSelectorConfig, StockFilterConfig + +if TYPE_CHECKING: + from src.factors.engine.data_router import DataRouter + + +class StockPoolManager: + """股票池管理器 - 每日独立筛选 + + 重要约束: + 1. 市值数据仅从 daily_basic 表获取,仅用于筛选 + 2. 市值数据绝不混入特征矩阵 + 3. 每日独立筛选(市值是动态变化的) + + 处理流程(每日): + 当日所有股票 + ↓ + 代码过滤(创业板、ST等) + ↓ + 查询 daily_basic 获取当日市值 + ↓ + 市值选择(前N只) + ↓ + 返回当日选中股票列表 + """ + + def __init__( + self, + filter_config: StockFilterConfig, + selector_config: Optional[MarketCapSelectorConfig], + data_router: "DataRouter", + code_col: str = "ts_code", + date_col: str = "trade_date", + ): + """初始化股票池管理器 + + Args: + filter_config: 股票过滤器配置 + selector_config: 市值选择器配置,None 表示跳过市值选择 + data_router: 数据路由器,用于获取 daily_basic 数据 + code_col: 股票代码列名 + date_col: 日期列名 + """ + self.filter_config = filter_config + self.selector_config = selector_config + self.data_router = data_router + self.code_col = code_col + self.date_col = date_col + + def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame: + """每日独立筛选股票池 + + Args: + data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列 + + Returns: + 筛选后的数据,仅包含每日选中的股票 + + Note: + - 按日期分组处理 + - 市值数据从 daily_basic 独立获取 + - 保持市值数据与特征数据隔离 + """ + dates = data.select(self.date_col).unique().sort(self.date_col) + + result_frames = [] + for date in dates.to_series(): + # 获取当日数据 + daily_data = data.filter(pl.col(self.date_col) == date) + daily_codes = daily_data.select(self.code_col).to_series().to_list() + + # 1. 代码过滤 + filtered_codes = self.filter_config.filter_codes(daily_codes) + + # 2. 市值选择(如果启用) + if self.selector_config and self.selector_config.enabled: + # 从 daily_basic 获取当日市值 + market_caps = self._get_market_caps_for_date(filtered_codes, date) + selected_codes = self._select_by_market_cap(filtered_codes, market_caps) + else: + selected_codes = filtered_codes + + # 3. 保留当日选中的股票数据 + daily_selected = daily_data.filter( + pl.col(self.code_col).is_in(selected_codes) + ) + result_frames.append(daily_selected) + + return pl.concat(result_frames) + + def _get_market_caps_for_date( + self, codes: List[str], date: str + ) -> Dict[str, float]: + """从 daily_basic 表获取指定日期的市值数据 + + Args: + codes: 股票代码列表 + date: 日期 "YYYYMMDD" + + Returns: + {股票代码: 市值} 的字典 + """ + if not codes: + return {} + + assert self.selector_config is not None, ( + "selector_config should not be None when calling _get_market_caps_for_date" + ) + + try: + # 通过 data_router 查询 daily_basic 表 + from src.factors.engine.data_spec import DataSpec + + data_specs = [ + DataSpec("daily_basic", [self.selector_config.market_cap_col]) + ] + df = self.data_router.fetch_data( + data_specs=data_specs, + start_date=date, + end_date=date, + stock_codes=codes, + ) + + # 转换为字典 + market_caps = {} + for row in df.iter_rows(named=True): + code = row[self.code_col] + cap = row.get(self.selector_config.market_cap_col) + if cap is not None and code in codes: + market_caps[code] = float(cap) + + return market_caps + + except Exception as e: + print(f"[警告] 获取 {date} 市值数据失败: {e}") + return {} + + def _select_by_market_cap( + self, codes: List[str], market_caps: Dict[str, float] + ) -> List[str]: + """根据市值选择股票 + + Args: + codes: 股票代码列表 + market_caps: 市值数据字典 + + Returns: + 选中的股票代码列表 + """ + if self.selector_config is None: + return codes + + if not market_caps: + return codes[: self.selector_config.n] + + # 按市值排序并选择前N只 + sorted_codes = sorted( + codes, + key=lambda c: market_caps.get(c, 0), + reverse=not self.selector_config.ascending, + ) + return sorted_codes[: self.selector_config.n] diff --git a/src/training/core/trainer.py b/src/training/core/trainer.py new file mode 100644 index 0000000..c33c174 --- /dev/null +++ b/src/training/core/trainer.py @@ -0,0 +1,179 @@ +"""训练器主类 + +整合数据处理、模型训练、预测的完整流程。 +""" + +from typing import List, Optional + +import polars as pl + +from src.training.components.base import BaseModel, BaseProcessor +from src.training.components.splitters import DateSplitter +from src.training.core.stock_pool_manager import StockPoolManager + + +class Trainer: + """训练器主类 + + 整合数据处理、模型训练、预测的完整流程。 + + 关键设计: + 1. 因子先计算(全市场),再筛选股票池(每日独立) + 2. Processor 分阶段行为:训练集 fit_transform,测试集 transform + 3. 一次性训练,不滚动 + 4. 支持模型持久化 + """ + + def __init__( + self, + model: BaseModel, + pool_manager: Optional[StockPoolManager] = None, + processors: Optional[List[BaseProcessor]] = None, + splitter: Optional[DateSplitter] = None, + target_col: str = "target", + feature_cols: Optional[List[str]] = None, + persist_model: bool = False, + model_save_path: Optional[str] = None, + ): + """初始化训练器 + + Args: + model: 模型实例 + pool_manager: 股票池管理器,None 表示不筛选 + processors: 数据处理器列表 + splitter: 数据划分器 + target_col: 目标变量列名 + feature_cols: 特征列名列表 + persist_model: 是否保存模型 + model_save_path: 模型保存路径 + """ + self.model = model + self.pool_manager = pool_manager + self.processors = processors or [] + self.splitter = splitter + self.target_col = target_col + self.feature_cols = feature_cols or [] + self.persist_model = persist_model + self.model_save_path = model_save_path + + # 存储训练后的处理器 + self.fitted_processors: List[BaseProcessor] = [] + self.results: Optional[pl.DataFrame] = None + + def train(self, data: pl.DataFrame) -> "Trainer": + """执行训练流程 + + 流程: + 1. 股票池每日筛选(如果配置了 pool_manager) + 2. 按日期划分训练集/测试集 + 3. 训练集:processors fit_transform + 4. 训练模型 + 5. 测试集:processors transform(使用训练集学到的参数) + 6. 预测 + 7. 保存结果 + 8. 持久化模型(如果启用) + + Args: + data: 因子计算后的全市场数据 + 必须包含 ts_code 和 trade_date 列 + + Returns: + self (支持链式调用) + """ + # 1. 股票池筛选(每日独立) + if self.pool_manager: + print("[筛选] 每日独立筛选股票池...") + data = self.pool_manager.filter_and_select_daily(data) + + # 2. 划分训练/测试集 + if self.splitter: + print("[划分] 划分训练集和测试集...") + train_data, test_data = self.splitter.split(data) + else: + # 没有划分器,全部作为训练集 + train_data = data + test_data = data + + # 3. 训练集:processors fit_transform + if self.processors: + print("[处理] 处理训练集...") + for processor in self.processors: + train_data = processor.fit_transform(train_data) + self.fitted_processors.append(processor) + + # 4. 训练模型 + print("[训练] 训练模型...") + if not self.feature_cols: + raise ValueError("feature_cols 不能为空") + + X_train = train_data.select(self.feature_cols) + y_train = train_data.select(self.target_col).to_series() + self.model.fit(X_train, y_train) + + # 5. 测试集:processors transform + if self.processors and test_data is not train_data: + print("[处理] 处理测试集...") + for processor in self.fitted_processors: + test_data = processor.transform(test_data) + + # 6. 预测 + print("[预测] 生成预测...") + X_test = test_data.select(self.feature_cols) + predictions = self.model.predict(X_test) + + # 7. 保存结果 + self.results = test_data.with_columns([pl.Series("prediction", predictions)]) + + # 8. 持久化模型 + if self.persist_model and self.model_save_path: + print(f"[保存] 保存模型到 {self.model_save_path}...") + self.save_model(self.model_save_path) + + return self + + def predict(self, data: pl.DataFrame) -> pl.DataFrame: + """对新数据进行预测 + + 注意:新数据需要先经过股票池筛选, + 然后使用训练好的 processors 和 model 进行预测。 + + Args: + data: 输入数据 + + Returns: + 包含预测列的数据 + """ + # 应用 processors + for processor in self.fitted_processors: + data = processor.transform(data) + + # 预测 + X = data.select(self.feature_cols) + predictions = self.model.predict(X) + + return data.with_columns([pl.Series("prediction", predictions)]) + + def get_results(self) -> Optional[pl.DataFrame]: + """获取所有预测结果 + + Returns: + 预测结果 DataFrame,包含原始列和 prediction 列 + """ + return self.results + + def save_results(self, path: str) -> None: + """保存预测结果到文件 + + Args: + path: 保存路径(CSV 格式) + """ + if self.results is not None: + self.results.write_csv(path) + + def save_model(self, path: str) -> None: + """保存模型 + + Args: + path: 模型保存路径 + """ + self.model.save(path) diff --git a/tests/training/test_lightgbm_model.py b/tests/training/test_lightgbm_model.py index 3a673f2..4d74170 100644 --- a/tests/training/test_lightgbm_model.py +++ b/tests/training/test_lightgbm_model.py @@ -193,8 +193,17 @@ class TestLightGBMModel: """测试模型已注册到 registry""" from src.training.registry import ModelRegistry + # 重新导入模型模块以确保注册(处理其他测试 clear 注册表的情况) + import importlib + import src.training.components.models.lightgbm as lightgbm_module + + importlib.reload(lightgbm_module) + from src.training.components.models.lightgbm import ( + LightGBMModel as ReloadedModel, + ) + model_class = ModelRegistry.get_model("lightgbm") - assert model_class is LightGBMModel + assert model_class is ReloadedModel def test_fit_predict_consistency(self): """测试多次预测结果一致"""