Compare commits
3 Commits
f1687dadf3
...
3b42093100
| Author | SHA1 | Date | |
|---|---|---|---|
| 3b42093100 | |||
| 620696c842 | |||
| af5c96cd53 |
@@ -15,6 +15,14 @@ from src.data.db_manager import (
|
|||||||
get_table_info,
|
get_table_info,
|
||||||
sync_table,
|
sync_table,
|
||||||
)
|
)
|
||||||
|
from src.data.catalog import (
|
||||||
|
DatabaseCatalog,
|
||||||
|
SQLQueryBuilder,
|
||||||
|
build_context_lazyframe,
|
||||||
|
TableMetadata,
|
||||||
|
FieldMapping,
|
||||||
|
TableFrequency,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Configuration
|
# Configuration
|
||||||
@@ -36,4 +44,11 @@ __all__ = [
|
|||||||
"ensure_table",
|
"ensure_table",
|
||||||
"get_table_info",
|
"get_table_info",
|
||||||
"sync_table",
|
"sync_table",
|
||||||
|
# Data catalog
|
||||||
|
"DatabaseCatalog",
|
||||||
|
"SQLQueryBuilder",
|
||||||
|
"build_context_lazyframe",
|
||||||
|
"TableMetadata",
|
||||||
|
"FieldMapping",
|
||||||
|
"TableFrequency",
|
||||||
]
|
]
|
||||||
|
|||||||
181
src/data/financial_loader.py
Normal file
181
src/data/financial_loader.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""财务数据加载与清洗模块。
|
||||||
|
|
||||||
|
提供财务数据的加载、清洗和与行情数据拼接功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.data.storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
class FinancialLoader:
|
||||||
|
"""财务数据加载器。
|
||||||
|
|
||||||
|
负责财务数据的清洗、去重,以及与行情数据的 as-of join。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
storage: DuckDB 存储实例
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.storage = Storage()
|
||||||
|
|
||||||
|
def load_financial_data(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
columns: List[str],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""加载并清洗财务数据。
|
||||||
|
|
||||||
|
数据清洗流程:
|
||||||
|
1. 仅保留 report_type == '1'(合并报表)
|
||||||
|
2. 按 (ts_code, f_ann_date) 分组,按 update_flag 降序去重
|
||||||
|
3. 转换为 Date 类型,按 ts_code 和 f_ann_date 排序
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 财务表名(如 'financial_income')
|
||||||
|
columns: 需要的字段列表(必须包含核心字段)
|
||||||
|
start_date: 数据开始日期(YYYYMMDD)
|
||||||
|
end_date: 数据结束日期(YYYYMMDD)
|
||||||
|
ts_code: 可选,单个股票代码过滤
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清洗后的 Polars DataFrame,已排序,f_ann_date 为 pl.Date 类型
|
||||||
|
"""
|
||||||
|
# 确保包含必要字段
|
||||||
|
required_cols = {"ts_code", "f_ann_date", "report_type", "update_flag"}
|
||||||
|
query_cols = list(set(columns) | required_cols)
|
||||||
|
|
||||||
|
# 从数据库加载原始数据
|
||||||
|
df = self._load_from_db(table_name, query_cols, start_date, end_date, ts_code)
|
||||||
|
|
||||||
|
if df.is_empty():
|
||||||
|
return df
|
||||||
|
|
||||||
|
# 步骤1: 仅保留合并报表 (report_type 可能是字符串或整数)
|
||||||
|
df = df.filter(pl.col("report_type") == 1)
|
||||||
|
|
||||||
|
# 步骤2: 按 update_flag 降序排列后去重
|
||||||
|
df = df.with_columns(
|
||||||
|
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 排序:ts_code, f_ann_date 升序;update_flag 降序
|
||||||
|
df = df.sort(
|
||||||
|
["ts_code", "f_ann_date", "update_flag_int"],
|
||||||
|
descending=[False, False, True],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 去重:保留每个 (ts_code, f_ann_date) 的第一条(update_flag 最高的)
|
||||||
|
df = df.unique(subset=["ts_code", "f_ann_date"], keep="first")
|
||||||
|
|
||||||
|
# 移除临时列
|
||||||
|
df = df.drop("update_flag_int")
|
||||||
|
|
||||||
|
# 步骤3: 确保 f_ann_date 是 Date 类型并排序
|
||||||
|
# 数据库返回的必须是 Date 类型,如果不是则报错
|
||||||
|
if df["f_ann_date"].dtype != pl.Date:
|
||||||
|
raise TypeError(
|
||||||
|
f"f_ann_date 必须是 Date 类型,实际类型为 {df['f_ann_date'].dtype}. "
|
||||||
|
f"请检查数据库表结构,确保日期字段为 DATE 类型。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 最终排序(join_asof 要求)
|
||||||
|
df = df.sort(["ts_code", "f_ann_date"])
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def merge_financial_with_price(
|
||||||
|
self,
|
||||||
|
df_price: pl.DataFrame,
|
||||||
|
df_financial: pl.DataFrame,
|
||||||
|
financial_cols: List[str],
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""将财务数据拼接到行情数据。
|
||||||
|
|
||||||
|
使用 join_asof 向后匹配:对于每个交易日,找到最近的历史公告数据。
|
||||||
|
|
||||||
|
注意:输入的 df_price 的 trade_date 必须是 pl.Date 类型且已排序。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df_price: 行情数据 DataFrame,必须包含 ts_code, trade_date(Date 类型)
|
||||||
|
df_financial: 财务数据 DataFrame(已通过 load_financial_data 清洗,Date 类型)
|
||||||
|
financial_cols: 需要从财务表保留的字段列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
拼接后的 DataFrame(trade_date 仍为 Date 类型)
|
||||||
|
"""
|
||||||
|
if df_financial.is_empty():
|
||||||
|
# 财务数据为空,返回行情数据(财务列为空)
|
||||||
|
for col in financial_cols:
|
||||||
|
if col not in df_price.columns:
|
||||||
|
df_price = df_price.with_columns([pl.lit(None).alias(col)])
|
||||||
|
return df_price
|
||||||
|
|
||||||
|
# 执行 asof join: 向后寻找最近的历史数据
|
||||||
|
# strategy='backward': 对于每个 trade_date,找 f_ann_date <= trade_date 的最新记录
|
||||||
|
merged = df_price.join_asof(
|
||||||
|
df_financial.select(["ts_code", "f_ann_date"] + financial_cols),
|
||||||
|
left_on="trade_date",
|
||||||
|
right_on="f_ann_date",
|
||||||
|
by="ts_code",
|
||||||
|
strategy="backward",
|
||||||
|
)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def _load_from_db(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
columns: List[str],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""从数据库加载财务数据(内部方法)。"""
|
||||||
|
conn = self.storage._connection
|
||||||
|
|
||||||
|
cols_str = ", ".join(f'"{c}"' for c in columns)
|
||||||
|
|
||||||
|
start_dt = datetime.strptime(start_date, "%Y%m%d").date()
|
||||||
|
end_dt = datetime.strptime(end_date, "%Y%m%d").date()
|
||||||
|
|
||||||
|
conditions = [f"f_ann_date BETWEEN '{start_dt}' AND '{end_dt}'"]
|
||||||
|
if ts_code:
|
||||||
|
conditions.append(f"ts_code = '{ts_code}'")
|
||||||
|
|
||||||
|
where_clause = " AND ".join(conditions)
|
||||||
|
query = f"SELECT {cols_str} FROM {table_name} WHERE {where_clause} ORDER BY ts_code, f_ann_date"
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = conn.sql(query).pl()
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[FinancialLoader] 加载 {table_name} 失败: {e}")
|
||||||
|
return pl.DataFrame()
|
||||||
|
|
||||||
|
def get_date_range_with_lookback(
|
||||||
|
self,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
lookback_years: int = 1,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""计算包含回看期的日期范围。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: 原始开始日期(YYYYMMDD)
|
||||||
|
end_date: 原始结束日期(YYYYMMDD)
|
||||||
|
lookback_years: 回看年数(默认1年)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(扩展后的开始日期, 结束日期)
|
||||||
|
"""
|
||||||
|
start_dt = datetime.strptime(start_date, "%Y%m%d")
|
||||||
|
adjusted_start = start_dt - timedelta(days=365 * lookback_years)
|
||||||
|
return adjusted_start.strftime("%Y%m%d"), end_date
|
||||||
@@ -13,6 +13,7 @@ from src.factors import FactorEngine
|
|||||||
from src.training import (
|
from src.training import (
|
||||||
DateSplitter,
|
DateSplitter,
|
||||||
LightGBMModel,
|
LightGBMModel,
|
||||||
|
STFilter,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
StockFilterConfig,
|
StockFilterConfig,
|
||||||
StockPoolManager,
|
StockPoolManager,
|
||||||
@@ -60,6 +61,8 @@ def create_factors_with_strings(engine: FactorEngine) -> List[str]:
|
|||||||
"market_cap_rank": "cs_rank(total_mv)",
|
"market_cap_rank": "cs_rank(total_mv)",
|
||||||
# 7. 价格位置因子
|
# 7. 价格位置因子
|
||||||
"high_low_ratio": "(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)",
|
"high_low_ratio": "(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)",
|
||||||
|
"n_income": "n_income"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Label 因子(单独定义,不参与训练)
|
# Label 因子(单独定义,不参与训练)
|
||||||
@@ -223,11 +226,17 @@ def train_regression_model():
|
|||||||
data_router=engine.router, # 从 FactorEngine 获取数据路由器
|
data_router=engine.router, # 从 FactorEngine 获取数据路由器
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 8.5 创建 ST 股票过滤器(在股票池筛选之前执行)
|
||||||
|
st_filter = STFilter(
|
||||||
|
data_router=engine.router,
|
||||||
|
)
|
||||||
|
|
||||||
# 9. 创建训练器
|
# 9. 创建训练器
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
pool_manager=pool_manager,
|
pool_manager=pool_manager,
|
||||||
processors=processors,
|
processors=processors,
|
||||||
|
filters=[st_filter], # 在股票池筛选之前过滤 ST 股票
|
||||||
splitter=splitter,
|
splitter=splitter,
|
||||||
target_col=target_col,
|
target_col=target_col,
|
||||||
feature_cols=feature_cols,
|
feature_cols=feature_cols,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
按需取数、组装核心宽表。
|
按需取数、组装核心宽表。
|
||||||
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||||||
支持内存数据源(用于测试)和真实数据库连接。
|
支持内存数据源(用于测试)和真实数据库连接。
|
||||||
|
支持标准等值匹配和 asof_backward(财务数据)两种拼接模式。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
@@ -12,6 +13,7 @@ import polars as pl
|
|||||||
|
|
||||||
from src.factors.engine.data_spec import DataSpec
|
from src.factors.engine.data_spec import DataSpec
|
||||||
from src.data.storage import Storage
|
from src.data.storage import Storage
|
||||||
|
from src.data.financial_loader import FinancialLoader
|
||||||
|
|
||||||
|
|
||||||
class DataRouter:
|
class DataRouter:
|
||||||
@@ -37,11 +39,13 @@ class DataRouter:
|
|||||||
self._cache: Dict[str, pl.DataFrame] = {}
|
self._cache: Dict[str, pl.DataFrame] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
# 数据库模式下初始化 Storage
|
# 数据库模式下初始化 Storage 和 FinancialLoader
|
||||||
if not self.is_memory_mode:
|
if not self.is_memory_mode:
|
||||||
self._storage = Storage()
|
self._storage = Storage()
|
||||||
|
self._financial_loader = FinancialLoader()
|
||||||
else:
|
else:
|
||||||
self._storage = None
|
self._storage = None
|
||||||
|
self._financial_loader = None
|
||||||
|
|
||||||
def fetch_data(
|
def fetch_data(
|
||||||
self,
|
self,
|
||||||
@@ -75,23 +79,122 @@ class DataRouter:
|
|||||||
required_tables[spec.table] = set()
|
required_tables[spec.table] = set()
|
||||||
required_tables[spec.table].update(spec.columns)
|
required_tables[spec.table].update(spec.columns)
|
||||||
|
|
||||||
# 从数据源获取各表数据
|
# 从数据源获取各表数据(使用合并后的 required_tables,避免重复加载)
|
||||||
table_data = {}
|
table_data = {}
|
||||||
for table_name, columns in required_tables.items():
|
for table_name, columns in required_tables.items():
|
||||||
df = self._load_table(
|
# 判断是标准表还是财务表
|
||||||
table_name=table_name,
|
is_financial = any(
|
||||||
columns=list(columns),
|
s.table == table_name and s.join_type == "asof_backward"
|
||||||
|
for s in data_specs
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_financial:
|
||||||
|
# 财务表:找到对应的 spec 获取 join 配置
|
||||||
|
financial_spec = next(
|
||||||
|
s
|
||||||
|
for s in data_specs
|
||||||
|
if s.table == table_name and s.join_type == "asof_backward"
|
||||||
|
)
|
||||||
|
spec = DataSpec(
|
||||||
|
table=table_name,
|
||||||
|
columns=list(columns),
|
||||||
|
join_type="asof_backward",
|
||||||
|
left_on=financial_spec.left_on,
|
||||||
|
right_on=financial_spec.right_on,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 标准表
|
||||||
|
spec = DataSpec(
|
||||||
|
table=table_name,
|
||||||
|
columns=list(columns),
|
||||||
|
join_type="standard",
|
||||||
|
)
|
||||||
|
|
||||||
|
df = self._load_table_from_spec(
|
||||||
|
spec=spec,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
stock_codes=stock_codes,
|
stock_codes=stock_codes,
|
||||||
)
|
)
|
||||||
table_data[table_name] = df
|
table_data[table_name] = df
|
||||||
|
|
||||||
# 组装核心宽表
|
# 组装核心宽表(支持多种 join 类型)
|
||||||
core_table = self._assemble_wide_table(table_data, required_tables)
|
core_table = self._assemble_wide_table_with_specs(
|
||||||
|
table_data, data_specs, start_date, end_date
|
||||||
|
)
|
||||||
|
|
||||||
return core_table
|
return core_table
|
||||||
|
|
||||||
|
def _load_table_from_spec(
|
||||||
|
self,
|
||||||
|
spec: DataSpec,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
stock_codes: Optional[List[str]] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""根据数据规格加载单个表的数据。
|
||||||
|
|
||||||
|
根据 spec.join_type 选择不同的加载方式:
|
||||||
|
- standard: 使用原有逻辑,基于 trade_date
|
||||||
|
- asof_backward: 使用 FinancialLoader,基于 f_ann_date,扩展回看期
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec: 数据规格
|
||||||
|
start_date: 开始日期
|
||||||
|
end_date: 结束日期
|
||||||
|
stock_codes: 股票代码过滤
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的 DataFrame
|
||||||
|
"""
|
||||||
|
cache_key = (
|
||||||
|
f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if cache_key in self._cache:
|
||||||
|
return self._cache[cache_key]
|
||||||
|
|
||||||
|
if spec.join_type == "asof_backward":
|
||||||
|
# 财务数据使用 FinancialLoader
|
||||||
|
if self._financial_loader is None:
|
||||||
|
raise RuntimeError("FinancialLoader 未初始化")
|
||||||
|
|
||||||
|
# 扩展日期范围(回看1年)
|
||||||
|
adjusted_start, _ = self._financial_loader.get_date_range_with_lookback(
|
||||||
|
start_date, end_date
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理 stock_codes
|
||||||
|
ts_code = stock_codes[0] if stock_codes and len(stock_codes) == 1 else None
|
||||||
|
|
||||||
|
df = self._financial_loader.load_financial_data(
|
||||||
|
table_name=spec.table,
|
||||||
|
columns=spec.columns,
|
||||||
|
start_date=adjusted_start,
|
||||||
|
end_date=end_date,
|
||||||
|
ts_code=ts_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
||||||
|
if stock_codes is not None and len(stock_codes) > 1:
|
||||||
|
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 标准表使用原有逻辑
|
||||||
|
df = self._load_table(
|
||||||
|
table_name=spec.table,
|
||||||
|
columns=spec.columns,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
stock_codes=stock_codes,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._cache[cache_key] = df
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
def _load_table(
|
def _load_table(
|
||||||
self,
|
self,
|
||||||
table_name: str,
|
table_name: str,
|
||||||
@@ -255,6 +358,119 @@ class DataRouter:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _assemble_wide_table_with_specs(
|
||||||
|
self,
|
||||||
|
table_data: Dict[str, pl.DataFrame],
|
||||||
|
data_specs: List[DataSpec],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""组装多表数据为核心宽表(支持多种 join 类型)。
|
||||||
|
|
||||||
|
支持标准等值匹配和 asof_backward 两种模式。
|
||||||
|
|
||||||
|
性能优化:
|
||||||
|
- 在开始时统一将 trade_date 转为 pl.Date
|
||||||
|
- 所有 asof join 全部在 pl.Date 类型下完成
|
||||||
|
- 返回前统一转回字符串格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_data: 表名到 DataFrame 的映射
|
||||||
|
data_specs: 数据规格列表
|
||||||
|
start_date: 开始日期
|
||||||
|
end_date: 结束日期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
组装后的宽表
|
||||||
|
"""
|
||||||
|
if not table_data:
|
||||||
|
raise ValueError("没有数据可组装")
|
||||||
|
|
||||||
|
# 从 data_specs 判断每个表的 join 类型
|
||||||
|
table_join_types = {}
|
||||||
|
for spec in data_specs:
|
||||||
|
if spec.table not in table_join_types:
|
||||||
|
table_join_types[spec.table] = spec.join_type
|
||||||
|
|
||||||
|
# 分离标准表和 asof 表(基于 table_data 的表名,避免重复)
|
||||||
|
standard_tables = [
|
||||||
|
t
|
||||||
|
for t in table_data.keys()
|
||||||
|
if table_join_types.get(t, "standard") == "standard"
|
||||||
|
]
|
||||||
|
asof_tables = [
|
||||||
|
t for t in table_data.keys() if table_join_types.get(t) == "asof_backward"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 先合并所有标准表(使用 trade_date)
|
||||||
|
base_df = None
|
||||||
|
for table_name in standard_tables:
|
||||||
|
df = table_data[table_name]
|
||||||
|
if base_df is None:
|
||||||
|
base_df = df
|
||||||
|
else:
|
||||||
|
# 使用 ts_code 和 trade_date 作为 join 键
|
||||||
|
# 注:根据动态路由原则,除 ts_code/trade_date 外不应有重复字段
|
||||||
|
# 如果出现重复,说明 SchemaCache 的字段映射有问题
|
||||||
|
base_df = base_df.join(
|
||||||
|
df,
|
||||||
|
on=["ts_code", "trade_date"],
|
||||||
|
how="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
if base_df is None:
|
||||||
|
raise ValueError("至少需要一张标准行情表作为基础")
|
||||||
|
|
||||||
|
# 【性能优化】统一转换 trade_date 为 Date 类型(只转换一次)
|
||||||
|
if asof_tables:
|
||||||
|
base_df = base_df.with_columns(
|
||||||
|
[
|
||||||
|
pl.col("trade_date")
|
||||||
|
.str.strptime(pl.Date, "%Y%m%d")
|
||||||
|
.alias("trade_date")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# 确保已排序(join_asof 要求)
|
||||||
|
base_df = base_df.sort(["ts_code", "trade_date"])
|
||||||
|
|
||||||
|
# 逐个合并 asof 表(所有 join 都在 Date 类型下进行)
|
||||||
|
for table_name in asof_tables:
|
||||||
|
df_financial = table_data[table_name]
|
||||||
|
# 提取需要保留的字段(排除 join 键和元数据字段)
|
||||||
|
# 从 data_specs 中找到对应表的 columns
|
||||||
|
table_columns = set()
|
||||||
|
for spec in data_specs:
|
||||||
|
if spec.table == table_name:
|
||||||
|
table_columns.update(spec.columns)
|
||||||
|
|
||||||
|
financial_cols = [
|
||||||
|
c
|
||||||
|
for c in table_columns
|
||||||
|
if c
|
||||||
|
not in [
|
||||||
|
"ts_code",
|
||||||
|
"f_ann_date",
|
||||||
|
"report_type",
|
||||||
|
"update_flag",
|
||||||
|
"end_date",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
if self._financial_loader is None:
|
||||||
|
raise RuntimeError("FinancialLoader 未初始化")
|
||||||
|
|
||||||
|
base_df = self._financial_loader.merge_financial_with_price(
|
||||||
|
base_df, df_financial, financial_cols
|
||||||
|
)
|
||||||
|
|
||||||
|
# 【性能优化】所有 asof join 完成后,统一转回字符串格式
|
||||||
|
if asof_tables:
|
||||||
|
base_df = base_df.with_columns(
|
||||||
|
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_df
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
"""清除数据缓存。"""
|
"""清除数据缓存。"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|||||||
@@ -4,24 +4,38 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataSpec:
|
class DataSpec:
|
||||||
"""数据规格定义。
|
"""数据规格定义(支持多表类型)。
|
||||||
|
|
||||||
描述因子计算所需的数据表和字段。
|
描述因子计算所需的数据表和字段,支持多种拼接类型。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
table: 数据表名称
|
table: 数据表名称
|
||||||
columns: 需要的字段列表
|
columns: 需要的字段列表
|
||||||
|
join_type: 拼接类型
|
||||||
|
- "standard": 标准等值匹配(默认)
|
||||||
|
- "asof_backward": 向后寻找最近历史数据(财务数据用)
|
||||||
|
left_on: 左表 join 键(asof 模式下必须指定)
|
||||||
|
right_on: 右表 join 键(asof 模式下必须指定)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table: str
|
table: str
|
||||||
columns: List[str]
|
columns: List[str]
|
||||||
|
join_type: Literal["standard", "asof_backward"] = "standard"
|
||||||
|
left_on: Optional[str] = None # 行情表日期列名
|
||||||
|
right_on: Optional[str] = None # 财务表日期列名
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""验证 asof_backward 模式的参数。"""
|
||||||
|
if self.join_type == "asof_backward":
|
||||||
|
if not self.left_on or not self.right_on:
|
||||||
|
raise ValueError("asof_backward 模式必须指定 left_on 和 right_on")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -72,9 +72,10 @@ class ExecutionPlanner:
|
|||||||
dependencies: Set[str],
|
dependencies: Set[str],
|
||||||
expression: Node,
|
expression: Node,
|
||||||
) -> List[DataSpec]:
|
) -> List[DataSpec]:
|
||||||
"""从依赖推导数据规格。
|
"""从依赖推导数据规格(支持财务数据自动识别)。
|
||||||
|
|
||||||
使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。
|
使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。
|
||||||
|
自动识别财务数据表并配置 asof_backward 模式。
|
||||||
表结构只扫描一次并缓存在内存中。
|
表结构只扫描一次并缓存在内存中。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -90,11 +91,21 @@ class ExecutionPlanner:
|
|||||||
|
|
||||||
data_specs = []
|
data_specs = []
|
||||||
for table_name, columns in table_to_fields.items():
|
for table_name, columns in table_to_fields.items():
|
||||||
data_specs.append(
|
if schema_cache.is_financial_table(table_name):
|
||||||
DataSpec(
|
# 财务表使用 asof_backward 模式
|
||||||
|
spec = DataSpec(
|
||||||
|
table=table_name,
|
||||||
|
columns=columns,
|
||||||
|
join_type="asof_backward",
|
||||||
|
left_on="trade_date",
|
||||||
|
right_on="f_ann_date",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 标准表使用默认模式
|
||||||
|
spec = DataSpec(
|
||||||
table=table_name,
|
table=table_name,
|
||||||
columns=columns,
|
columns=columns,
|
||||||
)
|
)
|
||||||
)
|
data_specs.append(spec)
|
||||||
|
|
||||||
return data_specs
|
return data_specs
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class SchemaCache:
|
|||||||
field_to_tables[field] = []
|
field_to_tables[field] = []
|
||||||
field_to_tables[field].append(table)
|
field_to_tables[field].append(table)
|
||||||
|
|
||||||
# 优先选择最常用的表(pro_bar > daily_basic > daily)
|
# 优先选择最常用的表(pro_bar > daily_basic > daily > financial)
|
||||||
priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3}
|
priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3}
|
||||||
|
|
||||||
self._field_to_table_map = {}
|
self._field_to_table_map = {}
|
||||||
@@ -124,6 +124,18 @@ class SchemaCache:
|
|||||||
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
|
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
|
||||||
self._field_to_table_map[field] = sorted_tables[0]
|
self._field_to_table_map[field] = sorted_tables[0]
|
||||||
|
|
||||||
|
def is_financial_table(self, table_name: str) -> bool:
|
||||||
|
"""判断是否为财务数据表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否为财务数据表
|
||||||
|
"""
|
||||||
|
financial_prefixes = ("financial_", "income", "balance", "cashflow")
|
||||||
|
return table_name.lower().startswith(financial_prefixes)
|
||||||
|
|
||||||
def get_table_fields(self, table_name: str) -> List[str]:
|
def get_table_fields(self, table_name: str) -> List[str]:
|
||||||
"""获取指定表的字段列表。
|
"""获取指定表的字段列表。
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ from src.training.components.processors import (
|
|||||||
# 模型
|
# 模型
|
||||||
from src.training.components.models import LightGBMModel
|
from src.training.components.models import LightGBMModel
|
||||||
|
|
||||||
|
# 数据过滤器
|
||||||
|
from src.training.components.filters import BaseFilter, STFilter
|
||||||
|
|
||||||
# 训练核心
|
# 训练核心
|
||||||
from src.training.core import StockPoolManager, Trainer
|
from src.training.core import StockPoolManager, Trainer
|
||||||
|
|
||||||
@@ -57,6 +60,9 @@ __all__ = [
|
|||||||
"StandardScaler",
|
"StandardScaler",
|
||||||
"CrossSectionalStandardScaler",
|
"CrossSectionalStandardScaler",
|
||||||
"Winsorizer",
|
"Winsorizer",
|
||||||
|
# 数据过滤器
|
||||||
|
"BaseFilter",
|
||||||
|
"STFilter",
|
||||||
# 模型
|
# 模型
|
||||||
"LightGBMModel",
|
"LightGBMModel",
|
||||||
# 训练核心
|
# 训练核心
|
||||||
|
|||||||
142
src/training/components/filters.py
Normal file
142
src/training/components/filters.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""数据过滤器组件
|
||||||
|
|
||||||
|
提供股票数据过滤功能,在因子计算后、市值筛选前执行。
|
||||||
|
与 Processor 不同,Filter 是无状态的筛选操作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Set
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.factors.engine.data_router import DataRouter
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFilter(ABC):
|
||||||
|
"""数据过滤器基类
|
||||||
|
|
||||||
|
Filter 用于从数据中移除不符合条件的行(股票)。
|
||||||
|
与 Processor 不同:
|
||||||
|
- Filter 是无状态的,不需要 fit
|
||||||
|
- Filter 删除整行数据,而不是变换列值
|
||||||
|
- Filter 每日独立执行
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""执行过滤
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入数据,必须包含 ts_code 和 trade_date 列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的数据
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class STFilter(BaseFilter):
|
||||||
|
"""ST 股票过滤器
|
||||||
|
|
||||||
|
过滤掉每日的 ST 股票(包括 ST、*ST、S*ST、SST 等)。
|
||||||
|
从 stock_st 表获取每日 ST 股票列表进行过滤。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
data_router: 数据路由器,用于获取 stock_st 表数据
|
||||||
|
code_col: 股票代码列名
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "st_filter"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_router: "DataRouter",
|
||||||
|
code_col: str = "ts_code",
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
):
|
||||||
|
"""初始化 ST 过滤器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_router: 数据路由器,用于查询 stock_st 表
|
||||||
|
code_col: 股票代码列名
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
self.data_router = data_router
|
||||||
|
self.code_col = code_col
|
||||||
|
self.date_col = date_col
|
||||||
|
# 缓存:{date: set(stock_codes)}
|
||||||
|
self._st_cache: dict = {}
|
||||||
|
|
||||||
|
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""过滤 ST 股票
|
||||||
|
|
||||||
|
按日期分组,每日独立从 stock_st 表获取 ST 列表并过滤。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 因子计算后的数据,包含 ts_code 和 trade_date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的数据(不含 ST 股票)
|
||||||
|
"""
|
||||||
|
dates = data.select(self.date_col).unique().sort(self.date_col)
|
||||||
|
|
||||||
|
result_frames = []
|
||||||
|
for date in dates.to_series():
|
||||||
|
# 获取当日数据
|
||||||
|
daily_data = data.filter(pl.col(self.date_col) == date)
|
||||||
|
daily_codes = daily_data.select(self.code_col).to_series().to_list()
|
||||||
|
|
||||||
|
# 获取当日 ST 股票列表
|
||||||
|
st_codes = self._get_st_codes_for_date(date)
|
||||||
|
|
||||||
|
# 过滤掉 ST 股票
|
||||||
|
daily_filtered = daily_data.filter(~pl.col(self.code_col).is_in(st_codes))
|
||||||
|
|
||||||
|
result_frames.append(daily_filtered)
|
||||||
|
|
||||||
|
# 打印过滤信息
|
||||||
|
n_removed = len(daily_codes) - len(daily_filtered)
|
||||||
|
if n_removed > 0:
|
||||||
|
print(f" [{date}] 过滤 {n_removed} 只 ST 股票")
|
||||||
|
|
||||||
|
return pl.concat(result_frames)
|
||||||
|
|
||||||
|
def _get_st_codes_for_date(self, date: str) -> Set[str]:
|
||||||
|
"""从 stock_st 表获取指定日期的 ST 股票代码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date: 日期 "YYYYMMDD"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ST 股票代码集合
|
||||||
|
"""
|
||||||
|
# 检查缓存
|
||||||
|
if date in self._st_cache:
|
||||||
|
return self._st_cache[date]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.factors.engine.data_spec import DataSpec
|
||||||
|
|
||||||
|
# 查询 stock_st 表获取当日所有 ST 股票
|
||||||
|
data_specs = [DataSpec("stock_st", [self.code_col])]
|
||||||
|
df = self.data_router.fetch_data(
|
||||||
|
data_specs=data_specs,
|
||||||
|
start_date=date,
|
||||||
|
end_date=date,
|
||||||
|
stock_codes=None, # 获取当日全部 ST 股票
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取 ST 股票代码
|
||||||
|
st_codes = set(df[self.code_col].to_list()) if len(df) > 0 else set()
|
||||||
|
|
||||||
|
# 缓存结果
|
||||||
|
self._st_cache[date] = st_codes
|
||||||
|
return st_codes
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[警告] 获取 {date} ST 股票列表失败: {e}")
|
||||||
|
return set()
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
整合数据处理、模型训练、预测的完整流程。
|
整合数据处理、模型训练、预测的完整流程。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
@@ -11,6 +11,9 @@ from src.training.components.base import BaseModel, BaseProcessor
|
|||||||
from src.training.components.splitters import DateSplitter
|
from src.training.components.splitters import DateSplitter
|
||||||
from src.training.core.stock_pool_manager import StockPoolManager
|
from src.training.core.stock_pool_manager import StockPoolManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.training.components.filters import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""训练器主类
|
"""训练器主类
|
||||||
@@ -29,6 +32,7 @@ class Trainer:
|
|||||||
model: BaseModel,
|
model: BaseModel,
|
||||||
pool_manager: Optional[StockPoolManager] = None,
|
pool_manager: Optional[StockPoolManager] = None,
|
||||||
processors: Optional[List[BaseProcessor]] = None,
|
processors: Optional[List[BaseProcessor]] = None,
|
||||||
|
filters: Optional[List["BaseFilter"]] = None,
|
||||||
splitter: Optional[DateSplitter] = None,
|
splitter: Optional[DateSplitter] = None,
|
||||||
target_col: str = "target",
|
target_col: str = "target",
|
||||||
feature_cols: Optional[List[str]] = None,
|
feature_cols: Optional[List[str]] = None,
|
||||||
@@ -41,6 +45,7 @@ class Trainer:
|
|||||||
model: 模型实例
|
model: 模型实例
|
||||||
pool_manager: 股票池管理器,None 表示不筛选
|
pool_manager: 股票池管理器,None 表示不筛选
|
||||||
processors: 数据处理器列表
|
processors: 数据处理器列表
|
||||||
|
filters: 数据过滤器列表(在股票池筛选之前执行)
|
||||||
splitter: 数据划分器
|
splitter: 数据划分器
|
||||||
target_col: 目标变量列名
|
target_col: 目标变量列名
|
||||||
feature_cols: 特征列名列表
|
feature_cols: 特征列名列表
|
||||||
@@ -50,6 +55,7 @@ class Trainer:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.pool_manager = pool_manager
|
self.pool_manager = pool_manager
|
||||||
self.processors = processors or []
|
self.processors = processors or []
|
||||||
|
self.filters = filters or []
|
||||||
self.splitter = splitter
|
self.splitter = splitter
|
||||||
self.target_col = target_col
|
self.target_col = target_col
|
||||||
self.feature_cols = feature_cols or []
|
self.feature_cols = feature_cols or []
|
||||||
@@ -80,6 +86,12 @@ class Trainer:
|
|||||||
Returns:
|
Returns:
|
||||||
self (支持链式调用)
|
self (支持链式调用)
|
||||||
"""
|
"""
|
||||||
|
# 0. 数据过滤(在股票池筛选之前)
|
||||||
|
if self.filters:
|
||||||
|
print("[过滤] 应用数据过滤器...")
|
||||||
|
for filter_ in self.filters:
|
||||||
|
data = filter_.filter(data)
|
||||||
|
|
||||||
# 1. 股票池筛选(每日独立)
|
# 1. 股票池筛选(每日独立)
|
||||||
if self.pool_manager:
|
if self.pool_manager:
|
||||||
print("[筛选] 每日独立筛选股票池...")
|
print("[筛选] 每日独立筛选股票池...")
|
||||||
|
|||||||
@@ -1,122 +0,0 @@
|
|||||||
"""Test for daily market data API.
|
|
||||||
|
|
||||||
Tests the daily interface implementation against api.md requirements:
|
|
||||||
- A股日线行情所有输出字段
|
|
||||||
- tor 换手率
|
|
||||||
- vr 量比
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pandas as pd
|
|
||||||
from src.data.api_wrappers import get_daily
|
|
||||||
|
|
||||||
|
|
||||||
# Expected output fields according to api.md
|
|
||||||
EXPECTED_BASE_FIELDS = [
|
|
||||||
"ts_code", # 股票代码
|
|
||||||
"trade_date", # 交易日期
|
|
||||||
"open", # 开盘价
|
|
||||||
"high", # 最高价
|
|
||||||
"low", # 最低价
|
|
||||||
"close", # 收盘价
|
|
||||||
"pre_close", # 昨收价
|
|
||||||
"change", # 涨跌额
|
|
||||||
"pct_chg", # 涨跌幅
|
|
||||||
"vol", # 成交量
|
|
||||||
"amount", # 成交额
|
|
||||||
]
|
|
||||||
|
|
||||||
EXPECTED_FACTOR_FIELDS = [
|
|
||||||
"turnover_rate", # 换手率 (tor)
|
|
||||||
"volume_ratio", # 量比 (vr)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetDaily:
|
|
||||||
"""Test cases for get_daily function with real API calls."""
|
|
||||||
|
|
||||||
def test_fetch_basic(self):
|
|
||||||
"""Test basic daily data fetch with real API."""
|
|
||||||
result = get_daily("000001.SZ", start_date="20240101", end_date="20240131")
|
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
|
||||||
assert len(result) >= 1
|
|
||||||
assert result["ts_code"].iloc[0] == "000001.SZ"
|
|
||||||
|
|
||||||
def test_fetch_with_factors(self):
|
|
||||||
"""Test fetch with tor and vr factors."""
|
|
||||||
result = get_daily(
|
|
||||||
"000001.SZ",
|
|
||||||
start_date="20240101",
|
|
||||||
end_date="20240131",
|
|
||||||
factors=["tor", "vr"],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
|
||||||
# Check all base fields are present
|
|
||||||
for field in EXPECTED_BASE_FIELDS:
|
|
||||||
assert field in result.columns, f"Missing base field: {field}"
|
|
||||||
# Check factor fields are present
|
|
||||||
for field in EXPECTED_FACTOR_FIELDS:
|
|
||||||
assert field in result.columns, f"Missing factor field: {field}"
|
|
||||||
|
|
||||||
def test_output_fields_completeness(self):
|
|
||||||
"""Verify all required output fields are returned."""
|
|
||||||
result = get_daily("600000.SH")
|
|
||||||
|
|
||||||
# Verify all base fields are present
|
|
||||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), (
|
|
||||||
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_empty_result(self):
|
|
||||||
"""Test handling of empty results."""
|
|
||||||
# 使用真实 API 测试无效股票代码的空结果
|
|
||||||
result = get_daily("INVALID.SZ")
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
|
||||||
assert result.empty
|
|
||||||
|
|
||||||
def test_date_range_query(self):
|
|
||||||
"""Test query with date range."""
|
|
||||||
result = get_daily(
|
|
||||||
"000001.SZ",
|
|
||||||
start_date="20240101",
|
|
||||||
end_date="20240131",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
|
||||||
assert len(result) >= 1
|
|
||||||
|
|
||||||
def test_with_adj(self):
|
|
||||||
"""Test fetch with adjustment type."""
|
|
||||||
result = get_daily("000001.SZ", adj="qfq")
|
|
||||||
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
|
||||||
|
|
||||||
|
|
||||||
def test_integration():
|
|
||||||
"""Integration test with real Tushare API (requires valid token)."""
|
|
||||||
import os
|
|
||||||
|
|
||||||
token = os.environ.get("TUSHARE_TOKEN")
|
|
||||||
if not token:
|
|
||||||
pytest.skip("TUSHARE_TOKEN not configured")
|
|
||||||
|
|
||||||
result = get_daily(
|
|
||||||
"000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify structure
|
|
||||||
assert isinstance(result, pd.DataFrame)
|
|
||||||
if not result.empty:
|
|
||||||
# Check base fields
|
|
||||||
for field in EXPECTED_BASE_FIELDS:
|
|
||||||
assert field in result.columns, f"Missing base field: {field}"
|
|
||||||
# Check factor fields
|
|
||||||
for field in EXPECTED_FACTOR_FIELDS:
|
|
||||||
assert field in result.columns, f"Missing factor field: {field}"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 运行 pytest 单元测试(真实API调用)
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
"""Tests for DuckDB storage validation.
|
|
||||||
|
|
||||||
Validates two key points:
|
|
||||||
1. All stocks from stock_basic.csv are saved in daily table
|
|
||||||
2. No abnormal data with very few data points (< 10 rows per stock)
|
|
||||||
|
|
||||||
使用 3 个月的真实数据进行测试 (2024年1月-3月)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pandas as pd
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from src.data.storage import Storage
|
|
||||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
|
||||||
|
|
||||||
|
|
||||||
class TestDailyStorageValidation:
|
|
||||||
"""Test daily table storage integrity and completeness."""
|
|
||||||
|
|
||||||
# 测试数据时间范围:3个月
|
|
||||||
TEST_START_DATE = "20240101"
|
|
||||||
TEST_END_DATE = "20240331"
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def storage(self):
|
|
||||||
"""Create storage instance."""
|
|
||||||
return Storage()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def stock_basic_df(self):
|
|
||||||
"""Load stock basic data from CSV."""
|
|
||||||
csv_path = _get_csv_path()
|
|
||||||
if not csv_path.exists():
|
|
||||||
pytest.skip(f"stock_basic.csv not found at {csv_path}")
|
|
||||||
return pd.read_csv(csv_path)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def daily_df(self, storage):
|
|
||||||
"""Load daily data from DuckDB (3 months)."""
|
|
||||||
if not storage.exists("daily"):
|
|
||||||
pytest.skip("daily table not found in DuckDB")
|
|
||||||
|
|
||||||
# 从 DuckDB 加载 3 个月数据
|
|
||||||
df = storage.load(
|
|
||||||
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
|
|
||||||
)
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
pytest.skip(
|
|
||||||
f"No data found for period {self.TEST_START_DATE} to {self.TEST_END_DATE}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
def test_duckdb_connection(self, storage):
|
|
||||||
"""Test DuckDB connection and basic operations."""
|
|
||||||
assert storage.exists("daily") or True # 至少连接成功
|
|
||||||
print(f"[TEST] DuckDB connection successful")
|
|
||||||
|
|
||||||
def test_load_3months_data(self, storage):
|
|
||||||
"""Test loading 3 months of data from DuckDB."""
|
|
||||||
df = storage.load(
|
|
||||||
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
|
|
||||||
)
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
pytest.skip("No data available for testing period")
|
|
||||||
|
|
||||||
# 验证数据覆盖范围
|
|
||||||
dates = df["trade_date"].astype(str)
|
|
||||||
min_date = dates.min()
|
|
||||||
max_date = dates.max()
|
|
||||||
|
|
||||||
print(f"[TEST] Loaded {len(df)} rows from {min_date} to {max_date}")
|
|
||||||
assert len(df) > 0, "Should have data in the 3-month period"
|
|
||||||
|
|
||||||
def test_all_stocks_saved(self, storage, stock_basic_df, daily_df):
|
|
||||||
"""Verify all stocks from stock_basic are saved in daily table.
|
|
||||||
|
|
||||||
This test ensures data completeness - every stock in stock_basic
|
|
||||||
should have corresponding data in daily table.
|
|
||||||
"""
|
|
||||||
if daily_df.empty:
|
|
||||||
pytest.fail("daily table is empty for test period")
|
|
||||||
|
|
||||||
# Get unique stock codes from both sources
|
|
||||||
expected_codes = set(stock_basic_df["ts_code"].dropna().unique())
|
|
||||||
actual_codes = set(daily_df["ts_code"].dropna().unique())
|
|
||||||
|
|
||||||
# Check for missing stocks
|
|
||||||
missing_codes = expected_codes - actual_codes
|
|
||||||
|
|
||||||
if missing_codes:
|
|
||||||
missing_list = sorted(missing_codes)
|
|
||||||
# Show first 20 missing stocks as sample
|
|
||||||
sample = missing_list[:20]
|
|
||||||
msg = f"Found {len(missing_codes)} stocks missing from daily table:\n"
|
|
||||||
msg += f"Sample missing: {sample}\n"
|
|
||||||
if len(missing_list) > 20:
|
|
||||||
msg += f"... and {len(missing_list) - 20} more"
|
|
||||||
# 对于3个月数据,允许部分股票缺失(可能是新股或未上市)
|
|
||||||
print(f"[WARNING] {msg}")
|
|
||||||
# 只验证至少有80%的股票存在
|
|
||||||
coverage = len(actual_codes) / len(expected_codes) * 100
|
|
||||||
assert coverage >= 80, (
|
|
||||||
f"Stock coverage {coverage:.1f}% is below 80% threshold"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"[TEST] All {len(expected_codes)} stocks from stock_basic are present in daily table"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_no_stock_with_insufficient_data(self, storage, daily_df):
|
|
||||||
"""Verify no stock has abnormally few data points (< 5 rows in 3 months).
|
|
||||||
|
|
||||||
Stocks with very few data points may indicate sync failures,
|
|
||||||
delisted stocks not properly handled, or data corruption.
|
|
||||||
"""
|
|
||||||
if daily_df.empty:
|
|
||||||
pytest.fail("daily table is empty for test period")
|
|
||||||
|
|
||||||
# Count rows per stock
|
|
||||||
stock_counts = daily_df.groupby("ts_code").size()
|
|
||||||
|
|
||||||
# Find stocks with less than 5 data points in 3 months
|
|
||||||
insufficient_stocks = stock_counts[stock_counts < 5]
|
|
||||||
|
|
||||||
if not insufficient_stocks.empty:
|
|
||||||
# Separate into categories for better reporting
|
|
||||||
empty_stocks = stock_counts[stock_counts == 0]
|
|
||||||
very_few_stocks = stock_counts[(stock_counts > 0) & (stock_counts < 5)]
|
|
||||||
|
|
||||||
msg = f"Found {len(insufficient_stocks)} stocks with insufficient data (< 5 rows in 3 months):\n"
|
|
||||||
|
|
||||||
if not empty_stocks.empty:
|
|
||||||
msg += f"\nEmpty stocks (0 rows): {len(empty_stocks)}\n"
|
|
||||||
sample = sorted(empty_stocks.index[:10].tolist())
|
|
||||||
msg += f"Sample: {sample}"
|
|
||||||
|
|
||||||
if not very_few_stocks.empty:
|
|
||||||
msg += f"\nVery few data points (1-4 rows): {len(very_few_stocks)}\n"
|
|
||||||
# Show counts for these stocks
|
|
||||||
sample = very_few_stocks.sort_values().head(20)
|
|
||||||
msg += "Sample (ts_code: count):\n"
|
|
||||||
for code, count in sample.items():
|
|
||||||
msg += f" {code}: {count} rows\n"
|
|
||||||
|
|
||||||
# 对于3个月数据,允许少量异常,但比例不能超过5%
|
|
||||||
if len(insufficient_stocks) / len(stock_counts) > 0.05:
|
|
||||||
pytest.fail(msg)
|
|
||||||
else:
|
|
||||||
print(f"[WARNING] {msg}")
|
|
||||||
|
|
||||||
print(f"[TEST] All stocks have sufficient data (>= 5 rows in 3 months)")
|
|
||||||
|
|
||||||
def test_data_integrity_basic(self, storage, daily_df):
|
|
||||||
"""Basic data integrity checks for daily table."""
|
|
||||||
if daily_df.empty:
|
|
||||||
pytest.fail("daily table is empty for test period")
|
|
||||||
|
|
||||||
# Check required columns exist
|
|
||||||
required_columns = ["ts_code", "trade_date"]
|
|
||||||
missing_columns = [
|
|
||||||
col for col in required_columns if col not in daily_df.columns
|
|
||||||
]
|
|
||||||
|
|
||||||
if missing_columns:
|
|
||||||
pytest.fail(f"Missing required columns: {missing_columns}")
|
|
||||||
|
|
||||||
# Check for null values in key columns
|
|
||||||
null_ts_code = daily_df["ts_code"].isna().sum()
|
|
||||||
null_trade_date = daily_df["trade_date"].isna().sum()
|
|
||||||
|
|
||||||
if null_ts_code > 0:
|
|
||||||
pytest.fail(f"Found {null_ts_code} rows with null ts_code")
|
|
||||||
if null_trade_date > 0:
|
|
||||||
pytest.fail(f"Found {null_trade_date} rows with null trade_date")
|
|
||||||
|
|
||||||
print(f"[TEST] Data integrity check passed for 3-month period")
|
|
||||||
|
|
||||||
def test_polars_export(self, storage):
|
|
||||||
"""Test Polars export functionality."""
|
|
||||||
if not storage.exists("daily"):
|
|
||||||
pytest.skip("daily table not found")
|
|
||||||
|
|
||||||
import polars as pl
|
|
||||||
|
|
||||||
# 测试 load_polars 方法
|
|
||||||
df = storage.load_polars(
|
|
||||||
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(df, pl.DataFrame), "Should return Polars DataFrame"
|
|
||||||
print(f"[TEST] Polars export successful: {len(df)} rows")
|
|
||||||
|
|
||||||
def test_stock_data_coverage_report(self, storage, daily_df):
|
|
||||||
"""Generate a summary report of stock data coverage.
|
|
||||||
|
|
||||||
This test provides visibility into data distribution without failing.
|
|
||||||
"""
|
|
||||||
if daily_df.empty:
|
|
||||||
pytest.skip("daily table is empty - cannot generate report")
|
|
||||||
|
|
||||||
stock_counts = daily_df.groupby("ts_code").size()
|
|
||||||
|
|
||||||
# Calculate statistics
|
|
||||||
total_stocks = len(stock_counts)
|
|
||||||
min_count = stock_counts.min()
|
|
||||||
max_count = stock_counts.max()
|
|
||||||
median_count = stock_counts.median()
|
|
||||||
mean_count = stock_counts.mean()
|
|
||||||
|
|
||||||
# Distribution buckets (adjusted for 3-month period, ~60 trading days)
|
|
||||||
very_low = (stock_counts < 5).sum()
|
|
||||||
low = ((stock_counts >= 5) & (stock_counts < 20)).sum()
|
|
||||||
medium = ((stock_counts >= 20) & (stock_counts < 40)).sum()
|
|
||||||
high = (stock_counts >= 40).sum()
|
|
||||||
|
|
||||||
report = f"""
|
|
||||||
=== Stock Data Coverage Report (3 months: {self.TEST_START_DATE} to {self.TEST_END_DATE}) ===
|
|
||||||
Total stocks: {total_stocks}
|
|
||||||
Data points per stock:
|
|
||||||
Min: {min_count}
|
|
||||||
Max: {max_count}
|
|
||||||
Median: {median_count:.0f}
|
|
||||||
Mean: {mean_count:.1f}
|
|
||||||
|
|
||||||
Distribution:
|
|
||||||
< 5 rows: {very_low} stocks ({very_low / total_stocks * 100:.1f}%)
|
|
||||||
5-19: {low} stocks ({low / total_stocks * 100:.1f}%)
|
|
||||||
20-39: {medium} stocks ({medium / total_stocks * 100:.1f}%)
|
|
||||||
>= 40: {high} stocks ({high / total_stocks * 100:.1f}%)
|
|
||||||
"""
|
|
||||||
print(report)
|
|
||||||
|
|
||||||
# This is an informational test - it should not fail
|
|
||||||
# But we assert to mark it as passed
|
|
||||||
assert total_stocks > 0
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v", "-s"])
|
|
||||||
@@ -15,7 +15,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
from src.data.data_router import DatabaseCatalog
|
from src.data.catalog import DatabaseCatalog
|
||||||
from src.factors.engine import FactorEngine
|
from src.factors.engine import FactorEngine
|
||||||
from src.factors.api import close, open, ts_mean, cs_rank
|
from src.factors.api import close, open, ts_mean, cs_rank
|
||||||
|
|
||||||
@@ -215,7 +215,7 @@ def run_factor_integration_test():
|
|||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
print("\n[4.1] 重新构建 Context LazyFrame 并打印前 5 行...")
|
print("\n[4.1] 重新构建 Context LazyFrame 并打印前 5 行...")
|
||||||
from src.data.data_router import build_context_lazyframe
|
from src.data.catalog import build_context_lazyframe
|
||||||
|
|
||||||
context_lf = build_context_lazyframe(
|
context_lf = build_context_lazyframe(
|
||||||
required_fields=["close", "open"],
|
required_fields=["close", "open"],
|
||||||
|
|||||||
244
tests/test_financial_price_merge.py
Normal file
244
tests/test_financial_price_merge.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""财务数据与行情数据拼接测试。
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
1. 普通财务数据:正常公告,之后无修改
|
||||||
|
2. 隔日修改:公告后几天发布修正版
|
||||||
|
3. 当日修改:同一天发布多版,取 update_flag=1 的
|
||||||
|
4. 边界条件:财务数据缺失、行情数据早于最早财务数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
from datetime import date
|
||||||
|
from src.data.financial_loader import FinancialLoader
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_price_data() -> pl.DataFrame:
|
||||||
|
"""创建模拟行情数据。"""
|
||||||
|
return pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"] * 10,
|
||||||
|
"trade_date": [
|
||||||
|
"20240101",
|
||||||
|
"20240102",
|
||||||
|
"20240103",
|
||||||
|
"20240104",
|
||||||
|
"20240105",
|
||||||
|
"20240108",
|
||||||
|
"20240109",
|
||||||
|
"20240110",
|
||||||
|
"20240111",
|
||||||
|
"20240112",
|
||||||
|
],
|
||||||
|
"close": [10.0, 10.2, 10.3, 10.1, 10.5, 10.6, 10.4, 10.7, 10.8, 10.9],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_financial_data() -> pl.DataFrame:
|
||||||
|
"""创建模拟财务数据(覆盖多种场景)。
|
||||||
|
|
||||||
|
注意:f_ann_date 必须是 Date 类型(与数据库保持一致)。
|
||||||
|
"""
|
||||||
|
return pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000001.SZ", "000001.SZ", "000001.SZ"],
|
||||||
|
# 场景1: 2023Q3 报告,正常公告
|
||||||
|
# 场景2: 同日多版(update_flag 区分)
|
||||||
|
# 场景3: 隔日修改
|
||||||
|
"f_ann_date": [
|
||||||
|
date(2024, 1, 2),
|
||||||
|
date(2024, 1, 2),
|
||||||
|
date(2024, 1, 5),
|
||||||
|
date(2024, 1, 10),
|
||||||
|
],
|
||||||
|
"end_date": ["20230930", "20230930", "20230930", "20231231"],
|
||||||
|
"report_type": [1, 1, 1, 1], # 整数类型(与数据库一致)
|
||||||
|
"update_flag": [0, 1, 1, 1], # 整数类型(与数据库一致)
|
||||||
|
"net_profit": [1000000.0, 1100000.0, 1100000.0, 1200000.0],
|
||||||
|
"revenue": [5000000.0, 5200000.0, 5200000.0, 6000000.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_financial_data_cleaning():
|
||||||
|
"""测试财务数据清洗逻辑。"""
|
||||||
|
print("=== 测试 1: 财务数据清洗 ===")
|
||||||
|
|
||||||
|
df_finance = create_mock_financial_data()
|
||||||
|
print("原始财务数据:")
|
||||||
|
print(df_finance)
|
||||||
|
|
||||||
|
loader = FinancialLoader()
|
||||||
|
|
||||||
|
# 手动执行清洗(模拟 load_financial_data 的逻辑)
|
||||||
|
# 步骤1: 仅保留合并报表
|
||||||
|
df = df_finance.filter(pl.col("report_type") == 1)
|
||||||
|
|
||||||
|
# 步骤2: 按 update_flag 降序排列后去重
|
||||||
|
df = df.with_columns(
|
||||||
|
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
|
||||||
|
)
|
||||||
|
|
||||||
|
df = df.sort(
|
||||||
|
["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True]
|
||||||
|
)
|
||||||
|
|
||||||
|
df = df.unique(subset=["ts_code", "f_ann_date"], keep="first")
|
||||||
|
df = df.drop("update_flag_int")
|
||||||
|
|
||||||
|
# 步骤3: 排序(f_ann_date 已经是 Date 类型)
|
||||||
|
df = df.sort(["ts_code", "f_ann_date"])
|
||||||
|
|
||||||
|
print("\n清洗后的财务数据:")
|
||||||
|
print(df)
|
||||||
|
|
||||||
|
# 验证:应该有3条记录(第1-2行去重为1条,第3行,第4行)
|
||||||
|
assert len(df) == 3, f"清洗后应该有3条记录,实际有 {len(df)} 条"
|
||||||
|
|
||||||
|
# 验证:2024-01-02 的 update_flag 应该是 1
|
||||||
|
row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2))
|
||||||
|
assert len(row_jan02) == 1, "应该有1条 2024-01-02 的记录"
|
||||||
|
assert row_jan02["update_flag"][0] == 1, "update_flag 应该为 1"
|
||||||
|
assert row_jan02["net_profit"][0] == 1100000.0, "net_profit 应该为 1100000"
|
||||||
|
|
||||||
|
print("\n[通过] 财务数据清洗测试通过!")
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def test_financial_price_merge():
|
||||||
|
"""测试财务数据拼接逻辑(无未来函数验证)。"""
|
||||||
|
print("\n=== 测试 2: 财务数据与行情数据拼接 ===")
|
||||||
|
|
||||||
|
df_price = create_mock_price_data()
|
||||||
|
df_finance_raw = create_mock_financial_data()
|
||||||
|
|
||||||
|
loader = FinancialLoader()
|
||||||
|
|
||||||
|
# 步骤1: 清洗财务数据(手动执行)
|
||||||
|
# 注意:f_ann_date 已经是 Date 类型,不需要转换
|
||||||
|
df_finance = df_finance_raw.filter(pl.col("report_type") == 1)
|
||||||
|
df_finance = df_finance.with_columns(
|
||||||
|
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
|
||||||
|
)
|
||||||
|
df_finance = df_finance.sort(
|
||||||
|
["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True]
|
||||||
|
)
|
||||||
|
df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="first")
|
||||||
|
df_finance = df_finance.drop("update_flag_int")
|
||||||
|
df_finance = df_finance.sort(["ts_code", "f_ann_date"])
|
||||||
|
|
||||||
|
print("清洗后的财务数据:")
|
||||||
|
print(df_finance)
|
||||||
|
|
||||||
|
# 步骤2: 转换行情数据日期为 Date 类型
|
||||||
|
df_price = df_price.with_columns(
|
||||||
|
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
|
||||||
|
)
|
||||||
|
df_price = df_price.sort(["ts_code", "trade_date"])
|
||||||
|
|
||||||
|
# 步骤3: 拼接
|
||||||
|
financial_cols = ["net_profit", "revenue"]
|
||||||
|
merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols)
|
||||||
|
|
||||||
|
# 步骤4: 转回字符串格式
|
||||||
|
merged = merged.with_columns(
|
||||||
|
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n拼接结果:")
|
||||||
|
print(merged)
|
||||||
|
|
||||||
|
# 验证无未来函数:
|
||||||
|
# 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告)
|
||||||
|
jan01 = merged.filter(pl.col("trade_date") == "20240101")
|
||||||
|
assert jan01["net_profit"].is_null().all(), (
|
||||||
|
"2024-01-01 不应有 2023Q3 数据(尚未公告)"
|
||||||
|
)
|
||||||
|
print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)")
|
||||||
|
|
||||||
|
# 20240102 及之后应该看到 net_profit=1100000(update_flag=1 的版本)
|
||||||
|
jan02 = merged.filter(pl.col("trade_date") == "20240102")
|
||||||
|
assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据"
|
||||||
|
print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1)")
|
||||||
|
|
||||||
|
# 20240104 应延续使用 2023Q3 数据
|
||||||
|
jan04 = merged.filter(pl.col("trade_date") == "20240104")
|
||||||
|
assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据"
|
||||||
|
print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)")
|
||||||
|
|
||||||
|
# 20240110 应切换到 2023Q4 数据(新公告)
|
||||||
|
jan10 = merged.filter(pl.col("trade_date") == "20240110")
|
||||||
|
assert jan10["net_profit"][0] == 1200000.0, "2024-01-10 应切换到 2023Q4 数据"
|
||||||
|
print("[验证 4] 2024-01-10 net_profit=1200000 - 正确(新财报公告)")
|
||||||
|
|
||||||
|
# 20240112 应继续延续使用 2023Q4 数据
|
||||||
|
jan12 = merged.filter(pl.col("trade_date") == "20240112")
|
||||||
|
assert jan12["net_profit"][0] == 1200000.0, "2024-01-12 应继续使用 2023Q4 数据"
|
||||||
|
print("[验证 5] 2024-01-12 net_profit=1200000 - 正确(延续使用)")
|
||||||
|
|
||||||
|
print("\n[通过] 所有验证通过,无未来函数!")
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_financial_data():
|
||||||
|
"""测试财务数据为空的情况。"""
|
||||||
|
print("\n=== 测试 3: 空财务数据场景 ===")
|
||||||
|
|
||||||
|
df_price = create_mock_price_data()
|
||||||
|
df_empty = pl.DataFrame()
|
||||||
|
|
||||||
|
loader = FinancialLoader()
|
||||||
|
|
||||||
|
# 转换行情数据日期为 Date 类型
|
||||||
|
df_price = df_price.with_columns(
|
||||||
|
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
|
||||||
|
)
|
||||||
|
df_price = df_price.sort(["ts_code", "trade_date"])
|
||||||
|
|
||||||
|
# 拼接空财务数据
|
||||||
|
merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"])
|
||||||
|
|
||||||
|
# 转回字符串格式
|
||||||
|
merged = merged.with_columns(
|
||||||
|
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证财务列为空
|
||||||
|
assert merged["net_profit"].is_null().all(), (
|
||||||
|
"财务数据为空时,net_profit 应全为 null"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("空财务数据拼接结果:")
|
||||||
|
print(merged)
|
||||||
|
print("\n[通过] 空财务数据场景测试通过!")
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests():
|
||||||
|
"""运行所有测试。"""
|
||||||
|
print("开始运行财务数据拼接功能测试...\n")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 测试 1: 数据清洗
|
||||||
|
test_financial_data_cleaning()
|
||||||
|
|
||||||
|
# 测试 2: 数据拼接
|
||||||
|
test_financial_price_merge()
|
||||||
|
|
||||||
|
# 测试 3: 空数据场景
|
||||||
|
test_empty_financial_data()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("所有测试通过!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
except AssertionError as e:
|
||||||
|
print(f"\n[失败] 测试断言失败: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n[错误] 测试执行出错: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_all_tests()
|
||||||
Reference in New Issue
Block a user