"""组件注册中心 提供装饰器风格的组件注册机制,支持即插即用。 """ from typing import Dict, Type, Callable, Any from src.training.components.base import BaseModel, BaseProcessor class ModelRegistry: """模型注册中心 管理所有可用的模型类,支持通过名称获取模型类。 Example: >>> @register_model("lightgbm") ... class LightGBMModel(BaseModel): ... pass >>> >>> model_class = ModelRegistry.get_model("lightgbm") >>> model = model_class(**params) """ _registry: Dict[str, Type[BaseModel]] = {} @classmethod def register(cls, name: str, model_class: Type[BaseModel]) -> None: """注册模型类 Args: name: 模型名称 model_class: 模型类(必须继承 BaseModel) Raises: ValueError: 名称已被注册或类不继承 BaseModel """ if name in cls._registry: raise ValueError(f"模型 '{name}' 已被注册") if not issubclass(model_class, BaseModel): raise ValueError(f"模型类必须继承 BaseModel") cls._registry[name] = model_class @classmethod def get_model(cls, name: str) -> Type[BaseModel]: """获取模型类 Args: name: 模型名称 Returns: 模型类 Raises: KeyError: 未找到该名称的模型 """ if name not in cls._registry: available = ", ".join(cls._registry.keys()) raise KeyError(f"未知模型 '{name}',可用模型: {available}") return cls._registry[name] @classmethod def list_models(cls) -> list[str]: """列出所有已注册的模型名称""" return list(cls._registry.keys()) @classmethod def clear(cls) -> None: """清空注册表(主要用于测试)""" cls._registry.clear() class ProcessorRegistry: """处理器注册中心 管理所有可用的数据处理器类,支持通过名称获取处理器类。 Example: >>> @register_processor("standard_scaler") ... class StandardScaler(BaseProcessor): ... pass >>> >>> processor_class = ProcessorRegistry.get_processor("standard_scaler") >>> processor = processor_class(**params) """ _registry: Dict[str, Type[BaseProcessor]] = {} @classmethod def register(cls, name: str, processor_class: Type[BaseProcessor]) -> None: """注册处理器类 Args: name: 处理器名称 processor_class: 处理器类(必须继承 BaseProcessor) Raises: ValueError: 名称已被注册或类不继承 BaseProcessor """ if name in cls._registry: raise ValueError(f"处理器 '{name}' 已被注册") if not issubclass(processor_class, BaseProcessor): raise ValueError(f"处理器类必须继承 BaseProcessor") cls._registry[name] = processor_class @classmethod def get_processor(cls, name: str) -> Type[BaseProcessor]: """获取处理器类 Args: name: 处理器名称 Returns: 处理器类 Raises: KeyError: 未找到该名称的处理器 """ if name not in cls._registry: available = ", ".join(cls._registry.keys()) raise KeyError(f"未知处理器 '{name}',可用处理器: {available}") return cls._registry[name] @classmethod def list_processors(cls) -> list[str]: """列出所有已注册的处理器名称""" return list(cls._registry.keys()) @classmethod def clear(cls) -> None: """清空注册表(主要用于测试)""" cls._registry.clear() def register_model(name: str) -> Callable[[Type[BaseModel]], Type[BaseModel]]: """模型注册装饰器 用于装饰继承 BaseModel 的类,将其注册到 ModelRegistry。 Args: name: 模型名称 Returns: 装饰器函数 Example: >>> @register_model("lightgbm") ... class LightGBMModel(BaseModel): ... name = "lightgbm" ... def fit(self, X, y): ... ... def predict(self, X): ... """ def decorator(cls: Type[BaseModel]) -> Type[BaseModel]: ModelRegistry.register(name, cls) return cls return decorator def register_processor( name: str, ) -> Callable[[Type[BaseProcessor]], Type[BaseProcessor]]: """处理器注册装饰器 用于装饰继承 BaseProcessor 的类,将其注册到 ProcessorRegistry。 Args: name: 处理器名称 Returns: 装饰器函数 Example: >>> @register_processor("standard_scaler") ... class StandardScaler(BaseProcessor): ... name = "standard_scaler" ... def transform(self, X): ... """ def decorator(cls: Type[BaseProcessor]) -> Type[BaseProcessor]: ProcessorRegistry.register(name, cls) return cls return decorator