""" 统一因子计算模块 提供统一接口来计算所有类型的因子 """ import polars as pl from typing import List, Dict, Any from main.factor.operator_framework import FactorGraph from main.factor import ( # 技术指标因子 SMAFactor, EMAFactor, ATRFactor, OBVFactor, MACDFactor, RSI_Factor, CrossSectionalRankFactor, # 动量因子 ReturnFactor, VolatilityFactor, MomentumFactor, MomentumAcceleration, TrendEfficiency, # 资金流因子 LGFlowFactor, FlowIntensityFactor, FlowDivergenceFactor, FlowStructureFactor, FlowAccelerationFactor, # 筹码分布因子 ChipConcentrationFactor, ChipSkewnessFactor, FloatingChipFactor, CostSupportFactor, WinnerPriceZoneFactor, CostSqueeze, HighCostSelling, LowCostAccumulation, InstNetAccum, ChipLockin, RetailOutInstIn, AccumAccel, # 市场情绪因子 SentimentPanicGreedFactor, SentimentBreadthFactor, SentimentReversalFactor, PriceDeductionFactor, PriceDeductionRatioFactor, IndustryMomentumLeadership, LeadershipPersistenceScore, DynamicIndustryLeadership, # 行业/横截面因子 IndustryMomentumFactor, MarketBreadthFactor, SectorRotationFactor, # 财务因子 CashflowToEVFactor, BookToPriceFactor, DebtToEquityFactor, ProfitMarginFactor, BMFactor, # 特殊因子 LimitFactor, VolumeRatioFactor, BBI_RATIO_FACTOR, VolatilitySlopeFactor, PriceVolumeTrendFactor, ) def calculate_all_factors( df: pl.DataFrame, stock_factor_configs: List[Dict[str, Any]] = None, date_factor_configs: List[Dict[str, Any]] = None, ) -> pl.DataFrame: """ 统一计算所有因子的函数 Parameters: df (pl.DataFrame): 输入的股票数据表 stock_factor_configs (List[Dict]): 股票截面因子配置列表 date_factor_configs (List[Dict]): 日期截面因子配置列表 Returns: pl.DataFrame: 包含所有计算因子的DataFrame """ # 初始化因子图 factor_graph = FactorGraph() # 如果没有提供配置,则使用默认配置 if stock_factor_configs is None: stock_factor_configs = [ {"class": SMAFactor, "params": {"window": 5}}, {"class": SMAFactor, "params": {"window": 20}}, {"class": EMAFactor, "params": {"window": 12}}, {"class": EMAFactor, "params": {"window": 26}}, {"class": ATRFactor, "params": {"window": 14}}, {"class": OBVFactor, "params": {}}, { "class": MACDFactor, "params": {"fast_period": 12, "slow_period": 26, "signal_period": 9}, }, {"class": RSI_Factor, "params": {"window": 14}}, # 资金流因子 {"class": LGFlowFactor, "params": {}}, {"class": FlowIntensityFactor, "params": {}}, {"class": FlowDivergenceFactor, "params": {}}, {"class": FlowStructureFactor, "params": {}}, {"class": FlowAccelerationFactor, "params": {}}, {"class": InstNetAccum, "params": {}}, {"class": ChipLockin, "params": {}}, {"class": RetailOutInstIn, "params": {}}, {"class": AccumAccel, "params": {}}, # 筹码分布因子 {"class": ChipConcentrationFactor, "params": {}}, {"class": ChipSkewnessFactor, "params": {}}, {"class": FloatingChipFactor, "params": {}}, {"class": CostSupportFactor, "params": {}}, {"class": WinnerPriceZoneFactor, "params": {}}, {"class": LowCostAccumulation, "params": {}}, {"class": HighCostSelling, "params": {}}, {"class": CostSqueeze, "params": {}}, # 市场情绪因子 { "class": SentimentPanicGreedFactor, "params": {"window_atr": 14, "window_smooth": 5}, }, { "class": SentimentBreadthFactor, "params": {"window_vol": 20, "window_smooth": 3}, }, { "class": SentimentReversalFactor, "params": {"window_ret": 5, "window_vol": 5}, }, {"class": PriceDeductionFactor, "params": {"n": 10}}, {"class": PriceDeductionRatioFactor, "params": {"n": 10}}, {"class": IndustryMomentumLeadership, "params": {}}, {"class": LeadershipPersistenceScore, "params": {}}, # {"class": DynamicIndustryLeadership, "params": {}}, # 财务因子 # {"class": CashflowToEVFactor, "params": {}}, # {"class": BookToPriceFactor, "params": {}}, # {"class": ROEFactor, "params": {}}, # {"class": DebtToEquityFactor, "params": {}}, # {"class": ProfitMarginFactor, "params": {}}, {"class": BMFactor, "params": {}}, # 特殊因子 {"class": LimitFactor, "params": {}}, {"class": VolumeRatioFactor, "params": {}}, {"class": BBI_RATIO_FACTOR, "params": {}}, { "class": VolatilitySlopeFactor, "params": {"window_vol": 20, "window_slope": 5}, }, {"class": PriceVolumeTrendFactor, "params": {}}, # 动量因子 - 添加20日收益率因子 {"class": ReturnFactor, "params": {"period": 20}}, {"class": ReturnFactor, "params": {"period": 5}}, {"class": VolatilityFactor, "params": {"period": 10}}, { "class": MomentumAcceleration, "params": {"short_period": 5, "long_period": 60}, }, {"class": TrendEfficiency, "params": {"period": 10}}, { "class": CrossSectionalRankFactor, "params": {"column": "circ_mv", "name": "size_rank"}, }, ] if date_factor_configs is None: date_factor_configs = [ {"class": CrossSectionalRankFactor, "params": {"column": "return_5"}}, {"class": CrossSectionalRankFactor, "params": {"column": "return_5"}}, { "class": CrossSectionalRankFactor, "params": {"column": "return_20"}, }, { "class": CrossSectionalRankFactor, "params": {"column": "volatility_10"}, }, { "class": CrossSectionalRankFactor, "params": {"column": "circ_mv"}, }, # { # "class": CrossSectionalRankFactor, # "params": {"factor_name": "momentum_10"}, # }, ] # 添加股票截面因子 stock_factors = [] for config in stock_factor_configs: factor_class = config["class"] params = config["params"] try: factor = factor_class(**params) factor_graph.add_factor(factor) stock_factors.append(factor) except Exception as e: print(f"创建股票因子 {factor_class.__name__} 时出错: {e}") # 添加日期截面因子 date_factors = [] for config in date_factor_configs: factor_class = config["class"] params = config["params"] try: factor = factor_class(**params) factor_graph.add_factor(factor) date_factors.append(factor) except Exception as e: print(f"创建日期因子 {factor_class.__name__} 时出错: {e}") # 先计算股票截面因子 result_df = df.clone() # 获取所有需要的因子ID stock_factor_ids = [factor.get_factor_id() for factor in stock_factors] date_factor_ids = [factor.get_factor_id() for factor in date_factors] # 计算股票因子 if stock_factor_ids: result_df = factor_graph.compute(result_df, stock_factor_ids) # 计算日期因子 if date_factor_ids: result_df = factor_graph.compute(result_df, date_factor_ids) all_ids = [] for ids in stock_factor_ids: all_ids.append(ids) for ids in date_factor_ids: all_ids.append(ids) return result_df, all_ids # 为了兼容旧的函数调用方式,提供一个简化的统一接口 def compute_factors(df: pl.DataFrame): """ 简化版因子计算接口 Parameters: df (pl.DataFrame): 输入的股票数据表 Returns: pl.DataFrame: 包含所有计算因子的DataFrame """ return calculate_all_factors(df)