yizhilll commited on
Commit
7a45a98
·
1 Parent(s): 3bb10de

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -3
README.md CHANGED
@@ -65,21 +65,33 @@ from datasets import load_dataset
65
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
66
  dataset = dataset.sort("id")
67
  sampling_rate = dataset.features["audio"].sampling_rate
68
- processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
 
 
 
 
 
 
 
69
 
70
  # loading our model weights
71
  commit_hash='bccff5376fc07235d88954b43e5cd739fbc0796b' # this is recommended for security reason, the hash might be updated
72
  model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True, revision=commit_hash)
73
 
74
  # audio file is decoded on the fly
75
- inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
 
 
 
 
 
76
  with torch.no_grad():
77
  outputs = model(**inputs, output_hidden_states=True)
78
 
79
  # take a look at the output shape, there are 13 layers of representation
80
  # each layer performs differently in different downstream tasks, you should choose empirically
81
  all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
82
- print(all_layer_hidden_states.shape) # [13 layer, 292 timestep, 768 feature_dim]
83
 
84
  # for utterance level classification tasks, you can simply reduce the representation in time
85
  time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
 
65
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
66
  dataset = dataset.sort("id")
67
  sampling_rate = dataset.features["audio"].sampling_rate
68
+ processor = Wav2Vec2Processor.from_pretrained("m-a-p/MERT-v1-95M")
69
+
70
+ resample_rate = processor.feature_extractor.sampling_rate
71
+ # make sure the sample_rate aligned
72
+ if resample_rate != sampling_rate:
73
+ resampler = T.Resample(sample_rate, resample_rate)
74
+ else:
75
+ resampler = None
76
 
77
  # loading our model weights
78
  commit_hash='bccff5376fc07235d88954b43e5cd739fbc0796b' # this is recommended for security reason, the hash might be updated
79
  model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True, revision=commit_hash)
80
 
81
  # audio file is decoded on the fly
82
+ if resampler is None:
83
+ input_audio = dataset[0]["audio"]["array"]
84
+ else:
85
+ input_audio = resampler(dataset[0]["audio"]["array"])
86
+
87
+ inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
88
  with torch.no_grad():
89
  outputs = model(**inputs, output_hidden_states=True)
90
 
91
  # take a look at the output shape, there are 13 layers of representation
92
  # each layer performs differently in different downstream tasks, you should choose empirically
93
  all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
94
+ print(all_layer_hidden_states.shape) # [13 layer, Time steps, 768 feature_dim]
95
 
96
  # for utterance level classification tasks, you can simply reduce the representation in time
97
  time_reduced_hidden_states = all_layer_hidden_states.mean(-2)