L0SG commited on
Commit
2cefcfb
1 Parent(s): 78885f4
.gitignore CHANGED
@@ -1,6 +1,137 @@
1
- *.pyc
2
- __pycache__/
3
- */__pycache__/
4
- alias_free_cuda/build/
5
  exp/
6
- tmp/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BigVGAN
2
+ alias_free_activation/cuda/build/
 
 
3
  exp/
4
+ tmp/
5
+
6
+ # VSCode configs
7
+ .vscode/
8
+
9
+ # Byte-compiled / optimized / DLL files
10
+ __pycache__/
11
+ *.py[cod]
12
+ *$py.class
13
+
14
+ # C extensions
15
+ *.so
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ .pybuilder/
84
+ target/
85
+
86
+ # Jupyter Notebook
87
+ .ipynb_checkpoints
88
+
89
+ # IPython
90
+ profile_default/
91
+ ipython_config.py
92
+
93
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
94
+ __pypackages__/
95
+
96
+ # Celery stuff
97
+ celerybeat-schedule
98
+ celerybeat.pid
99
+
100
+ # SageMath parsed files
101
+ *.sage.py
102
+
103
+ # Environments
104
+ .env
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ # pytype static type analyzer
131
+ .pytype/
132
+
133
+ # Cython debug symbols
134
+ cython_debug/
135
+
136
+ # PyCharm
137
+ .idea/
README_model.md CHANGED
@@ -1,37 +1,96 @@
1
  ## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
 
2
  #### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
3
 
4
- <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
5
 
 
6
 
7
- ### [Paper](https://arxiv.org/abs/2206.04658) &emsp; [Project page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) &emsp; [Audio demo](https://bigvgan-demo.github.io/)
8
 
9
  ## News
10
- [Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
11
- * Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
12
- * Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
13
- * Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
14
- * We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
 
 
 
 
 
 
 
 
15
 
16
  ## Installation
 
17
  The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
 
18
  ```shell
19
  conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
20
  conda activate bigvgan
21
  ```
22
 
23
  Clone the repository and install dependencies:
 
24
  ```shell
25
  git clone https://github.com/NVIDIA/BigVGAN
26
  cd BigVGAN
27
  pip install -r requirements.txt
28
  ```
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
31
 
32
  Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
33
- ``` shell
34
- cd LibriTTS && \
 
35
  ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
36
  ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
37
  ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
@@ -39,29 +98,30 @@ ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
39
  ln -s /path/to/your/LibriTTS/dev-other dev-other && \
40
  ln -s /path/to/your/LibriTTS/test-clean test-clean && \
41
  ln -s /path/to/your/LibriTTS/test-other test-other && \
42
- cd ..
43
  ```
44
 
45
- ## Training
46
  Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
 
47
  ```shell
48
  python train.py \
49
  --config configs/bigvgan_v2_24khz_100band_256x.json \
50
- --input_wavs_dir LibriTTS \
51
- --input_training_file LibriTTS/train-full.txt \
52
- --input_validation_file LibriTTS/val-full.txt \
53
- --list_input_unseen_wavs_dir LibriTTS LibriTTS \
54
- --list_input_unseen_validation_file LibriTTS/dev-clean.txt LibriTTS/dev-other.txt \
55
  --checkpoint_path exp/bigvgan_v2_24khz_100band_256x
56
  ```
57
 
58
-
59
  ## Synthesis
 
60
  Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
61
  It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
 
62
  ```shell
63
  python inference.py \
64
- --checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \
65
  --input_wavs_dir /path/to/your/input_wav \
66
  --output_dir /path/to/your/output_wav
67
  ```
@@ -70,14 +130,16 @@ python inference.py \
70
  It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
71
 
72
  Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
 
73
  ```shell
74
  python inference_e2e.py \
75
- --checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \
76
  --input_mels_dir /path/to/your/input_mel \
77
  --output_dir /path/to/your/output_wav
78
  ```
79
 
80
  ## Using Custom CUDA Kernel for Synthesis
 
81
  You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
82
 
83
  ```python
@@ -86,15 +148,15 @@ generator = BigVGAN(h, use_cuda_kernel=True)
86
 
87
  You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
88
 
89
- When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
90
 
91
  Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
92
 
93
  We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
94
 
95
  ```python
96
- python test_cuda_vs_torch_model.py \
97
- --checkpoint_file /path/to/your/bigvgan/g_03000000
98
  ```
99
 
100
  ```shell
@@ -102,12 +164,12 @@ loading plain Pytorch BigVGAN
102
  ...
103
  loading CUDA kernel BigVGAN with auto-build
104
  Detected CUDA files, patching ldflags
105
- Emitting ninja build file /path/to/your/BigVGAN/alias_free_cuda/build/build.ninja...
106
  Building extension module anti_alias_activation_cuda...
107
  ...
108
  Loading extension module anti_alias_activation_cuda...
109
  ...
110
- Loading '/path/to/your/bigvgan/g_03000000'
111
  ...
112
  [Success] test CUDA fused vs. plain torch BigVGAN inference
113
  > mean_difference=0.0007238413265440613
@@ -116,30 +178,34 @@ Loading '/path/to/your/bigvgan/g_03000000'
116
 
117
  If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
118
 
119
-
120
  ## Pretrained Models
121
- We provide the [pretrained models](https://drive.google.com/drive/folders/1L2RDeJMBE7QAI8qV51n0QAf4mkSgUUeE?usp=sharing).
122
- One can download the checkpoints of the generator weight (e.g., `g_(training_steps)`) and its discriminator/optimizer states (e.g., `do_(training_steps)`) within the listed folders.
123
-
124
- |Folder Name|Sampling Rate|Mel band|fmax|Upsampling Ratio|Params.|Dataset|Fine-Tuned|
125
- |------|---|---|---|---|---|------|---|
126
- |bigvgan_v2_44khz_128band_512x|44 kHz|128|22050|512|122M|Large-scale Compilation|No|
127
- |bigvgan_v2_44khz_128band_256x|44 kHz|128|22050|256|112M|Large-scale Compilation|No|
128
- |bigvgan_v2_24khz_100band_256x|24 kHz|100|12000|256|112M|Large-scale Compilation|No|
129
- |bigvgan_v2_22khz_80band_256x|22 kHz|80|11025|256|112M|Large-scale Compilation|No|
130
- |bigvgan_v2_22khz_80band_fmax8k_256x|22 kHz|80|8000|256|112M|Large-scale Compilation|No|
131
- |bigvgan_24khz_100band|24 kHz|100|12000|256|112M|LibriTTS|No|
132
- |bigvgan_base_24khz_100band|24 kHz|100|12000|256|14M|LibriTTS|No|
133
- |bigvgan_22khz_80band|22 kHz|80|8000|256|112M|LibriTTS + VCTK + LJSpeech|No|
134
- |bigvgan_base_22khz_80band|22 kHz|80|8000|256|14M|LibriTTS + VCTK + LJSpeech|No|
 
135
 
136
  The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
137
  We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
138
- Note that the checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality.
 
 
139
 
140
- You can fine-tune the models by downloading the checkpoints (both the generator weight and its discrimiantor/optimizer states) and resuming training using your audio dataset.
 
141
 
142
  ## Training Details of BigVGAN-v2
 
143
  Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
144
 
145
  Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
@@ -147,23 +213,50 @@ Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` a
147
  When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
148
 
149
  ## Evaluation Results of BigVGAN-v2
 
150
  Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
151
 
152
- |Model|Dataset|Steps|PESQ(↑)|M-STFT(↓)|MCD(↓)|Periodicity(↓)|V/UV F1(↑)|
153
- |-------|-----|-----|-----|-----|-----|-----|-----|
154
- |BigVGAN|LibriTTS|1M|4.027|0.7997|0.3745|0.1018|0.9598|
155
- |BigVGAN|LibriTTS|5M|4.256|0.7409|0.2988|0.0809|0.9698|
156
- |BigVGAN-v2|Large-scale Compilation|3M|**4.359**|**0.7134**|0.3060|**0.0621**|**0.9777**|
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  ## Acknowledgements
 
159
  We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
160
 
161
  ## References
162
- * [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
163
- * [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
164
- * [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
165
- * [Julius](https://github.com/adefossez/julius) (for low-pass filter)
166
- * [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
167
- * [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
168
- * [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
169
 
 
 
 
 
 
 
 
 
1
  ## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
2
+
3
  #### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
4
 
5
+ [[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
6
 
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
8
 
9
+ <center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
10
 
11
  ## News
12
+ - **Jul 2024 (v2.3):**
13
+ - General refactor and code improvements for improved readability.
14
+ - Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
15
+
16
+ - **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
17
+
18
+ - **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
19
+
20
+ - **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
21
+ - Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
22
+ - Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
23
+ - Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
24
+ - We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
25
 
26
  ## Installation
27
+
28
  The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
29
+
30
  ```shell
31
  conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
32
  conda activate bigvgan
33
  ```
34
 
35
  Clone the repository and install dependencies:
36
+
37
  ```shell
38
  git clone https://github.com/NVIDIA/BigVGAN
39
  cd BigVGAN
40
  pip install -r requirements.txt
41
  ```
42
 
43
+ ## Inference Quickstart using 🤗 Hugging Face Hub
44
+
45
+ Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
46
+
47
+ ```python
48
+ device = 'cuda'
49
+
50
+ import torch
51
+ import bigvgan
52
+ import librosa
53
+ from meldataset import get_mel_spectrogram
54
+
55
+ # instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
56
+ model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
57
+
58
+ # remove weight norm in the model and set to eval mode
59
+ model.remove_weight_norm()
60
+ model = model.eval().to(device)
61
+
62
+ # load wav file and compute mel spectrogram
63
+ wav_path = '/path/to/your/audio.wav'
64
+ wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
65
+ wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
66
+
67
+ # compute mel spectrogram from the ground truth audio
68
+ mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
69
+
70
+ # generate waveform from mel
71
+ with torch.inference_mode():
72
+ wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
73
+ wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
74
+
75
+ # you can convert the generated waveform to 16 bit linear PCM
76
+ wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
77
+ ```
78
+
79
+ ## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
80
 
81
+ You can run a local gradio demo using below command:
82
+
83
+ ```python
84
+ pip install -r demo/requirements.txt
85
+ python demo/app.py
86
+ ```
87
+
88
+ ## Training
89
 
90
  Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
91
+
92
+ ```shell
93
+ cd filelists/LibriTTS && \
94
  ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
95
  ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
96
  ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
 
98
  ln -s /path/to/your/LibriTTS/dev-other dev-other && \
99
  ln -s /path/to/your/LibriTTS/test-clean test-clean && \
100
  ln -s /path/to/your/LibriTTS/test-other test-other && \
101
+ cd ../..
102
  ```
103
 
 
104
  Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
105
+
106
  ```shell
107
  python train.py \
108
  --config configs/bigvgan_v2_24khz_100band_256x.json \
109
+ --input_wavs_dir filelists/LibriTTS \
110
+ --input_training_file filelists/LibriTTS/train-full.txt \
111
+ --input_validation_file filelists/LibriTTS/val-full.txt \
112
+ --list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
113
+ --list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
114
  --checkpoint_path exp/bigvgan_v2_24khz_100band_256x
115
  ```
116
 
 
117
  ## Synthesis
118
+
119
  Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
120
  It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
121
+
122
  ```shell
123
  python inference.py \
124
+ --checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
125
  --input_wavs_dir /path/to/your/input_wav \
126
  --output_dir /path/to/your/output_wav
127
  ```
 
130
  It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
131
 
132
  Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
133
+
134
  ```shell
135
  python inference_e2e.py \
136
+ --checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
137
  --input_mels_dir /path/to/your/input_mel \
138
  --output_dir /path/to/your/output_wav
139
  ```
140
 
141
  ## Using Custom CUDA Kernel for Synthesis
142
+
143
  You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
144
 
145
  ```python
 
148
 
149
  You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
150
 
151
+ When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
152
 
153
  Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
154
 
155
  We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
156
 
157
  ```python
158
+ python tests/test_cuda_vs_torch_model.py \
159
+ --checkpoint_file /path/to/your/bigvgan_generator.pt
160
  ```
161
 
162
  ```shell
 
164
  ...
165
  loading CUDA kernel BigVGAN with auto-build
166
  Detected CUDA files, patching ldflags
167
+ Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
168
  Building extension module anti_alias_activation_cuda...
169
  ...
170
  Loading extension module anti_alias_activation_cuda...
171
  ...
172
+ Loading '/path/to/your/bigvgan_generator.pt'
173
  ...
174
  [Success] test CUDA fused vs. plain torch BigVGAN inference
175
  > mean_difference=0.0007238413265440613
 
178
 
179
  If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
180
 
 
181
  ## Pretrained Models
182
+
183
+ We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
184
+ One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
185
+
186
+ | Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
187
+ |:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
188
+ | [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 3M | No |
189
+ | [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 3M | No |
190
+ | [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 3M | No |
191
+ | [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 3M | No |
192
+ | [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 3M | No |
193
+ | [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
194
+ | [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
195
+ | [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
196
+ | [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
197
 
198
  The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
199
  We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
200
+ Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
201
+
202
+ You can fine-tune the models by:
203
 
204
+ 1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
205
+ 2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
206
 
207
  ## Training Details of BigVGAN-v2
208
+
209
  Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
210
 
211
  Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
 
213
  When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
214
 
215
  ## Evaluation Results of BigVGAN-v2
216
+
217
  Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
218
 
219
+ | Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
220
+ |:----------:|:-----------------------:|:-----:|:---------:|:----------:|:------:|:--------------:|:----------:|
221
+ | BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
222
+ | BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
223
+ | BigVGAN-v2 | Large-scale Compilation | 3M | **4.359** | **0.7134** | 0.3060 | **0.0621** | **0.9777** |
224
+
225
+ ## Speed Benchmark
226
+
227
+ Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
228
+
229
+ | GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
230
+ |:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
231
+ | NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
232
+ | | | True | 3916.5 | 163.2x | 1.3 |
233
+ | | 2048 | False | 1899.6 | 79.2x | 1.7 |
234
+ | | | True | 5330.1 | 222.1x | 1.7 |
235
+ | | 16384 | False | 1973.8 | 82.2x | 5.0 |
236
+ | | | True | 5761.7 | 240.1x | 4.4 |
237
+ | NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
238
+ | | | True | 1598.1 | 66.6x | 1.3 |
239
+ | | 2048 | False | 929.9 | 38.7x | 1.7 |
240
+ | | | True | 1971.3 | 82.1x | 1.6 |
241
+ | | 16384 | False | 943.4 | 39.3x | 5.0 |
242
+ | | | True | 2026.5 | 84.4x | 3.9 |
243
+ | NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
244
+ | | | True | 811.3 | 33.8x | 1.3 |
245
+ | | 2048 | False | 576.5 | 24.0x | 1.7 |
246
+ | | | True | 1023.0 | 42.6x | 1.5 |
247
+ | | 16384 | False | 589.4 | 24.6x | 5.0 |
248
+ | | | True | 1068.1 | 44.5x | 3.2 |
249
 
250
  ## Acknowledgements
251
+
252
  We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
253
 
254
  ## References
 
 
 
 
 
 
 
255
 
256
+ - [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
257
+ - [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
258
+ - [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
259
+ - [Julius](https://github.com/adefossez/julius) (for low-pass filter)
260
+ - [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
261
+ - [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
262
+ - [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
{alias_free_cuda → alias_free_activation/cuda}/__init__.py RENAMED
File without changes
{alias_free_cuda → alias_free_activation/cuda}/activation1d.py RENAMED
@@ -3,36 +3,45 @@
3
 
4
  import torch
5
  import torch.nn as nn
6
- from alias_free_torch.resample import UpSample1d, DownSample1d
 
7
  # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
8
- from alias_free_cuda import load
9
- load.load()
 
 
10
 
11
  class FusedAntiAliasActivation(torch.autograd.Function):
12
  """
13
- Assumes filter size 12, replication padding on upsampling, and logscale alpha/beta parameters as inputs
 
 
14
  """
 
15
  @staticmethod
16
- def forward(ctx, inputs, ftr, alpha, beta):
17
- import anti_alias_activation_cuda
18
- activation_results = anti_alias_activation_cuda.forward(inputs, ftr, alpha, beta)
 
 
19
  return activation_results
20
 
21
  @staticmethod
22
  def backward(ctx, output_grads):
23
- # TODO: implement bwd pass
24
  raise NotImplementedError
25
  return output_grads, None, None
26
 
 
27
  class Activation1d(nn.Module):
28
- def __init__(self,
29
- activation,
30
- up_ratio: int = 2,
31
- down_ratio: int = 2,
32
- up_kernel_size: int = 12,
33
- down_kernel_size: int = 12,
34
- fused: bool = True
35
- ):
 
36
  super().__init__()
37
  self.up_ratio = up_ratio
38
  self.down_ratio = down_ratio
@@ -40,8 +49,7 @@ class Activation1d(nn.Module):
40
  self.upsample = UpSample1d(up_ratio, up_kernel_size)
41
  self.downsample = DownSample1d(down_ratio, down_kernel_size)
42
 
43
- self.fused = fused # whether to use fused CUDA kernel or not
44
-
45
 
46
  def forward(self, x):
47
  if not self.fused:
@@ -51,13 +59,19 @@ class Activation1d(nn.Module):
51
  return x
52
  else:
53
  if self.act.__class__.__name__ == "Snake":
54
- beta = self.act.alpha.data # snake uses same params for alpha and beta
55
  else:
56
- beta = self.act.beta.data # snakebeta uses different params for alpha and beta
 
 
57
  alpha = self.act.alpha.data
58
- if not self.act.alpha_logscale: # exp baked into cuda kernel, cancel it out with a log
 
 
59
  alpha = torch.log(alpha)
60
  beta = torch.log(beta)
61
- x = FusedAntiAliasActivation.apply(x, self.upsample.filter, alpha, beta)
62
- x = self.downsample(x)
 
 
63
  return x
 
3
 
4
  import torch
5
  import torch.nn as nn
6
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
7
+
8
  # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from alias_free_activation.cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
 
14
  class FusedAntiAliasActivation(torch.autograd.Function):
15
  """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
  """
20
+
21
  @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, up_ftr, down_ftr, alpha, beta
25
+ )
26
+
27
  return activation_results
28
 
29
  @staticmethod
30
  def backward(ctx, output_grads):
 
31
  raise NotImplementedError
32
  return output_grads, None, None
33
 
34
+
35
  class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
  super().__init__()
46
  self.up_ratio = up_ratio
47
  self.down_ratio = down_ratio
 
49
  self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
  self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
 
52
+ self.fused = fused # Whether to use fused CUDA kernel or not
 
53
 
54
  def forward(self, x):
55
  if not self.fused:
 
59
  return x
60
  else:
61
  if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
63
  else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # Snakebeta uses different params for alpha and beta
67
  alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # Exp baked into cuda kernel, cancel it out with a log
71
  alpha = torch.log(alpha)
72
  beta = torch.log(beta)
73
+
74
+ x = FusedAntiAliasActivation.apply(
75
+ x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76
+ )
77
  return x
{alias_free_cuda → alias_free_activation/cuda}/anti_alias_activation.cpp RENAMED
@@ -14,35 +14,10 @@
14
  * limitations under the License.
15
  */
16
 
17
- #include <cuda_fp16.h>
18
- #include <torch/extension.h>
19
- #include <vector>
20
 
21
- namespace anti_alias_activation {
22
-
23
- torch::Tensor fwd_cuda(torch::Tensor const& input,
24
- torch::Tensor const& filter,
25
- torch::Tensor const& alpha,
26
- torch::Tensor const& beta
27
- );
28
-
29
- torch::Tensor fwd(torch::Tensor const& input,
30
- torch::Tensor const& filter,
31
- torch::Tensor const& alpha,
32
- torch::Tensor const& beta
33
- ) {
34
- AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
35
- //AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
36
- // (input.scalar_type() == at::ScalarType::BFloat16),
37
- // "Only fp16 and bf16 are supported");
38
-
39
- return fwd_cuda(input, filter, alpha, beta);
40
- }
41
-
42
- } // end namespace anti_alias_activation
43
 
44
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
45
- m.def("forward",
46
- &anti_alias_activation::fwd,
47
- "Anti Alias Activation -- Forward.");
48
- }
 
14
  * limitations under the License.
15
  */
16
 
17
+ #include <torch/extension.h>
 
 
18
 
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
 
 
alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
{alias_free_cuda → alias_free_activation/cuda}/compat.h RENAMED
@@ -18,8 +18,6 @@
18
  * https://github.com/NVIDIA/apex
19
  * with minor changes. */
20
 
21
-
22
-
23
  #ifndef TORCH_CHECK
24
  #define TORCH_CHECK AT_CHECK
25
  #endif
 
18
  * https://github.com/NVIDIA/apex
19
  * with minor changes. */
20
 
 
 
21
  #ifndef TORCH_CHECK
22
  #define TORCH_CHECK AT_CHECK
23
  #endif
{alias_free_cuda → alias_free_activation/cuda}/load.py RENAMED
@@ -7,26 +7,24 @@ import subprocess
7
 
8
  from torch.utils import cpp_extension
9
 
10
- # Setting this param to a list has a problem of generating different
11
- # compilation commands (with diferent order of architectures) and
12
- # leading to recompilation of fused kernels. Set it to empty string
13
- # to avoid recompilation and assign arch flags explicity in
14
- # extra_cuda_cflags below
15
  os.environ["TORCH_CUDA_ARCH_LIST"] = ""
16
 
17
 
18
  def load():
19
  # Check if cuda 11 is installed for compute capability 8.0
20
  cc_flag = []
21
- _, bare_metal_major, _ = _get_cuda_bare_metal_version(
22
- cpp_extension.CUDA_HOME)
23
  if int(bare_metal_major) >= 11:
24
- cc_flag.append('-gencode')
25
- cc_flag.append('arch=compute_80,code=sm_80')
26
 
27
  # Build path
28
  srcpath = pathlib.Path(__file__).parent.absolute()
29
- buildpath = srcpath / 'build'
30
  _create_build_dir(buildpath)
31
 
32
  # Helper function to build the kernels.
@@ -35,26 +33,42 @@ def load():
35
  name=name,
36
  sources=sources,
37
  build_directory=buildpath,
38
- extra_cflags=['-O3',],
39
- extra_cuda_cflags=['-O3',
40
- '-gencode', 'arch=compute_70,code=sm_70',
41
- '--use_fast_math'] + extra_cuda_flags + cc_flag,
42
- verbose=True
 
 
 
 
 
 
 
43
  )
44
 
45
- extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
46
- '-U__CUDA_NO_HALF_CONVERSIONS__',
47
- '--expt-relaxed-constexpr',
48
- '--expt-extended-lambda']
49
-
50
- sources=[srcpath / 'anti_alias_activation.cpp',
51
- srcpath / 'anti_alias_activation_cuda.cu']
 
 
 
 
52
  anti_alias_activation_cuda = _cpp_extention_load_helper(
53
- "anti_alias_activation_cuda", sources, extra_cuda_flags)
 
 
 
 
54
 
55
  def _get_cuda_bare_metal_version(cuda_dir):
56
- raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
57
- universal_newlines=True)
 
58
  output = raw_output.split()
59
  release_idx = output.index("release") + 1
60
  release = output[release_idx].split(".")
@@ -69,4 +83,4 @@ def _create_build_dir(buildpath):
69
  os.mkdir(buildpath)
70
  except OSError:
71
  if not os.path.isdir(buildpath):
72
- print(f"Creation of the build directory {buildpath} failed")
 
7
 
8
  from torch.utils import cpp_extension
9
 
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
 
14
  os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
 
16
 
17
  def load():
18
  # Check if cuda 11 is installed for compute capability 8.0
19
  cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
 
21
  if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
 
25
  # Build path
26
  srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
  _create_build_dir(buildpath)
29
 
30
  # Helper function to build the kernels.
 
33
  name=name,
34
  sources=sources,
35
  build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
  )
49
 
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
  anti_alias_activation_cuda = _cpp_extention_load_helper(
62
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
63
+ )
64
+
65
+ return anti_alias_activation_cuda
66
+
67
 
68
  def _get_cuda_bare_metal_version(cuda_dir):
69
+ raw_output = subprocess.check_output(
70
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71
+ )
72
  output = raw_output.split()
73
  release_idx = output.index("release") + 1
74
  release = output[release_idx].split(".")
 
83
  os.mkdir(buildpath)
84
  except OSError:
85
  if not os.path.isdir(buildpath):
86
+ print(f"Creation of the build directory {buildpath} failed")
alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
{alias_free_torch → alias_free_activation/torch}/__init__.py RENAMED
@@ -3,4 +3,4 @@
3
 
4
  from .filter import *
5
  from .resample import *
6
- from .act import *
 
3
 
4
  from .filter import *
5
  from .resample import *
6
+ from .act import *
{alias_free_torch → alias_free_activation/torch}/act.py RENAMED
@@ -2,16 +2,18 @@
2
  # LICENSE is in incl_licenses directory.
3
 
4
  import torch.nn as nn
5
- from .resample import UpSample1d, DownSample1d
6
 
7
 
8
  class Activation1d(nn.Module):
9
- def __init__(self,
10
- activation,
11
- up_ratio: int = 2,
12
- down_ratio: int = 2,
13
- up_kernel_size: int = 12,
14
- down_kernel_size: int = 12):
 
 
15
  super().__init__()
16
  self.up_ratio = up_ratio
17
  self.down_ratio = down_ratio
@@ -25,4 +27,4 @@ class Activation1d(nn.Module):
25
  x = self.act(x)
26
  x = self.downsample(x)
27
 
28
- return x
 
2
  # LICENSE is in incl_licenses directory.
3
 
4
  import torch.nn as nn
5
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
6
 
7
 
8
  class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
  super().__init__()
18
  self.up_ratio = up_ratio
19
  self.down_ratio = down_ratio
 
27
  x = self.act(x)
28
  x = self.downsample(x)
29
 
30
+ return x
{alias_free_torch → alias_free_activation/torch}/filter.py RENAMED
@@ -6,7 +6,7 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  import math
8
 
9
- if 'sinc' in dir(torch):
10
  sinc = torch.sinc
11
  else:
12
  # This code is adopted from adefossez's julius.core.sinc under the MIT License
@@ -17,40 +17,45 @@ else:
17
  Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
  __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
  """
20
- return torch.where(x == 0,
21
- torch.tensor(1., device=x.device, dtype=x.dtype),
22
- torch.sin(math.pi * x) / math.pi / x)
 
 
23
 
24
 
25
  # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
  # https://adefossez.github.io/julius/julius/lowpass.html
27
  # LICENSE is in incl_licenses directory.
28
- def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
- even = (kernel_size % 2 == 0)
 
 
30
  half_size = kernel_size // 2
31
 
32
- #For kaiser window
33
  delta_f = 4 * half_width
34
  A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
- if A > 50.:
36
  beta = 0.1102 * (A - 8.7)
37
- elif A >= 21.:
38
- beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
  else:
40
- beta = 0.
41
  window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
 
43
  # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
  if even:
45
- time = (torch.arange(-half_size, half_size) + 0.5)
46
  else:
47
  time = torch.arange(kernel_size) - half_size
48
  if cutoff == 0:
49
  filter_ = torch.zeros_like(time)
50
  else:
51
  filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
- # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
- # of the constant component in the input signal.
 
54
  filter_ /= filter_.sum()
55
  filter = filter_.view(1, 1, kernel_size)
56
 
@@ -58,22 +63,25 @@ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,
58
 
59
 
60
  class LowPassFilter1d(nn.Module):
61
- def __init__(self,
62
- cutoff=0.5,
63
- half_width=0.6,
64
- stride: int = 1,
65
- padding: bool = True,
66
- padding_mode: str = 'replicate',
67
- kernel_size: int = 12):
68
- # kernel_size should be even number for stylegan3 setup,
69
- # in this implementation, odd number is also possible.
 
 
 
70
  super().__init__()
71
- if cutoff < -0.:
72
  raise ValueError("Minimum cutoff must be larger than zero.")
73
  if cutoff > 0.5:
74
  raise ValueError("A cutoff above 0.5 does not make sense.")
75
  self.kernel_size = kernel_size
76
- self.even = (kernel_size % 2 == 0)
77
  self.pad_left = kernel_size // 2 - int(self.even)
78
  self.pad_right = kernel_size // 2
79
  self.stride = stride
@@ -82,14 +90,12 @@ class LowPassFilter1d(nn.Module):
82
  filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
  self.register_buffer("filter", filter)
84
 
85
- #input [B, C, T]
86
  def forward(self, x):
87
  _, C, _ = x.shape
88
 
89
  if self.padding:
90
- x = F.pad(x, (self.pad_left, self.pad_right),
91
- mode=self.padding_mode)
92
- out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
- stride=self.stride, groups=C)
94
 
95
- return out
 
6
  import torch.nn.functional as F
7
  import math
8
 
9
+ if "sinc" in dir(torch):
10
  sinc = torch.sinc
11
  else:
12
  # This code is adopted from adefossez's julius.core.sinc under the MIT License
 
17
  Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
  __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
  """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
 
26
 
27
  # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
  # https://adefossez.github.io/julius/julius/lowpass.html
29
  # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
  half_size = kernel_size // 2
35
 
36
+ # For kaiser window
37
  delta_f = 4 * half_width
38
  A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
  beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
  else:
44
+ beta = 0.0
45
  window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
 
47
  # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
  if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
  else:
51
  time = torch.arange(kernel_size) - half_size
52
  if cutoff == 0:
53
  filter_ = torch.zeros_like(time)
54
  else:
55
  filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
  filter_ /= filter_.sum()
60
  filter = filter_.view(1, 1, kernel_size)
61
 
 
63
 
64
 
65
  class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
  super().__init__()
79
+ if cutoff < -0.0:
80
  raise ValueError("Minimum cutoff must be larger than zero.")
81
  if cutoff > 0.5:
82
  raise ValueError("A cutoff above 0.5 does not make sense.")
83
  self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
  self.pad_left = kernel_size // 2 - int(self.even)
86
  self.pad_right = kernel_size // 2
87
  self.stride = stride
 
90
  filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
  self.register_buffer("filter", filter)
92
 
93
+ # Input [B, C, T]
94
  def forward(self, x):
95
  _, C, _ = x.shape
96
 
97
  if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
 
 
100
 
101
+ return out
{alias_free_torch → alias_free_activation/torch}/resample.py RENAMED
@@ -3,32 +3,37 @@
3
 
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
- from .filter import LowPassFilter1d
7
- from .filter import kaiser_sinc_filter1d
8
 
9
 
10
  class UpSample1d(nn.Module):
11
  def __init__(self, ratio=2, kernel_size=None):
12
  super().__init__()
13
  self.ratio = ratio
14
- self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
 
 
15
  self.stride = ratio
16
  self.pad = self.kernel_size // ratio - 1
17
  self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
- self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
- filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
- half_width=0.6 / ratio,
21
- kernel_size=self.kernel_size)
 
 
22
  self.register_buffer("filter", filter)
23
 
24
  # x: [B, C, T]
25
  def forward(self, x):
26
  _, C, _ = x.shape
27
 
28
- x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
  x = self.ratio * F.conv_transpose1d(
30
- x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
- x = x[..., self.pad_left:-self.pad_right]
 
32
 
33
  return x
34
 
@@ -37,13 +42,17 @@ class DownSample1d(nn.Module):
37
  def __init__(self, ratio=2, kernel_size=None):
38
  super().__init__()
39
  self.ratio = ratio
40
- self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
- self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
- half_width=0.6 / ratio,
43
- stride=ratio,
44
- kernel_size=self.kernel_size)
 
 
 
 
45
 
46
  def forward(self, x):
47
  xx = self.lowpass(x)
48
 
49
- return xx
 
3
 
4
  import torch.nn as nn
5
  from torch.nn import functional as F
6
+ from alias_free_activation.torch.filter import LowPassFilter1d
7
+ from alias_free_activation.torch.filter import kaiser_sinc_filter1d
8
 
9
 
10
  class UpSample1d(nn.Module):
11
  def __init__(self, ratio=2, kernel_size=None):
12
  super().__init__()
13
  self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
  self.stride = ratio
18
  self.pad = self.kernel_size // ratio - 1
19
  self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
  self.register_buffer("filter", filter)
27
 
28
  # x: [B, C, T]
29
  def forward(self, x):
30
  _, C, _ = x.shape
31
 
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
  x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
 
38
  return x
39
 
 
42
  def __init__(self, ratio=2, kernel_size=None):
43
  super().__init__()
44
  self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
 
55
  def forward(self, x):
56
  xx = self.lowpass(x)
57
 
58
+ return xx
alias_free_cuda/anti_alias_activation_cuda.cu DELETED
@@ -1,314 +0,0 @@
1
- /* coding=utf-8
2
- * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #include <ATen/ATen.h>
18
- #include <cuda.h>
19
- #include <cuda_runtime.h>
20
- #include <cuda_fp16.h>
21
- #include <cuda_profiler_api.h>
22
- #include <ATen/cuda/CUDAContext.h>
23
- #include <torch/extension.h>
24
- #include "type_shim.h"
25
- #include <assert.h>
26
- #include <cfloat>
27
- #include <limits>
28
- #include <stdint.h>
29
- #include <c10/macros/Macros.h>
30
-
31
- namespace {
32
-
33
- /*
34
- template <typename Datatype, int ELEMENTS_PER_LDG>
35
- __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
36
-
37
- template <>
38
- __device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
39
-
40
- template <>
41
- __device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
42
-
43
- template <>
44
- __device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
45
-
46
- template <>
47
- __device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
48
-
49
- template <>
50
- __device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
51
-
52
- template <>
53
- __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
54
-
55
- int log2_ceil(int value) {
56
- int log2_value = 0;
57
- while ((1 << log2_value) < value) ++log2_value;
58
- return log2_value;
59
- }
60
-
61
- template<typename T>
62
- struct Add {
63
- __device__ __forceinline__ T operator()(T a, T b) const {
64
- return a + b;
65
- }
66
- };
67
-
68
- template<typename T>
69
- struct Max {
70
- __device__ __forceinline__ T operator()(T a, T b) const {
71
- return a < b ? b : a;
72
- }
73
- };
74
-
75
- template <typename T>
76
- __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
77
- {
78
- #if CUDA_VERSION >= 9000
79
- return __shfl_xor_sync(mask, value, laneMask, width);
80
- #else
81
- return __shfl_xor(value, laneMask, width);
82
- #endif
83
- }
84
-
85
- template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
86
- __device__ __forceinline__ void warp_reduce(acc_t* sum) {
87
- ReduceOp<acc_t> r;
88
- #pragma unroll
89
- for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
90
- #pragma unroll
91
- for (int i = 0; i < WARP_BATCH; ++i) {
92
- acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
93
- sum[i] = r(sum[i], b);
94
- }
95
- }
96
- }
97
- */
98
-
99
- template <typename input_t, typename output_t, typename acc_t>
100
- __global__ void anti_alias_activation_forward(
101
- output_t *dst,
102
- const input_t *src,
103
- const input_t *ftr,
104
- const input_t *alpha,
105
- const input_t *beta,
106
- int batch_size,
107
- int channels,
108
- int seq_len)
109
- {
110
- // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
111
- constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
112
- constexpr int BUFFER_SIZE = 32;
113
- constexpr int FILTER_SIZE = 12;
114
- constexpr int HALF_FILTER_SIZE = 6;
115
- constexpr int REPLICATION_PAD = 5; // 5 on each side
116
-
117
- // blockDim/threadIdx = (128, 1, 1)
118
- // gridDim/blockIdx = (seq_blocks, channels, batches)
119
- int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
120
- int local_offset = threadIdx.x * BUFFER_SIZE;
121
- int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
122
-
123
-
124
- //int intermediate_seq_len = seq_len * 2 - 1 + 4 * REPLICATION_PAD;
125
- //int intermediate_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
126
- //int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
127
-
128
- int output_seq_len = seq_len * 2 ; //
129
- int output_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + output_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
130
- int output_local_offset = threadIdx.x * BUFFER_SIZE * 2;
131
- int output_seq_offset = blockIdx.x * 128 * BUFFER_SIZE *2 + output_local_offset;
132
- // get values needed for replication padding before moving pointer
133
- const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
134
- input_t seq_left_most_value = right_most_pntr[0];
135
- input_t seq_right_most_value = right_most_pntr[seq_len - 1];
136
-
137
- src += block_offset + local_offset;
138
- dst += output_block_offset + output_local_offset ;
139
- alpha = alpha + blockIdx.y;
140
- input_t alpha_val = expf(alpha[0]);
141
- beta = beta + blockIdx.y;
142
- input_t beta_val = expf(beta[0]);
143
- // load data from global memory
144
- input_t elements[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
145
- input_t intermediates[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
146
- //output_t output[2*BUFFER_SIZE];
147
- input_t filter[FILTER_SIZE];
148
- //input_t temp_data[ELEMENTS_PER_LDG_STG];
149
- //uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
150
-
151
- #pragma unroll
152
- for (int it = 0; it < FILTER_SIZE; it+=1) {
153
- filter[it] = ftr[it];
154
- }
155
-
156
-
157
- #pragma unroll
158
- for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE ; it+=1) {
159
- int element_index = seq_offset + it;
160
- if ((element_index < 0) && (element_index >= -REPLICATION_PAD)) {
161
- elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_left_most_value;
162
- }
163
- if ((element_index >= seq_len) && (element_index < seq_len + REPLICATION_PAD)) {
164
- elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_right_most_value;
165
- }
166
- if ((element_index >= 0) && (element_index < seq_len)) {
167
- elements[2*(HALF_FILTER_SIZE+it)] = 2*src[it];
168
- }
169
- }
170
-
171
-
172
-
173
- // apply filter
174
- #pragma unroll
175
- for (int it = 0; it < (2 * BUFFER_SIZE + 2*FILTER_SIZE); it+=1) {
176
- input_t acc = 0.0;
177
-
178
- int element_index = output_seq_offset + it; // index for output
179
- #pragma unroll
180
- for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
181
- if ((element_index + f_idx) >= 0){
182
- acc += filter[f_idx] * elements[it+f_idx];
183
- }
184
- }
185
- intermediates[it] = acc;
186
- }
187
-
188
- double no_div_by_zero = 0.000000001;
189
- #pragma unroll
190
- for (int it = 0; it < 12 + 2 * BUFFER_SIZE; it++) {
191
- intermediates[it] += (1.0/(beta_val + no_div_by_zero)) * sinf(intermediates[it] * alpha_val) * sinf(intermediates[it] * alpha_val);
192
- }
193
-
194
-
195
- // now copy to output
196
- #pragma unroll
197
- for (int it = 0; it < 2*BUFFER_SIZE; it+=1){
198
- int element_index = output_seq_offset + it;
199
- if (element_index < output_seq_len) {
200
- dst[it] = intermediates[it+6];
201
- }
202
- }
203
-
204
-
205
-
206
- // for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
207
- // int element_index = seq_offset + it;
208
- // if (element_index < seq_len) {
209
- // dst[it] = output[it];
210
- // }
211
- // }
212
-
213
-
214
- // // Upsample convolution
215
- // for (int it = 0; it < 2 * BUFFER_SIZE + 12; it+=1) {
216
- // input_t acc = 0.0;
217
-
218
- // for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
219
- // acc += filter[f_idx] * elements[it+f_idx];
220
- // }
221
- // intermediates[it] = acc;
222
- // }
223
-
224
- // // correct the corners of intermediates
225
- // if (seq_offset == 0) {
226
- // for (int it = 0; it < 6; it+=1)
227
- // intermediates[it] = 0;
228
- // }
229
-
230
- // if (seq_offset + 32 >= seq_len) {
231
- // int offset = seq_len % 32 == 0 ? 32 : seq_len % 32;
232
-
233
- // for (int it = 0; it < 6; it++) {
234
- // intermediates[6+2*offset+it] = 0;
235
- // }
236
- // }
237
-
238
-
239
-
240
-
241
- // for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
242
- // int element_index = seq_offset + it;
243
- // if (element_index < seq_len) {
244
- // dst[it] = output[it];
245
- // }
246
- // }
247
- }
248
-
249
- template<typename input_t, typename output_t, typename acc_t>
250
- void dispatch_anti_alias_activation_forward(
251
- output_t *dst,
252
- const input_t *src,
253
- const input_t *ftr,
254
- const input_t *alpha,
255
- const input_t *beta,
256
- int batch_size,
257
- int channels,
258
- int seq_len)
259
- {
260
- if (seq_len == 0) {
261
- return;
262
- } else {
263
- // use 128 threads per block to maximimize gpu utilization
264
- constexpr int threads_per_block = 128;
265
- constexpr int seq_len_per_block = 4096;
266
- int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
267
- dim3 blocks(blocks_per_seq_len, channels, batch_size);
268
- dim3 threads(threads_per_block, 1, 1);
269
-
270
- anti_alias_activation_forward<input_t, output_t, acc_t>
271
- <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, ftr, alpha, beta, batch_size, channels, seq_len);
272
- }
273
- }
274
- }
275
-
276
- namespace anti_alias_activation {
277
-
278
- torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& filter, torch::Tensor const& alpha, torch::Tensor const& beta)
279
- {
280
- // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
281
- const int batches = input.size(0);
282
- const int channels = input.size(1);
283
- const int seq_len = input.size(2);
284
-
285
- // Output
286
- auto act_options = input.options().requires_grad(false);
287
- int output_seq_len = seq_len*2; // we'll be dilating between each element by interspersing with zeros
288
-
289
- torch::Tensor anti_alias_activation_results =
290
- torch::empty({batches, channels, output_seq_len}, act_options);
291
-
292
- // Softmax Intermediate Result Ptr
293
- void* input_ptr = static_cast<void*>(input.data_ptr());
294
- void* filter_ptr = static_cast<void*>(filter.data_ptr());
295
- void* alpha_ptr = static_cast<void*>(alpha.data_ptr());
296
- void* beta_ptr = static_cast<void*>(beta.data_ptr());
297
- void* anti_alias_activation_results_ptr = static_cast<void*>(anti_alias_activation_results.data_ptr());
298
-
299
- DISPATCH_FLOAT_HALF_AND_BFLOAT(
300
- input.scalar_type(),
301
- "dispatch anti alias activation_forward",
302
- dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
303
- reinterpret_cast<scalar_t*>(anti_alias_activation_results_ptr),
304
- reinterpret_cast<const scalar_t*>(input_ptr),
305
- reinterpret_cast<const scalar_t*>(filter_ptr),
306
- reinterpret_cast<const scalar_t*>(alpha_ptr),
307
- reinterpret_cast<const scalar_t*>(beta_ptr),
308
- batches,
309
- channels,
310
- seq_len);
311
- );
312
- return anti_alias_activation_results;
313
- }
314
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
alias_free_cuda/test_activation.py DELETED
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION.
2
- # Licensed under the MIT license.
3
-
4
- import math
5
- import torch
6
- import alias_free_cuda
7
- from alias_free_cuda import activation1d
8
- from activations import Snake, SnakeBeta
9
-
10
- def test_load_fused_kernels():
11
- try:
12
- import alias_free_cuda
13
- import torch
14
- print("[Success] load_fused_kernels")
15
- except ImportError as e:
16
- print("[Fail] load_fused_kernels")
17
- raise e
18
-
19
- def test_anti_alias_activation():
20
- data = torch.rand((10, 10, 50000), device='cuda')
21
-
22
- # check activations.Snake cuda vs. torch
23
- fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
24
- fused_activation_output = fused_anti_alias_activation(data)
25
-
26
- torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
27
- torch_activation_output = torch_anti_alias_activation(data)
28
-
29
- test_result = (fused_activation_output - torch_activation_output).abs()
30
-
31
- while test_result.dim() != 1:
32
- test_result = test_result.mean(dim=-1)
33
-
34
- diff = test_result.mean(dim=-1)
35
-
36
- if diff <= 1e-3:
37
- print(
38
- f"\n[Success] test_fused_anti_alias_activation"
39
- f"\n > mean_difference={diff}"
40
- f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
41
- f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
42
- )
43
- else:
44
- print(
45
- f"\n[Fail] test_fused_anti_alias_activation"
46
- f"\n > mean_difference={diff}, "
47
- f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
48
- f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
49
- )
50
-
51
- if __name__ == "__main__":
52
- from alias_free_cuda import load
53
- load.load()
54
- test_load_fused_kernels()
55
- test_anti_alias_activation()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
alias_free_cuda/test_activation_snake_beta.py DELETED
@@ -1,55 +0,0 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION.
2
- # Licensed under the MIT license.
3
-
4
- import math
5
- import torch
6
- import alias_free_cuda
7
- from alias_free_cuda import activation1d
8
- from activations import Snake, SnakeBeta
9
-
10
- def test_load_fused_kernels():
11
- try:
12
- import alias_free_cuda
13
- import torch
14
- print("[Success] load_fused_kernels")
15
- except ImportError as e:
16
- print("[Fail] load_fused_kernels")
17
- raise e
18
-
19
- def test_anti_alias_activation():
20
- data = torch.rand((10, 10, 50000), device='cuda')
21
-
22
- # check activations.Snake cuda vs. torch
23
- fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
24
- fused_activation_output = fused_anti_alias_activation(data)
25
-
26
- torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
27
- torch_activation_output = torch_anti_alias_activation(data)
28
-
29
- test_result = (fused_activation_output - torch_activation_output).abs()
30
-
31
- while test_result.dim() != 1:
32
- test_result = test_result.mean(dim=-1)
33
-
34
- diff = test_result.mean(dim=-1)
35
-
36
- if diff <= 1e-3:
37
- print(
38
- f"\n[Success] test_fused_anti_alias_activation"
39
- f"\n > mean_difference={diff}"
40
- f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
41
- f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
42
- )
43
- else:
44
- print(
45
- f"\n[Fail] test_fused_anti_alias_activation"
46
- f"\n > mean_difference={diff}, "
47
- f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
48
- f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
49
- )
50
-
51
- if __name__ == "__main__":
52
- from alias_free_cuda import load
53
- load.load()
54
- test_load_fused_kernels()
55
- test_anti_alias_activation()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
alias_free_cuda/type_shim.h DELETED
@@ -1,97 +0,0 @@
1
- /* coding=utf-8
2
- * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
-
18
- #include <ATen/ATen.h>
19
- #include "compat.h"
20
-
21
-
22
- #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
23
- switch(TYPE) \
24
- { \
25
- case at::ScalarType::Float: \
26
- { \
27
- using scalar_t = float; \
28
- __VA_ARGS__; \
29
- break; \
30
- } \
31
- case at::ScalarType::Half: \
32
- { \
33
- using scalar_t = at::Half; \
34
- __VA_ARGS__; \
35
- break; \
36
- } \
37
- case at::ScalarType::BFloat16: \
38
- { \
39
- using scalar_t = at::BFloat16; \
40
- __VA_ARGS__; \
41
- break; \
42
- } \
43
- default: \
44
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
45
- }
46
-
47
-
48
-
49
- #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
50
- switch(TYPEIN) \
51
- { \
52
- case at::ScalarType::Float: \
53
- { \
54
- using scalar_t_in = float; \
55
- switch(TYPEOUT) \
56
- { \
57
- case at::ScalarType::Float: \
58
- { \
59
- using scalar_t_out = float; \
60
- __VA_ARGS__; \
61
- break; \
62
- } \
63
- case at::ScalarType::Half: \
64
- { \
65
- using scalar_t_out = at::Half; \
66
- __VA_ARGS__; \
67
- break; \
68
- } \
69
- case at::ScalarType::BFloat16: \
70
- { \
71
- using scalar_t_out = at::BFloat16; \
72
- __VA_ARGS__; \
73
- break; \
74
- } \
75
- default: \
76
- AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
77
- } \
78
- break; \
79
- } \
80
- case at::ScalarType::Half: \
81
- { \
82
- using scalar_t_in = at::Half; \
83
- using scalar_t_out = at::Half; \
84
- __VA_ARGS__; \
85
- break; \
86
- } \
87
- case at::ScalarType::BFloat16: \
88
- { \
89
- using scalar_t_in = at::BFloat16; \
90
- using scalar_t_out = at::BFloat16; \
91
- __VA_ARGS__; \
92
- break; \
93
- } \
94
- default: \
95
- AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
96
- }
97
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -273,7 +273,7 @@ with iface:
273
  <h3>News</h3>
274
  <p>[Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:</p>
275
  <ul>
276
- <li>Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.</li>
277
  <li>Improved discriminator and loss: BigVGAN-v2 is trained using a <a href="https://arxiv.org/abs/2311.14957" target="_blank">multi-scale sub-band CQT discriminator</a> and a <a href="https://arxiv.org/abs/2306.06546" target="_blank">multi-scale mel spectrogram loss</a>.</li>
278
  <li>Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.</li>
279
  <li>We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio. See the table below for the link.</li>
 
273
  <h3>News</h3>
274
  <p>[Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:</p>
275
  <ul>
276
+ <li>Custom CUDA kernel for inference: we provide a fused anti-aliased activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.</li>
277
  <li>Improved discriminator and loss: BigVGAN-v2 is trained using a <a href="https://arxiv.org/abs/2311.14957" target="_blank">multi-scale sub-band CQT discriminator</a> and a <a href="https://arxiv.org/abs/2306.06546" target="_blank">multi-scale mel spectrogram loss</a>.</li>
278
  <li>Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.</li>
279
  <li>We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio. See the table below for the link.</li>
bigvgan.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION.
2
  # Licensed under the MIT license.
3
 
4
  # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
@@ -7,78 +7,127 @@
7
  import os
8
  import json
9
  from pathlib import Path
10
-
11
- from collections import namedtuple
12
- from typing import Optional, List, Union, Dict
13
 
14
  import torch
15
- import torch.nn.functional as F
16
  import torch.nn as nn
17
  from torch.nn import Conv1d, ConvTranspose1d
18
  from torch.nn.utils import weight_norm, remove_weight_norm
19
 
20
  import activations
21
  from utils import init_weights, get_padding
22
- from alias_free_torch.act import Activation1d as TorchActivation1d
23
  from env import AttrDict
24
 
25
  from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
26
 
 
27
  def load_hparams_from_json(path) -> AttrDict:
28
  with open(path) as f:
29
  data = f.read()
30
- h = json.loads(data)
31
- return AttrDict(h)
32
 
33
  class AMPBlock1(torch.nn.Module):
34
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
35
- super(AMPBlock1, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self.h = h
37
 
38
- self.convs1 = nn.ModuleList([
39
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
40
- padding=get_padding(kernel_size, dilation[0]))),
41
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
42
- padding=get_padding(kernel_size, dilation[1]))),
43
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
44
- padding=get_padding(kernel_size, dilation[2])))
45
- ])
 
 
 
 
 
 
 
46
  self.convs1.apply(init_weights)
47
 
48
- self.convs2 = nn.ModuleList([
49
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
- padding=get_padding(kernel_size, 1))),
51
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
- padding=get_padding(kernel_size, 1))),
53
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
54
- padding=get_padding(kernel_size, 1)))
55
- ])
 
 
 
 
 
 
 
56
  self.convs2.apply(init_weights)
57
 
58
- self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
 
 
59
 
60
- # select which Activation1d, lazy-load cuda version to ensure backward compatibility
61
  if self.h.get("use_cuda_kernel", False):
62
- # faster CUDA kernel implementation of Activation1d
63
- from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
 
 
64
  Activation1d = CudaActivation1d
65
  else:
66
  Activation1d = TorchActivation1d
67
 
68
- if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
69
- self.activations = nn.ModuleList([
70
- Activation1d(
71
- activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
72
- for _ in range(self.num_layers)
73
- ])
74
- elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
75
- self.activations = nn.ModuleList([
76
- Activation1d(
77
- activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
78
- for _ in range(self.num_layers)
79
- ])
 
 
 
 
 
 
 
 
 
 
 
80
  else:
81
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
 
 
82
 
83
  def forward(self, x):
84
  acts1, acts2 = self.activations[::2], self.activations[1::2]
@@ -99,51 +148,93 @@ class AMPBlock1(torch.nn.Module):
99
 
100
 
101
  class AMPBlock2(torch.nn.Module):
102
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
103
- super(AMPBlock2, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  self.h = h
105
 
106
- self.convs = nn.ModuleList([
107
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
108
- padding=get_padding(kernel_size, dilation[0]))),
109
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
110
- padding=get_padding(kernel_size, dilation[1])))
111
- ])
 
 
 
 
 
 
 
 
 
112
  self.convs.apply(init_weights)
113
 
114
- self.num_layers = len(self.convs) # total number of conv layers
115
 
116
- # select which Activation1d, lazy-load cuda version to ensure backward compatibility
117
  if self.h.get("use_cuda_kernel", False):
118
- # faster CUDA kernel implementation of Activation1d
119
- from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
 
 
120
  Activation1d = CudaActivation1d
121
  else:
122
  Activation1d = TorchActivation1d
123
 
124
- if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
125
- self.activations = nn.ModuleList([
126
- Activation1d(
127
- activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
128
- for _ in range(self.num_layers)
129
- ])
130
- elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
131
- self.activations = nn.ModuleList([
132
- Activation1d(
133
- activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
134
- for _ in range(self.num_layers)
135
- ])
 
 
 
 
 
 
 
 
 
 
 
136
  else:
137
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
 
 
138
 
139
  def forward(self, x):
140
- for c, a in zip (self.convs, self.activations):
141
  xt = a(x)
142
  xt = c(xt)
143
  x = xt + x
144
 
145
- return x
146
-
147
  def remove_weight_norm(self):
148
  for l in self.convs:
149
  remove_weight_norm(l)
@@ -157,83 +248,121 @@ class BigVGAN(
157
  docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
158
  pipeline_tag="audio-to-audio",
159
  license="mit",
160
- tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"]
161
  ):
162
- # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
163
- # New in v2: if use_cuda_kernel is set to True, it loads optimized CUDA kernels for AMP.
164
- # NOTE: use_cuda_kernel=True should be used for inference only (training is not supported).
165
- def __init__(
166
- self,
167
- h,
168
- use_cuda_kernel: bool=False
169
- ):
170
- super(BigVGAN, self).__init__()
 
 
 
 
 
 
171
  self.h = h
172
- self.h["use_cuda_kernel"] = use_cuda_kernel # add it to global hyperparameters (h)
 
 
 
 
 
 
 
 
 
 
173
 
174
  self.num_kernels = len(h.resblock_kernel_sizes)
175
  self.num_upsamples = len(h.upsample_rates)
176
 
177
- # pre conv
178
- self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
 
 
179
 
180
- # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
181
- resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
 
 
 
 
 
 
 
182
 
183
- # transposed conv-based upsamplers. does not apply anti-aliasing
184
  self.ups = nn.ModuleList()
185
  for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
186
- self.ups.append(nn.ModuleList([
187
- weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
188
- h.upsample_initial_channel // (2 ** (i + 1)),
189
- k, u, padding=(k - u) // 2))
190
- ]))
 
 
 
 
 
 
 
 
 
 
191
 
192
- # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
193
  self.resblocks = nn.ModuleList()
194
  for i in range(len(self.ups)):
195
  ch = h.upsample_initial_channel // (2 ** (i + 1))
196
- for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
197
- self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
198
-
199
- # select which Activation1d, lazy-load cuda version to ensure backward compatibility
200
- if self.h.get("use_cuda_kernel", False):
201
- # faster CUDA kernel implementation of Activation1d
202
- from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
203
- Activation1d = CudaActivation1d
204
- else:
205
- Activation1d = TorchActivation1d
206
 
207
- # post conv
208
- if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
209
- activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
210
- self.activation_post = Activation1d(activation=activation_post)
211
- elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
212
- activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
213
- self.activation_post = Activation1d(activation=activation_post)
214
- else:
215
- raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
216
-
217
- # whether to use bias for the final conv_post. Defaults to True for backward compatibility
 
 
 
 
 
 
 
218
  self.use_bias_at_final = h.get("use_bias_at_final", True)
219
- self.conv_post = weight_norm(Conv1d(
220
- ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final
221
- ))
222
 
223
- # weight initialization
224
  for i in range(len(self.ups)):
225
  self.ups[i].apply(init_weights)
226
  self.conv_post.apply(init_weights)
227
-
228
- # final tanh activation. Defaults to True for backward compatibility
229
  self.use_tanh_at_final = h.get("use_tanh_at_final", True)
230
 
231
  def forward(self, x):
232
- # pre conv
233
  x = self.conv_pre(x)
234
 
235
  for i in range(self.num_upsamples):
236
- # upsampling
237
  for i_up in range(len(self.ups[i])):
238
  x = self.ups[i][i_up](x)
239
  # AMP blocks
@@ -245,20 +374,20 @@ class BigVGAN(
245
  xs += self.resblocks[i * self.num_kernels + j](x)
246
  x = xs / self.num_kernels
247
 
248
- # post conv
249
  x = self.activation_post(x)
250
  x = self.conv_post(x)
251
- # final tanh activation
252
  if self.use_tanh_at_final:
253
  x = torch.tanh(x)
254
  else:
255
- x = torch.clamp(x, min=-1., max=1.) # bound the output to [-1, 1]
256
 
257
  return x
258
 
259
  def remove_weight_norm(self):
260
  try:
261
- print('Removing weight norm...')
262
  for l in self.ups:
263
  for l_i in l:
264
  remove_weight_norm(l_i)
@@ -267,23 +396,18 @@ class BigVGAN(
267
  remove_weight_norm(self.conv_pre)
268
  remove_weight_norm(self.conv_post)
269
  except ValueError:
270
- print('[INFO] Model already removed weight norm. Skipping!')
271
  pass
272
 
273
- ##################################################################
274
- # additional methods for huggingface_hub support
275
- ##################################################################
276
  def _save_pretrained(self, save_directory: Path) -> None:
277
  """Save weights and config.json from a Pytorch model to a local directory."""
278
 
279
- model_path = save_directory / 'bigvgan_generator.pt'
280
- torch.save(
281
- {'generator': self.state_dict()},
282
- model_path
283
- )
284
 
285
- config_path = save_directory / 'config.json'
286
- with open(config_path, 'w') as config_file:
287
  json.dump(self.h, config_file, indent=4)
288
 
289
  @classmethod
@@ -298,23 +422,21 @@ class BigVGAN(
298
  resume_download: bool,
299
  local_files_only: bool,
300
  token: Union[str, bool, None],
301
- map_location: str = "cpu", # additional argument
302
- strict: bool = False, # additional argument
303
  use_cuda_kernel: bool = False,
304
  **model_kwargs,
305
  ):
306
  """Load Pytorch pretrained weights and return the loaded model."""
307
 
308
- ##################################################################
309
- # download and load hyperparameters (h) used by BigVGAN
310
- ##################################################################
311
  if os.path.isdir(model_id):
312
  print("Loading config.json from local directory")
313
- config_file = os.path.join(model_id, 'config.json')
314
  else:
315
  config_file = hf_hub_download(
316
  repo_id=model_id,
317
- filename='config.json',
318
  revision=revision,
319
  cache_dir=cache_dir,
320
  force_download=force_download,
@@ -325,26 +447,28 @@ class BigVGAN(
325
  )
326
  h = load_hparams_from_json(config_file)
327
 
328
- ##################################################################
329
  # instantiate BigVGAN using h
330
- ##################################################################
331
  if use_cuda_kernel:
332
- print(f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!")
333
- print(f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!")
334
- print(f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis")
 
 
 
 
 
 
335
  model = cls(h, use_cuda_kernel=use_cuda_kernel)
336
 
337
- ##################################################################
338
- # download and load pretrained generator weight
339
- ##################################################################
340
  if os.path.isdir(model_id):
341
  print("Loading weights from local directory")
342
- model_file = os.path.join(model_id, 'bigvgan_generator.pt')
343
  else:
344
  print(f"Loading weights from {model_id}")
345
  model_file = hf_hub_download(
346
  repo_id=model_id,
347
- filename='bigvgan_generator.pt',
348
  revision=revision,
349
  cache_dir=cache_dir,
350
  force_download=force_download,
@@ -352,15 +476,17 @@ class BigVGAN(
352
  resume_download=resume_download,
353
  token=token,
354
  local_files_only=local_files_only,
355
- )
356
-
357
  checkpoint_dict = torch.load(model_file, map_location=map_location)
358
 
359
  try:
360
- model.load_state_dict(checkpoint_dict['generator'])
361
  except RuntimeError:
362
- print(f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!")
 
 
363
  model.remove_weight_norm()
364
- model.load_state_dict(checkpoint_dict['generator'])
365
 
366
- return model
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
  # Licensed under the MIT license.
3
 
4
  # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
 
7
  import os
8
  import json
9
  from pathlib import Path
10
+ from typing import Optional, Union, Dict
 
 
11
 
12
  import torch
 
13
  import torch.nn as nn
14
  from torch.nn import Conv1d, ConvTranspose1d
15
  from torch.nn.utils import weight_norm, remove_weight_norm
16
 
17
  import activations
18
  from utils import init_weights, get_padding
19
+ from alias_free_activation.torch.act import Activation1d as TorchActivation1d
20
  from env import AttrDict
21
 
22
  from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
23
 
24
+
25
  def load_hparams_from_json(path) -> AttrDict:
26
  with open(path) as f:
27
  data = f.read()
28
+ return AttrDict(json.loads(data))
29
+
30
 
31
  class AMPBlock1(torch.nn.Module):
32
+ """
33
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
34
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
35
+
36
+ Args:
37
+ h (AttrDict): Hyperparameters.
38
+ channels (int): Number of convolution channels.
39
+ kernel_size (int): Size of the convolution kernel. Default is 3.
40
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
41
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ h: AttrDict,
47
+ channels: int,
48
+ kernel_size: int = 3,
49
+ dilation: tuple = (1, 3, 5),
50
+ activation: str = None,
51
+ ):
52
+ super().__init__()
53
+
54
  self.h = h
55
 
56
+ self.convs1 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=d,
65
+ padding=get_padding(kernel_size, d),
66
+ )
67
+ )
68
+ for d in dilation
69
+ ]
70
+ )
71
  self.convs1.apply(init_weights)
72
 
73
+ self.convs2 = nn.ModuleList(
74
+ [
75
+ weight_norm(
76
+ Conv1d(
77
+ channels,
78
+ channels,
79
+ kernel_size,
80
+ stride=1,
81
+ dilation=1,
82
+ padding=get_padding(kernel_size, 1),
83
+ )
84
+ )
85
+ for _ in range(len(dilation))
86
+ ]
87
+ )
88
  self.convs2.apply(init_weights)
89
 
90
+ self.num_layers = len(self.convs1) + len(
91
+ self.convs2
92
+ ) # Total number of conv layers
93
 
94
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
95
  if self.h.get("use_cuda_kernel", False):
96
+ from alias_free_activation.cuda.activation1d import (
97
+ Activation1d as CudaActivation1d,
98
+ )
99
+
100
  Activation1d = CudaActivation1d
101
  else:
102
  Activation1d = TorchActivation1d
103
 
104
+ # Activation functions
105
+ if activation == "snake":
106
+ self.activations = nn.ModuleList(
107
+ [
108
+ Activation1d(
109
+ activation=activations.Snake(
110
+ channels, alpha_logscale=h.snake_logscale
111
+ )
112
+ )
113
+ for _ in range(self.num_layers)
114
+ ]
115
+ )
116
+ elif activation == "snakebeta":
117
+ self.activations = nn.ModuleList(
118
+ [
119
+ Activation1d(
120
+ activation=activations.SnakeBeta(
121
+ channels, alpha_logscale=h.snake_logscale
122
+ )
123
+ )
124
+ for _ in range(self.num_layers)
125
+ ]
126
+ )
127
  else:
128
+ raise NotImplementedError(
129
+ "activation incorrectly specified. check the config file and look for 'activation'."
130
+ )
131
 
132
  def forward(self, x):
133
  acts1, acts2 = self.activations[::2], self.activations[1::2]
 
148
 
149
 
150
  class AMPBlock2(torch.nn.Module):
151
+ """
152
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
153
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
154
+
155
+ Args:
156
+ h (AttrDict): Hyperparameters.
157
+ channels (int): Number of convolution channels.
158
+ kernel_size (int): Size of the convolution kernel. Default is 3.
159
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
160
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ h: AttrDict,
166
+ channels: int,
167
+ kernel_size: int = 3,
168
+ dilation: tuple = (1, 3, 5),
169
+ activation: str = None,
170
+ ):
171
+ super().__init__()
172
+
173
  self.h = h
174
 
175
+ self.convs = nn.ModuleList(
176
+ [
177
+ weight_norm(
178
+ Conv1d(
179
+ channels,
180
+ channels,
181
+ kernel_size,
182
+ stride=1,
183
+ dilation=d,
184
+ padding=get_padding(kernel_size, d),
185
+ )
186
+ )
187
+ for d in dilation
188
+ ]
189
+ )
190
  self.convs.apply(init_weights)
191
 
192
+ self.num_layers = len(self.convs) # Total number of conv layers
193
 
194
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
195
  if self.h.get("use_cuda_kernel", False):
196
+ from alias_free_activation.cuda.activation1d import (
197
+ Activation1d as CudaActivation1d,
198
+ )
199
+
200
  Activation1d = CudaActivation1d
201
  else:
202
  Activation1d = TorchActivation1d
203
 
204
+ # Activation functions
205
+ if activation == "snake":
206
+ self.activations = nn.ModuleList(
207
+ [
208
+ Activation1d(
209
+ activation=activations.Snake(
210
+ channels, alpha_logscale=h.snake_logscale
211
+ )
212
+ )
213
+ for _ in range(self.num_layers)
214
+ ]
215
+ )
216
+ elif activation == "snakebeta":
217
+ self.activations = nn.ModuleList(
218
+ [
219
+ Activation1d(
220
+ activation=activations.SnakeBeta(
221
+ channels, alpha_logscale=h.snake_logscale
222
+ )
223
+ )
224
+ for _ in range(self.num_layers)
225
+ ]
226
+ )
227
  else:
228
+ raise NotImplementedError(
229
+ "activation incorrectly specified. check the config file and look for 'activation'."
230
+ )
231
 
232
  def forward(self, x):
233
+ for c, a in zip(self.convs, self.activations):
234
  xt = a(x)
235
  xt = c(xt)
236
  x = xt + x
237
 
 
 
238
  def remove_weight_norm(self):
239
  for l in self.convs:
240
  remove_weight_norm(l)
 
248
  docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
249
  pipeline_tag="audio-to-audio",
250
  license="mit",
251
+ tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
252
  ):
253
+ """
254
+ BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
255
+ New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
256
+
257
+ Args:
258
+ h (AttrDict): Hyperparameters.
259
+ use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
260
+
261
+ Note:
262
+ - The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
263
+ - Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
264
+ """
265
+
266
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
267
+ super().__init__()
268
  self.h = h
269
+ self.h["use_cuda_kernel"] = use_cuda_kernel
270
+
271
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
272
+ if self.h.get("use_cuda_kernel", False):
273
+ from alias_free_activation.cuda.activation1d import (
274
+ Activation1d as CudaActivation1d,
275
+ )
276
+
277
+ Activation1d = CudaActivation1d
278
+ else:
279
+ Activation1d = TorchActivation1d
280
 
281
  self.num_kernels = len(h.resblock_kernel_sizes)
282
  self.num_upsamples = len(h.upsample_rates)
283
 
284
+ # Pre-conv
285
+ self.conv_pre = weight_norm(
286
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
287
+ )
288
 
289
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
290
+ if h.resblock == "1":
291
+ resblock_class = AMPBlock1
292
+ elif h.resblock == "2":
293
+ resblock_class = AMPBlock2
294
+ else:
295
+ raise ValueError(
296
+ f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
297
+ )
298
 
299
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
300
  self.ups = nn.ModuleList()
301
  for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
302
+ self.ups.append(
303
+ nn.ModuleList(
304
+ [
305
+ weight_norm(
306
+ ConvTranspose1d(
307
+ h.upsample_initial_channel // (2**i),
308
+ h.upsample_initial_channel // (2 ** (i + 1)),
309
+ k,
310
+ u,
311
+ padding=(k - u) // 2,
312
+ )
313
+ )
314
+ ]
315
+ )
316
+ )
317
 
318
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
319
  self.resblocks = nn.ModuleList()
320
  for i in range(len(self.ups)):
321
  ch = h.upsample_initial_channel // (2 ** (i + 1))
322
+ for j, (k, d) in enumerate(
323
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
324
+ ):
325
+ self.resblocks.append(
326
+ resblock_class(h, ch, k, d, activation=h.activation)
327
+ )
 
 
 
 
328
 
329
+ # Post-conv
330
+ activation_post = (
331
+ activations.Snake(ch, alpha_logscale=h.snake_logscale)
332
+ if h.activation == "snake"
333
+ else (
334
+ activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
335
+ if h.activation == "snakebeta"
336
+ else None
337
+ )
338
+ )
339
+ if activation_post is None:
340
+ raise NotImplementedError(
341
+ "activation incorrectly specified. check the config file and look for 'activation'."
342
+ )
343
+
344
+ self.activation_post = Activation1d(activation=activation_post)
345
+
346
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
347
  self.use_bias_at_final = h.get("use_bias_at_final", True)
348
+ self.conv_post = weight_norm(
349
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
350
+ )
351
 
352
+ # Weight initialization
353
  for i in range(len(self.ups)):
354
  self.ups[i].apply(init_weights)
355
  self.conv_post.apply(init_weights)
356
+
357
+ # Final tanh activation. Defaults to True for backward compatibility
358
  self.use_tanh_at_final = h.get("use_tanh_at_final", True)
359
 
360
  def forward(self, x):
361
+ # Pre-conv
362
  x = self.conv_pre(x)
363
 
364
  for i in range(self.num_upsamples):
365
+ # Upsampling
366
  for i_up in range(len(self.ups[i])):
367
  x = self.ups[i][i_up](x)
368
  # AMP blocks
 
374
  xs += self.resblocks[i * self.num_kernels + j](x)
375
  x = xs / self.num_kernels
376
 
377
+ # Post-conv
378
  x = self.activation_post(x)
379
  x = self.conv_post(x)
380
+ # Final tanh activation
381
  if self.use_tanh_at_final:
382
  x = torch.tanh(x)
383
  else:
384
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
385
 
386
  return x
387
 
388
  def remove_weight_norm(self):
389
  try:
390
+ print("Removing weight norm...")
391
  for l in self.ups:
392
  for l_i in l:
393
  remove_weight_norm(l_i)
 
396
  remove_weight_norm(self.conv_pre)
397
  remove_weight_norm(self.conv_post)
398
  except ValueError:
399
+ print("[INFO] Model already removed weight norm. Skipping!")
400
  pass
401
 
402
+ # Additional methods for huggingface_hub support
 
 
403
  def _save_pretrained(self, save_directory: Path) -> None:
404
  """Save weights and config.json from a Pytorch model to a local directory."""
405
 
406
+ model_path = save_directory / "bigvgan_generator.pt"
407
+ torch.save({"generator": self.state_dict()}, model_path)
 
 
 
408
 
409
+ config_path = save_directory / "config.json"
410
+ with open(config_path, "w") as config_file:
411
  json.dump(self.h, config_file, indent=4)
412
 
413
  @classmethod
 
422
  resume_download: bool,
423
  local_files_only: bool,
424
  token: Union[str, bool, None],
425
+ map_location: str = "cpu", # Additional argument
426
+ strict: bool = False, # Additional argument
427
  use_cuda_kernel: bool = False,
428
  **model_kwargs,
429
  ):
430
  """Load Pytorch pretrained weights and return the loaded model."""
431
 
432
+ # Download and load hyperparameters (h) used by BigVGAN
 
 
433
  if os.path.isdir(model_id):
434
  print("Loading config.json from local directory")
435
+ config_file = os.path.join(model_id, "config.json")
436
  else:
437
  config_file = hf_hub_download(
438
  repo_id=model_id,
439
+ filename="config.json",
440
  revision=revision,
441
  cache_dir=cache_dir,
442
  force_download=force_download,
 
447
  )
448
  h = load_hparams_from_json(config_file)
449
 
 
450
  # instantiate BigVGAN using h
 
451
  if use_cuda_kernel:
452
+ print(
453
+ f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
454
+ )
455
+ print(
456
+ f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
457
+ )
458
+ print(
459
+ f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
460
+ )
461
  model = cls(h, use_cuda_kernel=use_cuda_kernel)
462
 
463
+ # Download and load pretrained generator weight
 
 
464
  if os.path.isdir(model_id):
465
  print("Loading weights from local directory")
466
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
467
  else:
468
  print(f"Loading weights from {model_id}")
469
  model_file = hf_hub_download(
470
  repo_id=model_id,
471
+ filename="bigvgan_generator.pt",
472
  revision=revision,
473
  cache_dir=cache_dir,
474
  force_download=force_download,
 
476
  resume_download=resume_download,
477
  token=token,
478
  local_files_only=local_files_only,
479
+ )
480
+
481
  checkpoint_dict = torch.load(model_file, map_location=map_location)
482
 
483
  try:
484
+ model.load_state_dict(checkpoint_dict["generator"])
485
  except RuntimeError:
486
+ print(
487
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
488
+ )
489
  model.remove_weight_norm()
490
+ model.load_state_dict(checkpoint_dict["generator"])
491
 
492
+ return model
meldataset.py CHANGED
@@ -1,66 +1,354 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION.
2
  # Licensed under the MIT license.
3
 
4
  # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
  # LICENSE is in incl_licenses directory.
6
 
 
 
 
7
  import torch
8
  import torch.utils.data
9
  import numpy as np
 
10
  from scipy.io.wavfile import read
11
  from librosa.filters import mel as librosa_mel_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
14
 
15
  def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
  return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
 
 
18
  def dynamic_range_decompression(x, C=1):
19
  return np.exp(x) / C
20
 
 
21
  def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
22
  return torch.log(torch.clamp(x, min=clip_val) * C)
23
 
 
24
  def dynamic_range_decompression_torch(x, C=1):
25
  return torch.exp(x) / C
26
 
 
27
  def spectral_normalize_torch(magnitudes):
28
- output = dynamic_range_compression_torch(magnitudes)
29
- return output
30
 
31
  def spectral_de_normalize_torch(magnitudes):
32
- output = dynamic_range_decompression_torch(magnitudes)
33
- return output
 
 
 
34
 
35
- mel_basis = {}
36
- hann_window = {}
37
 
38
- def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
39
- if torch.min(y) < -1.:
40
- print('min value is ', torch.min(y))
41
- if torch.max(y) > 1.:
42
- print('max value is ', torch.max(y))
 
 
 
 
 
 
 
 
 
43
 
44
- global mel_basis, hann_window
45
- if fmax not in mel_basis:
46
- mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
47
- str_key_mel_basis = str(fmax)+'_'+str(y.device)
48
- mel_basis[str_key_mel_basis] = torch.from_numpy(mel).float().to(y.device)
49
- hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
 
 
 
 
50
 
51
- y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
52
- y = y.squeeze(1)
 
 
 
 
 
53
 
54
- # complex tensor as default, then use view_as_real for future pytorch compatibility
55
- spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
56
- center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
57
- spec = torch.view_as_real(spec)
58
- spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
59
 
60
- spec = torch.matmul(mel_basis[str_key_mel_basis], spec)
61
- spec = spectral_normalize_torch(spec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- return spec
64
 
65
  def get_mel_spectrogram(wav, h):
66
- return mel_spectrogram(wav, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
  # Licensed under the MIT license.
3
 
4
  # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
  # LICENSE is in incl_licenses directory.
6
 
7
+ import math
8
+ import os
9
+ import random
10
  import torch
11
  import torch.utils.data
12
  import numpy as np
13
+ from librosa.util import normalize
14
  from scipy.io.wavfile import read
15
  from librosa.filters import mel as librosa_mel_fn
16
+ import pathlib
17
+ from tqdm import tqdm
18
+
19
+ MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
20
+
21
+
22
+ def load_wav(full_path, sr_target):
23
+ sampling_rate, data = read(full_path)
24
+ if sampling_rate != sr_target:
25
+ raise RuntimeError(
26
+ f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
27
+ )
28
+ return data, sampling_rate
29
 
 
30
 
31
  def dynamic_range_compression(x, C=1, clip_val=1e-5):
32
  return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
33
 
34
+
35
  def dynamic_range_decompression(x, C=1):
36
  return np.exp(x) / C
37
 
38
+
39
  def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
40
  return torch.log(torch.clamp(x, min=clip_val) * C)
41
 
42
+
43
  def dynamic_range_decompression_torch(x, C=1):
44
  return torch.exp(x) / C
45
 
46
+
47
  def spectral_normalize_torch(magnitudes):
48
+ return dynamic_range_compression_torch(magnitudes)
49
+
50
 
51
  def spectral_de_normalize_torch(magnitudes):
52
+ return dynamic_range_decompression_torch(magnitudes)
53
+
54
+
55
+ mel_basis_cache = {}
56
+ hann_window_cache = {}
57
 
 
 
58
 
59
+ def mel_spectrogram(
60
+ y: torch.Tensor,
61
+ n_fft: int,
62
+ num_mels: int,
63
+ sampling_rate: int,
64
+ hop_size: int,
65
+ win_size: int,
66
+ fmin: int,
67
+ fmax: int = None,
68
+ center: bool = False,
69
+ ) -> torch.Tensor:
70
+ """
71
+ Calculate the mel spectrogram of an input signal.
72
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
73
 
74
+ Args:
75
+ y (torch.Tensor): Input signal.
76
+ n_fft (int): FFT size.
77
+ num_mels (int): Number of mel bins.
78
+ sampling_rate (int): Sampling rate of the input signal.
79
+ hop_size (int): Hop size for STFT.
80
+ win_size (int): Window size for STFT.
81
+ fmin (int): Minimum frequency for mel filterbank.
82
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
83
+ center (bool): Whether to pad the input to center the frames. Default is False.
84
 
85
+ Returns:
86
+ torch.Tensor: Mel spectrogram.
87
+ """
88
+ if torch.min(y) < -1.0:
89
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
90
+ if torch.max(y) > 1.0:
91
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
92
 
93
+ device = y.device
94
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
 
 
 
95
 
96
+ if key not in mel_basis_cache:
97
+ mel = librosa_mel_fn(
98
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
99
+ )
100
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
101
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
102
+
103
+ mel_basis = mel_basis_cache[key]
104
+ hann_window = hann_window_cache[key]
105
+
106
+ padding = (n_fft - hop_size) // 2
107
+ y = torch.nn.functional.pad(
108
+ y.unsqueeze(1), (padding, padding), mode="reflect"
109
+ ).squeeze(1)
110
+
111
+ spec = torch.stft(
112
+ y,
113
+ n_fft,
114
+ hop_length=hop_size,
115
+ win_length=win_size,
116
+ window=hann_window,
117
+ center=center,
118
+ pad_mode="reflect",
119
+ normalized=False,
120
+ onesided=True,
121
+ return_complex=True,
122
+ )
123
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
124
+
125
+ mel_spec = torch.matmul(mel_basis, spec)
126
+ mel_spec = spectral_normalize_torch(mel_spec)
127
+
128
+ return mel_spec
129
 
 
130
 
131
  def get_mel_spectrogram(wav, h):
132
+ """
133
+ Generate mel spectrogram from a waveform using given hyperparameters.
134
+
135
+ Args:
136
+ wav (torch.Tensor): Input waveform.
137
+ h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
138
+
139
+ Returns:
140
+ torch.Tensor: Mel spectrogram.
141
+ """
142
+ return mel_spectrogram(
143
+ wav,
144
+ h.n_fft,
145
+ h.num_mels,
146
+ h.sampling_rate,
147
+ h.hop_size,
148
+ h.win_size,
149
+ h.fmin,
150
+ h.fmax,
151
+ )
152
+
153
+
154
+ def get_dataset_filelist(a):
155
+ training_files = []
156
+ validation_files = []
157
+ list_unseen_validation_files = []
158
+
159
+ with open(a.input_training_file, "r", encoding="utf-8") as fi:
160
+ training_files = [
161
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
162
+ for x in fi.read().split("\n")
163
+ if len(x) > 0
164
+ ]
165
+ print(f"first training file: {training_files[0]}")
166
+
167
+ with open(a.input_validation_file, "r", encoding="utf-8") as fi:
168
+ validation_files = [
169
+ os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
170
+ for x in fi.read().split("\n")
171
+ if len(x) > 0
172
+ ]
173
+ print(f"first validation file: {validation_files[0]}")
174
+
175
+ for i in range(len(a.list_input_unseen_validation_file)):
176
+ with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
177
+ unseen_validation_files = [
178
+ os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
179
+ for x in fi.read().split("\n")
180
+ if len(x) > 0
181
+ ]
182
+ print(
183
+ f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
184
+ )
185
+ list_unseen_validation_files.append(unseen_validation_files)
186
+
187
+ return training_files, validation_files, list_unseen_validation_files
188
+
189
+
190
+ class MelDataset(torch.utils.data.Dataset):
191
+ def __init__(
192
+ self,
193
+ training_files,
194
+ hparams,
195
+ segment_size,
196
+ n_fft,
197
+ num_mels,
198
+ hop_size,
199
+ win_size,
200
+ sampling_rate,
201
+ fmin,
202
+ fmax,
203
+ split=True,
204
+ shuffle=True,
205
+ n_cache_reuse=1,
206
+ device=None,
207
+ fmax_loss=None,
208
+ fine_tuning=False,
209
+ base_mels_path=None,
210
+ is_seen=True,
211
+ ):
212
+ self.audio_files = training_files
213
+ random.seed(1234)
214
+ if shuffle:
215
+ random.shuffle(self.audio_files)
216
+ self.hparams = hparams
217
+ self.is_seen = is_seen
218
+ if self.is_seen:
219
+ self.name = pathlib.Path(self.audio_files[0]).parts[0]
220
+ else:
221
+ self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
222
+
223
+ self.segment_size = segment_size
224
+ self.sampling_rate = sampling_rate
225
+ self.split = split
226
+ self.n_fft = n_fft
227
+ self.num_mels = num_mels
228
+ self.hop_size = hop_size
229
+ self.win_size = win_size
230
+ self.fmin = fmin
231
+ self.fmax = fmax
232
+ self.fmax_loss = fmax_loss
233
+ self.cached_wav = None
234
+ self.n_cache_reuse = n_cache_reuse
235
+ self._cache_ref_count = 0
236
+ self.device = device
237
+ self.fine_tuning = fine_tuning
238
+ self.base_mels_path = base_mels_path
239
+
240
+ print("[INFO] checking dataset integrity...")
241
+ for i in tqdm(range(len(self.audio_files))):
242
+ assert os.path.exists(
243
+ self.audio_files[i]
244
+ ), f"{self.audio_files[i]} not found"
245
+
246
+ def __getitem__(self, index):
247
+ filename = self.audio_files[index]
248
+ if self._cache_ref_count == 0:
249
+ audio, sampling_rate = load_wav(filename, self.sampling_rate)
250
+ audio = audio / MAX_WAV_VALUE
251
+ if not self.fine_tuning:
252
+ audio = normalize(audio) * 0.95
253
+ self.cached_wav = audio
254
+ if sampling_rate != self.sampling_rate:
255
+ raise ValueError(
256
+ f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
257
+ )
258
+ self._cache_ref_count = self.n_cache_reuse
259
+ else:
260
+ audio = self.cached_wav
261
+ self._cache_ref_count -= 1
262
+
263
+ audio = torch.FloatTensor(audio)
264
+ audio = audio.unsqueeze(0)
265
+
266
+ if not self.fine_tuning:
267
+ if self.split:
268
+ if audio.size(1) >= self.segment_size:
269
+ max_audio_start = audio.size(1) - self.segment_size
270
+ audio_start = random.randint(0, max_audio_start)
271
+ audio = audio[:, audio_start : audio_start + self.segment_size]
272
+ else:
273
+ audio = torch.nn.functional.pad(
274
+ audio, (0, self.segment_size - audio.size(1)), "constant"
275
+ )
276
+
277
+ mel = mel_spectrogram(
278
+ audio,
279
+ self.n_fft,
280
+ self.num_mels,
281
+ self.sampling_rate,
282
+ self.hop_size,
283
+ self.win_size,
284
+ self.fmin,
285
+ self.fmax,
286
+ center=False,
287
+ )
288
+ else: # Validation step
289
+ # Match audio length to self.hop_size * n for evaluation
290
+ if (audio.size(1) % self.hop_size) != 0:
291
+ audio = audio[:, : -(audio.size(1) % self.hop_size)]
292
+ mel = mel_spectrogram(
293
+ audio,
294
+ self.n_fft,
295
+ self.num_mels,
296
+ self.sampling_rate,
297
+ self.hop_size,
298
+ self.win_size,
299
+ self.fmin,
300
+ self.fmax,
301
+ center=False,
302
+ )
303
+ assert (
304
+ audio.shape[1] == mel.shape[2] * self.hop_size
305
+ ), f"audio shape {audio.shape} mel shape {mel.shape}"
306
+
307
+ else:
308
+ mel = np.load(
309
+ os.path.join(
310
+ self.base_mels_path,
311
+ os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
312
+ )
313
+ )
314
+ mel = torch.from_numpy(mel)
315
+
316
+ if len(mel.shape) < 3:
317
+ mel = mel.unsqueeze(0)
318
+
319
+ if self.split:
320
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
321
+
322
+ if audio.size(1) >= self.segment_size:
323
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
324
+ mel = mel[:, :, mel_start : mel_start + frames_per_seg]
325
+ audio = audio[
326
+ :,
327
+ mel_start
328
+ * self.hop_size : (mel_start + frames_per_seg)
329
+ * self.hop_size,
330
+ ]
331
+ else:
332
+ mel = torch.nn.functional.pad(
333
+ mel, (0, frames_per_seg - mel.size(2)), "constant"
334
+ )
335
+ audio = torch.nn.functional.pad(
336
+ audio, (0, self.segment_size - audio.size(1)), "constant"
337
+ )
338
+
339
+ mel_loss = mel_spectrogram(
340
+ audio,
341
+ self.n_fft,
342
+ self.num_mels,
343
+ self.sampling_rate,
344
+ self.hop_size,
345
+ self.win_size,
346
+ self.fmin,
347
+ self.fmax_loss,
348
+ center=False,
349
+ )
350
+
351
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
352
+
353
+ def __len__(self):
354
+ return len(self.audio_files)
utils.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  import matplotlib
7
  import torch
8
  from torch.nn.utils import weight_norm
 
9
  matplotlib.use("Agg")
10
  import matplotlib.pylab as plt
11
  from meldataset import MAX_WAV_VALUE
@@ -14,8 +15,7 @@ from scipy.io.wavfile import write
14
 
15
  def plot_spectrogram(spectrogram):
16
  fig, ax = plt.subplots(figsize=(10, 2))
17
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
18
- interpolation='none')
19
  plt.colorbar(im, ax=ax)
20
 
21
  fig.canvas.draw()
@@ -24,10 +24,16 @@ def plot_spectrogram(spectrogram):
24
  return fig
25
 
26
 
27
- def plot_spectrogram_clipped(spectrogram, clip_max=2.):
28
  fig, ax = plt.subplots(figsize=(10, 2))
29
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
30
- interpolation='none', vmin=1e-6, vmax=clip_max)
 
 
 
 
 
 
31
  plt.colorbar(im, ax=ax)
32
 
33
  fig.canvas.draw()
@@ -49,32 +55,45 @@ def apply_weight_norm(m):
49
 
50
 
51
  def get_padding(kernel_size, dilation=1):
52
- return int((kernel_size*dilation - dilation)/2)
53
 
54
 
55
  def load_checkpoint(filepath, device):
56
  assert os.path.isfile(filepath)
57
- print("Loading '{}'".format(filepath))
58
  checkpoint_dict = torch.load(filepath, map_location=device)
59
  print("Complete.")
60
  return checkpoint_dict
61
 
62
 
63
  def save_checkpoint(filepath, obj):
64
- print("Saving checkpoint to {}".format(filepath))
65
  torch.save(obj, filepath)
66
  print("Complete.")
67
 
68
 
69
- def scan_checkpoint(cp_dir, prefix):
70
- pattern = os.path.join(cp_dir, prefix + '????????')
 
71
  cp_list = glob.glob(pattern)
72
- if len(cp_list) == 0:
73
- return None
74
- return sorted(cp_list)[-1]
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def save_audio(audio, path, sr):
77
  # wav: torch with 1d shape
78
  audio = audio * MAX_WAV_VALUE
79
- audio = audio.cpu().numpy().astype('int16')
80
- write(path, sr, audio)
 
6
  import matplotlib
7
  import torch
8
  from torch.nn.utils import weight_norm
9
+
10
  matplotlib.use("Agg")
11
  import matplotlib.pylab as plt
12
  from meldataset import MAX_WAV_VALUE
 
15
 
16
  def plot_spectrogram(spectrogram):
17
  fig, ax = plt.subplots(figsize=(10, 2))
18
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
 
19
  plt.colorbar(im, ax=ax)
20
 
21
  fig.canvas.draw()
 
24
  return fig
25
 
26
 
27
+ def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
28
  fig, ax = plt.subplots(figsize=(10, 2))
29
+ im = ax.imshow(
30
+ spectrogram,
31
+ aspect="auto",
32
+ origin="lower",
33
+ interpolation="none",
34
+ vmin=1e-6,
35
+ vmax=clip_max,
36
+ )
37
  plt.colorbar(im, ax=ax)
38
 
39
  fig.canvas.draw()
 
55
 
56
 
57
  def get_padding(kernel_size, dilation=1):
58
+ return int((kernel_size * dilation - dilation) / 2)
59
 
60
 
61
  def load_checkpoint(filepath, device):
62
  assert os.path.isfile(filepath)
63
+ print(f"Loading '{filepath}'")
64
  checkpoint_dict = torch.load(filepath, map_location=device)
65
  print("Complete.")
66
  return checkpoint_dict
67
 
68
 
69
  def save_checkpoint(filepath, obj):
70
+ print(f"Saving checkpoint to {filepath}")
71
  torch.save(obj, filepath)
72
  print("Complete.")
73
 
74
 
75
+ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
76
+ # Fallback to original scanning logic first
77
+ pattern = os.path.join(cp_dir, prefix + "????????")
78
  cp_list = glob.glob(pattern)
79
+
80
+ if len(cp_list) > 0:
81
+ last_checkpoint_path = sorted(cp_list)[-1]
82
+ print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
83
+ return last_checkpoint_path
84
+
85
+ # If no pattern-based checkpoints are found, check for renamed file
86
+ if renamed_file:
87
+ renamed_path = os.path.join(cp_dir, renamed_file)
88
+ if os.path.isfile(renamed_path):
89
+ print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
90
+ return renamed_path
91
+
92
+ return None
93
+
94
 
95
  def save_audio(audio, path, sr):
96
  # wav: torch with 1d shape
97
  audio = audio * MAX_WAV_VALUE
98
+ audio = audio.cpu().numpy().astype("int16")
99
+ write(path, sr, audio)