refactor(training): 重构股票池管理 API 并更新训练流程
- 移除 StockFilterConfig/MarketCapSelectorConfig,改用 StockPoolManager + filter_func - Trainer 支持 train/val/test 三分法划分 - 更新 regression.ipynb 适配新 API - 删除已弃用的 test_selectors.py,后续补充 StockPoolManager 测试
This commit is contained in:
@@ -1,183 +0,0 @@
|
||||
"""测试股票池选择器配置
|
||||
|
||||
验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.training.components.selectors import (
|
||||
MarketCapSelectorConfig,
|
||||
StockFilterConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestStockFilterConfig:
|
||||
"""StockFilterConfig 测试类"""
|
||||
|
||||
def test_default_values(self):
|
||||
"""测试默认值"""
|
||||
config = StockFilterConfig()
|
||||
assert config.exclude_cyb is True
|
||||
assert config.exclude_kcb is True
|
||||
assert config.exclude_bj is True
|
||||
assert config.exclude_st is True
|
||||
|
||||
def test_custom_values(self):
|
||||
"""测试自定义值"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=False,
|
||||
exclude_kcb=False,
|
||||
exclude_bj=False,
|
||||
exclude_st=False,
|
||||
)
|
||||
assert config.exclude_cyb is False
|
||||
assert config.exclude_kcb is False
|
||||
assert config.exclude_bj is False
|
||||
assert config.exclude_st is False
|
||||
|
||||
def test_filter_codes_exclude_all(self):
|
||||
"""测试排除所有类型"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=True,
|
||||
exclude_kcb=True,
|
||||
exclude_bj=True,
|
||||
exclude_st=True,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ", # 主板 - 保留
|
||||
"300001.SZ", # 创业板 - 排除
|
||||
"688001.SH", # 科创板 - 排除
|
||||
"830001.BJ", # 北交所(8开头)- 排除
|
||||
"430001.BJ", # 北交所(4开头)- 排除
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == ["000001.SZ"]
|
||||
|
||||
def test_filter_codes_allow_cyb(self):
|
||||
"""测试允许创业板"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=False,
|
||||
exclude_kcb=True,
|
||||
exclude_bj=True,
|
||||
exclude_st=True,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ",
|
||||
"300001.SZ",
|
||||
"688001.SH",
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == ["000001.SZ", "300001.SZ"]
|
||||
|
||||
def test_filter_codes_allow_kcb(self):
|
||||
"""测试允许科创板"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=True,
|
||||
exclude_kcb=False,
|
||||
exclude_bj=True,
|
||||
exclude_st=True,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ",
|
||||
"300001.SZ",
|
||||
"688001.SH",
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == ["000001.SZ", "688001.SH"]
|
||||
|
||||
def test_filter_codes_allow_bj(self):
|
||||
"""测试允许北交所"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=True,
|
||||
exclude_kcb=True,
|
||||
exclude_bj=False,
|
||||
exclude_st=True,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ",
|
||||
"300001.SZ",
|
||||
"830001.BJ",
|
||||
"430001.BJ",
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
|
||||
|
||||
def test_filter_codes_allow_all(self):
|
||||
"""测试允许所有类型"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=False,
|
||||
exclude_kcb=False,
|
||||
exclude_bj=False,
|
||||
exclude_st=False,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ",
|
||||
"300001.SZ",
|
||||
"688001.SH",
|
||||
"830001.BJ",
|
||||
"430001.BJ",
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == codes
|
||||
|
||||
def test_filter_codes_empty_list(self):
|
||||
"""测试空列表"""
|
||||
config = StockFilterConfig()
|
||||
result = config.filter_codes([])
|
||||
assert result == []
|
||||
|
||||
def test_filter_codes_no_matching(self):
|
||||
"""测试全部排除"""
|
||||
config = StockFilterConfig()
|
||||
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestMarketCapSelectorConfig:
|
||||
"""MarketCapSelectorConfig 测试类"""
|
||||
|
||||
def test_default_values(self):
|
||||
"""测试默认值"""
|
||||
config = MarketCapSelectorConfig()
|
||||
assert config.enabled is True
|
||||
assert config.n == 100
|
||||
assert config.ascending is False
|
||||
assert config.market_cap_col == "total_mv"
|
||||
|
||||
def test_custom_values(self):
|
||||
"""测试自定义值"""
|
||||
config = MarketCapSelectorConfig(
|
||||
enabled=False,
|
||||
n=50,
|
||||
ascending=True,
|
||||
market_cap_col="circ_mv",
|
||||
)
|
||||
assert config.enabled is False
|
||||
assert config.n == 50
|
||||
assert config.ascending is True
|
||||
assert config.market_cap_col == "circ_mv"
|
||||
|
||||
def test_invalid_n_zero(self):
|
||||
"""测试无效的 n=0"""
|
||||
with pytest.raises(ValueError, match="n 必须是正整数"):
|
||||
MarketCapSelectorConfig(n=0)
|
||||
|
||||
def test_invalid_n_negative(self):
|
||||
"""测试无效的负数 n"""
|
||||
with pytest.raises(ValueError, match="n 必须是正整数"):
|
||||
MarketCapSelectorConfig(n=-1)
|
||||
|
||||
def test_invalid_empty_market_cap_col(self):
|
||||
"""测试空的市值列名"""
|
||||
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
|
||||
MarketCapSelectorConfig(market_cap_col="")
|
||||
|
||||
def test_large_n(self):
|
||||
"""测试大的 n 值"""
|
||||
config = MarketCapSelectorConfig(n=5000)
|
||||
assert config.n == 5000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
452
tests/training/test_stock_pool_manager.py
Normal file
452
tests/training/test_stock_pool_manager.py
Normal file
@@ -0,0 +1,452 @@
|
||||
"""测试 StockPoolManager
|
||||
|
||||
验证新的自定义函数和因子筛选功能。
|
||||
重点测试:临时因子隔离(只删除新生成的因子,保留原本存在的)。
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.training.core.stock_pool_manager import StockPoolManager
|
||||
|
||||
|
||||
class TestStockPoolManagerBasic:
|
||||
"""StockPoolManager 基础测试类"""
|
||||
|
||||
def test_basic_filter_with_columns(self):
|
||||
"""测试使用基础列进行筛选"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["total_mv"] > 50
|
||||
|
||||
# 创建模拟 data_router
|
||||
mock_router = Mock()
|
||||
|
||||
manager = StockPoolManager(
|
||||
filter_func=filter_func,
|
||||
required_columns=["total_mv"],
|
||||
data_router=mock_router,
|
||||
)
|
||||
|
||||
# 创建测试数据
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ", "000004.SZ"],
|
||||
"trade_date": ["20240101"] * 4,
|
||||
"close": [10.0, 20.0, 30.0, 40.0],
|
||||
"total_mv": [100.0, 30.0, 80.0, 20.0],
|
||||
}
|
||||
)
|
||||
|
||||
# 执行筛选(无需 mock,因为 total_mv 已在数据中)
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# 验证返回数据列与输入一致
|
||||
assert result.columns == data.columns
|
||||
# 验证筛选生效(保留市值 > 50 的股票)
|
||||
assert len(result) == 2
|
||||
assert "000001.SZ" in result["ts_code"].to_list()
|
||||
assert "000003.SZ" in result["ts_code"].to_list()
|
||||
|
||||
def test_filter_without_required_columns(self):
|
||||
"""测试不使用额外列,仅使用输入数据中已有的列"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["close"] > 25
|
||||
|
||||
manager = StockPoolManager(filter_func=filter_func)
|
||||
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"trade_date": ["20240101"] * 3,
|
||||
"close": [10.0, 30.0, 20.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# 验证只保留 close > 25 的股票
|
||||
assert len(result) == 1
|
||||
assert result["ts_code"][0] == "000002.SZ"
|
||||
assert result.columns == data.columns
|
||||
|
||||
def test_empty_result(self):
|
||||
"""测试筛选结果为空的情况"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["close"] > 9999 # 不可能满足的条件
|
||||
|
||||
manager = StockPoolManager(filter_func=filter_func)
|
||||
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101"] * 2,
|
||||
"close": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
assert len(result) == 0
|
||||
assert result.columns == data.columns # 即使为空,列结构保持一致
|
||||
|
||||
|
||||
class TestStockPoolManagerDailyIndependence:
|
||||
"""每日独立筛选测试类"""
|
||||
|
||||
def test_daily_independence(self):
|
||||
"""测试每日独立进行筛选"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
# 每日选收盘价前 50%
|
||||
median = df["close"].median()
|
||||
return df["close"] >= median
|
||||
|
||||
manager = StockPoolManager(filter_func=filter_func)
|
||||
|
||||
# 创建多日期数据
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": [
|
||||
"000001.SZ",
|
||||
"000002.SZ",
|
||||
"000003.SZ",
|
||||
"000004.SZ",
|
||||
# 日期 2
|
||||
"000001.SZ",
|
||||
"000002.SZ",
|
||||
"000003.SZ",
|
||||
"000004.SZ",
|
||||
],
|
||||
"trade_date": [
|
||||
"20240101",
|
||||
"20240101",
|
||||
"20240101",
|
||||
"20240101",
|
||||
"20240102",
|
||||
"20240102",
|
||||
"20240102",
|
||||
"20240102",
|
||||
],
|
||||
"close": [
|
||||
10.0,
|
||||
20.0,
|
||||
30.0,
|
||||
40.0, # 日期1:选 30, 40
|
||||
5.0,
|
||||
15.0,
|
||||
25.0,
|
||||
35.0, # 日期2:选 25, 35
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# 验证每个日期独立筛选
|
||||
day1 = result.filter(pl.col("trade_date") == "20240101")
|
||||
day2 = result.filter(pl.col("trade_date") == "20240102")
|
||||
|
||||
# 日期1:收盘价 >= 25(中位数)- 30 和 40
|
||||
assert len(day1) == 2
|
||||
assert set(day1["ts_code"].to_list()) == {"000003.SZ", "000004.SZ"}
|
||||
|
||||
# 日期2:收盘价 >= 20(中位数)- 25 和 35
|
||||
assert len(day2) == 2
|
||||
assert set(day2["ts_code"].to_list()) == {"000003.SZ", "000004.SZ"}
|
||||
|
||||
def test_uneven_daily_distribution(self):
|
||||
"""测试每日股票数量不均的情况"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["close"] > 15
|
||||
|
||||
manager = StockPoolManager(filter_func=filter_func)
|
||||
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": [
|
||||
"000001.SZ",
|
||||
"000002.SZ", # 日期1:2只股票
|
||||
"000001.SZ",
|
||||
"000002.SZ",
|
||||
"000003.SZ",
|
||||
"000004.SZ", # 日期2:4只股票
|
||||
],
|
||||
"trade_date": [
|
||||
"20240101",
|
||||
"20240101",
|
||||
"20240102",
|
||||
"20240102",
|
||||
"20240102",
|
||||
"20240102",
|
||||
],
|
||||
"close": [10.0, 20.0, 5.0, 15.0, 25.0, 35.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# 日期1:只有 000002.SZ (20 > 15)
|
||||
day1 = result.filter(pl.col("trade_date") == "20240101")
|
||||
assert len(day1) == 1
|
||||
|
||||
# 日期2:000003.SZ (25) 和 000004.SZ (35)
|
||||
day2 = result.filter(pl.col("trade_date") == "20240102")
|
||||
assert len(day2) == 2
|
||||
|
||||
|
||||
class TestStockPoolManagerFactorIsolation:
|
||||
"""因子隔离测试类 - 核心测试"""
|
||||
|
||||
@patch.object(StockPoolManager, "_compute_factors")
|
||||
def test_filter_with_factors(self, mock_compute):
|
||||
"""测试使用因子表达式进行筛选,验证临时因子被删除"""
|
||||
|
||||
# 设置 mock 返回值(包含计算后的因子)
|
||||
mock_compute.return_value = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"trade_date": ["20240101"] * 3,
|
||||
"close": [11.0, 9.5, 10.8],
|
||||
"momentum_20": [0.1, -0.05, 0.08], # 只有第一个 > 0.05
|
||||
}
|
||||
)
|
||||
|
||||
# 创建 Manager
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["momentum_20"] > 0.05
|
||||
|
||||
manager = StockPoolManager(
|
||||
filter_func=filter_func,
|
||||
required_factors={"momentum_20": "(close / ts_delay(close, 20)) - 1"},
|
||||
)
|
||||
|
||||
# 输入数据不含 momentum_20
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"trade_date": ["20240101"] * 3,
|
||||
"close": [11.0, 9.5, 10.8],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# 验证返回数据列与输入一致(momentum_20 被删除)
|
||||
assert result.columns == data.columns
|
||||
assert "momentum_20" not in result.columns
|
||||
|
||||
# 验证筛选生效
|
||||
# momentum_20 > 0.05: 000001.SZ (0.1), 000003.SZ (0.08)
|
||||
assert len(result) == 2
|
||||
assert "000001.SZ" in result["ts_code"].to_list()
|
||||
assert "000003.SZ" in result["ts_code"].to_list()
|
||||
|
||||
# 验证 _compute_factors 被调用
|
||||
mock_compute.assert_called_once()
|
||||
|
||||
@patch.object(StockPoolManager, "_compute_factors")
|
||||
def test_preserve_existing_factors(self, mock_compute):
|
||||
"""测试输入中已存在的因子不会被删除(核心测试)"""
|
||||
|
||||
# 设置 mock 返回值(包含 roe,但 momentum_20 已在输入中)
|
||||
mock_compute.return_value = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"trade_date": ["20240101"] * 3,
|
||||
"close": [11.0, 9.5, 10.8],
|
||||
"momentum_20": [0.1, -0.05, 0.08], # 原本就存在
|
||||
"roe": [0.12, 0.08, 0.15], # 本次生成,第二个 < 0.1
|
||||
}
|
||||
)
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return (df["momentum_20"] > 0.05) & (df["roe"] > 0.1)
|
||||
|
||||
manager = StockPoolManager(
|
||||
filter_func=filter_func,
|
||||
# 只声明 roe 为本次生成,momentum_20 已在输入中
|
||||
required_factors={"roe": "n_income / equity"},
|
||||
)
|
||||
|
||||
# 输入数据已包含 momentum_20
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"trade_date": ["20240101"] * 3,
|
||||
"close": [11.0, 9.5, 10.8],
|
||||
"momentum_20": [0.1, -0.05, 0.08], # 原本就存在的因子
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# 关键断言:
|
||||
# 1. momentum_20 保留(原本存在)
|
||||
assert "momentum_20" in result.columns
|
||||
# 2. roe 删除(本次生成)
|
||||
assert "roe" not in result.columns
|
||||
# 3. 列与输入完全一致
|
||||
assert result.columns == data.columns
|
||||
|
||||
# 验证筛选正确执行
|
||||
# momentum_20 > 0.05: 000001.SZ (0.1), 000003.SZ (0.08)
|
||||
# roe > 0.1: 000001.SZ (0.12), 000003.SZ (0.15)
|
||||
# 交集:000001.SZ, 000003.SZ
|
||||
assert len(result) == 2
|
||||
|
||||
# 验证 _compute_factors 被调用(因为 roe 不存在)
|
||||
mock_compute.assert_called_once()
|
||||
|
||||
def test_no_factor_computation_when_all_exist(self):
|
||||
"""测试所有因子都已存在时,不调用 FactorEngine"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return (df["factor_a"] > 0.5) & (df["factor_b"] < 0.3)
|
||||
|
||||
manager = StockPoolManager(
|
||||
filter_func=filter_func,
|
||||
required_factors={
|
||||
"factor_a": "some_expr_a",
|
||||
"factor_b": "some_expr_b",
|
||||
},
|
||||
)
|
||||
|
||||
# 输入数据已包含所有因子
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101"] * 2,
|
||||
"close": [10.0, 20.0],
|
||||
"factor_a": [0.6, 0.4], # 原本存在
|
||||
"factor_b": [0.2, 0.5], # 原本存在
|
||||
}
|
||||
)
|
||||
|
||||
with patch("src.factors.engine.factor_engine.FactorEngine") as mock_engine:
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# FactorEngine 不应被调用(所有因子都已存在)
|
||||
mock_engine.assert_not_called()
|
||||
|
||||
# 验证结果正确
|
||||
assert len(result) == 1
|
||||
assert result["ts_code"][0] == "000001.SZ"
|
||||
assert result.columns == data.columns
|
||||
|
||||
|
||||
class TestStockPoolManagerEdgeCases:
|
||||
"""边界情况测试类"""
|
||||
|
||||
def test_single_date(self):
|
||||
"""测试单日数据"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["close"] > 15
|
||||
|
||||
manager = StockPoolManager(filter_func=filter_func)
|
||||
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"trade_date": ["20240101"] * 3,
|
||||
"close": [10.0, 20.0, 30.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result.columns == data.columns
|
||||
|
||||
def test_single_stock_per_day(self):
|
||||
"""测试每天只有一只股票"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["close"] > 0 # 都保留
|
||||
|
||||
manager = StockPoolManager(filter_func=filter_func)
|
||||
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"trade_date": ["20240101", "20240102", "20240103"],
|
||||
"close": [10.0, 20.0, 30.0],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result.columns == data.columns
|
||||
|
||||
def test_filter_all_out_one_day(self):
|
||||
"""测试某天全部过滤掉"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["close"] > 100 # 很高的阈值
|
||||
|
||||
manager = StockPoolManager(filter_func=filter_func)
|
||||
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": [
|
||||
"000001.SZ",
|
||||
"000002.SZ",
|
||||
"000001.SZ",
|
||||
"000002.SZ",
|
||||
],
|
||||
"trade_date": [
|
||||
"20240101",
|
||||
"20240101",
|
||||
"20240102",
|
||||
"20240102",
|
||||
],
|
||||
"close": [
|
||||
10.0,
|
||||
20.0, # 日期1:都过滤掉
|
||||
150.0,
|
||||
200.0, # 日期2:都保留
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# 只有日期2的数据
|
||||
assert len(result) == 2
|
||||
assert all(result["trade_date"] == "20240102")
|
||||
|
||||
def test_column_order_preserved(self):
|
||||
"""测试列顺序保持不变"""
|
||||
|
||||
def filter_func(df: pl.DataFrame) -> pl.Series:
|
||||
return df["close"] > 15
|
||||
|
||||
manager = StockPoolManager(filter_func=filter_func)
|
||||
|
||||
data = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"trade_date": ["20240101"] * 3,
|
||||
"open": [9.0, 19.0, 29.0],
|
||||
"high": [11.0, 21.0, 31.0],
|
||||
"low": [8.0, 18.0, 28.0],
|
||||
"close": [10.0, 20.0, 30.0],
|
||||
"volume": [1000, 2000, 3000],
|
||||
}
|
||||
)
|
||||
|
||||
result = manager.filter_and_select_daily(data)
|
||||
|
||||
# 验证列顺序完全一致
|
||||
assert list(result.columns) == list(data.columns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user