mins commited on
Commit
b443c25
β€’
1 Parent(s): 4a650e1

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__
3
+ *.pyc
4
+ *.egg-info
5
+ dist
6
+
7
+ # Log
8
+ *.log
9
+ *.log.*
10
+ # *.json
11
+ *.jsonl
12
+ images/*
13
+
14
+ # Editor
15
+ .idea
16
+ *.swp
17
+ .github
18
+ .vscode
19
+
20
+ # Other
21
+ .DS_Store
22
+ wandb
23
+ output
app.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ # import copy
4
+ import torch
5
+ # import random
6
+ import spaces
7
+
8
+ from eagle import conversation as conversation_lib
9
+ from eagle.constants import DEFAULT_IMAGE_TOKEN
10
+
11
+ from eagle.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
12
+ from eagle.conversation import conv_templates, SeparatorStyle
13
+ from eagle.model.builder import load_pretrained_model
14
+ from eagle.utils import disable_torch_init
15
+ from eagle.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
16
+
17
+ from PIL import Image
18
+ import argparse
19
+
20
+ from transformers import TextIteratorStreamer
21
+ from threading import Thread
22
+
23
+ # os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp'
24
+ no_change_btn = gr.Button()
25
+ enable_btn = gr.Button(interactive=True)
26
+ disable_btn = gr.Button(interactive=False)
27
+
28
+ argparser = argparse.ArgumentParser()
29
+ argparser.add_argument("--server_name", default="0.0.0.0", type=str)
30
+ argparser.add_argument("--port", default="6324", type=str)
31
+ argparser.add_argument("--model-path", default="NVEagle/Eagle-X5-13B", type=str)
32
+ argparser.add_argument("--model-base", type=str, default=None)
33
+ argparser.add_argument("--num-gpus", type=int, default=1)
34
+ argparser.add_argument("--conv-mode", type=str, default="vicuna_v1")
35
+ argparser.add_argument("--temperature", type=float, default=0.2)
36
+ argparser.add_argument("--max-new-tokens", type=int, default=512)
37
+ argparser.add_argument("--num_frames", type=int, default=16)
38
+ argparser.add_argument("--load-8bit", action="store_true")
39
+ argparser.add_argument("--load-4bit", action="store_true")
40
+ argparser.add_argument("--debug", action="store_true")
41
+
42
+ args = argparser.parse_args()
43
+ model_path = args.model_path
44
+ conv_mode = args.conv_mode
45
+ filt_invalid="cut"
46
+ model_name = get_model_name_from_path(args.model_path)
47
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit)
48
+ our_chatbot = None
49
+
50
+ def upvote_last_response(state):
51
+ return ("",) + (disable_btn,) * 3
52
+
53
+
54
+ def downvote_last_response(state):
55
+ return ("",) + (disable_btn,) * 3
56
+
57
+
58
+ def flag_last_response(state):
59
+ return ("",) + (disable_btn,) * 3
60
+
61
+ def clear_history():
62
+ state =conv_templates[conv_mode].copy()
63
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
64
+
65
+ def add_text(state, imagebox, textbox, image_process_mode):
66
+ if state is None:
67
+ state = conv_templates[conv_mode].copy()
68
+
69
+ if imagebox is not None:
70
+ textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
71
+ image = Image.open(imagebox).convert('RGB')
72
+
73
+ if imagebox is not None:
74
+ textbox = (textbox, image, image_process_mode)
75
+
76
+ state.append_message(state.roles[0], textbox)
77
+ state.append_message(state.roles[1], None)
78
+
79
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
80
+
81
+ def delete_text(state, image_process_mode):
82
+ state.messages[-1][-1] = None
83
+ prev_human_msg = state.messages[-2]
84
+ if type(prev_human_msg[1]) in (tuple, list):
85
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
86
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
87
+
88
+ def regenerate(state, image_process_mode):
89
+ state.messages[-1][-1] = None
90
+ prev_human_msg = state.messages[-2]
91
+ if type(prev_human_msg[1]) in (tuple, list):
92
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
93
+ state.skip_next = False
94
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
95
+
96
+ @spaces.GPU
97
+ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
98
+ prompt = state.get_prompt()
99
+ images = state.get_images(return_pil=True)
100
+ #prompt, image_args = process_image(prompt, images)
101
+
102
+ ori_prompt = prompt
103
+ num_image_tokens = 0
104
+
105
+ if images is not None and len(images) > 0:
106
+ if len(images) > 0:
107
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
108
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
109
+
110
+ #images = [load_image_from_base64(image) for image in images]
111
+ image_sizes = [image.size for image in images]
112
+ images = process_images(images, image_processor, model.config)
113
+
114
+ if type(images) is list:
115
+ images = [image.to(model.device, dtype=torch.float16) for image in images]
116
+ else:
117
+ images = images.to(model.device, dtype=torch.float16)
118
+ else:
119
+ images = None
120
+ image_sizes = None
121
+ image_args = {"images": images, "image_sizes": image_sizes}
122
+ else:
123
+ images = None
124
+ image_args = {}
125
+
126
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
127
+ max_new_tokens = 512
128
+ do_sample = True if temperature > 0.001 else False
129
+ stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
130
+
131
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
132
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
133
+
134
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
135
+
136
+ if max_new_tokens < 1:
137
+ # yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
138
+ return
139
+
140
+ thread = Thread(target=model.generate, kwargs=dict(
141
+ inputs=input_ids,
142
+ do_sample=do_sample,
143
+ temperature=temperature,
144
+ top_p=top_p,
145
+ max_new_tokens=max_new_tokens,
146
+ streamer=streamer,
147
+ use_cache=True,
148
+ pad_token_id=tokenizer.eos_token_id,
149
+ **image_args
150
+ ))
151
+ thread.start()
152
+ generated_text = ''
153
+ for new_text in streamer:
154
+ generated_text += new_text
155
+ if generated_text.endswith(stop_str):
156
+ generated_text = generated_text[:-len(stop_str)]
157
+ state.messages[-1][-1] = generated_text
158
+ yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
159
+
160
+ yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
161
+
162
+ torch.cuda.empty_cache()
163
+
164
+ txt = gr.Textbox(
165
+ scale=4,
166
+ show_label=False,
167
+ placeholder="Enter text and press enter.",
168
+ container=False,
169
+ )
170
+
171
+
172
+ title_markdown = ("""
173
+ # Eagle: Exploring The Design Space for Multimodal LLMs with Mixture of Encoders
174
+ [[Project Page](TODO)] [[Code](TODO)] [[Model](TODO)] | πŸ“š [[Arxiv](TODO)]]
175
+ """)
176
+
177
+ tos_markdown = ("""
178
+ ### Terms of use
179
+ By using this service, users are required to agree to the following terms:
180
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
181
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
182
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
183
+ """)
184
+
185
+
186
+ learn_more_markdown = ("""
187
+ ### License
188
+ The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
189
+ """)
190
+
191
+ block_css = """
192
+ #buttons button {
193
+ min-width: min(120px,100%);
194
+ }
195
+ """
196
+
197
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
198
+ with gr.Blocks(title="Eagle", theme=gr.themes.Default(), css=block_css) as demo:
199
+ state = gr.State()
200
+
201
+ gr.Markdown(title_markdown)
202
+
203
+ with gr.Row():
204
+ with gr.Column(scale=3):
205
+ imagebox = gr.Image(label="Input Image", type="filepath")
206
+ image_process_mode = gr.Radio(
207
+ ["Crop", "Resize", "Pad", "Default"],
208
+ value="Default",
209
+ label="Preprocess for non-square image", visible=False)
210
+
211
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
212
+ gr.Examples(examples=[
213
+ [f"{cur_dir}/assets/health-insurance.png", "Under which circumstances do I need to be enrolled in mandatory health insurance if I am an international student?"],
214
+ [f"{cur_dir}/assets/leasing-apartment.png", "I don't have any 3rd party renter's insurance now. Do I need to get one for myself?"],
215
+ [f"{cur_dir}/assets/nvidia.jpeg", "Who is the person in the middle?"],
216
+ [f"{cur_dir}/assets/animal-compare.png", "Are these two pictures showing the same kind of animal?"],
217
+ [f"{cur_dir}/assets/georgia-tech.jpeg", "Where is this photo taken?"]
218
+ ], inputs=[imagebox, textbox], cache_examples=False)
219
+
220
+ with gr.Accordion("Parameters", open=False) as parameter_row:
221
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
222
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
223
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
224
+
225
+ with gr.Column(scale=8):
226
+ chatbot = gr.Chatbot(
227
+ elem_id="chatbot",
228
+ label="Eagle Chatbot",
229
+ height=650,
230
+ layout="panel",
231
+ )
232
+ with gr.Row():
233
+ with gr.Column(scale=8):
234
+ textbox.render()
235
+ with gr.Column(scale=1, min_width=50):
236
+ submit_btn = gr.Button(value="Send", variant="primary")
237
+ with gr.Row(elem_id="buttons") as button_row:
238
+ upvote_btn = gr.Button(value="πŸ‘ Upvote", interactive=False)
239
+ downvote_btn = gr.Button(value="πŸ‘Ž Downvote", interactive=False)
240
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
241
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
242
+ regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False)
243
+ clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False)
244
+
245
+ gr.Markdown(tos_markdown)
246
+ gr.Markdown(learn_more_markdown)
247
+ url_params = gr.JSON(visible=False)
248
+
249
+ # Register listeners
250
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
251
+ upvote_btn.click(
252
+ upvote_last_response,
253
+ [state],
254
+ [textbox, upvote_btn, downvote_btn, flag_btn]
255
+ )
256
+ downvote_btn.click(
257
+ downvote_last_response,
258
+ [state],
259
+ [textbox, upvote_btn, downvote_btn, flag_btn]
260
+ )
261
+ flag_btn.click(
262
+ flag_last_response,
263
+ [state],
264
+ [textbox, upvote_btn, downvote_btn, flag_btn]
265
+ )
266
+
267
+ clear_btn.click(
268
+ clear_history,
269
+ None,
270
+ [state, chatbot, textbox, imagebox] + btn_list,
271
+ queue=False
272
+ )
273
+
274
+ regenerate_btn.click(
275
+ delete_text,
276
+ [state, image_process_mode],
277
+ [state, chatbot, textbox, imagebox] + btn_list,
278
+ ).then(
279
+ generate,
280
+ [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
281
+ [state, chatbot, textbox, imagebox] + btn_list,
282
+ )
283
+ textbox.submit(
284
+ add_text,
285
+ [state, imagebox, textbox, image_process_mode],
286
+ [state, chatbot, textbox, imagebox] + btn_list,
287
+ ).then(
288
+ generate,
289
+ [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
290
+ [state, chatbot, textbox, imagebox] + btn_list,
291
+ )
292
+
293
+ submit_btn.click(
294
+ add_text,
295
+ [state, imagebox, textbox, image_process_mode],
296
+ [state, chatbot, textbox, imagebox] + btn_list,
297
+ ).then(
298
+ generate,
299
+ [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
300
+ [state, chatbot, textbox, imagebox] + btn_list,
301
+ )
302
+
303
+ demo.queue(
304
+ status_update_rate=10,
305
+ api_open=False
306
+ ).launch()
assets/animal-compare.png ADDED
assets/georgia-tech.jpeg ADDED
assets/health-insurance.png ADDED
assets/leasing-apartment.png ADDED
assets/nvidia.jpeg ADDED
eagle/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import EagleLlamaForCausalLM
eagle/constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
eagle/conversation.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Conversation:
20
+ """A class that keeps all conversation history."""
21
+ system: str
22
+ roles: List[str]
23
+ messages: List[List[str]]
24
+ offset: int
25
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
26
+ sep: str = "###"
27
+ sep2: str = None
28
+ version: str = "Unknown"
29
+
30
+ skip_next: bool = False
31
+
32
+ def get_prompt(self):
33
+ messages = self.messages
34
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
35
+ messages = self.messages.copy()
36
+ init_role, init_msg = messages[0].copy()
37
+ init_msg = init_msg[0].replace("<image>", "").strip()
38
+ if 'mmtag' in self.version:
39
+ messages[0] = (init_role, init_msg)
40
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
41
+ messages.insert(1, (self.roles[1], "Received."))
42
+ else:
43
+ messages[0] = (init_role, "<image>\n" + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.MPT:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + message + self.sep
71
+ else:
72
+ ret += role
73
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
74
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
75
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
76
+ ret = ""
77
+
78
+ for i, (role, message) in enumerate(messages):
79
+ if i == 0:
80
+ assert message, "first message should not be none"
81
+ assert role == self.roles[0], "first message should come from user"
82
+ if message:
83
+ if type(message) is tuple:
84
+ message, _, _ = message
85
+ if i == 0: message = wrap_sys(self.system) + message
86
+ if i % 2 == 0:
87
+ message = wrap_inst(message)
88
+ ret += self.sep + message
89
+ else:
90
+ ret += " " + message + " " + self.sep2
91
+ else:
92
+ ret += ""
93
+ ret = ret.lstrip(self.sep)
94
+ elif self.sep_style == SeparatorStyle.PLAIN:
95
+ seps = [self.sep, self.sep2]
96
+ ret = self.system
97
+ for i, (role, message) in enumerate(messages):
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += message + seps[i % 2]
102
+ else:
103
+ ret += ""
104
+ else:
105
+ raise ValueError(f"Invalid style: {self.sep_style}")
106
+
107
+ return ret
108
+
109
+ def append_message(self, role, message):
110
+ self.messages.append([role, message])
111
+
112
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
113
+ if image_process_mode == "Pad":
114
+ def expand2square(pil_img, background_color=(122, 116, 104)):
115
+ width, height = pil_img.size
116
+ if width == height:
117
+ return pil_img
118
+ elif width > height:
119
+ result = Image.new(pil_img.mode, (width, width), background_color)
120
+ result.paste(pil_img, (0, (width - height) // 2))
121
+ return result
122
+ else:
123
+ result = Image.new(pil_img.mode, (height, height), background_color)
124
+ result.paste(pil_img, ((height - width) // 2, 0))
125
+ return result
126
+ image = expand2square(image)
127
+ elif image_process_mode in ["Default", "Crop"]:
128
+ pass
129
+ elif image_process_mode == "Resize":
130
+ image = image.resize((336, 336))
131
+ else:
132
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
133
+ if max(image.size) > max_len:
134
+ max_hw, min_hw = max(image.size), min(image.size)
135
+ aspect_ratio = max_hw / min_hw
136
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
137
+ longest_edge = int(shortest_edge * aspect_ratio)
138
+ W, H = image.size
139
+ if H > W:
140
+ H, W = longest_edge, shortest_edge
141
+ else:
142
+ H, W = shortest_edge, longest_edge
143
+ image = image.resize((W, H))
144
+ if return_pil:
145
+ return image
146
+ else:
147
+ buffered = BytesIO()
148
+ image.save(buffered, format=image_format)
149
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
150
+ return img_b64_str
151
+
152
+ def get_images(self, return_pil=False):
153
+ images = []
154
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
155
+ if i % 2 == 0:
156
+ if type(msg) is tuple:
157
+ msg, image, image_process_mode = msg
158
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
159
+ images.append(image)
160
+ return images
161
+
162
+ def to_gradio_chatbot(self):
163
+ ret = []
164
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
165
+ if i % 2 == 0:
166
+ if type(msg) is tuple:
167
+ msg, image, image_process_mode = msg
168
+ img_b64_str = self.process_image(
169
+ image, "Default", return_pil=False,
170
+ image_format='JPEG')
171
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
172
+ msg = img_str + msg.replace('<image>', '').strip()
173
+ ret.append([msg, None])
174
+ else:
175
+ ret.append([msg, None])
176
+ else:
177
+ ret[-1][-1] = msg
178
+ return ret
179
+
180
+ def copy(self):
181
+ return Conversation(
182
+ system=self.system,
183
+ roles=self.roles,
184
+ messages=[[x, y] for x, y in self.messages],
185
+ offset=self.offset,
186
+ sep_style=self.sep_style,
187
+ sep=self.sep,
188
+ sep2=self.sep2,
189
+ version=self.version)
190
+
191
+ def dict(self):
192
+ if len(self.get_images()) > 0:
193
+ return {
194
+ "system": self.system,
195
+ "roles": self.roles,
196
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
197
+ "offset": self.offset,
198
+ "sep": self.sep,
199
+ "sep2": self.sep2,
200
+ }
201
+ return {
202
+ "system": self.system,
203
+ "roles": self.roles,
204
+ "messages": self.messages,
205
+ "offset": self.offset,
206
+ "sep": self.sep,
207
+ "sep2": self.sep2,
208
+ }
209
+
210
+
211
+ conv_vicuna_v0 = Conversation(
212
+ system="A chat between a curious human and an artificial intelligence assistant. "
213
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
214
+ roles=("Human", "Assistant"),
215
+ messages=(
216
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
217
+ ("Assistant",
218
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
219
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
220
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
221
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
222
+ "renewable and non-renewable energy sources:\n"
223
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
224
+ "energy sources are finite and will eventually run out.\n"
225
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
226
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
227
+ "and other negative effects.\n"
228
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
229
+ "have lower operational costs than non-renewable sources.\n"
230
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
231
+ "locations than non-renewable sources.\n"
232
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
233
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
234
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
235
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
236
+ ),
237
+ offset=2,
238
+ sep_style=SeparatorStyle.SINGLE,
239
+ sep="###",
240
+ )
241
+
242
+ conv_vicuna_v1 = Conversation(
243
+ system="A chat between a curious user and an artificial intelligence assistant. "
244
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
245
+ roles=("USER", "ASSISTANT"),
246
+ version="v1",
247
+ messages=(),
248
+ offset=0,
249
+ sep_style=SeparatorStyle.TWO,
250
+ sep=" ",
251
+ sep2="</s>",
252
+ )
253
+
254
+ conv_llama_2 = Conversation(
255
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
256
+
257
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
258
+ roles=("USER", "ASSISTANT"),
259
+ version="llama_v2",
260
+ messages=(),
261
+ offset=0,
262
+ sep_style=SeparatorStyle.LLAMA_2,
263
+ sep="<s>",
264
+ sep2="</s>",
265
+ )
266
+
267
+ conv_llava_llama_2 = Conversation(
268
+ system="You are a helpful language and vision assistant. "
269
+ "You are able to understand the visual content that the user provides, "
270
+ "and assist the user with a variety of tasks using natural language.",
271
+ roles=("USER", "ASSISTANT"),
272
+ version="llama_v2",
273
+ messages=(),
274
+ offset=0,
275
+ sep_style=SeparatorStyle.LLAMA_2,
276
+ sep="<s>",
277
+ sep2="</s>",
278
+ )
279
+
280
+ conv_mpt = Conversation(
281
+ system="""<|im_start|>system
282
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
283
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
284
+ version="mpt",
285
+ messages=(),
286
+ offset=0,
287
+ sep_style=SeparatorStyle.MPT,
288
+ sep="<|im_end|>",
289
+ )
290
+
291
+ conv_llava_plain = Conversation(
292
+ system="",
293
+ roles=("", ""),
294
+ messages=(
295
+ ),
296
+ offset=0,
297
+ sep_style=SeparatorStyle.PLAIN,
298
+ sep="\n",
299
+ )
300
+
301
+ conv_llava_v0 = Conversation(
302
+ system="A chat between a curious human and an artificial intelligence assistant. "
303
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
304
+ roles=("Human", "Assistant"),
305
+ messages=(
306
+ ),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.SINGLE,
309
+ sep="###",
310
+ )
311
+
312
+ conv_llava_v0_mmtag = Conversation(
313
+ system="A chat between a curious user and an artificial intelligence assistant. "
314
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
315
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
316
+ roles=("Human", "Assistant"),
317
+ messages=(
318
+ ),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.SINGLE,
321
+ sep="###",
322
+ version="v0_mmtag",
323
+ )
324
+
325
+ conv_llava_v1 = Conversation(
326
+ system="A chat between a curious human and an artificial intelligence assistant. "
327
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
328
+ roles=("USER", "ASSISTANT"),
329
+ version="v1",
330
+ messages=(),
331
+ offset=0,
332
+ sep_style=SeparatorStyle.TWO,
333
+ sep=" ",
334
+ sep2="</s>",
335
+ )
336
+
337
+ conv_llava_v1_mmtag = Conversation(
338
+ system="A chat between a curious user and an artificial intelligence assistant. "
339
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
340
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
341
+ roles=("USER", "ASSISTANT"),
342
+ messages=(),
343
+ offset=0,
344
+ sep_style=SeparatorStyle.TWO,
345
+ sep=" ",
346
+ sep2="</s>",
347
+ version="v1_mmtag",
348
+ )
349
+
350
+ conv_mistral_instruct = Conversation(
351
+ system="",
352
+ roles=("USER", "ASSISTANT"),
353
+ version="llama_v2",
354
+ messages=(),
355
+ offset=0,
356
+ sep_style=SeparatorStyle.LLAMA_2,
357
+ sep="",
358
+ sep2="</s>",
359
+ )
360
+
361
+ conv_chatml_direct = Conversation(
362
+ system="""<|im_start|>system
363
+ Answer the questions.""",
364
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
365
+ version="mpt",
366
+ messages=(),
367
+ offset=0,
368
+ sep_style=SeparatorStyle.MPT,
369
+ sep="<|im_end|>",
370
+ )
371
+
372
+ default_conversation = conv_vicuna_v1
373
+ conv_templates = {
374
+ "default": conv_vicuna_v0,
375
+ "v0": conv_vicuna_v0,
376
+ "v1": conv_vicuna_v1,
377
+ "vicuna_v1": conv_vicuna_v1,
378
+ "llama_2": conv_llama_2,
379
+ "mistral_instruct": conv_mistral_instruct,
380
+ "chatml_direct": conv_chatml_direct,
381
+ "mistral_direct": conv_chatml_direct,
382
+
383
+ "plain": conv_llava_plain,
384
+ "v0_plain": conv_llava_plain,
385
+ "llava_v0": conv_llava_v0,
386
+ "v0_mmtag": conv_llava_v0_mmtag,
387
+ "llava_v1": conv_llava_v1,
388
+ "v1_mmtag": conv_llava_v1_mmtag,
389
+ "llava_llama_2": conv_llava_llama_2,
390
+
391
+ "mpt": conv_mpt,
392
+ }
393
+
394
+
395
+ if __name__ == "__main__":
396
+ print(default_conversation.get_prompt())
eagle/mm_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import torch
5
+ import math
6
+ import ast
7
+
8
+ from transformers import StoppingCriteria
9
+ from eagle.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def select_best_resolution(original_size, possible_resolutions):
13
+ """
14
+ Selects the best resolution from a list of possible resolutions based on the original size.
15
+
16
+ Args:
17
+ original_size (tuple): The original size of the image in the format (width, height).
18
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19
+
20
+ Returns:
21
+ tuple: The best fit resolution in the format (width, height).
22
+ """
23
+ original_width, original_height = original_size
24
+ best_fit = None
25
+ max_effective_resolution = 0
26
+ min_wasted_resolution = float('inf')
27
+
28
+ for width, height in possible_resolutions:
29
+ scale = min(width / original_width, height / original_height)
30
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32
+ wasted_resolution = (width * height) - effective_resolution
33
+
34
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35
+ max_effective_resolution = effective_resolution
36
+ min_wasted_resolution = wasted_resolution
37
+ best_fit = (width, height)
38
+
39
+ return best_fit
40
+
41
+
42
+ def resize_and_pad_image(image, target_resolution):
43
+ """
44
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
45
+
46
+ Args:
47
+ image (PIL.Image.Image): The input image.
48
+ target_resolution (tuple): The target resolution (width, height) of the image.
49
+
50
+ Returns:
51
+ PIL.Image.Image: The resized and padded image.
52
+ """
53
+ original_width, original_height = image.size
54
+ target_width, target_height = target_resolution
55
+
56
+ scale_w = target_width / original_width
57
+ scale_h = target_height / original_height
58
+
59
+ if scale_w < scale_h:
60
+ new_width = target_width
61
+ new_height = min(math.ceil(original_height * scale_w), target_height)
62
+ else:
63
+ new_height = target_height
64
+ new_width = min(math.ceil(original_width * scale_h), target_width)
65
+
66
+ # Resize the image
67
+ resized_image = image.resize((new_width, new_height))
68
+
69
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70
+ paste_x = (target_width - new_width) // 2
71
+ paste_y = (target_height - new_height) // 2
72
+ new_image.paste(resized_image, (paste_x, paste_y))
73
+
74
+ return new_image
75
+
76
+
77
+ def divide_to_patches(image, patch_size):
78
+ """
79
+ Divides an image into patches of a specified size.
80
+
81
+ Args:
82
+ image (PIL.Image.Image): The input image.
83
+ patch_size (int): The size of each patch.
84
+
85
+ Returns:
86
+ list: A list of PIL.Image.Image objects representing the patches.
87
+ """
88
+ patches = []
89
+ width, height = image.size
90
+ for i in range(0, height, patch_size):
91
+ for j in range(0, width, patch_size):
92
+ box = (j, i, j + patch_size, i + patch_size)
93
+ patch = image.crop(box)
94
+ patches.append(patch)
95
+
96
+ return patches
97
+
98
+
99
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100
+ """
101
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102
+
103
+ Args:
104
+ image_size (tuple): The size of the input image in the format (width, height).
105
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
106
+ patch_size (int): The size of each image patch.
107
+
108
+ Returns:
109
+ tuple: The shape of the image patch grid in the format (width, height).
110
+ """
111
+ if type(grid_pinpoints) is list:
112
+ possible_resolutions = grid_pinpoints
113
+ else:
114
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
115
+ width, height = select_best_resolution(image_size, possible_resolutions)
116
+ return width // patch_size, height // patch_size
117
+
118
+
119
+ def process_anyres_image(image, processor, grid_pinpoints):
120
+ """
121
+ Process an image with variable resolutions.
122
+
123
+ Args:
124
+ image (PIL.Image.Image): The input image to be processed.
125
+ processor: The image processor object.
126
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
127
+
128
+ Returns:
129
+ torch.Tensor: A tensor containing the processed image patches.
130
+ """
131
+ if type(grid_pinpoints) is list:
132
+ possible_resolutions = grid_pinpoints
133
+ else:
134
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
135
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
136
+ image_padded = resize_and_pad_image(image, best_resolution)
137
+
138
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
139
+
140
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141
+
142
+ image_patches = [image_original_resize] + patches
143
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144
+ for image_patch in image_patches]
145
+ return torch.stack(image_patches, dim=0)
146
+
147
+
148
+ def load_image_from_base64(image):
149
+ return Image.open(BytesIO(base64.b64decode(image)))
150
+
151
+
152
+ def expand2square(pil_img, background_color):
153
+ width, height = pil_img.size
154
+ if width == height:
155
+ return pil_img
156
+ elif width > height:
157
+ result = Image.new(pil_img.mode, (width, width), background_color)
158
+ result.paste(pil_img, (0, (width - height) // 2))
159
+ return result
160
+ else:
161
+ result = Image.new(pil_img.mode, (height, height), background_color)
162
+ result.paste(pil_img, ((height - width) // 2, 0))
163
+ return result
164
+
165
+
166
+ def process_images(images, image_processor, model_cfg):
167
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168
+ new_images = []
169
+ if image_aspect_ratio == 'pad':
170
+ for image in images:
171
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173
+ new_images.append(image)
174
+ elif image_aspect_ratio == "anyres":
175
+ for image in images:
176
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177
+ new_images.append(image)
178
+ else:
179
+ return image_processor(images, return_tensors='pt')['pixel_values']
180
+ if all(x.shape == new_images[0].shape for x in new_images):
181
+ new_images = torch.stack(new_images, dim=0)
182
+ return new_images
183
+
184
+
185
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
186
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
187
+
188
+ def insert_separator(X, sep):
189
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
190
+
191
+ input_ids = []
192
+ offset = 0
193
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
194
+ offset = 1
195
+ input_ids.append(prompt_chunks[0][0])
196
+
197
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
198
+ input_ids.extend(x[offset:])
199
+
200
+ if return_tensors is not None:
201
+ if return_tensors == 'pt':
202
+ return torch.tensor(input_ids, dtype=torch.long)
203
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
204
+ return input_ids
205
+
206
+
207
+ def get_model_name_from_path(model_path):
208
+ model_path = model_path.strip("/")
209
+ model_paths = model_path.split("/")
210
+ if model_paths[-1].startswith('checkpoint-'):
211
+ return model_paths[-2] + "_" + model_paths[-1]
212
+ else:
213
+ return model_paths[-1]
214
+
215
+ class KeywordsStoppingCriteria(StoppingCriteria):
216
+ def __init__(self, keywords, tokenizer, input_ids):
217
+ self.keywords = keywords
218
+ self.keyword_ids = []
219
+ self.max_keyword_len = 0
220
+ for keyword in keywords:
221
+ cur_keyword_ids = tokenizer(keyword).input_ids
222
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
223
+ cur_keyword_ids = cur_keyword_ids[1:]
224
+ if len(cur_keyword_ids) > self.max_keyword_len:
225
+ self.max_keyword_len = len(cur_keyword_ids)
226
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
227
+ self.tokenizer = tokenizer
228
+ self.start_len = input_ids.shape[1]
229
+
230
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
231
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
232
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
233
+ for keyword_id in self.keyword_ids:
234
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
235
+ if torch.equal(truncated_output_ids, keyword_id):
236
+ return True
237
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
238
+ for keyword in self.keywords:
239
+ if keyword in outputs:
240
+ return True
241
+ return False
242
+
243
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
244
+ outputs = []
245
+ for i in range(output_ids.shape[0]):
246
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
247
+ return all(outputs)
eagle/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .language_model.eagle_llama import EagleLlamaForCausalLM, EagleConfig
eagle/model/builder.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from eagle.model import *
23
+ from eagle.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27
+ kwargs = {"device_map": device_map, **kwargs}
28
+
29
+ if device != "cuda":
30
+ kwargs['device_map'] = {"": device}
31
+
32
+ if load_8bit:
33
+ kwargs['load_in_8bit'] = True
34
+ elif load_4bit:
35
+ kwargs['load_in_4bit'] = True
36
+ kwargs['quantization_config'] = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type='nf4'
41
+ )
42
+ else:
43
+ kwargs['torch_dtype'] = torch.float16
44
+
45
+ if use_flash_attn:
46
+ kwargs['attn_implementation'] = 'flash_attention_2'
47
+
48
+ if 'eagle' in model_name.lower():
49
+ if 'lora' in model_name.lower() and model_base is None:
50
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
51
+ if 'lora' in model_name.lower() and model_base is not None:
52
+ from eagle.model.language_model.eagle_llama import eagleConfig
53
+ lora_cfg_pretrained = eagleConfig.from_pretrained(model_path)
54
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
55
+ print('Loading eagle from base model...')
56
+ model = EagleLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
57
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
58
+ if model.lm_head.weight.shape[0] != token_num:
59
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
60
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61
+
62
+ print('Loading additional Eagle weights...')
63
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
64
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
65
+ else:
66
+ # this is probably from HF Hub
67
+ from huggingface_hub import hf_hub_download
68
+ def load_from_hf(repo_id, filename, subfolder=None):
69
+ cache_file = hf_hub_download(
70
+ repo_id=repo_id,
71
+ filename=filename,
72
+ subfolder=subfolder)
73
+ return torch.load(cache_file, map_location='cpu')
74
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
75
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
76
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
77
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
78
+ model.load_state_dict(non_lora_trainables, strict=False)
79
+
80
+ from peft import PeftModel
81
+ print('Loading LoRA weights...')
82
+ model = PeftModel.from_pretrained(model, model_path)
83
+ print('Merging LoRA weights...')
84
+ model = model.merge_and_unload()
85
+ print('Model is loaded...')
86
+ elif model_base is not None:
87
+ # this may be mm projector only
88
+ print('Loading Eagle from base model...')
89
+
90
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
92
+ model = EagleLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
93
+
94
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
95
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
96
+ model.load_state_dict(mm_projector_weights, strict=False)
97
+ else:
98
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
99
+ model = EagleLlamaForCausalLM.from_pretrained(
100
+ model_path,
101
+ low_cpu_mem_usage=True,
102
+ **kwargs
103
+ )
104
+ else:
105
+ # Load language model
106
+ if model_base is not None:
107
+ # PEFT model
108
+ from peft import PeftModel
109
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
110
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
111
+ print(f"Loading LoRA weights from {model_path}")
112
+ model = PeftModel.from_pretrained(model, model_path)
113
+ print(f"Merging weights")
114
+ model = model.merge_and_unload()
115
+ print('Convert to FP16...')
116
+ model.to(torch.float16)
117
+ else:
118
+ use_fast = False
119
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
120
+ model = EagleLlamaForCausalLM.from_pretrained(
121
+ model_path,
122
+ low_cpu_mem_usage=True,
123
+ **kwargs
124
+ )
125
+ # Always load the weight into a EagleLLaMA model
126
+ # tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
127
+ # model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
128
+
129
+ image_processor = None
130
+
131
+ # if 'eagle' in model_name.lower():
132
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
133
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
134
+ if mm_use_im_patch_token:
135
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
136
+ if mm_use_im_start_end:
137
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
138
+ model.resize_token_embeddings(len(tokenizer))
139
+
140
+ vision_tower = model.get_vision_tower()
141
+ if not vision_tower.is_loaded:
142
+ vision_tower.load_model(device_map=device_map)
143
+ if device_map != 'auto':
144
+ vision_tower.to(device=device_map, dtype=torch.float16)
145
+ image_processor = vision_tower.image_processor
146
+
147
+ if hasattr(model.config, "max_sequence_length"):
148
+ context_len = model.config.max_sequence_length
149
+ else:
150
+ context_len = 2048
151
+
152
+ return tokenizer, model, image_processor, context_len
eagle/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m eagle.model.consolidate --src ~/model_weights/eagle-7b --dst ~/model_weights/eagle-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from eagle.model import *
10
+ from eagle.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
eagle/model/eagle_arch.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from eagle.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+ from eagle.mm_utils import get_anyres_image_grid_shape
27
+
28
+
29
+ class EagleMetaModel:
30
+
31
+ def __init__(self, config):
32
+ super(EagleMetaModel, self).__init__(config)
33
+
34
+ if hasattr(config, "mm_vision_tower"):
35
+ self.vision_tower = build_vision_tower(config, delay_load=True)
36
+ fpn_input_dim = [] if not hasattr(self.vision_tower, "fpn_input_dim") else self.vision_tower.fpn_input_dim
37
+ self.mm_projector = build_vision_projector(config, fpn_input_dim=fpn_input_dim)
38
+
39
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
40
+ self.image_newline = nn.Parameter(
41
+ torch.empty(config.hidden_size, dtype=self.dtype)
42
+ )
43
+
44
+ def get_vision_tower(self):
45
+ vision_tower = getattr(self, 'vision_tower', None)
46
+ if type(vision_tower) is list:
47
+ vision_tower = vision_tower[0]
48
+ return vision_tower
49
+
50
+ def initialize_vision_modules(self, model_args, fsdp=None):
51
+ vision_tower = model_args.vision_tower
52
+ mm_vision_select_layer = model_args.mm_vision_select_layer
53
+ mm_vision_select_feature = model_args.mm_vision_select_feature
54
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
55
+ mm_patch_merge_type = model_args.mm_patch_merge_type
56
+
57
+ self.config.mm_vision_tower = vision_tower
58
+
59
+ if self.get_vision_tower() is None:
60
+ vision_tower = build_vision_tower(model_args)
61
+
62
+ if fsdp is not None and len(fsdp) > 0:
63
+ self.vision_tower = [vision_tower]
64
+ else:
65
+ self.vision_tower = vision_tower
66
+ else:
67
+ if fsdp is not None and len(fsdp) > 0:
68
+ vision_tower = self.vision_tower[0]
69
+ else:
70
+ vision_tower = self.vision_tower
71
+ vision_tower.load_model()
72
+
73
+ self.config.use_mm_proj = True
74
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
75
+ self.config.mm_hidden_size = vision_tower.hidden_size
76
+ self.config.mm_vision_select_layer = mm_vision_select_layer
77
+ self.config.mm_vision_select_feature = mm_vision_select_feature
78
+ self.config.mm_patch_merge_type = mm_patch_merge_type
79
+ # record config for resampler
80
+ self.config.mm_projector_query_number = model_args.mm_projector_query_number
81
+
82
+ if getattr(self, 'mm_projector', None) is None:
83
+ fpn_input_dim = [] if not hasattr(self.vision_tower, "fpn_input_dim") else self.vision_tower.fpn_input_dim
84
+ self.mm_projector = build_vision_projector(self.config, fpn_input_dim=fpn_input_dim)
85
+
86
+ if 'unpad' in mm_patch_merge_type:
87
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
88
+ self.image_newline = nn.Parameter(
89
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
90
+ )
91
+ else:
92
+ # In case it is frozen by LoRA
93
+ for p in self.mm_projector.parameters():
94
+ p.requires_grad = True
95
+
96
+ if pretrain_mm_mlp_adapter is not None:
97
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
98
+ def get_w(weights, keyword):
99
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
100
+
101
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
102
+
103
+
104
+ def unpad_image(tensor, original_size):
105
+ """
106
+ Unpads a PyTorch tensor of a padded and resized image.
107
+
108
+ Args:
109
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
110
+ original_size (tuple): The original size of the image (height, width).
111
+
112
+ Returns:
113
+ torch.Tensor: The unpadded image tensor.
114
+ """
115
+ original_width, original_height = original_size
116
+ current_height, current_width = tensor.shape[1:]
117
+
118
+ original_aspect_ratio = original_width / original_height
119
+ current_aspect_ratio = current_width / current_height
120
+
121
+ if original_aspect_ratio > current_aspect_ratio:
122
+ scale_factor = current_width / original_width
123
+ new_height = int(original_height * scale_factor)
124
+ padding = (current_height - new_height) // 2
125
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
126
+ else:
127
+ scale_factor = current_height / original_height
128
+ new_width = int(original_width * scale_factor)
129
+ padding = (current_width - new_width) // 2
130
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
131
+
132
+ return unpadded_tensor
133
+
134
+
135
+ class EagleMetaForCausalLM(ABC):
136
+
137
+ @abstractmethod
138
+ def get_model(self):
139
+ pass
140
+
141
+ def get_vision_tower(self):
142
+ return self.get_model().get_vision_tower()
143
+
144
+ def encode_images(self, images):
145
+ image_features = self.get_model().get_vision_tower()(images)
146
+ image_features = self.get_model().mm_projector(image_features)
147
+ return image_features
148
+
149
+ def prepare_inputs_labels_for_multimodal(
150
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
151
+ images, image_sizes=None
152
+ ):
153
+ vision_tower = self.get_vision_tower()
154
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
155
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
156
+
157
+ if type(images) is list or images.ndim == 5:
158
+ if type(images) is list:
159
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
160
+ concat_images = torch.cat([image for image in images], dim=0)
161
+ image_features = self.encode_images(concat_images)
162
+ split_sizes = [image.shape[0] for image in images]
163
+ image_features = torch.split(image_features, split_sizes, dim=0)
164
+ mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
165
+ image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
166
+ if mm_patch_merge_type == 'flat':
167
+ image_features = [x.flatten(0, 1) for x in image_features]
168
+ elif mm_patch_merge_type.startswith('spatial'):
169
+ new_image_features = []
170
+ for image_idx, image_feature in enumerate(image_features):
171
+ if image_feature.shape[0] > 1:
172
+ base_image_feature = image_feature[0]
173
+ image_feature = image_feature[1:]
174
+ height = width = self.get_vision_tower().num_patches_per_side
175
+ assert height * width == base_image_feature.shape[0]
176
+ if image_aspect_ratio == 'anyres':
177
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
178
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
179
+ else:
180
+ raise NotImplementedError
181
+ if 'unpad' in mm_patch_merge_type:
182
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
183
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
184
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
185
+ image_feature = torch.cat((
186
+ image_feature,
187
+ self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
188
+ ), dim=-1)
189
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
190
+ else:
191
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
192
+ image_feature = image_feature.flatten(0, 3)
193
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
194
+ else:
195
+ image_feature = image_feature[0]
196
+ if 'unpad' in mm_patch_merge_type:
197
+ image_feature = torch.cat((
198
+ image_feature,
199
+ self.model.image_newline[None].to(image_feature.device)
200
+ ), dim=0)
201
+ new_image_features.append(image_feature)
202
+ image_features = new_image_features
203
+ else:
204
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
205
+ else:
206
+ image_features = self.encode_images(images)
207
+
208
+ # TODO: image start / end is not implemented here to support pretraining.
209
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
210
+ raise NotImplementedError
211
+
212
+ # Let's just add dummy tensors if they do not exist,
213
+ # it is a headache to deal with None all the time.
214
+ # But it is not ideal, and if you have a better idea,
215
+ # please open an issue / submit a PR, thanks.
216
+ _labels = labels
217
+ _position_ids = position_ids
218
+ _attention_mask = attention_mask
219
+ if attention_mask is None:
220
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
221
+ else:
222
+ attention_mask = attention_mask.bool()
223
+ if position_ids is None:
224
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
225
+ if labels is None:
226
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
227
+
228
+ # remove the padding using attention_mask -- FIXME
229
+ _input_ids = input_ids
230
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
231
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
232
+
233
+ new_input_embeds = []
234
+ new_labels = []
235
+ cur_image_idx = 0
236
+ for batch_idx, cur_input_ids in enumerate(input_ids):
237
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
238
+ if num_images == 0:
239
+ cur_image_features = image_features[cur_image_idx]
240
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
241
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
242
+ new_input_embeds.append(cur_input_embeds)
243
+ new_labels.append(labels[batch_idx])
244
+ cur_image_idx += 1
245
+ continue
246
+
247
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
248
+ cur_input_ids_noim = []
249
+ cur_labels = labels[batch_idx]
250
+ cur_labels_noim = []
251
+ for i in range(len(image_token_indices) - 1):
252
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
253
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
254
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
255
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
256
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
257
+ cur_new_input_embeds = []
258
+ cur_new_labels = []
259
+
260
+ for i in range(num_images + 1):
261
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
262
+ cur_new_labels.append(cur_labels_noim[i])
263
+ if i < num_images:
264
+ cur_image_features = image_features[cur_image_idx]
265
+ cur_image_idx += 1
266
+ cur_new_input_embeds.append(cur_image_features)
267
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
268
+
269
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
270
+
271
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
272
+ cur_new_labels = torch.cat(cur_new_labels)
273
+
274
+ new_input_embeds.append(cur_new_input_embeds)
275
+ new_labels.append(cur_new_labels)
276
+
277
+ # Truncate sequences to max length as image embeddings can make the sequence longer
278
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
279
+ if tokenizer_model_max_length is not None:
280
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
281
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
282
+
283
+ # Combine them
284
+ max_len = max(x.shape[0] for x in new_input_embeds)
285
+ batch_size = len(new_input_embeds)
286
+
287
+ new_input_embeds_padded = []
288
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
289
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
290
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
291
+
292
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
293
+ cur_len = cur_new_embed.shape[0]
294
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
295
+ new_input_embeds_padded.append(torch.cat((
296
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
297
+ cur_new_embed
298
+ ), dim=0))
299
+ if cur_len > 0:
300
+ new_labels_padded[i, -cur_len:] = cur_new_labels
301
+ attention_mask[i, -cur_len:] = True
302
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
303
+ else:
304
+ new_input_embeds_padded.append(torch.cat((
305
+ cur_new_embed,
306
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
307
+ ), dim=0))
308
+ if cur_len > 0:
309
+ new_labels_padded[i, :cur_len] = cur_new_labels
310
+ attention_mask[i, :cur_len] = True
311
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
312
+
313
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
314
+
315
+ if _labels is None:
316
+ new_labels = None
317
+ else:
318
+ new_labels = new_labels_padded
319
+
320
+ if _attention_mask is None:
321
+ attention_mask = None
322
+ else:
323
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
324
+
325
+ if _position_ids is None:
326
+ position_ids = None
327
+
328
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
329
+
330
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
331
+ if model_args.mm_use_im_patch_token:
332
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
333
+ self.resize_token_embeddings(len(tokenizer))
334
+
335
+ if model_args.mm_use_im_start_end:
336
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
337
+ self.resize_token_embeddings(len(tokenizer))
338
+
339
+ if num_new_tokens > 0:
340
+ input_embeddings = self.get_input_embeddings().weight.data
341
+ output_embeddings = self.get_output_embeddings().weight.data
342
+
343
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
344
+ dim=0, keepdim=True)
345
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
346
+ dim=0, keepdim=True)
347
+
348
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
349
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
350
+
351
+ if model_args.tune_mm_mlp_adapter:
352
+ for p in self.get_input_embeddings().parameters():
353
+ p.requires_grad = True
354
+ for p in self.get_output_embeddings().parameters():
355
+ p.requires_grad = False
356
+
357
+ if model_args.pretrain_mm_mlp_adapter:
358
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
359
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
360
+ assert num_new_tokens == 2
361
+ if input_embeddings.shape == embed_tokens_weight.shape:
362
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
363
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
364
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
365
+ else:
366
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
367
+ elif model_args.mm_use_im_patch_token:
368
+ if model_args.tune_mm_mlp_adapter:
369
+ for p in self.get_input_embeddings().parameters():
370
+ p.requires_grad = False
371
+ for p in self.get_output_embeddings().parameters():
372
+ p.requires_grad = False
eagle/model/language_model/eagle_llama.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from ..eagle_arch import EagleMetaModel, EagleMetaForCausalLM
28
+
29
+
30
+ class EagleConfig(LlamaConfig):
31
+ model_type = "eagle_llama"
32
+
33
+
34
+ class EagleLlamaModel(EagleMetaModel, LlamaModel):
35
+ config_class = EagleConfig
36
+
37
+ def __init__(self, config: LlamaConfig):
38
+ super(EagleLlamaModel, self).__init__(config)
39
+
40
+
41
+ class EagleLlamaForCausalLM(LlamaForCausalLM, EagleMetaForCausalLM):
42
+ config_class = EagleConfig
43
+
44
+ def __init__(self, config):
45
+ super(LlamaForCausalLM, self).__init__(config)
46
+ self.model = EagleLlamaModel(config)
47
+ self.pretraining_tp = config.pretraining_tp
48
+ self.vocab_size = config.vocab_size
49
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50
+
51
+ # Initialize weights and apply final processing
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
72
+
73
+ if inputs_embeds is None:
74
+ (
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ inputs_embeds,
80
+ labels
81
+ ) = self.prepare_inputs_labels_for_multimodal(
82
+ input_ids,
83
+ position_ids,
84
+ attention_mask,
85
+ past_key_values,
86
+ labels,
87
+ images,
88
+ image_sizes
89
+ )
90
+
91
+ return super().forward(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ position_ids=position_ids,
95
+ past_key_values=past_key_values,
96
+ inputs_embeds=inputs_embeds,
97
+ labels=labels,
98
+ use_cache=use_cache,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict
102
+ )
103
+
104
+ @torch.no_grad()
105
+ def generate(
106
+ self,
107
+ inputs: Optional[torch.Tensor] = None,
108
+ images: Optional[torch.Tensor] = None,
109
+ image_sizes: Optional[torch.Tensor] = None,
110
+ **kwargs,
111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
112
+ position_ids = kwargs.pop("position_ids", None)
113
+ attention_mask = kwargs.pop("attention_mask", None)
114
+ if "inputs_embeds" in kwargs:
115
+ raise NotImplementedError("`inputs_embeds` is not supported")
116
+
117
+ if images is not None:
118
+ (
119
+ inputs,
120
+ position_ids,
121
+ attention_mask,
122
+ _,
123
+ inputs_embeds,
124
+ _
125
+ ) = self.prepare_inputs_labels_for_multimodal(
126
+ inputs,
127
+ position_ids,
128
+ attention_mask,
129
+ None,
130
+ None,
131
+ images,
132
+ image_sizes=image_sizes
133
+ )
134
+ else:
135
+ inputs_embeds = self.get_model().embed_tokens(inputs)
136
+
137
+ return super().generate(
138
+ position_ids=position_ids,
139
+ attention_mask=attention_mask,
140
+ inputs_embeds=inputs_embeds,
141
+ **kwargs
142
+ )
143
+
144
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145
+ inputs_embeds=None, **kwargs):
146
+ images = kwargs.pop("images", None)
147
+ image_sizes = kwargs.pop("image_sizes", None)
148
+ inputs = super().prepare_inputs_for_generation(
149
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150
+ )
151
+ if images is not None:
152
+ inputs['images'] = images
153
+ if image_sizes is not None:
154
+ inputs['image_sizes'] = image_sizes
155
+ return inputs
156
+
157
+ AutoConfig.register("eagle_llama", EagleConfig)
158
+ AutoModelForCausalLM.register(EagleConfig, EagleLlamaForCausalLM)
eagle/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+ from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower
4
+
5
+ def build_vision_tower(vision_tower_cfg, **kwargs):
6
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7
+
8
+ if "clip" in vision_tower and vision_tower.startswith("openai"):
9
+ is_absolute_path_exists = os.path.exists(vision_tower)
10
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
11
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
12
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
13
+
14
+ elif ";" in vision_tower:
15
+ return MultiBackboneChannelConcatenationVisionTower(vision_tower, args=vision_tower_cfg)
16
+
17
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
eagle/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+
17
+ if not delay_load:
18
+ self.load_model()
19
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
20
+ self.load_model()
21
+ else:
22
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23
+
24
+ def load_model(self, device_map=None):
25
+ if self.is_loaded:
26
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
27
+ return
28
+
29
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
30
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
31
+ self.vision_tower.requires_grad_(False)
32
+
33
+ self.is_loaded = True
34
+
35
+ def feature_select(self, image_forward_outs):
36
+ image_features = image_forward_outs.hidden_states[self.select_layer]
37
+ if self.select_feature == 'patch':
38
+ image_features = image_features[:, 1:]
39
+ elif self.select_feature == 'cls_patch':
40
+ image_features = image_features
41
+ else:
42
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
43
+ return image_features
44
+
45
+ @torch.no_grad()
46
+ def forward(self, images):
47
+ if type(images) is list:
48
+ image_features = []
49
+ for image in images:
50
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
51
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
52
+ image_features.append(image_feature)
53
+ else:
54
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
55
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
56
+
57
+ return image_features
58
+
59
+ @property
60
+ def dummy_feature(self):
61
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
62
+
63
+ @property
64
+ def dtype(self):
65
+ return self.vision_tower.dtype
66
+
67
+ @property
68
+ def device(self):
69
+ return self.vision_tower.device
70
+
71
+ @property
72
+ def config(self):
73
+ if self.is_loaded:
74
+ return self.vision_tower.config
75
+ else:
76
+ return self.cfg_only
77
+
78
+ @property
79
+ def hidden_size(self):
80
+ return self.config.hidden_size
81
+
82
+ @property
83
+ def num_patches_per_side(self):
84
+ return self.config.image_size // self.config.patch_size
85
+
86
+ @property
87
+ def num_patches(self):
88
+ return (self.config.image_size // self.config.patch_size) ** 2
eagle/model/multimodal_encoder/convnext_encoder.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPImageProcessor
4
+ from .vision_models.convnext import convnext_xxlarge
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ cfg={
8
+ "crop_size": 256,
9
+ "do_center_crop": True,
10
+ "do_normalize": True,
11
+ "do_resize": True,
12
+ "feature_extractor_type": "CLIPFeatureExtractor",
13
+ "image_mean": [
14
+ 0.48145466,
15
+ 0.4578275,
16
+ 0.40821073
17
+ ],
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "size": 256
25
+ }
26
+
27
+ class ConvNextVisionTower(nn.Module):
28
+ def __init__(self, vision_tower, args, delay_load=False):
29
+ super().__init__()
30
+
31
+ self.is_loaded = False
32
+ self.freeze_vision=args.freeze_vision
33
+ self.input_image_size=args.input_image_size
34
+ self.vision_tower_name = vision_tower
35
+ self.select_layer = -1 # hardcode
36
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
37
+
38
+ self.load_model()
39
+
40
+ def load_model(self):
41
+ self.image_processor = CLIPImageProcessor(**cfg)
42
+ if 'xxlarge' in self.vision_tower_name:
43
+ self.vision_tower = convnext_xxlarge(self.vision_tower_name)
44
+ setattr(self.vision_tower, 'hidden_size', 3072)
45
+ else:
46
+ raise NotImplementedError
47
+
48
+ if self.freeze_vision:
49
+ self.vision_tower.requires_grad_(False)
50
+
51
+ # Hardcode
52
+ for s in self.vision_tower.stages:
53
+ s.grad_checkpointing = True
54
+
55
+ if self.input_image_size is not None:
56
+ self.image_processor.size=self.input_image_size
57
+ self.image_processor.crop_size={
58
+ 'height':self.input_image_size,
59
+ 'width': self.input_image_size
60
+ }
61
+
62
+ self.is_loaded = True
63
+
64
+ def feature_select(self, image_forward_outs):
65
+ image_features = image_forward_outs[self.select_layer]
66
+ return image_features
67
+
68
+ def forward_features(self, x):
69
+ x = self.vision_tower.stem(x)
70
+ image_forward_out=[]
71
+ for blk in self.vision_tower.stages:
72
+ x = blk(x)
73
+ b,c,h,w=x.shape
74
+ image_forward_out.append(x.view(b,c,-1).transpose(1,2))
75
+ return image_forward_out
76
+
77
+ def forward(self, images):
78
+ if self.freeze_vision:
79
+ with torch.no_grad():
80
+ image_features = self._forward_images(images)
81
+ else:
82
+ image_features = self._forward_images(images)
83
+
84
+ return image_features
85
+
86
+ def _forward_images(self, images):
87
+
88
+ image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
89
+ image_features = self.feature_select(image_forward_outs)
90
+
91
+ return image_features
92
+
93
+ @property
94
+ def dummy_feature(self):
95
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
96
+
97
+ @property
98
+ def dtype(self):
99
+ return next(self.vision_tower.parameters()).dtype
100
+
101
+ @property
102
+ def device(self):
103
+ return next(self.vision_tower.parameters()).device
104
+
105
+ @property
106
+ def config(self):
107
+ assert NotImplementedError
108
+ pass
109
+
110
+ @property
111
+ def num_attention_heads(self):
112
+ # as constant
113
+ return 16
114
+ @property
115
+ def num_layers(self):
116
+ # as constant
117
+ return 4
118
+ @property
119
+ def hidden_size(self):
120
+ return self.vision_tower.hidden_size
121
+
122
+ @property
123
+ def num_patches(self):
124
+ return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
eagle/model/multimodal_encoder/hr_clip_encoder.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly copy-paste from LLaVA-HR
3
+ https://github.com/luogen1996/LLaVA-HR
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
11
+
12
+ import math
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from typing import List, Optional
16
+
17
+
18
+ def forward_embeddings(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
19
+ batch_size = pixel_values.shape[0]
20
+ target_dtype = self.patch_embedding.weight.dtype
21
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
22
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
23
+
24
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
25
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
26
+ position_embeddings = self.position_embedding(self.position_ids)
27
+
28
+ if position_embeddings.shape[1]!=embeddings.shape[1]:
29
+ position_embeddings=resample_pos_embed(position_embeddings,embeddings.shape[1])
30
+
31
+ embeddings = embeddings + position_embeddings
32
+ return embeddings
33
+
34
+
35
+ def resample_pos_embed(
36
+ posemb,
37
+ new_size: int,
38
+ num_prefix_tokens: int = 1,
39
+ interpolation: str = 'bicubic',
40
+ antialias: bool = True,
41
+ verbose: bool = False,
42
+ ):
43
+ new_size=[int(math.sqrt(new_size-num_prefix_tokens)),int(math.sqrt(new_size-num_prefix_tokens))]
44
+ num_pos_tokens = posemb.shape[1] - num_prefix_tokens
45
+ old_size = int(math.sqrt(num_pos_tokens))
46
+ bs=posemb.shape[0]
47
+
48
+ if num_prefix_tokens:
49
+ posemb_prefix, posemb = posemb[:,:num_prefix_tokens], posemb[:,num_prefix_tokens:]
50
+ else:
51
+ posemb_prefix, posemb = None, posemb
52
+
53
+ # do the interpolation
54
+ embed_dim = posemb.shape[-1]
55
+ orig_dtype = posemb.dtype
56
+ posemb = posemb.float() # interpolate needs float32
57
+ posemb = posemb.reshape(bs, old_size, old_size, -1).permute(0, 3, 1, 2)
58
+ posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
59
+ posemb = posemb.permute(0, 2, 3, 1).reshape(bs, -1, embed_dim)
60
+ posemb = posemb.to(dtype=orig_dtype)
61
+
62
+ # add back extra (class, etc) prefix tokens
63
+ if posemb_prefix is not None:
64
+ posemb = torch.cat([posemb_prefix, posemb],1)
65
+
66
+ if not torch.jit.is_scripting() and verbose:
67
+ print(f'Resized position embedding: {old_size} to {new_size}.')
68
+
69
+ return posemb
70
+
71
+ class HRCLIPVisionTower(nn.Module):
72
+ def __init__(self, vision_tower, args, delay_load=False):
73
+ super().__init__()
74
+
75
+ self.is_loaded = False
76
+ self.freeze_vision=args.freeze_vision
77
+ self.input_image_size=args.input_image_size
78
+ self.vision_tower_name = vision_tower
79
+ self.select_layer = args.mm_vision_select_layer
80
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
81
+
82
+ if not delay_load:
83
+ self.load_model()
84
+ else:
85
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
86
+
87
+
88
+ def load_model(self):
89
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
90
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
91
+ # checkpointing for clip
92
+ self.vision_tower.vision_model.encoder.gradient_checkpointing =True
93
+
94
+ if self.freeze_vision:
95
+ self.vision_tower.requires_grad_(False)
96
+
97
+ cls_=self.vision_tower.vision_model.embeddings
98
+ bound_method = forward_embeddings.__get__(cls_, cls_.__class__)
99
+ setattr(cls_, 'forward', bound_method)
100
+
101
+ if self.input_image_size is not None:
102
+ self.image_processor.size=self.input_image_size
103
+ self.image_processor.crop_size={
104
+ 'height':self.input_image_size,
105
+ 'width': self.input_image_size
106
+ }
107
+
108
+ self.is_loaded = True
109
+
110
+ def forward(self, x):
111
+ # 448 image input
112
+ blks = self.vision_tower.vision_model.encoder.layers
113
+ x = self.vision_tower.vision_model.embeddings(x)
114
+ x = self.vision_tower.vision_model.pre_layrnorm(x[:, 1:])
115
+
116
+ # inference of fast branch
117
+ for blk in blks:
118
+ if self.training:
119
+ x=checkpoint(
120
+ blk.__call__,
121
+ x,
122
+ None,
123
+ None
124
+ )[0]
125
+ else:
126
+ x = blk(x, None, None)[0]
127
+
128
+ return x
129
+
130
+ @property
131
+ def dummy_feature(self):
132
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
133
+
134
+ @property
135
+ def dtype(self):
136
+ return self.vision_tower.dtype
137
+
138
+ @property
139
+ def device(self):
140
+ return self.vision_tower.device
141
+
142
+
143
+ @property
144
+ def num_attention_heads(self):
145
+ return self.config.num_attention_heads
146
+ @property
147
+ def num_layers(self):
148
+ return self.config.num_hidden_layers
149
+ @property
150
+ def config(self):
151
+ if self.is_loaded:
152
+ return self.vision_tower.config
153
+ else:
154
+ return self.cfg_only
155
+
156
+ @property
157
+ def hidden_size(self):
158
+ return self.config.hidden_size
159
+
160
+ @property
161
+ def num_patches(self):
162
+ return (self.config.image_size // self.config.patch_size) ** 2
eagle/model/multimodal_encoder/multi_backbone_channel_concatenation_encoder.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ from .convnext_encoder import ConvNextVisionTower
5
+ from .hr_clip_encoder import HRCLIPVisionTower
6
+ from .vision_models.eva_vit import EVAVITVisionTower
7
+ from .sam_encoder import SAMVisionTower
8
+ from .pix2struct_encoder import Pix2StructLargeVisionTower
9
+ import torch.nn.functional as F
10
+ from torch.nn.init import trunc_normal_
11
+ from copy import deepcopy
12
+ import random
13
+ import math
14
+
15
+ class MultiBackboneChannelConcatenationVisionTower(nn.Module):
16
+ def __init__(self,
17
+ vision_tower,
18
+ args,
19
+ grid_size=32):
20
+
21
+ super().__init__()
22
+
23
+ self.is_loaded = False
24
+ self.grid_size = grid_size
25
+ self.num_tokens = self.grid_size ** 2
26
+
27
+ vision_tower_name_list = vision_tower.split(";")
28
+ self.input_image_size = 1024 # hardcode
29
+ self.load_vision_towers(vision_tower_name_list, args)
30
+
31
+
32
+ def load_vision_towers(self, vision_tower_name_list, args):
33
+ self.vision_towers = nn.ModuleList()
34
+ for name in vision_tower_name_list:
35
+ if name == 'det-1024':
36
+ det_args = deepcopy(args)
37
+ det_args.input_image_size = 1024
38
+ det_args.freeze_vision = False
39
+ det_args.vision_tower_pretrained_from = '/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth'
40
+ det_vision_tower = EVAVITVisionTower("eva02-l-16", det_args)
41
+ det_vision_tower.load_model()
42
+ self.vision_towers.append(det_vision_tower)
43
+
44
+ elif name == 'convnext-1024':
45
+ ## ConvNeXt
46
+ convnext_args = deepcopy(args)
47
+ convnext_args.freeze_vision = False
48
+ convnext_args.input_image_size = 1024
49
+ convnext_vision_tower = "convnext_xxlarge.clip_laion2b_soup" # hardcode
50
+ convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower,
51
+ convnext_args)
52
+ convnext_vision_tower.load_model()
53
+ self.vision_towers.append(convnext_vision_tower)
54
+
55
+ elif name == "sam-1024":
56
+ sam_args = deepcopy(args)
57
+ sam_args.freeze_vision = False
58
+ sam_args.input_image_size = 1024
59
+ sam_args.add_pixel_shuffle = True
60
+ sam_vision_tower = SAMVisionTower("SAM-L", sam_args)
61
+ sam_vision_tower.load_model()
62
+ self.vision_towers.append(sam_vision_tower)
63
+
64
+ elif name == 'pix2struct-1024':
65
+ pix_args = deepcopy(args)
66
+ #pix_args.freeze_vision = True
67
+ pix_args.input_image_size = 1024
68
+ pix_args.freeze_vision = False
69
+ pix_args.do_resize = True
70
+ pix_args.de_normalize = True
71
+ pix_vision_tower = Pix2StructLargeVisionTower("pix2struct-large", pix_args)
72
+ pix_vision_tower.load_model()
73
+ self.vision_towers.append(pix_vision_tower)
74
+
75
+ elif name == 'clip-448':
76
+ clip_args = deepcopy(args)
77
+ clip_args.input_image_size = 336 # actually 448, will have no effect
78
+ clip_args.freeze_vision = False
79
+ clip_vision_tower = HRCLIPVisionTower("openai/clip-vit-large-patch14-336", clip_args)
80
+ clip_vision_tower.load_model()
81
+ self.vision_towers.append(clip_vision_tower)
82
+
83
+ # a hardcode here, so we always use convnext in the vision encoder mixture
84
+ self.image_processor = convnext_vision_tower.image_processor
85
+ self.is_loaded = True
86
+
87
+ def load_model(self):
88
+ assert self.is_loaded, "All the vision encoders should be loaded during initialization!"
89
+
90
+ def forward(self, x):
91
+ features = []
92
+ for vision_tower in self.vision_towers:
93
+ if vision_tower.input_image_size != self.input_image_size:
94
+ resized_x = F.interpolate(x.float(),
95
+ size=(vision_tower.input_image_size, vision_tower.input_image_size),
96
+ mode='bilinear',
97
+ align_corners=True).to(dtype=x.dtype)
98
+ else:
99
+ resized_x = x
100
+ feature = vision_tower(resized_x)
101
+ if len(feature.shape) == 3: # b, n, c
102
+ b, n, c = feature.shape
103
+ if n == self.num_tokens:
104
+ features.append(feature)
105
+ continue
106
+
107
+ w = h = int(n**0.5)
108
+ feature = feature.transpose(1,2).reshape(b, c, h, w)
109
+ else:
110
+ b, c, h, w = feature.shape
111
+
112
+ if w != self.grid_size:
113
+ feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
114
+ features.append(feature.flatten(2,3).transpose(1,2))
115
+
116
+ features = torch.cat(features, dim=-1)
117
+
118
+ return features
119
+
120
+ @property
121
+ def dummy_feature(self):
122
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
123
+
124
+ @property
125
+ def dtype(self):
126
+ return next(self.clip_vision_tower.parameters()).dtype
127
+
128
+ @property
129
+ def device(self):
130
+ return next(self.clip_vision_tower.parameters()).device
131
+
132
+ @property
133
+ def config(self):
134
+ assert NotImplementedError
135
+ pass
136
+
137
+ @property
138
+ def hidden_size(self):
139
+ return sum([_.hidden_size for _ in self.vision_towers])
140
+
141
+ @property
142
+ def num_patches(self):
143
+ return self.num_tokens
eagle/model/multimodal_encoder/pix2struct_encoder.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import AutoModel, CLIPImageProcessor
6
+ from PIL import Image
7
+ import requests
8
+ import torch.nn.functional as F
9
+ from transformers import AutoProcessor, Pix2StructVisionModel, Pix2StructProcessor, Pix2StructForConditionalGeneration
10
+
11
+ cfg={
12
+ "crop_size": 256,
13
+ "do_center_crop": True,
14
+ "do_normalize": True,
15
+ "do_resize": True,
16
+ "feature_extractor_type": "CLIPFeatureExtractor",
17
+ "image_mean": [
18
+ 0.48145466,
19
+ 0.4578275,
20
+ 0.40821073
21
+ ],
22
+ "image_std": [
23
+ 0.26862954,
24
+ 0.26130258,
25
+ 0.27577711
26
+ ],
27
+ "resample": 3,
28
+ "size": 256
29
+ }
30
+
31
+ '''
32
+ Pixel2Struct-Large Model (pretrained version)
33
+ '''
34
+ class Pix2StructLargeVisionTower(nn.Module):
35
+ def __init__(self, vision_tower, args, delay_load=False):
36
+ super().__init__()
37
+
38
+ self.is_loaded = False
39
+ self.vision_tower_name = vision_tower
40
+ self.do_resize = args.do_resize
41
+ self.de_normalize = args.de_normalize # de-normalize the input image and perform preprocessing with pix2struct processor
42
+ self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect
43
+ self.input_image_size = args.input_image_size
44
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
45
+ self.freeze_vision = args.freeze_vision
46
+
47
+ self.args = args
48
+ if not self.is_loaded:
49
+ self.load_model()
50
+
51
+ def load_model(self):
52
+ if self.is_loaded:
53
+ return
54
+ whole_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-large")
55
+ self.vision_tower = whole_model.encoder
56
+ self.pix2struct_processor = AutoProcessor.from_pretrained("google/pix2struct-large")
57
+ self.pix2struct_processor.image_processor.is_vqa = False
58
+
59
+ self.image_processor = CLIPImageProcessor(**cfg)
60
+ if self.input_image_size is not None:
61
+ self.image_processor.size=self.input_image_size
62
+ self.image_processor.crop_size={
63
+ 'height':self.input_image_size,
64
+ 'width': self.input_image_size
65
+ }
66
+
67
+ if self.freeze_vision:
68
+ self.vision_tower.requires_grad_(False)
69
+
70
+ self.image_mean = torch.tensor(self.image_processor.image_mean).view(1, 3, 1, 1)
71
+ self.image_std = torch.tensor(self.image_processor.image_std).view(1, 3, 1, 1)
72
+
73
+ self.is_loaded = True
74
+
75
+ def feature_select(self, image_forward_outs):
76
+ image_features = image_forward_outs.hidden_states[self.select_layer] # [bs, n, c], cls at idx=0
77
+ if self.select_feature == 'patch':
78
+ image_features = image_features[:, 1:]
79
+ elif self.select_feature == 'cls_patch':
80
+ image_features = image_features
81
+ else:
82
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
83
+ return image_features
84
+
85
+ # @torch.no_grad()
86
+ def forward(self, images):
87
+
88
+ if self.de_normalize:
89
+ mean = self.image_mean.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device)
90
+ std = self.image_std.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device)
91
+ x = (images * std + mean) * 255.0
92
+ x = self.pix2struct_processor(images=x.float(), return_tensors="pt")
93
+
94
+ image_features = self.vision_tower(**(x.to(device=self.device, dtype=self.dtype))).last_hidden_state
95
+ bs, n, c = image_features.shape
96
+ image_features = image_features[:, :2025, :] # HARD CODE
97
+
98
+ if self.do_resize:
99
+ image_features = image_features.transpose(1,2).reshape(bs, c, 45, 45) # HARD CODE
100
+ image_features = F.interpolate(image_features.float(), size=(32, 32), mode='bilinear', align_corners=True).to(dtype=image_features.dtype) # HARD CODE
101
+ return image_features
102
+ else:
103
+ return image_features
104
+
105
+
106
+ @property
107
+ def dummy_feature(self):
108
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
109
+
110
+ @property
111
+ def dtype(self):
112
+ return next(self.vision_tower.parameters()).dtype
113
+
114
+ @property
115
+ def device(self):
116
+ return next(self.vision_tower.parameters()).device
117
+
118
+ @property
119
+ def config(self):
120
+ return self.vision_tower.config
121
+
122
+ @property
123
+ def hidden_size(self):
124
+ #return self.config.hidden_size
125
+ hidden_dim = 1536
126
+ return hidden_dim
127
+
128
+ @property
129
+ def num_patches(self):
130
+ # return (self.config.image_size // self.config.patch_size) ** 2
131
+ return self.config['num_patches']
132
+
133
+
134
+ #main
135
+ if __name__ == "__main__":
136
+
137
+ '''
138
+ print('hello')
139
+ from PIL import Image
140
+ import requests
141
+ from transformers import AutoProcessor, Pix2StructVisionModel
142
+
143
+ model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
144
+ processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
145
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
146
+ image = Image.open("/lustre/fsw/portfolios/llmservice/users/fuxiaol/me.jpg")
147
+
148
+ for name, param in model.named_parameters():
149
+ param.requires_grad = False
150
+
151
+
152
+ #inputs = processor(images=image, return_tensors="pt")
153
+
154
+ image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-448px-V1-5')
155
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
156
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
157
+ #inputs = pixel_values.to(torch.bfloat16)
158
+
159
+ print('pixel_values:', pixel_values.size())
160
+
161
+ inputs = processor(images=pixel_values, max_patches=1024, return_tensors='pt')['flattened_patches']
162
+ print(inputs.size())
163
+ print(inputs.size())
164
+
165
+ outputs = model(inputs)
166
+
167
+ print(outputs.last_hidden_state.size())
168
+ '''
169
+
170
+ cfg={
171
+ "crop_size": 1024,
172
+ "do_center_crop": True,
173
+ "do_normalize": True,
174
+ "do_resize": True,
175
+ "feature_extractor_type": "CLIPFeatureExtractor",
176
+ "image_mean": [
177
+ 0.48145466,
178
+ 0.4578275,
179
+ 0.40821073
180
+ ],
181
+ "image_std": [
182
+ 0.26862954,
183
+ 0.26130258,
184
+ 0.27577711
185
+ ],
186
+ "resample": 3,
187
+ "size": 1024
188
+ }
189
+
190
+ from PIL import Image
191
+ import requests
192
+ from transformers import AutoProcessor, Pix2StructForConditionalGeneration
193
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
194
+ import torchvision.transforms as T
195
+
196
+ processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-large")
197
+ model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-large")
198
+
199
+ #url = "https://www.ilankelman.org/stopsigns/australia.jpg"
200
+ #image = Image.open(requests.get(url, stream=True).raw)
201
+ image = Image.open("/lustre/fsw/portfolios/llmservice/users/fuxiaol/sample2.jpg")
202
+
203
+ image_processor= CLIPImageProcessor(**cfg)
204
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
205
+ print(pixel_values.size())
206
+
207
+
208
+ mean = [0.48145466, 0.4578275, 0.40821073]
209
+ std = [0.26862954, 0.26130258, 0.27577711]
210
+ mean = torch.tensor(mean).view(1, 3, 1, 1)
211
+ std = torch.tensor(std).view(1, 3, 1, 1)
212
+ pixel_values = pixel_values * std + mean
213
+ print(pixel_values.size())
214
+
215
+
216
+ #pixel_values.save('pix2image.jpg')
217
+ transform = T.ToPILImage()
218
+ img = transform(pixel_values.squeeze(0))
219
+ img.save('pix2image.jpg')
220
+
221
+
222
+
223
+
224
+ inputs = processor(images=pixel_values, max_patches=1024,return_tensors="pt")['flattened_patches']
225
+
226
+ # autoregressive generation
227
+ generated_ids = model.generate(inputs, max_new_tokens=50)
228
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
229
+ print(generated_text)
230
+ #A stop sign is on a street corner.
231
+ #A stop sign is on a street corner.
232
+
233
+
234
+
235
+ '''
236
+ from PIL import Image
237
+ import requests
238
+ from transformers import AutoProcessor, CLIPModel
239
+
240
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
241
+
242
+
243
+
244
+ processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
245
+ model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336')
246
+
247
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
248
+ image = Image.open(requests.get(url, stream=True).raw)
249
+ print(image)
250
+
251
+ inputs = processor(images=image, return_tensors="pt")
252
+
253
+ #image_features = model.get_image_features(**inputs)
254
+ outputs = model(**inputs,output_hidden_states=True)
255
+ print(outputs.hidden_states[-1].size())
256
+ print(outputs.hidden_states[-2].size())
257
+ print(outputs.hidden_states[-3].size())
258
+ '''
259
+
260
+
261
+
262
+
263
+
264
+ #sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
265
+ #sequence = processor.post_process_generation(sequence, fix_markdown=False)
266
+ # note: we're using repr here such for the sake of printing the \n characters, feel free to just print the sequence
267
+ #print(repr(sequence))
eagle/model/multimodal_encoder/sam_encoder.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ from torch import Tensor, nn
6
+
7
+ import transformers
8
+ from transformers import SamProcessor
9
+ from transformers import SamModel, SamVisionConfig, SamVisionConfig
10
+ from transformers import SamImageProcessor
11
+ from PIL import Image
12
+
13
+
14
+ # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
15
+ class SamLayerNorm(nn.Module):
16
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
17
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
18
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
19
+ """
20
+
21
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
24
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
25
+ self.eps = eps
26
+ self.data_format = data_format
27
+ if self.data_format not in ["channels_last", "channels_first"]:
28
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
29
+ self.normalized_shape = (normalized_shape,)
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ if self.data_format == "channels_last":
33
+ x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
34
+ elif self.data_format == "channels_first":
35
+ input_dtype = x.dtype
36
+ x = x.float()
37
+ u = x.mean(1, keepdim=True)
38
+ s = (x - u).pow(2).mean(1, keepdim=True)
39
+ x = (x - u) / torch.sqrt(s + self.eps)
40
+ x = x.to(dtype=input_dtype)
41
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
42
+ return x
43
+
44
+
45
+
46
+ class ShortSamVisionNeck(nn.Module):
47
+ def __init__(self, config: SamVisionConfig):
48
+ super().__init__()
49
+ self.config = config
50
+
51
+ self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
52
+ self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
53
+
54
+ def forward(self, hidden_states):
55
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
56
+ hidden_states = self.conv1(hidden_states)
57
+ hidden_states = self.layer_norm1(hidden_states)
58
+ hidden_states = hidden_states.permute(0,2,3,1)
59
+ return hidden_states
60
+
61
+
62
+ class SAMVisionTower(nn.Module):
63
+ def __init__(self, vision_tower, args):
64
+ super().__init__()
65
+
66
+ self.args = args
67
+ self.is_loaded = False
68
+ self.vision_tower_name = vision_tower
69
+ self.input_image_size = args.input_image_size
70
+ self.pixel_shuffle = getattr(args, 'add_pixel_shuffle', False)
71
+
72
+ self.freeze = args.freeze_vision
73
+
74
+ self.load_model()
75
+
76
+ def load_model(self):
77
+ if self.is_loaded:
78
+ return
79
+
80
+ self.image_processor= SamProcessor.from_pretrained("facebook/sam-vit-large")
81
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder
82
+ sam_model.neck = ShortSamVisionNeck(sam_model.config)
83
+ self.image_processor.preprocess = self.image_processor.__call__
84
+ self.image_processor.image_mean = [0.485,0.456,0.406]
85
+ self.vision_tower = sam_model
86
+
87
+ if self.freeze:
88
+ self.vision_tower.requires_grad_(False)
89
+
90
+ self.is_loaded = True
91
+
92
+
93
+ def forward(self, images):
94
+ if type(images) is list:
95
+ image_features = []
96
+ for image in images:
97
+ image_feature = self.vision_tower(image.to(device=self.device).unsqueeze(0))
98
+ image_features.append(image_feature)
99
+ else:
100
+ image_features = self.vision_tower(images.to(device=self.device)).last_hidden_state.flatten(start_dim=1, end_dim=2).to(device=self.device)
101
+
102
+ if self.pixel_shuffle:
103
+ b, n, c = image_features.shape
104
+ h = w = int(n ** 0.5)
105
+ image_features = image_features.transpose(1,2).reshape(b, c, h, w)
106
+ image_features = nn.functional.pixel_unshuffle(image_features, 2)
107
+
108
+ return image_features
109
+ @property
110
+ def dummy_feature(self):
111
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
112
+
113
+ @property
114
+ def dtype(self):
115
+ return next(self.vision_tower.parameters()).dtype
116
+
117
+ @property
118
+ def device(self):
119
+ return next(self.vision_tower.parameters()).device
120
+
121
+ @property
122
+ def config(self):
123
+ # if self.is_loaded:
124
+ # return self.vision_tower.config
125
+ # else:
126
+ # return self.cfg_only
127
+ config_info = SamVisionConfig()
128
+ return SamVisionConfig()
129
+
130
+ @property
131
+ def hidden_size(self):
132
+ #return self.config.hidden_size
133
+ if self.pixel_shuffle:
134
+ hidden_size = 256 * 4
135
+ else:
136
+ hidden_size = 256
137
+ return hidden_size
138
+
139
+ @property
140
+ def num_patches(self):
141
+ # return (self.config.image_size // self.config.patch_size) ** 2
142
+ return self.config.num_patches
143
+
144
+
145
+ #main
146
+ if __name__ == "__main__":
147
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder
148
+ #sam_model = SamModel.from_pretrained("facebook/sam-vit-large")
149
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
150
+ for name, param in sam_model.named_parameters():
151
+ param.requires_grad = False
152
+
153
+ #raw_image = torch.rand(1, 3, 224, 224).to('cuda')
154
+ raw_image = Image.open('/lustre/fsw/portfolios/llmservice/users/fuxiaol/image/me.jpg').convert("RGB")
155
+ inputs = sam_processor(raw_image, return_tensors="pt")
156
+ #print(inputs)
157
+ #print(inputs['pixel_values'])
158
+ out = sam_model(inputs['pixel_values'])
159
+
160
+ print(out[0].size())
161
+ #vision_config = SamVisionConfig()
162
+ #print('=============')
163
+ #print(vision_config.hidden_size)
164
+ #print('=============')
165
+ #print(out)
166
+
167
+
168
+ #print(out)
169
+ #print(out)
170
+ #config_vision
171
+ #vision_config = SamVisionConfig()
172
+ #print(sam_model.layers)
173
+ #print(vision_config)
eagle/model/multimodal_encoder/vision_models/__init__.py ADDED
File without changes
eagle/model/multimodal_encoder/vision_models/convnext.py ADDED
@@ -0,0 +1,1110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ConvNeXt
2
+
3
+ Papers:
4
+ * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
5
+ @Article{liu2022convnet,
6
+ author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
7
+ title = {A ConvNet for the 2020s},
8
+ journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
9
+ year = {2022},
10
+ }
11
+
12
+ * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
13
+ @article{Woo2023ConvNeXtV2,
14
+ title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
15
+ author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
16
+ year={2023},
17
+ journal={arXiv preprint arXiv:2301.00808},
18
+ }
19
+
20
+ Original code and weights from:
21
+ * https://github.com/facebookresearch/ConvNeXt, original copyright below
22
+ * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
23
+
24
+ Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
25
+
26
+ Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
27
+ """
28
+ # ConvNeXt
29
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
30
+ # All rights reserved.
31
+ # This source code is licensed under the MIT license
32
+
33
+ # ConvNeXt-V2
34
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
35
+ # All rights reserved.
36
+ # This source code is licensed under the license found in the
37
+ # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
38
+ # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
39
+
40
+ from collections import OrderedDict
41
+ from functools import partial
42
+ from typing import Callable, Optional, Tuple, Union
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+ # hack for huggingface spaces
47
+ torch.jit.script = lambda f: f
48
+
49
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
50
+ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
51
+ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
52
+ from timm.layers import NormMlpClassifierHead, ClassifierHead
53
+ from timm.models._builder import build_model_with_cfg
54
+ from timm.models._manipulate import named_apply, checkpoint_seq
55
+ from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
56
+
57
+ __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
58
+
59
+
60
+ class Downsample(nn.Module):
61
+
62
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
63
+ super().__init__()
64
+ avg_stride = stride if dilation == 1 else 1
65
+ if stride > 1 or dilation > 1:
66
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
67
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
68
+ else:
69
+ self.pool = nn.Identity()
70
+
71
+ if in_chs != out_chs:
72
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
73
+ else:
74
+ self.conv = nn.Identity()
75
+
76
+ def forward(self, x):
77
+ x = self.pool(x)
78
+ x = self.conv(x)
79
+ return x
80
+
81
+
82
+ class ConvNeXtBlock(nn.Module):
83
+ """ ConvNeXt Block
84
+ There are two equivalent implementations:
85
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
86
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
87
+
88
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
89
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
90
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ in_chs: int,
96
+ out_chs: Optional[int] = None,
97
+ kernel_size: int = 7,
98
+ stride: int = 1,
99
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
100
+ mlp_ratio: float = 4,
101
+ conv_mlp: bool = False,
102
+ conv_bias: bool = True,
103
+ use_grn: bool = False,
104
+ ls_init_value: Optional[float] = 1e-6,
105
+ act_layer: Union[str, Callable] = 'gelu',
106
+ norm_layer: Optional[Callable] = None,
107
+ drop_path: float = 0.,
108
+ ):
109
+ """
110
+
111
+ Args:
112
+ in_chs: Block input channels.
113
+ out_chs: Block output channels (same as in_chs if None).
114
+ kernel_size: Depthwise convolution kernel size.
115
+ stride: Stride of depthwise convolution.
116
+ dilation: Tuple specifying input and output dilation of block.
117
+ mlp_ratio: MLP expansion ratio.
118
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
119
+ conv_bias: Apply bias for all convolution (linear) layers.
120
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
121
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
122
+ act_layer: Activation layer.
123
+ norm_layer: Normalization layer (defaults to LN if not specified).
124
+ drop_path: Stochastic depth probability.
125
+ """
126
+ super().__init__()
127
+ out_chs = out_chs or in_chs
128
+ dilation = to_ntuple(2)(dilation)
129
+ act_layer = get_act_layer(act_layer)
130
+ if not norm_layer:
131
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
132
+ mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
133
+ self.use_conv_mlp = conv_mlp
134
+ self.conv_dw = create_conv2d(
135
+ in_chs,
136
+ out_chs,
137
+ kernel_size=kernel_size,
138
+ stride=stride,
139
+ dilation=dilation[0],
140
+ depthwise=True,
141
+ bias=conv_bias,
142
+ )
143
+ self.norm = norm_layer(out_chs)
144
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
145
+ self.weight = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
146
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
147
+ self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
148
+ else:
149
+ self.shortcut = nn.Identity()
150
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
151
+
152
+ def forward(self, x):
153
+ shortcut = x
154
+ x = self.conv_dw(x)
155
+ if self.use_conv_mlp:
156
+ x = self.norm(x)
157
+ x = self.mlp(x)
158
+ else:
159
+ x = x.permute(0, 2, 3, 1)
160
+ x = self.norm(x)
161
+ x = self.mlp(x)
162
+ x = x.permute(0, 3, 1, 2)
163
+ if self.weight is not None:
164
+ x = x.mul(self.weight.reshape(1, -1, 1, 1))
165
+
166
+ x = self.drop_path(x) + self.shortcut(shortcut)
167
+ return x
168
+
169
+
170
+ class ConvNeXtStage(nn.Module):
171
+
172
+ def __init__(
173
+ self,
174
+ in_chs,
175
+ out_chs,
176
+ kernel_size=7,
177
+ stride=2,
178
+ depth=2,
179
+ dilation=(1, 1),
180
+ drop_path_rates=None,
181
+ ls_init_value=1.0,
182
+ conv_mlp=False,
183
+ conv_bias=True,
184
+ use_grn=False,
185
+ act_layer='gelu',
186
+ norm_layer=None,
187
+ norm_layer_cl=None
188
+ ):
189
+ super().__init__()
190
+ self.grad_checkpointing = False
191
+
192
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
193
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
194
+ pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
195
+ self.downsample = nn.Sequential(
196
+ norm_layer(in_chs),
197
+ create_conv2d(
198
+ in_chs,
199
+ out_chs,
200
+ kernel_size=ds_ks,
201
+ stride=stride,
202
+ dilation=dilation[0],
203
+ padding=pad,
204
+ bias=conv_bias,
205
+ ),
206
+ )
207
+ in_chs = out_chs
208
+ else:
209
+ self.downsample = nn.Identity()
210
+
211
+ drop_path_rates = drop_path_rates or [0.] * depth
212
+ stage_blocks = []
213
+ for i in range(depth):
214
+ stage_blocks.append(ConvNeXtBlock(
215
+ in_chs=in_chs,
216
+ out_chs=out_chs,
217
+ kernel_size=kernel_size,
218
+ dilation=dilation[1],
219
+ drop_path=drop_path_rates[i],
220
+ ls_init_value=ls_init_value,
221
+ conv_mlp=conv_mlp,
222
+ conv_bias=conv_bias,
223
+ use_grn=use_grn,
224
+ act_layer=act_layer,
225
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
226
+ ))
227
+ in_chs = out_chs
228
+ self.blocks = nn.Sequential(*stage_blocks)
229
+
230
+ def forward(self, x):
231
+ x = self.downsample(x)
232
+ if self.grad_checkpointing and not torch.jit.is_scripting():
233
+ x = checkpoint_seq(self.blocks, x)
234
+ else:
235
+ x = self.blocks(x)
236
+ return x
237
+
238
+
239
+ class ConvNeXt(nn.Module):
240
+ r""" ConvNeXt
241
+ A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ in_chans: int = 3,
247
+ num_classes: int = 1000,
248
+ global_pool: str = 'avg',
249
+ output_stride: int = 32,
250
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
251
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
252
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
253
+ ls_init_value: Optional[float] = 1e-6,
254
+ stem_type: str = 'patch',
255
+ patch_size: int = 4,
256
+ head_init_scale: float = 1.,
257
+ head_norm_first: bool = False,
258
+ head_hidden_size: Optional[int] = None,
259
+ conv_mlp: bool = False,
260
+ conv_bias: bool = True,
261
+ use_grn: bool = False,
262
+ act_layer: Union[str, Callable] = 'gelu',
263
+ norm_layer: Optional[Union[str, Callable]] = None,
264
+ norm_eps: Optional[float] = None,
265
+ drop_rate: float = 0.,
266
+ drop_path_rate: float = 0.,
267
+ ):
268
+ """
269
+ Args:
270
+ in_chans: Number of input image channels.
271
+ num_classes: Number of classes for classification head.
272
+ global_pool: Global pooling type.
273
+ output_stride: Output stride of network, one of (8, 16, 32).
274
+ depths: Number of blocks at each stage.
275
+ dims: Feature dimension at each stage.
276
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
277
+ ls_init_value: Init value for Layer Scale, disabled if None.
278
+ stem_type: Type of stem.
279
+ patch_size: Stem patch size for patch stem.
280
+ head_init_scale: Init scaling value for classifier weights and biases.
281
+ head_norm_first: Apply normalization before global pool + head.
282
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
283
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
284
+ conv_bias: Use bias layers w/ all convolutions.
285
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
286
+ act_layer: Activation layer type.
287
+ norm_layer: Normalization layer type.
288
+ drop_rate: Head pre-classifier dropout rate.
289
+ drop_path_rate: Stochastic depth drop rate.
290
+ """
291
+ super().__init__()
292
+ assert output_stride in (8, 16, 32)
293
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
294
+ if norm_layer is None:
295
+ norm_layer = LayerNorm2d
296
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
297
+ if norm_eps is not None:
298
+ norm_layer = partial(norm_layer, eps=norm_eps)
299
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
300
+ else:
301
+ assert conv_mlp,\
302
+ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
303
+ norm_layer_cl = norm_layer
304
+ if norm_eps is not None:
305
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
306
+
307
+ self.num_classes = num_classes
308
+ self.drop_rate = drop_rate
309
+ self.feature_info = []
310
+
311
+ assert stem_type in ('patch', 'overlap', 'overlap_tiered')
312
+ if stem_type == 'patch':
313
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
314
+ self.stem = nn.Sequential(
315
+ nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
316
+ norm_layer(dims[0]),
317
+ )
318
+ stem_stride = patch_size
319
+ else:
320
+ mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
321
+ self.stem = nn.Sequential(
322
+ nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
323
+ nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
324
+ norm_layer(dims[0]),
325
+ )
326
+ stem_stride = 4
327
+
328
+ self.stages = nn.Sequential()
329
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
330
+ stages = []
331
+ prev_chs = dims[0]
332
+ curr_stride = stem_stride
333
+ dilation = 1
334
+ # 4 feature resolution stages, each consisting of multiple residual blocks
335
+ for i in range(4):
336
+ stride = 2 if curr_stride == 2 or i > 0 else 1
337
+ if curr_stride >= output_stride and stride > 1:
338
+ dilation *= stride
339
+ stride = 1
340
+ curr_stride *= stride
341
+ first_dilation = 1 if dilation in (1, 2) else 2
342
+ out_chs = dims[i]
343
+ stages.append(ConvNeXtStage(
344
+ prev_chs,
345
+ out_chs,
346
+ kernel_size=kernel_sizes[i],
347
+ stride=stride,
348
+ dilation=(first_dilation, dilation),
349
+ depth=depths[i],
350
+ drop_path_rates=dp_rates[i],
351
+ ls_init_value=ls_init_value,
352
+ conv_mlp=conv_mlp,
353
+ conv_bias=conv_bias,
354
+ use_grn=use_grn,
355
+ act_layer=act_layer,
356
+ norm_layer=norm_layer,
357
+ norm_layer_cl=norm_layer_cl,
358
+ ))
359
+ prev_chs = out_chs
360
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
361
+ self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
362
+ self.stages = nn.Sequential(*stages)
363
+ self.num_features = prev_chs
364
+
365
+ # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
366
+ # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
367
+ if head_norm_first:
368
+ assert not head_hidden_size
369
+ self.norm_pre = norm_layer(self.num_features)
370
+ self.head = ClassifierHead(
371
+ self.num_features,
372
+ num_classes,
373
+ pool_type=global_pool,
374
+ drop_rate=self.drop_rate,
375
+ )
376
+ else:
377
+ self.norm_pre = nn.Identity()
378
+ self.head = NormMlpClassifierHead(
379
+ self.num_features,
380
+ num_classes,
381
+ hidden_size=head_hidden_size,
382
+ pool_type=global_pool,
383
+ drop_rate=self.drop_rate,
384
+ norm_layer=norm_layer,
385
+ act_layer='gelu',
386
+ )
387
+ named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
388
+
389
+ @torch.jit.ignore
390
+ def group_matcher(self, coarse=False):
391
+ return dict(
392
+ stem=r'^stem',
393
+ blocks=r'^stages\.(\d+)' if coarse else [
394
+ (r'^stages\.(\d+)\.downsample', (0,)), # blocks
395
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
396
+ (r'^norm_pre', (99999,))
397
+ ]
398
+ )
399
+
400
+ @torch.jit.ignore
401
+ def set_grad_checkpointing(self, enable=True):
402
+ for s in self.stages:
403
+ s.grad_checkpointing = enable
404
+
405
+ @torch.jit.ignore
406
+ def get_classifier(self):
407
+ return self.head.fc
408
+
409
+ def reset_classifier(self, num_classes=0, global_pool=None):
410
+ self.head.reset(num_classes, global_pool)
411
+
412
+ def forward_features(self, x):
413
+ x = self.stem(x)
414
+ x = self.stages(x)
415
+ x = self.norm_pre(x)
416
+ return x
417
+
418
+ def forward_head(self, x, pre_logits: bool = False):
419
+ return self.head(x, pre_logits=True) if pre_logits else self.head(x)
420
+
421
+ def forward(self, x):
422
+ x = self.forward_features(x)
423
+ x = self.forward_head(x)
424
+ return x
425
+
426
+
427
+ def _init_weights(module, name=None, head_init_scale=1.0):
428
+ if isinstance(module, nn.Conv2d):
429
+ trunc_normal_(module.weight, std=.02)
430
+ if module.bias is not None:
431
+ nn.init.zeros_(module.bias)
432
+ elif isinstance(module, nn.Linear):
433
+ trunc_normal_(module.weight, std=.02)
434
+ nn.init.zeros_(module.bias)
435
+ if name and 'head.' in name:
436
+ module.weight.data.mul_(head_init_scale)
437
+ module.bias.data.mul_(head_init_scale)
438
+
439
+
440
+ def checkpoint_filter_fn(state_dict, model):
441
+ """ Remap FB checkpoints -> timm """
442
+ if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
443
+ out_dict={}
444
+ out_dict = {k.replace('gamma', 'weight'): v for k, v in state_dict.items()}
445
+ return out_dict # non-FB checkpoint
446
+ if 'model' in state_dict:
447
+ state_dict = state_dict['model']
448
+
449
+ out_dict = {}
450
+ if 'visual.trunk.stem.0.weight' in state_dict:
451
+ out_dict = {k.replace('visual.trunk.', '').replace('gamma', 'weight'): v for k, v in state_dict.items() if
452
+ k.startswith('visual.trunk.')}
453
+
454
+ if 'visual.head.proj.weight' in state_dict:
455
+ out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
456
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
457
+ elif 'visual.head.mlp.fc1.weight' in state_dict:
458
+ out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
459
+ out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
460
+ out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
461
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
462
+ return out_dict
463
+
464
+ import re
465
+ for k, v in state_dict.items():
466
+ k = k.replace('downsample_layers.0.', 'stem.')
467
+ k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
468
+ k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
469
+ k = k.replace('dwconv', 'conv_dw')
470
+ k = k.replace('pwconv', 'mlp.fc')
471
+ if 'grn' in k:
472
+ k = k.replace('grn.beta', 'mlp.grn.bias')
473
+ k = k.replace('grn.gamma', 'mlp.grn.weight')
474
+ v = v.reshape(v.shape[-1])
475
+ k = k.replace('head.', 'head.fc.')
476
+ if k.startswith('norm.'):
477
+ k = k.replace('norm', 'head.norm')
478
+ if v.ndim == 2 and 'head' not in k:
479
+ model_shape = model.state_dict()[k].shape
480
+ v = v.reshape(model_shape)
481
+ k=k.replace('gamma','weight')
482
+ out_dict[k] = v
483
+
484
+ return out_dict
485
+
486
+
487
+ def _create_convnext(variant, pretrained=False, **kwargs):
488
+ if kwargs.get('pretrained_cfg', '') == 'fcmae':
489
+ # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
490
+ # This is workaround loading with num_classes=0 w/o removing norm-layer.
491
+ kwargs.setdefault('pretrained_strict', False)
492
+
493
+ model = build_model_with_cfg(
494
+ ConvNeXt, variant, pretrained,
495
+ pretrained_filter_fn=checkpoint_filter_fn,
496
+ feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
497
+ **kwargs)
498
+ return model
499
+
500
+
501
+ def _cfg(url='', **kwargs):
502
+ return {
503
+ 'url': url,
504
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
505
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
506
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
507
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
508
+ **kwargs
509
+ }
510
+
511
+
512
+ def _cfgv2(url='', **kwargs):
513
+ return {
514
+ 'url': url,
515
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
516
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
517
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
518
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
519
+ 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
520
+ 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
521
+ 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
522
+ **kwargs
523
+ }
524
+
525
+
526
+ default_cfgs = generate_default_cfgs({
527
+ # timm specific variants
528
+ 'convnext_tiny.in12k_ft_in1k': _cfg(
529
+ hf_hub_id='timm/',
530
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
531
+ 'convnext_small.in12k_ft_in1k': _cfg(
532
+ hf_hub_id='timm/',
533
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
534
+
535
+ 'convnext_atto.d2_in1k': _cfg(
536
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
537
+ hf_hub_id='timm/',
538
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
539
+ 'convnext_atto_ols.a2_in1k': _cfg(
540
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
541
+ hf_hub_id='timm/',
542
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
543
+ 'convnext_femto.d1_in1k': _cfg(
544
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
545
+ hf_hub_id='timm/',
546
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
547
+ 'convnext_femto_ols.d1_in1k': _cfg(
548
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
549
+ hf_hub_id='timm/',
550
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
551
+ 'convnext_pico.d1_in1k': _cfg(
552
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
553
+ hf_hub_id='timm/',
554
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
555
+ 'convnext_pico_ols.d1_in1k': _cfg(
556
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
557
+ hf_hub_id='timm/',
558
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
559
+ 'convnext_nano.in12k_ft_in1k': _cfg(
560
+ hf_hub_id='timm/',
561
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
562
+ 'convnext_nano.d1h_in1k': _cfg(
563
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
564
+ hf_hub_id='timm/',
565
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
566
+ 'convnext_nano_ols.d1h_in1k': _cfg(
567
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
568
+ hf_hub_id='timm/',
569
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
570
+ 'convnext_tiny_hnf.a2h_in1k': _cfg(
571
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
572
+ hf_hub_id='timm/',
573
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
574
+
575
+ 'convnext_tiny.in12k_ft_in1k_384': _cfg(
576
+ hf_hub_id='timm/',
577
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
578
+ 'convnext_small.in12k_ft_in1k_384': _cfg(
579
+ hf_hub_id='timm/',
580
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
581
+
582
+ 'convnext_nano.in12k': _cfg(
583
+ hf_hub_id='timm/',
584
+ crop_pct=0.95, num_classes=11821),
585
+ 'convnext_tiny.in12k': _cfg(
586
+ hf_hub_id='timm/',
587
+ crop_pct=0.95, num_classes=11821),
588
+ 'convnext_small.in12k': _cfg(
589
+ hf_hub_id='timm/',
590
+ crop_pct=0.95, num_classes=11821),
591
+
592
+ 'convnext_tiny.fb_in22k_ft_in1k': _cfg(
593
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
594
+ hf_hub_id='timm/',
595
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
596
+ 'convnext_small.fb_in22k_ft_in1k': _cfg(
597
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
598
+ hf_hub_id='timm/',
599
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
600
+ 'convnext_base.fb_in22k_ft_in1k': _cfg(
601
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
602
+ hf_hub_id='timm/',
603
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
604
+ 'convnext_large.fb_in22k_ft_in1k': _cfg(
605
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
606
+ hf_hub_id='timm/',
607
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
608
+ 'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
609
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
610
+ hf_hub_id='timm/',
611
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
612
+
613
+ 'convnext_tiny.fb_in1k': _cfg(
614
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
615
+ hf_hub_id='timm/',
616
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
617
+ 'convnext_small.fb_in1k': _cfg(
618
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
619
+ hf_hub_id='timm/',
620
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
621
+ 'convnext_base.fb_in1k': _cfg(
622
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
623
+ hf_hub_id='timm/',
624
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
625
+ 'convnext_large.fb_in1k': _cfg(
626
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
627
+ hf_hub_id='timm/',
628
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
629
+
630
+ 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
631
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
632
+ hf_hub_id='timm/',
633
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
634
+ 'convnext_small.fb_in22k_ft_in1k_384': _cfg(
635
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
636
+ hf_hub_id='timm/',
637
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
638
+ 'convnext_base.fb_in22k_ft_in1k_384': _cfg(
639
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
640
+ hf_hub_id='timm/',
641
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
642
+ 'convnext_large.fb_in22k_ft_in1k_384': _cfg(
643
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
644
+ hf_hub_id='timm/',
645
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
646
+ 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
647
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
648
+ hf_hub_id='timm/',
649
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
650
+
651
+ 'convnext_tiny.fb_in22k': _cfg(
652
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
653
+ hf_hub_id='timm/',
654
+ num_classes=21841),
655
+ 'convnext_small.fb_in22k': _cfg(
656
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
657
+ hf_hub_id='timm/',
658
+ num_classes=21841),
659
+ 'convnext_base.fb_in22k': _cfg(
660
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
661
+ hf_hub_id='timm/',
662
+ num_classes=21841),
663
+ 'convnext_large.fb_in22k': _cfg(
664
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
665
+ hf_hub_id='timm/',
666
+ num_classes=21841),
667
+ 'convnext_xlarge.fb_in22k': _cfg(
668
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
669
+ hf_hub_id='timm/',
670
+ num_classes=21841),
671
+
672
+ 'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
673
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
674
+ hf_hub_id='timm/',
675
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
676
+ 'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
677
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
678
+ hf_hub_id='timm/',
679
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
680
+ 'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
681
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
682
+ hf_hub_id='timm/',
683
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
684
+ 'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
685
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
686
+ hf_hub_id='timm/',
687
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
688
+ 'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
689
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
690
+ hf_hub_id='timm/',
691
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
692
+ 'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
693
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
694
+ hf_hub_id='timm/',
695
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
696
+ 'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
697
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
698
+ hf_hub_id='timm/',
699
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
700
+ 'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
701
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
702
+ hf_hub_id='timm/',
703
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
704
+ 'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
705
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
706
+ hf_hub_id='timm/',
707
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
708
+ 'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
709
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
710
+ hf_hub_id='timm/',
711
+ input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
712
+
713
+ 'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
714
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
715
+ hf_hub_id='timm/',
716
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
717
+ 'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
718
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
719
+ hf_hub_id='timm/',
720
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
721
+ 'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
722
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
723
+ hf_hub_id='timm/',
724
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
725
+ 'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
726
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
727
+ hf_hub_id='timm/',
728
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
729
+ 'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
730
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
731
+ hf_hub_id='timm/',
732
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
733
+ 'convnextv2_base.fcmae_ft_in1k': _cfgv2(
734
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
735
+ hf_hub_id='timm/',
736
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
737
+ 'convnextv2_large.fcmae_ft_in1k': _cfgv2(
738
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
739
+ hf_hub_id='timm/',
740
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
741
+ 'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
742
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
743
+ hf_hub_id='timm/',
744
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
745
+
746
+ 'convnextv2_atto.fcmae': _cfgv2(
747
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
748
+ hf_hub_id='timm/',
749
+ num_classes=0),
750
+ 'convnextv2_femto.fcmae': _cfgv2(
751
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
752
+ hf_hub_id='timm/',
753
+ num_classes=0),
754
+ 'convnextv2_pico.fcmae': _cfgv2(
755
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
756
+ hf_hub_id='timm/',
757
+ num_classes=0),
758
+ 'convnextv2_nano.fcmae': _cfgv2(
759
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
760
+ hf_hub_id='timm/',
761
+ num_classes=0),
762
+ 'convnextv2_tiny.fcmae': _cfgv2(
763
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
764
+ hf_hub_id='timm/',
765
+ num_classes=0),
766
+ 'convnextv2_base.fcmae': _cfgv2(
767
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
768
+ hf_hub_id='timm/',
769
+ num_classes=0),
770
+ 'convnextv2_large.fcmae': _cfgv2(
771
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
772
+ hf_hub_id='timm/',
773
+ num_classes=0),
774
+ 'convnextv2_huge.fcmae': _cfgv2(
775
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
776
+ hf_hub_id='timm/',
777
+ num_classes=0),
778
+
779
+ 'convnextv2_small.untrained': _cfg(),
780
+
781
+ # CLIP weights, fine-tuned on in1k or in12k + in1k
782
+ 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k': _cfg(
783
+ hf_hub_id='timm/',
784
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
785
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
786
+ 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384': _cfg(
787
+ hf_hub_id='timm/',
788
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
789
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
790
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320': _cfg(
791
+ hf_hub_id='timm/',
792
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
793
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
794
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384': _cfg(
795
+ hf_hub_id='timm/',
796
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
797
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
798
+
799
+ 'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
800
+ hf_hub_id='timm/',
801
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
802
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
803
+ 'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
804
+ hf_hub_id='timm/',
805
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
806
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
807
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
808
+ hf_hub_id='timm/',
809
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
810
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
811
+ ),
812
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg(
813
+ hf_hub_id='timm/',
814
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
815
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'
816
+ ),
817
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
818
+ hf_hub_id='timm/',
819
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
820
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
821
+
822
+ 'convnext_base.clip_laion2b_augreg_ft_in12k': _cfg(
823
+ hf_hub_id='timm/',
824
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
825
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
826
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_320': _cfg(
827
+ hf_hub_id='timm/',
828
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
829
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
830
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384': _cfg(
831
+ hf_hub_id='timm/',
832
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
833
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
834
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_384': _cfg(
835
+ hf_hub_id='timm/',
836
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
837
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
838
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
839
+ hf_hub_id='timm/',
840
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
841
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
842
+
843
+ # CLIP original image tower weights
844
+ 'convnext_base.clip_laion2b': _cfg(
845
+ hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
846
+ hf_hub_filename='open_clip_pytorch_model.bin',
847
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
848
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
849
+ 'convnext_base.clip_laion2b_augreg': _cfg(
850
+ hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
851
+ hf_hub_filename='open_clip_pytorch_model.bin',
852
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
853
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
854
+ 'convnext_base.clip_laiona': _cfg(
855
+ hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
856
+ hf_hub_filename='open_clip_pytorch_model.bin',
857
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
858
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
859
+ 'convnext_base.clip_laiona_320': _cfg(
860
+ hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
861
+ hf_hub_filename='open_clip_pytorch_model.bin',
862
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
863
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
864
+ 'convnext_base.clip_laiona_augreg_320': _cfg(
865
+ hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
866
+ hf_hub_filename='open_clip_pytorch_model.bin',
867
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
868
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
869
+ 'convnext_large_mlp.clip_laion2b_augreg': _cfg(
870
+ hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
871
+ hf_hub_filename='open_clip_pytorch_model.bin',
872
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
873
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
874
+ 'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
875
+ hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
876
+ hf_hub_filename='open_clip_pytorch_model.bin',
877
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
878
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
879
+ 'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
880
+ hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
881
+ hf_hub_filename='open_clip_pytorch_model.bin',
882
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
883
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
884
+ 'convnext_xxlarge.clip_laion2b_soup': _cfg(
885
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
886
+ hf_hub_filename='open_clip_pytorch_model.bin',
887
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
888
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
889
+ 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
890
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
891
+ hf_hub_filename='open_clip_pytorch_model.bin',
892
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
893
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
894
+ })
895
+
896
+
897
+ # @register_model
898
+ # def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
899
+ # # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
900
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
901
+ # model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
902
+ # return model
903
+
904
+
905
+ # @register_model
906
+ # def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
907
+ # # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
908
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
909
+ # model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
910
+ # return model
911
+
912
+
913
+ # @register_model
914
+ # def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
915
+ # # timm femto variant
916
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
917
+ # model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
918
+ # return model
919
+
920
+
921
+ # @register_model
922
+ # def convnext_femto_ols(pretrained=False, **kwargs) -> ConvNeXt:
923
+ # # timm femto variant
924
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
925
+ # model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
926
+ # return model
927
+
928
+
929
+ # @register_model
930
+ # def convnext_pico(pretrained=False, **kwargs) -> ConvNeXt:
931
+ # # timm pico variant
932
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
933
+ # model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
934
+ # return model
935
+
936
+
937
+ # @register_model
938
+ # def convnext_pico_ols(pretrained=False, **kwargs) -> ConvNeXt:
939
+ # # timm nano variant with overlapping 3x3 conv stem
940
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
941
+ # model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
942
+ # return model
943
+
944
+
945
+ # @register_model
946
+ # def convnext_nano(pretrained=False, **kwargs) -> ConvNeXt:
947
+ # # timm nano variant with standard stem and head
948
+ # model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
949
+ # model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
950
+ # return model
951
+
952
+
953
+ # @register_model
954
+ # def convnext_nano_ols(pretrained=False, **kwargs) -> ConvNeXt:
955
+ # # experimental nano variant with overlapping conv stem
956
+ # model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
957
+ # model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
958
+ # return model
959
+
960
+
961
+ # @register_model
962
+ # def convnext_tiny_hnf(pretrained=False, **kwargs) -> ConvNeXt:
963
+ # # experimental tiny variant with norm before pooling in head (head norm first)
964
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
965
+ # model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
966
+ # return model
967
+
968
+
969
+ # @register_model
970
+ # def convnext_tiny(pretrained=False, **kwargs) -> ConvNeXt:
971
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
972
+ # model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
973
+ # return model
974
+
975
+
976
+ # @register_model
977
+ # def convnext_small(pretrained=False, **kwargs) -> ConvNeXt:
978
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
979
+ # model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
980
+ # return model
981
+
982
+ # @register_model
983
+ # def convnext_base_clip(pretrained='', **kwargs) -> ConvNeXt:
984
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
985
+ # model = _create_convnext(pretrained, pretrained=True, **dict(model_args, **kwargs))
986
+ # return model
987
+
988
+ # @register_model
989
+ # def convnext_base(pretrained=False, **kwargs) -> ConvNeXt:
990
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
991
+ # model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
992
+ # return model
993
+
994
+
995
+ # @register_model
996
+ # def convnext_large(pretrained=False, **kwargs) -> ConvNeXt:
997
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
998
+ # model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
999
+ # return model
1000
+
1001
+
1002
+ # @register_model
1003
+ # def convnext_large_mlp(pretrained=False, **kwargs) -> ConvNeXt:
1004
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
1005
+ # model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
1006
+ # return model
1007
+
1008
+
1009
+ # @register_model
1010
+ # def convnext_xlarge(pretrained=False, **kwargs) -> ConvNeXt:
1011
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
1012
+ # model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
1013
+ # return model
1014
+
1015
+
1016
+ # @register_model
1017
+ def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
1018
+ model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
1019
+ model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
1020
+ return model
1021
+
1022
+
1023
+ # @register_model
1024
+ # def convnextv2_atto(pretrained=False, **kwargs) -> ConvNeXt:
1025
+ # # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
1026
+ # model_args = dict(
1027
+ # depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
1028
+ # model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
1029
+ # return model
1030
+
1031
+
1032
+ # @register_model
1033
+ # def convnextv2_femto(pretrained=False, **kwargs) -> ConvNeXt:
1034
+ # # timm femto variant
1035
+ # model_args = dict(
1036
+ # depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
1037
+ # model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
1038
+ # return model
1039
+
1040
+
1041
+ # @register_model
1042
+ # def convnextv2_pico(pretrained=False, **kwargs) -> ConvNeXt:
1043
+ # # timm pico variant
1044
+ # model_args = dict(
1045
+ # depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
1046
+ # model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
1047
+ # return model
1048
+
1049
+
1050
+ # @register_model
1051
+ # def convnextv2_nano(pretrained=False, **kwargs) -> ConvNeXt:
1052
+ # # timm nano variant with standard stem and head
1053
+ # model_args = dict(
1054
+ # depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
1055
+ # model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
1056
+ # return model
1057
+
1058
+
1059
+ # @register_model
1060
+ # def convnextv2_tiny(pretrained=False, **kwargs) -> ConvNeXt:
1061
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
1062
+ # model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
1063
+ # return model
1064
+
1065
+
1066
+ # @register_model
1067
+ # def convnextv2_small(pretrained=False, **kwargs) -> ConvNeXt:
1068
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
1069
+ # model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
1070
+ # return model
1071
+
1072
+
1073
+ # @register_model
1074
+ # def convnextv2_base(pretrained=False, **kwargs) -> ConvNeXt:
1075
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
1076
+ # model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
1077
+ # return model
1078
+
1079
+
1080
+ # @register_model
1081
+ # def convnextv2_large(pretrained=False, **kwargs) -> ConvNeXt:
1082
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
1083
+ # model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
1084
+ # return model
1085
+
1086
+
1087
+ # @register_model
1088
+ # def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt:
1089
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
1090
+ # model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
1091
+ # return model
1092
+
1093
+
1094
+ # register_model_deprecations(__name__, {
1095
+ # 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
1096
+ # 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
1097
+ # 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
1098
+ # 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
1099
+ # 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
1100
+ # 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
1101
+ # 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
1102
+ # 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
1103
+ # 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
1104
+ # 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
1105
+ # 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
1106
+ # 'convnext_small_in22k': 'convnext_small.fb_in22k',
1107
+ # 'convnext_base_in22k': 'convnext_base.fb_in22k',
1108
+ # 'convnext_large_in22k': 'convnext_large.fb_in22k',
1109
+ # 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
1110
+ # })
eagle/model/multimodal_encoder/vision_models/eva_vit.py ADDED
@@ -0,0 +1,1244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fvcore.nn.weight_init as weight_init
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ import numpy as np
7
+ import logging
8
+ from functools import partial
9
+ from scipy import interpolate
10
+ from math import pi
11
+ from einops import rearrange, repeat
12
+ import warnings
13
+ from PIL import Image
14
+ import torch.utils.checkpoint as cp
15
+ from transformers import CLIPImageProcessor
16
+ # from ..utils.attention import FlashAttention, FlashMHA
17
+ # try:
18
+ # import xformers.ops as xops
19
+ # except:
20
+ # pass
21
+
22
+ logger = logging.getLogger(__name__)
23
+ BatchNorm2d = torch.nn.BatchNorm2d
24
+
25
+ class Conv2d(torch.nn.Conv2d):
26
+ """
27
+ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
28
+ """
29
+
30
+ def __init__(self, *args, **kwargs):
31
+ """
32
+ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
33
+ Args:
34
+ norm (nn.Module, optional): a normalization layer
35
+ activation (callable(Tensor) -> Tensor): a callable activation function
36
+ It assumes that norm layer is used before activation.
37
+ """
38
+ norm = kwargs.pop("norm", None)
39
+ activation = kwargs.pop("activation", None)
40
+ super().__init__(*args, **kwargs)
41
+
42
+ self.norm = norm
43
+ self.activation = activation
44
+
45
+ def forward(self, x):
46
+ # torchscript does not support SyncBatchNorm yet
47
+ # https://github.com/pytorch/pytorch/issues/40507
48
+ # and we skip these codes in torchscript since:
49
+ # 1. currently we only support torchscript in evaluation mode
50
+ # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
51
+ # later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
52
+ if not torch.jit.is_scripting():
53
+ with warnings.catch_warnings(record=True):
54
+ if x.numel() == 0 and self.training:
55
+ # https://github.com/pytorch/pytorch/issues/12013
56
+ assert not isinstance(
57
+ self.norm, torch.nn.SyncBatchNorm
58
+ ), "SyncBatchNorm does not support empty inputs!"
59
+
60
+ x = F.conv2d(
61
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
62
+ )
63
+ if self.norm is not None:
64
+ x = self.norm(x)
65
+ if self.activation is not None:
66
+ x = self.activation(x)
67
+ return x
68
+
69
+
70
+ def window_partition(x, window_size):
71
+ """
72
+ Partition into non-overlapping windows with padding if needed.
73
+ Args:
74
+ x (tensor): input tokens with [B, H, W, C].
75
+ window_size (int): window size.
76
+ Returns:
77
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
78
+ (Hp, Wp): padded height and width before partition
79
+ """
80
+ B, H, W, C = x.shape
81
+
82
+ pad_h = (window_size - H % window_size) % window_size
83
+ pad_w = (window_size - W % window_size) % window_size
84
+ if pad_h > 0 or pad_w > 0:
85
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
86
+ Hp, Wp = H + pad_h, W + pad_w
87
+
88
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
89
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
90
+ return windows, (Hp, Wp)
91
+
92
+
93
+ def window_unpartition(windows, window_size, pad_hw, hw):
94
+ """
95
+ Window unpartition into original sequences and removing padding.
96
+ Args:
97
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
98
+ window_size (int): window size.
99
+ pad_hw (Tuple): padded height and width (Hp, Wp).
100
+ hw (Tuple): original height and width (H, W) before padding.
101
+ Returns:
102
+ x: unpartitioned sequences with [B, H, W, C].
103
+ """
104
+ Hp, Wp = pad_hw
105
+ H, W = hw
106
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
107
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
108
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
109
+
110
+ if Hp > H or Wp > W:
111
+ x = x[:, :H, :W, :].contiguous()
112
+ return x
113
+
114
+
115
+ def get_rel_pos(q_size, k_size, rel_pos):
116
+ """
117
+ Get relative positional embeddings according to the relative positions of
118
+ query and key sizes.
119
+ Args:
120
+ q_size (int): size of query q.
121
+ k_size (int): size of key k.
122
+ rel_pos (Tensor): relative position embeddings (L, C).
123
+ Returns:
124
+ Extracted positional embeddings according to relative positions.
125
+ """
126
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
127
+ use_log_interpolation = True
128
+
129
+ # Interpolate rel pos if needed.
130
+ if rel_pos.shape[0] != max_rel_dist:
131
+ if not use_log_interpolation:
132
+ # Interpolate rel pos.
133
+ rel_pos_resized = F.interpolate(
134
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
135
+ size=max_rel_dist,
136
+ mode="linear",
137
+ )
138
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
139
+ else:
140
+ src_size = rel_pos.shape[0]
141
+ dst_size = max_rel_dist
142
+
143
+ # q = 1.13492
144
+ q = 1.0903078
145
+ dis = []
146
+
147
+ cur = 1
148
+ for i in range(src_size // 2):
149
+ dis.append(cur)
150
+ cur += q ** (i + 1)
151
+
152
+ r_ids = [-_ for _ in reversed(dis)]
153
+ x = r_ids + [0] + dis
154
+ t = dst_size // 2.0
155
+ dx = np.arange(-t, t + 0.1, 1.0)
156
+ all_rel_pos_bias = []
157
+ for i in range(rel_pos.shape[1]):
158
+ z = rel_pos[:, i].view(src_size).cpu().float().numpy()
159
+ f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate")
160
+ all_rel_pos_bias.append(
161
+ torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device))
162
+ rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1)
163
+ else:
164
+ rel_pos_resized = rel_pos
165
+
166
+ # Scale the coords with short length if shapes for q and k are different.
167
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
168
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
169
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
170
+
171
+ return rel_pos_resized[relative_coords.long()]
172
+
173
+
174
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
175
+ """
176
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
177
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
178
+ Args:
179
+ attn (Tensor): attention map.
180
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
181
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
182
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
183
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
184
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
185
+ Returns:
186
+ attn (Tensor): attention map with added relative positional embeddings.
187
+ """
188
+ q_h, q_w = q_size
189
+ k_h, k_w = k_size
190
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
191
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
192
+
193
+ B, _, dim = q.shape
194
+ r_q = q.reshape(B, q_h, q_w, dim)
195
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
196
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
197
+
198
+ attn = (
199
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
200
+ ).view(B, q_h * q_w, k_h * k_w)
201
+
202
+ return attn
203
+
204
+
205
+ def get_abs_pos(abs_pos, has_cls_token, hw):
206
+ """
207
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
208
+ dimension for the original embeddings.
209
+ Args:
210
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
211
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
212
+ hw (Tuple): size of input image tokens.
213
+ Returns:
214
+ Absolute positional embeddings after processing with shape (1, H, W, C)
215
+ """
216
+ h, w = hw
217
+ if has_cls_token:
218
+ abs_pos = abs_pos[:, 1:]
219
+ xy_num = abs_pos.shape[1]
220
+ size = int(math.sqrt(xy_num))
221
+ assert size * size == xy_num
222
+
223
+ if size != h or size != w:
224
+ original_datatype = abs_pos.dtype
225
+ new_abs_pos = F.interpolate(
226
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2).float(), # bf16 is not implemented
227
+ size=(h, w),
228
+ mode="bicubic",
229
+ align_corners=False,
230
+ ).to(original_datatype)
231
+
232
+ return new_abs_pos.permute(0, 2, 3, 1)
233
+ else:
234
+ return abs_pos.reshape(1, h, w, -1)
235
+
236
+
237
+ class PatchEmbed(nn.Module):
238
+ """
239
+ Image to Patch Embedding.
240
+ """
241
+
242
+ def __init__(
243
+ self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
244
+ ):
245
+ """
246
+ Args:
247
+ kernel_size (Tuple): kernel size of the projection layer.
248
+ stride (Tuple): stride of the projection layer.
249
+ padding (Tuple): padding size of the projection layer.
250
+ in_chans (int): Number of input image channels.
251
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
252
+ """
253
+ super().__init__()
254
+
255
+ self.proj = nn.Conv2d(
256
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
257
+ )
258
+
259
+ def forward(self, x):
260
+ x = self.proj(x)
261
+ # B C H W -> B H W C
262
+ x = x.permute(0, 2, 3, 1)
263
+ return x
264
+
265
+
266
+ def broadcat(tensors, dim = -1):
267
+ num_tensors = len(tensors)
268
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
269
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
270
+ shape_len = list(shape_lens)[0]
271
+ dim = (dim + shape_len) if dim < 0 else dim
272
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
273
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
274
+ assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
275
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
276
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
277
+ expanded_dims.insert(dim, (dim, dims[dim]))
278
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
279
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
280
+ return torch.cat(tensors, dim = dim)
281
+
282
+
283
+
284
+ def rotate_half(x):
285
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
286
+ x1, x2 = x.unbind(dim = -1)
287
+ x = torch.stack((-x2, x1), dim = -1)
288
+ return rearrange(x, '... d r -> ... (d r)')
289
+
290
+
291
+
292
+ class VisionRotaryEmbedding(nn.Module):
293
+ def __init__(
294
+ self,
295
+ dim,
296
+ pt_seq_len,
297
+ ft_seq_len=None,
298
+ custom_freqs = None,
299
+ freqs_for = 'lang',
300
+ theta = 10000,
301
+ max_freq = 10,
302
+ num_freqs = 1,
303
+ ):
304
+ super().__init__()
305
+ if custom_freqs:
306
+ freqs = custom_freqs
307
+ elif freqs_for == 'lang':
308
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
309
+ elif freqs_for == 'pixel':
310
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
311
+ elif freqs_for == 'constant':
312
+ freqs = torch.ones(num_freqs).float()
313
+ else:
314
+ raise ValueError(f'unknown modality {freqs_for}')
315
+
316
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
317
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
318
+
319
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
320
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
321
+
322
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
323
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
324
+
325
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
326
+
327
+ self.register_buffer("freqs_cos", freqs.cos())
328
+ self.register_buffer("freqs_sin", freqs.sin())
329
+
330
+ # print('======== shape of rope freq', self.freqs_cos.shape, '========')
331
+
332
+ def forward(self, t, start_index = 0):
333
+ rot_dim = self.freqs_cos.shape[-1]
334
+ end_index = start_index + rot_dim
335
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
336
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
337
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
338
+ return torch.cat((t_left, t, t_right), dim = -1)
339
+
340
+
341
+
342
+
343
+ class VisionRotaryEmbeddingFast(nn.Module):
344
+ def __init__(
345
+ self,
346
+ dim,
347
+ pt_seq_len=16,
348
+ ft_seq_len=None,
349
+ custom_freqs = None,
350
+ freqs_for = 'lang',
351
+ theta = 10000,
352
+ max_freq = 10,
353
+ num_freqs = 1,
354
+ ):
355
+ super().__init__()
356
+ if custom_freqs:
357
+ freqs = custom_freqs
358
+ elif freqs_for == 'lang':
359
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
360
+ elif freqs_for == 'pixel':
361
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
362
+ elif freqs_for == 'constant':
363
+ freqs = torch.ones(num_freqs).float()
364
+ else:
365
+ raise ValueError(f'unknown modality {freqs_for}')
366
+
367
+ if ft_seq_len is None: ft_seq_len = pt_seq_len
368
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
369
+
370
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
371
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
372
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
373
+
374
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
375
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
376
+
377
+ self.register_buffer("freqs_cos", freqs_cos)
378
+ self.register_buffer("freqs_sin", freqs_sin)
379
+
380
+ # print('======== shape of rope freq', self.freqs_cos.shape, '========')
381
+
382
+ def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
383
+
384
+
385
+ class FrozenBatchNorm2d(nn.Module):
386
+ """
387
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
388
+ It contains non-trainable buffers called
389
+ "weight" and "bias", "running_mean", "running_var",
390
+ initialized to perform identity transformation.
391
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
392
+ which are computed from the original four parameters of BN.
393
+ The affine transform `x * weight + bias` will perform the equivalent
394
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
395
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
396
+ will be left unchanged as identity transformation.
397
+ Other pre-trained backbone models may contain all 4 parameters.
398
+ The forward is implemented by `F.batch_norm(..., training=False)`.
399
+ """
400
+
401
+ _version = 3
402
+
403
+ def __init__(self, num_features, eps=1e-5):
404
+ super().__init__()
405
+ self.num_features = num_features
406
+ self.eps = eps
407
+ self.register_buffer("weight", torch.ones(num_features))
408
+ self.register_buffer("bias", torch.zeros(num_features))
409
+ self.register_buffer("running_mean", torch.zeros(num_features))
410
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
411
+
412
+ def forward(self, x):
413
+ if x.requires_grad:
414
+ # When gradients are needed, F.batch_norm will use extra memory
415
+ # because its backward op computes gradients for weight/bias as well.
416
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
417
+ bias = self.bias - self.running_mean * scale
418
+ scale = scale.reshape(1, -1, 1, 1)
419
+ bias = bias.reshape(1, -1, 1, 1)
420
+ out_dtype = x.dtype # may be half
421
+ return x * scale.to(out_dtype) + bias.to(out_dtype)
422
+ else:
423
+ # When gradients are not needed, F.batch_norm is a single fused op
424
+ # and provide more optimization opportunities.
425
+ return F.batch_norm(
426
+ x,
427
+ self.running_mean,
428
+ self.running_var,
429
+ self.weight,
430
+ self.bias,
431
+ training=False,
432
+ eps=self.eps,
433
+ )
434
+
435
+ def _load_from_state_dict(
436
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
437
+ ):
438
+ version = local_metadata.get("version", None)
439
+
440
+ if version is None or version < 2:
441
+ # No running_mean/var in early versions
442
+ # This will silent the warnings
443
+ if prefix + "running_mean" not in state_dict:
444
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
445
+ if prefix + "running_var" not in state_dict:
446
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
447
+
448
+ super()._load_from_state_dict(
449
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
450
+ )
451
+
452
+ def __repr__(self):
453
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
454
+
455
+ @classmethod
456
+ def convert_frozen_batchnorm(cls, module):
457
+ """
458
+ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
459
+ Args:
460
+ module (torch.nn.Module):
461
+ Returns:
462
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
463
+ Otherwise, in-place convert module and return it.
464
+ Similar to convert_sync_batchnorm in
465
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
466
+ """
467
+ bn_module = nn.modules.batchnorm
468
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
469
+ res = module
470
+ if isinstance(module, bn_module):
471
+ res = cls(module.num_features)
472
+ if module.affine:
473
+ res.weight.data = module.weight.data.clone().detach()
474
+ res.bias.data = module.bias.data.clone().detach()
475
+ res.running_mean.data = module.running_mean.data
476
+ res.running_var.data = module.running_var.data
477
+ res.eps = module.eps
478
+ else:
479
+ for name, child in module.named_children():
480
+ new_child = cls.convert_frozen_batchnorm(child)
481
+ if new_child is not child:
482
+ res.add_module(name, new_child)
483
+ return res
484
+
485
+ class LayerNorm(nn.Module):
486
+ """
487
+ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
488
+ variance normalization over the channel dimension for inputs that have shape
489
+ (batch_size, channels, height, width).
490
+ https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
491
+ """
492
+
493
+ def __init__(self, normalized_shape, eps=1e-6):
494
+ super().__init__()
495
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
496
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
497
+ self.eps = eps
498
+ self.normalized_shape = (normalized_shape,)
499
+
500
+ def forward(self, x):
501
+ u = x.mean(1, keepdim=True)
502
+ s = (x - u).pow(2).mean(1, keepdim=True)
503
+ x = (x - u) / torch.sqrt(s + self.eps)
504
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
505
+ return x
506
+
507
+
508
+ class CNNBlockBase(nn.Module):
509
+ """
510
+ A CNN block is assumed to have input channels, output channels and a stride.
511
+ The input and output of `forward()` method must be NCHW tensors.
512
+ The method can perform arbitrary computation but must match the given
513
+ channels and stride specification.
514
+ Attribute:
515
+ in_channels (int):
516
+ out_channels (int):
517
+ stride (int):
518
+ """
519
+
520
+ def __init__(self, in_channels, out_channels, stride):
521
+ """
522
+ The `__init__` method of any subclass should also contain these arguments.
523
+ Args:
524
+ in_channels (int):
525
+ out_channels (int):
526
+ stride (int):
527
+ """
528
+ super().__init__()
529
+ self.in_channels = in_channels
530
+ self.out_channels = out_channels
531
+ self.stride = stride
532
+
533
+ def freeze(self):
534
+ """
535
+ Make this block not trainable.
536
+ This method sets all parameters to `requires_grad=False`,
537
+ and convert all BatchNorm layers to FrozenBatchNorm
538
+ Returns:
539
+ the block itself
540
+ """
541
+ for p in self.parameters():
542
+ p.requires_grad = False
543
+ FrozenBatchNorm2d.convert_frozen_batchnorm(self)
544
+ return self
545
+
546
+ def get_norm(norm, out_channels):
547
+ """
548
+ Args:
549
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
550
+ or a callable that takes a channel number and returns
551
+ the normalization layer as a nn.Module.
552
+ Returns:
553
+ nn.Module or None: the normalization layer
554
+ """
555
+ if norm is None:
556
+ return None
557
+ if isinstance(norm, str):
558
+ if len(norm) == 0:
559
+ return None
560
+ norm = {
561
+ "BN": BatchNorm2d,
562
+ # Fixed in https://github.com/pytorch/pytorch/pull/36382
563
+ "SyncBN": nn.SyncBatchNorm,
564
+ "FrozenBN": FrozenBatchNorm2d,
565
+ "GN": lambda channels: nn.GroupNorm(32, channels),
566
+ # for debugging:
567
+ "nnSyncBN": nn.SyncBatchNorm,
568
+ "LN": lambda channels: LayerNorm(channels)
569
+ }[norm]
570
+ return norm(out_channels)
571
+
572
+ class DropPath(nn.Module):
573
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
574
+ """
575
+
576
+ def __init__(self, drop_prob=None):
577
+ super(DropPath, self).__init__()
578
+ self.drop_prob = drop_prob
579
+
580
+ def forward(self, x):
581
+ if self.drop_prob == 0. or not self.training:
582
+ return x
583
+ keep_prob = 1 - self.drop_prob
584
+ # work with diff dim tensors, not just 2D ConvNets
585
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
586
+ random_tensor = keep_prob + \
587
+ torch.rand(shape, dtype=x.dtype, device=x.device)
588
+ random_tensor.floor_() # binarize
589
+ output = x.div(keep_prob) * random_tensor
590
+ return output
591
+
592
+
593
+
594
+ class SwiGLU(nn.Module):
595
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
596
+ norm_layer=nn.LayerNorm, subln=False
597
+ ):
598
+ super().__init__()
599
+ out_features = out_features or in_features
600
+ hidden_features = hidden_features or in_features
601
+
602
+ self.w1 = nn.Linear(in_features, hidden_features)
603
+ self.w2 = nn.Linear(in_features, hidden_features)
604
+
605
+ self.act = act_layer()
606
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
607
+ self.w3 = nn.Linear(hidden_features, out_features)
608
+
609
+ self.drop = nn.Dropout(drop)
610
+
611
+ def forward(self, x):
612
+ x1 = self.w1(x)
613
+ x2 = self.w2(x)
614
+ hidden = self.act(x1) * x2
615
+ x = self.ffn_ln(hidden)
616
+ x = self.w3(x)
617
+ x = self.drop(x)
618
+ return x
619
+
620
+
621
+ class Attention(nn.Module):
622
+ def __init__(
623
+ self,
624
+ dim,
625
+ num_heads=8,
626
+ qkv_bias=True,
627
+ qk_scale=None,
628
+ attn_head_dim=None,
629
+ norm_layer=nn.LayerNorm,
630
+ rope=None,
631
+ xattn=True,
632
+ subln=False
633
+ ):
634
+ super().__init__()
635
+ self.num_heads = num_heads
636
+ head_dim = dim // num_heads
637
+ if attn_head_dim is not None:
638
+ head_dim = attn_head_dim
639
+ all_head_dim = head_dim * self.num_heads
640
+ self.scale = qk_scale or head_dim ** -0.5
641
+
642
+ self.subln = subln
643
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
644
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
645
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
646
+
647
+ if qkv_bias:
648
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
649
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
650
+ else:
651
+ self.q_bias = None
652
+ self.v_bias = None
653
+
654
+ self.rope = rope
655
+ self.xattn = xattn
656
+ self.proj = nn.Linear(all_head_dim, dim)
657
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
658
+
659
+ if self.xattn:
660
+ factory_kwargs = {'device': 'cuda', 'dtype': torch.float16}
661
+ self.inner_attn = FlashAttention(attention_dropout=0.0, **factory_kwargs)
662
+
663
+ def forward(self, x):
664
+ B, H, W, C = x.shape
665
+ x = x.view(B, -1, C)
666
+ N = H * W
667
+
668
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
669
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
670
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
671
+
672
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
673
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
674
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
675
+
676
+ ## rope
677
+ q = self.rope(q).type_as(v)
678
+ k = self.rope(k).type_as(v)
679
+
680
+ if self.xattn:
681
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
682
+ k = k.permute(0, 2, 1, 3)
683
+ v = v.permute(0, 2, 1, 3)
684
+
685
+ kv = torch.stack([k, v], dim=2)
686
+ x, attn_weights = self.inner_attn(q, kv, key_padding_mask=None, causal=False)
687
+ # x = xops.memory_efficient_attention(q, k, v)
688
+ x = x.reshape(B, N, -1)
689
+ x = self.inner_attn_ln(x)
690
+ else:
691
+ q = q * self.scale
692
+ attn = (q @ k.transpose(-2, -1))
693
+ attn = attn.softmax(dim=-1).type_as(x)
694
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
695
+ x = self.inner_attn_ln(x)
696
+
697
+ x = self.proj(x)
698
+ x = x.view(B, H, W, C)
699
+
700
+ return x
701
+
702
+
703
+ class ResBottleneckBlock(CNNBlockBase):
704
+ """
705
+ The standard bottleneck residual block without the last activation layer.
706
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
707
+ """
708
+
709
+ def __init__(
710
+ self,
711
+ in_channels,
712
+ out_channels,
713
+ bottleneck_channels,
714
+ norm="LN",
715
+ act_layer=nn.GELU,
716
+ ):
717
+ """
718
+ Args:
719
+ in_channels (int): Number of input channels.
720
+ out_channels (int): Number of output channels.
721
+ bottleneck_channels (int): number of output channels for the 3x3
722
+ "bottleneck" conv layers.
723
+ norm (str or callable): normalization for all conv layers.
724
+ See :func:`layers.get_norm` for supported format.
725
+ act_layer (callable): activation for all conv layers.
726
+ """
727
+ super().__init__(in_channels, out_channels, 1)
728
+
729
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
730
+ self.norm1 = get_norm(norm, bottleneck_channels)
731
+ self.act1 = act_layer()
732
+
733
+ self.conv2 = Conv2d(
734
+ bottleneck_channels,
735
+ bottleneck_channels,
736
+ 3,
737
+ padding=1,
738
+ bias=False,
739
+ )
740
+ self.norm2 = get_norm(norm, bottleneck_channels)
741
+ self.act2 = act_layer()
742
+
743
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
744
+ self.norm3 = get_norm(norm, out_channels)
745
+
746
+ for layer in [self.conv1, self.conv2, self.conv3]:
747
+ weight_init.c2_msra_fill(layer)
748
+ for layer in [self.norm1, self.norm2]:
749
+ layer.weight.data.fill_(1.0)
750
+ layer.bias.data.zero_()
751
+ # zero init last norm layer.
752
+ self.norm3.weight.data.zero_()
753
+ self.norm3.bias.data.zero_()
754
+
755
+ def forward(self, x):
756
+ out = x
757
+ for layer in self.children():
758
+ out = layer(out)
759
+
760
+ out = x + out
761
+ return out
762
+
763
+
764
+ class Block(nn.Module):
765
+ """Transformer blocks with support of window attention and residual propagation blocks"""
766
+
767
+ def __init__(
768
+ self,
769
+ dim,
770
+ num_heads,
771
+ mlp_ratio=4*2/3,
772
+ qkv_bias=True,
773
+ drop_path=0.0,
774
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
775
+ window_size=0,
776
+ use_residual_block=False,
777
+ rope=None,
778
+ xattn=True,
779
+ subln=False,
780
+ # with_cp=True,
781
+ ):
782
+ """
783
+ Args:
784
+ dim (int): Number of input channels.
785
+ num_heads (int): Number of attention heads in each ViT block.
786
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
787
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
788
+ drop_path (float): Stochastic depth rate.
789
+ norm_layer (nn.Module): Normalization layer.
790
+ act_layer (nn.Module): Activation layer.
791
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
792
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
793
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
794
+ use window attention.
795
+ use_residual_block (bool): If True, use a residual block after the MLP block.
796
+ input_size (int or None): Input resolution for calculating the relative positional
797
+ parameter size.
798
+ """
799
+ super().__init__()
800
+ self.norm1 = norm_layer(dim)
801
+ self.attn = Attention(
802
+ dim,
803
+ num_heads=num_heads,
804
+ qkv_bias=qkv_bias,
805
+ rope=rope,
806
+ xattn=xattn,
807
+ subln=subln
808
+ )
809
+
810
+
811
+ # self.with_cp = with_cp
812
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
813
+ self.norm2 = norm_layer(dim)
814
+ self.mlp = SwiGLU(
815
+ in_features=dim,
816
+ hidden_features=int(dim * mlp_ratio),
817
+ subln=True,
818
+ norm_layer=norm_layer,
819
+ )
820
+
821
+ self.window_size = window_size
822
+
823
+ self.use_residual_block = use_residual_block
824
+ if use_residual_block:
825
+ # Use a residual block with bottleneck channel as dim // 2
826
+ self.residual = ResBottleneckBlock(
827
+ in_channels=dim,
828
+ out_channels=dim,
829
+ bottleneck_channels=dim // 2,
830
+ norm="LN",
831
+ )
832
+
833
+ def _forward(self, x):
834
+ shortcut = x
835
+ x = self.norm1(x)
836
+
837
+ # Window partition
838
+ if self.window_size > 0:
839
+ H, W = x.shape[1], x.shape[2]
840
+ x, pad_hw = window_partition(x, self.window_size)
841
+
842
+ x = self.attn(x)
843
+
844
+ # Reverse window partition
845
+ if self.window_size > 0:
846
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
847
+
848
+ x = shortcut + self.drop_path(x)
849
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
850
+
851
+ if self.use_residual_block:
852
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
853
+
854
+ return x
855
+
856
+ def forward(self, x, with_cp=False):
857
+ # if self.with_cp and self.training:
858
+ if with_cp:
859
+ x = cp.checkpoint(self._forward, x)
860
+ else:
861
+ x = self._forward(x)
862
+ return x
863
+
864
+ #@BACKBONES.register_module()
865
+ class EVAViT(nn.Module):
866
+ """
867
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
868
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
869
+ https://arxiv.org/abs/2203.16527
870
+ """
871
+
872
+ def __init__(
873
+ self,
874
+ img_size=1024,
875
+ patch_size=16,
876
+ in_chans=3,
877
+ embed_dim=768,
878
+ depth=12,
879
+ num_heads=12,
880
+ mlp_ratio=4*2/3,
881
+ qkv_bias=True,
882
+ drop_path_rate=0.0,
883
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
884
+ act_layer=nn.GELU,
885
+ use_abs_pos=True,
886
+ use_rel_pos=False,
887
+ # sim_fpn=None,
888
+ rope=True,
889
+ pt_hw_seq_len=16,
890
+ intp_freq=True,
891
+ window_size=0,
892
+ global_window_size=0,
893
+ window_block_indexes=(),
894
+ residual_block_indexes=(),
895
+ pretrain_img_size=224,
896
+ pretrain_use_cls_token=True,
897
+ out_feature="last_feat",
898
+ subln=False,
899
+ xattn=True,
900
+ # with_cp=True,
901
+ frozen=False,
902
+ ):
903
+ """
904
+ Args:
905
+ img_size (int): Input image size.
906
+ patch_size (int): Patch size.
907
+ in_chans (int): Number of input image channels.
908
+ embed_dim (int): Patch embedding dimension.
909
+ depth (int): Depth of ViT.
910
+ num_heads (int): Number of attention heads in each ViT block.
911
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
912
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
913
+ drop_path_rate (float): Stochastic depth rate.
914
+ norm_layer (nn.Module): Normalization layer.
915
+ act_layer (nn.Module): Activation layer.
916
+ use_abs_pos (bool): If True, use absolute positional embeddings.
917
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
918
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
919
+ window_size (int): Window size for window attention blocks.
920
+ window_block_indexes (list): Indexes for blocks using window attention.
921
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
922
+ use_act_checkpoint (bool): If True, use activation checkpointing.
923
+ pretrain_img_size (int): input image size for pretraining models.
924
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
925
+ out_feature (str): name of the feature from the last block.
926
+ """
927
+ super().__init__()
928
+ self.pretrain_use_cls_token = pretrain_use_cls_token
929
+ self.patch_embed = PatchEmbed(
930
+ kernel_size=(patch_size, patch_size),
931
+ stride=(patch_size, patch_size),
932
+ in_chans=in_chans,
933
+ embed_dim=embed_dim,
934
+ )
935
+ self.frozen = frozen
936
+ self.gradient_checkpointing = False
937
+
938
+ if use_abs_pos:
939
+ # Initialize absolute positional embedding with pretrain image size.
940
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
941
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
942
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
943
+ else:
944
+ self.pos_embed = None
945
+
946
+ half_head_dim = embed_dim // num_heads // 2
947
+ hw_seq_len = img_size // patch_size
948
+
949
+ self.rope_win = VisionRotaryEmbeddingFast(
950
+ dim=half_head_dim,
951
+ pt_seq_len=pt_hw_seq_len,
952
+ ft_seq_len=window_size if intp_freq else None,
953
+ )
954
+ self.rope_glb = VisionRotaryEmbeddingFast(
955
+ dim=half_head_dim,
956
+ pt_seq_len=pt_hw_seq_len,
957
+ ft_seq_len=hw_seq_len if intp_freq else None,
958
+ )
959
+
960
+ # stochastic depth decay rule
961
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
962
+
963
+ self.blocks = nn.ModuleList()
964
+ for i in range(depth):
965
+ block = Block(
966
+ dim=embed_dim,
967
+ num_heads=num_heads,
968
+ mlp_ratio=mlp_ratio,
969
+ qkv_bias=qkv_bias,
970
+ drop_path=dpr[i],
971
+ norm_layer=norm_layer,
972
+ window_size=window_size if i in window_block_indexes else global_window_size,
973
+ use_residual_block=i in residual_block_indexes,
974
+ rope=self.rope_win if i in window_block_indexes else self.rope_glb,
975
+ xattn=xattn,
976
+ subln=subln,
977
+ # with_cp=with_cp,
978
+ )
979
+
980
+ self.blocks.append(block)
981
+
982
+ self._out_feature_channels = {out_feature: embed_dim}
983
+ self._out_feature_strides = {out_feature: patch_size}
984
+ self._out_features = [out_feature]
985
+
986
+ # if self.pos_embed is not None:
987
+ # nn.init.trunc_normal_(self.pos_embed, std=0.02)
988
+ if self.pos_embed is not None:
989
+ nn.init.normal_(self.pos_embed, std=0.02)
990
+
991
+ # MIN SHI: I disable the weight initialization since they will be automatically loaded
992
+ # **However, they will cause problems (deepspeed + bf16)**
993
+ # self.apply(self._init_weights)
994
+ self._freeze_stages()
995
+
996
+ # def _init_weights(self, m):
997
+ # if isinstance(m, nn.Linear):
998
+ # nn.init.trunc_normal_(m.weight, std=0.02)
999
+ # if isinstance(m, nn.Linear) and m.bias is not None:
1000
+ # nn.init.constant_(m.bias, 0)
1001
+ # elif isinstance(m, nn.LayerNorm):
1002
+ # nn.init.constant_(m.bias, 0)
1003
+ # nn.init.constant_(m.weight, 1.0)
1004
+
1005
+ def _freeze_stages(self):
1006
+ if self.frozen:
1007
+ self.eval()
1008
+ for m in self.parameters():
1009
+ m.requires_grad = False
1010
+
1011
+ def forward(self, x):
1012
+ x = self.patch_embed(x)
1013
+ if self.pos_embed is not None:
1014
+ x = x + get_abs_pos(
1015
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
1016
+ )
1017
+
1018
+ for blk in self.blocks:
1019
+ x = blk(x, with_cp=self.gradient_checkpointing) # b, h, w, c
1020
+ x = x.permute(0, 3, 1, 2) # b, c, h, w
1021
+
1022
+ # if self.adapter is not None:
1023
+ # outputs = self.adapter(x)
1024
+ # else:
1025
+ # outputs = [x, ]
1026
+
1027
+ # return outputs
1028
+ return x
1029
+
1030
+ '''
1031
+ EVA VIT vision encoder for LLaVA
1032
+ '''
1033
+ class EVAVITVisionTower(nn.Module):
1034
+ def __init__(self, vision_tower, args, delay_load=False):
1035
+ super().__init__()
1036
+
1037
+ self.is_loaded = False
1038
+ self.vision_tower_name = vision_tower
1039
+ self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect
1040
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
1041
+
1042
+ self.args = args
1043
+ self.vision_tower, vision_tower_config = build_eva_vit(args=args,
1044
+ model_name=vision_tower,
1045
+ image_size=args.input_image_size
1046
+ )
1047
+ self.input_image_size=args.input_image_size
1048
+ self.vision_tower.config = vision_tower_config
1049
+ self.freeze_vision = args.freeze_vision
1050
+
1051
+ if not self.is_loaded:
1052
+ self.load_model()
1053
+ # if not delay_load:
1054
+ # self.load_model()
1055
+ # else:
1056
+ # self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
1057
+
1058
+ def load_model(self):
1059
+ if self.is_loaded:
1060
+ return
1061
+
1062
+ # self.args.vision_tower_input_size = 224 # hardcode
1063
+ self.image_processor = CLIPImageProcessor(crop_size={"height": self.args.input_image_size, "width": self.args.input_image_size},
1064
+ size={'shortest_edge': self.args.input_image_size},
1065
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
1066
+ image_std=[0.26862954, 0.26130258, 0.27577711])
1067
+
1068
+ # load weights
1069
+ if self.args.vision_tower_pretrained_from is None:
1070
+ self.args.vision_tower_pretrained_from = "/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth"
1071
+
1072
+ # pretrained_params = torch.load(self.args.vision_tower_pretrained_from)
1073
+ # if 'ema_state' in pretrained_params:
1074
+ # pretrained_params = pretrained_params['ema_state']
1075
+ # elif 'module' in pretrained_params:
1076
+ # pretrained_params = pretrained_params['module']
1077
+
1078
+ # from collections import OrderedDict
1079
+ # new_params = OrderedDict()
1080
+
1081
+ # kw = ""
1082
+ # if "det" in self.args.vision_tower_pretrained_from.lower():
1083
+ # kw = "backbone.net."
1084
+ # elif "clip" in self.args.vision_tower_pretrained_from.lower():
1085
+ # kw = "visual."
1086
+
1087
+ # for k, v in pretrained_params.items():
1088
+ # if len(kw) > 0:
1089
+ # if kw in k and ("rope" not in k):
1090
+ # new_params[k.replace(kw, "")] = v
1091
+ # else:
1092
+ # if "rope" not in k:
1093
+ # new_params[k] = v
1094
+
1095
+ # incompatiblekeys = self.vision_tower.load_state_dict(new_params, strict=False)
1096
+ # for k in incompatiblekeys[0]:
1097
+ # if "rope" not in k:
1098
+ # warnings.warn(f"Find incompatible keys {k} in state dict.")
1099
+
1100
+ # print(f"EVA-02 ckpt loaded from {self.args.vision_tower_pretrained_from}")
1101
+
1102
+ if self.freeze_vision:
1103
+ self.vision_tower.requires_grad_(False)
1104
+
1105
+ self.is_loaded = True
1106
+
1107
+
1108
+ # @torch.no_grad()
1109
+ def forward(self, images):
1110
+ if type(images) is list:
1111
+ image_features = []
1112
+ for image in images:
1113
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
1114
+ image_feature = image_forward_out.flatten(2,3).transpose(1,2) # b, n, c
1115
+ image_features.append(image_feature)
1116
+ else:
1117
+ image_forward_out = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
1118
+
1119
+ return image_forward_out
1120
+
1121
+ @property
1122
+ def dummy_feature(self):
1123
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
1124
+
1125
+ @property
1126
+ def dtype(self):
1127
+ return next(self.vision_tower.parameters()).dtype
1128
+
1129
+ @property
1130
+ def device(self):
1131
+ return next(self.vision_tower.parameters()).device
1132
+
1133
+ @property
1134
+ def config(self):
1135
+ # if self.is_loaded:
1136
+ # return self.vision_tower.config
1137
+ # else:
1138
+ # return self.cfg_only
1139
+ # TODO
1140
+ return self.vision_tower.config
1141
+
1142
+ @property
1143
+ def hidden_size(self):
1144
+ #return self.config.hidden_size
1145
+ return self.config['hidden_dim']
1146
+
1147
+ @property
1148
+ def num_patches(self):
1149
+ # return (self.config.image_size // self.config.patch_size) ** 2
1150
+ return self.config['num_patches']
1151
+
1152
+
1153
+ def build_eva_vit(args,
1154
+ model_name=None,
1155
+ image_size=224,
1156
+ window_attn=True
1157
+ ):
1158
+
1159
+ if "336" in args.vision_tower_pretrained_from:
1160
+ pretrained_image_size = 336
1161
+ else:
1162
+ pretrained_image_size = 224
1163
+
1164
+ if "clip" in args.vision_tower_pretrained_from.lower():
1165
+ subln = True
1166
+ else:
1167
+ subln = False
1168
+
1169
+ if model_name == 'eva02-l-16':
1170
+ # shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth
1171
+ if window_attn:
1172
+ window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23)))
1173
+ else:
1174
+ window_block_indexes = ()
1175
+
1176
+ model = EVAViT(
1177
+ img_size=image_size,
1178
+ patch_size=16,
1179
+ window_size=16,
1180
+ in_chans=3,
1181
+ embed_dim=1024,
1182
+ depth=24,
1183
+ num_heads=16,
1184
+ mlp_ratio=4*2/3,
1185
+ window_block_indexes = window_block_indexes,
1186
+ qkv_bias=True,
1187
+ drop_path_rate=0.0,
1188
+ xattn=False,
1189
+ # with_cp=False,
1190
+ # frozen=True,
1191
+ )
1192
+ # image_size = 224 # HARDCODE
1193
+ eva_config = dict(image_size=image_size,
1194
+ patch_size=16,
1195
+ window_size=16,
1196
+ hidden_dim=1024,
1197
+ depth=24,
1198
+ num_heads=16,
1199
+ window_block_indexes=window_block_indexes,
1200
+ num_patches=image_size ** 2 // 16 ** 2,
1201
+ pretrained_from=args.vision_tower_pretrained_from
1202
+ )
1203
+
1204
+ elif model_name == 'eva02-l-14':
1205
+ # shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth
1206
+ if window_attn:
1207
+ window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23)))
1208
+ else:
1209
+ window_block_indexes = ()
1210
+
1211
+ model = EVAViT(
1212
+ img_size=image_size,
1213
+ pretrain_img_size=pretrained_image_size,
1214
+ patch_size=14,
1215
+ window_size=16,
1216
+ in_chans=3,
1217
+ embed_dim=1024,
1218
+ depth=24,
1219
+ num_heads=16,
1220
+ mlp_ratio=4*2/3,
1221
+ window_block_indexes = window_block_indexes,
1222
+ qkv_bias=True,
1223
+ drop_path_rate=0.0,
1224
+ xattn=False,
1225
+ # with_cp=False,
1226
+ subln=subln,
1227
+ # frozen=True,
1228
+ )
1229
+ # image_size = 224 # HARDCODE
1230
+ eva_config = dict(image_size=image_size,
1231
+ patch_size=14,
1232
+ window_size=16,
1233
+ hidden_dim=1024,
1234
+ depth=24,
1235
+ num_heads=16,
1236
+ window_block_indexes=window_block_indexes,
1237
+ num_patches=image_size ** 2 // 14 ** 2,
1238
+ pretrained_from=args.vision_tower_pretrained_from
1239
+ )
1240
+
1241
+ else:
1242
+ raise NotImplementedError
1243
+
1244
+ return model, eva_config
eagle/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+ # from llava.model.multimodal_projector.deformable_resampler import DeformableResampler
5
+
6
+ class IdentityMap(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x, *args, **kwargs):
11
+ return x
12
+
13
+ @property
14
+ def config(self):
15
+ return {"mm_projector_type": 'identity'}
16
+
17
+
18
+ class SimpleResBlock(nn.Module):
19
+ def __init__(self, channels):
20
+ super().__init__()
21
+ self.pre_norm = nn.LayerNorm(channels)
22
+
23
+ self.proj = nn.Sequential(
24
+ nn.Linear(channels, channels),
25
+ nn.GELU(),
26
+ nn.Linear(channels, channels)
27
+ )
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, fpn_input_dim=[], **kwargs):
34
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
35
+
36
+ if projector_type == 'linear':
37
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
38
+
39
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40
+ if mlp_gelu_match:
41
+ mlp_depth = int(mlp_gelu_match.group(1))
42
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43
+ for _ in range(1, mlp_depth):
44
+ modules.append(nn.GELU())
45
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46
+ return nn.Sequential(*modules)
47
+
48
+ # resampler_match = re.match(r'^deformable-resampler-l(\d+)d(\d+)p(\d+)', projector_type)
49
+ # if resampler_match:
50
+ # use_fpn = "fpn" in projector_type or len(fpn_input_dim) > 0
51
+ # layer_num = int(resampler_match.group(1))
52
+ # embed_dim = int(resampler_match.group(2))
53
+ # sample_point = int(resampler_match.group(3))
54
+ # if len(fpn_input_dim) > 0:
55
+ # fpn_type = 'multi-level'
56
+ # else:
57
+ # fpn_type = 'simple'
58
+
59
+ # return DeformableResampler(input_dimension=config.mm_hidden_size,
60
+ # output_dimension=config.hidden_size,
61
+ # query_number=config.mm_projector_query_number,
62
+ # num_layers=layer_num,
63
+ # num_heads=8,
64
+ # feedforward_dims=2048,
65
+ # embed_dims=embed_dim,
66
+ # num_points=sample_point,
67
+ # direct_projection=True,
68
+ # use_fpn=use_fpn,
69
+ # fpn_config=dict(
70
+ # fpn_type=fpn_type,
71
+ # in_channels=fpn_input_dim))
72
+
73
+ if projector_type == 'identity':
74
+ return IdentityMap()
75
+
76
+ raise ValueError(f'Unknown projector type: {projector_type}')
eagle/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from eagle.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True, encoding='UTF-8')
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+ def __init__(self, logger, log_level=logging.INFO):
65
+ self.terminal = sys.stdout
66
+ self.logger = logger
67
+ self.log_level = log_level
68
+ self.linebuf = ''
69
+
70
+ def __getattr__(self, attr):
71
+ return getattr(self.terminal, attr)
72
+
73
+ def write(self, buf):
74
+ temp_linebuf = self.linebuf + buf
75
+ self.linebuf = ''
76
+ for line in temp_linebuf.splitlines(True):
77
+ # From the io.TextIOWrapper docs:
78
+ # On output, if newline is None, any '\n' characters written
79
+ # are translated to the system default line separator.
80
+ # By default sys.stdout.write() expects '\n' newlines and then
81
+ # translates them so this is still cross platform.
82
+ if line[-1] == '\n':
83
+ self.logger.log(self.log_level, line.rstrip())
84
+ else:
85
+ self.linebuf += line
86
+
87
+ def flush(self):
88
+ if self.linebuf != '':
89
+ self.logger.log(self.log_level, self.linebuf.rstrip())
90
+ self.linebuf = ''
91
+
92
+
93
+ def disable_torch_init():
94
+ """
95
+ Disable the redundant torch default initialization to accelerate model creation.
96
+ """
97
+ import torch
98
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
+
101
+
102
+ def violates_moderation(text):
103
+ """
104
+ Check whether the text violates OpenAI moderation API.
105
+ """
106
+ url = "https://api.openai.com/v1/moderations"
107
+ headers = {"Content-Type": "application/json",
108
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109
+ text = text.replace("\n", "")
110
+ data = "{" + '"input": ' + f'"{text}"' + "}"
111
+ data = data.encode("utf-8")
112
+ try:
113
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
114
+ flagged = ret.json()["results"][0]["flagged"]
115
+ except requests.exceptions.RequestException as e:
116
+ flagged = False
117
+ except KeyError as e:
118
+ flagged = False
119
+
120
+ return flagged
121
+
122
+
123
+ def pretty_print_semaphore(semaphore):
124
+ if semaphore is None:
125
+ return "None"
126
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.2
2
+ torchvision==0.16.2
3
+ transformers==4.37.2
4
+ tokenizers==0.15.1
5
+ sentencepiece==0.1.99
6
+ shortuuid
7
+ accelerate==0.21.0
8
+ peft
9
+ bitsandbytes
10
+ pydantic
11
+ markdown2[all]
12
+ numpy
13
+ scikit-learn==1.2.2
14
+ #gradio==4.38.1
15
+ #gradio_client==1.1.0
16
+ gradio==4.16.0
17
+ gradio_client==0.8.1
18
+ requests
19
+ httpx==0.27.0
20
+ uvicorn
21
+ fastapi
22
+ einops==0.6.1
23
+ einops-exts==0.0.4
24
+ timm==0.9.11
25
+ opencv-python
26
+ fvcore