feat(factorminer): 支持多股票池因子评估
- 新增 StockPoolRegistry 与默认股票池定义(all/small_cap 等) - RalphLoop 与 ValidationPipeline 支持按多股票池计算 IC - Factor 序列化新增 pool_metrics 字段 - LocalFactorEvaluator 增加 evaluate_returns_by_pool 方法 - 主入口集成股票池配置与 ST 预过滤
This commit is contained in:
1165
docs/plans/2026-04-11-factorminer-multi-stock-pool.md
Normal file
1165
docs/plans/2026-04-11-factorminer-multi-stock-pool.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -42,6 +42,7 @@ class Factor:
|
|||||||
batch_number: int # Which mining batch admitted this factor
|
batch_number: int # Which mining batch admitted this factor
|
||||||
admission_date: str = ""
|
admission_date: str = ""
|
||||||
signals: Optional[np.ndarray] = field(default=None, repr=False) # (M, T)
|
signals: Optional[np.ndarray] = field(default=None, repr=False) # (M, T)
|
||||||
|
pool_metrics: Dict[str, dict] = field(default_factory=dict)
|
||||||
research_metrics: dict = field(default_factory=dict)
|
research_metrics: dict = field(default_factory=dict)
|
||||||
provenance: dict = field(default_factory=dict)
|
provenance: dict = field(default_factory=dict)
|
||||||
metadata: dict = field(default_factory=dict)
|
metadata: dict = field(default_factory=dict)
|
||||||
@@ -50,6 +51,20 @@ class Factor:
|
|||||||
if not self.admission_date:
|
if not self.admission_date:
|
||||||
self.admission_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
self.admission_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_pool_metrics(pool_metrics: Dict[str, dict]) -> Dict[str, dict]:
|
||||||
|
sanitized: Dict[str, dict] = {}
|
||||||
|
for pool_name, stats in pool_metrics.items():
|
||||||
|
sanitized[pool_name] = {}
|
||||||
|
for key, value in stats.items():
|
||||||
|
if key == "ic_series":
|
||||||
|
continue
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
sanitized[pool_name][key] = value.tolist()
|
||||||
|
else:
|
||||||
|
sanitized[pool_name][key] = value
|
||||||
|
return sanitized
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""Serialize to a JSON-compatible dictionary (excludes signals)."""
|
"""Serialize to a JSON-compatible dictionary (excludes signals)."""
|
||||||
return {
|
return {
|
||||||
@@ -63,6 +78,7 @@ class Factor:
|
|||||||
"max_correlation": self.max_correlation,
|
"max_correlation": self.max_correlation,
|
||||||
"batch_number": self.batch_number,
|
"batch_number": self.batch_number,
|
||||||
"admission_date": self.admission_date,
|
"admission_date": self.admission_date,
|
||||||
|
"pool_metrics": self._sanitize_pool_metrics(self.pool_metrics),
|
||||||
"research_metrics": self.research_metrics,
|
"research_metrics": self.research_metrics,
|
||||||
"provenance": self.provenance,
|
"provenance": self.provenance,
|
||||||
"metadata": self.metadata,
|
"metadata": self.metadata,
|
||||||
@@ -82,6 +98,7 @@ class Factor:
|
|||||||
max_correlation=d["max_correlation"],
|
max_correlation=d["max_correlation"],
|
||||||
batch_number=d["batch_number"],
|
batch_number=d["batch_number"],
|
||||||
admission_date=d.get("admission_date", ""),
|
admission_date=d.get("admission_date", ""),
|
||||||
|
pool_metrics=d.get("pool_metrics", {}),
|
||||||
research_metrics=d.get("research_metrics", {}),
|
research_metrics=d.get("research_metrics", {}),
|
||||||
provenance=d.get("provenance", {}),
|
provenance=d.get("provenance", {}),
|
||||||
metadata=d.get("metadata", {}),
|
metadata=d.get("metadata", {}),
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -229,7 +229,7 @@ class HelixLoop(RalphLoop):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Any,
|
config: Any,
|
||||||
returns: np.ndarray,
|
returns: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||||
llm_provider: Optional[LLMProvider] = None,
|
llm_provider: Optional[LLMProvider] = None,
|
||||||
memory: Optional[ExperienceMemory] = None,
|
memory: Optional[ExperienceMemory] = None,
|
||||||
library: Optional[FactorLibrary] = None,
|
library: Optional[FactorLibrary] = None,
|
||||||
|
|||||||
@@ -257,9 +257,13 @@ class ValidationPipeline:
|
|||||||
evaluator: Any = None,
|
evaluator: Any = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.data_tensor = data_tensor # (M, T, F)
|
self.data_tensor = data_tensor # (M, T, F)
|
||||||
|
if isinstance(returns, dict):
|
||||||
|
self.returns = returns.get("all", next(iter(returns.values())))
|
||||||
|
self.target_panels = returns
|
||||||
|
else:
|
||||||
self.returns = returns # (M, T)
|
self.returns = returns # (M, T)
|
||||||
self.evaluator = evaluator
|
|
||||||
self.target_panels = target_panels or {"paper": returns}
|
self.target_panels = target_panels or {"paper": returns}
|
||||||
|
self.evaluator = evaluator
|
||||||
self.target_horizons = target_horizons or {"paper": 1}
|
self.target_horizons = target_horizons or {"paper": 1}
|
||||||
self.library = library or FactorLibrary(
|
self.library = library or FactorLibrary(
|
||||||
correlation_threshold=0.5,
|
correlation_threshold=0.5,
|
||||||
@@ -353,20 +357,18 @@ class ValidationPipeline:
|
|||||||
result.stage_passed = 0
|
result.stage_passed = 0
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Full IC statistics on all assets
|
# Full IC statistics on all target panels (stock pools)
|
||||||
stats = compute_factor_stats(signals, self.returns)
|
all_stats: Dict[str, dict] = {}
|
||||||
result.ic_mean = stats["ic_abs_mean"]
|
for panel_name, panel_returns in self.target_panels.items():
|
||||||
result.icir = stats["icir"]
|
all_stats[panel_name] = compute_factor_stats(signals, panel_returns)
|
||||||
result.ic_win_rate = stats["ic_win_rate"]
|
|
||||||
result.target_stats = {"paper": stats}
|
|
||||||
|
|
||||||
if self.target_panels:
|
best_panel_name = max(all_stats, key=lambda k: all_stats[k]["ic_abs_mean"])
|
||||||
for target_name, target_returns in self.target_panels.items():
|
best_stats = all_stats[best_panel_name]
|
||||||
if target_name == "paper":
|
|
||||||
continue
|
result.ic_mean = best_stats["ic_abs_mean"]
|
||||||
result.target_stats[target_name] = compute_factor_stats(
|
result.icir = best_stats["icir"]
|
||||||
signals, target_returns
|
result.ic_win_rate = best_stats["ic_win_rate"]
|
||||||
)
|
result.target_stats = all_stats
|
||||||
|
|
||||||
score_vector_obj = None
|
score_vector_obj = None
|
||||||
if self._research_enabled():
|
if self._research_enabled():
|
||||||
@@ -737,7 +739,7 @@ class RalphLoop:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Any,
|
config: Any,
|
||||||
returns: np.ndarray,
|
returns: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||||
llm_provider: Optional[LLMProvider] = None,
|
llm_provider: Optional[LLMProvider] = None,
|
||||||
memory: Optional[ExperienceMemory] = None,
|
memory: Optional[ExperienceMemory] = None,
|
||||||
library: Optional[FactorLibrary] = None,
|
library: Optional[FactorLibrary] = None,
|
||||||
@@ -767,7 +769,12 @@ class RalphLoop:
|
|||||||
the legacy data_tensor path.
|
the legacy data_tensor path.
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
if isinstance(returns, dict):
|
||||||
|
self.returns = returns.get("all", next(iter(returns.values())))
|
||||||
|
self.target_panels = returns
|
||||||
|
else:
|
||||||
self.returns = returns
|
self.returns = returns
|
||||||
|
self.target_panels = getattr(config, "target_panels", None)
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.checkpoint_interval = checkpoint_interval
|
self.checkpoint_interval = checkpoint_interval
|
||||||
|
|
||||||
@@ -783,9 +790,9 @@ class RalphLoop:
|
|||||||
prompt_builder=PromptBuilder(),
|
prompt_builder=PromptBuilder(),
|
||||||
)
|
)
|
||||||
self.pipeline = ValidationPipeline(
|
self.pipeline = ValidationPipeline(
|
||||||
data_tensor=np.empty((returns.shape[0], returns.shape[1], 1)),
|
data_tensor=np.empty((self.returns.shape[0], self.returns.shape[1], 1)),
|
||||||
returns=returns,
|
returns=self.returns,
|
||||||
target_panels=getattr(config, "target_panels", None),
|
target_panels=self.target_panels,
|
||||||
target_horizons=getattr(config, "target_horizons", None),
|
target_horizons=getattr(config, "target_horizons", None),
|
||||||
library=self.library,
|
library=self.library,
|
||||||
ic_threshold=getattr(config, "ic_threshold", 0.04),
|
ic_threshold=getattr(config, "ic_threshold", 0.04),
|
||||||
@@ -1082,6 +1089,14 @@ class RalphLoop:
|
|||||||
# Handle replacement
|
# Handle replacement
|
||||||
if result.replaced is not None:
|
if result.replaced is not None:
|
||||||
old_id = result.replaced
|
old_id = result.replaced
|
||||||
|
if result.target_stats:
|
||||||
|
target_stats_clean = {}
|
||||||
|
for pool_name, stats in result.target_stats.items():
|
||||||
|
target_stats_clean[pool_name] = {
|
||||||
|
k: v for k, v in stats.items() if k != "ic_series"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
target_stats_clean = {}
|
||||||
new_factor = Factor(
|
new_factor = Factor(
|
||||||
id=0, # Will be reassigned by library
|
id=0, # Will be reassigned by library
|
||||||
name=result.factor_name,
|
name=result.factor_name,
|
||||||
@@ -1093,6 +1108,7 @@ class RalphLoop:
|
|||||||
max_correlation=result.max_correlation,
|
max_correlation=result.max_correlation,
|
||||||
batch_number=self.iteration,
|
batch_number=self.iteration,
|
||||||
signals=result.signals,
|
signals=result.signals,
|
||||||
|
pool_metrics=target_stats_clean,
|
||||||
research_metrics=result.score_vector or {},
|
research_metrics=result.score_vector or {},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@@ -1110,6 +1126,14 @@ class RalphLoop:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Direct admission
|
# Direct admission
|
||||||
|
if result.target_stats:
|
||||||
|
target_stats_clean = {}
|
||||||
|
for pool_name, stats in result.target_stats.items():
|
||||||
|
target_stats_clean[pool_name] = {
|
||||||
|
k: v for k, v in stats.items() if k != "ic_series"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
target_stats_clean = {}
|
||||||
factor = Factor(
|
factor = Factor(
|
||||||
id=0, # Will be reassigned
|
id=0, # Will be reassigned
|
||||||
name=result.factor_name,
|
name=result.factor_name,
|
||||||
@@ -1121,6 +1145,7 @@ class RalphLoop:
|
|||||||
max_correlation=result.max_correlation,
|
max_correlation=result.max_correlation,
|
||||||
batch_number=self.iteration,
|
batch_number=self.iteration,
|
||||||
signals=result.signals,
|
signals=result.signals,
|
||||||
|
pool_metrics=target_stats_clean,
|
||||||
research_metrics=result.score_vector or {},
|
research_metrics=result.score_vector or {},
|
||||||
)
|
)
|
||||||
self.library.admit_factor(factor)
|
self.library.admit_factor(factor)
|
||||||
@@ -1435,7 +1460,7 @@ class RalphLoop:
|
|||||||
cls,
|
cls,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
config: Any,
|
config: Any,
|
||||||
returns: np.ndarray,
|
returns: Union[np.ndarray, Dict[str, np.ndarray]],
|
||||||
llm_provider: Optional[LLMProvider] = None,
|
llm_provider: Optional[LLMProvider] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "RalphLoop":
|
) -> "RalphLoop":
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import numpy as np
|
|||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
from src.factors import FactorEngine
|
from src.factors import FactorEngine
|
||||||
|
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
|
||||||
|
|
||||||
|
|
||||||
class LocalFactorEvaluator:
|
class LocalFactorEvaluator:
|
||||||
@@ -40,6 +41,7 @@ class LocalFactorEvaluator:
|
|||||||
start_date: str,
|
start_date: str,
|
||||||
end_date: str,
|
end_date: str,
|
||||||
stock_codes: Optional[List[str]] = None,
|
stock_codes: Optional[List[str]] = None,
|
||||||
|
stock_pool_registry: Optional[StockPoolRegistry] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化评估器。
|
"""初始化评估器。
|
||||||
|
|
||||||
@@ -47,11 +49,14 @@ class LocalFactorEvaluator:
|
|||||||
start_date: 计算开始日期,YYYYMMDD 格式
|
start_date: 计算开始日期,YYYYMMDD 格式
|
||||||
end_date: 计算结束日期,YYYYMMDD 格式
|
end_date: 计算结束日期,YYYYMMDD 格式
|
||||||
stock_codes: 可选的股票代码列表,None 表示全量
|
stock_codes: 可选的股票代码列表,None 表示全量
|
||||||
|
stock_pool_registry: 股票池注册表,用于多股票池评估
|
||||||
"""
|
"""
|
||||||
self.start_date = start_date
|
self.start_date = start_date
|
||||||
self.end_date = end_date
|
self.end_date = end_date
|
||||||
self.stock_codes = stock_codes
|
self.stock_codes = stock_codes
|
||||||
self.engine = FactorEngine()
|
self.engine = FactorEngine()
|
||||||
|
self.stock_pool_registry = stock_pool_registry
|
||||||
|
self._asset_codes: Optional[List[str]] = None
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
@@ -166,6 +171,76 @@ class LocalFactorEvaluator:
|
|||||||
print(f"[ERROR] 计算收益率矩阵失败: {e}")
|
print(f"[ERROR] 计算收益率矩阵失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_asset_codes(self) -> List[str]:
|
||||||
|
"""返回最近一次 evaluate/evaluate_returns 中的资产代码列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
按字母序排列的股票代码列表 (M,)。
|
||||||
|
"""
|
||||||
|
if self._asset_codes is None:
|
||||||
|
raise RuntimeError("请先调用 evaluate() 或 evaluate_returns()")
|
||||||
|
return self._asset_codes
|
||||||
|
|
||||||
|
def _get_metadata_df(self, columns: List[str]) -> Optional[pl.DataFrame]:
|
||||||
|
"""拉取构建股票池所需的元数据(默认取 end_date 最新截面)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
columns: 需要的额外列名列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 ts_code 和 columns 的 DataFrame,若失败则返回 None。
|
||||||
|
"""
|
||||||
|
if not columns:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
df = self.engine.router._load_table(
|
||||||
|
table_name="daily_basic",
|
||||||
|
columns=columns,
|
||||||
|
start_date=self.end_date,
|
||||||
|
end_date=self.end_date,
|
||||||
|
stock_codes=self.stock_codes,
|
||||||
|
)
|
||||||
|
if "trade_date" in df.columns:
|
||||||
|
df = df.sort("trade_date", descending=True)
|
||||||
|
df = df.unique(subset=["ts_code"], maintain_order=True)
|
||||||
|
return df.select(["ts_code"] + columns)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"[WARN] 拉取股票池元数据失败: {exc}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def evaluate_returns_by_pool(self, periods: int = 1) -> Dict[str, np.ndarray]:
|
||||||
|
"""计算各股票池的收益率矩阵。
|
||||||
|
|
||||||
|
先计算全市场收益率矩阵,然后根据注册的股票池掩码生成子池矩阵。
|
||||||
|
未入选该池的资产在所有时间上的收益率标记为 NaN。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
periods: 计算 N 日后的收益率,默认 1。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{pool_name: (M, T) np.ndarray} 字典,包含所有注册的股票池和 "all"。
|
||||||
|
"""
|
||||||
|
returns_all = self.evaluate_returns(periods=periods)
|
||||||
|
result: Dict[str, np.ndarray] = {"all": returns_all}
|
||||||
|
|
||||||
|
if self.stock_pool_registry is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
codes = self.get_asset_codes()
|
||||||
|
req_cols = self.stock_pool_registry.get_required_columns()
|
||||||
|
metadata = self._get_metadata_df(req_cols)
|
||||||
|
self.stock_pool_registry.build_masks(codes, metadata_df=metadata)
|
||||||
|
|
||||||
|
for name in self.stock_pool_registry.get_pool_names():
|
||||||
|
if name == "all":
|
||||||
|
continue
|
||||||
|
mask = self.stock_pool_registry.masks[name]
|
||||||
|
pool_returns = np.full_like(returns_all, np.nan, dtype=np.float64)
|
||||||
|
pool_returns[mask, :] = returns_all[mask, :]
|
||||||
|
result[name] = pool_returns
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def _pivot_to_matrix(
|
def _pivot_to_matrix(
|
||||||
self,
|
self,
|
||||||
df: pl.DataFrame,
|
df: pl.DataFrame,
|
||||||
@@ -191,6 +266,8 @@ class LocalFactorEvaluator:
|
|||||||
# 获取时间戳和股票代码的唯一值(已排序)
|
# 获取时间戳和股票代码的唯一值(已排序)
|
||||||
timestamps = df["trade_date"].unique().sort()
|
timestamps = df["trade_date"].unique().sort()
|
||||||
asset_codes = df["ts_code"].unique().sort()
|
asset_codes = df["ts_code"].unique().sort()
|
||||||
|
if self._asset_codes is None:
|
||||||
|
self._asset_codes = asset_codes.to_list()
|
||||||
|
|
||||||
n_assets = len(asset_codes)
|
n_assets = len(asset_codes)
|
||||||
n_times = len(timestamps)
|
n_times = len(timestamps)
|
||||||
|
|||||||
150
src/factorminer/evaluation/stock_pool_registry.py
Normal file
150
src/factorminer/evaluation/stock_pool_registry.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""股票池注册表:支持配置化股票池筛选与掩码生成。"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StockPoolDefinition:
|
||||||
|
"""股票池定义,参考 experiment/common.py 中的 filter 模式。"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
filter_func: Callable[[pl.DataFrame], pl.Series | pl.Expr]
|
||||||
|
required_columns: List[str] = field(default_factory=list)
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class StockPoolRegistry:
|
||||||
|
"""管理多个股票池的定义,并为 (M,) 资产列表生成布尔掩码。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
definitions: Optional[List[StockPoolDefinition]] = None,
|
||||||
|
pre_filters: Optional[
|
||||||
|
List[Callable[[pl.DataFrame], pl.Series | pl.Expr]]
|
||||||
|
] = None,
|
||||||
|
) -> None:
|
||||||
|
self.pre_filters: List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] = (
|
||||||
|
pre_filters or []
|
||||||
|
)
|
||||||
|
self.pools: Dict[str, dict] = {}
|
||||||
|
self.masks: Dict[str, np.ndarray] = {}
|
||||||
|
self._resolved = False
|
||||||
|
if definitions:
|
||||||
|
for d in definitions:
|
||||||
|
self.add_pool(d.name, d.filter_func, d.required_columns)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_definitions(
|
||||||
|
cls,
|
||||||
|
definitions: List[StockPoolDefinition],
|
||||||
|
pre_filters: Optional[
|
||||||
|
List[Callable[[pl.DataFrame], pl.Series | pl.Expr]]
|
||||||
|
] = None,
|
||||||
|
) -> "StockPoolRegistry":
|
||||||
|
"""通过定义列表快速创建注册表。"""
|
||||||
|
return cls(definitions=definitions, pre_filters=pre_filters)
|
||||||
|
|
||||||
|
def add_pool(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
filter_func: Callable[[pl.DataFrame], pl.Series | pl.Expr],
|
||||||
|
required_columns: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""注册一个股票池。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 股票池名称,如 "small_cap"。
|
||||||
|
filter_func: 接收包含 ts_code 和 required_columns 的 DataFrame,
|
||||||
|
返回布尔 Series 或 Expr。
|
||||||
|
required_columns: filter_func 额外需要的列名列表。
|
||||||
|
"""
|
||||||
|
self.pools[name] = {
|
||||||
|
"filter_func": filter_func,
|
||||||
|
"required_columns": required_columns or [],
|
||||||
|
}
|
||||||
|
self._resolved = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _eval_filter_result(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
result: pl.Series | pl.Expr,
|
||||||
|
context: str,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""将 filter_func 的返回值统一转换为 numpy bool 数组。"""
|
||||||
|
if isinstance(result, pl.Series):
|
||||||
|
mask_series = result
|
||||||
|
elif isinstance(result, pl.Expr):
|
||||||
|
mask_series = df.select(result.alias("_mask")).to_series()
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{context} 的 filter_func 必须返回 pl.Series 或 pl.Expr,"
|
||||||
|
f"实际返回 {type(result)}"
|
||||||
|
)
|
||||||
|
return mask_series.to_numpy().astype(bool)
|
||||||
|
|
||||||
|
def build_masks(
|
||||||
|
self,
|
||||||
|
asset_codes: List[str],
|
||||||
|
metadata_df: Optional[pl.DataFrame] = None,
|
||||||
|
) -> Dict[str, np.ndarray]:
|
||||||
|
"""为所有注册的股票池生成布尔掩码。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asset_codes: 按矩阵顺序排列的股票代码列表 (M,)。
|
||||||
|
metadata_df: 包含 extra columns 的 metadata,可选。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{pool_name: bool_array_of_shape_(M,)} 字典。
|
||||||
|
"""
|
||||||
|
base_df = pl.DataFrame({"ts_code": asset_codes})
|
||||||
|
df = base_df
|
||||||
|
if metadata_df is not None:
|
||||||
|
df = base_df.join(metadata_df, on="ts_code", how="left")
|
||||||
|
|
||||||
|
base_mask = np.ones(len(asset_codes), dtype=bool)
|
||||||
|
for pre_filter in self.pre_filters:
|
||||||
|
result = pre_filter(df)
|
||||||
|
base_mask &= self._eval_filter_result(df, result, "pre_filter")
|
||||||
|
|
||||||
|
self.masks = {}
|
||||||
|
for name, cfg in self.pools.items():
|
||||||
|
result = cfg["filter_func"](df)
|
||||||
|
pool_mask = self._eval_filter_result(df, result, f"股票池 '{name}'")
|
||||||
|
self.masks[name] = base_mask & pool_mask
|
||||||
|
|
||||||
|
self._resolved = True
|
||||||
|
return self.masks
|
||||||
|
|
||||||
|
def filter_signals(self, signals: np.ndarray, pool_name: str) -> np.ndarray:
|
||||||
|
"""使用已构建的掩码过滤信号矩阵。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signals: (M, T) 信号矩阵。
|
||||||
|
pool_name: 股票池名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(M_pool, T) 的子矩阵。
|
||||||
|
"""
|
||||||
|
if not self._resolved:
|
||||||
|
raise RuntimeError("请先调用 build_masks 构建掩码")
|
||||||
|
mask = self.masks.get(pool_name)
|
||||||
|
if mask is None:
|
||||||
|
raise KeyError(f"未知的股票池: {pool_name}")
|
||||||
|
return signals[mask, :]
|
||||||
|
|
||||||
|
def get_pool_names(self) -> List[str]:
|
||||||
|
"""返回已注册的股票池名称列表。"""
|
||||||
|
return list(self.pools.keys())
|
||||||
|
|
||||||
|
def get_required_columns(self) -> List[str]:
|
||||||
|
"""返回所有 pre_filter 和股票池所需的列名并集。"""
|
||||||
|
cols: set[str] = set()
|
||||||
|
for pre_filter in self.pre_filters:
|
||||||
|
cols.update(getattr(pre_filter, "required_columns", []))
|
||||||
|
for cfg in self.pools.values():
|
||||||
|
cols.update(cfg["required_columns"])
|
||||||
|
return sorted(cols)
|
||||||
@@ -11,6 +11,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
from src.config.settings import get_settings
|
from src.config.settings import get_settings
|
||||||
from src.factorminer.agent.llm_interface import create_provider, AnthropicProvider
|
from src.factorminer.agent.llm_interface import create_provider, AnthropicProvider
|
||||||
from src.factorminer.core.config import MiningConfig as CoreMiningConfig
|
from src.factorminer.core.config import MiningConfig as CoreMiningConfig
|
||||||
@@ -19,7 +21,10 @@ from src.factorminer.core.ralph_loop import RalphLoop
|
|||||||
from src.factorminer.core.helix_loop import HelixLoop
|
from src.factorminer.core.helix_loop import HelixLoop
|
||||||
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
||||||
from src.factorminer.evaluation.significance import SignificanceConfig
|
from src.factorminer.evaluation.significance import SignificanceConfig
|
||||||
|
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
|
||||||
|
from src.factorminer.pool_definitions import get_default_stock_pool_definitions
|
||||||
from src.factorminer.utils.config import load_config
|
from src.factorminer.utils.config import load_config
|
||||||
|
from src.training.components.filters import STFilter
|
||||||
|
|
||||||
RUN_CONFIG: dict = {
|
RUN_CONFIG: dict = {
|
||||||
# 全局开关
|
# 全局开关
|
||||||
@@ -35,9 +40,16 @@ RUN_CONFIG: dict = {
|
|||||||
"stock_codes": None, # 可选股票列表,None 表示全量
|
"stock_codes": None, # 可选股票列表,None 表示全量
|
||||||
# 种子库
|
# 种子库
|
||||||
"seed_paper_library": True, # 是否预加载 110 Paper Factors 作为种子库
|
"seed_paper_library": True, # 是否预加载 110 Paper Factors 作为种子库
|
||||||
|
# 股票池配置(支持多股票池评估)
|
||||||
|
"stock_pools": {
|
||||||
|
"enabled": True, # 是否启用多股票池
|
||||||
|
"provider": "default", # "default" 使用 pool_definitions.py 中的默认定义
|
||||||
|
"use_st_filter": True, # 是否在构建股票池前过滤 ST 股票
|
||||||
|
"pools": ["all", "small_cap"], # 要启用的股票池名称列表,None 表示全部启用
|
||||||
|
},
|
||||||
# 挖掘配置(覆盖 default.yaml 中的同名项)
|
# 挖掘配置(覆盖 default.yaml 中的同名项)
|
||||||
"mining": {
|
"mining": {
|
||||||
"max_iterations": 3,
|
"max_iterations": 1,
|
||||||
"target_library_size": 10,
|
"target_library_size": 10,
|
||||||
"correlation_threshold": 0.50,
|
"correlation_threshold": 0.50,
|
||||||
"ic_threshold": 0.04,
|
"ic_threshold": 0.04,
|
||||||
@@ -139,6 +151,54 @@ def _build_core_mining_config(run_cfg: dict) -> CoreMiningConfig:
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def _build_st_pre_filter(st_filter: STFilter, end_date: str):
|
||||||
|
"""将 STFilter 包装为 StockPoolRegistry 可接受的 pre-filter 函数。"""
|
||||||
|
|
||||||
|
def _pre_filter(df: pl.DataFrame) -> pl.Series:
|
||||||
|
try:
|
||||||
|
if "trade_date" not in df.columns:
|
||||||
|
df = df.with_columns(pl.lit(end_date).alias("trade_date"))
|
||||||
|
filtered = st_filter.filter(df)
|
||||||
|
return df["ts_code"].is_in(filtered["ts_code"])
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"[WARN] ST pre-filter 失败: {exc}")
|
||||||
|
return pl.Series([True] * len(df))
|
||||||
|
|
||||||
|
return _pre_filter
|
||||||
|
|
||||||
|
|
||||||
|
def _build_stock_pool_registry(
|
||||||
|
run_cfg: dict,
|
||||||
|
st_filter: Optional[STFilter] = None,
|
||||||
|
end_date: str = "20231231",
|
||||||
|
) -> Optional[StockPoolRegistry]:
|
||||||
|
"""根据 RUN_CONFIG 构建股票池注册表。"""
|
||||||
|
pool_cfg = run_cfg.get("stock_pools", {})
|
||||||
|
if not pool_cfg.get("enabled", False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
pre_filters = []
|
||||||
|
if pool_cfg.get("use_st_filter", False) and st_filter is not None:
|
||||||
|
pre_filters.append(_build_st_pre_filter(st_filter, end_date))
|
||||||
|
|
||||||
|
provider = pool_cfg.get("provider", "default")
|
||||||
|
if provider == "default":
|
||||||
|
definitions = get_default_stock_pool_definitions()
|
||||||
|
pool_names = pool_cfg.get("pools")
|
||||||
|
if pool_names is not None:
|
||||||
|
definitions = [d for d in definitions if d.name in pool_names]
|
||||||
|
if not definitions:
|
||||||
|
raise ValueError(
|
||||||
|
f"股票池配置错误: pools={pool_names} 未匹配到任何默认股票池定义"
|
||||||
|
)
|
||||||
|
missing = set(pool_names) - {d.name for d in definitions}
|
||||||
|
if missing:
|
||||||
|
print(f"[WARN] 以下股票池未在默认定义中找到,已忽略: {sorted(missing)}")
|
||||||
|
return StockPoolRegistry.from_definitions(definitions, pre_filters=pre_filters)
|
||||||
|
|
||||||
|
raise ValueError(f"不支持的股票池 provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
def _build_helix_kwargs(run_cfg: dict) -> dict:
|
def _build_helix_kwargs(run_cfg: dict) -> dict:
|
||||||
"""从 RUN_CONFIG 构建 HelixLoop 需要的 Phase 2 扩展配置。"""
|
"""从 RUN_CONFIG 构建 HelixLoop 需要的 Phase 2 扩展配置。"""
|
||||||
helix = run_cfg.get("helix", {})
|
helix = run_cfg.get("helix", {})
|
||||||
@@ -215,14 +275,33 @@ def main(config: dict | None = None) -> None:
|
|||||||
end_date = run_cfg.get("end_date", "20201231")
|
end_date = run_cfg.get("end_date", "20201231")
|
||||||
stock_codes = run_cfg.get("stock_codes")
|
stock_codes = run_cfg.get("stock_codes")
|
||||||
|
|
||||||
|
# 4.1 创建 evaluator(先实例化才能拿到 router)
|
||||||
evaluator = LocalFactorEvaluator(
|
evaluator = LocalFactorEvaluator(
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
stock_codes=stock_codes,
|
stock_codes=stock_codes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 4.2 构建 STFilter 和股票池注册表
|
||||||
|
st_filter = STFilter(data_router=evaluator.engine.router)
|
||||||
|
stock_pool_registry = _build_stock_pool_registry(
|
||||||
|
run_cfg, st_filter=st_filter, end_date=end_date
|
||||||
|
)
|
||||||
|
evaluator.stock_pool_registry = stock_pool_registry
|
||||||
|
|
||||||
|
if stock_pool_registry is not None:
|
||||||
|
returns = evaluator.evaluate_returns_by_pool(periods=1)
|
||||||
|
for pool_name, pool_ret in returns.items():
|
||||||
|
valid_assets = int((~np.isnan(pool_ret).all(axis=1)).sum())
|
||||||
|
print(
|
||||||
|
f"[main] 股票池 {pool_name}: {valid_assets} 只资产, "
|
||||||
|
f"returns shape: {pool_ret.shape}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
returns = evaluator.evaluate_returns(periods=1)
|
returns = evaluator.evaluate_returns(periods=1)
|
||||||
print(
|
print(
|
||||||
f"[main] 本地数据范围: {start_date} ~ {end_date}, returns shape: {returns.shape}"
|
f"[main] 本地数据范围: {start_date} ~ {end_date}, "
|
||||||
|
f"returns shape: {returns.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
171
src/factorminer/pool_definitions.py
Normal file
171
src/factorminer/pool_definitions.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""用户可配置的股票池定义文件。
|
||||||
|
|
||||||
|
参考 `src/experiment/common.py` 中的 `stock_pool_filter` 模式,
|
||||||
|
支持通过编写 Polars filter 函数来自定义股票池。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.factorminer.evaluation.stock_pool_registry import StockPoolDefinition, StockPoolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 股票池大小配置
|
||||||
|
# =============================================================================
|
||||||
|
SMALL_CAP_N = 1000 # 小微盘股票池默认选取的股票数量
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 股票池筛选函数(与 experiment/common.py 风格一致)
|
||||||
|
# =============================================================================
|
||||||
|
def all_market_filter(df: pl.DataFrame) -> pl.Series:
|
||||||
|
"""全市场股票池。"""
|
||||||
|
return pl.Series([True] * len(df))
|
||||||
|
|
||||||
|
|
||||||
|
def growth_board_filter(df: pl.DataFrame) -> pl.Series:
|
||||||
|
"""创业板股票池(代码以 300 开头)。"""
|
||||||
|
return df["ts_code"].str.starts_with("300")
|
||||||
|
|
||||||
|
|
||||||
|
def star_board_filter(df: pl.DataFrame) -> pl.Series:
|
||||||
|
"""科创板股票池(代码以 688 开头)。"""
|
||||||
|
return df["ts_code"].str.starts_with("688")
|
||||||
|
|
||||||
|
|
||||||
|
def bse_board_filter(df: pl.DataFrame) -> pl.Series:
|
||||||
|
"""北交所股票池(代码以 8 或 4 开头)。"""
|
||||||
|
return df["ts_code"].str.starts_with("8") | df["ts_code"].str.starts_with("4")
|
||||||
|
|
||||||
|
|
||||||
|
def main_board_filter(df: pl.DataFrame) -> pl.Series:
|
||||||
|
"""主板股票池(排除创业板、科创板、北交所)。"""
|
||||||
|
return (
|
||||||
|
~df["ts_code"].str.starts_with("300")
|
||||||
|
& ~df["ts_code"].str.starts_with("688")
|
||||||
|
& ~df["ts_code"].str.starts_with("8")
|
||||||
|
& ~df["ts_code"].str.starts_with("4")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def small_cap_filter(df: pl.DataFrame, n_stocks: int = SMALL_CAP_N) -> pl.Series:
|
||||||
|
"""小微盘股票池(市值最小的 n_stocks 只股票)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 包含 ts_code 和 circ_mv 列的数据框。
|
||||||
|
n_stocks: 选取的股票数量,默认 SMALL_CAP_N。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
布尔 Series,表示是否入选小微盘股票池。
|
||||||
|
"""
|
||||||
|
if "circ_mv" not in df.columns:
|
||||||
|
# 若缺失 circ_mv 数据则全部排除(安全降级)
|
||||||
|
return pl.Series([False] * len(df))
|
||||||
|
n = min(n_stocks, len(df))
|
||||||
|
small_codes = df.sort("circ_mv").head(n)["ts_code"]
|
||||||
|
return df["ts_code"].is_in(small_codes.implode())
|
||||||
|
|
||||||
|
|
||||||
|
def main_board_small_cap_filter(
|
||||||
|
df: pl.DataFrame, n_stocks: int = SMALL_CAP_N
|
||||||
|
) -> pl.Series:
|
||||||
|
"""主板小微盘股票池(主板中市值最小的 n_stocks 只股票)。"""
|
||||||
|
main_board = main_board_filter(df)
|
||||||
|
main_df = df.filter(main_board)
|
||||||
|
if "circ_mv" not in df.columns or len(main_df) == 0:
|
||||||
|
return pl.Series([False] * len(df))
|
||||||
|
n = min(n_stocks, len(main_df))
|
||||||
|
small_codes = main_df.sort("circ_mv").head(n)["ts_code"]
|
||||||
|
return df["ts_code"].is_in(small_codes.implode())
|
||||||
|
|
||||||
|
|
||||||
|
def common_small_cap_filter(df: pl.DataFrame, n_stocks: int = SMALL_CAP_N) -> pl.Series:
|
||||||
|
"""与 experiment/common.py 完全一致的小微盘筛选函数。
|
||||||
|
|
||||||
|
筛选条件:
|
||||||
|
1. 排除创业板(代码以 300 开头)
|
||||||
|
2. 排除科创板(代码以 688 开头)
|
||||||
|
3. 排除北交所(代码以 8、9 或 4 开头)
|
||||||
|
4. 选取当日流通市值最小的 n_stocks 只股票
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 数据框,必须包含 ts_code 和 circ_mv 列
|
||||||
|
n_stocks: 选取的股票数量,默认 SMALL_CAP_N
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
布尔 Series,表示哪些股票被选中
|
||||||
|
"""
|
||||||
|
code_filter = (
|
||||||
|
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||||
|
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||||
|
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||||
|
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||||
|
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||||
|
)
|
||||||
|
valid_df = df.filter(code_filter)
|
||||||
|
n = min(n_stocks, len(valid_df))
|
||||||
|
small_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
|
||||||
|
return df["ts_code"].is_in(small_codes.implode())
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 默认股票池定义列表
|
||||||
|
# =============================================================================
|
||||||
|
def get_default_stock_pool_definitions() -> List[StockPoolDefinition]:
|
||||||
|
"""返回默认的股票池定义列表。"""
|
||||||
|
return [
|
||||||
|
StockPoolDefinition(
|
||||||
|
name="all",
|
||||||
|
filter_func=all_market_filter,
|
||||||
|
required_columns=[],
|
||||||
|
description="全市场",
|
||||||
|
),
|
||||||
|
StockPoolDefinition(
|
||||||
|
name="growth_board",
|
||||||
|
filter_func=growth_board_filter,
|
||||||
|
required_columns=[],
|
||||||
|
description="创业板",
|
||||||
|
),
|
||||||
|
StockPoolDefinition(
|
||||||
|
name="star_board",
|
||||||
|
filter_func=star_board_filter,
|
||||||
|
required_columns=[],
|
||||||
|
description="科创板",
|
||||||
|
),
|
||||||
|
StockPoolDefinition(
|
||||||
|
name="bse_board",
|
||||||
|
filter_func=bse_board_filter,
|
||||||
|
required_columns=[],
|
||||||
|
description="北交所",
|
||||||
|
),
|
||||||
|
StockPoolDefinition(
|
||||||
|
name="main_board",
|
||||||
|
filter_func=main_board_filter,
|
||||||
|
required_columns=[],
|
||||||
|
description="主板",
|
||||||
|
),
|
||||||
|
StockPoolDefinition(
|
||||||
|
name="small_cap",
|
||||||
|
filter_func=small_cap_filter,
|
||||||
|
required_columns=["circ_mv"],
|
||||||
|
description="小微盘",
|
||||||
|
),
|
||||||
|
StockPoolDefinition(
|
||||||
|
name="main_board_small_cap",
|
||||||
|
filter_func=main_board_small_cap_filter,
|
||||||
|
required_columns=["circ_mv"],
|
||||||
|
description="主板小微盘",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_stock_pools(
|
||||||
|
pre_filters: List[Callable[[pl.DataFrame], pl.Series | pl.Expr]] | None = None,
|
||||||
|
) -> StockPoolRegistry:
|
||||||
|
"""返回默认的股票池注册表。"""
|
||||||
|
return StockPoolRegistry.from_definitions(
|
||||||
|
get_default_stock_pool_definitions(),
|
||||||
|
pre_filters=pre_filters,
|
||||||
|
)
|
||||||
@@ -21,6 +21,7 @@ from src.factorminer.evaluation.metrics import (
|
|||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def rng():
|
def rng():
|
||||||
return np.random.default_rng(123)
|
return np.random.default_rng(123)
|
||||||
@@ -58,6 +59,7 @@ def known_quintile_signal(rng):
|
|||||||
# IC computation
|
# IC computation
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestIC:
|
class TestIC:
|
||||||
"""Test Information Coefficient computation."""
|
"""Test Information Coefficient computation."""
|
||||||
|
|
||||||
@@ -106,6 +108,7 @@ class TestIC:
|
|||||||
# ICIR computation
|
# ICIR computation
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestICIR:
|
class TestICIR:
|
||||||
"""Test ICIR = mean(IC) / std(IC)."""
|
"""Test ICIR = mean(IC) / std(IC)."""
|
||||||
|
|
||||||
@@ -142,6 +145,7 @@ class TestICIR:
|
|||||||
# IC-derived statistics
|
# IC-derived statistics
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestICStats:
|
class TestICStats:
|
||||||
"""Test IC mean and win rate."""
|
"""Test IC mean and win rate."""
|
||||||
|
|
||||||
@@ -170,6 +174,7 @@ class TestICStats:
|
|||||||
# Pairwise correlation
|
# Pairwise correlation
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestPairwiseCorrelation:
|
class TestPairwiseCorrelation:
|
||||||
"""Test pairwise cross-sectional correlation."""
|
"""Test pairwise cross-sectional correlation."""
|
||||||
|
|
||||||
@@ -207,6 +212,7 @@ class TestPairwiseCorrelation:
|
|||||||
# Quintile returns
|
# Quintile returns
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestQuintileReturns:
|
class TestQuintileReturns:
|
||||||
"""Test quintile return computation."""
|
"""Test quintile return computation."""
|
||||||
|
|
||||||
@@ -243,6 +249,7 @@ class TestQuintileReturns:
|
|||||||
# Turnover
|
# Turnover
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestTurnover:
|
class TestTurnover:
|
||||||
"""Test portfolio turnover computation."""
|
"""Test portfolio turnover computation."""
|
||||||
|
|
||||||
@@ -263,6 +270,7 @@ class TestTurnover:
|
|||||||
# Comprehensive factor stats
|
# Comprehensive factor stats
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestFactorStats:
|
class TestFactorStats:
|
||||||
"""Test the compute_factor_stats wrapper."""
|
"""Test the compute_factor_stats wrapper."""
|
||||||
|
|
||||||
@@ -285,3 +293,18 @@ class TestFactorStats:
|
|||||||
returns = rng.normal(0, 0.01, (M, T))
|
returns = rng.normal(0, 0.01, (M, T))
|
||||||
stats = compute_factor_stats(signals, returns)
|
stats = compute_factor_stats(signals, returns)
|
||||||
assert stats["ic_series"].shape == (T,)
|
assert stats["ic_series"].shape == (T,)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFactorStatsPools:
|
||||||
|
"""Test compute_factor_stats under multi-stock-pool NaN scenarios."""
|
||||||
|
|
||||||
|
def test_pool_with_nans(self, rng):
|
||||||
|
M, T = 40, 30
|
||||||
|
signals = rng.normal(0, 1, (M, T))
|
||||||
|
returns = rng.normal(0, 0.01, (M, T))
|
||||||
|
# 模拟子池:部分资产在所有时间上为 NaN
|
||||||
|
signals[30:, :] = np.nan
|
||||||
|
returns[30:, :] = np.nan
|
||||||
|
stats = compute_factor_stats(signals, returns)
|
||||||
|
assert stats["n_periods"] == T
|
||||||
|
assert np.isfinite(stats["ic_mean"])
|
||||||
|
|||||||
72
src/factorminer/tests/test_stock_pool.py
Normal file
72
src/factorminer/tests/test_stock_pool.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.factorminer.evaluation.stock_pool_registry import StockPoolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class TestStockPoolRegistry:
|
||||||
|
def test_add_and_get_pool_names(self):
|
||||||
|
registry = StockPoolRegistry()
|
||||||
|
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
|
||||||
|
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
|
||||||
|
assert registry.get_pool_names() == ["all", "growth"]
|
||||||
|
|
||||||
|
def test_build_masks(self):
|
||||||
|
registry = StockPoolRegistry()
|
||||||
|
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
|
||||||
|
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
|
||||||
|
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
|
||||||
|
masks = registry.build_masks(asset_codes)
|
||||||
|
assert masks["all"].sum() == 3
|
||||||
|
assert masks["growth"].sum() == 1
|
||||||
|
assert masks["growth"][1]
|
||||||
|
|
||||||
|
def test_filter_signals(self):
|
||||||
|
registry = StockPoolRegistry()
|
||||||
|
registry.add_pool("growth", lambda df: df["ts_code"].str.starts_with("300"))
|
||||||
|
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
|
||||||
|
registry.build_masks(asset_codes)
|
||||||
|
signals = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float64)
|
||||||
|
filtered = registry.filter_signals(signals, "growth")
|
||||||
|
assert filtered.shape == (1, 2)
|
||||||
|
np.testing.assert_array_equal(filtered, [[3, 4]])
|
||||||
|
|
||||||
|
def test_build_masks_with_metadata(self):
|
||||||
|
registry = StockPoolRegistry()
|
||||||
|
registry.add_pool(
|
||||||
|
"small_cap",
|
||||||
|
lambda df: df["ts_code"].is_in(
|
||||||
|
df.sort("total_mv").head(2)["ts_code"].implode()
|
||||||
|
),
|
||||||
|
required_columns=["total_mv"],
|
||||||
|
)
|
||||||
|
asset_codes = ["A", "B", "C"]
|
||||||
|
metadata = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C"],
|
||||||
|
"total_mv": [300.0, 100.0, 200.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
masks = registry.build_masks(asset_codes, metadata_df=metadata)
|
||||||
|
assert masks["small_cap"].sum() == 2
|
||||||
|
assert masks["small_cap"][1] and masks["small_cap"][2]
|
||||||
|
|
||||||
|
def test_expr_fallback(self):
|
||||||
|
registry = StockPoolRegistry()
|
||||||
|
registry.add_pool(
|
||||||
|
"expr_pool", lambda df: pl.col("ts_code").str.starts_with("300")
|
||||||
|
)
|
||||||
|
asset_codes = ["000001.SZ", "300001.SZ"]
|
||||||
|
masks = registry.build_masks(asset_codes)
|
||||||
|
assert masks["expr_pool"].sum() == 1
|
||||||
|
|
||||||
|
def test_pre_filters(self):
|
||||||
|
registry = StockPoolRegistry(
|
||||||
|
pre_filters=[lambda df: df["ts_code"].str.starts_with("0")]
|
||||||
|
)
|
||||||
|
registry.add_pool("all", lambda df: pl.Series([True] * len(df)))
|
||||||
|
asset_codes = ["000001.SZ", "300001.SZ", "688001.SH"]
|
||||||
|
masks = registry.build_masks(asset_codes)
|
||||||
|
assert masks["all"].sum() == 1
|
||||||
|
assert masks["all"][0]
|
||||||
Reference in New Issue
Block a user