Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,407 Bytes
8e8cd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import difflib
import torch
def get_layer(l_name, library=torch.nn):
"""Return layer object handler from library e.g. from torch.nn
E.g. if l_name=="elu", returns torch.nn.ELU.
Args:
l_name (string): Case insensitive name for layer in library (e.g. .'elu').
library (module): Name of library/module where to search for object handler
with l_name e.g. "torch.nn".
Returns:
layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU)
"""
all_torch_layers = [x for x in dir(torch.nn)]
match = [x for x in all_torch_layers if l_name.lower() == x.lower()]
if len(match) == 0:
close_matches = difflib.get_close_matches(
l_name, [x.lower() for x in all_torch_layers]
)
raise NotImplementedError(
"Layer with name {} not found in {}.\n Closest matches: {}".format(
l_name, str(library), close_matches
)
)
elif len(match) > 1:
close_matches = difflib.get_close_matches(
l_name, [x.lower() for x in all_torch_layers]
)
raise NotImplementedError(
"Multiple matchs for layer with name {} not found in {}.\n "
"All matches: {}".format(l_name, str(library), close_matches)
)
else:
# valid
layer_handler = getattr(library, match[0])
return layer_handler
|