feat(factorminer): 支持多股票池因子评估

- 新增 StockPoolRegistry 与默认股票池定义(all/small_cap 等)
- RalphLoop 与 ValidationPipeline 支持按多股票池计算 IC
- Factor 序列化新增 pool_metrics 字段
- LocalFactorEvaluator 增加 evaluate_returns_by_pool 方法
- 主入口集成股票池配置与 ST 预过滤
This commit is contained in:
2026-04-12 02:26:29 +08:00
parent 613223edd6
commit 521609e46a
10 changed files with 1806 additions and 27 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -42,6 +42,7 @@ class Factor:
batch_number: int # Which mining batch admitted this factor
admission_date: str = ""
signals: Optional[np.ndarray] = field(default=None, repr=False) # (M, T)
pool_metrics: Dict[str, dict] = field(default_factory=dict)
research_metrics: dict = field(default_factory=dict)
provenance: dict = field(default_factory=dict)
metadata: dict = field(default_factory=dict)
@@ -50,6 +51,20 @@ class Factor:
if not self.admission_date:
self.admission_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@staticmethod
def _sanitize_pool_metrics(pool_metrics: Dict[str, dict]) -> Dict[str, dict]:
sanitized: Dict[str, dict] = {}
for pool_name, stats in pool_metrics.items():
sanitized[pool_name] = {}
for key, value in stats.items():
if key == "ic_series":
continue
if isinstance(value, np.ndarray):
sanitized[pool_name][key] = value.tolist()
else:
sanitized[pool_name][key] = value
return sanitized
def to_dict(self) -> dict:
"""Serialize to a JSON-compatible dictionary (excludes signals)."""
return {
@@ -63,6 +78,7 @@ class Factor:
"max_correlation": self.max_correlation,
"batch_number": self.batch_number,
"admission_date": self.admission_date,
"pool_metrics": self._sanitize_pool_metrics(self.pool_metrics),
"research_metrics": self.research_metrics,
"provenance": self.provenance,
"metadata": self.metadata,
@@ -82,6 +98,7 @@ class Factor:
max_correlation=d["max_correlation"],
batch_number=d["batch_number"],
admission_date=d.get("admission_date", ""),
pool_metrics=d.get("pool_metrics", {}),
research_metrics=d.get("research_metrics", {}),
provenance=d.get("provenance", {}),
metadata=d.get("metadata", {}),

View File

@@ -18,7 +18,7 @@ import logging
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import numpy as np
@@ -229,7 +229,7 @@ class HelixLoop(RalphLoop):
def __init__(
self,
config: Any,
returns: np.ndarray,
returns: Union[np.ndarray, Dict[str, np.ndarray]],
llm_provider: Optional[LLMProvider] = None,
memory: Optional[ExperienceMemory] = None,
library: Optional[FactorLibrary] = None,

View File

@@ -257,9 +257,13 @@ class ValidationPipeline:
evaluator: Any = None,
) -> None:
self.data_tensor = data_tensor # (M, T, F)
self.returns = returns # (M, T)
if isinstance(returns, dict):
self.returns = returns.get("all", next(iter(returns.values())))
self.target_panels = returns
else:
self.returns = returns # (M, T)
self.target_panels = target_panels or {"paper": returns}
self.evaluator = evaluator
self.target_panels = target_panels or {"paper": returns}
self.target_horizons = target_horizons or {"paper": 1}
self.library = library or FactorLibrary(
correlation_threshold=0.5,
@@ -353,20 +357,18 @@ class ValidationPipeline:
result.stage_passed = 0
return result
# Full IC statistics on all assets
stats = compute_factor_stats(signals, self.returns)
result.ic_mean = stats["ic_abs_mean"]
result.icir = stats["icir"]
result.ic_win_rate = stats["ic_win_rate"]
result.target_stats = {"paper": stats}
# Full IC statistics on all target panels (stock pools)
all_stats: Dict[str, dict] = {}
for panel_name, panel_returns in self.target_panels.items():
all_stats[panel_name] = compute_factor_stats(signals, panel_returns)
if self.target_panels:
for target_name, target_returns in self.target_panels.items():
if target_name == "paper":
continue
result.target_stats[target_name] = compute_factor_stats(
signals, target_returns
)
best_panel_name = max(all_stats, key=lambda k: all_stats[k]["ic_abs_mean"])
best_stats = all_stats[best_panel_name]
result.ic_mean = best_stats["ic_abs_mean"]
result.icir = best_stats["icir"]
result.ic_win_rate = best_stats["ic_win_rate"]
result.target_stats = all_stats
score_vector_obj = None
if self._research_enabled():
@@ -737,7 +739,7 @@ class RalphLoop:
def __init__(
self,
config: Any,
returns: np.ndarray,
returns: Union[np.ndarray, Dict[str, np.ndarray]],
llm_provider: Optional[LLMProvider] = None,
memory: Optional[ExperienceMemory] = None,
library: Optional[FactorLibrary] = None,
@@ -767,7 +769,12 @@ class RalphLoop:
the legacy data_tensor path.
"""
self.config = config
self.returns = returns
if isinstance(returns, dict):
self.returns = returns.get("all", next(iter(returns.values())))
self.target_panels = returns
else:
self.returns = returns
self.target_panels = getattr(config, "target_panels", None)
self.evaluator = evaluator
self.checkpoint_interval = checkpoint_interval
@@ -783,9 +790,9 @@ class RalphLoop:
prompt_builder=PromptBuilder(),
)
self.pipeline = ValidationPipeline(
data_tensor=np.empty((returns.shape[0], returns.shape[1], 1)),
returns=returns,
target_panels=getattr(config, "target_panels", None),
data_tensor=np.empty((self.returns.shape[0], self.returns.shape[1], 1)),
returns=self.returns,
target_panels=self.target_panels,
target_horizons=getattr(config, "target_horizons", None),
library=self.library,
ic_threshold=getattr(config, "ic_threshold", 0.04),
@@ -1082,6 +1089,14 @@ class RalphLoop:
# Handle replacement
if result.replaced is not None:
old_id = result.replaced
if result.target_stats:
target_stats_clean = {}
for pool_name, stats in result.target_stats.items():
target_stats_clean[pool_name] = {
k: v for k, v in stats.items() if k != "ic_series"
}
else:
target_stats_clean = {}
new_factor = Factor(
id=0, # Will be reassigned by library
name=result.factor_name,
@@ -1093,6 +1108,7 @@ class RalphLoop:
max_correlation=result.max_correlation,
batch_number=self.iteration,
signals=result.signals,
pool_metrics=target_stats_clean,
research_metrics=result.score_vector or {},
)
try:
@@ -1110,6 +1126,14 @@ class RalphLoop:
)
else:
# Direct admission
if result.target_stats:
target_stats_clean = {}
for pool_name, stats in result.target_stats.items():
target_stats_clean[pool_name] = {
k: v for k, v in stats.items() if k != "ic_series"
}
else:
target_stats_clean = {}
factor = Factor(
id=0, # Will be reassigned
name=result.factor_name,
@@ -1121,6 +1145,7 @@ class RalphLoop:
max_correlation=result.max_correlation,
batch_number=self.iteration,
signals=result.signals,
pool_metrics=target_stats_clean,
research_metrics=result.score_vector or {},
)
self.library.admit_factor(factor)
@@ -1435,7 +1460,7 @@ class RalphLoop:
cls,
checkpoint_path: str,
config: Any,
returns: np.ndarray,
returns: Union[np.ndarray, Dict[str, np.ndarray]],
llm_provider: Optional[LLMProvider] = None,
**kwargs: Any,
) -> "RalphLoop":

View File

@@ -19,6 +19,7 @@ import numpy as np
import polars as pl
from src.factors import FactorEngine
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
class LocalFactorEvaluator:
@@ -40,6 +41,7 @@ class LocalFactorEvaluator:
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
stock_pool_registry: Optional[StockPoolRegistry] = None,
) -> None:
"""初始化评估器。
@@ -47,11 +49,14 @@ class LocalFactorEvaluator:
start_date: 计算开始日期YYYYMMDD 格式
end_date: 计算结束日期YYYYMMDD 格式
stock_codes: 可选的股票代码列表None 表示全量
stock_pool_registry: 股票池注册表,用于多股票池评估
"""
self.start_date = start_date
self.end_date = end_date
self.stock_codes = stock_codes
self.engine = FactorEngine()
self.stock_pool_registry = stock_pool_registry
self._asset_codes: Optional[List[str]] = None
def evaluate(
self,
@@ -166,6 +171,76 @@ class LocalFactorEvaluator:
print(f"[ERROR] 计算收益率矩阵失败: {e}")
raise
def get_asset_codes(self) -> List[str]:
"""返回最近一次 evaluate/evaluate_returns 中的资产代码列表。
Returns:
按字母序排列的股票代码列表 (M,)。
"""
if self._asset_codes is None:
raise RuntimeError("请先调用 evaluate() 或 evaluate_returns()")
return self._asset_codes
def _get_metadata_df(self, columns: List[str]) -> Optional[pl.DataFrame]:
"""拉取构建股票池所需的元数据(默认取 end_date 最新截面)。
Args:
columns: 需要的额外列名列表。
Returns:
包含 ts_code 和 columns 的 DataFrame若失败则返回 None。
"""
if not columns:
return None
try:
df = self.engine.router._load_table(
table_name="daily_basic",
columns=columns,
start_date=self.end_date,
end_date=self.end_date,
stock_codes=self.stock_codes,
)
if "trade_date" in df.columns:
df = df.sort("trade_date", descending=True)
df = df.unique(subset=["ts_code"], maintain_order=True)
return df.select(["ts_code"] + columns)
except Exception as exc:
print(f"[WARN] 拉取股票池元数据失败: {exc}")
return None
def evaluate_returns_by_pool(self, periods: int = 1) -> Dict[str, np.ndarray]:
"""计算各股票池的收益率矩阵。
先计算全市场收益率矩阵,然后根据注册的股票池掩码生成子池矩阵。
未入选该池的资产在所有时间上的收益率标记为 NaN。
Args:
periods: 计算 N 日后的收益率,默认 1。
Returns:
{pool_name: (M, T) np.ndarray} 字典,包含所有注册的股票池和 "all"
"""
returns_all = self.evaluate_returns(periods=periods)
result: Dict[str, np.ndarray] = {"all": returns_all}
if self.stock_pool_registry is None:
return result
codes = self.get_asset_codes()
req_cols = self.stock_pool_registry.get_required_columns()
metadata = self._get_metadata_df(req_cols)
self.stock_pool_registry.build_masks(codes, metadata_df=metadata)
for name in self.stock_pool_registry.get_pool_names():
if name == "all":
continue
mask = self.stock_pool_registry.masks[name]
pool_returns = np.full_like(returns_all, np.nan, dtype=np.float64)
pool_returns[mask, :] = returns_all[mask, :]
result[name] = pool_returns
return result
def _pivot_to_matrix(
self,
df: pl.DataFrame,
@@ -191,6 +266,8 @@ class LocalFactorEvaluator:
# 获取时间戳和股票代码的唯一值(已排序)
timestamps = df["trade_date"].unique().sort()
asset_codes = df["ts_code"].unique().sort()
if self._asset_codes is None:
self._asset_codes = asset_codes.to_list()
n_assets = len(asset_codes)
n_times = len(timestamps)

View File

@@ -0,0 +1,150 @@
"""股票池注册表:支持配置化股票池筛选与掩码生成。"""
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional
import numpy as np
import polars as pl
@dataclass
class StockPoolDefinition:
"""股票池定义,参考 experiment/common.py 中的 filter 模式。"""
name: str
filter_func: Callable[[pl.DataFrame], pl.Series | pl.Expr]
required_columns: List[str] = field(default_factory=list)
description: str = ""
class StockPoolRegistry:
"""管理多个股票池的定义,并为 (M,) 资产列表生成布尔掩码。"""
def __init__(
self,
definitions: Optional[List[StockPoolDefinition]] = None,
pre_filters: Optional[
List[Callable[[pl.DataFrame], pl.Series | pl.Expr]]
] = None,
) -> None:
self.pre_filters: List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] = (
pre_filters or []
)
self.pools: Dict[str, dict] = {}
self.masks: Dict[str, np.ndarray] = {}
self._resolved = False
if definitions:
for d in definitions:
self.add_pool(d.name, d.filter_func, d.required_columns)
@classmethod
def from_definitions(
cls,
definitions: List[StockPoolDefinition],
pre_filters: Optional[
List[Callable[[pl.DataFrame], pl.Series | pl.Expr]]
] = None,
) -> "StockPoolRegistry":
"""通过定义列表快速创建注册表。"""
return cls(definitions=definitions, pre_filters=pre_filters)
def add_pool(
self,
name: str,
filter_func: Callable[[pl.DataFrame], pl.Series | pl.Expr],
required_columns: Optional[List[str]] = None,
) -> None:
"""注册一个股票池。
Args:
name: 股票池名称,如 "small_cap"
filter_func: 接收包含 ts_code 和 required_columns 的 DataFrame
返回布尔 Series 或 Expr。
required_columns: filter_func 额外需要的列名列表。
"""
self.pools[name] = {
"filter_func": filter_func,
"required_columns": required_columns or [],
}
self._resolved = False
@staticmethod
def _eval_filter_result(
df: pl.DataFrame,
result: pl.Series | pl.Expr,
context: str,
) -> np.ndarray:
"""将 filter_func 的返回值统一转换为 numpy bool 数组。"""
if isinstance(result, pl.Series):
mask_series = result
elif isinstance(result, pl.Expr):
mask_series = df.select(result.alias("_mask")).to_series()
else:
raise TypeError(
f"{context} 的 filter_func 必须返回 pl.Series 或 pl.Expr"
f"实际返回 {type(result)}"
)
return mask_series.to_numpy().astype(bool)
def build_masks(
self,
asset_codes: List[str],
metadata_df: Optional[pl.DataFrame] = None,
) -> Dict[str, np.ndarray]:
"""为所有注册的股票池生成布尔掩码。
Args:
asset_codes: 按矩阵顺序排列的股票代码列表 (M,)。
metadata_df: 包含 extra columns 的 metadata可选。
Returns:
{pool_name: bool_array_of_shape_(M,)} 字典。
"""
base_df = pl.DataFrame({"ts_code": asset_codes})
df = base_df
if metadata_df is not None:
df = base_df.join(metadata_df, on="ts_code", how="left")
base_mask = np.ones(len(asset_codes), dtype=bool)
for pre_filter in self.pre_filters:
result = pre_filter(df)
base_mask &= self._eval_filter_result(df, result, "pre_filter")
self.masks = {}
for name, cfg in self.pools.items():
result = cfg["filter_func"](df)
pool_mask = self._eval_filter_result(df, result, f"股票池 '{name}'")
self.masks[name] = base_mask & pool_mask
self._resolved = True
return self.masks
def filter_signals(self, signals: np.ndarray, pool_name: str) -> np.ndarray:
"""使用已构建的掩码过滤信号矩阵。
Args:
signals: (M, T) 信号矩阵。
pool_name: 股票池名称。
Returns:
(M_pool, T) 的子矩阵。
"""
if not self._resolved:
raise RuntimeError("请先调用 build_masks 构建掩码")
mask = self.masks.get(pool_name)
if mask is None:
raise KeyError(f"未知的股票池: {pool_name}")
return signals[mask, :]
def get_pool_names(self) -> List[str]:
"""返回已注册的股票池名称列表。"""
return list(self.pools.keys())
def get_required_columns(self) -> List[str]:
"""返回所有 pre_filter 和股票池所需的列名并集。"""
cols: set[str] = set()
for pre_filter in self.pre_filters:
cols.update(getattr(pre_filter, "required_columns", []))
for cfg in self.pools.values():
cols.update(cfg["required_columns"])
return sorted(cols)

View File

@@ -11,6 +11,8 @@ from typing import Any, Optional
import numpy as np
import polars as pl
from src.config.settings import get_settings
from src.factorminer.agent.llm_interface import create_provider, AnthropicProvider
from src.factorminer.core.config import MiningConfig as CoreMiningConfig
@@ -19,7 +21,10 @@ from src.factorminer.core.ralph_loop import RalphLoop
from src.factorminer.core.helix_loop import HelixLoop
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
from src.factorminer.evaluation.significance import SignificanceConfig
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
from src.factorminer.pool_definitions import get_default_stock_pool_definitions
from src.factorminer.utils.config import load_config
from src.training.components.filters import STFilter
RUN_CONFIG: dict = {
# 全局开关
@@ -35,9 +40,16 @@ RUN_CONFIG: dict = {
"stock_codes": None, # 可选股票列表None 表示全量
# 种子库
"seed_paper_library": True, # 是否预加载 110 Paper Factors 作为种子库
# 股票池配置(支持多股票池评估)
"stock_pools": {
"enabled": True, # 是否启用多股票池
"provider": "default", # "default" 使用 pool_definitions.py 中的默认定义
"use_st_filter": True, # 是否在构建股票池前过滤 ST 股票
"pools": ["all", "small_cap"], # 要启用的股票池名称列表None 表示全部启用
},
# 挖掘配置(覆盖 default.yaml 中的同名项)
"mining": {
"max_iterations": 3,
"max_iterations": 1,
"target_library_size": 10,
"correlation_threshold": 0.50,
"ic_threshold": 0.04,
@@ -139,6 +151,54 @@ def _build_core_mining_config(run_cfg: dict) -> CoreMiningConfig:
return cfg
def _build_st_pre_filter(st_filter: STFilter, end_date: str):
"""将 STFilter 包装为 StockPoolRegistry 可接受的 pre-filter 函数。"""
def _pre_filter(df: pl.DataFrame) -> pl.Series:
try:
if "trade_date" not in df.columns:
df = df.with_columns(pl.lit(end_date).alias("trade_date"))
filtered = st_filter.filter(df)
return df["ts_code"].is_in(filtered["ts_code"])
except Exception as exc:
print(f"[WARN] ST pre-filter 失败: {exc}")
return pl.Series([True] * len(df))
return _pre_filter
def _build_stock_pool_registry(
run_cfg: dict,
st_filter: Optional[STFilter] = None,
end_date: str = "20231231",
) -> Optional[StockPoolRegistry]:
"""根据 RUN_CONFIG 构建股票池注册表。"""
pool_cfg = run_cfg.get("stock_pools", {})
if not pool_cfg.get("enabled", False):
return None
pre_filters = []
if pool_cfg.get("use_st_filter", False) and st_filter is not None:
pre_filters.append(_build_st_pre_filter(st_filter, end_date))
provider = pool_cfg.get("provider", "default")
if provider == "default":
definitions = get_default_stock_pool_definitions()
pool_names = pool_cfg.get("pools")
if pool_names is not None:
definitions = [d for d in definitions if d.name in pool_names]
if not definitions:
raise ValueError(
f"股票池配置错误: pools={pool_names} 未匹配到任何默认股票池定义"
)
missing = set(pool_names) - {d.name for d in definitions}
if missing:
print(f"[WARN] 以下股票池未在默认定义中找到,已忽略: {sorted(missing)}")
return StockPoolRegistry.from_definitions(definitions, pre_filters=pre_filters)
raise ValueError(f"不支持的股票池 provider: {provider}")
def _build_helix_kwargs(run_cfg: dict) -> dict:
"""从 RUN_CONFIG 构建 HelixLoop 需要的 Phase 2 扩展配置。"""
helix = run_cfg.get("helix", {})
@@ -215,15 +275,34 @@ def main(config: dict | None = None) -> None:
end_date = run_cfg.get("end_date", "20201231")
stock_codes = run_cfg.get("stock_codes")
# 4.1 创建 evaluator先实例化才能拿到 router
evaluator = LocalFactorEvaluator(
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
returns = evaluator.evaluate_returns(periods=1)
print(
f"[main] 本地数据范围: {start_date} ~ {end_date}, returns shape: {returns.shape}"
# 4.2 构建 STFilter 和股票池注册表
st_filter = STFilter(data_router=evaluator.engine.router)
stock_pool_registry = _build_stock_pool_registry(
run_cfg, st_filter=st_filter, end_date=end_date
)
evaluator.stock_pool_registry = stock_pool_registry
if stock_pool_registry is not None:
returns = evaluator.evaluate_returns_by_pool(periods=1)
for pool_name, pool_ret in returns.items():
valid_assets = int((~np.isnan(pool_ret).all(axis=1)).sum())
print(
f"[main] 股票池 {pool_name}: {valid_assets} 只资产, "
f"returns shape: {pool_ret.shape}"
)
else:
returns = evaluator.evaluate_returns(periods=1)
print(
f"[main] 本地数据范围: {start_date} ~ {end_date}, "
f"returns shape: {returns.shape}"
)
# ------------------------------------------------------------------
# 5. 构建 MiningConfig

View File

@@ -0,0 +1,171 @@
"""用户可配置的股票池定义文件。
参考 `src/experiment/common.py` 中的 `stock_pool_filter` 模式,
支持通过编写 Polars filter 函数来自定义股票池。
"""
from typing import Callable, List
import polars as pl
from src.factorminer.evaluation.stock_pool_registry import StockPoolDefinition, StockPoolRegistry
# =============================================================================
# 股票池大小配置
# =============================================================================
SMALL_CAP_N = 1000 # 小微盘股票池默认选取的股票数量
# =============================================================================
# 股票池筛选函数(与 experiment/common.py 风格一致)
# =============================================================================
def all_market_filter(df: pl.DataFrame) -> pl.Series:
"""全市场股票池。"""
return pl.Series([True] * len(df))
def growth_board_filter(df: pl.DataFrame) -> pl.Series:
"""创业板股票池(代码以 300 开头)。"""
return df["ts_code"].str.starts_with("300")
def star_board_filter(df: pl.DataFrame) -> pl.Series:
"""科创板股票池(代码以 688 开头)。"""
return df["ts_code"].str.starts_with("688")
def bse_board_filter(df: pl.DataFrame) -> pl.Series:
"""北交所股票池(代码以 8 或 4 开头)。"""
return df["ts_code"].str.starts_with("8") | df["ts_code"].str.starts_with("4")
def main_board_filter(df: pl.DataFrame) -> pl.Series:
"""主板股票池(排除创业板、科创板、北交所)。"""
return (
~df["ts_code"].str.starts_with("300")
& ~df["ts_code"].str.starts_with("688")
& ~df["ts_code"].str.starts_with("8")
& ~df["ts_code"].str.starts_with("4")
)
def small_cap_filter(df: pl.DataFrame, n_stocks: int = SMALL_CAP_N) -> pl.Series:
"""小微盘股票池(市值最小的 n_stocks 只股票)。
Args:
df: 包含 ts_code 和 circ_mv 列的数据框。
n_stocks: 选取的股票数量,默认 SMALL_CAP_N。
Returns:
布尔 Series表示是否入选小微盘股票池。
"""
if "circ_mv" not in df.columns:
# 若缺失 circ_mv 数据则全部排除(安全降级)
return pl.Series([False] * len(df))
n = min(n_stocks, len(df))
small_codes = df.sort("circ_mv").head(n)["ts_code"]
return df["ts_code"].is_in(small_codes.implode())
def main_board_small_cap_filter(
df: pl.DataFrame, n_stocks: int = SMALL_CAP_N
) -> pl.Series:
"""主板小微盘股票池(主板中市值最小的 n_stocks 只股票)。"""
main_board = main_board_filter(df)
main_df = df.filter(main_board)
if "circ_mv" not in df.columns or len(main_df) == 0:
return pl.Series([False] * len(df))
n = min(n_stocks, len(main_df))
small_codes = main_df.sort("circ_mv").head(n)["ts_code"]
return df["ts_code"].is_in(small_codes.implode())
def common_small_cap_filter(df: pl.DataFrame, n_stocks: int = SMALL_CAP_N) -> pl.Series:
"""与 experiment/common.py 完全一致的小微盘筛选函数。
筛选条件:
1. 排除创业板(代码以 300 开头)
2. 排除科创板(代码以 688 开头)
3. 排除北交所(代码以 8、9 或 4 开头)
4. 选取当日流通市值最小的 n_stocks 只股票
Args:
df: 数据框,必须包含 ts_code 和 circ_mv 列
n_stocks: 选取的股票数量,默认 SMALL_CAP_N
Returns:
布尔 Series表示哪些股票被选中
"""
code_filter = (
~df["ts_code"].str.starts_with("30") # 排除创业板
& ~df["ts_code"].str.starts_with("68") # 排除科创板
& ~df["ts_code"].str.starts_with("8") # 排除北交所
& ~df["ts_code"].str.starts_with("9") # 排除北交所
& ~df["ts_code"].str.starts_with("4") # 排除北交所
)
valid_df = df.filter(code_filter)
n = min(n_stocks, len(valid_df))
small_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
return df["ts_code"].is_in(small_codes.implode())
# =============================================================================
# 默认股票池定义列表
# =============================================================================
def get_default_stock_pool_definitions() -> List[StockPoolDefinition]:
"""返回默认的股票池定义列表。"""
return [
StockPoolDefinition(
name="all",
filter_func=all_market_filter,
required_columns=[],
description="全市场",
),
StockPoolDefinition(
name="growth_board",
filter_func=growth_board_filter,
required_columns=[],
description="创业板",
),
StockPoolDefinition(
name="star_board",
filter_func=star_board_filter,
required_columns=[],
description="科创板",
),
StockPoolDefinition(
name="bse_board",
filter_func=bse_board_filter,
required_columns=[],
description="北交所",
),
StockPoolDefinition(
name="main_board",
filter_func=main_board_filter,
required_columns=[],
description="主板",
),
StockPoolDefinition(
name="small_cap",
filter_func=small_cap_filter,
required_columns=["circ_mv"],
description="小微盘",
),
StockPoolDefinition(
name="main_board_small_cap",
filter_func=main_board_small_cap_filter,
required_columns=["circ_mv"],
description="主板小微盘",
),
]
def get_default_stock_pools(
pre_filters: List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] | None = None,
) -> StockPoolRegistry:
"""返回默认的股票池注册表。"""
return StockPoolRegistry.from_definitions(
get_default_stock_pool_definitions(),
pre_filters=pre_filters,
)

View File

@@ -21,6 +21,7 @@ from src.factorminer.evaluation.metrics import (
# Helpers
# ---------------------------------------------------------------------------
@pytest.fixture
def rng():
return np.random.default_rng(123)
@@ -58,6 +59,7 @@ def known_quintile_signal(rng):
# IC computation
# ---------------------------------------------------------------------------
class TestIC:
"""Test Information Coefficient computation."""
@@ -106,6 +108,7 @@ class TestIC:
# ICIR computation
# ---------------------------------------------------------------------------
class TestICIR:
"""Test ICIR = mean(IC) / std(IC)."""
@@ -142,6 +145,7 @@ class TestICIR:
# IC-derived statistics
# ---------------------------------------------------------------------------
class TestICStats:
"""Test IC mean and win rate."""
@@ -170,6 +174,7 @@ class TestICStats:
# Pairwise correlation
# ---------------------------------------------------------------------------
class TestPairwiseCorrelation:
"""Test pairwise cross-sectional correlation."""
@@ -207,6 +212,7 @@ class TestPairwiseCorrelation:
# Quintile returns
# ---------------------------------------------------------------------------
class TestQuintileReturns:
"""Test quintile return computation."""
@@ -243,6 +249,7 @@ class TestQuintileReturns:
# Turnover
# ---------------------------------------------------------------------------
class TestTurnover:
"""Test portfolio turnover computation."""
@@ -263,6 +270,7 @@ class TestTurnover:
# Comprehensive factor stats
# ---------------------------------------------------------------------------
class TestFactorStats:
"""Test the compute_factor_stats wrapper."""
@@ -285,3 +293,18 @@ class TestFactorStats:
returns = rng.normal(0, 0.01, (M, T))
stats = compute_factor_stats(signals, returns)
assert stats["ic_series"].shape == (T,)
class TestFactorStatsPools:
"""Test compute_factor_stats under multi-stock-pool NaN scenarios."""
def test_pool_with_nans(self, rng):
M, T = 40, 30
signals = rng.normal(0, 1, (M, T))
returns = rng.normal(0, 0.01, (M, T))
# 模拟子池:部分资产在所有时间上为 NaN
signals[30:, :] = np.nan
returns[30:, :] = np.nan
stats = compute_factor_stats(signals, returns)
assert stats["n_periods"] == T
assert np.isfinite(stats["ic_mean"])

View File

@@ -0,0 +1,72 @@
import numpy as np
import polars as pl
import pytest
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
class TestStockPoolRegistry:
def test_add_and_get_pool_names(self):
registry = StockPoolRegistry()
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
assert registry.get_pool_names() == ["all", "growth"]
def test_build_masks(self):
registry = StockPoolRegistry()
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
masks = registry.build_masks(asset_codes)
assert masks["all"].sum() == 3
assert masks["growth"].sum() == 1
assert masks["growth"][1]
def test_filter_signals(self):
registry = StockPoolRegistry()
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
registry.build_masks(asset_codes)
signals = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float64)
filtered = registry.filter_signals(signals, "growth")
assert filtered.shape == (1, 2)
np.testing.assert_array_equal(filtered, [[3, 4]])
def test_build_masks_with_metadata(self):
registry = StockPoolRegistry()
registry.add_pool(
"small_cap",
lambda df: df["ts_code"].is_in(
df.sort("total_mv").head(2)["ts_code"].implode()
),
required_columns=["total_mv"],
)
asset_codes = ["A", "B", "C"]
metadata = pl.DataFrame(
{
"ts_code": ["A", "B", "C"],
"total_mv": [300.0, 100.0, 200.0],
}
)
masks = registry.build_masks(asset_codes, metadata_df=metadata)
assert masks["small_cap"].sum() == 2
assert masks["small_cap"][1] and masks["small_cap"][2]
def test_expr_fallback(self):
registry = StockPoolRegistry()
registry.add_pool(
"expr_pool", lambda df: pl.col("ts_code").str.starts_with("300")
)
asset_codes = ["000001.SZ", "300001.SZ"]
masks = registry.build_masks(asset_codes)
assert masks["expr_pool"].sum() == 1
def test_pre_filters(self):
registry = StockPoolRegistry(
pre_filters=[lambda df: df["ts_code"].str.starts_with("0")]
)
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
masks = registry.build_masks(asset_codes)
assert masks["all"].sum() == 1
assert masks["all"][0]