# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import logging
import math
import os
import time
import numpy as np
import six
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import signature_constants
from easy_rec.python.utils import pai_util
from easy_rec.python.utils.config_util import get_configs_from_pipeline_file
from easy_rec.python.utils.input_utils import get_type_defaults
from easy_rec.python.utils.load_class import get_register_class_meta
if tf.__version__ >= '2.0':
tf = tf.compat.v1
SINGLE_PLACEHOLDER_FEATURE_KEY = 'features'
_PREDICTOR_CLASS_MAP = {}
_register_abc_meta = get_register_class_meta(
_PREDICTOR_CLASS_MAP, have_abstract_class=True)
[docs]class PredictorInterface(six.with_metaclass(_register_abc_meta, object)):
version = 1
[docs] def __init__(self, model_path, model_config=None):
"""Init tensorflow session and load tf model.
Args:
model_path: init model from this directory
model_config: config string for model to init, in json format
"""
pass
[docs] @abc.abstractmethod
def predict(self, input_data, batch_size):
"""Using session run predict a number of samples using batch_size.
Args:
input_data: a list of numpy array, each array is a sample to be predicted
batch_size: batch_size passed by the caller, you can also ignore this param and
use a fixed number if you do not want to adjust batch_size in runtime
Returns:
result: a list of dict, each dict is the prediction result of one sample
eg, {"output1": value1, "output2": value2}, the value type can be
python int str float, and numpy array
"""
pass
[docs] def get_output_type(self):
"""Get output types of prediction.
In this function user should return a type dict, which indicates which type of
data should the output of predictor be converted to.
In this function user should return a type dict, which indicates
which type of data should the output of predictor be converted to
* type json, data will be serialized to json str
* type image, data will be converted to encode image binary and write to oss file,
whose name is output_dir/${key}/${input_filename}_${idx}.jpg, where input_filename
is extracted from url, key corresponds to the key in the dict of output_type,
if the type of data indexed by key is a list, idx is the index of element in list, otherwhile ${idx} will be empty
* type video, data will be converted to encode video binary and write to oss file,
eg: return {
'image': 'image',
'feature': 'json'
}
indicating that the image data in the output dict will be save to image
file and feature in output dict will be converted to json
"""
return {}
[docs]class PredictorImpl(object):
[docs] def __init__(self, model_path, profiling_file=None):
"""Impl class for predictor.
Args:
model_path: saved_model directory or frozenpb file path
profiling_file: profiling result file, default None.
if not None, predict function will use Timeline to profiling
prediction time, and the result json will be saved to profiling_file
"""
self._inputs_map = {}
self._outputs_map = {}
self._is_saved_model = False
self._profiling_file = profiling_file
self._model_path = model_path
self._input_names = []
self._is_multi_placeholder = True
self._build_model()
@property
def input_names(self):
return self._input_names
@property
def output_names(self):
return list(self._outputs_map.keys())
def __del__(self):
"""Destroy predictor resources."""
self._session.close()
[docs] def search_pb(self, directory):
"""Search pb file recursively in model directory. if multiple pb files exist, exception will be raised.
If multiple pb files exist, exception will be raised.
Args:
directory: model directory.
Returns:
directory contain pb file
"""
dir_list = []
for root, dirs, files in gfile.Walk(directory):
for f in files:
_, ext = os.path.splitext(f)
if ext == '.pb':
dir_list.append(root)
if len(dir_list) == 0:
raise ValueError('savedmodel is not found in directory %s' % directory)
elif len(dir_list) > 1:
raise ValueError('multiple saved model found in directory %s' % directory)
return dir_list[0]
def _get_input_fields_from_pipeline_config(self, model_path):
pipeline_path = os.path.join(model_path, 'assets/pipeline.config')
assert gfile.Exists(pipeline_path), '%s not exists.' % pipeline_path
pipeline_config = get_configs_from_pipeline_file(pipeline_path)
input_fields = pipeline_config.data_config.input_fields
input_fields_info = {
input_field.input_name:
(input_field.input_type, input_field.default_val)
for input_field in input_fields
}
input_fields_list = [input_field.input_name for input_field in input_fields]
return input_fields_info, input_fields_list
def _build_model(self):
"""Load graph from model_path and create session for this graph."""
model_path = self._model_path
self._graph = tf.Graph()
gpu_options = tf.GPUOptions(allow_growth=True)
session_config = tf.ConfigProto(
gpu_options=gpu_options,
allow_soft_placement=True,
log_device_placement=(self._profiling_file is not None))
self._session = tf.Session(config=session_config, graph=self._graph)
with self._graph.as_default():
with self._session.as_default():
# load model
_, ext = os.path.splitext(model_path)
tf.logging.info('loading model from %s' % model_path)
if gfile.IsDirectory(model_path):
model_path = self.search_pb(model_path)
logging.info('model find in %s' % model_path)
self._input_fields_info, self._input_fields_list = self._get_input_fields_from_pipeline_config(
model_path)
assert tf.saved_model.loader.maybe_saved_model_directory(model_path), \
'saved model does not exists in %s' % model_path
self._is_saved_model = True
meta_graph_def = tf.saved_model.loader.load(
self._session, [tf.saved_model.tag_constants.SERVING], model_path)
# parse signature
signature_def = meta_graph_def.signature_def[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
inputs = signature_def.inputs
# each input_info is a tuple of input_id, name, data_type
input_info = []
self._is_multi_placeholder = len(inputs.items()) > 1
if self._is_multi_placeholder:
for gid, item in enumerate(inputs.items()):
name, tensor = item
logging.info('Load input binding: %s -> %s' % (name, tensor.name))
input_name = tensor.name
input_name, _ = input_name.split(':')
try:
input_id = input_name.split('_')[-1]
input_id = int(input_id)
except Exception:
# support for models that are not exported by easy_rec
# in which case, the order of inputs may not be the
# same as they are defined, thereforce, list input
# could not be supported, only dict input could be supported
logging.warning(
'could not determine input_id from input_name: %s' %
input_name)
input_id = gid
input_info.append((input_id, name, tensor.dtype))
self._inputs_map[name] = self._graph.get_tensor_by_name(
tensor.name)
else:
# only one input, all features concatenate together
for name, tensor in inputs.items():
logging.info('Load input binding: %s -> %s' % (name, tensor.name))
input_info.append((0, name, tensor.dtype))
self._inputs_map[name] = self._graph.get_tensor_by_name(
tensor.name)
# sort inputs by input_ids so as to match the order of csv data
input_info.sort(key=lambda t: t[0])
self._input_names = [t[1] for t in input_info]
outputs = signature_def.outputs
for name, tensor in outputs.items():
logging.info('Load output binding: %s -> %s' % (name, tensor.name))
self._outputs_map[name] = self._graph.get_tensor_by_name(
tensor.name)
# get assets
self._assets = {}
asset_files = tf.get_collection(constants.ASSETS_KEY)
for any_proto in asset_files:
asset_file = meta_graph_pb2.AssetFileDef()
any_proto.Unpack(asset_file)
type_name = asset_file.tensor_info.name.split(':')[0]
asset_path = os.path.join(model_path, constants.ASSETS_DIRECTORY,
asset_file.filename)
assert gfile.Exists(
asset_path), '%s is missing in saved model' % asset_path
self._assets[type_name] = asset_path
logging.info(self._assets)
# get export config
self._export_config = {}
# export_config_collection = tf.get_collection(fields.EVGraphKeys.export_config)
# if len(export_config_collection) > 0:
# self._export_config = json.loads(export_config_collection[0])
# logging.info('load export config info %s' % export_config_collection[0])
else:
raise ValueError('currently only savedmodel is supported')
[docs] def predict(self, input_data_dict, output_names=None):
"""Predict input data with loaded model.
Args:
input_data_dict: a dict containing all input data, key is the input name,
value is the corresponding value
output_names: if not None, will fetch certain outputs, if set None, will
return all the output info according to the output info in model signature
Return:
a dict of outputs, key is the output name, value is the corresponding value
"""
feed_dict = {}
for input_name, tensor in six.iteritems(self._inputs_map):
assert input_name in input_data_dict, 'input data %s is missing' % input_name
tensor_shape = tensor.get_shape().as_list()
input_shape = input_data_dict[input_name].shape
assert tensor_shape[0] is None or (tensor_shape[0] == input_shape[0]), \
'input %s batchsize %d is not the same as the exported batch_size %d' % \
(input_name, input_shape[0], tensor_shape[0])
feed_dict[tensor] = input_data_dict[input_name]
fetch_dict = {}
if output_names is not None:
for output_name in output_names:
assert output_name in self._outputs_map, \
'invalid output name %s' % output_name
fetch_dict[output_name] = self._outputs_map[output_name]
else:
fetch_dict = self._outputs_map
with self._graph.as_default():
with self._session.as_default():
if self._profiling_file is None:
return self._session.run(fetch_dict, feed_dict)
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
results = self._session.run(
fetch_dict,
feed_dict,
options=run_options,
run_metadata=run_metadata)
# Create the Timeline object, and write it to a json
from tensorflow.python.client import timeline
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with gfile.GFile(self._profiling_file, 'w') as f:
f.write(ctf)
return results
[docs]class Predictor(PredictorInterface):
[docs] def __init__(self, model_path, profiling_file=None):
"""Initialize a `Predictor`.
Args:
model_path: saved_model directory or frozenpb file path
profiling_file: profiling result file, default None.
if not None, predict function will use Timeline to profiling
prediction time, and the result json will be saved to profiling_file
"""
self._predictor_impl = PredictorImpl(model_path, profiling_file)
self._inputs_map = self._predictor_impl._inputs_map
self._outputs_map = self._predictor_impl._outputs_map
self._profiling_file = profiling_file
self._export_config = self._predictor_impl._export_config
self._input_fields_info = self._predictor_impl._input_fields_info
self._is_multi_placeholder = self._predictor_impl._is_multi_placeholder
self._input_fields = self._predictor_impl._input_fields_list
@property
def input_names(self):
"""Input names of the model.
Returns:
a list, which conaining the name of input nodes available in model
"""
return list(self._inputs_map.keys())
@property
def output_names(self):
"""Output names of the model.
Returns:
a list, which conaining the name of outputs nodes available in model
"""
return list(self._outputs_map.keys())
[docs] def predict_impl(self,
input_table,
output_table,
all_cols='',
all_col_types='',
selected_cols='',
reserved_cols='',
output_cols=None,
batch_size=1024,
slice_id=0,
slice_num=1,
input_sep=',',
output_sep=chr(1)):
"""Predict table input with loaded model.
Args:
input_table: table/file_path to read
output_table: table/file_path to write
all_cols: union of columns
all_col_types: data types of the columns
selected_cols: included column names, comma separated, such as "a,b,c"
reserved_cols: columns to be copy to output_table, comma separated, such as "a,b"
output_cols: output columns, comma separated, such as "y float, embedding string",
the output names[y, embedding] must be in saved_model output_names
batch_size: predict batch size
slice_id: when multiple workers write the same table, each worker should
be assigned different slice_id, which is usually slice_id
slice_num: table slice number
input_sep: separator of input file.
output_sep: separator of predict result file.
"""
if pai_util.is_on_pai():
self.predict_table(
input_table,
output_table,
all_cols=all_cols,
all_col_types=all_col_types,
selected_cols=selected_cols,
reserved_cols=reserved_cols,
output_cols=output_cols,
batch_size=batch_size,
slice_id=slice_id,
slice_num=slice_num)
else:
self.predict_csv(
input_table,
output_table,
reserved_cols=reserved_cols,
output_cols=output_cols,
batch_size=batch_size,
slice_id=slice_id,
slice_num=slice_num,
input_sep=input_sep,
output_sep=output_sep)
[docs] def predict_csv(self, input_path, output_path, reserved_cols, output_cols,
batch_size, slice_id, slice_num, input_sep, output_sep):
record_defaults = [
self._input_fields_info[col_name][1] for col_name in self._input_fields
]
if reserved_cols == 'ALL_COLUMNS':
reserved_cols = self._input_fields
else:
reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
if output_cols is None or output_cols == 'ALL_COLUMNS':
output_cols = sorted(self._predictor_impl.output_names)
logging.info('predict output cols: %s' % output_cols)
else:
# specified as score float,embedding string
tmp_cols = []
for x in output_cols.split(','):
if x.strip() == '':
continue
tmp_keys = x.split(' ')
tmp_cols.append(tmp_keys[0].strip())
output_cols = tmp_cols
with tf.Graph().as_default(), tf.Session() as sess:
num_parallel_calls = 8
file_paths = []
for x in input_path.split(','):
file_paths.extend(gfile.Glob(x))
assert len(file_paths) > 0, 'match no files with %s' % input_path
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
parallel_num = min(num_parallel_calls, len(file_paths))
dataset = dataset.interleave(
tf.data.TextLineDataset,
cycle_length=parallel_num,
num_parallel_calls=parallel_num)
dataset = dataset.shard(slice_num, slice_id)
logging.info('batch_size = %d' % batch_size)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=64)
def _parse_csv(line):
def _check_data(line):
sep = input_sep
if type(sep) != type(str):
sep = sep.encode('utf-8')
field_num = len(line[0].split(sep))
assert field_num == len(record_defaults), 'sep[%s] maybe invalid: field_num=%d, required_num=%d' \
% (sep, field_num, len(record_defaults))
return True
check_op = tf.py_func(_check_data, [line], Tout=tf.bool)
with tf.control_dependencies([check_op]):
fields = tf.decode_csv(
line,
field_delim=',',
record_defaults=record_defaults,
name='decode_csv')
inputs = {self._input_fields[x]: fields[x] for x in range(len(fields))}
return inputs
dataset = dataset.map(_parse_csv, num_parallel_calls=num_parallel_calls)
iterator = dataset.make_one_shot_iterator()
all_dict = iterator.get_next()
if not gfile.Exists(output_path):
gfile.MakeDirs(output_path)
res_path = os.path.join(output_path, 'slice_%d.csv' % slice_id)
table_writer = gfile.GFile(res_path, 'w')
input_names = self._predictor_impl.input_names
progress = 0
sum_t0, sum_t1, sum_t2 = 0, 0, 0
pred_cnt = 0
table_writer.write(output_sep.join(output_cols + reserved_cols) + '\n')
while True:
try:
ts0 = time.time()
all_vals = sess.run(all_dict)
ts1 = time.time()
input_vals = {k: all_vals[k] for k in input_names}
outputs = self._predictor_impl.predict(input_vals, output_cols)
for x in output_cols:
if outputs[x].dtype == np.object:
outputs[x] = [val.decode('utf-8') for val in outputs[x]]
for k in reserved_cols:
if all_vals[k].dtype == np.object:
all_vals[k] = [val.decode('utf-8') for val in all_vals[k]]
ts2 = time.time()
reserve_vals = [outputs[x] for x in output_cols] + \
[all_vals[k] for k in reserved_cols]
outputs = [x for x in zip(*reserve_vals)]
pred_cnt += len(outputs)
outputs = '\n'.join(
[output_sep.join([str(i) for i in output]) for output in outputs])
table_writer.write(outputs + '\n')
ts3 = time.time()
progress += 1
sum_t0 += (ts1 - ts0)
sum_t1 += (ts2 - ts1)
sum_t2 += (ts3 - ts2)
except tf.errors.OutOfRangeError:
break
if progress % 100 == 0:
logging.info('progress: batch_num=%d sample_num=%d' %
(progress, progress * batch_size))
logging.info('time_stats: read: %.2f predict: %.2f write: %.2f' %
(sum_t0, sum_t1, sum_t2))
logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' %
(sum_t0, sum_t1, sum_t2))
table_writer.close()
logging.info('Predict %s done.' % input_path)
logging.info('Predict size: %d.' % pred_cnt)
[docs] def predict_table(self,
input_table,
output_table,
all_cols,
all_col_types,
selected_cols,
reserved_cols,
output_cols=None,
batch_size=1024,
slice_id=0,
slice_num=1):
def _get_defaults(col_name, col_type):
if col_name in self._input_fields_info:
col_type, default_val = self._input_fields_info[col_name]
default_val = get_type_defaults(col_type, default_val)
logging.info('col_name: %s, default_val: %s' % (col_name, default_val))
else:
logging.info('col_name: %s is not used in predict.' % col_name)
defaults = {'string': '', 'double': 0.0, 'bigint': 0}
assert col_type in defaults, 'invalid col_type: %s, col_type: %s' % (
col_name, col_type)
default_val = defaults[col_type]
return default_val
all_cols = [x.strip() for x in all_cols.split(',') if x != '']
all_col_types = [x.strip() for x in all_col_types.split(',') if x != '']
reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
if output_cols is None:
output_cols = self._predictor_impl.output_names
else:
# specified as score float,embedding string
tmp_cols = []
for x in output_cols.split(','):
if x.strip() == '':
continue
tmp_keys = x.split(' ')
tmp_cols.append(tmp_keys[0].strip())
output_cols = tmp_cols
record_defaults = [
_get_defaults(col_name, col_type)
for col_name, col_type in zip(all_cols, all_col_types)
]
with tf.Graph().as_default(), tf.Session() as sess:
num_parallel_calls = 8
input_table = input_table.split(',')
dataset = tf.data.TableRecordDataset([input_table],
record_defaults=record_defaults,
slice_id=slice_id,
slice_count=slice_num,
selected_cols=','.join(all_cols))
logging.info('batch_size = %d' % batch_size)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=64)
def _parse_table(*fields):
fields = list(fields)
field_dict = {all_cols[i]: fields[i] for i in range(len(fields))}
return field_dict
dataset = dataset.map(_parse_table, num_parallel_calls=num_parallel_calls)
iterator = dataset.make_one_shot_iterator()
all_dict = iterator.get_next()
import common_io
table_writer = common_io.table.TableWriter(
output_table, slice_id=slice_id)
input_names = self._predictor_impl.input_names
def _parse_value(all_vals):
if self._is_multi_placeholder:
if SINGLE_PLACEHOLDER_FEATURE_KEY in all_vals:
feature_vals = all_vals[SINGLE_PLACEHOLDER_FEATURE_KEY]
split_index = []
split_vals = {}
for i, k in enumerate(input_names):
split_index.append(k)
split_vals[k] = []
for record in feature_vals:
split_records = record.split('\002')
for i, r in enumerate(split_records):
split_vals[split_index[i]].append(r)
return {k: np.array(split_vals[k]) for k in input_names}
return {k: all_vals[k] for k in input_names}
progress = 0
sum_t0, sum_t1, sum_t2 = 0, 0, 0
while True:
try:
ts0 = time.time()
all_vals = sess.run(all_dict)
ts1 = time.time()
input_vals = _parse_value(all_vals)
# logging.info('input names = %s' % input_names)
# logging.info('input vals = %s' % input_vals)
outputs = self._predictor_impl.predict(input_vals, output_cols)
ts2 = time.time()
reserve_vals = [all_vals[k] for k in reserved_cols
] + [outputs[x] for x in output_cols]
indices = list(range(0, len(reserve_vals)))
outputs = [x for x in zip(*reserve_vals)]
table_writer.write(outputs, indices, allow_type_cast=False)
ts3 = time.time()
progress += 1
sum_t0 += (ts1 - ts0)
sum_t1 += (ts2 - ts1)
sum_t2 += (ts3 - ts2)
except tf.python_io.OutOfRangeException:
break
except tf.errors.OutOfRangeError:
break
if progress % 100 == 0:
logging.info('progress: batch_num=%d sample_num=%d' %
(progress, progress * batch_size))
logging.info('time_stats: read: %.2f predict: %.2f write: %.2f' %
(sum_t0, sum_t1, sum_t2))
logging.info('Final_time_stats: read: %.2f predict: %.2f write: %.2f' %
(sum_t0, sum_t1, sum_t2))
table_writer.close()
logging.info('Predict %s done.' % input_table)
[docs] def predict(self, input_data_dict_list, output_names=None, batch_size=1):
"""Predict input data with loaded model.
Args:
input_data_dict_list: list of dict
output_names: if not None, will fetch certain outputs, if set None, will
batch_size: batch_size used to predict, -1 indicates to use the real batch_size
Return:
a list of dict, each dict contain a key-value pair for output_name, output_value
"""
num_example = len(input_data_dict_list)
assert num_example > 0, 'input data should not be an empty list'
assert isinstance(input_data_dict_list[0], dict) or \
isinstance(input_data_dict_list[0], list) or \
isinstance(input_data_dict_list[0], str), 'input is not a list or dict or str'
if batch_size > 0:
num_batches = int(math.ceil(float(num_example) / batch_size))
else:
num_batches = 1
batch_size = len(input_data_dict_list)
outputs_list = []
for batch_idx in range(num_batches):
batch_data_list = input_data_dict_list[batch_idx *
batch_size:(batch_idx + 1) *
batch_size]
feed_dict = self.batch(batch_data_list)
outputs = self._predictor_impl.predict(feed_dict, output_names)
for idx in range(len(batch_data_list)):
single_result = {}
for key, batch_value in six.iteritems(outputs):
single_result[key] = batch_value[idx]
outputs_list.append(single_result)
return outputs_list
[docs] def batch(self, data_list):
"""Batching the data."""
batch_input = {key: [] for key in self._predictor_impl.input_names}
for data in data_list:
if isinstance(data, dict):
for key in data:
batch_input[key].append(data[key])
elif isinstance(data, list):
assert len(self._predictor_impl.input_names) == len(data), \
'input fields number incorrect, should be %d, but %d' \
% (len(self._predictor_impl.input_names), len(data))
for key, v in zip(self._predictor_impl.input_names, data):
if key != '':
batch_input[key].append(v)
elif isinstance(data, str):
batch_input[self._predictor_impl.input_names[0]].append(data)
for key in batch_input:
batch_input[key] = np.array(batch_input[key])
return batch_input