feat(training): 实现训练模块核心组件(commits 6-9)

- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择
- Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化
- TrainingConfig:pydantic 配置管理,含必填字段和日期验证
- experiment 模块:预留结构
- 从计划中移除 metrics 组件
- 调整 commit 序号(7-10 → 6-9)
- 更新 training/__init__.py 导出所有公开 API
This commit is contained in:
2026-03-03 22:57:01 +08:00
parent f35a6a76a6
commit 192718095f
9 changed files with 584 additions and 73 deletions

View File

@@ -40,9 +40,6 @@ src/
│ │ ├── processors/ # 数据处理器 │ │ ├── processors/ # 数据处理器
│ │ │ ├── __init__.py │ │ │ ├── __init__.py
│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾 │ │ │ └── transforms.py # 标准化(截面/时序)、缩尾
│ │ └── metrics/ # 评估指标
│ │ ├── __init__.py
│ │ └── metrics.py # IC, RankIC, MSE, MAE
│ ├── config/ # 配置管理 │ ├── config/ # 配置管理
│ │ ├── __init__.py │ │ ├── __init__.py
│ │ └── config.py # TrainingConfig (pydantic) │ │ └── config.py # TrainingConfig (pydantic)
@@ -568,30 +565,6 @@ class Winsorizer(BaseProcessor):
pass pass
``` ```
### 3.7 评估指标 (components/metrics/)
```python
def ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""信息系数 (Pearson 相关系数)"""
from scipy.stats import pearsonr
corr, _ = pearsonr(y_true, y_pred)
return corr
def rank_ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""秩相关系数 (Spearman 相关系数)"""
from scipy.stats import spearmanr
corr, _ = spearmanr(y_true, y_pred)
return corr
def mse_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""均方误差"""
return np.mean((y_true - y_pred) ** 2)
def mae_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""平均绝对误差"""
return np.mean(np.abs(y_true - y_pred))
```
## 4. 训练流程设计 ## 4. 训练流程设计
### 4.1 Trainer 主类 (core/trainer.py) ### 4.1 Trainer 主类 (core/trainer.py)
@@ -615,7 +588,6 @@ class Trainer:
pool_manager: Optional[StockPoolManager] = None, pool_manager: Optional[StockPoolManager] = None,
processors: List[BaseProcessor] = None, processors: List[BaseProcessor] = None,
splitter: DateSplitter = None, splitter: DateSplitter = None,
metrics: List[str] = None,
target_col: str = "target", target_col: str = "target",
feature_cols: List[str] = None, feature_cols: List[str] = None,
persist_model: bool = False, persist_model: bool = False,
@@ -625,7 +597,6 @@ class Trainer:
self.pool_manager = pool_manager self.pool_manager = pool_manager
self.processors = processors or [] self.processors = processors or []
self.splitter = splitter self.splitter = splitter
self.metrics = metrics or ["ic", "rank_ic", "mse", "mae"]
self.target_col = target_col self.target_col = target_col
self.feature_cols = feature_cols self.feature_cols = feature_cols
self.persist_model = persist_model self.persist_model = persist_model
@@ -634,7 +605,6 @@ class Trainer:
# 存储训练后的处理器 # 存储训练后的处理器
self.fitted_processors: List[BaseProcessor] = [] self.fitted_processors: List[BaseProcessor] = []
self.results: pl.DataFrame = None self.results: pl.DataFrame = None
self.metrics_results: Dict[str, float] = {}
def train(self, data: pl.DataFrame) -> "Trainer": def train(self, data: pl.DataFrame) -> "Trainer":
"""执行训练流程 """执行训练流程
@@ -686,38 +656,18 @@ class Trainer:
X_test = test_data.select(self.feature_cols) X_test = test_data.select(self.feature_cols)
predictions = self.model.predict(X_test) predictions = self.model.predict(X_test)
# 7. 评估 # 7. 保存结果
print("[评估] 计算指标...")
y_test = test_data.select(self.target_col).to_series().to_numpy()
self._evaluate(y_test, predictions)
# 8. 保存结果
self.results = test_data.with_columns([ self.results = test_data.with_columns([
pl.Series("prediction", predictions) pl.Series("prediction", predictions)
]) ])
# 9. 持久化模型 # 8. 持久化模型
if self.persist_model and self.model_save_path: if self.persist_model and self.model_save_path:
print(f"[保存] 保存模型到 {self.model_save_path}...") print(f"[保存] 保存模型到 {self.model_save_path}...")
self.save_model(self.model_save_path) self.save_model(self.model_save_path)
return self return self
def _evaluate(self, y_true: np.ndarray, y_pred: np.ndarray):
"""计算评估指标"""
from .components.metrics import ic_score, rank_ic_score, mse_score, mae_score
metric_funcs = {
"ic": ic_score,
"rank_ic": rank_ic_score,
"mse": mse_score,
"mae": mae_score,
}
for metric in self.metrics:
if metric in metric_funcs:
self.metrics_results[metric] = metric_funcs[metric](y_true, y_pred)
def predict(self, data: pl.DataFrame) -> pl.DataFrame: def predict(self, data: pl.DataFrame) -> pl.DataFrame:
"""对新数据进行预测 """对新数据进行预测
@@ -738,10 +688,6 @@ class Trainer:
"""获取所有预测结果""" """获取所有预测结果"""
return self.results return self.results
def get_metrics(self) -> Dict[str, float]:
"""获取评估指标"""
return self.metrics_results
def save_results(self, path: str): def save_results(self, path: str):
"""保存预测结果到文件""" """保存预测结果到文件"""
if self.results is not None: if self.results is not None:
@@ -827,9 +773,6 @@ class TrainingConfig(BaseSettings):
] ]
) )
# === 评估指标 ===
metrics: List[str] = ["ic", "rank_ic", "mse", "mae"]
# === 持久化配置 === # === 持久化配置 ===
persist_model: bool = False # 默认不持久化 persist_model: bool = False # 默认不持久化
model_save_path: Optional[str] = None # 持久化路径 model_save_path: Optional[str] = None # 持久化路径
@@ -914,8 +857,6 @@ trainer.train(all_data)
# 8. 获取结果 # 8. 获取结果
results = trainer.get_results() # 包含预测值 results = trainer.get_results() # 包含预测值
metrics = trainer.get_metrics() # IC, RankIC, etc.
print("评估指标:", metrics)
# 9. 保存结果 # 9. 保存结果
trainer.save_results("output/predictions.csv") trainer.save_results("output/predictions.csv")
@@ -989,7 +930,6 @@ trainer = Trainer(
trainer.train(all_data) trainer.train(all_data)
results = trainer.get_results() results = trainer.get_results()
metrics = trainer.get_metrics()
``` ```
## 6. 实现顺序 ## 6. 实现顺序
@@ -1019,22 +959,18 @@ metrics = trainer.get_metrics()
- `training/components/models/__init__.py` - `training/components/models/__init__.py`
- `training/components/models/lightgbm.py`(含 save_model/load_model - `training/components/models/lightgbm.py`(含 save_model/load_model
### Commit 6: 评估指标 ### Commit 6: 股票池管理器
- `training/components/metrics/__init__.py`
- `training/components/metrics/metrics.py`IC, RankIC, MSE, MAE
### Commit 7: 股票池管理器
- `training/core/__init__.py` - `training/core/__init__.py`
- `training/core/stock_pool_manager.py`(每日独立筛选) - `training/core/stock_pool_manager.py`(每日独立筛选)
### Commit 8: Trainer 训练器 ### Commit 7: Trainer 训练器
- `training/core/trainer.py` - `training/core/trainer.py`
### Commit 9: 配置管理 ### Commit 8: 配置管理
- `training/config/__init__.py` - `training/config/__init__.py`
- `training/config/config.py`TrainingConfig含必填校验 - `training/config/config.py`TrainingConfig含必填校验
### Commit 10: 预留实验模块 ### Commit 9: 预留实验模块
- `experiment/__init__.py` - `experiment/__init__.py`
## 7. 注意事项 ## 7. 注意事项
@@ -1080,7 +1016,7 @@ metrics = trainer.get_metrics()
1. **特征选择**processors/selectors.py 1. **特征选择**processors/selectors.py
2. **滚动训练**WalkForward, ExpandingWindow 2. **滚动训练**WalkForward, ExpandingWindow
3. **结果分析工具**(复杂分析功能) 3. **结果分析工具**(复杂分析功能)
4. **validator.py, evaluator.py**简化为 metrics.py 4. **validator.py, evaluator.py**已删除,不实现 metrics
### 7.5 新增功能 ### 7.5 新增功能

View File

@@ -0,0 +1,7 @@
"""实验管理模块(预留结构)
此模块为预留结构,用于未来的实验管理功能。
暂不提供具体实现。
"""
__all__ = []

View File

@@ -1,6 +1,6 @@
"""训练模块 - ProStock 量化投资框架 """训练模块 - ProStock 量化投资框架
提供模型训练、数据处理和评估的完整流程。 提供模型训练、数据处理和预测的完整流程。
""" """
# 基础抽象类 # 基础抽象类
@@ -14,6 +14,31 @@ from src.training.registry import (
register_processor, register_processor,
) )
# 数据划分器
from src.training.components.splitters import DateSplitter
# 股票池选择器配置
from src.training.components.selectors import (
MarketCapSelectorConfig,
StockFilterConfig,
)
# 数据处理器
from src.training.components.processors import (
CrossSectionalStandardScaler,
StandardScaler,
Winsorizer,
)
# 模型
from src.training.components.models import LightGBMModel
# 训练核心
from src.training.core import StockPoolManager, Trainer
# 配置
from src.training.config import TrainingConfig
__all__ = [ __all__ = [
# 基础抽象类 # 基础抽象类
"BaseModel", "BaseModel",
@@ -23,4 +48,20 @@ __all__ = [
"ProcessorRegistry", "ProcessorRegistry",
"register_model", "register_model",
"register_processor", "register_processor",
# 数据划分器
"DateSplitter",
# 股票池选择器配置
"StockFilterConfig",
"MarketCapSelectorConfig",
# 数据处理器
"StandardScaler",
"CrossSectionalStandardScaler",
"Winsorizer",
# 模型
"LightGBMModel",
# 训练核心
"StockPoolManager",
"Trainer",
# 配置
"TrainingConfig",
] ]

View File

@@ -0,0 +1,18 @@
"""训练配置管理
提供 TrainingConfig 配置类和相关配置数据类。
"""
from src.training.config.config import (
MarketCapSelectorConfig,
ProcessorConfig,
StockFilterConfig,
TrainingConfig,
)
__all__ = [
"TrainingConfig",
"StockFilterConfig",
"MarketCapSelectorConfig",
"ProcessorConfig",
]

View File

@@ -0,0 +1,141 @@
"""训练配置管理
提供 TrainingConfig 配置类,使用 pydantic 进行参数验证。
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from pydantic import Field, validator
from pydantic_settings import BaseSettings
@dataclass
class StockFilterConfig:
"""股票过滤器配置"""
exclude_cyb: bool = True # 排除创业板
exclude_kcb: bool = True # 排除科创板
exclude_bj: bool = True # 排除北交所
exclude_st: bool = True # 排除ST股票
@dataclass
class MarketCapSelectorConfig:
"""市值选择器配置"""
enabled: bool = True # 是否启用
n: int = 100 # 选择前 n 只
ascending: bool = False # False=最大市值, True=最小市值
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic
@dataclass
class ProcessorConfig:
"""处理器配置"""
name: str
params: Dict[str, Any] = Field(default_factory=dict)
class TrainingConfig(BaseSettings):
"""训练配置类
所有配置参数通过此类管理,支持 pydantic 验证。
"""
# === 数据配置(必填)===
feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个
target_col: str = "target" # 目标变量列名
date_col: str = "trade_date" # 日期列名
code_col: str = "ts_code" # 股票代码列名
# === 日期划分(必填)===
train_start: str = Field(..., description="训练期开始 YYYYMMDD")
train_end: str = Field(..., description="训练期结束 YYYYMMDD")
test_start: str = Field(..., description="测试期开始 YYYYMMDD")
test_end: str = Field(..., description="测试期结束 YYYYMMDD")
# === 股票池配置 ===
stock_filter: StockFilterConfig = Field(
default_factory=lambda: StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
)
stock_selector: Optional[MarketCapSelectorConfig] = Field(
default_factory=lambda: MarketCapSelectorConfig(
enabled=True,
n=100,
ascending=False,
market_cap_col="total_mv",
)
)
# 注意:如果 stock_selector = None则跳过市值选择
# === 模型配置 ===
model_type: str = "lightgbm"
model_params: Dict[str, Any] = Field(default_factory=dict)
# === 处理器配置 ===
processors: List[ProcessorConfig] = Field(
default_factory=lambda: [
ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}),
ProcessorConfig(name="cs_standard_scaler", params={}),
]
)
# === 持久化配置 ===
persist_model: bool = False # 默认不持久化
model_save_path: Optional[str] = None # 持久化路径
# === 输出配置 ===
output_dir: str = "output"
save_predictions: bool = True
@validator("train_start", "train_end", "test_start", "test_end")
def validate_date_format(cls, v: str) -> str:
"""验证日期格式为 YYYYMMDD"""
if not isinstance(v, str) or len(v) != 8:
raise ValueError(f"日期必须是格式为 'YYYYMMDD' 的8位字符串得到: {v}")
try:
int(v)
except ValueError:
raise ValueError(f"日期必须是数字字符串,得到: {v}")
return v
@validator("train_end")
def validate_train_dates(cls, v: str, values: Dict[str, Any]) -> str:
"""验证训练日期范围"""
if "train_start" in values and values["train_start"] > v:
raise ValueError(
f"train_start ({values['train_start']}) 必须早于或等于 train_end ({v})"
)
return v
@validator("test_end")
def validate_test_dates(cls, v: str, values: Dict[str, Any]) -> str:
"""验证测试日期范围"""
if "test_start" in values and values["test_start"] > v:
raise ValueError(
f"test_start ({values['test_start']}) 必须早于或等于 test_end ({v})"
)
return v
@validator("test_start")
def validate_no_overlap(cls, v: str, values: Dict[str, Any]) -> str:
"""验证训练集和测试集不重叠"""
if "train_end" in values and v <= values["train_end"]:
raise ValueError(
f"测试集开始日期 ({v}) 必须晚于训练集结束日期 ({values['train_end']})"
"以确保训练集和测试集不重叠"
)
return v
class Config:
"""Pydantic 配置"""
env_prefix = "TRAINING_" # 环境变量前缀
env_nested_delimiter = "__"

View File

@@ -0,0 +1,9 @@
"""训练核心模块
包含 Trainer 主类和股票池管理器。
"""
from src.training.core.stock_pool_manager import StockPoolManager
from src.training.core.trainer import Trainer
__all__ = ["StockPoolManager", "Trainer"]

View File

@@ -0,0 +1,171 @@
"""股票池管理器
每日独立筛选股票池,市值数据从 daily_basic 表独立获取。
"""
from typing import TYPE_CHECKING, Dict, List, Optional
import polars as pl
from src.training.components.selectors import MarketCapSelectorConfig, StockFilterConfig
if TYPE_CHECKING:
from src.factors.engine.data_router import DataRouter
class StockPoolManager:
"""股票池管理器 - 每日独立筛选
重要约束:
1. 市值数据仅从 daily_basic 表获取,仅用于筛选
2. 市值数据绝不混入特征矩阵
3. 每日独立筛选(市值是动态变化的)
处理流程(每日):
当日所有股票
代码过滤创业板、ST等
查询 daily_basic 获取当日市值
市值选择前N只
返回当日选中股票列表
"""
def __init__(
self,
filter_config: StockFilterConfig,
selector_config: Optional[MarketCapSelectorConfig],
data_router: "DataRouter",
code_col: str = "ts_code",
date_col: str = "trade_date",
):
"""初始化股票池管理器
Args:
filter_config: 股票过滤器配置
selector_config: 市值选择器配置None 表示跳过市值选择
data_router: 数据路由器,用于获取 daily_basic 数据
code_col: 股票代码列名
date_col: 日期列名
"""
self.filter_config = filter_config
self.selector_config = selector_config
self.data_router = data_router
self.code_col = code_col
self.date_col = date_col
def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
"""每日独立筛选股票池
Args:
data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
Returns:
筛选后的数据,仅包含每日选中的股票
Note:
- 按日期分组处理
- 市值数据从 daily_basic 独立获取
- 保持市值数据与特征数据隔离
"""
dates = data.select(self.date_col).unique().sort(self.date_col)
result_frames = []
for date in dates.to_series():
# 获取当日数据
daily_data = data.filter(pl.col(self.date_col) == date)
daily_codes = daily_data.select(self.code_col).to_series().to_list()
# 1. 代码过滤
filtered_codes = self.filter_config.filter_codes(daily_codes)
# 2. 市值选择(如果启用)
if self.selector_config and self.selector_config.enabled:
# 从 daily_basic 获取当日市值
market_caps = self._get_market_caps_for_date(filtered_codes, date)
selected_codes = self._select_by_market_cap(filtered_codes, market_caps)
else:
selected_codes = filtered_codes
# 3. 保留当日选中的股票数据
daily_selected = daily_data.filter(
pl.col(self.code_col).is_in(selected_codes)
)
result_frames.append(daily_selected)
return pl.concat(result_frames)
def _get_market_caps_for_date(
self, codes: List[str], date: str
) -> Dict[str, float]:
"""从 daily_basic 表获取指定日期的市值数据
Args:
codes: 股票代码列表
date: 日期 "YYYYMMDD"
Returns:
{股票代码: 市值} 的字典
"""
if not codes:
return {}
assert self.selector_config is not None, (
"selector_config should not be None when calling _get_market_caps_for_date"
)
try:
# 通过 data_router 查询 daily_basic 表
from src.factors.engine.data_spec import DataSpec
data_specs = [
DataSpec("daily_basic", [self.selector_config.market_cap_col])
]
df = self.data_router.fetch_data(
data_specs=data_specs,
start_date=date,
end_date=date,
stock_codes=codes,
)
# 转换为字典
market_caps = {}
for row in df.iter_rows(named=True):
code = row[self.code_col]
cap = row.get(self.selector_config.market_cap_col)
if cap is not None and code in codes:
market_caps[code] = float(cap)
return market_caps
except Exception as e:
print(f"[警告] 获取 {date} 市值数据失败: {e}")
return {}
def _select_by_market_cap(
self, codes: List[str], market_caps: Dict[str, float]
) -> List[str]:
"""根据市值选择股票
Args:
codes: 股票代码列表
market_caps: 市值数据字典
Returns:
选中的股票代码列表
"""
if self.selector_config is None:
return codes
if not market_caps:
return codes[: self.selector_config.n]
# 按市值排序并选择前N只
sorted_codes = sorted(
codes,
key=lambda c: market_caps.get(c, 0),
reverse=not self.selector_config.ascending,
)
return sorted_codes[: self.selector_config.n]

View File

@@ -0,0 +1,179 @@
"""训练器主类
整合数据处理、模型训练、预测的完整流程。
"""
from typing import List, Optional
import polars as pl
from src.training.components.base import BaseModel, BaseProcessor
from src.training.components.splitters import DateSplitter
from src.training.core.stock_pool_manager import StockPoolManager
class Trainer:
"""训练器主类
整合数据处理、模型训练、预测的完整流程。
关键设计:
1. 因子先计算(全市场),再筛选股票池(每日独立)
2. Processor 分阶段行为:训练集 fit_transform测试集 transform
3. 一次性训练,不滚动
4. 支持模型持久化
"""
def __init__(
self,
model: BaseModel,
pool_manager: Optional[StockPoolManager] = None,
processors: Optional[List[BaseProcessor]] = None,
splitter: Optional[DateSplitter] = None,
target_col: str = "target",
feature_cols: Optional[List[str]] = None,
persist_model: bool = False,
model_save_path: Optional[str] = None,
):
"""初始化训练器
Args:
model: 模型实例
pool_manager: 股票池管理器None 表示不筛选
processors: 数据处理器列表
splitter: 数据划分器
target_col: 目标变量列名
feature_cols: 特征列名列表
persist_model: 是否保存模型
model_save_path: 模型保存路径
"""
self.model = model
self.pool_manager = pool_manager
self.processors = processors or []
self.splitter = splitter
self.target_col = target_col
self.feature_cols = feature_cols or []
self.persist_model = persist_model
self.model_save_path = model_save_path
# 存储训练后的处理器
self.fitted_processors: List[BaseProcessor] = []
self.results: Optional[pl.DataFrame] = None
def train(self, data: pl.DataFrame) -> "Trainer":
"""执行训练流程
流程:
1. 股票池每日筛选(如果配置了 pool_manager
2. 按日期划分训练集/测试集
3. 训练集processors fit_transform
4. 训练模型
5. 测试集processors transform使用训练集学到的参数
6. 预测
7. 保存结果
8. 持久化模型(如果启用)
Args:
data: 因子计算后的全市场数据
必须包含 ts_code 和 trade_date 列
Returns:
self (支持链式调用)
"""
# 1. 股票池筛选(每日独立)
if self.pool_manager:
print("[筛选] 每日独立筛选股票池...")
data = self.pool_manager.filter_and_select_daily(data)
# 2. 划分训练/测试集
if self.splitter:
print("[划分] 划分训练集和测试集...")
train_data, test_data = self.splitter.split(data)
else:
# 没有划分器,全部作为训练集
train_data = data
test_data = data
# 3. 训练集processors fit_transform
if self.processors:
print("[处理] 处理训练集...")
for processor in self.processors:
train_data = processor.fit_transform(train_data)
self.fitted_processors.append(processor)
# 4. 训练模型
print("[训练] 训练模型...")
if not self.feature_cols:
raise ValueError("feature_cols 不能为空")
X_train = train_data.select(self.feature_cols)
y_train = train_data.select(self.target_col).to_series()
self.model.fit(X_train, y_train)
# 5. 测试集processors transform
if self.processors and test_data is not train_data:
print("[处理] 处理测试集...")
for processor in self.fitted_processors:
test_data = processor.transform(test_data)
# 6. 预测
print("[预测] 生成预测...")
X_test = test_data.select(self.feature_cols)
predictions = self.model.predict(X_test)
# 7. 保存结果
self.results = test_data.with_columns([pl.Series("prediction", predictions)])
# 8. 持久化模型
if self.persist_model and self.model_save_path:
print(f"[保存] 保存模型到 {self.model_save_path}...")
self.save_model(self.model_save_path)
return self
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
"""对新数据进行预测
注意:新数据需要先经过股票池筛选,
然后使用训练好的 processors 和 model 进行预测。
Args:
data: 输入数据
Returns:
包含预测列的数据
"""
# 应用 processors
for processor in self.fitted_processors:
data = processor.transform(data)
# 预测
X = data.select(self.feature_cols)
predictions = self.model.predict(X)
return data.with_columns([pl.Series("prediction", predictions)])
def get_results(self) -> Optional[pl.DataFrame]:
"""获取所有预测结果
Returns:
预测结果 DataFrame包含原始列和 prediction 列
"""
return self.results
def save_results(self, path: str) -> None:
"""保存预测结果到文件
Args:
path: 保存路径CSV 格式)
"""
if self.results is not None:
self.results.write_csv(path)
def save_model(self, path: str) -> None:
"""保存模型
Args:
path: 模型保存路径
"""
self.model.save(path)

View File

@@ -193,8 +193,17 @@ class TestLightGBMModel:
"""测试模型已注册到 registry""" """测试模型已注册到 registry"""
from src.training.registry import ModelRegistry from src.training.registry import ModelRegistry
# 重新导入模型模块以确保注册(处理其他测试 clear 注册表的情况)
import importlib
import src.training.components.models.lightgbm as lightgbm_module
importlib.reload(lightgbm_module)
from src.training.components.models.lightgbm import (
LightGBMModel as ReloadedModel,
)
model_class = ModelRegistry.get_model("lightgbm") model_class = ModelRegistry.get_model("lightgbm")
assert model_class is LightGBMModel assert model_class is ReloadedModel
def test_fit_predict_consistency(self): def test_fit_predict_consistency(self):
"""测试多次预测结果一致""" """测试多次预测结果一致"""