English
qtnx commited on
Commit
029d226
·
verified ·
1 Parent(s): 8b7c4a1
Files changed (6) hide show
  1. README.md +43 -0
  2. __main__.py +204 -0
  3. assets/demo-1.jpg +0 -0
  4. assets/demo-2.jpg +0 -0
  5. assets/demo-3.jpg +0 -0
  6. mm_projector.bin +3 -0
README.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ ---
5
+
6
+ # llama3-vision-alpha
7
+
8
+ projection module trained to add vision capabilties to Llama 3 using SigLIP. built by [@yeswondwerr](https://x.com/yeswondwerr) and [@qtnx_](https://x.com/qtnx_)
9
+
10
+ **usage**
11
+
12
+ ```
13
+ pip install torch transformers bitsandbytes accelerate
14
+ ```
15
+
16
+ ```
17
+ python __main__.py -i image
18
+ ```
19
+
20
+ **examples**
21
+
22
+ | Image | Examples |
23
+ | --- | --- |
24
+ | ![](assets/demo-1.jpg) | **What is the title of this book? answer briefly**<br>The title of the book is "The Little Book of Deep Learning".<br><br>**Where is the person standing? answer briefly**<br> The person is standing on the balcony. |
25
+ | ![](assets/demo-2.jpg) | **What type of food is the girl holding? answer briefly**<br>A hamburger!<br><br>**What color is the woman's hair? answer briefly**<br>It's white! |
26
+
27
+ ```
28
+ .x+=:.
29
+ z` ^% .uef^"
30
+ .u . . <k .u . :d88E
31
+ .u@u .d88B :@8c .u .@8Ned8" .u u .d88B :@8c . `888E
32
+ .zWF8888bx ="8888f8888r ud8888. .@^%8888" ud8888. us888u. ="8888f8888r .udR88N 888E .z8k
33
+ .888 9888 4888>'88" :888'8888. x88: `)8b. :888'8888. .@88 "8888" 4888>'88" <888'888k 888E~?888L
34
+ I888 9888 4888> ' d888 '88%" 8888N=*8888 d888 '88%" 9888 9888 4888> ' 9888 'Y" 888E 888E
35
+ I888 9888 4888> 8888.+" %8" R88 8888.+" 9888 9888 4888> 9888 888E 888E
36
+ I888 9888 .d888L .+ 8888L @8Wou 9% 8888L 9888 9888 .d888L .+ 9888 888E 888E
37
+ `888Nx?888 ^"8888*" '8888c. .+ .888888P` '8888c. .+ 9888 9888 ^"8888*" ?8888u../ 888E 888E
38
+ "88" '888 "Y" "88888% ` ^"F "88888% "888*""888" "Y" "8888P' m888N= 888>
39
+ 88E "YP' "YP' ^Y" ^Y' "P' `Y" 888
40
+ 98> J88"
41
+ '8 @%
42
+ ` :"
43
+ ```
__main__.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ from transformers import (
7
+ AutoModel,
8
+ AutoProcessor,
9
+ AutoTokenizer,
10
+ BitsAndBytesConfig,
11
+ LlamaForCausalLM, SiglipImageProcessor, SiglipVisionModel
12
+
13
+ )
14
+ from transformers import TextStreamer
15
+
16
+
17
+
18
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=-200, return_tensors=None):
19
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
20
+
21
+ def insert_separator(X, sep):
22
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
23
+
24
+ input_ids = []
25
+ offset = 0
26
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
27
+ offset = 1
28
+ input_ids.append(prompt_chunks[0][0])
29
+
30
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
31
+ input_ids.extend(x[offset:])
32
+
33
+ return torch.tensor(input_ids, dtype=torch.long)
34
+
35
+
36
+ def process_tensors(input_ids, image_features, embedding_layer):
37
+ # Find the index of -200 in input_ids
38
+ split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0]
39
+
40
+ # Split the input_ids at the index found, excluding -200
41
+ input_ids_1 = input_ids[:, :split_index]
42
+ input_ids_2 = input_ids[:, split_index + 1:]
43
+
44
+ # Convert input_ids to embeddings
45
+ embeddings_1 = embedding_layer(input_ids_1)
46
+ embeddings_2 = embedding_layer(input_ids_2)
47
+
48
+ device = image_features.device
49
+ token_embeddings_part1 = embeddings_1.to(device)
50
+ token_embeddings_part2 = embeddings_2.to(device)
51
+
52
+ # Concatenate the token embeddings and image features
53
+ concatenated_embeddings = torch.cat(
54
+ [token_embeddings_part1, image_features, token_embeddings_part2], dim=1
55
+ )
56
+
57
+ # Create the corrected attention mask
58
+ attention_mask = torch.ones(concatenated_embeddings.shape[:2], dtype=torch.long, device=device)
59
+ return concatenated_embeddings, attention_mask
60
+
61
+
62
+ def initialize_models():
63
+ bnb_config = BitsAndBytesConfig(
64
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
65
+ )
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct", use_fast=True)
68
+ model = LlamaForCausalLM.from_pretrained(
69
+ "unsloth/llama-3-8b-Instruct",
70
+ torch_dtype=torch.float16,
71
+ device_map="auto",
72
+ quantization_config=bnb_config,
73
+ )
74
+
75
+ for param in model.base_model.parameters():
76
+ param.requires_grad = False
77
+
78
+ model_name = "google/siglip-so400m-patch14-384"
79
+ vision_model = SiglipVisionModel.from_pretrained(model_name, torch_dtype=torch.float16)
80
+ processor = SiglipImageProcessor.from_pretrained(model_name)
81
+
82
+ vision_model = vision_model.to("cuda")
83
+
84
+ return tokenizer, model, vision_model, processor
85
+
86
+
87
+ class ProjectionModule(nn.Module):
88
+ def __init__(self, mm_hidden_size, hidden_size):
89
+ super(ProjectionModule, self).__init__()
90
+
91
+ # Directly set up the sequential model
92
+ self.model = nn.Sequential(
93
+ nn.Linear(mm_hidden_size, hidden_size),
94
+ nn.GELU(),
95
+ nn.Linear(hidden_size, hidden_size)
96
+ )
97
+
98
+ def forward(self, x):
99
+ return self.model(x)
100
+
101
+ def load_projection_module(mm_hidden_size=1152, hidden_size=4096, device='cuda'):
102
+ projection_module = ProjectionModule(mm_hidden_size, hidden_size)
103
+ checkpoint = torch.load("./checkpoints/llama-3/checkpoint-2400/mm_projector.bin")
104
+ checkpoint = {k.replace("mm_projector.", ""): v for k, v in checkpoint.items()}
105
+ projection_module.load_state_dict(checkpoint)
106
+ projection_module = projection_module.to(device).half()
107
+ return projection_module
108
+
109
+
110
+ def answer_question(
111
+ image_path, tokenizer, model, vision_model, processor, projection_module
112
+ ):
113
+ image = Image.open(image_path).convert('RGB')
114
+
115
+ tokenizer.bos_token_id = None
116
+ tokenizer.eos_token = "<|eot_id|>"
117
+
118
+ try:
119
+ inp = input('user: ')
120
+ except EOFError:
121
+ inp = ""
122
+ if not inp:
123
+ print("exit...")
124
+
125
+ question = '<image>' + inp
126
+
127
+ prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
128
+
129
+ input_ids = tokenizer_image_token(prompt, tokenizer, -200, return_tensors='pt').unsqueeze(0).to(
130
+ model.device)
131
+
132
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
133
+
134
+ with torch.inference_mode():
135
+ image_inputs = processor(images=[image], return_tensors="pt", do_resize=True,
136
+ size={"height": 384, "width": 384}).to("cuda")
137
+
138
+ image_inputs = image_inputs['pixel_values'].squeeze(0)
139
+
140
+ image_forward_outs = vision_model(image_inputs.to(device='cuda', dtype=torch.float16).unsqueeze(0),
141
+ output_hidden_states=True)
142
+
143
+ image_features = image_forward_outs.hidden_states[-2]
144
+
145
+ image_features2 = image_features[:, 1:]
146
+
147
+ projected_embeddings = projection_module(image_features2).to("cuda")
148
+
149
+ embedding_layer = model.get_input_embeddings()
150
+ #text_embeddings = embedding_layer(input_ids)
151
+
152
+ new_embeds, attn_mask = process_tensors(input_ids, projected_embeddings, embedding_layer)
153
+ device = model.device
154
+ attn_mask = attn_mask.to(device)
155
+ new_embeds = new_embeds.to(device)
156
+
157
+ model_kwargs = {
158
+ 'do_sample': True,
159
+ 'temperature': 0.2,
160
+ 'max_new_tokens': 2000,
161
+ 'use_cache': True,
162
+ 'streamer': streamer
163
+ }
164
+
165
+ while True:
166
+ generated_ids = model.generate(
167
+ inputs_embeds=new_embeds,
168
+ attention_mask=attn_mask,
169
+ **model_kwargs
170
+
171
+ )[0]
172
+
173
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
174
+ try:
175
+ inp = input('user: ')
176
+ except EOFError:
177
+ inp = ""
178
+ if not inp:
179
+ print("exit...")
180
+
181
+ new_text = generated_text + "<|start_header_id|>user<|end_header_id|>\n\n" + inp + "<|start_header_id|>assistant<|end_header_id|>\n\n"
182
+ new_input_ids = tokenizer(new_text, return_tensors='pt').input_ids.to(device)
183
+ new_embeddings = embedding_layer(new_input_ids)
184
+
185
+ new_embeds = torch.cat([new_embeds, new_embeddings], dim=1)
186
+ attn_mask = torch.ones(new_embeds.shape[:2], device=device)
187
+
188
+
189
+ if __name__ == "__main__":
190
+ parser = argparse.ArgumentParser(description="Answer questions based on an image")
191
+ parser.add_argument("-i", "--image", required=True, help="Path to the image file")
192
+ args = parser.parse_args()
193
+
194
+ tokenizer, model, vision_model, processor = initialize_models()
195
+ projection_module = load_projection_module()
196
+
197
+ answer_question(
198
+ args.image,
199
+ tokenizer,
200
+ model,
201
+ vision_model,
202
+ processor,
203
+ projection_module,
204
+ )
assets/demo-1.jpg ADDED
assets/demo-2.jpg ADDED
assets/demo-3.jpg ADDED
mm_projector.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c67486e883bf7f02b9756850c6f1914e7146936b49805bd3ca8583a71c4d40f
3
+ size 43009661