Fabrice-TIERCELIN commited on
Commit
b96ccaf
·
verified ·
1 Parent(s): e37a0f8

Do not use LLAVA_CLIP_PATH

Browse files
llava/model/multimodal_encoder/clip_encoder.py CHANGED
@@ -1,84 +1,83 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
- from CKPT_PTH import LLAVA_CLIP_PATH
6
-
7
-
8
- class CLIPVisionTower(nn.Module):
9
- def __init__(self, vision_tower, args, delay_load=False):
10
- super().__init__()
11
-
12
- self.is_loaded = False
13
-
14
- self.vision_tower_name = vision_tower
15
- print(f'Loading vision tower: {self.vision_tower_name}')
16
- self.select_layer = args.mm_vision_select_layer
17
- self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
18
-
19
- if not delay_load:
20
- self.load_model()
21
- else:
22
- # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23
- self.cfg_only = CLIPVisionConfig.from_pretrained(
24
- self.vision_tower_name if LLAVA_CLIP_PATH is None else LLAVA_CLIP_PATH)
25
-
26
- def load_model(self):
27
- self.image_processor = CLIPImageProcessor.from_pretrained(
28
- self.vision_tower_name if LLAVA_CLIP_PATH is None else LLAVA_CLIP_PATH)
29
- self.vision_tower = CLIPVisionModel.from_pretrained(
30
- self.vision_tower_name if LLAVA_CLIP_PATH is None else LLAVA_CLIP_PATH)
31
- self.vision_tower.requires_grad_(False)
32
-
33
- self.is_loaded = True
34
-
35
- def feature_select(self, image_forward_outs):
36
- image_features = image_forward_outs.hidden_states[self.select_layer]
37
- if self.select_feature == 'patch':
38
- image_features = image_features[:, 1:]
39
- elif self.select_feature == 'cls_patch':
40
- image_features = image_features
41
- else:
42
- raise ValueError(f'Unexpected select feature: {self.select_feature}')
43
- return image_features
44
-
45
- @torch.no_grad()
46
- def forward(self, images):
47
- if type(images) is list:
48
- image_features = []
49
- for image in images:
50
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
51
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
52
- image_features.append(image_feature)
53
- else:
54
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
55
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
56
-
57
- return image_features
58
-
59
- @property
60
- def dummy_feature(self):
61
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
62
-
63
- @property
64
- def dtype(self):
65
- return self.vision_tower.dtype
66
-
67
- @property
68
- def device(self):
69
- return self.vision_tower.device
70
-
71
- @property
72
- def config(self):
73
- if self.is_loaded:
74
- return self.vision_tower.config
75
- else:
76
- return self.cfg_only
77
-
78
- @property
79
- def hidden_size(self):
80
- return self.config.hidden_size
81
-
82
- @property
83
- def num_patches(self):
84
- return (self.config.image_size // self.config.patch_size) ** 2
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ print(f'Loading vision tower: {self.vision_tower_name}')
15
+ self.select_layer = args.mm_vision_select_layer
16
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
17
+
18
+ if not delay_load:
19
+ self.load_model()
20
+ else:
21
+ # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
22
+ self.cfg_only = CLIPVisionConfig.from_pretrained(
23
+ self.vision_tower_name)
24
+
25
+ def load_model(self):
26
+ self.image_processor = CLIPImageProcessor.from_pretrained(
27
+ self.vision_tower_name)
28
+ self.vision_tower = CLIPVisionModel.from_pretrained(
29
+ self.vision_tower_name)
30
+ self.vision_tower.requires_grad_(False)
31
+
32
+ self.is_loaded = True
33
+
34
+ def feature_select(self, image_forward_outs):
35
+ image_features = image_forward_outs.hidden_states[self.select_layer]
36
+ if self.select_feature == 'patch':
37
+ image_features = image_features[:, 1:]
38
+ elif self.select_feature == 'cls_patch':
39
+ image_features = image_features
40
+ else:
41
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
42
+ return image_features
43
+
44
+ @torch.no_grad()
45
+ def forward(self, images):
46
+ if type(images) is list:
47
+ image_features = []
48
+ for image in images:
49
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
50
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
51
+ image_features.append(image_feature)
52
+ else:
53
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
54
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
55
+
56
+ return image_features
57
+
58
+ @property
59
+ def dummy_feature(self):
60
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
61
+
62
+ @property
63
+ def dtype(self):
64
+ return self.vision_tower.dtype
65
+
66
+ @property
67
+ def device(self):
68
+ return self.vision_tower.device
69
+
70
+ @property
71
+ def config(self):
72
+ if self.is_loaded:
73
+ return self.vision_tower.config
74
+ else:
75
+ return self.cfg_only
76
+
77
+ @property
78
+ def hidden_size(self):
79
+ return self.config.hidden_size
80
+
81
+ @property
82
+ def num_patches(self):
83
+ return (self.config.image_size // self.config.patch_size) ** 2