diff --git a/AGENTS.md b/AGENTS.md index f1c7550..22db931 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -85,8 +85,7 @@ ProStock/ │ ├── data/ # 数据获取与存储 │ │ ├── api_wrappers/ # Tushare API 封装 │ │ │ ├── base_sync.py # 同步基础抽象类 -│ │ │ ├── api_daily.py # 日线数据接口 -│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口 +│ │ │ ├── api_pro_bar.py # Pro Bar 行情数据接口(主用) │ │ │ ├── api_stock_basic.py # 股票基础信息接口 │ │ │ ├── api_trade_cal.py # 交易日历接口 │ │ │ ├── api_bak_basic.py # 历史股票列表接口 diff --git a/README.md b/README.md index 74022d3..0b62f5e 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,21 @@ ProStock/ │ │ │ ├── data/ # 数据获取与存储 │ │ ├── api_wrappers/ # Tushare API 封装 -│ │ │ ├── api_daily.py # 日线数据接口 -│ │ │ ├── api_stock_basic.py # 股票基础信息 -│ │ │ └── api_trade_cal.py # 交易日历 +│ │ │ ├── api_pro_bar.py # Pro Bar行情数据接口(主用) +│ │ │ ├── api_stock_basic.py # 股票基础信息接口 +│ │ │ ├── api_trade_cal.py # 交易日历接口 +│ │ │ ├── api_bak_basic.py # 历史股票列表接口 +│ │ │ ├── api_namechange.py # 股票名称变更接口 +│ │ │ ├── api_stock_st.py # ST股票信息接口 +│ │ │ ├── api_daily_basic.py # 每日指标接口 +│ │ │ ├── api_stk_limit.py # 涨跌停价格接口 +│ │ │ ├── financial_data/ # 财务数据接口 +│ │ │ │ ├── api_income.py # 利润表接口 +│ │ │ │ ├── api_balance.py # 资产负债表接口 +│ │ │ │ ├── api_cashflow.py # 现金流量表接口 +│ │ │ │ ├── api_fina_indicator.py # 财务指标接口 +│ │ │ │ └── api_financial_sync.py # 财务数据同步调度中心 +│ │ │ └── __init__.py │ │ ├── client.py # Tushare 客户端(含限流) │ │ ├── config.py # 数据模块配置 │ │ ├── db_manager.py # DuckDB 表管理和同步 @@ -140,83 +152,123 @@ uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()" ### 因子计算 ```python -from src.factors import FactorEngine, DataLoader, DataSpec -from src.factors.base import CrossSectionalFactor, TimeSeriesFactor +from src.factors import FactorEngine +from src.factors.api import close, ts_mean, cs_rank import polars as pl -# 自定义截面因子:PE排名 -class PERankFactor(CrossSectionalFactor): - name = "pe_rank" - data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)] - - def compute(self, data) -> pl.Series: - cs = data.get_cross_section() - return cs["pe"].rank() +# 初始化引擎 +engine = FactorEngine() -# 自定义时序因子:20日移动平均 -class MA20Factor(TimeSeriesFactor): - name = "ma20" - data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)] - - def compute(self, data) -> pl.Series: - return data.get_column("close").rolling_mean(window_size=20) +# 方式1:使用 DSL 表达式注册 +engine.register("ma20", ts_mean(close, 20)) +engine.register("price_rank", cs_rank(close)) -# 执行计算 -loader = DataLoader(data_dir="data") -engine = FactorEngine(loader) +# 方式2:使用字符串表达式(推荐) +engine.add_factor("ma20", "ts_mean(close, 20)") +engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))") -# 计算截面因子 -pe_rank = PERankFactor() -result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131") +# 方式3:从 metadata 查询(需先在 metadata 中定义) +engine.add_factor("mom_5d") -# 计算时序因子 -ma20 = MA20Factor() -result2 = engine.compute(ma20, stock_codes=["000001.SZ"], - start_date="20240101", end_date="20240131") +# 计算因子 +result = engine.compute( + factor_names=["ma20", "price_rank"], + start_date="20240101", + end_date="20240131" +) -# 因子组合 -combined = 0.5 * pe_rank + 0.3 * ma20 +# 查看执行计划 +plan = engine.preview_plan("ma20") ``` ### 模型训练 ```python -from src.models import PluginRegistry, ProcessingPipeline -from src.models.core import PipelineStage +from src.training import ( + Trainer, + LightGBMModel, + DateSplitter, + StockPoolManager, + NullFiller, + Winsorizer, + StandardScaler, + STFilter, + check_data_quality, +) +from src.factors import FactorEngine import polars as pl -# 创建处理流水线 -pipeline = ProcessingPipeline([ - PluginRegistry.get_processor("dropna")(), - PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99), - PluginRegistry.get_processor("standard_scaler")(), -]) +# 1. 创建模型 +model = LightGBMModel(params={ + "objective": "regression", + "metric": "mae", + "num_leaves": 20, + "learning_rate": 0.01, + "n_estimators": 1000, +}) -# 准备数据 -data = pl.read_csv("features.csv") # 包含特征和标签 +# 2. 准备因子数据 +engine = FactorEngine() +engine.add_factor("ma5", "ts_mean(close, 5)") +engine.add_factor("ma20", "ts_mean(close, 20)") -# 划分训练/测试集 -from src.models.core import WalkForwardSplit -splitter = WalkForwardSplit(train_window=252, test_window=21) +# 计算全市场因子 +data = engine.compute( + factor_names=["ma5", "ma20", "future_return_5"], + start_date="20200101", + end_date="20231231" +) -# 获取 LightGBM 模型 -ModelClass = PluginRegistry.get_model("lightgbm") -model = ModelClass(task_type="regression", params={"n_estimators": 100}) +# 3. 创建数据处理器 +processors = [ + NullFiller(feature_cols=["ma5", "ma20"], strategy="mean"), + Winsorizer(feature_cols=["ma5", "ma20"], lower=0.01, upper=0.99), + StandardScaler(feature_cols=["ma5", "ma20"]), +] -# 训练循环 -for train_idx, test_idx in splitter.split(data): - train_data = data[train_idx] - test_data = data[test_idx] - - # 数据处理 - X_train = pipeline.fit_transform(train_data.drop("target")) - X_test = pipeline.transform(test_data.drop("target")) - y_train = train_data["target"] - y_test = test_data["target"] - - # 训练模型 - model.fit(X_train, y_train) - predictions = model.predict(X_test) +# 4. 创建股票池筛选函数 +def stock_pool_filter(df: pl.DataFrame) -> pl.Series: + """筛选小市值股票""" + code_filter = ( + ~df["ts_code"].str.starts_with("300") & # 排除创业板 + ~df["ts_code"].str.starts_with("688") # 排除科创板 + ) + return code_filter + +pool_manager = StockPoolManager( + filter_func=stock_pool_filter, + required_columns=["total_mv"], +) + +# 5. 创建过滤器 +st_filter = STFilter(data_router=engine.router) + +# 6. 创建数据划分器 +splitter = DateSplitter( + train_start="20200101", + train_end="20221231", + val_start="20230101", + val_end="20230630", + test_start="20230701", + test_end="20231231", +) + +# 7. 创建训练器 +trainer = Trainer( + model=model, + pool_manager=pool_manager, + processors=processors, + filters=[st_filter], + splitter=splitter, + target_col="future_return_5", + feature_cols=["ma5", "ma20"], +) + +# 8. 执行训练 +results = trainer.train(data) + +# 9. 获取预测结果 +predictions = trainer.get_results() ``` ## 核心设计 diff --git a/docs/api/API_INTERFACE_SPEC.md b/docs/api/API_INTERFACE_SPEC.md index 9fd8627..fb0681c 100644 --- a/docs/api/API_INTERFACE_SPEC.md +++ b/docs/api/API_INTERFACE_SPEC.md @@ -776,9 +776,9 @@ Skill 会自动: - [ ] 测试覆盖正常和异常情况 ## 11. 示例参考 -### 11.1 完整示例:api_daily.py +### 11.1 完整示例:api_pro_bar.py -参见 `src/data/api_wrappers/api_daily.py` - 按股票获取日线数据的完整实现。 +参见 `src/data/api_wrappers/api_pro_bar.py` - 按股票获取 Pro Bar 行情数据的完整实现(主力行情表)。 ### 11.2 完整示例:api_trade_cal.py diff --git a/docs/factor_calculation_flow.md b/docs/factor_calculation_flow.md index 8b80b94..3578e44 100644 --- a/docs/factor_calculation_flow.md +++ b/docs/factor_calculation_flow.md @@ -222,7 +222,7 @@ def _infer_data_specs(self, node, dependencies): ``` **DataSpec 说明**: -- `table`: 数据表名(pro_bar 或 daily) +- `table`: 数据表名(pro_bar 为主力行情表) - `columns`: 需要的字段列表 **注意**:数据获取使用用户传入的日期范围,不做自动扩展。时序因子(如 `ts_delay`、`ts_mean`)在数据不足时会返回 null,这是符合预期的行为。 @@ -377,19 +377,19 @@ def execute(self, plan, data): ### 7.1 用户代码 ```python -from src.factors import FactorEngine, FormulaParser, FunctionRegistry +from src.factors import FactorEngine # 1. 创建引擎 engine = FactorEngine() -# 2. 解析字符串表达式 -parser = FormulaParser(FunctionRegistry()) -expr = parser.parse("(close / ts_delay(close, 5)) - 1") +# 2. 使用字符串表达式注册因子(推荐) +engine.add_factor("returns_5d", "(close / ts_delay(close, 5)) - 1") -# 3. 注册因子 -engine.register("returns_5d", expr) +# 或者使用 DSL 表达式 +from src.factors.api import close, ts_delay +engine.register("returns_5d", (close / ts_delay(close, 5)) - 1) -# 4. 执行计算 +# 3. 执行计算 result = engine.compute( factor_names=["returns_5d"], start_date="20240101", @@ -400,23 +400,27 @@ result = engine.compute( ### 7.2 内部调用链 ``` +FactorEngine.add_factor() / register() + │ + └── 创建并缓存 ExecutionPlan + └── ExecutionPlanner.create_plan() + ├── DependencyExtractor.extract_dependencies() → {'close'} + ├── _infer_data_specs() → [DataSpec('pro_bar', ['close'], 5)] + └── PolarsTranslator.translate() → pl.col('close').shift(5).over('ts_code')... + FactorEngine.compute() │ - ├── 1. 创建 ExecutionPlan - │ └── ExecutionPlanner.create_plan() - │ ├── DependencyExtractor.extract_dependencies() → {'close'} - │ ├── _infer_data_specs() → [DataSpec('pro_bar', ['close'], 5)] - │ └── PolarsTranslator.translate() → pl.col('close').shift(5).over('ts_code')... - │ - ├── 2. 获取数据 - │ └── DataRouter.fetch_data([plan.data_specs]) - │ ├── _load_table('pro_bar', ['close'], start_date-5d, end_date) + ├── 1. 获取所有缓存的执行计划 + ├── 2. 合并数据规格 + │ └── _merge_data_specs() + ├── 3. 获取数据 + │ └── DataRouter.fetch_data(merged_specs) + │ ├── _load_table('pro_bar', ['close'], start_date, end_date) │ │ └── Storage.load_polars() → 查询 DuckDB │ └── _assemble_wide_table() → Polars DataFrame - │ - └── 3. 执行计算 - └── ComputeEngine.execute(plan, data) - └── data.with_columns([polars_expr.alias('returns_5d')]) + └── 4. 执行计算 + └── ComputeEngine.execute_plans(plans, data) + └── data.with_columns([polars_exprs...]) └── Polars 执行表达式计算 ``` diff --git a/docs/factor_implementation_analysis.md b/docs/factor_implementation_analysis.md index 86c7756..48df969 100644 --- a/docs/factor_implementation_analysis.md +++ b/docs/factor_implementation_analysis.md @@ -92,17 +92,17 @@ | 字段名 | 状态 | 数据来源 | 所属类别 | |--------|------|----------|----------| -| `close` | 可用 | daily/pro_bar 表 | 价格 | -| `open` | 可用 | daily/pro_bar 表 | 价格 | -| `high` | 可用 | daily/pro_bar 表 | 价格 | -| `low` | 可用 | daily/pro_bar 表 | 价格 | -| `vol` | 可用 | daily/pro_bar 表 | 成交量 | -| `amount` | 可用 | daily/pro_bar 表 | 成交额 | -| `pre_close` | 可用 | daily/pro_bar 表 | 价格 | -| `change` | 可用 | daily/pro_bar 表 | 价格变化 | -| `pct_chg` | 可用 | daily/pro_bar 表 | 涨跌幅 | -| `turnover_rate` | 可用 | daily/pro_bar 表 | 换手率 | -| `volume_ratio` | 可用 | daily/pro_bar 表 | 量比 | +| `close` | 可用 | pro_bar 表 | 价格 | +| `open` | 可用 | pro_bar 表 | 价格 | +| `high` | 可用 | pro_bar 表 | 价格 | +| `low` | 可用 | pro_bar 表 | 价格 | +| `vol` | 可用 | pro_bar 表 | 成交量 | +| `amount` | 可用 | pro_bar 表 | 成交额 | +| `pre_close` | 可用 | pro_bar 表 | 价格 | +| `change` | 可用 | pro_bar 表 | 价格变化 | +| `pct_chg` | 可用 | pro_bar 表 | 涨跌幅 | +| `turnover_rate` | 可用 | pro_bar 表 | 换手率 | +| `volume_ratio` | 可用 | pro_bar 表 | 量比 | ### 1.8 支持的运算符 @@ -482,7 +482,7 @@ spec = DataSpec( | 数据源 | 依赖因子数 | 实现难度 | 优先级 | |--------|------------|----------|--------| -| daily/pro_bar (已有) | ~40 | 低 | 高 | +| pro_bar (主力行情表) | ~40 | 低 | 高 | | 纯技术指标 (ts_*) | ~30 | 中 | 高 | | 筹码分布 (cyq) | ~50 | 中 | 中 | | 资金流向 (moneyflow) | ~30 | 中 | 中 | diff --git a/docs/n_income_factor_lifecycle.md b/docs/n_income_factor_lifecycle.md index 92f33b3..216b392 100644 --- a/docs/n_income_factor_lifecycle.md +++ b/docs/n_income_factor_lifecycle.md @@ -524,7 +524,7 @@ def prepare_data(...) -> pl.DataFrame: ```python # 系统自动识别 n_income → financial_income 表 (PIT) -close → daily 表 (DAILY) +close → pro_bar 表 (主力行情表) ``` ### 3. 财务数据清洗 @@ -584,10 +584,10 @@ CREATE TABLE financial_income ( ); ``` -### daily(日线行情) +### pro_bar(主力行情表) ```sql -CREATE TABLE daily ( +CREATE TABLE pro_bar ( ts_code VARCHAR, -- 股票代码 trade_date DATE, -- 交易日期 open DOUBLE, -- 开盘价 @@ -595,6 +595,10 @@ CREATE TABLE daily ( low DOUBLE, -- 最低价 close DOUBLE, -- 收盘价 vol BIGINT, -- 成交量 + turnover_rate DOUBLE, -- 换手率 + volume_ratio DOUBLE, -- 量比 ... -- 其他行情字段 ); ``` + +**说明**: pro_bar 表通过 Tushare Pro Bar 接口获取,包含后复权数据和换手率、量比等指标,是主力行情数据表。 diff --git a/src/data/api_wrappers/api_daily.py b/src/data/api_wrappers/api_daily.py deleted file mode 100644 index 1810278..0000000 --- a/src/data/api_wrappers/api_daily.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Simplified daily market data interface. - -A single function to fetch A股日线行情 data from Tushare. -Supports all output fields including tor (换手率) and vr (量比). - -This module provides both single-stock fetching (get_daily) and -batch synchronization (DailySync class) for daily market data. -""" - -import pandas as pd -from typing import Optional, List, Literal, Dict - -from src.data.client import TushareClient -from src.data.api_wrappers.base_sync import StockBasedSync - - -def get_daily( - ts_code: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - trade_date: Optional[str] = None, - adj: Literal[None, "qfq", "hfq"] = None, - factors: Optional[List[Literal["tor", "vr"]]] = None, - adjfactor: bool = False, -) -> pd.DataFrame: - """Fetch daily market data for A-share stocks. - - This is a simplified interface that combines rate limiting, API calls, - and error handling into a single function. - - Args: - ts_code: Stock code (e.g., '000001.SZ', '600000.SH') - start_date: Start date in YYYYMMDD format - end_date: End date in YYYYMMDD format - trade_date: Specific trade date in YYYYMMDD format - adj: Adjustment type - None, 'qfq' (forward), 'hfq' (backward) - factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio) - adjfactor: Whether to include adjustment factor - - Returns: - pd.DataFrame with daily market data containing: - - Base fields: ts_code, trade_date, open, high, low, close, pre_close, - change, pct_chg, vol, amount - - Factor fields (if requested): tor, vr - - Adjustment factor (if adjfactor=True): adjfactor - - Example: - >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') - >>> data = get_daily('600000.SH', factors=['tor', 'vr']) - """ - # Initialize client - client = TushareClient() - - # Build parameters - params = {"ts_code": ts_code} - - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - if trade_date: - params["trade_date"] = trade_date - if adj: - params["adj"] = adj - if factors: - # Tushare expects factors as comma-separated string, not list - if isinstance(factors, list): - factors_str = ",".join(factors) - else: - factors_str = factors - params["factors"] = factors_str - if adjfactor: - params["adjfactor"] = "True" - - # Fetch data using pro_bar (supports factors like tor, vr) - data = client.query("pro_bar", **params) - - return data - - -class DailySync(StockBasedSync): - """日线数据批量同步管理器,支持全量/增量同步。 - - 继承自 StockBasedSync,使用多线程按股票并发获取数据。 - - Example: - >>> sync = DailySync() - >>> results = sync.sync_all() # 增量同步 - >>> results = sync.sync_all(force_full=True) # 全量同步 - >>> preview = sync.preview_sync() # 预览 - """ - - table_name = "daily" - - # 表结构定义 - TABLE_SCHEMA = { - "ts_code": "VARCHAR(16) NOT NULL", - "trade_date": "DATE NOT NULL", - "open": "DOUBLE", - "high": "DOUBLE", - "low": "DOUBLE", - "close": "DOUBLE", - "pre_close": "DOUBLE", - "change": "DOUBLE", - "pct_chg": "DOUBLE", - "vol": "DOUBLE", - "amount": "DOUBLE", - "turnover_rate": "DOUBLE", - "volume_ratio": "DOUBLE", - } - - # 索引定义 - TABLE_INDEXES = [ - ("idx_daily_date_code", ["trade_date", "ts_code"]), - ] - - # 主键定义 - PRIMARY_KEY = ("ts_code", "trade_date") - - def fetch_single_stock( - self, - ts_code: str, - start_date: str, - end_date: str, - ) -> pd.DataFrame: - """获取单只股票的日线数据。 - - Args: - ts_code: 股票代码 - start_date: 起始日期(YYYYMMDD) - end_date: 结束日期(YYYYMMDD) - - Returns: - 包含日线数据的 DataFrame - """ - # 使用共享客户端进行跨线程速率限制 - data = self.client.query( - "pro_bar", - ts_code=ts_code, - start_date=start_date, - end_date=end_date, - factors="tor,vr", - ) - return data - - -def sync_daily( - force_full: bool = False, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - max_workers: Optional[int] = None, - dry_run: bool = False, -) -> Dict[str, pd.DataFrame]: - """同步所有股票的日线数据。 - - 这是日线数据同步的主要入口点。 - - Args: - force_full: 若为 True,强制从 20180101 完整重载 - start_date: 手动指定起始日期(YYYYMMDD) - end_date: 手动指定结束日期(默认为今天) - max_workers: 工作线程数(默认: 10) - dry_run: 若为 True,仅预览将要同步的内容,不写入数据 - - Returns: - 映射 ts_code 到 DataFrame 的字典 - - Example: - >>> # 首次同步(从 20180101 全量加载) - >>> result = sync_daily() - >>> - >>> # 后续同步(增量 - 仅新数据) - >>> result = sync_daily() - >>> - >>> # 强制完整重载 - >>> result = sync_daily(force_full=True) - >>> - >>> # 手动指定日期范围 - >>> result = sync_daily(start_date='20240101', end_date='20240131') - >>> - >>> # 自定义线程数 - >>> result = sync_daily(max_workers=20) - >>> - >>> # Dry run(仅预览) - >>> result = sync_daily(dry_run=True) - """ - sync_manager = DailySync(max_workers=max_workers) - return sync_manager.sync_all( - force_full=force_full, - start_date=start_date, - end_date=end_date, - dry_run=dry_run, - ) - - -def preview_daily_sync( - force_full: bool = False, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - sample_size: int = 3, -) -> dict: - """预览日线同步数据量和样本(不实际同步)。 - - 这是推荐的方式,可在实际同步前检查将要同步的内容。 - - Args: - force_full: 若为 True,预览全量同步(从 20180101) - start_date: 手动指定起始日期(覆盖自动检测) - end_date: 手动指定结束日期(默认为今天) - sample_size: 预览用样本股票数量(默认: 3) - - Returns: - 包含预览信息的字典: - { - 'sync_needed': bool, - 'stock_count': int, - 'start_date': str, - 'end_date': str, - 'estimated_records': int, - 'sample_data': pd.DataFrame, - 'mode': str, # 'full', 'incremental', 'partial', 或 'none' - } - - Example: - >>> # 预览将要同步的内容 - >>> preview = preview_daily_sync() - >>> - >>> # 预览全量同步 - >>> preview = preview_daily_sync(force_full=True) - >>> - >>> # 预览更多样本 - >>> preview = preview_daily_sync(sample_size=5) - """ - sync_manager = DailySync() - return sync_manager.preview_sync( - force_full=force_full, - start_date=start_date, - end_date=end_date, - sample_size=sample_size, - )