Source code for easy_rec.python.input.odps_input

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf

from easy_rec.python.input.input import Input
from easy_rec.python.utils import odps_util

try:
  import pai
except Exception:
  pass


[docs]class OdpsInput(Input):
[docs] def __init__(self, data_config, feature_config, input_path, task_index=0, task_num=1, check_mode=False, pipeline_config=None): super(OdpsInput, self).__init__(data_config, feature_config, input_path, task_index, task_num, check_mode, pipeline_config)
def _build(self, mode, params): # check data_config are consistent with odps tables odps_util.check_input_field_and_types(self._data_config) selected_cols = ','.join(self._input_fields) if self._data_config.chief_redundant and \ mode == tf.estimator.ModeKeys.TRAIN: reader = tf.TableRecordReader( csv_delimiter=self._data_config.separator, selected_cols=selected_cols, slice_count=max(self._task_num - 1, 1), slice_id=max(self._task_index - 1, 0)) else: reader = tf.TableRecordReader( csv_delimiter=self._data_config.separator, selected_cols=selected_cols, slice_count=self._task_num, slice_id=self._task_index) if type(self._input_path) != list: self._input_path = self._input_path.split(',') assert len( self._input_path) > 0, 'match no files with %s' % self._input_path if mode == tf.estimator.ModeKeys.TRAIN: if self._data_config.pai_worker_queue: work_queue = pai.data.WorkQueue( self._input_path, num_epochs=self.num_epochs, shuffle=self._data_config.shuffle, num_slices=self._data_config.pai_worker_slice_num * self._task_num) work_queue.add_summary() file_queue = work_queue.input_producer() reader = tf.TableRecordReader() else: file_queue = tf.train.string_input_producer( self._input_path, num_epochs=self.num_epochs, capacity=1000, shuffle=self._data_config.shuffle) else: file_queue = tf.train.string_input_producer( self._input_path, num_epochs=1, capacity=1000, shuffle=False) key, value = reader.read_up_to(file_queue, self._batch_size) record_defaults = [ self.get_type_defaults(t, v) for t, v in zip(self._input_field_types, self._input_field_defaults) ] fields = tf.decode_csv( value, record_defaults=record_defaults, field_delim=self._data_config.separator, name='decode_csv') inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids} for x in self._label_fids: inputs[self._input_fields[x]] = fields[x] fields = self._preprocess(inputs) features = self._get_features(fields) # import pai if mode != tf.estimator.ModeKeys.PREDICT: labels = self._get_labels(fields) # features, labels = pai.data.prefetch(features=(features, labels), # capacity=self._prefetch_size, num_threads=2, # closed_exception_types=(tuple([tf.errors.InternalError]))) return features, labels else: # features = pai.data.prefetch(features=(features,), # capacity=self._prefetch_size, num_threads=2, # closed_exception_types=(tuple([tf.errors.InternalError]))) return features