diff --git a/docs/plan/training_module_plan.md b/docs/plan/training_module_plan.md new file mode 100644 index 0000000..2f31aa7 --- /dev/null +++ b/docs/plan/training_module_plan.md @@ -0,0 +1,1089 @@ +# 训练模块实现计划 + +## 1. 概述 + +本计划描述 ProStock 训练模块的完整实现方案,支持标准回归模型训练并输出预测结果。 + +### 1.1 目标 +- 提供简洁的训练流程 +- 支持标准回归模型(LightGBM、CatBoost) +- 输出可解释的预测结果 +- 与现有因子引擎无缝集成 + +### 1.2 设计原则 +- **职责分离**:训练流程与建模组件分离 +- **注册机制**:使用装饰器实现即插即用 +- **配置驱动**:所有参数通过配置类管理 +- **阶段感知**:Processor 在训练和测试阶段行为不同 +- **每日筛选**:股票池每日独立筛选(市值动态变化) +- **模型持久化**:支持保存和加载训练好的模型 + +## 2. 模块结构 + +``` +src/ +├── training/ # 训练模块(流程 + 组件) +│ ├── __init__.py # 导出核心类 +│ ├── core/ # 训练流程核心 +│ │ ├── __init__.py +│ │ ├── trainer.py # Trainer 主类 +│ │ └── stock_pool_manager.py # 股票池管理器(每日独立筛选) +│ ├── components/ # 建模组件 +│ │ ├── __init__.py +│ │ ├── base.py # BaseModel, BaseProcessor 抽象基类 +│ │ ├── splitters.py # 时间序列划分策略(一次性划分) +│ │ ├── selectors.py # 股票池选择器配置 +│ │ ├── models/ # 模型实现 +│ │ │ ├── __init__.py +│ │ │ ├── lightgbm.py # LightGBM 回归模型 +│ │ │ └── catboost.py # CatBoost 回归模型 +│ │ ├── processors/ # 数据处理器 +│ │ │ ├── __init__.py +│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾 +│ │ └── metrics/ # 评估指标 +│ │ ├── __init__.py +│ │ └── metrics.py # IC, RankIC, MSE, MAE +│ ├── config/ # 配置管理 +│ │ ├── __init__.py +│ │ └── config.py # TrainingConfig (pydantic) +│ └── registry.py # 组件注册中心 +│ +└── experiment/ # 实验管理(预留结构,暂不实现) + └── __init__.py +``` + +## 3. 核心组件设计 + +### 3.1 抽象基类 (components/base.py) + +#### BaseModel +```python +class BaseModel(ABC): + """模型基类""" + + name: str = "" # 模型名称 + + def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel": + """训练模型""" + raise NotImplementedError + + def predict(self, X: pl.DataFrame) -> np.ndarray: + """预测""" + raise NotImplementedError + + def feature_importance(self) -> Optional[pd.Series]: + """特征重要性""" + return None + + def save(self, path: str) -> None: + """保存模型到文件 + + 默认实现使用 pickle,子类可覆盖 + """ + import pickle + with open(path, 'wb') as f: + pickle.dump(self, f) + + @classmethod + def load(cls, path: str) -> "BaseModel": + """从文件加载模型""" + import pickle + with open(path, 'rb') as f: + return pickle.load(f) +``` + +#### BaseProcessor +```python +class BaseProcessor(ABC): + """数据处理器基类 + + 重要:Processor 在不同阶段行为不同: + - 训练阶段:fit_transform(学习参数并应用) + - 验证/测试阶段:transform(使用训练阶段学到的参数) + + 这意味着 Processor 实例会在训练后被保存, + 用于后续的验证和测试数据转换。 + """ + + name: str = "" + + def fit(self, X: pl.DataFrame) -> "BaseProcessor": + """学习参数(仅在训练阶段调用)""" + return self + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """转换数据""" + raise NotImplementedError + + def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame: + """拟合并转换(训练阶段使用)""" + return self.fit(X).transform(X) +``` + +### 3.2 时间序列划分 (components/splitters.py) + +**设计说明**:暂不实现滚动训练,采用一次性训练/测试划分。 + +#### DateSplitter +```python +class DateSplitter: + """基于日期范围的一次性划分 + + 将数据按日期划分为训练集和测试集,不滚动。 + + 示例: + train_start: "20200101", train_end: "20221231" (训练集:3年) + test_start: "20230101", test_end: "20231231" (测试集:1年) + + 特点: + - 一次性划分,不滚动 + - 训练集和测试集互不重叠 + - 基于实际日期范围,而非行数 + """ + + def __init__( + self, + train_start: str, # 训练期开始日期 "YYYYMMDD" + train_end: str, # 训练期结束日期 "YYYYMMDD" + test_start: str, # 测试期开始日期 "YYYYMMDD" + test_end: str, # 测试期结束日期 "YYYYMMDD" + ): + self.train_start = train_start + self.train_end = train_end + self.test_start = test_start + self.test_end = test_end + + def split(self, data: pl.DataFrame) -> Tuple[pl.DataFrame, pl.DataFrame]: + """划分数据为训练集和测试集""" + train_data = data.filter( + (pl.col("trade_date") >= self.train_start) & + (pl.col("trade_date") <= self.train_end) + ) + test_data = data.filter( + (pl.col("trade_date") >= self.test_start) & + (pl.col("trade_date") <= self.test_end) + ) + return train_data, test_data +``` + +### 3.3 股票池选择器配置 (components/selectors.py) + +**设计说明**:股票池每日独立筛选,市值选择需要配合 StockPoolManager 使用。 + +#### StockFilterConfig +```python +@dataclass +class StockFilterConfig: + """股票过滤器配置 + + 用于过滤掉不需要的股票(如创业板、科创板等)。 + 基于股票代码进行过滤,不依赖外部数据。 + """ + + exclude_cyb: bool = True # 是否排除创业板(300xxx) + exclude_kcb: bool = True # 是否排除科创板(688xxx) + exclude_bj: bool = True # 是否排除北交所(8xxxxxx, 4xxxxxx) + exclude_st: bool = True # 是否排除ST股票 + + def filter_codes(self, codes: List[str]) -> List[str]: + """应用过滤条件,返回过滤后的股票代码列表""" + result = [] + for code in codes: + if self.exclude_cyb and code.startswith("300"): + continue + if self.exclude_kcb and code.startswith("688"): + continue + if self.exclude_bj and (code.startswith("8") or code.startswith("4")): + continue + # ST 股票过滤需要额外数据,在 StockPoolManager 中处理 + result.append(code) + return result +``` + +#### MarketCapSelectorConfig +```python +@dataclass +class MarketCapSelectorConfig: + """市值选择器配置 + + 每日独立选择市值最大或最小的 n 只股票。 + 市值数据从 daily_basic 表独立获取,仅用于筛选。 + """ + + enabled: bool = True # 是否启用选择 + n: int = 100 # 选择前 n 只 + ascending: bool = False # False=最大市值, True=最小市值 + market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic) +``` + +### 3.4 股票池管理器 (core/stock_pool_manager.py) + +**设计说明**:每日独立筛选股票池,市值数据从 daily_basic 表独立获取。 + +```python +class StockPoolManager: + """股票池管理器 - 每日独立筛选 + + 重要约束: + 1. 市值数据仅从 daily_basic 表获取,仅用于筛选 + 2. 市值数据绝不混入特征矩阵 + 3. 每日独立筛选(市值是动态变化的) + + 处理流程(每日): + 当日所有股票 + ↓ + 代码过滤(创业板、ST等) + ↓ + 查询 daily_basic 获取当日市值 + ↓ + 市值选择(前N只) + ↓ + 返回当日选中股票列表 + """ + + def __init__( + self, + filter_config: StockFilterConfig, + selector_config: Optional[MarketCapSelectorConfig], + data_router: DataRouter, # 用于获取 daily_basic 数据 + code_col: str = "ts_code", + date_col: str = "trade_date", + ): + self.filter_config = filter_config + self.selector_config = selector_config + self.data_router = data_router + self.code_col = code_col + self.date_col = date_col + + def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame: + """每日独立筛选股票池 + + Args: + data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列 + + Returns: + 筛选后的数据,仅包含每日选中的股票 + + 注意: + - 按日期分组处理 + - 市值数据从 daily_basic 独立获取 + - 保持市值数据与特征数据隔离 + """ + dates = data.select(self.date_col).unique().sort(self.date_col) + + result_frames = [] + for date in dates.to_series(): + # 获取当日数据 + daily_data = data.filter(pl.col(self.date_col) == date) + daily_codes = daily_data.select(self.code_col).to_series().to_list() + + # 1. 代码过滤 + filtered_codes = self.filter_config.filter_codes(daily_codes) + + # 2. 市值选择(如果启用) + if self.selector_config and self.selector_config.enabled: + # 从 daily_basic 获取当日市值 + market_caps = self._get_market_caps_for_date(filtered_codes, date) + selected_codes = self._select_by_market_cap(filtered_codes, market_caps) + else: + selected_codes = filtered_codes + + # 3. 保留当日选中的股票数据 + daily_selected = daily_data.filter( + pl.col(self.code_col).is_in(selected_codes) + ) + result_frames.append(daily_selected) + + return pl.concat(result_frames) + + def _get_market_caps_for_date( + self, + codes: List[str], + date: str + ) -> Dict[str, float]: + """从 daily_basic 表获取指定日期的市值数据 + + Args: + codes: 股票代码列表 + date: 日期 "YYYYMMDD" + + Returns: + {股票代码: 市值} 的字典 + """ + # 通过 data_router 查询 daily_basic 表 + pass + + def _select_by_market_cap( + self, + codes: List[str], + market_caps: Dict[str, float] + ) -> List[str]: + """根据市值选择股票""" + if not market_caps: + return codes + + # 按市值排序并选择前N只 + sorted_codes = sorted( + codes, + key=lambda c: market_caps.get(c, 0), + reverse=not self.selector_config.ascending + ) + return sorted_codes[:self.selector_config.n] +``` + +### 3.5 模型实现 (components/models/) + +#### LightGBMModel +```python +@register_model("lightgbm") +class LightGBMModel(BaseModel): + """LightGBM 回归模型""" + + name = "lightgbm" + + def __init__(self, + objective: str = "regression", + metric: str = "rmse", + num_leaves: int = 31, + learning_rate: float = 0.05, + n_estimators: int = 100, + **kwargs): + self.params = { + "objective": objective, + "metric": metric, + "num_leaves": num_leaves, + "learning_rate": learning_rate, + "n_estimators": n_estimators, + **kwargs + } + self.model = None + + def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel": + """训练模型""" + import lightgbm as lgb + + # 转换为 numpy + X_np = X.to_numpy() + y_np = y.to_numpy() + + # 创建数据集 + train_data = lgb.Dataset(X_np, label=y_np) + + # 训练 + self.model = lgb.train( + self.params, + train_data, + num_boost_round=self.params.get("n_estimators", 100) + ) + return self + + def predict(self, X: pl.DataFrame) -> np.ndarray: + """预测""" + if self.model is None: + raise RuntimeError("Model not fitted yet") + X_np = X.to_numpy() + return self.model.predict(X_np) + + def feature_importance(self) -> pd.Series: + """返回特征重要性""" + if self.model is None: + return None + importance = self.model.feature_importance(importance_type="gain") + return pd.Series(importance, index=self.feature_names_) + + def save(self, path: str) -> None: + """保存模型(使用 LightGBM 原生格式)""" + if self.model is None: + raise RuntimeError("Model not fitted yet") + self.model.save_model(path) + + @classmethod + def load(cls, path: str) -> "LightGBMModel": + """加载模型""" + import lightgbm as lgb + instance = cls() + instance.model = lgb.Booster(model_file=path) + return instance +``` + +### 3.6 数据处理器 (components/processors/) + +#### StandardScaler +```python +@register_processor("standard_scaler") +class StandardScaler(BaseProcessor): + """标准化处理器(时序标准化) + + 在整个训练集上学习均值和标准差, + 然后应用到训练集和测试集。 + """ + + name = "standard_scaler" + + def __init__(self, exclude_cols: List[str] = None): + self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] + self.mean_ = {} + self.std_ = {} + + def fit(self, X: pl.DataFrame) -> "StandardScaler": + """计算均值和标准差(仅在训练集上)""" + numeric_cols = [ + c for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + for col in numeric_cols: + self.mean_[col] = X[col].mean() + self.std_[col] = X[col].std() + + return self + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """标准化(使用训练集学到的参数)""" + expressions = [] + for col in X.columns: + if col in self.mean_: + expr = ((pl.col(col) - self.mean_[col]) / self.std_[col]).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + + return X.select(expressions) +``` + +#### CrossSectionalStandardScaler +```python +@register_processor("cs_standard_scaler") +class CrossSectionalStandardScaler(BaseProcessor): + """截面标准化处理器 + + 每天独立进行标准化:对当天所有股票的某一因子进行标准化。 + + 特点: + - 不需要 fit,每天独立计算当天的均值和标准差 + - 适用于截面因子,消除市值等行业差异 + - 公式:z = (x - mean_today) / std_today + """ + + name = "cs_standard_scaler" + + def __init__(self, exclude_cols: List[str] = None, date_col: str = "trade_date"): + self.exclude_cols = exclude_cols or ["ts_code", "trade_date"] + self.date_col = date_col + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """截面标准化 + + 按日期分组,每天独立计算均值和标准差并标准化。 + 不需要 fit,因为每天使用当天的统计量。 + """ + numeric_cols = [ + c for c in X.columns + if c not in self.exclude_cols and X[c].dtype.is_numeric() + ] + + # 按日期分组标准化 + result = X.with_columns([ + pl.col(col).mean().over(self.date_col).alias(f"{col}_mean") + for col in numeric_cols + ] + [ + pl.col(col).std().over(self.date_col).alias(f"{col}_std") + for col in numeric_cols + ]) + + # 计算标准化值 + for col in numeric_cols: + result = result.with_columns([ + ((pl.col(col) - pl.col(f"{col}_mean")) / pl.col(f"{col}_std")).alias(col) + ]) + # 删除中间列 + result = result.drop([f"{col}_mean", f"{col}_std"]) + + return result +``` + +#### Winsorizer +```python +@register_processor("winsorizer") +class Winsorizer(BaseProcessor): + """缩尾处理器 + + 对每一列的极端值进行截断处理。 + 可以全局截断(在整个训练集上学习分位数), + 也可以截面截断(每天独立处理)。 + """ + + name = "winsorizer" + + def __init__( + self, + lower: float = 0.01, + upper: float = 0.99, + by_date: bool = False, # True=每天独立缩尾, False=全局缩尾 + date_col: str = "trade_date" + ): + self.lower = lower + self.upper = upper + self.by_date = by_date + self.date_col = date_col + self.bounds_ = {} # 存储分位数边界(全局模式) + + def fit(self, X: pl.DataFrame) -> "Winsorizer": + """学习分位数边界(仅在全局模式下)""" + if not self.by_date: + numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()] + for col in numeric_cols: + self.bounds_[col] = { + "lower": X[col].quantile(self.lower), + "upper": X[col].quantile(self.upper) + } + return self + + def transform(self, X: pl.DataFrame) -> pl.DataFrame: + """缩尾处理""" + if self.by_date: + # 每天独立缩尾 + return self._transform_by_date(X) + else: + # 全局缩尾 + return self._transform_global(X) + + def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame: + """全局缩尾(使用训练集学到的边界)""" + expressions = [] + for col in X.columns: + if col in self.bounds_: + lower = self.bounds_[col]["lower"] + upper = self.bounds_[col]["upper"] + expr = pl.col(col).clip(lower, upper).alias(col) + expressions.append(expr) + else: + expressions.append(pl.col(col)) + return X.select(expressions) + + def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame: + """每日独立缩尾""" + # 按日期分组计算分位数并截断 + # Polars 实现... + pass +``` + +### 3.7 评估指标 (components/metrics/) + +```python +def ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: + """信息系数 (Pearson 相关系数)""" + from scipy.stats import pearsonr + corr, _ = pearsonr(y_true, y_pred) + return corr + +def rank_ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: + """秩相关系数 (Spearman 相关系数)""" + from scipy.stats import spearmanr + corr, _ = spearmanr(y_true, y_pred) + return corr + +def mse_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: + """均方误差""" + return np.mean((y_true - y_pred) ** 2) + +def mae_score(y_true: np.ndarray, y_pred: np.ndarray) -> float: + """平均绝对误差""" + return np.mean(np.abs(y_true - y_pred)) +``` + +## 4. 训练流程设计 + +### 4.1 Trainer 主类 (core/trainer.py) + +```python +class Trainer: + """训练器主类 + + 整合数据处理、模型训练、评估的完整流程。 + + 关键设计: + 1. 因子先计算(全市场),再筛选股票池(每日独立) + 2. Processor 分阶段行为:训练集 fit_transform,测试集 transform + 3. 一次性训练,不滚动 + 4. 支持模型持久化 + """ + + def __init__( + self, + model: BaseModel, + pool_manager: Optional[StockPoolManager] = None, + processors: List[BaseProcessor] = None, + splitter: DateSplitter = None, + metrics: List[str] = None, + target_col: str = "target", + feature_cols: List[str] = None, + persist_model: bool = False, + model_save_path: Optional[str] = None, + ): + self.model = model + self.pool_manager = pool_manager + self.processors = processors or [] + self.splitter = splitter + self.metrics = metrics or ["ic", "rank_ic", "mse", "mae"] + self.target_col = target_col + self.feature_cols = feature_cols + self.persist_model = persist_model + self.model_save_path = model_save_path + + # 存储训练后的处理器 + self.fitted_processors: List[BaseProcessor] = [] + self.results: pl.DataFrame = None + self.metrics_results: Dict[str, float] = {} + + def train(self, data: pl.DataFrame) -> "Trainer": + """执行训练流程 + + 流程: + 1. 股票池每日筛选(如果配置了 pool_manager) + 2. 按日期划分训练集/测试集 + 3. 训练集:processors fit_transform + 4. 训练模型 + 5. 测试集:processors transform(使用训练集学到的参数) + 6. 预测并评估 + 7. 持久化模型(如果启用) + + Args: + data: 因子计算后的全市场数据 + 必须包含 ts_code 和 trade_date 列 + + Returns: + self (支持链式调用) + """ + # 1. 股票池筛选(每日独立) + if self.pool_manager: + print("[筛选] 每日独立筛选股票池...") + data = self.pool_manager.filter_and_select_daily(data) + + # 2. 划分训练/测试集 + print("[划分] 划分训练集和测试集...") + train_data, test_data = self.splitter.split(data) + + # 3. 训练集:processors fit_transform + print("[处理] 处理训练集...") + for processor in self.processors: + train_data = processor.fit_transform(train_data) + self.fitted_processors.append(processor) + + # 4. 训练模型 + print("[训练] 训练模型...") + X_train = train_data.select(self.feature_cols) + y_train = train_data.select(self.target_col).to_series() + self.model.fit(X_train, y_train) + + # 5. 测试集:processors transform + print("[处理] 处理测试集...") + for processor in self.fitted_processors: + test_data = processor.transform(test_data) + + # 6. 预测 + print("[预测] 生成预测...") + X_test = test_data.select(self.feature_cols) + predictions = self.model.predict(X_test) + + # 7. 评估 + print("[评估] 计算指标...") + y_test = test_data.select(self.target_col).to_series().to_numpy() + self._evaluate(y_test, predictions) + + # 8. 保存结果 + self.results = test_data.with_columns([ + pl.Series("prediction", predictions) + ]) + + # 9. 持久化模型 + if self.persist_model and self.model_save_path: + print(f"[保存] 保存模型到 {self.model_save_path}...") + self.save_model(self.model_save_path) + + return self + + def _evaluate(self, y_true: np.ndarray, y_pred: np.ndarray): + """计算评估指标""" + from .components.metrics import ic_score, rank_ic_score, mse_score, mae_score + + metric_funcs = { + "ic": ic_score, + "rank_ic": rank_ic_score, + "mse": mse_score, + "mae": mae_score, + } + + for metric in self.metrics: + if metric in metric_funcs: + self.metrics_results[metric] = metric_funcs[metric](y_true, y_pred) + + def predict(self, data: pl.DataFrame) -> pl.DataFrame: + """对新数据进行预测 + + 注意:新数据需要先经过股票池筛选, + 然后使用训练好的 processors 和 model 进行预测。 + """ + # 应用 processors + for processor in self.fitted_processors: + data = processor.transform(data) + + # 预测 + X = data.select(self.feature_cols) + predictions = self.model.predict(X) + + return data.with_columns([pl.Series("prediction", predictions)]) + + def get_results(self) -> pl.DataFrame: + """获取所有预测结果""" + return self.results + + def get_metrics(self) -> Dict[str, float]: + """获取评估指标""" + return self.metrics_results + + def save_results(self, path: str): + """保存预测结果到文件""" + if self.results is not None: + self.results.write_csv(path) + + def save_model(self, path: str): + """保存模型""" + self.model.save(path) +``` + +### 4.2 配置类 (config/config.py) + +```python +from pydantic import BaseSettings, Field +from typing import List, Dict, Optional +from dataclasses import dataclass + +@dataclass +class StockFilterConfig: + """股票过滤器配置""" + exclude_cyb: bool = True # 排除创业板 + exclude_kcb: bool = True # 排除科创板 + exclude_bj: bool = True # 排除北交所 + exclude_st: bool = True # 排除ST股票 + +@dataclass +class MarketCapSelectorConfig: + """市值选择器配置""" + enabled: bool = True # 是否启用 + n: int = 100 # 选择前 n 只 + ascending: bool = False # False=最大市值, True=最小市值 + market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic) + +@dataclass +class ProcessorConfig: + """处理器配置""" + name: str + params: Dict = Field(default_factory=dict) + +class TrainingConfig(BaseSettings): + """训练配置类""" + + # === 数据配置(必填)=== + feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个 + target_col: str = "target" # 目标变量列名 + date_col: str = "trade_date" # 日期列名 + code_col: str = "ts_code" # 股票代码列名 + + # === 日期划分(必填)=== + train_start: str = Field(..., description="训练期开始 YYYYMMDD") + train_end: str = Field(..., description="训练期结束 YYYYMMDD") + test_start: str = Field(..., description="测试期开始 YYYYMMDD") + test_end: str = Field(..., description="测试期结束 YYYYMMDD") + + # === 股票池配置 === + stock_filter: StockFilterConfig = Field( + default_factory=lambda: StockFilterConfig( + exclude_cyb=True, + exclude_kcb=True, + exclude_bj=True, + exclude_st=True, + ) + ) + stock_selector: Optional[MarketCapSelectorConfig] = Field( + default_factory=lambda: MarketCapSelectorConfig( + enabled=True, + n=100, + ascending=False, + market_cap_col="total_mv", + ) + ) + # 注意:如果 stock_selector = None,则跳过市值选择 + + # === 模型配置 === + model_type: str = "lightgbm" + model_params: Dict = Field(default_factory=dict) + + # === 处理器配置 === + processors: List[ProcessorConfig] = Field( + default_factory=lambda: [ + ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}), + ProcessorConfig(name="cs_standard_scaler", params={}), + ] + ) + + # === 评估指标 === + metrics: List[str] = ["ic", "rank_ic", "mse", "mae"] + + # === 持久化配置 === + persist_model: bool = False # 默认不持久化 + model_save_path: Optional[str] = None # 持久化路径 + + # === 输出配置 === + output_dir: str = "output" + save_predictions: bool = True +``` + +## 5. 使用示例 + +### 5.1 基本用法 + +```python +from src.training import Trainer, LightGBMModel, DateSplitter +from src.training.config import StockFilterConfig, MarketCapSelectorConfig +from src.training.core.stock_pool_manager import StockPoolManager +from src.training.components.processors import Winsorizer, CrossSectionalStandardScaler +from src.factors import FactorEngine +import polars as pl + +# 1. 因子计算(全市场数据) +engine = FactorEngine() +all_data = engine.compute( + factor_names=["factor1", "factor2", "factor3", "future_return_5d"], + start_date="20200101", + end_date="20231231", + # 不指定 stock_codes,获取全市场数据 +) + +# 2. 创建股票池管理器(每日独立筛选) +pool_manager = StockPoolManager( + filter_config=StockFilterConfig( + exclude_cyb=True, + exclude_kcb=True, + exclude_bj=True, + exclude_st=True, + ), + selector_config=MarketCapSelectorConfig( + enabled=True, + n=100, + ascending=False, # 最大市值 + ), + data_router=data_router, # 用于获取 daily_basic 数据 +) + +# 3. 创建模型 +model = LightGBMModel( + objective="regression", + n_estimators=100, + learning_rate=0.05 +) + +# 4. 创建处理器 +processors = [ + Winsorizer(lower=0.01, upper=0.99, by_date=False), # 全局缩尾 + CrossSectionalStandardScaler(), # 截面标准化(每天独立) +] + +# 5. 创建划分器 +splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231" +) + +# 6. 创建训练器 +trainer = Trainer( + model=model, + pool_manager=pool_manager, # 每日筛选股票池 + processors=processors, + splitter=splitter, + target_col="future_return_5d", + feature_cols=["factor1", "factor2", "factor3"], + persist_model=True, # 启用持久化 + model_save_path="output/model.pkl", +) + +# 7. 执行训练(传入全市场数据) +trainer.train(all_data) + +# 8. 获取结果 +results = trainer.get_results() # 包含预测值 +metrics = trainer.get_metrics() # IC, RankIC, etc. +print("评估指标:", metrics) + +# 9. 保存结果 +trainer.save_results("output/predictions.csv") + +# 10. 加载模型并预测新数据 +loaded_model = LightGBMModel.load("output/model.pkl") +new_predictions = loaded_model.predict(new_data) +``` + +### 5.2 使用配置驱动 + +```python +from src.training.config import TrainingConfig, StockFilterConfig +from src.training import Trainer +from src.training.registry import ModelRegistry, ProcessorRegistry +from src.training.core.stock_pool_manager import StockPoolManager +from src.factors import FactorEngine + +# 1. 配置(必填字段校验) +config = TrainingConfig( + feature_cols=["factor1", "factor2", "factor3"], + target_col="future_return_5d", + train_start="20200101", + train_end="20221231", + test_start="20230101", + test_end="20231231", + stock_filter=StockFilterConfig( + exclude_cyb=True, + exclude_kcb=True, + ), + stock_selector=None, # 跳过市值选择 + persist_model=False, # 不持久化 +) + +# 2. 因子计算 +engine = FactorEngine() +all_data = engine.compute( + factor_names=config.feature_cols + [config.target_col], + start_date=config.train_start, + end_date=config.test_end, +) + +# 3. 从配置创建组件 +model = ModelRegistry.get_model(config.model_type)(**config.model_params) +processors = [ + ProcessorRegistry.get_processor(p.name)(**p.params) + for p in config.processors +] + +# 4. 创建股票池管理器 +pool_manager = StockPoolManager( + filter_config=config.stock_filter, + selector_config=config.stock_selector, + data_router=data_router, +) + +# 5. 创建并运行训练器 +trainer = Trainer( + model=model, + pool_manager=pool_manager, + processors=processors, + splitter=DateSplitter( + train_start=config.train_start, + train_end=config.train_end, + test_start=config.test_start, + test_end=config.test_end, + ), + target_col=config.target_col, + feature_cols=config.feature_cols, +) + +trainer.train(all_data) +results = trainer.get_results() +metrics = trainer.get_metrics() +``` + +## 6. 实现顺序 + +按以下顺序实现和提交: + +### Commit 1: 基础架构 +- `training/__init__.py` +- `training/components/__init__.py` +- `training/components/base.py`(BaseModel, BaseProcessor,含 save/load) +- `training/registry.py`(组件注册中心) + +### Commit 2: 数据划分 +- `training/components/splitters.py`(DateSplitter,仅一次性划分) + +### Commit 3: 股票池选择器配置 +- `training/components/selectors.py`(StockFilterConfig, MarketCapSelectorConfig) + +### Commit 4: 数据处理器 +- `training/components/processors/__init__.py` +- `training/components/processors/transforms.py` + - Winsorizer + - StandardScaler + - CrossSectionalStandardScaler + +### Commit 5: LightGBM 模型 +- `training/components/models/__init__.py` +- `training/components/models/lightgbm.py`(含 save_model/load_model) + +### Commit 6: 评估指标 +- `training/components/metrics/__init__.py` +- `training/components/metrics/metrics.py`(IC, RankIC, MSE, MAE) + +### Commit 7: 股票池管理器 +- `training/core/__init__.py` +- `training/core/stock_pool_manager.py`(每日独立筛选) + +### Commit 8: Trainer 训练器 +- `training/core/trainer.py` + +### Commit 9: 配置管理 +- `training/config/__init__.py` +- `training/config/config.py`(TrainingConfig,含必填校验) + +### Commit 10: 预留实验模块 +- `experiment/__init__.py` + +## 7. 注意事项 + +### 7.1 股票池处理顺序(每日) + +``` +当日所有股票数据 + ↓ +代码过滤(创业板、ST等) + ↓ +查询 daily_basic 获取当日市值 + ↓ +市值选择(前N只) + ↓ +返回当日选中股票 +``` + +- 每日独立筛选,市值动态变化 +- 市值数据仅从 daily_basic 获取 +- 市值数据绝不混入特征矩阵 + +### 7.2 Processor 阶段行为 + +| Processor | 训练集 | 测试集 | +|-----------|--------|--------| +| StandardScaler | fit_transform | transform(使用训练集参数) | +| CrossSectionalStandardScaler | transform | transform(每天独立) | +| Winsorizer (global) | fit_transform | transform(使用训练集参数) | +| Winsorizer (by_date) | transform | transform(每天独立) | + +### 7.3 依赖关系 + +- 使用 Polars 进行数据处理 +- LightGBM 用于模型训练 +- Pydantic 用于配置管理 +- 市值数据来自 daily_basic 表(独立数据源) + +### 7.4 删除的功能 + +以下原计划在本次实现中删除: + +1. **特征选择**(processors/selectors.py) +2. **滚动训练**(WalkForward, ExpandingWindow) +3. **结果分析工具**(复杂分析功能) +4. **validator.py, evaluator.py**(简化为 metrics.py) + +### 7.5 新增功能 + +1. **StockPoolManager**(每日独立筛选) +2. **模型持久化**(save/load,默认关闭) +3. **配置必填校验**(feature_cols, 日期范围) diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py index 035ed4e..d1b29fc 100644 --- a/src/data/api_wrappers/__init__.py +++ b/src/data/api_wrappers/__init__.py @@ -11,16 +11,19 @@ Available APIs: - api_trade_cal: Trading calendar (交易日历) - api_namechange: Stock name change history (股票曾用名) - api_bak_basic: Stock historical list (股票历史列表) + - api_stock_st: ST stock list (ST股票列表) Example: >>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic >>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic + >>> from src.data.api_wrappers import get_stock_st, sync_stock_st >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') >>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131') >>> daily_basic = get_daily_basic(trade_date='20240101') >>> stocks = get_stock_basic() >>> calendar = get_trade_cal('20240101', '20240131') >>> bak_basic = get_bak_basic(trade_date='20240101') + >>> stock_st = get_stock_st(trade_date='20240101') """ from src.data.api_wrappers.api_daily import ( @@ -49,6 +52,11 @@ from src.data.api_wrappers.financial_data.api_income import ( from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks +from src.data.api_wrappers.api_stock_st import ( + get_stock_st, + sync_stock_st, + StockSTSync, +) from src.data.api_wrappers.api_trade_cal import ( get_trade_cal, get_trading_days, @@ -92,4 +100,8 @@ __all__ = [ "get_first_trading_day", "get_last_trading_day", "sync_trade_cal_cache", + # ST stock list + "get_stock_st", + "sync_stock_st", + "StockSTSync", ] diff --git a/src/data/api_wrappers/api.md b/src/data/api_wrappers/api.md index 71383b5..6d21a29 100644 --- a/src/data/api_wrappers/api.md +++ b/src/data/api_wrappers/api.md @@ -565,4 +565,57 @@ df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code, 16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661 17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276 18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446 -19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214 \ No newline at end of file +19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214 + + +ST股票列表 +接口:stock_st,可以通过数据工具调试和查看数据。 +描述:获取ST股票列表,可根据交易日期获取历史上每天的ST列表 +权限:3000积分起 +提示:每天上午9:20更新,单次请求最大返回1000行数据,可循环提取,本接口数据从20160101开始,太早历史无法补齐 + + + +输入参数 + +名称 类型 必选 描述 +ts_code str N 股票代码 +trade_date str N 交易日期(格式:YYYYMMDD下同) +start_date str N 开始时间 +end_date str N 结束时间 + + +输出参数 + +名称 类型 默认显示 描述 +ts_code str Y 股票代码 +name str Y 股票名称 +trade_date str Y 交易日期 +type str Y 类型 +type_name str Y 类型名称 + + +接口用法 + + +pro = ts.pro_api() + +#获取20250813日所有的ST股票 +df = pro.stock_st(trade_date='20250813') + + + +数据样例 + + ts_code name trade_date type type_name +0 300313.SZ *ST天山 20250813 ST 风险警示板 +1 605081.SH *ST太和 20250813 ST 风险警示板 +2 300391.SZ *ST长药 20250813 ST 风险警示板 +3 300343.SZ ST联创 20250813 ST 风险警示板 +4 300044.SZ ST赛为 20250813 ST 风险警示板 +.. ... ... ... ... ... +170 300175.SZ ST朗源 20250813 ST 风险警示板 +171 603721.SH *ST天择 20250813 ST 风险警示板 +172 600289.SH ST信通 20250813 ST 风险警示板 +173 000929.SZ *ST兰黄 20250813 ST 风险警示板 +174 000638.SZ *ST万方 20250813 ST 风险警示板 \ No newline at end of file diff --git a/src/data/api_wrappers/api_stock_st.py b/src/data/api_wrappers/api_stock_st.py new file mode 100644 index 0000000..f2dc02c --- /dev/null +++ b/src/data/api_wrappers/api_stock_st.py @@ -0,0 +1,147 @@ +"""ST股票列表接口。 + +获取ST股票列表数据,可根据交易日期获取历史上每天的ST列表。 +数据从20160101开始可用,每天上午9:20更新。 +""" + +import pandas as pd +from typing import Optional + +from src.data.client import TushareClient +from src.data.api_wrappers.base_sync import DateBasedSync + + +def get_stock_st( + trade_date: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ts_code: Optional[str] = None, +) -> pd.DataFrame: + """Fetch ST stock list from Tushare. + + This interface retrieves the daily ST stock list including stock codes, + names, and ST type information. Data is available from 20160101 onwards. + Updates at 9:20 AM daily. + + Args: + trade_date: Specific trade date in YYYYMMDD format + start_date: Start date for date range query (YYYYMMDD format) + end_date: End date for date range query (YYYYMMDD format) + ts_code: Stock code filter (optional, e.g., '000001.SZ') + + Returns: + pd.DataFrame with columns: + - ts_code: Stock code + - name: Stock name + - trade_date: Trade date (YYYYMMDD) + - type: Type code + - type_name: Type name (风险警示板) + + Example: + >>> # Get all ST stocks for a single date + >>> data = get_stock_st(trade_date='20240101') + >>> + >>> # Get date range data + >>> data = get_stock_st(start_date='20240101', end_date='20240131') + >>> + >>> # Get specific stock ST history + >>> data = get_stock_st(ts_code='000001.SZ') + """ + client = TushareClient() + + # Build parameters + params = {} + if trade_date: + params["trade_date"] = trade_date + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + if ts_code: + params["ts_code"] = ts_code + + # Fetch data + data = client.query("stock_st", **params) + + return data + + +class StockSTSync(DateBasedSync): + """ST股票列表批量同步管理器,支持全量/增量同步。 + + 继承自 DateBasedSync,按日期顺序获取数据。 + 数据从 2016 年开始可用,单次请求最大返回1000行数据。 + + Example: + >>> sync = StockSTSync() + >>> results = sync.sync_all() # 增量同步 + >>> results = sync.sync_all(force_full=True) # 全量同步 + >>> preview = sync.preview_sync() # 预览 + """ + + table_name = "stock_st" + default_start_date = "20160101" + + # 表结构定义 + TABLE_SCHEMA = { + "ts_code": "VARCHAR(16) NOT NULL", + "name": "VARCHAR(50)", + "trade_date": "DATE NOT NULL", + "type": "VARCHAR(10)", + "type_name": "VARCHAR(50)", + } + + # 索引定义 + TABLE_INDEXES = [ + ("idx_stock_st_date_code", ["trade_date", "ts_code"]), + ] + + # 主键定义 + PRIMARY_KEY = ("trade_date", "ts_code") + + def fetch_single_date(self, trade_date: str) -> pd.DataFrame: + """获取单日的ST股票列表数据。 + + Args: + trade_date: 交易日期(YYYYMMDD) + + Returns: + 包含当日ST股票列表的 DataFrame + """ + return get_stock_st(trade_date=trade_date) + + +def sync_stock_st( + start_date: Optional[str] = None, + end_date: Optional[str] = None, + force_full: bool = False, +) -> pd.DataFrame: + """Sync ST stock list to DuckDB with intelligent incremental sync. + + Logic: + - If table doesn't exist: create table + composite index (trade_date, ts_code) + full sync + - If table exists: incremental sync from last_date + 1 + + Args: + start_date: Start date for sync (YYYYMMDD format, default: 20160101 for full, last_date+1 for incremental) + end_date: End date for sync (YYYYMMDD format, default: today) + force_full: If True, force full reload from 20160101 + + Returns: + pd.DataFrame with synced data + """ + sync_manager = StockSTSync() + return sync_manager.sync_all( + start_date=start_date, + end_date=end_date, + force_full=force_full, + ) + + +if __name__ == "__main__": + # Test sync + result = sync_stock_st(end_date="20240102") + print(f"Synced {len(result)} records") + if not result.empty: + print("\nSample data:") + print(result.head()) diff --git a/src/data/api_wrappers/api_trade_cal.py b/src/data/api_wrappers/api_trade_cal.py index 6250d89..9839a2a 100644 --- a/src/data/api_wrappers/api_trade_cal.py +++ b/src/data/api_wrappers/api_trade_cal.py @@ -14,7 +14,6 @@ from src.config.settings import get_settings _cache_synced = False - # Trading calendar cache file path def _get_cache_path() -> Path: """Get the cache file path for trade calendar.""" diff --git a/src/data/api_wrappers/base_sync.py b/src/data/api_wrappers/base_sync.py index 118c4ad..4568f03 100644 --- a/src/data/api_wrappers/base_sync.py +++ b/src/data/api_wrappers/base_sync.py @@ -63,15 +63,15 @@ class BaseDataSync(ABC): table_name: str = "" # 子类必须覆盖 DEFAULT_START_DATE = "20180101" DEFAULT_MAX_WORKERS = get_settings().threads - + # 表结构定义(子类可覆盖) # 格式: {"column_name": "SQL_TYPE", ...} TABLE_SCHEMA: Dict[str, str] = {} - + # 索引定义(子类可覆盖) # 格式: [("index_name", ["col1", "col2"]), ...] TABLE_INDEXES: List[tuple] = [] - + # 主键定义(子类可覆盖) # 格式: ("col1", "col2") PRIMARY_KEY: tuple = () @@ -325,7 +325,9 @@ class BaseDataSync(ABC): try: print(f"[{class_name}] Probe: {probe_description}") - print(f"[{class_name}] Probe: Inserting {len(probe_data)} sample records...") + print( + f"[{class_name}] Probe: Inserting {len(probe_data)} sample records..." + ) # 插入样本数据 storage.save(self.table_name, probe_data, mode="append") @@ -344,18 +346,20 @@ class BaseDataSync(ABC): # 清空表(truncate) print(f"[{class_name}] Probe: Cleaning up sample data...") storage._connection.execute(f'DELETE FROM "{self.table_name}"') - + # 验证表已清空 count_result = storage._connection.execute( f'SELECT COUNT(*) FROM "{self.table_name}"' ).fetchone() remaining = count_result[0] if count_result else -1 - + if remaining == 0: print(f"[{class_name}] Probe: SUCCESS - Table verified and cleaned") return True else: - print(f"[{class_name}] Probe: WARNING - {remaining} rows remaining after cleanup") + print( + f"[{class_name}] Probe: WARNING - {remaining} rows remaining after cleanup" + ) return True # 仍然继续,因为主要目的是验证结构 except Exception as e: @@ -395,44 +399,50 @@ class BaseDataSync(ABC): 子类可以覆盖此方法以自定义建表逻辑。 """ storage = Storage() - + if storage.exists(self.table_name): return - + if not self.TABLE_SCHEMA: - print(f"[{self.__class__.__name__}] TABLE_SCHEMA not defined, skipping table creation") + print( + f"[{self.__class__.__name__}] TABLE_SCHEMA not defined, skipping table creation" + ) return - + # 构建列定义 columns_def = [] for col_name, col_type in self.TABLE_SCHEMA.items(): columns_def.append(f'"{col_name}" {col_type}') - + # 添加主键约束 if self.PRIMARY_KEY: - pk_cols = ', '.join(f'"{col}"' for col in self.PRIMARY_KEY) + pk_cols = ", ".join(f'"{col}"' for col in self.PRIMARY_KEY) columns_def.append(f"PRIMARY KEY ({pk_cols})") - + columns_sql = ", ".join(columns_def) create_sql = f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ({columns_sql})' - + try: storage._connection.execute(create_sql) print(f"[{self.__class__.__name__}] Created table '{self.table_name}'") except Exception as e: print(f"[{self.__class__.__name__}] Error creating table: {e}") raise - + # 创建索引 for idx_name, idx_cols in self.TABLE_INDEXES: try: - idx_cols_sql = ', '.join(f'"{col}"' for col in idx_cols) + idx_cols_sql = ", ".join(f'"{col}"' for col in idx_cols) storage._connection.execute( f'CREATE INDEX IF NOT EXISTS "{idx_name}" ON "{self.table_name}"({idx_cols_sql})' ) - print(f"[{self.__class__.__name__}] Created index '{idx_name}' on {idx_cols}") + print( + f"[{self.__class__.__name__}] Created index '{idx_name}' on {idx_cols}" + ) except Exception as e: - print(f"[{self.__class__.__name__}] Error creating index {idx_name}: {e}") + print( + f"[{self.__class__.__name__}] Error creating index {idx_name}: {e}" + ) @abstractmethod def preview_sync( @@ -863,28 +873,30 @@ class StockBasedSync(BaseDataSync): # 首次同步探测:验证表结构是否正常 if self._should_probe_table(): - print(f"[{class_name}] Table '{self.table_name}' is empty or doesn't exist, probing...") + print( + f"[{class_name}] Table '{self.table_name}' is empty or doesn't exist, probing..." + ) # 使用第一只股票的完整日期范围数据进行探测 probe_stock = stock_codes[0] - probe_data = self.fetch_single_stock( - probe_stock, sync_start_date, end_date - ) + probe_data = self.fetch_single_stock(probe_stock, sync_start_date, end_date) probe_desc = f"stock={probe_stock}, range={sync_start_date} to {end_date}" probe_success = self._probe_table_and_cleanup(probe_data, probe_desc) - + if not probe_success: - print(f"[{class_name}] Probe failed! Stopping sync to prevent data corruption.") + print( + f"[{class_name}] Probe failed! Stopping sync to prevent data corruption." + ) raise RuntimeError( f"Table '{self.table_name}' probe failed. " "Please check database schema and column mappings." ) if self._should_probe_table(): - print(f"[{class_name}] Table '{self.table_name}' is empty or doesn't exist, probing...") + print( + f"[{class_name}] Table '{self.table_name}' is empty or doesn't exist, probing..." + ) # 使用第一只股票的完整日期范围数据进行探测 probe_stock = stock_codes[0] - probe_data = self.fetch_single_stock( - probe_stock, sync_start_date, end_date - ) + probe_data = self.fetch_single_stock(probe_stock, sync_start_date, end_date) probe_desc = f"stock={probe_stock}, range={sync_start_date} to {end_date}" self._probe_table_and_cleanup(probe_data, probe_desc) @@ -1301,7 +1313,7 @@ class DateBasedSync(BaseDataSync): else: print(f"[{class_name}] Cannot create table: no sample data available") return pd.DataFrame() - + # 首次同步探测:验证表结构是否正常 if self._should_probe_table(): print(f"[{class_name}] Table '{self.table_name}' is empty, probing...") @@ -1335,10 +1347,8 @@ class DateBasedSync(BaseDataSync): if self._should_probe_table(): print(f"[{class_name}] Table '{self.table_name}' is empty, probing...") # 使用最近一个交易日的完整数据进行探测 - from src.data.api_wrappers.api_trade_cal import get_last_n_trading_days - last_days = get_last_n_trading_days(1, sync_end) - if last_days: - probe_date = last_days[0] + probe_date = get_last_trading_day(sync_start, sync_end) + if probe_date: probe_data = self.fetch_single_date(probe_date) probe_desc = f"date={probe_date}, all stocks" self._probe_table_and_cleanup(probe_data, probe_desc) diff --git a/src/data/sync.py b/src/data/sync.py index 822a1cf..42c082c 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -46,6 +46,7 @@ from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync from src.data.api_wrappers.api_pro_bar import sync_pro_bar from src.data.api_wrappers.api_bak_basic import sync_bak_basic from src.data.api_wrappers.api_daily_basic import sync_daily_basic +from src.data.api_wrappers.api_stock_st import sync_stock_st def preview_sync( @@ -161,6 +162,7 @@ def sync_all_data( 4. Pro Bar 数据 (sync_pro_bar) 5. 每日指标数据 (sync_daily_basic) 6. 历史股票列表 (sync_bak_basic) + 7. ST股票列表 (sync_stock_st) 【不包含的同步(需单独调用)】 - 财务数据: 利润表、资产负债表、现金流量表(季度更新) @@ -195,53 +197,53 @@ def sync_all_data( print("=" * 60) # 1. Sync trade calendar (always needed first) - print("\n[1/5] Syncing trade calendar cache...") + print("\n[1/7] Syncing trade calendar cache...") try: from src.data.api_wrappers import sync_trade_cal_cache sync_trade_cal_cache() results["trade_cal"] = pd.DataFrame() - print("[1/5] Trade calendar: OK") + print("[1/7] Trade calendar: OK") except Exception as e: - print(f"[1/5] Trade calendar: FAILED - {e}") + print(f"[1/7] Trade calendar: FAILED - {e}") results["trade_cal"] = pd.DataFrame() # 2. Sync stock basic info - print("\n[2/5] Syncing stock basic info...") + print("\n[2/7] Syncing stock basic info...") try: sync_all_stocks() results["stock_basic"] = pd.DataFrame() - print("[2/5] Stock basic: OK") + print("[2/7] Stock basic: OK") except Exception as e: - print(f"[2/5] Stock basic: FAILED - {e}") + print(f"[2/7] Stock basic: FAILED - {e}") results["stock_basic"] = pd.DataFrame() # 3. Sync daily market data - print("\n[3/5] Syncing daily market data...") - try: - # 确保表存在 - from src.data.api_wrappers.api_daily import DailySync - - DailySync().ensure_table_exists() - - daily_result = sync_daily( - force_full=force_full, - max_workers=max_workers, - dry_run=dry_run, - ) - results["daily"] = daily_result - total_daily_records = ( - sum(len(df) for df in daily_result.values()) if daily_result else 0 - ) - print( - f"[3/5] Daily data: OK ({total_daily_records} records from {len(daily_result)} stocks)" - ) - except Exception as e: - print(f"[3/5] Daily data: FAILED - {e}") - results["daily"] = pd.DataFrame() + # print("\n[3/7] Syncing daily market data...") + # try: + # # 确保表存在 + # from src.data.api_wrappers.api_daily import DailySync + # + # DailySync().ensure_table_exists() + # + # daily_result = sync_daily( + # force_full=force_full, + # max_workers=max_workers, + # dry_run=dry_run, + # ) + # results["daily"] = daily_result + # total_daily_records = ( + # sum(len(df) for df in daily_result.values()) if daily_result else 0 + # ) + # print( + # f"[3/7] Daily data: OK ({total_daily_records} records from {len(daily_result)} stocks)" + # ) + # except Exception as e: + # print(f"[3/7] Daily data: FAILED - {e}") + # results["daily"] = pd.DataFrame() # 4. Sync Pro Bar data - print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...") + print("\n[4/7] Syncing Pro Bar data (with adj, tor, vr)...") try: # 确保表存在 from src.data.api_wrappers.api_pro_bar import ProBarSync @@ -258,15 +260,15 @@ def sync_all_data( sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0 ) print( - f"[4/6] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)" + f"[4/7] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)" ) except Exception as e: - print(f"[4/6] Pro Bar data: FAILED - {e}") + print(f"[4/7] Pro Bar data: FAILED - {e}") results["pro_bar"] = pd.DataFrame() # 5. Sync daily basic indicators print( - "\n[5/6] Syncing daily basic indicators (PE, PB, turnover rate, market value)..." + "\n[5/7] Syncing daily basic indicators (PE, PB, turnover rate, market value)..." ) try: # 确保表存在 @@ -276,13 +278,13 @@ def sync_all_data( daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run) results["daily_basic"] = daily_basic_result - print(f"[5/6] Daily basic: OK ({len(daily_basic_result)} records)") + print(f"[5/7] Daily basic: OK ({len(daily_basic_result)} records)") except Exception as e: - print(f"[5/6] Daily basic: FAILED - {e}") + print(f"[5/7] Daily basic: FAILED - {e}") results["daily_basic"] = pd.DataFrame() # 6. Sync stock historical list (bak_basic) - print("\n[6/6] Syncing stock historical list (bak_basic)...") + print("\n[6/7] Syncing stock historical list (bak_basic)...") try: # 确保表存在 from src.data.api_wrappers.api_bak_basic import BakBasicSync @@ -291,11 +293,26 @@ def sync_all_data( bak_basic_result = sync_bak_basic(force_full=force_full) results["bak_basic"] = bak_basic_result - print(f"[6/6] Bak basic: OK ({len(bak_basic_result)} records)") + print(f"[6/7] Bak basic: OK ({len(bak_basic_result)} records)") except Exception as e: - print(f"[6/6] Bak basic: FAILED - {e}") + print(f"[6/7] Bak basic: FAILED - {e}") results["bak_basic"] = pd.DataFrame() + # 7. Sync ST stock list + print("\n[7/7] Syncing ST stock list...") + try: + # 确保表存在 + from src.data.api_wrappers.api_stock_st import StockSTSync + + StockSTSync().ensure_table_exists() + + stock_st_result = sync_stock_st(force_full=force_full) + results["stock_st"] = stock_st_result + print(f"[7/7] ST stock list: OK ({len(stock_st_result)} records)") + except Exception as e: + print(f"[7/7] ST stock list: FAILED - {e}") + results["stock_st"] = pd.DataFrame() + # Summary print("\n" + "=" * 60) print("[sync_all_data] Sync Summary") diff --git a/tests/test_stock_st.py b/tests/test_stock_st.py new file mode 100644 index 0000000..7e5b4c8 --- /dev/null +++ b/tests/test_stock_st.py @@ -0,0 +1,143 @@ +"""Test suite for stock_st API wrapper.""" + +import pytest +import pandas as pd +from unittest.mock import patch, MagicMock + +from src.data.api_wrappers.api_stock_st import get_stock_st, sync_stock_st, StockSTSync + + +class TestStockST: + """Test suite for stock_st API wrapper.""" + + @patch("src.data.api_wrappers.api_stock_st.TushareClient") + def test_get_by_date(self, mock_client_class): + """Test fetching ST stock list by date.""" + # Setup mock + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["300313.SZ", "605081.SH", "300391.SZ"], + "name": ["*ST天山", "*ST太和", "*ST长药"], + "trade_date": ["20240101", "20240101", "20240101"], + "type": ["ST", "ST", "ST"], + "type_name": ["风险警示板", "风险警示板", "风险警示板"], + } + ) + + # Test + result = get_stock_st(trade_date="20240101") + + # Assert + assert not result.empty + assert len(result) == 3 + assert "ts_code" in result.columns + assert "name" in result.columns + assert "trade_date" in result.columns + assert "type" in result.columns + assert "type_name" in result.columns + mock_client.query.assert_called_once() + + @patch("src.data.api_wrappers.api_stock_st.TushareClient") + def test_get_by_stock(self, mock_client_class): + """Test fetching ST history by stock code.""" + # Setup mock + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["300313.SZ", "300313.SZ"], + "name": ["*ST天山", "*ST天山"], + "trade_date": ["20240101", "20240102"], + "type": ["ST", "ST"], + "type_name": ["风险警示板", "风险警示板"], + } + ) + + # Test + result = get_stock_st( + ts_code="300313.SZ", start_date="20240101", end_date="20240102" + ) + + # Assert + assert not result.empty + assert len(result) == 2 + mock_client.query.assert_called_once() + + @patch("src.data.api_wrappers.api_stock_st.TushareClient") + def test_empty_response(self, mock_client_class): + """Test handling empty response.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame() + + result = get_stock_st(trade_date="20240101") + assert result.empty + + @patch("src.data.api_wrappers.api_stock_st.TushareClient") + def test_get_by_date_range(self, mock_client_class): + """Test fetching ST stock list by date range.""" + # Setup mock + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["300313.SZ"], + "name": ["*ST天山"], + "trade_date": ["20240101"], + "type": ["ST"], + "type_name": ["风险警示板"], + } + ) + + # Test + result = get_stock_st(start_date="20240101", end_date="20240131") + + # Assert + assert not result.empty + mock_client.query.assert_called_once() + + +class TestStockSTSync: + """Test suite for StockSTSync class.""" + + def test_sync_class_attributes(self): + """Test that sync class has correct attributes.""" + sync = StockSTSync() + assert sync.table_name == "stock_st" + assert sync.default_start_date == "20160101" + assert "ts_code" in sync.TABLE_SCHEMA + assert "trade_date" in sync.TABLE_SCHEMA + assert "name" in sync.TABLE_SCHEMA + assert "type" in sync.TABLE_SCHEMA + assert "type_name" in sync.TABLE_SCHEMA + assert sync.PRIMARY_KEY == ("trade_date", "ts_code") + + @patch("src.data.api_wrappers.api_stock_st.TushareClient") + def test_fetch_single_date(self, mock_client_class): + """Test fetching single date data.""" + # Setup mock + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["300313.SZ"], + "name": ["*ST天山"], + "trade_date": ["20240101"], + "type": ["ST"], + "type_name": ["风险警示板"], + } + ) + + # Test + sync = StockSTSync() + result = sync.fetch_single_date("20240101") + + # Assert + assert not result.empty + assert len(result) == 1 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])