# ProStock A股量化投资框架 - 从数据获取到模型训练的完整解决方案 ## 功能特性 ### 1. 数据层 (src/data/) - **多源数据接入**: Tushare API 集成,支持日线、股票基础信息、交易日历 - **DuckDB 存储**: 高性能嵌入式数据库,支持 SQL 查询下推 - **智能同步**: 增量/全量同步策略,自动检测数据更新需求 - **速率控制**: 令牌桶算法实现 API 限流 - **并发优化**: ThreadPoolExecutor 多线程数据获取 ### 2. 因子层 (src/factors/) - **类型安全**: 严格的截面因子 vs 时序因子区分 - **防泄露机制**: 框架层面防止未来数据和跨股票数据泄露 - **因子组合**: 支持因子加减乘除和标量运算 - **高性能计算**: Polars 向量化操作,零拷贝数据导出 - **灵活扩展**: 基类抽象便于自定义因子 ### 3. 模型层 (src/models/) - **插件架构**: 装饰器注册机制,新模型即插即用 - **阶段感知**: 训练/测试阶段区分,防止数据泄露 - **多模型支持**: LightGBM、CatBoost 等模型统一接口 - **数据处理**: 缺失值处理、缩尾、标准化、中性化等 - **时序划分**: WalkForward、ExpandingWindow 等时间序列划分策略 ## 项目结构 ``` ProStock/ ├── src/ │ ├── config/ # 配置管理 │ │ ├── settings.py # pydantic-settings 配置 │ │ └── __init__.py │ │ │ ├── data/ # 数据获取与存储 │ │ ├── api_wrappers/ # Tushare API 封装 │ │ │ ├── api_daily.py # 日线数据接口 │ │ │ ├── api_stock_basic.py # 股票基础信息 │ │ │ └── api_trade_cal.py # 交易日历 │ │ ├── client.py # Tushare 客户端(含限流) │ │ ├── config.py # 数据模块配置 │ │ ├── db_manager.py # DuckDB 表管理和同步 │ │ ├── db_inspector.py # 数据库信息查看工具 │ │ ├── rate_limiter.py # 令牌桶限流器 │ │ ├── storage.py # DuckDB 存储核心 │ │ ├── sync.py # 数据同步主逻辑 │ │ └── __init__.py │ │ │ ├── factors/ # 因子计算框架 │ │ ├── base.py # 因子基类(截面/时序) │ │ ├── composite.py # 组合因子和标量运算 │ │ ├── data_loader.py # DuckDB 数据加载器 │ │ ├── data_spec.py # 数据规格定义 │ │ ├── engine.py # 因子执行引擎 │ │ └── __init__.py │ │ │ ├── models/ # 模型训练框架 │ │ ├── core/ # 核心抽象 │ │ │ ├── base.py # 处理器/模型/划分基类 │ │ │ └── splitter.py # 时间序列划分策略 │ │ ├── models/ # 模型实现 │ │ │ └── models.py # LightGBM、CatBoost │ │ ├── processors/ # 数据处理器 │ │ │ └── processors.py # 标准化、缩尾、中性化等 │ │ ├── pipeline.py # 处理流水线 │ │ ├── registry.py # 插件注册中心 │ │ └── __init__.py │ │ │ └── __init__.py │ ├── docs/ # 文档 │ ├── factor_framework_design.md # 因子框架设计 │ ├── ml_framework_design.md # 模型框架设计 │ ├── db_sync_guide.md # 数据同步指南 │ └── ... │ ├── data/ # 数据存储(DuckDB) │ ├── prostock.db # 主数据库文件 │ └── stock_basic.csv # 股票基础信息缓存 │ ├── config/ # 配置文件 │ └── .env.local # 环境变量(API Token等) │ └── tests/ # 测试文件 ├── test_sync.py └── factors/ ``` ## 快速开始 ### 1. 安装依赖 **⚠️ 本项目强制使用 uv 作为 Python 包管理器** ```bash # 安装 uv (如果尚未安装) pip install uv # 安装项目依赖 uv pip install -e . ``` ### 2. 配置环境变量 创建 `config/.env.local` 文件: ```bash TUSHARE_TOKEN=your_tushare_token_here DATA_PATH=data RATE_LIMIT=100 THREADS=10 ``` ### 3. 数据同步 ```bash # 首次同步 - 全量同步(从20180101开始) uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" # 日常同步 - 增量同步(自动从最新日期开始) uv run python -c "from src.data.sync import sync_all; sync_all()" # 预览同步(检查需要同步的数据量) uv run python -c "from src.data.sync import preview_sync; preview_sync()" # 自定义线程数 uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" ``` ### 4. 查看数据库状态 ```bash uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()" ``` ## 使用示例 ### 因子计算 ```python from src.factors import FactorEngine, DataLoader, DataSpec from src.factors.base import CrossSectionalFactor, TimeSeriesFactor import polars as pl # 自定义截面因子:PE排名 class PERankFactor(CrossSectionalFactor): name = "pe_rank" data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)] def compute(self, data) -> pl.Series: cs = data.get_cross_section() return cs["pe"].rank() # 自定义时序因子:20日移动平均 class MA20Factor(TimeSeriesFactor): name = "ma20" data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)] def compute(self, data) -> pl.Series: return data.get_column("close").rolling_mean(window_size=20) # 执行计算 loader = DataLoader(data_dir="data") engine = FactorEngine(loader) # 计算截面因子 pe_rank = PERankFactor() result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131") # 计算时序因子 ma20 = MA20Factor() result2 = engine.compute(ma20, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131") # 因子组合 combined = 0.5 * pe_rank + 0.3 * ma20 ``` ### 模型训练 ```python from src.models import PluginRegistry, ProcessingPipeline from src.models.core import PipelineStage import polars as pl # 创建处理流水线 pipeline = ProcessingPipeline([ PluginRegistry.get_processor("dropna")(), PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99), PluginRegistry.get_processor("standard_scaler")(), ]) # 准备数据 data = pl.read_csv("features.csv") # 包含特征和标签 # 划分训练/测试集 from src.models.core import WalkForwardSplit splitter = WalkForwardSplit(train_window=252, test_window=21) # 获取 LightGBM 模型 ModelClass = PluginRegistry.get_model("lightgbm") model = ModelClass(task_type="regression", params={"n_estimators": 100}) # 训练循环 for train_idx, test_idx in splitter.split(data): train_data = data[train_idx] test_data = data[test_idx] # 数据处理 X_train = pipeline.fit_transform(train_data.drop("target")) X_test = pipeline.transform(test_data.drop("target")) y_train = train_data["target"] y_test = test_data["target"] # 训练模型 model.fit(X_train, y_train) predictions = model.predict(X_test) ``` ## 核心设计 ### 1. 数据防泄露机制 **截面因子 (CrossSectionalFactor)**: - 防止日期泄露:每天只传入 `[T-lookback+1, T]` 数据 - 允许股票间比较:传入当天所有股票数据 - 典型应用:PE排名、市值分位数、当日收益率排名 **时序因子 (TimeSeriesFactor)**: - 防止股票泄露:每只股票单独计算 - 允许历史数据访问:传入完整时间序列 - 典型应用:移动平均线、RSI、历史波动率 ### 2. 插件注册机制 ```python from src.models.registry import PluginRegistry # 注册自定义处理器 @PluginRegistry.register_processor("my_processor") class MyProcessor(BaseProcessor): stage = PipelineStage.TRAIN def fit(self, data): # 学习参数 return self def transform(self, data): # 转换数据 return data # 使用 processor_class = PluginRegistry.get_processor("my_processor") processor = processor_class() ``` ### 3. 数据同步策略 **智能增量同步**: ```python from src.data.db_manager import SyncManager manager = SyncManager() result = manager.sync( table_name="daily", fetch_func=get_daily, start_date="20240101", end_date="20240131" ) # 自动检测:表不存在→全量,表存在→增量 ``` ## 文档 - [因子框架设计](docs/factor_framework_design.md) - 因子计算架构详解 - [模型框架设计](docs/ml_framework_design.md) - 模型训练架构详解 - [数据同步指南](docs/db_sync_guide.md) - DuckDB 数据同步 API 说明 - [代码审查报告](docs/code_review_factors_20260222.md) - 因子框架代码审查 ## 开发规范 - **Python 版本**: 3.10+ - **代码风格**: Google 风格文档字符串 - **类型提示**: 强制类型注解 - **测试**: pytest 框架 - **包管理**: uv (禁止直接使用 pip/python) ## 技术栈 - **数据处理**: Polars, Pandas, NumPy - **数据存储**: DuckDB (嵌入式 OLAP 数据库) - **API 接口**: Tushare Pro - **机器学习**: LightGBM, CatBoost, scikit-learn - **配置管理**: pydantic-settings ## 许可证 MIT License