refactor: 存储层迁移DuckDB + 模块重构

- 存储层重构: HDF5 → DuckDB(UPSERT模式、线程安全存储)
- Sync类迁移: DataSync从sync.py迁移到api_daily.py(职责分离)
- 模型模块重构: src/models → src/pipeline(更清晰的命名)
- 新增因子模块: factors/momentum (MA、收益率排名)、factors/financial
- 新增API接口: api_namechange、api_bak_basic
- 新增训练入口: training模块(main.py、pipeline配置)
- 工具函数统一: get_today_date等移至utils.py
- 文档更新: AGENTS.md添加架构变更历史
This commit is contained in:
2026-02-23 16:23:53 +08:00
parent 9f95be56a0
commit 593ec99466
32 changed files with 4181 additions and 1395 deletions

View File

@@ -1,4 +1,4 @@
"""模型框架核心测试
"""Pipeline 组件库核心测试
测试核心抽象类插件注册中心处理器模型和划分策略
"""
@@ -9,7 +9,7 @@ import numpy as np
from typing import List, Optional
# 确保导入时注册所有组件
from src.models import (
from src.pipeline import (
PluginRegistry,
PipelineStage,
BaseProcessor,
@@ -17,7 +17,7 @@ from src.models import (
BaseSplitter,
ProcessingPipeline,
)
from src.models.core import TaskType
from src.pipeline.core import TaskType
# ========== 测试核心抽象类 ==========
@@ -232,7 +232,7 @@ class TestBuiltInProcessors:
def test_dropna_processor(self):
"""测试缺失值删除处理器"""
from src.models.processors import DropNAProcessor
from src.pipeline.processors import DropNAProcessor
processor = DropNAProcessor(columns=["a", "b"])
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
@@ -246,7 +246,7 @@ class TestBuiltInProcessors:
def test_fillna_processor(self):
"""测试缺失值填充处理器"""
from src.models.processors import FillNAProcessor
from src.pipeline.processors import FillNAProcessor
processor = FillNAProcessor(columns=["a"], method="mean")
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
@@ -258,7 +258,7 @@ class TestBuiltInProcessors:
def test_standard_scaler(self):
"""测试标准化处理器"""
from src.models.processors import StandardScaler
from src.pipeline.processors import StandardScaler
processor = StandardScaler(columns=["value"])
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
@@ -271,7 +271,7 @@ class TestBuiltInProcessors:
def test_winsorizer(self):
"""测试缩尾处理器"""
from src.models.processors import Winsorizer
from src.pipeline.processors import Winsorizer
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
df = pl.DataFrame(
@@ -288,7 +288,7 @@ class TestBuiltInProcessors:
def test_rank_transformer(self):
"""测试排名转换处理器"""
from src.models.processors import RankTransformer
from src.pipeline.processors import RankTransformer
processor = RankTransformer(columns=["value"])
df = pl.DataFrame(
@@ -302,7 +302,7 @@ class TestBuiltInProcessors:
def test_neutralizer(self):
"""测试中性化处理器"""
from src.models.processors import Neutralizer
from src.pipeline.processors import Neutralizer
processor = Neutralizer(columns=["value"], group_col="industry")
df = pl.DataFrame(
@@ -331,7 +331,7 @@ class TestProcessingPipeline:
def test_pipeline_fit_transform(self):
"""测试流水线的 fit_transform"""
from src.models.processors import StandardScaler
from src.pipeline.processors import StandardScaler
scaler1 = StandardScaler(columns=["a"])
scaler2 = StandardScaler(columns=["b"])
@@ -348,7 +348,7 @@ class TestProcessingPipeline:
def test_pipeline_transform_uses_fitted_params(self):
"""测试 transform 使用已 fit 的参数"""
from src.models.processors import StandardScaler
from src.pipeline.processors import StandardScaler
scaler = StandardScaler(columns=["value"])
pipeline = ProcessingPipeline([scaler])
@@ -383,7 +383,7 @@ class TestSplitters:
def test_time_series_split(self):
"""测试时间序列划分"""
from src.models.core import TimeSeriesSplit
from src.pipeline.core import TimeSeriesSplit
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
@@ -406,7 +406,7 @@ class TestSplitters:
def test_walk_forward_split(self):
"""测试滚动前向划分"""
from src.models.core import WalkForwardSplit
from src.pipeline.core import WalkForwardSplit
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
@@ -426,7 +426,7 @@ class TestSplitters:
def test_expanding_window_split(self):
"""测试扩展窗口划分"""
from src.models.core import ExpandingWindowSplit
from src.pipeline.core import ExpandingWindowSplit
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
@@ -455,7 +455,7 @@ class TestModels:
@pytest.mark.skip(reason="需要安装 lightgbm")
def test_lightgbm_model(self):
"""测试 LightGBM 模型"""
from src.models.models import LightGBMModel
from src.pipeline.models import LightGBMModel
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})

View File

@@ -1,277 +1,163 @@
"""Tests for data synchronization module.
"""Sync 接口测试规范与实现。
Tests the sync module's full/incremental sync logic for daily data:
- Full sync when local data doesn't exist (from 20180101)
- Incremental sync when local data exists (from last_date + 1)
- Data integrity validation
【测试规范】
1. 所有 sync 测试只使用 2018-01-01 到 2018-04-01 的数据
2. 只测试接口是否能正常返回数据,不测试落库逻辑
3. 对于按股票查询的接口,只测试 000001.SZ、000002.SZ 两支股票
4. 使用真实 API 调用,确保接口可用性
【测试范围】
- get_daily: 日线数据接口(按股票)
- sync_all_stocks: 股票基础信息接口
- sync_trade_cal_cache: 交易日历接口
- sync_namechange: 名称变更接口
- sync_bak_basic: 备用股票基础信息接口
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta
from datetime import datetime
from src.data.sync import (
DataSync,
sync_all,
get_today_date,
get_next_date,
DEFAULT_START_DATE,
)
from src.data.storage import ThreadSafeStorage
from src.data.client import TushareClient
# 测试用常量
TEST_START_DATE = "20180101"
TEST_END_DATE = "20180401"
TEST_STOCK_CODES = ["000001.SZ", "000002.SZ"]
@pytest.fixture
def mock_storage():
"""Create a mock storage instance."""
storage = Mock(spec=ThreadSafeStorage)
storage.exists = Mock(return_value=False)
storage.load = Mock(return_value=pd.DataFrame())
storage.save = Mock(return_value={"status": "success", "rows": 0})
return storage
class TestGetDaily:
"""测试日线数据 get 接口(按股票查询)."""
def test_get_daily_single_stock(self):
"""测试 get_daily 获取单只股票数据."""
from src.data.api_wrappers.api_daily import get_daily
@pytest.fixture
def mock_client():
"""Create a mock client instance."""
return Mock(spec=TushareClient)
result = get_daily(
ts_code=TEST_STOCK_CODES[0],
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "get_daily 应返回 DataFrame"
assert not result.empty, "get_daily 应返回非空数据"
class TestDateUtilities:
"""Test date utility functions."""
def test_get_daily_has_required_columns(self):
"""测试 get_daily 返回的数据包含必要字段."""
from src.data.api_wrappers.api_daily import get_daily
def test_get_today_date_format(self):
"""Test today date is in YYYYMMDD format."""
result = get_today_date()
assert len(result) == 8
assert result.isdigit()
result = get_daily(
ts_code=TEST_STOCK_CODES[0],
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
def test_get_next_date(self):
"""Test getting next date."""
result = get_next_date("20240101")
assert result == "20240102"
# 验证必要的列存在
required_columns = ["ts_code", "trade_date", "open", "high", "low", "close"]
for col in required_columns:
assert col in result.columns, f"get_daily 返回应包含 {col}"
def test_get_next_date_year_end(self):
"""Test getting next date across year boundary."""
result = get_next_date("20241231")
assert result == "20250101"
def test_get_daily_multiple_stocks(self):
"""测试 get_daily 获取多只股票数据."""
from src.data.api_wrappers.api_daily import get_daily
def test_get_next_date_month_end(self):
"""Test getting next date across month boundary."""
result = get_next_date("20240131")
assert result == "20240201"
class TestDataSync:
"""Test DataSync class functionality."""
def test_get_all_stock_codes_from_daily(self, mock_storage):
"""Test getting stock codes from daily data."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
}
results = {}
for code in TEST_STOCK_CODES:
result = get_daily(
ts_code=code,
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
codes = sync.get_all_stock_codes()
assert len(codes) == 2
assert "000001.SZ" in codes
assert "600000.SH" in codes
def test_get_all_stock_codes_fallback(self, mock_storage):
"""Test fallback to stock_basic when daily is empty."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
# First call (daily) returns empty, second call (stock_basic) returns data
mock_storage.load.side_effect = [
pd.DataFrame(), # daily empty
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
]
codes = sync.get_all_stock_codes()
assert len(codes) == 2
def test_get_global_last_date(self, mock_storage):
"""Test getting global last date."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "600000.SH"],
"trade_date": ["20240102", "20240103"],
}
results[code] = result
assert isinstance(result, pd.DataFrame), (
f"get_daily({code}) 应返回 DataFrame"
)
last_date = sync.get_global_last_date()
assert last_date == "20240103"
def test_get_global_last_date_empty(self, mock_storage):
"""Test getting last date from empty storage."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame()
last_date = sync.get_global_last_date()
assert last_date is None
def test_sync_single_stock(self, mock_storage):
"""Test syncing a single stock."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
with patch(
"src.data.sync.get_daily",
return_value=pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
),
):
sync = DataSync()
sync.storage = mock_storage
result = sync.sync_single_stock(
ts_code="000001.SZ",
start_date="20240101",
end_date="20240102",
)
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
def test_sync_single_stock_empty(self, mock_storage):
"""Test syncing a stock with no data."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
result = sync.sync_single_stock(
ts_code="INVALID.SZ",
start_date="20240101",
end_date="20240102",
)
assert result.empty
assert not result.empty, f"get_daily({code}) 应返回非空数据"
class TestSyncAll:
"""Test sync_all function."""
class TestSyncStockBasic:
"""测试股票基础信息 sync 接口."""
def test_full_sync_mode(self, mock_storage):
"""Test full sync mode when force_full=True."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
def test_sync_all_stocks_returns_data(self):
"""测试 sync_all_stocks 是否能正常返回数据."""
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
}
)
result = sync_all_stocks()
result = sync.sync_all(force_full=True)
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_all_stocks 应返回 DataFrame"
assert not result.empty, "sync_all_stocks 应返回非空数据"
# Verify sync_single_stock was called with default start date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]["start_date"] == DEFAULT_START_DATE
def test_sync_all_stocks_has_required_columns(self):
"""测试 sync_all_stocks 返回的数据包含必要字段."""
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
def test_incremental_sync_mode(self, mock_storage):
"""Test incremental sync mode when data exists."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
result = sync_all_stocks()
# Mock existing data with last date
mock_storage.load.side_effect = [
pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
), # get_all_stock_codes
pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
), # get_global_last_date
]
result = sync.sync_all(force_full=False)
# Verify sync_single_stock was called with next date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]["start_date"] == "20240103"
def test_manual_start_date(self, mock_storage):
"""Test sync with manual start date."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
}
)
result = sync.sync_all(force_full=False, start_date="20230601")
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]["start_date"] == "20230601"
def test_no_stocks_found(self, mock_storage):
"""Test sync when no stocks are found."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame()
result = sync.sync_all()
assert result == {}
# 验证必要的列存在
required_columns = ["ts_code"]
for col in required_columns:
assert col in result.columns, f"sync_all_stocks 返回应包含 {col}"
class TestSyncAllConvenienceFunction:
"""Test sync_all convenience function."""
class TestSyncTradeCal:
"""测试交易日历 sync 接口."""
def test_sync_all_function(self):
"""Test sync_all convenience function."""
with patch("src.data.sync.DataSync") as MockSync:
mock_instance = Mock()
mock_instance.sync_all.return_value = {}
MockSync.return_value = mock_instance
def test_sync_trade_cal_cache_returns_data(self):
"""测试 sync_trade_cal_cache 是否能正常返回数据."""
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
result = sync_all(force_full=True)
result = sync_trade_cal_cache(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
MockSync.assert_called_once()
mock_instance.sync_all.assert_called_once_with(
force_full=True,
start_date=None,
end_date=None,
dry_run=False,
)
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_trade_cal_cache 应返回 DataFrame"
assert not result.empty, "sync_trade_cal_cache 应返回非空数据"
def test_sync_trade_cal_cache_has_required_columns(self):
"""测试 sync_trade_cal_cache 返回的数据包含必要字段."""
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
result = sync_trade_cal_cache(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
# 验证必要的列存在
required_columns = ["cal_date", "is_open"]
for col in required_columns:
assert col in result.columns, f"sync_trade_cal_cache 返回应包含 {col}"
class TestSyncNamechange:
"""测试名称变更 sync 接口."""
def test_sync_namechange_returns_data(self):
"""测试 sync_namechange 是否能正常返回数据."""
from src.data.api_wrappers.api_namechange import sync_namechange
result = sync_namechange()
# 验证返回了数据(可能是空 DataFrame因为是历史变更
assert isinstance(result, pd.DataFrame), "sync_namechange 应返回 DataFrame"
class TestSyncBakBasic:
"""测试备用股票基础信息 sync 接口."""
def test_sync_bak_basic_returns_data(self):
"""测试 sync_bak_basic 是否能正常返回数据."""
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
result = sync_bak_basic(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_bak_basic 应返回 DataFrame"
# 注意bak_basic 可能返回空数据,这是正常的
if __name__ == "__main__":

View File

@@ -1,256 +0,0 @@
"""Tests for data sync with REAL data (read-only).
Tests verify:
1. get_global_last_date() correctly reads local data's max date
2. Incremental sync date calculation (local_last_date + 1)
3. Full sync date calculation (20180101)
4. Multi-stock scenario with real data
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
- NO sync_all() calls (writes daily.h5)
- NO check_sync_needed() calls (writes trade_cal.h5)
"""
import pytest
import pandas as pd
from pathlib import Path
from src.data.sync import (
DataSync,
get_next_date,
DEFAULT_START_DATE,
)
from src.data.storage import Storage
class TestDataSyncReadOnly:
"""Read-only tests for data sync - verify date calculation logic."""
@pytest.fixture
def storage(self):
"""Create storage instance."""
return Storage()
@pytest.fixture
def data_sync(self):
"""Create DataSync instance."""
return DataSync()
@pytest.fixture
def daily_exists(self, storage):
"""Check if daily.h5 exists."""
return storage.exists("daily")
def test_daily_h5_exists(self, storage):
"""Verify daily.h5 data file exists before running tests."""
assert storage.exists("daily"), (
"daily.h5 not found. Please run full sync first: "
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
)
def test_get_global_last_date(self, data_sync, daily_exists):
"""Test get_global_last_date returns correct max date from local data."""
if not daily_exists:
pytest.skip("daily.h5 not found")
last_date = data_sync.get_global_last_date()
# Verify it's a valid date string
assert last_date is not None, "get_global_last_date returned None"
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
# Verify by reading storage directly
daily_data = data_sync.storage.load("daily")
expected_max = str(daily_data["trade_date"].max())
assert last_date == expected_max, (
f"get_global_last_date returned {last_date}, "
f"but actual max date is {expected_max}"
)
print(f"[TEST] Local data last date: {last_date}")
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
"""Test incremental sync: start_date = local_last_date + 1.
This verifies that when local data exists, incremental sync should
fetch data from (local_last_date + 1), not from 20180101.
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
# Get local last date
local_last_date = data_sync.get_global_last_date()
assert local_last_date is not None, "No local data found"
# Calculate expected incremental start date
expected_start_date = get_next_date(local_last_date)
# Verify the calculation is correct
local_last_int = int(local_last_date)
expected_int = local_last_int + 1
actual_int = int(expected_start_date)
assert actual_int == expected_int, (
f"Incremental start date calculation error: "
f"expected {expected_int}, got {actual_int}"
)
print(
f"[TEST] Incremental sync: local_last={local_last_date}, "
f"start_date should be {expected_start_date}"
)
# Verify this is NOT 20180101 (would be full sync)
assert expected_start_date != DEFAULT_START_DATE, (
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
)
def test_full_sync_date_calculation(self):
"""Test full sync: start_date = 20180101 when force_full=True.
This verifies that force_full=True always starts from 20180101.
"""
# Full sync should always use DEFAULT_START_DATE
full_sync_start = DEFAULT_START_DATE
assert full_sync_start == "20180101", (
f"Full sync should start from 20180101, got {full_sync_start}"
)
print(f"[TEST] Full sync start date: {full_sync_start}")
def test_date_comparison_logic(self, data_sync, daily_exists):
"""Test date comparison: incremental vs full sync selection logic.
Verify that:
- If local_last_date < today: incremental sync needed
- If local_last_date >= today: no sync needed
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
from datetime import datetime
local_last_date = data_sync.get_global_last_date()
today = datetime.now().strftime("%Y%m%d")
local_last_int = int(local_last_date)
today_int = int(today)
# Log the comparison
print(
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
f"today={today} ({today_int})"
)
# This test just verifies the comparison logic works
if local_last_int < today_int:
print("[TEST] Local data is older than today - sync needed")
# Incremental sync should fetch from local_last_date + 1
sync_start = get_next_date(local_last_date)
assert int(sync_start) > local_last_int, (
"Sync start should be after local last"
)
else:
print("[TEST] Local data is up-to-date - no sync needed")
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
"""Test get_all_stock_codes returns multiple real stock codes."""
if not daily_exists:
pytest.skip("daily.h5 not found")
codes = data_sync.get_all_stock_codes()
# Verify it's a list
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
assert len(codes) > 0, "No stock codes found"
# Verify multiple stocks
assert len(codes) >= 10, (
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
)
# Verify format (should be like 000001.SZ, 600000.SH)
sample_codes = codes[:5]
for code in sample_codes:
assert "." in code, f"Invalid stock code format: {code}"
suffix = code.split(".")[-1]
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
def test_multi_stock_date_range(self, data_sync, daily_exists):
"""Test that multiple stocks share the same date range in local data.
This verifies that local data has consistent date coverage across stocks.
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
daily_data = data_sync.storage.load("daily")
# Get date range for each stock
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
# Get global min and max
global_min = str(daily_data["trade_date"].min())
global_max = str(daily_data["trade_date"].max())
print(f"[TEST] Global date range: {global_min} to {global_max}")
print(f"[TEST] Total stocks: {len(stock_dates)}")
# Verify we have data for multiple stocks
assert len(stock_dates) >= 10, (
f"Expected at least 10 stocks, got {len(stock_dates)}"
)
# Verify date range is reasonable (at least 1 year of data)
global_min_int = int(global_min)
global_max_int = int(global_max)
days_span = global_max_int - global_min_int
assert days_span > 100, (
f"Date range too small: {days_span} days. "
f"Expected at least 100 days of data."
)
print(f"[TEST] Date span: {days_span} days")
class TestDateUtilities:
"""Test date utility functions."""
def test_get_next_date(self):
"""Test get_next_date correctly calculates next day."""
# Test normal cases
assert get_next_date("20240101") == "20240102"
assert get_next_date("20240131") == "20240201" # Month boundary
assert get_next_date("20241231") == "20250101" # Year boundary
def test_incremental_vs_full_sync_logic(self):
"""Test the logic difference between incremental and full sync.
Incremental: start_date = local_last_date + 1
Full: start_date = 20180101
"""
# Scenario 1: Local data exists
local_last_date = "20240115"
incremental_start = get_next_date(local_last_date)
assert incremental_start == "20240116"
assert incremental_start != DEFAULT_START_DATE
# Scenario 2: Force full sync
full_sync_start = DEFAULT_START_DATE # "20180101"
assert full_sync_start == "20180101"
assert incremental_start != full_sync_start
print("[TEST] Incremental vs Full sync logic verified")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])