feat(training): 添加训练模块基础架构

实现 Commit 1:训练模块基础架构

新增文件:

- src/training/__init__.py - 主模块导出

- src/training/components/__init__.py - components 子模块导出

- src/training/components/base.py - BaseModel/BaseProcessor 抽象基类

- src/training/registry.py - 模型和处理器注册中心

- tests/training/test_base.py - 基础架构单元测试

功能特性:

- BaseModel: 提供 fit, predict, feature_importance, save/load 接口

- BaseProcessor: 提供 fit, transform, fit_transform 接口

- ModelRegistry/ProcessorRegistry: 支持装饰器风格组件注册

- 支持即插即用的组件扩展机制

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
2026-03-03 21:55:39 +08:00
parent 12ddb19b2e
commit 472b2b665a
18 changed files with 694 additions and 3997 deletions

184
src/training/registry.py Normal file
View File

@@ -0,0 +1,184 @@
"""组件注册中心
提供装饰器风格的组件注册机制,支持即插即用。
"""
from typing import Dict, Type, Callable, Any
from src.training.components.base import BaseModel, BaseProcessor
class ModelRegistry:
"""模型注册中心
管理所有可用的模型类,支持通过名称获取模型类。
Example:
>>> @register_model("lightgbm")
... class LightGBMModel(BaseModel):
... pass
>>>
>>> model_class = ModelRegistry.get_model("lightgbm")
>>> model = model_class(**params)
"""
_registry: Dict[str, Type[BaseModel]] = {}
@classmethod
def register(cls, name: str, model_class: Type[BaseModel]) -> None:
"""注册模型类
Args:
name: 模型名称
model_class: 模型类(必须继承 BaseModel
Raises:
ValueError: 名称已被注册或类不继承 BaseModel
"""
if name in cls._registry:
raise ValueError(f"模型 '{name}' 已被注册")
if not issubclass(model_class, BaseModel):
raise ValueError(f"模型类必须继承 BaseModel")
cls._registry[name] = model_class
@classmethod
def get_model(cls, name: str) -> Type[BaseModel]:
"""获取模型类
Args:
name: 模型名称
Returns:
模型类
Raises:
KeyError: 未找到该名称的模型
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"未知模型 '{name}',可用模型: {available}")
return cls._registry[name]
@classmethod
def list_models(cls) -> list[str]:
"""列出所有已注册的模型名称"""
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表(主要用于测试)"""
cls._registry.clear()
class ProcessorRegistry:
"""处理器注册中心
管理所有可用的数据处理器类,支持通过名称获取处理器类。
Example:
>>> @register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>>
>>> processor_class = ProcessorRegistry.get_processor("standard_scaler")
>>> processor = processor_class(**params)
"""
_registry: Dict[str, Type[BaseProcessor]] = {}
@classmethod
def register(cls, name: str, processor_class: Type[BaseProcessor]) -> None:
"""注册处理器类
Args:
name: 处理器名称
processor_class: 处理器类(必须继承 BaseProcessor
Raises:
ValueError: 名称已被注册或类不继承 BaseProcessor
"""
if name in cls._registry:
raise ValueError(f"处理器 '{name}' 已被注册")
if not issubclass(processor_class, BaseProcessor):
raise ValueError(f"处理器类必须继承 BaseProcessor")
cls._registry[name] = processor_class
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""获取处理器类
Args:
name: 处理器名称
Returns:
处理器类
Raises:
KeyError: 未找到该名称的处理器
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"未知处理器 '{name}',可用处理器: {available}")
return cls._registry[name]
@classmethod
def list_processors(cls) -> list[str]:
"""列出所有已注册的处理器名称"""
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表(主要用于测试)"""
cls._registry.clear()
def register_model(name: str) -> Callable[[Type[BaseModel]], Type[BaseModel]]:
"""模型注册装饰器
用于装饰继承 BaseModel 的类,将其注册到 ModelRegistry。
Args:
name: 模型名称
Returns:
装饰器函数
Example:
>>> @register_model("lightgbm")
... class LightGBMModel(BaseModel):
... name = "lightgbm"
... def fit(self, X, y): ...
... def predict(self, X): ...
"""
def decorator(cls: Type[BaseModel]) -> Type[BaseModel]:
ModelRegistry.register(name, cls)
return cls
return decorator
def register_processor(
name: str,
) -> Callable[[Type[BaseProcessor]], Type[BaseProcessor]]:
"""处理器注册装饰器
用于装饰继承 BaseProcessor 的类,将其注册到 ProcessorRegistry。
Args:
name: 处理器名称
Returns:
装饰器函数
Example:
>>> @register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... name = "standard_scaler"
... def transform(self, X): ...
"""
def decorator(cls: Type[BaseProcessor]) -> Type[BaseProcessor]:
ProcessorRegistry.register(name, cls)
return cls
return decorator