kaczmarj commited on
Commit
ec041df
·
verified ·
1 Parent(s): 5a0bb2f

upload safetensors and conversion script

Browse files
Files changed (2) hide show
  1. convert_pt.py +59 -0
  2. model.safetensors +3 -0
convert_pt.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import save_file
2
+ import torch
3
+ from torchvision.models import resnet34
4
+
5
+ model_path = "RESNET_34_cancer_350px_lr_1e-2_decay_5_jitter_val6slides_harder_tcga_none_0403_0204_0.9826153355179645_16.t7"
6
+
7
+ orig_model = torch.load(model_path, map_location="cpu")
8
+ state_dict = orig_model["model"].module.state_dict()
9
+ keys_missing = [
10
+ "bn1.num_batches_tracked",
11
+ "layer1.0.bn1.num_batches_tracked",
12
+ "layer1.0.bn2.num_batches_tracked",
13
+ "layer1.1.bn1.num_batches_tracked",
14
+ "layer1.1.bn2.num_batches_tracked",
15
+ "layer1.2.bn1.num_batches_tracked",
16
+ "layer1.2.bn2.num_batches_tracked",
17
+ "layer2.0.bn1.num_batches_tracked",
18
+ "layer2.0.bn2.num_batches_tracked",
19
+ "layer2.0.downsample.1.num_batches_tracked",
20
+ "layer2.1.bn1.num_batches_tracked",
21
+ "layer2.1.bn2.num_batches_tracked",
22
+ "layer2.2.bn1.num_batches_tracked",
23
+ "layer2.2.bn2.num_batches_tracked",
24
+ "layer2.3.bn1.num_batches_tracked",
25
+ "layer2.3.bn2.num_batches_tracked",
26
+ "layer3.0.bn1.num_batches_tracked",
27
+ "layer3.0.bn2.num_batches_tracked",
28
+ "layer3.0.downsample.1.num_batches_tracked",
29
+ "layer3.1.bn1.num_batches_tracked",
30
+ "layer3.1.bn2.num_batches_tracked",
31
+ "layer3.2.bn1.num_batches_tracked",
32
+ "layer3.2.bn2.num_batches_tracked",
33
+ "layer3.3.bn1.num_batches_tracked",
34
+ "layer3.3.bn2.num_batches_tracked",
35
+ "layer3.4.bn1.num_batches_tracked",
36
+ "layer3.4.bn2.num_batches_tracked",
37
+ "layer3.5.bn1.num_batches_tracked",
38
+ "layer3.5.bn2.num_batches_tracked",
39
+ "layer4.0.bn1.num_batches_tracked",
40
+ "layer4.0.bn2.num_batches_tracked",
41
+ "layer4.0.downsample.1.num_batches_tracked",
42
+ "layer4.1.bn1.num_batches_tracked",
43
+ "layer4.1.bn2.num_batches_tracked",
44
+ "layer4.2.bn1.num_batches_tracked",
45
+ "layer4.2.bn2.num_batches_tracked",
46
+ ]
47
+ assert not any(
48
+ key in state_dict.keys() for key in keys_missing
49
+ ), "key present that should be missing"
50
+ for key in keys_missing:
51
+ state_dict[key] = torch.as_tensor(0)
52
+ torch.save(state_dict, "pytorch_model.pt")
53
+ save_file(state_dict, "model.safetensors")
54
+
55
+ model = resnet34(weights=None)
56
+ model.fc = torch.nn.Linear(model.fc.in_features, out_features=5, bias=True)
57
+ model.load_state_dict(state_dict)
58
+ model_jit = torch.jit.script(model, example_inputs=[(torch.ones(1, 3, 224, 224),)])
59
+ torch.jit.save(model_jit, "torchscript_model.bin")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c22f2594a277db430a58896ccc1fa751435c6054e09fc825e5a5534af70f7ae3
3
+ size 85236652