andy-wyx commited on
Commit
505fe72
·
1 Parent(s): f6436fb

change map location to CPU

Browse files
Files changed (1) hide show
  1. inference_sam.py +12 -1
inference_sam.py CHANGED
@@ -33,11 +33,22 @@ if not os.path.exists('model'):
33
  print("warning! A read token in env variables is needed for authentication.")
34
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model')
35
 
 
 
 
 
 
 
 
 
 
36
  model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth')
37
- sam = sam_model_registry["default"](model_path,map_location=device)
38
  sam.to(device) #sam.cuda()
39
  predictor = SamPredictor(sam)
40
 
 
 
41
 
42
  from torch.nn import functional as F
43
 
 
33
  print("warning! A read token in env variables is needed for authentication.")
34
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='model',local_dir='model')
35
 
36
+
37
+ original_torch_load = torch.load
38
+
39
+ def patched_torch_load(*args, **kwargs):
40
+ kwargs['map_location'] = device
41
+ return original_torch_load(*args, **kwargs)
42
+
43
+ torch.load = patched_torch_load
44
+
45
  model_path = os.path.join('model', 'sam_02-06_dice_mse_0.pth')
46
+ sam = sam_model_registry["default"](model_path)
47
  sam.to(device) #sam.cuda()
48
  predictor = SamPredictor(sam)
49
 
50
+ torch.load = original_torch_load
51
+
52
 
53
  from torch.nn import functional as F
54