fix(factor-engine): 修复多因子计算时数据规格字段合并的 bug
- 修复 FactorEngine.compute() 中相同表的字段未正确合并的问题 - 将简单去重改为字段集合合并,确保所有因子依赖的字段都被获取 - 解决 high_low_ratio 等需要 high/low 字段的因子计算失败问题
This commit is contained in:
@@ -29,12 +29,6 @@ Example:
|
||||
>>> stk_limit = get_stk_limit(trade_date='20240101')
|
||||
"""
|
||||
|
||||
from src.data.api_wrappers.api_daily import (
|
||||
get_daily,
|
||||
sync_daily,
|
||||
preview_daily_sync,
|
||||
DailySync,
|
||||
)
|
||||
from src.data.api_wrappers.api_daily_basic import (
|
||||
get_daily_basic,
|
||||
sync_daily_basic,
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -302,14 +302,14 @@ SELECTED_FACTORS = [
|
||||
"return_5_rank",
|
||||
"EP_rank",
|
||||
"pe_expansion_trend",
|
||||
# "value_price_divergence",
|
||||
"value_price_divergence",
|
||||
"active_market_cap",
|
||||
# "ebit_rank",
|
||||
"ebit_rank",
|
||||
]
|
||||
|
||||
# 因子定义字典(完整因子库)
|
||||
FACTOR_DEFINITIONS = {
|
||||
# "turnover_rate_volatility": "ts_std(log(turnover_rate), 20)"
|
||||
"turnover_rate_volatility": "ts_std(log(turnover_rate), 20)"
|
||||
}
|
||||
|
||||
# Label 因子定义(不参与训练,用于计算目标)
|
||||
|
||||
@@ -335,13 +335,28 @@ class FactorEngine:
|
||||
for plan in plans:
|
||||
all_specs.extend(plan.data_specs)
|
||||
|
||||
# 去重数据规格(基于表名)
|
||||
seen_tables: set = set()
|
||||
unique_specs: List[DataSpec] = []
|
||||
# 合并相同表的字段(而不是简单地去重)
|
||||
table_to_columns: Dict[str, Set[str]] = {}
|
||||
table_to_spec: Dict[str, DataSpec] = {}
|
||||
for spec in all_specs:
|
||||
if spec.table not in seen_tables:
|
||||
seen_tables.add(spec.table)
|
||||
unique_specs.append(spec)
|
||||
if spec.table not in table_to_columns:
|
||||
table_to_columns[spec.table] = set()
|
||||
table_to_spec[spec.table] = spec
|
||||
table_to_columns[spec.table].update(spec.columns)
|
||||
|
||||
# 创建合并后的数据规格
|
||||
unique_specs: List[DataSpec] = []
|
||||
for table_name, columns in table_to_columns.items():
|
||||
original_spec = table_to_spec[table_name]
|
||||
unique_specs.append(
|
||||
DataSpec(
|
||||
table=table_name,
|
||||
columns=list(columns),
|
||||
join_type=original_spec.join_type,
|
||||
left_on=original_spec.left_on,
|
||||
right_on=original_spec.right_on,
|
||||
)
|
||||
)
|
||||
|
||||
# 4. 从路由器获取核心宽表
|
||||
core_data = self.router.fetch_data(
|
||||
|
||||
Reference in New Issue
Block a user