# 训练模块实现计划 ## 1. 概述 本计划描述 ProStock 训练模块的完整实现方案,支持标准回归模型训练并输出预测结果。 ### 1.1 目标 - 提供简洁的训练流程 - 支持标准回归模型(LightGBM、CatBoost) - 输出可解释的预测结果 - 与现有因子引擎无缝集成 ### 1.2 设计原则 - **职责分离**:训练流程与建模组件分离 - **注册机制**:使用装饰器实现即插即用 - **配置驱动**:所有参数通过配置类管理 - **阶段感知**:Processor 在训练和测试阶段行为不同 - **每日筛选**:股票池每日独立筛选(市值动态变化) - **模型持久化**:支持保存和加载训练好的模型 ## 2. 模块结构 ``` src/ ├── training/ # 训练模块(流程 + 组件) │ ├── __init__.py # 导出核心类 │ ├── core/ # 训练流程核心 │ │ ├── __init__.py │ │ ├── trainer.py # Trainer 主类 │ │ └── stock_pool_manager.py # 股票池管理器(每日独立筛选) │ ├── components/ # 建模组件 │ │ ├── __init__.py │ │ ├── base.py # BaseModel, BaseProcessor 抽象基类 │ │ ├── splitters.py # 时间序列划分策略(一次性划分) │ │ ├── selectors.py # 股票池选择器配置 │ │ ├── models/ # 模型实现 │ │ │ ├── __init__.py │ │ │ ├── lightgbm.py # LightGBM 回归模型 │ │ │ └── catboost.py # CatBoost 回归模型 │ │ ├── processors/ # 数据处理器 │ │ │ ├── __init__.py │ │ │ └── transforms.py # 标准化(截面/时序)、缩尾 │ ├── config/ # 配置管理 │ │ ├── __init__.py │ │ └── config.py # TrainingConfig (pydantic) │ └── registry.py # 组件注册中心 │ └── experiment/ # 实验管理(预留结构,暂不实现) └── __init__.py ``` ## 3. 核心组件设计 ### 3.1 抽象基类 (components/base.py) #### BaseModel ```python class BaseModel(ABC): """模型基类""" name: str = "" # 模型名称 def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel": """训练模型""" raise NotImplementedError def predict(self, X: pl.DataFrame) -> np.ndarray: """预测""" raise NotImplementedError def feature_importance(self) -> Optional[pd.Series]: """特征重要性""" return None def save(self, path: str) -> None: """保存模型到文件 默认实现使用 pickle,子类可覆盖 """ import pickle with open(path, 'wb') as f: pickle.dump(self, f) @classmethod def load(cls, path: str) -> "BaseModel": """从文件加载模型""" import pickle with open(path, 'rb') as f: return pickle.load(f) ``` #### BaseProcessor ```python class BaseProcessor(ABC): """数据处理器基类 重要:Processor 在不同阶段行为不同: - 训练阶段:fit_transform(学习参数并应用) - 验证/测试阶段:transform(使用训练阶段学到的参数) 这意味着 Processor 实例会在训练后被保存, 用于后续的验证和测试数据转换。 """ name: str = "" def fit(self, X: pl.DataFrame) -> "BaseProcessor": """学习参数(仅在训练阶段调用)""" return self def transform(self, X: pl.DataFrame) -> pl.DataFrame: """转换数据""" raise NotImplementedError def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame: """拟合并转换(训练阶段使用)""" return self.fit(X).transform(X) ``` ### 3.2 时间序列划分 (components/splitters.py) **设计说明**:暂不实现滚动训练,采用一次性训练/测试划分。 #### DateSplitter ```python class DateSplitter: """基于日期范围的一次性划分 将数据按日期划分为训练集和测试集,不滚动。 示例: train_start: "20200101", train_end: "20221231" (训练集:3年) test_start: "20230101", test_end: "20231231" (测试集:1年) 特点: - 一次性划分,不滚动 - 训练集和测试集互不重叠 - 基于实际日期范围,而非行数 """ def __init__( self, train_start: str, # 训练期开始日期 "YYYYMMDD" train_end: str, # 训练期结束日期 "YYYYMMDD" test_start: str, # 测试期开始日期 "YYYYMMDD" test_end: str, # 测试期结束日期 "YYYYMMDD" ): self.train_start = train_start self.train_end = train_end self.test_start = test_start self.test_end = test_end def split(self, data: pl.DataFrame) -> Tuple[pl.DataFrame, pl.DataFrame]: """划分数据为训练集和测试集""" train_data = data.filter( (pl.col("trade_date") >= self.train_start) & (pl.col("trade_date") <= self.train_end) ) test_data = data.filter( (pl.col("trade_date") >= self.test_start) & (pl.col("trade_date") <= self.test_end) ) return train_data, test_data ``` ### 3.3 股票池选择器配置 (components/selectors.py) **设计说明**:股票池每日独立筛选,市值选择需要配合 StockPoolManager 使用。 #### StockFilterConfig ```python @dataclass class StockFilterConfig: """股票过滤器配置 用于过滤掉不需要的股票(如创业板、科创板等)。 基于股票代码进行过滤,不依赖外部数据。 """ exclude_cyb: bool = True # 是否排除创业板(300xxx) exclude_kcb: bool = True # 是否排除科创板(688xxx) exclude_bj: bool = True # 是否排除北交所(8xxxxxx, 4xxxxxx) exclude_st: bool = True # 是否排除ST股票 def filter_codes(self, codes: List[str]) -> List[str]: """应用过滤条件,返回过滤后的股票代码列表""" result = [] for code in codes: if self.exclude_cyb and code.startswith("300"): continue if self.exclude_kcb and code.startswith("688"): continue if self.exclude_bj and (code.startswith("8") or code.startswith("4")): continue # ST 股票过滤需要额外数据,在 StockPoolManager 中处理 result.append(code) return result ``` #### MarketCapSelectorConfig ```python @dataclass class MarketCapSelectorConfig: """市值选择器配置 每日独立选择市值最大或最小的 n 只股票。 市值数据从 daily_basic 表独立获取,仅用于筛选。 """ enabled: bool = True # 是否启用选择 n: int = 100 # 选择前 n 只 ascending: bool = False # False=最大市值, True=最小市值 market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic) ``` ### 3.4 股票池管理器 (core/stock_pool_manager.py) **设计说明**:每日独立筛选股票池,市值数据从 daily_basic 表独立获取。 ```python class StockPoolManager: """股票池管理器 - 每日独立筛选 重要约束: 1. 市值数据仅从 daily_basic 表获取,仅用于筛选 2. 市值数据绝不混入特征矩阵 3. 每日独立筛选(市值是动态变化的) 处理流程(每日): 当日所有股票 ↓ 代码过滤(创业板、ST等) ↓ 查询 daily_basic 获取当日市值 ↓ 市值选择(前N只) ↓ 返回当日选中股票列表 """ def __init__( self, filter_config: StockFilterConfig, selector_config: Optional[MarketCapSelectorConfig], data_router: DataRouter, # 用于获取 daily_basic 数据 code_col: str = "ts_code", date_col: str = "trade_date", ): self.filter_config = filter_config self.selector_config = selector_config self.data_router = data_router self.code_col = code_col self.date_col = date_col def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame: """每日独立筛选股票池 Args: data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列 Returns: 筛选后的数据,仅包含每日选中的股票 注意: - 按日期分组处理 - 市值数据从 daily_basic 独立获取 - 保持市值数据与特征数据隔离 """ dates = data.select(self.date_col).unique().sort(self.date_col) result_frames = [] for date in dates.to_series(): # 获取当日数据 daily_data = data.filter(pl.col(self.date_col) == date) daily_codes = daily_data.select(self.code_col).to_series().to_list() # 1. 代码过滤 filtered_codes = self.filter_config.filter_codes(daily_codes) # 2. 市值选择(如果启用) if self.selector_config and self.selector_config.enabled: # 从 daily_basic 获取当日市值 market_caps = self._get_market_caps_for_date(filtered_codes, date) selected_codes = self._select_by_market_cap(filtered_codes, market_caps) else: selected_codes = filtered_codes # 3. 保留当日选中的股票数据 daily_selected = daily_data.filter( pl.col(self.code_col).is_in(selected_codes) ) result_frames.append(daily_selected) return pl.concat(result_frames) def _get_market_caps_for_date( self, codes: List[str], date: str ) -> Dict[str, float]: """从 daily_basic 表获取指定日期的市值数据 Args: codes: 股票代码列表 date: 日期 "YYYYMMDD" Returns: {股票代码: 市值} 的字典 """ # 通过 data_router 查询 daily_basic 表 pass def _select_by_market_cap( self, codes: List[str], market_caps: Dict[str, float] ) -> List[str]: """根据市值选择股票""" if not market_caps: return codes # 按市值排序并选择前N只 sorted_codes = sorted( codes, key=lambda c: market_caps.get(c, 0), reverse=not self.selector_config.ascending ) return sorted_codes[:self.selector_config.n] ``` ### 3.5 模型实现 (components/models/) #### LightGBMModel ```python @register_model("lightgbm") class LightGBMModel(BaseModel): """LightGBM 回归模型""" name = "lightgbm" def __init__(self, objective: str = "regression", metric: str = "rmse", num_leaves: int = 31, learning_rate: float = 0.05, n_estimators: int = 100, **kwargs): self.params = { "objective": objective, "metric": metric, "num_leaves": num_leaves, "learning_rate": learning_rate, "n_estimators": n_estimators, **kwargs } self.model = None def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel": """训练模型""" import lightgbm as lgb # 转换为 numpy X_np = X.to_numpy() y_np = y.to_numpy() # 创建数据集 train_data = lgb.Dataset(X_np, label=y_np) # 训练 self.model = lgb.train( self.params, train_data, num_boost_round=self.params.get("n_estimators", 100) ) return self def predict(self, X: pl.DataFrame) -> np.ndarray: """预测""" if self.model is None: raise RuntimeError("Model not fitted yet") X_np = X.to_numpy() return self.model.predict(X_np) def feature_importance(self) -> pd.Series: """返回特征重要性""" if self.model is None: return None importance = self.model.feature_importance(importance_type="gain") return pd.Series(importance, index=self.feature_names_) def save(self, path: str) -> None: """保存模型(使用 LightGBM 原生格式)""" if self.model is None: raise RuntimeError("Model not fitted yet") self.model.save_model(path) @classmethod def load(cls, path: str) -> "LightGBMModel": """加载模型""" import lightgbm as lgb instance = cls() instance.model = lgb.Booster(model_file=path) return instance ``` ### 3.6 数据处理器 (components/processors/) #### StandardScaler ```python @register_processor("standard_scaler") class StandardScaler(BaseProcessor): """标准化处理器(时序标准化) 在整个训练集上学习均值和标准差, 然后应用到训练集和测试集。 """ name = "standard_scaler" def __init__(self, exclude_cols: List[str] = None): self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] self.mean_ = {} self.std_ = {} def fit(self, X: pl.DataFrame) -> "StandardScaler": """计算均值和标准差(仅在训练集上)""" numeric_cols = [ c for c in X.columns if c not in self.exclude_cols and X[c].dtype.is_numeric() ] for col in numeric_cols: self.mean_[col] = X[col].mean() self.std_[col] = X[col].std() return self def transform(self, X: pl.DataFrame) -> pl.DataFrame: """标准化(使用训练集学到的参数)""" expressions = [] for col in X.columns: if col in self.mean_: expr = ((pl.col(col) - self.mean_[col]) / self.std_[col]).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) return X.select(expressions) ``` #### CrossSectionalStandardScaler ```python @register_processor("cs_standard_scaler") class CrossSectionalStandardScaler(BaseProcessor): """截面标准化处理器 每天独立进行标准化:对当天所有股票的某一因子进行标准化。 特点: - 不需要 fit,每天独立计算当天的均值和标准差 - 适用于截面因子,消除市值等行业差异 - 公式:z = (x - mean_today) / std_today """ name = "cs_standard_scaler" def __init__(self, exclude_cols: List[str] = None, date_col: str = "trade_date"): self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] self.date_col = date_col def transform(self, X: pl.DataFrame) -> pl.DataFrame: """截面标准化 按日期分组,每天独立计算均值和标准差并标准化。 不需要 fit,因为每天使用当天的统计量。 """ numeric_cols = [ c for c in X.columns if c not in self.exclude_cols and X[c].dtype.is_numeric() ] # 按日期分组标准化 result = X.with_columns([ pl.col(col).mean().over(self.date_col).alias(f"{col}_mean") for col in numeric_cols ] + [ pl.col(col).std().over(self.date_col).alias(f"{col}_std") for col in numeric_cols ]) # 计算标准化值 for col in numeric_cols: result = result.with_columns([ ((pl.col(col) - pl.col(f"{col}_mean")) / pl.col(f"{col}_std")).alias(col) ]) # 删除中间列 result = result.drop([f"{col}_mean", f"{col}_std"]) return result ``` #### Winsorizer ```python @register_processor("winsorizer") class Winsorizer(BaseProcessor): """缩尾处理器 对每一列的极端值进行截断处理。 可以全局截断(在整个训练集上学习分位数), 也可以截面截断(每天独立处理)。 """ name = "winsorizer" def __init__( self, lower: float = 0.01, upper: float = 0.99, by_date: bool = False, # True=每天独立缩尾, False=全局缩尾 date_col: str = "trade_date" ): self.lower = lower self.upper = upper self.by_date = by_date self.date_col = date_col self.bounds_ = {} # 存储分位数边界(全局模式) def fit(self, X: pl.DataFrame) -> "Winsorizer": """学习分位数边界(仅在全局模式下)""" if not self.by_date: numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()] for col in numeric_cols: self.bounds_[col] = { "lower": X[col].quantile(self.lower), "upper": X[col].quantile(self.upper) } return self def transform(self, X: pl.DataFrame) -> pl.DataFrame: """缩尾处理""" if self.by_date: # 每天独立缩尾 return self._transform_by_date(X) else: # 全局缩尾 return self._transform_global(X) def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame: """全局缩尾(使用训练集学到的边界)""" expressions = [] for col in X.columns: if col in self.bounds_: lower = self.bounds_[col]["lower"] upper = self.bounds_[col]["upper"] expr = pl.col(col).clip(lower, upper).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) return X.select(expressions) def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame: """每日独立缩尾""" # 按日期分组计算分位数并截断 # Polars 实现... pass ``` ## 4. 训练流程设计 ### 4.1 Trainer 主类 (core/trainer.py) ```python class Trainer: """训练器主类 整合数据处理、模型训练、评估的完整流程。 关键设计: 1. 因子先计算(全市场),再筛选股票池(每日独立) 2. Processor 分阶段行为:训练集 fit_transform,测试集 transform 3. 一次性训练,不滚动 4. 支持模型持久化 """ def __init__( self, model: BaseModel, pool_manager: Optional[StockPoolManager] = None, processors: List[BaseProcessor] = None, splitter: DateSplitter = None, target_col: str = "target", feature_cols: List[str] = None, persist_model: bool = False, model_save_path: Optional[str] = None, ): self.model = model self.pool_manager = pool_manager self.processors = processors or [] self.splitter = splitter self.target_col = target_col self.feature_cols = feature_cols self.persist_model = persist_model self.model_save_path = model_save_path # 存储训练后的处理器 self.fitted_processors: List[BaseProcessor] = [] self.results: pl.DataFrame = None def train(self, data: pl.DataFrame) -> "Trainer": """执行训练流程 流程: 1. 股票池每日筛选(如果配置了 pool_manager) 2. 按日期划分训练集/测试集 3. 训练集:processors fit_transform 4. 训练模型 5. 测试集:processors transform(使用训练集学到的参数) 6. 预测并评估 7. 持久化模型(如果启用) Args: data: 因子计算后的全市场数据 必须包含 ts_code 和 trade_date 列 Returns: self (支持链式调用) """ # 1. 股票池筛选(每日独立) if self.pool_manager: print("[筛选] 每日独立筛选股票池...") data = self.pool_manager.filter_and_select_daily(data) # 2. 划分训练/测试集 print("[划分] 划分训练集和测试集...") train_data, test_data = self.splitter.split(data) # 3. 训练集:processors fit_transform print("[处理] 处理训练集...") for processor in self.processors: train_data = processor.fit_transform(train_data) self.fitted_processors.append(processor) # 4. 训练模型 print("[训练] 训练模型...") X_train = train_data.select(self.feature_cols) y_train = train_data.select(self.target_col).to_series() self.model.fit(X_train, y_train) # 5. 测试集:processors transform print("[处理] 处理测试集...") for processor in self.fitted_processors: test_data = processor.transform(test_data) # 6. 预测 print("[预测] 生成预测...") X_test = test_data.select(self.feature_cols) predictions = self.model.predict(X_test) # 7. 保存结果 self.results = test_data.with_columns([ pl.Series("prediction", predictions) ]) # 8. 持久化模型 if self.persist_model and self.model_save_path: print(f"[保存] 保存模型到 {self.model_save_path}...") self.save_model(self.model_save_path) return self def predict(self, data: pl.DataFrame) -> pl.DataFrame: """对新数据进行预测 注意:新数据需要先经过股票池筛选, 然后使用训练好的 processors 和 model 进行预测。 """ # 应用 processors for processor in self.fitted_processors: data = processor.transform(data) # 预测 X = data.select(self.feature_cols) predictions = self.model.predict(X) return data.with_columns([pl.Series("prediction", predictions)]) def get_results(self) -> pl.DataFrame: """获取所有预测结果""" return self.results def save_results(self, path: str): """保存预测结果到文件""" if self.results is not None: self.results.write_csv(path) def save_model(self, path: str): """保存模型""" self.model.save(path) ``` ### 4.2 配置类 (config/config.py) ```python from pydantic import BaseSettings, Field from typing import List, Dict, Optional from dataclasses import dataclass @dataclass class StockFilterConfig: """股票过滤器配置""" exclude_cyb: bool = True # 排除创业板 exclude_kcb: bool = True # 排除科创板 exclude_bj: bool = True # 排除北交所 exclude_st: bool = True # 排除ST股票 @dataclass class MarketCapSelectorConfig: """市值选择器配置""" enabled: bool = True # 是否启用 n: int = 100 # 选择前 n 只 ascending: bool = False # False=最大市值, True=最小市值 market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic) @dataclass class ProcessorConfig: """处理器配置""" name: str params: Dict = Field(default_factory=dict) class TrainingConfig(BaseSettings): """训练配置类""" # === 数据配置(必填)=== feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个 target_col: str = "target" # 目标变量列名 date_col: str = "trade_date" # 日期列名 code_col: str = "ts_code" # 股票代码列名 # === 日期划分(必填)=== train_start: str = Field(..., description="训练期开始 YYYYMMDD") train_end: str = Field(..., description="训练期结束 YYYYMMDD") test_start: str = Field(..., description="测试期开始 YYYYMMDD") test_end: str = Field(..., description="测试期结束 YYYYMMDD") # === 股票池配置 === stock_filter: StockFilterConfig = Field( default_factory=lambda: StockFilterConfig( exclude_cyb=True, exclude_kcb=True, exclude_bj=True, exclude_st=True, ) ) stock_selector: Optional[MarketCapSelectorConfig] = Field( default_factory=lambda: MarketCapSelectorConfig( enabled=True, n=100, ascending=False, market_cap_col="total_mv", ) ) # 注意:如果 stock_selector = None,则跳过市值选择 # === 模型配置 === model_type: str = "lightgbm" model_params: Dict = Field(default_factory=dict) # === 处理器配置 === processors: List[ProcessorConfig] = Field( default_factory=lambda: [ ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}), ProcessorConfig(name="cs_standard_scaler", params={}), ] ) # === 持久化配置 === persist_model: bool = False # 默认不持久化 model_save_path: Optional[str] = None # 持久化路径 # === 输出配置 === output_dir: str = "output" save_predictions: bool = True ``` ## 5. 使用示例 ### 5.1 基本用法 ```python from src.training import Trainer, LightGBMModel, DateSplitter from src.training.config import StockFilterConfig, MarketCapSelectorConfig from src.training.core.stock_pool_manager import StockPoolManager from src.training.components.processors import Winsorizer, CrossSectionalStandardScaler from src.factors import FactorEngine import polars as pl # 1. 因子计算(全市场数据) engine = FactorEngine() all_data = engine.compute( factor_names=["factor1", "factor2", "factor3", "future_return_5d"], start_date="20200101", end_date="20231231", # 不指定 stock_codes,获取全市场数据 ) # 2. 创建股票池管理器(每日独立筛选) pool_manager = StockPoolManager( filter_config=StockFilterConfig( exclude_cyb=True, exclude_kcb=True, exclude_bj=True, exclude_st=True, ), selector_config=MarketCapSelectorConfig( enabled=True, n=100, ascending=False, # 最大市值 ), data_router=data_router, # 用于获取 daily_basic 数据 ) # 3. 创建模型 model = LightGBMModel( objective="regression", n_estimators=100, learning_rate=0.05 ) # 4. 创建处理器 processors = [ Winsorizer(lower=0.01, upper=0.99, by_date=False), # 全局缩尾 CrossSectionalStandardScaler(), # 截面标准化(每天独立) ] # 5. 创建划分器 splitter = DateSplitter( train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231" ) # 6. 创建训练器 trainer = Trainer( model=model, pool_manager=pool_manager, # 每日筛选股票池 processors=processors, splitter=splitter, target_col="future_return_5d", feature_cols=["factor1", "factor2", "factor3"], persist_model=True, # 启用持久化 model_save_path="output/model.pkl", ) # 7. 执行训练(传入全市场数据) trainer.train(all_data) # 8. 获取结果 results = trainer.get_results() # 包含预测值 # 9. 保存结果 trainer.save_results("output/predictions.csv") # 10. 加载模型并预测新数据 loaded_model = LightGBMModel.load("output/model.pkl") new_predictions = loaded_model.predict(new_data) ``` ### 5.2 使用配置驱动 ```python from src.training.config import TrainingConfig, StockFilterConfig from src.training import Trainer from src.training.registry import ModelRegistry, ProcessorRegistry from src.training.core.stock_pool_manager import StockPoolManager from src.factors import FactorEngine # 1. 配置(必填字段校验) config = TrainingConfig( feature_cols=["factor1", "factor2", "factor3"], target_col="future_return_5d", train_start="20200101", train_end="20221231", test_start="20230101", test_end="20231231", stock_filter=StockFilterConfig( exclude_cyb=True, exclude_kcb=True, ), stock_selector=None, # 跳过市值选择 persist_model=False, # 不持久化 ) # 2. 因子计算 engine = FactorEngine() all_data = engine.compute( factor_names=config.feature_cols + [config.target_col], start_date=config.train_start, end_date=config.test_end, ) # 3. 从配置创建组件 model = ModelRegistry.get_model(config.model_type)(**config.model_params) processors = [ ProcessorRegistry.get_processor(p.name)(**p.params) for p in config.processors ] # 4. 创建股票池管理器 pool_manager = StockPoolManager( filter_config=config.stock_filter, selector_config=config.stock_selector, data_router=data_router, ) # 5. 创建并运行训练器 trainer = Trainer( model=model, pool_manager=pool_manager, processors=processors, splitter=DateSplitter( train_start=config.train_start, train_end=config.train_end, test_start=config.test_start, test_end=config.test_end, ), target_col=config.target_col, feature_cols=config.feature_cols, ) trainer.train(all_data) results = trainer.get_results() ``` ## 6. 实现顺序 按以下顺序实现和提交: ### Commit 1: 基础架构 - `training/__init__.py` - `training/components/__init__.py` - `training/components/base.py`(BaseModel, BaseProcessor,含 save/load) - `training/registry.py`(组件注册中心) ### Commit 2: 数据划分 - `training/components/splitters.py`(DateSplitter,仅一次性划分) ### Commit 3: 股票池选择器配置 - `training/components/selectors.py`(StockFilterConfig, MarketCapSelectorConfig) ### Commit 4: 数据处理器 - `training/components/processors/__init__.py` - `training/components/processors/transforms.py` - Winsorizer - StandardScaler - CrossSectionalStandardScaler ### Commit 5: LightGBM 模型 - `training/components/models/__init__.py` - `training/components/models/lightgbm.py`(含 save_model/load_model) ### Commit 6: 股票池管理器 - `training/core/__init__.py` - `training/core/stock_pool_manager.py`(每日独立筛选) ### Commit 7: Trainer 训练器 - `training/core/trainer.py` ### Commit 8: 配置管理 - `training/config/__init__.py` - `training/config/config.py`(TrainingConfig,含必填校验) ### Commit 9: 预留实验模块 - `experiment/__init__.py` ## 7. 注意事项 ### 7.1 股票池处理顺序(每日) ``` 当日所有股票数据 ↓ 代码过滤(创业板、ST等) ↓ 查询 daily_basic 获取当日市值 ↓ 市值选择(前N只) ↓ 返回当日选中股票 ``` - 每日独立筛选,市值动态变化 - 市值数据仅从 daily_basic 获取 - 市值数据绝不混入特征矩阵 ### 7.2 Processor 阶段行为 | Processor | 训练集 | 测试集 | |-----------|--------|--------| | StandardScaler | fit_transform | transform(使用训练集参数) | | CrossSectionalStandardScaler | transform | transform(每天独立) | | Winsorizer (global) | fit_transform | transform(使用训练集参数) | | Winsorizer (by_date) | transform | transform(每天独立) | ### 7.3 依赖关系 - 使用 Polars 进行数据处理 - LightGBM 用于模型训练 - Pydantic 用于配置管理 - 市值数据来自 daily_basic 表(独立数据源) ### 7.4 删除的功能 以下原计划在本次实现中删除: 1. **特征选择**(processors/selectors.py) 2. **滚动训练**(WalkForward, ExpandingWindow) 3. **结果分析工具**(复杂分析功能) 4. **validator.py, evaluator.py**(已删除,不实现 metrics) ### 7.5 新增功能 1. **StockPoolManager**(每日独立筛选) 2. **模型持久化**(save/load,默认关闭) 3. **配置必填校验**(feature_cols, 日期范围)