feat(factorminer): 支持多股票池因子评估
- 新增 StockPoolRegistry 与默认股票池定义(all/small_cap 等) - RalphLoop 与 ValidationPipeline 支持按多股票池计算 IC - Factor 序列化新增 pool_metrics 字段 - LocalFactorEvaluator 增加 evaluate_returns_by_pool 方法 - 主入口集成股票池配置与 ST 预过滤
This commit is contained in:
1165
docs/plans/2026-04-11-factorminer-multi-stock-pool.md
Normal file
1165
docs/plans/2026-04-11-factorminer-multi-stock-pool.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -42,6 +42,7 @@ class Factor:
|
||||
batch_number: int # Which mining batch admitted this factor
|
||||
admission_date: str = ""
|
||||
signals: Optional[np.ndarray] = field(default=None, repr=False) # (M, T)
|
||||
pool_metrics: Dict[str, dict] = field(default_factory=dict)
|
||||
research_metrics: dict = field(default_factory=dict)
|
||||
provenance: dict = field(default_factory=dict)
|
||||
metadata: dict = field(default_factory=dict)
|
||||
@@ -50,6 +51,20 @@ class Factor:
|
||||
if not self.admission_date:
|
||||
self.admission_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_pool_metrics(pool_metrics: Dict[str, dict]) -> Dict[str, dict]:
|
||||
sanitized: Dict[str, dict] = {}
|
||||
for pool_name, stats in pool_metrics.items():
|
||||
sanitized[pool_name] = {}
|
||||
for key, value in stats.items():
|
||||
if key == "ic_series":
|
||||
continue
|
||||
if isinstance(value, np.ndarray):
|
||||
sanitized[pool_name][key] = value.tolist()
|
||||
else:
|
||||
sanitized[pool_name][key] = value
|
||||
return sanitized
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize to a JSON-compatible dictionary (excludes signals)."""
|
||||
return {
|
||||
@@ -63,6 +78,7 @@ class Factor:
|
||||
"max_correlation": self.max_correlation,
|
||||
"batch_number": self.batch_number,
|
||||
"admission_date": self.admission_date,
|
||||
"pool_metrics": self._sanitize_pool_metrics(self.pool_metrics),
|
||||
"research_metrics": self.research_metrics,
|
||||
"provenance": self.provenance,
|
||||
"metadata": self.metadata,
|
||||
@@ -82,6 +98,7 @@ class Factor:
|
||||
max_correlation=d["max_correlation"],
|
||||
batch_number=d["batch_number"],
|
||||
admission_date=d.get("admission_date", ""),
|
||||
pool_metrics=d.get("pool_metrics", {}),
|
||||
research_metrics=d.get("research_metrics", {}),
|
||||
provenance=d.get("provenance", {}),
|
||||
metadata=d.get("metadata", {}),
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -229,7 +229,7 @@ class HelixLoop(RalphLoop):
|
||||
def __init__(
|
||||
self,
|
||||
config: Any,
|
||||
returns: np.ndarray,
|
||||
returns: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
llm_provider: Optional[LLMProvider] = None,
|
||||
memory: Optional[ExperienceMemory] = None,
|
||||
library: Optional[FactorLibrary] = None,
|
||||
|
||||
@@ -257,9 +257,13 @@ class ValidationPipeline:
|
||||
evaluator: Any = None,
|
||||
) -> None:
|
||||
self.data_tensor = data_tensor # (M, T, F)
|
||||
if isinstance(returns, dict):
|
||||
self.returns = returns.get("all", next(iter(returns.values())))
|
||||
self.target_panels = returns
|
||||
else:
|
||||
self.returns = returns # (M, T)
|
||||
self.evaluator = evaluator
|
||||
self.target_panels = target_panels or {"paper": returns}
|
||||
self.evaluator = evaluator
|
||||
self.target_horizons = target_horizons or {"paper": 1}
|
||||
self.library = library or FactorLibrary(
|
||||
correlation_threshold=0.5,
|
||||
@@ -353,20 +357,18 @@ class ValidationPipeline:
|
||||
result.stage_passed = 0
|
||||
return result
|
||||
|
||||
# Full IC statistics on all assets
|
||||
stats = compute_factor_stats(signals, self.returns)
|
||||
result.ic_mean = stats["ic_abs_mean"]
|
||||
result.icir = stats["icir"]
|
||||
result.ic_win_rate = stats["ic_win_rate"]
|
||||
result.target_stats = {"paper": stats}
|
||||
# Full IC statistics on all target panels (stock pools)
|
||||
all_stats: Dict[str, dict] = {}
|
||||
for panel_name, panel_returns in self.target_panels.items():
|
||||
all_stats[panel_name] = compute_factor_stats(signals, panel_returns)
|
||||
|
||||
if self.target_panels:
|
||||
for target_name, target_returns in self.target_panels.items():
|
||||
if target_name == "paper":
|
||||
continue
|
||||
result.target_stats[target_name] = compute_factor_stats(
|
||||
signals, target_returns
|
||||
)
|
||||
best_panel_name = max(all_stats, key=lambda k: all_stats[k]["ic_abs_mean"])
|
||||
best_stats = all_stats[best_panel_name]
|
||||
|
||||
result.ic_mean = best_stats["ic_abs_mean"]
|
||||
result.icir = best_stats["icir"]
|
||||
result.ic_win_rate = best_stats["ic_win_rate"]
|
||||
result.target_stats = all_stats
|
||||
|
||||
score_vector_obj = None
|
||||
if self._research_enabled():
|
||||
@@ -737,7 +739,7 @@ class RalphLoop:
|
||||
def __init__(
|
||||
self,
|
||||
config: Any,
|
||||
returns: np.ndarray,
|
||||
returns: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
llm_provider: Optional[LLMProvider] = None,
|
||||
memory: Optional[ExperienceMemory] = None,
|
||||
library: Optional[FactorLibrary] = None,
|
||||
@@ -767,7 +769,12 @@ class RalphLoop:
|
||||
the legacy data_tensor path.
|
||||
"""
|
||||
self.config = config
|
||||
if isinstance(returns, dict):
|
||||
self.returns = returns.get("all", next(iter(returns.values())))
|
||||
self.target_panels = returns
|
||||
else:
|
||||
self.returns = returns
|
||||
self.target_panels = getattr(config, "target_panels", None)
|
||||
self.evaluator = evaluator
|
||||
self.checkpoint_interval = checkpoint_interval
|
||||
|
||||
@@ -783,9 +790,9 @@ class RalphLoop:
|
||||
prompt_builder=PromptBuilder(),
|
||||
)
|
||||
self.pipeline = ValidationPipeline(
|
||||
data_tensor=np.empty((returns.shape[0], returns.shape[1], 1)),
|
||||
returns=returns,
|
||||
target_panels=getattr(config, "target_panels", None),
|
||||
data_tensor=np.empty((self.returns.shape[0], self.returns.shape[1], 1)),
|
||||
returns=self.returns,
|
||||
target_panels=self.target_panels,
|
||||
target_horizons=getattr(config, "target_horizons", None),
|
||||
library=self.library,
|
||||
ic_threshold=getattr(config, "ic_threshold", 0.04),
|
||||
@@ -1082,6 +1089,14 @@ class RalphLoop:
|
||||
# Handle replacement
|
||||
if result.replaced is not None:
|
||||
old_id = result.replaced
|
||||
if result.target_stats:
|
||||
target_stats_clean = {}
|
||||
for pool_name, stats in result.target_stats.items():
|
||||
target_stats_clean[pool_name] = {
|
||||
k: v for k, v in stats.items() if k != "ic_series"
|
||||
}
|
||||
else:
|
||||
target_stats_clean = {}
|
||||
new_factor = Factor(
|
||||
id=0, # Will be reassigned by library
|
||||
name=result.factor_name,
|
||||
@@ -1093,6 +1108,7 @@ class RalphLoop:
|
||||
max_correlation=result.max_correlation,
|
||||
batch_number=self.iteration,
|
||||
signals=result.signals,
|
||||
pool_metrics=target_stats_clean,
|
||||
research_metrics=result.score_vector or {},
|
||||
)
|
||||
try:
|
||||
@@ -1110,6 +1126,14 @@ class RalphLoop:
|
||||
)
|
||||
else:
|
||||
# Direct admission
|
||||
if result.target_stats:
|
||||
target_stats_clean = {}
|
||||
for pool_name, stats in result.target_stats.items():
|
||||
target_stats_clean[pool_name] = {
|
||||
k: v for k, v in stats.items() if k != "ic_series"
|
||||
}
|
||||
else:
|
||||
target_stats_clean = {}
|
||||
factor = Factor(
|
||||
id=0, # Will be reassigned
|
||||
name=result.factor_name,
|
||||
@@ -1121,6 +1145,7 @@ class RalphLoop:
|
||||
max_correlation=result.max_correlation,
|
||||
batch_number=self.iteration,
|
||||
signals=result.signals,
|
||||
pool_metrics=target_stats_clean,
|
||||
research_metrics=result.score_vector or {},
|
||||
)
|
||||
self.library.admit_factor(factor)
|
||||
@@ -1435,7 +1460,7 @@ class RalphLoop:
|
||||
cls,
|
||||
checkpoint_path: str,
|
||||
config: Any,
|
||||
returns: np.ndarray,
|
||||
returns: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||
llm_provider: Optional[LLMProvider] = None,
|
||||
**kwargs: Any,
|
||||
) -> "RalphLoop":
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
|
||||
|
||||
|
||||
class LocalFactorEvaluator:
|
||||
@@ -40,6 +41,7 @@ class LocalFactorEvaluator:
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
stock_pool_registry: Optional[StockPoolRegistry] = None,
|
||||
) -> None:
|
||||
"""初始化评估器。
|
||||
|
||||
@@ -47,11 +49,14 @@ class LocalFactorEvaluator:
|
||||
start_date: 计算开始日期,YYYYMMDD 格式
|
||||
end_date: 计算结束日期,YYYYMMDD 格式
|
||||
stock_codes: 可选的股票代码列表,None 表示全量
|
||||
stock_pool_registry: 股票池注册表,用于多股票池评估
|
||||
"""
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.stock_codes = stock_codes
|
||||
self.engine = FactorEngine()
|
||||
self.stock_pool_registry = stock_pool_registry
|
||||
self._asset_codes: Optional[List[str]] = None
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
@@ -166,6 +171,76 @@ class LocalFactorEvaluator:
|
||||
print(f"[ERROR] 计算收益率矩阵失败: {e}")
|
||||
raise
|
||||
|
||||
def get_asset_codes(self) -> List[str]:
|
||||
"""返回最近一次 evaluate/evaluate_returns 中的资产代码列表。
|
||||
|
||||
Returns:
|
||||
按字母序排列的股票代码列表 (M,)。
|
||||
"""
|
||||
if self._asset_codes is None:
|
||||
raise RuntimeError("请先调用 evaluate() 或 evaluate_returns()")
|
||||
return self._asset_codes
|
||||
|
||||
def _get_metadata_df(self, columns: List[str]) -> Optional[pl.DataFrame]:
|
||||
"""拉取构建股票池所需的元数据(默认取 end_date 最新截面)。
|
||||
|
||||
Args:
|
||||
columns: 需要的额外列名列表。
|
||||
|
||||
Returns:
|
||||
包含 ts_code 和 columns 的 DataFrame,若失败则返回 None。
|
||||
"""
|
||||
if not columns:
|
||||
return None
|
||||
try:
|
||||
df = self.engine.router._load_table(
|
||||
table_name="daily_basic",
|
||||
columns=columns,
|
||||
start_date=self.end_date,
|
||||
end_date=self.end_date,
|
||||
stock_codes=self.stock_codes,
|
||||
)
|
||||
if "trade_date" in df.columns:
|
||||
df = df.sort("trade_date", descending=True)
|
||||
df = df.unique(subset=["ts_code"], maintain_order=True)
|
||||
return df.select(["ts_code"] + columns)
|
||||
except Exception as exc:
|
||||
print(f"[WARN] 拉取股票池元数据失败: {exc}")
|
||||
return None
|
||||
|
||||
def evaluate_returns_by_pool(self, periods: int = 1) -> Dict[str, np.ndarray]:
|
||||
"""计算各股票池的收益率矩阵。
|
||||
|
||||
先计算全市场收益率矩阵,然后根据注册的股票池掩码生成子池矩阵。
|
||||
未入选该池的资产在所有时间上的收益率标记为 NaN。
|
||||
|
||||
Args:
|
||||
periods: 计算 N 日后的收益率,默认 1。
|
||||
|
||||
Returns:
|
||||
{pool_name: (M, T) np.ndarray} 字典,包含所有注册的股票池和 "all"。
|
||||
"""
|
||||
returns_all = self.evaluate_returns(periods=periods)
|
||||
result: Dict[str, np.ndarray] = {"all": returns_all}
|
||||
|
||||
if self.stock_pool_registry is None:
|
||||
return result
|
||||
|
||||
codes = self.get_asset_codes()
|
||||
req_cols = self.stock_pool_registry.get_required_columns()
|
||||
metadata = self._get_metadata_df(req_cols)
|
||||
self.stock_pool_registry.build_masks(codes, metadata_df=metadata)
|
||||
|
||||
for name in self.stock_pool_registry.get_pool_names():
|
||||
if name == "all":
|
||||
continue
|
||||
mask = self.stock_pool_registry.masks[name]
|
||||
pool_returns = np.full_like(returns_all, np.nan, dtype=np.float64)
|
||||
pool_returns[mask, :] = returns_all[mask, :]
|
||||
result[name] = pool_returns
|
||||
|
||||
return result
|
||||
|
||||
def _pivot_to_matrix(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
@@ -191,6 +266,8 @@ class LocalFactorEvaluator:
|
||||
# 获取时间戳和股票代码的唯一值(已排序)
|
||||
timestamps = df["trade_date"].unique().sort()
|
||||
asset_codes = df["ts_code"].unique().sort()
|
||||
if self._asset_codes is None:
|
||||
self._asset_codes = asset_codes.to_list()
|
||||
|
||||
n_assets = len(asset_codes)
|
||||
n_times = len(timestamps)
|
||||
|
||||
150
src/factorminer/evaluation/stock_pool_registry.py
Normal file
150
src/factorminer/evaluation/stock_pool_registry.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""股票池注册表:支持配置化股票池筛选与掩码生成。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
|
||||
@dataclass
|
||||
class StockPoolDefinition:
|
||||
"""股票池定义,参考 experiment/common.py 中的 filter 模式。"""
|
||||
|
||||
name: str
|
||||
filter_func: Callable[[pl.DataFrame], pl.Series | pl.Expr]
|
||||
required_columns: List[str] = field(default_factory=list)
|
||||
description: str = ""
|
||||
|
||||
|
||||
class StockPoolRegistry:
|
||||
"""管理多个股票池的定义,并为 (M,) 资产列表生成布尔掩码。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
definitions: Optional[List[StockPoolDefinition]] = None,
|
||||
pre_filters: Optional[
|
||||
List[Callable[[pl.DataFrame], pl.Series | pl.Expr]]
|
||||
] = None,
|
||||
) -> None:
|
||||
self.pre_filters: List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] = (
|
||||
pre_filters or []
|
||||
)
|
||||
self.pools: Dict[str, dict] = {}
|
||||
self.masks: Dict[str, np.ndarray] = {}
|
||||
self._resolved = False
|
||||
if definitions:
|
||||
for d in definitions:
|
||||
self.add_pool(d.name, d.filter_func, d.required_columns)
|
||||
|
||||
@classmethod
|
||||
def from_definitions(
|
||||
cls,
|
||||
definitions: List[StockPoolDefinition],
|
||||
pre_filters: Optional[
|
||||
List[Callable[[pl.DataFrame], pl.Series | pl.Expr]]
|
||||
] = None,
|
||||
) -> "StockPoolRegistry":
|
||||
"""通过定义列表快速创建注册表。"""
|
||||
return cls(definitions=definitions, pre_filters=pre_filters)
|
||||
|
||||
def add_pool(
|
||||
self,
|
||||
name: str,
|
||||
filter_func: Callable[[pl.DataFrame], pl.Series | pl.Expr],
|
||||
required_columns: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""注册一个股票池。
|
||||
|
||||
Args:
|
||||
name: 股票池名称,如 "small_cap"。
|
||||
filter_func: 接收包含 ts_code 和 required_columns 的 DataFrame,
|
||||
返回布尔 Series 或 Expr。
|
||||
required_columns: filter_func 额外需要的列名列表。
|
||||
"""
|
||||
self.pools[name] = {
|
||||
"filter_func": filter_func,
|
||||
"required_columns": required_columns or [],
|
||||
}
|
||||
self._resolved = False
|
||||
|
||||
@staticmethod
|
||||
def _eval_filter_result(
|
||||
df: pl.DataFrame,
|
||||
result: pl.Series | pl.Expr,
|
||||
context: str,
|
||||
) -> np.ndarray:
|
||||
"""将 filter_func 的返回值统一转换为 numpy bool 数组。"""
|
||||
if isinstance(result, pl.Series):
|
||||
mask_series = result
|
||||
elif isinstance(result, pl.Expr):
|
||||
mask_series = df.select(result.alias("_mask")).to_series()
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{context} 的 filter_func 必须返回 pl.Series 或 pl.Expr,"
|
||||
f"实际返回 {type(result)}"
|
||||
)
|
||||
return mask_series.to_numpy().astype(bool)
|
||||
|
||||
def build_masks(
|
||||
self,
|
||||
asset_codes: List[str],
|
||||
metadata_df: Optional[pl.DataFrame] = None,
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""为所有注册的股票池生成布尔掩码。
|
||||
|
||||
Args:
|
||||
asset_codes: 按矩阵顺序排列的股票代码列表 (M,)。
|
||||
metadata_df: 包含 extra columns 的 metadata,可选。
|
||||
|
||||
Returns:
|
||||
{pool_name: bool_array_of_shape_(M,)} 字典。
|
||||
"""
|
||||
base_df = pl.DataFrame({"ts_code": asset_codes})
|
||||
df = base_df
|
||||
if metadata_df is not None:
|
||||
df = base_df.join(metadata_df, on="ts_code", how="left")
|
||||
|
||||
base_mask = np.ones(len(asset_codes), dtype=bool)
|
||||
for pre_filter in self.pre_filters:
|
||||
result = pre_filter(df)
|
||||
base_mask &= self._eval_filter_result(df, result, "pre_filter")
|
||||
|
||||
self.masks = {}
|
||||
for name, cfg in self.pools.items():
|
||||
result = cfg["filter_func"](df)
|
||||
pool_mask = self._eval_filter_result(df, result, f"股票池 '{name}'")
|
||||
self.masks[name] = base_mask & pool_mask
|
||||
|
||||
self._resolved = True
|
||||
return self.masks
|
||||
|
||||
def filter_signals(self, signals: np.ndarray, pool_name: str) -> np.ndarray:
|
||||
"""使用已构建的掩码过滤信号矩阵。
|
||||
|
||||
Args:
|
||||
signals: (M, T) 信号矩阵。
|
||||
pool_name: 股票池名称。
|
||||
|
||||
Returns:
|
||||
(M_pool, T) 的子矩阵。
|
||||
"""
|
||||
if not self._resolved:
|
||||
raise RuntimeError("请先调用 build_masks 构建掩码")
|
||||
mask = self.masks.get(pool_name)
|
||||
if mask is None:
|
||||
raise KeyError(f"未知的股票池: {pool_name}")
|
||||
return signals[mask, :]
|
||||
|
||||
def get_pool_names(self) -> List[str]:
|
||||
"""返回已注册的股票池名称列表。"""
|
||||
return list(self.pools.keys())
|
||||
|
||||
def get_required_columns(self) -> List[str]:
|
||||
"""返回所有 pre_filter 和股票池所需的列名并集。"""
|
||||
cols: set[str] = set()
|
||||
for pre_filter in self.pre_filters:
|
||||
cols.update(getattr(pre_filter, "required_columns", []))
|
||||
for cfg in self.pools.values():
|
||||
cols.update(cfg["required_columns"])
|
||||
return sorted(cols)
|
||||
@@ -11,6 +11,8 @@ from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.factorminer.agent.llm_interface import create_provider, AnthropicProvider
|
||||
from src.factorminer.core.config import MiningConfig as CoreMiningConfig
|
||||
@@ -19,7 +21,10 @@ from src.factorminer.core.ralph_loop import RalphLoop
|
||||
from src.factorminer.core.helix_loop import HelixLoop
|
||||
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
||||
from src.factorminer.evaluation.significance import SignificanceConfig
|
||||
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
|
||||
from src.factorminer.pool_definitions import get_default_stock_pool_definitions
|
||||
from src.factorminer.utils.config import load_config
|
||||
from src.training.components.filters import STFilter
|
||||
|
||||
RUN_CONFIG: dict = {
|
||||
# 全局开关
|
||||
@@ -35,9 +40,16 @@ RUN_CONFIG: dict = {
|
||||
"stock_codes": None, # 可选股票列表,None 表示全量
|
||||
# 种子库
|
||||
"seed_paper_library": True, # 是否预加载 110 Paper Factors 作为种子库
|
||||
# 股票池配置(支持多股票池评估)
|
||||
"stock_pools": {
|
||||
"enabled": True, # 是否启用多股票池
|
||||
"provider": "default", # "default" 使用 pool_definitions.py 中的默认定义
|
||||
"use_st_filter": True, # 是否在构建股票池前过滤 ST 股票
|
||||
"pools": ["all", "small_cap"], # 要启用的股票池名称列表,None 表示全部启用
|
||||
},
|
||||
# 挖掘配置(覆盖 default.yaml 中的同名项)
|
||||
"mining": {
|
||||
"max_iterations": 3,
|
||||
"max_iterations": 1,
|
||||
"target_library_size": 10,
|
||||
"correlation_threshold": 0.50,
|
||||
"ic_threshold": 0.04,
|
||||
@@ -139,6 +151,54 @@ def _build_core_mining_config(run_cfg: dict) -> CoreMiningConfig:
|
||||
return cfg
|
||||
|
||||
|
||||
def _build_st_pre_filter(st_filter: STFilter, end_date: str):
|
||||
"""将 STFilter 包装为 StockPoolRegistry 可接受的 pre-filter 函数。"""
|
||||
|
||||
def _pre_filter(df: pl.DataFrame) -> pl.Series:
|
||||
try:
|
||||
if "trade_date" not in df.columns:
|
||||
df = df.with_columns(pl.lit(end_date).alias("trade_date"))
|
||||
filtered = st_filter.filter(df)
|
||||
return df["ts_code"].is_in(filtered["ts_code"])
|
||||
except Exception as exc:
|
||||
print(f"[WARN] ST pre-filter 失败: {exc}")
|
||||
return pl.Series([True] * len(df))
|
||||
|
||||
return _pre_filter
|
||||
|
||||
|
||||
def _build_stock_pool_registry(
|
||||
run_cfg: dict,
|
||||
st_filter: Optional[STFilter] = None,
|
||||
end_date: str = "20231231",
|
||||
) -> Optional[StockPoolRegistry]:
|
||||
"""根据 RUN_CONFIG 构建股票池注册表。"""
|
||||
pool_cfg = run_cfg.get("stock_pools", {})
|
||||
if not pool_cfg.get("enabled", False):
|
||||
return None
|
||||
|
||||
pre_filters = []
|
||||
if pool_cfg.get("use_st_filter", False) and st_filter is not None:
|
||||
pre_filters.append(_build_st_pre_filter(st_filter, end_date))
|
||||
|
||||
provider = pool_cfg.get("provider", "default")
|
||||
if provider == "default":
|
||||
definitions = get_default_stock_pool_definitions()
|
||||
pool_names = pool_cfg.get("pools")
|
||||
if pool_names is not None:
|
||||
definitions = [d for d in definitions if d.name in pool_names]
|
||||
if not definitions:
|
||||
raise ValueError(
|
||||
f"股票池配置错误: pools={pool_names} 未匹配到任何默认股票池定义"
|
||||
)
|
||||
missing = set(pool_names) - {d.name for d in definitions}
|
||||
if missing:
|
||||
print(f"[WARN] 以下股票池未在默认定义中找到,已忽略: {sorted(missing)}")
|
||||
return StockPoolRegistry.from_definitions(definitions, pre_filters=pre_filters)
|
||||
|
||||
raise ValueError(f"不支持的股票池 provider: {provider}")
|
||||
|
||||
|
||||
def _build_helix_kwargs(run_cfg: dict) -> dict:
|
||||
"""从 RUN_CONFIG 构建 HelixLoop 需要的 Phase 2 扩展配置。"""
|
||||
helix = run_cfg.get("helix", {})
|
||||
@@ -215,14 +275,33 @@ def main(config: dict | None = None) -> None:
|
||||
end_date = run_cfg.get("end_date", "20201231")
|
||||
stock_codes = run_cfg.get("stock_codes")
|
||||
|
||||
# 4.1 创建 evaluator(先实例化才能拿到 router)
|
||||
evaluator = LocalFactorEvaluator(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
|
||||
# 4.2 构建 STFilter 和股票池注册表
|
||||
st_filter = STFilter(data_router=evaluator.engine.router)
|
||||
stock_pool_registry = _build_stock_pool_registry(
|
||||
run_cfg, st_filter=st_filter, end_date=end_date
|
||||
)
|
||||
evaluator.stock_pool_registry = stock_pool_registry
|
||||
|
||||
if stock_pool_registry is not None:
|
||||
returns = evaluator.evaluate_returns_by_pool(periods=1)
|
||||
for pool_name, pool_ret in returns.items():
|
||||
valid_assets = int((~np.isnan(pool_ret).all(axis=1)).sum())
|
||||
print(
|
||||
f"[main] 股票池 {pool_name}: {valid_assets} 只资产, "
|
||||
f"returns shape: {pool_ret.shape}"
|
||||
)
|
||||
else:
|
||||
returns = evaluator.evaluate_returns(periods=1)
|
||||
print(
|
||||
f"[main] 本地数据范围: {start_date} ~ {end_date}, returns shape: {returns.shape}"
|
||||
f"[main] 本地数据范围: {start_date} ~ {end_date}, "
|
||||
f"returns shape: {returns.shape}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
171
src/factorminer/pool_definitions.py
Normal file
171
src/factorminer/pool_definitions.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""用户可配置的股票池定义文件。
|
||||
|
||||
参考 `src/experiment/common.py` 中的 `stock_pool_filter` 模式,
|
||||
支持通过编写 Polars filter 函数来自定义股票池。
|
||||
"""
|
||||
|
||||
from typing import Callable, List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factorminer.evaluation.stock_pool_registry import StockPoolDefinition, StockPoolRegistry
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 股票池大小配置
|
||||
# =============================================================================
|
||||
SMALL_CAP_N = 1000 # 小微盘股票池默认选取的股票数量
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 股票池筛选函数(与 experiment/common.py 风格一致)
|
||||
# =============================================================================
|
||||
def all_market_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""全市场股票池。"""
|
||||
return pl.Series([True] * len(df))
|
||||
|
||||
|
||||
def growth_board_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""创业板股票池(代码以 300 开头)。"""
|
||||
return df["ts_code"].str.starts_with("300")
|
||||
|
||||
|
||||
def star_board_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""科创板股票池(代码以 688 开头)。"""
|
||||
return df["ts_code"].str.starts_with("688")
|
||||
|
||||
|
||||
def bse_board_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""北交所股票池(代码以 8 或 4 开头)。"""
|
||||
return df["ts_code"].str.starts_with("8") | df["ts_code"].str.starts_with("4")
|
||||
|
||||
|
||||
def main_board_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""主板股票池(排除创业板、科创板、北交所)。"""
|
||||
return (
|
||||
~df["ts_code"].str.starts_with("300")
|
||||
& ~df["ts_code"].str.starts_with("688")
|
||||
& ~df["ts_code"].str.starts_with("8")
|
||||
& ~df["ts_code"].str.starts_with("4")
|
||||
)
|
||||
|
||||
|
||||
def small_cap_filter(df: pl.DataFrame, n_stocks: int = SMALL_CAP_N) -> pl.Series:
|
||||
"""小微盘股票池(市值最小的 n_stocks 只股票)。
|
||||
|
||||
Args:
|
||||
df: 包含 ts_code 和 circ_mv 列的数据框。
|
||||
n_stocks: 选取的股票数量,默认 SMALL_CAP_N。
|
||||
|
||||
Returns:
|
||||
布尔 Series,表示是否入选小微盘股票池。
|
||||
"""
|
||||
if "circ_mv" not in df.columns:
|
||||
# 若缺失 circ_mv 数据则全部排除(安全降级)
|
||||
return pl.Series([False] * len(df))
|
||||
n = min(n_stocks, len(df))
|
||||
small_codes = df.sort("circ_mv").head(n)["ts_code"]
|
||||
return df["ts_code"].is_in(small_codes.implode())
|
||||
|
||||
|
||||
def main_board_small_cap_filter(
|
||||
df: pl.DataFrame, n_stocks: int = SMALL_CAP_N
|
||||
) -> pl.Series:
|
||||
"""主板小微盘股票池(主板中市值最小的 n_stocks 只股票)。"""
|
||||
main_board = main_board_filter(df)
|
||||
main_df = df.filter(main_board)
|
||||
if "circ_mv" not in df.columns or len(main_df) == 0:
|
||||
return pl.Series([False] * len(df))
|
||||
n = min(n_stocks, len(main_df))
|
||||
small_codes = main_df.sort("circ_mv").head(n)["ts_code"]
|
||||
return df["ts_code"].is_in(small_codes.implode())
|
||||
|
||||
|
||||
def common_small_cap_filter(df: pl.DataFrame, n_stocks: int = SMALL_CAP_N) -> pl.Series:
|
||||
"""与 experiment/common.py 完全一致的小微盘筛选函数。
|
||||
|
||||
筛选条件:
|
||||
1. 排除创业板(代码以 300 开头)
|
||||
2. 排除科创板(代码以 688 开头)
|
||||
3. 排除北交所(代码以 8、9 或 4 开头)
|
||||
4. 选取当日流通市值最小的 n_stocks 只股票
|
||||
|
||||
Args:
|
||||
df: 数据框,必须包含 ts_code 和 circ_mv 列
|
||||
n_stocks: 选取的股票数量,默认 SMALL_CAP_N
|
||||
|
||||
Returns:
|
||||
布尔 Series,表示哪些股票被选中
|
||||
"""
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(n_stocks, len(valid_df))
|
||||
small_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
|
||||
return df["ts_code"].is_in(small_codes.implode())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 默认股票池定义列表
|
||||
# =============================================================================
|
||||
def get_default_stock_pool_definitions() -> List[StockPoolDefinition]:
|
||||
"""返回默认的股票池定义列表。"""
|
||||
return [
|
||||
StockPoolDefinition(
|
||||
name="all",
|
||||
filter_func=all_market_filter,
|
||||
required_columns=[],
|
||||
description="全市场",
|
||||
),
|
||||
StockPoolDefinition(
|
||||
name="growth_board",
|
||||
filter_func=growth_board_filter,
|
||||
required_columns=[],
|
||||
description="创业板",
|
||||
),
|
||||
StockPoolDefinition(
|
||||
name="star_board",
|
||||
filter_func=star_board_filter,
|
||||
required_columns=[],
|
||||
description="科创板",
|
||||
),
|
||||
StockPoolDefinition(
|
||||
name="bse_board",
|
||||
filter_func=bse_board_filter,
|
||||
required_columns=[],
|
||||
description="北交所",
|
||||
),
|
||||
StockPoolDefinition(
|
||||
name="main_board",
|
||||
filter_func=main_board_filter,
|
||||
required_columns=[],
|
||||
description="主板",
|
||||
),
|
||||
StockPoolDefinition(
|
||||
name="small_cap",
|
||||
filter_func=small_cap_filter,
|
||||
required_columns=["circ_mv"],
|
||||
description="小微盘",
|
||||
),
|
||||
StockPoolDefinition(
|
||||
name="main_board_small_cap",
|
||||
filter_func=main_board_small_cap_filter,
|
||||
required_columns=["circ_mv"],
|
||||
description="主板小微盘",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_default_stock_pools(
|
||||
pre_filters: List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] | None = None,
|
||||
) -> StockPoolRegistry:
|
||||
"""返回默认的股票池注册表。"""
|
||||
return StockPoolRegistry.from_definitions(
|
||||
get_default_stock_pool_definitions(),
|
||||
pre_filters=pre_filters,
|
||||
)
|
||||
@@ -21,6 +21,7 @@ from src.factorminer.evaluation.metrics import (
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rng():
|
||||
return np.random.default_rng(123)
|
||||
@@ -58,6 +59,7 @@ def known_quintile_signal(rng):
|
||||
# IC computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIC:
|
||||
"""Test Information Coefficient computation."""
|
||||
|
||||
@@ -106,6 +108,7 @@ class TestIC:
|
||||
# ICIR computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestICIR:
|
||||
"""Test ICIR = mean(IC) / std(IC)."""
|
||||
|
||||
@@ -142,6 +145,7 @@ class TestICIR:
|
||||
# IC-derived statistics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestICStats:
|
||||
"""Test IC mean and win rate."""
|
||||
|
||||
@@ -170,6 +174,7 @@ class TestICStats:
|
||||
# Pairwise correlation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPairwiseCorrelation:
|
||||
"""Test pairwise cross-sectional correlation."""
|
||||
|
||||
@@ -207,6 +212,7 @@ class TestPairwiseCorrelation:
|
||||
# Quintile returns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuintileReturns:
|
||||
"""Test quintile return computation."""
|
||||
|
||||
@@ -243,6 +249,7 @@ class TestQuintileReturns:
|
||||
# Turnover
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTurnover:
|
||||
"""Test portfolio turnover computation."""
|
||||
|
||||
@@ -263,6 +270,7 @@ class TestTurnover:
|
||||
# Comprehensive factor stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFactorStats:
|
||||
"""Test the compute_factor_stats wrapper."""
|
||||
|
||||
@@ -285,3 +293,18 @@ class TestFactorStats:
|
||||
returns = rng.normal(0, 0.01, (M, T))
|
||||
stats = compute_factor_stats(signals, returns)
|
||||
assert stats["ic_series"].shape == (T,)
|
||||
|
||||
|
||||
class TestFactorStatsPools:
|
||||
"""Test compute_factor_stats under multi-stock-pool NaN scenarios."""
|
||||
|
||||
def test_pool_with_nans(self, rng):
|
||||
M, T = 40, 30
|
||||
signals = rng.normal(0, 1, (M, T))
|
||||
returns = rng.normal(0, 0.01, (M, T))
|
||||
# 模拟子池:部分资产在所有时间上为 NaN
|
||||
signals[30:, :] = np.nan
|
||||
returns[30:, :] = np.nan
|
||||
stats = compute_factor_stats(signals, returns)
|
||||
assert stats["n_periods"] == T
|
||||
assert np.isfinite(stats["ic_mean"])
|
||||
|
||||
72
src/factorminer/tests/test_stock_pool.py
Normal file
72
src/factorminer/tests/test_stock_pool.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
|
||||
|
||||
|
||||
class TestStockPoolRegistry:
|
||||
def test_add_and_get_pool_names(self):
|
||||
registry = StockPoolRegistry()
|
||||
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
|
||||
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
|
||||
assert registry.get_pool_names() == ["all", "growth"]
|
||||
|
||||
def test_build_masks(self):
|
||||
registry = StockPoolRegistry()
|
||||
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
|
||||
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
|
||||
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
|
||||
masks = registry.build_masks(asset_codes)
|
||||
assert masks["all"].sum() == 3
|
||||
assert masks["growth"].sum() == 1
|
||||
assert masks["growth"][1]
|
||||
|
||||
def test_filter_signals(self):
|
||||
registry = StockPoolRegistry()
|
||||
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
|
||||
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
|
||||
registry.build_masks(asset_codes)
|
||||
signals = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float64)
|
||||
filtered = registry.filter_signals(signals, "growth")
|
||||
assert filtered.shape == (1, 2)
|
||||
np.testing.assert_array_equal(filtered, [[3, 4]])
|
||||
|
||||
def test_build_masks_with_metadata(self):
|
||||
registry = StockPoolRegistry()
|
||||
registry.add_pool(
|
||||
"small_cap",
|
||||
lambda df: df["ts_code"].is_in(
|
||||
df.sort("total_mv").head(2)["ts_code"].implode()
|
||||
),
|
||||
required_columns=["total_mv"],
|
||||
)
|
||||
asset_codes = ["A", "B", "C"]
|
||||
metadata = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["A", "B", "C"],
|
||||
"total_mv": [300.0, 100.0, 200.0],
|
||||
}
|
||||
)
|
||||
masks = registry.build_masks(asset_codes, metadata_df=metadata)
|
||||
assert masks["small_cap"].sum() == 2
|
||||
assert masks["small_cap"][1] and masks["small_cap"][2]
|
||||
|
||||
def test_expr_fallback(self):
|
||||
registry = StockPoolRegistry()
|
||||
registry.add_pool(
|
||||
"expr_pool", lambda df: pl.col("ts_code").str.starts_with("300")
|
||||
)
|
||||
asset_codes = ["000001.SZ", "300001.SZ"]
|
||||
masks = registry.build_masks(asset_codes)
|
||||
assert masks["expr_pool"].sum() == 1
|
||||
|
||||
def test_pre_filters(self):
|
||||
registry = StockPoolRegistry(
|
||||
pre_filters=[lambda df: df["ts_code"].str.starts_with("0")]
|
||||
)
|
||||
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
|
||||
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
|
||||
masks = registry.build_masks(asset_codes)
|
||||
assert masks["all"].sum() == 1
|
||||
assert masks["all"][0]
|
||||
Reference in New Issue
Block a user