Source code for easy_rec.python.layers.mmoe

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import tensorflow as tf

from easy_rec.python.layers import dnn

if tf.__version__ >= '2.0':
  tf = tf.compat.v1

[docs]class MMOE:
[docs] def __init__(self, expert_dnn_config, l2_reg, num_task, num_expert=None, name='mmoe', is_training=False): """Initializes a `DNN` Layer. Args: expert_dnn_config: a instance or a list of easy_rec.python.protos.dnn_pb2.DNN, if it is a list of configs, the param `num_expert` will be ignored, if it is a single config, the number of experts will be specified by num_expert. l2_reg: l2 regularizer. num_task: number of tasks num_expert: number of experts, default is the list length of expert_dnn_configs name: scope of the DNN, so that the parameters could be separated from other dnns is_training: train phase or not, impact batchnorm and dropout """ if isinstance(expert_dnn_config, list): self._expert_dnn_configs = expert_dnn_config self._num_expert = len(expert_dnn_config) else: assert num_expert is not None and num_expert > 0, \ 'param `num_expert` must be large than zero, when expert_dnn_config is not a list' self._expert_dnn_configs = [expert_dnn_config] * num_expert self._num_expert = num_expert'num_expert: {0}'.format(self._num_expert)) self._num_task = num_task self._l2_reg = l2_reg self._name = name self._is_training = is_training
@property def num_expert(self): return self._num_expert
[docs] def gate(self, unit, deep_fea, name): fea = tf.layers.dense( inputs=deep_fea, units=unit, kernel_regularizer=self._l2_reg, name='%s/dnn' % name) fea = tf.nn.softmax(fea, axis=1) return fea
def __call__(self, deep_fea): expert_fea_list = [] for expert_id in range(self._num_expert): expert_dnn_config = self._expert_dnn_configs[expert_id] expert_dnn = dnn.DNN( expert_dnn_config, self._l2_reg, name='%s/expert_%d' % (self._name, expert_id), is_training=self._is_training) expert_fea = expert_dnn(deep_fea) expert_fea_list.append(expert_fea) experts_fea = tf.stack(expert_fea_list, axis=1) task_input_list = [] for task_id in range(self._num_task): gate = self.gate( self._num_expert, deep_fea, name='%s/gate_%d' % (self._name, task_id)) gate = tf.expand_dims(gate, -1) task_input = tf.multiply(experts_fea, gate) task_input = tf.reduce_sum(task_input, axis=1) task_input_list.append(task_input) return task_input_list