feat(data): 添加每日指标接口并优化因子引擎

- 新增 api_daily_basic.py 封装 Tushare 每日指标接口
- 因子引擎移除 lookback_days,支持 daily_basic 表字段路由
- 将每日指标纳入自动同步流程
- 删除废弃的 training/main.py
This commit is contained in:
2026-03-03 17:09:39 +08:00
parent 780284af7f
commit 53225b9443
12 changed files with 1132 additions and 433 deletions

View File

@@ -5,6 +5,7 @@ All wrapper files follow the naming convention: api_{data_type}.py
Available APIs:
- api_daily: Daily market data (日线行情)
- api_daily_basic: Daily basic indicators (每日指标换手率、PE、PB、市值等)
- api_pro_bar: Pro Bar universal market data (通用行情,后复权)
- api_stock_basic: Stock basic information (股票基本信息)
- api_trade_cal: Trading calendar (交易日历)
@@ -13,9 +14,10 @@ Available APIs:
Example:
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
>>> daily_basic = get_daily_basic(trade_date='20240101')
>>> stocks = get_stock_basic()
>>> calendar = get_trade_cal('20240101', '20240131')
>>> bak_basic = get_bak_basic(trade_date='20240101')
@@ -27,6 +29,12 @@ from src.data.api_wrappers.api_daily import (
preview_daily_sync,
DailySync,
)
from src.data.api_wrappers.api_daily_basic import (
get_daily_basic,
sync_daily_basic,
preview_daily_basic_sync,
DailyBasicSync,
)
from src.data.api_wrappers.api_pro_bar import (
get_pro_bar,
sync_pro_bar,
@@ -55,6 +63,11 @@ __all__ = [
"sync_daily",
"preview_daily_sync",
"DailySync",
# Daily basic indicators
"get_daily_basic",
"sync_daily_basic",
"preview_daily_basic_sync",
"DailyBasicSync",
# Pro Bar (universal market data)
"get_pro_bar",
"sync_pro_bar",

View File

@@ -495,4 +495,74 @@ df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011',
例如:
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
每日指标
接口daily_basic可以通过数据工具调试和查看数据。
更新时间交易日每日15点17点之间
描述获取全部股票每日重要的基本面指标可用于选股分析、报表展示等。单次请求最大返回6000条数据可按日线循环提取全部历史。
积分至少2000积分才可以调取5000积分无总量限制具体请参阅积分获取办法
输入参数
名称 类型 必选 描述
ts_code str Y 股票代码(二选一)
trade_date str N 交易日期 (二选一)
start_date str N 开始日期(YYYYMMDD)
end_date str N 结束日期(YYYYMMDD)
日期都填YYYYMMDD格式比如20181010
输出参数
名称 类型 描述
ts_code str TS股票代码
trade_date str 交易日期
close float 当日收盘价
turnover_rate float 换手率(%
turnover_rate_f float 换手率(自由流通股)
volume_ratio float 量比
pe float 市盈率(总市值/净利润, 亏损的PE为空
pe_ttm float 市盈率TTM亏损的PE为空
pb float 市净率(总市值/净资产)
ps float 市销率
ps_ttm float 市销率TTM
dv_ratio float 股息率 %
dv_ttm float 股息率TTM%
total_share float 总股本 (万股)
float_share float 流通股本 (万股)
free_share float 自由流通股本 (万)
total_mv float 总市值 (万元)
circ_mv float 流通市值(万元)
接口用法
pro = ts.pro_api()
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
或者
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
数据样例
ts_code trade_date turnover_rate volume_ratio pe pb
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214

View File

@@ -0,0 +1,252 @@
"""每日指标数据接口。
获取全部股票每日重要的基本面指标,包括换手率、市盈率、市净率、
总市值、流通市值等,可用于选股分析、报表展示等。
"""
from typing import Optional, Dict, Any
import pandas as pd
from src.data.client import TushareClient
from src.data.api_wrappers.base_sync import DateBasedSync
def get_daily_basic(
trade_date: Optional[str] = None,
ts_code: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
client: Optional[TushareClient] = None,
) -> pd.DataFrame:
"""Fetch daily basic indicators from Tushare.
This interface retrieves important daily fundamental indicators for all stocks,
including turnover rate, PE, PB, market value, etc. It can be used for stock
selection analysis and report display.
Note: At least one of trade_date or ts_code must be provided. The recommended
approach is to use trade_date to fetch data for all stocks on a specific date,
which is more efficient than fetching by individual stock codes.
Args:
trade_date: Specific trade date (YYYYMMDD format). Use this to get all
stocks' data for a single date. More efficient than ts_code.
ts_code: Stock code (e.g., '000001.SZ', '600000.SH'). Optional if
trade_date is provided.
start_date: Start date (YYYYMMDD format). Use with end_date for date range.
end_date: End date (YYYYMMDD format). Use with start_date for date range.
client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting.
Returns:
pd.DataFrame with columns:
- ts_code: TS stock code
- trade_date: Trade date (YYYYMMDD)
- close: Closing price
- turnover_rate: Turnover rate (%)
- turnover_rate_f: Turnover rate (free float shares)
- volume_ratio: Volume ratio
- pe: Price-to-earnings ratio (total market cap / net profit)
- pe_ttm: PE ratio (TTM)
- pb: Price-to-book ratio (total market cap / net assets)
- ps: Price-to-sales ratio
- ps_ttm: PS ratio (TTM)
- dv_ratio: Dividend yield (%)
- dv_ttm: Dividend yield (TTM) (%)
- total_share: Total shares (10k shares)
- float_share: Float shares (10k shares)
- free_share: Free float shares (10k shares)
- total_mv: Total market value (10k CNY)
- circ_mv: Circulating market value (10k CNY)
Example:
>>> # Get all stocks for a single date (recommended)
>>> data = get_daily_basic(trade_date='20240101')
>>>
>>> # Get specific stock data
>>> data = get_daily_basic(ts_code='000001.SZ', trade_date='20240101')
>>>
>>> # Get date range data for a specific stock
>>> data = get_daily_basic(
... ts_code='000001.SZ',
... start_date='20240101',
... end_date='20240131'
... )
"""
client = client or TushareClient()
# Build parameters
params = {}
if trade_date:
params["trade_date"] = trade_date
if ts_code:
params["ts_code"] = ts_code
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
# Fetch data using daily_basic API
data = client.query("daily_basic", **params)
# Rename date column if needed
if "date" in data.columns:
data = data.rename(columns={"date": "trade_date"})
return data
class DailyBasicSync(DateBasedSync):
"""每日指标数据批量同步管理器,支持全量/增量同步。
继承自 DateBasedSync按日期顺序获取数据。
每日指标数据适合按日期获取,一次 API 调用即可获取全市场数据。
Example:
>>> sync = DailyBasicSync()
>>> results = sync.sync_all() # 增量同步
>>> results = sync.sync_all(force_full=True) # 全量同步
>>> preview = sync.preview_sync() # 预览
"""
table_name = "daily_basic"
default_start_date = "20180101"
# 表结构定义
TABLE_SCHEMA = {
"ts_code": "VARCHAR(16) NOT NULL",
"trade_date": "DATE NOT NULL",
"close": "DOUBLE",
"turnover_rate": "DOUBLE",
"turnover_rate_f": "DOUBLE",
"volume_ratio": "DOUBLE",
"pe": "DOUBLE",
"pe_ttm": "DOUBLE",
"pb": "DOUBLE",
"ps": "DOUBLE",
"ps_ttm": "DOUBLE",
"dv_ratio": "DOUBLE",
"dv_ttm": "DOUBLE",
"total_share": "DOUBLE",
"float_share": "DOUBLE",
"free_share": "DOUBLE",
"total_mv": "DOUBLE",
"circ_mv": "DOUBLE",
}
# 索引定义
TABLE_INDEXES = [
("idx_daily_basic_date_code", ["trade_date", "ts_code"]),
]
# 主键定义
PRIMARY_KEY = ("ts_code", "trade_date")
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
"""获取单日的每日指标数据。
Args:
trade_date: 交易日期YYYYMMDD
Returns:
包含当日所有股票指标的 DataFrame
"""
# 使用 get_daily_basic 获取数据(传递共享 client
data = get_daily_basic(
trade_date=trade_date,
client=self.client, # 传递共享客户端以确保限流
)
return data
def sync_daily_basic(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
dry_run: bool = False,
) -> pd.DataFrame:
"""同步所有股票的每日指标数据。
这是每日指标数据同步的主要入口点。
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期YYYYMMDD
end_date: 手动指定结束日期(默认为今天)
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
同步的数据 DataFrame
Example:
>>> # 首次同步(从 20180101 全量加载)
>>> result = sync_daily_basic()
>>>
>>> # 后续同步(增量 - 仅新数据)
>>> result = sync_daily_basic()
>>>
>>> # 强制完整重载
>>> result = sync_daily_basic(force_full=True)
>>>
>>> # 手动指定日期范围
>>> result = sync_daily_basic(start_date='20240101', end_date='20240131')
>>>
>>> # Dry run仅预览
>>> result = sync_daily_basic(dry_run=True)
"""
sync_manager = DailyBasicSync()
return sync_manager.sync_all(
force_full=force_full,
start_date=start_date,
end_date=end_date,
dry_run=dry_run,
)
def preview_daily_basic_sync(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> Dict[str, Any]:
"""预览每日指标同步数据量和样本(不实际同步)。
这是推荐的方式,可在实际同步前检查将要同步的内容。
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本天数(默认: 3
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'date_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 或 'none'
}
Example:
>>> # 预览将要同步的内容
>>> preview = preview_daily_basic_sync()
>>>
>>> # 预览全量同步
>>> preview = preview_daily_basic_sync(force_full=True)
>>>
>>> # 预览更多样本
>>> preview = preview_daily_basic_sync(sample_size=5)
"""
sync_manager = DailyBasicSync()
return sync_manager.preview_sync(
force_full=force_full,
start_date=start_date,
end_date=end_date,
sample_size=sample_size,
)

View File

@@ -7,6 +7,7 @@
✅ 本模块包含的同步逻辑(每日更新):
- api_daily.py: 日线数据同步 (DailySync 类)
- api_daily_basic.py: 每日指标数据同步 (DailyBasicSync 类)
- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类)
- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类)
- api_stock_basic.py: 股票基本信息同步
@@ -44,6 +45,7 @@ from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
from src.data.api_wrappers.api_daily_basic import sync_daily_basic
def preview_sync(
@@ -157,7 +159,8 @@ def sync_all_data(
2. 股票基本信息 (sync_all_stocks)
3. 日线数据 (sync_daily)
4. Pro Bar 数据 (sync_pro_bar)
5. 历史股票列表 (sync_bak_basic)
5. 每日指标数据 (sync_daily_basic)
6. 历史股票列表 (sync_bak_basic)
【不包含的同步(需单独调用)】
- 财务数据: 利润表、资产负债表、现金流量表(季度更新)
@@ -238,7 +241,7 @@ def sync_all_data(
results["daily"] = pd.DataFrame()
# 4. Sync Pro Bar data
print("\n[4/5] Syncing Pro Bar data (with adj, tor, vr)...")
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...")
try:
# 确保表存在
from src.data.api_wrappers.api_pro_bar import ProBarSync
@@ -255,14 +258,31 @@ def sync_all_data(
sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0
)
print(
f"[4/5] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
f"[4/6] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
)
except Exception as e:
print(f"[4/5] Pro Bar data: FAILED - {e}")
print(f"[4/6] Pro Bar data: FAILED - {e}")
results["pro_bar"] = pd.DataFrame()
# 5. Sync stock historical list (bak_basic)
print("\n[5/5] Syncing stock historical list (bak_basic)...")
# 5. Sync daily basic indicators
print(
"\n[5/6] Syncing daily basic indicators (PE, PB, turnover rate, market value)..."
)
try:
# 确保表存在
from src.data.api_wrappers.api_daily_basic import DailyBasicSync
DailyBasicSync().ensure_table_exists()
daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run)
results["daily_basic"] = daily_basic_result
print(f"[5/6] Daily basic: OK ({len(daily_basic_result)} records)")
except Exception as e:
print(f"[5/6] Daily basic: FAILED - {e}")
results["daily_basic"] = pd.DataFrame()
# 6. Sync stock historical list (bak_basic)
print("\n[6/6] Syncing stock historical list (bak_basic)...")
try:
# 确保表存在
from src.data.api_wrappers.api_bak_basic import BakBasicSync
@@ -271,9 +291,9 @@ def sync_all_data(
bak_basic_result = sync_bak_basic(force_full=force_full)
results["bak_basic"] = bak_basic_result
print(f"[5/5] Bak basic: OK ({len(bak_basic_result)} records)")
print(f"[6/6] Bak basic: OK ({len(bak_basic_result)} records)")
except Exception as e:
print(f"[5/5] Bak basic: FAILED - {e}")
print(f"[6/6] Bak basic: FAILED - {e}")
results["bak_basic"] = pd.DataFrame()
# Summary
@@ -286,7 +306,7 @@ def sync_all_data(
total_records = sum(len(df) for df in data.values())
print(f" {data_type}: {len(data)} stocks, {total_records} total records")
else:
# bak_basic 返回的是 DataFrame
# daily_basic 和 bak_basic 返回的是 DataFrame
print(f" {data_type}: {len(data)} records")
print("=" * 60)
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
@@ -308,7 +328,7 @@ if __name__ == "__main__":
print("")
print(" # Or sync individual data types:")
print(" from src.data.sync import sync_all, preview_sync")
print(" from src.data.sync import sync_bak_basic")
print(" from src.data.api_wrappers import sync_daily_basic, sync_bak_basic")
print("")
print(" # Preview before sync (recommended)")
print(" preview = preview_sync()")

View File

@@ -69,16 +69,11 @@ class DataRouter:
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
max_lookback = 0
for spec in data_specs:
if spec.table not in required_tables:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
max_lookback = max(max_lookback, spec.lookback_days)
# 调整日期范围以包含回看期
adjusted_start = self._adjust_start_date(start_date, max_lookback)
# 从数据源获取各表数据
table_data = {}
@@ -86,7 +81,7 @@ class DataRouter:
df = self._load_table(
table_name=table_name,
columns=list(columns),
start_date=adjusted_start,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
@@ -95,11 +90,6 @@ class DataRouter:
# 组装核心宽表
core_table = self._assemble_wide_table(table_data, required_tables)
# 过滤到实际请求日期范围
core_table = core_table.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
return core_table
def _load_table(
@@ -265,34 +255,6 @@ class DataRouter:
return result
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
"""根据回看天数调整开始日期。
Args:
start_date: 原始开始日期 (YYYYMMDD)
lookback_days: 需要回看的交易日数
Returns:
调整后的开始日期
"""
# 简化的日期调整假设每月30天向前推移
# 实际应用中应该使用交易日历
year = int(start_date[:4])
month = int(start_date[4:6])
day = int(start_date[6:8])
total_days = lookback_days + 30 # 额外缓冲
day -= total_days
while day <= 0:
month -= 1
if month <= 0:
month = 12
year -= 1
day += 30
return f"{year:04d}{month:02d}{day:02d}"
def clear_cache(self) -> None:
"""清除数据缓存。"""
with self._lock:

View File

@@ -18,12 +18,10 @@ class DataSpec:
Attributes:
table: 数据表名称
columns: 需要的字段列表
lookback_days: 回看天数(用于时序计算)
"""
table: str
columns: List[str]
lookback_days: int = 1
@dataclass

View File

@@ -73,9 +73,9 @@ class ExecutionPlanner:
) -> List[DataSpec]:
"""从依赖推导数据规格。
根据表达式中的函数类型推断回看天数需求。
基础行情字段open, high, low, close, vol, amount, pre_close, change, pct_chg
默认从 pro_bar 表获取。
每日指标字段total_mv, circ_mv, pe, pb 等)从 daily_basic 表获取。
Args:
dependencies: 依赖的字段集合
@@ -84,10 +84,6 @@ class ExecutionPlanner:
Returns:
数据规格列表
"""
# 计算最大回看窗口
max_window = self._extract_max_window(expression)
lookback_days = max(1, max_window)
# 基础行情字段集合(这些字段从 pro_bar 表获取)
pro_bar_fields = {
"open",
@@ -103,9 +99,27 @@ class ExecutionPlanner:
"volume_ratio",
}
# 将依赖分为 pro_bar 字段和其他字段
# 每日指标字段集合(这些字段从 daily_basic 表获取)
daily_basic_fields = {
"turnover_rate_f",
"pe",
"pe_ttm",
"pb",
"ps",
"ps_ttm",
"dv_ratio",
"dv_ttm",
"total_share",
"float_share",
"free_share",
"total_mv",
"circ_mv",
}
# 将依赖分为不同表的字段
pro_bar_deps = dependencies & pro_bar_fields
other_deps = dependencies - pro_bar_fields
daily_basic_deps = dependencies & daily_basic_fields
other_deps = dependencies - pro_bar_fields - daily_basic_fields
data_specs = []
@@ -115,7 +129,15 @@ class ExecutionPlanner:
DataSpec(
table="pro_bar",
columns=sorted(pro_bar_deps),
lookback_days=lookback_days,
)
)
# daily_basic 表的数据规格
if daily_basic_deps:
data_specs.append(
DataSpec(
table="daily_basic",
columns=sorted(daily_basic_deps),
)
)
@@ -125,46 +147,7 @@ class ExecutionPlanner:
DataSpec(
table="daily",
columns=sorted(other_deps),
lookback_days=lookback_days,
)
)
return data_specs
def _extract_max_window(self, node: Node) -> int:
"""从表达式中提取最大窗口大小。
Args:
node: AST 节点
Returns:
最大窗口大小,无时序函数返回 1
"""
if isinstance(node, FunctionNode):
window = 1
# 检查函数参数中的窗口大小
for arg in node.args:
if (
isinstance(arg, Constant)
and isinstance(arg.value, int)
and arg.value > window
):
window = arg.value
# 递归检查子表达式
for arg in node.args:
if isinstance(arg, Node) and not isinstance(arg, Constant):
window = max(window, self._extract_max_window(arg))
return window
elif isinstance(node, BinaryOpNode):
return max(
self._extract_max_window(node.left),
self._extract_max_window(node.right),
)
elif isinstance(node, UnaryOpNode):
return self._extract_max_window(node.operand)
return 1

View File

@@ -1,302 +0,0 @@
"""训练流程入口脚本
运行方式:
uv run python -m src.training.main
或:
uv run python src/training/main.py
本脚本提供两种运行方式:
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
因子配置示例:
from src.factors import MovingAverageFactor, ReturnRankFactor
# 直接传入因子实例列表 - 最简单的方式
factors = [
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
MovingAverageFactor(period=20),
ReturnRankFactor(period=5),
ReturnRankFactor(period=10),
]
# 运行完整流程
result = run_full_pipeline(factors=factors)
"""
from pathlib import Path
from typing import Optional, List
import polars as pl
from src.factors import BaseFactor
from src.training.pipeline import (
FactorConfig,
predict_top_stocks,
prepare_data,
save_top_stocks,
train_model,
)
def prepare_data_and_train(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
train_start: str = "20190101",
train_end: str = "20231231",
val_start: str = "20240102",
val_end: str = "20240531",
test_start: str = "20240602",
test_end: str = "20241231",
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
"""第一步:数据处理
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
Returns:
tuple: (train_data, val_data, test_data, factor_config, label_col)
"""
print("=" * 50)
print("[Step 1] 数据处理")
print("=" * 50)
print(f"训练集: {train_start} -> {train_end}")
print(f"验证集: {val_start} -> {val_end}")
print(f"测试集: {test_start} -> {test_end}")
print()
# 1. 准备数据
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
print(f"训练集样本数: {len(train_data)}")
print(f"验证集样本数: {len(val_data)}")
print(f"测试集样本数: {len(test_data)}")
print()
# 打印少量数据样本展示
print("=" * 50)
print("[数据预览] 训练集前3行:")
print(train_data.head(3))
print()
print("[数据预览] 验证集前3行:")
print(val_data.head(3))
print()
print("[数据预览] 测试集前3行:")
print(test_data.head(3))
print()
# 2. 获取特征列名
feature_cols = factor_config.get_feature_names()
label_col = "label"
print(f"特征列: {feature_cols}")
print(f"标签列: {label_col}")
print()
return train_data, val_data, test_data, factor_config, label_col
def train_and_predict(
train_data: pl.DataFrame,
val_data: pl.DataFrame,
test_data: pl.DataFrame,
factor_config: FactorConfig,
label_col: str = "label",
top_n: int = 5,
output_path: str = "output/top_stocks.tsv",
) -> pl.DataFrame:
"""第二步:训练和预测
使用处理好的数据训练模型,进行测试集预测并保存结果。
Args:
train_data: 训练数据
val_data: 验证数据
test_data: 测试数据
factor_config: 因子配置对象
label_col: 标签列名
top_n: 每日选股数量
output_path: 输出文件路径
Returns:
选股结果DataFrame
"""
print("=" * 50)
print("[Step 2] 模型训练与预测")
print("=" * 50)
print()
# 获取特征列名
feature_cols = factor_config.get_feature_names()
print(f"使用特征: {feature_cols}")
print()
# 3. 训练模型
print("[Training] Training model...")
model, pipeline = train_model(
train_data=train_data,
val_data=val_data,
feature_cols=feature_cols,
label_col=label_col,
)
print()
# 4. 测试集预测
print("[Predict] Predicting on test set...")
top_stocks = predict_top_stocks(
model=model,
pipeline=pipeline,
test_data=test_data,
feature_cols=feature_cols,
top_n=top_n,
)
print()
# 5. 保存结果
print(f"[Saving] Saving results to {output_path}...")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
save_top_stocks(top_stocks, output_path)
print()
return top_stocks
def run_full_pipeline(
factors: Optional[List[BaseFactor]] = None,
train_start: str = "20190101",
train_end: str = "20231231",
val_start: str = "20240102",
val_end: str = "20240531",
test_start: str = "20240602",
test_end: str = "20241231",
top_n: int = 5,
output_path: str = "output/top_stocks.tsv",
) -> pl.DataFrame:
"""运行完整训练流程
相当于依次调用 prepare_data_and_train 和 train_and_predict。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
top_n: 每日选股数量
output_path: 输出文件路径
Returns:
选股结果DataFrame
"""
# 第一步:数据处理
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
factors=factors,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
# 第二步:训练和预测
result = train_and_predict(
train_data=train_data,
val_data=val_data,
test_data=test_data,
factor_config=factor_config,
label_col=label_col,
top_n=top_n,
output_path=output_path,
)
print("=" * 50)
print("[Done] 训练流程完成!")
print("=" * 50)
return result
if __name__ == "__main__":
from src.factors import MovingAverageFactor, ReturnRankFactor
# ========== 因子配置 ==========
# 直接传入因子实例列表 - 简单直观
factors = [
MovingAverageFactor(period=5), # 5日移动平均线
MovingAverageFactor(period=10), # 10日移动平均线
MovingAverageFactor(period=20), # 20日移动平均线
ReturnRankFactor(period=5), # 5日收益率排名
ReturnRankFactor(period=10), # 10日收益率排名
]
# ========== 运行方式 ==========
# 方式一:完整流程(一次性执行)
# result = run_full_pipeline(
# factors=factors,
# train_start="20190101",
# train_end="20231231",
# val_start="20240102",
# val_end="20240531",
# test_start="20240602",
# test_end="20241231",
# top_n=5,
# output_path="output/top_stocks.tsv",
# )
# 方式二:分步执行(便于调试)
# 第一步:数据处理
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
factors=factors,
train_start="20190101",
train_end="20231231",
val_start="20240102",
val_end="20240531",
test_start="20240602",
test_end="20241231",
)
# 可在此处添加自定义逻辑,例如:
# - 查看数据分布
# - 调整特征
# - 保存中间结果
print("\n[Info] 因子配置详情:")
print(f" 因子列表: {factor_config.get_feature_names()}")
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
# 第二步:训练和预测
# result = train_and_predict(
# train_data=train_data,
# val_data=val_data,
# test_data=test_data,
# factor_config=factor_config,
# label_col=label_col,
# top_n=5,
# output_path="output/top_stocks.tsv",
# )
#
# print("\n[Result] Top stocks selection:")
# print(result)