Files
NewStock/main/factor/industry_factors.py
2025-11-29 00:23:12 +08:00

91 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
行业/横截面因子模块
包含基于日期截面的行业/横截面因子实现
"""
import numpy as np
import polars as pl
from main.factor.operator_framework import DateWiseFactor
class IndustryMomentumFactor(DateWiseFactor):
"""行业动量因子"""
def __init__(self, factor_name: str):
super().__init__(
name=f"industry_momentum",
parameters={"factor_name": factor_name},
required_factor_ids=[factor_name, "cat_l2_code"]
)
self.factor_name = factor_name
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
# 计算行业动量基准
# 这里需要先计算每个行业的平均值,然后与个股比较
if self.factor_name in group_df.columns and "cat_l2_code" in group_df.columns:
# 按行业计算平均值
industry_means = group_df.group_by("cat_l2_code").agg([
pl.col(self.factor_name).mean().alias("industry_mean")
])
# 将行业均值合并回原数据
result_df = group_df.join(industry_means, on="cat_l2_code", how="left")
# 计算与行业均值的偏差
deviation = result_df[self.factor_name] - result_df["industry_mean"]
return deviation.alias(self.factor_id)
else:
# 如果缺少必要列返回全NaN
return pl.Series([None] * len(group_df)).alias(self.factor_id)
class MarketBreadthFactor(DateWiseFactor):
"""市场宽度因子"""
def __init__(self):
super().__init__(
name="market_breadth",
parameters={},
required_factor_ids=["pct_chg"]
)
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
# 计算市场宽度:上涨股票数 / 总股票数
pct_chg = group_df["pct_chg"]
positive_count = (pct_chg > 0).sum()
total_count = len(group_df)
# 避免除零
breadth = positive_count / (total_count + 1e-8)
return pl.Series([breadth] * len(group_df)).alias(self.factor_id)
class SectorRotationFactor(DateWiseFactor):
"""板块轮动因子"""
def __init__(self, sector_factor: str):
super().__init__(
name=f"sector_rotation_{sector_factor}",
parameters={"sector_factor": sector_factor},
required_factor_ids=[sector_factor, "cat_l2_code"]
)
self.sector_factor = sector_factor
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
# 计算板块轮动因子
if self.sector_factor in group_df.columns and "cat_l2_code" in group_df.columns:
# 计算每个板块的因子均值
sector_means = group_df.group_by("cat_l2_code").agg([
pl.col(self.sector_factor).mean().alias("sector_mean")
])
# 将板块均值合并回原数据
result_df = group_df.join(sector_means, on="cat_l2_code", how="left")
# 计算个股与板块均值的偏差
deviation = result_df[self.sector_factor] - result_df["sector_mean"]
return deviation.alias(self.factor_id)
else:
# 如果缺少必要列返回全NaN
return pl.Series([None] * len(group_df)).alias(self.factor_id)