fix(factor-engine): 修复多因子计算时数据规格字段合并的 bug
- 修复 FactorEngine.compute() 中相同表的字段未正确合并的问题 - 将简单去重改为字段集合合并,确保所有因子依赖的字段都被获取 - 解决 high_low_ratio 等需要 high/low 字段的因子计算失败问题
This commit is contained in:
@@ -29,12 +29,6 @@ Example:
|
|||||||
>>> stk_limit = get_stk_limit(trade_date='20240101')
|
>>> stk_limit = get_stk_limit(trade_date='20240101')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.data.api_wrappers.api_daily import (
|
|
||||||
get_daily,
|
|
||||||
sync_daily,
|
|
||||||
preview_daily_sync,
|
|
||||||
DailySync,
|
|
||||||
)
|
|
||||||
from src.data.api_wrappers.api_daily_basic import (
|
from src.data.api_wrappers.api_daily_basic import (
|
||||||
get_daily_basic,
|
get_daily_basic,
|
||||||
sync_daily_basic,
|
sync_daily_basic,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -302,14 +302,14 @@ SELECTED_FACTORS = [
|
|||||||
"return_5_rank",
|
"return_5_rank",
|
||||||
"EP_rank",
|
"EP_rank",
|
||||||
"pe_expansion_trend",
|
"pe_expansion_trend",
|
||||||
# "value_price_divergence",
|
"value_price_divergence",
|
||||||
"active_market_cap",
|
"active_market_cap",
|
||||||
# "ebit_rank",
|
"ebit_rank",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 因子定义字典(完整因子库)
|
# 因子定义字典(完整因子库)
|
||||||
FACTOR_DEFINITIONS = {
|
FACTOR_DEFINITIONS = {
|
||||||
# "turnover_rate_volatility": "ts_std(log(turnover_rate), 20)"
|
"turnover_rate_volatility": "ts_std(log(turnover_rate), 20)"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Label 因子定义(不参与训练,用于计算目标)
|
# Label 因子定义(不参与训练,用于计算目标)
|
||||||
|
|||||||
@@ -335,13 +335,28 @@ class FactorEngine:
|
|||||||
for plan in plans:
|
for plan in plans:
|
||||||
all_specs.extend(plan.data_specs)
|
all_specs.extend(plan.data_specs)
|
||||||
|
|
||||||
# 去重数据规格(基于表名)
|
# 合并相同表的字段(而不是简单地去重)
|
||||||
seen_tables: set = set()
|
table_to_columns: Dict[str, Set[str]] = {}
|
||||||
unique_specs: List[DataSpec] = []
|
table_to_spec: Dict[str, DataSpec] = {}
|
||||||
for spec in all_specs:
|
for spec in all_specs:
|
||||||
if spec.table not in seen_tables:
|
if spec.table not in table_to_columns:
|
||||||
seen_tables.add(spec.table)
|
table_to_columns[spec.table] = set()
|
||||||
unique_specs.append(spec)
|
table_to_spec[spec.table] = spec
|
||||||
|
table_to_columns[spec.table].update(spec.columns)
|
||||||
|
|
||||||
|
# 创建合并后的数据规格
|
||||||
|
unique_specs: List[DataSpec] = []
|
||||||
|
for table_name, columns in table_to_columns.items():
|
||||||
|
original_spec = table_to_spec[table_name]
|
||||||
|
unique_specs.append(
|
||||||
|
DataSpec(
|
||||||
|
table=table_name,
|
||||||
|
columns=list(columns),
|
||||||
|
join_type=original_spec.join_type,
|
||||||
|
left_on=original_spec.left_on,
|
||||||
|
right_on=original_spec.right_on,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 4. 从路由器获取核心宽表
|
# 4. 从路由器获取核心宽表
|
||||||
core_data = self.router.fetch_data(
|
core_data = self.router.fetch_data(
|
||||||
|
|||||||
Reference in New Issue
Block a user