Source code for easy_rec.python.input.csv_input_v2
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from easy_rec.python.input.input import Input
[docs]class CSVInputV2(Input):
[docs] def __init__(self,
data_config,
feature_config,
input_path,
task_index=0,
task_num=1,
check_mode=False):
super(CSVInputV2, self).__init__(data_config, feature_config, input_path,
task_index, task_num, check_mode)
def _build(self, mode, params):
if type(self._input_path) != list:
self._input_path = self._input_path.split(',')
assert len(
self._input_path) > 0, 'match no files with %s' % self._input_path
if self._input_path[0].startswith('hdfs://'):
# support hdfs input
dataset = tf.data.TextLineDataset(self._input_path)
else:
num_epochs = self.num_epochs if mode == tf.estimator.ModeKeys.TRAIN else 1
is_train = (mode == tf.estimator.ModeKeys.TRAIN)
record_defaults = [
self.get_type_defaults(x, v)
for x, v in zip(self._input_field_types, self._input_field_defaults)
]
dataset = tf.data.experimental.make_csv_dataset(
self._input_path,
self._data_config.batch_size,
column_names=self._input_fields,
field_delim=self._data_config.separator,
column_defaults=record_defaults,
header=False,
num_epochs=num_epochs,
shuffle=is_train and self._data_config.shuffle,
num_parallel_reads=8,
sloppy=is_train)
if mode == tf.estimator.ModeKeys.TRAIN:
if self._data_config.chief_redundant:
dataset = dataset.shard(
max(self._task_num - 1, 1), max(self._task_index - 1, 0))
else:
dataset = dataset.shard(self._task_num, self._task_index)
else:
dataset = dataset.repeat(1)
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
dataset = dataset.map(map_func=self._preprocess, num_parallel_calls=8)
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
if mode != tf.estimator.ModeKeys.PREDICT:
dataset = dataset.map(lambda x:
(self._get_features(x), self._get_labels(x)))
else:
dataset = dataset.map(lambda x: (self._get_features(x)))
return dataset