skar0 commited on
Commit
eafeaef
·
1 Parent(s): b8753fd

Set map_location in torch.load call

Browse files
Files changed (1) hide show
  1. shakespeare_demo.py +2 -1
shakespeare_demo.py CHANGED
@@ -22,7 +22,8 @@ with open('config.yaml', 'r') as f:
22
  #%%
23
  with open('model_state_dict.pt') as f:
24
  state_dict = t.load(
25
- 'model_state_dict.pt'
 
26
  )
27
  #%%
28
  base_config = transformer_replication.TransformerConfig(
 
22
  #%%
23
  with open('model_state_dict.pt') as f:
24
  state_dict = t.load(
25
+ 'model_state_dict.pt',
26
+ map_location=device,
27
  )
28
  #%%
29
  base_config = transformer_replication.TransformerConfig(