Source code for easy_rec.python.model.multi_task_model

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
from collections import OrderedDict

import tensorflow as tf

from easy_rec.python.builders import loss_builder
from easy_rec.python.layers.dnn import DNN
from easy_rec.python.model.rank_model import RankModel
from easy_rec.python.protos import tower_pb2
from easy_rec.python.protos.easy_rec_model_pb2 import EasyRecModel
from easy_rec.python.protos.loss_pb2 import LossType

if tf.__version__ >= '2.0':
  tf = tf.compat.v1


[docs]class MultiTaskModel(RankModel):
[docs] def __init__(self, model_config, feature_configs, features, labels=None, is_training=False): super(MultiTaskModel, self).__init__(model_config, feature_configs, features, labels, is_training) self._task_towers = [] self._task_num = None self._label_name_dict = {}
[docs] def build_predict_graph(self): if not self.has_backbone: raise NotImplementedError( 'method `build_predict_graph` must be implemented when backbone network do not exists' ) model = self._model_config.WhichOneof('model') assert model == 'model_params', '`model_params` must be configured' config = self._model_config.model_params self._init_towers(config.task_towers) backbone = self.backbone if type(backbone) in (list, tuple): if len(backbone) != len(config.task_towers): raise ValueError( 'The number of backbone outputs and task towers must be equal') task_input_list = backbone else: task_input_list = [backbone] * len(config.task_towers) tower_features = {} for i, task_tower_cfg in enumerate(config.task_towers): tower_name = task_tower_cfg.tower_name if task_tower_cfg.HasField('dnn'): tower_dnn = DNN( task_tower_cfg.dnn, self._l2_reg, name=tower_name, is_training=self._is_training) tower_output = tower_dnn(task_input_list[i]) else: tower_output = task_input_list[i] tower_features[tower_name] = tower_output tower_outputs = {} relation_features = {} # bayes network for task_tower_cfg in config.task_towers: tower_name = task_tower_cfg.tower_name if task_tower_cfg.HasField('relation_dnn'): relation_dnn = DNN( task_tower_cfg.relation_dnn, self._l2_reg, name=tower_name + '/relation_dnn', is_training=self._is_training) tower_inputs = [tower_features[tower_name]] for relation_tower_name in task_tower_cfg.relation_tower_names: tower_inputs.append(relation_features[relation_tower_name]) relation_input = tf.concat( tower_inputs, axis=-1, name=tower_name + '/relation_input') relation_fea = relation_dnn(relation_input) relation_features[tower_name] = relation_fea else: relation_fea = tower_features[tower_name] output_logits = tf.layers.dense( relation_fea, task_tower_cfg.num_class, kernel_regularizer=self._l2_reg, name=tower_name + '/output') tower_outputs[tower_name] = output_logits self._add_to_prediction_dict(tower_outputs) return self._prediction_dict
def _init_towers(self, task_tower_configs): """Init task towers.""" self._task_towers = task_tower_configs self._task_num = len(task_tower_configs) for i, task_tower_config in enumerate(task_tower_configs): assert isinstance(task_tower_config, tower_pb2.TaskTower) or \ isinstance(task_tower_config, tower_pb2.BayesTaskTower), \ 'task_tower_config must be a instance of tower_pb2.TaskTower or tower_pb2.BayesTaskTower' tower_name = task_tower_config.tower_name # For label backward compatibility with list if self._labels is not None: if task_tower_config.HasField('label_name'): label_name = task_tower_config.label_name else: # If label name is not specified, task_tower and label will be matched by order label_name = list(self._labels.keys())[i] logging.info('Task Tower [%s] use label [%s]' % (tower_name, label_name)) assert label_name in self._labels, 'label [%s] must exists in labels' % label_name self._label_name_dict[tower_name] = label_name def _add_to_prediction_dict(self, output): for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name if len(task_tower_cfg.losses) == 0: self._prediction_dict.update( self._output_to_prediction_impl( output[tower_name], loss_type=task_tower_cfg.loss_type, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name)) else: for loss in task_tower_cfg.losses: self._prediction_dict.update( self._output_to_prediction_impl( output[tower_name], loss_type=loss.loss_type, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name))
[docs] def build_metric_graph(self, eval_config): """Build metric graph for multi task model.""" for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name for metric in task_tower_cfg.metrics_set: loss_types = {task_tower_cfg.loss_type} if len(task_tower_cfg.losses) > 0: loss_types = {loss.loss_type for loss in task_tower_cfg.losses} self._metric_dict.update( self._build_metric_impl( metric, loss_type=loss_types, label_name=self._label_name_dict[tower_name], num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name)) return self._metric_dict
[docs] def build_loss_weight(self): loss_weights = OrderedDict() num_loss = 0 for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name losses = task_tower_cfg.losses n = len(losses) if n > 0: loss_weights[tower_name] = [loss.weight for loss in losses] num_loss += n else: loss_weights[tower_name] = [1.0] num_loss += 1 strategy = self._base_model_config.loss_weight_strategy if strategy == self._base_model_config.Random: weights = tf.random_normal([num_loss]) weights = tf.nn.softmax(weights) i = 0 for k, v in loss_weights.items(): n = len(v) loss_weights[k] = weights[i:i + n] i += n return loss_weights
[docs] def get_learnt_loss(self, loss_type, name, value): strategy = self._base_model_config.loss_weight_strategy if strategy == self._base_model_config.Uncertainty: uncertainty = tf.Variable( 0, name='%s_loss_weight' % name, dtype=tf.float32) tf.summary.scalar('loss/%s_uncertainty' % name, uncertainty) if loss_type in {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: return 0.5 * tf.exp(-uncertainty) * value + 0.5 * uncertainty else: return tf.exp(-uncertainty) * value + 0.5 * uncertainty else: strategy_name = EasyRecModel.LossWeightStrategy.Name(strategy) raise ValueError('Unsupported loss weight strategy: ' + strategy_name)
[docs] def build_loss_graph(self): """Build loss graph for multi task model.""" task_loss_weights = self.build_loss_weight() for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name loss_weight = task_tower_cfg.weight if task_tower_cfg.use_sample_weight: loss_weight *= self._sample_weight if hasattr(task_tower_cfg, 'task_space_indicator_label') and \ task_tower_cfg.HasField('task_space_indicator_label'): in_task_space = tf.to_float( self._labels[task_tower_cfg.task_space_indicator_label] > 0) loss_weight = loss_weight * ( task_tower_cfg.in_task_space_weight * in_task_space + task_tower_cfg.out_task_space_weight * (1 - in_task_space)) task_loss_weight = task_loss_weights[tower_name] loss_dict = {} losses = task_tower_cfg.losses if len(losses) == 0: loss_dict = self._build_loss_impl( task_tower_cfg.loss_type, label_name=self._label_name_dict[tower_name], loss_weight=loss_weight, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name) for loss_name in loss_dict.keys(): loss_dict[loss_name] = loss_dict[loss_name] * task_loss_weight[0] else: for loss in losses: loss_param = loss.WhichOneof('loss_param') if loss_param is not None: loss_param = getattr(loss, loss_param) loss_ops = self._build_loss_impl( loss.loss_type, label_name=self._label_name_dict[tower_name], loss_weight=loss_weight, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name, loss_name=loss.loss_name, loss_param=loss_param) for i, loss_name in enumerate(loss_ops): loss_value = loss_ops[loss_name] if loss.learn_loss_weight: loss_dict[loss_name] = self.get_learnt_loss( loss.loss_type, loss_name, loss_value) else: loss_dict[loss_name] = loss_value * task_loss_weight[i] self._loss_dict.update(loss_dict) kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, self._labels) self._loss_dict.update(kd_loss_dict) return self._loss_dict
[docs] def get_outputs(self): outputs = [] for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name if len(task_tower_cfg.losses) == 0: outputs.extend( self._get_outputs_impl( task_tower_cfg.loss_type, task_tower_cfg.num_class, suffix='_%s' % tower_name)) else: for loss in task_tower_cfg.losses: outputs.extend( self._get_outputs_impl( loss.loss_type, task_tower_cfg.num_class, suffix='_%s' % tower_name)) return list(set(outputs))