feat(training): 添加训练模块基础架构
实现 Commit 1:训练模块基础架构 新增文件: - src/training/__init__.py - 主模块导出 - src/training/components/__init__.py - components 子模块导出 - src/training/components/base.py - BaseModel/BaseProcessor 抽象基类 - src/training/registry.py - 模型和处理器注册中心 - tests/training/test_base.py - 基础架构单元测试 功能特性: - BaseModel: 提供 fit, predict, feature_importance, save/load 接口 - BaseProcessor: 提供 fit, transform, fit_transform 接口 - ModelRegistry/ProcessorRegistry: 支持装饰器风格组件注册 - 支持即插即用的组件扩展机制 Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
12
src/training/components/__init__.py
Normal file
12
src/training/components/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""训练组件子模块
|
||||
|
||||
包含模型、处理器、划分器、选择器等组件。
|
||||
"""
|
||||
|
||||
# 基础抽象类
|
||||
from src.training.components.base import BaseModel, BaseProcessor
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"BaseProcessor",
|
||||
]
|
||||
141
src/training/components/base.py
Normal file
141
src/training/components/base.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""基础抽象类定义
|
||||
|
||||
定义 BaseModel 和 BaseProcessor 抽象基类,
|
||||
为所有训练组件提供统一的接口。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
import pickle
|
||||
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""模型基类
|
||||
|
||||
所有机器学习模型必须继承此类并实现抽象方法。
|
||||
提供统一的训练、预测、特征重要性和持久化接口。
|
||||
|
||||
Attributes:
|
||||
name: 模型名称,子类必须定义
|
||||
"""
|
||||
|
||||
name: str = "" # 模型名称
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel":
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
y: 目标变量 (Polars Series)
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
预测结果 (numpy ndarray)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def feature_importance(self) -> Optional[pd.Series]:
|
||||
"""特征重要性
|
||||
|
||||
Returns:
|
||||
特征重要性序列,如果不支持则返回 None
|
||||
"""
|
||||
return None
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""保存模型到文件
|
||||
|
||||
默认实现使用 pickle 序列化,子类可覆盖以使用更高效的格式。
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
"""
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "BaseModel":
|
||||
"""从文件加载模型
|
||||
|
||||
Args:
|
||||
path: 模型文件路径
|
||||
|
||||
Returns:
|
||||
加载的模型实例
|
||||
"""
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
class BaseProcessor(ABC):
|
||||
"""数据处理器基类
|
||||
|
||||
重要:Processor 在不同阶段行为不同:
|
||||
- 训练阶段:fit_transform(学习参数并应用)
|
||||
- 验证/测试阶段:transform(使用训练阶段学到的参数)
|
||||
|
||||
这意味着 Processor 实例会在训练后被保存,
|
||||
用于后续的验证和测试数据转换。
|
||||
|
||||
Attributes:
|
||||
name: 处理器名称,子类必须定义
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
|
||||
def fit(self, X: pl.DataFrame) -> "BaseProcessor":
|
||||
"""学习参数(仅在训练阶段调用)
|
||||
|
||||
子类应覆盖此方法以学习统计参数(如均值、标准差等)。
|
||||
|
||||
Args:
|
||||
X: 训练数据 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""转换数据
|
||||
|
||||
Args:
|
||||
X: 输入数据 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
转换后的数据 (Polars DataFrame)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""拟合并转换(训练阶段使用)
|
||||
|
||||
先调用 fit 学习参数,然后调用 transform 应用转换。
|
||||
|
||||
Args:
|
||||
X: 训练数据 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
转换后的数据 (Polars DataFrame)
|
||||
"""
|
||||
return self.fit(X).transform(X)
|
||||
Reference in New Issue
Block a user