Spaces:
Running
on
Zero
Running
on
Zero
import difflib | |
import torch | |
def get_layer(l_name, library=torch.nn): | |
"""Return layer object handler from library e.g. from torch.nn | |
E.g. if l_name=="elu", returns torch.nn.ELU. | |
Args: | |
l_name (string): Case insensitive name for layer in library (e.g. .'elu'). | |
library (module): Name of library/module where to search for object handler | |
with l_name e.g. "torch.nn". | |
Returns: | |
layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU) | |
""" | |
all_torch_layers = [x for x in dir(torch.nn)] | |
match = [x for x in all_torch_layers if l_name.lower() == x.lower()] | |
if len(match) == 0: | |
close_matches = difflib.get_close_matches( | |
l_name, [x.lower() for x in all_torch_layers] | |
) | |
raise NotImplementedError( | |
"Layer with name {} not found in {}.\n Closest matches: {}".format( | |
l_name, str(library), close_matches | |
) | |
) | |
elif len(match) > 1: | |
close_matches = difflib.get_close_matches( | |
l_name, [x.lower() for x in all_torch_layers] | |
) | |
raise NotImplementedError( | |
"Multiple matchs for layer with name {} not found in {}.\n " | |
"All matches: {}".format(l_name, str(library), close_matches) | |
) | |
else: | |
# valid | |
layer_handler = getattr(library, match[0]) | |
return layer_handler | |