feat(training): 添加训练模块基础架构

实现 Commit 1:训练模块基础架构

新增文件:

- src/training/__init__.py - 主模块导出

- src/training/components/__init__.py - components 子模块导出

- src/training/components/base.py - BaseModel/BaseProcessor 抽象基类

- src/training/registry.py - 模型和处理器注册中心

- tests/training/test_base.py - 基础架构单元测试

功能特性:

- BaseModel: 提供 fit, predict, feature_importance, save/load 接口

- BaseProcessor: 提供 fit, transform, fit_transform 接口

- ModelRegistry/ProcessorRegistry: 支持装饰器风格组件注册

- 支持即插即用的组件扩展机制

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
2026-03-03 21:55:39 +08:00
parent 12ddb19b2e
commit 472b2b665a
18 changed files with 694 additions and 3997 deletions

View File

@@ -0,0 +1,12 @@
"""训练组件子模块
包含模型、处理器、划分器、选择器等组件。
"""
# 基础抽象类
from src.training.components.base import BaseModel, BaseProcessor
__all__ = [
"BaseModel",
"BaseProcessor",
]

View File

@@ -0,0 +1,141 @@
"""基础抽象类定义
定义 BaseModel 和 BaseProcessor 抽象基类,
为所有训练组件提供统一的接口。
"""
from abc import ABC, abstractmethod
from typing import Optional
import pickle
import polars as pl
import numpy as np
import pandas as pd
class BaseModel(ABC):
"""模型基类
所有机器学习模型必须继承此类并实现抽象方法。
提供统一的训练、预测、特征重要性和持久化接口。
Attributes:
name: 模型名称,子类必须定义
"""
name: str = "" # 模型名称
@abstractmethod
def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel":
"""训练模型
Args:
X: 特征矩阵 (Polars DataFrame)
y: 目标变量 (Polars Series)
Returns:
self (支持链式调用)
"""
raise NotImplementedError
@abstractmethod
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测
Args:
X: 特征矩阵 (Polars DataFrame)
Returns:
预测结果 (numpy ndarray)
"""
raise NotImplementedError
def feature_importance(self) -> Optional[pd.Series]:
"""特征重要性
Returns:
特征重要性序列,如果不支持则返回 None
"""
return None
def save(self, path: str) -> None:
"""保存模型到文件
默认实现使用 pickle 序列化,子类可覆盖以使用更高效的格式。
Args:
path: 保存路径
Raises:
RuntimeError: 模型未训练时调用
"""
with open(path, "wb") as f:
pickle.dump(self, f)
@classmethod
def load(cls, path: str) -> "BaseModel":
"""从文件加载模型
Args:
path: 模型文件路径
Returns:
加载的模型实例
"""
with open(path, "rb") as f:
return pickle.load(f)
class BaseProcessor(ABC):
"""数据处理器基类
重要Processor 在不同阶段行为不同:
- 训练阶段fit_transform学习参数并应用
- 验证/测试阶段transform使用训练阶段学到的参数
这意味着 Processor 实例会在训练后被保存,
用于后续的验证和测试数据转换。
Attributes:
name: 处理器名称,子类必须定义
"""
name: str = ""
def fit(self, X: pl.DataFrame) -> "BaseProcessor":
"""学习参数(仅在训练阶段调用)
子类应覆盖此方法以学习统计参数(如均值、标准差等)。
Args:
X: 训练数据 (Polars DataFrame)
Returns:
self (支持链式调用)
"""
return self
@abstractmethod
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""转换数据
Args:
X: 输入数据 (Polars DataFrame)
Returns:
转换后的数据 (Polars DataFrame)
"""
raise NotImplementedError
def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""拟合并转换(训练阶段使用)
先调用 fit 学习参数,然后调用 transform 应用转换。
Args:
X: 训练数据 (Polars DataFrame)
Returns:
转换后的数据 (Polars DataFrame)
"""
return self.fit(X).transform(X)