feat(training): 添加训练模块基础架构

实现 Commit 1:训练模块基础架构

新增文件:

- src/training/__init__.py - 主模块导出

- src/training/components/__init__.py - components 子模块导出

- src/training/components/base.py - BaseModel/BaseProcessor 抽象基类

- src/training/registry.py - 模型和处理器注册中心

- tests/training/test_base.py - 基础架构单元测试

功能特性:

- BaseModel: 提供 fit, predict, feature_importance, save/load 接口

- BaseProcessor: 提供 fit, transform, fit_transform 接口

- ModelRegistry/ProcessorRegistry: 支持装饰器风格组件注册

- 支持即插即用的组件扩展机制

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
2026-03-03 21:55:39 +08:00
parent 12ddb19b2e
commit 472b2b665a
18 changed files with 694 additions and 3997 deletions

View File

@@ -1,46 +1,26 @@
"""ProStock 训练流程模块
"""训练模块 - ProStock 量化投资框架
本模块提供完整的模型训练流程
1. 数据处理Fillna(0) -> Dropna
2. 模型训练LightGBM分类模型
3. 预测选股每日top5股票池
使用示例:
from src.training import run_training
# 运行完整训练流程
result = run_training(
train_start="20180101",
train_end="20230101",
test_start="20230101",
test_end="20240101",
top_n=5,
output_path="output/top_stocks.tsv"
)
因子使用:
from src.factors import MovingAverageFactor, ReturnRankFactor
ma5 = MovingAverageFactor(period=5) # 5日移动平均
ma10 = MovingAverageFactor(period=10) # 10日移动平均
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
提供模型训练、数据处理和评估的完整流程
"""
from src.training.pipeline import (
create_pipeline,
predict_top_stocks,
prepare_data,
run_training,
save_top_stocks,
train_model,
# 基础抽象类
from src.training.components.base import BaseModel, BaseProcessor
# 注册中心
from src.training.registry import (
ModelRegistry,
ProcessorRegistry,
register_model,
register_processor,
)
__all__ = [
# 管道函数
"prepare_data",
"create_pipeline",
"train_model",
"predict_top_stocks",
"save_top_stocks",
"run_training",
# 基础抽象类
"BaseModel",
"BaseProcessor",
# 注册中心
"ModelRegistry",
"ProcessorRegistry",
"register_model",
"register_processor",
]

View File

@@ -0,0 +1,12 @@
"""训练组件子模块
包含模型、处理器、划分器、选择器等组件。
"""
# 基础抽象类
from src.training.components.base import BaseModel, BaseProcessor
__all__ = [
"BaseModel",
"BaseProcessor",
]

View File

@@ -0,0 +1,141 @@
"""基础抽象类定义
定义 BaseModel 和 BaseProcessor 抽象基类,
为所有训练组件提供统一的接口。
"""
from abc import ABC, abstractmethod
from typing import Optional
import pickle
import polars as pl
import numpy as np
import pandas as pd
class BaseModel(ABC):
"""模型基类
所有机器学习模型必须继承此类并实现抽象方法。
提供统一的训练、预测、特征重要性和持久化接口。
Attributes:
name: 模型名称,子类必须定义
"""
name: str = "" # 模型名称
@abstractmethod
def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel":
"""训练模型
Args:
X: 特征矩阵 (Polars DataFrame)
y: 目标变量 (Polars Series)
Returns:
self (支持链式调用)
"""
raise NotImplementedError
@abstractmethod
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测
Args:
X: 特征矩阵 (Polars DataFrame)
Returns:
预测结果 (numpy ndarray)
"""
raise NotImplementedError
def feature_importance(self) -> Optional[pd.Series]:
"""特征重要性
Returns:
特征重要性序列,如果不支持则返回 None
"""
return None
def save(self, path: str) -> None:
"""保存模型到文件
默认实现使用 pickle 序列化,子类可覆盖以使用更高效的格式。
Args:
path: 保存路径
Raises:
RuntimeError: 模型未训练时调用
"""
with open(path, "wb") as f:
pickle.dump(self, f)
@classmethod
def load(cls, path: str) -> "BaseModel":
"""从文件加载模型
Args:
path: 模型文件路径
Returns:
加载的模型实例
"""
with open(path, "rb") as f:
return pickle.load(f)
class BaseProcessor(ABC):
"""数据处理器基类
重要Processor 在不同阶段行为不同:
- 训练阶段fit_transform学习参数并应用
- 验证/测试阶段transform使用训练阶段学到的参数
这意味着 Processor 实例会在训练后被保存,
用于后续的验证和测试数据转换。
Attributes:
name: 处理器名称,子类必须定义
"""
name: str = ""
def fit(self, X: pl.DataFrame) -> "BaseProcessor":
"""学习参数(仅在训练阶段调用)
子类应覆盖此方法以学习统计参数(如均值、标准差等)。
Args:
X: 训练数据 (Polars DataFrame)
Returns:
self (支持链式调用)
"""
return self
@abstractmethod
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""转换数据
Args:
X: 输入数据 (Polars DataFrame)
Returns:
转换后的数据 (Polars DataFrame)
"""
raise NotImplementedError
def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""拟合并转换(训练阶段使用)
先调用 fit 学习参数,然后调用 transform 应用转换。
Args:
X: 训练数据 (Polars DataFrame)
Returns:
转换后的数据 (Polars DataFrame)
"""
return self.fit(X).transform(X)

File diff suppressed because it is too large Load Diff

View File

@@ -1,667 +0,0 @@
"""训练管道 - 包含数据处理、模型训练和预测功能
本模块提供:
1. 数据准备:使用 FactorEngine 从因子计算结果中准备训练/测试数据
2. 数据处理Fillna(0) -> Dropna
3. 模型训练使用LightGBM训练分类模型
4. 预测和选股输出每日top5股票池
注意:本模块使用 src.factors 框架进行因子计算,确保防泄露机制生效。
因子配置示例:
from src.factors import MovingAverageFactor, ReturnRankFactor
from src.training.pipeline import prepare_data
# 直接传入因子实例列表 - 简单直观
factors = [
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
ReturnRankFactor(period=5),
]
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
...
)
# 或者使用 FactorConfig 包装(支持链式添加)
from src.training.pipeline import FactorConfig
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(MovingAverageFactor(period=10))
.add(ReturnRankFactor(period=5))
"""
from datetime import datetime
from pathlib import Path
from typing import List, Optional
import numpy as np
import polars as pl
from src.factors import DataLoader, FactorEngine, BaseFactor
from src.factors.data_spec import DataSpec
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
from src.pipeline import (
DropNAProcessor,
FillNAProcessor,
LightGBMModel,
PipelineStage,
ProcessingPipeline,
TaskType,
)
# ========== 因子配置类 ==========
class FactorConfig:
"""因子配置类 - 管理因子列表
用于包装因子实例列表,支持链式添加。
示例:
# 方式1初始化时传入列表
config = FactorConfig([
MovingAverageFactor(period=5),
ReturnRankFactor(period=5),
])
# 方式2链式添加
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(ReturnRankFactor(period=5))
# 获取因子实例列表
factors = config.get_factors()
# 获取特征列名
feature_cols = config.get_feature_names()
"""
def __init__(self, factors: Optional[List[BaseFactor]] = None):
"""初始化因子配置
Args:
factors: 因子实例列表
"""
self._factors: List[BaseFactor] = factors or []
def add(self, factor: BaseFactor) -> "FactorConfig":
"""添加因子到配置
支持链式调用:
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(ReturnRankFactor(period=5))
Args:
factor: 因子实例
Returns:
self支持链式调用
"""
if not isinstance(factor, BaseFactor):
raise ValueError(f"必须是 BaseFactor 实例, got {type(factor)}")
self._factors.append(factor)
return self
def get_factors(self) -> List[BaseFactor]:
"""获取因子实例列表
Returns:
因子实例列表
"""
return self._factors
def get_feature_names(self) -> List[str]:
"""获取所有因子的特征列名
Returns:
特征列名列表
"""
return [f.name for f in self._factors]
def get_max_lookback(self) -> int:
"""获取所有因子中最大的 lookback 天数
Returns:
最大 lookback 天数
"""
max_lookback = 0
for factor in self._factors:
for spec in factor.data_specs:
max_lookback = max(max_lookback, spec.lookback_days)
return max_lookback
def __len__(self) -> int:
return len(self._factors)
def __repr__(self) -> str:
names = [f.name for f in self._factors]
return f"FactorConfig({names})"
def prepare_data(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
train_start: str = "20180101",
train_end: str = "20230101",
val_start: str = "20230101",
val_end: str = "20230601",
test_start: str = "20230601",
test_end: str = "20240101",
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig]:
"""准备训练、验证和测试数据
使用 FactorEngine 计算因子,确保防泄露机制生效。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
Returns:
(train_data, val_data, test_data, factor_config):
训练集、验证集、测试集的DataFrame以及使用的因子配置
"""
from src.data.storage import Storage
storage = Storage()
# 1. 处理因子配置
if factors is None:
# 默认因子配置
factor_config = FactorConfig(
[
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
ReturnRankFactor(period=5),
]
)
elif isinstance(factors, FactorConfig):
factor_config = factors
elif isinstance(factors, list):
# 转换为 FactorConfig
factor_config = FactorConfig(factors)
else:
raise ValueError(
f"factors 必须是 List[BaseFactor] 或 FactorConfig got {type(factors)}"
)
factors_list = factor_config.get_factors()
feature_cols = factor_config.get_feature_names()
if not factors_list:
raise ValueError("至少需要提供一个因子")
print(f"[PrepareData] 使用因子: {feature_cols}")
# 2. 初始化 FactorEngine
loader = DataLoader(data_dir=data_dir)
engine = FactorEngine(loader)
# 3. 计算需要的回溯天数
max_lookback = factor_config.get_max_lookback()
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
# 获取所有股票列表
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
all_stocks_query = f"""
SELECT DISTINCT ts_code FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
"""
all_stocks_df = storage._connection.sql(all_stocks_query).pl()
all_stocks = all_stocks_df["ts_code"].to_list()
print(f"[PrepareData] 股票数量: {len(all_stocks)}")
# 4. 计算所有因子并合并
all_features = None
for factor in factors_list:
print(f"[PrepareData] 计算因子: {factor.name} ({factor.factor_type})")
if factor.factor_type == "time_series":
# 时序因子需要传入股票列表
result = engine.compute(
factor,
stock_codes=all_stocks,
start_date=start_with_lookback,
end_date=test_end,
)
else:
# 截面因子不需要股票列表
result = engine.compute(
factor,
start_date=start_with_lookback,
end_date=test_end,
)
# 合并结果
if all_features is None:
all_features = result
else:
# 确保没有重复的 _right 列
result = result.select([
c for c in result.columns if not c.endswith("_right")
])
all_features = all_features.select([
c for c in all_features.columns if not c.endswith("_right")
])
all_features = all_features.join(
result, on=["trade_date", "ts_code"], how="outer"
)
if all_features is None:
raise ValueError("没有计算任何因子")
# 5. 计算标签未来5日收益率
all_data = _compute_label(all_features, start_date=train_start, end_date=test_end)
# 6. 过滤不符合条件的股票
all_data = _filter_invalid_stocks(all_data)
print(f"[PrepareData] After filtering: total={len(all_data)}")
# 转换日期格式用于比较
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
val_start_fmt = f"{val_start[:4]}-{val_start[4:6]}-{val_start[6:8]}"
val_end_fmt = f"{val_end[:4]}-{val_end[4:6]}-{val_end[6:8]}"
test_start_fmt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
test_end_fmt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
# 拆分数据
train_data = all_data.filter(
(pl.col("trade_date") >= train_start_fmt)
& (pl.col("trade_date") <= train_end_fmt)
)
val_data = all_data.filter(
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
)
test_data = all_data.filter(
(pl.col("trade_date") >= test_start_fmt)
& (pl.col("trade_date") <= test_end_fmt)
)
print(
f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}"
)
return train_data, val_data, test_data, factor_config
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
"""过滤不符合条件的股票
过滤规则:
1. 过滤北交所股票ts_code 以 BJ 结尾)
2. 过滤创业板股票ts_code 以 30 开头)
3. 过滤科创板股票ts_code 以 68 开头)
4. 过滤退市/风险股票ts_code 以 8 开头)
Args:
df: 原始数据
Returns:
过滤后的数据
"""
ts_code_col = pl.col("ts_code")
return df.filter(
~ts_code_col.str.ends_with("BJ")
& ~ts_code_col.str.starts_with("30")
& ~ts_code_col.str.starts_with("68")
& ~ts_code_col.str.starts_with("8")
)
def _compute_label(
features_df: pl.DataFrame,
start_date: str,
end_date: str,
) -> pl.DataFrame:
"""计算标签未来5日收益率
标签定义未来5日收益率大于0为1否则为0
Args:
features_df: 包含因子的DataFrame
start_date: 开始日期
end_date: 结束日期
Returns:
包含因子和标签的DataFrame
"""
from src.data.storage import Storage
storage = Storage()
# 从数据库获取收盘价数据用于计算标签
start_dt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
end_dt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
# 需要多取5天数据来计算未来收益率
end_dt_extended = f"{end_date[:4]}-{end_date[4:6]}-{int(end_date[6:8]) + 5}"
price_query = f"""
SELECT ts_code, trade_date, close
FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt_extended}'
ORDER BY ts_code, trade_date
"""
price_data = storage._connection.sql(price_query).pl()
price_data = price_data.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
# 按股票计算未来5日收益率
result_list = []
for ts_code in price_data["ts_code"].unique():
stock_data = price_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
if len(stock_data) < 6:
continue
# 计算未来5日收益率
future_return = stock_data["close"].shift(-5) - stock_data["close"]
future_return_pct = future_return / stock_data["close"]
stock_data = stock_data.with_columns(
[
future_return_pct.alias("future_return_5"),
]
)
# 生成标签:收益率>0为1否则为0
stock_data = stock_data.with_columns(
[
(pl.col("future_return_5") > 0).cast(pl.Int8).alias("label"),
]
)
result_list.append(stock_data.select(["trade_date", "ts_code", "label"]))
if not result_list:
return pl.DataFrame()
label_df = pl.concat(result_list)
# 将标签合并到因子数据
result = features_df.join(label_df, on=["trade_date", "ts_code"], how="inner")
# 过滤有效日期范围
result = result.filter(
(pl.col("trade_date") >= start_dt) & (pl.col("trade_date") <= end_dt)
)
return result
def create_pipeline() -> ProcessingPipeline:
"""创建数据处理流水线
处理流程:
1. FillNA(0): 将缺失值填充为0
注意:不使用 Dropna因为会导致训练和预测时的行数不匹配
Returns:
配置好的ProcessingPipeline
"""
processors = [
FillNAProcessor(method="zero"), # 缺失值填充为0
]
return ProcessingPipeline(processors)
def train_model(
train_data: pl.DataFrame,
val_data: Optional[pl.DataFrame],
feature_cols: List[str],
label_col: str = "label",
model_params: Optional[dict] = None,
) -> tuple[LightGBMModel, ProcessingPipeline]:
"""训练LightGBM分类模型
Args:
train_data: 训练数据
val_data: 验证数据(用于早停)
feature_cols: 特征列名列表
label_col: 标签列名
model_params: 模型参数字典
Returns:
(训练好的模型, 处理流水线)
"""
# 创建处理流水线
pipeline = create_pipeline()
print("[TrainModel] Pipeline created: FillNA(0)")
# 准备训练特征和标签
X_train = train_data.select(feature_cols)
y_train = train_data[label_col]
print(f"[TrainModel] Train samples: {len(X_train)}, features: {feature_cols}")
# 处理训练数据
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
# 过滤训练集有效标签(排除-1等无效值
valid_mask = y_train.is_in([0, 1])
X_train_processed = X_train_processed.filter(valid_mask)
y_train = y_train.filter(valid_mask)
print(
f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples"
)
print(
f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}"
)
# 准备验证集
X_val_processed = None
y_val = None
if val_data is not None and len(val_data) > 0:
X_val = val_data.select(feature_cols)
y_val = val_data[label_col]
print(f"[TrainModel] Val samples: {len(X_val)}")
# 处理验证集数据(使用训练集的参数)
X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST)
# 过滤验证集有效标签
val_valid_mask = y_val.is_in([0, 1])
X_val_processed = X_val_processed.filter(val_valid_mask)
y_val = y_val.filter(val_valid_mask)
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
print(
f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}"
)
# 创建模型
params = model_params or {
"n_estimators": 100,
"learning_rate": 0.05,
"max_depth": 5,
"num_leaves": 31,
}
print(f"[TrainModel] Model params: {params}")
model = LightGBMModel(
task_type="classification",
params=params,
)
# 训练模型(使用验证集早停)
print("[TrainModel] Training LightGBM...")
if X_val_processed is not None and y_val is not None:
print("[TrainModel] Using validation set for early stopping")
model.fit(X_train_processed, y_train, X_val_processed, y_val)
else:
model.fit(X_train_processed, y_train)
print("[TrainModel] Training completed!")
return model, pipeline
def predict_top_stocks(
model: LightGBMModel,
pipeline: ProcessingPipeline,
test_data: pl.DataFrame,
feature_cols: List[str],
top_n: int = 5,
) -> pl.DataFrame:
"""预测并选出每日top N股票
Args:
model: 训练好的模型
pipeline: 数据处理流水线
test_data: 测试数据
feature_cols: 特征列名
top_n: 每日选出的股票数量
Returns:
包含日期和股票代码的DataFrame
"""
# 准备特征和必要列
X_test = test_data.select(feature_cols)
key_cols = ["trade_date", "ts_code"]
key_data = test_data.select(key_cols)
print(f"[Predict] Test samples: {len(X_test)}, top_n: {top_n}")
# 处理数据(使用训练阶段的参数)
X_test_processed = pipeline.transform(X_test, stage=PipelineStage.TEST)
print(f"[Predict] Data processed, shape: {X_test_processed.shape}")
# 预测概率
probs = model.predict_proba(X_test_processed)
print(f"[Predict] Predictions generated, probability shape: {probs.shape}")
# 使用 key_data 添加预测结果,保持行数一致
result = key_data.with_columns(
pl.Series(
name="pred_prob",
values=probs[:, 1]
if len(probs.shape) > 1 and probs.shape[1] > 1
else probs.flatten(),
),
)
# 每日选出top N
top_stocks = []
for date in result["trade_date"].unique().sort():
day_data = result.filter(pl.col("trade_date") == date)
# 按概率降序排序选出top N
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
top_stocks.append(
day_top.select(["trade_date", "pred_prob", "ts_code"]).rename(
{"pred_prob": "score"}
)
)
return pl.concat(top_stocks)
def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
"""保存选股结果到TSV文件
Args:
top_stocks: 选股结果
output_path: 输出文件路径
"""
# 转换为pandas并保存为TSV
df = top_stocks.to_pandas()
df.to_csv(output_path, sep="\t", index=False)
print(f"[Training] Top stocks saved to: {output_path}")
def run_training(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
output_path: str = "output/top_stocks.tsv",
train_start: str = "20180101",
train_end: str = "20230101",
val_start: str = "20230101",
val_end: str = "20230601",
test_start: str = "20230601",
test_end: str = "20240101",
top_n: int = 5,
) -> pl.DataFrame:
"""运行完整训练流程
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
output_path: 输出文件路径
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
top_n: 每日选股数量
Returns:
选股结果DataFrame
"""
print(f"[Training] Starting training pipeline...")
print(f"[Training] Train period: {train_start} -> {train_end}")
print(f"[Training] Val period: {val_start} -> {val_end}")
print(f"[Training] Test period: {test_start} -> {test_end}")
# 1. 准备数据
print("[Training] Preparing data...")
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
print(f"[Training] Train samples: {len(train_data)}")
print(f"[Training] Val samples: {len(val_data)}")
print(f"[Training] Test samples: {len(test_data)}")
# 2. 获取特征列名
feature_cols = factor_config.get_feature_names()
label_col = "label"
print(f"[Training] Feature columns: {feature_cols}")
# 3. 训练模型
print("[Training] Training model...")
model, pipeline = train_model(
train_data=train_data,
val_data=val_data,
feature_cols=feature_cols,
label_col=label_col,
)
# 4. 测试集预测
print("[Training] Predicting on test set...")
top_stocks = predict_top_stocks(
model=model,
pipeline=pipeline,
test_data=test_data,
feature_cols=feature_cols,
top_n=top_n,
)
# 5. 保存结果
print(f"[Training] Saving results to {output_path}...")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
save_top_stocks(top_stocks, output_path)
print("[Training] Training completed!")
return top_stocks

184
src/training/registry.py Normal file
View File

@@ -0,0 +1,184 @@
"""组件注册中心
提供装饰器风格的组件注册机制,支持即插即用。
"""
from typing import Dict, Type, Callable, Any
from src.training.components.base import BaseModel, BaseProcessor
class ModelRegistry:
"""模型注册中心
管理所有可用的模型类,支持通过名称获取模型类。
Example:
>>> @register_model("lightgbm")
... class LightGBMModel(BaseModel):
... pass
>>>
>>> model_class = ModelRegistry.get_model("lightgbm")
>>> model = model_class(**params)
"""
_registry: Dict[str, Type[BaseModel]] = {}
@classmethod
def register(cls, name: str, model_class: Type[BaseModel]) -> None:
"""注册模型类
Args:
name: 模型名称
model_class: 模型类(必须继承 BaseModel
Raises:
ValueError: 名称已被注册或类不继承 BaseModel
"""
if name in cls._registry:
raise ValueError(f"模型 '{name}' 已被注册")
if not issubclass(model_class, BaseModel):
raise ValueError(f"模型类必须继承 BaseModel")
cls._registry[name] = model_class
@classmethod
def get_model(cls, name: str) -> Type[BaseModel]:
"""获取模型类
Args:
name: 模型名称
Returns:
模型类
Raises:
KeyError: 未找到该名称的模型
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"未知模型 '{name}',可用模型: {available}")
return cls._registry[name]
@classmethod
def list_models(cls) -> list[str]:
"""列出所有已注册的模型名称"""
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表(主要用于测试)"""
cls._registry.clear()
class ProcessorRegistry:
"""处理器注册中心
管理所有可用的数据处理器类,支持通过名称获取处理器类。
Example:
>>> @register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>>
>>> processor_class = ProcessorRegistry.get_processor("standard_scaler")
>>> processor = processor_class(**params)
"""
_registry: Dict[str, Type[BaseProcessor]] = {}
@classmethod
def register(cls, name: str, processor_class: Type[BaseProcessor]) -> None:
"""注册处理器类
Args:
name: 处理器名称
processor_class: 处理器类(必须继承 BaseProcessor
Raises:
ValueError: 名称已被注册或类不继承 BaseProcessor
"""
if name in cls._registry:
raise ValueError(f"处理器 '{name}' 已被注册")
if not issubclass(processor_class, BaseProcessor):
raise ValueError(f"处理器类必须继承 BaseProcessor")
cls._registry[name] = processor_class
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""获取处理器类
Args:
name: 处理器名称
Returns:
处理器类
Raises:
KeyError: 未找到该名称的处理器
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"未知处理器 '{name}',可用处理器: {available}")
return cls._registry[name]
@classmethod
def list_processors(cls) -> list[str]:
"""列出所有已注册的处理器名称"""
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表(主要用于测试)"""
cls._registry.clear()
def register_model(name: str) -> Callable[[Type[BaseModel]], Type[BaseModel]]:
"""模型注册装饰器
用于装饰继承 BaseModel 的类,将其注册到 ModelRegistry。
Args:
name: 模型名称
Returns:
装饰器函数
Example:
>>> @register_model("lightgbm")
... class LightGBMModel(BaseModel):
... name = "lightgbm"
... def fit(self, X, y): ...
... def predict(self, X): ...
"""
def decorator(cls: Type[BaseModel]) -> Type[BaseModel]:
ModelRegistry.register(name, cls)
return cls
return decorator
def register_processor(
name: str,
) -> Callable[[Type[BaseProcessor]], Type[BaseProcessor]]:
"""处理器注册装饰器
用于装饰继承 BaseProcessor 的类,将其注册到 ProcessorRegistry。
Args:
name: 处理器名称
Returns:
装饰器函数
Example:
>>> @register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... name = "standard_scaler"
... def transform(self, X): ...
"""
def decorator(cls: Type[BaseProcessor]) -> Type[BaseProcessor]:
ProcessorRegistry.register(name, cls)
return cls
return decorator