Source code for easy_rec.python.input.odps_input_v2
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import tensorflow as tf
from easy_rec.python.input.input import Input
from easy_rec.python.utils import odps_util
try:
import pai
except Exception:
pass
[docs]class OdpsInputV2(Input):
[docs] def __init__(self,
data_config,
feature_config,
input_path,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None):
super(OdpsInputV2,
self).__init__(data_config, feature_config, input_path, task_index,
task_num, check_mode, pipeline_config)
def _parse_table(self, *fields):
fields = list(fields)
inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids}
for x in self._label_fids:
inputs[self._input_fields[x]] = fields[x]
return inputs
def _build(self, mode, params):
if type(self._input_path) != list:
self._input_path = self._input_path.split(',')
assert len(
self._input_path) > 0, 'match no files with %s' % self._input_path
# check data_config are consistent with odps tables
odps_util.check_input_field_and_types(self._data_config)
selected_cols = ','.join(self._input_fields)
record_defaults = [
self.get_type_defaults(x, v)
for x, v in zip(self._input_field_types, self._input_field_defaults)
]
if self._data_config.pai_worker_queue and \
mode == tf.estimator.ModeKeys.TRAIN:
logging.info('pai_worker_slice_num = %d' %
self._data_config.pai_worker_slice_num)
work_queue = pai.data.WorkQueue(
self._input_path,
num_epochs=self.num_epochs,
shuffle=self._data_config.shuffle,
num_slices=self._data_config.pai_worker_slice_num * self._task_num)
que_paths = work_queue.input_dataset()
dataset = tf.data.TableRecordDataset(
que_paths,
record_defaults=record_defaults,
selected_cols=selected_cols)
elif self._data_config.chief_redundant and \
mode == tf.estimator.ModeKeys.TRAIN:
dataset = tf.data.TableRecordDataset(
self._input_path,
record_defaults=record_defaults,
selected_cols=selected_cols,
slice_id=max(self._task_index - 1, 0),
slice_count=max(self._task_num - 1, 1))
else:
dataset = tf.data.TableRecordDataset(
self._input_path,
record_defaults=record_defaults,
selected_cols=selected_cols,
slice_id=self._task_index,
slice_count=self._task_num)
if mode == tf.estimator.ModeKeys.TRAIN:
if self._data_config.shuffle:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
seed=2020,
reshuffle_each_iteration=True)
dataset = dataset.repeat(self.num_epochs)
else:
dataset = dataset.repeat(1)
dataset = dataset.batch(batch_size=self._data_config.batch_size)
dataset = dataset.map(
self._parse_table,
num_parallel_calls=self._data_config.num_parallel_calls)
# preprocess is necessary to transform data
# so that they could be feed into FeatureColumns
dataset = dataset.map(
map_func=self._preprocess,
num_parallel_calls=self._data_config.num_parallel_calls)
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
if mode != tf.estimator.ModeKeys.PREDICT:
dataset = dataset.map(lambda x:
(self._get_features(x), self._get_labels(x)))
else:
dataset = dataset.map(lambda x: (self._get_features(x)))
return dataset