Source code for easy_rec.python.utils.odps_util
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Common functions used for odps input."""
from tensorflow.python.framework import dtypes
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
[docs]def is_type_compatiable(odps_type, input_type):
"""Check that odps_type are compatiable with input_type."""
type_map = {
'bigint': DatasetConfig.INT64,
'string': DatasetConfig.STRING,
'double': DatasetConfig.DOUBLE
}
tmp_type = type_map[odps_type]
if tmp_type == input_type:
return True
else:
float_types = [DatasetConfig.FLOAT, DatasetConfig.DOUBLE]
int_types = [DatasetConfig.INT32, DatasetConfig.INT64]
if tmp_type in float_types and input_type in float_types:
return True
elif tmp_type in int_types and input_type in int_types:
return True
else:
return False
[docs]def odps_type_to_input_type(odps_type):
"""Check that odps_type are compatiable with input_type."""
odps_type_map = {
'bigint': DatasetConfig.INT64,
'string': DatasetConfig.STRING,
'double': DatasetConfig.DOUBLE
}
assert odps_type in odps_type_map, 'only support [bigint, string, double]'
input_type = odps_type_map[odps_type]
return input_type
[docs]def check_input_field_and_types(data_config):
"""Check compatibility of input in data_config.
check that data_config.input_fields are compatible with
data_config.selected_cols and data_config.selected_types.
Args:
data_config: instance of DatasetConfig
"""
input_fields = [x.input_name for x in data_config.input_fields]
input_field_types = [x.input_type for x in data_config.input_fields]
selected_cols = data_config.selected_cols if data_config.selected_cols else None
selected_col_types = data_config.selected_col_types if data_config.selected_col_types else None
if not selected_cols:
return
selected_cols = selected_cols.split(',')
for x in input_fields:
assert x in selected_cols, 'column %s is not in table' % x
if selected_col_types:
selected_types = selected_col_types.split(',')
type_map = {x: y for x, y in zip(selected_cols, selected_types)}
for x, y in zip(input_fields, input_field_types):
tmp_type = type_map[x]
assert is_type_compatiable(tmp_type, y), \
'feature[%s] type error: odps %s is not compatible with input_type %s' % (
x, tmp_type, DatasetConfig.FieldType.Name(y))
[docs]def odps_type_2_tf_type(odps_type):
if odps_type == 'string':
return dtypes.string
elif odps_type == 'bigint':
return dtypes.int64
elif odps_type in ['double', 'float']:
return dtypes.float32
else:
return dtypes.string