Source code for easy_rec.python.input.kafka_input

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import sys

import tensorflow as tf

from easy_rec.python.input.input import Input

if tf.__version__ >= '2.0':
  tf = tf.compat.v1


[docs]class KafkaInput(Input):
[docs] def __init__(self, data_config, feature_config, kafka_config, task_index=0, task_num=1): super(KafkaInput, self).__init__(data_config, feature_config, '', task_index, task_num) self._kafka = kafka_config
def _parse_csv(self, line): record_defaults = [ self.get_type_defaults(t, v) for t, v in zip(self._input_field_types, self._input_field_defaults) ] def _check_data(line): sep = self._data_config.separator if type(sep) != type(str): sep = sep.encode('utf-8') field_num = len(line[0].split(sep)) assert field_num == len(record_defaults),\ 'sep[%s] maybe invalid: field_num=%d, required_num=%d' % (sep, field_num, len(record_defaults)) return True check_op = tf.py_func(_check_data, [line], Tout=tf.bool) with tf.control_dependencies([check_op]): fields = tf.decode_csv( line, field_delim=self._data_config.separator, record_defaults=record_defaults, name='decode_csv') inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids} for x in self._label_fids: inputs[self._input_fields[x]] = fields[x] return inputs def _build(self, mode, params): try: import tensorflow_io.kafka as kafka_io except ImportError: logging.error( 'Please install tensorflow-io, ' 'version compatibility can refer to https://github.com/tensorflow/io#tensorflow-version-compatibility' ) num_parallel_calls = self._data_config.num_parallel_calls if mode == tf.estimator.ModeKeys.TRAIN: train = self._kafka topics = [] i = self._task_index assert len(train.offset) == 1 or len(train.offset) == train.partitions, \ 'number of train.offset must be 1 or train.partitions' while i < train.partitions: offset_i = train.offset[i] if i < len( train.offset) else train.offset[-1] topics.append(train.topic + ':' + str(i) + ':' + str(offset_i) + ':-1') i = i + self._task_num logging.info( 'train kafka server: %s topic: %s task_num: %d task_index: %d topics: %s' % (train.server, train.topic, self._task_num, self._task_index, topics)) if len(topics) == 0: logging.info('train kafka topic is empty') sys.exit(1) dataset = kafka_io.KafkaDataset( topics, servers=train.server, group=train.group, eof=False) dataset = dataset.repeat(1) else: eval = self._kafka topics = [] i = 0 assert len(eval.offset) == 1 or len(eval.offset) == eval.partitions, \ 'number of eval.offset must be 1 or eval.partitions' while i < eval.partitions: offset_i = eval.offset[i] if i < len(eval.offset) else eval.offset[-1] topics.append(eval.topic + ':' + str(i) + ':' + str(eval.offset) + ':-1') i = i + 1 logging.info( 'eval kafka server: %s topic: %s task_num: %d task_index: %d topics: %s' % (eval.server, eval.topic, self._task_num, self._task_index, topics)) if len(topics) == 0: logging.info('eval kafka topic is empty') sys.exit(1) dataset = kafka_io.KafkaDataset( topics, servers=eval.server, group=eval.group, eof=False) dataset = dataset.repeat(1) dataset = dataset.batch(self._data_config.batch_size) dataset = dataset.map( self._parse_csv, num_parallel_calls=num_parallel_calls) dataset = dataset.prefetch(buffer_size=self._prefetch_size) dataset = dataset.map( map_func=self._preprocess, num_parallel_calls=num_parallel_calls) dataset = dataset.prefetch(buffer_size=self._prefetch_size) if mode != tf.estimator.ModeKeys.PREDICT: dataset = dataset.map(lambda x: (self._get_features(x), self._get_labels(x))) else: dataset = dataset.map(lambda x: (self._get_features(x))) return dataset