From 9e7d4241c6c4b0417b9b581bf37c049d16c7c523 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Fri, 3 Apr 2026 23:57:47 +0800 Subject: [PATCH] =?UTF-8?q?feat(data):=20=E6=B7=BB=E5=8A=A0=E4=B8=AA?= =?UTF-8?q?=E8=82=A1=E8=B5=84=E9=87=91=E6=B5=81=E5=90=91=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E5=B9=B6=E9=87=8D=E6=9E=84=E9=80=9F=E7=8E=87=E9=99=90=E5=88=B6?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=20-=20=E6=96=B0=E5=A2=9E=20moneyflow=20?= =?UTF-8?q?=E8=B5=84=E9=87=91=E6=B5=81=E5=90=91=E6=95=B0=E6=8D=AE=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E6=A8=A1=E5=9D=97=20-=20=E5=AE=9E=E7=8E=B0=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E7=BA=A7=E9=80=9F=E7=8E=87=E9=99=90=E5=88=B6=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=EF=BC=88sync=5Fconfig.py=EF=BC=89=20-=20=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=B5=81=E5=8A=A8=E6=80=A7=E7=9B=B8=E5=85=B3=E5=9B=A0?= =?UTF-8?q?=E5=AD=90=E5=AE=9A=E4=B9=89=20-=20=E6=B7=BB=E5=8A=A0=E9=9D=9E?= =?UTF-8?q?=E5=AF=B9=E7=A7=B0=E9=87=8F=E5=8C=96=E6=8D=9F=E5=A4=B1=E5=87=BD?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/sync_config.json | 12 + docs/api/api.md | 112 ++++- src/config/settings.py | 6 - src/data/api_wrappers/__init__.py | 25 + src/data/api_wrappers/api_moneyflow.py | 228 +++++++++ src/data/api_wrappers/base_sync.py | 12 +- .../financial_data/api_financial_sync.py | 6 + src/data/client.py | 79 +++- src/data/sync.py | 2 + src/data/sync_config.py | 431 ++++++++++++++++++ src/data/sync_registry.py | 7 + src/experiment/common.py | 314 +++++++------ src/experiment/learn_to_rank.py | 46 +- src/experiment/regression.py | 79 ++-- src/scripts/register_factors.py | 127 +++--- .../components/models/ensemble_quant_loss.py | 62 +++ src/training/components/models/tabm_model.py | 9 +- tests/test_moneyflow.py | 250 ++++++++++ 18 files changed, 1473 insertions(+), 334 deletions(-) create mode 100644 config/sync_config.json create mode 100644 src/data/api_wrappers/api_moneyflow.py create mode 100644 src/data/sync_config.py create mode 100644 tests/test_moneyflow.py diff --git a/config/sync_config.json b/config/sync_config.json new file mode 100644 index 0000000..0267983 --- /dev/null +++ b/config/sync_config.json @@ -0,0 +1,12 @@ +{ + "default": { + "rate_limit": 150, + "threads": 1 + }, + "interfaces": { + "pro_bar": { + "rate_limit": 500, + "threads": 8 + } + } +} \ No newline at end of file diff --git a/docs/api/api.md b/docs/api/api.md index d34f294..76cdff4 100644 --- a/docs/api/api.md +++ b/docs/api/api.md @@ -749,4 +749,114 @@ df = pro.cyq_perf(ts_code='600000.SH', start_date='20220101', end_date='20220429 73 600000.SH 20220107 0.72 12.16 8.60 11.36 9.89 7.59 74 600000.SH 20220106 0.72 12.16 8.60 11.36 9.89 3.92 75 600000.SH 20220105 0.72 12.16 8.60 11.36 9.89 5.65 -76 600000.SH 20220104 0.72 12.16 8.60 11.36 9.89 3.93 \ No newline at end of file +76 600000.SH 20220104 0.72 12.16 8.60 11.36 9.89 3.93 + + +个股资金流向 +接口:moneyflow,可以通过数据工具调试和查看数据。 +描述:获取沪深A股票资金流向数据,分析大单小单成交情况,用于判别资金动向,数据开始于2010年。 +限量:单次最大提取6000行记录,总量不限制 +积分:用户需要至少2000积分才可以调取,基础积分有流量控制,积分越多权限越大,请自行提高积分,具体请参阅积分获取办法 + + + +输入参数 + +名称 类型 必选 描述 +ts_code str N 股票代码 (股票和时间参数至少输入一个) +trade_date str N 交易日期 +start_date str N 开始日期 +end_date str N 结束日期 + + +输出参数 + +名称 类型 默认显示 描述 +ts_code str Y TS代码 +trade_date str Y 交易日期 +buy_sm_vol int Y 小单买入量(手) +buy_sm_amount float Y 小单买入金额(万元) +sell_sm_vol int Y 小单卖出量(手) +sell_sm_amount float Y 小单卖出金额(万元) +buy_md_vol int Y 中单买入量(手) +buy_md_amount float Y 中单买入金额(万元) +sell_md_vol int Y 中单卖出量(手) +sell_md_amount float Y 中单卖出金额(万元) +buy_lg_vol int Y 大单买入量(手) +buy_lg_amount float Y 大单买入金额(万元) +sell_lg_vol int Y 大单卖出量(手) +sell_lg_amount float Y 大单卖出金额(万元) +buy_elg_vol int Y 特大单买入量(手) +buy_elg_amount float Y 特大单买入金额(万元) +sell_elg_vol int Y 特大单卖出量(手) +sell_elg_amount float Y 特大单卖出金额(万元) +net_mf_vol int Y 净流入量(手) +net_mf_amount float Y 净流入额(万元) + +各类别统计规则如下: +小单:5万以下 中单:5万~20万 大单:20万~100万 特大单:成交额>=100万 ,数据基于主动买卖单统计 + + + +接口示例 + + +pro = ts.pro_api('your token') + +#获取单日全部股票数据 +df = pro.moneyflow(trade_date='20190315') + +#获取单个股票数据 +df = pro.moneyflow(ts_code='002149.SZ', start_date='20190115', end_date='20190315') + + + +数据示例 + + ts_code trade_date buy_sm_vol buy_sm_amount sell_sm_vol \ +0 000779.SZ 20190315 11377 1150.17 11100 +1 000933.SZ 20190315 94220 4803.22 105924 +2 002270.SZ 20190315 43979 2330.96 45893 +3 002319.SZ 20190315 21502 2952.88 17155 +4 002604.SZ 20190315 31944 607.35 58667 +5 300065.SZ 20190315 16048 2294.71 16425 +6 600062.SH 20190315 55439 7432.13 65765 +7 002735.SZ 20190315 3220 797.10 4598 +8 300196.SZ 20190315 12534 1286.02 8340 +9 300350.SZ 20190315 15346 1120.12 18853 +10 600193.SH 20190315 12183 503.73 19576 +11 002866.SZ 20190315 16932 2213.68 16037 +12 300481.SZ 20190315 21386 4275.33 21863 +13 600527.SH 20190315 115462 2975.44 79272 +14 603980.SH 20190315 13957 1924.69 11718 +15 600658.SH 20190315 71767 4826.73 69535 +16 600812.SH 20190315 26140 1247.47 34923 +17 002013.SZ 20190315 170234 12286.02 148509 +18 600789.SH 20190315 211012 21644.56 150598 +19 601636.SH 20190315 70737 3117.43 68073 +20 000807.SZ 20190315 129668 6361.06 122077 + +... + + sell_sm_amount buy_md_vol buy_md_amount sell_md_vol sell_md_amount \ +0 1122.97 13012 1316.72 14812 1498.90 +1 5411.72 135976 6935.40 154023 7863.00 +2 2435.98 57679 3059.15 47279 2507.55 +3 2358.68 27245 3742.52 26708 3670.05 +4 1114.40 69897 1327.41 41108 781.19 +5 2353.34 31232 4472.05 26771 3834.95 +6 8817.75 86617 11615.40 79551 10676.99 +7 1140.61 4602 1141.61 2730 676.72 +8 855.45 9401 963.72 10478 1074.32 +9 1380.31 24224 1770.90 21588 1577.92 +10 812.58 28696 1185.17 31087 1286.11 +11 2100.70 19197 2511.62 20269 2650.56 +12 4379.14 31692 6345.72 32873 6578.36 +13 2046.54 107103 2763.00 84883 2191.24 +14 1619.33 14621 2019.41 14528 2005.69 +15 4691.29 92788 6232.80 93273 6280.13 +16 1669.97 38812 1855.78 39211 1874.05 +17 10726.22 154979 11190.69 164090 11855.76 +18 15479.08 269470 27660.18 236958 24338.36 +19 3000.73 90416 3984.68 115162 5075.50 +20 5999.66 175692 8627.77 178044 8751.08 \ No newline at end of file diff --git a/src/config/settings.py b/src/config/settings.py index 047dbf5..cbead37 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -29,12 +29,6 @@ class Settings(BaseSettings): root_path: str = "" # 项目根路径,默认自动检测 data_path: str = "data" # 数据存储路径,相对于 root_path - # API 速率限制(每分钟请求数) - rate_limit: int = 300 - - # 同步工作线程数 - threads: int = 10 - # 数据库配置(可选,用于未来扩展) database_host: str = "localhost" database_port: int = 5432 diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py index 1e0c02d..b59b7d4 100644 --- a/src/data/api_wrappers/__init__.py +++ b/src/data/api_wrappers/__init__.py @@ -14,6 +14,7 @@ Available APIs: - api_stock_st: ST stock list (ST股票列表) - api_stk_limit: Stock limit price (每日涨跌停价格) - api_cyq_perf: CYQ performance (每日筹码及胜率) + - api_moneyflow: Money flow (个股资金流向) Example: >>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic @@ -21,6 +22,7 @@ Example: >>> from src.data.api_wrappers import get_stock_st, sync_stock_st >>> from src.data.api_wrappers import get_stk_limit, sync_stk_limit >>> from src.data.api_wrappers import get_cyq_perf, sync_cyq_perf + >>> from src.data.api_wrappers import get_moneyflow, sync_moneyflow >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') >>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131') >>> daily_basic = get_daily_basic(trade_date='20240101') @@ -30,6 +32,7 @@ Example: >>> stock_st = get_stock_st(trade_date='20240101') >>> stk_limit = get_stk_limit(trade_date='20240101') >>> cyq_perf = get_cyq_perf(trade_date='20240115') + >>> moneyflow = get_moneyflow(trade_date='20240115') """ from src.data.api_wrappers.api_daily_basic import ( @@ -77,6 +80,12 @@ from src.data.api_wrappers.api_cyq_perf import ( preview_cyq_perf_sync, CyqPerfSync, ) +from src.data.api_wrappers.api_moneyflow import ( + get_moneyflow, + sync_moneyflow, + preview_moneyflow_sync, + MoneyflowSync, +) __all__ = [ # Daily market data @@ -129,6 +138,11 @@ __all__ = [ "sync_cyq_perf", "preview_cyq_perf_sync", "CyqPerfSync", + # Moneyflow (个股资金流向) + "get_moneyflow", + "sync_moneyflow", + "preview_moneyflow_sync", + "MoneyflowSync", ] # ============================================================================= @@ -223,6 +237,17 @@ try: order=60, ) + # 9. Moneyflow - 个股资金流向 + from src.data.api_wrappers.api_moneyflow import MoneyflowSync + + sync_registry.register_class( + name="moneyflow", + sync_class=MoneyflowSync, + display_name="个股资金流向", + description="沪深A股资金流向数据,分析大单小单成交情况(2010年开始)", + order=70, + ) + except ImportError: # sync_registry 可能不存在(首次导入),忽略 pass diff --git a/src/data/api_wrappers/api_moneyflow.py b/src/data/api_wrappers/api_moneyflow.py new file mode 100644 index 0000000..794f400 --- /dev/null +++ b/src/data/api_wrappers/api_moneyflow.py @@ -0,0 +1,228 @@ +"""个股资金流向 (Moneyflow) interface. + +Fetch A-share stock money flow data from Tushare. +This interface retrieves fund flow data analyzing large and small order transactions. +Data starts from 2010. +""" + +import pandas as pd +from typing import Optional + +from src.data.client import TushareClient +from src.data.api_wrappers.base_sync import DateBasedSync + + +def get_moneyflow( + trade_date: Optional[str] = None, + ts_code: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + client: Optional[TushareClient] = None, +) -> pd.DataFrame: + """Fetch individual stock money flow data from Tushare. + + This interface retrieves fund flow data analyzing large and small order + transactions for A-share stocks. Data starts from 2010. + + Order size classification: + - Small orders (小单): < 50,000 yuan + - Medium orders (中单): 50,000 - 200,000 yuan + - Large orders (大单): 200,000 - 1,000,000 yuan + - Extra large orders (特大单): >= 1,000,000 yuan + + Args: + trade_date: Specific trade date in YYYYMMDD format + ts_code: Stock code filter (optional, e.g., '000001.SZ') + start_date: Start date for date range query (YYYYMMDD format) + end_date: End date for date range query (YYYYMMDD format) + client: Optional TushareClient instance for shared rate limiting. + If None, creates a new client. For concurrent sync operations, + pass a shared client to ensure proper rate limiting. + + Returns: + pd.DataFrame with columns: + - ts_code: Stock code + - trade_date: Trade date (YYYYMMDD) + - buy_sm_vol: Small order buy volume (hands) + - buy_sm_amount: Small order buy amount (10k yuan) + - sell_sm_vol: Small order sell volume (hands) + - sell_sm_amount: Small order sell amount (10k yuan) + - buy_md_vol: Medium order buy volume (hands) + - buy_md_amount: Medium order buy amount (10k yuan) + - sell_md_vol: Medium order sell volume (hands) + - sell_md_amount: Medium order sell amount (10k yuan) + - buy_lg_vol: Large order buy volume (hands) + - buy_lg_amount: Large order buy amount (10k yuan) + - sell_lg_vol: Large order sell volume (hands) + - sell_lg_amount: Large order sell amount (10k yuan) + - buy_elg_vol: Extra large order buy volume (hands) + - buy_elg_amount: Extra large order buy amount (10k yuan) + - sell_elg_vol: Extra large order sell volume (hands) + - sell_elg_amount: Extra large order sell amount (10k yuan) + - net_mf_vol: Net money flow volume (hands) + - net_mf_amount: Net money flow amount (10k yuan) + + Example: + >>> # Get all stocks' money flow for a single date + >>> data = get_moneyflow(trade_date='20240115') + >>> + >>> # Get date range data for a specific stock + >>> data = get_moneyflow(ts_code='000001.SZ', start_date='20240101', end_date='20240131') + >>> + >>> # Get specific stock on specific date + >>> data = get_moneyflow(ts_code='000001.SZ', trade_date='20240115') + """ + client = client or TushareClient() + + # Build parameters + params = {} + if trade_date: + params["trade_date"] = trade_date + if ts_code: + params["ts_code"] = ts_code + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + + # Fetch data using moneyflow API + data = client.query("moneyflow", **params) + + # Rename date column if needed + if "date" in data.columns: + data = data.rename(columns={"date": "trade_date"}) + + return data + + +class MoneyflowSync(DateBasedSync): + """个股资金流向数据批量同步管理器,支持全量/增量同步。 + + 继承自 DateBasedSync,使用按日期并发获取数据。 + 数据始于 2010 年。 + + Example: + >>> sync = MoneyflowSync() + >>> results = sync.sync_all() # 增量同步 + >>> results = sync.sync_all(force_full=True) # 全量同步 + >>> preview = sync.preview_sync() # 预览 + """ + + table_name = "moneyflow" + default_start_date = "20100101" + + # 表结构定义 - 使用 Tushare API 原始字段名 + TABLE_SCHEMA = { + "ts_code": "VARCHAR(16) NOT NULL", + "trade_date": "DATE NOT NULL", + "buy_sm_vol": "INTEGER", # 小单买入量(手) + "buy_sm_amount": "DOUBLE", # 小单买入金额(万元) + "sell_sm_vol": "INTEGER", # 小单卖出量(手) + "sell_sm_amount": "DOUBLE", # 小单卖出金额(万元) + "buy_md_vol": "INTEGER", # 中单买入量(手) + "buy_md_amount": "DOUBLE", # 中单买入金额(万元) + "sell_md_vol": "INTEGER", # 中单卖出量(手) + "sell_md_amount": "DOUBLE", # 中单卖出金额(万元) + "buy_lg_vol": "INTEGER", # 大单买入量(手) + "buy_lg_amount": "DOUBLE", # 大单买入金额(万元) + "sell_lg_vol": "INTEGER", # 大单卖出量(手) + "sell_lg_amount": "DOUBLE", # 大单卖出金额(万元) + "buy_elg_vol": "INTEGER", # 特大单买入量(手) + "buy_elg_amount": "DOUBLE", # 特大单买入金额(万元) + "sell_elg_vol": "INTEGER", # 特大单卖出量(手) + "sell_elg_amount": "DOUBLE", # 特大单卖出金额(万元) + "net_mf_vol": "INTEGER", # 净流入量(手) + "net_mf_amount": "DOUBLE", # 净流入额(万元) + } + + # 索引定义 + TABLE_INDEXES = [ + ("idx_moneyflow_date_code", ["trade_date", "ts_code"]), + ] + + # 主键定义 + PRIMARY_KEY = ("ts_code", "trade_date") + + def fetch_single_date(self, trade_date: str) -> pd.DataFrame: + """获取单日所有股票的资金流向数据。 + + Args: + trade_date: 交易日期(YYYYMMDD) + + Returns: + 包含当日所有股票资金流向数据的 DataFrame + """ + return get_moneyflow(trade_date=trade_date, client=self.client) + + +def sync_moneyflow( + start_date: Optional[str] = None, + end_date: Optional[str] = None, + force_full: bool = False, +) -> pd.DataFrame: + """同步个股资金流向数据到 DuckDB,支持智能增量同步。 + + 逻辑: + - 若表不存在:创建表 + 复合索引 (trade_date, ts_code) + 全量同步 + - 若表存在:从 last_date + 1 开始增量同步 + + Args: + start_date: 起始日期(YYYYMMDD 格式,默认全量从 20100101,增量从 last_date+1) + end_date: 结束日期(YYYYMMDD 格式,默认为今天) + force_full: 若为 True,强制从 20100101 完整重载 + + Returns: + 包含同步数据的 pd.DataFrame + + Example: + >>> # 首次同步(从 20100101 全量加载) + >>> result = sync_moneyflow() + >>> + >>> # 后续同步(增量 - 仅新数据) + >>> result = sync_moneyflow() + >>> + >>> # 强制完整重载 + >>> result = sync_moneyflow(force_full=True) + >>> + >>> # 手动指定日期范围 + >>> result = sync_moneyflow(start_date='20240101', end_date='20240131') + """ + sync_manager = MoneyflowSync() + return sync_manager.sync_all( + start_date=start_date, + end_date=end_date, + force_full=force_full, + ) + + +def preview_moneyflow_sync( + start_date: Optional[str] = None, + end_date: Optional[str] = None, + force_full: bool = False, + sample_size: int = 3, +) -> dict: + """预览个股资金流向数据同步数据量和样本(不实际同步)。 + + Args: + start_date: 手动指定起始日期(覆盖自动检测) + end_date: 手动指定结束日期(默认为今天) + force_full: 若为 True,预览全量同步(从 20100101) + sample_size: 预览天数(默认: 3) + + Returns: + 包含预览信息的字典 + + Example: + >>> # 预览将要同步的内容 + >>> preview = preview_moneyflow_sync() + >>> + >>> # 预览全量同步 + >>> preview = preview_moneyflow_sync(force_full=True) + """ + sync_manager = MoneyflowSync() + return sync_manager.preview_sync( + start_date=start_date, + end_date=end_date, + force_full=force_full, + sample_size=sample_size, + ) diff --git a/src/data/api_wrappers/base_sync.py b/src/data/api_wrappers/base_sync.py index 59bd0b5..d15a346 100644 --- a/src/data/api_wrappers/base_sync.py +++ b/src/data/api_wrappers/base_sync.py @@ -35,8 +35,8 @@ from tqdm import tqdm from src.data.client import TushareClient from src.data.storage import ThreadSafeStorage, Storage from src.data.sync_logger import SyncLogManager +from src.data.sync_config import get_threads from src.data.utils import get_today_date, get_next_date -from src.config.settings import get_settings from src.data.api_wrappers.api_trade_cal import ( get_first_trading_day, get_last_trading_day, @@ -63,7 +63,6 @@ class BaseDataSync(ABC): table_name: str = "" # 子类必须覆盖 DEFAULT_START_DATE = "20180101" - DEFAULT_MAX_WORKERS = get_settings().threads # 表结构定义(子类可覆盖) # 格式: {"column_name": "SQL_TYPE", ...} @@ -81,11 +80,14 @@ class BaseDataSync(ABC): """初始化同步管理器。 Args: - max_workers: 工作线程数(默认从配置读取) + max_workers: 工作线程数(默认从配置读取,根据 table_name 获取) """ self.storage = ThreadSafeStorage() - self.client = TushareClient() - self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS + # 使用 table_name 作为接口名称初始化客户端,获取特定的速率限制 + self.client = TushareClient(interface_name=self.table_name or "default") + # 根据表名获取线程数配置 + default_workers = get_threads(self.table_name or "default") + self.max_workers = max_workers or default_workers self._stop_flag = threading.Event() self._stop_flag.set() # 初始为未停止状态 self._cached_data: Optional[pd.DataFrame] = None diff --git a/src/data/api_wrappers/financial_data/api_financial_sync.py b/src/data/api_wrappers/financial_data/api_financial_sync.py index 7f9145f..9f254ca 100644 --- a/src/data/api_wrappers/financial_data/api_financial_sync.py +++ b/src/data/api_wrappers/financial_data/api_financial_sync.py @@ -37,6 +37,7 @@ from typing import List, Optional from src.data.sync_logger import SyncLogManager +from src.data.sync_config import get_rate_limit, get_threads from src.data.api_wrappers.financial_data.api_income import ( IncomeQuarterSync, sync_income, @@ -151,7 +152,12 @@ def sync_financial( sync_func = config["sync_func"] display_name = config["display_name"] + # 获取当前接口的 rate limit + rate_limit = get_rate_limit(data_type) + threads = get_threads(data_type) + print(f"\n[{display_name}] 开始同步...") + print(f" Rate Limit: {rate_limit}/min, Threads: {threads}") try: result = sync_func(force_full=force_full, dry_run=dry_run) diff --git a/src/data/client.py b/src/data/client.py index 36125d8..d2be47f 100644 --- a/src/data/client.py +++ b/src/data/client.py @@ -2,23 +2,29 @@ import time import pandas as pd +import tushare as ts + from typing import Optional from src.data.rate_limiter import TokenBucketRateLimiter +from src.data.sync_config import get_rate_limit from src.config.settings import get_settings class TushareClient: """Tushare API client with rate limiting and retry.""" - # 类级别共享限流器(确保所有实例共享同一个限流器) - _shared_limiter: Optional[TokenBucketRateLimiter] = None - _cached_rate_limit: int = 0 # 缓存上次使用的 rate_limit + # 类级别限流器缓存(按接口名称存储) + _limiter_cache: dict[str, TokenBucketRateLimiter] = {} + _cached_rate_limits: dict[str, int] = {} - def __init__(self, token: Optional[str] = None): + def __init__( + self, token: Optional[str] = None, interface_name: Optional[str] = None + ): """Initialize client. Args: token: Tushare API token (auto-loaded from config if not provided) + interface_name: 接口名称,用于获取特定的速率限制配置 """ cfg = get_settings() token = token or cfg.tushare_token @@ -28,32 +34,57 @@ class TushareClient: self.token = token self.config = cfg + self.interface_name = interface_name or "default" - # 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器) - # 检查是否需要重新创建限流器(配置发生变化时) - if ( - TushareClient._shared_limiter is None - or TushareClient._cached_rate_limit != cfg.rate_limit - ): - # 首次创建或配置变更:重新初始化共享限流器 - TushareClient._shared_limiter = TokenBucketRateLimiter( - rate_limit=cfg.rate_limit, - ) - TushareClient._cached_rate_limit = cfg.rate_limit - min_interval = 60.0 / cfg.rate_limit - print( - f"[TushareClient] Initialized shared rate limiter: rate={cfg.rate_limit}/min, interval={min_interval:.2f}s" - ) - # 复用共享限流器 - self.rate_limiter = TushareClient._shared_limiter + # 获取或创建限流器 + self.rate_limiter = self._get_or_create_limiter(self.interface_name) self._api = None + def _get_or_create_limiter(self, interface_name: str) -> TokenBucketRateLimiter: + """获取或创建指定接口的限流器。 + + 如果接口的 rate_limit 配置发生变化,会重新创建限流器。 + + Args: + interface_name: 接口名称 + + Returns: + TokenBucketRateLimiter 实例 + """ + # 获取当前配置的 rate_limit + current_rate_limit = get_rate_limit(interface_name) + + # 检查是否需要创建新的限流器 + if ( + interface_name not in TushareClient._limiter_cache + or TushareClient._cached_rate_limits.get(interface_name) + != current_rate_limit + ): + # 创建新的限流器 + TushareClient._limiter_cache[interface_name] = TokenBucketRateLimiter( + rate_limit=current_rate_limit, + ) + TushareClient._cached_rate_limits[interface_name] = current_rate_limit + min_interval = 60.0 / current_rate_limit + print( + f"[TushareClient] Initialized rate limiter for '{interface_name}': " + f"rate={current_rate_limit}/min, interval={min_interval:.2f}s" + ) + + return TushareClient._limiter_cache[interface_name] + + def get_current_rate_limit(self) -> int: + """获取当前接口的速率限制。 + + Returns: + 每分钟请求数 + """ + return get_rate_limit(self.interface_name) + def _get_api(self): """Get Tushare API instance.""" if self._api is None: - import tushare as ts - self._api = ts.pro_api(self.token) return self._api @@ -80,8 +111,6 @@ class TushareClient: raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout") try: - import tushare as ts - # pro_bar uses ts.pro_bar() instead of api.query() if api_name == "pro_bar": # pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor diff --git a/src/data/sync.py b/src/data/sync.py index 62708ab..13cb1a4 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -13,6 +13,7 @@ - api_stock_st.py: ST股票列表同步 (StockSTSync 类) - api_stk_limit.py: 涨跌停价格同步 (StkLimitSync 类) - api_cyq_perf.py: 筹码分布数据同步 (CyqPerfSync 类) + - api_moneyflow.py: 个股资金流向同步 (MoneyflowSync 类) - api_stock_basic.py: 股票基本信息同步 - api_trade_cal.py: 交易日历同步 @@ -82,6 +83,7 @@ def sync_all_data( 6. stock_st: ST股票列表 7. stk_limit: 每日涨跌停价格 8. cyq_perf: 每日筹码及胜率 + 9. moneyflow: 个股资金流向 新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码, 无需修改本函数。 diff --git a/src/data/sync_config.py b/src/data/sync_config.py new file mode 100644 index 0000000..5b50645 --- /dev/null +++ b/src/data/sync_config.py @@ -0,0 +1,431 @@ +"""同步配置管理模块。 + +该模块提供统一的同步配置管理,支持: +- 默认的线程数和速率限制配置 +- 每个接口单独的速率限制配置 +- 从 JSON 配置文件加载配置 + +配置文件路径: config/sync_config.json + +配置示例: + { + "default": { + "threads": 10, + "rate_limit": 300 + }, + "interfaces": { + "pro_bar": { + "rate_limit": 200 + }, + "daily_basic": { + "rate_limit": 150 + }, + "income": { + "rate_limit": 100 + } + } + } +""" + +import json +import threading +from dataclasses import dataclass, field, asdict +from pathlib import Path +from typing import Optional, Dict, Any +from functools import lru_cache + + +@dataclass +class InterfaceConfig: + """单个接口的配置。""" + + rate_limit: Optional[int] = None + threads: Optional[int] = None + + +@dataclass +class SyncConfig: + """同步配置类。""" + + default: InterfaceConfig = field( + default_factory=lambda: InterfaceConfig( + rate_limit=300, + threads=10, + ) + ) + interfaces: Dict[str, InterfaceConfig] = field(default_factory=dict) + + def get_interface_config(self, interface_name: str) -> InterfaceConfig: + """获取指定接口的配置。 + + 如果接口有特定配置则返回,否则返回默认配置。 + + Args: + interface_name: 接口名称 + + Returns: + 接口配置对象 + """ + if interface_name in self.interfaces: + interface_config = self.interfaces[interface_name] + # 合并默认值(特定配置覆盖默认) + return InterfaceConfig( + rate_limit=interface_config.rate_limit + if interface_config.rate_limit is not None + else self.default.rate_limit, + threads=interface_config.threads + if interface_config.threads is not None + else self.default.threads, + ) + return self.default + + def get_rate_limit(self, interface_name: str) -> int: + """获取指定接口的速率限制。 + + Args: + interface_name: 接口名称 + + Returns: + 每分钟请求数 + """ + config = self.get_interface_config(interface_name) + return config.rate_limit or self.default.rate_limit or 300 + + def get_threads(self, interface_name: str) -> int: + """获取指定接口的线程数。 + + Args: + interface_name: 接口名称 + + Returns: + 线程数 + """ + config = self.get_interface_config(interface_name) + return config.threads or self.default.threads or 10 + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式。""" + return { + "default": { + "rate_limit": self.default.rate_limit, + "threads": self.default.threads, + }, + "interfaces": { + name: { + "rate_limit": config.rate_limit, + "threads": config.threads, + } + for name, config in self.interfaces.items() + }, + } + + +class SyncConfigManager: + """同步配置管理器。 + + 负责加载、保存和管理同步配置。 + 支持从 JSON 文件加载配置,提供线程安全的单例访问。 + """ + + _instance: Optional["SyncConfigManager"] = None + _lock = threading.Lock() + + DEFAULT_CONFIG = { + "default": { + "threads": 10, + "rate_limit": 300, + }, + "interfaces": { + # 高频接口(每秒可请求多次) + "trade_cal": {"rate_limit": 500}, + "stock_basic": {"rate_limit": 500}, + "bak_basic": {"rate_limit": 200}, + "stock_st": {"rate_limit": 200}, + "stk_limit": {"rate_limit": 200}, + "cyq_perf": {"rate_limit": 150}, + "moneyflow": {"rate_limit": 150}, + "daily_basic": {"rate_limit": 150}, + "pro_bar": {"rate_limit": 200}, + # 低频财务接口 + "income": {"rate_limit": 100}, + "balance": {"rate_limit": 100}, + "cashflow": {"rate_limit": 100}, + "fina_indicator": {"rate_limit": 100}, + "namechange": {"rate_limit": 200}, + }, + } + + def __new__(cls, config_path: Optional[str] = None) -> "SyncConfigManager": + """单例模式确保全局只有一个配置管理器实例。""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config_path: Optional[str] = None): + """初始化配置管理器。 + + Args: + config_path: 配置文件路径,默认使用项目根目录下的 config/sync_config.json + """ + # 避免重复初始化 + if self._initialized: + return + + if config_path is None: + # 默认路径:项目根目录/config/sync_config.json + project_root = Path(__file__).parent.parent.parent + self.config_path = project_root / "config" / "sync_config.json" + else: + self.config_path = Path(config_path) + + self._config: Optional[SyncConfig] = None + self._config_lock = threading.RLock() + self._initialized = True + + def _load_config(self) -> SyncConfig: + """从文件加载配置。 + + 如果配置文件不存在,则创建默认配置。 + + Returns: + 同步配置对象 + """ + with self._config_lock: + if self._config is not None: + return self._config + + if self.config_path.exists(): + try: + with open(self.config_path, "r", encoding="utf-8") as f: + data = json.load(f) + + default_data = data.get("default", {}) + default_config = InterfaceConfig( + rate_limit=default_data.get("rate_limit", 300), + threads=default_data.get("threads", 10), + ) + + interfaces = {} + for name, config_data in data.get("interfaces", {}).items(): + interfaces[name] = InterfaceConfig( + rate_limit=config_data.get("rate_limit"), + threads=config_data.get("threads"), + ) + + self._config = SyncConfig( + default=default_config, + interfaces=interfaces, + ) + print(f"[SyncConfig] Loaded config from {self.config_path}") + except Exception as e: + print(f"[SyncConfig] Error loading config: {e}, using defaults") + self._config = self._create_default_config() + else: + self._config = self._create_default_config() + self.save_config() + + return self._config + + def _create_default_config(self) -> SyncConfig: + """创建默认配置。""" + default = InterfaceConfig( + rate_limit=self.DEFAULT_CONFIG["default"]["rate_limit"], + threads=self.DEFAULT_CONFIG["default"]["threads"], + ) + + interfaces = {} + for name, config_data in self.DEFAULT_CONFIG["interfaces"].items(): + interfaces[name] = InterfaceConfig( + rate_limit=config_data.get("rate_limit"), + threads=config_data.get("threads"), + ) + + return SyncConfig(default=default, interfaces=interfaces) + + def save_config(self) -> None: + """保存当前配置到文件。""" + with self._config_lock: + if self._config is None: + self._config = self._create_default_config() + + try: + self.config_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.config_path, "w", encoding="utf-8") as f: + json.dump(self._config.to_dict(), f, indent=2, ensure_ascii=False) + print(f"[SyncConfig] Saved config to {self.config_path}") + except Exception as e: + print(f"[SyncConfig] Error saving config: {e}") + + def get_config(self) -> SyncConfig: + """获取当前配置。 + + Returns: + 同步配置对象 + """ + with self._config_lock: + if self._config is None: + return self._load_config() + return self._config + + def get_rate_limit(self, interface_name: str) -> int: + """获取指定接口的速率限制。 + + Args: + interface_name: 接口名称 + + Returns: + 每分钟请求数 + """ + return self.get_config().get_rate_limit(interface_name) + + def get_threads(self, interface_name: str) -> int: + """获取指定接口的线程数。 + + Args: + interface_name: 接口名称 + + Returns: + 线程数 + """ + return self.get_config().get_threads(interface_name) + + def update_interface_config( + self, + interface_name: str, + rate_limit: Optional[int] = None, + threads: Optional[int] = None, + ) -> None: + """更新指定接口的配置。 + + Args: + interface_name: 接口名称 + rate_limit: 新的速率限制(None 表示不修改) + threads: 新的线程数(None 表示不修改) + """ + with self._config_lock: + config = self.get_config() + + if interface_name not in config.interfaces: + config.interfaces[interface_name] = InterfaceConfig() + + interface_config = config.interfaces[interface_name] + if rate_limit is not None: + interface_config.rate_limit = rate_limit + if threads is not None: + interface_config.threads = threads + + self.save_config() + + def reset_to_defaults(self) -> None: + """重置为默认配置。""" + with self._config_lock: + self._config = self._create_default_config() + self.save_config() + print("[SyncConfig] Reset to default configuration") + + def list_interfaces(self) -> Dict[str, Dict[str, Any]]: + """列出所有接口的配置。 + + Returns: + 接口配置字典 {接口名: {rate_limit, threads}} + """ + config = self.get_config() + result = {} + + # 添加默认配置 + result["__default__"] = { + "rate_limit": config.default.rate_limit, + "threads": config.default.threads, + } + + # 添加各接口配置 + all_interfaces = set(config.interfaces.keys()) | set( + self.DEFAULT_CONFIG["interfaces"].keys() + ) + for name in sorted(all_interfaces): + iface_config = config.get_interface_config(name) + result[name] = { + "rate_limit": iface_config.rate_limit, + "threads": iface_config.threads, + } + + return result + + +# 全局配置管理器实例 +_config_manager: Optional[SyncConfigManager] = None +_config_manager_lock = threading.Lock() + + +def get_sync_config_manager() -> SyncConfigManager: + """获取同步配置管理器单例。 + + Returns: + SyncConfigManager 实例 + """ + global _config_manager + if _config_manager is None: + with _config_manager_lock: + if _config_manager is None: + _config_manager = SyncConfigManager() + return _config_manager + + +def get_sync_config() -> SyncConfig: + """获取同步配置。 + + Returns: + SyncConfig 实例 + """ + return get_sync_config_manager().get_config() + + +def get_rate_limit(interface_name: str) -> int: + """获取指定接口的速率限制。 + + Args: + interface_name: 接口名称 + + Returns: + 每分钟请求数 + """ + return get_sync_config_manager().get_rate_limit(interface_name) + + +def get_threads(interface_name: str) -> int: + """获取指定接口的线程数。 + + Args: + interface_name: 接口名称 + + Returns: + 线程数 + """ + return get_sync_config_manager().get_threads(interface_name) + + +def print_sync_config() -> None: + """打印当前同步配置(用于调试)。""" + manager = get_sync_config_manager() + configs = manager.list_interfaces() + + print("\n" + "=" * 60) + print("[SyncConfig] Current Configuration") + print("=" * 60) + + default_config = configs.pop("__default__", {}) + print( + f"Default: rate_limit={default_config.get('rate_limit')}/min, threads={default_config.get('threads')}" + ) + print("-" * 60) + + for name, config in sorted(configs.items()): + marker = "*" if name in manager.DEFAULT_CONFIG["interfaces"] else " " + print( + f" [{marker}] {name:20s}: rate_limit={config['rate_limit']:3d}/min, threads={config['threads']:2d}" + ) + + print("=" * 60) diff --git a/src/data/sync_registry.py b/src/data/sync_registry.py index ead58a5..f5784a9 100644 --- a/src/data/sync_registry.py +++ b/src/data/sync_registry.py @@ -28,6 +28,8 @@ from collections import OrderedDict import pandas as pd +from src.data.sync_config import get_rate_limit, get_threads + @dataclass class SyncTask: @@ -221,7 +223,12 @@ class SyncRegistry: print("=" * 60) for idx, task in enumerate(tasks, 1): + # 获取当前接口的配置 + rate_limit = get_rate_limit(task.name) + threads = get_threads(task.name) + print(f"\n[{idx}/{total}] Syncing {task.display_name}...") + print(f" Rate Limit: {rate_limit}/min, Threads: {threads}") if task.description: print(f" Description: {task.description}") diff --git a/src/experiment/common.py b/src/experiment/common.py index b6fdc23..559061e 100644 --- a/src/experiment/common.py +++ b/src/experiment/common.py @@ -123,133 +123,133 @@ SELECTED_FACTORS = [ "GTJA_alpha048", "GTJA_alpha049", "GTJA_alpha050", - # "GTJA_alpha051", - # "GTJA_alpha052", - # "GTJA_alpha053", - # "GTJA_alpha054", - # "GTJA_alpha056", - # "GTJA_alpha057", - # "GTJA_alpha058", - # "GTJA_alpha059", - # "GTJA_alpha060", - # "GTJA_alpha061", - # "GTJA_alpha062", - # "GTJA_alpha063", - # "GTJA_alpha064", - # "GTJA_alpha065", - # "GTJA_alpha066", - # "GTJA_alpha067", - # "GTJA_alpha068", - # "GTJA_alpha070", - # "GTJA_alpha071", - # "GTJA_alpha072", - # "GTJA_alpha073", - # "GTJA_alpha074", - # "GTJA_alpha076", - # "GTJA_alpha077", - # "GTJA_alpha078", - # "GTJA_alpha079", - # "GTJA_alpha080", - # "GTJA_alpha081", - # "GTJA_alpha082", - # "GTJA_alpha083", - # "GTJA_alpha084", - # "GTJA_alpha085", - # "GTJA_alpha086", - # "GTJA_alpha087", - # "GTJA_alpha088", - # "GTJA_alpha089", - # "GTJA_alpha090", - # "GTJA_alpha091", - # "GTJA_alpha092", - # "GTJA_alpha093", - # "GTJA_alpha094", - # "GTJA_alpha095", - # "GTJA_alpha096", - # "GTJA_alpha097", - # "GTJA_alpha098", - # "GTJA_alpha099", - # "GTJA_alpha100", - # "GTJA_alpha101", - # "GTJA_alpha102", - # "GTJA_alpha103", - # "GTJA_alpha104", - # "GTJA_alpha105", - # "GTJA_alpha106", - # "GTJA_alpha107", - # "GTJA_alpha108", - # "GTJA_alpha109", - # "GTJA_alpha110", - # "GTJA_alpha111", - # "GTJA_alpha112", - # # "GTJA_alpha113", - # "GTJA_alpha114", - # "GTJA_alpha115", - # "GTJA_alpha117", - # "GTJA_alpha118", - # "GTJA_alpha119", - # "GTJA_alpha120", - # # "GTJA_alpha121", - # "GTJA_alpha122", - # "GTJA_alpha123", - # "GTJA_alpha124", - # "GTJA_alpha125", - # "GTJA_alpha126", - # "GTJA_alpha127", - # "GTJA_alpha128", - # "GTJA_alpha129", - # "GTJA_alpha130", - # "GTJA_alpha131", - # "GTJA_alpha132", - # "GTJA_alpha133", - # "GTJA_alpha134", - # "GTJA_alpha135", - # "GTJA_alpha136", - # # "GTJA_alpha138", - # "GTJA_alpha139", - # # "GTJA_alpha140", - # "GTJA_alpha141", - # "GTJA_alpha142", - # "GTJA_alpha145", - # # "GTJA_alpha146", - # "GTJA_alpha148", - # "GTJA_alpha150", - # "GTJA_alpha151", - # "GTJA_alpha152", - # "GTJA_alpha153", - # "GTJA_alpha154", - # "GTJA_alpha155", - # "GTJA_alpha156", - # "GTJA_alpha157", - # "GTJA_alpha158", - # "GTJA_alpha159", - # "GTJA_alpha160", - # "GTJA_alpha161", - # # "GTJA_alpha162", - # "GTJA_alpha163", - # "GTJA_alpha164", - # # "GTJA_alpha165", - # "GTJA_alpha166", - # "GTJA_alpha167", - # "GTJA_alpha168", - # "GTJA_alpha169", - # "GTJA_alpha170", - # "GTJA_alpha171", - # "GTJA_alpha173", - # "GTJA_alpha174", - # "GTJA_alpha175", - # "GTJA_alpha176", - # "GTJA_alpha177", - # "GTJA_alpha178", - # "GTJA_alpha179", - # "GTJA_alpha180", - # # "GTJA_alpha183", - # "GTJA_alpha184", - # "GTJA_alpha185", - # "GTJA_alpha187", - # "GTJA_alpha188", - # "GTJA_alpha189", - # "GTJA_alpha191", + "GTJA_alpha051", + "GTJA_alpha052", + "GTJA_alpha053", + "GTJA_alpha054", + "GTJA_alpha056", + "GTJA_alpha057", + "GTJA_alpha058", + "GTJA_alpha059", + "GTJA_alpha060", + "GTJA_alpha061", + "GTJA_alpha062", + "GTJA_alpha063", + "GTJA_alpha064", + "GTJA_alpha065", + "GTJA_alpha066", + "GTJA_alpha067", + "GTJA_alpha068", + "GTJA_alpha070", + "GTJA_alpha071", + "GTJA_alpha072", + "GTJA_alpha073", + "GTJA_alpha074", + "GTJA_alpha076", + "GTJA_alpha077", + "GTJA_alpha078", + "GTJA_alpha079", + "GTJA_alpha080", + "GTJA_alpha081", + "GTJA_alpha082", + "GTJA_alpha083", + "GTJA_alpha084", + "GTJA_alpha085", + "GTJA_alpha086", + "GTJA_alpha087", + "GTJA_alpha088", + "GTJA_alpha089", + "GTJA_alpha090", + "GTJA_alpha091", + "GTJA_alpha092", + "GTJA_alpha093", + "GTJA_alpha094", + "GTJA_alpha095", + "GTJA_alpha096", + "GTJA_alpha097", + "GTJA_alpha098", + "GTJA_alpha099", + "GTJA_alpha100", + "GTJA_alpha101", + "GTJA_alpha102", + "GTJA_alpha103", + "GTJA_alpha104", + "GTJA_alpha105", + "GTJA_alpha106", + "GTJA_alpha107", + "GTJA_alpha108", + "GTJA_alpha109", + "GTJA_alpha110", + "GTJA_alpha111", + "GTJA_alpha112", + # "GTJA_alpha113", + "GTJA_alpha114", + "GTJA_alpha115", + "GTJA_alpha117", + "GTJA_alpha118", + "GTJA_alpha119", + "GTJA_alpha120", + # "GTJA_alpha121", + "GTJA_alpha122", + "GTJA_alpha123", + "GTJA_alpha124", + "GTJA_alpha125", + "GTJA_alpha126", + "GTJA_alpha127", + "GTJA_alpha128", + "GTJA_alpha129", + "GTJA_alpha130", + "GTJA_alpha131", + "GTJA_alpha132", + "GTJA_alpha133", + "GTJA_alpha134", + "GTJA_alpha135", + "GTJA_alpha136", + # "GTJA_alpha138", + "GTJA_alpha139", + # "GTJA_alpha140", + "GTJA_alpha141", + "GTJA_alpha142", + "GTJA_alpha145", + # "GTJA_alpha146", + "GTJA_alpha148", + "GTJA_alpha150", + "GTJA_alpha151", + "GTJA_alpha152", + "GTJA_alpha153", + "GTJA_alpha154", + "GTJA_alpha155", + "GTJA_alpha156", + "GTJA_alpha157", + "GTJA_alpha158", + "GTJA_alpha159", + "GTJA_alpha160", + "GTJA_alpha161", + # "GTJA_alpha162", + "GTJA_alpha163", + "GTJA_alpha164", + # "GTJA_alpha165", + "GTJA_alpha166", + "GTJA_alpha167", + "GTJA_alpha168", + "GTJA_alpha169", + "GTJA_alpha170", + "GTJA_alpha171", + "GTJA_alpha173", + "GTJA_alpha174", + "GTJA_alpha175", + "GTJA_alpha176", + "GTJA_alpha177", + "GTJA_alpha178", + "GTJA_alpha179", + "GTJA_alpha180", + # "GTJA_alpha183", + "GTJA_alpha184", + "GTJA_alpha185", + "GTJA_alpha187", + "GTJA_alpha188", + "GTJA_alpha189", + "GTJA_alpha191", "chip_dispersion_90", "chip_dispersion_70", "cost_skewness", @@ -270,12 +270,27 @@ SELECTED_FACTORS = [ "bottom_cost_stability", "pivot_reversion", "chip_transition", + # "amivest_liq_20", + # "atr_price_impact", + # "hui_heubel_ratio", + # "corwin_schultz_spread_20", + # "roll_spread_20", + # "gibbs_effective_spread", + # "overnight_illiq_20", + # "illiq_volatility_20", + # "amount_cv_20", + # "amount_skewness_20", + # "low_vol_days_20", + # "liquidity_shock_momentum", + # "downside_illiq_20", + # "upside_illiq_20", + # "illiq_asymmetry_20", + # "pastor_stambaugh_proxy" ] # 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子) FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"} - # ============================================================================= # Label 配置(统一绑定 label_name 和 label_dsl) # ============================================================================= @@ -308,11 +323,11 @@ def get_label_factor(label_name: str) -> dict: # 辅助函数 # ============================================================================= def register_factors( - engine: FactorEngine, - selected_factors: List[str], - factor_definitions: dict, - label_factor: dict, - excluded_factors: Optional[List[str]] = None, + engine: FactorEngine, + selected_factors: List[str], + factor_definitions: dict, + label_factor: dict, + excluded_factors: Optional[List[str]] = None, ) -> List[str]: """注册因子。 @@ -393,11 +408,11 @@ def register_factors( def prepare_data( - engine: FactorEngine, - feature_cols: List[str], - start_date: str, - end_date: str, - label_name: str, + engine: FactorEngine, + feature_cols: List[str], + start_date: str, + end_date: str, + label_name: str, ) -> pl.DataFrame: """准备数据。 @@ -455,11 +470,11 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series: """ # 代码筛选(排除创业板、科创板、北交所) code_filter = ( - ~df["ts_code"].str.starts_with("30") # 排除创业板 - & ~df["ts_code"].str.starts_with("68") # 排除科创板 - & ~df["ts_code"].str.starts_with("8") # 排除北交所 - & ~df["ts_code"].str.starts_with("9") # 排除北交所 - & ~df["ts_code"].str.starts_with("4") # 排除北交所 + ~df["ts_code"].str.starts_with("30") # 排除创业板 + & ~df["ts_code"].str.starts_with("68") # 排除科创板 + & ~df["ts_code"].str.starts_with("8") # 排除北交所 + & ~df["ts_code"].str.starts_with("9") # 排除北交所 + & ~df["ts_code"].str.starts_with("4") # 排除北交所 ) # 在已筛选的股票中,选取流通市值最小的500只 @@ -474,7 +489,6 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series: # 定义筛选所需的基础列 STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"] - # ============================================================================= # 输出配置 # ============================================================================= @@ -518,7 +532,7 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str: def get_model_save_path( - model_type: str, + model_type: str, ) -> Optional[str]: """生成模型保存路径。 @@ -544,11 +558,11 @@ def get_model_save_path( def save_model_with_factors( - model, - model_path: str, - selected_factors: list[str], - factor_definitions: dict, - fitted_processors: list | None = None, + model, + model_path: str, + selected_factors: list[str], + factor_definitions: dict, + fitted_processors: list | None = None, ) -> str: """保存模型及关联的因子信息和处理器。 diff --git a/src/experiment/learn_to_rank.py b/src/experiment/learn_to_rank.py index 5eca623..7e34e2c 100644 --- a/src/experiment/learn_to_rank.py +++ b/src/experiment/learn_to_rank.py @@ -54,31 +54,27 @@ N_QUANTILES = 20 # 排除的因子列表 EXCLUDED_FACTORS = [ - "GTJA_alpha010", - "GTJA_alpha005", - "GTJA_alpha002", - "GTJA_alpha027", - "GTJA_alpha051", - "GTJA_alpha044", - "GTJA_alpha041", - "GTJA_alpha131", - "GTJA_alpha103", - "GTJA_alpha087", - "GTJA_alpha093", - "GTJA_alpha092", - "GTJA_alpha073", - "GTJA_alpha127", - "GTJA_alpha117", - "GTJA_alpha124", - "GTJA_alpha162", - "GTJA_alpha177", - "GTJA_alpha188", - "smart_money_accumulation", - "GTJA_alpha014", - "GTJA_alpha056", - "GTJA_alpha085", - "GTJA_alpha154", - "GTJA_alpha141", + 'active_market_cap', + 'close_vwap_deviation', + 'sharpe_ratio_20', + 'upper_shadow_ratio', + 'volume_ratio_5_20', + 'GTJA_alpha090', + 'GTJA_alpha084', + 'GTJA_alpha066', + 'GTJA_alpha150', + 'GTJA_alpha148', + 'GTJA_alpha106', + 'GTJA_alpha109', + 'GTJA_alpha108', + 'GTJA_alpha176', + 'GTJA_alpha169', + 'GTJA_alpha156', + 'chip_dispersion_70', + 'winner_rate_cs_rank', + 'atr_price_impact', + 'low_vol_days_20', + 'liquidity_shock_momentum', ] # LambdaRank 模型参数配置 diff --git a/src/experiment/regression.py b/src/experiment/regression.py index 045debc..be204ae 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -52,55 +52,36 @@ TRAINING_TYPE = "regression" # 排除的因子列表 EXCLUDED_FACTORS = [ - "GTJA_alpha036", - "GTJA_alpha032", - "GTJA_alpha010", - "GTJA_alpha005", - "CP", - "BP", - "debt_to_equity", - "current_ratio", - "GTJA_alpha002", - "GTJA_alpha027", - "GTJA_alpha064", - "GTJA_alpha062", - "GTJA_alpha043", - "GTJA_alpha044", - "GTJA_alpha120", - "GTJA_alpha117", - "GTJA_alpha103", - "GTJA_alpha104", - "GTJA_alpha105", - "GTJA_alpha073", - "GTJA_alpha077", - "GTJA_alpha085", - "GTJA_alpha090", - "GTJA_alpha087", - "GTJA_alpha083", - "GTJA_alpha092", - "GTJA_alpha133", - "GTJA_alpha131", - "GTJA_alpha126", - "GTJA_alpha124", - "GTJA_alpha162", - "GTJA_alpha164", - "GTJA_alpha157", - "GTJA_alpha177", - "price_to_avg_cost", - "cost_skewness", - "GTJA_alpha191", - "GTJA_alpha180", - "history_position", - "bottom_profit", - "mean_median_dev", - "smart_money_accumulation", - "GTJA_alpha013", - "GTJA_alpha099", - "GTJA_alpha107", - "GTJA_alpha119", - "GTJA_alpha141", - "GTJA_alpha130", - "GTJA_alpha173", + 'GTJA_alpha016', + 'volatility_20', + 'current_ratio', + 'GTJA_alpha001', + 'GTJA_alpha141', + 'GTJA_alpha129', + 'GTJA_alpha164', + 'amivest_liq_20', + 'GTJA_alpha012', + 'debt_to_equity', + 'turnover_deviation', + 'GTJA_alpha073', + 'GTJA_alpha043', + 'GTJA_alpha032', + 'GTJA_alpha028', + 'GTJA_alpha090', + 'GTJA_alpha108', + 'GTJA_alpha105', + 'GTJA_alpha091', + 'GTJA_alpha119', + 'GTJA_alpha104', + 'GTJA_alpha163', + 'GTJA_alpha157', + 'cost_skewness', + 'GTJA_alpha176', + 'chip_transition', + 'amount_skewness_20', + 'GTJA_alpha148', + 'mean_median_dev', + 'downside_illiq_20', ] # 模型参数配置 diff --git a/src/scripts/register_factors.py b/src/scripts/register_factors.py index 76108e2..102655b 100644 --- a/src/scripts/register_factors.py +++ b/src/scripts/register_factors.py @@ -34,112 +34,99 @@ from src.config.settings import get_settings # ============================================================================ FACTORS: List[Dict[str, Any]] =[ - # ==================== 第一类:筹码集中度与离散度因子 ==================== { - "name": "chip_dispersion_90", - "desc": "90%筹码离散度:衡量市场90%持仓筹码的宽度,值越小表示筹码越高度集中(单峰密集),往往是洗盘结束的前兆", - "dsl": "(cost_95pct - cost_5pct) / (cost_95pct + cost_5pct)", + "name": "amihud_illiq_20", + "desc": "Amihud非流动性指标(20日):绝对收益率/成交额。该值越大,说明少量的资金就能砸盘或拉升,流动性极差,隐含极高的风险补偿预期", + "dsl": "ts_mean(abs(pct_chg) / (amount + 1), 20)", }, { - "name": "chip_dispersion_70", - "desc": "70%核心筹码离散度:剔除极端的底部死筹和高位套牢盘,反映中间70%主流资金的成本集中度", - "dsl": "(cost_85pct - cost_15pct) / (cost_85pct + cost_15pct)", + "name": "amivest_liq_20", + "desc": "Amivest流动性指标(20日):Amihud的倒数变种,衡量推动1%价格变化需要的资金量。值越低,流动性溢价越高", + "dsl": "ts_mean(amount / (abs(pct_chg) + 0.001), 20)", }, { - "name": "cost_skewness", - "desc": "筹码偏度:反映筹码分布的不对称性。大于1说明上方套牢盘拖尾严重,小于1说明下方获利盘雄厚", - "dsl": "(cost_95pct - cost_50pct) / (cost_50pct - cost_5pct)", + "name": "atr_price_impact", + "desc": "真实波幅价格冲击(20日):以ATR(真实波幅)代替绝对收益,剔除跳空影响后的真实交易冲击", + "dsl": "ts_atr(high, low, close, 20) / close / (ts_mean(amount, 20) + 1)", }, { - "name": "dispersion_change_20", - "desc": "筹码集中度近期变化率:过去20天筹码宽度的变化比例,持续下降说明主力正在暗中吸筹", - "dsl": "ts_pct_change((cost_95pct - cost_5pct) / cost_50pct, 20)", + "name": "hui_heubel_ratio", + "desc": "Hui-Heubel流动性比率:利用高低价区间占均价的比例,除以成交额占比,捕捉阶段性区间的深度非流动性", + "dsl": "((ts_max(high, 20) - ts_min(low, 20)) / ts_min(low, 20)) / (ts_mean(amount, 20) + 1)", }, - # ==================== 第二类:筹码相对位置与压力/支撑因子 ==================== + # --- 维度 2: 交易摩擦与隐形价差类 (Spread Proxies) --- + # 逻辑:A股没有Tick级买卖价差数据,通常用日线的高低价或自协方差来推导隐形买卖价差。价差越大,交易成本越高,持有需高收益补偿。 { - "name": "price_to_avg_cost", - "desc": "整体浮盈比例:当前价格相对加权平均成本的溢价率。高溢价有均值回归压力,负溢价代表超跌", - "dsl": "(close - weight_avg) / weight_avg", + "name": "corwin_schultz_spread_20", + "desc": "Corwin-Schultz买卖价差代理:利用每日(最高价-最低价)/收盘价的均值衡量交易摩擦,摩擦越大的股票往往带有小盘股超额收益", + "dsl": "ts_mean((high - low) / close, 20)", }, { - "name": "price_to_median_cost", - "desc": "中位数成本偏离度:价格相对于50%分位点(绝对半数人持仓价)的偏离,向上突破通常是右侧买点", - "dsl": "(close - cost_50pct) / cost_50pct", + "name": "roll_spread_20", + "desc": "Roll买卖价差代理(20日):经典微观结构模型,计算相邻两日收益率的负协方差的平方根,反映做市商的隐形报价跳跃", + "dsl": "sqrt(max_(-ts_cov(change, ts_delay(change, 1), 20), 0))", }, { - "name": "mean_median_dev", - "desc": "均值中位数背离:均值显著大于中位数说明高位筹码堆积,上涨阻力大", - "dsl": "(weight_avg - cost_50pct) / cost_50pct", + "name": "gibbs_effective_spread", + "desc": "有效价差代理:使用日内振幅减去隔夜跳空幅度后的纯日内摩擦成本", + "dsl": "ts_mean(((high - low) - abs(open - ts_delay(close, 1))) / close, 20)", }, { - "name": "trap_pressure", - "desc": "高位套牢盘压力指数:当前价格距离上方95%高位套牢成本的距离。距离越大,反弹的真空期阻力越小", - "dsl": "(cost_95pct - close) / close", - }, - { - "name": "bottom_profit", - "desc": "底部支撑底仓利润率:当前价格距离底部5%筹码的利润空间。暴跌时大于0说明底仓极度稳定", - "dsl": "(close - cost_5pct) / cost_5pct", - }, - { - "name": "history_position", - "desc": "历史区间分位点:当前价格在个股上市以来历史最高点和最低点之间的相对位置", - "dsl": "(close - his_low) / (his_high - his_low)", + "name": "overnight_illiq_20", + "desc": "隔夜非流动性:开盘价相对于昨日收盘的跳空幅度,除以昨日成交额。隔夜极易跳空的股票具有夜间流动性溢价", + "dsl": "ts_mean(abs(open - ts_delay(close, 1)) / (ts_delay(amount, 1) + 1), 20)", }, - # ==================== 第三类:胜率相关的动量与反转因子 ==================== + # --- 维度 3: 流动性风险与枯竭类 (Liquidity Risk & Depletion) --- + # 逻辑:投资者不仅讨厌“平时流动性差”,更讨厌“流动性极其不稳定”。 { - "name": "winner_rate_surge_5", - "desc": "获利盘短期爆发力:胜率在过去5天内的变化值,急剧上升是极强的动量做多信号", - "dsl": "ts_delta(winner_rate, 5)", + "name": "illiq_volatility_20", + "desc": "Amihud非流动性的波动率(20日):衡量价格冲击的不确定性(即流动性风险本身)。波动越大的股越容易踩踏", + "dsl": "ts_std(abs(pct_chg) / (amount + 1), 20)", }, { - "name": "winner_rate_cs_rank", - "desc": "获利盘高位反转信号:全市场胜率截面排名,极端高胜率往往面临多头踩踏的获利了结压力(反转因子)", - "dsl": "cs_rank(winner_rate)", + "name": "amount_cv_20", + "desc": "成交额变异系数(20日):成交额的波动率除以均值。反映股票被市场关注的极度不稳定性", + "dsl": "ts_std(amount, 20) / (ts_mean(amount, 20) + 1)", }, { - "name": "winner_rate_dev_20", - "desc": "获利盘均线偏离:当前胜率相对过去20天平均胜率的偏离程度,捕捉筹码情绪的边际超买/超卖", - "dsl": "winner_rate - ts_mean(winner_rate, 20)", + "name": "amount_skewness_20", + "desc": "成交额偏度(20日):正偏度意味着平时成交极度清淡,偶尔脉冲式放量。这种“死鱼”状态是经典的非流动性特征", + "dsl": "ts_skew(amount, 20)", }, { - "name": "winner_rate_volatility", - "desc": "获利盘波动率:过去20天胜率的波动率。波动率低且胜率高说明单边上涨极度稳健", - "dsl": "ts_std(winner_rate, 20)", + "name": "low_vol_days_20", + "desc": "流动性枯竭天数:过去20天内,成交额低于长期(60日)均值一半的极端缩量天数", + "dsl": "ts_count(amount < ts_mean(amount, 60) * 0.5, 20)", }, { - "name": "smart_money_accumulation", - "desc": "潜在主力吸筹隐蔽指标:胜率的60日时序分位数减去价格的时序分位数。值越大说明‘价平而获利盘增’,底部吸筹明显", - "dsl": "ts_rank(winner_rate, 60) - ts_rank(close, 60)", + "name": "liquidity_shock_momentum", + "desc": "流动性恶化动量:近期(5日)非流动性相较于长期的变化。正值代表流动性正在迅速干涸", + "dsl": "ts_mean(abs(pct_chg) / (amount + 1), 5) - ts_mean(abs(pct_chg) / (amount + 1), 20)", }, - # ==================== 第四类:量价与筹码交乘因子 ==================== + # --- 维度 4: 非对称流动性类 (Asymmetric / Downside Illiquidity) --- + # 逻辑:买入时没问题,但“大跌时卖不出去(流动性丧失)”是最致命的风险,这部分风险被定价的权重最高。 { - "name": "winner_vol_corr_20", - "desc": "放量突破筹码密集区:胜率与成交量的20日时序相关性,正相关说明增量资金在主动解套上方筹码", - "dsl": "ts_corr(winner_rate, vol, 20)", + "name": "downside_illiq_20", + "desc": "下行非流动性:仅在股价下跌日计算的价格冲击。捕捉‘跌时没人接盘’的极端流动性折价", + "dsl": "ts_sum(where(change < 0, abs(change) / (amount + 1), 0), 20) / (ts_count(change < 0, 20) + 1)", }, { - "name": "cost_base_momentum", - "desc": "成本重心上移换手率:过去20天加权平均成本的变化幅度,快速上移说明高位换手极其充分", - "dsl": "ts_pct_change(weight_avg, 20)", + "name": "upside_illiq_20", + "desc": "上行非流动性:仅在股价上涨日计算的价格冲击。捕捉‘涨时抛压极轻’的状态", + "dsl": "ts_sum(where(change > 0, abs(change) / (amount + 1), 0), 20) / (ts_count(change > 0, 20) + 1)", }, { - "name": "bottom_cost_stability", - "desc": "底部坚如磐石因子:底部5%成本的60天波动率相对于中位数的比值,波动越小说明死筹越稳固", - "dsl": "ts_std(cost_5pct, 60) / cost_50pct", + "name": "illiq_asymmetry_20", + "desc": "非对称流动性比率:下行流动性恶化程度除以加上行流动性恶化程度。该值远大于1说明下跌时发生严重踩踏,股票本身必须折价(预期收益率补偿极高)", + "dsl": "(ts_sum(where(change < 0, abs(change)/(amount+1), 0), 20) + 1) / (ts_sum(where(change > 0, abs(change)/(amount+1), 0), 20) + 1)", }, { - "name": "pivot_reversion", - "desc": "盈亏分界线乖离修复:价格偏离50%分位点除以近20日价格标准差,用于寻找超跌后的均值回归买点", - "dsl": "(close - cost_50pct) / ts_std(close, 20)", - }, - { - "name": "chip_transition", - "desc": "强弱筹码切换度:上方厚度与下方厚度差值的20日变化量。由正变负说明筹码彻底完成了自上而下的转移(洗盘结束)", - "dsl": "ts_delta((cost_85pct - cost_50pct) - (cost_50pct - cost_15pct), 20)", + "name": "pastor_stambaugh_proxy", + "desc": "Pastor-Stambaugh流动性贝塔代理:收益率与滞后一期带有符号(涨跌)成交额的相关性。反映市场由于流动性短缺导致的价格过度反转现象", + "dsl": "ts_corr(change, sign(ts_delay(change, 1)) * ts_delay(amount, 1), 20)", }, ] diff --git a/src/training/components/models/ensemble_quant_loss.py b/src/training/components/models/ensemble_quant_loss.py index a7fe564..843e2ed 100644 --- a/src/training/components/models/ensemble_quant_loss.py +++ b/src/training/components/models/ensemble_quant_loss.py @@ -84,3 +84,65 @@ class EnsembleQuantLoss(nn.Module): total_loss += self.alpha * h_loss + (1.0 - self.alpha) * ic_loss return total_loss / self.ensemble_size + + +class AsymmetricQuantLoss(nn.Module): + """Asymmetric Quant Loss (非对称 Huber + IC) + + 保留全截面计算以维持稳定梯度的前提下,对多头关注的错误进行加权惩罚: + 1. 过度高估烂股票 (买入陷阱) -> 加重惩罚 + 2. 严重低估好股票 (错失金股) -> 加重惩罚 + """ + + def __init__(self, alpha: float = 0.5, ensemble_size: int = 32): + super().__init__() + self.alpha = alpha + self.ensemble_size = ensemble_size + # 我们不使用内置的 HuberLoss,而是手动写以支持 element-wise 加权 + self.delta = 1.0 # Huber threshold (可以设得比较小,比如对于收益率设为 0.05) + + def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + batch_size = preds.shape[0] + total_loss = 0.0 + + if batch_size < 10: + # 极小 batch 退回简单 MSE + return ((preds - target.unsqueeze(1)) ** 2).mean() + + target_mean = target.mean() + target_std = target.std(unbiased=False) + 1e-8 + target_norm = (target - target_mean) / target_std + + for i in range(self.ensemble_size): + pred_i = preds[:, i] + + # --- 1. 非对称 Huber 损失 --- + error = pred_i - target + abs_error = torch.abs(error) + + # 标准 Huber 计算 + quadratic = torch.clamp(abs_error, max=self.delta) + linear = abs_error - quadratic + base_huber = 0.5 * quadratic**2 + self.delta * linear + + # 【核心逻辑】:非对称权重 + # 如果 target > 0 且 pred < target (错过了涨的好票) -> 权重 1.5 + # 如果 target < 0 且 pred > 0 (把跌的票预测成了涨的,容易实盘买入踩雷) -> 权重 2.0 + # 其他情况 (比如烂票预测得更烂) -> 权重 1.0 正常学习 + weights = torch.ones_like(target) + weights[(target > 0) & (error < 0)] = 1.5 + weights[(target < 0) & (pred_i > 0)] = 2.0 + + weighted_huber = (base_huber * weights).mean() + + # --- 2. IC 损失 (维持全截面排序能力) --- + pred_mean = pred_i.mean() + pred_std = pred_i.std(unbiased=False) + 1e-8 + pred_norm = (pred_i - pred_mean) / pred_std + + ic = (pred_norm * target_norm).mean() + ic_loss = 1.0 - ic + + total_loss += self.alpha * weighted_huber + (1.0 - self.alpha) * ic_loss + + return total_loss / self.ensemble_size diff --git a/src/training/components/models/tabm_model.py b/src/training/components/models/tabm_model.py index f8063a5..cece5a0 100644 --- a/src/training/components/models/tabm_model.py +++ b/src/training/components/models/tabm_model.py @@ -19,7 +19,10 @@ from tabm import TabM from src.training.components.base import BaseModel from src.training.components.models.cross_section_sampler import CrossSectionSampler -from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss +from src.training.components.models.ensemble_quant_loss import ( + EnsembleQuantLoss, + AsymmetricQuantLoss, +) from src.training.registry import register_model @@ -235,8 +238,8 @@ class TabMModel(BaseModel): optimizer, T_max=epochs, eta_min=1e-6 ) - # 使用 EnsembleQuantLoss 替代 MSE - self.criterion = EnsembleQuantLoss( + # 使用 AsymmetricQuantLoss (非对称 Huber + IC) + self.criterion = AsymmetricQuantLoss( alpha=self.params.get("loss_alpha", 0.5), ensemble_size=ensemble_size ) diff --git a/tests/test_moneyflow.py b/tests/test_moneyflow.py new file mode 100644 index 0000000..a711842 --- /dev/null +++ b/tests/test_moneyflow.py @@ -0,0 +1,250 @@ +"""Tests for api_moneyflow module. + +Tests the moneyflow (个股资金流向) API wrapper. +""" + +import pytest +import pandas as pd +from unittest.mock import patch, MagicMock + +from src.data.api_wrappers.api_moneyflow import ( + get_moneyflow, + sync_moneyflow, + preview_moneyflow_sync, + MoneyflowSync, +) + + +class TestMoneyflowAPI: + """Test suite for moneyflow API wrapper.""" + + @patch("src.data.api_wrappers.api_moneyflow.TushareClient") + def test_get_moneyflow_by_date(self, mock_client_class): + """Test fetching moneyflow data by date.""" + # Setup mock + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ", "000002.SZ"], + "trade_date": ["20240115", "20240115"], + "buy_sm_vol": [10000, 20000], + "buy_sm_amount": [100.5, 200.5], + "sell_sm_vol": [8000, 15000], + "sell_sm_amount": [80.5, 150.5], + "buy_md_vol": [5000, 10000], + "buy_md_amount": [500.5, 1000.5], + "sell_md_vol": [4000, 8000], + "sell_md_amount": [400.5, 800.5], + "buy_lg_vol": [2000, 5000], + "buy_lg_amount": [2000.5, 5000.5], + "sell_lg_vol": [1500, 4000], + "sell_lg_amount": [1500.5, 4000.5], + "buy_elg_vol": [1000, 3000], + "buy_elg_amount": [5000.5, 15000.5], + "sell_elg_vol": [800, 2500], + "sell_elg_amount": [4000.5, 12500.5], + "net_mf_vol": [2700, 8000], + "net_mf_amount": [220.0, 550.0], + } + ) + + # Test + result = get_moneyflow(trade_date="20240115") + + # Assert + assert not result.empty + assert len(result) == 2 + assert "ts_code" in result.columns + assert "trade_date" in result.columns + assert "buy_sm_vol" in result.columns + assert "net_mf_amount" in result.columns + mock_client.query.assert_called_once_with("moneyflow", trade_date="20240115") + + @patch("src.data.api_wrappers.api_moneyflow.TushareClient") + def test_get_moneyflow_by_stock(self, mock_client_class): + """Test fetching moneyflow data by stock code.""" + # Setup mock + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ", "000001.SZ"], + "trade_date": ["20240114", "20240115"], + "buy_sm_vol": [10000, 11000], + "buy_sm_amount": [100.5, 110.5], + "sell_sm_vol": [8000, 8500], + "sell_sm_amount": [80.5, 85.5], + "buy_md_vol": [5000, 5500], + "buy_md_amount": [500.5, 550.5], + "sell_md_vol": [4000, 4200], + "sell_md_amount": [400.5, 420.5], + "buy_lg_vol": [2000, 2200], + "buy_lg_amount": [2000.5, 2200.5], + "sell_lg_vol": [1500, 1600], + "sell_lg_amount": [1500.5, 1600.5], + "buy_elg_vol": [1000, 1100], + "buy_elg_amount": [5000.5, 5500.5], + "sell_elg_vol": [800, 850], + "sell_elg_amount": [4000.5, 4250.5], + "net_mf_vol": [2700, 2950], + "net_mf_amount": [220.0, 245.0], + } + ) + + # Test + result = get_moneyflow( + ts_code="000001.SZ", start_date="20240101", end_date="20240131" + ) + + # Assert + assert not result.empty + assert len(result) == 2 + mock_client.query.assert_called_once_with( + "moneyflow", ts_code="000001.SZ", start_date="20240101", end_date="20240131" + ) + + @patch("src.data.api_wrappers.api_moneyflow.TushareClient") + def test_get_moneyflow_empty_response(self, mock_client_class): + """Test handling empty response.""" + # Setup mock + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame() + + # Test + result = get_moneyflow(trade_date="20240101") + + # Assert + assert result.empty + + @patch("src.data.api_wrappers.api_moneyflow.TushareClient") + def test_get_moneyflow_with_shared_client(self, mock_client_class): + """Test fetching with shared client for rate limiting.""" + # Setup mock shared client + mock_shared_client = MagicMock() + mock_shared_client.query.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "trade_date": ["20240115"], + "buy_sm_vol": [10000], + "net_mf_amount": [220.0], + } + ) + + # Test with shared client + result = get_moneyflow(trade_date="20240115", client=mock_shared_client) + + # Assert + assert not result.empty + mock_shared_client.query.assert_called_once_with( + "moneyflow", trade_date="20240115" + ) + mock_client_class.assert_not_called() # Should not create new client + + @patch("src.data.api_wrappers.api_moneyflow.TushareClient") + def test_get_moneyflow_date_column_rename(self, mock_client_class): + """Test that 'date' column is renamed to 'trade_date'.""" + # Setup mock with 'date' column (should be renamed) + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "date": ["20240115"], # Using 'date' instead of 'trade_date' + "buy_sm_vol": [10000], + "net_mf_amount": [220.0], + } + ) + + # Test + result = get_moneyflow(trade_date="20240115") + + # Assert - date column should be renamed to trade_date + assert "trade_date" in result.columns + assert "date" not in result.columns + assert result["trade_date"].iloc[0] == "20240115" + + +class TestMoneyflowSync: + """Test suite for MoneyflowSync class.""" + + def test_moneyflow_sync_class_attributes(self): + """Test MoneyflowSync class has correct attributes.""" + assert MoneyflowSync.table_name == "moneyflow" + assert MoneyflowSync.default_start_date == "20100101" + assert "ts_code" in MoneyflowSync.TABLE_SCHEMA + assert "trade_date" in MoneyflowSync.TABLE_SCHEMA + assert "buy_sm_vol" in MoneyflowSync.TABLE_SCHEMA + assert "net_mf_amount" in MoneyflowSync.TABLE_SCHEMA + assert MoneyflowSync.PRIMARY_KEY == ("ts_code", "trade_date") + + @patch("src.data.api_wrappers.api_moneyflow.get_moneyflow") + @patch("src.data.api_wrappers.api_moneyflow.TushareClient") + def test_fetch_single_date(self, mock_client_class, mock_get_moneyflow): + """Test fetch_single_date method.""" + # Setup mock + mock_get_moneyflow.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "trade_date": ["20240115"], + "buy_sm_vol": [10000], + "net_mf_amount": [220.0], + } + ) + + # Create sync instance + sync = MoneyflowSync() + + # Test + result = sync.fetch_single_date("20240115") + + # Assert + assert not result.empty + mock_get_moneyflow.assert_called_once_with( + trade_date="20240115", client=sync.client + ) + + @patch("src.data.api_wrappers.api_moneyflow.MoneyflowSync.sync_all") + def test_sync_moneyflow_function(self, mock_sync_all): + """Test sync_moneyflow convenience function.""" + # Setup mock + mock_sync_all.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "trade_date": ["20240115"], + } + ) + + # Test + result = sync_moneyflow(start_date="20240101", end_date="20240131") + + # Assert + assert not result.empty + mock_sync_all.assert_called_once_with( + start_date="20240101", end_date="20240131", force_full=False + ) + + @patch("src.data.api_wrappers.api_moneyflow.MoneyflowSync.preview_sync") + def test_preview_moneyflow_sync_function(self, mock_preview_sync): + """Test preview_moneyflow_sync convenience function.""" + # Setup mock + mock_preview_sync.return_value = { + "sync_needed": True, + "date_count": 31, + "start_date": "20240101", + "end_date": "20240131", + "estimated_records": 10000, + "sample_data": pd.DataFrame(), + "mode": "incremental", + } + + # Test + result = preview_moneyflow_sync(start_date="20240101", end_date="20240131") + + # Assert + assert result["sync_needed"] is True + assert result["date_count"] == 31 + mock_preview_sync.assert_called_once_with( + start_date="20240101", end_date="20240131", force_full=False, sample_size=3 + )