Files
ProStock/README.md
liaozhaorun 9f95be56a0 feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类
- 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露
- 内置 8 个数据处理器和 3 种时序划分策略
- 支持 LightGBM、CatBoost 模型
- PluginRegistry 装饰器注册,插件式架构
- 22 个单元测试
2026-02-23 01:37:34 +08:00

301 lines
9.4 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ProStock
A股量化投资框架 - 从数据获取到模型训练的完整解决方案
## 功能特性
### 1. 数据层 (src/data/)
- **多源数据接入**: Tushare API 集成,支持日线、股票基础信息、交易日历
- **DuckDB 存储**: 高性能嵌入式数据库,支持 SQL 查询下推
- **智能同步**: 增量/全量同步策略,自动检测数据更新需求
- **速率控制**: 令牌桶算法实现 API 限流
- **并发优化**: ThreadPoolExecutor 多线程数据获取
### 2. 因子层 (src/factors/)
- **类型安全**: 严格的截面因子 vs 时序因子区分
- **防泄露机制**: 框架层面防止未来数据和跨股票数据泄露
- **因子组合**: 支持因子加减乘除和标量运算
- **高性能计算**: Polars 向量化操作,零拷贝数据导出
- **灵活扩展**: 基类抽象便于自定义因子
### 3. 模型层 (src/models/)
- **插件架构**: 装饰器注册机制,新模型即插即用
- **阶段感知**: 训练/测试阶段区分,防止数据泄露
- **多模型支持**: LightGBM、CatBoost 等模型统一接口
- **数据处理**: 缺失值处理、缩尾、标准化、中性化等
- **时序划分**: WalkForward、ExpandingWindow 等时间序列划分策略
## 项目结构
```
ProStock/
├── src/
│ ├── config/ # 配置管理
│ │ ├── settings.py # pydantic-settings 配置
│ │ └── __init__.py
│ │
│ ├── data/ # 数据获取与存储
│ │ ├── api_wrappers/ # Tushare API 封装
│ │ │ ├── api_daily.py # 日线数据接口
│ │ │ ├── api_stock_basic.py # 股票基础信息
│ │ │ └── api_trade_cal.py # 交易日历
│ │ ├── client.py # Tushare 客户端(含限流)
│ │ ├── config.py # 数据模块配置
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── rate_limiter.py # 令牌桶限流器
│ │ ├── storage.py # DuckDB 存储核心
│ │ ├── sync.py # 数据同步主逻辑
│ │ └── __init__.py
│ │
│ ├── factors/ # 因子计算框架
│ │ ├── base.py # 因子基类(截面/时序)
│ │ ├── composite.py # 组合因子和标量运算
│ │ ├── data_loader.py # DuckDB 数据加载器
│ │ ├── data_spec.py # 数据规格定义
│ │ ├── engine.py # 因子执行引擎
│ │ └── __init__.py
│ │
│ ├── models/ # 模型训练框架
│ │ ├── core/ # 核心抽象
│ │ │ ├── base.py # 处理器/模型/划分基类
│ │ │ └── splitter.py # 时间序列划分策略
│ │ ├── models/ # 模型实现
│ │ │ └── models.py # LightGBM、CatBoost
│ │ ├── processors/ # 数据处理器
│ │ │ └── processors.py # 标准化、缩尾、中性化等
│ │ ├── pipeline.py # 处理流水线
│ │ ├── registry.py # 插件注册中心
│ │ └── __init__.py
│ │
│ └── __init__.py
├── docs/ # 文档
│ ├── factor_framework_design.md # 因子框架设计
│ ├── ml_framework_design.md # 模型框架设计
│ ├── db_sync_guide.md # 数据同步指南
│ └── ...
├── data/ # 数据存储DuckDB
│ ├── prostock.db # 主数据库文件
│ └── stock_basic.csv # 股票基础信息缓存
├── config/ # 配置文件
│ └── .env.local # 环境变量API Token等
└── tests/ # 测试文件
├── test_sync.py
└── factors/
```
## 快速开始
### 1. 安装依赖
**⚠️ 本项目强制使用 uv 作为 Python 包管理器**
```bash
# 安装 uv (如果尚未安装)
pip install uv
# 安装项目依赖
uv pip install -e .
```
### 2. 配置环境变量
创建 `config/.env.local` 文件:
```bash
TUSHARE_TOKEN=your_tushare_token_here
DATA_PATH=data
RATE_LIMIT=100
THREADS=10
```
### 3. 数据同步
```bash
# 首次同步 - 全量同步从20180101开始
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 日常同步 - 增量同步(自动从最新日期开始)
uv run python -c "from src.data.sync import sync_all; sync_all()"
# 预览同步(检查需要同步的数据量)
uv run python -c "from src.data.sync import preview_sync; preview_sync()"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
```
### 4. 查看数据库状态
```bash
uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"
```
## 使用示例
### 因子计算
```python
from src.factors import FactorEngine, DataLoader, DataSpec
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
import polars as pl
# 自定义截面因子PE排名
class PERankFactor(CrossSectionalFactor):
name = "pe_rank"
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)]
def compute(self, data) -> pl.Series:
cs = data.get_cross_section()
return cs["pe"].rank()
# 自定义时序因子20日移动平均
class MA20Factor(TimeSeriesFactor):
name = "ma20"
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
def compute(self, data) -> pl.Series:
return data.get_column("close").rolling_mean(window_size=20)
# 执行计算
loader = DataLoader(data_dir="data")
engine = FactorEngine(loader)
# 计算截面因子
pe_rank = PERankFactor()
result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131")
# 计算时序因子
ma20 = MA20Factor()
result2 = engine.compute(ma20, stock_codes=["000001.SZ"],
start_date="20240101", end_date="20240131")
# 因子组合
combined = 0.5 * pe_rank + 0.3 * ma20
```
### 模型训练
```python
from src.models import PluginRegistry, ProcessingPipeline
from src.models.core import PipelineStage
import polars as pl
# 创建处理流水线
pipeline = ProcessingPipeline([
PluginRegistry.get_processor("dropna")(),
PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
PluginRegistry.get_processor("standard_scaler")(),
])
# 准备数据
data = pl.read_csv("features.csv") # 包含特征和标签
# 划分训练/测试集
from src.models.core import WalkForwardSplit
splitter = WalkForwardSplit(train_window=252, test_window=21)
# 获取 LightGBM 模型
ModelClass = PluginRegistry.get_model("lightgbm")
model = ModelClass(task_type="regression", params={"n_estimators": 100})
# 训练循环
for train_idx, test_idx in splitter.split(data):
train_data = data[train_idx]
test_data = data[test_idx]
# 数据处理
X_train = pipeline.fit_transform(train_data.drop("target"))
X_test = pipeline.transform(test_data.drop("target"))
y_train = train_data["target"]
y_test = test_data["target"]
# 训练模型
model.fit(X_train, y_train)
predictions = model.predict(X_test)
```
## 核心设计
### 1. 数据防泄露机制
**截面因子 (CrossSectionalFactor)**:
- 防止日期泄露:每天只传入 `[T-lookback+1, T]` 数据
- 允许股票间比较:传入当天所有股票数据
- 典型应用PE排名、市值分位数、当日收益率排名
**时序因子 (TimeSeriesFactor)**:
- 防止股票泄露:每只股票单独计算
- 允许历史数据访问:传入完整时间序列
- 典型应用移动平均线、RSI、历史波动率
### 2. 插件注册机制
```python
from src.models.registry import PluginRegistry
# 注册自定义处理器
@PluginRegistry.register_processor("my_processor")
class MyProcessor(BaseProcessor):
stage = PipelineStage.TRAIN
def fit(self, data):
# 学习参数
return self
def transform(self, data):
# 转换数据
return data
# 使用
processor_class = PluginRegistry.get_processor("my_processor")
processor = processor_class()
```
### 3. 数据同步策略
**智能增量同步**:
```python
from src.data.db_manager import SyncManager
manager = SyncManager()
result = manager.sync(
table_name="daily",
fetch_func=get_daily,
start_date="20240101",
end_date="20240131"
)
# 自动检测:表不存在→全量,表存在→增量
```
## 文档
- [因子框架设计](docs/factor_framework_design.md) - 因子计算架构详解
- [模型框架设计](docs/ml_framework_design.md) - 模型训练架构详解
- [数据同步指南](docs/db_sync_guide.md) - DuckDB 数据同步 API 说明
- [代码审查报告](docs/code_review_factors_20260222.md) - 因子框架代码审查
## 开发规范
- **Python 版本**: 3.10+
- **代码风格**: Google 风格文档字符串
- **类型提示**: 强制类型注解
- **测试**: pytest 框架
- **包管理**: uv (禁止直接使用 pip/python)
## 技术栈
- **数据处理**: Polars, Pandas, NumPy
- **数据存储**: DuckDB (嵌入式 OLAP 数据库)
- **API 接口**: Tushare Pro
- **机器学习**: LightGBM, CatBoost, scikit-learn
- **配置管理**: pydantic-settings
## 许可证
MIT License