googlefan commited on
Commit
eac7d2b
·
verified ·
1 Parent(s): 1589c32

Create ultravox_processing.py

Browse files
Files changed (1) hide show
  1. ultravox_processing.py +210 -0
ultravox_processing.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import transformers
6
+
7
+ from .ultravox_config import UltravoxConfig
8
+
9
+
10
+ class UltravoxProcessor(transformers.ProcessorMixin):
11
+ """
12
+ Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
13
+ Args:
14
+ audio_processor: The audio processor for the audio encoder.
15
+ tokenizer: The tokenizer for the language model.
16
+ """
17
+
18
+ attributes = ["audio_processor", "tokenizer"]
19
+ audio_processor_class = (
20
+ "Wav2Vec2Processor",
21
+ "SeamlessM4TFeatureExtractor",
22
+ "WhisperProcessor",
23
+ )
24
+ tokenizer_class = (
25
+ "PreTrainedTokenizer",
26
+ "PreTrainedTokenizerFast",
27
+ )
28
+
29
+ tokenizer: transformers.PreTrainedTokenizerBase
30
+ audio_processor: transformers.ProcessorMixin
31
+
32
+ def __init__(
33
+ self,
34
+ audio_processor=None,
35
+ tokenizer=None,
36
+ audio_padding: str = "longest",
37
+ encoder_ds_factor: int = 320,
38
+ stack_factor: int = 8,
39
+ audio_placeholder: str = "<|audio|>",
40
+ ):
41
+ """
42
+ Args:
43
+ audio_processor: The audio processor for the audio encoder.
44
+ tokenizer: The tokenizer for the language model.
45
+ audio_padding: The padding strategy for the audio encoder.
46
+ encoder_ds_factor: The downsample factor of the audio encoder.
47
+ stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
48
+ audio_placeholder: The placeholder for the audio in the text.
49
+ """
50
+ self.audio_padding = audio_padding
51
+ self.encoder_ds_factor = encoder_ds_factor
52
+ self.stack_factor = stack_factor
53
+ self.audio_placeholder = audio_placeholder
54
+ self.audio_token_replacement = tokenizer.eos_token
55
+ assert (
56
+ self.audio_token_replacement is not None
57
+ ), "The tokenizer has no EOS token. Cannot recover."
58
+ if tokenizer.pad_token_id is None:
59
+ tokenizer.pad_token_id = tokenizer.eos_token_id
60
+
61
+ super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
62
+
63
+ @classmethod
64
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
65
+ config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
66
+ pretrained_model_name_or_path, **kwargs
67
+ )
68
+ audio_processor = transformers.AutoProcessor.from_pretrained(
69
+ config.audio_model_id
70
+ or config.audio_config._name_or_path
71
+ or "facebook/wav2vec2-base-960h"
72
+ )
73
+
74
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
75
+ pretrained_model_name_or_path, **kwargs
76
+ )
77
+ tokenizer.padding_side = "left"
78
+ tokenizer.pad_token = tokenizer.eos_token
79
+
80
+ return cls(
81
+ audio_processor=audio_processor,
82
+ tokenizer=tokenizer,
83
+ stack_factor=config.stack_factor,
84
+ )
85
+
86
+ def __call__(
87
+ self,
88
+ text: Optional[str] = None,
89
+ audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
90
+ sampling_rate: Optional[int] = None,
91
+ return_tensors: Optional[
92
+ Union[str, transformers.TensorType]
93
+ ] = transformers.TensorType.PYTORCH,
94
+ **kwargs,
95
+ ) -> transformers.BatchFeature:
96
+ """
97
+ Main method to prepare for the model one text sequence and audio. This method forwards the `text`
98
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
99
+ the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
100
+ audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
101
+ of the above two methods for more information.
102
+ Args:
103
+ text (`str`, `List[str]`):
104
+ The sequence to be encoded. Sequence can be a string or (pretokenized string).
105
+ audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
106
+ The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
107
+ NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
108
+ sample length of the audio.
109
+ sampling_rate (`int`, *optional*, defaults to 16000):
110
+ Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
111
+ you are doing.
112
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
113
+ If set, will return tensors of a particular framework. Acceptable values are:
114
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
115
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
116
+ - `'np'`: Return NumPy `np.ndarray` objects.
117
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
118
+ Returns:
119
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
120
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
121
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
122
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
123
+ `None`).
124
+ - **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`.
125
+ - **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound.
126
+ Returned when `audio` is not `None`.
127
+ - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
128
+ """
129
+ # TODO: Add support for multiple audio and text inputs.
130
+ data = {}
131
+ audio_embed_frames = 0
132
+ if audio is not None and len(audio) > 0:
133
+ if self.audio_padding == "max_length":
134
+ # 30 seconds is the expected length for Whisper
135
+ assert sampling_rate is not None, "Sampling rate must be provided."
136
+ audio_len = 30 * sampling_rate
137
+ else:
138
+ audio_len = audio.shape[-1]
139
+ # It's guaranteed that the number of frames is less than or equal to this amount.
140
+ # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
141
+ # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
142
+ nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
143
+ audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
144
+ data["audio_token_len"] = [audio_embed_frames]
145
+
146
+ # Main audio processing. The processor is model-specific.
147
+ x = self.audio_processor(
148
+ audio,
149
+ sampling_rate=sampling_rate,
150
+ padding="longest",
151
+ max_length=audio_len,
152
+ return_attention_mask=True,
153
+ **kwargs,
154
+ )
155
+ if "input_features" in x:
156
+ data["audio_values"] = x.input_features
157
+ else:
158
+ data["audio_values"] = x.input_values
159
+ if self.audio_padding == "max_length":
160
+ data["audio_len"] = x.attention_mask.sum(-1) - 1
161
+ else:
162
+ data["audio_len"] = [data["audio_values"].shape[-1]]
163
+
164
+ if text is not None:
165
+ assert isinstance(
166
+ text, str
167
+ ), "Text must be a string. Batch mode not supported yet."
168
+ if self.audio_placeholder in text:
169
+ if "audio_token_len" not in data:
170
+ raise ValueError(
171
+ f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
172
+ )
173
+
174
+ start_idx = len(
175
+ self.tokenizer.encode(
176
+ text[: text.index(self.audio_placeholder)],
177
+ add_special_tokens=False,
178
+ )
179
+ )
180
+ data["audio_token_start_idx"] = [start_idx]
181
+
182
+ # Replace the audio placeholder with the audio token.
183
+ # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
184
+ # where the number of </s> is the number of audio frames.
185
+ text = text.replace(
186
+ self.audio_placeholder,
187
+ self.audio_token_replacement * audio_embed_frames,
188
+ )
189
+
190
+ # Special tokens like BOS should already have been added by the caller.
191
+ data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
192
+
193
+ return transformers.BatchFeature(data=data, tensor_type=return_tensors)
194
+
195
+ def batch_decode(self, *args, **kwargs):
196
+ return self.tokenizer.batch_decode(*args, **kwargs)
197
+
198
+ def decode(self, *args, **kwargs):
199
+ return self.tokenizer.decode(*args, **kwargs)
200
+
201
+ @property
202
+ def model_input_names(self):
203
+ tokenizer_input_names = self.tokenizer.model_input_names
204
+ audio_processor_input_names = self.audio_processor.model_input_names
205
+ return list(set(tokenizer_input_names + audio_processor_input_names))
206
+
207
+
208
+ UltravoxProcessor.register_for_auto_class()
209
+
210
+ transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)