Source code for easy_rec.python.utils.restore_filter

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Define filters for restore."""

from abc import ABCMeta
from abc import abstractmethod
from enum import Enum


[docs]class Logical(Enum): AND = 1 OR = 2
[docs]class Filter: __metaclass__ = ABCMeta
[docs] def __init__(self): pass
[docs] @abstractmethod def keep(self, var_name): """Keep the var or not. Args: var_name: input name of the var Returns: True if the var will be kept, else False """ return True
[docs]class KeywordFilter(Filter):
[docs] def __init__(self, pattern, exclusive=False): """Init KeywordFilter. Args: pattern: keyword to be matched exclusive: if True, var_name should include the pattern else, var_name should not include the pattern """ self._pattern = pattern self._exclusive = exclusive
[docs] def keep(self, var_name): if not self._exclusive: return self._pattern in var_name else: return self._pattern not in var_name
[docs]class CombineFilter(Filter):
[docs] def __init__(self, filters, logical=Logical.AND): """Init CombineFilter. Args: filters: a set of filters to be combined logical: logical and/or combination of the filters """ self._filters = filters self._logical = logical
[docs] def keep(self, var_name): if self._logical == Logical.AND: for one_filter in self._filters: if not one_filter.keep(var_name): return False return True elif self._logical == Logical.OR: for one_filter in self._filters: if one_filter.keep(var_name): return True return False
[docs]class ScopeDrop: """For drop out scope prefix when restore variables from checkpoint."""
[docs] def __init__(self, scope_name): self._scope_name = scope_name if len(self._scope_name) >= 0: if self._scope_name[-1] != '/': self._scope_name += '/'
[docs] def update(self, var_name): return var_name.replace(self._scope_name, '')