"""测试因子基类 - BaseFactor、CrossSectionalFactor、TimeSeriesFactor 测试需求(来自 factor_implementation_plan.md): - BaseFactor: - 测试有效子类创建通过验证 - 测试缺少 `name` 时抛出 ValueError - 测试 `name` 为空字符串时抛出 ValueError - 测试缺少 `factor_type` 时抛出 ValueError - 测试无效的 `factor_type`(非 cs/ts)时抛出 ValueError - 测试缺少 `data_specs` 时抛出 ValueError - 测试 `data_specs` 为空列表时抛出 ValueError - 测试 `compute()` 抽象方法强制子类实现 - 测试参数化初始化 `params` 正确存储 - 测试 `_validate_params()` 被调用 - CrossSectionalFactor: - 测试 `factor_type` 自动设置为 "cross_sectional" - 测试子类必须实现 `compute()` - 测试 `compute()` 返回类型为 pl.Series - TimeSeriesFactor: - 测试 `factor_type` 自动设置为 "time_series" - 测试子类必须实现 `compute()` - 测试 `compute()` 返回类型为 pl.Series """ import pytest import polars as pl from src.factors import DataSpec, FactorContext, FactorData from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor # ========== 测试数据准备 ========== @pytest.fixture def sample_dataspec(): """创建一个示例 DataSpec""" return DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5 ) @pytest.fixture def sample_factor_data(): """创建一个示例 FactorData""" df = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"], "trade_date": ["20240101", "20240101", "20240102", "20240102"], "close": [10.0, 20.0, 11.0, 21.0], } ) context = FactorContext(current_date="20240102") return FactorData(df, context) # ========== BaseFactor 测试 ========== class TestBaseFactorValidation: """测试 BaseFactor 子类验证""" def test_valid_cross_sectional_subclass(self, sample_dataspec): """测试有效的截面因子子类创建通过验证""" class ValidFactor(CrossSectionalFactor): name = "valid_cs" data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) # 应该能成功创建实例 factor = ValidFactor() assert factor.name == "valid_cs" assert factor.factor_type == "cross_sectional" def test_valid_time_series_subclass(self, sample_dataspec): """测试有效的时序因子子类创建通过验证""" class ValidFactor(TimeSeriesFactor): name = "valid_ts" data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) factor = ValidFactor() assert factor.name == "valid_ts" assert factor.factor_type == "time_series" def test_missing_name(self, sample_dataspec): """测试缺少 name 时抛出 ValueError""" with pytest.raises(ValueError, match="must define 'name'"): class BadFactor(CrossSectionalFactor): # name = "" # 故意不定义 data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) def test_empty_name(self, sample_dataspec): """测试 name 为空字符串时抛出 ValueError""" with pytest.raises(ValueError, match="must define 'name'"): class BadFactor(CrossSectionalFactor): name = "" # 空字符串 data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) def test_missing_factor_type(self, sample_dataspec): """测试缺少 factor_type 时抛出 ValueError""" with pytest.raises(ValueError, match="must define 'factor_type'"): class BadFactor(BaseFactor): name = "bad_factor" # factor_type = "" # 故意不定义 data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) def test_invalid_factor_type(self, sample_dataspec): """测试无效的 factor_type(非 cs/ts)时抛出 ValueError""" with pytest.raises( ValueError, match="must be 'cross_sectional' or 'time_series'" ): class BadFactor(BaseFactor): name = "bad_factor" factor_type = "invalid_type" data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) def test_missing_data_specs(self): """测试缺少 data_specs 时抛出 ValueError""" with pytest.raises(ValueError, match="must define 'data_specs'"): class BadFactor(BaseFactor): name = "bad_factor" factor_type = "cross_sectional" # data_specs = [] # 故意不定义 def compute(self, data): return pl.Series([1.0]) def test_empty_data_specs(self): """测试 data_specs 为空列表时抛出 ValueError""" with pytest.raises(ValueError, match="cannot be empty"): class BadFactor(CrossSectionalFactor): name = "bad_factor" data_specs = [] # 空列表 def compute(self, data): return pl.Series([1.0]) class TestBaseFactorCompute: """测试 compute() 抽象方法""" def test_compute_must_be_implemented_cs(self, sample_dataspec): """测试截面因子子类必须实现 compute()""" with pytest.raises(TypeError, match="Can't instantiate abstract class"): class BadFactor(CrossSectionalFactor): name = "bad_cs" data_specs = [sample_dataspec] # 不实现 compute() BadFactor() def test_compute_must_be_implemented_ts(self, sample_dataspec): """测试时序因子子类必须实现 compute()""" with pytest.raises(TypeError, match="Can't instantiate abstract class"): class BadFactor(TimeSeriesFactor): name = "bad_ts" data_specs = [sample_dataspec] # 不实现 compute() BadFactor() class TestBaseFactorParams: """测试参数化初始化""" def test_params_stored_correctly(self, sample_dataspec): """测试参数化初始化 params 正确存储""" class ParamFactor(CrossSectionalFactor): name = "param_factor" data_specs = [sample_dataspec] def __init__(self, period: int = 20, weight: float = 1.0): super().__init__(period=period, weight=weight) def compute(self, data): return pl.Series([1.0]) factor = ParamFactor(period=10, weight=0.5) assert factor.params["period"] == 10 assert factor.params["weight"] == 0.5 def test_validate_params_called(self, sample_dataspec): """测试 _validate_params() 被调用""" validated = [] class ValidatedFactor(CrossSectionalFactor): name = "validated_factor" data_specs = [sample_dataspec] def __init__(self, period: int = 20): super().__init__(period=period) def _validate_params(self): validated.append(True) if self.params.get("period", 0) <= 0: raise ValueError("period must be positive") def compute(self, data): return pl.Series([1.0]) # 创建实例时应该调用 _validate_params factor = ValidatedFactor(period=10) assert len(validated) == 1 assert factor.params["period"] == 10 def test_validate_params_raises(self, sample_dataspec): """测试 _validate_params() 可以抛出异常""" class BadParamFactor(CrossSectionalFactor): name = "bad_param_factor" data_specs = [sample_dataspec] def __init__(self, period: int = 20): super().__init__(period=period) def _validate_params(self): if self.params.get("period", 0) <= 0: raise ValueError("period must be positive") def compute(self, data): return pl.Series([1.0]) with pytest.raises(ValueError, match="period must be positive"): BadParamFactor(period=-5) class TestBaseFactorRepr: """测试 __repr__""" def test_repr(self, sample_dataspec): """测试 __repr__ 返回正确格式""" class TestFactor(CrossSectionalFactor): name = "test_factor" data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) factor = TestFactor() repr_str = repr(factor) assert "TestFactor" in repr_str assert "test_factor" in repr_str assert "cross_sectional" in repr_str # ========== CrossSectionalFactor 测试 ========== class TestCrossSectionalFactor: """测试 CrossSectionalFactor""" def test_factor_type_auto_set(self, sample_dataspec): """测试 factor_type 自动设置为 'cross_sectional'""" class CSFactor(CrossSectionalFactor): name = "cs_factor" data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) factor = CSFactor() assert factor.factor_type == "cross_sectional" def test_compute_returns_series(self, sample_factor_data, sample_dataspec): """测试 compute() 返回类型为 pl.Series""" class CSFactor(CrossSectionalFactor): name = "cs_factor" data_specs = [sample_dataspec] def compute(self, data): # 返回一个简单的 Series return pl.Series([1.0, 2.0]) factor = CSFactor() result = factor.compute(sample_factor_data) assert isinstance(result, pl.Series) def test_compute_with_cross_section(self, sample_dataspec): """测试 compute() 使用 get_cross_section()""" df = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20240101", "20240101"], "close": [10.0, 20.0], } ) context = FactorContext(current_date="20240101") data = FactorData(df, context) class RankFactor(CrossSectionalFactor): name = "rank_factor" data_specs = [sample_dataspec] def compute(self, data): cs = data.get_cross_section() return cs["close"].rank() factor = RankFactor() result = factor.compute(data) assert isinstance(result, pl.Series) assert len(result) == 2 # ========== TimeSeriesFactor 测试 ========== class TestTimeSeriesFactor: """测试 TimeSeriesFactor""" def test_factor_type_auto_set(self, sample_dataspec): """测试 factor_type 自动设置为 'time_series'""" class TSFactor(TimeSeriesFactor): name = "ts_factor" data_specs = [sample_dataspec] def compute(self, data): return pl.Series([1.0]) factor = TSFactor() assert factor.factor_type == "time_series" def test_compute_returns_series(self, sample_factor_data, sample_dataspec): """测试 compute() 返回类型为 pl.Series""" class TSFactor(TimeSeriesFactor): name = "ts_factor" data_specs = [sample_dataspec] def compute(self, data): return data.get_column("close") * 2 factor = TSFactor() result = factor.compute(sample_factor_data) assert isinstance(result, pl.Series) assert len(result) == 4 # 4行数据 def test_compute_with_rolling(self, sample_dataspec): """测试 compute() 使用 rolling 操作""" df = pl.DataFrame( { "ts_code": ["000001.SZ"] * 5, "trade_date": [ "20240101", "20240102", "20240103", "20240104", "20240105", ], "close": [10.0, 11.0, 12.0, 13.0, 14.0], } ) context = FactorContext(current_stock="000001.SZ") data = FactorData(df, context) class MAFactor(TimeSeriesFactor): name = "ma_factor" data_specs = [sample_dataspec] def __init__(self, period: int = 3): super().__init__(period=period) def compute(self, data): return data.get_column("close").rolling_mean(self.params["period"]) factor = MAFactor(period=3) result = factor.compute(data) assert isinstance(result, pl.Series) assert len(result) == 5 # 前2个应该是 null(因为 period=3) assert result[0] is None assert result[1] is None assert result[2] is not None