feat(training): 添加训练模块基础架构
实现 Commit 1:训练模块基础架构 新增文件: - src/training/__init__.py - 主模块导出 - src/training/components/__init__.py - components 子模块导出 - src/training/components/base.py - BaseModel/BaseProcessor 抽象基类 - src/training/registry.py - 模型和处理器注册中心 - tests/training/test_base.py - 基础架构单元测试 功能特性: - BaseModel: 提供 fit, predict, feature_importance, save/load 接口 - BaseProcessor: 提供 fit, transform, fit_transform 接口 - ModelRegistry/ProcessorRegistry: 支持装饰器风格组件注册 - 支持即插即用的组件扩展机制 Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
184
src/training/registry.py
Normal file
184
src/training/registry.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""组件注册中心
|
||||
|
||||
提供装饰器风格的组件注册机制,支持即插即用。
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, Callable, Any
|
||||
from src.training.components.base import BaseModel, BaseProcessor
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""模型注册中心
|
||||
|
||||
管理所有可用的模型类,支持通过名称获取模型类。
|
||||
|
||||
Example:
|
||||
>>> @register_model("lightgbm")
|
||||
... class LightGBMModel(BaseModel):
|
||||
... pass
|
||||
>>>
|
||||
>>> model_class = ModelRegistry.get_model("lightgbm")
|
||||
>>> model = model_class(**params)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseModel]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, model_class: Type[BaseModel]) -> None:
|
||||
"""注册模型类
|
||||
|
||||
Args:
|
||||
name: 模型名称
|
||||
model_class: 模型类(必须继承 BaseModel)
|
||||
|
||||
Raises:
|
||||
ValueError: 名称已被注册或类不继承 BaseModel
|
||||
"""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"模型 '{name}' 已被注册")
|
||||
if not issubclass(model_class, BaseModel):
|
||||
raise ValueError(f"模型类必须继承 BaseModel")
|
||||
cls._registry[name] = model_class
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, name: str) -> Type[BaseModel]:
|
||||
"""获取模型类
|
||||
|
||||
Args:
|
||||
name: 模型名称
|
||||
|
||||
Returns:
|
||||
模型类
|
||||
|
||||
Raises:
|
||||
KeyError: 未找到该名称的模型
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
available = ", ".join(cls._registry.keys())
|
||||
raise KeyError(f"未知模型 '{name}',可用模型: {available}")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def list_models(cls) -> list[str]:
|
||||
"""列出所有已注册的模型名称"""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表(主要用于测试)"""
|
||||
cls._registry.clear()
|
||||
|
||||
|
||||
class ProcessorRegistry:
|
||||
"""处理器注册中心
|
||||
|
||||
管理所有可用的数据处理器类,支持通过名称获取处理器类。
|
||||
|
||||
Example:
|
||||
>>> @register_processor("standard_scaler")
|
||||
... class StandardScaler(BaseProcessor):
|
||||
... pass
|
||||
>>>
|
||||
>>> processor_class = ProcessorRegistry.get_processor("standard_scaler")
|
||||
>>> processor = processor_class(**params)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseProcessor]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, processor_class: Type[BaseProcessor]) -> None:
|
||||
"""注册处理器类
|
||||
|
||||
Args:
|
||||
name: 处理器名称
|
||||
processor_class: 处理器类(必须继承 BaseProcessor)
|
||||
|
||||
Raises:
|
||||
ValueError: 名称已被注册或类不继承 BaseProcessor
|
||||
"""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"处理器 '{name}' 已被注册")
|
||||
if not issubclass(processor_class, BaseProcessor):
|
||||
raise ValueError(f"处理器类必须继承 BaseProcessor")
|
||||
cls._registry[name] = processor_class
|
||||
|
||||
@classmethod
|
||||
def get_processor(cls, name: str) -> Type[BaseProcessor]:
|
||||
"""获取处理器类
|
||||
|
||||
Args:
|
||||
name: 处理器名称
|
||||
|
||||
Returns:
|
||||
处理器类
|
||||
|
||||
Raises:
|
||||
KeyError: 未找到该名称的处理器
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
available = ", ".join(cls._registry.keys())
|
||||
raise KeyError(f"未知处理器 '{name}',可用处理器: {available}")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def list_processors(cls) -> list[str]:
|
||||
"""列出所有已注册的处理器名称"""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表(主要用于测试)"""
|
||||
cls._registry.clear()
|
||||
|
||||
|
||||
def register_model(name: str) -> Callable[[Type[BaseModel]], Type[BaseModel]]:
|
||||
"""模型注册装饰器
|
||||
|
||||
用于装饰继承 BaseModel 的类,将其注册到 ModelRegistry。
|
||||
|
||||
Args:
|
||||
name: 模型名称
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
|
||||
Example:
|
||||
>>> @register_model("lightgbm")
|
||||
... class LightGBMModel(BaseModel):
|
||||
... name = "lightgbm"
|
||||
... def fit(self, X, y): ...
|
||||
... def predict(self, X): ...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[BaseModel]) -> Type[BaseModel]:
|
||||
ModelRegistry.register(name, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_processor(
|
||||
name: str,
|
||||
) -> Callable[[Type[BaseProcessor]], Type[BaseProcessor]]:
|
||||
"""处理器注册装饰器
|
||||
|
||||
用于装饰继承 BaseProcessor 的类,将其注册到 ProcessorRegistry。
|
||||
|
||||
Args:
|
||||
name: 处理器名称
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
|
||||
Example:
|
||||
>>> @register_processor("standard_scaler")
|
||||
... class StandardScaler(BaseProcessor):
|
||||
... name = "standard_scaler"
|
||||
... def transform(self, X): ...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[BaseProcessor]) -> Type[BaseProcessor]:
|
||||
ProcessorRegistry.register(name, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user