Source code for easy_rec.python.input.kafka_input

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import logging
import traceback

import six
import tensorflow as tf
from tensorflow.python.platform import gfile

from easy_rec.python.input.input import Input
from easy_rec.python.input.kafka_dataset import KafkaDataset
from easy_rec.python.utils.config_util import parse_time

try:
  from kafka import KafkaConsumer, TopicPartition
except ImportError:
  logging.warning(
      'kafka-python is not installed[%s]. You can install it by: pip install kafka-python'
      % traceback.format_exc())

if tf.__version__ >= '2.0':
  ignore_errors = tf.data.experimental.ignore_errors()
  tf = tf.compat.v1
else:
  ignore_errors = tf.contrib.data.ignore_errors()


[docs]class KafkaInput(Input): DATA_OFFSET = 'DATA_OFFSET'
[docs] def __init__(self, data_config, feature_config, kafka_config, task_index=0, task_num=1, check_mode=False, pipeline_config=None): super(KafkaInput, self).__init__(data_config, feature_config, '', task_index, task_num, check_mode, pipeline_config) self._kafka = kafka_config self._offset_dict = {} if self._kafka is not None: consumer = KafkaConsumer( group_id='kafka_dataset_consumer', bootstrap_servers=[self._kafka.server], api_version_auto_timeout_ms=60000) # in miliseconds partitions = consumer.partitions_for_topic(self._kafka.topic) self._num_partition = len(partitions) logging.info('all partitions[%d]: %s' % (self._num_partition, partitions)) # determine kafka offsets for each partition offset_type = self._kafka.WhichOneof('offset') if offset_type is not None: if offset_type == 'offset_time': ts = parse_time(self._kafka.offset_time) input_map = { TopicPartition(partition=part_id, topic=self._kafka.topic): ts * 1000 for part_id in partitions } part_offsets = consumer.offsets_for_times(input_map) # part_offsets is a dictionary: # { # TopicPartition(topic=u'kafka_data_20220408', partition=0): # OffsetAndTimestamp(offset=2, timestamp=1650611437895) # } for part in part_offsets: self._offset_dict[part.partition] = part_offsets[part].offset logging.info( 'Find offset by time, topic[%s], partition[%d], timestamp[%ss], offset[%d], offset_timestamp[%dms]' % (self._kafka.topic, part.partition, ts, part_offsets[part].offset, part_offsets[part].timestamp)) elif offset_type == 'offset_info': offset_dict = json.loads(self._kafka.offset_info) for part in offset_dict: part_id = int(part) self._offset_dict[part_id] = offset_dict[part] else: assert 'invalid offset_type: %s' % offset_type self._task_offset_dict = {}
def _preprocess(self, field_dict): output_dict = super(KafkaInput, self)._preprocess(field_dict) # append offset fields if Input.DATA_OFFSET in field_dict: output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET] # for _get_features to include DATA_OFFSET if Input.DATA_OFFSET not in self._appended_fields: self._appended_fields.append(Input.DATA_OFFSET) return output_dict def _parse_csv(self, line, message_key, message_offset): record_defaults = [ self.get_type_defaults(t, v) for t, v in zip(self._input_field_types, self._input_field_defaults) ] fields = tf.decode_csv( line, use_quote_delim=False, field_delim=self._data_config.separator, record_defaults=record_defaults, name='decode_csv') inputs = {self._input_fields[x]: fields[x] for x in self._effective_fids} for x in self._label_fids: inputs[self._input_fields[x]] = fields[x] # record current offset def _parse_offset(message_offset): for kv in message_offset: if six.PY3: kv = kv.decode('utf-8') k, v = kv.split(':') k = int(k) v = int(v) if k not in self._task_offset_dict or v > self._task_offset_dict[k]: self._task_offset_dict[k] = v return json.dumps(self._task_offset_dict) inputs[Input.DATA_OFFSET] = tf.py_func(_parse_offset, [message_offset], tf.string) return inputs
[docs] def restore(self, checkpoint_path): if checkpoint_path is None: return offset_path = checkpoint_path + '.offset' if not gfile.Exists(offset_path): return logging.info('will restore kafka offset from %s' % offset_path) with gfile.GFile(offset_path, 'r') as fin: offset_dict = json.load(fin) self._offset_dict = {} for k in offset_dict: v = offset_dict[k] k = int(k) if k not in self._offset_dict or v > self._offset_dict[k]: self._offset_dict[k] = v
def _get_topics(self): task_num = self._task_num task_index = self._task_index if self._data_config.chief_redundant and self._mode == tf.estimator.ModeKeys.TRAIN: task_index = max(task_index - 1, 0) task_num = max(task_num - 1, 1) topics = [] self._task_offset_dict = {} for part_id in range(self._num_partition): if (part_id % task_num) == task_index: offset = self._offset_dict.get(part_id, 0) topics.append('%s:%d:%d' % (self._kafka.topic, part_id, offset)) self._task_offset_dict[part_id] = offset logging.info('assigned topic partitions: %s' % (','.join(topics))) assert len( topics) > 0, 'no partitions are assigned for this task(%d/%d)' % ( self._task_index, self._task_num) return topics def _build(self, mode, params): num_parallel_calls = self._data_config.num_parallel_calls task_topics = self._get_topics() if mode == tf.estimator.ModeKeys.TRAIN: assert self._kafka is not None, 'kafka_train_input is not set.' train_kafka = self._kafka logging.info( 'train kafka server: %s topic: %s task_num: %d task_index: %d topics: %s' % (train_kafka.server, train_kafka.topic, self._task_num, self._task_index, task_topics)) dataset = KafkaDataset( task_topics, servers=train_kafka.server, group=train_kafka.group, eof=False, config_global=list(self._kafka.config_global), config_topic=list(self._kafka.config_topic), message_key=True, message_offset=True) if self._data_config.shuffle: dataset = dataset.shuffle( self._data_config.shuffle_buffer_size, seed=2020, reshuffle_each_iteration=True) else: eval_kafka = self._kafka assert self._kafka is not None, 'kafka_eval_input is not set.' logging.info( 'eval kafka server: %s topic: %s task_num: %d task_index: %d topics: %s' % (eval_kafka.server, eval_kafka.topic, self._task_num, self._task_index, task_topics)) dataset = KafkaDataset( task_topics, servers=self._kafka.server, group=eval_kafka.group, eof=False, config_global=list(self._kafka.config_global), config_topic=list(self._kafka.config_topic), message_key=True, message_offset=True) dataset = dataset.batch(self._data_config.batch_size) dataset = dataset.map( self._parse_csv, num_parallel_calls=num_parallel_calls) if self._data_config.ignore_error: dataset = dataset.apply(ignore_errors) dataset = dataset.prefetch(buffer_size=self._prefetch_size) dataset = dataset.map( map_func=self._preprocess, num_parallel_calls=num_parallel_calls) dataset = dataset.prefetch(buffer_size=self._prefetch_size) if mode != tf.estimator.ModeKeys.PREDICT: dataset = dataset.map(lambda x: (self._get_features(x), self._get_labels(x))) else: dataset = dataset.map(lambda x: (self._get_features(x))) return dataset