Update modeling_llava_qwen2.py
Browse files- modeling_llava_qwen2.py +4 -4
modeling_llava_qwen2.py
CHANGED
@@ -535,13 +535,13 @@ class SigLipVisionTower(nn.Module):
|
|
535 |
if type(images) is list:
|
536 |
image_features = []
|
537 |
for image in images:
|
538 |
-
image_forward_out = self.vision_tower(image.unsqueeze(0),
|
539 |
output_hidden_states=True)
|
540 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
541 |
assert image_features.shape[-2] == 729
|
542 |
image_features.append(image_feature)
|
543 |
else:
|
544 |
-
image_forward_outs = self.vision_tower(images,
|
545 |
output_hidden_states=True)
|
546 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
547 |
assert image_features.shape[-2] == 729
|
@@ -682,9 +682,9 @@ class LlavaMetaForCausalLM(ABC):
|
|
682 |
image_features = self.encode_images(concat_images)
|
683 |
split_sizes = [image.shape[0] for image in images]
|
684 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
685 |
-
image_features = [x.flatten(0, 1) for x in image_features]
|
686 |
else:
|
687 |
-
image_features = self.encode_images(images)
|
688 |
|
689 |
# Let's just add dummy tensors if they do not exist,
|
690 |
# it is a headache to deal with None all the time.
|
|
|
535 |
if type(images) is list:
|
536 |
image_features = []
|
537 |
for image in images:
|
538 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
539 |
output_hidden_states=True)
|
540 |
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
|
541 |
assert image_features.shape[-2] == 729
|
542 |
image_features.append(image_feature)
|
543 |
else:
|
544 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
|
545 |
output_hidden_states=True)
|
546 |
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
|
547 |
assert image_features.shape[-2] == 729
|
|
|
682 |
image_features = self.encode_images(concat_images)
|
683 |
split_sizes = [image.shape[0] for image in images]
|
684 |
image_features = torch.split(image_features, split_sizes, dim=0)
|
685 |
+
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
|
686 |
else:
|
687 |
+
image_features = self.encode_images(images).to(self.device)
|
688 |
|
689 |
# Let's just add dummy tensors if they do not exist,
|
690 |
# it is a headache to deal with None all the time.
|