实现 Commit 1:训练模块基础架构 新增文件: - src/training/__init__.py - 主模块导出 - src/training/components/__init__.py - components 子模块导出 - src/training/components/base.py - BaseModel/BaseProcessor 抽象基类 - src/training/registry.py - 模型和处理器注册中心 - tests/training/test_base.py - 基础架构单元测试 功能特性: - BaseModel: 提供 fit, predict, feature_importance, save/load 接口 - BaseProcessor: 提供 fit, transform, fit_transform 接口 - ModelRegistry/ProcessorRegistry: 支持装饰器风格组件注册 - 支持即插即用的组件扩展机制 Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
185 lines
5.0 KiB
Python
185 lines
5.0 KiB
Python
"""组件注册中心
|
||
|
||
提供装饰器风格的组件注册机制,支持即插即用。
|
||
"""
|
||
|
||
from typing import Dict, Type, Callable, Any
|
||
from src.training.components.base import BaseModel, BaseProcessor
|
||
|
||
|
||
class ModelRegistry:
|
||
"""模型注册中心
|
||
|
||
管理所有可用的模型类,支持通过名称获取模型类。
|
||
|
||
Example:
|
||
>>> @register_model("lightgbm")
|
||
... class LightGBMModel(BaseModel):
|
||
... pass
|
||
>>>
|
||
>>> model_class = ModelRegistry.get_model("lightgbm")
|
||
>>> model = model_class(**params)
|
||
"""
|
||
|
||
_registry: Dict[str, Type[BaseModel]] = {}
|
||
|
||
@classmethod
|
||
def register(cls, name: str, model_class: Type[BaseModel]) -> None:
|
||
"""注册模型类
|
||
|
||
Args:
|
||
name: 模型名称
|
||
model_class: 模型类(必须继承 BaseModel)
|
||
|
||
Raises:
|
||
ValueError: 名称已被注册或类不继承 BaseModel
|
||
"""
|
||
if name in cls._registry:
|
||
raise ValueError(f"模型 '{name}' 已被注册")
|
||
if not issubclass(model_class, BaseModel):
|
||
raise ValueError(f"模型类必须继承 BaseModel")
|
||
cls._registry[name] = model_class
|
||
|
||
@classmethod
|
||
def get_model(cls, name: str) -> Type[BaseModel]:
|
||
"""获取模型类
|
||
|
||
Args:
|
||
name: 模型名称
|
||
|
||
Returns:
|
||
模型类
|
||
|
||
Raises:
|
||
KeyError: 未找到该名称的模型
|
||
"""
|
||
if name not in cls._registry:
|
||
available = ", ".join(cls._registry.keys())
|
||
raise KeyError(f"未知模型 '{name}',可用模型: {available}")
|
||
return cls._registry[name]
|
||
|
||
@classmethod
|
||
def list_models(cls) -> list[str]:
|
||
"""列出所有已注册的模型名称"""
|
||
return list(cls._registry.keys())
|
||
|
||
@classmethod
|
||
def clear(cls) -> None:
|
||
"""清空注册表(主要用于测试)"""
|
||
cls._registry.clear()
|
||
|
||
|
||
class ProcessorRegistry:
|
||
"""处理器注册中心
|
||
|
||
管理所有可用的数据处理器类,支持通过名称获取处理器类。
|
||
|
||
Example:
|
||
>>> @register_processor("standard_scaler")
|
||
... class StandardScaler(BaseProcessor):
|
||
... pass
|
||
>>>
|
||
>>> processor_class = ProcessorRegistry.get_processor("standard_scaler")
|
||
>>> processor = processor_class(**params)
|
||
"""
|
||
|
||
_registry: Dict[str, Type[BaseProcessor]] = {}
|
||
|
||
@classmethod
|
||
def register(cls, name: str, processor_class: Type[BaseProcessor]) -> None:
|
||
"""注册处理器类
|
||
|
||
Args:
|
||
name: 处理器名称
|
||
processor_class: 处理器类(必须继承 BaseProcessor)
|
||
|
||
Raises:
|
||
ValueError: 名称已被注册或类不继承 BaseProcessor
|
||
"""
|
||
if name in cls._registry:
|
||
raise ValueError(f"处理器 '{name}' 已被注册")
|
||
if not issubclass(processor_class, BaseProcessor):
|
||
raise ValueError(f"处理器类必须继承 BaseProcessor")
|
||
cls._registry[name] = processor_class
|
||
|
||
@classmethod
|
||
def get_processor(cls, name: str) -> Type[BaseProcessor]:
|
||
"""获取处理器类
|
||
|
||
Args:
|
||
name: 处理器名称
|
||
|
||
Returns:
|
||
处理器类
|
||
|
||
Raises:
|
||
KeyError: 未找到该名称的处理器
|
||
"""
|
||
if name not in cls._registry:
|
||
available = ", ".join(cls._registry.keys())
|
||
raise KeyError(f"未知处理器 '{name}',可用处理器: {available}")
|
||
return cls._registry[name]
|
||
|
||
@classmethod
|
||
def list_processors(cls) -> list[str]:
|
||
"""列出所有已注册的处理器名称"""
|
||
return list(cls._registry.keys())
|
||
|
||
@classmethod
|
||
def clear(cls) -> None:
|
||
"""清空注册表(主要用于测试)"""
|
||
cls._registry.clear()
|
||
|
||
|
||
def register_model(name: str) -> Callable[[Type[BaseModel]], Type[BaseModel]]:
|
||
"""模型注册装饰器
|
||
|
||
用于装饰继承 BaseModel 的类,将其注册到 ModelRegistry。
|
||
|
||
Args:
|
||
name: 模型名称
|
||
|
||
Returns:
|
||
装饰器函数
|
||
|
||
Example:
|
||
>>> @register_model("lightgbm")
|
||
... class LightGBMModel(BaseModel):
|
||
... name = "lightgbm"
|
||
... def fit(self, X, y): ...
|
||
... def predict(self, X): ...
|
||
"""
|
||
|
||
def decorator(cls: Type[BaseModel]) -> Type[BaseModel]:
|
||
ModelRegistry.register(name, cls)
|
||
return cls
|
||
|
||
return decorator
|
||
|
||
|
||
def register_processor(
|
||
name: str,
|
||
) -> Callable[[Type[BaseProcessor]], Type[BaseProcessor]]:
|
||
"""处理器注册装饰器
|
||
|
||
用于装饰继承 BaseProcessor 的类,将其注册到 ProcessorRegistry。
|
||
|
||
Args:
|
||
name: 处理器名称
|
||
|
||
Returns:
|
||
装饰器函数
|
||
|
||
Example:
|
||
>>> @register_processor("standard_scaler")
|
||
... class StandardScaler(BaseProcessor):
|
||
... name = "standard_scaler"
|
||
... def transform(self, X): ...
|
||
"""
|
||
|
||
def decorator(cls: Type[BaseProcessor]) -> Type[BaseProcessor]:
|
||
ProcessorRegistry.register(name, cls)
|
||
return cls
|
||
|
||
return decorator
|