feat(training): 实现训练模块核心组件(commits 6-9)
- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择 - Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化 - TrainingConfig:pydantic 配置管理,含必填字段和日期验证 - experiment 模块:预留结构 - 从计划中移除 metrics 组件 - 调整 commit 序号(7-10 → 6-9) - 更新 training/__init__.py 导出所有公开 API
This commit is contained in:
7
src/experiment/__init__.py
Normal file
7
src/experiment/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""实验管理模块(预留结构)
|
||||
|
||||
此模块为预留结构,用于未来的实验管理功能。
|
||||
暂不提供具体实现。
|
||||
"""
|
||||
|
||||
__all__ = []
|
||||
@@ -1,6 +1,6 @@
|
||||
"""训练模块 - ProStock 量化投资框架
|
||||
|
||||
提供模型训练、数据处理和评估的完整流程。
|
||||
提供模型训练、数据处理和预测的完整流程。
|
||||
"""
|
||||
|
||||
# 基础抽象类
|
||||
@@ -14,6 +14,31 @@ from src.training.registry import (
|
||||
register_processor,
|
||||
)
|
||||
|
||||
# 数据划分器
|
||||
from src.training.components.splitters import DateSplitter
|
||||
|
||||
# 股票池选择器配置
|
||||
from src.training.components.selectors import (
|
||||
MarketCapSelectorConfig,
|
||||
StockFilterConfig,
|
||||
)
|
||||
|
||||
# 数据处理器
|
||||
from src.training.components.processors import (
|
||||
CrossSectionalStandardScaler,
|
||||
StandardScaler,
|
||||
Winsorizer,
|
||||
)
|
||||
|
||||
# 模型
|
||||
from src.training.components.models import LightGBMModel
|
||||
|
||||
# 训练核心
|
||||
from src.training.core import StockPoolManager, Trainer
|
||||
|
||||
# 配置
|
||||
from src.training.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础抽象类
|
||||
"BaseModel",
|
||||
@@ -23,4 +48,20 @@ __all__ = [
|
||||
"ProcessorRegistry",
|
||||
"register_model",
|
||||
"register_processor",
|
||||
# 数据划分器
|
||||
"DateSplitter",
|
||||
# 股票池选择器配置
|
||||
"StockFilterConfig",
|
||||
"MarketCapSelectorConfig",
|
||||
# 数据处理器
|
||||
"StandardScaler",
|
||||
"CrossSectionalStandardScaler",
|
||||
"Winsorizer",
|
||||
# 模型
|
||||
"LightGBMModel",
|
||||
# 训练核心
|
||||
"StockPoolManager",
|
||||
"Trainer",
|
||||
# 配置
|
||||
"TrainingConfig",
|
||||
]
|
||||
|
||||
18
src/training/config/__init__.py
Normal file
18
src/training/config/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""训练配置管理
|
||||
|
||||
提供 TrainingConfig 配置类和相关配置数据类。
|
||||
"""
|
||||
|
||||
from src.training.config.config import (
|
||||
MarketCapSelectorConfig,
|
||||
ProcessorConfig,
|
||||
StockFilterConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TrainingConfig",
|
||||
"StockFilterConfig",
|
||||
"MarketCapSelectorConfig",
|
||||
"ProcessorConfig",
|
||||
]
|
||||
141
src/training/config/config.py
Normal file
141
src/training/config/config.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""训练配置管理
|
||||
|
||||
提供 TrainingConfig 配置类,使用 pydantic 进行参数验证。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import Field, validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@dataclass
|
||||
class StockFilterConfig:
|
||||
"""股票过滤器配置"""
|
||||
|
||||
exclude_cyb: bool = True # 排除创业板
|
||||
exclude_kcb: bool = True # 排除科创板
|
||||
exclude_bj: bool = True # 排除北交所
|
||||
exclude_st: bool = True # 排除ST股票
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketCapSelectorConfig:
|
||||
"""市值选择器配置"""
|
||||
|
||||
enabled: bool = True # 是否启用
|
||||
n: int = 100 # 选择前 n 只
|
||||
ascending: bool = False # False=最大市值, True=最小市值
|
||||
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorConfig:
|
||||
"""处理器配置"""
|
||||
|
||||
name: str
|
||||
params: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TrainingConfig(BaseSettings):
|
||||
"""训练配置类
|
||||
|
||||
所有配置参数通过此类管理,支持 pydantic 验证。
|
||||
"""
|
||||
|
||||
# === 数据配置(必填)===
|
||||
feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个
|
||||
target_col: str = "target" # 目标变量列名
|
||||
date_col: str = "trade_date" # 日期列名
|
||||
code_col: str = "ts_code" # 股票代码列名
|
||||
|
||||
# === 日期划分(必填)===
|
||||
train_start: str = Field(..., description="训练期开始 YYYYMMDD")
|
||||
train_end: str = Field(..., description="训练期结束 YYYYMMDD")
|
||||
test_start: str = Field(..., description="测试期开始 YYYYMMDD")
|
||||
test_end: str = Field(..., description="测试期结束 YYYYMMDD")
|
||||
|
||||
# === 股票池配置 ===
|
||||
stock_filter: StockFilterConfig = Field(
|
||||
default_factory=lambda: StockFilterConfig(
|
||||
exclude_cyb=True,
|
||||
exclude_kcb=True,
|
||||
exclude_bj=True,
|
||||
exclude_st=True,
|
||||
)
|
||||
)
|
||||
stock_selector: Optional[MarketCapSelectorConfig] = Field(
|
||||
default_factory=lambda: MarketCapSelectorConfig(
|
||||
enabled=True,
|
||||
n=100,
|
||||
ascending=False,
|
||||
market_cap_col="total_mv",
|
||||
)
|
||||
)
|
||||
# 注意:如果 stock_selector = None,则跳过市值选择
|
||||
|
||||
# === 模型配置 ===
|
||||
model_type: str = "lightgbm"
|
||||
model_params: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# === 处理器配置 ===
|
||||
processors: List[ProcessorConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}),
|
||||
ProcessorConfig(name="cs_standard_scaler", params={}),
|
||||
]
|
||||
)
|
||||
|
||||
# === 持久化配置 ===
|
||||
persist_model: bool = False # 默认不持久化
|
||||
model_save_path: Optional[str] = None # 持久化路径
|
||||
|
||||
# === 输出配置 ===
|
||||
output_dir: str = "output"
|
||||
save_predictions: bool = True
|
||||
|
||||
@validator("train_start", "train_end", "test_start", "test_end")
|
||||
def validate_date_format(cls, v: str) -> str:
|
||||
"""验证日期格式为 YYYYMMDD"""
|
||||
if not isinstance(v, str) or len(v) != 8:
|
||||
raise ValueError(f"日期必须是格式为 'YYYYMMDD' 的8位字符串,得到: {v}")
|
||||
try:
|
||||
int(v)
|
||||
except ValueError:
|
||||
raise ValueError(f"日期必须是数字字符串,得到: {v}")
|
||||
return v
|
||||
|
||||
@validator("train_end")
|
||||
def validate_train_dates(cls, v: str, values: Dict[str, Any]) -> str:
|
||||
"""验证训练日期范围"""
|
||||
if "train_start" in values and values["train_start"] > v:
|
||||
raise ValueError(
|
||||
f"train_start ({values['train_start']}) 必须早于或等于 train_end ({v})"
|
||||
)
|
||||
return v
|
||||
|
||||
@validator("test_end")
|
||||
def validate_test_dates(cls, v: str, values: Dict[str, Any]) -> str:
|
||||
"""验证测试日期范围"""
|
||||
if "test_start" in values and values["test_start"] > v:
|
||||
raise ValueError(
|
||||
f"test_start ({values['test_start']}) 必须早于或等于 test_end ({v})"
|
||||
)
|
||||
return v
|
||||
|
||||
@validator("test_start")
|
||||
def validate_no_overlap(cls, v: str, values: Dict[str, Any]) -> str:
|
||||
"""验证训练集和测试集不重叠"""
|
||||
if "train_end" in values and v <= values["train_end"]:
|
||||
raise ValueError(
|
||||
f"测试集开始日期 ({v}) 必须晚于训练集结束日期 ({values['train_end']}),"
|
||||
"以确保训练集和测试集不重叠"
|
||||
)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
"""Pydantic 配置"""
|
||||
|
||||
env_prefix = "TRAINING_" # 环境变量前缀
|
||||
env_nested_delimiter = "__"
|
||||
9
src/training/core/__init__.py
Normal file
9
src/training/core/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""训练核心模块
|
||||
|
||||
包含 Trainer 主类和股票池管理器。
|
||||
"""
|
||||
|
||||
from src.training.core.stock_pool_manager import StockPoolManager
|
||||
from src.training.core.trainer import Trainer
|
||||
|
||||
__all__ = ["StockPoolManager", "Trainer"]
|
||||
171
src/training/core/stock_pool_manager.py
Normal file
171
src/training/core/stock_pool_manager.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""股票池管理器
|
||||
|
||||
每日独立筛选股票池,市值数据从 daily_basic 表独立获取。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.training.components.selectors import MarketCapSelectorConfig, StockFilterConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.factors.engine.data_router import DataRouter
|
||||
|
||||
|
||||
class StockPoolManager:
|
||||
"""股票池管理器 - 每日独立筛选
|
||||
|
||||
重要约束:
|
||||
1. 市值数据仅从 daily_basic 表获取,仅用于筛选
|
||||
2. 市值数据绝不混入特征矩阵
|
||||
3. 每日独立筛选(市值是动态变化的)
|
||||
|
||||
处理流程(每日):
|
||||
当日所有股票
|
||||
↓
|
||||
代码过滤(创业板、ST等)
|
||||
↓
|
||||
查询 daily_basic 获取当日市值
|
||||
↓
|
||||
市值选择(前N只)
|
||||
↓
|
||||
返回当日选中股票列表
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_config: StockFilterConfig,
|
||||
selector_config: Optional[MarketCapSelectorConfig],
|
||||
data_router: "DataRouter",
|
||||
code_col: str = "ts_code",
|
||||
date_col: str = "trade_date",
|
||||
):
|
||||
"""初始化股票池管理器
|
||||
|
||||
Args:
|
||||
filter_config: 股票过滤器配置
|
||||
selector_config: 市值选择器配置,None 表示跳过市值选择
|
||||
data_router: 数据路由器,用于获取 daily_basic 数据
|
||||
code_col: 股票代码列名
|
||||
date_col: 日期列名
|
||||
"""
|
||||
self.filter_config = filter_config
|
||||
self.selector_config = selector_config
|
||||
self.data_router = data_router
|
||||
self.code_col = code_col
|
||||
self.date_col = date_col
|
||||
|
||||
def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""每日独立筛选股票池
|
||||
|
||||
Args:
|
||||
data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
|
||||
|
||||
Returns:
|
||||
筛选后的数据,仅包含每日选中的股票
|
||||
|
||||
Note:
|
||||
- 按日期分组处理
|
||||
- 市值数据从 daily_basic 独立获取
|
||||
- 保持市值数据与特征数据隔离
|
||||
"""
|
||||
dates = data.select(self.date_col).unique().sort(self.date_col)
|
||||
|
||||
result_frames = []
|
||||
for date in dates.to_series():
|
||||
# 获取当日数据
|
||||
daily_data = data.filter(pl.col(self.date_col) == date)
|
||||
daily_codes = daily_data.select(self.code_col).to_series().to_list()
|
||||
|
||||
# 1. 代码过滤
|
||||
filtered_codes = self.filter_config.filter_codes(daily_codes)
|
||||
|
||||
# 2. 市值选择(如果启用)
|
||||
if self.selector_config and self.selector_config.enabled:
|
||||
# 从 daily_basic 获取当日市值
|
||||
market_caps = self._get_market_caps_for_date(filtered_codes, date)
|
||||
selected_codes = self._select_by_market_cap(filtered_codes, market_caps)
|
||||
else:
|
||||
selected_codes = filtered_codes
|
||||
|
||||
# 3. 保留当日选中的股票数据
|
||||
daily_selected = daily_data.filter(
|
||||
pl.col(self.code_col).is_in(selected_codes)
|
||||
)
|
||||
result_frames.append(daily_selected)
|
||||
|
||||
return pl.concat(result_frames)
|
||||
|
||||
def _get_market_caps_for_date(
|
||||
self, codes: List[str], date: str
|
||||
) -> Dict[str, float]:
|
||||
"""从 daily_basic 表获取指定日期的市值数据
|
||||
|
||||
Args:
|
||||
codes: 股票代码列表
|
||||
date: 日期 "YYYYMMDD"
|
||||
|
||||
Returns:
|
||||
{股票代码: 市值} 的字典
|
||||
"""
|
||||
if not codes:
|
||||
return {}
|
||||
|
||||
assert self.selector_config is not None, (
|
||||
"selector_config should not be None when calling _get_market_caps_for_date"
|
||||
)
|
||||
|
||||
try:
|
||||
# 通过 data_router 查询 daily_basic 表
|
||||
from src.factors.engine.data_spec import DataSpec
|
||||
|
||||
data_specs = [
|
||||
DataSpec("daily_basic", [self.selector_config.market_cap_col])
|
||||
]
|
||||
df = self.data_router.fetch_data(
|
||||
data_specs=data_specs,
|
||||
start_date=date,
|
||||
end_date=date,
|
||||
stock_codes=codes,
|
||||
)
|
||||
|
||||
# 转换为字典
|
||||
market_caps = {}
|
||||
for row in df.iter_rows(named=True):
|
||||
code = row[self.code_col]
|
||||
cap = row.get(self.selector_config.market_cap_col)
|
||||
if cap is not None and code in codes:
|
||||
market_caps[code] = float(cap)
|
||||
|
||||
return market_caps
|
||||
|
||||
except Exception as e:
|
||||
print(f"[警告] 获取 {date} 市值数据失败: {e}")
|
||||
return {}
|
||||
|
||||
def _select_by_market_cap(
|
||||
self, codes: List[str], market_caps: Dict[str, float]
|
||||
) -> List[str]:
|
||||
"""根据市值选择股票
|
||||
|
||||
Args:
|
||||
codes: 股票代码列表
|
||||
market_caps: 市值数据字典
|
||||
|
||||
Returns:
|
||||
选中的股票代码列表
|
||||
"""
|
||||
if self.selector_config is None:
|
||||
return codes
|
||||
|
||||
if not market_caps:
|
||||
return codes[: self.selector_config.n]
|
||||
|
||||
# 按市值排序并选择前N只
|
||||
sorted_codes = sorted(
|
||||
codes,
|
||||
key=lambda c: market_caps.get(c, 0),
|
||||
reverse=not self.selector_config.ascending,
|
||||
)
|
||||
return sorted_codes[: self.selector_config.n]
|
||||
179
src/training/core/trainer.py
Normal file
179
src/training/core/trainer.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""训练器主类
|
||||
|
||||
整合数据处理、模型训练、预测的完整流程。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.training.components.base import BaseModel, BaseProcessor
|
||||
from src.training.components.splitters import DateSplitter
|
||||
from src.training.core.stock_pool_manager import StockPoolManager
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""训练器主类
|
||||
|
||||
整合数据处理、模型训练、预测的完整流程。
|
||||
|
||||
关键设计:
|
||||
1. 因子先计算(全市场),再筛选股票池(每日独立)
|
||||
2. Processor 分阶段行为:训练集 fit_transform,测试集 transform
|
||||
3. 一次性训练,不滚动
|
||||
4. 支持模型持久化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseModel,
|
||||
pool_manager: Optional[StockPoolManager] = None,
|
||||
processors: Optional[List[BaseProcessor]] = None,
|
||||
splitter: Optional[DateSplitter] = None,
|
||||
target_col: str = "target",
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
persist_model: bool = False,
|
||||
model_save_path: Optional[str] = None,
|
||||
):
|
||||
"""初始化训练器
|
||||
|
||||
Args:
|
||||
model: 模型实例
|
||||
pool_manager: 股票池管理器,None 表示不筛选
|
||||
processors: 数据处理器列表
|
||||
splitter: 数据划分器
|
||||
target_col: 目标变量列名
|
||||
feature_cols: 特征列名列表
|
||||
persist_model: 是否保存模型
|
||||
model_save_path: 模型保存路径
|
||||
"""
|
||||
self.model = model
|
||||
self.pool_manager = pool_manager
|
||||
self.processors = processors or []
|
||||
self.splitter = splitter
|
||||
self.target_col = target_col
|
||||
self.feature_cols = feature_cols or []
|
||||
self.persist_model = persist_model
|
||||
self.model_save_path = model_save_path
|
||||
|
||||
# 存储训练后的处理器
|
||||
self.fitted_processors: List[BaseProcessor] = []
|
||||
self.results: Optional[pl.DataFrame] = None
|
||||
|
||||
def train(self, data: pl.DataFrame) -> "Trainer":
|
||||
"""执行训练流程
|
||||
|
||||
流程:
|
||||
1. 股票池每日筛选(如果配置了 pool_manager)
|
||||
2. 按日期划分训练集/测试集
|
||||
3. 训练集:processors fit_transform
|
||||
4. 训练模型
|
||||
5. 测试集:processors transform(使用训练集学到的参数)
|
||||
6. 预测
|
||||
7. 保存结果
|
||||
8. 持久化模型(如果启用)
|
||||
|
||||
Args:
|
||||
data: 因子计算后的全市场数据
|
||||
必须包含 ts_code 和 trade_date 列
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
# 1. 股票池筛选(每日独立)
|
||||
if self.pool_manager:
|
||||
print("[筛选] 每日独立筛选股票池...")
|
||||
data = self.pool_manager.filter_and_select_daily(data)
|
||||
|
||||
# 2. 划分训练/测试集
|
||||
if self.splitter:
|
||||
print("[划分] 划分训练集和测试集...")
|
||||
train_data, test_data = self.splitter.split(data)
|
||||
else:
|
||||
# 没有划分器,全部作为训练集
|
||||
train_data = data
|
||||
test_data = data
|
||||
|
||||
# 3. 训练集:processors fit_transform
|
||||
if self.processors:
|
||||
print("[处理] 处理训练集...")
|
||||
for processor in self.processors:
|
||||
train_data = processor.fit_transform(train_data)
|
||||
self.fitted_processors.append(processor)
|
||||
|
||||
# 4. 训练模型
|
||||
print("[训练] 训练模型...")
|
||||
if not self.feature_cols:
|
||||
raise ValueError("feature_cols 不能为空")
|
||||
|
||||
X_train = train_data.select(self.feature_cols)
|
||||
y_train = train_data.select(self.target_col).to_series()
|
||||
self.model.fit(X_train, y_train)
|
||||
|
||||
# 5. 测试集:processors transform
|
||||
if self.processors and test_data is not train_data:
|
||||
print("[处理] 处理测试集...")
|
||||
for processor in self.fitted_processors:
|
||||
test_data = processor.transform(test_data)
|
||||
|
||||
# 6. 预测
|
||||
print("[预测] 生成预测...")
|
||||
X_test = test_data.select(self.feature_cols)
|
||||
predictions = self.model.predict(X_test)
|
||||
|
||||
# 7. 保存结果
|
||||
self.results = test_data.with_columns([pl.Series("prediction", predictions)])
|
||||
|
||||
# 8. 持久化模型
|
||||
if self.persist_model and self.model_save_path:
|
||||
print(f"[保存] 保存模型到 {self.model_save_path}...")
|
||||
self.save_model(self.model_save_path)
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""对新数据进行预测
|
||||
|
||||
注意:新数据需要先经过股票池筛选,
|
||||
然后使用训练好的 processors 和 model 进行预测。
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
包含预测列的数据
|
||||
"""
|
||||
# 应用 processors
|
||||
for processor in self.fitted_processors:
|
||||
data = processor.transform(data)
|
||||
|
||||
# 预测
|
||||
X = data.select(self.feature_cols)
|
||||
predictions = self.model.predict(X)
|
||||
|
||||
return data.with_columns([pl.Series("prediction", predictions)])
|
||||
|
||||
def get_results(self) -> Optional[pl.DataFrame]:
|
||||
"""获取所有预测结果
|
||||
|
||||
Returns:
|
||||
预测结果 DataFrame,包含原始列和 prediction 列
|
||||
"""
|
||||
return self.results
|
||||
|
||||
def save_results(self, path: str) -> None:
|
||||
"""保存预测结果到文件
|
||||
|
||||
Args:
|
||||
path: 保存路径(CSV 格式)
|
||||
"""
|
||||
if self.results is not None:
|
||||
self.results.write_csv(path)
|
||||
|
||||
def save_model(self, path: str) -> None:
|
||||
"""保存模型
|
||||
|
||||
Args:
|
||||
path: 模型保存路径
|
||||
"""
|
||||
self.model.save(path)
|
||||
Reference in New Issue
Block a user