Spaces:
Sleeping
Sleeping
change map location to CPU
Browse files- inference_sam.py +12 -1
inference_sam.py
CHANGED
@@ -33,11 +33,22 @@ if not os.path.exists('model'):
|
|
33 |
print("warning! A read token in env variables is needed for authentication.")
|
34 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model')
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth')
|
37 |
-
sam = sam_model_registry["default"](model_path
|
38 |
sam.to(device) #sam.cuda()
|
39 |
predictor = SamPredictor(sam)
|
40 |
|
|
|
|
|
41 |
|
42 |
from torch.nn import functional as F
|
43 |
|
|
|
33 |
print("warning! A read token in env variables is needed for authentication.")
|
34 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model')
|
35 |
|
36 |
+
|
37 |
+
original_torch_load = torch.load
|
38 |
+
|
39 |
+
def patched_torch_load(*args, **kwargs):
|
40 |
+
kwargs['map_location'] = device
|
41 |
+
return original_torch_load(*args, **kwargs)
|
42 |
+
|
43 |
+
torch.load = patched_torch_load
|
44 |
+
|
45 |
model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth')
|
46 |
+
sam = sam_model_registry["default"](model_path)
|
47 |
sam.to(device) #sam.cuda()
|
48 |
predictor = SamPredictor(sam)
|
49 |
|
50 |
+
torch.load = original_torch_load
|
51 |
+
|
52 |
|
53 |
from torch.nn import functional as F
|
54 |
|