- 存储层重构: HDF5 → DuckDB(UPSERT模式、线程安全存储) - Sync类迁移: DataSync从sync.py迁移到api_daily.py(职责分离) - 模型模块重构: src/models → src/pipeline(更清晰的命名) - 新增因子模块: factors/momentum (MA、收益率排名)、factors/financial - 新增API接口: api_namechange、api_bak_basic - 新增训练入口: training模块(main.py、pipeline配置) - 工具函数统一: get_today_date等移至utils.py - 文档更新: AGENTS.md添加架构变更历史
211 lines
6.5 KiB
Python
211 lines
6.5 KiB
Python
"""内置机器学习模型
|
|
|
|
提供 LightGBM、CatBoost 等模型的统一接口包装器。
|
|
"""
|
|
|
|
from typing import Optional, Dict, Any
|
|
import polars as pl
|
|
import numpy as np
|
|
|
|
from src.pipeline.core import BaseModel, TaskType
|
|
from src.pipeline.registry import PluginRegistry
|
|
|
|
|
|
@PluginRegistry.register_model("lightgbm")
|
|
class LightGBMModel(BaseModel):
|
|
"""LightGBM 模型包装器
|
|
|
|
支持分类、回归、排序三种任务类型。
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
task_type: TaskType,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
name: Optional[str] = None,
|
|
):
|
|
super().__init__(task_type, params, name)
|
|
self._model = None
|
|
|
|
def fit(
|
|
self,
|
|
X: pl.DataFrame,
|
|
y: pl.Series,
|
|
X_val: Optional[pl.DataFrame] = None,
|
|
y_val: Optional[pl.Series] = None,
|
|
**fit_params,
|
|
) -> "LightGBMModel":
|
|
"""训练模型"""
|
|
try:
|
|
import lightgbm as lgb
|
|
except ImportError:
|
|
raise ImportError(
|
|
"lightgbm is required. Install with: uv pip install lightgbm"
|
|
)
|
|
|
|
X_arr = X.to_numpy()
|
|
y_arr = y.to_numpy()
|
|
|
|
train_data = lgb.Dataset(X_arr, label=y_arr)
|
|
valid_sets = [train_data]
|
|
valid_names = ["train"]
|
|
|
|
if X_val is not None and y_val is not None:
|
|
valid_data = lgb.Dataset(X_val.to_numpy(), label=y_val.to_numpy())
|
|
valid_sets.append(valid_data)
|
|
valid_names.append("valid")
|
|
|
|
default_params = {
|
|
"objective": self._get_objective(),
|
|
"metric": self._get_metric(),
|
|
"boosting_type": "gbdt",
|
|
"num_leaves": 31,
|
|
"learning_rate": 0.05,
|
|
"feature_fraction": 0.9,
|
|
"bagging_fraction": 0.8,
|
|
"bagging_freq": 5,
|
|
"verbose": -1,
|
|
}
|
|
default_params.update(self.params)
|
|
|
|
callbacks = []
|
|
if len(valid_sets) > 1:
|
|
callbacks.append(lgb.early_stopping(stopping_rounds=10, verbose=False))
|
|
|
|
self._model = lgb.train(
|
|
default_params,
|
|
train_data,
|
|
num_boost_round=fit_params.get("num_boost_round", 100),
|
|
valid_sets=valid_sets,
|
|
valid_names=valid_names,
|
|
callbacks=callbacks,
|
|
)
|
|
self._is_fitted = True
|
|
return self
|
|
|
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
|
"""预测"""
|
|
if not self._is_fitted:
|
|
raise RuntimeError("Model not fitted yet")
|
|
return self._model.predict(X.to_numpy())
|
|
|
|
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
|
"""预测概率(仅分类任务)"""
|
|
if self.task_type != "classification":
|
|
raise ValueError("predict_proba only for classification")
|
|
probs = self.predict(X)
|
|
if len(probs.shape) == 1:
|
|
return np.vstack([1 - probs, probs]).T
|
|
return probs
|
|
|
|
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
|
"""获取特征重要性"""
|
|
if self._model is None:
|
|
return None
|
|
importance = self._model.feature_importance(importance_type="gain")
|
|
feature_names = getattr(
|
|
self._model,
|
|
"feature_name",
|
|
lambda: [f"feature_{i}" for i in range(len(importance))],
|
|
)()
|
|
return pl.DataFrame({"feature": feature_names, "importance": importance}).sort(
|
|
"importance", descending=True
|
|
)
|
|
|
|
def _get_objective(self) -> str:
|
|
objectives = {
|
|
"classification": "binary",
|
|
"regression": "regression",
|
|
"ranking": "lambdarank",
|
|
}
|
|
return objectives.get(self.task_type, "regression")
|
|
|
|
def _get_metric(self) -> str:
|
|
metrics = {"classification": "auc", "regression": "rmse", "ranking": "ndcg"}
|
|
return metrics.get(self.task_type, "rmse")
|
|
|
|
|
|
@PluginRegistry.register_model("catboost")
|
|
class CatBoostModel(BaseModel):
|
|
"""CatBoost 模型包装器"""
|
|
|
|
def __init__(
|
|
self,
|
|
task_type: TaskType,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
name: Optional[str] = None,
|
|
):
|
|
super().__init__(task_type, params, name)
|
|
self._model = None
|
|
|
|
def fit(
|
|
self,
|
|
X: pl.DataFrame,
|
|
y: pl.Series,
|
|
X_val: Optional[pl.DataFrame] = None,
|
|
y_val: Optional[pl.Series] = None,
|
|
**fit_params,
|
|
) -> "CatBoostModel":
|
|
"""训练模型"""
|
|
try:
|
|
from catboost import CatBoostClassifier, CatBoostRegressor
|
|
except ImportError:
|
|
raise ImportError(
|
|
"catboost is required. Install with: uv pip install catboost"
|
|
)
|
|
|
|
if self.task_type == "classification":
|
|
model_class = CatBoostClassifier
|
|
default_params = {"loss_function": "Logloss", "eval_metric": "AUC"}
|
|
elif self.task_type == "regression":
|
|
model_class = CatBoostRegressor
|
|
default_params = {"loss_function": "RMSE"}
|
|
else:
|
|
model_class = CatBoostRegressor
|
|
default_params = {"loss_function": "QueryRMSE"}
|
|
|
|
default_params.update(self.params)
|
|
default_params["verbose"] = False
|
|
|
|
self._model = model_class(**default_params)
|
|
|
|
eval_set = None
|
|
if X_val is not None and y_val is not None:
|
|
eval_set = (X_val.to_pandas(), y_val.to_pandas())
|
|
|
|
self._model.fit(
|
|
X.to_pandas(),
|
|
y.to_pandas(),
|
|
eval_set=eval_set,
|
|
early_stopping_rounds=fit_params.get("early_stopping_rounds", 10),
|
|
verbose=False,
|
|
)
|
|
self._is_fitted = True
|
|
return self
|
|
|
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
|
"""预测"""
|
|
if not self._is_fitted:
|
|
raise RuntimeError("Model not fitted yet")
|
|
return self._model.predict(X.to_pandas())
|
|
|
|
def predict_proba(self, X: pl.DataFrame) -> np.ndarray:
|
|
"""预测概率"""
|
|
if self.task_type != "classification":
|
|
raise ValueError("predict_proba only for classification")
|
|
return self._model.predict_proba(X.to_pandas())
|
|
|
|
def get_feature_importance(self) -> Optional[pl.DataFrame]:
|
|
"""获取特征重要性"""
|
|
if self._model is None:
|
|
return None
|
|
return pl.DataFrame(
|
|
{
|
|
"feature": self._model.feature_names_,
|
|
"importance": self._model.feature_importances_,
|
|
}
|
|
).sort("importance", descending=True)
|
|
|
|
|
|
__all__ = ["LightGBMModel", "CatBoostModel"]
|