Source code for easy_rec.python.model.multi_task_model

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging

import tensorflow as tf

from easy_rec.python.builders import loss_builder
from easy_rec.python.model.rank_model import RankModel
from easy_rec.python.protos import tower_pb2

if tf.__version__ >= '2.0':
  tf = tf.compat.v1


[docs]class MultiTaskModel(RankModel):
[docs] def __init__(self, model_config, feature_configs, features, labels=None, is_training=False): super(MultiTaskModel, self).__init__(model_config, feature_configs, features, labels, is_training) self._task_towers = [] self._task_num = None self._label_name_dict = {}
def _init_towers(self, task_tower_configs): """Init task towers.""" self._task_towers = task_tower_configs self._task_num = len(task_tower_configs) for i, task_tower_config in enumerate(task_tower_configs): assert isinstance(task_tower_config, tower_pb2.TaskTower) or \ isinstance(task_tower_config, tower_pb2.BayesTaskTower), \ 'task_tower_config must be a instance of tower_pb2.TaskTower or tower_pb2.BayesTaskTower' tower_name = task_tower_config.tower_name # For label backward compatibility with list if self._labels is not None: if task_tower_config.HasField('label_name'): label_name = task_tower_config.label_name else: # If label name is not specified, task_tower and label will be matched by order label_name = list(self._labels.keys())[i] logging.info('Task Tower [%s] use label [%s]' % (tower_name, label_name)) assert label_name in self._labels, 'label [%s] must exists in labels' % label_name self._label_name_dict[tower_name] = label_name def _add_to_prediction_dict(self, output): for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name self._prediction_dict.update( self._output_to_prediction_impl( output[tower_name], loss_type=task_tower_cfg.loss_type, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name))
[docs] def build_metric_graph(self, eval_config): """Build metric graph for multi task model.""" metric_dict = {} for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name for metric in task_tower_cfg.metrics_set: metric_dict.update( self._build_metric_impl( metric, loss_type=task_tower_cfg.loss_type, label_name=self._label_name_dict[tower_name], num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name)) return metric_dict
[docs] def build_loss_graph(self): """Build loss graph for multi task model.""" for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name loss_weight = task_tower_cfg.weight * self._sample_weight if hasattr(task_tower_cfg, 'task_space_indicator_label') and \ task_tower_cfg.HasField('task_space_indicator_label'): in_task_space = tf.to_float( self._labels[task_tower_cfg.task_space_indicator_label] > 0) loss_weight = loss_weight * ( task_tower_cfg.in_task_space_weight * in_task_space + task_tower_cfg.out_task_space_weight * (1 - in_task_space)) self._loss_dict.update( self._build_loss_impl( task_tower_cfg.loss_type, label_name=self._label_name_dict[tower_name], loss_weight=loss_weight, num_class=task_tower_cfg.num_class, suffix='_%s' % tower_name)) kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict, self._labels) self._loss_dict.update(kd_loss_dict) return self._loss_dict
[docs] def get_outputs(self): outputs = [] for task_tower_cfg in self._task_towers: tower_name = task_tower_cfg.tower_name outputs.extend( self._get_outputs_impl( task_tower_cfg.loss_type, task_tower_cfg.num_class, suffix='_%s' % tower_name)) return outputs