Files
ProStock/src/training/registry.py
liaozhaorun 472b2b665a feat(training): 添加训练模块基础架构
实现 Commit 1:训练模块基础架构

新增文件:

- src/training/__init__.py - 主模块导出

- src/training/components/__init__.py - components 子模块导出

- src/training/components/base.py - BaseModel/BaseProcessor 抽象基类

- src/training/registry.py - 模型和处理器注册中心

- tests/training/test_base.py - 基础架构单元测试

功能特性:

- BaseModel: 提供 fit, predict, feature_importance, save/load 接口

- BaseProcessor: 提供 fit, transform, fit_transform 接口

- ModelRegistry/ProcessorRegistry: 支持装饰器风格组件注册

- 支持即插即用的组件扩展机制

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-03-03 21:55:39 +08:00

185 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""组件注册中心
提供装饰器风格的组件注册机制,支持即插即用。
"""
from typing import Dict, Type, Callable, Any
from src.training.components.base import BaseModel, BaseProcessor
class ModelRegistry:
"""模型注册中心
管理所有可用的模型类,支持通过名称获取模型类。
Example:
>>> @register_model("lightgbm")
... class LightGBMModel(BaseModel):
... pass
>>>
>>> model_class = ModelRegistry.get_model("lightgbm")
>>> model = model_class(**params)
"""
_registry: Dict[str, Type[BaseModel]] = {}
@classmethod
def register(cls, name: str, model_class: Type[BaseModel]) -> None:
"""注册模型类
Args:
name: 模型名称
model_class: 模型类(必须继承 BaseModel
Raises:
ValueError: 名称已被注册或类不继承 BaseModel
"""
if name in cls._registry:
raise ValueError(f"模型 '{name}' 已被注册")
if not issubclass(model_class, BaseModel):
raise ValueError(f"模型类必须继承 BaseModel")
cls._registry[name] = model_class
@classmethod
def get_model(cls, name: str) -> Type[BaseModel]:
"""获取模型类
Args:
name: 模型名称
Returns:
模型类
Raises:
KeyError: 未找到该名称的模型
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"未知模型 '{name}',可用模型: {available}")
return cls._registry[name]
@classmethod
def list_models(cls) -> list[str]:
"""列出所有已注册的模型名称"""
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表(主要用于测试)"""
cls._registry.clear()
class ProcessorRegistry:
"""处理器注册中心
管理所有可用的数据处理器类,支持通过名称获取处理器类。
Example:
>>> @register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... pass
>>>
>>> processor_class = ProcessorRegistry.get_processor("standard_scaler")
>>> processor = processor_class(**params)
"""
_registry: Dict[str, Type[BaseProcessor]] = {}
@classmethod
def register(cls, name: str, processor_class: Type[BaseProcessor]) -> None:
"""注册处理器类
Args:
name: 处理器名称
processor_class: 处理器类(必须继承 BaseProcessor
Raises:
ValueError: 名称已被注册或类不继承 BaseProcessor
"""
if name in cls._registry:
raise ValueError(f"处理器 '{name}' 已被注册")
if not issubclass(processor_class, BaseProcessor):
raise ValueError(f"处理器类必须继承 BaseProcessor")
cls._registry[name] = processor_class
@classmethod
def get_processor(cls, name: str) -> Type[BaseProcessor]:
"""获取处理器类
Args:
name: 处理器名称
Returns:
处理器类
Raises:
KeyError: 未找到该名称的处理器
"""
if name not in cls._registry:
available = ", ".join(cls._registry.keys())
raise KeyError(f"未知处理器 '{name}',可用处理器: {available}")
return cls._registry[name]
@classmethod
def list_processors(cls) -> list[str]:
"""列出所有已注册的处理器名称"""
return list(cls._registry.keys())
@classmethod
def clear(cls) -> None:
"""清空注册表(主要用于测试)"""
cls._registry.clear()
def register_model(name: str) -> Callable[[Type[BaseModel]], Type[BaseModel]]:
"""模型注册装饰器
用于装饰继承 BaseModel 的类,将其注册到 ModelRegistry。
Args:
name: 模型名称
Returns:
装饰器函数
Example:
>>> @register_model("lightgbm")
... class LightGBMModel(BaseModel):
... name = "lightgbm"
... def fit(self, X, y): ...
... def predict(self, X): ...
"""
def decorator(cls: Type[BaseModel]) -> Type[BaseModel]:
ModelRegistry.register(name, cls)
return cls
return decorator
def register_processor(
name: str,
) -> Callable[[Type[BaseProcessor]], Type[BaseProcessor]]:
"""处理器注册装饰器
用于装饰继承 BaseProcessor 的类,将其注册到 ProcessorRegistry。
Args:
name: 处理器名称
Returns:
装饰器函数
Example:
>>> @register_processor("standard_scaler")
... class StandardScaler(BaseProcessor):
... name = "standard_scaler"
... def transform(self, X): ...
"""
def decorator(cls: Type[BaseProcessor]) -> Type[BaseProcessor]:
ProcessorRegistry.register(name, cls)
return cls
return decorator