DestaModel
custom_code
kehanlu commited on
Commit
8ab1e14
·
verified ·
1 Parent(s): 8d682ff

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. __init__.py +0 -0
  2. config.json +154 -0
  3. modeling_desta.py +216 -0
  4. qformer_connector.pth +3 -0
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "llama_config": {
3
+ "_name_or_path": "meta-llama/Meta-Llama-3-8B-Instruct",
4
+ "architectures": [
5
+ "LlamaForCausalLM"
6
+ ],
7
+ "bos_token_id": 128000,
8
+ "eos_token_id": 128009,
9
+ "intermediate_size": 14336,
10
+ "max_position_embeddings": 8192,
11
+ "model_type": "llama",
12
+ "num_key_value_heads": 8,
13
+ "rms_norm_eps": 1e-05,
14
+ "rope_theta": 500000.0,
15
+ "torch_dtype": "bfloat16",
16
+ "vocab_size": 128256
17
+ },
18
+ "auto_map": {
19
+ "AutoConfig": "modeling_desta.Desta2Config",
20
+ "AutoModel": "modeling_desta.DestaModel"
21
+ },
22
+ "llama_model_id": "meta-llama/Meta-Llama-3-8B-Instruct",
23
+ "model_type": "DestaModel",
24
+ "prompt_size": 64,
25
+ "transformers_version": "4.41.1",
26
+ "whisper_config": {
27
+ "_name_or_path": "openai/whisper-small",
28
+ "architectures": [
29
+ "WhisperForConditionalGeneration"
30
+ ],
31
+ "begin_suppress_tokens": [
32
+ 220,
33
+ 50257
34
+ ],
35
+ "bos_token_id": 50257,
36
+ "d_model": 768,
37
+ "decoder_attention_heads": 12,
38
+ "decoder_ffn_dim": 3072,
39
+ "decoder_layers": 12,
40
+ "decoder_start_token_id": 50258,
41
+ "encoder_attention_heads": 12,
42
+ "encoder_ffn_dim": 3072,
43
+ "encoder_layers": 12,
44
+ "eos_token_id": 50257,
45
+ "forced_decoder_ids": [
46
+ [
47
+ 1,
48
+ 50259
49
+ ],
50
+ [
51
+ 2,
52
+ 50359
53
+ ],
54
+ [
55
+ 3,
56
+ 50363
57
+ ]
58
+ ],
59
+ "max_length": 448,
60
+ "model_type": "whisper",
61
+ "num_hidden_layers": 12,
62
+ "pad_token_id": 50257,
63
+ "suppress_tokens": [
64
+ 1,
65
+ 2,
66
+ 7,
67
+ 8,
68
+ 9,
69
+ 10,
70
+ 14,
71
+ 25,
72
+ 26,
73
+ 27,
74
+ 28,
75
+ 29,
76
+ 31,
77
+ 58,
78
+ 59,
79
+ 60,
80
+ 61,
81
+ 62,
82
+ 63,
83
+ 90,
84
+ 91,
85
+ 92,
86
+ 93,
87
+ 359,
88
+ 503,
89
+ 522,
90
+ 542,
91
+ 873,
92
+ 893,
93
+ 902,
94
+ 918,
95
+ 922,
96
+ 931,
97
+ 1350,
98
+ 1853,
99
+ 1982,
100
+ 2460,
101
+ 2627,
102
+ 3246,
103
+ 3253,
104
+ 3268,
105
+ 3536,
106
+ 3846,
107
+ 3961,
108
+ 4183,
109
+ 4667,
110
+ 6585,
111
+ 6647,
112
+ 7273,
113
+ 9061,
114
+ 9383,
115
+ 10428,
116
+ 10929,
117
+ 11938,
118
+ 12033,
119
+ 12331,
120
+ 12562,
121
+ 13793,
122
+ 14157,
123
+ 14635,
124
+ 15265,
125
+ 15618,
126
+ 16553,
127
+ 16604,
128
+ 18362,
129
+ 18956,
130
+ 20075,
131
+ 21675,
132
+ 22520,
133
+ 26130,
134
+ 26161,
135
+ 26435,
136
+ 28279,
137
+ 29464,
138
+ 31650,
139
+ 32302,
140
+ 32470,
141
+ 36865,
142
+ 42863,
143
+ 47425,
144
+ 49870,
145
+ 50254,
146
+ 50258,
147
+ 50360,
148
+ 50361,
149
+ 50362
150
+ ],
151
+ "torch_dtype": "float32"
152
+ },
153
+ "whisper_model_id": "openai/whisper-small"
154
+ }
modeling_desta.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, WhisperForConditionalGeneration, PretrainedConfig, PreTrainedModel, BertConfig, AutoProcessor
2
+ from transformers.models.bert.modeling_bert import BertEncoder
3
+ from torch import nn
4
+ import torch
5
+ import os
6
+
7
+
8
+ class Desta2Config(PretrainedConfig):
9
+ model_type = "DestaModel"
10
+
11
+ def __init__(
12
+ self,
13
+ llama_model_id="meta-llama/Meta-Llama-3-8B-Instruct",
14
+ whisper_model_id="openai/whisper-small",
15
+ prompt_size=64,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.llama_model_id = llama_model_id
20
+ self.whisper_model_id = whisper_model_id
21
+ self.prompt_size = prompt_size
22
+
23
+ self.whisper_config = AutoConfig.from_pretrained(self.whisper_model_id)
24
+ self.llama_config = AutoConfig.from_pretrained(self.llama_model_id)
25
+
26
+ class QformerConnector(PreTrainedModel):
27
+ def __init__(self, cfg):
28
+ super().__init__(cfg)
29
+ self.cfg = cfg
30
+
31
+
32
+ if self.cfg.whisper_model_id == "openai/whisper-medium":
33
+ self.target_layer_ids = [5, 11, 17, 23]
34
+ elif self.cfg.whisper_model_id == "openai/whisper-small":
35
+ self.target_layer_ids = [2, 5, 8, 11]
36
+ elif self.cfg.whisper_model_id == "openai/whisper-tiny":
37
+ self.target_layer_ids = [0,1,2,3]
38
+ elif self.cfg.whisper_model_id == "openai/whisper-large-v3":
39
+ self.target_layer_ids = [3, 7, 11, 15, 19, 23, 27, 31]
40
+ else:
41
+ raise NotImplementedError(f"model_id {self.cfg.whisper_model_id} not implemented")
42
+
43
+
44
+ self.layer_prompts = nn.ParameterList([
45
+ nn.Parameter(torch.randn(1, self.cfg.prompt_size, self.cfg.whisper_config.d_model)) for _ in range(len(self.target_layer_ids))]
46
+ )
47
+
48
+
49
+ # (prompt_size, target_layers)
50
+ self.layer_weights = nn.Parameter(torch.zeros(self.cfg.prompt_size, len(self.target_layer_ids), dtype=torch.float))
51
+
52
+ qformer_config = BertConfig()
53
+ qformer_config.num_hidden_layers = 2
54
+ qformer_config.num_attention_heads = self.cfg.whisper_config.encoder_attention_heads
55
+ qformer_config.hidden_size = self.cfg.whisper_config.d_model
56
+ qformer_config.add_cross_attention = True
57
+ qformer_config.is_decoder = True
58
+
59
+ self.qformer = BertEncoder(qformer_config)
60
+ self.proj = nn.Sequential(
61
+ nn.LayerNorm(self.cfg.whisper_config.d_model),
62
+ nn.Linear(self.cfg.whisper_config.d_model, self.cfg.llama_config.hidden_size) # project to llama hidden size
63
+ )
64
+
65
+ def forward(self, encoder_hidden_states):
66
+ layer_prompt_outputs = []
67
+ for idx, encoder_hidden_state in enumerate(encoder_hidden_states):
68
+ if idx in self.target_layer_ids:
69
+ layer_prompt = self.layer_prompts[self.target_layer_ids.index(idx)].expand(encoder_hidden_state.size(0), -1, -1)
70
+ qformer_output = self.qformer(
71
+ hidden_states=layer_prompt,
72
+ encoder_hidden_states=encoder_hidden_state,
73
+ )
74
+ layer_prompt_output = qformer_output.last_hidden_state
75
+ layer_prompt_outputs.append(layer_prompt_output)
76
+
77
+ layer_prompt_outputs = torch.stack(layer_prompt_outputs, dim=0)
78
+ layer_prompt_outputs = layer_prompt_outputs.permute(1, 2, 0, 3)
79
+
80
+ self.norm_weights = torch.nn.functional.softmax(self.layer_weights, dim=-1).unsqueeze(-1)
81
+
82
+ output = (layer_prompt_outputs * self.norm_weights).sum(dim=2) # (b, prompt_size, d_model)
83
+
84
+ output = self.proj(output)
85
+
86
+ return output
87
+
88
+ class SpeechPerception(PreTrainedModel):
89
+ def __init__(self, cfg):
90
+ super().__init__(cfg)
91
+ self.cfg = cfg
92
+
93
+ self.whisper = WhisperForConditionalGeneration.from_pretrained(cfg.whisper_model_id)
94
+ self.processor = AutoProcessor.from_pretrained(cfg.whisper_model_id)
95
+
96
+ self.connector = QformerConnector(cfg)
97
+
98
+ def generate(self, input_features):
99
+ input_features = input_features.to(self.whisper.device)
100
+
101
+ outputs = self.whisper.generate(inputs=input_features, return_dict_in_generate=True, output_hidden_states=True) # here we use default generate config for whisper
102
+
103
+ transcriptions = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
104
+ speech_features = self.connector(outputs.encoder_hidden_states)
105
+
106
+ return transcriptions, speech_features
107
+
108
+
109
+ class DestaModel(PreTrainedModel):
110
+ config_class = Desta2Config
111
+
112
+ def __init__(self, config):
113
+ super().__init__(config)
114
+
115
+ self.speech_perception = SpeechPerception(config)
116
+ self.llama = AutoModelForCausalLM.from_pretrained(config.llama_model_id, torch_dtype=torch.bfloat16)
117
+ self.tokenizer = AutoTokenizer.from_pretrained(config.llama_model_id)
118
+
119
+
120
+ def chat(self, messages, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.9):
121
+ """
122
+ messages: list of dicts with keys "role" and "content"
123
+ ```
124
+ [
125
+ {"role": "system", "content": "You are a helpful voice assistant."},
126
+ {"role": "audio", "content": "<path_to_audio_file>"},
127
+ {"role": "user", "content": "Describe the audio."}
128
+ ]
129
+ ```
130
+ """
131
+
132
+ audio_path, input_features = self.load_audio(messages)
133
+ transcription, audio_features = self.speech_perception.generate(input_features)
134
+ inputs, audio_position = self.process_text(messages, audio_path, transcription)
135
+
136
+ inputs_embeds, attention_mask = self.prepare_llm_input(
137
+ input_ids=inputs.input_ids,
138
+ attention_mask=inputs.attention_mask,
139
+ audio_position=audio_position,
140
+ audio_features=audio_features
141
+ )
142
+
143
+ outputs = self.llama.generate(
144
+ inputs_embeds=inputs_embeds,
145
+ attention_mask=attention_mask,
146
+ pad_token_id=self.tokenizer.eos_token_id,
147
+ max_new_tokens=max_new_tokens,
148
+ do_sample=do_sample,
149
+ temperature=temperature,
150
+ top_p=top_p,
151
+ )
152
+ return outputs
153
+
154
+ def process_text(self, messages, audio_path, transcription):
155
+ context = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
156
+ left_text, right_text = context.split(audio_path)
157
+ right_text = transcription + right_text #
158
+
159
+ audio_position = len(self.tokenizer.tokenize(left_text))
160
+ context = left_text + right_text
161
+
162
+ inputs = self.tokenizer(context, return_tensors="pt")
163
+
164
+ return inputs, audio_position
165
+
166
+
167
+ def prepare_llm_input(self, input_ids, attention_mask, audio_position, audio_features):
168
+ input_ids = input_ids.to(self.llama.device)
169
+ attention_mask = attention_mask.to(self.llama.device)
170
+ audio_features = audio_features.to(self.llama.device)
171
+ audio_feature_length = audio_features.size(1)
172
+
173
+ inputs_embeds = self.llama.model.embed_tokens(input_ids) # [bs, seq_len, hidden_size]
174
+
175
+
176
+ inputs_embeds = torch.cat([inputs_embeds[0, :audio_position], audio_features[0, :], inputs_embeds[0, audio_position:]], dim=0)
177
+ attention_mask = torch.cat([attention_mask[0, :audio_position], torch.ones([ audio_feature_length], dtype=torch.long, device=self.llama.device), attention_mask[0, audio_position:]], dim=0)
178
+
179
+ inputs_embeds = inputs_embeds.to(self.llama.dtype)
180
+ attention_mask = attention_mask.to(self.llama.dtype)
181
+ return inputs_embeds.unsqueeze(0), attention_mask.unsqueeze(0)
182
+
183
+
184
+ def load_audio(self, messages):
185
+ audio_path = None
186
+ for message in messages:
187
+ if message["role"] == "audio" and audio_path is not None:
188
+ raise ValueError("Multiple audio file paths found in messages. We only support one audio file per message at this moment.")
189
+ if message["role"] == "audio":
190
+ audio_path = message["content"]
191
+ if audio_path is None:
192
+ raise ValueError("No audio file path found in messages")
193
+ audio, ori_sr = librosa.load(audio_path)
194
+ audio = librosa.resample(audio, orig_sr=ori_sr, target_sr=16000)
195
+ input_features = self.speech_perception.processor(audio, sampling_rate=16000, return_tensors="pt").input_features
196
+
197
+ return audio_path, input_features
198
+
199
+ @classmethod
200
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, cache_dir=None,**kwargs):
201
+ config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
202
+ model = cls(config)
203
+
204
+ if os.path.isdir(pretrained_model_name_or_path):
205
+ model.speech_perception.connector.load_state_dict(
206
+ torch.load(os.path.join(pretrained_model_name_or_path, "qformer_connector.pth"))
207
+ )
208
+ else:
209
+ from huggingface_hub import hf_hub_download
210
+ path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="qformer_connector.pth")
211
+ model.speech_perception.connector.load_state_dict(
212
+ torch.load(path)
213
+ )
214
+
215
+ return model
216
+
qformer_connector.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25ec091d3a5e51f2d7e86d43b3867829c97f781017b21362d247b3985c74ad8f
3
+ size 89031593