a22bc2d282461c11957aa7409ebe3c238f19f3ac
- 删除 src/data/api_wrappers/api_daily.py (240行) - 更新 6 个文档文件,将 daily 表引用替换为 pro_bar - 同步 README.md 中的因子框架和训练模块示例 BREAKING CHANGE: api_daily 模块已移除,请使用 api_pro_bar 替代
ProStock
A股量化投资框架 - 从数据获取到模型训练的完整解决方案
功能特性
1. 数据层 (src/data/)
- 多源数据接入: Tushare API 集成,支持日线、股票基础信息、交易日历
- DuckDB 存储: 高性能嵌入式数据库,支持 SQL 查询下推
- 智能同步: 增量/全量同步策略,自动检测数据更新需求
- 速率控制: 令牌桶算法实现 API 限流
- 并发优化: ThreadPoolExecutor 多线程数据获取
2. 因子层 (src/factors/)
- 类型安全: 严格的截面因子 vs 时序因子区分
- 防泄露机制: 框架层面防止未来数据和跨股票数据泄露
- 因子组合: 支持因子加减乘除和标量运算
- 高性能计算: Polars 向量化操作,零拷贝数据导出
- 灵活扩展: 基类抽象便于自定义因子
3. 模型层 (src/models/)
- 插件架构: 装饰器注册机制,新模型即插即用
- 阶段感知: 训练/测试阶段区分,防止数据泄露
- 多模型支持: LightGBM、CatBoost 等模型统一接口
- 数据处理: 缺失值处理、缩尾、标准化、中性化等
- 时序划分: WalkForward、ExpandingWindow 等时间序列划分策略
项目结构
ProStock/
├── src/
│ ├── config/ # 配置管理
│ │ ├── settings.py # pydantic-settings 配置
│ │ └── __init__.py
│ │
│ ├── data/ # 数据获取与存储
│ │ ├── api_wrappers/ # Tushare API 封装
│ │ │ ├── api_pro_bar.py # Pro Bar行情数据接口(主用)
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
│ │ │ ├── api_trade_cal.py # 交易日历接口
│ │ │ ├── api_bak_basic.py # 历史股票列表接口
│ │ │ ├── api_namechange.py # 股票名称变更接口
│ │ │ ├── api_stock_st.py # ST股票信息接口
│ │ │ ├── api_daily_basic.py # 每日指标接口
│ │ │ ├── api_stk_limit.py # 涨跌停价格接口
│ │ │ ├── financial_data/ # 财务数据接口
│ │ │ │ ├── api_income.py # 利润表接口
│ │ │ │ ├── api_balance.py # 资产负债表接口
│ │ │ │ ├── api_cashflow.py # 现金流量表接口
│ │ │ │ ├── api_fina_indicator.py # 财务指标接口
│ │ │ │ └── api_financial_sync.py # 财务数据同步调度中心
│ │ │ └── __init__.py
│ │ ├── client.py # Tushare 客户端(含限流)
│ │ ├── config.py # 数据模块配置
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── rate_limiter.py # 令牌桶限流器
│ │ ├── storage.py # DuckDB 存储核心
│ │ ├── sync.py # 数据同步主逻辑
│ │ └── __init__.py
│ │
│ ├── factors/ # 因子计算框架
│ │ ├── base.py # 因子基类(截面/时序)
│ │ ├── composite.py # 组合因子和标量运算
│ │ ├── data_loader.py # DuckDB 数据加载器
│ │ ├── data_spec.py # 数据规格定义
│ │ ├── engine.py # 因子执行引擎
│ │ └── __init__.py
│ │
│ ├── models/ # 模型训练框架
│ │ ├── core/ # 核心抽象
│ │ │ ├── base.py # 处理器/模型/划分基类
│ │ │ └── splitter.py # 时间序列划分策略
│ │ ├── models/ # 模型实现
│ │ │ └── models.py # LightGBM、CatBoost
│ │ ├── processors/ # 数据处理器
│ │ │ └── processors.py # 标准化、缩尾、中性化等
│ │ ├── pipeline.py # 处理流水线
│ │ ├── registry.py # 插件注册中心
│ │ └── __init__.py
│ │
│ └── __init__.py
│
├── docs/ # 文档
│ ├── factor_framework_design.md # 因子框架设计
│ ├── ml_framework_design.md # 模型框架设计
│ ├── db_sync_guide.md # 数据同步指南
│ └── ...
│
├── data/ # 数据存储(DuckDB)
│ ├── prostock.db # 主数据库文件
│ └── stock_basic.csv # 股票基础信息缓存
│
├── config/ # 配置文件
│ └── .env.local # 环境变量(API Token等)
│
└── tests/ # 测试文件
├── test_sync.py
└── factors/
快速开始
1. 安装依赖
⚠️ 本项目强制使用 uv 作为 Python 包管理器
# 安装 uv (如果尚未安装)
pip install uv
# 安装项目依赖
uv pip install -e .
2. 配置环境变量
创建 config/.env.local 文件:
TUSHARE_TOKEN=your_tushare_token_here
DATA_PATH=data
RATE_LIMIT=100
THREADS=10
3. 数据同步
# 首次同步 - 全量同步(从20180101开始)
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 日常同步 - 增量同步(自动从最新日期开始)
uv run python -c "from src.data.sync import sync_all; sync_all()"
# 预览同步(检查需要同步的数据量)
uv run python -c "from src.data.sync import preview_sync; preview_sync()"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
4. 查看数据库状态
uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"
使用示例
因子计算
from src.factors import FactorEngine
from src.factors.api import close, ts_mean, cs_rank
import polars as pl
# 初始化引擎
engine = FactorEngine()
# 方式1:使用 DSL 表达式注册
engine.register("ma20", ts_mean(close, 20))
engine.register("price_rank", cs_rank(close))
# 方式2:使用字符串表达式(推荐)
engine.add_factor("ma20", "ts_mean(close, 20)")
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
# 方式3:从 metadata 查询(需先在 metadata 中定义)
engine.add_factor("mom_5d")
# 计算因子
result = engine.compute(
factor_names=["ma20", "price_rank"],
start_date="20240101",
end_date="20240131"
)
# 查看执行计划
plan = engine.preview_plan("ma20")
模型训练
from src.training import (
Trainer,
LightGBMModel,
DateSplitter,
StockPoolManager,
NullFiller,
Winsorizer,
StandardScaler,
STFilter,
check_data_quality,
)
from src.factors import FactorEngine
import polars as pl
# 1. 创建模型
model = LightGBMModel(params={
"objective": "regression",
"metric": "mae",
"num_leaves": 20,
"learning_rate": 0.01,
"n_estimators": 1000,
})
# 2. 准备因子数据
engine = FactorEngine()
engine.add_factor("ma5", "ts_mean(close, 5)")
engine.add_factor("ma20", "ts_mean(close, 20)")
# 计算全市场因子
data = engine.compute(
factor_names=["ma5", "ma20", "future_return_5"],
start_date="20200101",
end_date="20231231"
)
# 3. 创建数据处理器
processors = [
NullFiller(feature_cols=["ma5", "ma20"], strategy="mean"),
Winsorizer(feature_cols=["ma5", "ma20"], lower=0.01, upper=0.99),
StandardScaler(feature_cols=["ma5", "ma20"]),
]
# 4. 创建股票池筛选函数
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""筛选小市值股票"""
code_filter = (
~df["ts_code"].str.starts_with("300") & # 排除创业板
~df["ts_code"].str.starts_with("688") # 排除科创板
)
return code_filter
pool_manager = StockPoolManager(
filter_func=stock_pool_filter,
required_columns=["total_mv"],
)
# 5. 创建过滤器
st_filter = STFilter(data_router=engine.router)
# 6. 创建数据划分器
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
val_start="20230101",
val_end="20230630",
test_start="20230701",
test_end="20231231",
)
# 7. 创建训练器
trainer = Trainer(
model=model,
pool_manager=pool_manager,
processors=processors,
filters=[st_filter],
splitter=splitter,
target_col="future_return_5",
feature_cols=["ma5", "ma20"],
)
# 8. 执行训练
results = trainer.train(data)
# 9. 获取预测结果
predictions = trainer.get_results()
核心设计
1. 数据防泄露机制
截面因子 (CrossSectionalFactor):
- 防止日期泄露:每天只传入
[T-lookback+1, T]数据 - 允许股票间比较:传入当天所有股票数据
- 典型应用:PE排名、市值分位数、当日收益率排名
时序因子 (TimeSeriesFactor):
- 防止股票泄露:每只股票单独计算
- 允许历史数据访问:传入完整时间序列
- 典型应用:移动平均线、RSI、历史波动率
2. 插件注册机制
from src.models.registry import PluginRegistry
# 注册自定义处理器
@PluginRegistry.register_processor("my_processor")
class MyProcessor(BaseProcessor):
stage = PipelineStage.TRAIN
def fit(self, data):
# 学习参数
return self
def transform(self, data):
# 转换数据
return data
# 使用
processor_class = PluginRegistry.get_processor("my_processor")
processor = processor_class()
3. 数据同步策略
智能增量同步:
from src.data.db_manager import SyncManager
manager = SyncManager()
result = manager.sync(
table_name="daily",
fetch_func=get_daily,
start_date="20240101",
end_date="20240131"
)
# 自动检测:表不存在→全量,表存在→增量
文档
开发规范
- Python 版本: 3.10+
- 代码风格: Google 风格文档字符串
- 类型提示: 强制类型注解
- 测试: pytest 框架
- 包管理: uv (禁止直接使用 pip/python)
技术栈
- 数据处理: Polars, Pandas, NumPy
- 数据存储: DuckDB (嵌入式 OLAP 数据库)
- API 接口: Tushare Pro
- 机器学习: LightGBM, CatBoost, scikit-learn
- 配置管理: pydantic-settings
许可证
MIT License
Description
Languages
Python
88.9%
Jupyter Notebook
11.1%