Spaces:
Runtime error
Runtime error
Set map_location in torch.load call
Browse files- shakespeare_demo.py +2 -1
shakespeare_demo.py
CHANGED
@@ -22,7 +22,8 @@ with open('config.yaml', 'r') as f:
|
|
22 |
#%%
|
23 |
with open('model_state_dict.pt') as f:
|
24 |
state_dict = t.load(
|
25 |
-
'model_state_dict.pt'
|
|
|
26 |
)
|
27 |
#%%
|
28 |
base_config = transformer_replication.TransformerConfig(
|
|
|
22 |
#%%
|
23 |
with open('model_state_dict.pt') as f:
|
24 |
state_dict = t.load(
|
25 |
+
'model_state_dict.pt',
|
26 |
+
map_location=device,
|
27 |
)
|
28 |
#%%
|
29 |
base_config = transformer_replication.TransformerConfig(
|