|
"""Inference-only MERaLiON AudioLLM model compatible with HuggingFace weights.""" |
|
from functools import lru_cache |
|
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union |
|
|
|
import librosa |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from vllm.attention import AttentionMetadata |
|
from vllm.config import VllmConfig |
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, |
|
InputContext, token_inputs) |
|
from vllm.logger import init_logger |
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor |
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler |
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead |
|
from vllm.model_executor.model_loader.weight_utils import ( |
|
default_weight_loader, maybe_remap_kv_scale_name) |
|
from vllm.model_executor.models.gemma2 import Gemma2Model |
|
from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs |
|
from vllm.multimodal.utils import consecutive_placeholder_ranges |
|
from vllm.sequence import IntermediateTensors, SequenceData |
|
|
|
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP |
|
from vllm.model_executor.models.utils import maybe_prefix |
|
|
|
from .modeling_meralion import MERaLiONSpeechEncoder |
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
_KEYS_TO_MODIFY_MAPPING = { |
|
"text_decoder.model": "text_decoder", |
|
} |
|
|
|
|
|
|
|
class MERaLiONInputs(TypedDict): |
|
input_features: torch.Tensor |
|
"""Shape: |
|
`(num_audios, num_mel_bins, 3000)` |
|
""" |
|
|
|
feature_attention_mask: torch.Tensor |
|
"""Shape: `(num_audios, 3000)` |
|
""" |
|
|
|
|
|
|
|
class MERaLiONSpeechAudioAdaper(nn.Module): |
|
def __init__(self, audio_hidden_size: int, text_hidden_size: int): |
|
super(MERaLiONSpeechAudioAdaper, self).__init__() |
|
speech_mlp_scale_factor = 15 |
|
|
|
self.speech_mlp_scale_factor = speech_mlp_scale_factor |
|
self.mlp_adapter = nn.Sequential( |
|
nn.Linear( |
|
in_features=audio_hidden_size * speech_mlp_scale_factor, |
|
out_features=audio_hidden_size |
|
), |
|
nn.SiLU(), |
|
nn.Dropout(0.1), |
|
) |
|
|
|
self.speech_llm_proj = nn.Sequential( |
|
nn.Linear( |
|
audio_hidden_size, |
|
audio_hidden_size * 4 |
|
), |
|
nn.SiLU(), |
|
nn.Dropout(0.1), |
|
|
|
nn.Linear( |
|
audio_hidden_size * 4, |
|
text_hidden_size |
|
), |
|
) |
|
|
|
def forward(self, speech_embeds, **kwargs): |
|
B, T, C = speech_embeds.shape |
|
speech_embeds = self.mlp_adapter( |
|
speech_embeds.reshape( |
|
B, |
|
T // self.speech_mlp_scale_factor, |
|
C * self.speech_mlp_scale_factor, |
|
) |
|
) |
|
return self.speech_llm_proj(speech_embeds) |
|
|
|
|
|
def dummy_data_for_meralion(ctx: InputContext, seq_len: int, |
|
mm_counts: Mapping[str, int]): |
|
num_audios = mm_counts["audio"] |
|
max_tokens_per_audio = get_max_meralion_audio_tokens(ctx) |
|
max_llm_audio_tokens = max_tokens_per_audio * num_audios |
|
if seq_len - max_llm_audio_tokens - 2 < 0: |
|
raise RuntimeError( |
|
f"MERaLiON-AudioLLM cannot process {num_audios} audios in a prompt, " |
|
"please increase max_model_len or reduce audio limit by " |
|
"--limit-mm-per-prompt.") |
|
|
|
speech_token_index = ctx.model_config.hf_config.speech_token_index |
|
|
|
dummy_seqdata = SequenceData.from_prompt_token_counts( |
|
(speech_token_index, max_llm_audio_tokens), |
|
(0, seq_len - max_llm_audio_tokens), |
|
) |
|
dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.) |
|
return DummyData( |
|
dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}, { |
|
"audio": |
|
consecutive_placeholder_ranges(num_items=num_audios, |
|
item_size=max_tokens_per_audio) |
|
}) |
|
|
|
|
|
def get_processor( |
|
processor_name: str, |
|
*args, |
|
trust_remote_code: bool = True, |
|
**kwargs, |
|
): |
|
"""Gets a processor for the given model name via HuggingFace. |
|
|
|
Derived from `vllm.transformers_utils.image_processor.get_image_processor`. |
|
""" |
|
|
|
|
|
from transformers import AutoProcessor |
|
|
|
try: |
|
processor = AutoProcessor.from_pretrained( |
|
processor_name, |
|
*args, |
|
trust_remote_code=trust_remote_code, |
|
**kwargs) |
|
except ValueError as e: |
|
|
|
|
|
|
|
if not trust_remote_code: |
|
err_msg = ( |
|
"Failed to load the processor. If the processor is " |
|
"a custom processor not yet available in the HuggingFace " |
|
"transformers library, consider setting " |
|
"`trust_remote_code=True` in LLM or using the " |
|
"`--trust-remote-code` flag in the CLI.") |
|
raise RuntimeError(err_msg) from e |
|
else: |
|
raise e |
|
|
|
return processor |
|
|
|
|
|
cached_get_processor = lru_cache(get_processor) |
|
|
|
|
|
def get_max_meralion_audio_tokens(ctx: InputContext) -> int: |
|
""" |
|
The max number of tokens after speech audio adapter. |
|
""" |
|
return 100 |
|
|
|
|
|
def input_processor_for_meralion( |
|
ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs: |
|
multi_modal_data = inputs.get("multi_modal_data") |
|
if multi_modal_data is None or "audio" not in multi_modal_data: |
|
return inputs |
|
|
|
audios = multi_modal_data["audio"] |
|
if not isinstance(audios, list): |
|
audios = [audios] |
|
|
|
if len(audios) == 0: |
|
return inputs |
|
|
|
processor = cached_get_processor(ctx.model_config.model) |
|
resampled_audios = [ |
|
librosa.resample(audio, |
|
orig_sr=sampling_rate, |
|
target_sr=processor.feature_extractor.sampling_rate) |
|
for audio, sampling_rate in audios |
|
] |
|
|
|
audio_input_lengths = np.array( |
|
[min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios]) |
|
|
|
audio_output_length = get_max_meralion_audio_tokens(ctx) |
|
speech_token_index = ctx.model_config.hf_config.speech_token_index |
|
|
|
input_ids = inputs['prompt_token_ids'] |
|
|
|
new_input_ids = [] |
|
audio_num = input_ids.count(speech_token_index) |
|
assert len(audio_input_lengths) == audio_num, \ |
|
(f'The text input contains {audio_num} audio tokens, ' |
|
f'but {len(audio_input_lengths)} audios provided') |
|
start = 0 |
|
for _ in range(audio_num): |
|
end = input_ids.index(speech_token_index, start) |
|
new_input_ids.extend(input_ids[start:end]) |
|
|
|
new_input_ids.extend([speech_token_index] * audio_output_length) |
|
start = end + 1 |
|
new_input_ids.extend(input_ids[start:]) |
|
|
|
return token_inputs( |
|
prompt_token_ids=new_input_ids, |
|
prompt=inputs['prompt'], |
|
multi_modal_data=multi_modal_data, |
|
) |
|
|
|
|
|
def input_mapper_for_meralion( |
|
ctx: InputContext, |
|
multi_modal_data: Union[np.ndarray, List[np.ndarray]], |
|
) -> MultiModalKwargs: |
|
"""Input mapper for MERaLiON-AudioLLM.""" |
|
if not isinstance(multi_modal_data, list): |
|
multi_modal_data = [multi_modal_data] |
|
|
|
if len(multi_modal_data) == 0: |
|
return MultiModalKwargs() |
|
|
|
processor = cached_get_processor(ctx.model_config.model) |
|
audio_feature_extractor = processor.feature_extractor |
|
if audio_feature_extractor is None: |
|
raise RuntimeError( |
|
"No HuggingFace audio_feature_extractor is available " |
|
"to process the audio object") |
|
|
|
try: |
|
resampled_audios = [ |
|
librosa.resample( |
|
audio, |
|
orig_sr=sampling_rate, |
|
target_sr=processor.feature_extractor.sampling_rate) |
|
for audio, sampling_rate in multi_modal_data |
|
] |
|
batch_data = audio_feature_extractor(resampled_audios, |
|
sampling_rate=16000, |
|
return_attention_mask=True, |
|
padding="max_length", |
|
return_tensors="pt").data |
|
batch_data["feature_attention_mask"] = batch_data.pop("attention_mask") |
|
except Exception: |
|
logger.error("Failed to process audio (%s)", multi_modal_data) |
|
raise |
|
|
|
return MultiModalKwargs(batch_data) |
|
|
|
|
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_meralion) |
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_meralion) |
|
@MULTIMODAL_REGISTRY.register_input_mapper("audio", |
|
input_mapper_for_meralion) |
|
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( |
|
"audio", get_max_meralion_audio_tokens) |
|
class MERaLiONForConditionalGeneration(nn.Module, SupportsMultiModal, |
|
SupportsPP): |
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
super().__init__() |
|
config = vllm_config.model_config.hf_config |
|
quant_config = vllm_config.quant_config |
|
multimodal_config = vllm_config.model_config.multimodal_config |
|
self.config = config |
|
self.multimodal_config = multimodal_config |
|
|
|
self.speech_encoder = MERaLiONSpeechEncoder(config.speech_config) |
|
self.ln_speech = nn.LayerNorm(config.speech_config.d_model) |
|
self.speech_audio_adapter = MERaLiONSpeechAudioAdaper( |
|
config.speech_config.d_model, config.text_config.hidden_size) |
|
|
|
self.quant_config = quant_config |
|
|
|
self.text_decoder = Gemma2Model( |
|
vllm_config=vllm_config.with_hf_config(config.text_config), |
|
prefix=maybe_prefix(prefix, "model")) |
|
self.unpadded_vocab_size = config.text_config.vocab_size |
|
if config.text_config.tie_word_embeddings: |
|
self.lm_head = self.text_decoder.embed_tokens |
|
else: |
|
self.lm_head = ParallelLMHead(config.text_config.vocab_size, |
|
config.text_config.hidden_size, |
|
quant_config=quant_config) |
|
logit_scale = getattr(config, "logit_scale", 1.0) |
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, |
|
config.text_config.vocab_size, |
|
logit_scale) |
|
self.sampler = get_sampler() |
|
|
|
self.make_empty_intermediate_tensors = ( |
|
self.text_decoder.make_empty_intermediate_tensors) |
|
|
|
def _validate_and_reshape_mm_tensor(self, |
|
mm_input: Union[torch.Tensor, |
|
List[torch.Tensor]], |
|
name: str) -> torch.Tensor: |
|
if not isinstance(mm_input, (torch.Tensor, list)): |
|
raise ValueError(f"Incorrect type of {name}. " |
|
f"Got type: {type(mm_input)}") |
|
if isinstance(mm_input, torch.Tensor): |
|
return torch.concat(list(mm_input)) |
|
else: |
|
return torch.concat(mm_input) |
|
|
|
def _parse_and_validate_audio_input( |
|
self, **kwargs: object) -> Optional[MERaLiONInputs]: |
|
input_features = kwargs.pop('input_features', None) |
|
feature_attention_mask = kwargs.pop('feature_attention_mask', None) |
|
if input_features is None: |
|
return None |
|
input_features = self._validate_and_reshape_mm_tensor( |
|
input_features, 'input_features') |
|
feature_attention_mask = self._validate_and_reshape_mm_tensor( |
|
feature_attention_mask, 'feature_attention_mask') |
|
if not isinstance(input_features, (torch.Tensor, list)): |
|
raise ValueError("Incorrect type of audio input features. " |
|
f"Got type: {type(input_features)}") |
|
return MERaLiONInputs(input_features=input_features, |
|
feature_attention_mask=feature_attention_mask) |
|
|
|
def _process_audio_input(self, |
|
audio_input: MERaLiONInputs) -> torch.Tensor: |
|
|
|
input_features = audio_input["input_features"].to(self.speech_encoder.dtype) |
|
feature_attention_mask = audio_input["feature_attention_mask"] |
|
|
|
audio_outputs = self.speech_encoder(input_features, |
|
attention_mask=feature_attention_mask) |
|
audio_features = audio_outputs.last_hidden_state |
|
audio_features = self.ln_speech(audio_features) |
|
audio_features = self.speech_audio_adapter(audio_features) |
|
audio_features = audio_features.view(-1, audio_features.size(-1)) |
|
|
|
return audio_features |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
positions: torch.Tensor, |
|
kv_caches: List[torch.Tensor], |
|
attn_metadata: AttentionMetadata, |
|
intermediate_tensors: Optional[IntermediateTensors] = None, |
|
**kwargs: object, |
|
) -> Union[torch.Tensor, IntermediateTensors]: |
|
if intermediate_tensors is not None: |
|
input_ids = None |
|
inputs_embeds = None |
|
else: |
|
audio_input = self._parse_and_validate_audio_input(**kwargs) |
|
|
|
if audio_input is None: |
|
inputs_embeds = None |
|
else: |
|
inputs_embeds = self.text_decoder.embed_tokens(input_ids) |
|
processed_audio_features = self._process_audio_input(audio_input) |
|
|
|
mask = (input_ids == self.config.speech_token_index) |
|
inputs_embeds[mask, :] = processed_audio_features |
|
|
|
input_ids = None |
|
|
|
hidden_states = self.text_decoder( |
|
input_ids=input_ids, |
|
positions=positions, |
|
kv_caches=kv_caches, |
|
attn_metadata=attn_metadata, |
|
intermediate_tensors=intermediate_tensors, |
|
inputs_embeds=inputs_embeds, |
|
) |
|
return hidden_states |
|
|
|
def compute_logits(self, hidden_states: torch.Tensor, |
|
sampling_metadata: SamplingMetadata) -> torch.Tensor: |
|
logits = self.logits_processor(self.lm_head, hidden_states, |
|
sampling_metadata) |
|
return logits |
|
|
|
def sample( |
|
self, |
|
logits: torch.Tensor, |
|
sampling_metadata: SamplingMetadata, |
|
) -> Optional[SamplerOutput]: |
|
next_tokens = self.sampler(logits, sampling_metadata) |
|
return next_tokens |
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
|
stacked_params_mapping = [ |
|
|
|
("qkv_proj", "q_proj", "q"), |
|
("qkv_proj", "k_proj", "k"), |
|
("qkv_proj", "v_proj", "v"), |
|
("gate_up_proj", "gate_proj", 0), |
|
("gate_up_proj", "up_proj", 1), |
|
] |
|
params_dict = dict(self.named_parameters(remove_duplicate=False)) |
|
|
|
for name, loaded_weight in weights: |
|
if "rotary_emb.inv_freq" in name: |
|
continue |
|
if (self.config.text_config.tie_word_embeddings |
|
and "lm_head.weight" in name): |
|
continue |
|
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): |
|
if key_to_modify in name: |
|
name = name.replace(key_to_modify, new_key) |
|
for (param_name, weight_name, shard_id) in stacked_params_mapping: |
|
if weight_name not in name or 'speech_encoder' in name: |
|
continue |
|
name = name.replace(weight_name, param_name) |
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
continue |
|
param = params_dict[name] |
|
weight_loader = param.weight_loader |
|
weight_loader(param, loaded_weight, shard_id) |
|
break |
|
else: |
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
continue |
|
|
|
name = maybe_remap_kv_scale_name(name, params_dict) |
|
if name is None: |
|
continue |
|
|
|
param = params_dict[name] |
|
weight_loader = getattr(param, "weight_loader", |
|
default_weight_loader) |
|
weight_loader(param, loaded_weight) |
|
|