refactor(training): 重构股票池管理 API 并更新训练流程
- 移除 StockFilterConfig/MarketCapSelectorConfig,改用 StockPoolManager + filter_func - Trainer 支持 train/val/test 三分法划分 - 更新 regression.ipynb 适配新 API - 删除已弃用的 test_selectors.py,后续补充 StockPoolManager 测试
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -17,11 +17,8 @@ from src.training.registry import (
|
|||||||
# 数据划分器
|
# 数据划分器
|
||||||
from src.training.components.splitters import DateSplitter
|
from src.training.components.splitters import DateSplitter
|
||||||
|
|
||||||
# 股票池选择器配置
|
# 股票池选择器配置(已迁移到 StockPoolManager,保留文件占位)
|
||||||
from src.training.components.selectors import (
|
# from src.training.components.selectors import ...
|
||||||
MarketCapSelectorConfig,
|
|
||||||
StockFilterConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 数据处理器
|
# 数据处理器
|
||||||
from src.training.components.processors import (
|
from src.training.components.processors import (
|
||||||
@@ -54,9 +51,9 @@ __all__ = [
|
|||||||
"register_processor",
|
"register_processor",
|
||||||
# 数据划分器
|
# 数据划分器
|
||||||
"DateSplitter",
|
"DateSplitter",
|
||||||
# 股票池选择器配置
|
# 股票池选择器配置(已迁移,保留注释占位)
|
||||||
"StockFilterConfig",
|
# "StockFilterConfig", # 已删除,使用 StockPoolManager + filter_func 替代
|
||||||
"MarketCapSelectorConfig",
|
# "MarketCapSelectorConfig", # 已删除,使用 StockPoolManager + required_factors 替代
|
||||||
# 数据处理器
|
# 数据处理器
|
||||||
"NullFiller",
|
"NullFiller",
|
||||||
"StandardScaler",
|
"StandardScaler",
|
||||||
|
|||||||
@@ -9,11 +9,8 @@ from src.training.components.base import BaseModel, BaseProcessor
|
|||||||
# 数据划分器
|
# 数据划分器
|
||||||
from src.training.components.splitters import DateSplitter
|
from src.training.components.splitters import DateSplitter
|
||||||
|
|
||||||
# 股票池选择器配置
|
# 股票池选择器配置(已迁移到 StockPoolManager)
|
||||||
from src.training.components.selectors import (
|
# from src.training.components.selectors import ... # 已删除
|
||||||
MarketCapSelectorConfig,
|
|
||||||
StockFilterConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 数据处理器
|
# 数据处理器
|
||||||
from src.training.components.processors import (
|
from src.training.components.processors import (
|
||||||
@@ -29,8 +26,8 @@ __all__ = [
|
|||||||
"BaseModel",
|
"BaseModel",
|
||||||
"BaseProcessor",
|
"BaseProcessor",
|
||||||
"DateSplitter",
|
"DateSplitter",
|
||||||
"StockFilterConfig",
|
# "StockFilterConfig", # 已删除
|
||||||
"MarketCapSelectorConfig",
|
# "MarketCapSelectorConfig", # 已删除
|
||||||
"StandardScaler",
|
"StandardScaler",
|
||||||
"CrossSectionalStandardScaler",
|
"CrossSectionalStandardScaler",
|
||||||
"Winsorizer",
|
"Winsorizer",
|
||||||
|
|||||||
@@ -1,81 +1,20 @@
|
|||||||
"""股票池选择器配置
|
"""股票池选择器配置
|
||||||
|
|
||||||
提供股票过滤和市值选择的配置类。
|
此模块目前为空,股票池筛选功能已迁移到 StockPoolManager。
|
||||||
|
所有筛选逻辑通过传入自定义函数实现。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
# 旧配置类已删除:
|
||||||
from typing import List, Optional
|
# - StockFilterConfig (使用 filter_func 替代)
|
||||||
|
# - MarketCapSelectorConfig (使用 filter_func + required_factors 替代)
|
||||||
|
#
|
||||||
@dataclass
|
# 新的使用方式:
|
||||||
class StockFilterConfig:
|
# from src.training import StockPoolManager
|
||||||
"""股票过滤器配置
|
#
|
||||||
|
# def my_filter(df: pl.DataFrame) -> pl.Series:
|
||||||
用于过滤掉不需要的股票(如创业板、科创板等)。
|
# return df["total_mv"] > 1e9
|
||||||
基于股票代码进行过滤,不依赖外部数据。
|
#
|
||||||
|
# pool_manager = StockPoolManager(
|
||||||
Attributes:
|
# filter_func=my_filter,
|
||||||
exclude_cyb: 是否排除创业板(300xxx, 301xxx)
|
# required_columns=["total_mv"],
|
||||||
exclude_kcb: 是否排除科创板(688xxx)
|
# )
|
||||||
exclude_bj: 是否排除北交所(.BJ 后缀)
|
|
||||||
exclude_st: 是否排除ST股票(需要外部数据支持)
|
|
||||||
"""
|
|
||||||
|
|
||||||
exclude_cyb: bool = True
|
|
||||||
exclude_kcb: bool = True
|
|
||||||
exclude_bj: bool = True
|
|
||||||
exclude_st: bool = True
|
|
||||||
|
|
||||||
def filter_codes(self, codes: List[str]) -> List[str]:
|
|
||||||
"""应用过滤条件,返回过滤后的股票代码列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
codes: 原始股票代码列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
过滤后的股票代码列表
|
|
||||||
|
|
||||||
Note:
|
|
||||||
ST 股票过滤需要额外数据,在 StockPoolManager 中处理。
|
|
||||||
此方法仅基于代码前缀进行过滤。
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
for code in codes:
|
|
||||||
# 排除创业板(300xxx, 301xxx)
|
|
||||||
if self.exclude_cyb and code.startswith(("300", "301")):
|
|
||||||
continue
|
|
||||||
# 排除科创板(688xxx)
|
|
||||||
if self.exclude_kcb and code.startswith("688"):
|
|
||||||
continue
|
|
||||||
# 排除北交所(.BJ 后缀)
|
|
||||||
if self.exclude_bj and code.endswith(".BJ"):
|
|
||||||
continue
|
|
||||||
result.append(code)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MarketCapSelectorConfig:
|
|
||||||
"""市值选择器配置
|
|
||||||
|
|
||||||
每日独立选择市值最大或最小的 n 只股票。
|
|
||||||
市值数据从 daily_basic 表独立获取,仅用于筛选。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
enabled: 是否启用选择
|
|
||||||
n: 选择前 n 只
|
|
||||||
ascending: False=最大市值, True=最小市值
|
|
||||||
market_cap_col: 市值列名(来自 daily_basic)
|
|
||||||
"""
|
|
||||||
|
|
||||||
enabled: bool = True
|
|
||||||
n: int = 100
|
|
||||||
ascending: bool = False
|
|
||||||
market_cap_col: str = "total_mv"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""验证配置参数"""
|
|
||||||
if self.n <= 0:
|
|
||||||
raise ValueError(f"n 必须是正整数,得到: {self.n}")
|
|
||||||
if not self.market_cap_col:
|
|
||||||
raise ValueError("market_cap_col 不能为空")
|
|
||||||
|
|||||||
@@ -4,15 +4,13 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from src.training.config.config import (
|
from src.training.config.config import (
|
||||||
MarketCapSelectorConfig,
|
|
||||||
ProcessorConfig,
|
ProcessorConfig,
|
||||||
StockFilterConfig,
|
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"StockFilterConfig",
|
# "StockFilterConfig", # 已删除
|
||||||
"MarketCapSelectorConfig",
|
# "MarketCapSelectorConfig", # 已删除
|
||||||
"ProcessorConfig",
|
"ProcessorConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -10,26 +10,6 @@ from pydantic import Field, validator
|
|||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StockFilterConfig:
|
|
||||||
"""股票过滤器配置"""
|
|
||||||
|
|
||||||
exclude_cyb: bool = True # 排除创业板
|
|
||||||
exclude_kcb: bool = True # 排除科创板
|
|
||||||
exclude_bj: bool = True # 排除北交所
|
|
||||||
exclude_st: bool = True # 排除ST股票
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MarketCapSelectorConfig:
|
|
||||||
"""市值选择器配置"""
|
|
||||||
|
|
||||||
enabled: bool = True # 是否启用
|
|
||||||
n: int = 100 # 选择前 n 只
|
|
||||||
ascending: bool = False # False=最大市值, True=最小市值
|
|
||||||
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProcessorConfig:
|
class ProcessorConfig:
|
||||||
"""处理器配置"""
|
"""处理器配置"""
|
||||||
@@ -56,25 +36,6 @@ class TrainingConfig(BaseSettings):
|
|||||||
test_start: str = Field(..., description="测试期开始 YYYYMMDD")
|
test_start: str = Field(..., description="测试期开始 YYYYMMDD")
|
||||||
test_end: str = Field(..., description="测试期结束 YYYYMMDD")
|
test_end: str = Field(..., description="测试期结束 YYYYMMDD")
|
||||||
|
|
||||||
# === 股票池配置 ===
|
|
||||||
stock_filter: StockFilterConfig = Field(
|
|
||||||
default_factory=lambda: StockFilterConfig(
|
|
||||||
exclude_cyb=True,
|
|
||||||
exclude_kcb=True,
|
|
||||||
exclude_bj=True,
|
|
||||||
exclude_st=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
stock_selector: Optional[MarketCapSelectorConfig] = Field(
|
|
||||||
default_factory=lambda: MarketCapSelectorConfig(
|
|
||||||
enabled=True,
|
|
||||||
n=100,
|
|
||||||
ascending=False,
|
|
||||||
market_cap_col="total_mv",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 注意:如果 stock_selector = None,则跳过市值选择
|
|
||||||
|
|
||||||
# === 模型配置 ===
|
# === 模型配置 ===
|
||||||
model_type: str = "lightgbm"
|
model_type: str = "lightgbm"
|
||||||
model_params: Dict[str, Any] = Field(default_factory=dict)
|
model_params: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|||||||
@@ -1,57 +1,63 @@
|
|||||||
"""股票池管理器
|
"""股票池管理器
|
||||||
|
|
||||||
每日独立筛选股票池,市值数据从 daily_basic 表独立获取。
|
支持使用自定义函数和因子表达式进行每日股票池筛选。
|
||||||
|
临时计算的因子仅在筛选阶段使用,绝不泄露到训练数据。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
from src.training.components.selectors import MarketCapSelectorConfig, StockFilterConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.factors.engine.data_router import DataRouter
|
from src.factors.engine.data_router import DataRouter
|
||||||
|
|
||||||
|
|
||||||
class StockPoolManager:
|
class StockPoolManager:
|
||||||
"""股票池管理器 - 每日独立筛选
|
"""股票池管理器 - 支持自定义筛选函数和因子
|
||||||
|
|
||||||
重要约束:
|
核心特性:
|
||||||
1. 市值数据仅从 daily_basic 表获取,仅用于筛选
|
1. 支持传入自定义筛选函数
|
||||||
2. 市值数据绝不混入特征矩阵
|
2. 支持使用因子表达式进行筛选
|
||||||
3. 每日独立筛选(市值是动态变化的)
|
3. 使用 FactorEngine 计算所需因子
|
||||||
|
4. 只删除本次新生成的临时因子,保留输入中已存在的所有列
|
||||||
|
|
||||||
处理流程(每日):
|
数据流:
|
||||||
当日所有股票
|
输入数据 (含原始列,可能包含一些因子)
|
||||||
↓
|
↓
|
||||||
代码过滤(创业板、ST等)
|
[准备数据]
|
||||||
|
├─ 获取缺失的基础列 (from data_router)
|
||||||
|
└─ 计算缺失的因子 (使用 FactorEngine,标记为"本次生成")
|
||||||
↓
|
↓
|
||||||
查询 daily_basic 获取当日市值
|
[每日筛选]
|
||||||
|
├─ group_by("trade_date").apply(filter_func)
|
||||||
|
└─ 只保留 ts_code + trade_date (筛选结果标识)
|
||||||
↓
|
↓
|
||||||
市值选择(前N只)
|
[返回结果]
|
||||||
↓
|
└─ semi join 原始数据,保留所有原始列
|
||||||
返回当日选中股票列表
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filter_config: StockFilterConfig,
|
filter_func: Callable[[pl.DataFrame], pl.Series],
|
||||||
selector_config: Optional[MarketCapSelectorConfig],
|
required_columns: Optional[List[str]] = None,
|
||||||
data_router: "DataRouter",
|
required_factors: Optional[Dict[str, str]] = None,
|
||||||
|
data_router: Optional["DataRouter"] = None,
|
||||||
code_col: str = "ts_code",
|
code_col: str = "ts_code",
|
||||||
date_col: str = "trade_date",
|
date_col: str = "trade_date",
|
||||||
):
|
):
|
||||||
"""初始化股票池管理器
|
"""初始化股票池管理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filter_config: 股票过滤器配置
|
filter_func: 筛选函数,接收 DataFrame 返回布尔 Series
|
||||||
selector_config: 市值选择器配置,None 表示跳过市值选择
|
required_columns: 除输入数据外还需获取的基础列
|
||||||
data_router: 数据路由器,用于获取 daily_basic 数据
|
required_factors: 筛选所需的因子表达式 {因子名: DSL表达式}
|
||||||
|
data_router: 数据路由器,用于获取缺失列
|
||||||
code_col: 股票代码列名
|
code_col: 股票代码列名
|
||||||
date_col: 日期列名
|
date_col: 日期列名
|
||||||
"""
|
"""
|
||||||
self.filter_config = filter_config
|
self.filter_func = filter_func
|
||||||
self.selector_config = selector_config
|
self.required_columns = required_columns or []
|
||||||
|
self.required_factors = required_factors or {}
|
||||||
self.data_router = data_router
|
self.data_router = data_router
|
||||||
self.code_col = code_col
|
self.code_col = code_col
|
||||||
self.date_col = date_col
|
self.date_col = date_col
|
||||||
@@ -59,113 +65,191 @@ class StockPoolManager:
|
|||||||
def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
|
def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
"""每日独立筛选股票池
|
"""每日独立筛选股票池
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 记录输入数据的原始列
|
||||||
|
2. 收集筛选所需的完整数据(基础列 + 计算因子)
|
||||||
|
3. 按日期分组应用筛选函数
|
||||||
|
4. 只返回 ts_code 和 trade_date(筛选结果标识)
|
||||||
|
5. 用标识列从原始数据筛选(保留所有原始列)
|
||||||
|
|
||||||
|
关键:返回的数据包含输入数据的所有原始列,只移除本次新生成的临时因子
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
|
data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
筛选后的数据,仅包含每日选中的股票
|
筛选后的数据,列与输入数据完全一致(临时因子已移除)
|
||||||
|
|
||||||
Note:
|
|
||||||
- 按日期分组处理
|
|
||||||
- 市值数据从 daily_basic 独立获取
|
|
||||||
- 保持市值数据与特征数据隔离
|
|
||||||
"""
|
"""
|
||||||
dates = data.select(self.date_col).unique().sort(self.date_col)
|
# 1. 记录原始列,用于最后验证
|
||||||
|
original_columns = list(data.columns)
|
||||||
|
|
||||||
result_frames = []
|
# 2. 准备完整数据(用于筛选判断)
|
||||||
for date in dates.to_series():
|
# 返回的 enriched 包含临时因子,但不修改原始 data
|
||||||
# 获取当日数据
|
enriched = self._prepare_data(data)
|
||||||
daily_data = data.filter(pl.col(self.date_col) == date)
|
|
||||||
daily_codes = daily_data.select(self.code_col).to_series().to_list()
|
|
||||||
|
|
||||||
# 1. 代码过滤
|
# 3. 每日筛选,只保留标识列
|
||||||
filtered_codes = self.filter_config.filter_codes(daily_codes)
|
# 使用 group_by + map_groups 替代 apply(Polars 0.20+)
|
||||||
|
selected_ids = enriched.group_by(self.date_col).map_groups(
|
||||||
# 2. 市值选择(如果启用)
|
lambda df: df.filter(self.filter_func(df)).select(
|
||||||
if self.selector_config and self.selector_config.enabled:
|
[self.code_col, self.date_col]
|
||||||
# 从 daily_basic 获取当日市值
|
)
|
||||||
market_caps = self._get_market_caps_for_date(filtered_codes, date)
|
|
||||||
selected_codes = self._select_by_market_cap(filtered_codes, market_caps)
|
|
||||||
else:
|
|
||||||
selected_codes = filtered_codes
|
|
||||||
|
|
||||||
# 3. 保留当日选中的股票数据
|
|
||||||
daily_selected = daily_data.filter(
|
|
||||||
pl.col(self.code_col).is_in(selected_codes)
|
|
||||||
)
|
)
|
||||||
result_frames.append(daily_selected)
|
|
||||||
|
|
||||||
return pl.concat(result_frames)
|
# 4. 用 semi join 从原始数据筛选,自动只保留原始列
|
||||||
|
# semi join: 保留左侧(data)的所有列,只保留匹配的行
|
||||||
|
result = data.join(
|
||||||
|
selected_ids,
|
||||||
|
on=[self.code_col, self.date_col],
|
||||||
|
how="semi",
|
||||||
|
)
|
||||||
|
|
||||||
def _get_market_caps_for_date(
|
# 5. 验证:确保结果列与原始列完全一致
|
||||||
self, codes: List[str], date: str
|
if list(result.columns) != original_columns:
|
||||||
) -> Dict[str, float]:
|
raise RuntimeError(
|
||||||
"""从 daily_basic 表获取指定日期的市值数据
|
f"列发生变化!\n原始: {original_columns}\n结果: {list(result.columns)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _prepare_data(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""准备筛选所需的完整数据
|
||||||
|
|
||||||
|
步骤:
|
||||||
|
1. 获取缺失的基础列
|
||||||
|
2. 计算缺失的因子(输入中已存在的因子不再计算)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
codes: 股票代码列表
|
data: 输入数据
|
||||||
date: 日期 "YYYYMMDD"
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
{股票代码: 市值} 的字典
|
包含所有所需列和因子的数据(含临时因子)
|
||||||
"""
|
"""
|
||||||
if not codes:
|
result = data
|
||||||
return {}
|
|
||||||
|
|
||||||
assert self.selector_config is not None, (
|
# 1. 获取缺失的基础列
|
||||||
"selector_config should not be None when calling _get_market_caps_for_date"
|
if self.required_columns and self.data_router is not None:
|
||||||
)
|
result = self._fetch_required_columns(result)
|
||||||
|
|
||||||
|
# 2. 计算因子(只计算输入中不存在的)
|
||||||
|
if self.required_factors:
|
||||||
|
result = self._compute_factors(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _fetch_required_columns(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""从 data_router 获取缺失的基础列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 当前数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
补充了缺失列的数据
|
||||||
|
"""
|
||||||
|
missing_cols = set(self.required_columns) - set(data.columns)
|
||||||
|
if not missing_cols:
|
||||||
|
return data
|
||||||
|
|
||||||
|
if self.data_router is None:
|
||||||
|
raise ValueError(f"需要获取列 {missing_cols},但未提供 data_router")
|
||||||
|
|
||||||
|
# 获取日期范围
|
||||||
|
dates = data.select(self.date_col).unique().to_series().to_list()
|
||||||
|
if not dates:
|
||||||
|
return data
|
||||||
|
|
||||||
|
start_date = min(dates)
|
||||||
|
end_date = max(dates)
|
||||||
|
|
||||||
|
# 获取所有股票代码
|
||||||
|
codes = data.select(self.code_col).unique().to_series().to_list()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 通过 data_router 查询 daily_basic 表
|
|
||||||
from src.factors.engine.data_spec import DataSpec
|
from src.factors.engine.data_spec import DataSpec
|
||||||
|
|
||||||
|
# 构建 DataSpec 列表
|
||||||
data_specs = [
|
data_specs = [
|
||||||
DataSpec("daily_basic", [self.selector_config.market_cap_col])
|
DataSpec("daily", list(missing_cols)) # 假设从 daily 表获取
|
||||||
]
|
]
|
||||||
df = self.data_router.fetch_data(
|
|
||||||
|
# 从 data_router 获取数据
|
||||||
|
extra_data = self.data_router.fetch_data(
|
||||||
data_specs=data_specs,
|
data_specs=data_specs,
|
||||||
start_date=date,
|
start_date=start_date,
|
||||||
end_date=date,
|
end_date=end_date,
|
||||||
stock_codes=codes,
|
stock_codes=codes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 转换为字典
|
# 合并到结果
|
||||||
market_caps = {}
|
result = data.join(
|
||||||
for row in df.iter_rows(named=True):
|
extra_data,
|
||||||
code = row[self.code_col]
|
on=[self.code_col, self.date_col],
|
||||||
cap = row.get(self.selector_config.market_cap_col)
|
how="left",
|
||||||
if cap is not None and code in codes:
|
)
|
||||||
market_caps[code] = float(cap)
|
|
||||||
|
|
||||||
return market_caps
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[警告] 获取 {date} 市值数据失败: {e}")
|
print(f"[警告] 获取缺失列失败: {e}")
|
||||||
return {}
|
# 如果获取失败,继续使用现有数据(筛选可能不完全)
|
||||||
|
return data
|
||||||
|
|
||||||
def _select_by_market_cap(
|
def _compute_factors(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
self, codes: List[str], market_caps: Dict[str, float]
|
"""使用 FactorEngine 计算筛选所需的因子
|
||||||
) -> List[str]:
|
|
||||||
"""根据市值选择股票
|
只计算输入数据中不存在的因子,已存在的因子直接使用。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
codes: 股票代码列表
|
data: 当前数据
|
||||||
market_caps: 市值数据字典
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
选中的股票代码列表
|
补充了缺失因子的数据(含临时因子)
|
||||||
"""
|
"""
|
||||||
if self.selector_config is None:
|
existing_cols = set(data.columns)
|
||||||
return codes
|
|
||||||
|
|
||||||
if not market_caps:
|
# 确定需要计算的因子(输入中不存在的)
|
||||||
return codes[: self.selector_config.n]
|
factors_to_compute = {
|
||||||
|
name: expr
|
||||||
|
for name, expr in self.required_factors.items()
|
||||||
|
if name not in existing_cols
|
||||||
|
}
|
||||||
|
|
||||||
# 按市值排序并选择前N只
|
if not factors_to_compute:
|
||||||
sorted_codes = sorted(
|
# 所有因子都已存在,无需计算
|
||||||
codes,
|
return data
|
||||||
key=lambda c: market_caps.get(c, 0),
|
|
||||||
reverse=not self.selector_config.ascending,
|
try:
|
||||||
|
from src.factors import FactorEngine
|
||||||
|
|
||||||
|
# 获取日期范围
|
||||||
|
dates = data.select(self.date_col).unique().to_series().to_list()
|
||||||
|
if not dates:
|
||||||
|
return data
|
||||||
|
|
||||||
|
start_date = min(dates)
|
||||||
|
end_date = max(dates)
|
||||||
|
|
||||||
|
# 创建 FactorEngine 并注册因子
|
||||||
|
engine = FactorEngine()
|
||||||
|
for name, expr in factors_to_compute.items():
|
||||||
|
engine.add_factor(name, expr)
|
||||||
|
|
||||||
|
# 计算因子
|
||||||
|
factor_data = engine.compute(
|
||||||
|
factor_names=list(factors_to_compute.keys()),
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
)
|
)
|
||||||
return sorted_codes[: self.selector_config.n]
|
|
||||||
|
# 合并到数据(左连接,保留所有原始行)
|
||||||
|
result = data.join(
|
||||||
|
factor_data,
|
||||||
|
on=[self.code_col, self.date_col],
|
||||||
|
how="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[警告] 计算因子失败: {e}")
|
||||||
|
# 如果计算失败,继续使用现有数据
|
||||||
|
return data
|
||||||
|
|||||||
@@ -97,13 +97,14 @@ class Trainer:
|
|||||||
print("[筛选] 每日独立筛选股票池...")
|
print("[筛选] 每日独立筛选股票池...")
|
||||||
data = self.pool_manager.filter_and_select_daily(data)
|
data = self.pool_manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
# 2. 划分训练/测试集
|
# 2. 划分训练/验证/测试集(三分法)
|
||||||
if self.splitter:
|
if self.splitter:
|
||||||
print("[划分] 划分训练集和测试集...")
|
print("[划分] 划分训练集、验证集和测试集...")
|
||||||
train_data, test_data = self.splitter.split(data)
|
train_data, val_data, test_data = self.splitter.split(data)
|
||||||
else:
|
else:
|
||||||
# 没有划分器,全部作为训练集
|
# 没有划分器,全部作为训练集
|
||||||
train_data = data
|
train_data = data
|
||||||
|
val_data = data
|
||||||
test_data = data
|
test_data = data
|
||||||
|
|
||||||
# 3. 训练集:processors fit_transform
|
# 3. 训练集:processors fit_transform
|
||||||
|
|||||||
@@ -1,183 +0,0 @@
|
|||||||
"""测试股票池选择器配置
|
|
||||||
|
|
||||||
验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.training.components.selectors import (
|
|
||||||
MarketCapSelectorConfig,
|
|
||||||
StockFilterConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestStockFilterConfig:
|
|
||||||
"""StockFilterConfig 测试类"""
|
|
||||||
|
|
||||||
def test_default_values(self):
|
|
||||||
"""测试默认值"""
|
|
||||||
config = StockFilterConfig()
|
|
||||||
assert config.exclude_cyb is True
|
|
||||||
assert config.exclude_kcb is True
|
|
||||||
assert config.exclude_bj is True
|
|
||||||
assert config.exclude_st is True
|
|
||||||
|
|
||||||
def test_custom_values(self):
|
|
||||||
"""测试自定义值"""
|
|
||||||
config = StockFilterConfig(
|
|
||||||
exclude_cyb=False,
|
|
||||||
exclude_kcb=False,
|
|
||||||
exclude_bj=False,
|
|
||||||
exclude_st=False,
|
|
||||||
)
|
|
||||||
assert config.exclude_cyb is False
|
|
||||||
assert config.exclude_kcb is False
|
|
||||||
assert config.exclude_bj is False
|
|
||||||
assert config.exclude_st is False
|
|
||||||
|
|
||||||
def test_filter_codes_exclude_all(self):
|
|
||||||
"""测试排除所有类型"""
|
|
||||||
config = StockFilterConfig(
|
|
||||||
exclude_cyb=True,
|
|
||||||
exclude_kcb=True,
|
|
||||||
exclude_bj=True,
|
|
||||||
exclude_st=True,
|
|
||||||
)
|
|
||||||
codes = [
|
|
||||||
"000001.SZ", # 主板 - 保留
|
|
||||||
"300001.SZ", # 创业板 - 排除
|
|
||||||
"688001.SH", # 科创板 - 排除
|
|
||||||
"830001.BJ", # 北交所(8开头)- 排除
|
|
||||||
"430001.BJ", # 北交所(4开头)- 排除
|
|
||||||
]
|
|
||||||
result = config.filter_codes(codes)
|
|
||||||
assert result == ["000001.SZ"]
|
|
||||||
|
|
||||||
def test_filter_codes_allow_cyb(self):
|
|
||||||
"""测试允许创业板"""
|
|
||||||
config = StockFilterConfig(
|
|
||||||
exclude_cyb=False,
|
|
||||||
exclude_kcb=True,
|
|
||||||
exclude_bj=True,
|
|
||||||
exclude_st=True,
|
|
||||||
)
|
|
||||||
codes = [
|
|
||||||
"000001.SZ",
|
|
||||||
"300001.SZ",
|
|
||||||
"688001.SH",
|
|
||||||
]
|
|
||||||
result = config.filter_codes(codes)
|
|
||||||
assert result == ["000001.SZ", "300001.SZ"]
|
|
||||||
|
|
||||||
def test_filter_codes_allow_kcb(self):
|
|
||||||
"""测试允许科创板"""
|
|
||||||
config = StockFilterConfig(
|
|
||||||
exclude_cyb=True,
|
|
||||||
exclude_kcb=False,
|
|
||||||
exclude_bj=True,
|
|
||||||
exclude_st=True,
|
|
||||||
)
|
|
||||||
codes = [
|
|
||||||
"000001.SZ",
|
|
||||||
"300001.SZ",
|
|
||||||
"688001.SH",
|
|
||||||
]
|
|
||||||
result = config.filter_codes(codes)
|
|
||||||
assert result == ["000001.SZ", "688001.SH"]
|
|
||||||
|
|
||||||
def test_filter_codes_allow_bj(self):
|
|
||||||
"""测试允许北交所"""
|
|
||||||
config = StockFilterConfig(
|
|
||||||
exclude_cyb=True,
|
|
||||||
exclude_kcb=True,
|
|
||||||
exclude_bj=False,
|
|
||||||
exclude_st=True,
|
|
||||||
)
|
|
||||||
codes = [
|
|
||||||
"000001.SZ",
|
|
||||||
"300001.SZ",
|
|
||||||
"830001.BJ",
|
|
||||||
"430001.BJ",
|
|
||||||
]
|
|
||||||
result = config.filter_codes(codes)
|
|
||||||
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
|
|
||||||
|
|
||||||
def test_filter_codes_allow_all(self):
|
|
||||||
"""测试允许所有类型"""
|
|
||||||
config = StockFilterConfig(
|
|
||||||
exclude_cyb=False,
|
|
||||||
exclude_kcb=False,
|
|
||||||
exclude_bj=False,
|
|
||||||
exclude_st=False,
|
|
||||||
)
|
|
||||||
codes = [
|
|
||||||
"000001.SZ",
|
|
||||||
"300001.SZ",
|
|
||||||
"688001.SH",
|
|
||||||
"830001.BJ",
|
|
||||||
"430001.BJ",
|
|
||||||
]
|
|
||||||
result = config.filter_codes(codes)
|
|
||||||
assert result == codes
|
|
||||||
|
|
||||||
def test_filter_codes_empty_list(self):
|
|
||||||
"""测试空列表"""
|
|
||||||
config = StockFilterConfig()
|
|
||||||
result = config.filter_codes([])
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
def test_filter_codes_no_matching(self):
|
|
||||||
"""测试全部排除"""
|
|
||||||
config = StockFilterConfig()
|
|
||||||
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
|
|
||||||
result = config.filter_codes(codes)
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestMarketCapSelectorConfig:
|
|
||||||
"""MarketCapSelectorConfig 测试类"""
|
|
||||||
|
|
||||||
def test_default_values(self):
|
|
||||||
"""测试默认值"""
|
|
||||||
config = MarketCapSelectorConfig()
|
|
||||||
assert config.enabled is True
|
|
||||||
assert config.n == 100
|
|
||||||
assert config.ascending is False
|
|
||||||
assert config.market_cap_col == "total_mv"
|
|
||||||
|
|
||||||
def test_custom_values(self):
|
|
||||||
"""测试自定义值"""
|
|
||||||
config = MarketCapSelectorConfig(
|
|
||||||
enabled=False,
|
|
||||||
n=50,
|
|
||||||
ascending=True,
|
|
||||||
market_cap_col="circ_mv",
|
|
||||||
)
|
|
||||||
assert config.enabled is False
|
|
||||||
assert config.n == 50
|
|
||||||
assert config.ascending is True
|
|
||||||
assert config.market_cap_col == "circ_mv"
|
|
||||||
|
|
||||||
def test_invalid_n_zero(self):
|
|
||||||
"""测试无效的 n=0"""
|
|
||||||
with pytest.raises(ValueError, match="n 必须是正整数"):
|
|
||||||
MarketCapSelectorConfig(n=0)
|
|
||||||
|
|
||||||
def test_invalid_n_negative(self):
|
|
||||||
"""测试无效的负数 n"""
|
|
||||||
with pytest.raises(ValueError, match="n 必须是正整数"):
|
|
||||||
MarketCapSelectorConfig(n=-1)
|
|
||||||
|
|
||||||
def test_invalid_empty_market_cap_col(self):
|
|
||||||
"""测试空的市值列名"""
|
|
||||||
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
|
|
||||||
MarketCapSelectorConfig(market_cap_col="")
|
|
||||||
|
|
||||||
def test_large_n(self):
|
|
||||||
"""测试大的 n 值"""
|
|
||||||
config = MarketCapSelectorConfig(n=5000)
|
|
||||||
assert config.n == 5000
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
452
tests/training/test_stock_pool_manager.py
Normal file
452
tests/training/test_stock_pool_manager.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
"""测试 StockPoolManager
|
||||||
|
|
||||||
|
验证新的自定义函数和因子筛选功能。
|
||||||
|
重点测试:临时因子隔离(只删除新生成的因子,保留原本存在的)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.training.core.stock_pool_manager import StockPoolManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestStockPoolManagerBasic:
|
||||||
|
"""StockPoolManager 基础测试类"""
|
||||||
|
|
||||||
|
def test_basic_filter_with_columns(self):
|
||||||
|
"""测试使用基础列进行筛选"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["total_mv"] > 50
|
||||||
|
|
||||||
|
# 创建模拟 data_router
|
||||||
|
mock_router = Mock()
|
||||||
|
|
||||||
|
manager = StockPoolManager(
|
||||||
|
filter_func=filter_func,
|
||||||
|
required_columns=["total_mv"],
|
||||||
|
data_router=mock_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ", "000004.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 4,
|
||||||
|
"close": [10.0, 20.0, 30.0, 40.0],
|
||||||
|
"total_mv": [100.0, 30.0, 80.0, 20.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行筛选(无需 mock,因为 total_mv 已在数据中)
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 验证返回数据列与输入一致
|
||||||
|
assert result.columns == data.columns
|
||||||
|
# 验证筛选生效(保留市值 > 50 的股票)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "000001.SZ" in result["ts_code"].to_list()
|
||||||
|
assert "000003.SZ" in result["ts_code"].to_list()
|
||||||
|
|
||||||
|
def test_filter_without_required_columns(self):
|
||||||
|
"""测试不使用额外列,仅使用输入数据中已有的列"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["close"] > 25
|
||||||
|
|
||||||
|
manager = StockPoolManager(filter_func=filter_func)
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"close": [10.0, 30.0, 20.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 验证只保留 close > 25 的股票
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result["ts_code"][0] == "000002.SZ"
|
||||||
|
assert result.columns == data.columns
|
||||||
|
|
||||||
|
def test_empty_result(self):
|
||||||
|
"""测试筛选结果为空的情况"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["close"] > 9999 # 不可能满足的条件
|
||||||
|
|
||||||
|
manager = StockPoolManager(filter_func=filter_func)
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 2,
|
||||||
|
"close": [10.0, 20.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
assert len(result) == 0
|
||||||
|
assert result.columns == data.columns # 即使为空,列结构保持一致
|
||||||
|
|
||||||
|
|
||||||
|
class TestStockPoolManagerDailyIndependence:
|
||||||
|
"""每日独立筛选测试类"""
|
||||||
|
|
||||||
|
def test_daily_independence(self):
|
||||||
|
"""测试每日独立进行筛选"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
# 每日选收盘价前 50%
|
||||||
|
median = df["close"].median()
|
||||||
|
return df["close"] >= median
|
||||||
|
|
||||||
|
manager = StockPoolManager(filter_func=filter_func)
|
||||||
|
|
||||||
|
# 创建多日期数据
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": [
|
||||||
|
"000001.SZ",
|
||||||
|
"000002.SZ",
|
||||||
|
"000003.SZ",
|
||||||
|
"000004.SZ",
|
||||||
|
# 日期 2
|
||||||
|
"000001.SZ",
|
||||||
|
"000002.SZ",
|
||||||
|
"000003.SZ",
|
||||||
|
"000004.SZ",
|
||||||
|
],
|
||||||
|
"trade_date": [
|
||||||
|
"20240101",
|
||||||
|
"20240101",
|
||||||
|
"20240101",
|
||||||
|
"20240101",
|
||||||
|
"20240102",
|
||||||
|
"20240102",
|
||||||
|
"20240102",
|
||||||
|
"20240102",
|
||||||
|
],
|
||||||
|
"close": [
|
||||||
|
10.0,
|
||||||
|
20.0,
|
||||||
|
30.0,
|
||||||
|
40.0, # 日期1:选 30, 40
|
||||||
|
5.0,
|
||||||
|
15.0,
|
||||||
|
25.0,
|
||||||
|
35.0, # 日期2:选 25, 35
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 验证每个日期独立筛选
|
||||||
|
day1 = result.filter(pl.col("trade_date") == "20240101")
|
||||||
|
day2 = result.filter(pl.col("trade_date") == "20240102")
|
||||||
|
|
||||||
|
# 日期1:收盘价 >= 25(中位数)- 30 和 40
|
||||||
|
assert len(day1) == 2
|
||||||
|
assert set(day1["ts_code"].to_list()) == {"000003.SZ", "000004.SZ"}
|
||||||
|
|
||||||
|
# 日期2:收盘价 >= 20(中位数)- 25 和 35
|
||||||
|
assert len(day2) == 2
|
||||||
|
assert set(day2["ts_code"].to_list()) == {"000003.SZ", "000004.SZ"}
|
||||||
|
|
||||||
|
def test_uneven_daily_distribution(self):
|
||||||
|
"""测试每日股票数量不均的情况"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["close"] > 15
|
||||||
|
|
||||||
|
manager = StockPoolManager(filter_func=filter_func)
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": [
|
||||||
|
"000001.SZ",
|
||||||
|
"000002.SZ", # 日期1:2只股票
|
||||||
|
"000001.SZ",
|
||||||
|
"000002.SZ",
|
||||||
|
"000003.SZ",
|
||||||
|
"000004.SZ", # 日期2:4只股票
|
||||||
|
],
|
||||||
|
"trade_date": [
|
||||||
|
"20240101",
|
||||||
|
"20240101",
|
||||||
|
"20240102",
|
||||||
|
"20240102",
|
||||||
|
"20240102",
|
||||||
|
"20240102",
|
||||||
|
],
|
||||||
|
"close": [10.0, 20.0, 5.0, 15.0, 25.0, 35.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 日期1:只有 000002.SZ (20 > 15)
|
||||||
|
day1 = result.filter(pl.col("trade_date") == "20240101")
|
||||||
|
assert len(day1) == 1
|
||||||
|
|
||||||
|
# 日期2:000003.SZ (25) 和 000004.SZ (35)
|
||||||
|
day2 = result.filter(pl.col("trade_date") == "20240102")
|
||||||
|
assert len(day2) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestStockPoolManagerFactorIsolation:
|
||||||
|
"""因子隔离测试类 - 核心测试"""
|
||||||
|
|
||||||
|
@patch.object(StockPoolManager, "_compute_factors")
|
||||||
|
def test_filter_with_factors(self, mock_compute):
|
||||||
|
"""测试使用因子表达式进行筛选,验证临时因子被删除"""
|
||||||
|
|
||||||
|
# 设置 mock 返回值(包含计算后的因子)
|
||||||
|
mock_compute.return_value = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"close": [11.0, 9.5, 10.8],
|
||||||
|
"momentum_20": [0.1, -0.05, 0.08], # 只有第一个 > 0.05
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 Manager
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["momentum_20"] > 0.05
|
||||||
|
|
||||||
|
manager = StockPoolManager(
|
||||||
|
filter_func=filter_func,
|
||||||
|
required_factors={"momentum_20": "(close / ts_delay(close, 20)) - 1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输入数据不含 momentum_20
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"close": [11.0, 9.5, 10.8],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 验证返回数据列与输入一致(momentum_20 被删除)
|
||||||
|
assert result.columns == data.columns
|
||||||
|
assert "momentum_20" not in result.columns
|
||||||
|
|
||||||
|
# 验证筛选生效
|
||||||
|
# momentum_20 > 0.05: 000001.SZ (0.1), 000003.SZ (0.08)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "000001.SZ" in result["ts_code"].to_list()
|
||||||
|
assert "000003.SZ" in result["ts_code"].to_list()
|
||||||
|
|
||||||
|
# 验证 _compute_factors 被调用
|
||||||
|
mock_compute.assert_called_once()
|
||||||
|
|
||||||
|
@patch.object(StockPoolManager, "_compute_factors")
|
||||||
|
def test_preserve_existing_factors(self, mock_compute):
|
||||||
|
"""测试输入中已存在的因子不会被删除(核心测试)"""
|
||||||
|
|
||||||
|
# 设置 mock 返回值(包含 roe,但 momentum_20 已在输入中)
|
||||||
|
mock_compute.return_value = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"close": [11.0, 9.5, 10.8],
|
||||||
|
"momentum_20": [0.1, -0.05, 0.08], # 原本就存在
|
||||||
|
"roe": [0.12, 0.08, 0.15], # 本次生成,第二个 < 0.1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return (df["momentum_20"] > 0.05) & (df["roe"] > 0.1)
|
||||||
|
|
||||||
|
manager = StockPoolManager(
|
||||||
|
filter_func=filter_func,
|
||||||
|
# 只声明 roe 为本次生成,momentum_20 已在输入中
|
||||||
|
required_factors={"roe": "n_income / equity"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输入数据已包含 momentum_20
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"close": [11.0, 9.5, 10.8],
|
||||||
|
"momentum_20": [0.1, -0.05, 0.08], # 原本就存在的因子
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 关键断言:
|
||||||
|
# 1. momentum_20 保留(原本存在)
|
||||||
|
assert "momentum_20" in result.columns
|
||||||
|
# 2. roe 删除(本次生成)
|
||||||
|
assert "roe" not in result.columns
|
||||||
|
# 3. 列与输入完全一致
|
||||||
|
assert result.columns == data.columns
|
||||||
|
|
||||||
|
# 验证筛选正确执行
|
||||||
|
# momentum_20 > 0.05: 000001.SZ (0.1), 000003.SZ (0.08)
|
||||||
|
# roe > 0.1: 000001.SZ (0.12), 000003.SZ (0.15)
|
||||||
|
# 交集:000001.SZ, 000003.SZ
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
# 验证 _compute_factors 被调用(因为 roe 不存在)
|
||||||
|
mock_compute.assert_called_once()
|
||||||
|
|
||||||
|
def test_no_factor_computation_when_all_exist(self):
|
||||||
|
"""测试所有因子都已存在时,不调用 FactorEngine"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return (df["factor_a"] > 0.5) & (df["factor_b"] < 0.3)
|
||||||
|
|
||||||
|
manager = StockPoolManager(
|
||||||
|
filter_func=filter_func,
|
||||||
|
required_factors={
|
||||||
|
"factor_a": "some_expr_a",
|
||||||
|
"factor_b": "some_expr_b",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输入数据已包含所有因子
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 2,
|
||||||
|
"close": [10.0, 20.0],
|
||||||
|
"factor_a": [0.6, 0.4], # 原本存在
|
||||||
|
"factor_b": [0.2, 0.5], # 原本存在
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.factors.engine.factor_engine.FactorEngine") as mock_engine:
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# FactorEngine 不应被调用(所有因子都已存在)
|
||||||
|
mock_engine.assert_not_called()
|
||||||
|
|
||||||
|
# 验证结果正确
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result["ts_code"][0] == "000001.SZ"
|
||||||
|
assert result.columns == data.columns
|
||||||
|
|
||||||
|
|
||||||
|
class TestStockPoolManagerEdgeCases:
|
||||||
|
"""边界情况测试类"""
|
||||||
|
|
||||||
|
def test_single_date(self):
|
||||||
|
"""测试单日数据"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["close"] > 15
|
||||||
|
|
||||||
|
manager = StockPoolManager(filter_func=filter_func)
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"close": [10.0, 20.0, 30.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result.columns == data.columns
|
||||||
|
|
||||||
|
def test_single_stock_per_day(self):
|
||||||
|
"""测试每天只有一只股票"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["close"] > 0 # 都保留
|
||||||
|
|
||||||
|
manager = StockPoolManager(filter_func=filter_func)
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"trade_date": ["20240101", "20240102", "20240103"],
|
||||||
|
"close": [10.0, 20.0, 30.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result.columns == data.columns
|
||||||
|
|
||||||
|
def test_filter_all_out_one_day(self):
|
||||||
|
"""测试某天全部过滤掉"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["close"] > 100 # 很高的阈值
|
||||||
|
|
||||||
|
manager = StockPoolManager(filter_func=filter_func)
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": [
|
||||||
|
"000001.SZ",
|
||||||
|
"000002.SZ",
|
||||||
|
"000001.SZ",
|
||||||
|
"000002.SZ",
|
||||||
|
],
|
||||||
|
"trade_date": [
|
||||||
|
"20240101",
|
||||||
|
"20240101",
|
||||||
|
"20240102",
|
||||||
|
"20240102",
|
||||||
|
],
|
||||||
|
"close": [
|
||||||
|
10.0,
|
||||||
|
20.0, # 日期1:都过滤掉
|
||||||
|
150.0,
|
||||||
|
200.0, # 日期2:都保留
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 只有日期2的数据
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(result["trade_date"] == "20240102")
|
||||||
|
|
||||||
|
def test_column_order_preserved(self):
|
||||||
|
"""测试列顺序保持不变"""
|
||||||
|
|
||||||
|
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||||
|
return df["close"] > 15
|
||||||
|
|
||||||
|
manager = StockPoolManager(filter_func=filter_func)
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"open": [9.0, 19.0, 29.0],
|
||||||
|
"high": [11.0, 21.0, 31.0],
|
||||||
|
"low": [8.0, 18.0, 28.0],
|
||||||
|
"close": [10.0, 20.0, 30.0],
|
||||||
|
"volume": [1000, 2000, 3000],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 验证列顺序完全一致
|
||||||
|
assert list(result.columns) == list(data.columns)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user