Source code for easy_rec.python.utils.estimator_utils

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import logging
import os
import re
import time
from distutils.version import LooseVersion

import numpy as np
import six
import tensorflow as tf
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.framework import meta_graph
from tensorflow.python.training.summary_io import SummaryWriterCache

from easy_rec.python.utils import shape_utils

if tf.__version__ >= '2.0':
  tf = tf.compat.v1
  SessionRunHook = tf.estimator.SessionRunHook
  CheckpointSaverHook = tf.estimator.CheckpointSaverHook
else:
  SessionRunHook = tf.train.SessionRunHook
  CheckpointSaverHook = tf.train.CheckpointSaverHook


[docs]class ExitBarrierHook(SessionRunHook): """ExitBarrier to make sure master and workers exit at the same time. After training finish, master has to do evaluation and model export, so master exits a little late than workers. """
[docs] def __init__(self, num_worker, is_chief, model_dir): self._num_worker = num_worker self._is_chief = is_chief self._queue = None self._signal_que = None self._que_size = None self._queue = None self._enque = None self._deque = None self._model_dir = model_dir self._send = None self._recv = None
[docs] def begin(self): """Count the number of workers and masters, and setup barrier queue.""" tf.logging.info('number workers(including master) = %d' % self._num_worker) with tf.device( tf.DeviceSpec(job='ps', task=0, device_type='CPU', device_index=0)): self._queue = tf.FIFOQueue( capacity=self._num_worker, dtypes=[tf.float32], shapes=[()], name='exit_counter', shared_name='exit_counter') self._signal_que = tf.FIFOQueue( capacity=self._num_worker, dtypes=[tf.string], shapes=[()], name='exit_counter_signal', shared_name='exit_counter_signal') self._enque = self._queue.enqueue(1.0) self._que_size = self._queue.size() self._deque = self._queue.dequeue() if self._is_chief: self._flag_file = os.path.join(self._model_dir, 'atexit_sync_' + str(int(time.time()))) self._send = self._signal_que.enqueue([self._flag_file]) else: self._recv = self._signal_que.dequeue() self._flag_file = None
[docs] def after_create_session(self, session, coord): """Clean up the queue after create session. Sometimes ps is not exit, the last run enqueued elements will remain in the queue """ if self._is_chief: # clear the queue que_size = session.run(self._que_size) while que_size > 0: session.run(self._deque) que_size = session.run(self._que_size) logging.info('exit counter cleared: %d' % que_size)
[docs] def end(self, session): """Ensure when all workers and master enqueue an element, then exit.""" session.run(self._enque) que_size = session.run(self._que_size) while que_size < self._num_worker: que_size = session.run(self._que_size) time.sleep(5) tf.logging.info( 'waiting for other worker to exit, finished %d, total %d' % (que_size, self._num_worker)) # prepare on_exit synchronize base on self._flag_file if self._is_chief: for i in range(self._num_worker - 1): session.run(self._send) else: self._flag_file = session.run(self._recv) def _check_flag_file(is_chief, flag_file): logging.info('_check_flag_file: is_chief = %d flag_file=%s' % (is_chief, flag_file)) if is_chief: with tf.gfile.GFile(flag_file, 'w') as fout: fout.write('atexit time: %d' % int(time.time())) else: while not tf.gfile.Exists(flag_file): time.sleep(1) from atexit import register register( _check_flag_file, is_chief=self._is_chief, flag_file=self._flag_file) logging.info('ExitBarrier passed')
[docs]class EvaluateExitBarrierHook(SessionRunHook): """ExitBarrier to make sure master and workers exit at the same time. After training finish, master has to do evaluation and model export, so master exits a little late than workers. """
[docs] def __init__(self, num_worker, is_chief, model_dir, metric_ops=None): self._num_worker = num_worker self._is_chief = is_chief self._queue = None self._signal_que = None self._que_size = None self._queue = None self._enque = None self._deque = None self._model_dir = model_dir self._send = None self._recv = None self.metric_ops = metric_ops self.eval_result = None
[docs] def begin(self): """Count the number of workers and masters, and setup barrier queue.""" tf.logging.info('number workers(including master) = %d' % self._num_worker) with tf.device( tf.DeviceSpec(job='ps', task=0, device_type='CPU', device_index=0)): self._queue = tf.FIFOQueue( capacity=self._num_worker, dtypes=[tf.float32], shapes=[()], name='exit_counter', shared_name='exit_counter') self._signal_que = tf.FIFOQueue( capacity=self._num_worker, dtypes=[tf.string], shapes=[()], name='exit_counter_signal', shared_name='exit_counter_signal') self._enque = self._queue.enqueue(1.0) self._que_size = self._queue.size() self._deque = self._queue.dequeue() if self._is_chief: self._flag_file = os.path.join(self._model_dir, 'atexit_sync_' + str(int(time.time()))) self._send = self._signal_que.enqueue([self._flag_file]) else: self._recv = self._signal_que.dequeue() self._flag_file = None
[docs] def after_create_session(self, session, coord): """Clean up the queue after create session. Sometimes ps is not exit, the last run enqueued elements will remain in the queue """ if self._is_chief: # clear the queue que_size = session.run(self._que_size) while que_size > 0: session.run(self._deque) que_size = session.run(self._que_size) logging.info('exit counter cleared: %d' % que_size)
[docs] def end(self, session): """Ensure when all workers and master enqueue an element, then exit.""" session.run(self._enque) que_size = session.run(self._que_size) while que_size < self._num_worker: que_size = session.run(self._que_size) time.sleep(5) tf.logging.info( 'waiting for other worker to exit, finished %d, total %d' % (que_size, self._num_worker)) # prepare on_exit synchronize base on self._flag_file if self._is_chief: self.eval_result = session.run(self.metric_ops) for i in range(self._num_worker - 1): session.run(self._send) else: self._flag_file = session.run(self._recv) def _check_flag_file(is_chief, flag_file): logging.info('_check_flag_file: is_chief = %d flag_file=%s' % (is_chief, flag_file)) if is_chief: with tf.gfile.GFile(flag_file, 'w') as fout: fout.write('atexit time: %d' % int(time.time())) else: while not tf.gfile.Exists(flag_file): time.sleep(1) from atexit import register register( _check_flag_file, is_chief=self._is_chief, flag_file=self._flag_file) session.run(self.metric_ops) logging.info('ExitBarrier passed')
[docs]class ProgressHook(SessionRunHook):
[docs] def __init__(self, num_steps, filename, is_chief): """Initializes a `ProgressHook`. Args: num_steps: total train steps filename: progress file name is_chief: is chief worker or not """ self._num_steps = num_steps self._is_chief = is_chief if self._is_chief: self._progress_file = tf.gfile.GFile(filename, 'w') self._progress_file.write('0.00\n') self._progress_interval = 0.01 # 1% self._last_progress_cnt = 0
[docs] def before_run(self, run_context): if self._is_chief: return tf.train.SessionRunArgs([tf.train.get_global_step()])
[docs] def after_run( self, run_context, # pylint: disable=unused-argument run_values): if self._is_chief: global_step = run_values.results[0] curr_progress = global_step / self._num_steps curr_progress_cnt = int(curr_progress / self._progress_interval) if curr_progress_cnt >= self._last_progress_cnt + 1: self._progress_file.write('%.2f\n' % curr_progress) self._progress_file.flush() self._last_progress_cnt = curr_progress_cnt logging.info('Training Progress: %.2f' % curr_progress)
[docs] def end(self, session): if self._is_chief: if self._last_progress_cnt < 1 / self._progress_interval: self._progress_file.write('1.00\n') self._progress_file.close()
[docs]class CheckpointSaverHook(CheckpointSaverHook): """Saves checkpoints every N steps or seconds."""
[docs] def __init__(self, checkpoint_dir, save_secs=None, save_steps=None, saver=None, checkpoint_basename='model.ckpt', scaffold=None, listeners=None, write_graph=True): """Initializes a `CheckpointSaverHook`. Args: checkpoint_dir: `str`, base directory for the checkpoint files. save_secs: `int`, save every N secs. save_steps: `int`, save every N steps. saver: `Saver` object, used for saving. checkpoint_basename: `str`, base name for the checkpoint files. scaffold: `Scaffold`, use to get saver object. listeners: List of `CheckpointSaverListener` subclass instances. Used for callbacks that run immediately before or after this hook saves the checkpoint. write_graph: whether to save graph.pbtxt. Raises: ValueError: One of `save_steps` or `save_secs` should be set. ValueError: At most one of saver or scaffold should be set. """ super(CheckpointSaverHook, self).__init__( checkpoint_dir, save_secs=save_secs, save_steps=save_steps, saver=saver, checkpoint_basename=checkpoint_basename, scaffold=scaffold, listeners=listeners) self._write_graph = write_graph
[docs] def after_create_session(self, session, coord): global_step = session.run(self._global_step_tensor) if self._write_graph: # We do write graph and saver_def at the first call of before_run. # We cannot do this in begin, since we let other hooks to change graph and # add variables in begin. Graph is finalized after all begin calls. tf.train.write_graph(tf.get_default_graph().as_graph_def(add_shapes=True), self._checkpoint_dir, 'graph.pbtxt') saver_def = self._get_saver().saver_def if self._get_saver() else None graph = tf.get_default_graph() meta_graph_def = meta_graph.create_meta_graph_def( graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) self._summary_writer.add_graph(graph) self._summary_writer.add_meta_graph(meta_graph_def) # when tf version > 1.10.0, we use defaut training strategy, which saves ckpt # at first train step if LooseVersion(tf.__version__) >= LooseVersion('1.10.0'): # The checkpoint saved here is the state at step "global_step". self._save(session, global_step) self._timer.update_last_triggered_step(global_step)
[docs] def before_run(self, run_context): # pylint: disable=unused-argument return tf.train.SessionRunArgs(self._global_step_tensor)
def _save(self, session, step): """Saves the latest checkpoint, returns should_stop.""" logging.info('Saving checkpoints for %d into %s.', step, self._save_path) for l in self._listeners: # noqa: E741 l.before_save(session, step) self._get_saver().save( session, self._save_path, global_step=step, write_meta_graph=self._write_graph) save_dir, save_name = os.path.split(self._save_path) self._summary_writer.add_session_log( tf.SessionLog( status=tf.SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) should_stop = False for l in self._listeners: # noqa: E741 if l.after_save(session, step): logging.info( 'A CheckpointSaverListener requested that training be stopped. ' 'listener: {}'.format(l)) should_stop = True return should_stop
[docs]class NumpyCheckpointRestoreHook(SessionRunHook): """Restore variable from numpy checkpoint."""
[docs] def __init__(self, ckpt_path, name2var_map): """Initializes a `NumpyCheckpointRestoreHook`. Args: ckpt_path: numpy checkpoint path to restore from name2var_map: var name in numpy ckpt to variable map """ self._ckpt_path = ckpt_path self._name2var_map = name2var_map self._restore_op = None
[docs] def begin(self): ckpt_data = np.load(self._ckpt_path) vars_not_inited = {} assign_ops = [] has_shape_unmatch = False with tf.variable_scope('', reuse=True): for var_name, var in six.iteritems(self._name2var_map): var_shape = var.get_shape().as_list() if var_name in ckpt_data.keys(): var_data = ckpt_data[var_name] if list(var_data.shape) == var_shape: assign_ops.append(var.assign(var_data)) else: logging.error( 'variable [%s] shape not match %r vs %r' % (var.name.split(':')[0], var_shape, list(var_data.shape))) has_shape_unmatch = True elif 'Momentum' not in var_name and 'global_step' not in var_name: logging.error('variable [%s] not found in ckpt' % var_name) vars_not_inited[var_name] = ','.join([str(s) for s in var_shape]) self._restore_op = tf.group(assign_ops) with tf.gfile.GFile(self._ckpt_path[:-4] + '_not_inited.txt', 'w') as f: for var_name in sorted(vars_not_inited.keys()): f.write('%s:%s\n' % (var_name, vars_not_inited[var_name])) assert not has_shape_unmatch, 'exist variable shape not match, restore failed' assert len(vars_not_inited.keys()) == 0, \ 'exist variable shape not inited, restore failed'
[docs] def after_create_session(self, session, coord): assert self._restore_op is not None logging.info('running numpy checkpoint restore_op') session.run(self._restore_op)
[docs]class IncompatibleShapeRestoreHook(SessionRunHook): """Restore variable with incompatible shapes."""
[docs] def __init__(self, incompatible_shape_var_map): """Initializes a `IncompatibleShapeRestoreHook`. Args: incompatible_shape_var_map: a variables mapping with incompatible shapes, map from real variable to temp variable, real variable is the variable used in model, temp variable is the variable restored from checkpoint. """ self._incompatible_shape_var_map = incompatible_shape_var_map self._restore_op = None
[docs] def begin(self): assign_ops = [] for var, var_tmp in six.iteritems(self._incompatible_shape_var_map): assign_ops.append( var.assign( shape_utils.pad_or_clip_nd(var_tmp, var.get_shape().as_list()))) logging.info( 'Assign variable[%s] from shape%s to shape%s' % (var.name, var_tmp.get_shape().as_list(), var.get_shape().as_list())) self._restore_op = tf.group(assign_ops)
[docs] def after_create_session(self, session, coord): assert self._restore_op is not None logging.info('running incompatible shape variable restore_op') session.run(self._restore_op)
[docs]class MultipleCheckpointsRestoreHook(SessionRunHook): """Restore variable from numpy checkpoint.""" SEP = ';'
[docs] def __init__(self, ckpt_paths): """Initializes a `MultipleCheckpointsRestoreHook`. Args: ckpt_paths: multiple checkpoint path, seperated by ; name2var_map: var name in numpy ckpt to variable map """ self._ckpt_path_list = ckpt_paths.split(self.SEP) self._saver_list = []
[docs] def begin(self): global_variables = tf.global_variables() var_names = [re.sub(':[0-9]$', '', var.name) for var in global_variables] restore_status = {var_name: False for var_name in var_names} for ckpt_path in self._ckpt_path_list: logging.info('read variable from %s' % ckpt_path) ckpt_reader = tf.train.NewCheckpointReader(ckpt_path) ckpt_var2shape_map = ckpt_reader.get_variable_to_shape_map() # ckpt_var2shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None) name2var = {} for var in global_variables: var_name = re.sub(':[0-9]$', '', var.name) if var_name in ckpt_var2shape_map: if restore_status[var_name]: logging.warning( 'variable %s find in more than one checkpoint, skipped %s' % (var_name, ckpt_path)) continue name2var[var_name] = var restore_status[var_name] = True saver = tf.train.Saver(name2var) self._saver_list.append(saver) restore_check = True for var_name, stat in six.iteritems(restore_status): if not stat: logging.error('var %s not find in checkpoints' % var_name) restore_check = False assert restore_check, 'failed to find all variables in checkpoints provided'
[docs] def after_create_session(self, session, coord): logging.info('running multiple checkpoint restore hook') for saver, ckpt_path in zip(self._saver_list, self._ckpt_path_list): logging.info('restore checkpoint from %s' % ckpt_path) saver.restore(session, ckpt_path)
[docs]class OnlineEvaluationHook(SessionRunHook):
[docs] def __init__(self, metric_dict, output_dir): self._metric_dict = metric_dict self._output_dir = output_dir self._summary_writer = SummaryWriterCache.get(self._output_dir)
[docs] def end(self, session): metric_tensor_dict = {k: v[0] for k, v in self._metric_dict.items()} metric_value_dict = session.run(metric_tensor_dict) tf.logging.info('Eval metric: %s' % metric_value_dict) global_step_tensor = tf.train.get_or_create_global_step() global_step = session.run(global_step_tensor) summary = Summary() for k, v in metric_value_dict.items(): summary.value.add(tag=k, simple_value=v) self._summary_writer.add_summary(summary, global_step=global_step) self._summary_writer.flush() eval_result_file = os.path.join(self._output_dir, 'online_eval_result.txt-%s' % global_step) logging.info('Saving online eval result to file %s' % eval_result_file) with tf.gfile.GFile(eval_result_file, 'w') as ofile: result_to_write = {} for key in sorted(metric_value_dict): # convert numpy float to python float result_to_write[key] = metric_value_dict[key].item() ofile.write(json.dumps(result_to_write, indent=2))
[docs]def parse_tf_config(): tf_config_str = os.environ.get('TF_CONFIG', '') if 'TF_CONFIG' in os.environ: tf_config = json.loads(tf_config_str) cluster = tf_config['cluster'] task = tf_config['task'] task_type = task['type'] task_index = task['index'] else: cluster = {} task_type = 'master' task_index = 0 return cluster, task_type, task_index
[docs]def get_task_index_and_num(): cluster, task_type, task_index = parse_tf_config() if 'worker' not in cluster: return 0, 1 if task_type == 'evaluator': return 0, 1 task_num = len(cluster['worker']) if 'chief' in cluster or 'master' in cluster: task_num += 1 if task_type not in ['chief', 'master']: task_index += 1 return task_index, task_num
[docs]def get_ckpt_version(ckpt_path): """Get checkpoint version from ckpt_path. Args: ckpt_path: such as xx/model.ckpt-2000 or xx/model.ckpt-2000.meta Return: ckpt_version: such as 2000 """ _, ckpt_name = os.path.split(ckpt_path) ckpt_name, ext = os.path.splitext(ckpt_name) if ext.startswith('.ckpt-'): ckpt_name = ext toks = ckpt_name.split('-') return int(toks[-1])
[docs]def latest_checkpoint(model_dir): """Find lastest checkpoint under a directory. Args: model_dir: model directory Return: model_path: xx/model.ckpt-2000 """ ckpt_metas = tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.meta')) if len(ckpt_metas) == 0: return None if len(ckpt_metas) > 1: ckpt_metas.sort(key=lambda x: get_ckpt_version(x)) ckpt_path = os.path.splitext(ckpt_metas[-1])[0] return ckpt_path
[docs]def master_to_chief(): if 'TF_CONFIG' in os.environ: tf_config = json.loads(os.environ['TF_CONFIG']) # change chief to master if 'master' in tf_config['cluster']: tf_config['cluster']['chief'] = tf_config['cluster']['master'] del tf_config['cluster']['chief'] if tf_config['task']['type'] == 'master': tf_config['task']['type'] = 'chief' os.environ['TF_CONFIG'] = json.dumps(tf_config) return tf_config else: return None
[docs]def chief_to_master(): if 'TF_CONFIG' in os.environ: tf_config = json.loads(os.environ['TF_CONFIG']) # change chief to master if 'chief' in tf_config['cluster']: tf_config['cluster']['master'] = tf_config['cluster']['chief'] del tf_config['cluster']['chief'] if tf_config['task']['type'] == 'chief': tf_config['task']['type'] = 'master' os.environ['TF_CONFIG'] = json.dumps(tf_config) return tf_config else: return None
[docs]def is_chief(): if 'TF_CONFIG' in os.environ: tf_config = json.loads(os.environ['TF_CONFIG']) if 'task' in tf_config: return tf_config['task']['type'] in ['chief', 'master'] return True
[docs]def is_master(): if 'TF_CONFIG' in os.environ: tf_config = json.loads(os.environ['TF_CONFIG']) if 'task' in tf_config: return tf_config['task']['type'] == 'master' return True
[docs]def is_evaluator(): if 'TF_CONFIG' in os.environ: tf_config = json.loads(os.environ['TF_CONFIG']) if 'task' in tf_config: return tf_config['task']['type'] == 'evaluator' return False