"""测试股票池选择器配置 验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。 """ import pytest from src.training.components.selectors import ( MarketCapSelectorConfig, StockFilterConfig, ) class TestStockFilterConfig: """StockFilterConfig 测试类""" def test_default_values(self): """测试默认值""" config = StockFilterConfig() assert config.exclude_cyb is True assert config.exclude_kcb is True assert config.exclude_bj is True assert config.exclude_st is True def test_custom_values(self): """测试自定义值""" config = StockFilterConfig( exclude_cyb=False, exclude_kcb=False, exclude_bj=False, exclude_st=False, ) assert config.exclude_cyb is False assert config.exclude_kcb is False assert config.exclude_bj is False assert config.exclude_st is False def test_filter_codes_exclude_all(self): """测试排除所有类型""" config = StockFilterConfig( exclude_cyb=True, exclude_kcb=True, exclude_bj=True, exclude_st=True, ) codes = [ "000001.SZ", # 主板 - 保留 "300001.SZ", # 创业板 - 排除 "688001.SH", # 科创板 - 排除 "830001.BJ", # 北交所(8开头)- 排除 "430001.BJ", # 北交所(4开头)- 排除 ] result = config.filter_codes(codes) assert result == ["000001.SZ"] def test_filter_codes_allow_cyb(self): """测试允许创业板""" config = StockFilterConfig( exclude_cyb=False, exclude_kcb=True, exclude_bj=True, exclude_st=True, ) codes = [ "000001.SZ", "300001.SZ", "688001.SH", ] result = config.filter_codes(codes) assert result == ["000001.SZ", "300001.SZ"] def test_filter_codes_allow_kcb(self): """测试允许科创板""" config = StockFilterConfig( exclude_cyb=True, exclude_kcb=False, exclude_bj=True, exclude_st=True, ) codes = [ "000001.SZ", "300001.SZ", "688001.SH", ] result = config.filter_codes(codes) assert result == ["000001.SZ", "688001.SH"] def test_filter_codes_allow_bj(self): """测试允许北交所""" config = StockFilterConfig( exclude_cyb=True, exclude_kcb=True, exclude_bj=False, exclude_st=True, ) codes = [ "000001.SZ", "300001.SZ", "830001.BJ", "430001.BJ", ] result = config.filter_codes(codes) assert result == ["000001.SZ", "830001.BJ", "430001.BJ"] def test_filter_codes_allow_all(self): """测试允许所有类型""" config = StockFilterConfig( exclude_cyb=False, exclude_kcb=False, exclude_bj=False, exclude_st=False, ) codes = [ "000001.SZ", "300001.SZ", "688001.SH", "830001.BJ", "430001.BJ", ] result = config.filter_codes(codes) assert result == codes def test_filter_codes_empty_list(self): """测试空列表""" config = StockFilterConfig() result = config.filter_codes([]) assert result == [] def test_filter_codes_no_matching(self): """测试全部排除""" config = StockFilterConfig() codes = ["300001.SZ", "688001.SH", "830001.BJ"] result = config.filter_codes(codes) assert result == [] class TestMarketCapSelectorConfig: """MarketCapSelectorConfig 测试类""" def test_default_values(self): """测试默认值""" config = MarketCapSelectorConfig() assert config.enabled is True assert config.n == 100 assert config.ascending is False assert config.market_cap_col == "total_mv" def test_custom_values(self): """测试自定义值""" config = MarketCapSelectorConfig( enabled=False, n=50, ascending=True, market_cap_col="circ_mv", ) assert config.enabled is False assert config.n == 50 assert config.ascending is True assert config.market_cap_col == "circ_mv" def test_invalid_n_zero(self): """测试无效的 n=0""" with pytest.raises(ValueError, match="n 必须是正整数"): MarketCapSelectorConfig(n=0) def test_invalid_n_negative(self): """测试无效的负数 n""" with pytest.raises(ValueError, match="n 必须是正整数"): MarketCapSelectorConfig(n=-1) def test_invalid_empty_market_cap_col(self): """测试空的市值列名""" with pytest.raises(ValueError, match="market_cap_col 不能为空"): MarketCapSelectorConfig(market_cap_col="") def test_large_n(self): """测试大的 n 值""" config = MarketCapSelectorConfig(n=5000) assert config.n == 5000 if __name__ == "__main__": pytest.main([__file__, "-v"])