fix(factors): 修复 ts_corr/ts_cov 实现并添加 abs 函数支持
- 修复 ts_corr 和 ts_cov 使用 pl.rolling_corr/pl.rolling_cov 模块级函数 - 添加 abs 函数处理器到 translator - 扩展 notebook 中的因子定义(24 -> 49 个) - 更新 AGENTS.md 文档结构和 Training 模块说明
This commit is contained in:
457
AGENTS.md
457
AGENTS.md
@@ -6,9 +6,9 @@ A股量化投资框架 - Python 项目,用于量化股票投资分析。
|
||||
|
||||
**⚠️ 强制要求:所有沟通和思考过程必须使用中文。**
|
||||
|
||||
所有与 AI Agent 的交流必须使用中文
|
||||
代码中的注释和文档字符串使用中文
|
||||
禁止使用英文进行思考或沟通
|
||||
- 所有与 AI Agent 的交流必须使用中文
|
||||
- 代码中的注释和文档字符串使用中文
|
||||
- 禁止使用英文进行思考或沟通
|
||||
|
||||
|
||||
## 构建/检查/测试命令
|
||||
@@ -44,7 +44,7 @@ uv run pytest --cov=src --cov-report=term-missing
|
||||
|
||||
```bash
|
||||
# 禁止直接使用 python
|
||||
python -c "..." # 禁止!
|
||||
python -c "…" # 禁止!
|
||||
python script.py # 禁止!
|
||||
python -m pytest # 禁止!
|
||||
python -m pip install # 禁止!
|
||||
@@ -59,7 +59,7 @@ pip list # 禁止!
|
||||
|
||||
```bash
|
||||
# 运行 Python 代码
|
||||
uv run python -c "..." # ✅ 正确
|
||||
uv run python -c "…" # ✅ 正确
|
||||
uv run python script.py # ✅ 正确
|
||||
|
||||
# 安装依赖
|
||||
@@ -82,65 +82,79 @@ ProStock/
|
||||
│ │
|
||||
│ ├── data/ # 数据获取与存储
|
||||
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||
│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync)
|
||||
│ │ │ ├── api_daily.py # 日线数据接口(DailySync)
|
||||
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync)
|
||||
│ │ │ ├── base_sync.py # 同步基础抽象类
|
||||
│ │ │ ├── api_daily.py # 日线数据接口
|
||||
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口
|
||||
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
||||
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
||||
│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync)
|
||||
│ │ │ ├── api_bak_basic.py # 历史股票列表接口
|
||||
│ │ │ ├── api_namechange.py # 股票名称变更接口
|
||||
│ │ │ ├── api_stock_st.py # ST股票信息接口
|
||||
│ │ │ ├── api_daily_basic.py # 每日指标接口
|
||||
│ │ │ ├── api_stk_limit.py # 涨跌停价格接口
|
||||
│ │ │ ├── financial_data/ # 财务数据接口
|
||||
│ │ │ │ ├── api_income.py # 利润表接口
|
||||
│ │ │ │ └── api_financial_sync.py # 财务数据同步
|
||||
│ │ │ │ ├── api_balance.py # 资产负债表接口
|
||||
│ │ │ │ ├── api_cashflow.py # 现金流量表接口
|
||||
│ │ │ │ ├── api_fina_indicator.py # 财务指标接口
|
||||
│ │ │ │ └── api_financial_sync.py # 财务数据同步调度中心
|
||||
│ │ │ └── __init__.py
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── client.py # Tushare API 客户端(带速率限制)
|
||||
│ │ ├── config.py # 数据模块配置
|
||||
│ │ ├── data_router.py # 数据路由器(factors/engine 专用)
|
||||
│ │ ├── db_inspector.py # 数据库信息查看工具
|
||||
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
||||
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
||||
│ │ ├── storage.py # 数据存储核心
|
||||
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
||||
│ │ ├── db_inspector.py # 数据库信息查看工具
|
||||
│ │ ├── sync.py # 数据同步调度中心
|
||||
│ │ └── utils.py # 数据模块工具函数
|
||||
│ │ ├── sync_registry.py # 同步器注册表
|
||||
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
||||
│ │ ├── catalog.py # 数据目录管理
|
||||
│ │ ├── config.py # 数据模块配置
|
||||
│ │ ├── utils.py # 数据模块工具函数
|
||||
│ │ └── financial_loader.py # 财务数据加载器
|
||||
│ │
|
||||
│ ├── factors/ # 因子计算框架(DSL 表达式驱动)
|
||||
│ │ ├── engine/ # 执行引擎子模块
|
||||
│ │ │ ├── __init__.py # 导出引擎组件
|
||||
│ │ │ ├── data_spec.py # 数据规格定义(DataSpec, ExecutionPlan)
|
||||
│ │ │ ├── data_spec.py # 数据规格定义
|
||||
│ │ │ ├── data_router.py # 数据路由器
|
||||
│ │ │ ├── planner.py # 执行计划生成器(ExecutionPlanner)
|
||||
│ │ │ ├── compute_engine.py # 计算引擎(ComputeEngine)
|
||||
│ │ │ └── factor_engine.py # 因子引擎统一入口(FactorEngine)
|
||||
│ │ │ ├── planner.py # 执行计划生成器
|
||||
│ │ │ ├── compute_engine.py # 计算引擎
|
||||
│ │ │ ├── schema_cache.py # 表结构缓存
|
||||
│ │ │ └── factor_engine.py # 因子引擎统一入口
|
||||
│ │ ├── __init__.py # 导出所有公开 API
|
||||
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
|
||||
│ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等)
|
||||
│ │ ├── api.py # API 层 - 常用符号和函数
|
||||
│ │ ├── compiler.py # AST 编译器 - 依赖提取
|
||||
│ │ ├── translator.py # Polars 表达式翻译器
|
||||
│ │ ├── parser.py # 字符串公式解析器(FormulaParser)
|
||||
│ │ ├── registry.py # 函数注册表(FunctionRegistry)
|
||||
│ │ └── exceptions.py # 异常定义(FormulaParseError等)
|
||||
│ │ ├── parser.py # 字符串公式解析器
|
||||
│ │ ├── registry.py # 函数注册表
|
||||
│ │ ├── decorators.py # 装饰器工具
|
||||
│ │ └── exceptions.py # 异常定义
|
||||
│ │
|
||||
│ ├── pipeline/ # 模型训练管道
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── pipeline.py # 处理流水线(ProcessingPipeline)
|
||||
│ │ ├── registry.py # 插件注册中心(PluginRegistry)
|
||||
│ │ ├── core/ # 核心抽象
|
||||
│ ├── training/ # 训练模块
|
||||
│ │ ├── core/ # 训练核心组件
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── base.py # 基类定义(BaseProcessor/BaseModel/BaseSplitter等)
|
||||
│ │ │ └── splitter.py # 时间序列划分策略
|
||||
│ │ ├── models/ # 模型实现
|
||||
│ │ │ ├── trainer.py # 训练器主类
|
||||
│ │ │ └── stock_pool_manager.py # 股票池管理器
|
||||
│ │ ├── components/ # 组件
|
||||
│ │ │ ├── base.py # 基础抽象类
|
||||
│ │ │ ├── splitters.py # 数据划分器
|
||||
│ │ │ ├── selectors.py # 股票选择器
|
||||
│ │ │ ├── filters.py # 数据过滤器
|
||||
│ │ │ ├── models/ # 模型实现
|
||||
│ │ │ │ ├── __init__.py
|
||||
│ │ │ │ └── lightgbm.py # LightGBM 模型
|
||||
│ │ │ └── processors/ # 数据处理器
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ └── transforms.py # 变换处理器
|
||||
│ │ ├── config/ # 配置
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ └── models.py # LightGBM、CatBoost 等
|
||||
│ │ └── processors/ # 数据处理器
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── processors.py # 标准化、缩尾、中性化等
|
||||
│ │ │ └── config.py # 训练配置
|
||||
│ │ ├── registry.py # 组件注册中心
|
||||
│ │ └── __init__.py # 导出所有组件
|
||||
│ │
|
||||
│ └── training/ # 训练入口
|
||||
│ ├── __init__.py
|
||||
│ ├── pipeline.py # 训练流程配置
|
||||
│ └── output/ # 训练输出
|
||||
│ └── top_stocks.tsv # 推荐股票结果
|
||||
│ └── experiment/ # 实验代码
|
||||
│ └── regression.ipynb # 完整训练流程示例
|
||||
│
|
||||
├── tests/ # 测试文件
|
||||
│ ├── test_sync.py
|
||||
@@ -148,8 +162,6 @@ ProStock/
|
||||
│ ├── test_factor_engine.py
|
||||
│ ├── test_factor_integration.py
|
||||
│ ├── test_pro_bar.py
|
||||
│ ├── test_601117_factors.py
|
||||
│ ├── test_two_stocks_string_factors.py
|
||||
│ ├── test_db_manager.py
|
||||
│ ├── test_daily_storage.py
|
||||
│ ├── test_tushare_api.py
|
||||
@@ -183,6 +195,7 @@ import threading
|
||||
# 第三方包
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from tqdm import tqdm
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@@ -250,7 +263,7 @@ except Exception as e:
|
||||
### 配置
|
||||
- 对所有配置使用 **pydantic-settings**
|
||||
- 从 `config/.env.local` 文件加载
|
||||
- 环境变量自动转换:`tushare_token` → `TUSHARE_TOKEN`
|
||||
- 环境变量自动转换:`tushare_token` -> `TUSHARE_TOKEN`
|
||||
- 对配置单例使用 `@lru_cache()`
|
||||
|
||||
### 数据存储
|
||||
@@ -291,7 +304,6 @@ except Exception as e:
|
||||
- `pydantic>=2.0.0`、`pydantic-settings>=2.0.0` - 配置
|
||||
- `tqdm>=4.65.0` - 进度条
|
||||
- `lightgbm>=4.0.0` - 机器学习模型
|
||||
- `catboost>=1.2.0` - 机器学习模型
|
||||
- `pytest` - 测试(开发)
|
||||
|
||||
### 环境变量
|
||||
@@ -315,124 +327,13 @@ uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
||||
# 自定义线程数
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||
|
||||
# 同步财务数据
|
||||
uv run python -c "from src.data.api_wrappers.financial_data import sync_financial; sync_financial()"
|
||||
|
||||
# 运行因子计算测试
|
||||
uv run pytest tests/test_factor_engine.py -v
|
||||
```
|
||||
|
||||
## 架构变更历史
|
||||
|
||||
### v2.2 (2026-03-01) - 因子框架 DSL 化重构
|
||||
|
||||
#### 因子计算框架重构
|
||||
**变更**: 从基类继承方式迁移到 DSL 表达式方式
|
||||
**原因**:
|
||||
- 提供更直观的数学公式表达方式
|
||||
- 支持因子表达式的组合和嵌套
|
||||
- 更好的类型安全和编译期检查
|
||||
**架构变化**:
|
||||
- 新增 `dsl.py`: 表达式节点基类和运算符重载(Symbol、FunctionNode等)
|
||||
- 新增 `api.py`: 常用符号(close/open/volume等)和函数(ts_mean/cs_rank等)
|
||||
- 新增 `compiler.py`: AST 编译器,提取表达式依赖
|
||||
- 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式
|
||||
- 新增 `parser.py`: 字符串公式解析器(FormulaParser),支持从字符串解析 DSL 表达式
|
||||
- 新增 `registry.py`: 函数注册表(FunctionRegistry),管理字符串函数名到 Python 函数的映射
|
||||
- 新增 `exceptions.py`: 公式解析异常定义(FormulaParseError、UnknownFunctionError等)
|
||||
- 重构 `engine/` 子模块:
|
||||
- `factor_engine.py`: 因子引擎统一入口(FactorEngine)
|
||||
- `data_spec.py`: 数据规格定义(DataSpec, ExecutionPlan)
|
||||
- `data_router.py`: 数据路由器
|
||||
- `planner.py`: 执行计划生成器(ExecutionPlanner)
|
||||
- `compute_engine.py`: 计算引擎(ComputeEngine)
|
||||
- 移除: `base.py`、`composite.py`、`data_loader.py`、根目录的 `data_spec.py`
|
||||
- 移除: `factors/momentum/` 和 `factors/financial/` 子目录
|
||||
**使用方式对比**:
|
||||
```python
|
||||
# 旧方式(基类继承)
|
||||
class MA20Factor(TimeSeriesFactor):
|
||||
name = "ma20"
|
||||
data_specs = [DataSpec("daily", ["close"], 20)]
|
||||
def compute(self, data):
|
||||
return data.get_column("close").rolling_mean(20)
|
||||
|
||||
# 新方式(DSL 表达式)
|
||||
from src.factors.api import close, ts_mean
|
||||
ma20 = ts_mean(close, 20) # 直接编写数学表达式
|
||||
|
||||
engine = FactorEngine()
|
||||
engine.register("ma20", ma20)
|
||||
result = engine.compute(["ma20"], "20240101", "20240131")
|
||||
|
||||
# 字符串公式解析(Phase 1 新增)
|
||||
from src.factors import FormulaParser, FunctionRegistry
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
expr = parser.parse("ts_mean(close, 20) / close") # 从字符串解析
|
||||
```
|
||||
|
||||
#### data 模块补充完善
|
||||
**新增文件**:
|
||||
- `api_wrappers/base_sync.py`: 数据同步基础抽象类(BaseDataSync、StockBasedSync、DateBasedSync)
|
||||
- `data_router.py`: 数据路由器(已集成到 factors/engine.py 中的 DataRouter)
|
||||
- `utils.py`: 日期工具函数(get_today_date、get_next_date、is_quarter_end等)
|
||||
**影响**: 数据同步逻辑更加规范化,支持按股票和按日期两种同步模式
|
||||
|
||||
### v2.1 (2026-02-28) - 同步模块规范更新
|
||||
**变更**: 明确 `sync.py` 只包含每日更新的数据同步
|
||||
**原因**: 区分高频(每日)和低频(季度/年度)数据,避免不必要的 API 调用
|
||||
**规范**:
|
||||
- `sync.py` / `sync_all_data()`: **仅包含每日更新的数据**
|
||||
- 日线数据 (`api_daily`)
|
||||
- Pro Bar 数据 (`api_pro_bar`)
|
||||
- 交易日历 (`api_trade_cal`)
|
||||
- 股票基本信息 (`api_stock_basic`)
|
||||
- 历史股票列表 (`api_bak_basic`)
|
||||
|
||||
- **不应放入 `sync.py` 的季度/低频数据**:
|
||||
- 财务数据 (`financial_data/` 目录): 利润表、资产负债表、现金流量表等
|
||||
- 名称变更 (`api_namechange`): 已移除自动同步,建议手动定期同步
|
||||
|
||||
- **季度数据同步方式**:
|
||||
```python
|
||||
# 财务数据单独同步(不在 sync_all_data 中)
|
||||
from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial
|
||||
sync_financial() # 增量同步利润表
|
||||
|
||||
# 名称变更手动同步
|
||||
from src.data.api_wrappers import sync_namechange
|
||||
sync_namechange(force=True)
|
||||
```
|
||||
|
||||
### v2.0 (2026-02-23) - 重要更新
|
||||
|
||||
#### 存储层重构
|
||||
**变更**: 从 HDF5 迁移到 DuckDB
|
||||
**原因**: DuckDB 提供更好的查询性能、SQL 下推能力、并发支持
|
||||
**影响**: 所有数据表现在使用 DuckDB 存储,旧 HDF5 文件可手动迁移
|
||||
|
||||
#### Sync 类迁移
|
||||
**变更**: `DataSync` 类从 `sync.py` 迁移到 `api_daily.py`
|
||||
**原因**: 实现代码职责分离,每个 API 文件包含自己的同步逻辑
|
||||
**影响**:
|
||||
- `sync.py` 保留为调度中心
|
||||
- `api_daily.py` 包含 `DailySync` 类和 `sync_daily` 函数
|
||||
|
||||
#### 新增模块
|
||||
**pipeline 模块**: 机器学习流水线组件(处理器、模型、划分策略)
|
||||
**training 模块**: 训练入口程序
|
||||
**factors/momentum**: 动量因子(MA、收益率排名)
|
||||
**factors/financial**: 财务因子框架
|
||||
**data/utils.py**: 日期工具函数集中管理
|
||||
|
||||
#### 新增 API 接口
|
||||
`api_namechange.py`: 股票曾用名接口(手动同步)
|
||||
`api_bak_basic.py`: 历史股票列表接口
|
||||
|
||||
#### 工具函数统一
|
||||
`get_today_date()`、`get_next_date()`、`DEFAULT_START_DATE` 等函数统一在 `src/data/utils.py` 中管理
|
||||
其他模块应从 `utils.py` 导入这些函数,避免重复定义
|
||||
|
||||
其他模块应从 `utils.py` 导入这些函数,避免重复定义
|
||||
|
||||
|
||||
## Factors 框架设计说明
|
||||
|
||||
### 架构层次
|
||||
@@ -468,58 +369,28 @@ Engine (engine/) <- 执行引擎
|
||||
数据层 (data_router.py + DuckDB) <- 数据获取和存储
|
||||
```
|
||||
|
||||
### 使用方式
|
||||
|
||||
#### 1. 基础表达式(DSL 方式)
|
||||
|
||||
```python
|
||||
from src.factors.api import close, open, ts_mean, cs_rank
|
||||
|
||||
# 定义因子表达式(惰性计算)
|
||||
ma20 = ts_mean(close, 20) # 20日移动平均
|
||||
price_rank = cs_rank(close) # 收盘价截面排名
|
||||
|
||||
# 组合运算
|
||||
alpha = ma20 * 0.6 + price_rank * 0.4
|
||||
```
|
||||
|
||||
#### 2. 字符串公式解析(Phase 1 新增)
|
||||
|
||||
```python
|
||||
from src.factors import FormulaParser, FunctionRegistry
|
||||
|
||||
# 创建解析器(自动加载 api.py 中的所有函数)
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 从字符串解析公式
|
||||
expr = parser.parse("ts_mean(close, 20) / close")
|
||||
complex_expr = parser.parse("cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
|
||||
|
||||
# 支持完整运算符和函数调用
|
||||
expr2 = parser.parse("(close - open) / open * 100") # 涨跌幅
|
||||
expr3 = parser.parse("ts_corr(close, volume, 20)") # 量价相关性
|
||||
```
|
||||
|
||||
#### 3. 注册和执行
|
||||
### FactorEngine 核心 API
|
||||
|
||||
```python
|
||||
from src.factors import FactorEngine
|
||||
|
||||
# 初始化引擎
|
||||
engine = FactorEngine()
|
||||
|
||||
# 注册 DSL 表达式
|
||||
engine.register("ma20", ma20)
|
||||
engine.register("price_rank", price_rank)
|
||||
# 方式1: 使用 DSL 表达式
|
||||
from src.factors.api import close, ts_mean, cs_rank
|
||||
engine.register("ma20", ts_mean(close, 20))
|
||||
engine.register("price_rank", cs_rank(close))
|
||||
|
||||
# 或注册字符串解析的表达式
|
||||
engine.register("alpha", parser.parse("ma20 * 0.6 + price_rank * 0.4"))
|
||||
# 方式2: 使用字符串表达式(推荐)
|
||||
engine.add_factor("ma20", "ts_mean(close, 20)")
|
||||
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
|
||||
|
||||
# 执行计算
|
||||
result = engine.compute(
|
||||
factor_names=["ma20", "price_rank"],
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
# 计算因子
|
||||
result = engine.compute(["ma20", "price_rank"], "20240101", "20240131")
|
||||
|
||||
# 查看已注册因子
|
||||
print(engine.list_registered())
|
||||
```
|
||||
|
||||
### 支持的函数
|
||||
@@ -584,7 +455,6 @@ expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
|
||||
- `EmptyExpressionError` - 空表达式错误
|
||||
- `DuplicateFunctionError` - 函数重复注册错误
|
||||
|
||||
示例:
|
||||
```python
|
||||
from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError
|
||||
|
||||
@@ -596,7 +466,183 @@ except UnknownFunctionError as e:
|
||||
```
|
||||
|
||||
|
||||
## AI 行为准则
|
||||
## Training 模块设计说明
|
||||
|
||||
### 架构概述
|
||||
|
||||
Training 模块位于 `src/training/` 目录,负责从因子数据到模型训练、预测的完整流程。采用组件化设计,支持数据处理器、模型、过滤器、股票池管理器的灵活组合。
|
||||
|
||||
```
|
||||
src/training/
|
||||
├── core/
|
||||
│ ├── trainer.py # Trainer 主类
|
||||
│ └── stock_pool_manager.py # 股票池管理器
|
||||
├── components/
|
||||
│ ├── base.py # BaseModel、BaseProcessor 抽象基类
|
||||
│ ├── splitters.py # DateSplitter 日期划分器
|
||||
│ ├── filters.py # STFilter 等过滤器
|
||||
│ ├── models/
|
||||
│ │ └── lightgbm.py # LightGBMModel
|
||||
│ └── processors/
|
||||
│ └── transforms.py # 数据处理器实现
|
||||
├── config/
|
||||
│ └── config.py # TrainingConfig
|
||||
└── registry.py # 组件注册中心
|
||||
```
|
||||
|
||||
### Trainer 核心流程
|
||||
|
||||
```python
|
||||
from src.training import Trainer, DateSplitter, StockPoolManager
|
||||
from src.training.components.models import LightGBMModel
|
||||
from src.training.components.processors import Winsorizer, StandardScaler
|
||||
from src.training.components.filters import STFilter
|
||||
import polars as pl
|
||||
|
||||
# 1. 创建模型
|
||||
model = LightGBMModel(params={
|
||||
"objective": "regression",
|
||||
"metric": "mae",
|
||||
"num_leaves": 20,
|
||||
"learning_rate": 0.01,
|
||||
"n_estimators": 1000,
|
||||
})
|
||||
|
||||
# 2. 创建数据划分器(正确的 train/val/test 三分法)
|
||||
splitter = DateSplitter(
|
||||
train_start="20200101",
|
||||
train_end="20231231",
|
||||
val_start="20240101",
|
||||
val_end="20241231",
|
||||
test_start="20250101",
|
||||
test_end="20261231",
|
||||
)
|
||||
|
||||
# 3. 创建数据处理器
|
||||
processors = [
|
||||
NullFiller(strategy="mean"),
|
||||
Winsorizer(lower=0.01, upper=0.99),
|
||||
StandardScaler(exclude_cols=["ts_code", "trade_date", "target"]),
|
||||
]
|
||||
|
||||
# 4. 创建股票池筛选函数
|
||||
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""筛选小市值股票,排除创业板/科创板/北交所"""
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("300") & # 排除创业板
|
||||
~df["ts_code"].str.starts_with("688") & # 排除科创板
|
||||
~df["ts_code"].str.starts_with("8") & # 排除北交所
|
||||
~df["ts_code"].str.starts_with("9") &
|
||||
~df["ts_code"].str.starts_with("4")
|
||||
)
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(1000, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
|
||||
return df["ts_code"].is_in(small_cap_codes)
|
||||
|
||||
pool_manager = StockPoolManager(
|
||||
filter_func=stock_pool_filter,
|
||||
required_columns=["total_mv"],
|
||||
data_router=engine.router,
|
||||
)
|
||||
|
||||
# 5. 创建 ST 过滤器
|
||||
st_filter = STFilter(data_router=engine.router)
|
||||
|
||||
# 6. 创建训练器
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
pool_manager=pool_manager,
|
||||
processors=processors,
|
||||
filters=[st_filter],
|
||||
splitter=splitter,
|
||||
target_col="future_return_5",
|
||||
feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"],
|
||||
)
|
||||
|
||||
# 7. 执行训练
|
||||
trainer.train(data)
|
||||
|
||||
# 8. 获取结果
|
||||
results = trainer.get_results()
|
||||
```
|
||||
|
||||
### 数据处理器
|
||||
|
||||
**NullFiller** - 缺失值填充:
|
||||
```python
|
||||
from src.training.components.processors import NullFiller
|
||||
|
||||
# 使用 0 填充
|
||||
filler = NullFiller(strategy="zero")
|
||||
|
||||
# 使用均值填充(每天独立计算截面均值)
|
||||
filler = NullFiller(strategy="mean", by_date=True)
|
||||
|
||||
# 使用指定值填充
|
||||
filler = NullFiller(strategy="value", fill_value=-999)
|
||||
```
|
||||
|
||||
**Winsorizer** - 缩尾处理:
|
||||
```python
|
||||
from src.training.components.processors import Winsorizer
|
||||
|
||||
# 全局缩尾(默认)
|
||||
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=False)
|
||||
|
||||
# 每天独立缩尾
|
||||
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=True)
|
||||
```
|
||||
|
||||
**StandardScaler** - 标准化:
|
||||
```python
|
||||
from src.training.components.processors import StandardScaler
|
||||
|
||||
# 全局标准化(学习训练集的均值和标准差)
|
||||
scaler = StandardScaler(exclude_cols=["ts_code", "trade_date", "target"])
|
||||
```
|
||||
|
||||
**CrossSectionalStandardScaler** - 截面标准化:
|
||||
```python
|
||||
from src.training.components.processors import CrossSectionalStandardScaler
|
||||
|
||||
# 每天独立标准化(不需要 fit)
|
||||
cs_scaler = CrossSectionalStandardScaler(
|
||||
exclude_cols=["ts_code", "trade_date", "target"],
|
||||
date_col="trade_date",
|
||||
)
|
||||
```
|
||||
|
||||
### 组件注册机制
|
||||
|
||||
```python
|
||||
from src.training.registry import register_model, register_processor
|
||||
from src.training.components.base import BaseModel, BaseProcessor
|
||||
|
||||
# 注册自定义模型
|
||||
@register_model("custom_model")
|
||||
class CustomModel(BaseModel):
|
||||
name = "custom_model"
|
||||
|
||||
def fit(self, X, y):
|
||||
# 训练逻辑
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
# 预测逻辑
|
||||
return predictions
|
||||
|
||||
# 注册自定义处理器
|
||||
@register_processor("custom_processor")
|
||||
class CustomProcessor(BaseProcessor):
|
||||
name = "custom_processor"
|
||||
|
||||
def transform(self, X):
|
||||
# 转换逻辑
|
||||
return X
|
||||
```
|
||||
|
||||
|
||||
## AI 行为准则
|
||||
|
||||
### LSP 检测报错处理
|
||||
@@ -649,3 +695,4 @@ LSP 报错:Syntax error on line 45
|
||||
4. **检查方法**
|
||||
- 使用正则表达式搜索 emoji:`[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]`
|
||||
- 提交前自查,确保无 emoji 混入代码
|
||||
|
||||
|
||||
Reference in New Issue
Block a user