feat(data): 添加个股资金流向接口并重构速率限制配置
- 新增 moneyflow 资金流向数据同步模块 - 实现接口级速率限制配置(sync_config.py) - 更新流动性相关因子定义 - 添加非对称量化损失函数
This commit is contained in:
12
config/sync_config.json
Normal file
12
config/sync_config.json
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"default": {
|
||||||
|
"rate_limit": 150,
|
||||||
|
"threads": 1
|
||||||
|
},
|
||||||
|
"interfaces": {
|
||||||
|
"pro_bar": {
|
||||||
|
"rate_limit": 500,
|
||||||
|
"threads": 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
112
docs/api/api.md
112
docs/api/api.md
@@ -749,4 +749,114 @@ df = pro.cyq_perf(ts_code='600000.SH', start_date='20220101', end_date='20220429
|
|||||||
73 600000.SH 20220107 0.72 12.16 8.60 11.36 9.89 7.59
|
73 600000.SH 20220107 0.72 12.16 8.60 11.36 9.89 7.59
|
||||||
74 600000.SH 20220106 0.72 12.16 8.60 11.36 9.89 3.92
|
74 600000.SH 20220106 0.72 12.16 8.60 11.36 9.89 3.92
|
||||||
75 600000.SH 20220105 0.72 12.16 8.60 11.36 9.89 5.65
|
75 600000.SH 20220105 0.72 12.16 8.60 11.36 9.89 5.65
|
||||||
76 600000.SH 20220104 0.72 12.16 8.60 11.36 9.89 3.93
|
76 600000.SH 20220104 0.72 12.16 8.60 11.36 9.89 3.93
|
||||||
|
|
||||||
|
|
||||||
|
个股资金流向
|
||||||
|
接口:moneyflow,可以通过数据工具调试和查看数据。
|
||||||
|
描述:获取沪深A股票资金流向数据,分析大单小单成交情况,用于判别资金动向,数据开始于2010年。
|
||||||
|
限量:单次最大提取6000行记录,总量不限制
|
||||||
|
积分:用户需要至少2000积分才可以调取,基础积分有流量控制,积分越多权限越大,请自行提高积分,具体请参阅积分获取办法
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
输入参数
|
||||||
|
|
||||||
|
名称 类型 必选 描述
|
||||||
|
ts_code str N 股票代码 (股票和时间参数至少输入一个)
|
||||||
|
trade_date str N 交易日期
|
||||||
|
start_date str N 开始日期
|
||||||
|
end_date str N 结束日期
|
||||||
|
|
||||||
|
|
||||||
|
输出参数
|
||||||
|
|
||||||
|
名称 类型 默认显示 描述
|
||||||
|
ts_code str Y TS代码
|
||||||
|
trade_date str Y 交易日期
|
||||||
|
buy_sm_vol int Y 小单买入量(手)
|
||||||
|
buy_sm_amount float Y 小单买入金额(万元)
|
||||||
|
sell_sm_vol int Y 小单卖出量(手)
|
||||||
|
sell_sm_amount float Y 小单卖出金额(万元)
|
||||||
|
buy_md_vol int Y 中单买入量(手)
|
||||||
|
buy_md_amount float Y 中单买入金额(万元)
|
||||||
|
sell_md_vol int Y 中单卖出量(手)
|
||||||
|
sell_md_amount float Y 中单卖出金额(万元)
|
||||||
|
buy_lg_vol int Y 大单买入量(手)
|
||||||
|
buy_lg_amount float Y 大单买入金额(万元)
|
||||||
|
sell_lg_vol int Y 大单卖出量(手)
|
||||||
|
sell_lg_amount float Y 大单卖出金额(万元)
|
||||||
|
buy_elg_vol int Y 特大单买入量(手)
|
||||||
|
buy_elg_amount float Y 特大单买入金额(万元)
|
||||||
|
sell_elg_vol int Y 特大单卖出量(手)
|
||||||
|
sell_elg_amount float Y 特大单卖出金额(万元)
|
||||||
|
net_mf_vol int Y 净流入量(手)
|
||||||
|
net_mf_amount float Y 净流入额(万元)
|
||||||
|
|
||||||
|
各类别统计规则如下:
|
||||||
|
小单:5万以下 中单:5万~20万 大单:20万~100万 特大单:成交额>=100万 ,数据基于主动买卖单统计
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
接口示例
|
||||||
|
|
||||||
|
|
||||||
|
pro = ts.pro_api('your token')
|
||||||
|
|
||||||
|
#获取单日全部股票数据
|
||||||
|
df = pro.moneyflow(trade_date='20190315')
|
||||||
|
|
||||||
|
#获取单个股票数据
|
||||||
|
df = pro.moneyflow(ts_code='002149.SZ', start_date='20190115', end_date='20190315')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
数据示例
|
||||||
|
|
||||||
|
ts_code trade_date buy_sm_vol buy_sm_amount sell_sm_vol \
|
||||||
|
0 000779.SZ 20190315 11377 1150.17 11100
|
||||||
|
1 000933.SZ 20190315 94220 4803.22 105924
|
||||||
|
2 002270.SZ 20190315 43979 2330.96 45893
|
||||||
|
3 002319.SZ 20190315 21502 2952.88 17155
|
||||||
|
4 002604.SZ 20190315 31944 607.35 58667
|
||||||
|
5 300065.SZ 20190315 16048 2294.71 16425
|
||||||
|
6 600062.SH 20190315 55439 7432.13 65765
|
||||||
|
7 002735.SZ 20190315 3220 797.10 4598
|
||||||
|
8 300196.SZ 20190315 12534 1286.02 8340
|
||||||
|
9 300350.SZ 20190315 15346 1120.12 18853
|
||||||
|
10 600193.SH 20190315 12183 503.73 19576
|
||||||
|
11 002866.SZ 20190315 16932 2213.68 16037
|
||||||
|
12 300481.SZ 20190315 21386 4275.33 21863
|
||||||
|
13 600527.SH 20190315 115462 2975.44 79272
|
||||||
|
14 603980.SH 20190315 13957 1924.69 11718
|
||||||
|
15 600658.SH 20190315 71767 4826.73 69535
|
||||||
|
16 600812.SH 20190315 26140 1247.47 34923
|
||||||
|
17 002013.SZ 20190315 170234 12286.02 148509
|
||||||
|
18 600789.SH 20190315 211012 21644.56 150598
|
||||||
|
19 601636.SH 20190315 70737 3117.43 68073
|
||||||
|
20 000807.SZ 20190315 129668 6361.06 122077
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
sell_sm_amount buy_md_vol buy_md_amount sell_md_vol sell_md_amount \
|
||||||
|
0 1122.97 13012 1316.72 14812 1498.90
|
||||||
|
1 5411.72 135976 6935.40 154023 7863.00
|
||||||
|
2 2435.98 57679 3059.15 47279 2507.55
|
||||||
|
3 2358.68 27245 3742.52 26708 3670.05
|
||||||
|
4 1114.40 69897 1327.41 41108 781.19
|
||||||
|
5 2353.34 31232 4472.05 26771 3834.95
|
||||||
|
6 8817.75 86617 11615.40 79551 10676.99
|
||||||
|
7 1140.61 4602 1141.61 2730 676.72
|
||||||
|
8 855.45 9401 963.72 10478 1074.32
|
||||||
|
9 1380.31 24224 1770.90 21588 1577.92
|
||||||
|
10 812.58 28696 1185.17 31087 1286.11
|
||||||
|
11 2100.70 19197 2511.62 20269 2650.56
|
||||||
|
12 4379.14 31692 6345.72 32873 6578.36
|
||||||
|
13 2046.54 107103 2763.00 84883 2191.24
|
||||||
|
14 1619.33 14621 2019.41 14528 2005.69
|
||||||
|
15 4691.29 92788 6232.80 93273 6280.13
|
||||||
|
16 1669.97 38812 1855.78 39211 1874.05
|
||||||
|
17 10726.22 154979 11190.69 164090 11855.76
|
||||||
|
18 15479.08 269470 27660.18 236958 24338.36
|
||||||
|
19 3000.73 90416 3984.68 115162 5075.50
|
||||||
|
20 5999.66 175692 8627.77 178044 8751.08
|
||||||
@@ -29,12 +29,6 @@ class Settings(BaseSettings):
|
|||||||
root_path: str = "" # 项目根路径,默认自动检测
|
root_path: str = "" # 项目根路径,默认自动检测
|
||||||
data_path: str = "data" # 数据存储路径,相对于 root_path
|
data_path: str = "data" # 数据存储路径,相对于 root_path
|
||||||
|
|
||||||
# API 速率限制(每分钟请求数)
|
|
||||||
rate_limit: int = 300
|
|
||||||
|
|
||||||
# 同步工作线程数
|
|
||||||
threads: int = 10
|
|
||||||
|
|
||||||
# 数据库配置(可选,用于未来扩展)
|
# 数据库配置(可选,用于未来扩展)
|
||||||
database_host: str = "localhost"
|
database_host: str = "localhost"
|
||||||
database_port: int = 5432
|
database_port: int = 5432
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ Available APIs:
|
|||||||
- api_stock_st: ST stock list (ST股票列表)
|
- api_stock_st: ST stock list (ST股票列表)
|
||||||
- api_stk_limit: Stock limit price (每日涨跌停价格)
|
- api_stk_limit: Stock limit price (每日涨跌停价格)
|
||||||
- api_cyq_perf: CYQ performance (每日筹码及胜率)
|
- api_cyq_perf: CYQ performance (每日筹码及胜率)
|
||||||
|
- api_moneyflow: Money flow (个股资金流向)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
||||||
@@ -21,6 +22,7 @@ Example:
|
|||||||
>>> from src.data.api_wrappers import get_stock_st, sync_stock_st
|
>>> from src.data.api_wrappers import get_stock_st, sync_stock_st
|
||||||
>>> from src.data.api_wrappers import get_stk_limit, sync_stk_limit
|
>>> from src.data.api_wrappers import get_stk_limit, sync_stk_limit
|
||||||
>>> from src.data.api_wrappers import get_cyq_perf, sync_cyq_perf
|
>>> from src.data.api_wrappers import get_cyq_perf, sync_cyq_perf
|
||||||
|
>>> from src.data.api_wrappers import get_moneyflow, sync_moneyflow
|
||||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
>>> daily_basic = get_daily_basic(trade_date='20240101')
|
>>> daily_basic = get_daily_basic(trade_date='20240101')
|
||||||
@@ -30,6 +32,7 @@ Example:
|
|||||||
>>> stock_st = get_stock_st(trade_date='20240101')
|
>>> stock_st = get_stock_st(trade_date='20240101')
|
||||||
>>> stk_limit = get_stk_limit(trade_date='20240101')
|
>>> stk_limit = get_stk_limit(trade_date='20240101')
|
||||||
>>> cyq_perf = get_cyq_perf(trade_date='20240115')
|
>>> cyq_perf = get_cyq_perf(trade_date='20240115')
|
||||||
|
>>> moneyflow = get_moneyflow(trade_date='20240115')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.data.api_wrappers.api_daily_basic import (
|
from src.data.api_wrappers.api_daily_basic import (
|
||||||
@@ -77,6 +80,12 @@ from src.data.api_wrappers.api_cyq_perf import (
|
|||||||
preview_cyq_perf_sync,
|
preview_cyq_perf_sync,
|
||||||
CyqPerfSync,
|
CyqPerfSync,
|
||||||
)
|
)
|
||||||
|
from src.data.api_wrappers.api_moneyflow import (
|
||||||
|
get_moneyflow,
|
||||||
|
sync_moneyflow,
|
||||||
|
preview_moneyflow_sync,
|
||||||
|
MoneyflowSync,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Daily market data
|
# Daily market data
|
||||||
@@ -129,6 +138,11 @@ __all__ = [
|
|||||||
"sync_cyq_perf",
|
"sync_cyq_perf",
|
||||||
"preview_cyq_perf_sync",
|
"preview_cyq_perf_sync",
|
||||||
"CyqPerfSync",
|
"CyqPerfSync",
|
||||||
|
# Moneyflow (个股资金流向)
|
||||||
|
"get_moneyflow",
|
||||||
|
"sync_moneyflow",
|
||||||
|
"preview_moneyflow_sync",
|
||||||
|
"MoneyflowSync",
|
||||||
]
|
]
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -223,6 +237,17 @@ try:
|
|||||||
order=60,
|
order=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 9. Moneyflow - 个股资金流向
|
||||||
|
from src.data.api_wrappers.api_moneyflow import MoneyflowSync
|
||||||
|
|
||||||
|
sync_registry.register_class(
|
||||||
|
name="moneyflow",
|
||||||
|
sync_class=MoneyflowSync,
|
||||||
|
display_name="个股资金流向",
|
||||||
|
description="沪深A股资金流向数据,分析大单小单成交情况(2010年开始)",
|
||||||
|
order=70,
|
||||||
|
)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# sync_registry 可能不存在(首次导入),忽略
|
# sync_registry 可能不存在(首次导入),忽略
|
||||||
pass
|
pass
|
||||||
|
|||||||
228
src/data/api_wrappers/api_moneyflow.py
Normal file
228
src/data/api_wrappers/api_moneyflow.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""个股资金流向 (Moneyflow) interface.
|
||||||
|
|
||||||
|
Fetch A-share stock money flow data from Tushare.
|
||||||
|
This interface retrieves fund flow data analyzing large and small order transactions.
|
||||||
|
Data starts from 2010.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.data.client import TushareClient
|
||||||
|
from src.data.api_wrappers.base_sync import DateBasedSync
|
||||||
|
|
||||||
|
|
||||||
|
def get_moneyflow(
|
||||||
|
trade_date: Optional[str] = None,
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
client: Optional[TushareClient] = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Fetch individual stock money flow data from Tushare.
|
||||||
|
|
||||||
|
This interface retrieves fund flow data analyzing large and small order
|
||||||
|
transactions for A-share stocks. Data starts from 2010.
|
||||||
|
|
||||||
|
Order size classification:
|
||||||
|
- Small orders (小单): < 50,000 yuan
|
||||||
|
- Medium orders (中单): 50,000 - 200,000 yuan
|
||||||
|
- Large orders (大单): 200,000 - 1,000,000 yuan
|
||||||
|
- Extra large orders (特大单): >= 1,000,000 yuan
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade_date: Specific trade date in YYYYMMDD format
|
||||||
|
ts_code: Stock code filter (optional, e.g., '000001.SZ')
|
||||||
|
start_date: Start date for date range query (YYYYMMDD format)
|
||||||
|
end_date: End date for date range query (YYYYMMDD format)
|
||||||
|
client: Optional TushareClient instance for shared rate limiting.
|
||||||
|
If None, creates a new client. For concurrent sync operations,
|
||||||
|
pass a shared client to ensure proper rate limiting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with columns:
|
||||||
|
- ts_code: Stock code
|
||||||
|
- trade_date: Trade date (YYYYMMDD)
|
||||||
|
- buy_sm_vol: Small order buy volume (hands)
|
||||||
|
- buy_sm_amount: Small order buy amount (10k yuan)
|
||||||
|
- sell_sm_vol: Small order sell volume (hands)
|
||||||
|
- sell_sm_amount: Small order sell amount (10k yuan)
|
||||||
|
- buy_md_vol: Medium order buy volume (hands)
|
||||||
|
- buy_md_amount: Medium order buy amount (10k yuan)
|
||||||
|
- sell_md_vol: Medium order sell volume (hands)
|
||||||
|
- sell_md_amount: Medium order sell amount (10k yuan)
|
||||||
|
- buy_lg_vol: Large order buy volume (hands)
|
||||||
|
- buy_lg_amount: Large order buy amount (10k yuan)
|
||||||
|
- sell_lg_vol: Large order sell volume (hands)
|
||||||
|
- sell_lg_amount: Large order sell amount (10k yuan)
|
||||||
|
- buy_elg_vol: Extra large order buy volume (hands)
|
||||||
|
- buy_elg_amount: Extra large order buy amount (10k yuan)
|
||||||
|
- sell_elg_vol: Extra large order sell volume (hands)
|
||||||
|
- sell_elg_amount: Extra large order sell amount (10k yuan)
|
||||||
|
- net_mf_vol: Net money flow volume (hands)
|
||||||
|
- net_mf_amount: Net money flow amount (10k yuan)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Get all stocks' money flow for a single date
|
||||||
|
>>> data = get_moneyflow(trade_date='20240115')
|
||||||
|
>>>
|
||||||
|
>>> # Get date range data for a specific stock
|
||||||
|
>>> data = get_moneyflow(ts_code='000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
|
>>>
|
||||||
|
>>> # Get specific stock on specific date
|
||||||
|
>>> data = get_moneyflow(ts_code='000001.SZ', trade_date='20240115')
|
||||||
|
"""
|
||||||
|
client = client or TushareClient()
|
||||||
|
|
||||||
|
# Build parameters
|
||||||
|
params = {}
|
||||||
|
if trade_date:
|
||||||
|
params["trade_date"] = trade_date
|
||||||
|
if ts_code:
|
||||||
|
params["ts_code"] = ts_code
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
|
||||||
|
# Fetch data using moneyflow API
|
||||||
|
data = client.query("moneyflow", **params)
|
||||||
|
|
||||||
|
# Rename date column if needed
|
||||||
|
if "date" in data.columns:
|
||||||
|
data = data.rename(columns={"date": "trade_date"})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class MoneyflowSync(DateBasedSync):
|
||||||
|
"""个股资金流向数据批量同步管理器,支持全量/增量同步。
|
||||||
|
|
||||||
|
继承自 DateBasedSync,使用按日期并发获取数据。
|
||||||
|
数据始于 2010 年。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> sync = MoneyflowSync()
|
||||||
|
>>> results = sync.sync_all() # 增量同步
|
||||||
|
>>> results = sync.sync_all(force_full=True) # 全量同步
|
||||||
|
>>> preview = sync.preview_sync() # 预览
|
||||||
|
"""
|
||||||
|
|
||||||
|
table_name = "moneyflow"
|
||||||
|
default_start_date = "20100101"
|
||||||
|
|
||||||
|
# 表结构定义 - 使用 Tushare API 原始字段名
|
||||||
|
TABLE_SCHEMA = {
|
||||||
|
"ts_code": "VARCHAR(16) NOT NULL",
|
||||||
|
"trade_date": "DATE NOT NULL",
|
||||||
|
"buy_sm_vol": "INTEGER", # 小单买入量(手)
|
||||||
|
"buy_sm_amount": "DOUBLE", # 小单买入金额(万元)
|
||||||
|
"sell_sm_vol": "INTEGER", # 小单卖出量(手)
|
||||||
|
"sell_sm_amount": "DOUBLE", # 小单卖出金额(万元)
|
||||||
|
"buy_md_vol": "INTEGER", # 中单买入量(手)
|
||||||
|
"buy_md_amount": "DOUBLE", # 中单买入金额(万元)
|
||||||
|
"sell_md_vol": "INTEGER", # 中单卖出量(手)
|
||||||
|
"sell_md_amount": "DOUBLE", # 中单卖出金额(万元)
|
||||||
|
"buy_lg_vol": "INTEGER", # 大单买入量(手)
|
||||||
|
"buy_lg_amount": "DOUBLE", # 大单买入金额(万元)
|
||||||
|
"sell_lg_vol": "INTEGER", # 大单卖出量(手)
|
||||||
|
"sell_lg_amount": "DOUBLE", # 大单卖出金额(万元)
|
||||||
|
"buy_elg_vol": "INTEGER", # 特大单买入量(手)
|
||||||
|
"buy_elg_amount": "DOUBLE", # 特大单买入金额(万元)
|
||||||
|
"sell_elg_vol": "INTEGER", # 特大单卖出量(手)
|
||||||
|
"sell_elg_amount": "DOUBLE", # 特大单卖出金额(万元)
|
||||||
|
"net_mf_vol": "INTEGER", # 净流入量(手)
|
||||||
|
"net_mf_amount": "DOUBLE", # 净流入额(万元)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 索引定义
|
||||||
|
TABLE_INDEXES = [
|
||||||
|
("idx_moneyflow_date_code", ["trade_date", "ts_code"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 主键定义
|
||||||
|
PRIMARY_KEY = ("ts_code", "trade_date")
|
||||||
|
|
||||||
|
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
|
||||||
|
"""获取单日所有股票的资金流向数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade_date: 交易日期(YYYYMMDD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含当日所有股票资金流向数据的 DataFrame
|
||||||
|
"""
|
||||||
|
return get_moneyflow(trade_date=trade_date, client=self.client)
|
||||||
|
|
||||||
|
|
||||||
|
def sync_moneyflow(
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
force_full: bool = False,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""同步个股资金流向数据到 DuckDB,支持智能增量同步。
|
||||||
|
|
||||||
|
逻辑:
|
||||||
|
- 若表不存在:创建表 + 复合索引 (trade_date, ts_code) + 全量同步
|
||||||
|
- 若表存在:从 last_date + 1 开始增量同步
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: 起始日期(YYYYMMDD 格式,默认全量从 20100101,增量从 last_date+1)
|
||||||
|
end_date: 结束日期(YYYYMMDD 格式,默认为今天)
|
||||||
|
force_full: 若为 True,强制从 20100101 完整重载
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含同步数据的 pd.DataFrame
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # 首次同步(从 20100101 全量加载)
|
||||||
|
>>> result = sync_moneyflow()
|
||||||
|
>>>
|
||||||
|
>>> # 后续同步(增量 - 仅新数据)
|
||||||
|
>>> result = sync_moneyflow()
|
||||||
|
>>>
|
||||||
|
>>> # 强制完整重载
|
||||||
|
>>> result = sync_moneyflow(force_full=True)
|
||||||
|
>>>
|
||||||
|
>>> # 手动指定日期范围
|
||||||
|
>>> result = sync_moneyflow(start_date='20240101', end_date='20240131')
|
||||||
|
"""
|
||||||
|
sync_manager = MoneyflowSync()
|
||||||
|
return sync_manager.sync_all(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
force_full=force_full,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preview_moneyflow_sync(
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
force_full: bool = False,
|
||||||
|
sample_size: int = 3,
|
||||||
|
) -> dict:
|
||||||
|
"""预览个股资金流向数据同步数据量和样本(不实际同步)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: 手动指定起始日期(覆盖自动检测)
|
||||||
|
end_date: 手动指定结束日期(默认为今天)
|
||||||
|
force_full: 若为 True,预览全量同步(从 20100101)
|
||||||
|
sample_size: 预览天数(默认: 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含预览信息的字典
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # 预览将要同步的内容
|
||||||
|
>>> preview = preview_moneyflow_sync()
|
||||||
|
>>>
|
||||||
|
>>> # 预览全量同步
|
||||||
|
>>> preview = preview_moneyflow_sync(force_full=True)
|
||||||
|
"""
|
||||||
|
sync_manager = MoneyflowSync()
|
||||||
|
return sync_manager.preview_sync(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
force_full=force_full,
|
||||||
|
sample_size=sample_size,
|
||||||
|
)
|
||||||
@@ -35,8 +35,8 @@ from tqdm import tqdm
|
|||||||
from src.data.client import TushareClient
|
from src.data.client import TushareClient
|
||||||
from src.data.storage import ThreadSafeStorage, Storage
|
from src.data.storage import ThreadSafeStorage, Storage
|
||||||
from src.data.sync_logger import SyncLogManager
|
from src.data.sync_logger import SyncLogManager
|
||||||
|
from src.data.sync_config import get_threads
|
||||||
from src.data.utils import get_today_date, get_next_date
|
from src.data.utils import get_today_date, get_next_date
|
||||||
from src.config.settings import get_settings
|
|
||||||
from src.data.api_wrappers.api_trade_cal import (
|
from src.data.api_wrappers.api_trade_cal import (
|
||||||
get_first_trading_day,
|
get_first_trading_day,
|
||||||
get_last_trading_day,
|
get_last_trading_day,
|
||||||
@@ -63,7 +63,6 @@ class BaseDataSync(ABC):
|
|||||||
|
|
||||||
table_name: str = "" # 子类必须覆盖
|
table_name: str = "" # 子类必须覆盖
|
||||||
DEFAULT_START_DATE = "20180101"
|
DEFAULT_START_DATE = "20180101"
|
||||||
DEFAULT_MAX_WORKERS = get_settings().threads
|
|
||||||
|
|
||||||
# 表结构定义(子类可覆盖)
|
# 表结构定义(子类可覆盖)
|
||||||
# 格式: {"column_name": "SQL_TYPE", ...}
|
# 格式: {"column_name": "SQL_TYPE", ...}
|
||||||
@@ -81,11 +80,14 @@ class BaseDataSync(ABC):
|
|||||||
"""初始化同步管理器。
|
"""初始化同步管理器。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_workers: 工作线程数(默认从配置读取)
|
max_workers: 工作线程数(默认从配置读取,根据 table_name 获取)
|
||||||
"""
|
"""
|
||||||
self.storage = ThreadSafeStorage()
|
self.storage = ThreadSafeStorage()
|
||||||
self.client = TushareClient()
|
# 使用 table_name 作为接口名称初始化客户端,获取特定的速率限制
|
||||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
self.client = TushareClient(interface_name=self.table_name or "default")
|
||||||
|
# 根据表名获取线程数配置
|
||||||
|
default_workers = get_threads(self.table_name or "default")
|
||||||
|
self.max_workers = max_workers or default_workers
|
||||||
self._stop_flag = threading.Event()
|
self._stop_flag = threading.Event()
|
||||||
self._stop_flag.set() # 初始为未停止状态
|
self._stop_flag.set() # 初始为未停止状态
|
||||||
self._cached_data: Optional[pd.DataFrame] = None
|
self._cached_data: Optional[pd.DataFrame] = None
|
||||||
|
|||||||
@@ -37,6 +37,7 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from src.data.sync_logger import SyncLogManager
|
from src.data.sync_logger import SyncLogManager
|
||||||
|
from src.data.sync_config import get_rate_limit, get_threads
|
||||||
from src.data.api_wrappers.financial_data.api_income import (
|
from src.data.api_wrappers.financial_data.api_income import (
|
||||||
IncomeQuarterSync,
|
IncomeQuarterSync,
|
||||||
sync_income,
|
sync_income,
|
||||||
@@ -151,7 +152,12 @@ def sync_financial(
|
|||||||
sync_func = config["sync_func"]
|
sync_func = config["sync_func"]
|
||||||
display_name = config["display_name"]
|
display_name = config["display_name"]
|
||||||
|
|
||||||
|
# 获取当前接口的 rate limit
|
||||||
|
rate_limit = get_rate_limit(data_type)
|
||||||
|
threads = get_threads(data_type)
|
||||||
|
|
||||||
print(f"\n[{display_name}] 开始同步...")
|
print(f"\n[{display_name}] 开始同步...")
|
||||||
|
print(f" Rate Limit: {rate_limit}/min, Threads: {threads}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = sync_func(force_full=force_full, dry_run=dry_run)
|
result = sync_func(force_full=force_full, dry_run=dry_run)
|
||||||
|
|||||||
@@ -2,23 +2,29 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import tushare as ts
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from src.data.rate_limiter import TokenBucketRateLimiter
|
from src.data.rate_limiter import TokenBucketRateLimiter
|
||||||
|
from src.data.sync_config import get_rate_limit
|
||||||
from src.config.settings import get_settings
|
from src.config.settings import get_settings
|
||||||
|
|
||||||
|
|
||||||
class TushareClient:
|
class TushareClient:
|
||||||
"""Tushare API client with rate limiting and retry."""
|
"""Tushare API client with rate limiting and retry."""
|
||||||
|
|
||||||
# 类级别共享限流器(确保所有实例共享同一个限流器)
|
# 类级别限流器缓存(按接口名称存储)
|
||||||
_shared_limiter: Optional[TokenBucketRateLimiter] = None
|
_limiter_cache: dict[str, TokenBucketRateLimiter] = {}
|
||||||
_cached_rate_limit: int = 0 # 缓存上次使用的 rate_limit
|
_cached_rate_limits: dict[str, int] = {}
|
||||||
|
|
||||||
def __init__(self, token: Optional[str] = None):
|
def __init__(
|
||||||
|
self, token: Optional[str] = None, interface_name: Optional[str] = None
|
||||||
|
):
|
||||||
"""Initialize client.
|
"""Initialize client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: Tushare API token (auto-loaded from config if not provided)
|
token: Tushare API token (auto-loaded from config if not provided)
|
||||||
|
interface_name: 接口名称,用于获取特定的速率限制配置
|
||||||
"""
|
"""
|
||||||
cfg = get_settings()
|
cfg = get_settings()
|
||||||
token = token or cfg.tushare_token
|
token = token or cfg.tushare_token
|
||||||
@@ -28,32 +34,57 @@ class TushareClient:
|
|||||||
|
|
||||||
self.token = token
|
self.token = token
|
||||||
self.config = cfg
|
self.config = cfg
|
||||||
|
self.interface_name = interface_name or "default"
|
||||||
|
|
||||||
# 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器)
|
# 获取或创建限流器
|
||||||
# 检查是否需要重新创建限流器(配置发生变化时)
|
self.rate_limiter = self._get_or_create_limiter(self.interface_name)
|
||||||
if (
|
|
||||||
TushareClient._shared_limiter is None
|
|
||||||
or TushareClient._cached_rate_limit != cfg.rate_limit
|
|
||||||
):
|
|
||||||
# 首次创建或配置变更:重新初始化共享限流器
|
|
||||||
TushareClient._shared_limiter = TokenBucketRateLimiter(
|
|
||||||
rate_limit=cfg.rate_limit,
|
|
||||||
)
|
|
||||||
TushareClient._cached_rate_limit = cfg.rate_limit
|
|
||||||
min_interval = 60.0 / cfg.rate_limit
|
|
||||||
print(
|
|
||||||
f"[TushareClient] Initialized shared rate limiter: rate={cfg.rate_limit}/min, interval={min_interval:.2f}s"
|
|
||||||
)
|
|
||||||
# 复用共享限流器
|
|
||||||
self.rate_limiter = TushareClient._shared_limiter
|
|
||||||
|
|
||||||
self._api = None
|
self._api = None
|
||||||
|
|
||||||
|
def _get_or_create_limiter(self, interface_name: str) -> TokenBucketRateLimiter:
|
||||||
|
"""获取或创建指定接口的限流器。
|
||||||
|
|
||||||
|
如果接口的 rate_limit 配置发生变化,会重新创建限流器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBucketRateLimiter 实例
|
||||||
|
"""
|
||||||
|
# 获取当前配置的 rate_limit
|
||||||
|
current_rate_limit = get_rate_limit(interface_name)
|
||||||
|
|
||||||
|
# 检查是否需要创建新的限流器
|
||||||
|
if (
|
||||||
|
interface_name not in TushareClient._limiter_cache
|
||||||
|
or TushareClient._cached_rate_limits.get(interface_name)
|
||||||
|
!= current_rate_limit
|
||||||
|
):
|
||||||
|
# 创建新的限流器
|
||||||
|
TushareClient._limiter_cache[interface_name] = TokenBucketRateLimiter(
|
||||||
|
rate_limit=current_rate_limit,
|
||||||
|
)
|
||||||
|
TushareClient._cached_rate_limits[interface_name] = current_rate_limit
|
||||||
|
min_interval = 60.0 / current_rate_limit
|
||||||
|
print(
|
||||||
|
f"[TushareClient] Initialized rate limiter for '{interface_name}': "
|
||||||
|
f"rate={current_rate_limit}/min, interval={min_interval:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return TushareClient._limiter_cache[interface_name]
|
||||||
|
|
||||||
|
def get_current_rate_limit(self) -> int:
|
||||||
|
"""获取当前接口的速率限制。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每分钟请求数
|
||||||
|
"""
|
||||||
|
return get_rate_limit(self.interface_name)
|
||||||
|
|
||||||
def _get_api(self):
|
def _get_api(self):
|
||||||
"""Get Tushare API instance."""
|
"""Get Tushare API instance."""
|
||||||
if self._api is None:
|
if self._api is None:
|
||||||
import tushare as ts
|
|
||||||
|
|
||||||
self._api = ts.pro_api(self.token)
|
self._api = ts.pro_api(self.token)
|
||||||
return self._api
|
return self._api
|
||||||
|
|
||||||
@@ -80,8 +111,6 @@ class TushareClient:
|
|||||||
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tushare as ts
|
|
||||||
|
|
||||||
# pro_bar uses ts.pro_bar() instead of api.query()
|
# pro_bar uses ts.pro_bar() instead of api.query()
|
||||||
if api_name == "pro_bar":
|
if api_name == "pro_bar":
|
||||||
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor
|
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
- api_stock_st.py: ST股票列表同步 (StockSTSync 类)
|
- api_stock_st.py: ST股票列表同步 (StockSTSync 类)
|
||||||
- api_stk_limit.py: 涨跌停价格同步 (StkLimitSync 类)
|
- api_stk_limit.py: 涨跌停价格同步 (StkLimitSync 类)
|
||||||
- api_cyq_perf.py: 筹码分布数据同步 (CyqPerfSync 类)
|
- api_cyq_perf.py: 筹码分布数据同步 (CyqPerfSync 类)
|
||||||
|
- api_moneyflow.py: 个股资金流向同步 (MoneyflowSync 类)
|
||||||
- api_stock_basic.py: 股票基本信息同步
|
- api_stock_basic.py: 股票基本信息同步
|
||||||
- api_trade_cal.py: 交易日历同步
|
- api_trade_cal.py: 交易日历同步
|
||||||
|
|
||||||
@@ -82,6 +83,7 @@ def sync_all_data(
|
|||||||
6. stock_st: ST股票列表
|
6. stock_st: ST股票列表
|
||||||
7. stk_limit: 每日涨跌停价格
|
7. stk_limit: 每日涨跌停价格
|
||||||
8. cyq_perf: 每日筹码及胜率
|
8. cyq_perf: 每日筹码及胜率
|
||||||
|
9. moneyflow: 个股资金流向
|
||||||
|
|
||||||
新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码,
|
新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码,
|
||||||
无需修改本函数。
|
无需修改本函数。
|
||||||
|
|||||||
431
src/data/sync_config.py
Normal file
431
src/data/sync_config.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
"""同步配置管理模块。
|
||||||
|
|
||||||
|
该模块提供统一的同步配置管理,支持:
|
||||||
|
- 默认的线程数和速率限制配置
|
||||||
|
- 每个接口单独的速率限制配置
|
||||||
|
- 从 JSON 配置文件加载配置
|
||||||
|
|
||||||
|
配置文件路径: config/sync_config.json
|
||||||
|
|
||||||
|
配置示例:
|
||||||
|
{
|
||||||
|
"default": {
|
||||||
|
"threads": 10,
|
||||||
|
"rate_limit": 300
|
||||||
|
},
|
||||||
|
"interfaces": {
|
||||||
|
"pro_bar": {
|
||||||
|
"rate_limit": 200
|
||||||
|
},
|
||||||
|
"daily_basic": {
|
||||||
|
"rate_limit": 150
|
||||||
|
},
|
||||||
|
"income": {
|
||||||
|
"rate_limit": 100
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InterfaceConfig:
|
||||||
|
"""单个接口的配置。"""
|
||||||
|
|
||||||
|
rate_limit: Optional[int] = None
|
||||||
|
threads: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SyncConfig:
|
||||||
|
"""同步配置类。"""
|
||||||
|
|
||||||
|
default: InterfaceConfig = field(
|
||||||
|
default_factory=lambda: InterfaceConfig(
|
||||||
|
rate_limit=300,
|
||||||
|
threads=10,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
interfaces: Dict[str, InterfaceConfig] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def get_interface_config(self, interface_name: str) -> InterfaceConfig:
|
||||||
|
"""获取指定接口的配置。
|
||||||
|
|
||||||
|
如果接口有特定配置则返回,否则返回默认配置。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
接口配置对象
|
||||||
|
"""
|
||||||
|
if interface_name in self.interfaces:
|
||||||
|
interface_config = self.interfaces[interface_name]
|
||||||
|
# 合并默认值(特定配置覆盖默认)
|
||||||
|
return InterfaceConfig(
|
||||||
|
rate_limit=interface_config.rate_limit
|
||||||
|
if interface_config.rate_limit is not None
|
||||||
|
else self.default.rate_limit,
|
||||||
|
threads=interface_config.threads
|
||||||
|
if interface_config.threads is not None
|
||||||
|
else self.default.threads,
|
||||||
|
)
|
||||||
|
return self.default
|
||||||
|
|
||||||
|
def get_rate_limit(self, interface_name: str) -> int:
|
||||||
|
"""获取指定接口的速率限制。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每分钟请求数
|
||||||
|
"""
|
||||||
|
config = self.get_interface_config(interface_name)
|
||||||
|
return config.rate_limit or self.default.rate_limit or 300
|
||||||
|
|
||||||
|
def get_threads(self, interface_name: str) -> int:
|
||||||
|
"""获取指定接口的线程数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
线程数
|
||||||
|
"""
|
||||||
|
config = self.get_interface_config(interface_name)
|
||||||
|
return config.threads or self.default.threads or 10
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典格式。"""
|
||||||
|
return {
|
||||||
|
"default": {
|
||||||
|
"rate_limit": self.default.rate_limit,
|
||||||
|
"threads": self.default.threads,
|
||||||
|
},
|
||||||
|
"interfaces": {
|
||||||
|
name: {
|
||||||
|
"rate_limit": config.rate_limit,
|
||||||
|
"threads": config.threads,
|
||||||
|
}
|
||||||
|
for name, config in self.interfaces.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SyncConfigManager:
|
||||||
|
"""同步配置管理器。
|
||||||
|
|
||||||
|
负责加载、保存和管理同步配置。
|
||||||
|
支持从 JSON 文件加载配置,提供线程安全的单例访问。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional["SyncConfigManager"] = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
"default": {
|
||||||
|
"threads": 10,
|
||||||
|
"rate_limit": 300,
|
||||||
|
},
|
||||||
|
"interfaces": {
|
||||||
|
# 高频接口(每秒可请求多次)
|
||||||
|
"trade_cal": {"rate_limit": 500},
|
||||||
|
"stock_basic": {"rate_limit": 500},
|
||||||
|
"bak_basic": {"rate_limit": 200},
|
||||||
|
"stock_st": {"rate_limit": 200},
|
||||||
|
"stk_limit": {"rate_limit": 200},
|
||||||
|
"cyq_perf": {"rate_limit": 150},
|
||||||
|
"moneyflow": {"rate_limit": 150},
|
||||||
|
"daily_basic": {"rate_limit": 150},
|
||||||
|
"pro_bar": {"rate_limit": 200},
|
||||||
|
# 低频财务接口
|
||||||
|
"income": {"rate_limit": 100},
|
||||||
|
"balance": {"rate_limit": 100},
|
||||||
|
"cashflow": {"rate_limit": 100},
|
||||||
|
"fina_indicator": {"rate_limit": 100},
|
||||||
|
"namechange": {"rate_limit": 200},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __new__(cls, config_path: Optional[str] = None) -> "SyncConfigManager":
|
||||||
|
"""单例模式确保全局只有一个配置管理器实例。"""
|
||||||
|
if cls._instance is None:
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, config_path: Optional[str] = None):
|
||||||
|
"""初始化配置管理器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: 配置文件路径,默认使用项目根目录下的 config/sync_config.json
|
||||||
|
"""
|
||||||
|
# 避免重复初始化
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
if config_path is None:
|
||||||
|
# 默认路径:项目根目录/config/sync_config.json
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
self.config_path = project_root / "config" / "sync_config.json"
|
||||||
|
else:
|
||||||
|
self.config_path = Path(config_path)
|
||||||
|
|
||||||
|
self._config: Optional[SyncConfig] = None
|
||||||
|
self._config_lock = threading.RLock()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def _load_config(self) -> SyncConfig:
|
||||||
|
"""从文件加载配置。
|
||||||
|
|
||||||
|
如果配置文件不存在,则创建默认配置。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
同步配置对象
|
||||||
|
"""
|
||||||
|
with self._config_lock:
|
||||||
|
if self._config is not None:
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
if self.config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
default_data = data.get("default", {})
|
||||||
|
default_config = InterfaceConfig(
|
||||||
|
rate_limit=default_data.get("rate_limit", 300),
|
||||||
|
threads=default_data.get("threads", 10),
|
||||||
|
)
|
||||||
|
|
||||||
|
interfaces = {}
|
||||||
|
for name, config_data in data.get("interfaces", {}).items():
|
||||||
|
interfaces[name] = InterfaceConfig(
|
||||||
|
rate_limit=config_data.get("rate_limit"),
|
||||||
|
threads=config_data.get("threads"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._config = SyncConfig(
|
||||||
|
default=default_config,
|
||||||
|
interfaces=interfaces,
|
||||||
|
)
|
||||||
|
print(f"[SyncConfig] Loaded config from {self.config_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[SyncConfig] Error loading config: {e}, using defaults")
|
||||||
|
self._config = self._create_default_config()
|
||||||
|
else:
|
||||||
|
self._config = self._create_default_config()
|
||||||
|
self.save_config()
|
||||||
|
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def _create_default_config(self) -> SyncConfig:
|
||||||
|
"""创建默认配置。"""
|
||||||
|
default = InterfaceConfig(
|
||||||
|
rate_limit=self.DEFAULT_CONFIG["default"]["rate_limit"],
|
||||||
|
threads=self.DEFAULT_CONFIG["default"]["threads"],
|
||||||
|
)
|
||||||
|
|
||||||
|
interfaces = {}
|
||||||
|
for name, config_data in self.DEFAULT_CONFIG["interfaces"].items():
|
||||||
|
interfaces[name] = InterfaceConfig(
|
||||||
|
rate_limit=config_data.get("rate_limit"),
|
||||||
|
threads=config_data.get("threads"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return SyncConfig(default=default, interfaces=interfaces)
|
||||||
|
|
||||||
|
def save_config(self) -> None:
|
||||||
|
"""保存当前配置到文件。"""
|
||||||
|
with self._config_lock:
|
||||||
|
if self._config is None:
|
||||||
|
self._config = self._create_default_config()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(self._config.to_dict(), f, indent=2, ensure_ascii=False)
|
||||||
|
print(f"[SyncConfig] Saved config to {self.config_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[SyncConfig] Error saving config: {e}")
|
||||||
|
|
||||||
|
def get_config(self) -> SyncConfig:
|
||||||
|
"""获取当前配置。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
同步配置对象
|
||||||
|
"""
|
||||||
|
with self._config_lock:
|
||||||
|
if self._config is None:
|
||||||
|
return self._load_config()
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def get_rate_limit(self, interface_name: str) -> int:
|
||||||
|
"""获取指定接口的速率限制。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每分钟请求数
|
||||||
|
"""
|
||||||
|
return self.get_config().get_rate_limit(interface_name)
|
||||||
|
|
||||||
|
def get_threads(self, interface_name: str) -> int:
|
||||||
|
"""获取指定接口的线程数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
线程数
|
||||||
|
"""
|
||||||
|
return self.get_config().get_threads(interface_name)
|
||||||
|
|
||||||
|
def update_interface_config(
|
||||||
|
self,
|
||||||
|
interface_name: str,
|
||||||
|
rate_limit: Optional[int] = None,
|
||||||
|
threads: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""更新指定接口的配置。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
rate_limit: 新的速率限制(None 表示不修改)
|
||||||
|
threads: 新的线程数(None 表示不修改)
|
||||||
|
"""
|
||||||
|
with self._config_lock:
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
if interface_name not in config.interfaces:
|
||||||
|
config.interfaces[interface_name] = InterfaceConfig()
|
||||||
|
|
||||||
|
interface_config = config.interfaces[interface_name]
|
||||||
|
if rate_limit is not None:
|
||||||
|
interface_config.rate_limit = rate_limit
|
||||||
|
if threads is not None:
|
||||||
|
interface_config.threads = threads
|
||||||
|
|
||||||
|
self.save_config()
|
||||||
|
|
||||||
|
def reset_to_defaults(self) -> None:
|
||||||
|
"""重置为默认配置。"""
|
||||||
|
with self._config_lock:
|
||||||
|
self._config = self._create_default_config()
|
||||||
|
self.save_config()
|
||||||
|
print("[SyncConfig] Reset to default configuration")
|
||||||
|
|
||||||
|
def list_interfaces(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""列出所有接口的配置。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
接口配置字典 {接口名: {rate_limit, threads}}
|
||||||
|
"""
|
||||||
|
config = self.get_config()
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# 添加默认配置
|
||||||
|
result["__default__"] = {
|
||||||
|
"rate_limit": config.default.rate_limit,
|
||||||
|
"threads": config.default.threads,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加各接口配置
|
||||||
|
all_interfaces = set(config.interfaces.keys()) | set(
|
||||||
|
self.DEFAULT_CONFIG["interfaces"].keys()
|
||||||
|
)
|
||||||
|
for name in sorted(all_interfaces):
|
||||||
|
iface_config = config.get_interface_config(name)
|
||||||
|
result[name] = {
|
||||||
|
"rate_limit": iface_config.rate_limit,
|
||||||
|
"threads": iface_config.threads,
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 全局配置管理器实例
|
||||||
|
_config_manager: Optional[SyncConfigManager] = None
|
||||||
|
_config_manager_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_sync_config_manager() -> SyncConfigManager:
|
||||||
|
"""获取同步配置管理器单例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SyncConfigManager 实例
|
||||||
|
"""
|
||||||
|
global _config_manager
|
||||||
|
if _config_manager is None:
|
||||||
|
with _config_manager_lock:
|
||||||
|
if _config_manager is None:
|
||||||
|
_config_manager = SyncConfigManager()
|
||||||
|
return _config_manager
|
||||||
|
|
||||||
|
|
||||||
|
def get_sync_config() -> SyncConfig:
|
||||||
|
"""获取同步配置。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SyncConfig 实例
|
||||||
|
"""
|
||||||
|
return get_sync_config_manager().get_config()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rate_limit(interface_name: str) -> int:
|
||||||
|
"""获取指定接口的速率限制。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每分钟请求数
|
||||||
|
"""
|
||||||
|
return get_sync_config_manager().get_rate_limit(interface_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_threads(interface_name: str) -> int:
|
||||||
|
"""获取指定接口的线程数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interface_name: 接口名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
线程数
|
||||||
|
"""
|
||||||
|
return get_sync_config_manager().get_threads(interface_name)
|
||||||
|
|
||||||
|
|
||||||
|
def print_sync_config() -> None:
|
||||||
|
"""打印当前同步配置(用于调试)。"""
|
||||||
|
manager = get_sync_config_manager()
|
||||||
|
configs = manager.list_interfaces()
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("[SyncConfig] Current Configuration")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
default_config = configs.pop("__default__", {})
|
||||||
|
print(
|
||||||
|
f"Default: rate_limit={default_config.get('rate_limit')}/min, threads={default_config.get('threads')}"
|
||||||
|
)
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
for name, config in sorted(configs.items()):
|
||||||
|
marker = "*" if name in manager.DEFAULT_CONFIG["interfaces"] else " "
|
||||||
|
print(
|
||||||
|
f" [{marker}] {name:20s}: rate_limit={config['rate_limit']:3d}/min, threads={config['threads']:2d}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
@@ -28,6 +28,8 @@ from collections import OrderedDict
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.data.sync_config import get_rate_limit, get_threads
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SyncTask:
|
class SyncTask:
|
||||||
@@ -221,7 +223,12 @@ class SyncRegistry:
|
|||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
for idx, task in enumerate(tasks, 1):
|
for idx, task in enumerate(tasks, 1):
|
||||||
|
# 获取当前接口的配置
|
||||||
|
rate_limit = get_rate_limit(task.name)
|
||||||
|
threads = get_threads(task.name)
|
||||||
|
|
||||||
print(f"\n[{idx}/{total}] Syncing {task.display_name}...")
|
print(f"\n[{idx}/{total}] Syncing {task.display_name}...")
|
||||||
|
print(f" Rate Limit: {rate_limit}/min, Threads: {threads}")
|
||||||
if task.description:
|
if task.description:
|
||||||
print(f" Description: {task.description}")
|
print(f" Description: {task.description}")
|
||||||
|
|
||||||
|
|||||||
@@ -123,133 +123,133 @@ SELECTED_FACTORS = [
|
|||||||
"GTJA_alpha048",
|
"GTJA_alpha048",
|
||||||
"GTJA_alpha049",
|
"GTJA_alpha049",
|
||||||
"GTJA_alpha050",
|
"GTJA_alpha050",
|
||||||
# "GTJA_alpha051",
|
"GTJA_alpha051",
|
||||||
# "GTJA_alpha052",
|
"GTJA_alpha052",
|
||||||
# "GTJA_alpha053",
|
"GTJA_alpha053",
|
||||||
# "GTJA_alpha054",
|
"GTJA_alpha054",
|
||||||
# "GTJA_alpha056",
|
"GTJA_alpha056",
|
||||||
# "GTJA_alpha057",
|
"GTJA_alpha057",
|
||||||
# "GTJA_alpha058",
|
"GTJA_alpha058",
|
||||||
# "GTJA_alpha059",
|
"GTJA_alpha059",
|
||||||
# "GTJA_alpha060",
|
"GTJA_alpha060",
|
||||||
# "GTJA_alpha061",
|
"GTJA_alpha061",
|
||||||
# "GTJA_alpha062",
|
"GTJA_alpha062",
|
||||||
# "GTJA_alpha063",
|
"GTJA_alpha063",
|
||||||
# "GTJA_alpha064",
|
"GTJA_alpha064",
|
||||||
# "GTJA_alpha065",
|
"GTJA_alpha065",
|
||||||
# "GTJA_alpha066",
|
"GTJA_alpha066",
|
||||||
# "GTJA_alpha067",
|
"GTJA_alpha067",
|
||||||
# "GTJA_alpha068",
|
"GTJA_alpha068",
|
||||||
# "GTJA_alpha070",
|
"GTJA_alpha070",
|
||||||
# "GTJA_alpha071",
|
"GTJA_alpha071",
|
||||||
# "GTJA_alpha072",
|
"GTJA_alpha072",
|
||||||
# "GTJA_alpha073",
|
"GTJA_alpha073",
|
||||||
# "GTJA_alpha074",
|
"GTJA_alpha074",
|
||||||
# "GTJA_alpha076",
|
"GTJA_alpha076",
|
||||||
# "GTJA_alpha077",
|
"GTJA_alpha077",
|
||||||
# "GTJA_alpha078",
|
"GTJA_alpha078",
|
||||||
# "GTJA_alpha079",
|
"GTJA_alpha079",
|
||||||
# "GTJA_alpha080",
|
"GTJA_alpha080",
|
||||||
# "GTJA_alpha081",
|
"GTJA_alpha081",
|
||||||
# "GTJA_alpha082",
|
"GTJA_alpha082",
|
||||||
# "GTJA_alpha083",
|
"GTJA_alpha083",
|
||||||
# "GTJA_alpha084",
|
"GTJA_alpha084",
|
||||||
# "GTJA_alpha085",
|
"GTJA_alpha085",
|
||||||
# "GTJA_alpha086",
|
"GTJA_alpha086",
|
||||||
# "GTJA_alpha087",
|
"GTJA_alpha087",
|
||||||
# "GTJA_alpha088",
|
"GTJA_alpha088",
|
||||||
# "GTJA_alpha089",
|
"GTJA_alpha089",
|
||||||
# "GTJA_alpha090",
|
"GTJA_alpha090",
|
||||||
# "GTJA_alpha091",
|
"GTJA_alpha091",
|
||||||
# "GTJA_alpha092",
|
"GTJA_alpha092",
|
||||||
# "GTJA_alpha093",
|
"GTJA_alpha093",
|
||||||
# "GTJA_alpha094",
|
"GTJA_alpha094",
|
||||||
# "GTJA_alpha095",
|
"GTJA_alpha095",
|
||||||
# "GTJA_alpha096",
|
"GTJA_alpha096",
|
||||||
# "GTJA_alpha097",
|
"GTJA_alpha097",
|
||||||
# "GTJA_alpha098",
|
"GTJA_alpha098",
|
||||||
# "GTJA_alpha099",
|
"GTJA_alpha099",
|
||||||
# "GTJA_alpha100",
|
"GTJA_alpha100",
|
||||||
# "GTJA_alpha101",
|
"GTJA_alpha101",
|
||||||
# "GTJA_alpha102",
|
"GTJA_alpha102",
|
||||||
# "GTJA_alpha103",
|
"GTJA_alpha103",
|
||||||
# "GTJA_alpha104",
|
"GTJA_alpha104",
|
||||||
# "GTJA_alpha105",
|
"GTJA_alpha105",
|
||||||
# "GTJA_alpha106",
|
"GTJA_alpha106",
|
||||||
# "GTJA_alpha107",
|
"GTJA_alpha107",
|
||||||
# "GTJA_alpha108",
|
"GTJA_alpha108",
|
||||||
# "GTJA_alpha109",
|
"GTJA_alpha109",
|
||||||
# "GTJA_alpha110",
|
"GTJA_alpha110",
|
||||||
# "GTJA_alpha111",
|
"GTJA_alpha111",
|
||||||
# "GTJA_alpha112",
|
"GTJA_alpha112",
|
||||||
# # "GTJA_alpha113",
|
# "GTJA_alpha113",
|
||||||
# "GTJA_alpha114",
|
"GTJA_alpha114",
|
||||||
# "GTJA_alpha115",
|
"GTJA_alpha115",
|
||||||
# "GTJA_alpha117",
|
"GTJA_alpha117",
|
||||||
# "GTJA_alpha118",
|
"GTJA_alpha118",
|
||||||
# "GTJA_alpha119",
|
"GTJA_alpha119",
|
||||||
# "GTJA_alpha120",
|
"GTJA_alpha120",
|
||||||
# # "GTJA_alpha121",
|
# "GTJA_alpha121",
|
||||||
# "GTJA_alpha122",
|
"GTJA_alpha122",
|
||||||
# "GTJA_alpha123",
|
"GTJA_alpha123",
|
||||||
# "GTJA_alpha124",
|
"GTJA_alpha124",
|
||||||
# "GTJA_alpha125",
|
"GTJA_alpha125",
|
||||||
# "GTJA_alpha126",
|
"GTJA_alpha126",
|
||||||
# "GTJA_alpha127",
|
"GTJA_alpha127",
|
||||||
# "GTJA_alpha128",
|
"GTJA_alpha128",
|
||||||
# "GTJA_alpha129",
|
"GTJA_alpha129",
|
||||||
# "GTJA_alpha130",
|
"GTJA_alpha130",
|
||||||
# "GTJA_alpha131",
|
"GTJA_alpha131",
|
||||||
# "GTJA_alpha132",
|
"GTJA_alpha132",
|
||||||
# "GTJA_alpha133",
|
"GTJA_alpha133",
|
||||||
# "GTJA_alpha134",
|
"GTJA_alpha134",
|
||||||
# "GTJA_alpha135",
|
"GTJA_alpha135",
|
||||||
# "GTJA_alpha136",
|
"GTJA_alpha136",
|
||||||
# # "GTJA_alpha138",
|
# "GTJA_alpha138",
|
||||||
# "GTJA_alpha139",
|
"GTJA_alpha139",
|
||||||
# # "GTJA_alpha140",
|
# "GTJA_alpha140",
|
||||||
# "GTJA_alpha141",
|
"GTJA_alpha141",
|
||||||
# "GTJA_alpha142",
|
"GTJA_alpha142",
|
||||||
# "GTJA_alpha145",
|
"GTJA_alpha145",
|
||||||
# # "GTJA_alpha146",
|
# "GTJA_alpha146",
|
||||||
# "GTJA_alpha148",
|
"GTJA_alpha148",
|
||||||
# "GTJA_alpha150",
|
"GTJA_alpha150",
|
||||||
# "GTJA_alpha151",
|
"GTJA_alpha151",
|
||||||
# "GTJA_alpha152",
|
"GTJA_alpha152",
|
||||||
# "GTJA_alpha153",
|
"GTJA_alpha153",
|
||||||
# "GTJA_alpha154",
|
"GTJA_alpha154",
|
||||||
# "GTJA_alpha155",
|
"GTJA_alpha155",
|
||||||
# "GTJA_alpha156",
|
"GTJA_alpha156",
|
||||||
# "GTJA_alpha157",
|
"GTJA_alpha157",
|
||||||
# "GTJA_alpha158",
|
"GTJA_alpha158",
|
||||||
# "GTJA_alpha159",
|
"GTJA_alpha159",
|
||||||
# "GTJA_alpha160",
|
"GTJA_alpha160",
|
||||||
# "GTJA_alpha161",
|
"GTJA_alpha161",
|
||||||
# # "GTJA_alpha162",
|
# "GTJA_alpha162",
|
||||||
# "GTJA_alpha163",
|
"GTJA_alpha163",
|
||||||
# "GTJA_alpha164",
|
"GTJA_alpha164",
|
||||||
# # "GTJA_alpha165",
|
# "GTJA_alpha165",
|
||||||
# "GTJA_alpha166",
|
"GTJA_alpha166",
|
||||||
# "GTJA_alpha167",
|
"GTJA_alpha167",
|
||||||
# "GTJA_alpha168",
|
"GTJA_alpha168",
|
||||||
# "GTJA_alpha169",
|
"GTJA_alpha169",
|
||||||
# "GTJA_alpha170",
|
"GTJA_alpha170",
|
||||||
# "GTJA_alpha171",
|
"GTJA_alpha171",
|
||||||
# "GTJA_alpha173",
|
"GTJA_alpha173",
|
||||||
# "GTJA_alpha174",
|
"GTJA_alpha174",
|
||||||
# "GTJA_alpha175",
|
"GTJA_alpha175",
|
||||||
# "GTJA_alpha176",
|
"GTJA_alpha176",
|
||||||
# "GTJA_alpha177",
|
"GTJA_alpha177",
|
||||||
# "GTJA_alpha178",
|
"GTJA_alpha178",
|
||||||
# "GTJA_alpha179",
|
"GTJA_alpha179",
|
||||||
# "GTJA_alpha180",
|
"GTJA_alpha180",
|
||||||
# # "GTJA_alpha183",
|
# "GTJA_alpha183",
|
||||||
# "GTJA_alpha184",
|
"GTJA_alpha184",
|
||||||
# "GTJA_alpha185",
|
"GTJA_alpha185",
|
||||||
# "GTJA_alpha187",
|
"GTJA_alpha187",
|
||||||
# "GTJA_alpha188",
|
"GTJA_alpha188",
|
||||||
# "GTJA_alpha189",
|
"GTJA_alpha189",
|
||||||
# "GTJA_alpha191",
|
"GTJA_alpha191",
|
||||||
"chip_dispersion_90",
|
"chip_dispersion_90",
|
||||||
"chip_dispersion_70",
|
"chip_dispersion_70",
|
||||||
"cost_skewness",
|
"cost_skewness",
|
||||||
@@ -270,12 +270,27 @@ SELECTED_FACTORS = [
|
|||||||
"bottom_cost_stability",
|
"bottom_cost_stability",
|
||||||
"pivot_reversion",
|
"pivot_reversion",
|
||||||
"chip_transition",
|
"chip_transition",
|
||||||
|
# "amivest_liq_20",
|
||||||
|
# "atr_price_impact",
|
||||||
|
# "hui_heubel_ratio",
|
||||||
|
# "corwin_schultz_spread_20",
|
||||||
|
# "roll_spread_20",
|
||||||
|
# "gibbs_effective_spread",
|
||||||
|
# "overnight_illiq_20",
|
||||||
|
# "illiq_volatility_20",
|
||||||
|
# "amount_cv_20",
|
||||||
|
# "amount_skewness_20",
|
||||||
|
# "low_vol_days_20",
|
||||||
|
# "liquidity_shock_momentum",
|
||||||
|
# "downside_illiq_20",
|
||||||
|
# "upside_illiq_20",
|
||||||
|
# "illiq_asymmetry_20",
|
||||||
|
# "pastor_stambaugh_proxy"
|
||||||
]
|
]
|
||||||
|
|
||||||
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
||||||
FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"}
|
FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"}
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Label 配置(统一绑定 label_name 和 label_dsl)
|
# Label 配置(统一绑定 label_name 和 label_dsl)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -308,11 +323,11 @@ def get_label_factor(label_name: str) -> dict:
|
|||||||
# 辅助函数
|
# 辅助函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
def register_factors(
|
def register_factors(
|
||||||
engine: FactorEngine,
|
engine: FactorEngine,
|
||||||
selected_factors: List[str],
|
selected_factors: List[str],
|
||||||
factor_definitions: dict,
|
factor_definitions: dict,
|
||||||
label_factor: dict,
|
label_factor: dict,
|
||||||
excluded_factors: Optional[List[str]] = None,
|
excluded_factors: Optional[List[str]] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""注册因子。
|
"""注册因子。
|
||||||
|
|
||||||
@@ -393,11 +408,11 @@ def register_factors(
|
|||||||
|
|
||||||
|
|
||||||
def prepare_data(
|
def prepare_data(
|
||||||
engine: FactorEngine,
|
engine: FactorEngine,
|
||||||
feature_cols: List[str],
|
feature_cols: List[str],
|
||||||
start_date: str,
|
start_date: str,
|
||||||
end_date: str,
|
end_date: str,
|
||||||
label_name: str,
|
label_name: str,
|
||||||
) -> pl.DataFrame:
|
) -> pl.DataFrame:
|
||||||
"""准备数据。
|
"""准备数据。
|
||||||
|
|
||||||
@@ -455,11 +470,11 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
|||||||
"""
|
"""
|
||||||
# 代码筛选(排除创业板、科创板、北交所)
|
# 代码筛选(排除创业板、科创板、北交所)
|
||||||
code_filter = (
|
code_filter = (
|
||||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||||
)
|
)
|
||||||
|
|
||||||
# 在已筛选的股票中,选取流通市值最小的500只
|
# 在已筛选的股票中,选取流通市值最小的500只
|
||||||
@@ -474,7 +489,6 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
|||||||
# 定义筛选所需的基础列
|
# 定义筛选所需的基础列
|
||||||
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
|
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 输出配置
|
# 输出配置
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -518,7 +532,7 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_model_save_path(
|
def get_model_save_path(
|
||||||
model_type: str,
|
model_type: str,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""生成模型保存路径。
|
"""生成模型保存路径。
|
||||||
|
|
||||||
@@ -544,11 +558,11 @@ def get_model_save_path(
|
|||||||
|
|
||||||
|
|
||||||
def save_model_with_factors(
|
def save_model_with_factors(
|
||||||
model,
|
model,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
selected_factors: list[str],
|
selected_factors: list[str],
|
||||||
factor_definitions: dict,
|
factor_definitions: dict,
|
||||||
fitted_processors: list | None = None,
|
fitted_processors: list | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""保存模型及关联的因子信息和处理器。
|
"""保存模型及关联的因子信息和处理器。
|
||||||
|
|
||||||
|
|||||||
@@ -54,31 +54,27 @@ N_QUANTILES = 20
|
|||||||
|
|
||||||
# 排除的因子列表
|
# 排除的因子列表
|
||||||
EXCLUDED_FACTORS = [
|
EXCLUDED_FACTORS = [
|
||||||
"GTJA_alpha010",
|
'active_market_cap',
|
||||||
"GTJA_alpha005",
|
'close_vwap_deviation',
|
||||||
"GTJA_alpha002",
|
'sharpe_ratio_20',
|
||||||
"GTJA_alpha027",
|
'upper_shadow_ratio',
|
||||||
"GTJA_alpha051",
|
'volume_ratio_5_20',
|
||||||
"GTJA_alpha044",
|
'GTJA_alpha090',
|
||||||
"GTJA_alpha041",
|
'GTJA_alpha084',
|
||||||
"GTJA_alpha131",
|
'GTJA_alpha066',
|
||||||
"GTJA_alpha103",
|
'GTJA_alpha150',
|
||||||
"GTJA_alpha087",
|
'GTJA_alpha148',
|
||||||
"GTJA_alpha093",
|
'GTJA_alpha106',
|
||||||
"GTJA_alpha092",
|
'GTJA_alpha109',
|
||||||
"GTJA_alpha073",
|
'GTJA_alpha108',
|
||||||
"GTJA_alpha127",
|
'GTJA_alpha176',
|
||||||
"GTJA_alpha117",
|
'GTJA_alpha169',
|
||||||
"GTJA_alpha124",
|
'GTJA_alpha156',
|
||||||
"GTJA_alpha162",
|
'chip_dispersion_70',
|
||||||
"GTJA_alpha177",
|
'winner_rate_cs_rank',
|
||||||
"GTJA_alpha188",
|
'atr_price_impact',
|
||||||
"smart_money_accumulation",
|
'low_vol_days_20',
|
||||||
"GTJA_alpha014",
|
'liquidity_shock_momentum',
|
||||||
"GTJA_alpha056",
|
|
||||||
"GTJA_alpha085",
|
|
||||||
"GTJA_alpha154",
|
|
||||||
"GTJA_alpha141",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# LambdaRank 模型参数配置
|
# LambdaRank 模型参数配置
|
||||||
|
|||||||
@@ -52,55 +52,36 @@ TRAINING_TYPE = "regression"
|
|||||||
|
|
||||||
# 排除的因子列表
|
# 排除的因子列表
|
||||||
EXCLUDED_FACTORS = [
|
EXCLUDED_FACTORS = [
|
||||||
"GTJA_alpha036",
|
'GTJA_alpha016',
|
||||||
"GTJA_alpha032",
|
'volatility_20',
|
||||||
"GTJA_alpha010",
|
'current_ratio',
|
||||||
"GTJA_alpha005",
|
'GTJA_alpha001',
|
||||||
"CP",
|
'GTJA_alpha141',
|
||||||
"BP",
|
'GTJA_alpha129',
|
||||||
"debt_to_equity",
|
'GTJA_alpha164',
|
||||||
"current_ratio",
|
'amivest_liq_20',
|
||||||
"GTJA_alpha002",
|
'GTJA_alpha012',
|
||||||
"GTJA_alpha027",
|
'debt_to_equity',
|
||||||
"GTJA_alpha064",
|
'turnover_deviation',
|
||||||
"GTJA_alpha062",
|
'GTJA_alpha073',
|
||||||
"GTJA_alpha043",
|
'GTJA_alpha043',
|
||||||
"GTJA_alpha044",
|
'GTJA_alpha032',
|
||||||
"GTJA_alpha120",
|
'GTJA_alpha028',
|
||||||
"GTJA_alpha117",
|
'GTJA_alpha090',
|
||||||
"GTJA_alpha103",
|
'GTJA_alpha108',
|
||||||
"GTJA_alpha104",
|
'GTJA_alpha105',
|
||||||
"GTJA_alpha105",
|
'GTJA_alpha091',
|
||||||
"GTJA_alpha073",
|
'GTJA_alpha119',
|
||||||
"GTJA_alpha077",
|
'GTJA_alpha104',
|
||||||
"GTJA_alpha085",
|
'GTJA_alpha163',
|
||||||
"GTJA_alpha090",
|
'GTJA_alpha157',
|
||||||
"GTJA_alpha087",
|
'cost_skewness',
|
||||||
"GTJA_alpha083",
|
'GTJA_alpha176',
|
||||||
"GTJA_alpha092",
|
'chip_transition',
|
||||||
"GTJA_alpha133",
|
'amount_skewness_20',
|
||||||
"GTJA_alpha131",
|
'GTJA_alpha148',
|
||||||
"GTJA_alpha126",
|
'mean_median_dev',
|
||||||
"GTJA_alpha124",
|
'downside_illiq_20',
|
||||||
"GTJA_alpha162",
|
|
||||||
"GTJA_alpha164",
|
|
||||||
"GTJA_alpha157",
|
|
||||||
"GTJA_alpha177",
|
|
||||||
"price_to_avg_cost",
|
|
||||||
"cost_skewness",
|
|
||||||
"GTJA_alpha191",
|
|
||||||
"GTJA_alpha180",
|
|
||||||
"history_position",
|
|
||||||
"bottom_profit",
|
|
||||||
"mean_median_dev",
|
|
||||||
"smart_money_accumulation",
|
|
||||||
"GTJA_alpha013",
|
|
||||||
"GTJA_alpha099",
|
|
||||||
"GTJA_alpha107",
|
|
||||||
"GTJA_alpha119",
|
|
||||||
"GTJA_alpha141",
|
|
||||||
"GTJA_alpha130",
|
|
||||||
"GTJA_alpha173",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# 模型参数配置
|
# 模型参数配置
|
||||||
|
|||||||
@@ -34,112 +34,99 @@ from src.config.settings import get_settings
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
FACTORS: List[Dict[str, Any]] =[
|
FACTORS: List[Dict[str, Any]] =[
|
||||||
# ==================== 第一类:筹码集中度与离散度因子 ====================
|
|
||||||
{
|
{
|
||||||
"name": "chip_dispersion_90",
|
"name": "amihud_illiq_20",
|
||||||
"desc": "90%筹码离散度:衡量市场90%持仓筹码的宽度,值越小表示筹码越高度集中(单峰密集),往往是洗盘结束的前兆",
|
"desc": "Amihud非流动性指标(20日):绝对收益率/成交额。该值越大,说明少量的资金就能砸盘或拉升,流动性极差,隐含极高的风险补偿预期",
|
||||||
"dsl": "(cost_95pct - cost_5pct) / (cost_95pct + cost_5pct)",
|
"dsl": "ts_mean(abs(pct_chg) / (amount + 1), 20)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "chip_dispersion_70",
|
"name": "amivest_liq_20",
|
||||||
"desc": "70%核心筹码离散度:剔除极端的底部死筹和高位套牢盘,反映中间70%主流资金的成本集中度",
|
"desc": "Amivest流动性指标(20日):Amihud的倒数变种,衡量推动1%价格变化需要的资金量。值越低,流动性溢价越高",
|
||||||
"dsl": "(cost_85pct - cost_15pct) / (cost_85pct + cost_15pct)",
|
"dsl": "ts_mean(amount / (abs(pct_chg) + 0.001), 20)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "cost_skewness",
|
"name": "atr_price_impact",
|
||||||
"desc": "筹码偏度:反映筹码分布的不对称性。大于1说明上方套牢盘拖尾严重,小于1说明下方获利盘雄厚",
|
"desc": "真实波幅价格冲击(20日):以ATR(真实波幅)代替绝对收益,剔除跳空影响后的真实交易冲击",
|
||||||
"dsl": "(cost_95pct - cost_50pct) / (cost_50pct - cost_5pct)",
|
"dsl": "ts_atr(high, low, close, 20) / close / (ts_mean(amount, 20) + 1)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "dispersion_change_20",
|
"name": "hui_heubel_ratio",
|
||||||
"desc": "筹码集中度近期变化率:过去20天筹码宽度的变化比例,持续下降说明主力正在暗中吸筹",
|
"desc": "Hui-Heubel流动性比率:利用高低价区间占均价的比例,除以成交额占比,捕捉阶段性区间的深度非流动性",
|
||||||
"dsl": "ts_pct_change((cost_95pct - cost_5pct) / cost_50pct, 20)",
|
"dsl": "((ts_max(high, 20) - ts_min(low, 20)) / ts_min(low, 20)) / (ts_mean(amount, 20) + 1)",
|
||||||
},
|
},
|
||||||
|
|
||||||
# ==================== 第二类:筹码相对位置与压力/支撑因子 ====================
|
# --- 维度 2: 交易摩擦与隐形价差类 (Spread Proxies) ---
|
||||||
|
# 逻辑:A股没有Tick级买卖价差数据,通常用日线的高低价或自协方差来推导隐形买卖价差。价差越大,交易成本越高,持有需高收益补偿。
|
||||||
{
|
{
|
||||||
"name": "price_to_avg_cost",
|
"name": "corwin_schultz_spread_20",
|
||||||
"desc": "整体浮盈比例:当前价格相对加权平均成本的溢价率。高溢价有均值回归压力,负溢价代表超跌",
|
"desc": "Corwin-Schultz买卖价差代理:利用每日(最高价-最低价)/收盘价的均值衡量交易摩擦,摩擦越大的股票往往带有小盘股超额收益",
|
||||||
"dsl": "(close - weight_avg) / weight_avg",
|
"dsl": "ts_mean((high - low) / close, 20)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "price_to_median_cost",
|
"name": "roll_spread_20",
|
||||||
"desc": "中位数成本偏离度:价格相对于50%分位点(绝对半数人持仓价)的偏离,向上突破通常是右侧买点",
|
"desc": "Roll买卖价差代理(20日):经典微观结构模型,计算相邻两日收益率的负协方差的平方根,反映做市商的隐形报价跳跃",
|
||||||
"dsl": "(close - cost_50pct) / cost_50pct",
|
"dsl": "sqrt(max_(-ts_cov(change, ts_delay(change, 1), 20), 0))",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "mean_median_dev",
|
"name": "gibbs_effective_spread",
|
||||||
"desc": "均值中位数背离:均值显著大于中位数说明高位筹码堆积,上涨阻力大",
|
"desc": "有效价差代理:使用日内振幅减去隔夜跳空幅度后的纯日内摩擦成本",
|
||||||
"dsl": "(weight_avg - cost_50pct) / cost_50pct",
|
"dsl": "ts_mean(((high - low) - abs(open - ts_delay(close, 1))) / close, 20)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "trap_pressure",
|
"name": "overnight_illiq_20",
|
||||||
"desc": "高位套牢盘压力指数:当前价格距离上方95%高位套牢成本的距离。距离越大,反弹的真空期阻力越小",
|
"desc": "隔夜非流动性:开盘价相对于昨日收盘的跳空幅度,除以昨日成交额。隔夜极易跳空的股票具有夜间流动性溢价",
|
||||||
"dsl": "(cost_95pct - close) / close",
|
"dsl": "ts_mean(abs(open - ts_delay(close, 1)) / (ts_delay(amount, 1) + 1), 20)",
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "bottom_profit",
|
|
||||||
"desc": "底部支撑底仓利润率:当前价格距离底部5%筹码的利润空间。暴跌时大于0说明底仓极度稳定",
|
|
||||||
"dsl": "(close - cost_5pct) / cost_5pct",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "history_position",
|
|
||||||
"desc": "历史区间分位点:当前价格在个股上市以来历史最高点和最低点之间的相对位置",
|
|
||||||
"dsl": "(close - his_low) / (his_high - his_low)",
|
|
||||||
},
|
},
|
||||||
|
|
||||||
# ==================== 第三类:胜率相关的动量与反转因子 ====================
|
# --- 维度 3: 流动性风险与枯竭类 (Liquidity Risk & Depletion) ---
|
||||||
|
# 逻辑:投资者不仅讨厌“平时流动性差”,更讨厌“流动性极其不稳定”。
|
||||||
{
|
{
|
||||||
"name": "winner_rate_surge_5",
|
"name": "illiq_volatility_20",
|
||||||
"desc": "获利盘短期爆发力:胜率在过去5天内的变化值,急剧上升是极强的动量做多信号",
|
"desc": "Amihud非流动性的波动率(20日):衡量价格冲击的不确定性(即流动性风险本身)。波动越大的股越容易踩踏",
|
||||||
"dsl": "ts_delta(winner_rate, 5)",
|
"dsl": "ts_std(abs(pct_chg) / (amount + 1), 20)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "winner_rate_cs_rank",
|
"name": "amount_cv_20",
|
||||||
"desc": "获利盘高位反转信号:全市场胜率截面排名,极端高胜率往往面临多头踩踏的获利了结压力(反转因子)",
|
"desc": "成交额变异系数(20日):成交额的波动率除以均值。反映股票被市场关注的极度不稳定性",
|
||||||
"dsl": "cs_rank(winner_rate)",
|
"dsl": "ts_std(amount, 20) / (ts_mean(amount, 20) + 1)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "winner_rate_dev_20",
|
"name": "amount_skewness_20",
|
||||||
"desc": "获利盘均线偏离:当前胜率相对过去20天平均胜率的偏离程度,捕捉筹码情绪的边际超买/超卖",
|
"desc": "成交额偏度(20日):正偏度意味着平时成交极度清淡,偶尔脉冲式放量。这种“死鱼”状态是经典的非流动性特征",
|
||||||
"dsl": "winner_rate - ts_mean(winner_rate, 20)",
|
"dsl": "ts_skew(amount, 20)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "winner_rate_volatility",
|
"name": "low_vol_days_20",
|
||||||
"desc": "获利盘波动率:过去20天胜率的波动率。波动率低且胜率高说明单边上涨极度稳健",
|
"desc": "流动性枯竭天数:过去20天内,成交额低于长期(60日)均值一半的极端缩量天数",
|
||||||
"dsl": "ts_std(winner_rate, 20)",
|
"dsl": "ts_count(amount < ts_mean(amount, 60) * 0.5, 20)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "smart_money_accumulation",
|
"name": "liquidity_shock_momentum",
|
||||||
"desc": "潜在主力吸筹隐蔽指标:胜率的60日时序分位数减去价格的时序分位数。值越大说明‘价平而获利盘增’,底部吸筹明显",
|
"desc": "流动性恶化动量:近期(5日)非流动性相较于长期的变化。正值代表流动性正在迅速干涸",
|
||||||
"dsl": "ts_rank(winner_rate, 60) - ts_rank(close, 60)",
|
"dsl": "ts_mean(abs(pct_chg) / (amount + 1), 5) - ts_mean(abs(pct_chg) / (amount + 1), 20)",
|
||||||
},
|
},
|
||||||
|
|
||||||
# ==================== 第四类:量价与筹码交乘因子 ====================
|
# --- 维度 4: 非对称流动性类 (Asymmetric / Downside Illiquidity) ---
|
||||||
|
# 逻辑:买入时没问题,但“大跌时卖不出去(流动性丧失)”是最致命的风险,这部分风险被定价的权重最高。
|
||||||
{
|
{
|
||||||
"name": "winner_vol_corr_20",
|
"name": "downside_illiq_20",
|
||||||
"desc": "放量突破筹码密集区:胜率与成交量的20日时序相关性,正相关说明增量资金在主动解套上方筹码",
|
"desc": "下行非流动性:仅在股价下跌日计算的价格冲击。捕捉‘跌时没人接盘’的极端流动性折价",
|
||||||
"dsl": "ts_corr(winner_rate, vol, 20)",
|
"dsl": "ts_sum(where(change < 0, abs(change) / (amount + 1), 0), 20) / (ts_count(change < 0, 20) + 1)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "cost_base_momentum",
|
"name": "upside_illiq_20",
|
||||||
"desc": "成本重心上移换手率:过去20天加权平均成本的变化幅度,快速上移说明高位换手极其充分",
|
"desc": "上行非流动性:仅在股价上涨日计算的价格冲击。捕捉‘涨时抛压极轻’的状态",
|
||||||
"dsl": "ts_pct_change(weight_avg, 20)",
|
"dsl": "ts_sum(where(change > 0, abs(change) / (amount + 1), 0), 20) / (ts_count(change > 0, 20) + 1)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "bottom_cost_stability",
|
"name": "illiq_asymmetry_20",
|
||||||
"desc": "底部坚如磐石因子:底部5%成本的60天波动率相对于中位数的比值,波动越小说明死筹越稳固",
|
"desc": "非对称流动性比率:下行流动性恶化程度除以加上行流动性恶化程度。该值远大于1说明下跌时发生严重踩踏,股票本身必须折价(预期收益率补偿极高)",
|
||||||
"dsl": "ts_std(cost_5pct, 60) / cost_50pct",
|
"dsl": "(ts_sum(where(change < 0, abs(change)/(amount+1), 0), 20) + 1) / (ts_sum(where(change > 0, abs(change)/(amount+1), 0), 20) + 1)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "pivot_reversion",
|
"name": "pastor_stambaugh_proxy",
|
||||||
"desc": "盈亏分界线乖离修复:价格偏离50%分位点除以近20日价格标准差,用于寻找超跌后的均值回归买点",
|
"desc": "Pastor-Stambaugh流动性贝塔代理:收益率与滞后一期带有符号(涨跌)成交额的相关性。反映市场由于流动性短缺导致的价格过度反转现象",
|
||||||
"dsl": "(close - cost_50pct) / ts_std(close, 20)",
|
"dsl": "ts_corr(change, sign(ts_delay(change, 1)) * ts_delay(amount, 1), 20)",
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "chip_transition",
|
|
||||||
"desc": "强弱筹码切换度:上方厚度与下方厚度差值的20日变化量。由正变负说明筹码彻底完成了自上而下的转移(洗盘结束)",
|
|
||||||
"dsl": "ts_delta((cost_85pct - cost_50pct) - (cost_50pct - cost_15pct), 20)",
|
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -84,3 +84,65 @@ class EnsembleQuantLoss(nn.Module):
|
|||||||
total_loss += self.alpha * h_loss + (1.0 - self.alpha) * ic_loss
|
total_loss += self.alpha * h_loss + (1.0 - self.alpha) * ic_loss
|
||||||
|
|
||||||
return total_loss / self.ensemble_size
|
return total_loss / self.ensemble_size
|
||||||
|
|
||||||
|
|
||||||
|
class AsymmetricQuantLoss(nn.Module):
|
||||||
|
"""Asymmetric Quant Loss (非对称 Huber + IC)
|
||||||
|
|
||||||
|
保留全截面计算以维持稳定梯度的前提下,对多头关注的错误进行加权惩罚:
|
||||||
|
1. 过度高估烂股票 (买入陷阱) -> 加重惩罚
|
||||||
|
2. 严重低估好股票 (错失金股) -> 加重惩罚
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, alpha: float = 0.5, ensemble_size: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = alpha
|
||||||
|
self.ensemble_size = ensemble_size
|
||||||
|
# 我们不使用内置的 HuberLoss,而是手动写以支持 element-wise 加权
|
||||||
|
self.delta = 1.0 # Huber threshold (可以设得比较小,比如对于收益率设为 0.05)
|
||||||
|
|
||||||
|
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size = preds.shape[0]
|
||||||
|
total_loss = 0.0
|
||||||
|
|
||||||
|
if batch_size < 10:
|
||||||
|
# 极小 batch 退回简单 MSE
|
||||||
|
return ((preds - target.unsqueeze(1)) ** 2).mean()
|
||||||
|
|
||||||
|
target_mean = target.mean()
|
||||||
|
target_std = target.std(unbiased=False) + 1e-8
|
||||||
|
target_norm = (target - target_mean) / target_std
|
||||||
|
|
||||||
|
for i in range(self.ensemble_size):
|
||||||
|
pred_i = preds[:, i]
|
||||||
|
|
||||||
|
# --- 1. 非对称 Huber 损失 ---
|
||||||
|
error = pred_i - target
|
||||||
|
abs_error = torch.abs(error)
|
||||||
|
|
||||||
|
# 标准 Huber 计算
|
||||||
|
quadratic = torch.clamp(abs_error, max=self.delta)
|
||||||
|
linear = abs_error - quadratic
|
||||||
|
base_huber = 0.5 * quadratic**2 + self.delta * linear
|
||||||
|
|
||||||
|
# 【核心逻辑】:非对称权重
|
||||||
|
# 如果 target > 0 且 pred < target (错过了涨的好票) -> 权重 1.5
|
||||||
|
# 如果 target < 0 且 pred > 0 (把跌的票预测成了涨的,容易实盘买入踩雷) -> 权重 2.0
|
||||||
|
# 其他情况 (比如烂票预测得更烂) -> 权重 1.0 正常学习
|
||||||
|
weights = torch.ones_like(target)
|
||||||
|
weights[(target > 0) & (error < 0)] = 1.5
|
||||||
|
weights[(target < 0) & (pred_i > 0)] = 2.0
|
||||||
|
|
||||||
|
weighted_huber = (base_huber * weights).mean()
|
||||||
|
|
||||||
|
# --- 2. IC 损失 (维持全截面排序能力) ---
|
||||||
|
pred_mean = pred_i.mean()
|
||||||
|
pred_std = pred_i.std(unbiased=False) + 1e-8
|
||||||
|
pred_norm = (pred_i - pred_mean) / pred_std
|
||||||
|
|
||||||
|
ic = (pred_norm * target_norm).mean()
|
||||||
|
ic_loss = 1.0 - ic
|
||||||
|
|
||||||
|
total_loss += self.alpha * weighted_huber + (1.0 - self.alpha) * ic_loss
|
||||||
|
|
||||||
|
return total_loss / self.ensemble_size
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from tabm import TabM
|
|||||||
|
|
||||||
from src.training.components.base import BaseModel
|
from src.training.components.base import BaseModel
|
||||||
from src.training.components.models.cross_section_sampler import CrossSectionSampler
|
from src.training.components.models.cross_section_sampler import CrossSectionSampler
|
||||||
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
|
from src.training.components.models.ensemble_quant_loss import (
|
||||||
|
EnsembleQuantLoss,
|
||||||
|
AsymmetricQuantLoss,
|
||||||
|
)
|
||||||
from src.training.registry import register_model
|
from src.training.registry import register_model
|
||||||
|
|
||||||
|
|
||||||
@@ -235,8 +238,8 @@ class TabMModel(BaseModel):
|
|||||||
optimizer, T_max=epochs, eta_min=1e-6
|
optimizer, T_max=epochs, eta_min=1e-6
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用 EnsembleQuantLoss 替代 MSE
|
# 使用 AsymmetricQuantLoss (非对称 Huber + IC)
|
||||||
self.criterion = EnsembleQuantLoss(
|
self.criterion = AsymmetricQuantLoss(
|
||||||
alpha=self.params.get("loss_alpha", 0.5), ensemble_size=ensemble_size
|
alpha=self.params.get("loss_alpha", 0.5), ensemble_size=ensemble_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
250
tests/test_moneyflow.py
Normal file
250
tests/test_moneyflow.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""Tests for api_moneyflow module.
|
||||||
|
|
||||||
|
Tests the moneyflow (个股资金流向) API wrapper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from src.data.api_wrappers.api_moneyflow import (
|
||||||
|
get_moneyflow,
|
||||||
|
sync_moneyflow,
|
||||||
|
preview_moneyflow_sync,
|
||||||
|
MoneyflowSync,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoneyflowAPI:
|
||||||
|
"""Test suite for moneyflow API wrapper."""
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
|
||||||
|
def test_get_moneyflow_by_date(self, mock_client_class):
|
||||||
|
"""Test fetching moneyflow data by date."""
|
||||||
|
# Setup mock
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.query.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||||
|
"trade_date": ["20240115", "20240115"],
|
||||||
|
"buy_sm_vol": [10000, 20000],
|
||||||
|
"buy_sm_amount": [100.5, 200.5],
|
||||||
|
"sell_sm_vol": [8000, 15000],
|
||||||
|
"sell_sm_amount": [80.5, 150.5],
|
||||||
|
"buy_md_vol": [5000, 10000],
|
||||||
|
"buy_md_amount": [500.5, 1000.5],
|
||||||
|
"sell_md_vol": [4000, 8000],
|
||||||
|
"sell_md_amount": [400.5, 800.5],
|
||||||
|
"buy_lg_vol": [2000, 5000],
|
||||||
|
"buy_lg_amount": [2000.5, 5000.5],
|
||||||
|
"sell_lg_vol": [1500, 4000],
|
||||||
|
"sell_lg_amount": [1500.5, 4000.5],
|
||||||
|
"buy_elg_vol": [1000, 3000],
|
||||||
|
"buy_elg_amount": [5000.5, 15000.5],
|
||||||
|
"sell_elg_vol": [800, 2500],
|
||||||
|
"sell_elg_amount": [4000.5, 12500.5],
|
||||||
|
"net_mf_vol": [2700, 8000],
|
||||||
|
"net_mf_amount": [220.0, 550.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test
|
||||||
|
result = get_moneyflow(trade_date="20240115")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert not result.empty
|
||||||
|
assert len(result) == 2
|
||||||
|
assert "ts_code" in result.columns
|
||||||
|
assert "trade_date" in result.columns
|
||||||
|
assert "buy_sm_vol" in result.columns
|
||||||
|
assert "net_mf_amount" in result.columns
|
||||||
|
mock_client.query.assert_called_once_with("moneyflow", trade_date="20240115")
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
|
||||||
|
def test_get_moneyflow_by_stock(self, mock_client_class):
|
||||||
|
"""Test fetching moneyflow data by stock code."""
|
||||||
|
# Setup mock
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.query.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000001.SZ"],
|
||||||
|
"trade_date": ["20240114", "20240115"],
|
||||||
|
"buy_sm_vol": [10000, 11000],
|
||||||
|
"buy_sm_amount": [100.5, 110.5],
|
||||||
|
"sell_sm_vol": [8000, 8500],
|
||||||
|
"sell_sm_amount": [80.5, 85.5],
|
||||||
|
"buy_md_vol": [5000, 5500],
|
||||||
|
"buy_md_amount": [500.5, 550.5],
|
||||||
|
"sell_md_vol": [4000, 4200],
|
||||||
|
"sell_md_amount": [400.5, 420.5],
|
||||||
|
"buy_lg_vol": [2000, 2200],
|
||||||
|
"buy_lg_amount": [2000.5, 2200.5],
|
||||||
|
"sell_lg_vol": [1500, 1600],
|
||||||
|
"sell_lg_amount": [1500.5, 1600.5],
|
||||||
|
"buy_elg_vol": [1000, 1100],
|
||||||
|
"buy_elg_amount": [5000.5, 5500.5],
|
||||||
|
"sell_elg_vol": [800, 850],
|
||||||
|
"sell_elg_amount": [4000.5, 4250.5],
|
||||||
|
"net_mf_vol": [2700, 2950],
|
||||||
|
"net_mf_amount": [220.0, 245.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test
|
||||||
|
result = get_moneyflow(
|
||||||
|
ts_code="000001.SZ", start_date="20240101", end_date="20240131"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert not result.empty
|
||||||
|
assert len(result) == 2
|
||||||
|
mock_client.query.assert_called_once_with(
|
||||||
|
"moneyflow", ts_code="000001.SZ", start_date="20240101", end_date="20240131"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
|
||||||
|
def test_get_moneyflow_empty_response(self, mock_client_class):
|
||||||
|
"""Test handling empty response."""
|
||||||
|
# Setup mock
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.query.return_value = pd.DataFrame()
|
||||||
|
|
||||||
|
# Test
|
||||||
|
result = get_moneyflow(trade_date="20240101")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result.empty
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
|
||||||
|
def test_get_moneyflow_with_shared_client(self, mock_client_class):
|
||||||
|
"""Test fetching with shared client for rate limiting."""
|
||||||
|
# Setup mock shared client
|
||||||
|
mock_shared_client = MagicMock()
|
||||||
|
mock_shared_client.query.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"trade_date": ["20240115"],
|
||||||
|
"buy_sm_vol": [10000],
|
||||||
|
"net_mf_amount": [220.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with shared client
|
||||||
|
result = get_moneyflow(trade_date="20240115", client=mock_shared_client)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert not result.empty
|
||||||
|
mock_shared_client.query.assert_called_once_with(
|
||||||
|
"moneyflow", trade_date="20240115"
|
||||||
|
)
|
||||||
|
mock_client_class.assert_not_called() # Should not create new client
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
|
||||||
|
def test_get_moneyflow_date_column_rename(self, mock_client_class):
|
||||||
|
"""Test that 'date' column is renamed to 'trade_date'."""
|
||||||
|
# Setup mock with 'date' column (should be renamed)
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.query.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"date": ["20240115"], # Using 'date' instead of 'trade_date'
|
||||||
|
"buy_sm_vol": [10000],
|
||||||
|
"net_mf_amount": [220.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test
|
||||||
|
result = get_moneyflow(trade_date="20240115")
|
||||||
|
|
||||||
|
# Assert - date column should be renamed to trade_date
|
||||||
|
assert "trade_date" in result.columns
|
||||||
|
assert "date" not in result.columns
|
||||||
|
assert result["trade_date"].iloc[0] == "20240115"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoneyflowSync:
|
||||||
|
"""Test suite for MoneyflowSync class."""
|
||||||
|
|
||||||
|
def test_moneyflow_sync_class_attributes(self):
|
||||||
|
"""Test MoneyflowSync class has correct attributes."""
|
||||||
|
assert MoneyflowSync.table_name == "moneyflow"
|
||||||
|
assert MoneyflowSync.default_start_date == "20100101"
|
||||||
|
assert "ts_code" in MoneyflowSync.TABLE_SCHEMA
|
||||||
|
assert "trade_date" in MoneyflowSync.TABLE_SCHEMA
|
||||||
|
assert "buy_sm_vol" in MoneyflowSync.TABLE_SCHEMA
|
||||||
|
assert "net_mf_amount" in MoneyflowSync.TABLE_SCHEMA
|
||||||
|
assert MoneyflowSync.PRIMARY_KEY == ("ts_code", "trade_date")
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.get_moneyflow")
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
|
||||||
|
def test_fetch_single_date(self, mock_client_class, mock_get_moneyflow):
|
||||||
|
"""Test fetch_single_date method."""
|
||||||
|
# Setup mock
|
||||||
|
mock_get_moneyflow.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"trade_date": ["20240115"],
|
||||||
|
"buy_sm_vol": [10000],
|
||||||
|
"net_mf_amount": [220.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create sync instance
|
||||||
|
sync = MoneyflowSync()
|
||||||
|
|
||||||
|
# Test
|
||||||
|
result = sync.fetch_single_date("20240115")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert not result.empty
|
||||||
|
mock_get_moneyflow.assert_called_once_with(
|
||||||
|
trade_date="20240115", client=sync.client
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.MoneyflowSync.sync_all")
|
||||||
|
def test_sync_moneyflow_function(self, mock_sync_all):
|
||||||
|
"""Test sync_moneyflow convenience function."""
|
||||||
|
# Setup mock
|
||||||
|
mock_sync_all.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"trade_date": ["20240115"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test
|
||||||
|
result = sync_moneyflow(start_date="20240101", end_date="20240131")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert not result.empty
|
||||||
|
mock_sync_all.assert_called_once_with(
|
||||||
|
start_date="20240101", end_date="20240131", force_full=False
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_moneyflow.MoneyflowSync.preview_sync")
|
||||||
|
def test_preview_moneyflow_sync_function(self, mock_preview_sync):
|
||||||
|
"""Test preview_moneyflow_sync convenience function."""
|
||||||
|
# Setup mock
|
||||||
|
mock_preview_sync.return_value = {
|
||||||
|
"sync_needed": True,
|
||||||
|
"date_count": 31,
|
||||||
|
"start_date": "20240101",
|
||||||
|
"end_date": "20240131",
|
||||||
|
"estimated_records": 10000,
|
||||||
|
"sample_data": pd.DataFrame(),
|
||||||
|
"mode": "incremental",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test
|
||||||
|
result = preview_moneyflow_sync(start_date="20240101", end_date="20240131")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result["sync_needed"] is True
|
||||||
|
assert result["date_count"] == 31
|
||||||
|
mock_preview_sync.assert_called_once_with(
|
||||||
|
start_date="20240101", end_date="20240131", force_full=False, sample_size=3
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user