Source code for easy_rec.python.model.autoint

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import tensorflow as tf

from easy_rec.python.layers import multihead_attention
from easy_rec.python.model.rank_model import RankModel

from easy_rec.python.protos.autoint_pb2 import AutoInt as AutoIntConfig  # NOQA

if tf.__version__ >= '2.0':
  tf = tf.compat.v1


[docs]class AutoInt(RankModel):
[docs] def __init__(self, model_config, feature_configs, features, labels=None, is_training=False): super(AutoInt, self).__init__(model_config, feature_configs, features, labels, is_training) assert self._model_config.WhichOneof('model') == 'autoint', \ 'invalid model config: %s' % self._model_config.WhichOneof('model') self._features, _ = self._input_layer(self._feature_dict, 'all') self._feature_num = len(self._model_config.feature_groups[0].feature_names) self._seq_key_num = 0 if len(self._model_config.feature_groups[0].sequence_features) > 0: for seq_fea in self._model_config.feature_groups[0].sequence_features: for seq_att in seq_fea.seq_att_map: self._feature_num += len(seq_att.hist_seq) self._seq_key_num += len(seq_att.key) self._model_config = self._model_config.autoint assert isinstance(self._model_config, AutoIntConfig) fea_emb_dim_list = [] for feature_config in feature_configs: fea_emb_dim_list.append(feature_config.embedding_dim) assert len(set(fea_emb_dim_list)) == 1 and len(fea_emb_dim_list) == self._feature_num, \ 'AutoInt requires that all feature dimensions must be consistent.' self._d_model = fea_emb_dim_list[0] self._head_num = self._model_config.multi_head_num self._head_size = self._model_config.multi_head_size
[docs] def build_predict_graph(self): logging.info('feature_num: {0}'.format(self._feature_num)) attention_fea = tf.reshape( self._features, shape=[-1, self._feature_num + self._seq_key_num, self._d_model]) for i in range(self._model_config.interacting_layer_num): attention_layer = multihead_attention.MultiHeadAttention( head_num=self._head_num, head_size=self._head_size, l2_reg=self._l2_reg, use_res=True, name='multi_head_self_attention_layer_%d' % i) attention_fea = attention_layer(attention_fea) attention_fea = tf.reshape( attention_fea, shape=[-1, attention_fea.shape[1] * attention_fea.shape[2]]) final = tf.layers.dense(attention_fea, self._num_class, name='output') self._add_to_prediction_dict(final) return self._prediction_dict