Source code for easy_rec.python.feature_column.feature_group

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import re

from easy_rec.python.protos.feature_config_pb2 import FeatureGroupConfig
from easy_rec.python.protos.feature_config_pb2 import WideOrDeep


[docs]class FeatureGroup(object):
[docs] def __init__(self, feature_group_config): self._config = feature_group_config assert isinstance(self._config, FeatureGroupConfig) assert self._config.wide_deep in [WideOrDeep.WIDE, WideOrDeep.DEEP] self._auto_expand_feature_name()
@property def group_name(self): return self._config.group_name @property def wide_and_deep_dict(self): wide_and_deep_dict = {} for feature_name in self._config.feature_names: wide_and_deep_dict[feature_name] = self._config.wide_deep return wide_and_deep_dict @property def feature_names(self): return self._config.feature_names
[docs] def select_columns(self, fc): if self._config.wide_deep == WideOrDeep.WIDE: wide_columns = [fc.wide_columns[x] for x in self._config.feature_names] return wide_columns, [] else: sequence_columns = [] deep_columns = [] for x in self._config.feature_names: if x in fc.sequence_columns: sequence_columns.append(fc.sequence_columns[x]) else: deep_columns.append(fc.deep_columns[x]) return deep_columns, sequence_columns
def _auto_expand_feature_name(self): features = [x for x in self._config.feature_names] while len(self._config.feature_names) > 0: self._config.feature_names.pop() for feature in features: match_obj = re.match(r'([a-zA-Z_]+)\[([0-9]+)-([0-9]+)\]', feature) if match_obj: prefix = match_obj.group(1) sid = int(match_obj.group(2)) eid = int(match_obj.group(3)) + 1 for tid in range(sid, eid): tmp_f = '%s%d' % (prefix, tid) self._config.feature_names.append(tmp_f) else: self._config.feature_names.append(feature)