yizhilll commited on
Commit
8fd246b
1 Parent(s): 55fa29e

Update README.md

Browse files

fix the demo code

Files changed (1) hide show
  1. README.md +15 -10
README.md CHANGED
@@ -55,34 +55,39 @@ More details will be written in our coming-soon paper.
55
  # Model Usage
56
 
57
  ```python
58
- from transformers import Wav2Vec2Processor
 
59
  from transformers import AutoModel
60
  import torch
61
  from torch import nn
 
62
  from datasets import load_dataset
63
 
 
 
 
 
 
 
 
64
  # load demo audio and set processor
65
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
66
  dataset = dataset.sort("id")
67
  sampling_rate = dataset.features["audio"].sampling_rate
68
- processor = Wav2Vec2Processor.from_pretrained("m-a-p/MERT-v1-95M")
69
 
70
- resample_rate = processor.feature_extractor.sampling_rate
71
  # make sure the sample_rate aligned
72
  if resample_rate != sampling_rate:
73
- resampler = T.Resample(sample_rate, resample_rate)
 
74
  else:
75
- resampler = None
76
-
77
- # loading our model weights
78
- commit_hash='bccff5376fc07235d88954b43e5cd739fbc0796b' # this is recommended for security reason, the hash might be updated
79
- model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True, revision=commit_hash)
80
 
81
  # audio file is decoded on the fly
82
  if resampler is None:
83
  input_audio = dataset[0]["audio"]["array"]
84
  else:
85
- input_audio = resampler(dataset[0]["audio"]["array"])
86
 
87
  inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
88
  with torch.no_grad():
 
55
  # Model Usage
56
 
57
  ```python
58
+ # from transformers import Wav2Vec2Processor
59
+ from transformers import Wav2Vec2FeatureExtractor
60
  from transformers import AutoModel
61
  import torch
62
  from torch import nn
63
+ import torchaudio.transforms as T
64
  from datasets import load_dataset
65
 
66
+
67
+ commit_hash='55fa29e5522049926c03d2ff9ae54d22c20e668f'# this is recommended for security reason, the hash might be updated
68
+ # loading our model weights
69
+ model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True, revision=commit_hash)
70
+ # loading the corresponding preprocessor config
71
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M",trust_remote_code=True, revision=commit_hash)
72
+
73
  # load demo audio and set processor
74
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
75
  dataset = dataset.sort("id")
76
  sampling_rate = dataset.features["audio"].sampling_rate
 
77
 
78
+ resample_rate = processor.sampling_rate
79
  # make sure the sample_rate aligned
80
  if resample_rate != sampling_rate:
81
+ print(f'setting rate from {sampling_rate} to {resample_rate}')
82
+ resampler = T.Resample(sampling_rate, resample_rate)
83
  else:
84
+ resampler = None
 
 
 
 
85
 
86
  # audio file is decoded on the fly
87
  if resampler is None:
88
  input_audio = dataset[0]["audio"]["array"]
89
  else:
90
+ input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
91
 
92
  inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
93
  with torch.no_grad():