Files
ProStock/src/data/utils.py
liaozhaorun a9e4746239 refactor: 代码审查修复 - 日期过滤、性能优化、数据泄露防护
- 修复 data_loader.py 财务数据日期过滤,支持按范围加载
- 优化 MADClipper 使用窗口函数替代 join,提升性能
- 修复训练日期边界问题,添加1天间隔避免数据泄露
- 新增 .gitignore 规则忽略训练输出目录
2026-02-25 21:11:19 +08:00

147 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Data module utility functions.
集中管理数据模块中常用的工具函数,避免重复定义。
"""
from datetime import datetime, timedelta
from typing import Optional, List
# 默认全量同步开始日期
DEFAULT_START_DATE = "20180101"
# 今日日期 (YYYYMMDD 格式)
TODAY: str = datetime.now().strftime("%Y%m%d")
def get_today_date() -> str:
"""获取今日日期YYYYMMDD 格式)。
Returns:
今日日期字符串,格式为 YYYYMMDD
"""
return TODAY
def get_next_date(date_str: str) -> str:
"""获取给定日期的下一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
YYYYMMDD 格式的下一天日期
"""
dt = datetime.strptime(date_str, "%Y%m%d")
next_dt = dt + timedelta(days=1)
return next_dt.strftime("%Y%m%d")
def get_prev_date(date_str: str) -> str:
"""获取给定日期的前一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
YYYYMMDD 格式的前一天日期
"""
dt = datetime.strptime(date_str, "%Y%m%d")
prev_dt = dt - timedelta(days=1)
return prev_dt.strftime("%Y%m%d")
def parse_date(date_str: str) -> datetime:
"""解析 YYYYMMDD 格式的日期字符串。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
datetime 对象
"""
return datetime.strptime(date_str, "%Y%m%d")
def format_date(dt: datetime) -> str:
"""将 datetime 对象格式化为 YYYYMMDD 字符串。
Args:
dt: datetime 对象
Returns:
YYYYMMDD 格式的日期字符串
"""
return dt.strftime("%Y%m%d")
def is_quarter_end(date_str: str) -> bool:
"""判断是否为季度最后一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
是否为季度最后一天
"""
month_day = date_str[4:]
return month_day in ("0331", "0630", "0930", "1231")
def date_to_quarter(date_str: str) -> str:
"""将日期转换为对应季度的最后一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
季度最后一天,格式为 YYYYMMDD
例如: 20230115 -> 20230331
"""
year = date_str[:4]
month = int(date_str[4:6])
if month <= 3:
return year + "0331"
elif month <= 6:
return year + "0630"
elif month <= 9:
return year + "0930"
else:
return year + "1231"
def get_quarters_in_range(start_date: str, end_date: str) -> List[str]:
"""获取日期范围内的所有季度列表。
Args:
start_date: 开始日期 YYYYMMDD
end_date: 结束日期 YYYYMMDD
Returns:
季度列表,格式为 YYYYMMDD按时间倒序排列
例如: ['20231231', '20230930', '20230630', '20230331']
"""
quarters = []
# 将开始日期和结束日期都转换为季度
start_quarter = date_to_quarter(start_date)
end_quarter = date_to_quarter(end_date)
# 解析年份
start_year = int(start_date[:4])
end_year = int(end_date[:4])
# 遍历所有年份和季度
for year in range(end_year, start_year - 1, -1):
year_str = str(year)
# 季度顺序: Q4, Q3, Q2, Q1 (倒序)
for quarter in ["1231", "0930", "0630", "0331"]:
quarter_date = year_str + quarter
# 只包含在范围内的季度
if quarter_date >= start_quarter and quarter_date <= end_quarter:
quarters.append(quarter_date)
return quarters