refactor(factor): 完全重构因子计算框架 - 引入DSL表达式系统

- 删除旧因子框架:移除 base.py、composite.py、data_loader.py、data_spec.py
  及所有子模块(momentum、financial、quality、sentiment等)
- 新增DSL表达式系统:实现 factor DSL 编译器和翻译器
  - dsl.py: 领域特定语言定义
  - compiler.py: AST编译与优化
  - translator.py: Polars表达式翻译
  - api.py: 统一API接口
- 新增数据路由层:data_router.py 实现字段到表的动态路由
- 新增API封装:api_pro_bar.py 提供pro_bar数据接口
- 更新执行引擎:engine.py 适配新的DSL架构
- 重构测试体系:删除旧测试,新增 test_dsl_promotion.py、
  test_factor_integration.py、test_pro_bar.py
- 清理文档:删除8个过时文档(factor_design、db_sync_guide等)
This commit is contained in:
2026-02-27 22:22:23 +08:00
parent c3c20ed7ea
commit a56433e440
51 changed files with 667 additions and 11287 deletions

View File

@@ -3,3 +3,17 @@
提供股票数据分析和交易策略等功能
"""
__version__ = "1.0.0"
import warnings
from pandas.errors import SettingWithCopyWarning
# 忽略 tushare 库的 FutureWarningfillna method 参数已弃用)
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message=".*fillna with 'method' is deprecated.*",
)
# 忽略 SettingWithCopyWarning常见于 pandas 链式赋值)
warnings.filterwarnings("ignore", category=SettingWithCopyWarning)

View File

@@ -2,6 +2,7 @@
从环境变量加载应用配置使用pydantic-settings进行类型验证
"""
import os
from pathlib import Path
from pydantic_settings import BaseSettings
@@ -15,20 +16,33 @@ CONFIG_DIR = PROJECT_ROOT / "config"
class Settings(BaseSettings):
"""应用配置类,从环境变量加载"""
"""应用配置类,从环境变量加载
# 数据库配置
所有配置项都会自动从环境变量读取(小写转大写)
例如tushare_token 会读取 TUSHARE_TOKEN 环境变量
"""
# Tushare API 配置
tushare_token: str = ""
# 数据存储配置
root_path: str = "" # 项目根路径,默认自动检测
data_path: str = "data" # 数据存储路径,相对于 root_path
# API 速率限制(每分钟请求数)
rate_limit: int = 300
# 同步工作线程数
threads: int = 10
# 数据库配置(可选,用于未来扩展)
database_host: str = "localhost"
database_port: int = 5432
database_name: str = "prostock"
database_user: str
database_password: str
database_user: Optional[str] = None
database_password: Optional[str] = None
# API密钥配置
api_key: str
secret_key: str
# Redis配置
# Redis配置可选用于未来扩展
redis_host: str = "localhost"
redis_port: int = 6379
redis_password: Optional[str] = None
@@ -38,11 +52,27 @@ class Settings(BaseSettings):
app_debug: bool = False
app_port: int = 8000
@property
def project_root(self) -> Path:
"""获取项目根路径。"""
if self.root_path:
return Path(self.root_path)
return PROJECT_ROOT
@property
def data_path_resolved(self) -> Path:
"""获取解析后的数据路径(绝对路径)。"""
path = Path(self.data_path)
if path.is_absolute():
return path
return self.project_root / path
class Config:
# 从 config/ 目录读取 .env.local 文件
env_file = str(CONFIG_DIR / ".env.local")
env_file_encoding = "utf-8"
case_sensitive = False
extra = "ignore" # 忽略 .env.local 中的额外变量
@lru_cache()

View File

@@ -3,7 +3,7 @@
Provides simplified interfaces for fetching and storing Tushare data.
"""
from src.data.config import Config, get_config
from src.config.settings import Settings, get_settings, settings
from src.data.client import TushareClient
from src.data.storage import Storage, ThreadSafeStorage, DEFAULT_TYPE_MAPPING
from src.data.api_wrappers import get_stock_basic, sync_all_stocks

View File

@@ -169,6 +169,120 @@ if "date" in data.columns:
### 4.5 令牌桶限速要求
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
#### 4.5.1 基本用法(单线程场景)
```python
from src.data.client import TushareClient
def get_{data_type}(...) -> pd.DataFrame:
client = TushareClient()
# Build parameters
params = {}
if trade_date:
params["trade_date"] = trade_date
if ts_code:
params["ts_code"] = ts_code
# ...
# Fetch data (rate limiting handled automatically)
data = client.query("{api_name}", **params)
return data
```
**注意**: `TushareClient` 自动处理:
- 令牌桶速率限制
- API 重试逻辑(指数退避)
- 配置加载
#### 4.5.2 多线程/并发场景(重要)
**问题**: 多线程并发调用时,如果每个线程创建独立的 `TushareClient` 实例,每个实例会有独立的限流器,导致实际并发请求数 = 线程数 × 单个限流器速率,**限流失效**。
**解决方案**: 数据获取函数必须接受可选的 `client` 参数Sync 类传递共享的客户端实例。
**数据获取函数签名**(必须支持 client 参数):
```python
from src.data.client import TushareClient
from typing import Optional
def get_{data_type}(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
client: Optional[TushareClient] = None, # 新增:可选客户端参数
) -> pd.DataFrame:
"""Fetch {数据描述} from Tushare.
Args:
ts_code: Stock code
start_date: Start date (YYYYMMDD)
end_date: End date (YYYYMMDD)
client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting.
Returns:
pd.DataFrame with data
"""
client = client or TushareClient() # 如果没有提供则创建新实例
params = {"ts_code": ts_code}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
data = client.query("{api_name}", **params)
return data
```
**Sync 类实现**(必须传递共享 client
```python
from concurrent.futures import ThreadPoolExecutor
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage
class {DataType}Sync:
def __init__(self, max_workers: Optional[int] = None):
self.storage = ThreadSafeStorage()
self.client = TushareClient() # 共享客户端实例
self.max_workers = max_workers or 10
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""同步单只股票的数据。"""
# 传递共享 client 以确保多线程下的限流生效
data = get_{data_type}(
ts_code=ts_code,
start_date=start_date,
end_date=end_date,
client=self.client, # 关键:传递共享客户端
)
return data
def sync_all(self, ...):
# 使用 ThreadPoolExecutor 并发执行
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 所有线程共享 self.client限流器正常工作
...
```
**关键规则**:
1. 所有按股票获取的接口必须接受 `client: Optional[TushareClient] = None` 参数
2. Sync 类在 `__init__` 中创建 `self.client = TushareClient()`
3. Sync 类的同步方法必须将 `self.client` 传递给数据获取函数
4. 数据获取函数使用 `client = client or TushareClient()` 模式
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求:
```python
@@ -198,6 +312,26 @@ def get_{data_type}(...) -> pd.DataFrame:
## 5. DuckDB 存储规范
### 5.0 强制落库要求(关键)
**所有封装的 API 接口必须将数据落库到 DuckDB。**
这是数据同步的核心原则,确保:
- 数据持久化:避免重复调用 API节省 token
- 增量更新:基于本地数据状态进行智能同步
- 数据一致性:所有数据都有统一的存储和访问方式
- 离线可用:数据可以在没有网络的情况下查询
**落库检查清单**
- [ ]`storage.py``_init_db()` 方法中创建对应的表
- [ ] 表结构必须包含 `ts_code``trade_date` 作为主键
- [ ] 实现 `sync_{data_type}()` 函数,使用 `Storage``ThreadSafeStorage` 保存数据
- [ ] 确保同步逻辑正确处理增量更新
**反例警示**`api_pro_bar.py` 早期版本虽然实现了 `sync_pro_bar()` 函数,但忘记在 `storage.py` 中创建 `pro_bar` 表,导致同步的数据无法落库,造成 token 浪费和数据丢失。
### 5.1 存储架构
### 5.1 存储架构
项目使用 **DuckDB** 作为持久化存储:

View File

@@ -5,6 +5,7 @@ All wrapper files follow the naming convention: api_{data_type}.py
Available APIs:
- api_daily: Daily market data (日线行情)
- api_pro_bar: Pro Bar universal market data (通用行情,后复权)
- api_stock_basic: Stock basic information (股票基本信息)
- api_trade_cal: Trading calendar (交易日历)
- api_namechange: Stock name change history (股票曾用名)
@@ -12,15 +13,31 @@ Available APIs:
Example:
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
>>> from src.data.api_wrappers import get_bak_basic, sync_bak_basic
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
>>> stocks = get_stock_basic()
>>> calendar = get_trade_cal('20240101', '20240131')
>>> bak_basic = get_bak_basic(trade_date='20240101')
"""
from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync
from src.data.api_wrappers.financial_data.api_income import get_income, sync_income, IncomeSync
from src.data.api_wrappers.api_daily import (
get_daily,
sync_daily,
preview_daily_sync,
DailySync,
)
from src.data.api_wrappers.api_pro_bar import (
get_pro_bar,
sync_pro_bar,
preview_pro_bar_sync,
ProBarSync,
)
from src.data.api_wrappers.financial_data.api_income import (
get_income,
sync_income,
IncomeSync,
)
from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic
from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
@@ -38,6 +55,11 @@ __all__ = [
"sync_daily",
"preview_daily_sync",
"DailySync",
# Pro Bar (universal market data)
"get_pro_bar",
"sync_pro_bar",
"preview_pro_bar_sync",
"ProBarSync",
# Income statement
"get_income",
"sync_income",

View File

@@ -345,4 +345,154 @@ df = pro.bak_basic(trade_date='20211012', fields='trade_date,ts_code,name,indust
4530 20211012 688255.SH 凯尔达 机械基件 0.0000
4531 20211012 688211.SH 中科微至 专用机械 0.0000
4532 20211012 605567.SH 春雪食品 食品 0.0000
4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000
4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000
通用行情接口
接口名称pro_bar本接口是集成开发接口部分指标是现用现算
更新时间股票和指数通常在15点17点之间数字货币实时更新具体请参考各接口文档明细。
描述目前整合了股票未复权、前复权、后复权、指数、数字货币、ETF基金、期货、期权的行情数据未来还将整合包括外汇在内的所有交易行情数据同时提供分钟数据。不同数据对应不同的积分要求具体请参阅每类数据的文档说明。
其它由于本接口是集成接口在SDK层做了一些逻辑处理目前暂时没法用http的方式调取通用行情接口。用户可以访问Tushare的Github查看源代码完成类似功能。
输入参数
名称 类型 必选 描述
ts_code str Y 证券代码,不支持多值输入,多值输入获取结果会有重复记录
start_date str N 开始日期 (日线格式YYYYMMDD提取分钟数据请用2019-09-01 09:00:00这种格式)
end_date str N 结束日期 (日线格式YYYYMMDD)
asset str Y 资产类别E股票 I沪深指数 C数字货币 FT期货 FD基金 O期权 CB可转债v1.2.39默认E
adj str N 复权类型(只针对股票)None未复权 qfq前复权 hfq后复权 , 默认None目前只支持日线复权同时复权机制是根据设定的end_date参数动态复权采用分红再投模式具体请参考常见问题列表里的说明。
freq str Y 数据频度 :支持分钟(min)/日(D)/周(W)/月(M)K线其中1min表示1分钟类推1/5/15/30/60分钟 默认D。对于分钟数据有600积分用户可以试用请求2次正式权限可以参考权限列表说明 ,使用方法请参考股票分钟使用方法。
ma list N 均线支持任意合理int数值。注均线是动态计算要设置一定时间范围才能获得相应的均线比如5日均线开始和结束日期参数跨度必须要超过5日。目前只支持单一个股票提取均线即需要输入ts_code参数。e.g: ma_5表示5日均价ma_v_5表示5日均量
factors list N 股票因子asset='E'有效)支持 tor换手率 vr量比
adjfactor str N 复权因子在复权数据时如果此参数为True返回的数据中则带复权因子默认为False。 该功能从1.2.33版本开始生效
输出指标
具体输出的数据指标可参考各行情具体指标:
股票Dailyhttps://tushare.pro/document/2?doc_id=27
内容如下A股日线行情
接口daily可以通过数据工具调试和查看数据
数据说明交易日每天15点16点之间入库。本接口是未复权行情停牌期间不提供数据
调取说明基础积分每分钟内可调取500次每次6000条数据一次请求相当于提取一个股票23年历史
描述:获取股票行情数据,或通过通用行情接口获取数据,包含了前后复权数据
输入参数
名称 类型 必选 描述
ts_code str N 股票代码(支持多个股票同时提取,逗号分隔)
trade_date str N 交易日期YYYYMMDD
start_date str N 开始日期(YYYYMMDD)
end_date str N 结束日期(YYYYMMDD)
日期都填YYYYMMDD格式比如20181010
输出参数
名称 类型 描述
ts_code str 股票代码
trade_date str 交易日期
open float 开盘价
high float 最高价
low float 最低价
close float 收盘价
pre_close float 昨收价【除权价】
change float 涨跌额
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
vol float 成交量 (手)
amount float 成交额 (千元)
接口示例
pro = ts.pro_api()
df = pro.daily(ts_code='000001.SZ', start_date='20180701', end_date='20180718')
#多个股票
df = pro.daily(ts_code='000001.SZ,600000.SH', start_date='20180701', end_date='20180718')
或者
df = pro.query('daily', ts_code='000001.SZ', start_date='20180701', end_date='20180718')
也可以通过日期取历史某一天的全部历史
df = pro.daily(trade_date='20180810')
数据样例
ts_code trade_date open high low close pre_close change pct_chg vol amount
0 000001.SZ 20180718 8.75 8.85 8.69 8.70 8.72 -0.02 -0.23 525152.77 460697.377
1 000001.SZ 20180717 8.74 8.75 8.66 8.72 8.73 -0.01 -0.11 375356.33 326396.994
2 000001.SZ 20180716 8.85 8.90 8.69 8.73 8.88 -0.15 -1.69 689845.58 603427.713
3 000001.SZ 20180713 8.92 8.94 8.82 8.88 8.88 0.00 0.00 603378.21 535401.175
4 000001.SZ 20180712 8.60 8.97 8.58 8.88 8.64 0.24 2.78 1140492.31 1008658.828
5 000001.SZ 20180711 8.76 8.83 8.68 8.78 8.98 -0.20 -2.23 851296.70 744765.824
6 000001.SZ 20180710 9.02 9.02 8.89 8.98 9.03 -0.05 -0.55 896862.02 803038.965
7 000001.SZ 20180709 8.69 9.03 8.68 9.03 8.66 0.37 4.27 1409954.60 1255007.609
8 000001.SZ 20180706 8.61 8.78 8.45 8.66 8.60 0.06 0.70 988282.69 852071.526
9 000001.SZ 20180705 8.62 8.73 8.55 8.60 8.61 -0.01 -0.12 835768.77 722169.579
基金Dailyhttps://tushare.pro/document/2?doc_id=127
期货Dailyhttps://tushare.pro/document/2?doc_id=138
期权Dailyhttps://tushare.pro/document/2?doc_id=159
指数Dailyhttps://tushare.pro/document/2?doc_id=95
接口用例
#取000001的前复权行情
df = ts.pro_bar(ts_code='000001.SZ', adj='qfq', start_date='20180101', end_date='20181011')
ts_code trade_date open high low close \
trade_date
20181011 000001.SZ 20181011 1085.71 1097.59 1047.90 1065.19
20181010 000001.SZ 20181010 1138.65 1151.61 1121.36 1128.92
20181009 000001.SZ 20181009 1130.00 1155.93 1122.44 1140.81
20181008 000001.SZ 20181008 1155.93 1165.65 1128.92 1128.92
20180928 000001.SZ 20180928 1164.57 1217.51 1164.57 1193.74
#取上证指数行情数据
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
In [10]: df.head()
Out[10]:
ts_code trade_date close open high low \
0 000001.SH 20181011 2583.4575 2643.0740 2661.2859 2560.3164
1 000001.SH 20181010 2725.8367 2723.7242 2743.5480 2703.0626
2 000001.SH 20181009 2721.0130 2713.7319 2734.3142 2711.1971
3 000001.SH 20181008 2716.5104 2768.2075 2771.9384 2710.1781
4 000001.SH 20180928 2821.3501 2794.2644 2821.7553 2791.8363
pre_close change pct_chg vol amount
0 2725.8367 -142.3792 -5.2233 197150702.0 170057762.5
1 2721.0130 4.8237 0.1773 113485736.0 111312455.3
2 2716.5104 4.5026 0.1657 116771899.0 110292457.8
3 2821.3501 -104.8397 -3.7159 149501388.0 141531551.8
4 2791.7748 29.5753 1.0594 134290456.0 125369989.4
#均线
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', ma=[5, 20, 50])
Tushare pro_bar接口的均价和均量数据是动态计算想要获取某个时间段的均线必须要设置start_date日期大于最大均线的日期数然后自行截取想要日期段。例如想要获取20190801开始的3日均线必须设置start_date='20190729'然后剔除20190801之前的日期记录。
#换手率tor量比vr
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', factors=['tor', 'vr'])
说明
对于pro_api参数如果在一开始就通过 ts.set_token('xxxx') 设置过token的情况这个参数就不是必需的。
例如:
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')

View File

@@ -129,7 +129,9 @@ def sync_bak_basic(
columns = []
for col in sample.columns:
dtype = str(sample[col].dtype)
if "int" in dtype:
if col == "trade_date":
col_type = "DATE"
elif "int" in dtype:
col_type = "INTEGER"
elif "float" in dtype:
col_type = "DOUBLE"
@@ -223,10 +225,16 @@ def sync_bak_basic(
# Combine and save
combined = pd.concat(all_data, ignore_index=True)
# Convert trade_date to datetime for proper DATE type storage
combined["trade_date"] = pd.to_datetime(combined["trade_date"], format="%Y%m%d")
print(f"[sync_bak_basic] Total records: {len(combined)}")
# Delete existing data for the date range and append new data
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start])
# Convert sync_start to date format for comparison with DATE column
sync_start_date = pd.to_datetime(sync_start, format="%Y%m%d").date()
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start_date])
thread_storage.queue_save(TABLE_NAME, combined)
thread_storage.flush()

View File

@@ -17,6 +17,7 @@ import threading
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
from src.config.settings import get_settings
from src.data.api_wrappers.api_trade_cal import (
get_first_trading_day,
get_last_trading_day,
@@ -105,16 +106,15 @@ class DailySync:
- 预览模式(预览同步数据量,不实际写入)
"""
# 默认工作线程数
DEFAULT_MAX_WORKERS = 10
# 默认工作线程数从配置读取默认10
DEFAULT_MAX_WORKERS = get_settings().threads
def __init__(self, max_workers: Optional[int] = None):
"""初始化同步管理器。
Args:
max_workers: 工作线程数(默认: 10
max_workers: 工作线程数(默认从配置读取
"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()

View File

@@ -8,13 +8,13 @@ import pandas as pd
from pathlib import Path
from typing import Optional, List
from src.data.client import TushareClient
from src.data.config import get_config
from src.config.settings import get_settings
# CSV file path for namechange data
def _get_csv_path() -> Path:
"""Get the CSV file path for namechange data."""
cfg = get_config()
cfg = get_settings()
return cfg.data_path_resolved / "namechange.csv"

View File

@@ -9,13 +9,13 @@ import pandas as pd
from pathlib import Path
from typing import Optional, Literal, List
from src.data.client import TushareClient
from src.data.config import get_config
from src.config.settings import get_settings
# CSV file path for stock basic data
def _get_csv_path() -> Path:
"""Get the CSV file path for stock basic data."""
cfg = get_config()
cfg = get_settings()
return cfg.data_path_resolved / "stock_basic.csv"

View File

@@ -8,7 +8,7 @@ import pandas as pd
from typing import Optional, Literal
from pathlib import Path
from src.data.client import TushareClient
from src.data.config import get_config
from src.config.settings import get_settings
# Module-level flag to track if cache has been synced in this session
_cache_synced = False
@@ -18,7 +18,7 @@ _cache_synced = False
# Trading calendar cache file path
def _get_cache_path() -> Path:
"""Get the cache file path for trade calendar."""
cfg = get_config()
cfg = get_settings()
return cfg.data_path_resolved / "trade_cal.h5"
@@ -296,8 +296,8 @@ def get_first_trading_day(
trading_days = get_trading_days(start_date, end_date, exchange)
if not trading_days:
return None
# Trading days are sorted in descending order (newest first) from cache
return trading_days[-1]
# Return the earliest trading day
return min(trading_days)
def get_last_trading_day(
@@ -318,8 +318,8 @@ def get_last_trading_day(
trading_days = get_trading_days(start_date, end_date, exchange)
if not trading_days:
return None
# Trading days are sorted in descending order (newest first) from cache
return trading_days[0]
# Return the latest trading day
return max(trading_days)
if __name__ == "__main__":

View File

@@ -1,21 +1,25 @@
"""Simplified Tushare client with rate limiting and retry logic."""
import time
import pandas as pd
from typing import Optional
from src.data.config import get_config
from src.data.rate_limiter import TokenBucketRateLimiter
from src.config.settings import get_settings
class TushareClient:
"""Tushare API client with rate limiting and retry."""
# 类级别共享限流器(确保所有实例共享同一个限流器)
_shared_limiter: Optional[TokenBucketRateLimiter] = None
def __init__(self, token: Optional[str] = None):
"""Initialize client.
Args:
token: Tushare API token (auto-loaded from config if not provided)
"""
cfg = get_config()
cfg = get_settings()
token = token or cfg.tushare_token
if not token:
@@ -24,12 +28,21 @@ class TushareClient:
self.token = token
self.config = cfg
# Initialize rate limiter: capacity = rate_limit, refill_rate = rate_limit/60 per second
# 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器)
rate_per_second = cfg.rate_limit / 60.0
self.rate_limiter = TokenBucketRateLimiter(
capacity=cfg.rate_limit,
refill_rate_per_second=rate_per_second,
)
capacity = cfg.rate_limit
if TushareClient._shared_limiter is None:
# 首次创建:初始化共享限流器
TushareClient._shared_limiter = TokenBucketRateLimiter(
capacity=capacity,
refill_rate_per_second=rate_per_second,
)
print(
f"[TushareClient] Initialized shared rate limiter: capacity={capacity}, window=60s"
)
# 复用共享限流器
self.rate_limiter = TushareClient._shared_limiter
self._api = None
@@ -37,6 +50,7 @@ class TushareClient:
"""Get Tushare API instance."""
if self._api is None:
import tushare as ts
self._api = ts.pro_api(self.token)
return self._api
@@ -52,7 +66,7 @@ class TushareClient:
DataFrame with query results
"""
# Acquire rate limit token (None = wait indefinitely)
timeout = timeout if timeout is not None else float('inf')
timeout = timeout if timeout is not None else float("inf")
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
if not success:
@@ -72,14 +86,21 @@ class TushareClient:
# pro_bar uses ts.pro_bar() instead of api.query()
if api_name == "pro_bar":
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor
data = ts.pro_bar(ts_code=params.get("ts_code"),
start_date=params.get("start_date"),
end_date=params.get("end_date"),
adj=params.get("adj"),
freq=params.get("freq", "D"),
factors=params.get("factors"), # factors should be a list like ['tor', 'vr']
ma=params.get("ma"),
adjfactor=params.get("adjfactor"))
data = ts.pro_bar(
ts_code=params.get("ts_code"),
start_date=params.get("start_date"),
end_date=params.get("end_date"),
adj=params.get("adj"),
freq=params.get("freq", "D"),
factors=params.get(
"factors"
), # factors should be a list like ['tor', 'vr']
ma=params.get("ma"),
adjfactor=params.get("adjfactor"),
)
# Handle None response (e.g., delisted stock)
if data is None:
data = pd.DataFrame()
else:
api = self._get_api()
data = api.query(api_name, **params)
@@ -89,10 +110,14 @@ class TushareClient:
except Exception as e:
if attempt < max_retries - 1:
delay = retry_delays[attempt]
print(f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s")
print(
f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s"
)
time.sleep(delay)
else:
raise RuntimeError(f"API call failed after {max_retries} attempts: {e}")
raise RuntimeError(
f"API call failed after {max_retries} attempts: {e}"
)
return pd.DataFrame()

View File

@@ -1,80 +0,0 @@
"""Configuration management for data collection module."""
import os
from pathlib import Path
from pydantic_settings import BaseSettings
# Config directory path - used for loading .env.local
# Static detection for pydantic-settings to find .env.local
CONFIG_DIR = Path(__file__).parent.parent.parent / "config"
def _get_project_root() -> Path:
"""Get project root path from ROOT_PATH env var or auto-detect."""
# Try to read from environment variable first
root_path = os.environ.get("ROOT_PATH") or os.environ.get("DATA_ROOT")
if root_path:
return Path(root_path)
# Fallback to auto-detection
return Path(__file__).parent.parent.parent
class Config(BaseSettings):
"""Application configuration loaded from environment variables."""
# Tushare API token
tushare_token: str = ""
# Root path - loaded from environment variable ROOT_PATH
# If not set, uses auto-detected path
root_path: str = ""
# Data storage path - can be set via DATA_PATH environment variable
# If relative path, it will be resolved relative to root_path
data_path: str = "data"
# Rate limit: requests per minute
rate_limit: int = 100
# Thread pool size
threads: int = 2
@property
def project_root(self) -> Path:
"""Get project root path."""
if self.root_path:
return Path(self.root_path)
return _get_project_root()
@property
def data_path_resolved(self) -> Path:
"""Get resolved data path (absolute)."""
path = Path(self.data_path)
if path.is_absolute():
return path
return self.project_root / path
class Config:
# 从 config/ 目录读取 .env.local 文件
env_file = str(CONFIG_DIR / ".env.local")
env_file_encoding = "utf-8"
case_sensitive = False
extra = "ignore" # 忽略 .env.local 中的额外变量
# pydantic-settings 默认会将字段名转换为大写作为环境变量名
# 所以 tushare_token 会映射到 TUSHARE_TOKEN
# root_path 会映射到 ROOT_PATH
# data_path 会映射到 DATA_PATH
# Global config instance
config = Config()
def get_config() -> Config:
"""Get configuration instance."""
return config
def get_project_root() -> Path:
"""Get project root path (convenience function)."""
return get_config().project_root

View File

@@ -32,9 +32,12 @@ def get_db_info(db_path: Optional[Path] = None):
# Get database path
if db_path is None:
from src.data.config import get_config
from src.config.settings import get_settings
cfg = get_config()
cfg = get_settings()
db_path = cfg.data_path_resolved / "prostock.db"
cfg = get_settings()
db_path = cfg.data_path_resolved / "prostock.db"
else:
db_path = Path(db_path)
@@ -231,9 +234,12 @@ def get_table_sample(table_name: str, limit: int = 5, db_path: Optional[Path] =
db_path: Path to database file
"""
if db_path is None:
from src.data.config import get_config
from src.config.settings import get_settings
cfg = get_config()
cfg = get_settings()
db_path = cfg.data_path_resolved / "prostock.db"
cfg = get_settings()
db_path = cfg.data_path_resolved / "prostock.db"
else:
db_path = Path(db_path)

View File

@@ -1,35 +1,35 @@
"""Token bucket rate limiter implementation.
"""API 速率限制器实现。
This module provides a thread-safe token bucket algorithm for rate limiting.
提供基于固定时间窗口的速率限制,适合 Tushare 等按分钟计费的 API。
"""
import time
import threading
from typing import Optional
from dataclasses import dataclass, field
from dataclasses import dataclass
@dataclass
class RateLimiterStats:
"""Statistics for rate limiter."""
"""速率限制器统计信息。"""
total_requests: int = 0
successful_requests: int = 0
denied_requests: int = 0
total_wait_time: float = 0.0
current_tokens: Optional[float] = None
current_window_requests: int = 0
window_start_time: float = 0.0
class TokenBucketRateLimiter:
"""Thread-safe token bucket rate limiter.
"""基于固定时间窗口的速率限制器。
Implements a token bucket algorithm for controlling request rate.
Tokens are added at a fixed rate up to the bucket capacity.
适合 Tushare 等按时间窗口(如每分钟)限制请求数的 API 场景。
在窗口期内,请求数达到上限后将阻塞或等待下一个窗口。
Attributes:
capacity: Maximum number of tokens in the bucket
refill_rate: Number of tokens added per second
initial_tokens: Initial number of tokens (default: capacity)
capacity: 每个时间窗口内允许的最大请求数
window_seconds: 时间窗口长度(秒)
"""
def __init__(
@@ -38,155 +38,157 @@ class TokenBucketRateLimiter:
refill_rate_per_second: float = 1.67,
initial_tokens: Optional[int] = None,
) -> None:
"""Initialize the token bucket rate limiter.
"""初始化速率限制器。
Args:
capacity: Maximum token capacity
refill_rate_per_second: Token refill rate per second
initial_tokens: Initial token count (default: capacity)
capacity: 每个时间窗口内允许的最大请求数
refill_rate_per_second: 保留参数(向后兼容),实际使用 window_seconds=60
initial_tokens: 保留参数(向后兼容)
"""
self.capacity = capacity
self.refill_rate = refill_rate_per_second
self.tokens = float(initial_tokens if initial_tokens is not None else capacity)
self.last_refill_time = time.monotonic()
# Tushare 通常按分钟限制,所以固定使用 60 秒窗口
self.window_seconds = 60.0
self._requests_in_window = 0
self._window_start = time.monotonic()
self._lock = threading.RLock()
self._stats = RateLimiterStats()
self._stats.current_tokens = self.tokens
self._stats.window_start_time = self._window_start
def _is_new_window(self) -> bool:
"""检查是否已进入新的时间窗口。"""
current_time = time.monotonic()
elapsed = current_time - self._window_start
return elapsed >= self.window_seconds
def _reset_window(self) -> None:
"""重置时间窗口。"""
self._window_start = time.monotonic()
self._requests_in_window = 0
self._stats.window_start_time = self._window_start
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
"""Acquire a token from the bucket.
"""获取请求许可。
Blocks until a token is available or timeout expires.
如果在当前窗口内请求数已达上限,则等待到下一个窗口。
Args:
timeout: Maximum time to wait for a token in seconds (default: inf)
timeout: 最大等待时间(秒),默认无限等待
Returns:
Tuple of (success, wait_time):
- success: True if token was acquired, False if timed out
- wait_time: Time spent waiting for token
(success, wait_time): 是否成功获取许可,以及等待时间
"""
start_time = time.monotonic()
wait_time = 0.0
with self._lock:
self._refill()
# 检查是否需要进入新窗口
if self._is_new_window():
self._reset_window()
if self.tokens >= 1:
self.tokens -= 1
# 如果当前窗口还有余量,直接通过
if self._requests_in_window < self.capacity:
self._requests_in_window += 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_tokens = self.tokens
self._stats.current_window_requests = self._requests_in_window
return True, 0.0
# Calculate time to wait for next token
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
# 当前窗口已满,计算需要等待的时间
current_time = time.monotonic()
time_to_next_window = self.window_seconds - (
current_time - self._window_start
)
# Check if we can wait for the token within timeout
# Handle infinite timeout specially
is_infinite_timeout = timeout == float("inf")
if not is_infinite_timeout and time_to_refill > timeout:
if time_to_next_window <= 0:
# 刚好进入新窗口
self._reset_window()
self._requests_in_window = 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_window_requests = 1
return True, 0.0
# 检查是否能在超时时间内等待
if timeout != float("inf") and time_to_next_window > timeout:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, timeout
# Wait for tokens - loop until we get one or timeout
while True:
# Calculate remaining time we can wait
elapsed = time.monotonic() - start_time
remaining_timeout = (
timeout - elapsed if not is_infinite_timeout else float("inf")
)
# 需要等待到下一个窗口
if timeout != float("inf"):
time_to_wait = min(time_to_next_window, timeout)
else:
time_to_wait = time_to_next_window
# Check if we've exceeded timeout
if not is_infinite_timeout and remaining_timeout <= 0:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, elapsed
time.sleep(time_to_wait)
# Calculate wait time for next token
tokens_needed = max(0, 1 - self.tokens)
time_to_wait = (
tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1
)
# If we can't wait long enough, fail
if not is_infinite_timeout and time_to_wait > remaining_timeout:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, elapsed
# Wait outside the lock to allow other threads to refill
self._lock.release()
time.sleep(
min(time_to_wait, 0.1)
) # Cap wait to 100ms to check frequently
self._lock.acquire()
# Refill and check again
self._refill()
if self.tokens >= 1:
self.tokens -= 1
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.total_wait_time += wait_time
self._stats.current_tokens = self.tokens
return True, wait_time
def acquire_nonblocking(self) -> tuple[bool, float]:
"""Try to acquire a token without blocking.
Returns:
Tuple of (success, wait_time):
- success: True if token was acquired, False otherwise
- wait_time: 0 for non-blocking, or required wait time if failed
"""
# 重新尝试获取许可
with self._lock:
self._refill()
# 再次检查窗口状态(可能其他线程已经重置了窗口)
if self._is_new_window():
self._reset_window()
if self.tokens >= 1:
self.tokens -= 1
if self._requests_in_window < self.capacity:
self._requests_in_window += 1
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_tokens = self.tokens
self._stats.total_wait_time += wait_time
self._stats.current_window_requests = self._requests_in_window
return True, wait_time
else:
# 在极端情况下,等待后仍然无法获取(其他线程抢先)
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, wait_time
def acquire_nonblocking(self) -> tuple[bool, float]:
"""尝试非阻塞地获取请求许可。
Returns:
(success, wait_time): 是否成功获取许可,以及需要等待的时间
"""
with self._lock:
# 检查是否需要进入新窗口
if self._is_new_window():
self._reset_window()
# 如果当前窗口还有余量,直接通过
if self._requests_in_window < self.capacity:
self._requests_in_window += 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_window_requests = self._requests_in_window
return True, 0.0
# Calculate time needed
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
# 当前窗口已满,计算需要等待的时间
current_time = time.monotonic()
time_to_next_window = self.window_seconds - (
current_time - self._window_start
)
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, time_to_refill
def _refill(self) -> None:
"""Refill tokens based on elapsed time."""
current_time = time.monotonic()
elapsed = current_time - self.last_refill_time
self.last_refill_time = current_time
tokens_to_add = elapsed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
return False, max(0.0, time_to_next_window)
def get_available_tokens(self) -> float:
"""Get the current number of available tokens.
"""获取当前窗口剩余可用请求数。
Returns:
Current token count
当前窗口剩余可用请求数
"""
with self._lock:
self._refill()
return self.tokens
if self._is_new_window():
return float(self.capacity)
return float(self.capacity - self._requests_in_window)
def get_stats(self) -> RateLimiterStats:
"""Get rate limiter statistics.
"""获取速率限制器统计信息。
Returns:
RateLimiterStats instance
RateLimiterStats 实例
"""
with self._lock:
self._refill()
self._stats.current_tokens = self.tokens
self._stats.current_window_requests = self._requests_in_window
return self._stats

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from typing import Optional, List, Dict, Any, Tuple
from collections import defaultdict
from datetime import datetime
from src.data.config import get_config
from src.config.settings import get_settings
# Default column type mapping for automatic schema inference
@@ -53,7 +53,7 @@ class Storage:
if hasattr(self, "_initialized"):
return
cfg = get_config()
cfg = get_settings()
self.base_path = path or cfg.data_path_resolved
self.base_path.mkdir(parents=True, exist_ok=True)
self.db_path = self.base_path / "prostock.db"
@@ -190,6 +190,26 @@ class Storage:
update_flag VARCHAR(1),
PRIMARY KEY (ts_code, end_date)
)
# Create pro_bar table for pro bar data (with adj, tor, vr)
self._connection.execute("""
CREATE TABLE IF NOT EXISTS pro_bar (
ts_code VARCHAR(16) NOT NULL,
trade_date DATE NOT NULL,
open DOUBLE,
high DOUBLE,
low DOUBLE,
close DOUBLE,
pre_close DOUBLE,
change DOUBLE,
pct_chg DOUBLE,
vol DOUBLE,
amount DOUBLE,
tor DOUBLE,
vr DOUBLE,
adj_factor DOUBLE,
PRIMARY KEY (ts_code, trade_date)
)
""")
# Create index for financial_income

View File

@@ -29,6 +29,7 @@ import pandas as pd
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
def preview_sync(
@@ -134,7 +135,6 @@ def sync_all_data(
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步所有数据类型(每日同步)。
该函数按顺序同步所有可用的数据类型:
1. 交易日历 (sync_trade_cal_cache)
2. 股票基本信息 (sync_all_stocks)
@@ -146,13 +146,12 @@ def sync_all_data(
Args:
force_full: 若为 True强制所有数据类型完整重载
max_workers: 日线数据同步的工作线程数(默认: 10
dry_run: 若为 True仅显示将要同步的内容 Returns:
映射数据类型,不写入数据
dry_run: 若为 True仅显示将要同步的内容,不写入数据
到同步结果的字典
Returns:
映射数据类型到同步结果的字典
Example:
>>> # 同步所有数据(增量)
>>> result = sync_all_data()
>>>
>>> # 强制完整重载
@@ -167,6 +166,92 @@ def sync_all_data(
print("[sync_all_data] Starting full data synchronization...")
print("=" * 60)
# 1. Sync trade calendar (always needed first)
print("\n[1/6] Syncing trade calendar cache...")
try:
from src.data.api_wrappers import sync_trade_cal_cache
sync_trade_cal_cache()
results["trade_cal"] = pd.DataFrame()
print("[1/6] Trade calendar: OK")
except Exception as e:
print(f"[1/6] Trade calendar: FAILED - {e}")
results["trade_cal"] = pd.DataFrame()
# 2. Sync stock basic info
print("\n[2/6] Syncing stock basic info...")
try:
sync_all_stocks()
results["stock_basic"] = pd.DataFrame()
print("[2/6] Stock basic: OK")
except Exception as e:
print(f"[2/6] Stock basic: FAILED - {e}")
results["stock_basic"] = pd.DataFrame()
# # 3. Sync daily market data
# print("\n[3/6] Syncing daily market data...")
# try:
# daily_result = sync_daily(
# force_full=force_full,
# max_workers=max_workers,
# dry_run=dry_run,
# )
# results["daily"] = (
# pd.concat(daily_result.values(), ignore_index=True)
# if daily_result
# else pd.DataFrame()
# )
# print("[3/6] Daily data: OK")
# except Exception as e:
# print(f"[3/6] Daily data: FAILED - {e}")
# results["daily"] = pd.DataFrame()
# 4. Sync Pro Bar data
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...")
try:
pro_bar_result = sync_pro_bar(
force_full=force_full,
max_workers=max_workers,
dry_run=dry_run,
)
results["pro_bar"] = (
pd.concat(pro_bar_result.values(), ignore_index=True)
if pro_bar_result
else pd.DataFrame()
)
print(f"[4/6] Pro Bar data: OK ({len(results['pro_bar'])} records)")
except Exception as e:
print(f"[4/6] Pro Bar data: FAILED - {e}")
results["pro_bar"] = pd.DataFrame()
# 5. Sync stock historical list (bak_basic)
print("\n[5/6] Syncing stock historical list (bak_basic)...")
try:
bak_basic_result = sync_bak_basic(force_full=force_full)
results["bak_basic"] = bak_basic_result
print(f"[5/6] Bak basic: OK ({len(bak_basic_result)} records)")
except Exception as e:
print(f"[5/6] Bak basic: FAILED - {e}")
results["bak_basic"] = pd.DataFrame()
# Summary
print("\n" + "=" * 60)
print("[sync_all_data] Sync Summary")
print("=" * 60)
for data_type, df in results.items():
print(f" {data_type}: {len(df)} records")
print("=" * 60)
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
print(" from src.data.api_wrappers import sync_namechange")
print(" sync_namechange(force=True)")
return results
results: Dict[str, pd.DataFrame] = {}
print("\n" + "=" * 60)
print("[sync_all_data] Starting full data synchronization...")
print("=" * 60)
# 1. Sync trade calendar (always needed first)
print("\n[1/5] Syncing trade calendar cache...")
try:

File diff suppressed because it is too large Load Diff

View File

@@ -1,118 +0,0 @@
"""ProStock 因子计算框架
因子框架提供以下核心功能:
1. 类型安全的因子定义(截面因子、时序因子)
2. 数据泄露防护机制
3. 因子组合和运算
4. 高效的数据加载和计算引擎
基础数据类型Phase 1
- DataSpec: 数据需求规格
- FactorContext: 计算上下文
- FactorData: 数据容器
因子基类Phase 2
- BaseFactor: 抽象基类
- CrossSectionalFactor: 日期截面因子基类
- TimeSeriesFactor: 时间序列因子基类
- CompositeFactor: 组合因子
- ScalarFactor: 标量运算因子
因子分类目录:
- momentum/: 动量因子MA、收益率排名等
- financial/: 财务因子EPS、ROE等
- valuation/: 估值因子PE、PB、PS等
- technical/: 技术指标因子RSI、MACD、布林带等
- quality/: 质量因子(盈利能力、稳定性等)
- sentiment/: 情绪因子(换手率、资金流向等)
- volume/: 成交量因子OBV、成交量比率等
- volatility/: 波动率因子历史波动率、GARCH等
数据加载和执行Phase 3-4
- DataLoader: 数据加载器
- FactorEngine: 因子执行引擎
使用示例:
# 使用通用因子(参数化)
from src.factors import MovingAverageFactor, ReturnRankFactor
from src.factors import DataLoader, FactorEngine
ma5 = MovingAverageFactor(period=5) # 5日MA
ma10 = MovingAverageFactor(period=10) # 10日MA
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
loader = DataLoader(data_dir="data")
engine = FactorEngine(loader)
result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131")
"""
因子框架提供以下核心功能
1. 类型安全的因子定义截面因子时序因子
2. 数据泄露防护机制
3. 因子组合和运算
4. 高效的数据加载和计算引擎
基础数据类型Phase 1
- DataSpec: 数据需求规格
- FactorContext: 计算上下文
- FactorData: 数据容器
因子基类Phase 2
- BaseFactor: 抽象基类
- CrossSectionalFactor: 日期截面因子基类
- TimeSeriesFactor: 时间序列因子基类
- CompositeFactor: 组合因子
- ScalarFactor: 标量运算因子
动量因子momentum/
- MovingAverageFactor: 移动平均线时序因子
- ReturnRankFactor: 收益率排名截面因子
财务因子financial/
- 待添加
数据加载和执行Phase 3-4
- DataLoader: 数据加载器
- FactorEngine: 因子执行引擎
使用示例
# 使用通用因子(参数化)
from src.factors import MovingAverageFactor, ReturnRankFactor
from src.factors import DataLoader, FactorEngine
ma5 = MovingAverageFactor(period=5) # 5日MA
ma10 = MovingAverageFactor(period=10) # 10日MA
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
loader = DataLoader(data_dir="data")
engine = FactorEngine(loader)
result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131")
"""
from src.factors.data_spec import DataSpec, FactorContext, FactorData
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
from src.factors.composite import CompositeFactor, ScalarFactor
from src.factors.data_loader import DataLoader
from src.factors.engine import FactorEngine
# 动量因子
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
__all__ = [
# Phase 1: 数据类型定义
"DataSpec",
"FactorContext",
"FactorData",
# Phase 2: 因子基类
"BaseFactor",
"CrossSectionalFactor",
"TimeSeriesFactor",
"CompositeFactor",
"ScalarFactor",
# Phase 3-4: 数据加载和执行引擎
"DataLoader",
"FactorEngine",
# 动量因子
"MovingAverageFactor",
"ReturnRankFactor",
]

View File

@@ -1,274 +0,0 @@
"""因子基类 - Phase 2 核心抽象类
本模块定义了因子框架的基类:
- BaseFactor: 抽象基类,定义通用接口和验证逻辑
- CrossSectionalFactor: 日期截面因子基类(防止日期泄露)
- TimeSeriesFactor: 时间序列因子基类(防止股票泄露)
"""
from abc import ABC, abstractmethod
from dataclasses import field
from typing import List
import polars as pl
from src.factors.data_spec import DataSpec, FactorData
class BaseFactor(ABC):
"""因子基类 - 定义通用接口
所有因子必须继承此类,并声明以下类属性:
- name: 因子唯一标识snake_case
- factor_type: "cross_sectional""time_series"
- data_specs: List[DataSpec] 数据需求列表
可选声明:
- category: 因子分类(默认 "default"
- description: 因子描述
示例:
>>> class MyFactor(CrossSectionalFactor):
... name = "my_factor"
... data_specs = [DataSpec("daily", ["close"], lookback_days=5)]
...
... def compute(self, data: FactorData) -> pl.Series:
... return data.get_column("close").rank()
"""
# 必须声明的类属性
name: str = ""
factor_type: str = "" # "cross_sectional" | "time_series"
data_specs: List[DataSpec] = field(default_factory=list)
# 可选声明的类属性
category: str = "default"
description: str = ""
def __init_subclass__(cls, **kwargs):
"""子类创建时验证必须属性
验证项:
1. name 必须是非空字符串
2. factor_type 必须是 "cross_sectional""time_series"
3. data_specs 必须是非空列表
"""
super().__init_subclass__(**kwargs)
# 跳过抽象基类和特殊因子类的验证
if cls.__name__ in (
"CrossSectionalFactor",
"TimeSeriesFactor",
"CompositeFactor",
"ScalarFactor",
):
return
# 验证 name - 必须直接定义在类中(不能继承)
if "name" not in cls.__dict__ or not cls.name:
raise ValueError(f"Factor {cls.__name__} must define 'name'")
if not isinstance(cls.name, str):
raise ValueError(f"Factor {cls.__name__}.name must be a string")
# 验证 factor_type - 必须有值(可以是继承的)
if not cls.factor_type:
raise ValueError(f"Factor {cls.__name__} must define 'factor_type'")
if cls.factor_type not in ("cross_sectional", "time_series"):
raise ValueError(
f"Factor {cls.__name__}.factor_type must be 'cross_sectional' "
f"or 'time_series', got '{cls.factor_type}'"
)
# 验证 data_specs
# 情况1: 完全没有定义 data_specs继承的空列表
if "data_specs" not in cls.__dict__:
raise ValueError(f"Factor {cls.__name__} must define 'data_specs'")
# 情况2: 定义了但为空列表
if not cls.data_specs or len(cls.data_specs) == 0:
raise ValueError(f"Factor {cls.__name__}.data_specs cannot be empty")
if not isinstance(cls.data_specs, list):
raise ValueError(f"Factor {cls.__name__}.data_specs must be a list")
def __init__(self, **params):
"""初始化因子参数
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
注意data_specs 必须在类级别定义(类属性),
而非在 __init__ 中设置。data_specs 的验证在
__init_subclass__ 中完成(类创建时)。
Args:
**params: 因子参数,存储在 self.params 中
"""
self.params = params
def _validate_params(self):
"""验证参数有效性
子类可覆盖此方法进行自定义验证(需自行在子类 __init__ 中调用)。
基类实现为空,表示不执行任何验证。
注意:由于 data_specs 在类创建时通过 __init_subclass__ 验证,
不应在实例级别修改。如需动态 data_specs请使用参数化模式
>>> class ParamFactor(TimeSeriesFactor):
... name = "param_factor"
... data_specs = [] # 类级别定义
...
... def __init__(self, period: int = 20):
... super().__init__(period=period)
... # 通过参数化改变计算逻辑,而非 data_specs
...
... def compute(self, data: FactorData) -> pl.Series:
... return data.get_column("close").rolling_mean(self.params["period"])
"""
pass
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""核心计算逻辑 - 子类必须实现
Args:
data: 安全的数据容器,已根据因子类型裁剪
Returns:
计算得到的因子值 Series
"""
pass
# ========== 因子组合运算符 ==========
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相加f1 + f2要求同类型"""
from src.factors.composite import CompositeFactor
return CompositeFactor(self, other, "+")
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相减f1 - f2要求同类型"""
from src.factors.composite import CompositeFactor
return CompositeFactor(self, other, "-")
def __mul__(self, other):
"""因子相乘f1 * f2 或 f1 * scalar"""
if isinstance(other, (int, float)):
from src.factors.composite import ScalarFactor
return ScalarFactor(self, float(other), "*")
elif isinstance(other, BaseFactor):
from src.factors.composite import CompositeFactor
return CompositeFactor(self, other, "*")
return NotImplemented
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
"""因子相除f1 / f2要求同类型"""
from src.factors.composite import CompositeFactor
return CompositeFactor(self, other, "/")
def __rmul__(self, scalar: float) -> "ScalarFactor":
"""标量乘法0.5 * f1"""
from src.factors.composite import ScalarFactor
return ScalarFactor(self, scalar, "*")
def __repr__(self) -> str:
"""返回因子的字符串表示"""
return (
f"{self.__class__.__name__}(name='{self.name}', type='{self.factor_type}')"
)
class CrossSectionalFactor(BaseFactor):
"""日期截面因子基类
计算逻辑:在每个交易日,对所有股票进行横向计算
防泄露边界:
- ❌ 禁止访问未来日期的数据(日期泄露)
- ✅ 允许访问当前日期的所有股票数据
数据传入:
- compute() 接收的是 [T-lookback+1, T] 的数据
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
示例:
>>> class PERankFactor(CrossSectionalFactor):
... name = "pe_rank"
... data_specs = [DataSpec("daily", ["pe"], lookback_days=1)]
...
... def compute(self, data: FactorData) -> pl.Series:
... cs = data.get_cross_section()
... return cs["pe"].rank()
"""
factor_type: str = "cross_sectional"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""计算截面因子值
Args:
data: FactorData包含 [T-lookback+1, T] 的截面数据
格式DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
示例:
>>> def compute(self, data):
... # 获取当前日期截面
... cs = data.get_cross_section()
... # 计算市值排名
... return cs['market_cap'].rank()
"""
pass
class TimeSeriesFactor(BaseFactor):
"""时间序列因子基类(股票截面)
计算逻辑:对每只股票,在其时间序列上进行纵向计算
防泄露边界:
- ❌ 禁止访问其他股票的数据(股票泄露)
- ✅ 允许访问该股票的完整历史数据
数据传入:
- compute() 接收的是单只股票的完整时间序列
- 包含该股票在 [start_date, end_date] 范围内的所有数据
示例:
>>> class MovingAverageFactor(TimeSeriesFactor):
... name = "ma"
...
... def __init__(self, period: int = 20):
... super().__init__(period=period)
... self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
...
... def compute(self, data: FactorData) -> pl.Series:
... return data.get_column("close").rolling_mean(self.params["period"])
"""
factor_type: str = "time_series"
@abstractmethod
def compute(self, data: FactorData) -> pl.Series:
"""计算时间序列因子值
Args:
data: FactorData包含单只股票的完整时间序列
格式DataFrame[ts_code, trade_date, col1, col2, ...]
Returns:
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
示例:
>>> def compute(self, data):
... series = data.get_column("close")
... return series.rolling_mean(window_size=self.params['period'])
"""
pass

View File

@@ -1,201 +0,0 @@
"""组合因子 - Phase 2 因子组合和标量运算
本模块定义了因子组合相关的类:
- CompositeFactor: 组合因子,用于实现因子间的数学运算
- ScalarFactor: 标量运算因子,支持因子与标量的运算
"""
from typing import List
import polars as pl
from src.factors.data_spec import DataSpec, FactorData
from src.factors.base import BaseFactor
class CompositeFactor(BaseFactor):
"""组合因子 - 用于实现因子间的数学运算
约束:左右因子必须是同类型(同为截面或同为时序)
支持的运算符:'+', '-', '*', '/'
示例:
>>> f1 = SomeCrossSectionalFactor()
>>> f2 = AnotherCrossSectionalFactor()
>>> combined = f1 + f2 # 创建 CompositeFactor
"""
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
"""创建组合因子
Args:
left: 左操作数因子
right: 右操作数因子
op: 运算符,支持 '+', '-', '*', '/'
Raises:
ValueError: 左右因子类型不一致
ValueError: 不支持的运算符
"""
# 验证类型一致性
if left.factor_type != right.factor_type:
raise ValueError(
f"Cannot combine factors of different types: "
f"'{left.factor_type}' vs '{right.factor_type}'"
)
# 验证运算符
if op not in ("+", "-", "*", "/"):
raise ValueError(f"Unsupported operator: '{op}'")
self.left = left
self.right = right
self.op = op
# 设置类属性
self.factor_type = left.factor_type
self.name = f"({left.name}_{op}_{right.name})"
self.data_specs = self._merge_data_specs()
self.category = "composite"
self.description = f"Composite factor: {left.name} {op} {right.name}"
# 注意:不调用 super().__init__(),因为 CompositeFactor 是特殊因子
self.params = {
"left": left,
"right": right,
"op": op,
}
def _merge_data_specs(self) -> List[DataSpec]:
"""合并左右因子的数据需求
策略:
1. 相同 source 和 columns 的 DataSpec 合并
2. lookback_days 取最大值
Returns:
合并后的 DataSpec 列表
"""
merged = []
# 收集所有 specs
all_specs = list(self.left.data_specs) + list(self.right.data_specs)
# 按 (source, columns_tuple) 分组
spec_groups = {}
for spec in all_specs:
key = (spec.source, tuple(sorted(spec.columns)))
if key not in spec_groups:
spec_groups[key] = []
spec_groups[key].append(spec)
# 合并每组,取最大 lookback_days
for (source, columns_tuple), specs in spec_groups.items():
max_lookback = max(spec.lookback_days for spec in specs)
merged.append(
DataSpec(
source=source,
columns=list(columns_tuple),
lookback_days=max_lookback,
)
)
return merged
def compute(self, data: FactorData) -> pl.Series:
"""执行组合运算
流程:
1. 分别计算 left 和 right 的值
2. 根据 op 执行运算
3. 返回结果
Args:
data: 包含左右因子所需数据的 FactorData
Returns:
组合运算后的因子值 Series
"""
left_values = self.left.compute(data)
right_values = self.right.compute(data)
ops = {
"+": lambda a, b: a + b,
"-": lambda a, b: a - b,
"*": lambda a, b: a * b,
"/": lambda a, b: a / b,
}
return ops[self.op](left_values, right_values)
def _validate_params(self):
"""CompositeFactor 不需要额外验证"""
pass
class ScalarFactor(BaseFactor):
"""标量运算因子
支持scalar * factor, factor * scalar通过 __rmul__
示例:
>>> factor = SomeFactor()
>>> scaled = 0.5 * factor # 创建 ScalarFactor
"""
def __init__(self, factor: BaseFactor, scalar: float, op: str):
"""创建标量运算因子
Args:
factor: 基础因子
scalar: 标量值
op: 运算符,支持 '*', '+'
Raises:
ValueError: 不支持的运算符
"""
# 验证运算符
if op not in ("*", "+"):
raise ValueError(f"ScalarFactor only supports '*' and '+', got '{op}'")
self.factor = factor
self.scalar = scalar
self.op = op
# 设置类属性
self.factor_type = factor.factor_type
self.name = f"({scalar}_{op}_{factor.name})"
self.data_specs = factor.data_specs
self.category = "scalar"
self.description = f"Scalar factor: {scalar} {op} {factor.name}"
# 注意:不调用 super().__init__(),因为 ScalarFactor 是特殊因子
self.params = {
"factor": factor,
"scalar": scalar,
"op": op,
}
def compute(self, data: FactorData) -> pl.Series:
"""执行标量运算
Args:
data: 包含基础因子所需数据的 FactorData
Returns:
标量运算后的因子值 Series
"""
values = self.factor.compute(data)
if self.op == "*":
return values * self.scalar
elif self.op == "+":
return values + self.scalar
else:
# 不应该执行到这里,因为 __init__ 已经验证了 op
raise ValueError(f"Unsupported operation: '{self.op}'")
def _validate_params(self):
"""ScalarFactor 不需要额外验证"""
pass

View File

@@ -1,213 +0,0 @@
"""数据加载器 - Phase 3 数据加载模块
本模块负责从 DuckDB 安全加载数据:
- DataLoader: 数据加载器,支持多文件聚合、列选择、缓存
"""
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import pandas as pd
import polars as pl
from src.factors.data_spec import DataSpec
class DataLoader:
"""数据加载器 - 负责从 DuckDB 安全加载数据
功能:
1. 多文件聚合:合并多个表的数据
2. 列选择:只加载需要的列
3. 原始数据缓存:避免重复读取
4. 查询下推:利用 DuckDB SQL 过滤,只加载必要数据
示例:
>>> loader = DataLoader(data_dir="data")
>>> specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
>>> df = loader.load(specs, date_range=("20240101", "20240131"))
"""
def __init__(self, data_dir: str):
"""初始化 DataLoader
Args:
data_dir: DuckDB 数据库文件所在目录
"""
self.data_dir = Path(data_dir)
self._cache: Dict[str, pl.DataFrame] = {}
def load(
self,
specs: List[DataSpec],
date_range: Optional[Tuple[str, str]] = None,
) -> pl.DataFrame:
"""加载并聚合多个 H5 文件的数据
流程:
1. 对每个 DataSpec
a. 检查缓存,命中则直接使用
b. 未命中则读取 HDF5通过 pandas
c. 转换为 Polars DataFrame
d. 按 date_range 过滤
e. 存入缓存
2. 合并多个 DataFrame按 trade_date 和 ts_code join
Args:
specs: 数据需求规格列表
date_range: 日期范围限制 (start_date, end_date),可选
Returns:
合并后的 Polars DataFrame
Raises:
FileNotFoundError: H5 文件不存在
KeyError: 列不存在于文件中
"""
dataframes = []
for spec in specs:
# 检查缓存
cache_key = f"{spec.source}_{','.join(sorted(spec.columns))}"
if cache_key in self._cache:
df = self._cache[cache_key]
else:
# 读取 H5 文件(传入日期范围以支持过滤)
df = self._read_h5(spec.source, date_range=date_range)
# 列选择 - 只保留需要的列
missing_cols = set(spec.columns) - set(df.columns)
if missing_cols:
raise KeyError(
f"Columns {missing_cols} not found in {spec.source}.h5. "
f"Available columns: {df.columns}"
)
df = df.select(spec.columns)
# 存入缓存
self._cache[cache_key] = df
# 按 date_range 过滤
if date_range:
start_date, end_date = date_range
df = df.filter(
(pl.col("trade_date") >= start_date)
& (pl.col("trade_date") <= end_date)
)
dataframes.append(df)
# 合并多个 DataFrame
if len(dataframes) == 1:
return dataframes[0]
else:
return self._merge_dataframes(dataframes)
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
def _read_h5(
self,
source: str,
date_range: Optional[Tuple[str, str]] = None,
) -> pl.DataFrame:
"""读取数据 - 从 DuckDB 加载为 Polars DataFrame。
迁移说明:
- 方法名保持 _read_h5 以兼容现有代码(实际从 DuckDB 读取)
- 使用 Storage.load_polars() 直接返回 Polars DataFrame
- 支持零拷贝导出,性能优于 HDF5 + Pandas + Polars 转换
Args:
source: 表名(对应 DuckDB 中的表,如 "daily"
date_range: 日期范围限制 (start_date, end_date),可选
Returns:
Polars DataFrame
Raises:
Exception: 数据库查询错误
"""
from src.data.storage import Storage
from src.data.api_wrappers.api_trade_cal import get_trading_days
from src.data.utils import get_today_date
from src.factors.financial.utils import expand_period_to_trading_days
storage = Storage()
# 特殊处理财务数据:将报告期展开到交易日
if source == "financial_income":
# 确定日期范围
start_date = date_range[0] if date_range else "20180101"
end_date = date_range[1] if date_range else get_today_date()
# 1. 加载原始财务数据(报告期粒度),按日期范围过滤
# 注意financial_income 使用 end_date 字段作为报告期
df = storage.load_polars(
"financial_income",
start_date=start_date,
end_date=end_date,
)
if len(df) == 0:
return pl.DataFrame()
# 2. 获取交易日历从2018年开始到当前确保有足够的历史数据用于前向填充
# 需要从数据的最小日期开始,确保能获取到足够的交易日
trade_start = "20180101" if start_date > "20180101" else start_date
trade_dates = get_trading_days(trade_start, get_today_date())
# 3. 展开到交易日(前向填充)
return expand_period_to_trading_days(df, trade_dates)
# 其他数据源保持原有逻辑
return storage.load_polars(source)
def _merge_dataframes(self, dataframes: List[pl.DataFrame]) -> pl.DataFrame:
"""合并多个 DataFrame
策略:
1. 按 trade_date 和 ts_code join
2. 使用外连接保留所有数据
Args:
dataframes: DataFrame 列表
Returns:
合并后的 DataFrame
"""
result = dataframes[0]
for df in dataframes[1:]:
# 确定 join 键
join_keys = ["trade_date", "ts_code"]
# 检查 join 键是否存在
for key in join_keys:
if key not in result.columns or key not in df.columns:
raise KeyError(f"Join key '{key}' not found in DataFrames")
# 获取需要添加的列(排除重复的 join 键)
new_cols = [c for c in df.columns if c not in result.columns]
if new_cols:
# 选择必要的列进行 join
df_to_join = df.select(join_keys + new_cols)
# 执行 join
result = result.join(df_to_join, on=join_keys, how="full")
return result
def get_cache_info(self) -> Dict[str, int]:
"""获取缓存信息
Returns:
包含缓存条目数和总字节数的字典
"""
total_rows = sum(len(df) for df in self._cache.values())
return {
"entries": len(self._cache),
"total_rows": total_rows,
}

View File

@@ -1,242 +0,0 @@
"""数据类型定义 - Phase 1 核心数据模型
本模块定义了因子框架的基础数据类型:
- DataSpec: 数据需求规格,声明因子所需的数据源、列和回看窗口
- FactorContext: 计算上下文,由引擎自动注入,提供计算点信息
- FactorData: 数据容器,封装底层 Polars DataFrame提供安全的数据访问
"""
from dataclasses import dataclass, field
from typing import List, Optional
import polars as pl
@dataclass(frozen=True)
class DataSpec:
"""数据需求规格说明
用于声明因子计算所需的数据来源、列和回看窗口。
这是一个不可变对象,创建后不可修改。
Args:
source: H5 文件名(如 "daily", "fundamental"
columns: 需要的列名列表,必须包含 "ts_code""trade_date"
lookback_days: 需要回看的天数(包含当日)
- 1 表示只需要当日数据 [T]
- 5 表示需要 [T-4, T] 共5天
- 20 表示需要 [T-19, T] 共20天
Raises:
ValueError: 当参数不满足约束条件时
Examples:
>>> spec = DataSpec(
... source="daily",
... columns=["ts_code", "trade_date", "close"],
... lookback_days=20
... )
"""
source: str
columns: List[str]
lookback_days: int = 1
def __post_init__(self):
"""验证约束条件
验证项:
1. lookback_days >= 1至少包含当日
2. columns 必须包含 ts_code 和 trade_date
3. source 不能为空字符串
注意:由于 frozen=True实例创建后不可修改。
若需要在 __post_init__ 中修改字段(如有),可使用 object.__setattr__。
本类仅做验证,无需修改字段,因此直接 raise ValueError 即可。
"""
if self.lookback_days < 1:
raise ValueError(f"lookback_days must be >= 1, got {self.lookback_days}")
if not self.source:
raise ValueError("source cannot be empty string")
required_cols = {"ts_code", "trade_date"}
missing_cols = required_cols - set(self.columns)
if missing_cols:
raise ValueError(
f"columns must contain {required_cols}, missing: {missing_cols}"
)
@dataclass
class FactorContext:
"""因子计算上下文
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问。
根据因子类型的不同,包含不同的上下文信息:
- CrossSectionalFactorcurrent_date 表示当前计算的日期
- TimeSeriesFactorcurrent_stock 表示当前计算的股票
Attributes:
current_date: 当前计算日期 YYYYMMDD截面因子使用
current_stock: 当前计算股票代码(时序因子使用)
trade_dates: 交易日历列表(可选,用于对齐)
Examples:
>>> context = FactorContext(current_date="20240101")
>>> context.current_date
'20240101'
"""
current_date: Optional[str] = None
current_stock: Optional[str] = None
trade_dates: Optional[List[str]] = None
class FactorData:
"""提供给因子的数据容器
封装底层 Polars DataFrame提供安全的数据访问接口。
根据因子类型的不同,包含不同的数据:
- CrossSectionalFactor当前日期及历史 lookback 的截面数据(所有股票)
- TimeSeriesFactor单只股票的完整时间序列数据
Args:
df: 底层的 Polars DataFrame
context: 计算上下文
Examples:
>>> df = pl.DataFrame({
... "ts_code": ["000001.SZ"],
... "trade_date": ["20240101"],
... "close": [10.0]
... })
>>> context = FactorContext(current_date="20240101")
>>> data = FactorData(df, context)
"""
def __init__(self, df: pl.DataFrame, context: FactorContext):
self._df = df
self._context = context
def get_column(self, col: str) -> pl.Series:
"""获取指定列的数据
适用于两种因子类型:
- 截面因子:获取当天所有股票的该列值
- 时序因子:获取该股票时间序列的该列值
Args:
col: 列名
Returns:
Polars Series
Raises:
KeyError: 列不存在于数据中
Examples:
>>> prices = data.get_column("close")
>>> print(prices)
"""
if col not in self._df.columns:
raise KeyError(
f"Column '{col}' not found in data. Available columns: {self._df.columns}"
)
return self._df[col]
def filter_by_date(self, date: str) -> "FactorData":
"""按日期过滤数据,返回新的 FactorData
主要用于截面因子获取特定日期的数据。
注意:无法获取未来日期的数据(引擎已经裁剪掉)。
Args:
date: YYYYMMDD 格式的日期
Returns:
过滤后的 FactorData新实例不修改原数据
Examples:
>>> today_data = data.filter_by_date("20240101")
>>> print(len(today_data))
"""
filtered = self._df.filter(pl.col("trade_date") == date)
return FactorData(filtered, self._context)
def get_cross_section(self) -> pl.DataFrame:
"""获取当前日期的截面数据
仅适用于截面因子,返回 current_date 当天的所有股票数据。
Returns:
DataFrame 包含当前日期的所有股票
Raises:
ValueError: current_date 未设置(非截面因子场景)
Examples:
>>> cs = data.get_cross_section()
>>> rankings = cs["pe"].rank()
"""
if self._context.current_date is None:
raise ValueError(
"current_date is not set in context. "
"get_cross_section() is only applicable for cross-sectional factors."
)
return self._df.filter(pl.col("trade_date") == self._context.current_date)
def to_polars(self) -> pl.DataFrame:
"""获取底层的 Polars DataFrame高级用法
返回原始 DataFrame允许进行自定义的 Polars 操作。
注意:直接操作底层数据可能绕过框架的防泄露保护,请谨慎使用。
Returns:
底层的 Polars DataFrame
Examples:
>>> df = data.to_polars()
>>> result = df.group_by("industry").agg(pl.col("pe").mean())
"""
return self._df
@property
def context(self) -> FactorContext:
"""获取计算上下文
Returns:
当前的 FactorContext 实例
Examples:
>>> date = data.context.current_date
>>> stock = data.context.current_stock
"""
return self._context
def __len__(self) -> int:
"""返回数据行数
Returns:
DataFrame 的行数
Examples:
>>> if len(data) > 0:
... result = data.get_column("close").mean()
"""
return len(self._df)
def __repr__(self) -> str:
"""返回 FactorData 的字符串表示
Returns:
包含类名、行数、列数和上下文信息的字符串
"""
cols = self._df.columns
context_info = []
if self._context.current_date:
context_info.append(f"date={self._context.current_date}")
if self._context.current_stock:
context_info.append(f"stock={self._context.current_stock}")
context_str = ", ".join(context_info) if context_info else "no context"
return f"FactorData(rows={len(self)}, cols={len(cols)}, {context_str})"

View File

@@ -1,20 +0,0 @@
"""财务因子模块
本模块提供财务类型的因子:
因子分类:
- financial: 财务因子
- EPSFactor: 每股收益排名因子
已添加因子:
- EPSFactor: 每股收益排名基于basic_eps
待添加因子:
- PERankFactor: 市盈率排名
- PBFactor: 市净率因子
- DividendFactor: 股息率因子
"""
from src.factors.financial.eps_factor import EPSFactor
__all__ = ["EPSFactor"]

View File

@@ -1,66 +0,0 @@
"""EPS因子
每股收益(EPS)排名因子实现
"""
from typing import List
import polars as pl
from src.factors.base import CrossSectionalFactor
from src.factors.data_spec import DataSpec, FactorData
class EPSFactor(CrossSectionalFactor):
"""每股收益(EPS)排名因子
计算逻辑使用最新报告期的basic_eps每天对所有股票进行截面排名
Attributes:
name: 因子名称 "eps_rank"
category: 因子分类 "financial"
data_specs: 数据需求规格
Example:
>>> from src.factors import FactorEngine, DataLoader
>>> from src.factors.financial.eps_factor import EPSFactor
>>> loader = DataLoader('data')
>>> engine = FactorEngine(loader)
>>> eps_factor = EPSFactor()
>>> result = engine.compute(eps_factor, start_date='20210101', end_date='20210131')
"""
name: str = "eps_rank"
category: str = "financial"
description: str = "每股收益截面排名因子"
data_specs: List[DataSpec] = [
DataSpec(
"financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1
)
]
def compute(self, data: FactorData) -> pl.Series:
"""计算EPS排名
Args:
data: FactorData包含当前日期的截面数据
Returns:
EPS排名的0-1标准化值0-1之间
"""
# 获取当前日期的截面数据
cs = data.get_cross_section()
if len(cs) == 0:
return pl.Series(name=self.name, values=[])
# 提取EPS值填充缺失值为0
eps = cs["basic_eps"].fill_null(0)
# 计算排名并归一化到0-1
if len(eps) > 1 and eps.max() != eps.min():
ranks = eps.rank(method="average") / len(eps)
else:
# 数据不足或全部相同返回0.5
ranks = pl.Series(name=self.name, values=[0.5] * len(eps))
return ranks

View File

@@ -1,82 +0,0 @@
"""财务因子工具函数
提供财务数据处理的工具函数:
- expand_period_to_trading_days: 将报告期数据展开到每个交易日(前向填充)
"""
from typing import List
import polars as pl
def expand_period_to_trading_days(
financial_df: pl.DataFrame,
trade_dates: List[str],
) -> pl.DataFrame:
"""将财务数据(报告期粒度)展开到每个交易日(前向填充)
核心逻辑:对于每个交易日,找到该日期之前最新的已公告报告期数据。
例如2020年报(20201231)公告于20210428则在2021-04-28之后的每个
交易日都使用该年报数据直到2021一季报公告。
Args:
financial_df: 财务数据DataFrame包含 ts_code, ann_date, end_date, ...
trade_dates: 交易日列表YYYYMMDD格式已排序
Returns:
DataFrame包含 trade_date, ts_code 和所有财务字段
Example:
>>> financial_df = pl.DataFrame({
... 'ts_code': ['000001.SZ'],
... 'ann_date': ['20210428'],
... 'end_date': ['20210331'],
... 'basic_eps': [0.5]
... })
>>> trade_dates = ['20210428', '20210429', '20210430']
>>> result = expand_period_to_trading_days(financial_df, trade_dates)
>>> print(result)
shape: (3, 5)
┌───────────┬───────────┬────────────┬────────────┬───────────┐
│ ts_code ┆ ann_date ┆ end_date ┆ basic_eps ┆ trade_date│
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ str │
╞═══════════╪═══════════╪════════════╪════════════╪═══════════╡
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210428 │
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210429 │
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210430 │
└───────────┴───────────┴────────────┴────────────┴───────────┘
"""
if len(financial_df) == 0:
return pl.DataFrame()
results = []
# 按股票分组处理
for ts_code in financial_df["ts_code"].unique():
stock_data = financial_df.filter(pl.col("ts_code") == ts_code)
# 按报告期排序end_date升序
stock_data = stock_data.sort("end_date")
rows = []
for trade_date in trade_dates:
# 找到该交易日之前最新的已公告报告期
# 条件1: end_date <= trade_date报告期不晚于交易日
# 条件2: ann_date <= trade_date已公告
applicable = stock_data.filter(
(pl.col("end_date") <= trade_date) & (pl.col("ann_date") <= trade_date)
)
if len(applicable) > 0:
# 取最新的一条end_date最大的
latest = applicable.tail(1).with_columns(
[pl.lit(trade_date).alias("trade_date")]
)
rows.append(latest)
if rows:
results.append(pl.concat(rows))
if results:
return pl.concat(results)
return pl.DataFrame()

View File

@@ -1,19 +0,0 @@
"""动量因子模块
本模块提供动量类型的因子:
- MovingAverageFactor: 移动平均线(时序因子)
- ReturnRankFactor: 收益率排名(截面因子)
因子分类:
- momentum: 动量因子
- ma: 移动平均线
- return_rank: 收益率排名
"""
from src.factors.momentum.ma import MovingAverageFactor
from src.factors.momentum.return_rank import ReturnRankFactor
__all__ = [
"MovingAverageFactor",
"ReturnRankFactor",
]

View File

@@ -1,78 +0,0 @@
"""动量因子 - 移动平均线
本模块提供通用移动平均线因子,支持参数化配置:
- MovingAverageFactor: 移动平均线(时序因子)
使用示例:
>>> from src.factors.momentum import MovingAverageFactor
>>> ma5 = MovingAverageFactor(period=5) # 5日MA
>>> ma10 = MovingAverageFactor(period=10) # 10日MA
>>> ma20 = MovingAverageFactor(period=20) # 20日MA
"""
from typing import List
import polars as pl
from src.factors.base import TimeSeriesFactor
from src.factors.data_spec import DataSpec, FactorData
class MovingAverageFactor(TimeSeriesFactor):
"""移动平均线因子
计算逻辑对每只股票计算其过去n日收盘价的移动平均值。
特点:
- 参数化因子:训练时通过 period 参数指定计算窗口
- 时序因子:每只股票单独计算,防止股票间数据泄露
Attributes:
period: MA计算期天数默认5
Example:
>>> ma5 = MovingAverageFactor(period=5)
>>> # 计算过去5日的收盘价均值
"""
name: str = "ma"
factor_type: str = "time_series"
category: str = "momentum"
description: str = "移动平均线因子计算过去n日收盘价的均值"
data_specs: List[DataSpec] = [
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
]
def __init__(self, period: int = 5):
"""初始化因子
Args:
period: MA计算期天数默认5日
"""
super().__init__(period=period)
# 重新创建 DataSpec 以设置正确的 lookback_daysDataSpec 是 frozen 的)
self.data_specs = [
DataSpec(
"daily",
["ts_code", "trade_date", "close"],
lookback_days=period,
)
]
self.name = f"ma_{period}"
def compute(self, data: FactorData) -> pl.Series:
"""计算移动平均线
Args:
data: FactorData包含单只股票的完整时间序列
Returns:
移动平均值序列
"""
# 获取收盘价序列
close_prices = data.get_column("close")
# 计算移动平均
ma = close_prices.rolling_mean(window_size=self.params["period"])
return ma

View File

@@ -1,100 +0,0 @@
"""动量因子 - 收益率排名
本模块提供收益率排名因子:
- ReturnRankFactor: 过去n日收益率的rank因子截面因子
使用示例:
>>> from src.factors.momentum import ReturnRankFactor
>>> ret5 = ReturnRankFactor(period=5) # 5日收益率排名
>>> ret10 = ReturnRankFactor(period=10) # 10日收益率排名
"""
from typing import List
import polars as pl
from src.factors.base import CrossSectionalFactor
from src.factors.data_spec import DataSpec, FactorData
class ReturnRankFactor(CrossSectionalFactor):
"""过去n日收益率排名因子
计算逻辑每个交易日计算所有股票过去n日的收益率然后进行截面排名。
特点:
- 参数化因子:训练时通过 period 参数指定计算窗口
- 截面因子:每天对所有股票进行横向排名,防止日期泄露
Attributes:
period: 收益率计算期默认5日
Example:
>>> ret5 = ReturnRankFactor(period=5)
>>> # 每个交易日返回所有股票过去5日收益率的排名
"""
name: str = "return_rank"
factor_type: str = "cross_sectional"
category: str = "momentum"
description: str = "过去n日收益率的截面排名因子"
data_specs: List[DataSpec] = [
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
]
def __init__(self, period: int = 5):
"""初始化因子
Args:
period: 收益率计算期(天数)
"""
super().__init__(period=period)
# 重新创建 DataSpec 以设置正确的 lookback_daysDataSpec 是 frozen 的)
self.data_specs = [
DataSpec(
"daily",
["ts_code", "trade_date", "close"],
lookback_days=period + 1,
)
]
self.name = f"return_{period}_rank"
def compute(self, data: FactorData) -> pl.Series:
"""计算过去n日收益率排名
Args:
data: FactorData包含过去n+1天的截面数据
Returns:
过去n日收益率的截面排名0-1之间
"""
# 获取当前日期的截面数据
cs = data.to_polars()
# 获取所有交易日期(已按日期排序)
trade_dates = cs["trade_date"].unique().sort()
if len(trade_dates) < 2:
# 数据不足,返回空排名
return pl.Series(name=self.name, values=[])
# 获取最新日期的数据
latest_date = trade_dates[-1]
current_data = cs.filter(pl.col("trade_date") == latest_date)
# 获取n天前的日期
n_days_ago = trade_dates[-(self.params["period"] + 1)]
past_data = cs.filter(pl.col("trade_date") == n_days_ago)
# 通过 ts_code join 计算收益率
merged = current_data.select(["ts_code", "close"]).join(
past_data.select(["ts_code", "close"]).rename({"close": "close_past"}),
on="ts_code",
how="inner",
)
# 计算收益率
returns = (merged["close"] - merged["close_past"]) / merged["close_past"]
# 返回排名0-1之间
return returns.rank(method="average") / len(returns)

View File

@@ -1,20 +0,0 @@
"""质量因子模块
本模块提供质量类因子:
- 盈利能力ROE、ROA、毛利率、净利率
- 盈利稳定性:盈利波动率、盈利持续性
- 财务健康度:资产负债率、流动比率等
使用示例:
>>> from src.factors.quality import ROEFactor
>>> factor = ROEFactor()
"""
# 在此处导入具体的质量因子
# from .roe import ROEFactor
# from .roa import ROAFactor
# from .profit_stability import ProfitStabilityFactor
__all__ = [
# 添加你的质量因子
]

View File

@@ -1,20 +0,0 @@
"""情绪因子模块
本模块提供市场情绪类因子:
- 换手率、换手率变化率
- 资金流向、主力净流入
- 波动率、振幅等
使用示例:
>>> from src.factors.sentiment import TurnoverFactor
>>> factor = TurnoverFactor(period=20)
"""
# 在此处导入具体的情绪因子
# from .turnover import TurnoverFactor
# from .money_flow import MoneyFlowFactor
# from .amplitude import AmplitudeFactor
__all__ = [
# 添加你的情绪因子
]

View File

@@ -1,20 +0,0 @@
"""技术指标因子模块
本模块提供技术分析类因子:
- 移动平均线(MA)、指数移动平均(EMA)
- 相对强弱指标(RSI)、MACD、KDJ
- 布林带(Bollinger Bands)等
使用示例:
>>> from src.factors.technical import RSIFactor
>>> factor = RSIFactor(period=14)
"""
# 在此处导入具体的技术指标因子
# from .rsi import RSIFactor
# from .macd import MACDFactor
# from .bollinger import BollingerFactor
__all__ = [
# 添加你的技术指标因子
]

View File

@@ -1,18 +0,0 @@
"""估值因子模块
本模块提供估值类因子:
- 市盈率(PE)、市净率(PB)、市销率(PS)等估值指标
- 估值排名、估值分位数等衍生因子
使用示例:
>>> from src.factors.valuation import PERankFactor
>>> factor = PERankFactor()
"""
# 在此处导入具体的估值因子
# from .pe_rank import PERankFactor
# from .pb_rank import PBRankFactor
__all__ = [
# 添加你的估值因子
]

View File

@@ -1,21 +0,0 @@
"""波动率因子模块
本模块提供波动率相关因子:
- 历史波动率(Historical Volatility)
- 实现波动率(Realized Volatility)
- GARCH类波动率预测
- 波动率风险指标等
使用示例:
>>> from src.factors.volatility import HistoricalVolFactor
>>> factor = HistoricalVolFactor(period=20)
"""
# 在此处导入具体的波动率因子
# from .historical_vol import HistoricalVolFactor
# from .realized_vol import RealizedVolFactor
# from .garch_vol import GARCHVolFactor
__all__ = [
# 添加你的波动率因子
]

View File

@@ -1,20 +0,0 @@
"""成交量因子模块
本模块提供成交量相关因子:
- 成交量移动平均
- 成交量比率(VR)、能量潮(OBV)
- 量价配合指标等
使用示例:
>>> from src.factors.volume import OBVFactor
>>> factor = OBVFactor()
"""
# 在此处导入具体的成交量因子
# from .obv import OBVFactor
# from .volume_ratio import VolumeRatioFactor
# from .volume_ma import VolumeMAFactor
__all__ = [
# 添加你的成交量因子
]