1、策略更新
2、新增qmt
This commit is contained in:
90
main/factor/industry_factors.py
Normal file
90
main/factor/industry_factors.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
行业/横截面因子模块
|
||||
包含基于日期截面的行业/横截面因子实现
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from main.factor.operator_framework import DateWiseFactor
|
||||
|
||||
|
||||
class IndustryMomentumFactor(DateWiseFactor):
|
||||
"""行业动量因子"""
|
||||
|
||||
def __init__(self, factor_name: str):
|
||||
super().__init__(
|
||||
name=f"industry_momentum",
|
||||
parameters={"factor_name": factor_name},
|
||||
required_factor_ids=[factor_name, "cat_l2_code"]
|
||||
)
|
||||
self.factor_name = factor_name
|
||||
|
||||
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
||||
# 计算行业动量基准
|
||||
# 这里需要先计算每个行业的平均值,然后与个股比较
|
||||
if self.factor_name in group_df.columns and "cat_l2_code" in group_df.columns:
|
||||
# 按行业计算平均值
|
||||
industry_means = group_df.group_by("cat_l2_code").agg([
|
||||
pl.col(self.factor_name).mean().alias("industry_mean")
|
||||
])
|
||||
|
||||
# 将行业均值合并回原数据
|
||||
result_df = group_df.join(industry_means, on="cat_l2_code", how="left")
|
||||
|
||||
# 计算与行业均值的偏差
|
||||
deviation = result_df[self.factor_name] - result_df["industry_mean"]
|
||||
return deviation.alias(self.factor_id)
|
||||
else:
|
||||
# 如果缺少必要列,返回全NaN
|
||||
return pl.Series([None] * len(group_df)).alias(self.factor_id)
|
||||
|
||||
|
||||
class MarketBreadthFactor(DateWiseFactor):
|
||||
"""市场宽度因子"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="market_breadth",
|
||||
parameters={},
|
||||
required_factor_ids=["pct_chg"]
|
||||
)
|
||||
|
||||
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
||||
# 计算市场宽度:上涨股票数 / 总股票数
|
||||
pct_chg = group_df["pct_chg"]
|
||||
positive_count = (pct_chg > 0).sum()
|
||||
total_count = len(group_df)
|
||||
|
||||
# 避免除零
|
||||
breadth = positive_count / (total_count + 1e-8)
|
||||
return pl.Series([breadth] * len(group_df)).alias(self.factor_id)
|
||||
|
||||
|
||||
class SectorRotationFactor(DateWiseFactor):
|
||||
"""板块轮动因子"""
|
||||
|
||||
def __init__(self, sector_factor: str):
|
||||
super().__init__(
|
||||
name=f"sector_rotation_{sector_factor}",
|
||||
parameters={"sector_factor": sector_factor},
|
||||
required_factor_ids=[sector_factor, "cat_l2_code"]
|
||||
)
|
||||
self.sector_factor = sector_factor
|
||||
|
||||
def calc_factor(self, group_df: pl.DataFrame) -> pl.Series:
|
||||
# 计算板块轮动因子
|
||||
if self.sector_factor in group_df.columns and "cat_l2_code" in group_df.columns:
|
||||
# 计算每个板块的因子均值
|
||||
sector_means = group_df.group_by("cat_l2_code").agg([
|
||||
pl.col(self.sector_factor).mean().alias("sector_mean")
|
||||
])
|
||||
|
||||
# 将板块均值合并回原数据
|
||||
result_df = group_df.join(sector_means, on="cat_l2_code", how="left")
|
||||
|
||||
# 计算个股与板块均值的偏差
|
||||
deviation = result_df[self.sector_factor] - result_df["sector_mean"]
|
||||
return deviation.alias(self.factor_id)
|
||||
else:
|
||||
# 如果缺少必要列,返回全NaN
|
||||
return pl.Series([None] * len(group_df)).alias(self.factor_id)
|
||||
Reference in New Issue
Block a user