Source code for easy_rec.python.layers.embed_input_layer
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from easy_rec.python.feature_column.feature_group import FeatureGroup
[docs]class EmbedInputLayer(object):
[docs] def __init__(self, feature_groups_config, dump_dir=None):
self._feature_groups = {
x.group_name: FeatureGroup(x) for x in feature_groups_config
}
self._dump_dir = dump_dir
def __call__(self, features, group_name):
assert group_name in self._feature_groups, 'invalid group_name[%s], list: %s' % \
','.join([x for x in self._feature_groups])
feature_group = self._feature_groups[group_name]
group_features = []
for feature_name in feature_group.feature_names:
tmp_fea = features[feature_name]
group_features.append(tmp_fea)
return tf.concat(group_features, axis=1), group_features