refactor(factor): 完全重构因子计算框架 - 引入DSL表达式系统
- 删除旧因子框架:移除 base.py、composite.py、data_loader.py、data_spec.py 及所有子模块(momentum、financial、quality、sentiment等) - 新增DSL表达式系统:实现 factor DSL 编译器和翻译器 - dsl.py: 领域特定语言定义 - compiler.py: AST编译与优化 - translator.py: Polars表达式翻译 - api.py: 统一API接口 - 新增数据路由层:data_router.py 实现字段到表的动态路由 - 新增API封装:api_pro_bar.py 提供pro_bar数据接口 - 更新执行引擎:engine.py 适配新的DSL架构 - 重构测试体系:删除旧测试,新增 test_dsl_promotion.py、 test_factor_integration.py、test_pro_bar.py - 清理文档:删除8个过时文档(factor_design、db_sync_guide等)
This commit is contained in:
@@ -3,3 +3,17 @@
|
||||
提供股票数据分析和交易策略等功能
|
||||
"""
|
||||
__version__ = "1.0.0"
|
||||
|
||||
|
||||
import warnings
|
||||
from pandas.errors import SettingWithCopyWarning
|
||||
|
||||
# 忽略 tushare 库的 FutureWarning(fillna method 参数已弃用)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=FutureWarning,
|
||||
message=".*fillna with 'method' is deprecated.*",
|
||||
)
|
||||
|
||||
# 忽略 SettingWithCopyWarning(常见于 pandas 链式赋值)
|
||||
warnings.filterwarnings("ignore", category=SettingWithCopyWarning)
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
从环境变量加载应用配置,使用pydantic-settings进行类型验证
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
@@ -15,20 +16,33 @@ CONFIG_DIR = PROJECT_ROOT / "config"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置类,从环境变量加载"""
|
||||
"""应用配置类,从环境变量加载
|
||||
|
||||
# 数据库配置
|
||||
所有配置项都会自动从环境变量读取(小写转大写)
|
||||
例如:tushare_token 会读取 TUSHARE_TOKEN 环境变量
|
||||
"""
|
||||
|
||||
# Tushare API 配置
|
||||
tushare_token: str = ""
|
||||
|
||||
# 数据存储配置
|
||||
root_path: str = "" # 项目根路径,默认自动检测
|
||||
data_path: str = "data" # 数据存储路径,相对于 root_path
|
||||
|
||||
# API 速率限制(每分钟请求数)
|
||||
rate_limit: int = 300
|
||||
|
||||
# 同步工作线程数
|
||||
threads: int = 10
|
||||
|
||||
# 数据库配置(可选,用于未来扩展)
|
||||
database_host: str = "localhost"
|
||||
database_port: int = 5432
|
||||
database_name: str = "prostock"
|
||||
database_user: str
|
||||
database_password: str
|
||||
database_user: Optional[str] = None
|
||||
database_password: Optional[str] = None
|
||||
|
||||
# API密钥配置
|
||||
api_key: str
|
||||
secret_key: str
|
||||
|
||||
# Redis配置
|
||||
# Redis配置(可选,用于未来扩展)
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_password: Optional[str] = None
|
||||
@@ -38,11 +52,27 @@ class Settings(BaseSettings):
|
||||
app_debug: bool = False
|
||||
app_port: int = 8000
|
||||
|
||||
@property
|
||||
def project_root(self) -> Path:
|
||||
"""获取项目根路径。"""
|
||||
if self.root_path:
|
||||
return Path(self.root_path)
|
||||
return PROJECT_ROOT
|
||||
|
||||
@property
|
||||
def data_path_resolved(self) -> Path:
|
||||
"""获取解析后的数据路径(绝对路径)。"""
|
||||
path = Path(self.data_path)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
return self.project_root / path
|
||||
|
||||
class Config:
|
||||
# 从 config/ 目录读取 .env.local 文件
|
||||
env_file = str(CONFIG_DIR / ".env.local")
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
extra = "ignore" # 忽略 .env.local 中的额外变量
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Provides simplified interfaces for fetching and storing Tushare data.
|
||||
"""
|
||||
|
||||
from src.data.config import Config, get_config
|
||||
from src.config.settings import Settings, get_settings, settings
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage, ThreadSafeStorage, DEFAULT_TYPE_MAPPING
|
||||
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
|
||||
|
||||
@@ -169,6 +169,120 @@ if "date" in data.columns:
|
||||
|
||||
### 4.5 令牌桶限速要求
|
||||
|
||||
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
|
||||
|
||||
#### 4.5.1 基本用法(单线程场景)
|
||||
|
||||
```python
|
||||
from src.data.client import TushareClient
|
||||
|
||||
def get_{data_type}(...) -> pd.DataFrame:
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
# ...
|
||||
|
||||
# Fetch data (rate limiting handled automatically)
|
||||
data = client.query("{api_name}", **params)
|
||||
|
||||
return data
|
||||
```
|
||||
|
||||
**注意**: `TushareClient` 自动处理:
|
||||
- 令牌桶速率限制
|
||||
- API 重试逻辑(指数退避)
|
||||
- 配置加载
|
||||
|
||||
#### 4.5.2 多线程/并发场景(重要)
|
||||
|
||||
**问题**: 多线程并发调用时,如果每个线程创建独立的 `TushareClient` 实例,每个实例会有独立的限流器,导致实际并发请求数 = 线程数 × 单个限流器速率,**限流失效**。
|
||||
|
||||
**解决方案**: 数据获取函数必须接受可选的 `client` 参数,Sync 类传递共享的客户端实例。
|
||||
|
||||
**数据获取函数签名**(必须支持 client 参数):
|
||||
|
||||
```python
|
||||
from src.data.client import TushareClient
|
||||
from typing import Optional
|
||||
|
||||
def get_{data_type}(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
client: Optional[TushareClient] = None, # 新增:可选客户端参数
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch {数据描述} from Tushare.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code
|
||||
start_date: Start date (YYYYMMDD)
|
||||
end_date: End date (YYYYMMDD)
|
||||
client: Optional TushareClient instance for shared rate limiting.
|
||||
If None, creates a new client. For concurrent sync operations,
|
||||
pass a shared client to ensure proper rate limiting.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with data
|
||||
"""
|
||||
client = client or TushareClient() # 如果没有提供则创建新实例
|
||||
|
||||
params = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
data = client.query("{api_name}", **params)
|
||||
return data
|
||||
```
|
||||
|
||||
**Sync 类实现**(必须传递共享 client):
|
||||
|
||||
```python
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage
|
||||
|
||||
class {DataType}Sync:
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
self.storage = ThreadSafeStorage()
|
||||
self.client = TushareClient() # 共享客户端实例
|
||||
self.max_workers = max_workers or 10
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""同步单只股票的数据。"""
|
||||
# 传递共享 client 以确保多线程下的限流生效
|
||||
data = get_{data_type}(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
client=self.client, # 关键:传递共享客户端
|
||||
)
|
||||
return data
|
||||
|
||||
def sync_all(self, ...):
|
||||
# 使用 ThreadPoolExecutor 并发执行
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 所有线程共享 self.client,限流器正常工作
|
||||
...
|
||||
```
|
||||
|
||||
**关键规则**:
|
||||
1. 所有按股票获取的接口必须接受 `client: Optional[TushareClient] = None` 参数
|
||||
2. Sync 类在 `__init__` 中创建 `self.client = TushareClient()`
|
||||
3. Sync 类的同步方法必须将 `self.client` 传递给数据获取函数
|
||||
4. 数据获取函数使用 `client = client or TushareClient()` 模式
|
||||
|
||||
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求:
|
||||
|
||||
```python
|
||||
@@ -198,6 +312,26 @@ def get_{data_type}(...) -> pd.DataFrame:
|
||||
|
||||
## 5. DuckDB 存储规范
|
||||
|
||||
### 5.0 强制落库要求(关键)
|
||||
|
||||
**所有封装的 API 接口必须将数据落库到 DuckDB。**
|
||||
|
||||
这是数据同步的核心原则,确保:
|
||||
- 数据持久化:避免重复调用 API,节省 token
|
||||
- 增量更新:基于本地数据状态进行智能同步
|
||||
- 数据一致性:所有数据都有统一的存储和访问方式
|
||||
- 离线可用:数据可以在没有网络的情况下查询
|
||||
|
||||
**落库检查清单**:
|
||||
- [ ] 在 `storage.py` 的 `_init_db()` 方法中创建对应的表
|
||||
- [ ] 表结构必须包含 `ts_code` 和 `trade_date` 作为主键
|
||||
- [ ] 实现 `sync_{data_type}()` 函数,使用 `Storage` 或 `ThreadSafeStorage` 保存数据
|
||||
- [ ] 确保同步逻辑正确处理增量更新
|
||||
|
||||
**反例警示**:`api_pro_bar.py` 早期版本虽然实现了 `sync_pro_bar()` 函数,但忘记在 `storage.py` 中创建 `pro_bar` 表,导致同步的数据无法落库,造成 token 浪费和数据丢失。
|
||||
|
||||
### 5.1 存储架构
|
||||
|
||||
### 5.1 存储架构
|
||||
|
||||
项目使用 **DuckDB** 作为持久化存储:
|
||||
|
||||
@@ -5,6 +5,7 @@ All wrapper files follow the naming convention: api_{data_type}.py
|
||||
|
||||
Available APIs:
|
||||
- api_daily: Daily market data (日线行情)
|
||||
- api_pro_bar: Pro Bar universal market data (通用行情,后复权)
|
||||
- api_stock_basic: Stock basic information (股票基本信息)
|
||||
- api_trade_cal: Trading calendar (交易日历)
|
||||
- api_namechange: Stock name change history (股票曾用名)
|
||||
@@ -12,15 +13,31 @@ Available APIs:
|
||||
|
||||
Example:
|
||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
||||
>>> from src.data.api_wrappers import get_bak_basic, sync_bak_basic
|
||||
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar
|
||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> stocks = get_stock_basic()
|
||||
>>> calendar = get_trade_cal('20240101', '20240131')
|
||||
>>> bak_basic = get_bak_basic(trade_date='20240101')
|
||||
"""
|
||||
|
||||
from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync
|
||||
from src.data.api_wrappers.financial_data.api_income import get_income, sync_income, IncomeSync
|
||||
from src.data.api_wrappers.api_daily import (
|
||||
get_daily,
|
||||
sync_daily,
|
||||
preview_daily_sync,
|
||||
DailySync,
|
||||
)
|
||||
from src.data.api_wrappers.api_pro_bar import (
|
||||
get_pro_bar,
|
||||
sync_pro_bar,
|
||||
preview_pro_bar_sync,
|
||||
ProBarSync,
|
||||
)
|
||||
from src.data.api_wrappers.financial_data.api_income import (
|
||||
get_income,
|
||||
sync_income,
|
||||
IncomeSync,
|
||||
)
|
||||
from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic
|
||||
from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange
|
||||
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
|
||||
@@ -38,6 +55,11 @@ __all__ = [
|
||||
"sync_daily",
|
||||
"preview_daily_sync",
|
||||
"DailySync",
|
||||
# Pro Bar (universal market data)
|
||||
"get_pro_bar",
|
||||
"sync_pro_bar",
|
||||
"preview_pro_bar_sync",
|
||||
"ProBarSync",
|
||||
# Income statement
|
||||
"get_income",
|
||||
"sync_income",
|
||||
|
||||
@@ -345,4 +345,154 @@ df = pro.bak_basic(trade_date='20211012', fields='trade_date,ts_code,name,indust
|
||||
4530 20211012 688255.SH 凯尔达 机械基件 0.0000
|
||||
4531 20211012 688211.SH 中科微至 专用机械 0.0000
|
||||
4532 20211012 605567.SH 春雪食品 食品 0.0000
|
||||
4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000
|
||||
4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000
|
||||
|
||||
|
||||
通用行情接口
|
||||
接口名称:pro_bar,本接口是集成开发接口,部分指标是现用现算
|
||||
更新时间:股票和指数通常在15点~17点之间,数字货币实时更新,具体请参考各接口文档明细。
|
||||
描述:目前整合了股票(未复权、前复权、后复权)、指数、数字货币、ETF基金、期货、期权的行情数据,未来还将整合包括外汇在内的所有交易行情数据,同时提供分钟数据。不同数据对应不同的积分要求,具体请参阅每类数据的文档说明。
|
||||
其它:由于本接口是集成接口,在SDK层做了一些逻辑处理,目前暂时没法用http的方式调取通用行情接口。用户可以访问Tushare的Github,查看源代码完成类似功能。
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str Y 证券代码,不支持多值输入,多值输入获取结果会有重复记录
|
||||
start_date str N 开始日期 (日线格式:YYYYMMDD,提取分钟数据请用2019-09-01 09:00:00这种格式)
|
||||
end_date str N 结束日期 (日线格式:YYYYMMDD)
|
||||
asset str Y 资产类别:E股票 I沪深指数 C数字货币 FT期货 FD基金 O期权 CB可转债(v1.2.39),默认E
|
||||
adj str N 复权类型(只针对股票):None未复权 qfq前复权 hfq后复权 , 默认None,目前只支持日线复权,同时复权机制是根据设定的end_date参数动态复权,采用分红再投模式,具体请参考常见问题列表里的说明。
|
||||
freq str Y 数据频度 :支持分钟(min)/日(D)/周(W)/月(M)K线,其中1min表示1分钟(类推1/5/15/30/60分钟) ,默认D。对于分钟数据有600积分用户可以试用(请求2次),正式权限可以参考权限列表说明 ,使用方法请参考股票分钟使用方法。
|
||||
ma list N 均线,支持任意合理int数值。注:均线是动态计算,要设置一定时间范围才能获得相应的均线,比如5日均线,开始和结束日期参数跨度必须要超过5日。目前只支持单一个股票提取均线,即需要输入ts_code参数。e.g: ma_5表示5日均价,ma_v_5表示5日均量
|
||||
factors list N 股票因子(asset='E'有效)支持 tor换手率 vr量比
|
||||
adjfactor str N 复权因子,在复权数据时,如果此参数为True,返回的数据中则带复权因子,默认为False。 该功能从1.2.33版本开始生效
|
||||
|
||||
输出指标
|
||||
|
||||
具体输出的数据指标可参考各行情具体指标:
|
||||
|
||||
股票Daily:https://tushare.pro/document/2?doc_id=27
|
||||
(内容如下:A股日线行情
|
||||
接口:daily,可以通过数据工具调试和查看数据
|
||||
数据说明:交易日每天15点~16点之间入库。本接口是未复权行情,停牌期间不提供数据
|
||||
调取说明:基础积分每分钟内可调取500次,每次6000条数据,一次请求相当于提取一个股票23年历史
|
||||
描述:获取股票行情数据,或通过通用行情接口获取数据,包含了前后复权数据
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str N 股票代码(支持多个股票同时提取,逗号分隔)
|
||||
trade_date str N 交易日期(YYYYMMDD)
|
||||
start_date str N 开始日期(YYYYMMDD)
|
||||
end_date str N 结束日期(YYYYMMDD)
|
||||
注:日期都填YYYYMMDD格式,比如20181010
|
||||
|
||||
输出参数
|
||||
|
||||
名称 类型 描述
|
||||
ts_code str 股票代码
|
||||
trade_date str 交易日期
|
||||
open float 开盘价
|
||||
high float 最高价
|
||||
low float 最低价
|
||||
close float 收盘价
|
||||
pre_close float 昨收价【除权价】
|
||||
change float 涨跌额
|
||||
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
|
||||
vol float 成交量 (手)
|
||||
amount float 成交额 (千元)
|
||||
接口示例
|
||||
|
||||
pro = ts.pro_api()
|
||||
|
||||
df = pro.daily(ts_code='000001.SZ', start_date='20180701', end_date='20180718')
|
||||
|
||||
#多个股票
|
||||
df = pro.daily(ts_code='000001.SZ,600000.SH', start_date='20180701', end_date='20180718')
|
||||
或者
|
||||
|
||||
df = pro.query('daily', ts_code='000001.SZ', start_date='20180701', end_date='20180718')
|
||||
也可以通过日期取历史某一天的全部历史
|
||||
|
||||
df = pro.daily(trade_date='20180810')
|
||||
数据样例
|
||||
|
||||
ts_code trade_date open high low close pre_close change pct_chg vol amount
|
||||
0 000001.SZ 20180718 8.75 8.85 8.69 8.70 8.72 -0.02 -0.23 525152.77 460697.377
|
||||
1 000001.SZ 20180717 8.74 8.75 8.66 8.72 8.73 -0.01 -0.11 375356.33 326396.994
|
||||
2 000001.SZ 20180716 8.85 8.90 8.69 8.73 8.88 -0.15 -1.69 689845.58 603427.713
|
||||
3 000001.SZ 20180713 8.92 8.94 8.82 8.88 8.88 0.00 0.00 603378.21 535401.175
|
||||
4 000001.SZ 20180712 8.60 8.97 8.58 8.88 8.64 0.24 2.78 1140492.31 1008658.828
|
||||
5 000001.SZ 20180711 8.76 8.83 8.68 8.78 8.98 -0.20 -2.23 851296.70 744765.824
|
||||
6 000001.SZ 20180710 9.02 9.02 8.89 8.98 9.03 -0.05 -0.55 896862.02 803038.965
|
||||
7 000001.SZ 20180709 8.69 9.03 8.68 9.03 8.66 0.37 4.27 1409954.60 1255007.609
|
||||
8 000001.SZ 20180706 8.61 8.78 8.45 8.66 8.60 0.06 0.70 988282.69 852071.526
|
||||
9 000001.SZ 20180705 8.62 8.73 8.55 8.60 8.61 -0.01 -0.12 835768.77 722169.579)
|
||||
|
||||
基金Daily:https://tushare.pro/document/2?doc_id=127
|
||||
|
||||
期货Daily:https://tushare.pro/document/2?doc_id=138
|
||||
|
||||
期权Daily:https://tushare.pro/document/2?doc_id=159
|
||||
|
||||
指数Daily:https://tushare.pro/document/2?doc_id=95
|
||||
|
||||
接口用例
|
||||
|
||||
|
||||
#取000001的前复权行情
|
||||
df = ts.pro_bar(ts_code='000001.SZ', adj='qfq', start_date='20180101', end_date='20181011')
|
||||
|
||||
ts_code trade_date open high low close \
|
||||
trade_date
|
||||
20181011 000001.SZ 20181011 1085.71 1097.59 1047.90 1065.19
|
||||
20181010 000001.SZ 20181010 1138.65 1151.61 1121.36 1128.92
|
||||
20181009 000001.SZ 20181009 1130.00 1155.93 1122.44 1140.81
|
||||
20181008 000001.SZ 20181008 1155.93 1165.65 1128.92 1128.92
|
||||
20180928 000001.SZ 20180928 1164.57 1217.51 1164.57 1193.74
|
||||
|
||||
|
||||
|
||||
#取上证指数行情数据
|
||||
|
||||
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
|
||||
|
||||
In [10]: df.head()
|
||||
Out[10]:
|
||||
ts_code trade_date close open high low \
|
||||
0 000001.SH 20181011 2583.4575 2643.0740 2661.2859 2560.3164
|
||||
1 000001.SH 20181010 2725.8367 2723.7242 2743.5480 2703.0626
|
||||
2 000001.SH 20181009 2721.0130 2713.7319 2734.3142 2711.1971
|
||||
3 000001.SH 20181008 2716.5104 2768.2075 2771.9384 2710.1781
|
||||
4 000001.SH 20180928 2821.3501 2794.2644 2821.7553 2791.8363
|
||||
|
||||
pre_close change pct_chg vol amount
|
||||
0 2725.8367 -142.3792 -5.2233 197150702.0 170057762.5
|
||||
1 2721.0130 4.8237 0.1773 113485736.0 111312455.3
|
||||
2 2716.5104 4.5026 0.1657 116771899.0 110292457.8
|
||||
3 2821.3501 -104.8397 -3.7159 149501388.0 141531551.8
|
||||
4 2791.7748 29.5753 1.0594 134290456.0 125369989.4
|
||||
|
||||
|
||||
|
||||
#均线
|
||||
|
||||
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', ma=[5, 20, 50])
|
||||
注:Tushare pro_bar接口的均价和均量数据是动态计算,想要获取某个时间段的均线,必须要设置start_date日期大于最大均线的日期数,然后自行截取想要日期段。例如,想要获取20190801开始的3日均线,必须设置start_date='20190729',然后剔除20190801之前的日期记录。
|
||||
|
||||
|
||||
|
||||
|
||||
#换手率tor,量比vr
|
||||
|
||||
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', factors=['tor', 'vr'])
|
||||
|
||||
|
||||
说明
|
||||
|
||||
对于pro_api参数,如果在一开始就通过 ts.set_token('xxxx') 设置过token的情况,这个参数就不是必需的。
|
||||
|
||||
例如:
|
||||
|
||||
|
||||
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
|
||||
@@ -129,7 +129,9 @@ def sync_bak_basic(
|
||||
columns = []
|
||||
for col in sample.columns:
|
||||
dtype = str(sample[col].dtype)
|
||||
if "int" in dtype:
|
||||
if col == "trade_date":
|
||||
col_type = "DATE"
|
||||
elif "int" in dtype:
|
||||
col_type = "INTEGER"
|
||||
elif "float" in dtype:
|
||||
col_type = "DOUBLE"
|
||||
@@ -223,10 +225,16 @@ def sync_bak_basic(
|
||||
|
||||
# Combine and save
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
|
||||
# Convert trade_date to datetime for proper DATE type storage
|
||||
combined["trade_date"] = pd.to_datetime(combined["trade_date"], format="%Y%m%d")
|
||||
|
||||
print(f"[sync_bak_basic] Total records: {len(combined)}")
|
||||
|
||||
# Delete existing data for the date range and append new data
|
||||
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start])
|
||||
# Convert sync_start to date format for comparison with DATE column
|
||||
sync_start_date = pd.to_datetime(sync_start, format="%Y%m%d").date()
|
||||
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start_date])
|
||||
thread_storage.queue_save(TABLE_NAME, combined)
|
||||
thread_storage.flush()
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import threading
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
|
||||
from src.config.settings import get_settings
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
@@ -105,16 +106,15 @@ class DailySync:
|
||||
- 预览模式(预览同步数据量,不实际写入)
|
||||
"""
|
||||
|
||||
# 默认工作线程数
|
||||
DEFAULT_MAX_WORKERS = 10
|
||||
# 默认工作线程数(从配置读取,默认10)
|
||||
DEFAULT_MAX_WORKERS = get_settings().threads
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""初始化同步管理器。
|
||||
|
||||
Args:
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
max_workers: 工作线程数(默认从配置读取)
|
||||
"""
|
||||
self.storage = ThreadSafeStorage()
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
|
||||
@@ -8,13 +8,13 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from src.data.client import TushareClient
|
||||
from src.data.config import get_config
|
||||
from src.config.settings import get_settings
|
||||
|
||||
|
||||
# CSV file path for namechange data
|
||||
def _get_csv_path() -> Path:
|
||||
"""Get the CSV file path for namechange data."""
|
||||
cfg = get_config()
|
||||
cfg = get_settings()
|
||||
return cfg.data_path_resolved / "namechange.csv"
|
||||
|
||||
|
||||
|
||||
@@ -9,13 +9,13 @@ import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, List
|
||||
from src.data.client import TushareClient
|
||||
from src.data.config import get_config
|
||||
from src.config.settings import get_settings
|
||||
|
||||
|
||||
# CSV file path for stock basic data
|
||||
def _get_csv_path() -> Path:
|
||||
"""Get the CSV file path for stock basic data."""
|
||||
cfg = get_config()
|
||||
cfg = get_settings()
|
||||
return cfg.data_path_resolved / "stock_basic.csv"
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import pandas as pd
|
||||
from typing import Optional, Literal
|
||||
from pathlib import Path
|
||||
from src.data.client import TushareClient
|
||||
from src.data.config import get_config
|
||||
from src.config.settings import get_settings
|
||||
|
||||
# Module-level flag to track if cache has been synced in this session
|
||||
_cache_synced = False
|
||||
@@ -18,7 +18,7 @@ _cache_synced = False
|
||||
# Trading calendar cache file path
|
||||
def _get_cache_path() -> Path:
|
||||
"""Get the cache file path for trade calendar."""
|
||||
cfg = get_config()
|
||||
cfg = get_settings()
|
||||
return cfg.data_path_resolved / "trade_cal.h5"
|
||||
|
||||
|
||||
@@ -296,8 +296,8 @@ def get_first_trading_day(
|
||||
trading_days = get_trading_days(start_date, end_date, exchange)
|
||||
if not trading_days:
|
||||
return None
|
||||
# Trading days are sorted in descending order (newest first) from cache
|
||||
return trading_days[-1]
|
||||
# Return the earliest trading day
|
||||
return min(trading_days)
|
||||
|
||||
|
||||
def get_last_trading_day(
|
||||
@@ -318,8 +318,8 @@ def get_last_trading_day(
|
||||
trading_days = get_trading_days(start_date, end_date, exchange)
|
||||
if not trading_days:
|
||||
return None
|
||||
# Trading days are sorted in descending order (newest first) from cache
|
||||
return trading_days[0]
|
||||
# Return the latest trading day
|
||||
return max(trading_days)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
"""Simplified Tushare client with rate limiting and retry logic."""
|
||||
|
||||
import time
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from src.data.config import get_config
|
||||
from src.data.rate_limiter import TokenBucketRateLimiter
|
||||
from src.config.settings import get_settings
|
||||
|
||||
|
||||
class TushareClient:
|
||||
"""Tushare API client with rate limiting and retry."""
|
||||
|
||||
# 类级别共享限流器(确保所有实例共享同一个限流器)
|
||||
_shared_limiter: Optional[TokenBucketRateLimiter] = None
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
"""Initialize client.
|
||||
|
||||
Args:
|
||||
token: Tushare API token (auto-loaded from config if not provided)
|
||||
"""
|
||||
cfg = get_config()
|
||||
cfg = get_settings()
|
||||
token = token or cfg.tushare_token
|
||||
|
||||
if not token:
|
||||
@@ -24,12 +28,21 @@ class TushareClient:
|
||||
self.token = token
|
||||
self.config = cfg
|
||||
|
||||
# Initialize rate limiter: capacity = rate_limit, refill_rate = rate_limit/60 per second
|
||||
# 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器)
|
||||
rate_per_second = cfg.rate_limit / 60.0
|
||||
self.rate_limiter = TokenBucketRateLimiter(
|
||||
capacity=cfg.rate_limit,
|
||||
refill_rate_per_second=rate_per_second,
|
||||
)
|
||||
capacity = cfg.rate_limit
|
||||
|
||||
if TushareClient._shared_limiter is None:
|
||||
# 首次创建:初始化共享限流器
|
||||
TushareClient._shared_limiter = TokenBucketRateLimiter(
|
||||
capacity=capacity,
|
||||
refill_rate_per_second=rate_per_second,
|
||||
)
|
||||
print(
|
||||
f"[TushareClient] Initialized shared rate limiter: capacity={capacity}, window=60s"
|
||||
)
|
||||
# 复用共享限流器
|
||||
self.rate_limiter = TushareClient._shared_limiter
|
||||
|
||||
self._api = None
|
||||
|
||||
@@ -37,6 +50,7 @@ class TushareClient:
|
||||
"""Get Tushare API instance."""
|
||||
if self._api is None:
|
||||
import tushare as ts
|
||||
|
||||
self._api = ts.pro_api(self.token)
|
||||
return self._api
|
||||
|
||||
@@ -52,7 +66,7 @@ class TushareClient:
|
||||
DataFrame with query results
|
||||
"""
|
||||
# Acquire rate limit token (None = wait indefinitely)
|
||||
timeout = timeout if timeout is not None else float('inf')
|
||||
timeout = timeout if timeout is not None else float("inf")
|
||||
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
|
||||
|
||||
if not success:
|
||||
@@ -72,14 +86,21 @@ class TushareClient:
|
||||
# pro_bar uses ts.pro_bar() instead of api.query()
|
||||
if api_name == "pro_bar":
|
||||
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor
|
||||
data = ts.pro_bar(ts_code=params.get("ts_code"),
|
||||
start_date=params.get("start_date"),
|
||||
end_date=params.get("end_date"),
|
||||
adj=params.get("adj"),
|
||||
freq=params.get("freq", "D"),
|
||||
factors=params.get("factors"), # factors should be a list like ['tor', 'vr']
|
||||
ma=params.get("ma"),
|
||||
adjfactor=params.get("adjfactor"))
|
||||
data = ts.pro_bar(
|
||||
ts_code=params.get("ts_code"),
|
||||
start_date=params.get("start_date"),
|
||||
end_date=params.get("end_date"),
|
||||
adj=params.get("adj"),
|
||||
freq=params.get("freq", "D"),
|
||||
factors=params.get(
|
||||
"factors"
|
||||
), # factors should be a list like ['tor', 'vr']
|
||||
ma=params.get("ma"),
|
||||
adjfactor=params.get("adjfactor"),
|
||||
)
|
||||
# Handle None response (e.g., delisted stock)
|
||||
if data is None:
|
||||
data = pd.DataFrame()
|
||||
else:
|
||||
api = self._get_api()
|
||||
data = api.query(api_name, **params)
|
||||
@@ -89,10 +110,14 @@ class TushareClient:
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
delay = retry_delays[attempt]
|
||||
print(f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s")
|
||||
print(
|
||||
f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise RuntimeError(f"API call failed after {max_retries} attempts: {e}")
|
||||
raise RuntimeError(
|
||||
f"API call failed after {max_retries} attempts: {e}"
|
||||
)
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Configuration management for data collection module."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
# Config directory path - used for loading .env.local
|
||||
# Static detection for pydantic-settings to find .env.local
|
||||
CONFIG_DIR = Path(__file__).parent.parent.parent / "config"
|
||||
|
||||
|
||||
def _get_project_root() -> Path:
|
||||
"""Get project root path from ROOT_PATH env var or auto-detect."""
|
||||
# Try to read from environment variable first
|
||||
root_path = os.environ.get("ROOT_PATH") or os.environ.get("DATA_ROOT")
|
||||
if root_path:
|
||||
return Path(root_path)
|
||||
# Fallback to auto-detection
|
||||
return Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""Application configuration loaded from environment variables."""
|
||||
|
||||
# Tushare API token
|
||||
tushare_token: str = ""
|
||||
|
||||
# Root path - loaded from environment variable ROOT_PATH
|
||||
# If not set, uses auto-detected path
|
||||
root_path: str = ""
|
||||
|
||||
# Data storage path - can be set via DATA_PATH environment variable
|
||||
# If relative path, it will be resolved relative to root_path
|
||||
data_path: str = "data"
|
||||
|
||||
# Rate limit: requests per minute
|
||||
rate_limit: int = 100
|
||||
|
||||
# Thread pool size
|
||||
threads: int = 2
|
||||
|
||||
@property
|
||||
def project_root(self) -> Path:
|
||||
"""Get project root path."""
|
||||
if self.root_path:
|
||||
return Path(self.root_path)
|
||||
return _get_project_root()
|
||||
|
||||
@property
|
||||
def data_path_resolved(self) -> Path:
|
||||
"""Get resolved data path (absolute)."""
|
||||
path = Path(self.data_path)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
return self.project_root / path
|
||||
|
||||
class Config:
|
||||
# 从 config/ 目录读取 .env.local 文件
|
||||
env_file = str(CONFIG_DIR / ".env.local")
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
extra = "ignore" # 忽略 .env.local 中的额外变量
|
||||
# pydantic-settings 默认会将字段名转换为大写作为环境变量名
|
||||
# 所以 tushare_token 会映射到 TUSHARE_TOKEN
|
||||
# root_path 会映射到 ROOT_PATH
|
||||
# data_path 会映射到 DATA_PATH
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get configuration instance."""
|
||||
return config
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""Get project root path (convenience function)."""
|
||||
return get_config().project_root
|
||||
@@ -32,9 +32,12 @@ def get_db_info(db_path: Optional[Path] = None):
|
||||
|
||||
# Get database path
|
||||
if db_path is None:
|
||||
from src.data.config import get_config
|
||||
from src.config.settings import get_settings
|
||||
|
||||
cfg = get_config()
|
||||
cfg = get_settings()
|
||||
db_path = cfg.data_path_resolved / "prostock.db"
|
||||
|
||||
cfg = get_settings()
|
||||
db_path = cfg.data_path_resolved / "prostock.db"
|
||||
else:
|
||||
db_path = Path(db_path)
|
||||
@@ -231,9 +234,12 @@ def get_table_sample(table_name: str, limit: int = 5, db_path: Optional[Path] =
|
||||
db_path: Path to database file
|
||||
"""
|
||||
if db_path is None:
|
||||
from src.data.config import get_config
|
||||
from src.config.settings import get_settings
|
||||
|
||||
cfg = get_config()
|
||||
cfg = get_settings()
|
||||
db_path = cfg.data_path_resolved / "prostock.db"
|
||||
|
||||
cfg = get_settings()
|
||||
db_path = cfg.data_path_resolved / "prostock.db"
|
||||
else:
|
||||
db_path = Path(db_path)
|
||||
|
||||
@@ -1,35 +1,35 @@
|
||||
"""Token bucket rate limiter implementation.
|
||||
"""API 速率限制器实现。
|
||||
|
||||
This module provides a thread-safe token bucket algorithm for rate limiting.
|
||||
提供基于固定时间窗口的速率限制,适合 Tushare 等按分钟计费的 API。
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimiterStats:
|
||||
"""Statistics for rate limiter."""
|
||||
"""速率限制器统计信息。"""
|
||||
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
denied_requests: int = 0
|
||||
total_wait_time: float = 0.0
|
||||
current_tokens: Optional[float] = None
|
||||
current_window_requests: int = 0
|
||||
window_start_time: float = 0.0
|
||||
|
||||
|
||||
class TokenBucketRateLimiter:
|
||||
"""Thread-safe token bucket rate limiter.
|
||||
"""基于固定时间窗口的速率限制器。
|
||||
|
||||
Implements a token bucket algorithm for controlling request rate.
|
||||
Tokens are added at a fixed rate up to the bucket capacity.
|
||||
适合 Tushare 等按时间窗口(如每分钟)限制请求数的 API 场景。
|
||||
在窗口期内,请求数达到上限后将阻塞或等待下一个窗口。
|
||||
|
||||
Attributes:
|
||||
capacity: Maximum number of tokens in the bucket
|
||||
refill_rate: Number of tokens added per second
|
||||
initial_tokens: Initial number of tokens (default: capacity)
|
||||
capacity: 每个时间窗口内允许的最大请求数
|
||||
window_seconds: 时间窗口长度(秒)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -38,155 +38,157 @@ class TokenBucketRateLimiter:
|
||||
refill_rate_per_second: float = 1.67,
|
||||
initial_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize the token bucket rate limiter.
|
||||
"""初始化速率限制器。
|
||||
|
||||
Args:
|
||||
capacity: Maximum token capacity
|
||||
refill_rate_per_second: Token refill rate per second
|
||||
initial_tokens: Initial token count (default: capacity)
|
||||
capacity: 每个时间窗口内允许的最大请求数
|
||||
refill_rate_per_second: 保留参数(向后兼容),实际使用 window_seconds=60
|
||||
initial_tokens: 保留参数(向后兼容)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate_per_second
|
||||
self.tokens = float(initial_tokens if initial_tokens is not None else capacity)
|
||||
self.last_refill_time = time.monotonic()
|
||||
# Tushare 通常按分钟限制,所以固定使用 60 秒窗口
|
||||
self.window_seconds = 60.0
|
||||
|
||||
self._requests_in_window = 0
|
||||
self._window_start = time.monotonic()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = RateLimiterStats()
|
||||
self._stats.current_tokens = self.tokens
|
||||
self._stats.window_start_time = self._window_start
|
||||
|
||||
def _is_new_window(self) -> bool:
|
||||
"""检查是否已进入新的时间窗口。"""
|
||||
current_time = time.monotonic()
|
||||
elapsed = current_time - self._window_start
|
||||
return elapsed >= self.window_seconds
|
||||
|
||||
def _reset_window(self) -> None:
|
||||
"""重置时间窗口。"""
|
||||
self._window_start = time.monotonic()
|
||||
self._requests_in_window = 0
|
||||
self._stats.window_start_time = self._window_start
|
||||
|
||||
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
|
||||
"""Acquire a token from the bucket.
|
||||
"""获取请求许可。
|
||||
|
||||
Blocks until a token is available or timeout expires.
|
||||
如果在当前窗口内请求数已达上限,则等待到下一个窗口。
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for a token in seconds (default: inf)
|
||||
timeout: 最大等待时间(秒),默认无限等待
|
||||
|
||||
Returns:
|
||||
Tuple of (success, wait_time):
|
||||
- success: True if token was acquired, False if timed out
|
||||
- wait_time: Time spent waiting for token
|
||||
(success, wait_time): 是否成功获取许可,以及等待时间
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
wait_time = 0.0
|
||||
|
||||
with self._lock:
|
||||
self._refill()
|
||||
# 检查是否需要进入新窗口
|
||||
if self._is_new_window():
|
||||
self._reset_window()
|
||||
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
# 如果当前窗口还有余量,直接通过
|
||||
if self._requests_in_window < self.capacity:
|
||||
self._requests_in_window += 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_tokens = self.tokens
|
||||
self._stats.current_window_requests = self._requests_in_window
|
||||
return True, 0.0
|
||||
|
||||
# Calculate time to wait for next token
|
||||
tokens_needed = 1 - self.tokens
|
||||
time_to_refill = tokens_needed / self.refill_rate
|
||||
# 当前窗口已满,计算需要等待的时间
|
||||
current_time = time.monotonic()
|
||||
time_to_next_window = self.window_seconds - (
|
||||
current_time - self._window_start
|
||||
)
|
||||
|
||||
# Check if we can wait for the token within timeout
|
||||
# Handle infinite timeout specially
|
||||
is_infinite_timeout = timeout == float("inf")
|
||||
if not is_infinite_timeout and time_to_refill > timeout:
|
||||
if time_to_next_window <= 0:
|
||||
# 刚好进入新窗口
|
||||
self._reset_window()
|
||||
self._requests_in_window = 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_window_requests = 1
|
||||
return True, 0.0
|
||||
|
||||
# 检查是否能在超时时间内等待
|
||||
if timeout != float("inf") and time_to_next_window > timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, timeout
|
||||
|
||||
# Wait for tokens - loop until we get one or timeout
|
||||
while True:
|
||||
# Calculate remaining time we can wait
|
||||
elapsed = time.monotonic() - start_time
|
||||
remaining_timeout = (
|
||||
timeout - elapsed if not is_infinite_timeout else float("inf")
|
||||
)
|
||||
# 需要等待到下一个窗口
|
||||
if timeout != float("inf"):
|
||||
time_to_wait = min(time_to_next_window, timeout)
|
||||
else:
|
||||
time_to_wait = time_to_next_window
|
||||
|
||||
# Check if we've exceeded timeout
|
||||
if not is_infinite_timeout and remaining_timeout <= 0:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, elapsed
|
||||
time.sleep(time_to_wait)
|
||||
|
||||
# Calculate wait time for next token
|
||||
tokens_needed = max(0, 1 - self.tokens)
|
||||
time_to_wait = (
|
||||
tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1
|
||||
)
|
||||
|
||||
# If we can't wait long enough, fail
|
||||
if not is_infinite_timeout and time_to_wait > remaining_timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, elapsed
|
||||
|
||||
# Wait outside the lock to allow other threads to refill
|
||||
self._lock.release()
|
||||
time.sleep(
|
||||
min(time_to_wait, 0.1)
|
||||
) # Cap wait to 100ms to check frequently
|
||||
self._lock.acquire()
|
||||
|
||||
# Refill and check again
|
||||
self._refill()
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
wait_time = time.monotonic() - start_time
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.total_wait_time += wait_time
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, wait_time
|
||||
|
||||
def acquire_nonblocking(self) -> tuple[bool, float]:
|
||||
"""Try to acquire a token without blocking.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, wait_time):
|
||||
- success: True if token was acquired, False otherwise
|
||||
- wait_time: 0 for non-blocking, or required wait time if failed
|
||||
"""
|
||||
# 重新尝试获取许可
|
||||
with self._lock:
|
||||
self._refill()
|
||||
# 再次检查窗口状态(可能其他线程已经重置了窗口)
|
||||
if self._is_new_window():
|
||||
self._reset_window()
|
||||
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
if self._requests_in_window < self.capacity:
|
||||
self._requests_in_window += 1
|
||||
wait_time = time.monotonic() - start_time
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_tokens = self.tokens
|
||||
self._stats.total_wait_time += wait_time
|
||||
self._stats.current_window_requests = self._requests_in_window
|
||||
return True, wait_time
|
||||
else:
|
||||
# 在极端情况下,等待后仍然无法获取(其他线程抢先)
|
||||
wait_time = time.monotonic() - start_time
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, wait_time
|
||||
|
||||
def acquire_nonblocking(self) -> tuple[bool, float]:
|
||||
"""尝试非阻塞地获取请求许可。
|
||||
|
||||
Returns:
|
||||
(success, wait_time): 是否成功获取许可,以及需要等待的时间
|
||||
"""
|
||||
with self._lock:
|
||||
# 检查是否需要进入新窗口
|
||||
if self._is_new_window():
|
||||
self._reset_window()
|
||||
|
||||
# 如果当前窗口还有余量,直接通过
|
||||
if self._requests_in_window < self.capacity:
|
||||
self._requests_in_window += 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_window_requests = self._requests_in_window
|
||||
return True, 0.0
|
||||
|
||||
# Calculate time needed
|
||||
tokens_needed = 1 - self.tokens
|
||||
time_to_refill = tokens_needed / self.refill_rate
|
||||
# 当前窗口已满,计算需要等待的时间
|
||||
current_time = time.monotonic()
|
||||
time_to_next_window = self.window_seconds - (
|
||||
current_time - self._window_start
|
||||
)
|
||||
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, time_to_refill
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""Refill tokens based on elapsed time."""
|
||||
current_time = time.monotonic()
|
||||
elapsed = current_time - self.last_refill_time
|
||||
self.last_refill_time = current_time
|
||||
|
||||
tokens_to_add = elapsed * self.refill_rate
|
||||
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
||||
return False, max(0.0, time_to_next_window)
|
||||
|
||||
def get_available_tokens(self) -> float:
|
||||
"""Get the current number of available tokens.
|
||||
"""获取当前窗口剩余可用请求数。
|
||||
|
||||
Returns:
|
||||
Current token count
|
||||
当前窗口剩余可用请求数
|
||||
"""
|
||||
with self._lock:
|
||||
self._refill()
|
||||
return self.tokens
|
||||
if self._is_new_window():
|
||||
return float(self.capacity)
|
||||
return float(self.capacity - self._requests_in_window)
|
||||
|
||||
def get_stats(self) -> RateLimiterStats:
|
||||
"""Get rate limiter statistics.
|
||||
"""获取速率限制器统计信息。
|
||||
|
||||
Returns:
|
||||
RateLimiterStats instance
|
||||
RateLimiterStats 实例
|
||||
"""
|
||||
with self._lock:
|
||||
self._refill()
|
||||
self._stats.current_tokens = self.tokens
|
||||
self._stats.current_window_requests = self._requests_in_window
|
||||
return self._stats
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from src.data.config import get_config
|
||||
from src.config.settings import get_settings
|
||||
|
||||
|
||||
# Default column type mapping for automatic schema inference
|
||||
@@ -53,7 +53,7 @@ class Storage:
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
cfg = get_config()
|
||||
cfg = get_settings()
|
||||
self.base_path = path or cfg.data_path_resolved
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = self.base_path / "prostock.db"
|
||||
@@ -190,6 +190,26 @@ class Storage:
|
||||
update_flag VARCHAR(1),
|
||||
PRIMARY KEY (ts_code, end_date)
|
||||
)
|
||||
|
||||
# Create pro_bar table for pro bar data (with adj, tor, vr)
|
||||
self._connection.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pro_bar (
|
||||
ts_code VARCHAR(16) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
open DOUBLE,
|
||||
high DOUBLE,
|
||||
low DOUBLE,
|
||||
close DOUBLE,
|
||||
pre_close DOUBLE,
|
||||
change DOUBLE,
|
||||
pct_chg DOUBLE,
|
||||
vol DOUBLE,
|
||||
amount DOUBLE,
|
||||
tor DOUBLE,
|
||||
vr DOUBLE,
|
||||
adj_factor DOUBLE,
|
||||
PRIMARY KEY (ts_code, trade_date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create index for financial_income
|
||||
|
||||
@@ -29,6 +29,7 @@ import pandas as pd
|
||||
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
||||
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
|
||||
|
||||
|
||||
def preview_sync(
|
||||
@@ -134,7 +135,6 @@ def sync_all_data(
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步所有数据类型(每日同步)。
|
||||
|
||||
该函数按顺序同步所有可用的数据类型:
|
||||
1. 交易日历 (sync_trade_cal_cache)
|
||||
2. 股票基本信息 (sync_all_stocks)
|
||||
@@ -146,13 +146,12 @@ def sync_all_data(
|
||||
Args:
|
||||
force_full: 若为 True,强制所有数据类型完整重载
|
||||
max_workers: 日线数据同步的工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅显示将要同步的内容 Returns:
|
||||
映射数据类型,不写入数据
|
||||
dry_run: 若为 True,仅显示将要同步的内容,不写入数据
|
||||
|
||||
到同步结果的字典
|
||||
Returns:
|
||||
映射数据类型到同步结果的字典
|
||||
|
||||
Example:
|
||||
>>> # 同步所有数据(增量)
|
||||
>>> result = sync_all_data()
|
||||
>>>
|
||||
>>> # 强制完整重载
|
||||
@@ -167,6 +166,92 @@ def sync_all_data(
|
||||
print("[sync_all_data] Starting full data synchronization...")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Sync trade calendar (always needed first)
|
||||
print("\n[1/6] Syncing trade calendar cache...")
|
||||
try:
|
||||
from src.data.api_wrappers import sync_trade_cal_cache
|
||||
|
||||
sync_trade_cal_cache()
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
print("[1/6] Trade calendar: OK")
|
||||
except Exception as e:
|
||||
print(f"[1/6] Trade calendar: FAILED - {e}")
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
|
||||
# 2. Sync stock basic info
|
||||
print("\n[2/6] Syncing stock basic info...")
|
||||
try:
|
||||
sync_all_stocks()
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
print("[2/6] Stock basic: OK")
|
||||
except Exception as e:
|
||||
print(f"[2/6] Stock basic: FAILED - {e}")
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
|
||||
# # 3. Sync daily market data
|
||||
# print("\n[3/6] Syncing daily market data...")
|
||||
# try:
|
||||
# daily_result = sync_daily(
|
||||
# force_full=force_full,
|
||||
# max_workers=max_workers,
|
||||
# dry_run=dry_run,
|
||||
# )
|
||||
# results["daily"] = (
|
||||
# pd.concat(daily_result.values(), ignore_index=True)
|
||||
# if daily_result
|
||||
# else pd.DataFrame()
|
||||
# )
|
||||
# print("[3/6] Daily data: OK")
|
||||
# except Exception as e:
|
||||
# print(f"[3/6] Daily data: FAILED - {e}")
|
||||
# results["daily"] = pd.DataFrame()
|
||||
|
||||
# 4. Sync Pro Bar data
|
||||
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...")
|
||||
try:
|
||||
pro_bar_result = sync_pro_bar(
|
||||
force_full=force_full,
|
||||
max_workers=max_workers,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
results["pro_bar"] = (
|
||||
pd.concat(pro_bar_result.values(), ignore_index=True)
|
||||
if pro_bar_result
|
||||
else pd.DataFrame()
|
||||
)
|
||||
print(f"[4/6] Pro Bar data: OK ({len(results['pro_bar'])} records)")
|
||||
except Exception as e:
|
||||
print(f"[4/6] Pro Bar data: FAILED - {e}")
|
||||
results["pro_bar"] = pd.DataFrame()
|
||||
|
||||
# 5. Sync stock historical list (bak_basic)
|
||||
print("\n[5/6] Syncing stock historical list (bak_basic)...")
|
||||
try:
|
||||
bak_basic_result = sync_bak_basic(force_full=force_full)
|
||||
results["bak_basic"] = bak_basic_result
|
||||
print(f"[5/6] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[5/6] Bak basic: FAILED - {e}")
|
||||
results["bak_basic"] = pd.DataFrame()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("[sync_all_data] Sync Summary")
|
||||
print("=" * 60)
|
||||
for data_type, df in results.items():
|
||||
print(f" {data_type}: {len(df)} records")
|
||||
print("=" * 60)
|
||||
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
|
||||
print(" from src.data.api_wrappers import sync_namechange")
|
||||
print(" sync_namechange(force=True)")
|
||||
|
||||
return results
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("[sync_all_data] Starting full data synchronization...")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Sync trade calendar (always needed first)
|
||||
print("\n[1/5] Syncing trade calendar cache...")
|
||||
try:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,118 +0,0 @@
|
||||
"""ProStock 因子计算框架
|
||||
|
||||
因子框架提供以下核心功能:
|
||||
1. 类型安全的因子定义(截面因子、时序因子)
|
||||
2. 数据泄露防护机制
|
||||
3. 因子组合和运算
|
||||
4. 高效的数据加载和计算引擎
|
||||
|
||||
基础数据类型(Phase 1):
|
||||
- DataSpec: 数据需求规格
|
||||
- FactorContext: 计算上下文
|
||||
- FactorData: 数据容器
|
||||
|
||||
因子基类(Phase 2):
|
||||
- BaseFactor: 抽象基类
|
||||
- CrossSectionalFactor: 日期截面因子基类
|
||||
- TimeSeriesFactor: 时间序列因子基类
|
||||
- CompositeFactor: 组合因子
|
||||
- ScalarFactor: 标量运算因子
|
||||
|
||||
因子分类目录:
|
||||
- momentum/: 动量因子(MA、收益率排名等)
|
||||
- financial/: 财务因子(EPS、ROE等)
|
||||
- valuation/: 估值因子(PE、PB、PS等)
|
||||
- technical/: 技术指标因子(RSI、MACD、布林带等)
|
||||
- quality/: 质量因子(盈利能力、稳定性等)
|
||||
- sentiment/: 情绪因子(换手率、资金流向等)
|
||||
- volume/: 成交量因子(OBV、成交量比率等)
|
||||
- volatility/: 波动率因子(历史波动率、GARCH等)
|
||||
|
||||
数据加载和执行(Phase 3-4):
|
||||
- DataLoader: 数据加载器
|
||||
- FactorEngine: 因子执行引擎
|
||||
|
||||
使用示例:
|
||||
# 使用通用因子(参数化)
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
from src.factors import DataLoader, FactorEngine
|
||||
|
||||
ma5 = MovingAverageFactor(period=5) # 5日MA
|
||||
ma10 = MovingAverageFactor(period=10) # 10日MA
|
||||
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
|
||||
|
||||
loader = DataLoader(data_dir="data")
|
||||
engine = FactorEngine(loader)
|
||||
result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131")
|
||||
"""
|
||||
|
||||
因子框架提供以下核心功能:
|
||||
1. 类型安全的因子定义(截面因子、时序因子)
|
||||
2. 数据泄露防护机制
|
||||
3. 因子组合和运算
|
||||
4. 高效的数据加载和计算引擎
|
||||
|
||||
基础数据类型(Phase 1):
|
||||
- DataSpec: 数据需求规格
|
||||
- FactorContext: 计算上下文
|
||||
- FactorData: 数据容器
|
||||
|
||||
因子基类(Phase 2):
|
||||
- BaseFactor: 抽象基类
|
||||
- CrossSectionalFactor: 日期截面因子基类
|
||||
- TimeSeriesFactor: 时间序列因子基类
|
||||
- CompositeFactor: 组合因子
|
||||
- ScalarFactor: 标量运算因子
|
||||
|
||||
动量因子(momentum/):
|
||||
- MovingAverageFactor: 移动平均线(时序因子)
|
||||
- ReturnRankFactor: 收益率排名(截面因子)
|
||||
|
||||
财务因子(financial/):
|
||||
- (待添加)
|
||||
|
||||
数据加载和执行(Phase 3-4):
|
||||
- DataLoader: 数据加载器
|
||||
- FactorEngine: 因子执行引擎
|
||||
|
||||
使用示例:
|
||||
# 使用通用因子(参数化)
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
from src.factors import DataLoader, FactorEngine
|
||||
|
||||
ma5 = MovingAverageFactor(period=5) # 5日MA
|
||||
ma10 = MovingAverageFactor(period=10) # 10日MA
|
||||
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
|
||||
|
||||
loader = DataLoader(data_dir="data")
|
||||
engine = FactorEngine(loader)
|
||||
result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131")
|
||||
"""
|
||||
|
||||
from src.factors.data_spec import DataSpec, FactorContext, FactorData
|
||||
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
|
||||
from src.factors.composite import CompositeFactor, ScalarFactor
|
||||
from src.factors.data_loader import DataLoader
|
||||
from src.factors.engine import FactorEngine
|
||||
|
||||
# 动量因子
|
||||
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
|
||||
|
||||
__all__ = [
|
||||
# Phase 1: 数据类型定义
|
||||
"DataSpec",
|
||||
"FactorContext",
|
||||
"FactorData",
|
||||
# Phase 2: 因子基类
|
||||
"BaseFactor",
|
||||
"CrossSectionalFactor",
|
||||
"TimeSeriesFactor",
|
||||
"CompositeFactor",
|
||||
"ScalarFactor",
|
||||
# Phase 3-4: 数据加载和执行引擎
|
||||
"DataLoader",
|
||||
"FactorEngine",
|
||||
# 动量因子
|
||||
"MovingAverageFactor",
|
||||
"ReturnRankFactor",
|
||||
]
|
||||
@@ -1,274 +0,0 @@
|
||||
"""因子基类 - Phase 2 核心抽象类
|
||||
|
||||
本模块定义了因子框架的基类:
|
||||
- BaseFactor: 抽象基类,定义通用接口和验证逻辑
|
||||
- CrossSectionalFactor: 日期截面因子基类(防止日期泄露)
|
||||
- TimeSeriesFactor: 时间序列因子基类(防止股票泄露)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import field
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
|
||||
|
||||
class BaseFactor(ABC):
|
||||
"""因子基类 - 定义通用接口
|
||||
|
||||
所有因子必须继承此类,并声明以下类属性:
|
||||
- name: 因子唯一标识(snake_case)
|
||||
- factor_type: "cross_sectional" 或 "time_series"
|
||||
- data_specs: List[DataSpec] 数据需求列表
|
||||
|
||||
可选声明:
|
||||
- category: 因子分类(默认 "default")
|
||||
- description: 因子描述
|
||||
|
||||
示例:
|
||||
>>> class MyFactor(CrossSectionalFactor):
|
||||
... name = "my_factor"
|
||||
... data_specs = [DataSpec("daily", ["close"], lookback_days=5)]
|
||||
...
|
||||
... def compute(self, data: FactorData) -> pl.Series:
|
||||
... return data.get_column("close").rank()
|
||||
"""
|
||||
|
||||
# 必须声明的类属性
|
||||
name: str = ""
|
||||
factor_type: str = "" # "cross_sectional" | "time_series"
|
||||
data_specs: List[DataSpec] = field(default_factory=list)
|
||||
|
||||
# 可选声明的类属性
|
||||
category: str = "default"
|
||||
description: str = ""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""子类创建时验证必须属性
|
||||
|
||||
验证项:
|
||||
1. name 必须是非空字符串
|
||||
2. factor_type 必须是 "cross_sectional" 或 "time_series"
|
||||
3. data_specs 必须是非空列表
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 跳过抽象基类和特殊因子类的验证
|
||||
if cls.__name__ in (
|
||||
"CrossSectionalFactor",
|
||||
"TimeSeriesFactor",
|
||||
"CompositeFactor",
|
||||
"ScalarFactor",
|
||||
):
|
||||
return
|
||||
|
||||
# 验证 name - 必须直接定义在类中(不能继承)
|
||||
if "name" not in cls.__dict__ or not cls.name:
|
||||
raise ValueError(f"Factor {cls.__name__} must define 'name'")
|
||||
if not isinstance(cls.name, str):
|
||||
raise ValueError(f"Factor {cls.__name__}.name must be a string")
|
||||
|
||||
# 验证 factor_type - 必须有值(可以是继承的)
|
||||
if not cls.factor_type:
|
||||
raise ValueError(f"Factor {cls.__name__} must define 'factor_type'")
|
||||
if cls.factor_type not in ("cross_sectional", "time_series"):
|
||||
raise ValueError(
|
||||
f"Factor {cls.__name__}.factor_type must be 'cross_sectional' "
|
||||
f"or 'time_series', got '{cls.factor_type}'"
|
||||
)
|
||||
|
||||
# 验证 data_specs
|
||||
# 情况1: 完全没有定义 data_specs(继承的空列表)
|
||||
if "data_specs" not in cls.__dict__:
|
||||
raise ValueError(f"Factor {cls.__name__} must define 'data_specs'")
|
||||
# 情况2: 定义了但为空列表
|
||||
if not cls.data_specs or len(cls.data_specs) == 0:
|
||||
raise ValueError(f"Factor {cls.__name__}.data_specs cannot be empty")
|
||||
if not isinstance(cls.data_specs, list):
|
||||
raise ValueError(f"Factor {cls.__name__}.data_specs must be a list")
|
||||
|
||||
def __init__(self, **params):
|
||||
"""初始化因子参数
|
||||
|
||||
子类可通过 __init__ 接收参数化配置,如 MA(period=20)
|
||||
|
||||
注意:data_specs 必须在类级别定义(类属性),
|
||||
而非在 __init__ 中设置。data_specs 的验证在
|
||||
__init_subclass__ 中完成(类创建时)。
|
||||
|
||||
Args:
|
||||
**params: 因子参数,存储在 self.params 中
|
||||
"""
|
||||
self.params = params
|
||||
|
||||
def _validate_params(self):
|
||||
"""验证参数有效性
|
||||
|
||||
子类可覆盖此方法进行自定义验证(需自行在子类 __init__ 中调用)。
|
||||
基类实现为空,表示不执行任何验证。
|
||||
|
||||
注意:由于 data_specs 在类创建时通过 __init_subclass__ 验证,
|
||||
不应在实例级别修改。如需动态 data_specs,请使用参数化模式:
|
||||
|
||||
>>> class ParamFactor(TimeSeriesFactor):
|
||||
... name = "param_factor"
|
||||
... data_specs = [] # 类级别定义
|
||||
...
|
||||
... def __init__(self, period: int = 20):
|
||||
... super().__init__(period=period)
|
||||
... # 通过参数化改变计算逻辑,而非 data_specs
|
||||
...
|
||||
... def compute(self, data: FactorData) -> pl.Series:
|
||||
... return data.get_column("close").rolling_mean(self.params["period"])
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""核心计算逻辑 - 子类必须实现
|
||||
|
||||
Args:
|
||||
data: 安全的数据容器,已根据因子类型裁剪
|
||||
|
||||
Returns:
|
||||
计算得到的因子值 Series
|
||||
"""
|
||||
pass
|
||||
|
||||
# ========== 因子组合运算符 ==========
|
||||
|
||||
def __add__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相加:f1 + f2(要求同类型)"""
|
||||
from src.factors.composite import CompositeFactor
|
||||
|
||||
return CompositeFactor(self, other, "+")
|
||||
|
||||
def __sub__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相减:f1 - f2(要求同类型)"""
|
||||
from src.factors.composite import CompositeFactor
|
||||
|
||||
return CompositeFactor(self, other, "-")
|
||||
|
||||
def __mul__(self, other):
|
||||
"""因子相乘:f1 * f2 或 f1 * scalar"""
|
||||
if isinstance(other, (int, float)):
|
||||
from src.factors.composite import ScalarFactor
|
||||
|
||||
return ScalarFactor(self, float(other), "*")
|
||||
elif isinstance(other, BaseFactor):
|
||||
from src.factors.composite import CompositeFactor
|
||||
|
||||
return CompositeFactor(self, other, "*")
|
||||
return NotImplemented
|
||||
|
||||
def __truediv__(self, other: "BaseFactor") -> "CompositeFactor":
|
||||
"""因子相除:f1 / f2(要求同类型)"""
|
||||
from src.factors.composite import CompositeFactor
|
||||
|
||||
return CompositeFactor(self, other, "/")
|
||||
|
||||
def __rmul__(self, scalar: float) -> "ScalarFactor":
|
||||
"""标量乘法:0.5 * f1"""
|
||||
from src.factors.composite import ScalarFactor
|
||||
|
||||
return ScalarFactor(self, scalar, "*")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回因子的字符串表示"""
|
||||
return (
|
||||
f"{self.__class__.__name__}(name='{self.name}', type='{self.factor_type}')"
|
||||
)
|
||||
|
||||
|
||||
class CrossSectionalFactor(BaseFactor):
|
||||
"""日期截面因子基类
|
||||
|
||||
计算逻辑:在每个交易日,对所有股票进行横向计算
|
||||
|
||||
防泄露边界:
|
||||
- ❌ 禁止访问未来日期的数据(日期泄露)
|
||||
- ✅ 允许访问当前日期的所有股票数据
|
||||
|
||||
数据传入:
|
||||
- compute() 接收的是 [T-lookback+1, T] 的数据
|
||||
- 包含 lookback_days 的历史数据(用于时序计算后再截面)
|
||||
|
||||
示例:
|
||||
>>> class PERankFactor(CrossSectionalFactor):
|
||||
... name = "pe_rank"
|
||||
... data_specs = [DataSpec("daily", ["pe"], lookback_days=1)]
|
||||
...
|
||||
... def compute(self, data: FactorData) -> pl.Series:
|
||||
... cs = data.get_cross_section()
|
||||
... return cs["pe"].rank()
|
||||
"""
|
||||
|
||||
factor_type: str = "cross_sectional"
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算截面因子值
|
||||
|
||||
Args:
|
||||
data: FactorData,包含 [T-lookback+1, T] 的截面数据
|
||||
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
|
||||
|
||||
Returns:
|
||||
pl.Series: 当前日期所有股票的因子值(长度 = 该日股票数量)
|
||||
|
||||
示例:
|
||||
>>> def compute(self, data):
|
||||
... # 获取当前日期截面
|
||||
... cs = data.get_cross_section()
|
||||
... # 计算市值排名
|
||||
... return cs['market_cap'].rank()
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TimeSeriesFactor(BaseFactor):
|
||||
"""时间序列因子基类(股票截面)
|
||||
|
||||
计算逻辑:对每只股票,在其时间序列上进行纵向计算
|
||||
|
||||
防泄露边界:
|
||||
- ❌ 禁止访问其他股票的数据(股票泄露)
|
||||
- ✅ 允许访问该股票的完整历史数据
|
||||
|
||||
数据传入:
|
||||
- compute() 接收的是单只股票的完整时间序列
|
||||
- 包含该股票在 [start_date, end_date] 范围内的所有数据
|
||||
|
||||
示例:
|
||||
>>> class MovingAverageFactor(TimeSeriesFactor):
|
||||
... name = "ma"
|
||||
...
|
||||
... def __init__(self, period: int = 20):
|
||||
... super().__init__(period=period)
|
||||
... self.data_specs = [DataSpec("daily", ["close"], lookback_days=period)]
|
||||
...
|
||||
... def compute(self, data: FactorData) -> pl.Series:
|
||||
... return data.get_column("close").rolling_mean(self.params["period"])
|
||||
"""
|
||||
|
||||
factor_type: str = "time_series"
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算时间序列因子值
|
||||
|
||||
Args:
|
||||
data: FactorData,包含单只股票的完整时间序列
|
||||
格式:DataFrame[ts_code, trade_date, col1, col2, ...]
|
||||
|
||||
Returns:
|
||||
pl.Series: 该股票在各日期的因子值(长度 = 日期数量)
|
||||
|
||||
示例:
|
||||
>>> def compute(self, data):
|
||||
... series = data.get_column("close")
|
||||
... return series.rolling_mean(window_size=self.params['period'])
|
||||
"""
|
||||
pass
|
||||
@@ -1,201 +0,0 @@
|
||||
"""组合因子 - Phase 2 因子组合和标量运算
|
||||
|
||||
本模块定义了因子组合相关的类:
|
||||
- CompositeFactor: 组合因子,用于实现因子间的数学运算
|
||||
- ScalarFactor: 标量运算因子,支持因子与标量的运算
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
from src.factors.base import BaseFactor
|
||||
|
||||
|
||||
class CompositeFactor(BaseFactor):
|
||||
"""组合因子 - 用于实现因子间的数学运算
|
||||
|
||||
约束:左右因子必须是同类型(同为截面或同为时序)
|
||||
|
||||
支持的运算符:'+', '-', '*', '/'
|
||||
|
||||
示例:
|
||||
>>> f1 = SomeCrossSectionalFactor()
|
||||
>>> f2 = AnotherCrossSectionalFactor()
|
||||
>>> combined = f1 + f2 # 创建 CompositeFactor
|
||||
"""
|
||||
|
||||
def __init__(self, left: BaseFactor, right: BaseFactor, op: str):
|
||||
"""创建组合因子
|
||||
|
||||
Args:
|
||||
left: 左操作数因子
|
||||
right: 右操作数因子
|
||||
op: 运算符,支持 '+', '-', '*', '/'
|
||||
|
||||
Raises:
|
||||
ValueError: 左右因子类型不一致
|
||||
ValueError: 不支持的运算符
|
||||
"""
|
||||
# 验证类型一致性
|
||||
if left.factor_type != right.factor_type:
|
||||
raise ValueError(
|
||||
f"Cannot combine factors of different types: "
|
||||
f"'{left.factor_type}' vs '{right.factor_type}'"
|
||||
)
|
||||
|
||||
# 验证运算符
|
||||
if op not in ("+", "-", "*", "/"):
|
||||
raise ValueError(f"Unsupported operator: '{op}'")
|
||||
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.op = op
|
||||
|
||||
# 设置类属性
|
||||
self.factor_type = left.factor_type
|
||||
self.name = f"({left.name}_{op}_{right.name})"
|
||||
self.data_specs = self._merge_data_specs()
|
||||
self.category = "composite"
|
||||
self.description = f"Composite factor: {left.name} {op} {right.name}"
|
||||
|
||||
# 注意:不调用 super().__init__(),因为 CompositeFactor 是特殊因子
|
||||
self.params = {
|
||||
"left": left,
|
||||
"right": right,
|
||||
"op": op,
|
||||
}
|
||||
|
||||
def _merge_data_specs(self) -> List[DataSpec]:
|
||||
"""合并左右因子的数据需求
|
||||
|
||||
策略:
|
||||
1. 相同 source 和 columns 的 DataSpec 合并
|
||||
2. lookback_days 取最大值
|
||||
|
||||
Returns:
|
||||
合并后的 DataSpec 列表
|
||||
"""
|
||||
merged = []
|
||||
|
||||
# 收集所有 specs
|
||||
all_specs = list(self.left.data_specs) + list(self.right.data_specs)
|
||||
|
||||
# 按 (source, columns_tuple) 分组
|
||||
spec_groups = {}
|
||||
for spec in all_specs:
|
||||
key = (spec.source, tuple(sorted(spec.columns)))
|
||||
if key not in spec_groups:
|
||||
spec_groups[key] = []
|
||||
spec_groups[key].append(spec)
|
||||
|
||||
# 合并每组,取最大 lookback_days
|
||||
for (source, columns_tuple), specs in spec_groups.items():
|
||||
max_lookback = max(spec.lookback_days for spec in specs)
|
||||
merged.append(
|
||||
DataSpec(
|
||||
source=source,
|
||||
columns=list(columns_tuple),
|
||||
lookback_days=max_lookback,
|
||||
)
|
||||
)
|
||||
|
||||
return merged
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""执行组合运算
|
||||
|
||||
流程:
|
||||
1. 分别计算 left 和 right 的值
|
||||
2. 根据 op 执行运算
|
||||
3. 返回结果
|
||||
|
||||
Args:
|
||||
data: 包含左右因子所需数据的 FactorData
|
||||
|
||||
Returns:
|
||||
组合运算后的因子值 Series
|
||||
"""
|
||||
left_values = self.left.compute(data)
|
||||
right_values = self.right.compute(data)
|
||||
|
||||
ops = {
|
||||
"+": lambda a, b: a + b,
|
||||
"-": lambda a, b: a - b,
|
||||
"*": lambda a, b: a * b,
|
||||
"/": lambda a, b: a / b,
|
||||
}
|
||||
|
||||
return ops[self.op](left_values, right_values)
|
||||
|
||||
def _validate_params(self):
|
||||
"""CompositeFactor 不需要额外验证"""
|
||||
pass
|
||||
|
||||
|
||||
class ScalarFactor(BaseFactor):
|
||||
"""标量运算因子
|
||||
|
||||
支持:scalar * factor, factor * scalar(通过 __rmul__)
|
||||
|
||||
示例:
|
||||
>>> factor = SomeFactor()
|
||||
>>> scaled = 0.5 * factor # 创建 ScalarFactor
|
||||
"""
|
||||
|
||||
def __init__(self, factor: BaseFactor, scalar: float, op: str):
|
||||
"""创建标量运算因子
|
||||
|
||||
Args:
|
||||
factor: 基础因子
|
||||
scalar: 标量值
|
||||
op: 运算符,支持 '*', '+'
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的运算符
|
||||
"""
|
||||
# 验证运算符
|
||||
if op not in ("*", "+"):
|
||||
raise ValueError(f"ScalarFactor only supports '*' and '+', got '{op}'")
|
||||
|
||||
self.factor = factor
|
||||
self.scalar = scalar
|
||||
self.op = op
|
||||
|
||||
# 设置类属性
|
||||
self.factor_type = factor.factor_type
|
||||
self.name = f"({scalar}_{op}_{factor.name})"
|
||||
self.data_specs = factor.data_specs
|
||||
self.category = "scalar"
|
||||
self.description = f"Scalar factor: {scalar} {op} {factor.name}"
|
||||
|
||||
# 注意:不调用 super().__init__(),因为 ScalarFactor 是特殊因子
|
||||
self.params = {
|
||||
"factor": factor,
|
||||
"scalar": scalar,
|
||||
"op": op,
|
||||
}
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""执行标量运算
|
||||
|
||||
Args:
|
||||
data: 包含基础因子所需数据的 FactorData
|
||||
|
||||
Returns:
|
||||
标量运算后的因子值 Series
|
||||
"""
|
||||
values = self.factor.compute(data)
|
||||
|
||||
if self.op == "*":
|
||||
return values * self.scalar
|
||||
elif self.op == "+":
|
||||
return values + self.scalar
|
||||
else:
|
||||
# 不应该执行到这里,因为 __init__ 已经验证了 op
|
||||
raise ValueError(f"Unsupported operation: '{self.op}'")
|
||||
|
||||
def _validate_params(self):
|
||||
"""ScalarFactor 不需要额外验证"""
|
||||
pass
|
||||
@@ -1,213 +0,0 @@
|
||||
"""数据加载器 - Phase 3 数据加载模块
|
||||
|
||||
本模块负责从 DuckDB 安全加载数据:
|
||||
- DataLoader: 数据加载器,支持多文件聚合、列选择、缓存
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
|
||||
from src.factors.data_spec import DataSpec
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""数据加载器 - 负责从 DuckDB 安全加载数据
|
||||
|
||||
功能:
|
||||
1. 多文件聚合:合并多个表的数据
|
||||
2. 列选择:只加载需要的列
|
||||
3. 原始数据缓存:避免重复读取
|
||||
4. 查询下推:利用 DuckDB SQL 过滤,只加载必要数据
|
||||
|
||||
示例:
|
||||
>>> loader = DataLoader(data_dir="data")
|
||||
>>> specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
|
||||
>>> df = loader.load(specs, date_range=("20240101", "20240131"))
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: str):
|
||||
"""初始化 DataLoader
|
||||
|
||||
Args:
|
||||
data_dir: DuckDB 数据库文件所在目录
|
||||
"""
|
||||
self.data_dir = Path(data_dir)
|
||||
self._cache: Dict[str, pl.DataFrame] = {}
|
||||
|
||||
def load(
|
||||
self,
|
||||
specs: List[DataSpec],
|
||||
date_range: Optional[Tuple[str, str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""加载并聚合多个 H5 文件的数据
|
||||
|
||||
流程:
|
||||
1. 对每个 DataSpec:
|
||||
a. 检查缓存,命中则直接使用
|
||||
b. 未命中则读取 HDF5(通过 pandas)
|
||||
c. 转换为 Polars DataFrame
|
||||
d. 按 date_range 过滤
|
||||
e. 存入缓存
|
||||
2. 合并多个 DataFrame(按 trade_date 和 ts_code join)
|
||||
|
||||
Args:
|
||||
specs: 数据需求规格列表
|
||||
date_range: 日期范围限制 (start_date, end_date),可选
|
||||
|
||||
Returns:
|
||||
合并后的 Polars DataFrame
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: H5 文件不存在
|
||||
KeyError: 列不存在于文件中
|
||||
"""
|
||||
dataframes = []
|
||||
|
||||
for spec in specs:
|
||||
# 检查缓存
|
||||
cache_key = f"{spec.source}_{','.join(sorted(spec.columns))}"
|
||||
if cache_key in self._cache:
|
||||
df = self._cache[cache_key]
|
||||
else:
|
||||
# 读取 H5 文件(传入日期范围以支持过滤)
|
||||
df = self._read_h5(spec.source, date_range=date_range)
|
||||
|
||||
# 列选择 - 只保留需要的列
|
||||
missing_cols = set(spec.columns) - set(df.columns)
|
||||
if missing_cols:
|
||||
raise KeyError(
|
||||
f"Columns {missing_cols} not found in {spec.source}.h5. "
|
||||
f"Available columns: {df.columns}"
|
||||
)
|
||||
df = df.select(spec.columns)
|
||||
|
||||
# 存入缓存
|
||||
self._cache[cache_key] = df
|
||||
|
||||
# 按 date_range 过滤
|
||||
if date_range:
|
||||
start_date, end_date = date_range
|
||||
df = df.filter(
|
||||
(pl.col("trade_date") >= start_date)
|
||||
& (pl.col("trade_date") <= end_date)
|
||||
)
|
||||
|
||||
dataframes.append(df)
|
||||
|
||||
# 合并多个 DataFrame
|
||||
if len(dataframes) == 1:
|
||||
return dataframes[0]
|
||||
else:
|
||||
return self._merge_dataframes(dataframes)
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
|
||||
def _read_h5(
|
||||
self,
|
||||
source: str,
|
||||
date_range: Optional[Tuple[str, str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""读取数据 - 从 DuckDB 加载为 Polars DataFrame。
|
||||
|
||||
迁移说明:
|
||||
- 方法名保持 _read_h5 以兼容现有代码(实际从 DuckDB 读取)
|
||||
- 使用 Storage.load_polars() 直接返回 Polars DataFrame
|
||||
- 支持零拷贝导出,性能优于 HDF5 + Pandas + Polars 转换
|
||||
|
||||
Args:
|
||||
source: 表名(对应 DuckDB 中的表,如 "daily")
|
||||
date_range: 日期范围限制 (start_date, end_date),可选
|
||||
|
||||
Returns:
|
||||
Polars DataFrame
|
||||
|
||||
Raises:
|
||||
Exception: 数据库查询错误
|
||||
"""
|
||||
from src.data.storage import Storage
|
||||
from src.data.api_wrappers.api_trade_cal import get_trading_days
|
||||
from src.data.utils import get_today_date
|
||||
from src.factors.financial.utils import expand_period_to_trading_days
|
||||
|
||||
storage = Storage()
|
||||
|
||||
# 特殊处理财务数据:将报告期展开到交易日
|
||||
if source == "financial_income":
|
||||
# 确定日期范围
|
||||
start_date = date_range[0] if date_range else "20180101"
|
||||
end_date = date_range[1] if date_range else get_today_date()
|
||||
|
||||
# 1. 加载原始财务数据(报告期粒度),按日期范围过滤
|
||||
# 注意:financial_income 使用 end_date 字段作为报告期
|
||||
df = storage.load_polars(
|
||||
"financial_income",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
if len(df) == 0:
|
||||
return pl.DataFrame()
|
||||
|
||||
# 2. 获取交易日历(从2018年开始到当前,确保有足够的历史数据用于前向填充)
|
||||
# 需要从数据的最小日期开始,确保能获取到足够的交易日
|
||||
trade_start = "20180101" if start_date > "20180101" else start_date
|
||||
trade_dates = get_trading_days(trade_start, get_today_date())
|
||||
|
||||
# 3. 展开到交易日(前向填充)
|
||||
return expand_period_to_trading_days(df, trade_dates)
|
||||
|
||||
# 其他数据源保持原有逻辑
|
||||
return storage.load_polars(source)
|
||||
|
||||
def _merge_dataframes(self, dataframes: List[pl.DataFrame]) -> pl.DataFrame:
|
||||
"""合并多个 DataFrame
|
||||
|
||||
策略:
|
||||
1. 按 trade_date 和 ts_code join
|
||||
2. 使用外连接保留所有数据
|
||||
|
||||
Args:
|
||||
dataframes: DataFrame 列表
|
||||
|
||||
Returns:
|
||||
合并后的 DataFrame
|
||||
"""
|
||||
result = dataframes[0]
|
||||
|
||||
for df in dataframes[1:]:
|
||||
# 确定 join 键
|
||||
join_keys = ["trade_date", "ts_code"]
|
||||
|
||||
# 检查 join 键是否存在
|
||||
for key in join_keys:
|
||||
if key not in result.columns or key not in df.columns:
|
||||
raise KeyError(f"Join key '{key}' not found in DataFrames")
|
||||
|
||||
# 获取需要添加的列(排除重复的 join 键)
|
||||
new_cols = [c for c in df.columns if c not in result.columns]
|
||||
|
||||
if new_cols:
|
||||
# 选择必要的列进行 join
|
||||
df_to_join = df.select(join_keys + new_cols)
|
||||
|
||||
# 执行 join
|
||||
result = result.join(df_to_join, on=join_keys, how="full")
|
||||
|
||||
return result
|
||||
|
||||
def get_cache_info(self) -> Dict[str, int]:
|
||||
"""获取缓存信息
|
||||
|
||||
Returns:
|
||||
包含缓存条目数和总字节数的字典
|
||||
"""
|
||||
total_rows = sum(len(df) for df in self._cache.values())
|
||||
return {
|
||||
"entries": len(self._cache),
|
||||
"total_rows": total_rows,
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
"""数据类型定义 - Phase 1 核心数据模型
|
||||
|
||||
本模块定义了因子框架的基础数据类型:
|
||||
- DataSpec: 数据需求规格,声明因子所需的数据源、列和回看窗口
|
||||
- FactorContext: 计算上下文,由引擎自动注入,提供计算点信息
|
||||
- FactorData: 数据容器,封装底层 Polars DataFrame,提供安全的数据访问
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
import polars as pl
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DataSpec:
|
||||
"""数据需求规格说明
|
||||
|
||||
用于声明因子计算所需的数据来源、列和回看窗口。
|
||||
这是一个不可变对象,创建后不可修改。
|
||||
|
||||
Args:
|
||||
source: H5 文件名(如 "daily", "fundamental")
|
||||
columns: 需要的列名列表,必须包含 "ts_code" 和 "trade_date"
|
||||
lookback_days: 需要回看的天数(包含当日)
|
||||
- 1 表示只需要当日数据 [T]
|
||||
- 5 表示需要 [T-4, T] 共5天
|
||||
- 20 表示需要 [T-19, T] 共20天
|
||||
|
||||
Raises:
|
||||
ValueError: 当参数不满足约束条件时
|
||||
|
||||
Examples:
|
||||
>>> spec = DataSpec(
|
||||
... source="daily",
|
||||
... columns=["ts_code", "trade_date", "close"],
|
||||
... lookback_days=20
|
||||
... )
|
||||
"""
|
||||
|
||||
source: str
|
||||
columns: List[str]
|
||||
lookback_days: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证约束条件
|
||||
|
||||
验证项:
|
||||
1. lookback_days >= 1(至少包含当日)
|
||||
2. columns 必须包含 ts_code 和 trade_date
|
||||
3. source 不能为空字符串
|
||||
|
||||
注意:由于 frozen=True,实例创建后不可修改。
|
||||
若需要在 __post_init__ 中修改字段(如有),可使用 object.__setattr__。
|
||||
本类仅做验证,无需修改字段,因此直接 raise ValueError 即可。
|
||||
"""
|
||||
if self.lookback_days < 1:
|
||||
raise ValueError(f"lookback_days must be >= 1, got {self.lookback_days}")
|
||||
|
||||
if not self.source:
|
||||
raise ValueError("source cannot be empty string")
|
||||
|
||||
required_cols = {"ts_code", "trade_date"}
|
||||
missing_cols = required_cols - set(self.columns)
|
||||
if missing_cols:
|
||||
raise ValueError(
|
||||
f"columns must contain {required_cols}, missing: {missing_cols}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorContext:
|
||||
"""因子计算上下文
|
||||
|
||||
由 FactorEngine 自动注入,因子开发者可通过 data.context 访问。
|
||||
根据因子类型的不同,包含不同的上下文信息:
|
||||
- CrossSectionalFactor:current_date 表示当前计算的日期
|
||||
- TimeSeriesFactor:current_stock 表示当前计算的股票
|
||||
|
||||
Attributes:
|
||||
current_date: 当前计算日期 YYYYMMDD(截面因子使用)
|
||||
current_stock: 当前计算股票代码(时序因子使用)
|
||||
trade_dates: 交易日历列表(可选,用于对齐)
|
||||
|
||||
Examples:
|
||||
>>> context = FactorContext(current_date="20240101")
|
||||
>>> context.current_date
|
||||
'20240101'
|
||||
"""
|
||||
|
||||
current_date: Optional[str] = None
|
||||
current_stock: Optional[str] = None
|
||||
trade_dates: Optional[List[str]] = None
|
||||
|
||||
|
||||
class FactorData:
|
||||
"""提供给因子的数据容器
|
||||
|
||||
封装底层 Polars DataFrame,提供安全的数据访问接口。
|
||||
根据因子类型的不同,包含不同的数据:
|
||||
- CrossSectionalFactor:当前日期及历史 lookback 的截面数据(所有股票)
|
||||
- TimeSeriesFactor:单只股票的完整时间序列数据
|
||||
|
||||
Args:
|
||||
df: 底层的 Polars DataFrame
|
||||
context: 计算上下文
|
||||
|
||||
Examples:
|
||||
>>> df = pl.DataFrame({
|
||||
... "ts_code": ["000001.SZ"],
|
||||
... "trade_date": ["20240101"],
|
||||
... "close": [10.0]
|
||||
... })
|
||||
>>> context = FactorContext(current_date="20240101")
|
||||
>>> data = FactorData(df, context)
|
||||
"""
|
||||
|
||||
def __init__(self, df: pl.DataFrame, context: FactorContext):
|
||||
self._df = df
|
||||
self._context = context
|
||||
|
||||
def get_column(self, col: str) -> pl.Series:
|
||||
"""获取指定列的数据
|
||||
|
||||
适用于两种因子类型:
|
||||
- 截面因子:获取当天所有股票的该列值
|
||||
- 时序因子:获取该股票时间序列的该列值
|
||||
|
||||
Args:
|
||||
col: 列名
|
||||
|
||||
Returns:
|
||||
Polars Series
|
||||
|
||||
Raises:
|
||||
KeyError: 列不存在于数据中
|
||||
|
||||
Examples:
|
||||
>>> prices = data.get_column("close")
|
||||
>>> print(prices)
|
||||
"""
|
||||
if col not in self._df.columns:
|
||||
raise KeyError(
|
||||
f"Column '{col}' not found in data. Available columns: {self._df.columns}"
|
||||
)
|
||||
return self._df[col]
|
||||
|
||||
def filter_by_date(self, date: str) -> "FactorData":
|
||||
"""按日期过滤数据,返回新的 FactorData
|
||||
|
||||
主要用于截面因子获取特定日期的数据。
|
||||
注意:无法获取未来日期的数据(引擎已经裁剪掉)。
|
||||
|
||||
Args:
|
||||
date: YYYYMMDD 格式的日期
|
||||
|
||||
Returns:
|
||||
过滤后的 FactorData(新实例,不修改原数据)
|
||||
|
||||
Examples:
|
||||
>>> today_data = data.filter_by_date("20240101")
|
||||
>>> print(len(today_data))
|
||||
"""
|
||||
filtered = self._df.filter(pl.col("trade_date") == date)
|
||||
return FactorData(filtered, self._context)
|
||||
|
||||
def get_cross_section(self) -> pl.DataFrame:
|
||||
"""获取当前日期的截面数据
|
||||
|
||||
仅适用于截面因子,返回 current_date 当天的所有股票数据。
|
||||
|
||||
Returns:
|
||||
DataFrame 包含当前日期的所有股票
|
||||
|
||||
Raises:
|
||||
ValueError: current_date 未设置(非截面因子场景)
|
||||
|
||||
Examples:
|
||||
>>> cs = data.get_cross_section()
|
||||
>>> rankings = cs["pe"].rank()
|
||||
"""
|
||||
if self._context.current_date is None:
|
||||
raise ValueError(
|
||||
"current_date is not set in context. "
|
||||
"get_cross_section() is only applicable for cross-sectional factors."
|
||||
)
|
||||
return self._df.filter(pl.col("trade_date") == self._context.current_date)
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
"""获取底层的 Polars DataFrame(高级用法)
|
||||
|
||||
返回原始 DataFrame,允许进行自定义的 Polars 操作。
|
||||
注意:直接操作底层数据可能绕过框架的防泄露保护,请谨慎使用。
|
||||
|
||||
Returns:
|
||||
底层的 Polars DataFrame
|
||||
|
||||
Examples:
|
||||
>>> df = data.to_polars()
|
||||
>>> result = df.group_by("industry").agg(pl.col("pe").mean())
|
||||
"""
|
||||
return self._df
|
||||
|
||||
@property
|
||||
def context(self) -> FactorContext:
|
||||
"""获取计算上下文
|
||||
|
||||
Returns:
|
||||
当前的 FactorContext 实例
|
||||
|
||||
Examples:
|
||||
>>> date = data.context.current_date
|
||||
>>> stock = data.context.current_stock
|
||||
"""
|
||||
return self._context
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""返回数据行数
|
||||
|
||||
Returns:
|
||||
DataFrame 的行数
|
||||
|
||||
Examples:
|
||||
>>> if len(data) > 0:
|
||||
... result = data.get_column("close").mean()
|
||||
"""
|
||||
return len(self._df)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回 FactorData 的字符串表示
|
||||
|
||||
Returns:
|
||||
包含类名、行数、列数和上下文信息的字符串
|
||||
"""
|
||||
cols = self._df.columns
|
||||
context_info = []
|
||||
if self._context.current_date:
|
||||
context_info.append(f"date={self._context.current_date}")
|
||||
if self._context.current_stock:
|
||||
context_info.append(f"stock={self._context.current_stock}")
|
||||
|
||||
context_str = ", ".join(context_info) if context_info else "no context"
|
||||
return f"FactorData(rows={len(self)}, cols={len(cols)}, {context_str})"
|
||||
@@ -1,20 +0,0 @@
|
||||
"""财务因子模块
|
||||
|
||||
本模块提供财务类型的因子:
|
||||
|
||||
因子分类:
|
||||
- financial: 财务因子
|
||||
- EPSFactor: 每股收益排名因子
|
||||
|
||||
已添加因子:
|
||||
- EPSFactor: 每股收益排名(基于basic_eps)
|
||||
|
||||
待添加因子:
|
||||
- PERankFactor: 市盈率排名
|
||||
- PBFactor: 市净率因子
|
||||
- DividendFactor: 股息率因子
|
||||
"""
|
||||
|
||||
from src.factors.financial.eps_factor import EPSFactor
|
||||
|
||||
__all__ = ["EPSFactor"]
|
||||
@@ -1,66 +0,0 @@
|
||||
"""EPS因子
|
||||
|
||||
每股收益(EPS)排名因子实现
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import polars as pl
|
||||
|
||||
from src.factors.base import CrossSectionalFactor
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
|
||||
|
||||
class EPSFactor(CrossSectionalFactor):
|
||||
"""每股收益(EPS)排名因子
|
||||
|
||||
计算逻辑:使用最新报告期的basic_eps,每天对所有股票进行截面排名
|
||||
|
||||
Attributes:
|
||||
name: 因子名称 "eps_rank"
|
||||
category: 因子分类 "financial"
|
||||
data_specs: 数据需求规格
|
||||
|
||||
Example:
|
||||
>>> from src.factors import FactorEngine, DataLoader
|
||||
>>> from src.factors.financial.eps_factor import EPSFactor
|
||||
>>> loader = DataLoader('data')
|
||||
>>> engine = FactorEngine(loader)
|
||||
>>> eps_factor = EPSFactor()
|
||||
>>> result = engine.compute(eps_factor, start_date='20210101', end_date='20210131')
|
||||
"""
|
||||
|
||||
name: str = "eps_rank"
|
||||
category: str = "financial"
|
||||
description: str = "每股收益截面排名因子"
|
||||
data_specs: List[DataSpec] = [
|
||||
DataSpec(
|
||||
"financial_income", ["ts_code", "trade_date", "basic_eps"], lookback_days=1
|
||||
)
|
||||
]
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算EPS排名
|
||||
|
||||
Args:
|
||||
data: FactorData,包含当前日期的截面数据
|
||||
|
||||
Returns:
|
||||
EPS排名的0-1标准化值(0-1之间)
|
||||
"""
|
||||
# 获取当前日期的截面数据
|
||||
cs = data.get_cross_section()
|
||||
|
||||
if len(cs) == 0:
|
||||
return pl.Series(name=self.name, values=[])
|
||||
|
||||
# 提取EPS值,填充缺失值为0
|
||||
eps = cs["basic_eps"].fill_null(0)
|
||||
|
||||
# 计算排名并归一化到0-1
|
||||
if len(eps) > 1 and eps.max() != eps.min():
|
||||
ranks = eps.rank(method="average") / len(eps)
|
||||
else:
|
||||
# 数据不足或全部相同,返回0.5
|
||||
ranks = pl.Series(name=self.name, values=[0.5] * len(eps))
|
||||
|
||||
return ranks
|
||||
@@ -1,82 +0,0 @@
|
||||
"""财务因子工具函数
|
||||
|
||||
提供财务数据处理的工具函数:
|
||||
- expand_period_to_trading_days: 将报告期数据展开到每个交易日(前向填充)
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import polars as pl
|
||||
|
||||
|
||||
def expand_period_to_trading_days(
|
||||
financial_df: pl.DataFrame,
|
||||
trade_dates: List[str],
|
||||
) -> pl.DataFrame:
|
||||
"""将财务数据(报告期粒度)展开到每个交易日(前向填充)
|
||||
|
||||
核心逻辑:对于每个交易日,找到该日期之前最新的已公告报告期数据。
|
||||
例如:2020年报(20201231)公告于20210428,则在2021-04-28之后的每个
|
||||
交易日都使用该年报数据,直到2021一季报公告。
|
||||
|
||||
Args:
|
||||
financial_df: 财务数据DataFrame,包含 ts_code, ann_date, end_date, ...
|
||||
trade_dates: 交易日列表(YYYYMMDD格式,已排序)
|
||||
|
||||
Returns:
|
||||
DataFrame,包含 trade_date, ts_code 和所有财务字段
|
||||
|
||||
Example:
|
||||
>>> financial_df = pl.DataFrame({
|
||||
... 'ts_code': ['000001.SZ'],
|
||||
... 'ann_date': ['20210428'],
|
||||
... 'end_date': ['20210331'],
|
||||
... 'basic_eps': [0.5]
|
||||
... })
|
||||
>>> trade_dates = ['20210428', '20210429', '20210430']
|
||||
>>> result = expand_period_to_trading_days(financial_df, trade_dates)
|
||||
>>> print(result)
|
||||
shape: (3, 5)
|
||||
┌───────────┬───────────┬────────────┬────────────┬───────────┐
|
||||
│ ts_code ┆ ann_date ┆ end_date ┆ basic_eps ┆ trade_date│
|
||||
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
|
||||
│ str ┆ str ┆ str ┆ f64 ┆ str │
|
||||
╞═══════════╪═══════════╪════════════╪════════════╪═══════════╡
|
||||
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210428 │
|
||||
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210429 │
|
||||
│ 000001.SZ ┆ 20210428 ┆ 20210331 ┆ 0.5 ┆ 20210430 │
|
||||
└───────────┴───────────┴────────────┴────────────┴───────────┘
|
||||
"""
|
||||
if len(financial_df) == 0:
|
||||
return pl.DataFrame()
|
||||
|
||||
results = []
|
||||
|
||||
# 按股票分组处理
|
||||
for ts_code in financial_df["ts_code"].unique():
|
||||
stock_data = financial_df.filter(pl.col("ts_code") == ts_code)
|
||||
|
||||
# 按报告期排序(end_date升序)
|
||||
stock_data = stock_data.sort("end_date")
|
||||
|
||||
rows = []
|
||||
for trade_date in trade_dates:
|
||||
# 找到该交易日之前最新的已公告报告期
|
||||
# 条件1: end_date <= trade_date(报告期不晚于交易日)
|
||||
# 条件2: ann_date <= trade_date(已公告)
|
||||
applicable = stock_data.filter(
|
||||
(pl.col("end_date") <= trade_date) & (pl.col("ann_date") <= trade_date)
|
||||
)
|
||||
|
||||
if len(applicable) > 0:
|
||||
# 取最新的一条(end_date最大的)
|
||||
latest = applicable.tail(1).with_columns(
|
||||
[pl.lit(trade_date).alias("trade_date")]
|
||||
)
|
||||
rows.append(latest)
|
||||
|
||||
if rows:
|
||||
results.append(pl.concat(rows))
|
||||
|
||||
if results:
|
||||
return pl.concat(results)
|
||||
return pl.DataFrame()
|
||||
@@ -1,19 +0,0 @@
|
||||
"""动量因子模块
|
||||
|
||||
本模块提供动量类型的因子:
|
||||
- MovingAverageFactor: 移动平均线(时序因子)
|
||||
- ReturnRankFactor: 收益率排名(截面因子)
|
||||
|
||||
因子分类:
|
||||
- momentum: 动量因子
|
||||
- ma: 移动平均线
|
||||
- return_rank: 收益率排名
|
||||
"""
|
||||
|
||||
from src.factors.momentum.ma import MovingAverageFactor
|
||||
from src.factors.momentum.return_rank import ReturnRankFactor
|
||||
|
||||
__all__ = [
|
||||
"MovingAverageFactor",
|
||||
"ReturnRankFactor",
|
||||
]
|
||||
@@ -1,78 +0,0 @@
|
||||
"""动量因子 - 移动平均线
|
||||
|
||||
本模块提供通用移动平均线因子,支持参数化配置:
|
||||
- MovingAverageFactor: 移动平均线(时序因子)
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.momentum import MovingAverageFactor
|
||||
>>> ma5 = MovingAverageFactor(period=5) # 5日MA
|
||||
>>> ma10 = MovingAverageFactor(period=10) # 10日MA
|
||||
>>> ma20 = MovingAverageFactor(period=20) # 20日MA
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.base import TimeSeriesFactor
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
|
||||
|
||||
class MovingAverageFactor(TimeSeriesFactor):
|
||||
"""移动平均线因子
|
||||
|
||||
计算逻辑:对每只股票,计算其过去n日收盘价的移动平均值。
|
||||
|
||||
特点:
|
||||
- 参数化因子:训练时通过 period 参数指定计算窗口
|
||||
- 时序因子:每只股票单独计算,防止股票间数据泄露
|
||||
|
||||
Attributes:
|
||||
period: MA计算期(天数),默认5
|
||||
|
||||
Example:
|
||||
>>> ma5 = MovingAverageFactor(period=5)
|
||||
>>> # 计算过去5日的收盘价均值
|
||||
"""
|
||||
|
||||
name: str = "ma"
|
||||
factor_type: str = "time_series"
|
||||
category: str = "momentum"
|
||||
description: str = "移动平均线因子,计算过去n日收盘价的均值"
|
||||
data_specs: List[DataSpec] = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
]
|
||||
|
||||
def __init__(self, period: int = 5):
|
||||
"""初始化因子
|
||||
|
||||
Args:
|
||||
period: MA计算期(天数),默认5日
|
||||
"""
|
||||
super().__init__(period=period)
|
||||
# 重新创建 DataSpec 以设置正确的 lookback_days(DataSpec 是 frozen 的)
|
||||
self.data_specs = [
|
||||
DataSpec(
|
||||
"daily",
|
||||
["ts_code", "trade_date", "close"],
|
||||
lookback_days=period,
|
||||
)
|
||||
]
|
||||
self.name = f"ma_{period}"
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算移动平均线
|
||||
|
||||
Args:
|
||||
data: FactorData,包含单只股票的完整时间序列
|
||||
|
||||
Returns:
|
||||
移动平均值序列
|
||||
"""
|
||||
# 获取收盘价序列
|
||||
close_prices = data.get_column("close")
|
||||
|
||||
# 计算移动平均
|
||||
ma = close_prices.rolling_mean(window_size=self.params["period"])
|
||||
|
||||
return ma
|
||||
@@ -1,100 +0,0 @@
|
||||
"""动量因子 - 收益率排名
|
||||
|
||||
本模块提供收益率排名因子:
|
||||
- ReturnRankFactor: 过去n日收益率的rank因子(截面因子)
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.momentum import ReturnRankFactor
|
||||
>>> ret5 = ReturnRankFactor(period=5) # 5日收益率排名
|
||||
>>> ret10 = ReturnRankFactor(period=10) # 10日收益率排名
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.base import CrossSectionalFactor
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
|
||||
|
||||
class ReturnRankFactor(CrossSectionalFactor):
|
||||
"""过去n日收益率排名因子
|
||||
|
||||
计算逻辑:每个交易日,计算所有股票过去n日的收益率,然后进行截面排名。
|
||||
|
||||
特点:
|
||||
- 参数化因子:训练时通过 period 参数指定计算窗口
|
||||
- 截面因子:每天对所有股票进行横向排名,防止日期泄露
|
||||
|
||||
Attributes:
|
||||
period: 收益率计算期(默认5日)
|
||||
|
||||
Example:
|
||||
>>> ret5 = ReturnRankFactor(period=5)
|
||||
>>> # 每个交易日,返回所有股票过去5日收益率的排名
|
||||
"""
|
||||
|
||||
name: str = "return_rank"
|
||||
factor_type: str = "cross_sectional"
|
||||
category: str = "momentum"
|
||||
description: str = "过去n日收益率的截面排名因子"
|
||||
data_specs: List[DataSpec] = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
]
|
||||
|
||||
def __init__(self, period: int = 5):
|
||||
"""初始化因子
|
||||
|
||||
Args:
|
||||
period: 收益率计算期(天数)
|
||||
"""
|
||||
super().__init__(period=period)
|
||||
# 重新创建 DataSpec 以设置正确的 lookback_days(DataSpec 是 frozen 的)
|
||||
self.data_specs = [
|
||||
DataSpec(
|
||||
"daily",
|
||||
["ts_code", "trade_date", "close"],
|
||||
lookback_days=period + 1,
|
||||
)
|
||||
]
|
||||
self.name = f"return_{period}_rank"
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算过去n日收益率排名
|
||||
|
||||
Args:
|
||||
data: FactorData,包含过去n+1天的截面数据
|
||||
|
||||
Returns:
|
||||
过去n日收益率的截面排名(0-1之间)
|
||||
"""
|
||||
# 获取当前日期的截面数据
|
||||
cs = data.to_polars()
|
||||
|
||||
# 获取所有交易日期(已按日期排序)
|
||||
trade_dates = cs["trade_date"].unique().sort()
|
||||
|
||||
if len(trade_dates) < 2:
|
||||
# 数据不足,返回空排名
|
||||
return pl.Series(name=self.name, values=[])
|
||||
|
||||
# 获取最新日期的数据
|
||||
latest_date = trade_dates[-1]
|
||||
current_data = cs.filter(pl.col("trade_date") == latest_date)
|
||||
|
||||
# 获取n天前的日期
|
||||
n_days_ago = trade_dates[-(self.params["period"] + 1)]
|
||||
past_data = cs.filter(pl.col("trade_date") == n_days_ago)
|
||||
|
||||
# 通过 ts_code join 计算收益率
|
||||
merged = current_data.select(["ts_code", "close"]).join(
|
||||
past_data.select(["ts_code", "close"]).rename({"close": "close_past"}),
|
||||
on="ts_code",
|
||||
how="inner",
|
||||
)
|
||||
|
||||
# 计算收益率
|
||||
returns = (merged["close"] - merged["close_past"]) / merged["close_past"]
|
||||
|
||||
# 返回排名(0-1之间)
|
||||
return returns.rank(method="average") / len(returns)
|
||||
@@ -1,20 +0,0 @@
|
||||
"""质量因子模块
|
||||
|
||||
本模块提供质量类因子:
|
||||
- 盈利能力:ROE、ROA、毛利率、净利率
|
||||
- 盈利稳定性:盈利波动率、盈利持续性
|
||||
- 财务健康度:资产负债率、流动比率等
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.quality import ROEFactor
|
||||
>>> factor = ROEFactor()
|
||||
"""
|
||||
|
||||
# 在此处导入具体的质量因子
|
||||
# from .roe import ROEFactor
|
||||
# from .roa import ROAFactor
|
||||
# from .profit_stability import ProfitStabilityFactor
|
||||
|
||||
__all__ = [
|
||||
# 添加你的质量因子
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
"""情绪因子模块
|
||||
|
||||
本模块提供市场情绪类因子:
|
||||
- 换手率、换手率变化率
|
||||
- 资金流向、主力净流入
|
||||
- 波动率、振幅等
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.sentiment import TurnoverFactor
|
||||
>>> factor = TurnoverFactor(period=20)
|
||||
"""
|
||||
|
||||
# 在此处导入具体的情绪因子
|
||||
# from .turnover import TurnoverFactor
|
||||
# from .money_flow import MoneyFlowFactor
|
||||
# from .amplitude import AmplitudeFactor
|
||||
|
||||
__all__ = [
|
||||
# 添加你的情绪因子
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
"""技术指标因子模块
|
||||
|
||||
本模块提供技术分析类因子:
|
||||
- 移动平均线(MA)、指数移动平均(EMA)
|
||||
- 相对强弱指标(RSI)、MACD、KDJ
|
||||
- 布林带(Bollinger Bands)等
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.technical import RSIFactor
|
||||
>>> factor = RSIFactor(period=14)
|
||||
"""
|
||||
|
||||
# 在此处导入具体的技术指标因子
|
||||
# from .rsi import RSIFactor
|
||||
# from .macd import MACDFactor
|
||||
# from .bollinger import BollingerFactor
|
||||
|
||||
__all__ = [
|
||||
# 添加你的技术指标因子
|
||||
]
|
||||
@@ -1,18 +0,0 @@
|
||||
"""估值因子模块
|
||||
|
||||
本模块提供估值类因子:
|
||||
- 市盈率(PE)、市净率(PB)、市销率(PS)等估值指标
|
||||
- 估值排名、估值分位数等衍生因子
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.valuation import PERankFactor
|
||||
>>> factor = PERankFactor()
|
||||
"""
|
||||
|
||||
# 在此处导入具体的估值因子
|
||||
# from .pe_rank import PERankFactor
|
||||
# from .pb_rank import PBRankFactor
|
||||
|
||||
__all__ = [
|
||||
# 添加你的估值因子
|
||||
]
|
||||
@@ -1,21 +0,0 @@
|
||||
"""波动率因子模块
|
||||
|
||||
本模块提供波动率相关因子:
|
||||
- 历史波动率(Historical Volatility)
|
||||
- 实现波动率(Realized Volatility)
|
||||
- GARCH类波动率预测
|
||||
- 波动率风险指标等
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.volatility import HistoricalVolFactor
|
||||
>>> factor = HistoricalVolFactor(period=20)
|
||||
"""
|
||||
|
||||
# 在此处导入具体的波动率因子
|
||||
# from .historical_vol import HistoricalVolFactor
|
||||
# from .realized_vol import RealizedVolFactor
|
||||
# from .garch_vol import GARCHVolFactor
|
||||
|
||||
__all__ = [
|
||||
# 添加你的波动率因子
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
"""成交量因子模块
|
||||
|
||||
本模块提供成交量相关因子:
|
||||
- 成交量移动平均
|
||||
- 成交量比率(VR)、能量潮(OBV)
|
||||
- 量价配合指标等
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.volume import OBVFactor
|
||||
>>> factor = OBVFactor()
|
||||
"""
|
||||
|
||||
# 在此处导入具体的成交量因子
|
||||
# from .obv import OBVFactor
|
||||
# from .volume_ratio import VolumeRatioFactor
|
||||
# from .volume_ma import VolumeMAFactor
|
||||
|
||||
__all__ = [
|
||||
# 添加你的成交量因子
|
||||
]
|
||||
Reference in New Issue
Block a user