liaozhaorun f48b307ad2 feat(training): 实现 DateSplitter 数据划分器
- 新增 DateSplitter 类,支持基于日期范围的一次性训练/测试划分
- 实现日期格式验证和日期范围逻辑检查
- 支持自定义日期列名参数
- 添加完整的单元测试(12个测试用例)
- 在 components 模块导出 DateSplitter
2026-03-03 22:07:45 +08:00

ProStock

A股量化投资框架 - 从数据获取到模型训练的完整解决方案

功能特性

1. 数据层 (src/data/)

  • 多源数据接入: Tushare API 集成,支持日线、股票基础信息、交易日历
  • DuckDB 存储: 高性能嵌入式数据库,支持 SQL 查询下推
  • 智能同步: 增量/全量同步策略,自动检测数据更新需求
  • 速率控制: 令牌桶算法实现 API 限流
  • 并发优化: ThreadPoolExecutor 多线程数据获取

2. 因子层 (src/factors/)

  • 类型安全: 严格的截面因子 vs 时序因子区分
  • 防泄露机制: 框架层面防止未来数据和跨股票数据泄露
  • 因子组合: 支持因子加减乘除和标量运算
  • 高性能计算: Polars 向量化操作,零拷贝数据导出
  • 灵活扩展: 基类抽象便于自定义因子

3. 模型层 (src/models/)

  • 插件架构: 装饰器注册机制,新模型即插即用
  • 阶段感知: 训练/测试阶段区分,防止数据泄露
  • 多模型支持: LightGBM、CatBoost 等模型统一接口
  • 数据处理: 缺失值处理、缩尾、标准化、中性化等
  • 时序划分: WalkForward、ExpandingWindow 等时间序列划分策略

项目结构

ProStock/
├── src/
│   ├── config/              # 配置管理
│   │   ├── settings.py      # pydantic-settings 配置
│   │   └── __init__.py
│   │
│   ├── data/                # 数据获取与存储
│   │   ├── api_wrappers/    # Tushare API 封装
│   │   │   ├── api_daily.py         # 日线数据接口
│   │   │   ├── api_stock_basic.py   # 股票基础信息
│   │   │   └── api_trade_cal.py     # 交易日历
│   │   ├── client.py        # Tushare 客户端(含限流)
│   │   ├── config.py        # 数据模块配置
│   │   ├── db_manager.py    # DuckDB 表管理和同步
│   │   ├── db_inspector.py  # 数据库信息查看工具
│   │   ├── rate_limiter.py  # 令牌桶限流器
│   │   ├── storage.py       # DuckDB 存储核心
│   │   ├── sync.py          # 数据同步主逻辑
│   │   └── __init__.py
│   │
│   ├── factors/             # 因子计算框架
│   │   ├── base.py          # 因子基类(截面/时序)
│   │   ├── composite.py     # 组合因子和标量运算
│   │   ├── data_loader.py   # DuckDB 数据加载器
│   │   ├── data_spec.py     # 数据规格定义
│   │   ├── engine.py        # 因子执行引擎
│   │   └── __init__.py
│   │
│   ├── models/              # 模型训练框架
│   │   ├── core/            # 核心抽象
│   │   │   ├── base.py      # 处理器/模型/划分基类
│   │   │   └── splitter.py  # 时间序列划分策略
│   │   ├── models/          # 模型实现
│   │   │   └── models.py    # LightGBM、CatBoost
│   │   ├── processors/      # 数据处理器
│   │   │   └── processors.py # 标准化、缩尾、中性化等
│   │   ├── pipeline.py      # 处理流水线
│   │   ├── registry.py      # 插件注册中心
│   │   └── __init__.py
│   │
│   └── __init__.py
│
├── docs/                    # 文档
│   ├── factor_framework_design.md    # 因子框架设计
│   ├── ml_framework_design.md        # 模型框架设计
│   ├── db_sync_guide.md              # 数据同步指南
│   └── ...
│
├── data/                    # 数据存储DuckDB
│   ├── prostock.db          # 主数据库文件
│   └── stock_basic.csv      # 股票基础信息缓存
│
├── config/                  # 配置文件
│   └── .env.local           # 环境变量API Token等
│
└── tests/                   # 测试文件
    ├── test_sync.py
    └── factors/

快速开始

1. 安装依赖

⚠️ 本项目强制使用 uv 作为 Python 包管理器

# 安装 uv (如果尚未安装)
pip install uv

# 安装项目依赖
uv pip install -e .

2. 配置环境变量

创建 config/.env.local 文件:

TUSHARE_TOKEN=your_tushare_token_here
DATA_PATH=data
RATE_LIMIT=100
THREADS=10

3. 数据同步

# 首次同步 - 全量同步从20180101开始
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"

# 日常同步 - 增量同步(自动从最新日期开始)
uv run python -c "from src.data.sync import sync_all; sync_all()"

# 预览同步(检查需要同步的数据量)
uv run python -c "from src.data.sync import preview_sync; preview_sync()"

# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"

4. 查看数据库状态

uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"

使用示例

因子计算

from src.factors import FactorEngine, DataLoader, DataSpec
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
import polars as pl

# 自定义截面因子PE排名
class PERankFactor(CrossSectionalFactor):
    name = "pe_rank"
    data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)]
    
    def compute(self, data) -> pl.Series:
        cs = data.get_cross_section()
        return cs["pe"].rank()

# 自定义时序因子20日移动平均
class MA20Factor(TimeSeriesFactor):
    name = "ma20"
    data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
    
    def compute(self, data) -> pl.Series:
        return data.get_column("close").rolling_mean(window_size=20)

# 执行计算
loader = DataLoader(data_dir="data")
engine = FactorEngine(loader)

# 计算截面因子
pe_rank = PERankFactor()
result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131")

# 计算时序因子
ma20 = MA20Factor()
result2 = engine.compute(ma20, stock_codes=["000001.SZ"], 
                        start_date="20240101", end_date="20240131")

# 因子组合
combined = 0.5 * pe_rank + 0.3 * ma20

模型训练

from src.models import PluginRegistry, ProcessingPipeline
from src.models.core import PipelineStage
import polars as pl

# 创建处理流水线
pipeline = ProcessingPipeline([
    PluginRegistry.get_processor("dropna")(),
    PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
    PluginRegistry.get_processor("standard_scaler")(),
])

# 准备数据
data = pl.read_csv("features.csv")  # 包含特征和标签

# 划分训练/测试集
from src.models.core import WalkForwardSplit
splitter = WalkForwardSplit(train_window=252, test_window=21)

# 获取 LightGBM 模型
ModelClass = PluginRegistry.get_model("lightgbm")
model = ModelClass(task_type="regression", params={"n_estimators": 100})

# 训练循环
for train_idx, test_idx in splitter.split(data):
    train_data = data[train_idx]
    test_data = data[test_idx]
    
    # 数据处理
    X_train = pipeline.fit_transform(train_data.drop("target"))
    X_test = pipeline.transform(test_data.drop("target"))
    y_train = train_data["target"]
    y_test = test_data["target"]
    
    # 训练模型
    model.fit(X_train, y_train)
    predictions = model.predict(X_test)

核心设计

1. 数据防泄露机制

截面因子 (CrossSectionalFactor):

  • 防止日期泄露:每天只传入 [T-lookback+1, T] 数据
  • 允许股票间比较:传入当天所有股票数据
  • 典型应用PE排名、市值分位数、当日收益率排名

时序因子 (TimeSeriesFactor):

  • 防止股票泄露:每只股票单独计算
  • 允许历史数据访问:传入完整时间序列
  • 典型应用移动平均线、RSI、历史波动率

2. 插件注册机制

from src.models.registry import PluginRegistry

# 注册自定义处理器
@PluginRegistry.register_processor("my_processor")
class MyProcessor(BaseProcessor):
    stage = PipelineStage.TRAIN
    
    def fit(self, data):
        # 学习参数
        return self
    
    def transform(self, data):
        # 转换数据
        return data

# 使用
processor_class = PluginRegistry.get_processor("my_processor")
processor = processor_class()

3. 数据同步策略

智能增量同步:

from src.data.db_manager import SyncManager

manager = SyncManager()
result = manager.sync(
    table_name="daily",
    fetch_func=get_daily,
    start_date="20240101",
    end_date="20240131"
)
# 自动检测:表不存在→全量,表存在→增量

文档

开发规范

  • Python 版本: 3.10+
  • 代码风格: Google 风格文档字符串
  • 类型提示: 强制类型注解
  • 测试: pytest 框架
  • 包管理: uv (禁止直接使用 pip/python)

技术栈

  • 数据处理: Polars, Pandas, NumPy
  • 数据存储: DuckDB (嵌入式 OLAP 数据库)
  • API 接口: Tushare Pro
  • 机器学习: LightGBM, CatBoost, scikit-learn
  • 配置管理: pydantic-settings

许可证

MIT License

Description
No description provided
Readme 4.5 MiB
Languages
Python 88.9%
Jupyter Notebook 11.1%