timm
PyTorch
medical
Image Feature Extraction
Ege Oezsoy commited on
Commit
1c42c71
·
1 Parent(s): 74033b8

Adjustments

Browse files
Files changed (3) hide show
  1. endovit_demo.py +21 -8
  2. endovit_online.py +43 -0
  3. requirements.txt +2 -1
endovit_demo.py CHANGED
@@ -5,8 +5,9 @@ from pathlib import Path
5
  from timm.models.vision_transformer import VisionTransformer
6
  from functools import partial
7
  from torch import nn
 
 
8
 
9
- # requires: pytorch 2.0.1, timm 0.9.16
10
  def process_single_image(image_path, input_size=224, dataset_mean=[0.3464, 0.2280, 0.2228], dataset_std=[0.2520, 0.2128, 0.2093]):
11
  # Define the transformations
12
  transform = T.Compose([
@@ -22,18 +23,30 @@ def process_single_image(image_path, input_size=224, dataset_mean=[0.3464, 0.228
22
  processed_image = transform(image)
23
 
24
  return processed_image
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
26
 
27
- image_paths = sorted(Path('demo_images').glob('*.png'))
 
28
  images = torch.stack([process_single_image(image_path) for image_path in image_paths])
29
 
30
  device = "cuda"
31
  dtype = torch.float16
32
-
33
- model_weights = torch.load('endovit_seg.pth')['model']
34
-
35
- model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)).to(device, dtype).eval()
36
- loading = model.load_state_dict(model_weights, strict=False)
37
- print(loading)
38
  output = model.forward_features(images.to(device, dtype))
39
  print(output.shape)
 
5
  from timm.models.vision_transformer import VisionTransformer
6
  from functools import partial
7
  from torch import nn
8
+ from huggingface_hub import snapshot_download
9
+
10
 
 
11
  def process_single_image(image_path, input_size=224, dataset_mean=[0.3464, 0.2280, 0.2228], dataset_std=[0.2520, 0.2128, 0.2093]):
12
  # Define the transformations
13
  transform = T.Compose([
 
23
  processed_image = transform(image)
24
 
25
  return processed_image
26
+ def load_model_from_huggingface(repo_id, model_filename):
27
+ # Download model files
28
+ model_path = snapshot_download(repo_id=repo_id, revision="main")
29
+ model_weights_path = Path(model_path) / model_filename
30
+
31
+ # Load model weights
32
+ model_weights = torch.load(model_weights_path)['model']
33
+
34
+ # Define the model (ensure this matches your model's architecture)
35
+ model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)).eval()
36
+
37
+ # Load the weights into the model
38
+ loading = model.load_state_dict(model_weights, strict=False)
39
 
40
+ return model, loading
41
 
42
+
43
+ image_paths = sorted(Path('demo_images').glob('*.png')) # TODO replace with image pass
44
  images = torch.stack([process_single_image(image_path) for image_path in image_paths])
45
 
46
  device = "cuda"
47
  dtype = torch.float16
48
+ model, loading_info = load_model_from_huggingface("egeozsoy/EndoViT", "endovit.pth")
49
+ model = model.to(device, dtype)
50
+ print(loading_info)
 
 
 
51
  output = model.forward_features(images.to(device, dtype))
52
  print(output.shape)
endovit_online.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from timm.models.vision_transformer import VisionTransformer
4
+ from functools import partial
5
+ from torch import nn
6
+ from huggingface_hub import snapshot_download
7
+
8
+ def load_model_from_huggingface(repo_id, model_filename):
9
+ # Download model files
10
+ model_path = snapshot_download(repo_id=repo_id, revision="main")
11
+ model_weights_path = Path(model_path) / model_filename
12
+
13
+ # Load model weights
14
+ model_weights = torch.load(model_weights_path)['model']
15
+
16
+ # Define the model (ensure this matches your model's architecture)
17
+ model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)).eval()
18
+
19
+ # Load the weights into the model
20
+ loading = model.load_state_dict(model_weights, strict=False)
21
+
22
+ return model, loading
23
+ def process_single_image(image_path, input_size=224, dataset_mean=[0.3464, 0.2280, 0.2228], dataset_std=[0.2520, 0.2128, 0.2093]):
24
+ # Define the transformations
25
+ transform = T.Compose([
26
+ T.Resize((input_size, input_size)),
27
+ T.ToTensor(),
28
+ T.Normalize(mean=dataset_mean, std=dataset_std)
29
+ ])
30
+
31
+ # Open the image
32
+ image = Image.open(image_path).convert('RGB')
33
+
34
+ # Apply the transformations
35
+ processed_image = transform(image)
36
+
37
+ return processed_image
38
+
39
+ device = "cuda"
40
+ dtype = torch.float16
41
+ model, loading_info = load_model_from_huggingface("egeozsoy/EndoViT", "endovit.pth")
42
+ model = model.to(device, dtype)
43
+ print(loading_info)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  torch==2.0.1
2
- timm==0.9.16
 
 
1
  torch==2.0.1
2
+ timm==0.9.16
3
+ huggingface-hub==0.22.2