Update app.py
Browse files
app.py
CHANGED
@@ -20,10 +20,10 @@ import models_vit
|
|
20 |
|
21 |
def prepare_model(chkpt_dir, arch='vit_large_patch14'):
|
22 |
# build model
|
23 |
-
model = getattr(models_vit, arch)(global_pool=
|
24 |
# load model
|
25 |
checkpoint = torch.load(chkpt_dir, map_location='cpu')
|
26 |
-
msg = model.load_state_dict(checkpoint['model'], strict=
|
27 |
print(msg)
|
28 |
return model
|
29 |
|
|
|
20 |
|
21 |
def prepare_model(chkpt_dir, arch='vit_large_patch14'):
|
22 |
# build model
|
23 |
+
model = getattr(models_vit, arch)(global_pool=True)
|
24 |
# load model
|
25 |
checkpoint = torch.load(chkpt_dir, map_location='cpu')
|
26 |
+
msg = model.load_state_dict(checkpoint['model'], strict=True)
|
27 |
print(msg)
|
28 |
return model
|
29 |
|