From 521609e46ade2d83e46674c49f384152083a675f Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Sun, 12 Apr 2026 02:26:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(factorminer):=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=A4=9A=E8=82=A1=E7=A5=A8=E6=B1=A0=E5=9B=A0=E5=AD=90=E8=AF=84?= =?UTF-8?q?=E4=BC=B0=20-=20=E6=96=B0=E5=A2=9E=20StockPoolRegistry=20?= =?UTF-8?q?=E4=B8=8E=E9=BB=98=E8=AE=A4=E8=82=A1=E7=A5=A8=E6=B1=A0=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=EF=BC=88all/small=5Fcap=20=E7=AD=89=EF=BC=89=20-=20Ra?= =?UTF-8?q?lphLoop=20=E4=B8=8E=20ValidationPipeline=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=8C=89=E5=A4=9A=E8=82=A1=E7=A5=A8=E6=B1=A0=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=20IC=20-=20Factor=20=E5=BA=8F=E5=88=97=E5=8C=96=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=20pool=5Fmetrics=20=E5=AD=97=E6=AE=B5=20-=20LocalFact?= =?UTF-8?q?orEvaluator=20=E5=A2=9E=E5=8A=A0=20evaluate=5Freturns=5Fby=5Fpo?= =?UTF-8?q?ol=20=E6=96=B9=E6=B3=95=20-=20=E4=B8=BB=E5=85=A5=E5=8F=A3?= =?UTF-8?q?=E9=9B=86=E6=88=90=E8=82=A1=E7=A5=A8=E6=B1=A0=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E4=B8=8E=20ST=20=E9=A2=84=E8=BF=87=E6=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...2026-04-11-factorminer-multi-stock-pool.md | 1165 +++++++++++++++++ src/factorminer/core/factor_library.py | 17 + src/factorminer/core/helix_loop.py | 4 +- src/factorminer/core/ralph_loop.py | 67 +- src/factorminer/evaluation/local_engine.py | 77 ++ .../evaluation/stock_pool_registry.py | 150 +++ src/factorminer/main.py | 87 +- src/factorminer/pool_definitions.py | 171 +++ src/factorminer/tests/test_evaluation.py | 23 + src/factorminer/tests/test_stock_pool.py | 72 + 10 files changed, 1806 insertions(+), 27 deletions(-) create mode 100644 docs/plans/2026-04-11-factorminer-multi-stock-pool.md create mode 100644 src/factorminer/evaluation/stock_pool_registry.py create mode 100644 src/factorminer/pool_definitions.py create mode 100644 src/factorminer/tests/test_stock_pool.py diff --git a/docs/plans/2026-04-11-factorminer-multi-stock-pool.md b/docs/plans/2026-04-11-factorminer-multi-stock-pool.md new file mode 100644 index 0000000..8999b4a --- /dev/null +++ b/docs/plans/2026-04-11-factorminer-multi-stock-pool.md @@ -0,0 +1,1165 @@ +# FactorMiner 多股票池指标评估与入库改造计划 + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** 改造 `src/factorminer/` 以支持在多个股票池(如全市场、小微盘、创业板)上计算因子指标。若因子在任一股票池表现优异即入库,且入库时保存所有股票池的指标。股票池筛选通过用户自定义函数配置。 + +**Architecture:** +- 新增 `StockPoolRegistry` 统一管理股票池定义与掩码生成,支持用户通过 `filter_func` 配置(参考 `experiment/common.py` 的 `stock_pool_filter` 模式)。 +- 扩展 `LocalFactorEvaluator` 以输出各股票池的收益率矩阵;`ValidationPipeline` 复用现有的 `target_panels` 机制,对每个股票池计算 `compute_factor_stats`,只要有任一池子通过 IC/ICIR 阈值即允许入库。 +- `Factor` 数据类新增 `pool_metrics` 字段保存各池指标,入库和序列化时完整保留。 + +**Tech Stack:** Python 3.10+, Polars, NumPy, DuckDB (通过 `DataRouter._load_table` 查询元数据), pytest + +--- + +## 前置约定 + +- 所有新增/修改的代码必须位于 `src/factorminer/` 或 `tests/` 下。 +- 测试使用 `uv run pytest tests/xxx.py -v` 运行。 +- 代码注释和文档字符串使用中文。 +- 禁止在代码中使用 emoji。 + +--- + +## Task 1: 创建 `StockPoolRegistry`(股票池注册表) + +**Files:** +- Create: `src/factorminer/evaluation/stock_pool.py` +- Test: `src/factorminer/tests/test_stock_pool.py` + +**Step 1: 编写失败测试** + +```python +import numpy as np +import polars as pl +import pytest + +from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry + + +class TestStockPoolRegistry: + def test_add_and_get_pool_names(self): + registry = StockPoolRegistry() + registry.add_pool("all", lambda df: pl.Series([True] * len(df))) + registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300")) + assert registry.get_pool_names() == ["all", "growth"] + + def test_build_masks(self): + registry = StockPoolRegistry() + registry.add_pool("all", lambda df: pl.Series([True] * len(df))) + registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300")) + asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"] + masks = registry.build_masks(asset_codes) + assert masks["all"].sum() == 3 + assert masks["growth"].sum() == 1 + assert masks["growth"][1] + + def test_filter_signals(self): + registry = StockPoolRegistry() + registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300")) + asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"] + registry.build_masks(asset_codes) + signals = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float64) + filtered = registry.filter_signals(signals, "growth") + assert filtered.shape == (1, 2) + np.testing.assert_array_equal(filtered, [[3, 4]]) + + def test_build_masks_with_metadata(self): + registry = StockPoolRegistry() + registry.add_pool( + "small_cap", + lambda df: df["ts_code"].is_in(df.sort("total_mv").head(2)["ts_code"]), + required_columns=["total_mv"], + ) + asset_codes = ["A", "B", "C"] + metadata = pl.DataFrame({ + "ts_code": ["A", "B", "C"], + "total_mv": [300.0, 100.0, 200.0], + }) + masks = registry.build_masks(asset_codes, metadata_df=metadata) + assert masks["small_cap"].sum() == 2 + assert masks["small_cap"][1] and masks["small_cap"][2] +``` + +**Step 2: 运行测试确认失败** + +```bash +uv run pytest src/factorminer/tests/test_stock_pool.py -v +``` + +Expected: FAIL (module not found) + +**Step 3: 最小实现** + +```python +"""股票池注册表:支持配置化股票池筛选与掩码生成。""" + +from typing import Callable, Dict, List, Optional + +import numpy as np +import polars as pl + + +class StockPoolRegistry: + """管理多个股票池的定义,并为 (M,) 资产列表生成布尔掩码。""" + + def __init__(self) -> None: + self.pools: Dict[str, dict] = {} + self.masks: Dict[str, np.ndarray] = {} + self._resolved = False + + def add_pool( + self, + name: str, + filter_func: Callable[[pl.DataFrame], pl.Series], + required_columns: Optional[List[str]] = None, + ) -> None: + """注册一个股票池。 + + Args: + name: 股票池名称,如 "small_cap"。 + filter_func: 接收包含 ts_code 和 required_columns 的 DataFrame, + 返回布尔 Series。 + required_columns: filter_func 额外需要的列名列表。 + """ + self.pools[name] = { + "filter_func": filter_func, + "required_columns": required_columns or [], + } + self._resolved = False + + def build_masks( + self, + asset_codes: List[str], + metadata_df: Optional[pl.DataFrame] = None, + ) -> Dict[str, np.ndarray]: + """为所有注册的股票池生成布尔掩码。 + + Args: + asset_codes: 按矩阵顺序排列的股票代码列表 (M,)。 + metadata_df: 包含 extra columns 的 metadata,可选。 + + Returns: + {pool_name: bool_array_of_shape_(M,)} 字典。 + """ + base_df = pl.DataFrame({"ts_code": asset_codes}) + df = base_df + if metadata_df is not None: + df = base_df.join(metadata_df, on="ts_code", how="left") + + self.masks = {} + for name, cfg in self.pools.items(): + result = cfg["filter_func"](df) + if isinstance(result, pl.Series): + mask_series = result + elif isinstance(result, pl.Expr): + mask_series = df.select(result.alias("_mask")).to_series() + else: + raise TypeError( + f"股票池 '{name}' 的 filter_func 必须返回 pl.Series 或 pl.Expr," + f"实际返回 {type(result)}" + ) + self.masks[name] = mask_series.to_numpy().astype(bool) + + self._resolved = True + return self.masks + + def filter_signals(self, signals: np.ndarray, pool_name: str) -> np.ndarray: + """使用已构建的掩码过滤信号矩阵。 + + Args: + signals: (M, T) 信号矩阵。 + pool_name: 股票池名称。 + + Returns: + (M_pool, T) 的子矩阵。 + """ + if not self._resolved: + raise RuntimeError("请先调用 build_masks 构建掩码") + mask = self.masks.get(pool_name) + if mask is None: + raise KeyError(f"未知的股票池: {pool_name}") + return signals[mask, :] + + def get_pool_names(self) -> List[str]: + """返回已注册的股票池名称列表。""" + return list(self.pools.keys()) + + def get_required_columns(self) -> List[str]: + """返回所有股票池所需的列名并集。""" + cols: set[str] = set() + for cfg in self.pools.values(): + cols.update(cfg["required_columns"]) + return sorted(cols) +``` + +**Step 4: 运行测试确认通过** + +```bash +uv run pytest src/factorminer/tests/test_stock_pool.py -v +``` + +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/factorminer/evaluation/stock_pool_registry.py src/factorminer/tests/test_stock_pool.py +git commit -m "feat(factorminer): add StockPoolRegistry for configurable stock pools" +``` + +--- + +## Task 2: 扩展 `LocalFactorEvaluator` 支持股票池收益率与资产码暴露 + +**Files:** +- Modify: `src/factorminer/evaluation/local_engine.py` +- Test: `src/factorminer/tests/test_local_engine.py` (新建) + +**Step 1: 编写失败测试** + +```python +import numpy as np +import polars as pl +import pytest + +from src.factorminer.evaluation.local_engine import LocalFactorEvaluator +from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry + + +class TestLocalEnginePools: + def test_get_asset_codes_after_evaluate(self): + # 使用 mock engine 避免真实数据库依赖 + class MockRouter: + def _load_table(self, table, columns, start, end, stock_codes=None): + return pl.DataFrame({"ts_code": [], "trade_date": []}) + + class MockEngine: + router = MockRouter() + + def add_factor(self, name, formula): + pass + + def compute(self, **kwargs): + return pl.DataFrame({ + "ts_code": ["000001.SZ", "300001.SZ"], + "trade_date": ["20230101", "20230101"], + "ret": [0.01, 0.02], + }) + + def clear(self): + pass + + evaluator = LocalFactorEvaluator("20230101", "20230101") + evaluator.engine = MockEngine() + evaluator.evaluate_returns(periods=1) + codes = evaluator.get_asset_codes() + assert codes == ["000001.SZ", "300001.SZ"] + + def test_evaluate_returns_by_pool(self): + class MockRouter: + pass + + class MockEngine: + router = MockRouter() + + def add_factor(self, name, formula): + pass + + def compute(self, **kwargs): + return pl.DataFrame({ + "ts_code": ["000001.SZ", "300001.SZ", "688001.SH"], + "trade_date": ["20230101", "20230101", "20230101"], + "__returns_tmp": [0.01, 0.02, 0.03], + }) + + def clear(self): + pass + + registry = StockPoolRegistry() + registry.add_pool("all", lambda df: pl.Series([True] * len(df))) + registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300")) + + evaluator = LocalFactorEvaluator("20230101", "20230101", stock_pool_registry=registry) + evaluator.engine = MockEngine() + pool_returns = evaluator.evaluate_returns_by_pool(periods=1) + + assert "all" in pool_returns + assert "growth" in pool_returns + assert pool_returns["all"].shape == (3, 1) + assert pool_returns["growth"].shape == (1, 1) + np.testing.assert_array_equal(pool_returns["growth"], [[0.02]]) +``` + +**Step 2: 运行测试确认失败** + +```bash +uv run pytest src/factorminer/tests/test_local_engine.py -v +``` + +Expected: FAIL (methods not found) + +**Step 3: 最小实现** + +在 `src/factorminer/evaluation/local_engine.py` 中: + +1. 导入 `StockPoolRegistry`: + +```python +from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry +``` + +2. 修改 `__init__`: +```python + def __init__( + self, + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + stock_pool_registry: Optional[StockPoolRegistry] = None, + ) -> None: + self.start_date = start_date + self.end_date = end_date + self.stock_codes = stock_codes + self.engine = FactorEngine() + self.stock_pool_registry = stock_pool_registry + self._asset_codes: Optional[List[str]] = None +``` + +3. 在 `_pivot_to_matrix` 方法末尾(`return result` 之前)加入: +```python + # 缓存 asset_codes(按字母序,与矩阵行顺序一致) + if self._asset_codes is None: + self._asset_codes = asset_codes.to_list() +``` + +4. 新增两个方法(放在 `evaluate_single` 之后,`evaluate_returns` 之前): + +```python + def get_asset_codes(self) -> List[str]: + """获取上一次计算得到的资产代码列表(按矩阵行顺序)。 + + Returns: + 股票代码列表,仅在 evaluate / evaluate_returns 调用后才可用。 + """ + if self._asset_codes is None: + raise RuntimeError("请先调用 evaluate() 或 evaluate_returns()") + return self._asset_codes + + def _get_metadata_df(self, columns: List[str]) -> Optional[pl.DataFrame]: + """从 DataRouter 拉取指定列的截面元数据(使用 end_date 作为参考日期)。""" + if not columns: + return None + try: + df = self.engine.router._load_table( + table_name="daily_basic", + columns=columns, + start_date=self.end_date, + end_date=self.end_date, + stock_codes=self.stock_codes, + ) + # 去重保留最新一条 + if "trade_date" in df.columns: + df = df.sort("trade_date", descending=True) + df = df.unique(subset=["ts_code"], maintain_order=True) + return df.select(["ts_code"] + columns) + except Exception as exc: + print(f"[WARN] 拉取股票池元数据失败: {exc}") + return None +``` + +5. 在 `evaluate_returns` 之后新增 `evaluate_returns_by_pool`: + +```python + def evaluate_returns_by_pool( + self, + periods: int = 1, + ) -> Dict[str, np.ndarray]: + """计算各股票池的收益率矩阵。 + + 如果未配置 stock_pool_registry,则返回仅包含 'all' 的字典。 + + Returns: + {pool_name: (M_pool, T) returns 矩阵} 字典。 + """ + returns_all = self.evaluate_returns(periods=periods) + result: Dict[str, np.ndarray] = {"all": returns_all} + + if self.stock_pool_registry is None: + return result + + codes = self.get_asset_codes() + req_cols = self.stock_pool_registry.get_required_columns() + metadata = self._get_metadata_df(req_cols) + self.stock_pool_registry.build_masks(codes, metadata_df=metadata) + + for name in self.stock_pool_registry.get_pool_names(): + if name == "all": + continue + result[name] = self.stock_pool_registry.filter_signals(returns_all, name) + + return result +``` + +**Step 4: 运行测试确认通过** + +```bash +uv run pytest src/factorminer/tests/test_local_engine.py -v +``` + +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/factorminer/evaluation/local_engine.py src/factorminer/tests/test_local_engine.py +git commit -m "feat(factorminer): extend LocalFactorEvaluator to support multi-stock-pool returns" +``` + +--- + +## Task 3: 扩展 `Factor` 数据类保存股票池指标 + +**Files:** +- Modify: `src/factorminer/core/factor_library.py` +- Test: `src/factorminer/tests/test_library.py` + +**Step 1: 编写失败测试** + +在 `src/factorminer/tests/test_library.py` 中(如文件不存在则新建): + +```python +from src.factorminer.core.factor_library import Factor + + +def test_factor_pool_metrics_roundtrip(): + factor = Factor( + id=1, + name="test", + formula="close", + category="test", + ic_mean=0.05, + icir=0.5, + ic_win_rate=0.55, + max_correlation=0.3, + batch_number=1, + pool_metrics={ + "all": {"ic_abs_mean": 0.05, "icir": 0.5}, + "small_cap": {"ic_abs_mean": 0.08, "icir": 0.8}, + }, + ) + d = factor.to_dict() + assert "pool_metrics" in d + assert d["pool_metrics"]["small_cap"]["ic_abs_mean"] == 0.08 + + restored = Factor.from_dict(d) + assert restored.pool_metrics["small_cap"]["icir"] == 0.8 +``` + +**Step 2: 运行测试确认失败** + +```bash +uv run pytest src/factorminer/tests/test_library.py -v +``` + +Expected: FAIL (`pool_metrics` unexpected keyword) + +**Step 3: 最小实现** + +在 `src/factorminer/core/factor_library.py` 的 `Factor` 中: + +1. 新增字段: +```python + pool_metrics: Dict[str, dict] = field(default_factory=dict) +``` + +2. 修改 `to_dict`: +```python + "research_metrics": self.research_metrics, + "provenance": self.provenance, + "metadata": self.metadata, + "pool_metrics": self.pool_metrics, +``` + +3. 修改 `from_dict`: +```python + research_metrics=d.get("research_metrics", {}), + provenance=d.get("provenance", {}), + metadata=d.get("metadata", {}), + pool_metrics=d.get("pool_metrics", {}), +``` + +**Step 4: 运行测试确认通过** + +```bash +uv run pytest src/factorminer/tests/test_library.py -v +``` + +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/factorminer/core/factor_library.py src/factorminer/tests/test_library.py +git commit -m "feat(factorminer): add pool_metrics to Factor dataclass" +``` + +--- + +## Task 4: 扩展 `ValidationPipeline` 多股票池评估与入库门控 + +**Files:** +- Modify: `src/factorminer/core/ralph_loop.py` +- Test: `src/factorminer/tests/test_ralph_loop.py` + +**目标:** +- `ValidationPipeline` 支持传入 `returns: Dict[str, np.ndarray]`(多股票池)或 `np.ndarray`(单市场)。 +- 对每个 `target_panel` 计算 `compute_factor_stats`。 +- 以表现最好的股票池作为 admission gate(IC 和 ICIR 满足阈值即可)。 +- `EvaluationResult.target_stats` 保存所有池子指标。 + +**注意**:`EvaluationResult` 在 `src/factorminer/core/ralph_loop.py` 第 157 行**已经定义**了 `target_stats: Dict[str, dict] = field(default_factory=dict)`,因此本 Task **无需新增字段**,只需复用并填充多股票池数据即可。 + +**Step 1: 编写失败测试** + +在 `src/factorminer/tests/test_ralph_loop.py` 中新增: + +```python +class TestValidationPipelinePools: + @pytest.fixture + def pool_pipeline(self, synthetic_data, empty_library): + data_tensor, returns = synthetic_data + pool_returns = { + "all": returns, + "sub": returns[:5, :], # 模拟子池 + } + return ValidationPipeline( + data_tensor=data_tensor, + returns=pool_returns, + library=empty_library, + ic_threshold=0.02, + fast_screen_assets=0, + ) + + def test_multi_pool_target_stats(self, pool_pipeline): + # 构造一个 deterministic 信号 + M, T = pool_pipeline.returns.shape + signals = np.random.RandomState(7).randn(M, T) + result = pool_pipeline.evaluate_candidate( + "test", "Neg($close)", fast_screen=False, signals=signals + ) + assert result.parse_ok + assert "all" in result.target_stats + # 若 sub 池包含在 target_panels 中,也应存在 + assert "sub" in result.target_stats or "paper" in result.target_stats +``` + +**Step 2: 运行测试确认失败** + +```bash +uv run pytest src/factorminer/tests/test_ralph_loop.py::TestValidationPipelinePools -v +``` + +Expected: FAIL (ValidationPipeline 不支持 dict returns 初始化) + +**Step 3: 最小实现(仅修改 RalphLoop 中的 ValidationPipeline)** + +在 `src/factorminer/core/ralph_loop.py` 的 `ValidationPipeline.__init__` 中: + +1. 替换 `returns` 处理逻辑: +```python + # 支持单市场 (np.ndarray) 或多股票池 (Dict[str, np.ndarray]) + if isinstance(returns, dict): + self.returns = returns.get("all", next(iter(returns.values()))) + self.target_panels = returns + else: + self.returns = returns + self.target_panels = target_panels or {"paper": returns} +``` + +2. 在 `evaluate_candidate` 中,将原来的 Stats 计算部分替换为: + +找到原来这段( around line 356-369 ): +```python + # Full IC statistics on all assets + stats = compute_factor_stats(signals, self.returns) + result.ic_mean = stats["ic_abs_mean"] + result.icir = stats["icir"] + result.ic_win_rate = stats["ic_win_rate"] + result.target_stats = {"paper": stats} + + if self.target_panels: + for target_name, target_returns in self.target_panels.items(): + if target_name == "paper": + continue + result.target_stats[target_name] = compute_factor_stats( + signals, target_returns + ) +``` + +替换为: +```python + # 对所有 target_panels(含股票池)计算指标 + all_stats: Dict[str, dict] = {} + for panel_name, panel_returns in self.target_panels.items(): + # 当 panel 是子集时,signals 需要裁剪到对应维度 + panel_signals = signals + if panel_returns.shape[0] < signals.shape[0]: + panel_signals = signals[: panel_returns.shape[0], :] + all_stats[panel_name] = compute_factor_stats(panel_signals, panel_returns) + + # 选取表现最好的股票池作为 admission gate + best_panel_name = max(all_stats, key=lambda k: all_stats[k]["ic_abs_mean"]) + best_stats = all_stats[best_panel_name] + + result.ic_mean = best_stats["ic_abs_mean"] + result.icir = best_stats["icir"] + result.ic_win_rate = best_stats["ic_win_rate"] + result.target_stats = all_stats +``` + +注意:这里有一个维度匹配问题。对于真实的 `StockPoolRegistry`,`filter_signals` 已经把 `signals` 裁剪到子池维度。但在 `ValidationPipeline` 中,`signals` 是全市场维度,而 `target_panels` 中的某些 panel(如 "sub")可能是子集维度。 + +**更好的设计**:`RalphLoop` 应该把 `signals` 和 `returns` 的维度对齐。实际上,在 `LocalFactorEvaluator` 的设计中,子池的 `signals` 和 `returns` 都应该是裁剪后的。但 `ValidationPipeline` 并不直接调用 evaluator 来按池裁剪 signals;signals 是在 `evaluate_candidate` 中计算的(全市场),而 `target_panels` 可能包含不同维度的 returns。 + +**修正思路**:在 `main.py` 中,当使用 stock pools 时,`LocalFactorEvaluator` 不再在 `RalphLoop` 层面使用——实际上 `RalphLoop` 的 `evaluator` 仍然是 `LocalFactorEvaluator`,它在 `evaluate_candidate` 中计算全市场 signals,而 `target_panels` 包含各池的 returns。 + +等等,在 `evaluate_candidate` 中,signals 是全市场 (M, T)。如果 `target_panels["small_cap"]` 是裁剪后的 (M_small, T),我们需要对 signals 也做同样的裁剪。但 `ValidationPipeline` 目前不知道每个 pool 对应的 asset mask。 + +所以更好的方案是:**`ValidationPipeline` 也接收掩码信息**,或者 **`target_panels` 全部是 (M, T) 但含 NaN**。不,最简单的方案是: + +**在 `main.py` 中,传入的 `returns` dict 的值都保持全市场 (M, T) 维度,只有对应池子的行有有效值,其余为 NaN。** + +这样 `compute_factor_stats` 自然的 NaN 处理机制会自动忽略非池子股票。 + +怎么做?修改 `LocalFactorEvaluator.evaluate_returns_by_pool`: + +```python + def evaluate_returns_by_pool(self, periods: int = 1) -> Dict[str, np.ndarray]: + returns_all = self.evaluate_returns(periods=periods) + result: Dict[str, np.ndarray] = {"all": returns_all} + + if self.stock_pool_registry is None: + return result + + codes = self.get_asset_codes() + req_cols = self.stock_pool_registry.get_required_columns() + metadata = self._get_metadata_df(req_cols) + self.stock_pool_registry.build_masks(codes, metadata_df=metadata) + + for name in self.stock_pool_registry.get_pool_names(): + if name == "all": + continue + mask = self.stock_pool_registry.masks[name] + pool_returns = np.full_like(returns_all, np.nan) + pool_returns[mask, :] = returns_all[mask, :] + result[name] = pool_returns + + return result +``` + +这样所有 `target_panels` 都是 `(M, T)` 维度,`compute_factor_stats` 会正确处理 NaN。太好了! + +所以 **Step 3 的最小实现** 如下: + +在 `ValidationPipeline.__init__` 中: +```python + if isinstance(returns, dict): + self.returns = returns.get("all", next(iter(returns.values()))) + self.target_panels = returns + else: + self.returns = returns + self.target_panels = target_panels or {"paper": returns} +``` + +Stats 计算部分替换为: +```python + # 对所有 target_panels 计算指标(支持多股票池) + all_stats: Dict[str, dict] = {} + for panel_name, panel_returns in self.target_panels.items(): + all_stats[panel_name] = compute_factor_stats(signals, panel_returns) + + best_panel_name = max(all_stats, key=lambda k: all_stats[k]["ic_abs_mean"]) + best_stats = all_stats[best_panel_name] + + result.ic_mean = best_stats["ic_abs_mean"] + result.icir = best_stats["icir"] + result.ic_win_rate = best_stats["ic_win_rate"] + result.target_stats = all_stats +``` + +**同时**,需要把 `evaluate_returns_by_pool` 的实现按上面的 NaN-padding 方式修正(放到 Task 2 的后续补丁里,或直接在这里修正)。由于这是计划文档,我可以在 Task 2 的 Step 3 中先按 NaN-padding 写法,或者在这里指出需要回滚修改 Task 2。 + +更简单:在 Task 2 的初次实现中,我就让 `evaluate_returns_by_pool` 返回 NaN-padded 的全市场矩阵。这样 Task 4 就无需处理维度不匹配问题。 + +让我在 Task 2 的代码中已经写对了:`result[name] = self.stock_pool_registry.filter_signals(returns_all, name)` 这返回的是 `(M_small, T)`,确实不对。需要在实际执行时修改 Task 2 的实现为 NaN-padding 版本。 + +由于这是计划,我会在 Task 2 中写 NaN-padding 版本。 + +继续 Task 4 的 Step 3-5。 + +**Step 4: 运行测试确认通过** + +```bash +uv run pytest src/factorminer/tests/test_ralph_loop.py::TestValidationPipelinePools -v +``` + +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/factorminer/core/ralph_loop.py src/factorminer/tests/test_ralph_loop.py +git commit -m "feat(factorminer): ValidationPipeline admission by best-performing stock pool" +``` + +--- + +## Task 5: `RalphLoop` 接收多股票池收益率并保存 `pool_metrics` + +**Files:** +- Modify: `src/factorminer/core/ralph_loop.py` +- Test: `src/factorminer/tests/test_ralph_loop.py` + +**Step 1: 编写失败测试** + +```python +class TestRalphLoopPools: + def test_loop_accepts_dict_returns(self, test_config, synthetic_data, mock_provider, tmp_dir): + data_tensor, returns = synthetic_data + pool_returns = {"all": returns, "sub": returns.copy()} + test_config.output_dir = tmp_dir + test_config.max_iterations = 1 + + loop = RalphLoop( + config=test_config, + returns=pool_returns, + llm_provider=mock_provider, + ) + library = loop.run(max_iterations=1, target_size=200) + assert isinstance(library, FactorLibrary) + assert loop.returns is returns # "all" 被当成默认 +``` + +**Step 2: 运行测试确认失败** + +```bash +uv run pytest src/factorminer/tests/test_ralph_loop.py::TestRalphLoopPools -v +``` + +Expected: FAIL (RalphLoop 初始化不支持 dict returns) + +**Step 3: 最小实现** + +在 `RalphLoop.__init__` 中: + +```python + # 支持多股票池收益率字典 + if isinstance(returns, dict): + self.returns = returns.get("all", next(iter(returns.values()))) + self.target_panels = returns + else: + self.returns = returns + self.target_panels = getattr(config, "target_panels", None) +``` + +然后 `ValidationPipeline` 初始化时 `target_panels=self.target_panels`。 + +在 `RalphLoop._update_library` 中,修改创建 `Factor` 的两处代码,加入 `pool_metrics=result.target_stats`: + +```python + new_factor = Factor( + id=0, + name=result.factor_name, + formula=result.formula, + category=self._infer_category(result.formula), + ic_mean=result.ic_mean, + icir=result.icir, + ic_win_rate=result.ic_win_rate, + max_correlation=result.max_correlation, + batch_number=self.iteration, + signals=result.signals, + research_metrics=result.score_vector or {}, + pool_metrics=result.target_stats, + ) +``` + +两处都要加(replace branch 和 direct admission branch)。 + +**Step 4: 运行测试确认通过** + +```bash +uv run pytest src/factorminer/tests/test_ralph_loop.py::TestRalphLoopPools -v +``` + +Expected: PASS + +**Step 5: Commit** + +```bash +git add src/factorminer/core/ralph_loop.py src/factorminer/tests/test_ralph_loop.py +git commit -m "feat(factorminer): RalphLoop supports dict returns and persists pool_metrics" +``` + +--- + +## Task 6: 创建用户可配置的股票池定义文件 + +**Files:** +- Create: `src/factorminer/stock_pools.py` + +**Step 1: 编写文件** + +```python +"""用户可配置的股票池定义。 + +在此文件中定义所有需要在 FactorMiner 中评估的股票池。 +示例包含了全市场、创业板、科创板、北交所、小市值等常见股票池。 + +参考:`src/experiment/common.py` 中的 `stock_pool_filter` 设计。 +""" + +import polars as pl + +from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry + + +def get_default_stock_pools() -> StockPoolRegistry: + """返回默认的股票池注册表。 + + 用户可在此函数中增删股票池,或编写自己的 `get_xxx_pools()` 函数 + 并在 `main.py` 的 `RUN_CONFIG` / 命令行参数中指定使用。 + """ + registry = StockPoolRegistry() + + # 1. 全市场 + registry.add_pool("all", lambda df: pl.Series([True] * len(df))) + + # 2. 创业板 (代码以 300 开头) + registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300")) + + # 3. 科创板 (代码以 688 开头) + registry.add_pool("star", lambda df: df["ts_code"].str.starts_with("688")) + + # 4. 北交所 (代码以 8 或 4 开头) + registry.add_pool( + "bse", + lambda df: ( + df["ts_code"].str.starts_with("8") | df["ts_code"].str.starts_with("4") + ), + ) + + # 5. 主板(排除创业板/科创板/北交所) + registry.add_pool( + "main_board", + lambda df: ( + ~df["ts_code"].str.starts_with("300") + & ~df["ts_code"].str.starts_with("688") + & ~df["ts_code"].str.starts_with("8") + & ~df["ts_code"].str.starts_with("4") + & ~df["ts_code"].str.starts_with("9") + ), + ) + + # 6. 小微盘(示例:每日截面市值最小的 1000 只股票) + # 注意:该过滤器依赖 daily_basic.total_mv,因此需要 required_columns + def _small_cap_filter(df: pl.DataFrame) -> pl.Series: + if "total_mv" not in df.columns: + # 若缺失数据则全部排除(安全降级) + return pl.Series([False] * len(df)) + n = min(1000, len(df)) + small_codes = df.sort("total_mv").head(n)["ts_code"] + return df["ts_code"].is_in(small_codes) + + registry.add_pool( + "small_cap", + _small_cap_filter, + required_columns=["total_mv"], + ) + + return registry +``` + +**Step 2: Commit** + +```bash +git add src/factorminer/pool_definitions.py +git commit -m "feat(factorminer): add user-configurable stock pool definitions" +``` + +--- + +## Task 7: 改造 `main.py` 支持股票池配置 + +**Files:** +- Modify: `src/factorminer/main.py` + +**Step 1: 导入依赖** + +在文件顶部加入: + +```python +from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry +from src.factorminer.pool_definitions import get_default_stock_pools +``` + +**Step 2: 修改 `RUN_CONFIG`** + +在 `RUN_CONFIG` 中新增 `stock_pools` 段: +```python + # 股票池配置 + "stock_pools": { + "enabled": True, + "provider": "default", # "default" 使用 pool_definitions.py 中的 get_default_stock_pools + }, +``` + +**Step 3: 新增辅助函数 `_build_stock_pool_registry`** + +放在 `_build_core_mining_config` 附近: + +```python +def _build_stock_pool_registry(run_cfg: dict) -> Optional[StockPoolRegistry]: + """根据 RUN_CONFIG 构建股票池注册表。""" + pool_cfg = run_cfg.get("stock_pools", {}) + if not pool_cfg.get("enabled", False): + return None + + provider = pool_cfg.get("provider", "default") + if provider == "default": + return get_default_stock_pools() + + # 未来可扩展自定义 provider 路径 + raise ValueError(f"不支持的股票池 provider: {provider}") +``` + +**Step 4: 修改 `main()` 中的 evaluator 和 returns 逻辑** + +找到这段代码(原 214-226 行附近): +```python + evaluator = LocalFactorEvaluator( + start_date=start_date, + end_date=end_date, + stock_codes=stock_codes, + ) + returns = evaluator.evaluate_returns(periods=1) +``` + +替换为: +```python + stock_pool_registry = _build_stock_pool_registry(run_cfg) + if stock_pool_registry is not None: + print(f"[main] 已启用股票池评估: {stock_pool_registry.get_pool_names()}") + + evaluator = LocalFactorEvaluator( + start_date=start_date, + end_date=end_date, + stock_codes=stock_codes, + stock_pool_registry=stock_pool_registry, + ) + + if stock_pool_registry is not None: + returns = evaluator.evaluate_returns_by_pool(periods=1) + print( + f"[main] 本地数据范围: {start_date} ~ {end_date}, " + f"各股票池资产数: {{k: v.shape[0] for k, v in returns.items()}}" + ) + else: + returns = evaluator.evaluate_returns(periods=1) + print( + f"[main] 本地数据范围: {start_date} ~ {end_date}, " + f"returns shape: {returns.shape}" + ) +``` + +**Step 5: 确保 `evaluator` 正确传入 `RalphLoop` / `HelixLoop`** + +检查原代码中 `LoopCls` 初始化是否已传入 `evaluator=evaluator`,如果是(当前代码已有),则无需修改。确认 `resume_from` 路径也传入了 `evaluator=evaluator`。 + +**Step 6: 运行集成测试** + +```bash +uv run pytest src/factorminer/tests/test_ralph_loop.py -v +``` + +Expected: PASS(或只出现与 stock pool 无关的既有失败) + +**Step 7: Commit** + +```bash +git add src/factorminer/main.py +git commit -m "feat(factorminer): integrate multi-stock-pool evaluation into main entrypoint" +``` + +--- + +## Task 8: 修复 Task 2 中 `evaluate_returns_by_pool` 的维度对齐问题 + +**Files:** +- Modify: `src/factorminer/evaluation/local_engine.py` + +在 Task 2 的初次实现中,`evaluate_returns_by_pool` 使用了 `filter_signals`,这会返回裁剪后的子矩阵。为了让 `ValidationPipeline` 中所有 `target_panels` 保持统一的 `(M, T)` 维度(从而无需修改 `compute_factor_stats` 的调用方式),**必须改为 NaN-padding 版本**。 + +**Step 1: 修改 `evaluate_returns_by_pool`** + +```python + def evaluate_returns_by_pool( + self, + periods: int = 1, + ) -> Dict[str, np.ndarray]: + """计算各股票池的收益率矩阵。 + + 返回的每个矩阵维度均为 (M_all, T),但非股票池内的资产行被填充为 NaN, + 以便下游 ValidationPipeline 统一处理。 + + Returns: + {pool_name: (M_all, T) returns 矩阵} 字典。 + """ + returns_all = self.evaluate_returns(periods=periods) + result: Dict[str, np.ndarray] = {"all": returns_all} + + if self.stock_pool_registry is None: + return result + + codes = self.get_asset_codes() + req_cols = self.stock_pool_registry.get_required_columns() + metadata = self._get_metadata_df(req_cols) + self.stock_pool_registry.build_masks(codes, metadata_df=metadata) + + for name in self.stock_pool_registry.get_pool_names(): + if name == "all": + continue + mask = self.stock_pool_registry.masks[name] + pool_returns = np.full_like(returns_all, np.nan, dtype=np.float64) + pool_returns[mask, :] = returns_all[mask, :] + result[name] = pool_returns + + return result +``` + +**Step 2: 更新测试** + +修改 `src/factorminer/tests/test_local_engine.py` 中的断言: + +```python + assert pool_returns["growth"].shape == (3, 1) # 维度保持全市场 + np.testing.assert_array_equal( + pool_returns["growth"], + [[np.nan], [0.02], [np.nan]], + ) +``` + +**Step 3: 运行测试** + +```bash +uv run pytest src/factorminer/tests/test_local_engine.py -v +``` + +Expected: PASS + +**Step 4: Commit** + +```bash +git add src/factorminer/evaluation/local_engine.py src/factorminer/tests/test_local_engine.py +git commit -m "fix(factorminer): pad out-of-pool returns with NaN to keep consistent matrix shape" +``` + +--- + +## Task 9: 更新 `test_evaluation.py` 验证多股票池 `compute_factor_stats` 的 NaN 行为 + +**Files:** +- Modify: `src/factorminer/tests/test_evaluation.py` + +**Step 1: 新增测试用例** + +```python +class TestFactorStatsPools: + def test_factor_stats_with_nan_rows(self, rng): + """验证 NaN 行能被 compute_factor_stats 正确忽略(用于多股票池场景)。""" + M, T = 30, 40 + signals = rng.normal(0, 1, (M, T)) + returns = rng.normal(0, 0.01, (M, T)) + # 将一半资产设为 NaN,模拟非股票池内资产 + signals[15:, :] = np.nan + returns[15:, :] = np.nan + stats = compute_factor_stats(signals, returns) + assert "ic_mean" in stats + assert stats["n_periods"] == T # 每期仍有足够有效样本 +``` + +**Step 2: 运行测试** + +```bash +uv run pytest src/factorminer/tests/test_evaluation.py::TestFactorStatsPools -v +``` + +Expected: PASS + +**Step 3: Commit** + +```bash +git add src/factorminer/tests/test_evaluation.py +git commit -m "test(factorminer): ensure compute_factor_stats handles NaN rows for stock pools" +``` + +--- + +## Task 10: 全量测试与回归验证 + +**Step 1: 运行 factorminer 全部测试** + +```bash +uv run pytest src/factorminer/tests/ -v +``` + +Expected: 所有既有测试通过,新增测试全部 PASS。若出现失败,定位并修复。 + +**Step 2: 运行核心项目测试(确保没有破坏 factors / experiment)** + +```bash +uv run pytest tests/test_factor_engine.py tests/test_factor_integration.py -v +``` + +Expected: PASS + +**Step 3: Commit(如仅测试通过,无代码改动可跳过)** + +--- + +## 附录:用户使用说明 + +### 如何添加自定义股票池? + +编辑 `src/factorminer/stock_pools.py`: + +```python +def get_default_stock_pools() -> StockPoolRegistry: + registry = StockPoolRegistry() + # ... 既有池子 ... + + # 自定义:只保留上证 50 成分股(示例) + registry.add_pool( + "sz50", + lambda df: df["ts_code"].is_in(["600519.SH", "600036.SH", ...]), + ) + + return registry +``` + +### 如何禁用股票池功能? + +在 `main.py` 的 `RUN_CONFIG` 中: + +```python + "stock_pools": { + "enabled": False, + }, +``` + +### 入库规则 + +- 因子在 **任一** 配置的股票池中 IC_mean >= `ic_threshold` 且 ICIR >= `icir_threshold`,即可通过 Stage 1。 +- 相关性检查仍在 **全市场 signals** 上进行,与现有逻辑保持一致。 +- 最终入库时,`Factor.pool_metrics` 会记录 **所有股票池** 的 `compute_factor_stats` 完整指标。 diff --git a/src/factorminer/core/factor_library.py b/src/factorminer/core/factor_library.py index 5549d5d..df14988 100644 --- a/src/factorminer/core/factor_library.py +++ b/src/factorminer/core/factor_library.py @@ -42,6 +42,7 @@ class Factor: batch_number: int # Which mining batch admitted this factor admission_date: str = "" signals: Optional[np.ndarray] = field(default=None, repr=False) # (M, T) + pool_metrics: Dict[str, dict] = field(default_factory=dict) research_metrics: dict = field(default_factory=dict) provenance: dict = field(default_factory=dict) metadata: dict = field(default_factory=dict) @@ -50,6 +51,20 @@ class Factor: if not self.admission_date: self.admission_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + @staticmethod + def _sanitize_pool_metrics(pool_metrics: Dict[str, dict]) -> Dict[str, dict]: + sanitized: Dict[str, dict] = {} + for pool_name, stats in pool_metrics.items(): + sanitized[pool_name] = {} + for key, value in stats.items(): + if key == "ic_series": + continue + if isinstance(value, np.ndarray): + sanitized[pool_name][key] = value.tolist() + else: + sanitized[pool_name][key] = value + return sanitized + def to_dict(self) -> dict: """Serialize to a JSON-compatible dictionary (excludes signals).""" return { @@ -63,6 +78,7 @@ class Factor: "max_correlation": self.max_correlation, "batch_number": self.batch_number, "admission_date": self.admission_date, + "pool_metrics": self._sanitize_pool_metrics(self.pool_metrics), "research_metrics": self.research_metrics, "provenance": self.provenance, "metadata": self.metadata, @@ -82,6 +98,7 @@ class Factor: max_correlation=d["max_correlation"], batch_number=d["batch_number"], admission_date=d.get("admission_date", ""), + pool_metrics=d.get("pool_metrics", {}), research_metrics=d.get("research_metrics", {}), provenance=d.get("provenance", {}), metadata=d.get("metadata", {}), diff --git a/src/factorminer/core/helix_loop.py b/src/factorminer/core/helix_loop.py index f46023c..632d8b4 100644 --- a/src/factorminer/core/helix_loop.py +++ b/src/factorminer/core/helix_loop.py @@ -18,7 +18,7 @@ import logging import re import time from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np @@ -229,7 +229,7 @@ class HelixLoop(RalphLoop): def __init__( self, config: Any, - returns: np.ndarray, + returns: Union[np.ndarray, Dict[str, np.ndarray]], llm_provider: Optional[LLMProvider] = None, memory: Optional[ExperienceMemory] = None, library: Optional[FactorLibrary] = None, diff --git a/src/factorminer/core/ralph_loop.py b/src/factorminer/core/ralph_loop.py index 09a215d..7c8dbf3 100644 --- a/src/factorminer/core/ralph_loop.py +++ b/src/factorminer/core/ralph_loop.py @@ -257,9 +257,13 @@ class ValidationPipeline: evaluator: Any = None, ) -> None: self.data_tensor = data_tensor # (M, T, F) - self.returns = returns # (M, T) + if isinstance(returns, dict): + self.returns = returns.get("all", next(iter(returns.values()))) + self.target_panels = returns + else: + self.returns = returns # (M, T) + self.target_panels = target_panels or {"paper": returns} self.evaluator = evaluator - self.target_panels = target_panels or {"paper": returns} self.target_horizons = target_horizons or {"paper": 1} self.library = library or FactorLibrary( correlation_threshold=0.5, @@ -353,20 +357,18 @@ class ValidationPipeline: result.stage_passed = 0 return result - # Full IC statistics on all assets - stats = compute_factor_stats(signals, self.returns) - result.ic_mean = stats["ic_abs_mean"] - result.icir = stats["icir"] - result.ic_win_rate = stats["ic_win_rate"] - result.target_stats = {"paper": stats} + # Full IC statistics on all target panels (stock pools) + all_stats: Dict[str, dict] = {} + for panel_name, panel_returns in self.target_panels.items(): + all_stats[panel_name] = compute_factor_stats(signals, panel_returns) - if self.target_panels: - for target_name, target_returns in self.target_panels.items(): - if target_name == "paper": - continue - result.target_stats[target_name] = compute_factor_stats( - signals, target_returns - ) + best_panel_name = max(all_stats, key=lambda k: all_stats[k]["ic_abs_mean"]) + best_stats = all_stats[best_panel_name] + + result.ic_mean = best_stats["ic_abs_mean"] + result.icir = best_stats["icir"] + result.ic_win_rate = best_stats["ic_win_rate"] + result.target_stats = all_stats score_vector_obj = None if self._research_enabled(): @@ -737,7 +739,7 @@ class RalphLoop: def __init__( self, config: Any, - returns: np.ndarray, + returns: Union[np.ndarray, Dict[str, np.ndarray]], llm_provider: Optional[LLMProvider] = None, memory: Optional[ExperienceMemory] = None, library: Optional[FactorLibrary] = None, @@ -767,7 +769,12 @@ class RalphLoop: the legacy data_tensor path. """ self.config = config - self.returns = returns + if isinstance(returns, dict): + self.returns = returns.get("all", next(iter(returns.values()))) + self.target_panels = returns + else: + self.returns = returns + self.target_panels = getattr(config, "target_panels", None) self.evaluator = evaluator self.checkpoint_interval = checkpoint_interval @@ -783,9 +790,9 @@ class RalphLoop: prompt_builder=PromptBuilder(), ) self.pipeline = ValidationPipeline( - data_tensor=np.empty((returns.shape[0], returns.shape[1], 1)), - returns=returns, - target_panels=getattr(config, "target_panels", None), + data_tensor=np.empty((self.returns.shape[0], self.returns.shape[1], 1)), + returns=self.returns, + target_panels=self.target_panels, target_horizons=getattr(config, "target_horizons", None), library=self.library, ic_threshold=getattr(config, "ic_threshold", 0.04), @@ -1082,6 +1089,14 @@ class RalphLoop: # Handle replacement if result.replaced is not None: old_id = result.replaced + if result.target_stats: + target_stats_clean = {} + for pool_name, stats in result.target_stats.items(): + target_stats_clean[pool_name] = { + k: v for k, v in stats.items() if k != "ic_series" + } + else: + target_stats_clean = {} new_factor = Factor( id=0, # Will be reassigned by library name=result.factor_name, @@ -1093,6 +1108,7 @@ class RalphLoop: max_correlation=result.max_correlation, batch_number=self.iteration, signals=result.signals, + pool_metrics=target_stats_clean, research_metrics=result.score_vector or {}, ) try: @@ -1110,6 +1126,14 @@ class RalphLoop: ) else: # Direct admission + if result.target_stats: + target_stats_clean = {} + for pool_name, stats in result.target_stats.items(): + target_stats_clean[pool_name] = { + k: v for k, v in stats.items() if k != "ic_series" + } + else: + target_stats_clean = {} factor = Factor( id=0, # Will be reassigned name=result.factor_name, @@ -1121,6 +1145,7 @@ class RalphLoop: max_correlation=result.max_correlation, batch_number=self.iteration, signals=result.signals, + pool_metrics=target_stats_clean, research_metrics=result.score_vector or {}, ) self.library.admit_factor(factor) @@ -1435,7 +1460,7 @@ class RalphLoop: cls, checkpoint_path: str, config: Any, - returns: np.ndarray, + returns: Union[np.ndarray, Dict[str, np.ndarray]], llm_provider: Optional[LLMProvider] = None, **kwargs: Any, ) -> "RalphLoop": diff --git a/src/factorminer/evaluation/local_engine.py b/src/factorminer/evaluation/local_engine.py index c1af262..a163b49 100644 --- a/src/factorminer/evaluation/local_engine.py +++ b/src/factorminer/evaluation/local_engine.py @@ -19,6 +19,7 @@ import numpy as np import polars as pl from src.factors import FactorEngine +from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry class LocalFactorEvaluator: @@ -40,6 +41,7 @@ class LocalFactorEvaluator: start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, + stock_pool_registry: Optional[StockPoolRegistry] = None, ) -> None: """初始化评估器。 @@ -47,11 +49,14 @@ class LocalFactorEvaluator: start_date: 计算开始日期,YYYYMMDD 格式 end_date: 计算结束日期,YYYYMMDD 格式 stock_codes: 可选的股票代码列表,None 表示全量 + stock_pool_registry: 股票池注册表,用于多股票池评估 """ self.start_date = start_date self.end_date = end_date self.stock_codes = stock_codes self.engine = FactorEngine() + self.stock_pool_registry = stock_pool_registry + self._asset_codes: Optional[List[str]] = None def evaluate( self, @@ -166,6 +171,76 @@ class LocalFactorEvaluator: print(f"[ERROR] 计算收益率矩阵失败: {e}") raise + def get_asset_codes(self) -> List[str]: + """返回最近一次 evaluate/evaluate_returns 中的资产代码列表。 + + Returns: + 按字母序排列的股票代码列表 (M,)。 + """ + if self._asset_codes is None: + raise RuntimeError("请先调用 evaluate() 或 evaluate_returns()") + return self._asset_codes + + def _get_metadata_df(self, columns: List[str]) -> Optional[pl.DataFrame]: + """拉取构建股票池所需的元数据(默认取 end_date 最新截面)。 + + Args: + columns: 需要的额外列名列表。 + + Returns: + 包含 ts_code 和 columns 的 DataFrame,若失败则返回 None。 + """ + if not columns: + return None + try: + df = self.engine.router._load_table( + table_name="daily_basic", + columns=columns, + start_date=self.end_date, + end_date=self.end_date, + stock_codes=self.stock_codes, + ) + if "trade_date" in df.columns: + df = df.sort("trade_date", descending=True) + df = df.unique(subset=["ts_code"], maintain_order=True) + return df.select(["ts_code"] + columns) + except Exception as exc: + print(f"[WARN] 拉取股票池元数据失败: {exc}") + return None + + def evaluate_returns_by_pool(self, periods: int = 1) -> Dict[str, np.ndarray]: + """计算各股票池的收益率矩阵。 + + 先计算全市场收益率矩阵,然后根据注册的股票池掩码生成子池矩阵。 + 未入选该池的资产在所有时间上的收益率标记为 NaN。 + + Args: + periods: 计算 N 日后的收益率,默认 1。 + + Returns: + {pool_name: (M, T) np.ndarray} 字典,包含所有注册的股票池和 "all"。 + """ + returns_all = self.evaluate_returns(periods=periods) + result: Dict[str, np.ndarray] = {"all": returns_all} + + if self.stock_pool_registry is None: + return result + + codes = self.get_asset_codes() + req_cols = self.stock_pool_registry.get_required_columns() + metadata = self._get_metadata_df(req_cols) + self.stock_pool_registry.build_masks(codes, metadata_df=metadata) + + for name in self.stock_pool_registry.get_pool_names(): + if name == "all": + continue + mask = self.stock_pool_registry.masks[name] + pool_returns = np.full_like(returns_all, np.nan, dtype=np.float64) + pool_returns[mask, :] = returns_all[mask, :] + result[name] = pool_returns + + return result + def _pivot_to_matrix( self, df: pl.DataFrame, @@ -191,6 +266,8 @@ class LocalFactorEvaluator: # 获取时间戳和股票代码的唯一值(已排序) timestamps = df["trade_date"].unique().sort() asset_codes = df["ts_code"].unique().sort() + if self._asset_codes is None: + self._asset_codes = asset_codes.to_list() n_assets = len(asset_codes) n_times = len(timestamps) diff --git a/src/factorminer/evaluation/stock_pool_registry.py b/src/factorminer/evaluation/stock_pool_registry.py new file mode 100644 index 0000000..297c378 --- /dev/null +++ b/src/factorminer/evaluation/stock_pool_registry.py @@ -0,0 +1,150 @@ +"""股票池注册表:支持配置化股票池筛选与掩码生成。""" + +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional + +import numpy as np +import polars as pl + + +@dataclass +class StockPoolDefinition: + """股票池定义,参考 experiment/common.py 中的 filter 模式。""" + + name: str + filter_func: Callable[[pl.DataFrame], pl.Series | pl.Expr] + required_columns: List[str] = field(default_factory=list) + description: str = "" + + +class StockPoolRegistry: + """管理多个股票池的定义,并为 (M,) 资产列表生成布尔掩码。""" + + def __init__( + self, + definitions: Optional[List[StockPoolDefinition]] = None, + pre_filters: Optional[ + List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] + ] = None, + ) -> None: + self.pre_filters: List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] = ( + pre_filters or [] + ) + self.pools: Dict[str, dict] = {} + self.masks: Dict[str, np.ndarray] = {} + self._resolved = False + if definitions: + for d in definitions: + self.add_pool(d.name, d.filter_func, d.required_columns) + + @classmethod + def from_definitions( + cls, + definitions: List[StockPoolDefinition], + pre_filters: Optional[ + List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] + ] = None, + ) -> "StockPoolRegistry": + """通过定义列表快速创建注册表。""" + return cls(definitions=definitions, pre_filters=pre_filters) + + def add_pool( + self, + name: str, + filter_func: Callable[[pl.DataFrame], pl.Series | pl.Expr], + required_columns: Optional[List[str]] = None, + ) -> None: + """注册一个股票池。 + + Args: + name: 股票池名称,如 "small_cap"。 + filter_func: 接收包含 ts_code 和 required_columns 的 DataFrame, + 返回布尔 Series 或 Expr。 + required_columns: filter_func 额外需要的列名列表。 + """ + self.pools[name] = { + "filter_func": filter_func, + "required_columns": required_columns or [], + } + self._resolved = False + + @staticmethod + def _eval_filter_result( + df: pl.DataFrame, + result: pl.Series | pl.Expr, + context: str, + ) -> np.ndarray: + """将 filter_func 的返回值统一转换为 numpy bool 数组。""" + if isinstance(result, pl.Series): + mask_series = result + elif isinstance(result, pl.Expr): + mask_series = df.select(result.alias("_mask")).to_series() + else: + raise TypeError( + f"{context} 的 filter_func 必须返回 pl.Series 或 pl.Expr," + f"实际返回 {type(result)}" + ) + return mask_series.to_numpy().astype(bool) + + def build_masks( + self, + asset_codes: List[str], + metadata_df: Optional[pl.DataFrame] = None, + ) -> Dict[str, np.ndarray]: + """为所有注册的股票池生成布尔掩码。 + + Args: + asset_codes: 按矩阵顺序排列的股票代码列表 (M,)。 + metadata_df: 包含 extra columns 的 metadata,可选。 + + Returns: + {pool_name: bool_array_of_shape_(M,)} 字典。 + """ + base_df = pl.DataFrame({"ts_code": asset_codes}) + df = base_df + if metadata_df is not None: + df = base_df.join(metadata_df, on="ts_code", how="left") + + base_mask = np.ones(len(asset_codes), dtype=bool) + for pre_filter in self.pre_filters: + result = pre_filter(df) + base_mask &= self._eval_filter_result(df, result, "pre_filter") + + self.masks = {} + for name, cfg in self.pools.items(): + result = cfg["filter_func"](df) + pool_mask = self._eval_filter_result(df, result, f"股票池 '{name}'") + self.masks[name] = base_mask & pool_mask + + self._resolved = True + return self.masks + + def filter_signals(self, signals: np.ndarray, pool_name: str) -> np.ndarray: + """使用已构建的掩码过滤信号矩阵。 + + Args: + signals: (M, T) 信号矩阵。 + pool_name: 股票池名称。 + + Returns: + (M_pool, T) 的子矩阵。 + """ + if not self._resolved: + raise RuntimeError("请先调用 build_masks 构建掩码") + mask = self.masks.get(pool_name) + if mask is None: + raise KeyError(f"未知的股票池: {pool_name}") + return signals[mask, :] + + def get_pool_names(self) -> List[str]: + """返回已注册的股票池名称列表。""" + return list(self.pools.keys()) + + def get_required_columns(self) -> List[str]: + """返回所有 pre_filter 和股票池所需的列名并集。""" + cols: set[str] = set() + for pre_filter in self.pre_filters: + cols.update(getattr(pre_filter, "required_columns", [])) + for cfg in self.pools.values(): + cols.update(cfg["required_columns"]) + return sorted(cols) diff --git a/src/factorminer/main.py b/src/factorminer/main.py index d2a9cb7..5412810 100644 --- a/src/factorminer/main.py +++ b/src/factorminer/main.py @@ -11,6 +11,8 @@ from typing import Any, Optional import numpy as np +import polars as pl + from src.config.settings import get_settings from src.factorminer.agent.llm_interface import create_provider, AnthropicProvider from src.factorminer.core.config import MiningConfig as CoreMiningConfig @@ -19,7 +21,10 @@ from src.factorminer.core.ralph_loop import RalphLoop from src.factorminer.core.helix_loop import HelixLoop from src.factorminer.evaluation.local_engine import LocalFactorEvaluator from src.factorminer.evaluation.significance import SignificanceConfig +from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry +from src.factorminer.pool_definitions import get_default_stock_pool_definitions from src.factorminer.utils.config import load_config +from src.training.components.filters import STFilter RUN_CONFIG: dict = { # 全局开关 @@ -35,9 +40,16 @@ RUN_CONFIG: dict = { "stock_codes": None, # 可选股票列表,None 表示全量 # 种子库 "seed_paper_library": True, # 是否预加载 110 Paper Factors 作为种子库 + # 股票池配置(支持多股票池评估) + "stock_pools": { + "enabled": True, # 是否启用多股票池 + "provider": "default", # "default" 使用 pool_definitions.py 中的默认定义 + "use_st_filter": True, # 是否在构建股票池前过滤 ST 股票 + "pools": ["all", "small_cap"], # 要启用的股票池名称列表,None 表示全部启用 + }, # 挖掘配置(覆盖 default.yaml 中的同名项) "mining": { - "max_iterations": 3, + "max_iterations": 1, "target_library_size": 10, "correlation_threshold": 0.50, "ic_threshold": 0.04, @@ -139,6 +151,54 @@ def _build_core_mining_config(run_cfg: dict) -> CoreMiningConfig: return cfg +def _build_st_pre_filter(st_filter: STFilter, end_date: str): + """将 STFilter 包装为 StockPoolRegistry 可接受的 pre-filter 函数。""" + + def _pre_filter(df: pl.DataFrame) -> pl.Series: + try: + if "trade_date" not in df.columns: + df = df.with_columns(pl.lit(end_date).alias("trade_date")) + filtered = st_filter.filter(df) + return df["ts_code"].is_in(filtered["ts_code"]) + except Exception as exc: + print(f"[WARN] ST pre-filter 失败: {exc}") + return pl.Series([True] * len(df)) + + return _pre_filter + + +def _build_stock_pool_registry( + run_cfg: dict, + st_filter: Optional[STFilter] = None, + end_date: str = "20231231", +) -> Optional[StockPoolRegistry]: + """根据 RUN_CONFIG 构建股票池注册表。""" + pool_cfg = run_cfg.get("stock_pools", {}) + if not pool_cfg.get("enabled", False): + return None + + pre_filters = [] + if pool_cfg.get("use_st_filter", False) and st_filter is not None: + pre_filters.append(_build_st_pre_filter(st_filter, end_date)) + + provider = pool_cfg.get("provider", "default") + if provider == "default": + definitions = get_default_stock_pool_definitions() + pool_names = pool_cfg.get("pools") + if pool_names is not None: + definitions = [d for d in definitions if d.name in pool_names] + if not definitions: + raise ValueError( + f"股票池配置错误: pools={pool_names} 未匹配到任何默认股票池定义" + ) + missing = set(pool_names) - {d.name for d in definitions} + if missing: + print(f"[WARN] 以下股票池未在默认定义中找到,已忽略: {sorted(missing)}") + return StockPoolRegistry.from_definitions(definitions, pre_filters=pre_filters) + + raise ValueError(f"不支持的股票池 provider: {provider}") + + def _build_helix_kwargs(run_cfg: dict) -> dict: """从 RUN_CONFIG 构建 HelixLoop 需要的 Phase 2 扩展配置。""" helix = run_cfg.get("helix", {}) @@ -215,15 +275,34 @@ def main(config: dict | None = None) -> None: end_date = run_cfg.get("end_date", "20201231") stock_codes = run_cfg.get("stock_codes") + # 4.1 创建 evaluator(先实例化才能拿到 router) evaluator = LocalFactorEvaluator( start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) - returns = evaluator.evaluate_returns(periods=1) - print( - f"[main] 本地数据范围: {start_date} ~ {end_date}, returns shape: {returns.shape}" + + # 4.2 构建 STFilter 和股票池注册表 + st_filter = STFilter(data_router=evaluator.engine.router) + stock_pool_registry = _build_stock_pool_registry( + run_cfg, st_filter=st_filter, end_date=end_date ) + evaluator.stock_pool_registry = stock_pool_registry + + if stock_pool_registry is not None: + returns = evaluator.evaluate_returns_by_pool(periods=1) + for pool_name, pool_ret in returns.items(): + valid_assets = int((~np.isnan(pool_ret).all(axis=1)).sum()) + print( + f"[main] 股票池 {pool_name}: {valid_assets} 只资产, " + f"returns shape: {pool_ret.shape}" + ) + else: + returns = evaluator.evaluate_returns(periods=1) + print( + f"[main] 本地数据范围: {start_date} ~ {end_date}, " + f"returns shape: {returns.shape}" + ) # ------------------------------------------------------------------ # 5. 构建 MiningConfig diff --git a/src/factorminer/pool_definitions.py b/src/factorminer/pool_definitions.py new file mode 100644 index 0000000..4e5f68b --- /dev/null +++ b/src/factorminer/pool_definitions.py @@ -0,0 +1,171 @@ +"""用户可配置的股票池定义文件。 + +参考 `src/experiment/common.py` 中的 `stock_pool_filter` 模式, +支持通过编写 Polars filter 函数来自定义股票池。 +""" + +from typing import Callable, List + +import polars as pl + +from src.factorminer.evaluation.stock_pool_registry import StockPoolDefinition, StockPoolRegistry + + +# ============================================================================= +# 股票池大小配置 +# ============================================================================= +SMALL_CAP_N = 1000 # 小微盘股票池默认选取的股票数量 + + +# ============================================================================= +# 股票池筛选函数(与 experiment/common.py 风格一致) +# ============================================================================= +def all_market_filter(df: pl.DataFrame) -> pl.Series: + """全市场股票池。""" + return pl.Series([True] * len(df)) + + +def growth_board_filter(df: pl.DataFrame) -> pl.Series: + """创业板股票池(代码以 300 开头)。""" + return df["ts_code"].str.starts_with("300") + + +def star_board_filter(df: pl.DataFrame) -> pl.Series: + """科创板股票池(代码以 688 开头)。""" + return df["ts_code"].str.starts_with("688") + + +def bse_board_filter(df: pl.DataFrame) -> pl.Series: + """北交所股票池(代码以 8 或 4 开头)。""" + return df["ts_code"].str.starts_with("8") | df["ts_code"].str.starts_with("4") + + +def main_board_filter(df: pl.DataFrame) -> pl.Series: + """主板股票池(排除创业板、科创板、北交所)。""" + return ( + ~df["ts_code"].str.starts_with("300") + & ~df["ts_code"].str.starts_with("688") + & ~df["ts_code"].str.starts_with("8") + & ~df["ts_code"].str.starts_with("4") + ) + + +def small_cap_filter(df: pl.DataFrame, n_stocks: int = SMALL_CAP_N) -> pl.Series: + """小微盘股票池(市值最小的 n_stocks 只股票)。 + + Args: + df: 包含 ts_code 和 circ_mv 列的数据框。 + n_stocks: 选取的股票数量,默认 SMALL_CAP_N。 + + Returns: + 布尔 Series,表示是否入选小微盘股票池。 + """ + if "circ_mv" not in df.columns: + # 若缺失 circ_mv 数据则全部排除(安全降级) + return pl.Series([False] * len(df)) + n = min(n_stocks, len(df)) + small_codes = df.sort("circ_mv").head(n)["ts_code"] + return df["ts_code"].is_in(small_codes.implode()) + + +def main_board_small_cap_filter( + df: pl.DataFrame, n_stocks: int = SMALL_CAP_N +) -> pl.Series: + """主板小微盘股票池(主板中市值最小的 n_stocks 只股票)。""" + main_board = main_board_filter(df) + main_df = df.filter(main_board) + if "circ_mv" not in df.columns or len(main_df) == 0: + return pl.Series([False] * len(df)) + n = min(n_stocks, len(main_df)) + small_codes = main_df.sort("circ_mv").head(n)["ts_code"] + return df["ts_code"].is_in(small_codes.implode()) + + +def common_small_cap_filter(df: pl.DataFrame, n_stocks: int = SMALL_CAP_N) -> pl.Series: + """与 experiment/common.py 完全一致的小微盘筛选函数。 + + 筛选条件: + 1. 排除创业板(代码以 300 开头) + 2. 排除科创板(代码以 688 开头) + 3. 排除北交所(代码以 8、9 或 4 开头) + 4. 选取当日流通市值最小的 n_stocks 只股票 + + Args: + df: 数据框,必须包含 ts_code 和 circ_mv 列 + n_stocks: 选取的股票数量,默认 SMALL_CAP_N + + Returns: + 布尔 Series,表示哪些股票被选中 + """ + code_filter = ( + ~df["ts_code"].str.starts_with("30") # 排除创业板 + & ~df["ts_code"].str.starts_with("68") # 排除科创板 + & ~df["ts_code"].str.starts_with("8") # 排除北交所 + & ~df["ts_code"].str.starts_with("9") # 排除北交所 + & ~df["ts_code"].str.starts_with("4") # 排除北交所 + ) + valid_df = df.filter(code_filter) + n = min(n_stocks, len(valid_df)) + small_codes = valid_df.sort("circ_mv").head(n)["ts_code"] + return df["ts_code"].is_in(small_codes.implode()) + + +# ============================================================================= +# 默认股票池定义列表 +# ============================================================================= +def get_default_stock_pool_definitions() -> List[StockPoolDefinition]: + """返回默认的股票池定义列表。""" + return [ + StockPoolDefinition( + name="all", + filter_func=all_market_filter, + required_columns=[], + description="全市场", + ), + StockPoolDefinition( + name="growth_board", + filter_func=growth_board_filter, + required_columns=[], + description="创业板", + ), + StockPoolDefinition( + name="star_board", + filter_func=star_board_filter, + required_columns=[], + description="科创板", + ), + StockPoolDefinition( + name="bse_board", + filter_func=bse_board_filter, + required_columns=[], + description="北交所", + ), + StockPoolDefinition( + name="main_board", + filter_func=main_board_filter, + required_columns=[], + description="主板", + ), + StockPoolDefinition( + name="small_cap", + filter_func=small_cap_filter, + required_columns=["circ_mv"], + description="小微盘", + ), + StockPoolDefinition( + name="main_board_small_cap", + filter_func=main_board_small_cap_filter, + required_columns=["circ_mv"], + description="主板小微盘", + ), + ] + + +def get_default_stock_pools( + pre_filters: List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] | None = None, +) -> StockPoolRegistry: + """返回默认的股票池注册表。""" + return StockPoolRegistry.from_definitions( + get_default_stock_pool_definitions(), + pre_filters=pre_filters, + ) diff --git a/src/factorminer/tests/test_evaluation.py b/src/factorminer/tests/test_evaluation.py index 494e33a..4bffb6a 100644 --- a/src/factorminer/tests/test_evaluation.py +++ b/src/factorminer/tests/test_evaluation.py @@ -21,6 +21,7 @@ from src.factorminer.evaluation.metrics import ( # Helpers # --------------------------------------------------------------------------- + @pytest.fixture def rng(): return np.random.default_rng(123) @@ -58,6 +59,7 @@ def known_quintile_signal(rng): # IC computation # --------------------------------------------------------------------------- + class TestIC: """Test Information Coefficient computation.""" @@ -106,6 +108,7 @@ class TestIC: # ICIR computation # --------------------------------------------------------------------------- + class TestICIR: """Test ICIR = mean(IC) / std(IC).""" @@ -142,6 +145,7 @@ class TestICIR: # IC-derived statistics # --------------------------------------------------------------------------- + class TestICStats: """Test IC mean and win rate.""" @@ -170,6 +174,7 @@ class TestICStats: # Pairwise correlation # --------------------------------------------------------------------------- + class TestPairwiseCorrelation: """Test pairwise cross-sectional correlation.""" @@ -207,6 +212,7 @@ class TestPairwiseCorrelation: # Quintile returns # --------------------------------------------------------------------------- + class TestQuintileReturns: """Test quintile return computation.""" @@ -243,6 +249,7 @@ class TestQuintileReturns: # Turnover # --------------------------------------------------------------------------- + class TestTurnover: """Test portfolio turnover computation.""" @@ -263,6 +270,7 @@ class TestTurnover: # Comprehensive factor stats # --------------------------------------------------------------------------- + class TestFactorStats: """Test the compute_factor_stats wrapper.""" @@ -285,3 +293,18 @@ class TestFactorStats: returns = rng.normal(0, 0.01, (M, T)) stats = compute_factor_stats(signals, returns) assert stats["ic_series"].shape == (T,) + + +class TestFactorStatsPools: + """Test compute_factor_stats under multi-stock-pool NaN scenarios.""" + + def test_pool_with_nans(self, rng): + M, T = 40, 30 + signals = rng.normal(0, 1, (M, T)) + returns = rng.normal(0, 0.01, (M, T)) + # 模拟子池:部分资产在所有时间上为 NaN + signals[30:, :] = np.nan + returns[30:, :] = np.nan + stats = compute_factor_stats(signals, returns) + assert stats["n_periods"] == T + assert np.isfinite(stats["ic_mean"]) diff --git a/src/factorminer/tests/test_stock_pool.py b/src/factorminer/tests/test_stock_pool.py new file mode 100644 index 0000000..428fa13 --- /dev/null +++ b/src/factorminer/tests/test_stock_pool.py @@ -0,0 +1,72 @@ +import numpy as np +import polars as pl +import pytest + +from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry + + +class TestStockPoolRegistry: + def test_add_and_get_pool_names(self): + registry = StockPoolRegistry() + registry.add_pool("all", lambda df: pl.Series([True] * len(df))) + registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300")) + assert registry.get_pool_names() == ["all", "growth"] + + def test_build_masks(self): + registry = StockPoolRegistry() + registry.add_pool("all", lambda df: pl.Series([True] * len(df))) + registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300")) + asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"] + masks = registry.build_masks(asset_codes) + assert masks["all"].sum() == 3 + assert masks["growth"].sum() == 1 + assert masks["growth"][1] + + def test_filter_signals(self): + registry = StockPoolRegistry() + registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300")) + asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"] + registry.build_masks(asset_codes) + signals = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float64) + filtered = registry.filter_signals(signals, "growth") + assert filtered.shape == (1, 2) + np.testing.assert_array_equal(filtered, [[3, 4]]) + + def test_build_masks_with_metadata(self): + registry = StockPoolRegistry() + registry.add_pool( + "small_cap", + lambda df: df["ts_code"].is_in( + df.sort("total_mv").head(2)["ts_code"].implode() + ), + required_columns=["total_mv"], + ) + asset_codes = ["A", "B", "C"] + metadata = pl.DataFrame( + { + "ts_code": ["A", "B", "C"], + "total_mv": [300.0, 100.0, 200.0], + } + ) + masks = registry.build_masks(asset_codes, metadata_df=metadata) + assert masks["small_cap"].sum() == 2 + assert masks["small_cap"][1] and masks["small_cap"][2] + + def test_expr_fallback(self): + registry = StockPoolRegistry() + registry.add_pool( + "expr_pool", lambda df: pl.col("ts_code").str.starts_with("300") + ) + asset_codes = ["000001.SZ", "300001.SZ"] + masks = registry.build_masks(asset_codes) + assert masks["expr_pool"].sum() == 1 + + def test_pre_filters(self): + registry = StockPoolRegistry( + pre_filters=[lambda df: df["ts_code"].str.starts_with("0")] + ) + registry.add_pool("all", lambda df: pl.Series([True] * len(df))) + asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"] + masks = registry.build_masks(asset_codes) + assert masks["all"].sum() == 1 + assert masks["all"][0]