feat(data): 添加个股资金流向接口并重构速率限制配置

- 新增 moneyflow 资金流向数据同步模块
- 实现接口级速率限制配置(sync_config.py)
- 更新流动性相关因子定义
- 添加非对称量化损失函数
This commit is contained in:
2026-04-03 23:57:47 +08:00
parent c143815443
commit 9e7d4241c6
18 changed files with 1473 additions and 334 deletions

12
config/sync_config.json Normal file
View File

@@ -0,0 +1,12 @@
{
"default": {
"rate_limit": 150,
"threads": 1
},
"interfaces": {
"pro_bar": {
"rate_limit": 500,
"threads": 8
}
}
}

View File

@@ -749,4 +749,114 @@ df = pro.cyq_perf(ts_code='600000.SH', start_date='20220101', end_date='20220429
73 600000.SH 20220107 0.72 12.16 8.60 11.36 9.89 7.59
74 600000.SH 20220106 0.72 12.16 8.60 11.36 9.89 3.92
75 600000.SH 20220105 0.72 12.16 8.60 11.36 9.89 5.65
76 600000.SH 20220104 0.72 12.16 8.60 11.36 9.89 3.93
76 600000.SH 20220104 0.72 12.16 8.60 11.36 9.89 3.93
个股资金流向
接口moneyflow可以通过数据工具调试和查看数据。
描述获取沪深A股票资金流向数据分析大单小单成交情况用于判别资金动向数据开始于2010年。
限量单次最大提取6000行记录总量不限制
积分用户需要至少2000积分才可以调取基础积分有流量控制积分越多权限越大请自行提高积分具体请参阅积分获取办法
输入参数
名称 类型 必选 描述
ts_code str N 股票代码 (股票和时间参数至少输入一个)
trade_date str N 交易日期
start_date str N 开始日期
end_date str N 结束日期
输出参数
名称 类型 默认显示 描述
ts_code str Y TS代码
trade_date str Y 交易日期
buy_sm_vol int Y 小单买入量(手)
buy_sm_amount float Y 小单买入金额(万元)
sell_sm_vol int Y 小单卖出量(手)
sell_sm_amount float Y 小单卖出金额(万元)
buy_md_vol int Y 中单买入量(手)
buy_md_amount float Y 中单买入金额(万元)
sell_md_vol int Y 中单卖出量(手)
sell_md_amount float Y 中单卖出金额(万元)
buy_lg_vol int Y 大单买入量(手)
buy_lg_amount float Y 大单买入金额(万元)
sell_lg_vol int Y 大单卖出量(手)
sell_lg_amount float Y 大单卖出金额(万元)
buy_elg_vol int Y 特大单买入量(手)
buy_elg_amount float Y 特大单买入金额(万元)
sell_elg_vol int Y 特大单卖出量(手)
sell_elg_amount float Y 特大单卖出金额(万元)
net_mf_vol int Y 净流入量(手)
net_mf_amount float Y 净流入额(万元)
各类别统计规则如下:
小单5万以下 中单5万20万 大单20万100万 特大单:成交额>=100万 ,数据基于主动买卖单统计
接口示例
pro = ts.pro_api('your token')
#获取单日全部股票数据
df = pro.moneyflow(trade_date='20190315')
#获取单个股票数据
df = pro.moneyflow(ts_code='002149.SZ', start_date='20190115', end_date='20190315')
数据示例
ts_code trade_date buy_sm_vol buy_sm_amount sell_sm_vol \
0 000779.SZ 20190315 11377 1150.17 11100
1 000933.SZ 20190315 94220 4803.22 105924
2 002270.SZ 20190315 43979 2330.96 45893
3 002319.SZ 20190315 21502 2952.88 17155
4 002604.SZ 20190315 31944 607.35 58667
5 300065.SZ 20190315 16048 2294.71 16425
6 600062.SH 20190315 55439 7432.13 65765
7 002735.SZ 20190315 3220 797.10 4598
8 300196.SZ 20190315 12534 1286.02 8340
9 300350.SZ 20190315 15346 1120.12 18853
10 600193.SH 20190315 12183 503.73 19576
11 002866.SZ 20190315 16932 2213.68 16037
12 300481.SZ 20190315 21386 4275.33 21863
13 600527.SH 20190315 115462 2975.44 79272
14 603980.SH 20190315 13957 1924.69 11718
15 600658.SH 20190315 71767 4826.73 69535
16 600812.SH 20190315 26140 1247.47 34923
17 002013.SZ 20190315 170234 12286.02 148509
18 600789.SH 20190315 211012 21644.56 150598
19 601636.SH 20190315 70737 3117.43 68073
20 000807.SZ 20190315 129668 6361.06 122077
...
sell_sm_amount buy_md_vol buy_md_amount sell_md_vol sell_md_amount \
0 1122.97 13012 1316.72 14812 1498.90
1 5411.72 135976 6935.40 154023 7863.00
2 2435.98 57679 3059.15 47279 2507.55
3 2358.68 27245 3742.52 26708 3670.05
4 1114.40 69897 1327.41 41108 781.19
5 2353.34 31232 4472.05 26771 3834.95
6 8817.75 86617 11615.40 79551 10676.99
7 1140.61 4602 1141.61 2730 676.72
8 855.45 9401 963.72 10478 1074.32
9 1380.31 24224 1770.90 21588 1577.92
10 812.58 28696 1185.17 31087 1286.11
11 2100.70 19197 2511.62 20269 2650.56
12 4379.14 31692 6345.72 32873 6578.36
13 2046.54 107103 2763.00 84883 2191.24
14 1619.33 14621 2019.41 14528 2005.69
15 4691.29 92788 6232.80 93273 6280.13
16 1669.97 38812 1855.78 39211 1874.05
17 10726.22 154979 11190.69 164090 11855.76
18 15479.08 269470 27660.18 236958 24338.36
19 3000.73 90416 3984.68 115162 5075.50
20 5999.66 175692 8627.77 178044 8751.08

View File

@@ -29,12 +29,6 @@ class Settings(BaseSettings):
root_path: str = "" # 项目根路径,默认自动检测
data_path: str = "data" # 数据存储路径,相对于 root_path
# API 速率限制(每分钟请求数)
rate_limit: int = 300
# 同步工作线程数
threads: int = 10
# 数据库配置(可选,用于未来扩展)
database_host: str = "localhost"
database_port: int = 5432

View File

@@ -14,6 +14,7 @@ Available APIs:
- api_stock_st: ST stock list (ST股票列表)
- api_stk_limit: Stock limit price (每日涨跌停价格)
- api_cyq_perf: CYQ performance (每日筹码及胜率)
- api_moneyflow: Money flow (个股资金流向)
Example:
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
@@ -21,6 +22,7 @@ Example:
>>> from src.data.api_wrappers import get_stock_st, sync_stock_st
>>> from src.data.api_wrappers import get_stk_limit, sync_stk_limit
>>> from src.data.api_wrappers import get_cyq_perf, sync_cyq_perf
>>> from src.data.api_wrappers import get_moneyflow, sync_moneyflow
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
>>> daily_basic = get_daily_basic(trade_date='20240101')
@@ -30,6 +32,7 @@ Example:
>>> stock_st = get_stock_st(trade_date='20240101')
>>> stk_limit = get_stk_limit(trade_date='20240101')
>>> cyq_perf = get_cyq_perf(trade_date='20240115')
>>> moneyflow = get_moneyflow(trade_date='20240115')
"""
from src.data.api_wrappers.api_daily_basic import (
@@ -77,6 +80,12 @@ from src.data.api_wrappers.api_cyq_perf import (
preview_cyq_perf_sync,
CyqPerfSync,
)
from src.data.api_wrappers.api_moneyflow import (
get_moneyflow,
sync_moneyflow,
preview_moneyflow_sync,
MoneyflowSync,
)
__all__ = [
# Daily market data
@@ -129,6 +138,11 @@ __all__ = [
"sync_cyq_perf",
"preview_cyq_perf_sync",
"CyqPerfSync",
# Moneyflow (个股资金流向)
"get_moneyflow",
"sync_moneyflow",
"preview_moneyflow_sync",
"MoneyflowSync",
]
# =============================================================================
@@ -223,6 +237,17 @@ try:
order=60,
)
# 9. Moneyflow - 个股资金流向
from src.data.api_wrappers.api_moneyflow import MoneyflowSync
sync_registry.register_class(
name="moneyflow",
sync_class=MoneyflowSync,
display_name="个股资金流向",
description="沪深A股资金流向数据分析大单小单成交情况2010年开始",
order=70,
)
except ImportError:
# sync_registry 可能不存在(首次导入),忽略
pass

View File

@@ -0,0 +1,228 @@
"""个股资金流向 (Moneyflow) interface.
Fetch A-share stock money flow data from Tushare.
This interface retrieves fund flow data analyzing large and small order transactions.
Data starts from 2010.
"""
import pandas as pd
from typing import Optional
from src.data.client import TushareClient
from src.data.api_wrappers.base_sync import DateBasedSync
def get_moneyflow(
trade_date: Optional[str] = None,
ts_code: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
client: Optional[TushareClient] = None,
) -> pd.DataFrame:
"""Fetch individual stock money flow data from Tushare.
This interface retrieves fund flow data analyzing large and small order
transactions for A-share stocks. Data starts from 2010.
Order size classification:
- Small orders (小单): < 50,000 yuan
- Medium orders (中单): 50,000 - 200,000 yuan
- Large orders (大单): 200,000 - 1,000,000 yuan
- Extra large orders (特大单): >= 1,000,000 yuan
Args:
trade_date: Specific trade date in YYYYMMDD format
ts_code: Stock code filter (optional, e.g., '000001.SZ')
start_date: Start date for date range query (YYYYMMDD format)
end_date: End date for date range query (YYYYMMDD format)
client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting.
Returns:
pd.DataFrame with columns:
- ts_code: Stock code
- trade_date: Trade date (YYYYMMDD)
- buy_sm_vol: Small order buy volume (hands)
- buy_sm_amount: Small order buy amount (10k yuan)
- sell_sm_vol: Small order sell volume (hands)
- sell_sm_amount: Small order sell amount (10k yuan)
- buy_md_vol: Medium order buy volume (hands)
- buy_md_amount: Medium order buy amount (10k yuan)
- sell_md_vol: Medium order sell volume (hands)
- sell_md_amount: Medium order sell amount (10k yuan)
- buy_lg_vol: Large order buy volume (hands)
- buy_lg_amount: Large order buy amount (10k yuan)
- sell_lg_vol: Large order sell volume (hands)
- sell_lg_amount: Large order sell amount (10k yuan)
- buy_elg_vol: Extra large order buy volume (hands)
- buy_elg_amount: Extra large order buy amount (10k yuan)
- sell_elg_vol: Extra large order sell volume (hands)
- sell_elg_amount: Extra large order sell amount (10k yuan)
- net_mf_vol: Net money flow volume (hands)
- net_mf_amount: Net money flow amount (10k yuan)
Example:
>>> # Get all stocks' money flow for a single date
>>> data = get_moneyflow(trade_date='20240115')
>>>
>>> # Get date range data for a specific stock
>>> data = get_moneyflow(ts_code='000001.SZ', start_date='20240101', end_date='20240131')
>>>
>>> # Get specific stock on specific date
>>> data = get_moneyflow(ts_code='000001.SZ', trade_date='20240115')
"""
client = client or TushareClient()
# Build parameters
params = {}
if trade_date:
params["trade_date"] = trade_date
if ts_code:
params["ts_code"] = ts_code
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
# Fetch data using moneyflow API
data = client.query("moneyflow", **params)
# Rename date column if needed
if "date" in data.columns:
data = data.rename(columns={"date": "trade_date"})
return data
class MoneyflowSync(DateBasedSync):
"""个股资金流向数据批量同步管理器,支持全量/增量同步。
继承自 DateBasedSync使用按日期并发获取数据。
数据始于 2010 年。
Example:
>>> sync = MoneyflowSync()
>>> results = sync.sync_all() # 增量同步
>>> results = sync.sync_all(force_full=True) # 全量同步
>>> preview = sync.preview_sync() # 预览
"""
table_name = "moneyflow"
default_start_date = "20100101"
# 表结构定义 - 使用 Tushare API 原始字段名
TABLE_SCHEMA = {
"ts_code": "VARCHAR(16) NOT NULL",
"trade_date": "DATE NOT NULL",
"buy_sm_vol": "INTEGER", # 小单买入量(手)
"buy_sm_amount": "DOUBLE", # 小单买入金额(万元)
"sell_sm_vol": "INTEGER", # 小单卖出量(手)
"sell_sm_amount": "DOUBLE", # 小单卖出金额(万元)
"buy_md_vol": "INTEGER", # 中单买入量(手)
"buy_md_amount": "DOUBLE", # 中单买入金额(万元)
"sell_md_vol": "INTEGER", # 中单卖出量(手)
"sell_md_amount": "DOUBLE", # 中单卖出金额(万元)
"buy_lg_vol": "INTEGER", # 大单买入量(手)
"buy_lg_amount": "DOUBLE", # 大单买入金额(万元)
"sell_lg_vol": "INTEGER", # 大单卖出量(手)
"sell_lg_amount": "DOUBLE", # 大单卖出金额(万元)
"buy_elg_vol": "INTEGER", # 特大单买入量(手)
"buy_elg_amount": "DOUBLE", # 特大单买入金额(万元)
"sell_elg_vol": "INTEGER", # 特大单卖出量(手)
"sell_elg_amount": "DOUBLE", # 特大单卖出金额(万元)
"net_mf_vol": "INTEGER", # 净流入量(手)
"net_mf_amount": "DOUBLE", # 净流入额(万元)
}
# 索引定义
TABLE_INDEXES = [
("idx_moneyflow_date_code", ["trade_date", "ts_code"]),
]
# 主键定义
PRIMARY_KEY = ("ts_code", "trade_date")
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
"""获取单日所有股票的资金流向数据。
Args:
trade_date: 交易日期YYYYMMDD
Returns:
包含当日所有股票资金流向数据的 DataFrame
"""
return get_moneyflow(trade_date=trade_date, client=self.client)
def sync_moneyflow(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
force_full: bool = False,
) -> pd.DataFrame:
"""同步个股资金流向数据到 DuckDB支持智能增量同步。
逻辑:
- 若表不存在:创建表 + 复合索引 (trade_date, ts_code) + 全量同步
- 若表存在:从 last_date + 1 开始增量同步
Args:
start_date: 起始日期YYYYMMDD 格式,默认全量从 20100101增量从 last_date+1
end_date: 结束日期YYYYMMDD 格式,默认为今天)
force_full: 若为 True强制从 20100101 完整重载
Returns:
包含同步数据的 pd.DataFrame
Example:
>>> # 首次同步(从 20100101 全量加载)
>>> result = sync_moneyflow()
>>>
>>> # 后续同步(增量 - 仅新数据)
>>> result = sync_moneyflow()
>>>
>>> # 强制完整重载
>>> result = sync_moneyflow(force_full=True)
>>>
>>> # 手动指定日期范围
>>> result = sync_moneyflow(start_date='20240101', end_date='20240131')
"""
sync_manager = MoneyflowSync()
return sync_manager.sync_all(
start_date=start_date,
end_date=end_date,
force_full=force_full,
)
def preview_moneyflow_sync(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
force_full: bool = False,
sample_size: int = 3,
) -> dict:
"""预览个股资金流向数据同步数据量和样本(不实际同步)。
Args:
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
force_full: 若为 True预览全量同步从 20100101
sample_size: 预览天数(默认: 3
Returns:
包含预览信息的字典
Example:
>>> # 预览将要同步的内容
>>> preview = preview_moneyflow_sync()
>>>
>>> # 预览全量同步
>>> preview = preview_moneyflow_sync(force_full=True)
"""
sync_manager = MoneyflowSync()
return sync_manager.preview_sync(
start_date=start_date,
end_date=end_date,
force_full=force_full,
sample_size=sample_size,
)

View File

@@ -35,8 +35,8 @@ from tqdm import tqdm
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.sync_logger import SyncLogManager
from src.data.sync_config import get_threads
from src.data.utils import get_today_date, get_next_date
from src.config.settings import get_settings
from src.data.api_wrappers.api_trade_cal import (
get_first_trading_day,
get_last_trading_day,
@@ -63,7 +63,6 @@ class BaseDataSync(ABC):
table_name: str = "" # 子类必须覆盖
DEFAULT_START_DATE = "20180101"
DEFAULT_MAX_WORKERS = get_settings().threads
# 表结构定义(子类可覆盖)
# 格式: {"column_name": "SQL_TYPE", ...}
@@ -81,11 +80,14 @@ class BaseDataSync(ABC):
"""初始化同步管理器。
Args:
max_workers: 工作线程数(默认从配置读取)
max_workers: 工作线程数(默认从配置读取,根据 table_name 获取
"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
# 使用 table_name 作为接口名称初始化客户端,获取特定的速率限制
self.client = TushareClient(interface_name=self.table_name or "default")
# 根据表名获取线程数配置
default_workers = get_threads(self.table_name or "default")
self.max_workers = max_workers or default_workers
self._stop_flag = threading.Event()
self._stop_flag.set() # 初始为未停止状态
self._cached_data: Optional[pd.DataFrame] = None

View File

@@ -37,6 +37,7 @@
from typing import List, Optional
from src.data.sync_logger import SyncLogManager
from src.data.sync_config import get_rate_limit, get_threads
from src.data.api_wrappers.financial_data.api_income import (
IncomeQuarterSync,
sync_income,
@@ -151,7 +152,12 @@ def sync_financial(
sync_func = config["sync_func"]
display_name = config["display_name"]
# 获取当前接口的 rate limit
rate_limit = get_rate_limit(data_type)
threads = get_threads(data_type)
print(f"\n[{display_name}] 开始同步...")
print(f" Rate Limit: {rate_limit}/min, Threads: {threads}")
try:
result = sync_func(force_full=force_full, dry_run=dry_run)

View File

@@ -2,23 +2,29 @@
import time
import pandas as pd
import tushare as ts
from typing import Optional
from src.data.rate_limiter import TokenBucketRateLimiter
from src.data.sync_config import get_rate_limit
from src.config.settings import get_settings
class TushareClient:
"""Tushare API client with rate limiting and retry."""
# 类级别共享限流器(确保所有实例共享同一个限流器
_shared_limiter: Optional[TokenBucketRateLimiter] = None
_cached_rate_limit: int = 0 # 缓存上次使用的 rate_limit
# 类级别限流器缓存(按接口名称存储
_limiter_cache: dict[str, TokenBucketRateLimiter] = {}
_cached_rate_limits: dict[str, int] = {}
def __init__(self, token: Optional[str] = None):
def __init__(
self, token: Optional[str] = None, interface_name: Optional[str] = None
):
"""Initialize client.
Args:
token: Tushare API token (auto-loaded from config if not provided)
interface_name: 接口名称,用于获取特定的速率限制配置
"""
cfg = get_settings()
token = token or cfg.tushare_token
@@ -28,32 +34,57 @@ class TushareClient:
self.token = token
self.config = cfg
self.interface_name = interface_name or "default"
# 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器)
# 检查是否需要重新创建限流器(配置发生变化时)
if (
TushareClient._shared_limiter is None
or TushareClient._cached_rate_limit != cfg.rate_limit
):
# 首次创建或配置变更:重新初始化共享限流器
TushareClient._shared_limiter = TokenBucketRateLimiter(
rate_limit=cfg.rate_limit,
)
TushareClient._cached_rate_limit = cfg.rate_limit
min_interval = 60.0 / cfg.rate_limit
print(
f"[TushareClient] Initialized shared rate limiter: rate={cfg.rate_limit}/min, interval={min_interval:.2f}s"
)
# 复用共享限流器
self.rate_limiter = TushareClient._shared_limiter
# 获取或创建限流器
self.rate_limiter = self._get_or_create_limiter(self.interface_name)
self._api = None
def _get_or_create_limiter(self, interface_name: str) -> TokenBucketRateLimiter:
"""获取或创建指定接口的限流器。
如果接口的 rate_limit 配置发生变化,会重新创建限流器。
Args:
interface_name: 接口名称
Returns:
TokenBucketRateLimiter 实例
"""
# 获取当前配置的 rate_limit
current_rate_limit = get_rate_limit(interface_name)
# 检查是否需要创建新的限流器
if (
interface_name not in TushareClient._limiter_cache
or TushareClient._cached_rate_limits.get(interface_name)
!= current_rate_limit
):
# 创建新的限流器
TushareClient._limiter_cache[interface_name] = TokenBucketRateLimiter(
rate_limit=current_rate_limit,
)
TushareClient._cached_rate_limits[interface_name] = current_rate_limit
min_interval = 60.0 / current_rate_limit
print(
f"[TushareClient] Initialized rate limiter for '{interface_name}': "
f"rate={current_rate_limit}/min, interval={min_interval:.2f}s"
)
return TushareClient._limiter_cache[interface_name]
def get_current_rate_limit(self) -> int:
"""获取当前接口的速率限制。
Returns:
每分钟请求数
"""
return get_rate_limit(self.interface_name)
def _get_api(self):
"""Get Tushare API instance."""
if self._api is None:
import tushare as ts
self._api = ts.pro_api(self.token)
return self._api
@@ -80,8 +111,6 @@ class TushareClient:
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
try:
import tushare as ts
# pro_bar uses ts.pro_bar() instead of api.query()
if api_name == "pro_bar":
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor

View File

@@ -13,6 +13,7 @@
- api_stock_st.py: ST股票列表同步 (StockSTSync 类)
- api_stk_limit.py: 涨跌停价格同步 (StkLimitSync 类)
- api_cyq_perf.py: 筹码分布数据同步 (CyqPerfSync 类)
- api_moneyflow.py: 个股资金流向同步 (MoneyflowSync 类)
- api_stock_basic.py: 股票基本信息同步
- api_trade_cal.py: 交易日历同步
@@ -82,6 +83,7 @@ def sync_all_data(
6. stock_st: ST股票列表
7. stk_limit: 每日涨跌停价格
8. cyq_perf: 每日筹码及胜率
9. moneyflow: 个股资金流向
新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码,
无需修改本函数。

431
src/data/sync_config.py Normal file
View File

@@ -0,0 +1,431 @@
"""同步配置管理模块。
该模块提供统一的同步配置管理,支持:
- 默认的线程数和速率限制配置
- 每个接口单独的速率限制配置
- 从 JSON 配置文件加载配置
配置文件路径: config/sync_config.json
配置示例:
{
"default": {
"threads": 10,
"rate_limit": 300
},
"interfaces": {
"pro_bar": {
"rate_limit": 200
},
"daily_basic": {
"rate_limit": 150
},
"income": {
"rate_limit": 100
}
}
}
"""
import json
import threading
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Optional, Dict, Any
from functools import lru_cache
@dataclass
class InterfaceConfig:
"""单个接口的配置。"""
rate_limit: Optional[int] = None
threads: Optional[int] = None
@dataclass
class SyncConfig:
"""同步配置类。"""
default: InterfaceConfig = field(
default_factory=lambda: InterfaceConfig(
rate_limit=300,
threads=10,
)
)
interfaces: Dict[str, InterfaceConfig] = field(default_factory=dict)
def get_interface_config(self, interface_name: str) -> InterfaceConfig:
"""获取指定接口的配置。
如果接口有特定配置则返回,否则返回默认配置。
Args:
interface_name: 接口名称
Returns:
接口配置对象
"""
if interface_name in self.interfaces:
interface_config = self.interfaces[interface_name]
# 合并默认值(特定配置覆盖默认)
return InterfaceConfig(
rate_limit=interface_config.rate_limit
if interface_config.rate_limit is not None
else self.default.rate_limit,
threads=interface_config.threads
if interface_config.threads is not None
else self.default.threads,
)
return self.default
def get_rate_limit(self, interface_name: str) -> int:
"""获取指定接口的速率限制。
Args:
interface_name: 接口名称
Returns:
每分钟请求数
"""
config = self.get_interface_config(interface_name)
return config.rate_limit or self.default.rate_limit or 300
def get_threads(self, interface_name: str) -> int:
"""获取指定接口的线程数。
Args:
interface_name: 接口名称
Returns:
线程数
"""
config = self.get_interface_config(interface_name)
return config.threads or self.default.threads or 10
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式。"""
return {
"default": {
"rate_limit": self.default.rate_limit,
"threads": self.default.threads,
},
"interfaces": {
name: {
"rate_limit": config.rate_limit,
"threads": config.threads,
}
for name, config in self.interfaces.items()
},
}
class SyncConfigManager:
"""同步配置管理器。
负责加载、保存和管理同步配置。
支持从 JSON 文件加载配置,提供线程安全的单例访问。
"""
_instance: Optional["SyncConfigManager"] = None
_lock = threading.Lock()
DEFAULT_CONFIG = {
"default": {
"threads": 10,
"rate_limit": 300,
},
"interfaces": {
# 高频接口(每秒可请求多次)
"trade_cal": {"rate_limit": 500},
"stock_basic": {"rate_limit": 500},
"bak_basic": {"rate_limit": 200},
"stock_st": {"rate_limit": 200},
"stk_limit": {"rate_limit": 200},
"cyq_perf": {"rate_limit": 150},
"moneyflow": {"rate_limit": 150},
"daily_basic": {"rate_limit": 150},
"pro_bar": {"rate_limit": 200},
# 低频财务接口
"income": {"rate_limit": 100},
"balance": {"rate_limit": 100},
"cashflow": {"rate_limit": 100},
"fina_indicator": {"rate_limit": 100},
"namechange": {"rate_limit": 200},
},
}
def __new__(cls, config_path: Optional[str] = None) -> "SyncConfigManager":
"""单例模式确保全局只有一个配置管理器实例。"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, config_path: Optional[str] = None):
"""初始化配置管理器。
Args:
config_path: 配置文件路径,默认使用项目根目录下的 config/sync_config.json
"""
# 避免重复初始化
if self._initialized:
return
if config_path is None:
# 默认路径:项目根目录/config/sync_config.json
project_root = Path(__file__).parent.parent.parent
self.config_path = project_root / "config" / "sync_config.json"
else:
self.config_path = Path(config_path)
self._config: Optional[SyncConfig] = None
self._config_lock = threading.RLock()
self._initialized = True
def _load_config(self) -> SyncConfig:
"""从文件加载配置。
如果配置文件不存在,则创建默认配置。
Returns:
同步配置对象
"""
with self._config_lock:
if self._config is not None:
return self._config
if self.config_path.exists():
try:
with open(self.config_path, "r", encoding="utf-8") as f:
data = json.load(f)
default_data = data.get("default", {})
default_config = InterfaceConfig(
rate_limit=default_data.get("rate_limit", 300),
threads=default_data.get("threads", 10),
)
interfaces = {}
for name, config_data in data.get("interfaces", {}).items():
interfaces[name] = InterfaceConfig(
rate_limit=config_data.get("rate_limit"),
threads=config_data.get("threads"),
)
self._config = SyncConfig(
default=default_config,
interfaces=interfaces,
)
print(f"[SyncConfig] Loaded config from {self.config_path}")
except Exception as e:
print(f"[SyncConfig] Error loading config: {e}, using defaults")
self._config = self._create_default_config()
else:
self._config = self._create_default_config()
self.save_config()
return self._config
def _create_default_config(self) -> SyncConfig:
"""创建默认配置。"""
default = InterfaceConfig(
rate_limit=self.DEFAULT_CONFIG["default"]["rate_limit"],
threads=self.DEFAULT_CONFIG["default"]["threads"],
)
interfaces = {}
for name, config_data in self.DEFAULT_CONFIG["interfaces"].items():
interfaces[name] = InterfaceConfig(
rate_limit=config_data.get("rate_limit"),
threads=config_data.get("threads"),
)
return SyncConfig(default=default, interfaces=interfaces)
def save_config(self) -> None:
"""保存当前配置到文件。"""
with self._config_lock:
if self._config is None:
self._config = self._create_default_config()
try:
self.config_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.config_path, "w", encoding="utf-8") as f:
json.dump(self._config.to_dict(), f, indent=2, ensure_ascii=False)
print(f"[SyncConfig] Saved config to {self.config_path}")
except Exception as e:
print(f"[SyncConfig] Error saving config: {e}")
def get_config(self) -> SyncConfig:
"""获取当前配置。
Returns:
同步配置对象
"""
with self._config_lock:
if self._config is None:
return self._load_config()
return self._config
def get_rate_limit(self, interface_name: str) -> int:
"""获取指定接口的速率限制。
Args:
interface_name: 接口名称
Returns:
每分钟请求数
"""
return self.get_config().get_rate_limit(interface_name)
def get_threads(self, interface_name: str) -> int:
"""获取指定接口的线程数。
Args:
interface_name: 接口名称
Returns:
线程数
"""
return self.get_config().get_threads(interface_name)
def update_interface_config(
self,
interface_name: str,
rate_limit: Optional[int] = None,
threads: Optional[int] = None,
) -> None:
"""更新指定接口的配置。
Args:
interface_name: 接口名称
rate_limit: 新的速率限制None 表示不修改)
threads: 新的线程数None 表示不修改)
"""
with self._config_lock:
config = self.get_config()
if interface_name not in config.interfaces:
config.interfaces[interface_name] = InterfaceConfig()
interface_config = config.interfaces[interface_name]
if rate_limit is not None:
interface_config.rate_limit = rate_limit
if threads is not None:
interface_config.threads = threads
self.save_config()
def reset_to_defaults(self) -> None:
"""重置为默认配置。"""
with self._config_lock:
self._config = self._create_default_config()
self.save_config()
print("[SyncConfig] Reset to default configuration")
def list_interfaces(self) -> Dict[str, Dict[str, Any]]:
"""列出所有接口的配置。
Returns:
接口配置字典 {接口名: {rate_limit, threads}}
"""
config = self.get_config()
result = {}
# 添加默认配置
result["__default__"] = {
"rate_limit": config.default.rate_limit,
"threads": config.default.threads,
}
# 添加各接口配置
all_interfaces = set(config.interfaces.keys()) | set(
self.DEFAULT_CONFIG["interfaces"].keys()
)
for name in sorted(all_interfaces):
iface_config = config.get_interface_config(name)
result[name] = {
"rate_limit": iface_config.rate_limit,
"threads": iface_config.threads,
}
return result
# 全局配置管理器实例
_config_manager: Optional[SyncConfigManager] = None
_config_manager_lock = threading.Lock()
def get_sync_config_manager() -> SyncConfigManager:
"""获取同步配置管理器单例。
Returns:
SyncConfigManager 实例
"""
global _config_manager
if _config_manager is None:
with _config_manager_lock:
if _config_manager is None:
_config_manager = SyncConfigManager()
return _config_manager
def get_sync_config() -> SyncConfig:
"""获取同步配置。
Returns:
SyncConfig 实例
"""
return get_sync_config_manager().get_config()
def get_rate_limit(interface_name: str) -> int:
"""获取指定接口的速率限制。
Args:
interface_name: 接口名称
Returns:
每分钟请求数
"""
return get_sync_config_manager().get_rate_limit(interface_name)
def get_threads(interface_name: str) -> int:
"""获取指定接口的线程数。
Args:
interface_name: 接口名称
Returns:
线程数
"""
return get_sync_config_manager().get_threads(interface_name)
def print_sync_config() -> None:
"""打印当前同步配置(用于调试)。"""
manager = get_sync_config_manager()
configs = manager.list_interfaces()
print("\n" + "=" * 60)
print("[SyncConfig] Current Configuration")
print("=" * 60)
default_config = configs.pop("__default__", {})
print(
f"Default: rate_limit={default_config.get('rate_limit')}/min, threads={default_config.get('threads')}"
)
print("-" * 60)
for name, config in sorted(configs.items()):
marker = "*" if name in manager.DEFAULT_CONFIG["interfaces"] else " "
print(
f" [{marker}] {name:20s}: rate_limit={config['rate_limit']:3d}/min, threads={config['threads']:2d}"
)
print("=" * 60)

View File

@@ -28,6 +28,8 @@ from collections import OrderedDict
import pandas as pd
from src.data.sync_config import get_rate_limit, get_threads
@dataclass
class SyncTask:
@@ -221,7 +223,12 @@ class SyncRegistry:
print("=" * 60)
for idx, task in enumerate(tasks, 1):
# 获取当前接口的配置
rate_limit = get_rate_limit(task.name)
threads = get_threads(task.name)
print(f"\n[{idx}/{total}] Syncing {task.display_name}...")
print(f" Rate Limit: {rate_limit}/min, Threads: {threads}")
if task.description:
print(f" Description: {task.description}")

View File

@@ -123,133 +123,133 @@ SELECTED_FACTORS = [
"GTJA_alpha048",
"GTJA_alpha049",
"GTJA_alpha050",
# "GTJA_alpha051",
# "GTJA_alpha052",
# "GTJA_alpha053",
# "GTJA_alpha054",
# "GTJA_alpha056",
# "GTJA_alpha057",
# "GTJA_alpha058",
# "GTJA_alpha059",
# "GTJA_alpha060",
# "GTJA_alpha061",
# "GTJA_alpha062",
# "GTJA_alpha063",
# "GTJA_alpha064",
# "GTJA_alpha065",
# "GTJA_alpha066",
# "GTJA_alpha067",
# "GTJA_alpha068",
# "GTJA_alpha070",
# "GTJA_alpha071",
# "GTJA_alpha072",
# "GTJA_alpha073",
# "GTJA_alpha074",
# "GTJA_alpha076",
# "GTJA_alpha077",
# "GTJA_alpha078",
# "GTJA_alpha079",
# "GTJA_alpha080",
# "GTJA_alpha081",
# "GTJA_alpha082",
# "GTJA_alpha083",
# "GTJA_alpha084",
# "GTJA_alpha085",
# "GTJA_alpha086",
# "GTJA_alpha087",
# "GTJA_alpha088",
# "GTJA_alpha089",
# "GTJA_alpha090",
# "GTJA_alpha091",
# "GTJA_alpha092",
# "GTJA_alpha093",
# "GTJA_alpha094",
# "GTJA_alpha095",
# "GTJA_alpha096",
# "GTJA_alpha097",
# "GTJA_alpha098",
# "GTJA_alpha099",
# "GTJA_alpha100",
# "GTJA_alpha101",
# "GTJA_alpha102",
# "GTJA_alpha103",
# "GTJA_alpha104",
# "GTJA_alpha105",
# "GTJA_alpha106",
# "GTJA_alpha107",
# "GTJA_alpha108",
# "GTJA_alpha109",
# "GTJA_alpha110",
# "GTJA_alpha111",
# "GTJA_alpha112",
# # "GTJA_alpha113",
# "GTJA_alpha114",
# "GTJA_alpha115",
# "GTJA_alpha117",
# "GTJA_alpha118",
# "GTJA_alpha119",
# "GTJA_alpha120",
# # "GTJA_alpha121",
# "GTJA_alpha122",
# "GTJA_alpha123",
# "GTJA_alpha124",
# "GTJA_alpha125",
# "GTJA_alpha126",
# "GTJA_alpha127",
# "GTJA_alpha128",
# "GTJA_alpha129",
# "GTJA_alpha130",
# "GTJA_alpha131",
# "GTJA_alpha132",
# "GTJA_alpha133",
# "GTJA_alpha134",
# "GTJA_alpha135",
# "GTJA_alpha136",
# # "GTJA_alpha138",
# "GTJA_alpha139",
# # "GTJA_alpha140",
# "GTJA_alpha141",
# "GTJA_alpha142",
# "GTJA_alpha145",
# # "GTJA_alpha146",
# "GTJA_alpha148",
# "GTJA_alpha150",
# "GTJA_alpha151",
# "GTJA_alpha152",
# "GTJA_alpha153",
# "GTJA_alpha154",
# "GTJA_alpha155",
# "GTJA_alpha156",
# "GTJA_alpha157",
# "GTJA_alpha158",
# "GTJA_alpha159",
# "GTJA_alpha160",
# "GTJA_alpha161",
# # "GTJA_alpha162",
# "GTJA_alpha163",
# "GTJA_alpha164",
# # "GTJA_alpha165",
# "GTJA_alpha166",
# "GTJA_alpha167",
# "GTJA_alpha168",
# "GTJA_alpha169",
# "GTJA_alpha170",
# "GTJA_alpha171",
# "GTJA_alpha173",
# "GTJA_alpha174",
# "GTJA_alpha175",
# "GTJA_alpha176",
# "GTJA_alpha177",
# "GTJA_alpha178",
# "GTJA_alpha179",
# "GTJA_alpha180",
# # "GTJA_alpha183",
# "GTJA_alpha184",
# "GTJA_alpha185",
# "GTJA_alpha187",
# "GTJA_alpha188",
# "GTJA_alpha189",
# "GTJA_alpha191",
"GTJA_alpha051",
"GTJA_alpha052",
"GTJA_alpha053",
"GTJA_alpha054",
"GTJA_alpha056",
"GTJA_alpha057",
"GTJA_alpha058",
"GTJA_alpha059",
"GTJA_alpha060",
"GTJA_alpha061",
"GTJA_alpha062",
"GTJA_alpha063",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha067",
"GTJA_alpha068",
"GTJA_alpha070",
"GTJA_alpha071",
"GTJA_alpha072",
"GTJA_alpha073",
"GTJA_alpha074",
"GTJA_alpha076",
"GTJA_alpha077",
"GTJA_alpha078",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha081",
"GTJA_alpha082",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha085",
"GTJA_alpha086",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha089",
"GTJA_alpha090",
"GTJA_alpha091",
"GTJA_alpha092",
"GTJA_alpha093",
"GTJA_alpha094",
"GTJA_alpha095",
"GTJA_alpha096",
"GTJA_alpha097",
"GTJA_alpha098",
"GTJA_alpha099",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha106",
"GTJA_alpha107",
"GTJA_alpha108",
"GTJA_alpha109",
"GTJA_alpha110",
"GTJA_alpha111",
"GTJA_alpha112",
# "GTJA_alpha113",
"GTJA_alpha114",
"GTJA_alpha115",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha119",
"GTJA_alpha120",
# "GTJA_alpha121",
"GTJA_alpha122",
"GTJA_alpha123",
"GTJA_alpha124",
"GTJA_alpha125",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha131",
"GTJA_alpha132",
"GTJA_alpha133",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
# "GTJA_alpha138",
"GTJA_alpha139",
# "GTJA_alpha140",
"GTJA_alpha141",
"GTJA_alpha142",
"GTJA_alpha145",
# "GTJA_alpha146",
"GTJA_alpha148",
"GTJA_alpha150",
"GTJA_alpha151",
"GTJA_alpha152",
"GTJA_alpha153",
"GTJA_alpha154",
"GTJA_alpha155",
"GTJA_alpha156",
"GTJA_alpha157",
"GTJA_alpha158",
"GTJA_alpha159",
"GTJA_alpha160",
"GTJA_alpha161",
# "GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha164",
# "GTJA_alpha165",
"GTJA_alpha166",
"GTJA_alpha167",
"GTJA_alpha168",
"GTJA_alpha169",
"GTJA_alpha170",
"GTJA_alpha171",
"GTJA_alpha173",
"GTJA_alpha174",
"GTJA_alpha175",
"GTJA_alpha176",
"GTJA_alpha177",
"GTJA_alpha178",
"GTJA_alpha179",
"GTJA_alpha180",
# "GTJA_alpha183",
"GTJA_alpha184",
"GTJA_alpha185",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
"chip_dispersion_90",
"chip_dispersion_70",
"cost_skewness",
@@ -270,12 +270,27 @@ SELECTED_FACTORS = [
"bottom_cost_stability",
"pivot_reversion",
"chip_transition",
# "amivest_liq_20",
# "atr_price_impact",
# "hui_heubel_ratio",
# "corwin_schultz_spread_20",
# "roll_spread_20",
# "gibbs_effective_spread",
# "overnight_illiq_20",
# "illiq_volatility_20",
# "amount_cv_20",
# "amount_skewness_20",
# "low_vol_days_20",
# "liquidity_shock_momentum",
# "downside_illiq_20",
# "upside_illiq_20",
# "illiq_asymmetry_20",
# "pastor_stambaugh_proxy"
]
# 因子定义字典完整因子库用于存放尚未注册到metadata的因子
FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"}
# =============================================================================
# Label 配置(统一绑定 label_name 和 label_dsl
# =============================================================================
@@ -308,11 +323,11 @@ def get_label_factor(label_name: str) -> dict:
# 辅助函数
# =============================================================================
def register_factors(
engine: FactorEngine,
selected_factors: List[str],
factor_definitions: dict,
label_factor: dict,
excluded_factors: Optional[List[str]] = None,
engine: FactorEngine,
selected_factors: List[str],
factor_definitions: dict,
label_factor: dict,
excluded_factors: Optional[List[str]] = None,
) -> List[str]:
"""注册因子。
@@ -393,11 +408,11 @@ def register_factors(
def prepare_data(
engine: FactorEngine,
feature_cols: List[str],
start_date: str,
end_date: str,
label_name: str,
engine: FactorEngine,
feature_cols: List[str],
start_date: str,
end_date: str,
label_name: str,
) -> pl.DataFrame:
"""准备数据。
@@ -455,11 +470,11 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""
# 代码筛选(排除创业板、科创板、北交所)
code_filter = (
~df["ts_code"].str.starts_with("30") # 排除创业板
& ~df["ts_code"].str.starts_with("68") # 排除科创板
& ~df["ts_code"].str.starts_with("8") # 排除北交所
& ~df["ts_code"].str.starts_with("9") # 排除北交所
& ~df["ts_code"].str.starts_with("4") # 排除北交所
~df["ts_code"].str.starts_with("30") # 排除创业板
& ~df["ts_code"].str.starts_with("68") # 排除科创板
& ~df["ts_code"].str.starts_with("8") # 排除北交所
& ~df["ts_code"].str.starts_with("9") # 排除北交所
& ~df["ts_code"].str.starts_with("4") # 排除北交所
)
# 在已筛选的股票中选取流通市值最小的500只
@@ -474,7 +489,6 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
# 定义筛选所需的基础列
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
# =============================================================================
# 输出配置
# =============================================================================
@@ -518,7 +532,7 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
def get_model_save_path(
model_type: str,
model_type: str,
) -> Optional[str]:
"""生成模型保存路径。
@@ -544,11 +558,11 @@ def get_model_save_path(
def save_model_with_factors(
model,
model_path: str,
selected_factors: list[str],
factor_definitions: dict,
fitted_processors: list | None = None,
model,
model_path: str,
selected_factors: list[str],
factor_definitions: dict,
fitted_processors: list | None = None,
) -> str:
"""保存模型及关联的因子信息和处理器。

View File

@@ -54,31 +54,27 @@ N_QUANTILES = 20
# 排除的因子列表
EXCLUDED_FACTORS = [
"GTJA_alpha010",
"GTJA_alpha005",
"GTJA_alpha002",
"GTJA_alpha027",
"GTJA_alpha051",
"GTJA_alpha044",
"GTJA_alpha041",
"GTJA_alpha131",
"GTJA_alpha103",
"GTJA_alpha087",
"GTJA_alpha093",
"GTJA_alpha092",
"GTJA_alpha073",
"GTJA_alpha127",
"GTJA_alpha117",
"GTJA_alpha124",
"GTJA_alpha162",
"GTJA_alpha177",
"GTJA_alpha188",
"smart_money_accumulation",
"GTJA_alpha014",
"GTJA_alpha056",
"GTJA_alpha085",
"GTJA_alpha154",
"GTJA_alpha141",
'active_market_cap',
'close_vwap_deviation',
'sharpe_ratio_20',
'upper_shadow_ratio',
'volume_ratio_5_20',
'GTJA_alpha090',
'GTJA_alpha084',
'GTJA_alpha066',
'GTJA_alpha150',
'GTJA_alpha148',
'GTJA_alpha106',
'GTJA_alpha109',
'GTJA_alpha108',
'GTJA_alpha176',
'GTJA_alpha169',
'GTJA_alpha156',
'chip_dispersion_70',
'winner_rate_cs_rank',
'atr_price_impact',
'low_vol_days_20',
'liquidity_shock_momentum',
]
# LambdaRank 模型参数配置

View File

@@ -52,55 +52,36 @@ TRAINING_TYPE = "regression"
# 排除的因子列表
EXCLUDED_FACTORS = [
"GTJA_alpha036",
"GTJA_alpha032",
"GTJA_alpha010",
"GTJA_alpha005",
"CP",
"BP",
"debt_to_equity",
"current_ratio",
"GTJA_alpha002",
"GTJA_alpha027",
"GTJA_alpha064",
"GTJA_alpha062",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha120",
"GTJA_alpha117",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha073",
"GTJA_alpha077",
"GTJA_alpha085",
"GTJA_alpha090",
"GTJA_alpha087",
"GTJA_alpha083",
"GTJA_alpha092",
"GTJA_alpha133",
"GTJA_alpha131",
"GTJA_alpha126",
"GTJA_alpha124",
"GTJA_alpha162",
"GTJA_alpha164",
"GTJA_alpha157",
"GTJA_alpha177",
"price_to_avg_cost",
"cost_skewness",
"GTJA_alpha191",
"GTJA_alpha180",
"history_position",
"bottom_profit",
"mean_median_dev",
"smart_money_accumulation",
"GTJA_alpha013",
"GTJA_alpha099",
"GTJA_alpha107",
"GTJA_alpha119",
"GTJA_alpha141",
"GTJA_alpha130",
"GTJA_alpha173",
'GTJA_alpha016',
'volatility_20',
'current_ratio',
'GTJA_alpha001',
'GTJA_alpha141',
'GTJA_alpha129',
'GTJA_alpha164',
'amivest_liq_20',
'GTJA_alpha012',
'debt_to_equity',
'turnover_deviation',
'GTJA_alpha073',
'GTJA_alpha043',
'GTJA_alpha032',
'GTJA_alpha028',
'GTJA_alpha090',
'GTJA_alpha108',
'GTJA_alpha105',
'GTJA_alpha091',
'GTJA_alpha119',
'GTJA_alpha104',
'GTJA_alpha163',
'GTJA_alpha157',
'cost_skewness',
'GTJA_alpha176',
'chip_transition',
'amount_skewness_20',
'GTJA_alpha148',
'mean_median_dev',
'downside_illiq_20',
]
# 模型参数配置

View File

@@ -34,112 +34,99 @@ from src.config.settings import get_settings
# ============================================================================
FACTORS: List[Dict[str, Any]] =[
# ==================== 第一类:筹码集中度与离散度因子 ====================
{
"name": "chip_dispersion_90",
"desc": "90%筹码离散度衡量市场90%持仓筹码的宽度,值越小表示筹码越高度集中(单峰密集),往往是洗盘结束的前兆",
"dsl": "(cost_95pct - cost_5pct) / (cost_95pct + cost_5pct)",
"name": "amihud_illiq_20",
"desc": "Amihud非流动性指标(20日):绝对收益率/成交额。该值越大,说明少量的资金就能砸盘或拉升,流动性极差,隐含极高的风险补偿预期",
"dsl": "ts_mean(abs(pct_chg) / (amount + 1), 20)",
},
{
"name": "chip_dispersion_70",
"desc": "70%核心筹码离散度剔除极端的底部死筹和高位套牢盘反映中间70%主流资金的成本集中度",
"dsl": "(cost_85pct - cost_15pct) / (cost_85pct + cost_15pct)",
"name": "amivest_liq_20",
"desc": "Amivest流动性指标(20日)Amihud的倒数变种衡量推动1%价格变化需要的资金量。值越低,流动性溢价越高",
"dsl": "ts_mean(amount / (abs(pct_chg) + 0.001), 20)",
},
{
"name": "cost_skewness",
"desc": "筹码偏度反映筹码分布的不对称性。大于1说明上方套牢盘拖尾严重小于1说明下方获利盘雄厚",
"dsl": "(cost_95pct - cost_50pct) / (cost_50pct - cost_5pct)",
"name": "atr_price_impact",
"desc": "真实波幅价格冲击(20日)以ATR(真实波幅)代替绝对收益,剔除跳空影响后的真实交易冲击",
"dsl": "ts_atr(high, low, close, 20) / close / (ts_mean(amount, 20) + 1)",
},
{
"name": "dispersion_change_20",
"desc": "筹码集中度近期变化率过去20天筹码宽度的变化比例持续下降说明主力正在暗中吸筹",
"dsl": "ts_pct_change((cost_95pct - cost_5pct) / cost_50pct, 20)",
"name": "hui_heubel_ratio",
"desc": "Hui-Heubel流动性比率利用高低价区间占均价的比例除以成交额占比捕捉阶段性区间的深度非流动性",
"dsl": "((ts_max(high, 20) - ts_min(low, 20)) / ts_min(low, 20)) / (ts_mean(amount, 20) + 1)",
},
# ==================== 第二类:筹码相对位置与压力/支撑因子 ====================
# --- 维度 2: 交易摩擦与隐形价差类 (Spread Proxies) ---
# 逻辑A股没有Tick级买卖价差数据通常用日线的高低价或自协方差来推导隐形买卖价差。价差越大交易成本越高持有需高收益补偿。
{
"name": "price_to_avg_cost",
"desc": "整体浮盈比例:当前价格相对加权平均成本的溢价率。高溢价有均值回归压力,负溢价代表超跌",
"dsl": "(close - weight_avg) / weight_avg",
"name": "corwin_schultz_spread_20",
"desc": "Corwin-Schultz买卖价差代理利用每日(最高价-最低价)/收盘价的均值衡量交易摩擦,摩擦越大的股票往往带有小盘股超额收益",
"dsl": "ts_mean((high - low) / close, 20)",
},
{
"name": "price_to_median_cost",
"desc": "中位数成本偏离度价格相对于50%分位点(绝对半数人持仓价)的偏离,向上突破通常是右侧买点",
"dsl": "(close - cost_50pct) / cost_50pct",
"name": "roll_spread_20",
"desc": "Roll买卖价差代理(20日):经典微观结构模型,计算相邻两日收益率的负协方差的平方根,反映做市商的隐形报价跳跃",
"dsl": "sqrt(max_(-ts_cov(change, ts_delay(change, 1), 20), 0))",
},
{
"name": "mean_median_dev",
"desc": "均值中位数背离:均值显著大于中位数说明高位筹码堆积,上涨阻力大",
"dsl": "(weight_avg - cost_50pct) / cost_50pct",
"name": "gibbs_effective_spread",
"desc": "有效价差代理:使用日内振幅减去隔夜跳空幅度后的纯日内摩擦成本",
"dsl": "ts_mean(((high - low) - abs(open - ts_delay(close, 1))) / close, 20)",
},
{
"name": "trap_pressure",
"desc": "高位套牢盘压力指数当前价格距离上方95%高位套牢成本的距离。距离越大,反弹的真空期阻力越小",
"dsl": "(cost_95pct - close) / close",
},
{
"name": "bottom_profit",
"desc": "底部支撑底仓利润率当前价格距离底部5%筹码的利润空间。暴跌时大于0说明底仓极度稳定",
"dsl": "(close - cost_5pct) / cost_5pct",
},
{
"name": "history_position",
"desc": "历史区间分位点:当前价格在个股上市以来历史最高点和最低点之间的相对位置",
"dsl": "(close - his_low) / (his_high - his_low)",
"name": "overnight_illiq_20",
"desc": "隔夜非流动性:开盘价相对于昨日收盘的跳空幅度,除以昨日成交额。隔夜极易跳空的股票具有夜间流动性溢价",
"dsl": "ts_mean(abs(open - ts_delay(close, 1)) / (ts_delay(amount, 1) + 1), 20)",
},
# ==================== 第三类:胜率相关的动量与反转因子 ====================
# --- 维度 3: 流动性风险与枯竭类 (Liquidity Risk & Depletion) ---
# 逻辑:投资者不仅讨厌“平时流动性差”,更讨厌“流动性极其不稳定”。
{
"name": "winner_rate_surge_5",
"desc": "获利盘短期爆发力胜率在过去5天内的变化值急剧上升是极强的动量做多信号",
"dsl": "ts_delta(winner_rate, 5)",
"name": "illiq_volatility_20",
"desc": "Amihud非流动性的波动率(20日):衡量价格冲击的不确定性(即流动性风险本身)。波动越大的股越容易踩踏",
"dsl": "ts_std(abs(pct_chg) / (amount + 1), 20)",
},
{
"name": "winner_rate_cs_rank",
"desc": "获利盘高位反转信号:全市场胜率截面排名,极端高胜率往往面临多头踩踏的获利了结压力(反转因子)",
"dsl": "cs_rank(winner_rate)",
"name": "amount_cv_20",
"desc": "成交额变异系数(20日):成交额的波动率除以均值。反映股票被市场关注的极度不稳定性",
"dsl": "ts_std(amount, 20) / (ts_mean(amount, 20) + 1)",
},
{
"name": "winner_rate_dev_20",
"desc": "获利盘均线偏离当前胜率相对过去20天平均胜率的偏离程度捕捉筹码情绪的边际超买/超卖",
"dsl": "winner_rate - ts_mean(winner_rate, 20)",
"name": "amount_skewness_20",
"desc": "成交额偏度(20日):正偏度意味着平时成交极度清淡,偶尔脉冲式放量。这种“死鱼”状态是经典的非流动性特征",
"dsl": "ts_skew(amount, 20)",
},
{
"name": "winner_rate_volatility",
"desc": "获利盘波动率过去20天胜率的波动率。波动率低且胜率高说明单边上涨极度稳健",
"dsl": "ts_std(winner_rate, 20)",
"name": "low_vol_days_20",
"desc": "流动性枯竭天数过去20天内成交额低于长期(60日)均值一半的极端缩量天数",
"dsl": "ts_count(amount < ts_mean(amount, 60) * 0.5, 20)",
},
{
"name": "smart_money_accumulation",
"desc": "潜在主力吸筹隐蔽指标胜率的60日时序分位数减去价格的时序分位数。值越大说明价平而获利盘增底部吸筹明显",
"dsl": "ts_rank(winner_rate, 60) - ts_rank(close, 60)",
"name": "liquidity_shock_momentum",
"desc": "流动性恶化动量:近期(5日)非流动性相较于长期的变化。正值代表流动性正在迅速干涸",
"dsl": "ts_mean(abs(pct_chg) / (amount + 1), 5) - ts_mean(abs(pct_chg) / (amount + 1), 20)",
},
# ==================== 第四类:量价与筹码交乘因子 ====================
# --- 维度 4: 非对称流动性类 (Asymmetric / Downside Illiquidity) ---
# 逻辑:买入时没问题,但“大跌时卖不出去(流动性丧失)”是最致命的风险,这部分风险被定价的权重最高。
{
"name": "winner_vol_corr_20",
"desc": "放量突破筹码密集区胜率与成交量的20日时序相关性正相关说明增量资金在主动解套上方筹码",
"dsl": "ts_corr(winner_rate, vol, 20)",
"name": "downside_illiq_20",
"desc": "下行非流动性:仅在股价下跌日计算的价格冲击。捕捉‘跌时没人接盘’的极端流动性折价",
"dsl": "ts_sum(where(change < 0, abs(change) / (amount + 1), 0), 20) / (ts_count(change < 0, 20) + 1)",
},
{
"name": "cost_base_momentum",
"desc": "成本重心上移换手率过去20天加权平均成本的变化幅度快速上移说明高位换手极其充分",
"dsl": "ts_pct_change(weight_avg, 20)",
"name": "upside_illiq_20",
"desc": "上行非流动性:仅在股价上涨日计算的价格冲击。捕捉‘涨时抛压极轻’的状态",
"dsl": "ts_sum(where(change > 0, abs(change) / (amount + 1), 0), 20) / (ts_count(change > 0, 20) + 1)",
},
{
"name": "bottom_cost_stability",
"desc": "底部坚如磐石因子底部5%成本的60天波动率相对于中位数的比值波动越小说明死筹越稳固",
"dsl": "ts_std(cost_5pct, 60) / cost_50pct",
"name": "illiq_asymmetry_20",
"desc": "非对称流动性比率下行流动性恶化程度除以加上行流动性恶化程度。该值远大于1说明下跌时发生严重踩踏股票本身必须折价预期收益率补偿极高",
"dsl": "(ts_sum(where(change < 0, abs(change)/(amount+1), 0), 20) + 1) / (ts_sum(where(change > 0, abs(change)/(amount+1), 0), 20) + 1)",
},
{
"name": "pivot_reversion",
"desc": "盈亏分界线乖离修复价格偏离50%分位点除以近20日价格标准差用于寻找超跌后的均值回归买点",
"dsl": "(close - cost_50pct) / ts_std(close, 20)",
},
{
"name": "chip_transition",
"desc": "强弱筹码切换度上方厚度与下方厚度差值的20日变化量。由正变负说明筹码彻底完成了自上而下的转移洗盘结束",
"dsl": "ts_delta((cost_85pct - cost_50pct) - (cost_50pct - cost_15pct), 20)",
"name": "pastor_stambaugh_proxy",
"desc": "Pastor-Stambaugh流动性贝塔代理收益率与滞后一期带有符号(涨跌)成交额的相关性。反映市场由于流动性短缺导致的价格过度反转现象",
"dsl": "ts_corr(change, sign(ts_delay(change, 1)) * ts_delay(amount, 1), 20)",
},
]

View File

@@ -84,3 +84,65 @@ class EnsembleQuantLoss(nn.Module):
total_loss += self.alpha * h_loss + (1.0 - self.alpha) * ic_loss
return total_loss / self.ensemble_size
class AsymmetricQuantLoss(nn.Module):
"""Asymmetric Quant Loss (非对称 Huber + IC)
保留全截面计算以维持稳定梯度的前提下,对多头关注的错误进行加权惩罚:
1. 过度高估烂股票 (买入陷阱) -> 加重惩罚
2. 严重低估好股票 (错失金股) -> 加重惩罚
"""
def __init__(self, alpha: float = 0.5, ensemble_size: int = 32):
super().__init__()
self.alpha = alpha
self.ensemble_size = ensemble_size
# 我们不使用内置的 HuberLoss而是手动写以支持 element-wise 加权
self.delta = 1.0 # Huber threshold (可以设得比较小,比如对于收益率设为 0.05)
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
batch_size = preds.shape[0]
total_loss = 0.0
if batch_size < 10:
# 极小 batch 退回简单 MSE
return ((preds - target.unsqueeze(1)) ** 2).mean()
target_mean = target.mean()
target_std = target.std(unbiased=False) + 1e-8
target_norm = (target - target_mean) / target_std
for i in range(self.ensemble_size):
pred_i = preds[:, i]
# --- 1. 非对称 Huber 损失 ---
error = pred_i - target
abs_error = torch.abs(error)
# 标准 Huber 计算
quadratic = torch.clamp(abs_error, max=self.delta)
linear = abs_error - quadratic
base_huber = 0.5 * quadratic**2 + self.delta * linear
# 【核心逻辑】:非对称权重
# 如果 target > 0 且 pred < target (错过了涨的好票) -> 权重 1.5
# 如果 target < 0 且 pred > 0 (把跌的票预测成了涨的,容易实盘买入踩雷) -> 权重 2.0
# 其他情况 (比如烂票预测得更烂) -> 权重 1.0 正常学习
weights = torch.ones_like(target)
weights[(target > 0) & (error < 0)] = 1.5
weights[(target < 0) & (pred_i > 0)] = 2.0
weighted_huber = (base_huber * weights).mean()
# --- 2. IC 损失 (维持全截面排序能力) ---
pred_mean = pred_i.mean()
pred_std = pred_i.std(unbiased=False) + 1e-8
pred_norm = (pred_i - pred_mean) / pred_std
ic = (pred_norm * target_norm).mean()
ic_loss = 1.0 - ic
total_loss += self.alpha * weighted_huber + (1.0 - self.alpha) * ic_loss
return total_loss / self.ensemble_size

View File

@@ -19,7 +19,10 @@ from tabm import TabM
from src.training.components.base import BaseModel
from src.training.components.models.cross_section_sampler import CrossSectionSampler
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
from src.training.components.models.ensemble_quant_loss import (
EnsembleQuantLoss,
AsymmetricQuantLoss,
)
from src.training.registry import register_model
@@ -235,8 +238,8 @@ class TabMModel(BaseModel):
optimizer, T_max=epochs, eta_min=1e-6
)
# 使用 EnsembleQuantLoss 替代 MSE
self.criterion = EnsembleQuantLoss(
# 使用 AsymmetricQuantLoss (非对称 Huber + IC)
self.criterion = AsymmetricQuantLoss(
alpha=self.params.get("loss_alpha", 0.5), ensemble_size=ensemble_size
)

250
tests/test_moneyflow.py Normal file
View File

@@ -0,0 +1,250 @@
"""Tests for api_moneyflow module.
Tests the moneyflow (个股资金流向) API wrapper.
"""
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
from src.data.api_wrappers.api_moneyflow import (
get_moneyflow,
sync_moneyflow,
preview_moneyflow_sync,
MoneyflowSync,
)
class TestMoneyflowAPI:
"""Test suite for moneyflow API wrapper."""
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
def test_get_moneyflow_by_date(self, mock_client_class):
"""Test fetching moneyflow data by date."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240115", "20240115"],
"buy_sm_vol": [10000, 20000],
"buy_sm_amount": [100.5, 200.5],
"sell_sm_vol": [8000, 15000],
"sell_sm_amount": [80.5, 150.5],
"buy_md_vol": [5000, 10000],
"buy_md_amount": [500.5, 1000.5],
"sell_md_vol": [4000, 8000],
"sell_md_amount": [400.5, 800.5],
"buy_lg_vol": [2000, 5000],
"buy_lg_amount": [2000.5, 5000.5],
"sell_lg_vol": [1500, 4000],
"sell_lg_amount": [1500.5, 4000.5],
"buy_elg_vol": [1000, 3000],
"buy_elg_amount": [5000.5, 15000.5],
"sell_elg_vol": [800, 2500],
"sell_elg_amount": [4000.5, 12500.5],
"net_mf_vol": [2700, 8000],
"net_mf_amount": [220.0, 550.0],
}
)
# Test
result = get_moneyflow(trade_date="20240115")
# Assert
assert not result.empty
assert len(result) == 2
assert "ts_code" in result.columns
assert "trade_date" in result.columns
assert "buy_sm_vol" in result.columns
assert "net_mf_amount" in result.columns
mock_client.query.assert_called_once_with("moneyflow", trade_date="20240115")
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
def test_get_moneyflow_by_stock(self, mock_client_class):
"""Test fetching moneyflow data by stock code."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ"],
"trade_date": ["20240114", "20240115"],
"buy_sm_vol": [10000, 11000],
"buy_sm_amount": [100.5, 110.5],
"sell_sm_vol": [8000, 8500],
"sell_sm_amount": [80.5, 85.5],
"buy_md_vol": [5000, 5500],
"buy_md_amount": [500.5, 550.5],
"sell_md_vol": [4000, 4200],
"sell_md_amount": [400.5, 420.5],
"buy_lg_vol": [2000, 2200],
"buy_lg_amount": [2000.5, 2200.5],
"sell_lg_vol": [1500, 1600],
"sell_lg_amount": [1500.5, 1600.5],
"buy_elg_vol": [1000, 1100],
"buy_elg_amount": [5000.5, 5500.5],
"sell_elg_vol": [800, 850],
"sell_elg_amount": [4000.5, 4250.5],
"net_mf_vol": [2700, 2950],
"net_mf_amount": [220.0, 245.0],
}
)
# Test
result = get_moneyflow(
ts_code="000001.SZ", start_date="20240101", end_date="20240131"
)
# Assert
assert not result.empty
assert len(result) == 2
mock_client.query.assert_called_once_with(
"moneyflow", ts_code="000001.SZ", start_date="20240101", end_date="20240131"
)
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
def test_get_moneyflow_empty_response(self, mock_client_class):
"""Test handling empty response."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame()
# Test
result = get_moneyflow(trade_date="20240101")
# Assert
assert result.empty
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
def test_get_moneyflow_with_shared_client(self, mock_client_class):
"""Test fetching with shared client for rate limiting."""
# Setup mock shared client
mock_shared_client = MagicMock()
mock_shared_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"buy_sm_vol": [10000],
"net_mf_amount": [220.0],
}
)
# Test with shared client
result = get_moneyflow(trade_date="20240115", client=mock_shared_client)
# Assert
assert not result.empty
mock_shared_client.query.assert_called_once_with(
"moneyflow", trade_date="20240115"
)
mock_client_class.assert_not_called() # Should not create new client
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
def test_get_moneyflow_date_column_rename(self, mock_client_class):
"""Test that 'date' column is renamed to 'trade_date'."""
# Setup mock with 'date' column (should be renamed)
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"date": ["20240115"], # Using 'date' instead of 'trade_date'
"buy_sm_vol": [10000],
"net_mf_amount": [220.0],
}
)
# Test
result = get_moneyflow(trade_date="20240115")
# Assert - date column should be renamed to trade_date
assert "trade_date" in result.columns
assert "date" not in result.columns
assert result["trade_date"].iloc[0] == "20240115"
class TestMoneyflowSync:
"""Test suite for MoneyflowSync class."""
def test_moneyflow_sync_class_attributes(self):
"""Test MoneyflowSync class has correct attributes."""
assert MoneyflowSync.table_name == "moneyflow"
assert MoneyflowSync.default_start_date == "20100101"
assert "ts_code" in MoneyflowSync.TABLE_SCHEMA
assert "trade_date" in MoneyflowSync.TABLE_SCHEMA
assert "buy_sm_vol" in MoneyflowSync.TABLE_SCHEMA
assert "net_mf_amount" in MoneyflowSync.TABLE_SCHEMA
assert MoneyflowSync.PRIMARY_KEY == ("ts_code", "trade_date")
@patch("src.data.api_wrappers.api_moneyflow.get_moneyflow")
@patch("src.data.api_wrappers.api_moneyflow.TushareClient")
def test_fetch_single_date(self, mock_client_class, mock_get_moneyflow):
"""Test fetch_single_date method."""
# Setup mock
mock_get_moneyflow.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
"buy_sm_vol": [10000],
"net_mf_amount": [220.0],
}
)
# Create sync instance
sync = MoneyflowSync()
# Test
result = sync.fetch_single_date("20240115")
# Assert
assert not result.empty
mock_get_moneyflow.assert_called_once_with(
trade_date="20240115", client=sync.client
)
@patch("src.data.api_wrappers.api_moneyflow.MoneyflowSync.sync_all")
def test_sync_moneyflow_function(self, mock_sync_all):
"""Test sync_moneyflow convenience function."""
# Setup mock
mock_sync_all.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240115"],
}
)
# Test
result = sync_moneyflow(start_date="20240101", end_date="20240131")
# Assert
assert not result.empty
mock_sync_all.assert_called_once_with(
start_date="20240101", end_date="20240131", force_full=False
)
@patch("src.data.api_wrappers.api_moneyflow.MoneyflowSync.preview_sync")
def test_preview_moneyflow_sync_function(self, mock_preview_sync):
"""Test preview_moneyflow_sync convenience function."""
# Setup mock
mock_preview_sync.return_value = {
"sync_needed": True,
"date_count": 31,
"start_date": "20240101",
"end_date": "20240131",
"estimated_records": 10000,
"sample_data": pd.DataFrame(),
"mode": "incremental",
}
# Test
result = preview_moneyflow_sync(start_date="20240101", end_date="20240131")
# Assert
assert result["sync_needed"] is True
assert result["date_count"] == 31
mock_preview_sync.assert_called_once_with(
start_date="20240101", end_date="20240131", force_full=False, sample_size=3
)