# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Load_class.py tools for loading classes."""
import inspect
import logging
import os
import pkgutil
import pydoc
from abc import ABCMeta
import six
import tensorflow as tf
import easy_rec
from easy_rec.python.utils import compat
[docs]def python_file_to_module(python_file):
mod = python_file.strip('/').replace('/', '.')
if mod.endswith('.py'):
mod = mod[:-3]
return mod
[docs]def load_by_path(path):
"""Load functions or modules or classes.
Args:
path: path to modules or functions or classes,
such as: tf.nn.relu
Return:
modules or functions or classes
"""
path = path.strip()
if path == '' or path is None:
return None
components = path.split('.')
if components[0] == 'tf':
components[0] = 'tensorflow'
path = '.'.join(components)
try:
return pydoc.locate(path)
except pydoc.ErrorDuringImport:
logging.error('load %s failed' % path)
return None
def _get_methods(aClass):
def should_track(func_name):
return func_name == '__init__' or func_name[0] != '_'
names = sorted(dir(aClass), key=str.lower)
attrs = [(n, getattr(aClass, n)) for n in names if should_track(n)]
# in python3 , unbound class method is function while is
# method in python2
if compat.in_python3():
return dict((n, a) for n, a in attrs if inspect.isfunction(a))
else:
return dict((n, a) for n, a in attrs if inspect.ismethod(a))
def _get_method_declare(aMethod):
try:
name = aMethod.__name__
if compat.in_python3():
sig_str = str(inspect.signature(aMethod))
return sig_str
else:
spec = inspect.getargspec(aMethod)
args = inspect.formatargspec(spec.args, spec.varargs, spec.keywords,
spec.defaults)
return '%s%s' % (name, args)
except TypeError:
return '%s(cls, ...)' % name
[docs]def check_class(cls, impl_cls, function_names=None):
"""Check implemented class is valid according to template class.
if function signature is not the same, exception will be raised.
Args:
cls: class which declares functions that need users to implement
impl_cls: user implemented class
function_names: if not None, will only check these funtions and their signature
"""
missing = {}
ours = _get_methods(cls)
theirs = _get_methods(impl_cls)
for name, method in six.iteritems(ours):
if function_names is not None and name not in function_names:
continue
if name not in theirs:
missing[name + '()'] = 'not implemented'
continue
ourf = _get_method_declare(method)
theirf = _get_method_declare(theirs[name])
if not (ourf == theirf):
missing[name + '()'] = 'method signature differs'
if len(missing) > 0:
raise Exception('incompatible Implementation-implementation %s: %s' %
(impl_cls.__class__.__name__, missing))
[docs]def import_pkg(pkg_info, prefix_to_remove=None):
"""Import package.
Args:
pkg_info: pkgutil.ModuleInfo object
prefix_to_remove: the package prefix to be removed
"""
package_path = pkg_info[0].path
if prefix_to_remove is not None:
package_path = package_path.replace(prefix_to_remove, '')
mod_name = pkg_info[1]
if package_path.startswith('/'):
# absolute path file, we should use relative import
mod = pkg_info[0].find_module(mod_name)
if mod is not None:
# skip those test files in easyrec
if not mod_name.endswith('_test'):
mod.load_module(pkg_info[1])
else:
raise Exception('import module %s failed' % (package_path + mod_name))
else:
# use similar import methods as the import keyword
module_path = os.path.join(package_path, mod_name).replace('/', '.')
# skip those test files
if not mod_name.endswith('_test'):
try:
__import__(module_path)
except Exception as e:
import traceback
logging.error(traceback.format_exc())
raise ValueError('import module %s failed: %s' % (module_path, str(e)))
[docs]def auto_import(user_path=None):
"""Auto import python files so that register_xxx decorator will take effect.
By default, we will import files in pre-defined directory and import all
files recursively in user_dir
Args:
user_path: directory or file that store user-defined python code, by default we wiil only
search file in current directory
"""
# True False indicates import recursively or not
pre_defined_dirs = [
('easy_rec/python/model', False),
('easy_rec/python/input', False),
]
parent_dir = easy_rec.parent_dir
prefix_to_remove = None
# dealing with easy-rec in sited-packages, remove parent directory prefix
# to make class name starts with easy_rec
if parent_dir != '':
for idx in range(len(pre_defined_dirs)):
pre_defined_dirs[idx] = (os.path.join(parent_dir,
pre_defined_dirs[idx][0]),
pre_defined_dirs[idx][1])
prefix_to_remove = parent_dir + '/'
if user_path is not None:
if tf.gfile.IsDirectory(user_path):
user_dir = user_path
else:
user_dir, _ = os.path.split(user_path)
pre_defined_dirs.append((user_dir, True))
for dir_path, recursive_import in pre_defined_dirs:
for pkg_info in pkgutil.iter_modules([dir_path]):
import_pkg(pkg_info, prefix_to_remove)
if recursive_import:
for root, dirs, files in os.walk(dir_path):
for subdir in dirs:
dirname = os.path.join(root, subdir)
for pkg_info in pkgutil.iter_modules([dirname]):
import_pkg(pkg_info, prefix_to_remove)
[docs]def register_class(class_map, class_name, cls):
assert class_name not in class_map or class_map[class_name] == cls, \
'confilict class %s , %s is already register to be %s' % (
cls, class_name, str(class_map[class_name]))
logging.debug('register class %s' % class_name)
class_map[class_name] = cls