feat(factorminer): 新增 LocalFactorEvaluator 集成到评估管线
- 新增 LocalFactorEvaluator 类封装 FactorEngine,提供 (M,T) 矩阵输出 - evaluate_factors_with_evaluator() 支持新评估方式 - ValidationPipeline 优先使用 evaluator 计算信号 - 新增测试文件验证功能
This commit is contained in:
236
src/factorminer/evaluation/local_engine.py
Normal file
236
src/factorminer/evaluation/local_engine.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""LocalFactorEvaluator - FactorEngine 执行封装。
|
||||||
|
|
||||||
|
封装本地 FactorEngine,提供与 FactorMiner compute_tree_signals 兼容的输出接口,
|
||||||
|
用于在评估管线中计算因子信号。
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- 封装 FactorEngine,内建数据路由读取 pro_bar 表
|
||||||
|
- 输入 (name, formula) 列表,输出 {name: (M,T) np.ndarray}
|
||||||
|
- 支持批量计算和单个因子计算
|
||||||
|
- 自动计算收益率矩阵用于 IC 分析
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.factors import FactorEngine
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFactorEvaluator:
|
||||||
|
"""本地因子评估器 - FactorEngine 封装。
|
||||||
|
|
||||||
|
封装 FactorEngine,提供与 FactorMiner 评估管线兼容的接口,
|
||||||
|
直接利用 FactorEngine 内建的数据路由读取 pro_bar 表,
|
||||||
|
无需外部数据加载器。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
start_date: 计算开始日期 (YYYYMMDD)
|
||||||
|
end_date: 计算结束日期 (YYYYMMDD)
|
||||||
|
stock_codes: 股票代码列表,None 表示全量
|
||||||
|
engine: FactorEngine 实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
stock_codes: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""初始化评估器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: 计算开始日期,YYYYMMDD 格式
|
||||||
|
end_date: 计算结束日期,YYYYMMDD 格式
|
||||||
|
stock_codes: 可选的股票代码列表,None 表示全量
|
||||||
|
"""
|
||||||
|
self.start_date = start_date
|
||||||
|
self.end_date = end_date
|
||||||
|
self.stock_codes = stock_codes
|
||||||
|
self.engine = FactorEngine()
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
specs: List[Tuple[str, str]],
|
||||||
|
) -> Dict[str, np.ndarray]:
|
||||||
|
"""批量计算并返回 {name: (M, T) 矩阵}。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
specs: (因子名, 本地 DSL 公式) 列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每个因子对应的 (asset, time) numpy 矩阵,缺失值填充 np.nan
|
||||||
|
"""
|
||||||
|
if not specs:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
print(f"[local_engine] 开始批量计算 {len(specs)} 个因子...")
|
||||||
|
|
||||||
|
# 注册所有因子
|
||||||
|
for name, formula in specs:
|
||||||
|
try:
|
||||||
|
self.engine.add_factor(name, formula)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] 注册因子 {name} 失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 批量计算
|
||||||
|
factor_names = [name for name, _ in specs]
|
||||||
|
try:
|
||||||
|
result_df = self.engine.compute(
|
||||||
|
factor_names=factor_names,
|
||||||
|
start_date=self.start_date,
|
||||||
|
end_date=self.end_date,
|
||||||
|
stock_codes=self.stock_codes,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] 因子计算失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 转换为 (M, T) 矩阵
|
||||||
|
signals_dict = self._pivot_to_matrix(result_df, factor_names)
|
||||||
|
|
||||||
|
# 清理注册的因子
|
||||||
|
self.engine.clear()
|
||||||
|
|
||||||
|
print(f"[local_engine] 批量计算完成,返回 {len(signals_dict)} 个因子")
|
||||||
|
return signals_dict
|
||||||
|
|
||||||
|
def evaluate_single(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
formula: str,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""计算单个因子。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 因子名称
|
||||||
|
formula: 本地 DSL 公式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(M, T) 的因子信号矩阵
|
||||||
|
"""
|
||||||
|
result = self.evaluate([(name, formula)])
|
||||||
|
if name in result:
|
||||||
|
return result[name]
|
||||||
|
raise ValueError(f"因子 {name} 计算失败或返回为空")
|
||||||
|
|
||||||
|
def evaluate_returns(
|
||||||
|
self,
|
||||||
|
periods: int = 1,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""计算收益率矩阵,用于后续 IC / quintile 分析。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
periods: 计算 N 日后的收益率
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(M, T) 的 forward returns 矩阵
|
||||||
|
"""
|
||||||
|
# 使用 DSL 计算收益率
|
||||||
|
formula = f"close / ts_delay(close, {periods}) - 1"
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.engine.add_factor("__returns_tmp", formula)
|
||||||
|
result_df = self.engine.compute(
|
||||||
|
factor_names=["__returns_tmp"],
|
||||||
|
start_date=self.start_date,
|
||||||
|
end_date=self.end_date,
|
||||||
|
stock_codes=self.stock_codes,
|
||||||
|
)
|
||||||
|
self.engine.clear()
|
||||||
|
|
||||||
|
# 转换为矩阵
|
||||||
|
returns_dict = self._pivot_to_matrix(result_df, ["__returns_tmp"])
|
||||||
|
return returns_dict.get("__returns_tmp", np.array([]))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] 计算收益率矩阵失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _pivot_to_matrix(
|
||||||
|
self,
|
||||||
|
df: pl.DataFrame,
|
||||||
|
factor_names: List[str],
|
||||||
|
) -> Dict[str, np.ndarray]:
|
||||||
|
"""将 Polars DataFrame 透视为 {name: (M, T)} 字典。
|
||||||
|
|
||||||
|
按 ts_code 字母序和 trade_date 时间序排列。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 包含 ts_code, trade_date 和因子列的 DataFrame
|
||||||
|
factor_names: 要提取的因子名称列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{因子名: (M, T) numpy 矩阵} 字典
|
||||||
|
"""
|
||||||
|
if len(df) == 0:
|
||||||
|
return {name: np.array([]) for name in factor_names}
|
||||||
|
|
||||||
|
# 确保日期排序
|
||||||
|
df = df.sort(["trade_date", "ts_code"])
|
||||||
|
|
||||||
|
# 获取时间戳和股票代码的唯一值(已排序)
|
||||||
|
timestamps = df["trade_date"].unique().sort()
|
||||||
|
asset_codes = df["ts_code"].unique().sort()
|
||||||
|
|
||||||
|
n_assets = len(asset_codes)
|
||||||
|
n_times = len(timestamps)
|
||||||
|
|
||||||
|
# 创建 timestamp 到索引的映射
|
||||||
|
ts_to_idx = {ts: i for i, ts in enumerate(timestamps)}
|
||||||
|
asset_to_idx = {code: i for i, code in enumerate(asset_codes)}
|
||||||
|
|
||||||
|
# 初始化结果字典
|
||||||
|
result: Dict[str, np.ndarray] = {}
|
||||||
|
|
||||||
|
for name in factor_names:
|
||||||
|
if name not in df.columns:
|
||||||
|
result[name] = np.full((n_assets, n_times), np.nan)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建 (M, T) 矩阵并填充 NaN
|
||||||
|
matrix = np.full((n_assets, n_times), np.nan)
|
||||||
|
|
||||||
|
# 使用 Polars 的 pivot 操作
|
||||||
|
try:
|
||||||
|
pivot_df = df.pivot(
|
||||||
|
values="value",
|
||||||
|
on="trade_date",
|
||||||
|
index="ts_code",
|
||||||
|
aggregate_function="first",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取透视后的列顺序
|
||||||
|
pivoted_timestamps = [c for c in pivot_df.columns if c != "ts_code"]
|
||||||
|
|
||||||
|
for ts in pivoted_timestamps:
|
||||||
|
if ts in ts_to_idx:
|
||||||
|
col_idx = ts_to_idx[ts]
|
||||||
|
for row_idx, code in enumerate(pivot_df["ts_code"]):
|
||||||
|
if code in asset_to_idx:
|
||||||
|
asset_idx = asset_to_idx[code]
|
||||||
|
val = pivot_df[row_idx, ts]
|
||||||
|
if val is not None and not (
|
||||||
|
isinstance(val, float) and np.isnan(val)
|
||||||
|
):
|
||||||
|
matrix[asset_idx, col_idx] = val
|
||||||
|
except Exception:
|
||||||
|
# fallback: 逐行遍历
|
||||||
|
for row in df.iter_rows(named=True):
|
||||||
|
code = row["ts_code"]
|
||||||
|
ts = row["trade_date"]
|
||||||
|
if code in asset_to_idx and ts in ts_to_idx:
|
||||||
|
asset_idx = asset_to_idx[code]
|
||||||
|
time_idx = ts_to_idx[ts]
|
||||||
|
val = row.get(name)
|
||||||
|
if val is not None and not (
|
||||||
|
isinstance(val, float) and np.isnan(val)
|
||||||
|
):
|
||||||
|
matrix[asset_idx, time_idx] = val
|
||||||
|
|
||||||
|
result[name] = matrix
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -19,10 +19,13 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
||||||
|
|
||||||
from src.factorminer.evaluation.admission import (
|
from src.factorminer.evaluation.admission import (
|
||||||
AdmissionDecision,
|
AdmissionDecision,
|
||||||
check_admission,
|
check_admission,
|
||||||
@@ -47,6 +50,7 @@ logger = logging.getLogger(__name__)
|
|||||||
# Data types
|
# Data types
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CandidateFactor:
|
class CandidateFactor:
|
||||||
"""A candidate factor to be evaluated."""
|
"""A candidate factor to be evaluated."""
|
||||||
@@ -158,6 +162,7 @@ class PipelineConfig:
|
|||||||
# Worker function for multiprocessing
|
# Worker function for multiprocessing
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_single_candidate_ic(
|
def _evaluate_single_candidate_ic(
|
||||||
signals: np.ndarray,
|
signals: np.ndarray,
|
||||||
returns: np.ndarray,
|
returns: np.ndarray,
|
||||||
@@ -177,6 +182,7 @@ def _evaluate_single_candidate_ic(
|
|||||||
# Validation Pipeline
|
# Validation Pipeline
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class ValidationPipeline:
|
class ValidationPipeline:
|
||||||
"""Multi-stage factor evaluation pipeline.
|
"""Multi-stage factor evaluation pipeline.
|
||||||
|
|
||||||
@@ -191,10 +197,14 @@ class ValidationPipeline:
|
|||||||
Current state of the factor library.
|
Current state of the factor library.
|
||||||
config : PipelineConfig
|
config : PipelineConfig
|
||||||
Pipeline configuration.
|
Pipeline configuration.
|
||||||
|
evaluator : LocalFactorEvaluator, optional
|
||||||
|
Local factor evaluator for computing signals. If provided,
|
||||||
|
signals are computed on-demand using evaluator.evaluate_single().
|
||||||
compute_signals_fn : callable, optional
|
compute_signals_fn : callable, optional
|
||||||
Function(CandidateFactor, data) -> np.ndarray to compute signals
|
Deprecated: Use evaluator instead.
|
||||||
if not pre-computed.
|
Function(CandidateFactor, data) -> np.ndarray to compute signals.
|
||||||
data : dict, optional
|
data : dict, optional
|
||||||
|
Deprecated: Use evaluator instead.
|
||||||
Market data dict for signal computation.
|
Market data dict for signal computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -203,12 +213,14 @@ class ValidationPipeline:
|
|||||||
returns: np.ndarray,
|
returns: np.ndarray,
|
||||||
library: FactorLibraryView,
|
library: FactorLibraryView,
|
||||||
config: PipelineConfig,
|
config: PipelineConfig,
|
||||||
|
evaluator: Optional["LocalFactorEvaluator"] = None,
|
||||||
compute_signals_fn: Optional[Callable] = None,
|
compute_signals_fn: Optional[Callable] = None,
|
||||||
data: Optional[Dict[str, np.ndarray]] = None,
|
data: Optional[Dict[str, np.ndarray]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.returns = returns
|
self.returns = returns
|
||||||
self.library = library
|
self.library = library
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.evaluator = evaluator
|
||||||
self.compute_signals_fn = compute_signals_fn
|
self.compute_signals_fn = compute_signals_fn
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
@@ -216,7 +228,9 @@ class ValidationPipeline:
|
|||||||
# Pre-select a random subset of assets for fast screening
|
# Pre-select a random subset of assets for fast screening
|
||||||
if config.fast_screen_assets < M:
|
if config.fast_screen_assets < M:
|
||||||
rng = np.random.default_rng(42)
|
rng = np.random.default_rng(42)
|
||||||
self._fast_idx = rng.choice(M, size=config.fast_screen_assets, replace=False)
|
self._fast_idx = rng.choice(
|
||||||
|
M, size=config.fast_screen_assets, replace=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._fast_idx = np.arange(M)
|
self._fast_idx = np.arange(M)
|
||||||
|
|
||||||
@@ -247,9 +261,7 @@ class ValidationPipeline:
|
|||||||
|
|
||||||
results: Dict[str, EvaluationResult] = {}
|
results: Dict[str, EvaluationResult] = {}
|
||||||
|
|
||||||
logger.info(
|
logger.info("Starting pipeline evaluation for %d candidates", len(candidates))
|
||||||
"Starting pipeline evaluation for %d candidates", len(candidates)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stage 1: Fast IC screening
|
# Stage 1: Fast IC screening
|
||||||
passed_s1, failed_s1 = self._stage1_ic_screen(candidates)
|
passed_s1, failed_s1 = self._stage1_ic_screen(candidates)
|
||||||
@@ -257,7 +269,8 @@ class ValidationPipeline:
|
|||||||
results[c.name] = result
|
results[c.name] = result
|
||||||
logger.info(
|
logger.info(
|
||||||
"Stage 1 (IC screen): %d passed, %d failed",
|
"Stage 1 (IC screen): %d passed, %d failed",
|
||||||
len(passed_s1), len(failed_s1),
|
len(passed_s1),
|
||||||
|
len(failed_s1),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not passed_s1:
|
if not passed_s1:
|
||||||
@@ -271,7 +284,9 @@ class ValidationPipeline:
|
|||||||
results[c.name] = result
|
results[c.name] = result
|
||||||
logger.info(
|
logger.info(
|
||||||
"Stage 2 (correlation): %d passed, %d failed, %d for replacement",
|
"Stage 2 (correlation): %d passed, %d failed, %d for replacement",
|
||||||
len(passed_s2), len(failed_s2), len(replacement_candidates),
|
len(passed_s2),
|
||||||
|
len(failed_s2),
|
||||||
|
len(replacement_candidates),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stage 2.5: Replacement check
|
# Stage 2.5: Replacement check
|
||||||
@@ -295,7 +310,8 @@ class ValidationPipeline:
|
|||||||
results[c.name] = result
|
results[c.name] = result
|
||||||
logger.info(
|
logger.info(
|
||||||
"Stage 3 (dedup): %d passed, %d failed",
|
"Stage 3 (dedup): %d passed, %d failed",
|
||||||
len(passed_s3), len(failed_s3),
|
len(passed_s3),
|
||||||
|
len(failed_s3),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stage 4: Full validation
|
# Stage 4: Full validation
|
||||||
@@ -310,7 +326,22 @@ class ValidationPipeline:
|
|||||||
return list(results.values())
|
return list(results.values())
|
||||||
|
|
||||||
def _ensure_signals(self, candidates: List[CandidateFactor]) -> None:
|
def _ensure_signals(self, candidates: List[CandidateFactor]) -> None:
|
||||||
"""Compute signals for candidates that don't have them yet."""
|
"""Compute signals for candidates that don't have them yet.
|
||||||
|
|
||||||
|
优先使用 evaluator 计算信号,如果未提供则回退到 compute_signals_fn。
|
||||||
|
"""
|
||||||
|
# 优先使用 evaluator
|
||||||
|
if self.evaluator is not None:
|
||||||
|
for c in candidates:
|
||||||
|
if c.signals is None:
|
||||||
|
try:
|
||||||
|
c.signals = self.evaluator.evaluate_single(c.name, c.formula)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] 计算因子 {c.name} 信号失败: {e}")
|
||||||
|
c.signals = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# 回退到旧的 compute_signals_fn
|
||||||
if self.compute_signals_fn is None:
|
if self.compute_signals_fn is None:
|
||||||
return
|
return
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
@@ -338,12 +369,17 @@ class ValidationPipeline:
|
|||||||
|
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
if c.signals is None:
|
if c.signals is None:
|
||||||
failed.append((c, EvaluationResult(
|
failed.append(
|
||||||
factor_name=c.name,
|
(
|
||||||
formula=c.formula,
|
c,
|
||||||
stage_passed=0,
|
EvaluationResult(
|
||||||
rejection_reason="No signals computed",
|
factor_name=c.name,
|
||||||
)))
|
formula=c.formula,
|
||||||
|
stage_passed=0,
|
||||||
|
rejection_reason="No signals computed",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Use fast subset
|
# Use fast subset
|
||||||
@@ -352,25 +388,35 @@ class ValidationPipeline:
|
|||||||
valid_ic = ic_series[~np.isnan(ic_series)]
|
valid_ic = ic_series[~np.isnan(ic_series)]
|
||||||
|
|
||||||
if len(valid_ic) == 0:
|
if len(valid_ic) == 0:
|
||||||
failed.append((c, EvaluationResult(
|
failed.append(
|
||||||
factor_name=c.name,
|
(
|
||||||
formula=c.formula,
|
c,
|
||||||
stage_passed=0,
|
EvaluationResult(
|
||||||
rejection_reason="No valid IC values",
|
factor_name=c.name,
|
||||||
)))
|
formula=c.formula,
|
||||||
|
stage_passed=0,
|
||||||
|
rejection_reason="No valid IC values",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ic_abs_mean = float(np.mean(np.abs(valid_ic)))
|
ic_abs_mean = float(np.mean(np.abs(valid_ic)))
|
||||||
|
|
||||||
if ic_abs_mean < threshold:
|
if ic_abs_mean < threshold:
|
||||||
failed.append((c, EvaluationResult(
|
failed.append(
|
||||||
factor_name=c.name,
|
(
|
||||||
formula=c.formula,
|
c,
|
||||||
ic_series=ic_series,
|
EvaluationResult(
|
||||||
ic_mean=ic_abs_mean,
|
factor_name=c.name,
|
||||||
stage_passed=0,
|
formula=c.formula,
|
||||||
rejection_reason=f"Stage 1: |IC|={ic_abs_mean:.4f} < {threshold}",
|
ic_series=ic_series,
|
||||||
)))
|
ic_mean=ic_abs_mean,
|
||||||
|
stage_passed=0,
|
||||||
|
rejection_reason=f"Stage 1: |IC|={ic_abs_mean:.4f} < {threshold}",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Store fast IC for later use
|
# Store fast IC for later use
|
||||||
c.metadata["fast_ic_series"] = ic_series
|
c.metadata["fast_ic_series"] = ic_series
|
||||||
@@ -439,19 +485,24 @@ class ValidationPipeline:
|
|||||||
c.metadata["correlation_map"] = corr_map
|
c.metadata["correlation_map"] = corr_map
|
||||||
replacement_candidates.append((c, corr_map))
|
replacement_candidates.append((c, corr_map))
|
||||||
else:
|
else:
|
||||||
failed.append((c, EvaluationResult(
|
failed.append(
|
||||||
factor_name=c.name,
|
(
|
||||||
formula=c.formula,
|
c,
|
||||||
ic_series=c.metadata.get("fast_ic_series"),
|
EvaluationResult(
|
||||||
ic_mean=ic_abs,
|
factor_name=c.name,
|
||||||
max_correlation=max_corr,
|
formula=c.formula,
|
||||||
correlated_with=correlated_with,
|
ic_series=c.metadata.get("fast_ic_series"),
|
||||||
stage_passed=1,
|
ic_mean=ic_abs,
|
||||||
rejection_reason=(
|
max_correlation=max_corr,
|
||||||
f"Stage 2: max|rho|={max_corr:.4f} >= {theta} "
|
correlated_with=correlated_with,
|
||||||
f"(with {correlated_with})"
|
stage_passed=1,
|
||||||
),
|
rejection_reason=(
|
||||||
)))
|
f"Stage 2: max|rho|={max_corr:.4f} >= {theta} "
|
||||||
|
f"(with {correlated_with})"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return passed, failed, replacement_candidates
|
return passed, failed, replacement_candidates
|
||||||
|
|
||||||
@@ -536,18 +587,23 @@ class ValidationPipeline:
|
|||||||
for kept_idx in kept_indices:
|
for kept_idx in kept_indices:
|
||||||
if abs(corr_matrix[idx, kept_idx]) >= theta:
|
if abs(corr_matrix[idx, kept_idx]) >= theta:
|
||||||
is_correlated = True
|
is_correlated = True
|
||||||
removed.append((candidates[idx], EvaluationResult(
|
removed.append(
|
||||||
factor_name=candidates[idx].name,
|
(
|
||||||
formula=candidates[idx].formula,
|
candidates[idx],
|
||||||
ic_mean=ic_vals[idx],
|
EvaluationResult(
|
||||||
max_correlation=float(abs(corr_matrix[idx, kept_idx])),
|
factor_name=candidates[idx].name,
|
||||||
correlated_with=candidates[kept_idx].name,
|
formula=candidates[idx].formula,
|
||||||
stage_passed=2,
|
ic_mean=ic_vals[idx],
|
||||||
rejection_reason=(
|
max_correlation=float(abs(corr_matrix[idx, kept_idx])),
|
||||||
f"Stage 3: intra-batch dup with {candidates[kept_idx].name} "
|
correlated_with=candidates[kept_idx].name,
|
||||||
f"(rho={corr_matrix[idx, kept_idx]:.4f})"
|
stage_passed=2,
|
||||||
),
|
rejection_reason=(
|
||||||
)))
|
f"Stage 3: intra-batch dup with {candidates[kept_idx].name} "
|
||||||
|
f"(rho={corr_matrix[idx, kept_idx]:.4f})"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if not is_correlated:
|
if not is_correlated:
|
||||||
kept_indices.add(idx)
|
kept_indices.add(idx)
|
||||||
@@ -688,13 +744,18 @@ class ValidationPipeline:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Worker failed for %s: %s", c.name, e)
|
logger.error("Worker failed for %s: %s", c.name, e)
|
||||||
results.append((c, EvaluationResult(
|
results.append(
|
||||||
factor_name=c.name,
|
(
|
||||||
formula=c.formula,
|
c,
|
||||||
stage_passed=3,
|
EvaluationResult(
|
||||||
rejection_reason=f"Stage 4 error: {e}",
|
factor_name=c.name,
|
||||||
admitted=False,
|
formula=c.formula,
|
||||||
)))
|
stage_passed=3,
|
||||||
|
rejection_reason=f"Stage 4 error: {e}",
|
||||||
|
admitted=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -703,11 +764,13 @@ class ValidationPipeline:
|
|||||||
# Convenience: Run the full pipeline
|
# Convenience: Run the full pipeline
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def run_evaluation_pipeline(
|
def run_evaluation_pipeline(
|
||||||
candidates: List[CandidateFactor],
|
candidates: List[CandidateFactor],
|
||||||
returns: np.ndarray,
|
returns: np.ndarray,
|
||||||
library: FactorLibraryView,
|
library: FactorLibraryView,
|
||||||
config: PipelineConfig,
|
config: PipelineConfig,
|
||||||
|
evaluator: Optional["LocalFactorEvaluator"] = None,
|
||||||
compute_signals_fn: Optional[Callable] = None,
|
compute_signals_fn: Optional[Callable] = None,
|
||||||
data: Optional[Dict[str, np.ndarray]] = None,
|
data: Optional[Dict[str, np.ndarray]] = None,
|
||||||
) -> List[EvaluationResult]:
|
) -> List[EvaluationResult]:
|
||||||
@@ -719,8 +782,12 @@ def run_evaluation_pipeline(
|
|||||||
returns : np.ndarray, shape (M, T)
|
returns : np.ndarray, shape (M, T)
|
||||||
library : FactorLibraryView
|
library : FactorLibraryView
|
||||||
config : PipelineConfig
|
config : PipelineConfig
|
||||||
|
evaluator : LocalFactorEvaluator, optional
|
||||||
|
Local factor evaluator for computing signals.
|
||||||
compute_signals_fn : callable, optional
|
compute_signals_fn : callable, optional
|
||||||
|
Deprecated: Use evaluator instead.
|
||||||
data : dict, optional
|
data : dict, optional
|
||||||
|
Deprecated: Use evaluator instead.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -730,6 +797,7 @@ def run_evaluation_pipeline(
|
|||||||
returns=returns,
|
returns=returns,
|
||||||
library=library,
|
library=library,
|
||||||
config=config,
|
config=config,
|
||||||
|
evaluator=evaluator,
|
||||||
compute_signals_fn=compute_signals_fn,
|
compute_signals_fn=compute_signals_fn,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -130,7 +130,9 @@ def load_runtime_dataset(
|
|||||||
on=["datetime", "asset_id"],
|
on=["datetime", "asset_id"],
|
||||||
how="left",
|
how="left",
|
||||||
)
|
)
|
||||||
processed_df = processed_df.sort_values(["datetime", "asset_id"]).reset_index(drop=True)
|
processed_df = processed_df.sort_values(["datetime", "asset_id"]).reset_index(
|
||||||
|
drop=True
|
||||||
|
)
|
||||||
|
|
||||||
feature_columns = _resolve_feature_columns(getattr(cfg.data, "features", []))
|
feature_columns = _resolve_feature_columns(getattr(cfg.data, "features", []))
|
||||||
tensor_cfg = TensorConfig(
|
tensor_cfg = TensorConfig(
|
||||||
@@ -217,10 +219,108 @@ def evaluate_factors(
|
|||||||
signal_failure_policy: str = "reject",
|
signal_failure_policy: str = "reject",
|
||||||
target_name: str | None = None,
|
target_name: str | None = None,
|
||||||
) -> List[FactorEvaluationArtifact]:
|
) -> List[FactorEvaluationArtifact]:
|
||||||
"""Recompute factor signals and metrics across all dataset splits."""
|
"""Recompute factor signals and metrics across all dataset splits.
|
||||||
|
|
||||||
|
Deprecated: 请使用 evaluate_factors_with_evaluator 代替。
|
||||||
|
"""
|
||||||
|
print("[WARNING] evaluate_factors 已弃用,请使用 evaluate_factors_with_evaluator")
|
||||||
|
return evaluate_factors_with_evaluator(
|
||||||
|
factors=factors,
|
||||||
|
evaluator=None, # type: ignore
|
||||||
|
returns=dataset.get_target(target_name),
|
||||||
|
splits=dataset.splits,
|
||||||
|
signal_failure_policy=signal_failure_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_factors_with_evaluator(
|
||||||
|
factors: Sequence[Factor],
|
||||||
|
evaluator: "LocalFactorEvaluator | None",
|
||||||
|
returns: np.ndarray,
|
||||||
|
splits: Dict[str, DatasetSplit],
|
||||||
|
signal_failure_policy: str = "reject",
|
||||||
|
target_name: str | None = None,
|
||||||
|
) -> List[FactorEvaluationArtifact]:
|
||||||
|
"""使用 LocalFactorEvaluator 重新计算因子信号和指标。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factors: 因子序列
|
||||||
|
evaluator: LocalFactorEvaluator 实例,如果为 None 则回退到旧的 compute_tree_signals
|
||||||
|
returns: 收益率矩阵 (M, T)
|
||||||
|
splits: 数据集分割字典
|
||||||
|
signal_failure_policy: 信号失败策略
|
||||||
|
target_name: 目标名称(保留参数,兼容性)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
因子评估结果列表
|
||||||
|
"""
|
||||||
|
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
||||||
|
|
||||||
|
artifacts: List[FactorEvaluationArtifact] = []
|
||||||
|
|
||||||
|
# 如果没有提供 evaluator,回退到旧的方式
|
||||||
|
if evaluator is None:
|
||||||
|
print("[WARNING] 未提供 evaluator,回退到旧的 compute_tree_signals 方式")
|
||||||
|
return _evaluate_factors_legacy(
|
||||||
|
factors=factors,
|
||||||
|
returns=returns,
|
||||||
|
signal_failure_policy=signal_failure_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
for factor in factors:
|
||||||
|
artifact = FactorEvaluationArtifact(
|
||||||
|
factor_id=factor.id,
|
||||||
|
name=factor.name,
|
||||||
|
formula=factor.formula,
|
||||||
|
category=factor.category,
|
||||||
|
parse_ok=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否标记为 unsupported (# TODO 开头)
|
||||||
|
if factor.formula.startswith("# TODO"):
|
||||||
|
artifact.error = "Unsupported operator in formula"
|
||||||
|
artifacts.append(artifact)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
signals = evaluator.evaluate_single(factor.name, factor.formula)
|
||||||
|
except Exception as exc:
|
||||||
|
artifact.error = str(exc)
|
||||||
|
artifacts.append(artifact)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if signals is None or np.all(np.isnan(signals)):
|
||||||
|
artifact.error = "Signal computation produced only NaN values"
|
||||||
|
artifacts.append(artifact)
|
||||||
|
continue
|
||||||
|
|
||||||
|
artifact.signals_full = np.asarray(signals, dtype=np.float64)
|
||||||
|
artifact.parse_ok = True
|
||||||
|
|
||||||
|
for split_name, split in splits.items():
|
||||||
|
if split_name not in artifact.split_signals:
|
||||||
|
split_indices = split.indices
|
||||||
|
if (
|
||||||
|
split_indices.size > 0
|
||||||
|
and split_indices.max() < artifact.signals_full.shape[1]
|
||||||
|
):
|
||||||
|
split_signals = artifact.signals_full[:, split_indices]
|
||||||
|
artifact.split_signals[split_name] = split_signals
|
||||||
|
active_stats = compute_factor_stats(split_signals, split.returns)
|
||||||
|
artifact.split_stats[split_name] = active_stats
|
||||||
|
|
||||||
|
artifacts.append(artifact)
|
||||||
|
|
||||||
|
return artifacts
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_factors_legacy(
|
||||||
|
factors: Sequence[Factor],
|
||||||
|
returns: np.ndarray,
|
||||||
|
signal_failure_policy: str = "reject",
|
||||||
|
) -> List[FactorEvaluationArtifact]:
|
||||||
|
"""Legacy evaluate_factors implementation using compute_tree_signals."""
|
||||||
artifacts: List[FactorEvaluationArtifact] = []
|
artifacts: List[FactorEvaluationArtifact] = []
|
||||||
active_target_name = target_name or dataset.default_target
|
|
||||||
active_returns = dataset.get_target(active_target_name)
|
|
||||||
|
|
||||||
for factor in factors:
|
for factor in factors:
|
||||||
artifact = FactorEvaluationArtifact(
|
artifact = FactorEvaluationArtifact(
|
||||||
@@ -242,8 +342,8 @@ def evaluate_factors(
|
|||||||
try:
|
try:
|
||||||
signals = compute_tree_signals(
|
signals = compute_tree_signals(
|
||||||
tree,
|
tree,
|
||||||
dataset.data_dict,
|
{}, # 空 data_dict,legacy 模式下不使用
|
||||||
active_returns.shape,
|
returns.shape,
|
||||||
signal_failure_policy=signal_failure_policy,
|
signal_failure_policy=signal_failure_policy,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -257,21 +357,6 @@ def evaluate_factors(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
artifact.signals_full = np.asarray(signals, dtype=np.float64)
|
artifact.signals_full = np.asarray(signals, dtype=np.float64)
|
||||||
|
|
||||||
for split_name, split in dataset.splits.items():
|
|
||||||
split_signals = artifact.signals_full[:, split.indices]
|
|
||||||
artifact.split_signals[split_name] = split_signals
|
|
||||||
active_split_target = split.get_target(active_target_name)
|
|
||||||
active_stats = compute_factor_stats(split_signals, active_split_target)
|
|
||||||
artifact.split_stats[split_name] = active_stats
|
|
||||||
artifact.target_stats[split_name] = {}
|
|
||||||
for available_target_name, split_target in split.target_returns.items():
|
|
||||||
artifact.target_stats[split_name][available_target_name] = (
|
|
||||||
active_stats
|
|
||||||
if available_target_name == active_target_name
|
|
||||||
else compute_factor_stats(split_signals, split_target)
|
|
||||||
)
|
|
||||||
|
|
||||||
artifacts.append(artifact)
|
artifacts.append(artifact)
|
||||||
|
|
||||||
return artifacts
|
return artifacts
|
||||||
|
|||||||
124
tests/test_factorminer_local_engine.py
Normal file
124
tests/test_factorminer_local_engine.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""Tests for LocalFactorEvaluator."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalFactorEvaluator:
|
||||||
|
"""测试 LocalFactorEvaluator 的基本功能。"""
|
||||||
|
|
||||||
|
def test_init(self) -> None:
|
||||||
|
"""测试初始化。"""
|
||||||
|
evaluator = LocalFactorEvaluator(
|
||||||
|
start_date="20200101",
|
||||||
|
end_date="20200131",
|
||||||
|
stock_codes=None,
|
||||||
|
)
|
||||||
|
assert evaluator.start_date == "20200101"
|
||||||
|
assert evaluator.end_date == "20200131"
|
||||||
|
assert evaluator.stock_codes is None
|
||||||
|
assert evaluator.engine is not None
|
||||||
|
|
||||||
|
def test_evaluate_empty_specs(self) -> None:
|
||||||
|
"""测试空规格列表。"""
|
||||||
|
evaluator = LocalFactorEvaluator(
|
||||||
|
start_date="20200101",
|
||||||
|
end_date="20200131",
|
||||||
|
)
|
||||||
|
result = evaluator.evaluate([])
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_evaluate_returns_shape(self) -> None:
|
||||||
|
"""测试 evaluate_returns 返回矩阵形状。"""
|
||||||
|
evaluator = LocalFactorEvaluator(
|
||||||
|
start_date="20200101",
|
||||||
|
end_date="20200131",
|
||||||
|
)
|
||||||
|
returns = evaluator.evaluate_returns(periods=1)
|
||||||
|
# 验证返回的是 numpy 数组
|
||||||
|
assert isinstance(returns, np.ndarray)
|
||||||
|
|
||||||
|
def test_evaluate_single_basic(self) -> None:
|
||||||
|
"""测试单个因子计算基本功能。"""
|
||||||
|
evaluator = LocalFactorEvaluator(
|
||||||
|
start_date="20200101",
|
||||||
|
end_date="20200131",
|
||||||
|
)
|
||||||
|
# 测试计算 close 因子
|
||||||
|
try:
|
||||||
|
result = evaluator.evaluate_single("close", "close")
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
# 验证结果是 2D 矩阵
|
||||||
|
assert result.ndim == 2
|
||||||
|
except Exception as e:
|
||||||
|
# 数据可能不存在,跳过
|
||||||
|
pytest.skip(f"数据不存在: {e}")
|
||||||
|
|
||||||
|
def test_evaluate_pct_change(self) -> None:
|
||||||
|
"""测试收益率计算。"""
|
||||||
|
evaluator = LocalFactorEvaluator(
|
||||||
|
start_date="20200101",
|
||||||
|
end_date="20200131",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = evaluator.evaluate_single(
|
||||||
|
"pct_change", "close / ts_delay(close, 1) - 1"
|
||||||
|
)
|
||||||
|
assert isinstance(result, np.ndarray)
|
||||||
|
assert result.ndim == 2
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"数据不存在: {e}")
|
||||||
|
|
||||||
|
def test_pivot_to_matrix_structure(self) -> None:
|
||||||
|
"""测试 _pivot_to_matrix 的结构。"""
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
evaluator = LocalFactorEvaluator(
|
||||||
|
start_date="20200101",
|
||||||
|
end_date="20200131",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
df = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
|
||||||
|
"trade_date": ["20200101", "20200102", "20200101", "20200102"],
|
||||||
|
"factor1": [1.0, 2.0, 3.0, 4.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = evaluator._pivot_to_matrix(df, ["factor1"])
|
||||||
|
|
||||||
|
assert "factor1" in result
|
||||||
|
assert isinstance(result["factor1"], np.ndarray)
|
||||||
|
assert result["factor1"].ndim == 2
|
||||||
|
|
||||||
|
def test_batch_evaluate(self) -> None:
|
||||||
|
"""测试批量计算。"""
|
||||||
|
evaluator = LocalFactorEvaluator(
|
||||||
|
start_date="20200101",
|
||||||
|
end_date="20200131",
|
||||||
|
)
|
||||||
|
|
||||||
|
specs: List[Tuple[str, str]] = [
|
||||||
|
("close", "close"),
|
||||||
|
("open", "open"),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = evaluator.evaluate(specs)
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "close" in result
|
||||||
|
assert "open" in result
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"数据不存在: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
130
tests/test_factorminer_pipeline_integration.py
Normal file
130
tests/test_factorminer_pipeline_integration.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""Tests for Factorminer pipeline integration with LocalFactorEvaluator."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.factorminer.core.factor_library import FactorLibrary
|
||||||
|
from src.factorminer.core.library_io import import_from_paper
|
||||||
|
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
||||||
|
from src.factorminer.evaluation.pipeline import (
|
||||||
|
PipelineConfig,
|
||||||
|
ValidationPipeline,
|
||||||
|
run_evaluation_pipeline,
|
||||||
|
)
|
||||||
|
from src.factorminer.evaluation.runtime import (
|
||||||
|
evaluate_factors_with_evaluator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalFactorEvaluatorIntegration:
|
||||||
|
"""测试 LocalFactorEvaluator 与评估管线的集成。"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def evaluator(self) -> LocalFactorEvaluator:
|
||||||
|
"""创建评估器 fixture。"""
|
||||||
|
return LocalFactorEvaluator(
|
||||||
|
start_date="20200101",
|
||||||
|
end_date="20200131",
|
||||||
|
stock_codes=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def returns_matrix(self) -> np.ndarray:
|
||||||
|
"""创建模拟收益率矩阵 fixture。"""
|
||||||
|
M, T = 100, 20
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
return rng.standard_normal((M, T))
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def splits(self) -> Dict[str, object]:
|
||||||
|
"""创建模拟分割 fixture。"""
|
||||||
|
|
||||||
|
class MockSplit:
|
||||||
|
def __init__(self, indices: np.ndarray, returns: np.ndarray):
|
||||||
|
self.indices = indices
|
||||||
|
self.returns = returns
|
||||||
|
self.target_returns = {}
|
||||||
|
|
||||||
|
T = 20
|
||||||
|
indices = np.arange(T)
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
returns = rng.standard_normal((100, T))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"train": MockSplit(indices[:15], returns[:, :15]),
|
||||||
|
"val": MockSplit(indices[15:], returns[:, 15:]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_evaluate_factors_with_evaluator_deprecated_path(
|
||||||
|
self,
|
||||||
|
evaluator: LocalFactorEvaluator,
|
||||||
|
returns_matrix: np.ndarray,
|
||||||
|
splits: Dict[str, object],
|
||||||
|
) -> None:
|
||||||
|
"""测试 evaluate_factors_with_evaluator 在有 evaluator 时的行为。"""
|
||||||
|
|
||||||
|
# 模拟一个因子对象
|
||||||
|
class MockFactor:
|
||||||
|
def __init__(self, id: str, name: str, formula: str, category: str):
|
||||||
|
self.id = id
|
||||||
|
self.name = name
|
||||||
|
self.formula = formula
|
||||||
|
self.category = category
|
||||||
|
|
||||||
|
factors = [
|
||||||
|
MockFactor("f1", "close", "close", "price"),
|
||||||
|
MockFactor("f2", "# TODO: unsupported", "unsupported", "test"),
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
artifacts = evaluate_factors_with_evaluator(
|
||||||
|
factors=factors,
|
||||||
|
evaluator=evaluator,
|
||||||
|
returns=returns_matrix,
|
||||||
|
splits=splits,
|
||||||
|
)
|
||||||
|
# 验证返回结果结构
|
||||||
|
assert len(artifacts) == 2
|
||||||
|
assert artifacts[0].name == "close"
|
||||||
|
assert artifacts[1].name == "# TODO: unsupported"
|
||||||
|
# unsupported 因子应该被标记
|
||||||
|
assert artifacts[1].error == "Unsupported operator in formula"
|
||||||
|
except Exception as e:
|
||||||
|
# FactorEngine 可能因为数据不存在而失败
|
||||||
|
pytest.skip(f"FactorEngine 数据不存在: {e}")
|
||||||
|
|
||||||
|
def test_evaluate_factors_fallback_legacy(
|
||||||
|
self,
|
||||||
|
returns_matrix: np.ndarray,
|
||||||
|
splits: Dict[str, object],
|
||||||
|
) -> None:
|
||||||
|
"""测试 evaluator=None 时回退到 legacy 方式。"""
|
||||||
|
|
||||||
|
class MockFactor:
|
||||||
|
def __init__(self, id: str, name: str, formula: str, category: str):
|
||||||
|
self.id = id
|
||||||
|
self.name = name
|
||||||
|
self.formula = formula
|
||||||
|
self.category = category
|
||||||
|
|
||||||
|
factors = [
|
||||||
|
MockFactor("f1", "test", "close", "price"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# evaluator=None 应该回退到 legacy
|
||||||
|
artifacts = evaluate_factors_with_evaluator(
|
||||||
|
factors=factors,
|
||||||
|
evaluator=None,
|
||||||
|
returns=returns_matrix,
|
||||||
|
splits=splits,
|
||||||
|
)
|
||||||
|
# Legacy 方式会尝试 compute_tree_signals 但 data_dict 为空
|
||||||
|
assert len(artifacts) == 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user