Files
ProStock/src/training/registry.py

185 lines
5.0 KiB
Python
Raw Normal View History

"""组件注册中心
提供装饰器风格的组件注册机制支持即插即用
"""
from typing import Dict, Type, Callable, Any
from src.training.components.base import BaseModel, BaseProcessor
class ModelRegistry:
"""模型注册中心
管理所有可用的模型类支持通过名称获取模型类
Example:
>>> @register_model("lightgbm")
... class LightGBMModel(BaseModel):
... pass
>>>
>>> model_class = ModelRegistry.get_model("lightgbm")
>>> model = model_class(**params)
"""
_registry: Dict[str, Type[BaseModel]] = {}
@classmethod
def register(cls, name: str, model_class: Type[BaseModel]) -> None:
"""注册模型类
Args:
name: 模型名称
model_class: 模型类必须继承 BaseModel
Raises:
ValueError: 名称已被注册或类不继承 BaseModel
"""
if name in cls._registry:
raise ValueError(f"模型 '{name}' 已被注册")
if not issubclass(model_class, BaseModel):
raise ValueError(f"模型类必须继承 BaseModel")
cls._registry[name] = model_class
@classmethod
def get_model(cls, name: str) -> Type[BaseModel]:
"""获取模型类
Args:
name: 模型名称
Returns:
模型类
Raises:
KeyError: 未找到该名称的模型
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"未知模型 '{name}',可用模型: {available}")
return cls._registry[name]
@classmethod
def list_models(cls) -> list[str]:
"""列出所有已注册的模型名称"""
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表(主要用于测试)"""
cls._registry.clear()
class ProcessorRegistry:
"""处理器注册中心
管理所有可用的数据处理器类支持通过名称获取处理器类
Example:
>>> @register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>>
>>> processor_class = ProcessorRegistry.get_processor("standard_scaler")
>>> processor = processor_class(**params)
"""
_registry: Dict[str, Type[BaseProcessor]] = {}
@classmethod
def register(cls, name: str, processor_class: Type[BaseProcessor]) -> None:
"""注册处理器类
Args:
name: 处理器名称
processor_class: 处理器类必须继承 BaseProcessor
Raises:
ValueError: 名称已被注册或类不继承 BaseProcessor
"""
if name in cls._registry:
raise ValueError(f"处理器 '{name}' 已被注册")
if not issubclass(processor_class, BaseProcessor):
raise ValueError(f"处理器类必须继承 BaseProcessor")
cls._registry[name] = processor_class
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""获取处理器类
Args:
name: 处理器名称
Returns:
处理器类
Raises:
KeyError: 未找到该名称的处理器
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"未知处理器 '{name}',可用处理器: {available}")
return cls._registry[name]
@classmethod
def list_processors(cls) -> list[str]:
"""列出所有已注册的处理器名称"""
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表(主要用于测试)"""
cls._registry.clear()
def register_model(name: str) -> Callable[[Type[BaseModel]], Type[BaseModel]]:
"""模型注册装饰器
用于装饰继承 BaseModel 的类将其注册到 ModelRegistry
Args:
name: 模型名称
Returns:
装饰器函数
Example:
>>> @register_model("lightgbm")
... class LightGBMModel(BaseModel):
... name = "lightgbm"
... def fit(self, X, y): ...
... def predict(self, X): ...
"""
def decorator(cls: Type[BaseModel]) -> Type[BaseModel]:
ModelRegistry.register(name, cls)
return cls
return decorator
def register_processor(
name: str,
) -> Callable[[Type[BaseProcessor]], Type[BaseProcessor]]:
"""处理器注册装饰器
用于装饰继承 BaseProcessor 的类将其注册到 ProcessorRegistry
Args:
name: 处理器名称
Returns:
装饰器函数
Example:
>>> @register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... name = "standard_scaler"
... def transform(self, X): ...
"""
def decorator(cls: Type[BaseProcessor]) -> Type[BaseProcessor]:
ProcessorRegistry.register(name, cls)
return cls
return decorator