# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Define filters for restore."""
from abc import ABCMeta
from abc import abstractmethod
from enum import Enum
[docs]class Logical(Enum):
AND = 1
OR = 2
[docs]class Filter:
__metaclass__ = ABCMeta
[docs] def __init__(self):
pass
[docs] @abstractmethod
def keep(self, var_name):
"""Keep the var or not.
Args:
var_name: input name of the var
Returns:
True if the var will be kept, else False
"""
return True
[docs]class KeywordFilter(Filter):
[docs] def __init__(self, pattern, exclusive=False):
"""Init KeywordFilter.
Args:
pattern: keyword to be matched
exclusive: if True, var_name should include the pattern
else, var_name should not include the pattern
"""
self._pattern = pattern
self._exclusive = exclusive
[docs] def keep(self, var_name):
if not self._exclusive:
return self._pattern in var_name
else:
return self._pattern not in var_name
[docs]class CombineFilter(Filter):
[docs] def __init__(self, filters, logical=Logical.AND):
"""Init CombineFilter.
Args:
filters: a set of filters to be combined
logical: logical and/or combination of the filters
"""
self._filters = filters
self._logical = logical
[docs] def keep(self, var_name):
if self._logical == Logical.AND:
for one_filter in self._filters:
if not one_filter.keep(var_name):
return False
return True
elif self._logical == Logical.OR:
for one_filter in self._filters:
if one_filter.keep(var_name):
return True
return False
[docs]class ScopeDrop:
"""For drop out scope prefix when restore variables from checkpoint."""
[docs] def __init__(self, scope_name):
self._scope_name = scope_name
if len(self._scope_name) >= 0:
if self._scope_name[-1] != '/':
self._scope_name += '/'
[docs] def update(self, var_name):
return var_name.replace(self._scope_name, '')