Source code for easy_rec.python.model.simple_multi_task

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf

from easy_rec.python.layers import dnn
from easy_rec.python.model.multi_task_model import MultiTaskModel

from easy_rec.python.protos.simple_multi_task_pb2 import SimpleMultiTask as SimpleMultiTaskConfig  # NOQA

if tf.__version__ >= '2.0':
  tf = tf.compat.v1


[docs]class SimpleMultiTask(MultiTaskModel):
[docs] def __init__(self, model_config, feature_configs, features, labels=None, is_training=False): super(SimpleMultiTask, self).__init__(model_config, feature_configs, features, labels, is_training) assert self._model_config.WhichOneof('model') == 'simple_multi_task', \ 'invalid model config: %s' % self._model_config.WhichOneof('model') self._model_config = self._model_config.simple_multi_task assert isinstance(self._model_config, SimpleMultiTaskConfig) if self.has_backbone: self._features = self.backbone else: self._features, _ = self._input_layer(self._feature_dict, 'all') self._init_towers(self._model_config.task_towers)
[docs] def build_predict_graph(self): tower_outputs = {} for i, task_tower_cfg in enumerate(self._task_towers): tower_name = task_tower_cfg.tower_name task_dnn = dnn.DNN( task_tower_cfg.dnn, self._l2_reg, name=tower_name, is_training=self._is_training) task_fea = task_dnn(self._features) task_output = tf.layers.dense( inputs=task_fea, units=task_tower_cfg.num_class, kernel_regularizer=self._l2_reg, name='dnn_output_%d' % i) tower_outputs[tower_name] = task_output self._add_to_prediction_dict(tower_outputs) return self._prediction_dict