Model Card for Model ID
Model Details
Model Description
The model consists of a music encoder MERT-v1-300M
, a natural language decoder vicuna-7b-delta-v0
, and a linear projection laer between the two.
This checkpoint of MusiLingo is developed on the MusicInstruct (MI)-short and can answer short instructions with music raw audio, such as querying about the tempo, emotion, genre, tags information. You can use the MI dataset for the following demo
Model Sources [optional]
Getting Start
from tqdm.auto import tqdm
import torch
from torch.utils.data import DataLoader
from transformers import Wav2Vec2FeatureExtractor
from transformers import StoppingCriteria, StoppingCriteriaList
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):
# see https://huggingface.co/m-a-p/MusiLingo-musicqa-v1 for load_audio function definition
audio = load_audio(audio_path, target_sr=24000,
is_mono=True,
is_normalize=False,
crop_to_length_in_sample_points=int(30*16000)+1,
crop_randomly=True,
pad=False).cuda()
processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
audio = processor(audio,
sampling_rate=24000,
return_tensors="pt")['input_values'][0].cuda()
audio_embeds, atts_audio = model.encode_audio(audio)
prompt = '<Audio><AudioHere></Audio> ' + text
instruction_prompt = [model.prompt_template.format(prompt)]
audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
model.llama_tokenizer.padding_side = "right"
batch_size = audio_embeds.shape[0]
bos = torch.ones([batch_size, 1],
dtype=torch.long,
device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
bos_embeds = model.llama_model.model.embed_tokens(bos)
# atts_bos = atts_audio[:, :1]
inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
# attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
outputs = model.llama_model.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=max_new_tokens,
stopping_criteria=stopping,
num_beams=num_beams,
do_sample=True,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
)
output_token = outputs[0]
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
output_token = output_token[1:]
if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
output_token = output_token[1:]
output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
output_text = output_text.split('###')[0] # remove the stop sign '###'
output_text = output_text.split('Assistant:')[-1].strip()
return output_text
musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1", trust_remote_code=True)
musilingo.to("cuda")
musilingo.eval()
prompt = "this is the task instruction and input question for MusiLingo model"
audio = "/path/to/the/audio"
stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
torch.tensor([2277, 29937]).cuda()])])
response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)
Citing This Work
If you find the work useful for your research, please consider citing it using the following BibTeX entry:
@inproceedings{deng2024musilingo,
title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
year={2024},
organization={Association for Computational Linguistics}
}