Ehsa commited on
Commit
f071537
·
1 Parent(s): 06f1701

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -20,10 +20,10 @@ import models_vit
20
 
21
  def prepare_model(chkpt_dir, arch='vit_large_patch14'):
22
  # build model
23
- model = getattr(models_vit, arch)(global_pool=False)
24
  # load model
25
  checkpoint = torch.load(chkpt_dir, map_location='cpu')
26
- msg = model.load_state_dict(checkpoint['model'], strict=False)
27
  print(msg)
28
  return model
29
 
 
20
 
21
  def prepare_model(chkpt_dir, arch='vit_large_patch14'):
22
  # build model
23
+ model = getattr(models_vit, arch)(global_pool=True)
24
  # load model
25
  checkpoint = torch.load(chkpt_dir, map_location='cpu')
26
+ msg = model.load_state_dict(checkpoint['model'], strict=True)
27
  print(msg)
28
  return model
29